# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np


class DatasetInfo:
    def __init__(self, dataset_info):
        self._dataset_info = dataset_info
        self.dataset_name = self._dataset_info["dataset_name"]
        self.paper_info = self._dataset_info["paper_info"]
        self.keypoint_info = self._dataset_info["keypoint_info"]
        self.skeleton_info = self._dataset_info["skeleton_info"]
        self.joint_weights = np.array(
            self._dataset_info["joint_weights"], dtype=np.float32
        )[:, None]

        self.sigmas = np.array(self._dataset_info["sigmas"])

        self._parse_keypoint_info()
        self._parse_skeleton_info()

    def _parse_skeleton_info(self):
        """Parse skeleton information.

        - link_num (int): number of links.
        - skeleton (list((2,))): list of links (id).
        - skeleton_name (list((2,))): list of links (name).
        - pose_link_color (np.ndarray): the color of the link for
            visualization.
        """
        self.link_num = len(self.skeleton_info.keys())
        self.pose_link_color = []

        self.skeleton_name = []
        self.skeleton = []
        for skid in self.skeleton_info.keys():
            link = self.skeleton_info[skid]["link"]
            self.skeleton_name.append(link)
            self.skeleton.append(
                [self.keypoint_name2id[link[0]], self.keypoint_name2id[link[1]]]
            )
            self.pose_link_color.append(
                self.skeleton_info[skid].get("color", [255, 128, 0])
            )
        self.pose_link_color = np.array(self.pose_link_color)

    def _parse_keypoint_info(self):
        """Parse keypoint information.

        - keypoint_num (int): number of keypoints.
        - keypoint_id2name (dict): mapping keypoint id to keypoint name.
        - keypoint_name2id (dict): mapping keypoint name to keypoint id.
        - upper_body_ids (list): a list of keypoints that belong to the
            upper body.
        - lower_body_ids (list): a list of keypoints that belong to the
            lower body.
        - flip_index (list): list of flip index (id)
        - flip_pairs (list((2,))): list of flip pairs (id)
        - flip_index_name (list): list of flip index (name)
        - flip_pairs_name (list((2,))): list of flip pairs (name)
        - pose_kpt_color (np.ndarray): the color of the keypoint for
            visualization.
        """

        self.keypoint_num = len(self.keypoint_info.keys())
        self.keypoint_id2name = {}
        self.keypoint_name2id = {}

        self.pose_kpt_color = []
        self.upper_body_ids = []
        self.lower_body_ids = []

        self.flip_index_name = []
        self.flip_pairs_name = []

        for kid in self.keypoint_info.keys():

            keypoint_name = self.keypoint_info[kid]["name"]
            self.keypoint_id2name[kid] = keypoint_name
            self.keypoint_name2id[keypoint_name] = kid
            self.pose_kpt_color.append(
                self.keypoint_info[kid].get("color", [255, 128, 0])
            )

            type = self.keypoint_info[kid].get("type", "")
            if type == "upper":
                self.upper_body_ids.append(kid)
            elif type == "lower":
                self.lower_body_ids.append(kid)
            else:
                pass

            swap_keypoint = self.keypoint_info[kid].get("swap", "")
            if swap_keypoint == keypoint_name or swap_keypoint == "":
                self.flip_index_name.append(keypoint_name)
            else:
                self.flip_index_name.append(swap_keypoint)
                if [swap_keypoint, keypoint_name] not in self.flip_pairs_name:
                    self.flip_pairs_name.append([keypoint_name, swap_keypoint])

        self.flip_pairs = [
            [self.keypoint_name2id[pair[0]], self.keypoint_name2id[pair[1]]]
            for pair in self.flip_pairs_name
        ]
        self.flip_index = [self.keypoint_name2id[name] for name in self.flip_index_name]
        self.pose_kpt_color = np.array(self.pose_kpt_color)


def get_dataset_info_by_name(dataset):
    if dataset in (
        "TopDownCocoDataset",
        "BottomUpCocoDataset",
        "TopDownOCHumanDataset",
        "AnimalMacaqueDataset",
    ):
        from .pose_datasets.coco import dataset_info

        dataset_info = DatasetInfo(dataset_info)
    elif dataset == "TopDownCocoWholeBodyDataset":
        from .pose_datasets.coco_wholebody import dataset_info

        dataset_info = DatasetInfo(dataset_info)
    elif dataset == "TopDownAicDataset":
        raise NotImplementedError
    elif dataset == "TopDownMpiiDataset":
        from .pose_datasets.mpii import dataset_info

        dataset_info = DatasetInfo(dataset_info)

    elif dataset == "TopDownMpiiTrbDataset":
        pass
    elif dataset in ("OneHand10KDataset", "FreiHandDataset", "PanopticDataset"):
        from .pose_datasets.onehand10k import dataset_info

        dataset_info = DatasetInfo(dataset_info)

    elif dataset == "InterHand2DDataset":
        from .pose_datasets.interhand2d import dataset_info

        dataset_info = DatasetInfo(dataset_info)

    elif dataset == "Face300WDataset":
        pass

    elif dataset == "FaceAFLWDataset":
        pass

    elif dataset == "FaceCOFWDataset":
        pass

    elif dataset == "FaceWFLWDataset":
        pass

    elif dataset == "AnimalHorse10Dataset":
        pass

    elif dataset == "AnimalFlyDataset":
        pass
    elif dataset == "AnimalLocustDataset":
        pass

    elif dataset == "AnimalZebraDataset":
        pass

    elif dataset in "AnimalPoseDataset":
        from .pose_datasets.animalpose import dataset_info

        dataset_info = DatasetInfo(dataset_info)
    else:
        print('{} not supported for now.'.format(dataset))
    return dataset_info
