# Copyright (c) Facebook, Inc. and its affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os

import torch
from torch.utils.data import Dataset
import numpy as np

from transformers.tokenization_bert import BertTokenizer
from ._image_features_reader import ImageFeaturesH5Reader
import _pickle as cPickle

import xml.etree.ElementTree as ET


def iou(anchors, gt_boxes):
    """
    anchors: (N, 4) ndarray of float
    gt_boxes: (K, 4) ndarray of float
    overlaps: (N, K) ndarray of overlap between boxes and query_boxes
    """
    N = anchors.size(0)
    K = gt_boxes.size(0)

    gt_boxes_area = ((gt_boxes[:, 2] - gt_boxes[:, 0] + 1) * (gt_boxes[:, 3] - gt_boxes[:, 1] + 1)).view(1, K)

    anchors_area = ((anchors[:, 2] - anchors[:, 0] + 1) * (anchors[:, 3] - anchors[:, 1] + 1)).view(N, 1)

    boxes = anchors.view(N, 1, 4).expand(N, K, 4)
    query_boxes = gt_boxes.view(1, K, 4).expand(N, K, 4)

    iw = (torch.min(boxes[:, :, 2], query_boxes[:, :, 2]) - torch.max(boxes[:, :, 0], query_boxes[:, :, 0]) + 1)
    iw[iw < 0] = 0

    ih = (torch.min(boxes[:, :, 3], query_boxes[:, :, 3]) - torch.max(boxes[:, :, 1], query_boxes[:, :, 1]) + 1)
    ih[ih < 0] = 0

    ua = anchors_area + gt_boxes_area - (iw * ih)
    overlaps = iw * ih / ua

    return overlaps


def assert_eq(real, expected):
    assert real == expected, '%s (true) vs %s (expected)' % (real, expected)


def get_sentence_data(fn):
    """
    Parses a sentence file from the Flickr30K Entities dataset
    input:
      fn - full file path to the sentence file to parse

    output:
      a list of dictionaries for each sentence with the following fields:
          sentence - the original sentence
          phrases - a list of dictionaries for each phrase with the
                    following fields:
                      phrase - the text of the annotated phrase
                      first_word_index - the position of the first word of
                                         the phrase in the sentence
                      phrase_id - an identifier for this phrase
                      phrase_type - a list of the coarse categories this
                                    phrase belongs to
    """
    with open(fn, 'r') as f:
        sentences = f.read().split('\n')

    annotations = []
    for sentence in sentences:
        if not sentence:
            continue

        first_word = []
        phrases = []
        phrase_id = []
        phrase_type = []
        words = []
        current_phrase = []
        add_to_phrase = False
        for token in sentence.split():
            if add_to_phrase:
                if token[-1] == ']':
                    add_to_phrase = False
                    token = token[:-1]
                    current_phrase.append(token)
                    phrases.append(' '.join(current_phrase))
                    current_phrase = []
                else:
                    current_phrase.append(token)

                words.append(token)
            else:
                if token[0] == '[':
                    add_to_phrase = True
                    first_word.append(len(words))
                    parts = token.split('/')
                    phrase_id.append(parts[1][3:])
                    phrase_type.append(parts[2:])
                else:
                    words.append(token)

        sentence_data = {'sentence': ' '.join(words), 'phrases': []}
        for index, phrase, p_id, p_type in zip(first_word, phrases, phrase_id, phrase_type):
            sentence_data['phrases'].append({
                'first_word_index': index,
                'phrase': phrase,
                'phrase_id': p_id,
                'phrase_type': p_type,
            })

        annotations.append(sentence_data)

    return annotations


def get_annotations(fn):
    """
    Parses the xml files in the Flickr30K Entities dataset
    input:
      fn - full file path to the annotations file to parse
    output:
      dictionary with the following fields:
          scene - list of identifiers which were annotated as
                  pertaining to the whole scene
          nobox - list of identifiers which were annotated as
                  not being visible in the image
          boxes - a dictionary where the fields are identifiers
                  and the values are its list of boxes in the
                  [xmin ymin xmax ymax] format
    """
    tree = ET.parse(fn)
    root = tree.getroot()
    size_container = root.findall('size')[0]
    anno_info = {'boxes': {}, 'scene': [], 'nobox': []}
    for size_element in size_container:
        anno_info[size_element.tag] = int(size_element.text)

    for object_container in root.findall('object'):
        for names in object_container.findall('name'):
            box_id = names.text
            box_container = object_container.findall('bndbox')
            if len(box_container) > 0:
                if box_id not in anno_info['boxes']:
                    anno_info['boxes'][box_id] = []
                xmin = int(box_container[0].findall('xmin')[0].text) - 1
                ymin = int(box_container[0].findall('ymin')[0].text) - 1
                xmax = int(box_container[0].findall('xmax')[0].text) - 1
                ymax = int(box_container[0].findall('ymax')[0].text) - 1
                anno_info['boxes'][box_id].append([xmin, ymin, xmax, ymax])
            else:
                nobndbox = int(object_container.findall('nobndbox')[0].text)
                if nobndbox > 0:
                    anno_info['nobox'].append(box_id)

                scene = int(object_container.findall('scene')[0].text)
                if scene > 0:
                    anno_info['scene'].append(box_id)

    return anno_info


class FlickrGroundingDataset(Dataset):

    def __init__(
        self,
        task: str,
        dataroot: str,
        annotations_jsonpath: str,
        split: str,
        image_features_reader: ImageFeaturesH5Reader,
        gt_image_features_reader: ImageFeaturesH5Reader,
        tokenizer: BertTokenizer,
        bert_model,
        clean_datasets,
        padding_index: int = 0,
        max_seq_length: int = 20,
        max_region_num: int = 60,
    ):
        self.split = split
        self.num_labels = 1
        self._image_features_reader = image_features_reader
        self._gt_image_features_reader = gt_image_features_reader
        self._tokenizer = tokenizer

        self._padding_index = padding_index
        self._max_seq_length = max_seq_length
        self.dataroot = dataroot
        self.entries = self._load_annotations(clean_datasets)

        self.max_region_num = max_region_num

        clean_train = '_cleaned' if clean_datasets else ''

        if 'roberta' in bert_model:
            cache_path = os.path.join(
                dataroot,
                'cache',
                task + '_' + split + '_' + 'roberta' + '_' + str(max_seq_length) + '_' + str(max_region_num) +
                clean_train + '.pkl',
            )
        else:
            cache_path = os.path.join(
                dataroot,
                'cache',
                task + '_' + split + '_' + str(max_seq_length) + '_' + str(max_region_num) + clean_train + '.pkl',
            )

        if not os.path.exists(cache_path):
            self.tokenize()
            self.tensorize()
            cPickle.dump(self.entries, open(cache_path, 'wb'))
        else:
            print('loading entries from %s' % (cache_path))
            self.entries = cPickle.load(open(cache_path, 'rb'))

    def _load_annotations(self, clean_datasets):

        entries = []
        remove_ids = []
        if clean_datasets:
            remove_ids = np.load(os.path.join(self.dataroot, 'cache', 'flickr_test_ids.npy'))
            remove_ids = [int(x) for x in remove_ids]

        with open(
                os.path.join('/checkpoint/vedanuj/datasets/flickr30k', '%s.txt' % self.split),
                'r',
        ) as f:
            images = f.read().splitlines()

        for img in images:
            if self.split == 'train' and int(img) in remove_ids:
                continue
            annotation = get_annotations(
                os.path.join('/checkpoint/vedanuj/datasets/flickr30k/Annotations', img + '.xml'))
            sentences = get_sentence_data(
                os.path.join('/checkpoint/vedanuj/datasets/flickr30k/Sentences', img + '.txt'))

            for i, sent in enumerate(sentences):
                for phrase in sent['phrases']:
                    if str(phrase['phrase_id']) in annotation['boxes'].keys():
                        entries.append({
                            'caption': phrase['phrase'],
                            'sent_id': phrase['phrase_id'],
                            'image_id': int(img),
                            'refBox': annotation['boxes'][str(phrase['phrase_id'])][0],
                        })

        return entries

    def tokenize(self):
        """Tokenizes the captions.

        This will add caption_tokens in each entry of the dataset.
        -1 represents nil, and should be treated as padding_idx in embedding.
        """
        for entry in self.entries:

            # sentence_tokens = self._tokenizer.tokenize(entry["caption"])
            # sentence_tokens = ["[CLS]"] + sentence_tokens + ["[SEP]"]

            # tokens = [
            #     self._tokenizer.vocab.get(w, self._tokenizer.vocab["[UNK]"])
            #     for w in sentence_tokens
            # ]

            tokens = self._tokenizer.encode(entry['caption'])
            tokens = tokens[:self._max_seq_length - 2]
            tokens = self._tokenizer.add_special_tokens_single_sentence(tokens)

            segment_ids = [0] * len(tokens)
            input_mask = [1] * len(tokens)

            if len(tokens) < self._max_seq_length:
                # Note here we pad in front of the sentence
                padding = [self._padding_index] * (self._max_seq_length - len(tokens))
                tokens = tokens + padding
                input_mask += padding
                segment_ids += padding

            assert_eq(len(tokens), self._max_seq_length)
            entry['token'] = tokens
            entry['input_mask'] = input_mask
            entry['segment_ids'] = segment_ids

    def tensorize(self):

        for entry in self.entries:
            token = torch.from_numpy(np.array(entry['token']))
            entry['token'] = token

            input_mask = torch.from_numpy(np.array(entry['input_mask']))
            entry['input_mask'] = input_mask

            segment_ids = torch.from_numpy(np.array(entry['segment_ids']))
            entry['segment_ids'] = segment_ids

    def __getitem__(self, index):
        entry = self.entries[index]

        image_id = entry['image_id']
        ref_box = entry['refBox']

        features, num_boxes, boxes, boxes_ori = self._image_features_reader[image_id]

        boxes_ori = boxes_ori[:num_boxes]
        boxes = boxes[:num_boxes]
        features = features[:num_boxes]

        if self.split == 'train':
            gt_features, gt_num_boxes, gt_boxes, gt_boxes_ori = self._gt_image_features_reader[image_id]

            # merge two boxes, and assign the labels.
            gt_boxes_ori = gt_boxes_ori[1:gt_num_boxes]
            gt_boxes = gt_boxes[1:gt_num_boxes]
            gt_features = gt_features[1:gt_num_boxes]

            # concatenate the boxes
            mix_boxes_ori = np.concatenate((boxes_ori, gt_boxes_ori), axis=0)
            mix_boxes = np.concatenate((boxes, gt_boxes), axis=0)
            mix_features = np.concatenate((features, gt_features), axis=0)
            mix_num_boxes = min(int(num_boxes + int(gt_num_boxes) - 1), self.max_region_num)
            # given the mix boxes, and ref_box, calculate the overlap.
            mix_target = iou(
                torch.tensor(mix_boxes_ori[:, :4]).float(),
                torch.tensor([ref_box]).float(),
            )
            mix_target[mix_target < 0.5] = 0
        else:
            mix_boxes_ori = boxes_ori
            mix_boxes = boxes
            mix_features = features
            mix_num_boxes = min(int(num_boxes), self.max_region_num)
            mix_target = iou(
                torch.tensor(mix_boxes_ori[:, :4]).float(),
                torch.tensor([ref_box]).float(),
            )

        image_mask = [1] * (mix_num_boxes)
        while len(image_mask) < self.max_region_num:
            image_mask.append(0)

        mix_boxes_pad = np.zeros((self.max_region_num, 5))
        mix_features_pad = np.zeros((self.max_region_num, 2048))

        mix_boxes_pad[:mix_num_boxes] = mix_boxes[:mix_num_boxes]
        mix_features_pad[:mix_num_boxes] = mix_features[:mix_num_boxes]

        # appending the target feature.
        features = torch.tensor(mix_features_pad).float()
        image_mask = torch.tensor(image_mask).long()
        spatials = torch.tensor(mix_boxes_pad).float()

        target = torch.zeros((self.max_region_num, 1)).float()
        target[:mix_num_boxes] = mix_target[:mix_num_boxes]

        # spatials_ori = torch.tensor(mix_boxes_ori).float()
        co_attention_mask = torch.zeros((self.max_region_num, self._max_seq_length))

        caption = entry['token']
        input_mask = entry['input_mask']
        segment_ids = entry['segment_ids']

        return (
            features,
            spatials,
            image_mask,
            caption,
            target,
            input_mask,
            segment_ids,
            co_attention_mask,
            image_id,
        )

    def __len__(self):
        return len(self.entries)
