from torch.utils.data import Dataset, DataLoader, Subset, random_split
from typing import List, Tuple, Dict, Optional, Type
import json
import nltk
import numpy as np
import pandas as pd
import torch
import torch.optim


def tokenize(text: str):
    """
    Convert sentences into tokens.
        >>> list(tokenize('I cat. I dog.'))
        >>> ['I', 'cat', '.', 'I', 'dog', '.']
    """
    for sent in nltk.sent_tokenize(text, language='spanish'):
        yield from nltk.word_tokenize(sent, language='spanish')


def transform_reports_to_idxs(reports) -> Tuple[pd.Series, Dict[str, int]]:
    """
    build vocabulary dictionary {word: index} and transform dataset of words
    into dataset of indices"""
    dictionary: Dict[str, int] = {'': 0}  # {word: index}
    reports_transformed = []
    for report in reports:
        if report is np.nan:
            idxs = []
        else:
            idxs = [
                dictionary.setdefault(word, len(dictionary))
                for word in tokenize(report)]
        reports_transformed.append(idxs)
    return pd.Series(reports_transformed, index=reports.index), dictionary


class PadChestReports(Dataset):
    def __init__(self, reports: pd.Series):
        super().__init__()
        self.reports = reports
        self.reports_transformed, self.dictionary = \
            transform_reports_to_idxs(reports)
        self.reverse_dictionary = {v: k for k, v in self.dictionary.items()}
        assert len(self.reverse_dictionary) == len(self.dictionary)
        self.max_report_length = max(
            len(tokens) for tokens in self.reports_transformed)

    def __len__(self):
        return len(self.reports)

    def __getitem__(self, report_idx):
        word_idxs = self.reports_transformed.iloc[report_idx]
        return self.random_pad_zeros(word_idxs)

    def random_pad_zeros(self, word_idxs):
        # pad zeros at random locations.  This ensures gradients for positions
        # at end of long reports are not biased by being trained with fewer
        # gradient updates than the beginning of reports.  It also ensures all
        # reports are the same length.
        output = torch.zeros(self.max_report_length, dtype=torch.long)
        token_locations = np.random.choice(
            self.max_report_length, len(word_idxs), replace=False)
        token_locations.sort()
        output[token_locations] = torch.tensor(word_idxs, dtype=output.dtype)
        return output

    def reverse_transform(self, embedding):
        """Convert a sequence of index numbers into the report"""
        tmp = (self.reverse_dictionary[num.item()] for num in embedding)
        tmp = ' '.join(x for x in tmp if x != '')
        return tmp


class PadChestLabels(Dataset):
    def __init__(self, labels: pd.Series):
        super().__init__()
        self.labels_unmodified = labels
        labels = labels.fillna('[]').str.replace("'", '"')\
            .apply(json.loads).apply(lambda x: [y.strip() for y in x])
        # convert to "many hot" (replace each row of labels with binary vector)
        self._lookup_class_to_idx: Dict[str, int] = {}
        [self._lookup_class_to_idx.setdefault(x, len(self._lookup_class_to_idx))
         for y in labels for x in y]
        self._reverse_lookup: Dict[int, str] = {
            v: k for k, v in self._lookup_class_to_idx.items()}
        assert len(self._lookup_class_to_idx) == len(self._reverse_lookup)
        self.num_classes = len(self._lookup_class_to_idx)
        new_labels = []
        for row_data in labels:
            new_row = torch.zeros(self.num_classes, dtype=torch.long)
            new_row[[self._lookup_class_to_idx[lbl] for lbl in row_data]] = 1
            new_labels.append(new_row)
        self.labels = pd.Series(new_labels, index=labels.index)
        self.header = [self._reverse_lookup[i] for i in range(self.num_classes)]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, report_idx: int):
        return self.labels.iloc[report_idx]

    def reverse_transform(self, labels: torch.Tensor) -> List[str]:
        """Convert from many hot vector to a list of "hot" class names"""
        hot_idxs = torch.arange(self.num_classes)[labels != 0]
        labels_str = [self._reverse_lookup[idx.item()] for idx in hot_idxs]
        return labels_str

    def get_intra_class_balancing_weight(self):
        """
        For each binary class, find the positive class weight that balances the
        positive and negative samples.  For each class, the value computed is:
            max(num_pos, num_neg) / num_pos

        NOTE: this weight should be computed on the training (or validation)
        set, not the test set.

        Returns the "pos_weight" to be passed to a BCE loss function so that positive
        class is given as much importance as the negative class.

        Example:
            BCE(pos_weight=X)  where X is output of this function
        """
        _wpos = self.labels.sum(0).reshape(1, -1)
        _wneg = self.labels.shape[0] - _wpos
        pos_weight = torch.maximum(_wpos, _wneg) / _wpos
        return pos_weight


class PadChestImages(Dataset):
    def __init__(self, dir, image_ids):
        super().__init__()
        raise NotImplementedError('todo future')


class Combined(Dataset):
    def __init__(self, *dsets):
        super().__init__()
        assert all(len(dset) == len(dsets[0]) for dset in dsets)
        self.dsets = dsets

    def __getitem__(self, idx):
        return [dset[idx] for dset in self.dsets]

    def __len__(self):
        return len(self.dsets[0])


class PadChest(Dataset):
    """https://b2drop.bsc.es/index.php/s/BIMCV-PadChest-FULL

    Select different kinds of data from PadChest dataset:

        - Obtain labels, reports, and/or images of the PadChest dataset.
        - Define a `row_selector` to subset the available data to only patients
          of interest.  It is applied to the whole dataset, before splitting
          into train/val/test.

        - Select either the train/val/test sets generated from the selected
          data. Note that the splits are persistent if the row_selector is
          persistent.
    """
    def __init__(
            self, fp_image_dir: Optional[str] = None,
            fp_label_csv: Optional[str] = None,
            labels=True, reports=True, images=False,
            row_selector=lambda df: df,
            train_val_test: str = 'all'):
        super().__init__()
        if fp_label_csv is None:
            assert fp_image_dir is not None
            fp_label_csv = f'{fp_image_dir}/PADCHEST_chest_x_ray_images_labels_160K_01.02.19.csv.gz'
        df = pd.read_csv(fp_label_csv)
        df = row_selector(df)

        # select train, val or test
        self.partition = train_val_test
        if train_val_test == 'all':
            pass
        else:
            z = df.index
            N_train = int(.7*len(z))
            N_val = int(.1*len(z))
            N_test = len(z) - N_train - N_val
            train, val, test = random_split(
                z, [N_train, N_val, N_test],
                generator=torch.Generator().manual_seed(420))
            if train_val_test == 'train':
                df = df.loc[z[train.indices]]
            elif train_val_test == 'val':
                df = df.loc[z[val.indices]]
            elif train_val_test == 'test':
                df = df.loc[z[test.indices]]

        self.df = df
        dsets = []
        if reports:
            self.reports = PadChestReports(df['Report'])
            dsets.append(self.reports)
        if labels:
            self.labels = PadChestLabels(df['Labels'])
            dsets.append(self.labels)
        if images:
            self.images = PadChestImages(fp_image_dir, df['ImageID'])
            dsets.append(self.images)
        # if localizations:
        #     raise NotImplementedError()
        # if labels_localizations:
        #     raise NotImplementedError()
        self.dset = Combined(*dsets)

    def __len__(self):
        return len(self.dset)

    def __getitem__(self, idx):
        if isinstance(idx, torch.Tensor):
            idx = idx.item()
        return self.dset[idx]

    def random_subsample_normal(self, n=500):
        mask_all_normal_samples = self.df['Labels'].str.contains('normal')
        p = n / mask_all_normal_samples.sum()
        mask_some_normal_samples = (
            mask_all_normal_samples
            & (np.random.uniform(0, 1, mask_all_normal_samples.shape) <= p)
        )
        mask = (~mask_all_normal_samples) | mask_some_normal_samples
        return Subset(self, torch.arange(mask.shape[0])[mask.values])


def test_dset():
    dset = PadChest(
        fp_label_csv='PADCHEST_chest_x_ray_images_labels_160K_01.02.19.csv.gz',
        labels=True, reports=True,
        row_selector=lambda df: df[df['MethodLabel'] == 'Physician'],
        train_val_test='all'
    )
    print('test that report embedding is valid for patient 0')
    report_embedding, _ = dset[0]
    # print('computed', dset.reports.reverse_transform(report_embedding))
    # print('expected', dset.df.iloc[0]['Report'].strip())
    assert dset.reports.reverse_transform(report_embedding) \
        == dset.df.iloc[0]['Report'].strip()
    print('test that labels are valid for patient 1')
    _, labels = dset[1]
    # print('computed', labels)
    # print('expected', dset.df.iloc[1]['Labels'])
    assert dset.df.iloc[1]['Labels'] == \
        json.dumps(dset.labels.reverse_transform(labels)).replace('"', "'")
    print('test random_subsample_normal')
    dset2 = dset.random_subsample_normal(500)
    assert abs(len(dset2) + 12694-500 - len(dset)) < 50


if __name__ == '__main__':
    test_dset()
