"""Copyright (c) Microsoft Corporation. Licensed under the MIT license.

VCR dataset
"""
import copy
import json
import torch
from torch.nn.utils.rnn import pad_sequence
from toolz.sandbox import unzip
from cytoolz import concat
from .data import (
    DetectFeatTxtTokDataset,
    TxtTokLmdb,
    DetectFeatLmdb,
    TxtLmdb,
    get_ids_and_lens,
    pad_tensors,
    ImageLmdbGroup,
    get_gather_index,
    ConcatDatasetWithLens,
)
from imix.data.builder import DATASETS
import logging
from imix.data.utils.stream import ItemFeature

logger = logging.getLogger(__name__)


class VcrTxtTokLmdb(TxtTokLmdb):

    def __init__(self, db_dir, max_txt_len=120, task='qa,qar'):
        assert task == 'qa' or task == 'qar' or task == 'qa,qar', \
            "VCR only support the following tasks: 'qa', 'qar' or 'qa,qar'"
        self.task = task
        if task == 'qa,qar':
            id2len_task = 'qar'
        else:
            id2len_task = task
        if max_txt_len == -1:
            self.id2len = json.load(open(f'{db_dir}/id2len_{id2len_task}.json'))
        else:
            self.id2len = {
                id_: len_
                for id_, len_ in json.load(open(f'{db_dir}/id2len_{id2len_task}.json')).items() if len_ <= max_txt_len
            }

        self.db_dir = db_dir
        self.db = TxtLmdb(db_dir, readonly=True)
        meta = json.load(open(f'{db_dir}/meta.json', 'r'))
        self.cls_ = meta['CLS']
        self.sep = meta['SEP']
        self.mask = meta['MASK']
        self.v_range = meta['v_range']


class VcrDetectFeatTxtTokDataset(DetectFeatTxtTokDataset):

    def __init__(self, txt_db, img_db_gt=None, img_db=None):
        assert not (img_db_gt is None and img_db is None), \
            'img_db_gt and img_db cannot all be None'
        assert isinstance(txt_db, VcrTxtTokLmdb)
        assert img_db_gt is None or isinstance(img_db_gt, DetectFeatLmdb)
        assert img_db is None or isinstance(img_db, DetectFeatLmdb)
        self.txt_db = txt_db
        self.img_db = img_db
        self.img_db_gt = img_db_gt
        self.task = self.txt_db.task
        txt_lens, self.ids = get_ids_and_lens(txt_db)

        txt2img = txt_db.txt2img

        if self.img_db and self.img_db_gt:
            self.lens = [
                tl + self.img_db_gt.name2nbb[txt2img[id_][0]] + self.img_db.name2nbb[txt2img[id_][1]]
                for tl, id_ in zip(txt_lens, self.ids)
            ]
        elif self.img_db:
            self.lens = [tl + self.img_db.name2nbb[txt2img[id_][1]] for tl, id_ in zip(txt_lens, self.ids)]
        else:
            self.lens = [tl + self.img_db_gt.name2nbb[txt2img[id_][0]] for tl, id_ in zip(txt_lens, self.ids)]

    def _get_img_feat(self, fname_gt, fname):
        if self.img_db and self.img_db_gt:
            img_feat_gt, bb_gt = self.img_db_gt[fname_gt]
            img_bb_gt = torch.cat([bb_gt, bb_gt[:, 4:5] * bb_gt[:, 5:]], dim=-1)

            img_feat, bb = self.img_db[fname]
            img_bb = torch.cat([bb, bb[:, 4:5] * bb[:, 5:]], dim=-1)

            img_feat = torch.cat([img_feat_gt, img_feat], dim=0)
            img_bb = torch.cat([img_bb_gt, img_bb], dim=0)
            num_bb = img_feat.size(0)
        elif self.img_db:
            img_feat, bb = self.img_db[fname]
            img_bb = torch.cat([bb, bb[:, 4:5] * bb[:, 5:]], dim=-1)
            num_bb = img_feat.size(0)
        elif self.img_db_gt:
            img_feat, bb = self.img_db_gt[fname_gt]
            img_bb = torch.cat([bb, bb[:, 4:5] * bb[:, 5:]], dim=-1)
            num_bb = img_feat.size(0)
        return img_feat, img_bb, num_bb


def load_img_feat(db_list, all_img_dbs, opts):
    db_ = db_list.split(';')
    assert len(db_) <= 2, 'More than two img_dbs found'
    gt_db_path, db_path = '', ''
    for d in db_:
        if 'gt' in d:
            gt_db_path = d
        else:
            db_path = d
    if gt_db_path != '':
        img_db_gt = DetectFeatLmdb(gt_db_path, -1, opts.max_bb, opts.min_bb, 100, opts.compressed_db)
        all_img_dbs.path2imgdb[gt_db_path] = img_db_gt
    else:
        img_db_gt = None
    img_db = all_img_dbs[db_path] if db_path != '' else None
    all_img_dbs.path2imgdb[db_path] = img_db
    return img_db, img_db_gt


@DATASETS.register_module()
class UNITER_VcrDataset(VcrDetectFeatTxtTokDataset):

    def __init__(self, **kwargs):
        cls_name = self.__class__.__name__
        logger.info('start loading {}'.format(cls_name))

        # load DBs and image dirs
        opts = ItemFeature(kwargs['datacfg'])
        all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb, opts.num_bb, opts.compressed_db)
        train_or_val = kwargs['train_or_val']
        assert train_or_val is not None
        if train_or_val:  # train
            train_datasets = []
            for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs):
                img_db, img_db_gt = load_img_feat(img_path, all_img_dbs, opts)
                qa_txt_db = VcrTxtTokLmdb(txt_path, opts.max_txt_len, task='qa')
                qar_txt_db = VcrTxtTokLmdb(txt_path, opts.max_txt_len, task='qar')
                train_datasets.append(VcrTrainDataset(qa_txt_db, img_db_gt=img_db_gt, img_db=img_db))
                train_datasets.append(VcrTrainDataset(qar_txt_db, img_db_gt=img_db_gt, img_db=img_db))

                self.dataset = ConcatDatasetWithLens(train_datasets)

        else:

            val_img_db, val_img_db_gt = load_img_feat(opts.val_img_db, all_img_dbs, opts)
            val_txt_db = VcrTxtTokLmdb(opts.val_txt_db, -1)
            self.dataset = VcrEvalDataset('val', val_txt_db, img_db=val_img_db, img_db_gt=val_img_db_gt)

        self.collate_fn = vcr_collate if train_or_val else vcr_eval_collate
        logger.info('load {} successfully'.format(cls_name))
        logger.info('Num examples = %d', len(self.dataset))

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        return self.dataset[item]


class VcrTrainDataset(VcrDetectFeatTxtTokDataset):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        assert self.task != 'qa,qar', \
            'loading training dataset with each task separately'
        self.collate_fn = vcr_collate

    def _get_input_ids(self, txt_dump):
        # text input
        input_ids_q = txt_dump['input_ids']
        type_ids_q = [0] * len(input_ids_q)
        input_ids_as = txt_dump['input_ids_as']
        if self.task == 'qar':
            input_ids_rs = txt_dump['input_ids_rs']
            answer_label = txt_dump['qa_target']
            assert answer_label >= 0, 'answer_label < 0'
            input_ids_gt_a = [self.txt_db.sep] + copy.deepcopy(input_ids_as[answer_label])
            type_ids_gt_a = [2] * len(input_ids_gt_a)
            type_ids_q += type_ids_gt_a
            input_ids_q += input_ids_gt_a
            input_ids_for_choices = input_ids_rs
        else:
            input_ids_for_choices = input_ids_as
        return input_ids_q, input_ids_for_choices, type_ids_q

    def __getitem__(self, i):
        """[[txt, img1],

        [txt, img2]]
        """
        example = super().__getitem__(i)
        img_feat, img_pos_feat, num_bb = self._get_img_feat(example['img_fname'][0], example['img_fname'][1])
        input_ids_q, input_ids_for_choices, type_ids_q = self._get_input_ids(example)
        label = example['%s_target' % (self.task)]

        outs = []
        for index, input_ids_a in enumerate(input_ids_for_choices):
            if index == label:
                target = torch.tensor([1]).long()
            else:
                target = torch.tensor([0]).long()
            input_ids = [self.txt_db.cls_] + copy.deepcopy(input_ids_q) + \
                        [self.txt_db.sep] + input_ids_a + [self.txt_db.sep]
            # type_id
            # 0 -- question
            # 1 -- region
            # 2 -- answer
            # 3 -- rationale
            type_id_for_choice = 3 if type_ids_q[-1] == 2 else 2
            txt_type_ids = [0] + type_ids_q + [type_id_for_choice] * (len(input_ids_a) + 2)
            attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
            input_ids = torch.tensor(input_ids)
            txt_type_ids = torch.tensor(txt_type_ids)

            outs.append((input_ids, txt_type_ids, img_feat, img_pos_feat, attn_masks, target))

        return tuple(outs)


def vcr_collate(inputs):
    (input_ids, txt_type_ids, img_feats, img_pos_feats, attn_masks, targets) = map(list, unzip(concat(inputs)))

    txt_lens = [i.size(0) for i in input_ids]
    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
    txt_type_ids = pad_sequence(txt_type_ids, batch_first=True, padding_value=0)
    position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0)

    # image batches
    num_bbs = [f.size(0) for f in img_feats]
    img_feat = pad_tensors(img_feats, num_bbs)
    img_pos_feat = pad_tensors(img_pos_feats, num_bbs)

    attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
    targets = torch.stack(targets, dim=0)

    bs, max_tl = input_ids.size()
    out_size = attn_masks.size(1)
    gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)

    batch = {
        'input_ids': input_ids,
        'txt_type_ids': txt_type_ids,
        'position_ids': position_ids,
        'img_feat': img_feat,
        'img_pos_feat': img_pos_feat,
        'attn_masks': attn_masks,
        'gather_index': gather_index,
        'targets': targets
    }
    return batch


class VcrEvalDataset(VcrDetectFeatTxtTokDataset):

    def __init__(self, split, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.split = split
        assert self.task == 'qa,qar', \
            'loading evaluation dataset with two tasks together'

    def _get_input_ids(self, txt_dump):
        # text input
        input_ids_for_choices = []
        type_ids_for_choices = []
        input_ids_q = txt_dump['input_ids']
        type_ids_q = [0] * len(input_ids_q)
        input_ids_as = txt_dump['input_ids_as']
        input_ids_rs = txt_dump['input_ids_rs']
        for index, input_ids_a in enumerate(input_ids_as):
            curr_input_ids_qa = \
                [self.txt_db.cls_] + copy.deepcopy(input_ids_q) + [self.txt_db.sep] + input_ids_a + [self.txt_db.sep]
            curr_type_ids_qa = [0] + type_ids_q + [2] * (len(input_ids_a) + 2)
            input_ids_for_choices.append(curr_input_ids_qa)
            type_ids_for_choices.append(curr_type_ids_qa)

        for index, input_ids_a in enumerate(input_ids_as):
            curr_input_ids_qa = \
                [self.txt_db.cls_] + copy.deepcopy(input_ids_q) + [self.txt_db.sep] + input_ids_a + [self.txt_db.sep]
            curr_type_ids_qa = [0] + type_ids_q + [2] * (len(input_ids_a) + 1)
            if (self.split == 'val' and index == txt_dump['qa_target']) or self.split == 'test':
                for input_ids_r in input_ids_rs:
                    curr_input_ids_qar = copy.deepcopy(curr_input_ids_qa) + input_ids_r + [self.txt_db.sep]
                    curr_type_ids_qar = copy.deepcopy(curr_type_ids_qa) + [3] * (len(input_ids_r) + 2)
                    input_ids_for_choices.append(curr_input_ids_qar)
                    type_ids_for_choices.append(curr_type_ids_qar)

        return input_ids_for_choices, type_ids_for_choices

    def __getitem__(self, i):
        qid = self.ids[i]
        example = super().__getitem__(i)
        img_feat, img_pos_feat, num_bb = self._get_img_feat(example['img_fname'][0], example['img_fname'][1])

        input_ids_for_choices, type_ids_for_choices = self._get_input_ids(example)
        qa_target = torch.tensor([int(example['qa_target'])])
        qar_target = torch.tensor([int(example['qar_target'])])

        outs = []
        for index, input_ids in enumerate(input_ids_for_choices):
            attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long)

            input_ids = torch.tensor(input_ids)
            txt_type_ids = torch.tensor(type_ids_for_choices[index])

            outs.append((input_ids, txt_type_ids, img_feat, img_pos_feat, attn_masks))

        return tuple(outs), qid, qa_target, qar_target

    # for test
    # def __len__(self):
    #     return 64


def vcr_eval_collate(inputs):
    (input_ids, txt_type_ids, img_feats, img_pos_feats,
     attn_masks) = map(list, unzip(concat(outs for outs, _, _, _ in inputs)))

    txt_lens = [i.size(0) for i in input_ids]
    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
    txt_type_ids = pad_sequence(txt_type_ids, batch_first=True, padding_value=0)
    position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0)

    # image batches
    num_bbs = [f.size(0) for f in img_feats]
    img_feat = pad_tensors(img_feats, num_bbs)
    img_pos_feat = pad_tensors(img_pos_feats, num_bbs)

    attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)

    bs, max_tl = input_ids.size()
    out_size = attn_masks.size(1)
    gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)

    qa_targets = torch.stack([t for _, _, t, _ in inputs], dim=0)
    qar_targets = torch.stack([t for _, _, _, t in inputs], dim=0)
    qids = [id_ for _, id_, _, _ in inputs]

    return {
        'qids': qids,
        'input_ids': input_ids,
        'txt_type_ids': txt_type_ids,
        'position_ids': position_ids,
        'img_feat': img_feat,
        'img_pos_feat': img_pos_feat,
        'attn_masks': attn_masks,
        'gather_index': gather_index,
        'qa_targets': qa_targets,
        'qar_targets': qar_targets
    }
