from typing import Dict, List, Optional, Tuple, Union

import torch

from chunkformer.modules.ctc import CTC
from chunkformer.modules.decoder import BiTransformerDecoder, TransformerDecoder
from chunkformer.modules.encoder import ChunkFormerEncoder
from chunkformer.modules.label_smoothing_loss import LabelSmoothingLoss
from chunkformer.modules.search import (
    DecodeResult,
    attention_beam_search,
    attention_rescoring,
    ctc_greedy_search,
    ctc_prefix_beam_search,
)
from chunkformer.transducer.search.greedy_search import greedy_search
from chunkformer.utils.common import (
    IGNORE_ID,
    add_sos_eos,
    pad_sequence,
    reverse_pad_list,
    th_accuracy,
)
from chunkformer.utils.context_graph import ContextGraph
from chunkformer.utils.mask import make_pad_mask


class ASRModel(torch.nn.Module):
    """CTC-attention hybrid Encoder-Decoder model"""

    def __init__(
        self,
        vocab_size: int,
        encoder: ChunkFormerEncoder,
        decoder: Optional[Union[TransformerDecoder, BiTransformerDecoder]] = None,
        ctc: Optional[CTC] = None,
        ctc_weight: float = 0.5,
        ignore_id: int = IGNORE_ID,
        reverse_weight: float = 0.0,
        lsm_weight: float = 0.0,
        length_normalized_loss: bool = False,
        special_tokens: Optional[dict] = None,
        apply_non_blank_embedding: bool = False,
    ):
        assert 0.0 <= ctc_weight <= 1.0, ctc_weight

        super().__init__()
        # note that eos is the same as sos (equivalent ID)
        self.sos = (
            vocab_size - 1
            if special_tokens is None
            else special_tokens.get("<sos>", vocab_size - 1)
        )
        self.eos = (
            vocab_size - 1
            if special_tokens is None
            else special_tokens.get("<eos>", vocab_size - 1)
        )
        self.vocab_size = vocab_size
        self.special_tokens = special_tokens
        self.ignore_id = ignore_id
        self.ctc_weight = ctc_weight
        self.reverse_weight = reverse_weight
        self.apply_non_blank_embedding = apply_non_blank_embedding

        self.encoder = encoder
        self.decoder = decoder
        self.ctc = ctc
        self.criterion_att = LabelSmoothingLoss(
            size=vocab_size,
            padding_idx=ignore_id,
            smoothing=lsm_weight,
            normalize_length=length_normalized_loss,
        )

    @torch.jit.unused
    def forward(
        self,
        batch: dict,
        device: torch.device,
    ) -> Dict[str, Optional[torch.Tensor]]:
        """Frontend + Encoder + Decoder + Calc loss"""
        speech = batch["feats"].to(device)
        speech_lengths = batch["feats_lengths"].to(device)
        text = batch["target"].to(device)
        text_lengths = batch["target_lengths"].to(device)

        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (
            speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]
        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
        # 1. Encoder
        encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
        encoder_out_lens = encoder_mask.squeeze(1).sum(1)

        # 2a. CTC branch
        if self.ctc_weight != 0.0 and self.ctc is not None:
            loss_ctc, ctc_probs = self.ctc(encoder_out, encoder_out_lens, text, text_lengths)
        else:
            loss_ctc, ctc_probs = None, None

        # 2b. Attention-decoder branch
        # use non blank (token level) embedding for decoder
        if self.apply_non_blank_embedding:
            assert self.ctc_weight != 0
            assert ctc_probs is not None
            encoder_out, encoder_mask = self.filter_blank_embedding(ctc_probs, encoder_out)
        if self.ctc_weight != 1.0:
            loss_att, acc_att = self._calc_att_loss(
                encoder_out,
                encoder_mask,
                text,
                text_lengths,
            )
        else:
            loss_att = None
            acc_att = None

        if loss_ctc is None:
            loss = loss_att
        elif loss_att is None:
            loss = loss_ctc
        else:
            loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
        return {
            "loss": loss,
            "loss_att": loss_att,
            "loss_ctc": loss_ctc,
            "th_accuracy": acc_att,
        }

    def tie_or_clone_weights(self, jit_mode: bool = True):
        assert self.decoder is not None, "decoder is not initialized!"
        self.decoder.tie_or_clone_weights(jit_mode)

    @torch.jit.unused
    def _forward_ctc(
        self,
        encoder_out: torch.Tensor,
        encoder_mask: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        assert self.ctc is not None, "ctc is not initialized!"
        encoder_out_lens = encoder_mask.squeeze(1).sum(1)
        loss_ctc, ctc_probs = self.ctc(encoder_out, encoder_out_lens, text, text_lengths)
        return loss_ctc, ctc_probs

    def filter_blank_embedding(
        self, ctc_probs: torch.Tensor, encoder_out: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = encoder_out.size(0)
        maxlen = encoder_out.size(1)
        top1_index = torch.argmax(ctc_probs, dim=2)
        indices = []
        for j in range(batch_size):
            indices.append(torch.tensor([i for i in range(maxlen) if top1_index[j][i] != 0]))

        select_encoder_out_list = [
            torch.index_select(encoder_out[i, :, :], 0, indices[i].to(encoder_out.device))
            for i in range(batch_size)
        ]
        select_encoder_out: torch.Tensor = pad_sequence(
            select_encoder_out_list, batch_first=True, padding_value=0
        ).to(encoder_out.device)
        xs_lens = torch.tensor([len(indices[i]) for i in range(batch_size)]).to(encoder_out.device)
        T = select_encoder_out.size(1)
        encoder_mask = ~make_pad_mask(xs_lens, T).unsqueeze(1)  # (B, 1, T)
        encoder_out = select_encoder_out
        return encoder_out, encoder_mask

    def _calc_att_loss(
        self,
        encoder_out: torch.Tensor,
        encoder_mask: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
        infos: Optional[Dict[str, List[str]]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        assert self.decoder is not None, "decoder is not initialized!"
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
        ys_in_lens = ys_pad_lens + 1

        # reverse the seq, used for right to left decoder
        r_ys_pad = reverse_pad_list(ys_pad, ys_pad_lens, float(self.ignore_id))
        r_ys_in_pad, r_ys_out_pad = add_sos_eos(r_ys_pad, self.sos, self.eos, self.ignore_id)
        # 1. Forward decoder
        decoder_out, r_decoder_out, _ = self.decoder(
            encoder_out, encoder_mask, ys_in_pad, ys_in_lens, r_ys_in_pad, self.reverse_weight
        )
        # 2. Compute attention loss
        loss_att = self.criterion_att(decoder_out, ys_out_pad)
        r_loss_att = torch.tensor(0.0)
        if self.reverse_weight > 0.0:
            r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad)
        loss_att = loss_att * (1 - self.reverse_weight) + r_loss_att * self.reverse_weight
        acc_att = th_accuracy(
            decoder_out.view(-1, self.vocab_size),
            ys_out_pad,
            ignore_label=self.ignore_id,
        )
        return loss_att, acc_att

    def _forward_encoder(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        chunk_size: int = -1,
        left_context_size: int = -1,
        right_context_size: int = -1,
        simulate_streaming: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Let's assume B = batch_size
        # 1. Encoder
        if simulate_streaming and chunk_size > 0:
            encoder_out, encoder_mask = self.encoder.forward_chunk_by_chunk(
                speech,
                speech_lengths,
                chunk_size=chunk_size,
                left_context_size=left_context_size,
                right_context_size=right_context_size,
            )  # (B, maxlen, encoder_dim)
        else:
            encoder_out, encoder_mask = self.encoder(
                speech,
                speech_lengths,
                chunk_size=chunk_size,
                left_context_size=left_context_size,
                right_context_size=right_context_size,
            )
        return encoder_out, encoder_mask

    # The same interface just like whisper
    # see https://github.com/openai/whisper/blob/main/whisper/model.py#L287
    def embed_audio(
        self,
        mel: torch.Tensor,
        mel_len: torch.Tensor,
        chunk_size: int = -1,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        encoder_out, encoder_mask = self.encoder(mel, mel_len, chunk_size)
        return encoder_out, encoder_mask

    @torch.jit.unused
    def ctc_logprobs(
        self, encoder_out: torch.Tensor, blank_penalty: float = 0.0, blank_id: int = 0
    ):
        assert self.ctc is not None, "ctc is not initialized!"
        if blank_penalty > 0.0:
            logits = self.ctc.ctc_lo(encoder_out)
            logits[:, :, blank_id] -= blank_penalty
            ctc_probs = logits.log_softmax(dim=2)
        else:
            ctc_probs = self.ctc.log_softmax(encoder_out)

        return ctc_probs

    def decode(
        self,
        methods: List[str],
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        beam_size: int,
        chunk_size: int = -1,
        left_context_size: int = -1,
        right_context_size: int = -1,
        ctc_weight: float = 0.0,
        simulate_streaming: bool = False,
        reverse_weight: float = 0.0,
        context_graph: Optional[ContextGraph] = None,
        blank_id: int = 0,
        blank_penalty: float = 0.0,
        length_penalty: float = 0.0,
        infos: Optional[Dict[str, List[str]]] = None,
    ) -> Dict[str, List[DecodeResult]]:
        """Decode input speech

        Args:
            methods:(List[str]): list of decoding methods to use, which could
                could contain the following decoding methods, please refer paper:
                https://arxiv.org/pdf/2102.01547.pdf
                   * ctc_greedy_search
                   * ctc_prefix_beam_search
                   * atttention
                   * attention_rescoring
            speech (torch.Tensor): (batch, max_len, feat_dim)
            speech_length (torch.Tensor): (batch, )
            beam_size (int): beam size for beam search
            decoding_chunk_size (int): decoding chunk for dynamic chunk
                trained model.
                <0: for decoding, use full chunk.
                >0: for decoding, use fixed chunk size as set.
                0: used for training, it's prohibited here
            simulate_streaming (bool): whether do encoder forward in a
                streaming fashion
            reverse_weight (float): right to left decoder weight
            ctc_weight (float): ctc score weight

        Returns: dict results of all decoding methods
        """
        assert len(speech) == len(speech_lengths)
        assert chunk_size != 0
        encoder_out, encoder_mask = self._forward_encoder(
            speech,
            speech_lengths,
            chunk_size,
            left_context_size,
            right_context_size,
            simulate_streaming,
        )
        encoder_lens = encoder_mask.squeeze(1).sum(1)
        ctc_probs = self.ctc_logprobs(encoder_out, blank_penalty, blank_id)
        results = {}
        if "attention" in methods:
            results["attention"] = attention_beam_search(
                self, encoder_out, encoder_mask, beam_size, length_penalty, infos
            )
        if "ctc_greedy_search" in methods:
            results["ctc_greedy_search"] = ctc_greedy_search(ctc_probs, encoder_lens, blank_id)
        if "ctc_prefix_beam_search" in methods:
            ctc_prefix_result = ctc_prefix_beam_search(
                ctc_probs, encoder_lens, beam_size, context_graph, blank_id
            )
            results["ctc_prefix_beam_search"] = ctc_prefix_result
        if "attention_rescoring" in methods:
            # attention_rescoring depends on ctc_prefix_beam_search nbest
            if "ctc_prefix_beam_search" in results:
                ctc_prefix_result = results["ctc_prefix_beam_search"]
            else:
                ctc_prefix_result = ctc_prefix_beam_search(
                    ctc_probs, encoder_lens, beam_size, context_graph, blank_id
                )
            if self.apply_non_blank_embedding:
                encoder_out, _ = self.filter_blank_embedding(ctc_probs, encoder_out)
            results["attention_rescoring"] = attention_rescoring(
                self,
                ctc_prefix_result,
                encoder_out,
                encoder_lens,
                ctc_weight,
                reverse_weight,
                infos,
            )
        if "rnnt_greedy_search" in methods:
            results["rnnt_greedy_search"] = [
                DecodeResult(hyp) for hyp in greedy_search(self, encoder_out, encoder_lens)
            ]
        return results

    @torch.jit.export
    def subsampling_rate(self) -> int:
        """Export interface for c++ call, return subsampling_rate of the
        model
        """
        return int(self.encoder.embed.subsampling_rate)

    @torch.jit.export
    def right_context(self) -> int:
        """Export interface for c++ call, return right_context of the model"""
        return self.encoder.embed.right_context

    @torch.jit.export
    def sos_symbol(self) -> int:
        """Export interface for c++ call, return sos symbol id of the model"""
        return self.sos

    @torch.jit.export
    def eos_symbol(self) -> int:
        """Export interface for c++ call, return eos symbol id of the model"""
        return self.eos

    @torch.jit.export
    def ctc_activation(self, xs: torch.Tensor) -> torch.Tensor:
        """Export interface for c++ call, apply linear transform and log
            softmax before ctc
        Args:
            xs (torch.Tensor): encoder output

        Returns:
            torch.Tensor: activation before ctc

        """
        assert self.ctc is not None, "ctc is not initialized!"
        return self.ctc.log_softmax(xs)

    @torch.jit.export
    def is_bidirectional_decoder(self) -> bool:
        """
        Returns:
            torch.Tensor: decoder output
        """
        if hasattr(self.decoder, "right_decoder"):
            return True
        else:
            return False

    @torch.jit.export
    def forward_attention_decoder(
        self,
        hyps: torch.Tensor,
        hyps_lens: torch.Tensor,
        encoder_out: torch.Tensor,
        reverse_weight: float = 0,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Export interface for c++ call, forward decoder with multiple
            hypothesis from ctc prefix beam search and one encoder output
        Args:
            hyps (torch.Tensor): hyps from ctc prefix beam search, already
                pad sos at the begining
            hyps_lens (torch.Tensor): length of each hyp in hyps
            encoder_out (torch.Tensor): corresponding encoder output
            r_hyps (torch.Tensor): hyps from ctc prefix beam search, already
                pad eos at the begining which is used fo right to left decoder
            reverse_weight: used for verfing whether used right to left decoder,
            > 0 will use.

        Returns:
            torch.Tensor: decoder output
        """
        assert self.decoder is not None, "decoder is not initialized!"
        assert encoder_out.size(0) == 1
        num_hyps = hyps.size(0)
        assert hyps_lens.size(0) == num_hyps
        encoder_out = encoder_out.repeat(num_hyps, 1, 1)
        encoder_mask = torch.ones(
            num_hyps, 1, encoder_out.size(1), dtype=torch.bool, device=encoder_out.device
        )

        # input for right to left decoder
        # this hyps_lens has count <sos> token, we need minus it.
        r_hyps_lens = hyps_lens - 1
        # this hyps has included <sos> token, so it should be
        # convert the original hyps.
        r_hyps = hyps[:, 1:]
        #   >>> r_hyps
        #   >>> tensor([[ 1,  2,  3],
        #   >>>         [ 9,  8,  4],
        #   >>>         [ 2, -1, -1]])
        #   >>> r_hyps_lens
        #   >>> tensor([3, 3, 1])

        # NOTE(Mddct): `pad_sequence` is not supported by ONNX, it is used
        #   in `reverse_pad_list` thus we have to refine the below code.
        #   Issue: https://github.com/wenet-e2e/wenet/issues/1113
        # Equal to:
        #   >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id))
        #   >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id)
        max_len = torch.max(r_hyps_lens)
        index_range = torch.arange(0, max_len, 1).to(encoder_out.device)
        seq_len_expand = r_hyps_lens.unsqueeze(1)
        seq_mask = seq_len_expand > index_range  # (beam, max_len)
        #   >>> seq_mask
        #   >>> tensor([[ True,  True,  True],
        #   >>>         [ True,  True,  True],
        #   >>>         [ True, False, False]])
        index = (seq_len_expand - 1) - index_range  # (beam, max_len)
        #   >>> index
        #   >>> tensor([[ 2,  1,  0],
        #   >>>         [ 2,  1,  0],
        #   >>>         [ 0, -1, -2]])
        index = index * seq_mask
        #   >>> index
        #   >>> tensor([[2, 1, 0],
        #   >>>         [2, 1, 0],
        #   >>>         [0, 0, 0]])
        r_hyps = torch.gather(r_hyps, 1, index)
        #   >>> r_hyps
        #   >>> tensor([[3, 2, 1],
        #   >>>         [4, 8, 9],
        #   >>>         [2, 2, 2]])
        r_hyps = torch.where(seq_mask, r_hyps, self.eos)
        #   >>> r_hyps
        #   >>> tensor([[3, 2, 1],
        #   >>>         [4, 8, 9],
        #   >>>         [2, eos, eos]])
        r_hyps = torch.cat([hyps[:, 0:1], r_hyps], dim=1)
        #   >>> r_hyps
        #   >>> tensor([[sos, 3, 2, 1],
        #   >>>         [sos, 4, 8, 9],
        #   >>>         [sos, 2, eos, eos]])

        decoder_out, r_decoder_out, _ = self.decoder(
            encoder_out, encoder_mask, hyps, hyps_lens, r_hyps, reverse_weight
        )  # (num_hyps, max_hyps_len, vocab_size)
        decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)

        # right to left decoder may be not used during decoding process,
        # which depends on reverse_weight param.
        # r_dccoder_out will be 0.0, if reverse_weight is 0.0
        r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1)
        return decoder_out, r_decoder_out
