import json
from importlib import import_module
import matplotlib.pyplot as plt
import numpy as np
import torch
from ..run_utils.utils import convert_pytorch_checkpoint


####
class InferManager(object):
    def __init__(self, **kwargs):
        self.run_step = None
        for variable, value in kwargs.items():
            self.__setattr__(variable, value)
        self.__load_model()
        self.nr_types = self.method["model_args"]["nr_types"]
        # create type info name and colour

        # default
        self.type_info_dict = {
            None: ["no label", [0, 0, 0]],
        }

        if self.nr_types is not None and self.type_info_path is not None:
            self.type_info_dict = json.load(open(self.type_info_path, "r"))
            self.type_info_dict = {
                int(k): (v[0], tuple(v[1])) for k, v in self.type_info_dict.items()
            }
            # availability check
            for k in range(self.nr_types):
                if k not in self.type_info_dict:
                    assert False, "Not detect type_id=%d defined in json." % k

        if self.nr_types is not None and self.type_info_path is None:
            cmap = plt.get_cmap("hot")
            colour_list = np.arange(self.nr_types, dtype=np.int32)
            colour_list = (cmap(colour_list)[..., :3] * 255).astype(np.uint8)
            # should be compatible out of the box wrt qupath
            self.type_info_dict = {
                k: (str(k), tuple(v)) for k, v in enumerate(colour_list)
            }
        return

    def __load_model(self):
        """Create the model, load the checkpoint and define
        associated run steps to process each data batch.
        
        """
        model_desc = import_module(".segmentation.hovernet.models.hovernet.net_desc", package="tumourkit")
        model_creator = getattr(model_desc, "create_model")

        net = model_creator(**self.method["model_args"])
        if self.nr_gpus > 0:
            saved_state_dict = torch.load(self.method["model_path"])["desc"]
        else:
            saved_state_dict = torch.load(self.method["model_path"], map_location=torch.device('cpu'))["desc"]
        saved_state_dict = convert_pytorch_checkpoint(saved_state_dict)

        net.load_state_dict(saved_state_dict, strict=True)
        net = torch.nn.DataParallel(net)
        if self.nr_gpus > 0:
            net = net.to("cuda")

        module_lib = import_module(".segmentation.hovernet.models.hovernet.run_desc", package="tumourkit")
        run_step = getattr(module_lib, "infer_step")
        self.run_step = lambda input_batch: run_step(input_batch, net, use_cpu=self.nr_gpus==0)

        module_lib = import_module(".segmentation.hovernet.models.hovernet.post_proc", package="tumourkit")
        self.post_proc_func = getattr(module_lib, "process")
        return

    def __save_json(self, path, old_dict, mag=None):
        new_dict = {}
        for inst_id, inst_info in old_dict.items():
            new_inst_info = {}
            for info_name, info_value in inst_info.items():
                # convert to jsonable
                if isinstance(info_value, np.ndarray):
                    info_value = info_value.tolist()
                new_inst_info[info_name] = info_value
            new_dict[int(inst_id)] = new_inst_info

        json_dict = {"mag": mag, "nuc": new_dict}  # to sync the format protocol
        with open(path, "w") as handle:
            json.dump(json_dict, handle)
        return new_dict
