#!/usr/bin/env python
# -*- coding: utf-8 -*-

import copy, math
import numpy as np
from tqdm import tqdm
from loguru import logger
import torch
import torch.nn as nn
from transformers import BertConfig, BertTokenizer
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, BertTokenizer, BertPreTrainedModel

from ..trainer import Trainer
from ...utils import softmax, acc_and_f1
from ...utils.multiprocesses import barrier_leader_process, barrier_member_processes


def gelu(x):
    """Implementation of the gelu activation function.
        For information: OpenAI GPT's gelu is slightly different (and gives
            slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
        Also see https://arxiv.org/abs/1606.08415
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


def swish(x):
    return x * torch.sigmoid(x)


ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}


#  try:
#      from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
#  except ImportError:
#      logger.info("Better speed can be achieved with apex installed from" +
#                  "https://www.github.com/nvidia/apex .")
#
class BertLayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-12):
        """Construct a layernorm module in the TF style
                (epsilon inside the square root).
        """
        super(BertLayerNorm, self).__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.weight * x + self.bias


class BertSelfAttention(nn.Module):
    def __init__(self, config):
        super(BertSelfAttention, self).__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" %
                (config.hidden_size, config.num_attention_heads))
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size /
                                       config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads,
                                       self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, attention_mask):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer,
                                        key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(
            self.attention_head_size)
        # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
        attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (
            self.all_head_size, )
        context_layer = context_layer.view(*new_context_layer_shape)
        return context_layer


class BertSelfOutput(nn.Module):
    def __init__(self, config):
        super(BertSelfOutput, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertAttention(nn.Module):
    def __init__(self, config):
        super(BertAttention, self).__init__()
        self.self = BertSelfAttention(config)
        self.output = BertSelfOutput(config)

    def forward(self, input_tensor, attention_mask):
        self_output = self.self(input_tensor, attention_mask)
        attention_output = self.output(self_output, input_tensor)
        return attention_output


class BertIntermediate(nn.Module):
    def __init__(self, config):
        super(BertIntermediate, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        if isinstance(config.hidden_act, str):
            # or (sys.version_info[0] == 2 and isinstance(hidden_act, unicode)):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class BertOutput(nn.Module):
    def __init__(self, config):
        super(BertOutput, self).__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertLayer(nn.Module):
    def __init__(self, config):
        super(BertLayer, self).__init__()
        self.attention = BertAttention(config)
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)

    def forward(self, hidden_states, attention_mask):
        attention_output = self.attention(hidden_states, attention_mask)
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output


class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings.
    """
    def __init__(self, config):
        super(BertEmbeddings, self).__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size,
                                            config.hidden_size,
                                            padding_idx=0)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings,
                                                config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
                                                  config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model
        #variable name and be able to load any TensorFlow checkpoint file
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, input_ids, token_type_ids=None):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length,
                                    dtype=torch.long,
                                    device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = words_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class BertEncoder(nn.Module):
    def __init__(self, config):
        super(BertEncoder, self).__init__()
        layer = BertLayer(config)
        self.layer = nn.ModuleList(
            [copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])

    def forward(self,
                hidden_states,
                attention_mask,
                output_all_encoded_layers=True):
        all_encoder_layers = []
        for layer_module in self.layer:
            hidden_states = layer_module(hidden_states, attention_mask)
            if output_all_encoded_layers:
                all_encoder_layers.append(hidden_states)
        if not output_all_encoded_layers:
            all_encoder_layers.append(hidden_states)
        return all_encoder_layers


class BertPooler(nn.Module):
    def __init__(self, config):
        super(BertPooler, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output


#  class BertPreTrainedModel(nn.Module):
#      """ An abstract class to handle weights initialization and
#          a simple interface for dowloading and loading pretrained models.
#      """
#      def __init__(self, config, *inputs, **kwargs):
#          super(BertPreTrainedModel, self).__init__()
#          if not isinstance(config, BertConfig):
#              raise ValueError(
#                  "Parameter config in `{}(config)` should be an instance of "
#                  "class `BertConfig`. "
#                  "To create a model from a Google pretrained model use "
#                  "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
#                      self.__class__.__name__, self.__class__.__name__))
#          self.config = config
#
#      def init_bert_weights(self, module):
#          """ Initialize the weights.
#          """
#          if isinstance(module, (nn.Linear, nn.Embedding)):
#              # Slightly different from the TF version which uses
#              # truncated_normal for initialization
#              # cf https://github.com/pytorch/pytorch/pull/5617
#              module.weight.data.normal_(mean=0.0,
#                                         std=self.config.initializer_range)
#          elif isinstance(module, BertLayerNorm):
#              module.bias.data.zero_()
#              module.weight.data.fill_(1.0)
#          if isinstance(module, nn.Linear) and module.bias is not None:
#              module.bias.data.zero_()
#          return


class BertModel(BertPreTrainedModel):
    """BERT model ("Bidirectional Embedding Representations from a Transformer").

    Params:
        config: a BertConfig class instance with the configuration to build a new model

    Inputs:
        `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
            with the word token indices in the vocabulary(see the tokens preprocessing
            logic in the scripts `extract_features.py`, `run_classifier.py` and
            `run_squad.py`)
        `token_type_ids`: an optional torch.LongTensor of shape
            [batch_size, sequence_length] with the token types indices selected
            in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
            a `sentence B` token (see BERT paper for more details).
        `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
            selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
            input sequence length in the current batch. It's the mask that we typically use for attention when
            a batch has varying length sentences.
        `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.

    Outputs: Tuple of (encoded_layers, pooled_output)
        `encoded_layers`: controled by `output_all_encoded_layers` argument:
            - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
                of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
                encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
            - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
                to the last attention block of shape [batch_size, sequence_length, hidden_size],
        `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
            classifier pretrained on top of the hidden state associated to the first character of the
            input (`CLS`) to train on the Next-Sentence task (see BERT's paper).

    Example usage:
    ```python
    # Already been converted into WordPiece token ids
    input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
    input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
    token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])

    config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
        num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)

    model = modeling.BertModel(config=config)
    all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
    ```
    """
    def __init__(self, config):
        super(BertModel, self).__init__(config)
        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config)
        self.apply(self.init_bert_weights)

    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                output_all_encoded_layers=True):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(
            dtype=next(self.parameters()).dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.embeddings(input_ids, token_type_ids)
        encoded_layers = self.encoder(
            embedding_output,
            extended_attention_mask,
            output_all_encoded_layers=output_all_encoded_layers)
        sequence_output = encoded_layers[-1]
        pooled_output = self.pooler(sequence_output)
        if not output_all_encoded_layers:
            encoded_layers = encoded_layers[-1]
        return encoded_layers, pooled_output

    def init_bert_weights(self, module):
        """ Initialize the weights.
        """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses
            # truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0,
                                       std=self.config.initializer_range)
        elif isinstance(module, BertLayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()
        return


class BertSentenceEmbedding(BertPreTrainedModel):
    """BERT model for sentence embedding.
    This module is composed of the BERT model with a linear layer on top of
    the pooled output. The output of the model is the embedding for the
    input sentence, which can be used for cosine similarity compuation.

    Params:
        `config`: a BertConfig class instance with the configuration to
                build a new model.
        `emb_size`: size of the sentence embedding vector.

    Inputs:
        `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
            with the word token indices in the vocabulary.
        `attention_mask`: a torch.LongTensor of shape
            [batch_size, sequence_length] with indices selected in [0, 1].
            It's a mask to be used if the input sequence length is smaller than
            the max input sequence length in the current batch. It's the mask
            that we typically use for attention when a batch has varying length
            sentences.

    Outputs:
        Outputs the classification logits of shape [batch_size, emb_size].

    Example usage:
    ```python
    # Already been converted into WordPiece token ids
    input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
    input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])

    config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
        num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)

    emb_size = 1024

    model = BertSentenceEmbedding(config, emb_size)
    emb = model(input_ids, input_mask)
    ```
    """
    def __init__(self, config, emb_size):
        super(BertSentenceEmbedding, self).__init__(config)
        self.emb_size = emb_size
        self.bert = BertModel(config)
        self.emb = nn.Linear(config.hidden_size, emb_size)
        self.activation = nn.Tanh()
        self.cos_fn = nn.CosineSimilarity(dim=1, eps=1e-6)
        self.apply(self.init_bert_weights)

    def calcSim(self, emb1, emb2):
        return self.cos_fn(emb1, emb2)

    def forward(self, input_ids, attention_mask):
        _, pooled_output = self.bert(input_ids,
                                     None,
                                     attention_mask,
                                     output_all_encoded_layers=False)
        emb = self.activation(self.emb(pooled_output))
        return emb

    def init_bert_weights(self, module):
        """ Initialize the weights.
        """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses
            # truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0,
                                       std=self.config.initializer_range)
        elif isinstance(module, BertLayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()
        return


class ClassReport(object):
    def __init__(self, target_names=None, labels=None):
        self.target_names = target_names
        self.labels = labels

    def __call__(self, output, target):
        _, y_pred = torch.max(output.data, 1)
        y_pred = y_pred.cpu().numpy()
        y_true = target.cpu().numpy()
        #  print(y_pred, y_true, self.labels, self.target_names)
        classify_report = classification_report(y_true,
                                                y_pred,
                                                labels=self.labels,
                                                target_names=self.target_names)
        print('\n\nclassify_report:\n', classify_report)


MODEL_CLASSES = {
    'bert': (AutoConfig, BertSentenceEmbedding, BertTokenizer),
}


def load_pretrained_tokenizer(args):
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(
        args.model_path,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )

    return tokenizer


def load_pretrained_model(args):
    # make sure only the first process in distributed training
    # will download model & vocab
    barrier_member_processes(args)

    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    config = config_class.from_pretrained(
        args.model_path,
        num_labels=args.num_labels,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    logger.info(f"model_path: {args.model_path}")
    logger.info(f"config:{config}")
    model = model_class.from_pretrained(
        args.model_path,
        from_tf=bool(".ckpt" in args.model_path),
        config=config,
        cache_dir=args.cache_dir if args.cache_dir else None,
        emb_size=1024)
    #  label2id=label2id,
    #  device=args.device)

    # make sure only the first process in distributed training
    # will download model & vocab
    barrier_leader_process(args)

    return model


class SenembTrainer(Trainer):
    def __init__(self):
        super(SenembTrainer, self).__init__()
        self.logits = None
        self.preds = None
        self.probs = None

    #  def package_inputs_from_batch(self, args, batch, known_labels=True):
    #      inputs = {
    #          "input_ids": batch[0],
    #          "attention_mask": batch[1],
    #          "labels": batch[3]
    #      }
    #      if args.model_type != "distilbert":
    #          # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
    #          inputs["token_type_ids"] = (batch[2] if args.model_type in [
    #              "bert", "xlnet", "albert"
    #          ] else None)
    #      return inputs

    def on_train_step(self, args, model, step, batch):
        def getLossFunction(args):
            cos_fn = nn.CosineSimilarity(dim=1, eps=1e-6).to(args.device)
            l1_loss = nn.L1Loss(reduction='mean')

            def cosine_loss(v1, v2, label):
                sim_t = cos_fn(v1, v2)
                d = l1_loss(sim_t, label)
                return d

            return cosine_loss

        input_ids1, input_mask1, input_ids2, input_mask2, label_ids = batch
        logits1 = model(input_ids1, input_mask1)
        logits2 = model(input_ids2, input_mask2)
        loss_fn = getLossFunction(args)
        loss = loss_fn(logits1, logits2, label_ids)

        return (loss, logits1, logits2)

    def on_eval_start(self, args, eval_dataset):
        self.logits = None
        self.preds = None
        self.out_label_ids = None
        self.results = {}
        self.probs = None

    #  def on_eval_step(self, args, eval_dataset, step, model, inputs, outputs):
    def on_eval_step(self, args, model, step, batch):
        input_ids1, input_mask1, input_ids2, input_mask2, label_ids = batch
        #  eval_loss, logits = outputs[:2]
        eval_loss, logits1, logits2 = self.on_train_step(
            args, model, step, batch)

        cos_fn = nn.CosineSimilarity(dim=1, eps=1e-6)
        logits = cos_fn(logits1, logits2)

        if self.logits is None:
            self.logits = logits.detach().cpu().numpy()
            self.out_label_ids = label_ids.detach().cpu().numpy()
        else:
            self.logits = np.append(self.logits,
                                    logits.detach().cpu().numpy(),
                                    axis=0)
            self.out_label_ids = np.append(self.out_label_ids,
                                           label_ids.detach().cpu().numpy(),
                                           axis=0)

        #  self.preds = np.argmax(self.logits, axis=1)
        #  self.probs = softmax(self.logits)
        self.preds = np.array([1 if x > 0.5 else 0 for x in self.logits])
        self.probs = np.array([x if x > 0.5 else 1 - x for x in self.logits])

        result = acc_and_f1(self.preds, self.out_label_ids)
        self.results.update(result)

        return (eval_loss, ), self.results

    def on_eval_end(self, args, eval_dataset):
        #  self.preds = np.argmax(self.preds, axis=1)
        # for regessions
        #  self.preds = np.squeeze(self.preds)

        logger.info(f"  Num examples = {len(eval_dataset)}")
        logger.info(f"  Batch size = {args.eval_batch_size}")
        logger.info(f"******** Eval results ********")
        for key in self.results.keys():
            logger.info(f" dev: {key} = {self.results[key]:.4f}")

        return self.results

    def on_predict_start(self, args, test_dataset):
        self.logits = None

        self.preds = None
        self.probs = None

    def on_predict_step(self, args, test_dataset, step, model, inputs,
                        outputs):
        _, logits = outputs[:2]
        if self.logits is None:
            self.logits = logits.detach().cpu().numpy()
        else:
            self.logits = np.append(self.logits,
                                    logits.detach().cpu().numpy(),
                                    axis=0)

    def on_predict_end(self, args, test_dataset):
        self.preds = np.argmax(self.logits, axis=1)
        self.probs = softmax(self.logits)
        logger.debug(self.preds)
        logger.debug(self.probs)

    def encode(self, args, model, encode_dataset):

        encode_dataloader = self.generate_dataloader(args,
                                                     encode_dataset,
                                                     batch_size=1,
                                                     keep_order=True)

        embeddings = None
        for step, batch in enumerate(encode_dataloader):
            model.eval()
            batch = tuple(t.to(args.device) for t in batch)
            with torch.no_grad():
                input_ids, input_mask = batch
                logits = model(input_ids, input_mask)

                if embeddings is None:
                    embeddings = logits.detach().cpu().numpy()
                else:
                    embeddings = np.append(embeddings,
                                           logits.detach().cpu().numpy(),
                                           axis=0)
        return embeddings
