# Copyright (c) 2021 Wenet Community. (authors: Binbin Zhang)
#               2023 Wenet Community. (authors: Dinghao Zhou)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import json
import logging
import os
import random
from subprocess import PIPE, Popen
from urllib.parse import urlparse

import librosa
import torch
import torch.nn.functional as F
import torchaudio
import torchaudio.compliance.kaldi as kaldi
from torch.nn.utils.rnn import pad_sequence

from chunkformer.text.base_tokenizer import BaseTokenizer

torchaudio.utils.sox_utils.set_buffer_size(16500)

AUDIO_FORMAT_SETS = set(["flac", "mp3", "m4a", "ogg", "opus", "wav", "wma"])


logging.getLogger("langid").setLevel(logging.INFO)

try:
    cpu_info = os.popen("lscpu | grep 'Vendor ID'").read()
    # 0x48 --> HiSilicon
    if cpu_info.rstrip().split(" ")[-1] == "0x48":
        # NOTE (MengqingCao): set number of threads in the subprocesses to 1
        # Why? There may be some operators ultilizing multi-threads in processor,
        # causing possibly deadlock in Kunpeng.
        # Similar issue in PyTorch: https://github.com/pytorch/pytorch/issues/45198
        torch.set_num_threads(1)
except Exception as ex:
    logging.warning(
        f"Failed to set number of thread in Kunpeng, \
        this may cause segmentfault while dataloading, \
        ignore this warning if you are not using Kunpeng: {ex}"
    )


class UrlOpenError(Exception):

    def __init__(self, msg: str, *args: object) -> None:
        super().__init__(*args)
        self.err_msg = msg

    def __str__(self) -> str:
        return self.err_msg


def parse_json(elem):
    line = elem["line"]
    obj = json.loads(line)
    obj["file_name"] = elem["file_name"]
    return dict(obj)


def parse_url(elem):
    assert "file_name" in elem
    assert "line" in elem
    assert isinstance(elem, dict)
    url = elem["line"]
    try:
        pr = urlparse(url)
        # local file
        if pr.scheme == "" or pr.scheme == "file":
            stream = open(url, "rb")
            # network file, such as HTTP(HDFS/OSS/S3)/HTTPS/SCP
        else:
            cmd = f"wget -q -O - {url}"
            process = Popen(cmd, shell=True, stdout=PIPE)
            elem.update(process=process)
            stream = process.stdout
        elem.update(stream=stream)
        return elem
    except Exception as ex:
        err_msg = "Failed to open {}".format(url)
        raise UrlOpenError(err_msg) from ex


def parse_speaker(sample, speaker_dict):
    assert "speaker" in sample
    speaker = sample["speaker"]
    sample["speaker"] = speaker_dict.get(speaker, 0)
    return sample


def decode_wav(sample):
    """Parse key/wav/txt from json line

    Args:
        sample: str, str is a json line has key/wav

    Returns:
        {key, wav, sample_rate, ...}
    """
    assert "key" in sample
    assert "wav" in sample
    wav_file = sample["wav"]
    if isinstance(wav_file, str):
        with open(wav_file, "rb") as f:
            wav_file = f.read()
    if "start" in sample:
        assert "end" in sample
        with io.BytesIO(wav_file) as file_obj:
            sample_rate = torchaudio.info(file_obj).sample_rate
            start_frame = int(sample["start"] * sample_rate)
            end_frame = int(sample["end"] * sample_rate)
            file_obj.seek(0)
            waveform, _ = torchaudio.load(
                file_obj, num_frames=end_frame - start_frame, frame_offset=start_frame
            )
    else:
        with io.BytesIO(wav_file) as file_obj:
            waveform, sample_rate = torchaudio.load(file_obj)
    # del wav_file
    del sample["wav"]
    sample["wav"] = waveform  # overwrite wav
    sample["sample_rate"] = sample_rate
    return sample


def singal_channel(sample, channel=0):
    """Choose a channel of sample.
    Inplace operation.

    Args:
        sample: {key, wav, label, sample_rate}
        channel: target channel index

    Returns:
        {key, wav, label, sample_rate}
    """
    assert "wav" in sample
    waveform = sample["wav"]
    channel_nums = waveform.size(0)
    assert channel < channel_nums
    if channel_nums != 1:
        waveform = waveform[channel, :].unsqueeze(0)
    sample["wav"] = waveform
    return sample


def resample(sample, resample_rate=16000):
    """Resample sample.
    Inplace operation.

    Args:
        sample: {key, wav, label, sample_rate}
        resample_rate: target resample rate

    Returns:
        {key, wav, label, sample_rate}
    """
    assert "sample_rate" in sample
    assert "wav" in sample
    sample_rate = sample["sample_rate"]
    waveform = sample["wav"]
    if sample_rate != resample_rate:
        sample["sample_rate"] = resample_rate
        sample["wav"] = torchaudio.transforms.Resample(
            orig_freq=sample_rate, new_freq=resample_rate
        )(waveform)
    return sample


def speed_perturb(sample, speeds=None):
    """Apply speed perturb to the sample.
    Inplace operation.

    Args:
        sample: {key, wav, label, sample_rate}
        speeds(List[float]): optional speed

    Returns:
        key, wav, label, sample_rate}
    """
    if speeds is None:
        speeds = [0.9, 1.0, 1.1]
    assert "sample_rate" in sample
    assert "wav" in sample
    sample_rate = sample["sample_rate"]
    waveform = sample["wav"]
    speed = random.choice(speeds)
    if speed != 1.0:
        wav, _ = torchaudio.sox_effects.apply_effects_tensor(
            waveform, sample_rate, [["speed", str(speed)], ["rate", str(sample_rate)]]
        )
        sample["wav"] = wav

    return sample


def compute_fbank(
    sample, num_mel_bins=23, frame_length=25, frame_shift=10, dither=0.0, window_type="povey"
):
    """Extract fbank

    Args:
        sample: {key, wav, sample_rate, ...}

    Returns:
        {key, feat, wav, sample_rate, ...}
    """
    assert "sample_rate" in sample
    assert "wav" in sample
    assert "key" in sample
    sample_rate = sample["sample_rate"]
    waveform = sample["wav"]
    waveform = waveform * (1 << 15)
    # Only keep key, feat, label
    mat = kaldi.fbank(
        waveform,
        num_mel_bins=num_mel_bins,
        frame_length=frame_length,
        frame_shift=frame_shift,
        dither=dither,
        energy_floor=0.0,
        sample_frequency=sample_rate,
        window_type=window_type,
    )
    sample["feat"] = mat
    return sample


def compute_w2vbert_fbank(sample, num_mel_bins=23, frame_length=25, frame_shift=10, dither=0.0):
    """Extract Pretrain w2vbert(4.5M hours) fbank"""
    sample = compute_fbank(sample, num_mel_bins, frame_length, frame_shift, dither)
    mat = sample["feat"]
    std, mean = torch.std_mean(mat, dim=0)
    mat = mat.subtract(mean).divide(std)
    sample["feat"] = mat
    return sample


def sort_by_feats(sample):
    assert "feat" in sample
    assert isinstance(sample["feat"], torch.Tensor)
    return sample["feat"].size(0)


def feats_length_fn(sample: dict) -> int:
    assert "feat" in sample
    feat_length: int = sample["feat"].size(0)
    return feat_length


def compute_mfcc(
    sample,
    num_mel_bins=23,
    frame_length=25,
    frame_shift=10,
    dither=0.0,
    num_ceps=40,
    high_freq=0.0,
    low_freq=20.0,
):
    """Extract mfcc

    Args:
        sample: {key, wav, sample_rate, ...}

    Returns:
        {key, wav, feat, sample_rate, ...}
    """
    assert "wav" in sample
    assert "key" in sample
    sample_rate = sample["sample_rate"]
    waveform = sample["wav"]
    waveform = waveform * (1 << 15)
    mat = kaldi.mfcc(
        waveform,
        num_mel_bins=num_mel_bins,
        frame_length=frame_length,
        frame_shift=frame_shift,
        dither=dither,
        num_ceps=num_ceps,
        high_freq=high_freq,
        low_freq=low_freq,
        sample_frequency=sample_rate,
    )
    sample["feat"] = mat
    return sample


def compute_log_mel_spectrogram(
    sample,
    n_fft=400,
    hop_length=160,
    num_mel_bins=80,
    padding=0,
    pad_or_trim: bool = False,
    max_duration: int = 30,
):
    """Extract log mel spectrogram, modified from openai-whisper, see:
    - https://github.com/openai/whisper/blob/main/whisper/audio.py
    - https://github.com/wenet-e2e/wenet/pull/2141#issuecomment-1811765040

    Args:
        sample: {key, wav, sample_rate, ...}
        max_duration: valid when pad_or_trim is True (orign whisper style)

    Returns:
        {key, feat, wav, sample_rate, ...}
    """
    assert "sample_rate" in sample
    assert "wav" in sample
    assert "key" in sample
    sample_rate = sample["sample_rate"]
    waveform = sample["wav"].squeeze(0)  # (channel=1, sample) -> (sample,)
    if padding > 0:
        waveform = F.pad(waveform, (0, padding))
    if pad_or_trim:
        length = max_duration * sample_rate
        if waveform.size(0) >= length:
            waveform = waveform[:length]
        else:
            waveform = F.pad(waveform, (0, length - waveform.size(0)))

    window = torch.hann_window(n_fft)
    stft = torch.stft(waveform, n_fft, hop_length, window=window, return_complex=True)
    magnitudes = stft[..., :-1].abs() ** 2

    filters = torch.from_numpy(
        librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=num_mel_bins)
    )
    mel_spec = filters @ magnitudes

    # NOTE(xcsong): https://github.com/openai/whisper/discussions/269
    log_spec = torch.clamp(mel_spec, min=1e-10).log10()
    log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
    log_spec = (log_spec + 4.0) / 4.0
    sample["feat"] = log_spec.transpose(0, 1)
    return sample


def tokenize(sample, tokenizer: BaseTokenizer):
    """Decode text to chars or BPE
    Inplace operation

    Args:
        sample: {key, wav, txt, sample_rate, ...}

    Returns:
        {key, wav, txt, tokens, label, sample_rate, ...}
    """
    assert "txt" in sample
    tokens, label = tokenizer.tokenize(sample["txt"])
    sample["tokens"] = tokens
    sample["label"] = label
    return sample


def filter(
    sample,
    max_length=10240,
    min_length=10,
    token_max_length=200,
    token_min_length=1,
    min_output_input_ratio=0.0005,
    max_output_input_ratio=1,
):
    """Filter sample according to feature and label length
    Inplace operation.

    Args::
        sample: {key, wav, label, sample_rate, ...}]
        max_length: drop utterance which is greater than max_length(10ms)
        min_length: drop utterance which is less than min_length(10ms)
        token_max_length: drop utterance which is greater than
            token_max_length, especially when use char unit for
            english modeling
        token_min_length: drop utterance which is
            less than token_max_length
        min_output_input_ratio: minimal ration of
            token_length / feats_length(10ms)
        max_output_input_ratio: maximum ration of
            token_length / feats_length(10ms)

    Returns:
        bool: True to keep, False to filter
    """
    assert "sample_rate" in sample
    assert "wav" in sample
    # sample['wav'] is torch.Tensor, we have 100 frames every second
    num_frames = sample["wav"].size(1) / sample["sample_rate"] * 100
    if num_frames < min_length:
        return False
    if num_frames > max_length:
        return False

    if "label" in sample:
        if len(sample["label"]) < token_min_length:
            return False
        if len(sample["label"]) > token_max_length:
            return False
        if num_frames != 0:
            if len(sample["label"]) / num_frames < min_output_input_ratio:
                return False
            if len(sample["label"]) / num_frames > max_output_input_ratio:
                return False
    return True


def spec_aug(sample, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10, max_w=80):
    """Do spec augmentation
    Inplace operation

    Args:
        sample: {key, feat, ...}
        num_t_mask: number of time mask to apply
        num_f_mask: number of freq mask to apply
        max_t: max width of time mask
        max_f: max width of freq mask
        max_w: max width of time warp

    Returns
        {key, feat, ....}
    """
    assert "feat" in sample
    x = sample["feat"]
    assert isinstance(x, torch.Tensor)
    y = x.clone().detach()
    max_frames = y.size(0)
    max_freq = y.size(1)
    # time mask
    for i in range(num_t_mask):
        start = random.randint(0, max_frames - 1)
        length = random.randint(1, max_t)
        end = min(max_frames, start + length)
        y[start:end, :] = 0
    # freq mask
    for _ in range(num_f_mask):
        start = random.randint(0, max_freq - 1)
        length = random.randint(1, max_f)
        end = min(max_freq, start + length)
        y[:, start:end] = 0
    sample["feat"] = y
    return sample


def spec_sub(sample, max_t=20, num_t_sub=3):
    """Do spec substitute
    Inplace operation
    ref: U2++, section 3.2.3 [https://arxiv.org/abs/2106.05642]

    Args:
        sample: Iterable{key, feat, ...}
        max_t: max width of time substitute
        num_t_sub: number of time substitute to apply

    Returns
        {key, feat, ...}
    """
    assert "feat" in sample
    x = sample["feat"]
    assert isinstance(x, torch.Tensor)
    y = x.clone().detach()
    max_frames = y.size(0)
    for _ in range(num_t_sub):
        start = random.randint(0, max_frames - 1)
        length = random.randint(1, max_t)
        end = min(max_frames, start + length)
        # only substitute the earlier time chosen randomly for current time
        pos = random.randint(0, start)
        y[start:end, :] = x[start - pos : end - pos, :]
    sample["feat"] = y
    return sample


def spec_trim(sample, max_t=20):
    """Trim tailing frames. Inplace operation.
    ref: TrimTail [https://arxiv.org/abs/2211.00522]

    Args:
        sample: {key, feat, label}
        max_t: max width of length trimming

    Returns:
        {key, feat, label}
    """
    assert "feat" in sample
    x = sample["feat"]
    assert isinstance(x, torch.Tensor)
    max_frames = x.size(0)
    length = random.randint(1, max_t)
    if length < max_frames / 2:
        y = x.clone().detach()[: max_frames - length]
        sample["feat"] = y
    return sample


def padding(data, pad_feat=True):
    """Padding the data into training data

    Automatically detects and supports both ASR and classification tasks.
    - ASR data: has "label" key (token IDs)
    - Classification data: has "{task}_label" keys (e.g., "gender_label", "emotion_label")

    Args:
        data: List[{key, feat, label, ...}] for ASR or
              List[{key, feat, {task}_label, ...}] for classification
        pad_feat: Whether to pad features

    Returns:
        Batched dictionary for ASR or classification
    """
    sample = data
    assert isinstance(sample, list)
    assert len(sample) > 0, "Empty batch"

    feats_length = torch.tensor([x["feat"].size(0) for x in sample], dtype=torch.int32)
    order = torch.argsort(feats_length, descending=True)
    feats_lengths = torch.tensor([sample[i]["feat"].size(0) for i in order], dtype=torch.int32)
    sorted_feats = [sample[i]["feat"] for i in order]
    sorted_keys = [sample[i]["key"] for i in order]
    padded_feats = pad_sequence(sorted_feats, batch_first=True, padding_value=0)

    # Detect data type: ASR has "label" key, classification has "*_label" keys
    is_asr = "label" in sample[0]

    if is_asr:
        # ASR-specific batching
        sorted_labels = [torch.tensor(sample[i]["label"], dtype=torch.int64) for i in order]
        sorted_wavs = [sample[i]["wav"].squeeze(0) for i in order]
        label_lengths = torch.tensor([x.size(0) for x in sorted_labels], dtype=torch.int32)
        wav_lengths = torch.tensor([x.size(0) for x in sorted_wavs], dtype=torch.int32)
        padding_labels = pad_sequence(sorted_labels, batch_first=True, padding_value=-1)
        padded_wavs = pad_sequence(sorted_wavs, batch_first=True, padding_value=0)

        batch = {
            "keys": sorted_keys,
            "feats": padded_feats if pad_feat else sorted_feats,
            "target": padding_labels,
            "feats_lengths": feats_lengths,
            "target_lengths": label_lengths,
            "pcm": padded_wavs,
            "pcm_length": wav_lengths,
        }
        if "speaker" in sample[0]:
            speaker = torch.tensor([sample[i]["speaker"] for i in order], dtype=torch.int32)
            batch["speaker"] = speaker
    else:
        # Classification-specific batching
        # Automatically detect all *_label keys
        batch = {
            "keys": sorted_keys,
            "feats": padded_feats if pad_feat else sorted_feats,
            "feats_lengths": feats_lengths,
        }

        # Find all classification label keys (e.g., gender_label, emotion_label, etc.)
        label_keys = [k for k in sample[0].keys() if k.endswith("_label")]

        # Add each classification task's labels to the batch
        for label_key in label_keys:
            labels = torch.tensor([sample[i][label_key] for i in order], dtype=torch.int64)
            batch[label_key] = labels
    return batch


class DynamicBatchWindow:

    def __init__(self, max_frames_in_batch=12000):
        self.longest_frames = 0
        self.max_frames_in_batch = max_frames_in_batch

    def __call__(self, sample, buffer_size):
        assert isinstance(sample, dict)
        assert "feat" in sample
        assert isinstance(sample["feat"], torch.Tensor)
        new_sample_frames = sample["feat"].size(0)
        self.longest_frames = max(self.longest_frames, new_sample_frames)
        frames_after_padding = self.longest_frames * (buffer_size + 1)
        if frames_after_padding > self.max_frames_in_batch:
            self.longest_frames = new_sample_frames
            return True
        return False


def parse_classification_labels(sample: dict, tasks: list) -> dict:
    """Parse classification labels from sample.

    Args:
        sample: Dictionary containing label information
        tasks: List of task names to parse (e.g., ['gender', 'emotion', 'region'])

    Returns:
        Updated sample with parsed labels

    Example:
        Input sample: {'key': 'utt1', 'gender_label': '0', 'emotion_label': '3'}
        tasks: ['gender', 'emotion']
        Output: {'key': 'utt1', 'gender_label': 0, 'emotion_label': 3}
    """
    for task in tasks:
        label_key = f"{task}_label"
        if label_key in sample:
            sample[label_key] = int(sample[label_key])
        else:
            # raise error
            raise KeyError(f"Label {label_key} not found in sample {sample}")
    return sample
