# -*- coding: utf-8 -*-
import logging
import math
from typing import List
from alfred.dl.torch.common import print_tensor
import numpy as np

import torch
import torch.nn.functional as F
from torch import nn

from detectron2.layers import ShapeSpec, batched_nms, cat, paste_masks_in_image
from detectron2.modeling.anchor_generator import DefaultAnchorGenerator
from detectron2.modeling.backbone import build_backbone
from detectron2.modeling.box_regression import Box2BoxTransform
from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
from detectron2.structures import Boxes, ImageList, Instances
from detectron2.utils.logger import log_first_n
from fvcore.nn import sigmoid_focal_loss_jit

from yolov7.utils.solov2_utils import imrescale, center_of_mass, point_nms, mask_nms, matrix_nms
from ..head.solov2_head import SOLOv2InsHead, SOLOv2MaskHead
from ..loss.loss import dice_loss, FocalLoss
from alfred.utils.log import logger
from alfred.vis.image.det import visualize_det_cv2_part, visualize_det_cv2_fancy
from alfred.vis.image.mask import label2color_mask, vis_bitmasks

__all__ = ["SOLOv2"]


@META_ARCH_REGISTRY.register()
class SOLOv2(nn.Module):
    """
    SOLOv2 model. Creates FPN backbone, instance branch for kernels and categories prediction,
    mask branch for unified mask features.
    Calculates and applies proper losses to class and masks.
    """

    def __init__(self, cfg):
        super().__init__()

        # get the device of the model
        self.device = torch.device(cfg.MODEL.DEVICE)

        self.scale_ranges = cfg.MODEL.SOLOV2.FPN_SCALE_RANGES
        self.strides = cfg.MODEL.SOLOV2.FPN_INSTANCE_STRIDES
        self.sigma = cfg.MODEL.SOLOV2.SIGMA
        # Instance parameters.
        self.num_classes = cfg.MODEL.SOLOV2.NUM_CLASSES
        self.num_kernels = cfg.MODEL.SOLOV2.NUM_KERNELS
        self.num_grids = cfg.MODEL.SOLOV2.NUM_GRIDS

        self.instance_in_features = cfg.MODEL.SOLOV2.INSTANCE_IN_FEATURES
        self.instance_strides = cfg.MODEL.SOLOV2.FPN_INSTANCE_STRIDES
        # = fpn.
        self.instance_in_channels = cfg.MODEL.SOLOV2.INSTANCE_IN_CHANNELS
        self.instance_channels = cfg.MODEL.SOLOV2.INSTANCE_CHANNELS

        # Mask parameters.
        self.mask_on = cfg.MODEL.MASK_ON
        self.mask_in_features = cfg.MODEL.SOLOV2.MASK_IN_FEATURES
        self.mask_in_channels = cfg.MODEL.SOLOV2.MASK_IN_CHANNELS
        self.mask_channels = cfg.MODEL.SOLOV2.MASK_CHANNELS
        self.num_masks = cfg.MODEL.SOLOV2.NUM_MASKS

        # Inference parameters.
        self.max_before_nms = cfg.MODEL.SOLOV2.NMS_PRE
        self.score_threshold = cfg.MODEL.SOLOV2.SCORE_THR
        self.update_threshold = cfg.MODEL.SOLOV2.UPDATE_THR
        self.mask_threshold = cfg.MODEL.SOLOV2.MASK_THR
        self.max_per_img = cfg.MODEL.SOLOV2.MAX_PER_IMG
        self.nms_kernel = cfg.MODEL.SOLOV2.NMS_KERNEL
        self.nms_sigma = cfg.MODEL.SOLOV2.NMS_SIGMA
        self.nms_type = cfg.MODEL.SOLOV2.NMS_TYPE

        # build the backbone.
        self.backbone = build_backbone(cfg)
        backbone_shape = self.backbone.output_shape()

        # build the ins head.
        instance_shapes = [backbone_shape[f]
                           for f in self.instance_in_features]
        logger.info('instance_shapes: {}'.format(instance_shapes))
        self.ins_head = SOLOv2InsHead(cfg, instance_shapes)

        # build the mask head.
        mask_shapes = [backbone_shape[f] for f in self.mask_in_features]
        self.mask_head = SOLOv2MaskHead(cfg, mask_shapes)

        # loss
        self.ins_loss_weight = cfg.MODEL.SOLOV2.LOSS.DICE_WEIGHT
        self.focal_loss_alpha = cfg.MODEL.SOLOV2.LOSS.FOCAL_ALPHA
        self.focal_loss_gamma = cfg.MODEL.SOLOV2.LOSS.FOCAL_GAMMA
        self.focal_loss_weight = cfg.MODEL.SOLOV2.LOSS.FOCAL_WEIGHT

        # image transform
        pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(
            self.device).view(3, 1, 1)
        pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(
            self.device).view(3, 1, 1)
        self.normalizer = lambda x: (x - pixel_mean) / pixel_std
        self.to(self.device)

        # add for onnx export
        self.onnx_export = cfg.MODEL.ONNX_EXPORT
        # self.onnx_export = False
        self.fixed_output_num = 100
        self.export_vis = False
        self.iter = 0

    def update_iter(self, i):
        self.iter = i

    def preprocess_input(self, x):
        x = x.permute(0, 3, 1, 2)
        # x = F.interpolate(x, size=(640, 640))
        # x = F.interpolate(x, size=(512, 960))
        x = self.normalizer(x)
        return x

    def forward(self, batched_inputs):
        """
        Args:
            batched_inputs: a list, batched outputs of :class:`DetectionTransform` .
                Each item in the list contains the inputs for one image.
            For now, each item in the list is a dict that contains:
                image: Tensor, image in (C, H, W) format.
                instances: Instances
                Other information that's included in the original dicts, such as:
                    "height", "width" (int): the output resolution of the model, used in inference.
                        See :meth:`postprocess` for details.
         Returns:
            losses (dict[str: Tensor]): mapping from a named loss to a tensor
                storing the loss. Used during training only.
        """
        if self.onnx_export:
            print('[WARN] exporting onnx...')
            assert isinstance(batched_inputs, torch.Tensor) or isinstance(
                batched_inputs, list), 'onnx export, batched_inputs only needs image tensor'
            images = self.preprocess_input(batched_inputs)
            batched_inputs = batched_inputs.permute(0, 3, 1, 2)
            # images = self.normalizer(images)
            print('batched_inputs: ', batched_inputs.shape)
        else:
            # print(batched_inputs)
            # print(batched_inputs[0]['image'].shape)
            #
            images = self.preprocess_image(batched_inputs)
            # print(images.tensor.shape)
            if "instances" in batched_inputs[0]:
                gt_instances = [x["instances"].to(
                    self.device) for x in batched_inputs]
            elif "targets" in batched_inputs[0]:
                log_first_n(
                    logging.WARN, "'targets' in the model inputs is now renamed to 'instances'!", n=10
                )
                gt_instances = [x["targets"].to(
                    self.device) for x in batched_inputs]
            else:
                gt_instances = None

        if self.onnx_export:
            features = self.backbone(images)
        else:
            features = self.backbone(images.tensor)

        # ins branch
        ins_features = [features[f] for f in self.instance_in_features]
        ins_features = self.split_feats(ins_features)
        cate_pred, kernel_pred = self.ins_head(ins_features)

        # mask branch
        mask_features = [features[f] for f in self.mask_in_features]
        mask_pred = self.mask_head(mask_features)

        if self.training:
            """
            get_ground_truth.
            return loss and so on.
            """
            mask_feat_size = mask_pred.size()[-2:]
            targets = self.get_ground_truth(gt_instances, mask_feat_size)
            losses = self.loss(cate_pred, kernel_pred, mask_pred, targets)
            return losses
        else:
            if self.onnx_export:
                # point nms.
                cate_pred = [point_nms(cate_p.sigmoid(), kernel=2).permute(0, 2, 3, 1)
                             for cate_p in cate_pred]
                # do inference for results.
                results = self.inference_onnx(
                    cate_pred, kernel_pred, mask_pred, images, batched_inputs)
                return results
            else:
                # point nms.
                cate_pred = [point_nms(cate_p.sigmoid(), kernel=2).permute(0, 2, 3, 1)
                             for cate_p in cate_pred]
                # do inference for results.
                results = self.inference(
                    cate_pred, kernel_pred, mask_pred, images.image_sizes, batched_inputs)
                return results

    def preprocess_image(self, batched_inputs):
        for a in batched_inputs:
            img = a["image"].cpu().permute(1, 2, 0).numpy().astype(np.uint8)
            ins = a['instances']
            bboxes = ins.gt_boxes.tensor.cpu().numpy().astype(int)
            clss = ins.gt_classes.cpu().numpy()
            im = img.copy()
            bit_masks = ins.gt_masks.tensor.cpu().numpy()
            print(bit_masks.shape)
            # img = vis_bitmasks_with_classes(img, clss, bit_masks)
            im = vis_bitmasks(im, bit_masks)
            im = visualize_det_cv2_part(im, None, clss, bboxes, is_show=True)

        images = [x["image"].to(self.device) for x in batched_inputs]
        images = [self.normalizer(x) for x in images]
        images = ImageList.from_tensors(
            images, self.backbone.size_divisibility)
        print(images.image_sizes)
        print(images.tensor.shape)
        return images

    @torch.no_grad()
    def get_ground_truth(self, gt_instances, mask_feat_size=None):
        ins_label_list, cate_label_list, ins_ind_label_list, grid_order_list = [], [], [], []
        for img_idx in range(len(gt_instances)):
            cur_ins_label_list, cur_cate_label_list, \
                cur_ins_ind_label_list, cur_grid_order_list = \
                self.get_ground_truth_single(img_idx, gt_instances,
                                             mask_feat_size=mask_feat_size)
            ins_label_list.append(cur_ins_label_list)
            cate_label_list.append(cur_cate_label_list)
            ins_ind_label_list.append(cur_ins_ind_label_list)
            grid_order_list.append(cur_grid_order_list)
        return ins_label_list, cate_label_list, ins_ind_label_list, grid_order_list

    def get_ground_truth_single(self, img_idx, gt_instances, mask_feat_size):
        gt_bboxes_raw = gt_instances[img_idx].gt_boxes.tensor
        gt_labels_raw = gt_instances[img_idx].gt_classes
        gt_masks_raw = gt_instances[img_idx].gt_masks.tensor
        device = gt_labels_raw[0].device

        # ins
        gt_areas = torch.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) * (
            gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))

        ins_label_list = []
        cate_label_list = []
        ins_ind_label_list = []
        grid_order_list = []
        for (lower_bound, upper_bound), stride, num_grid \
                in zip(self.scale_ranges, self.strides, self.num_grids):

            hit_indices = ((gt_areas >= lower_bound) & (
                gt_areas <= upper_bound)).nonzero().flatten()
            num_ins = len(hit_indices)

            ins_label = []
            grid_order = []
            cate_label = torch.zeros(
                [num_grid, num_grid], dtype=torch.int64, device=device)
            cate_label = torch.fill_(cate_label, self.num_classes)
            ins_ind_label = torch.zeros(
                [num_grid ** 2], dtype=torch.bool, device=device)

            if num_ins == 0:
                ins_label = torch.zeros(
                    [0, mask_feat_size[0], mask_feat_size[1]], dtype=torch.uint8, device=device)
                ins_label_list.append(ins_label)
                cate_label_list.append(cate_label)
                ins_ind_label_list.append(ins_ind_label)
                grid_order_list.append([])
                continue
            gt_bboxes = gt_bboxes_raw[hit_indices]
            gt_labels = gt_labels_raw[hit_indices]
            gt_masks = gt_masks_raw[hit_indices, ...]
            # print_tensor(gt_masks, 'gt_masks')

            half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma
            half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma

            # mass center
            center_ws, center_hs = center_of_mass(gt_masks)
            valid_mask_flags = gt_masks.sum(dim=-1).sum(dim=-1) > 0

            output_stride = 4
            gt_masks = gt_masks.permute(1, 2, 0).to(
                dtype=torch.uint8).cpu().numpy()
            gt_masks = imrescale(gt_masks, scale=1./output_stride)
            if len(gt_masks.shape) == 2:
                gt_masks = gt_masks[..., None]
            gt_masks = torch.from_numpy(gt_masks).to(
                dtype=torch.uint8, device=device).permute(2, 0, 1)
            for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags):
                if not valid_mask_flag:
                    continue
                upsampled_size = (mask_feat_size[0] * 4, mask_feat_size[1] * 4)
                coord_w = int(
                    (center_w / upsampled_size[1]) // (1. / num_grid))
                coord_h = int(
                    (center_h / upsampled_size[0]) // (1. / num_grid))

                # left, top, right, down
                top_box = max(
                    0, int(((center_h - half_h) / upsampled_size[0]) // (1. / num_grid)))
                down_box = min(
                    num_grid - 1, int(((center_h + half_h) / upsampled_size[0]) // (1. / num_grid)))
                left_box = max(
                    0, int(((center_w - half_w) / upsampled_size[1]) // (1. / num_grid)))
                right_box = min(
                    num_grid - 1, int(((center_w + half_w) / upsampled_size[1]) // (1. / num_grid)))

                top = max(top_box, coord_h-1)
                down = min(down_box, coord_h+1)
                left = max(coord_w-1, left_box)
                right = min(right_box, coord_w+1)

                cate_label[top:(down+1), left:(right+1)] = gt_label
                for i in range(top, down+1):
                    for j in range(left, right+1):
                        label = int(i * num_grid + j)

                        cur_ins_label = torch.zeros([mask_feat_size[0], mask_feat_size[1]], dtype=torch.uint8,
                                                    device=device)
                        cur_ins_label[:seg_mask.shape[0],
                                      :seg_mask.shape[1]] = seg_mask
                        ins_label.append(cur_ins_label)
                        ins_ind_label[label] = True
                        grid_order.append(label)
            if len(ins_label) == 0:
                ins_label = torch.zeros(
                    [0, mask_feat_size[0], mask_feat_size[1]], dtype=torch.uint8, device=device)
            else:
                ins_label = torch.stack(ins_label, 0)
            ins_label_list.append(ins_label)
            cate_label_list.append(cate_label)
            ins_ind_label_list.append(ins_ind_label)
            grid_order_list.append(grid_order)
        return ins_label_list, cate_label_list, ins_ind_label_list, grid_order_list

    def loss(self, cate_preds, kernel_preds, ins_pred, targets):
        pass
        ins_label_list, cate_label_list, ins_ind_label_list, grid_order_list = targets
        # ins
        ins_labels = [torch.cat([ins_labels_level_img
                                 for ins_labels_level_img in ins_labels_level], 0)
                      for ins_labels_level in zip(*ins_label_list)]

        kernel_preds = [[kernel_preds_level_img.view(kernel_preds_level_img.shape[0], -1)[:, grid_orders_level_img]
                         for kernel_preds_level_img, grid_orders_level_img in
                         zip(kernel_preds_level, grid_orders_level)]
                        for kernel_preds_level, grid_orders_level in zip(kernel_preds, zip(*grid_order_list))]
        # generate masks
        ins_pred_list = []
        for b_kernel_pred in kernel_preds:
            b_mask_pred = []
            for idx, kernel_pred in enumerate(b_kernel_pred):

                if kernel_pred.size()[-1] == 0:
                    continue
                cur_ins_pred = ins_pred[idx, ...]
                H, W = cur_ins_pred.shape[-2:]
                N, I = kernel_pred.shape
                cur_ins_pred = cur_ins_pred.unsqueeze(0)
                kernel_pred = kernel_pred.permute(1, 0).view(I, -1, 1, 1)
                cur_ins_pred = F.conv2d(
                    cur_ins_pred, kernel_pred, stride=1).view(-1, H, W)
                b_mask_pred.append(cur_ins_pred)
            if len(b_mask_pred) == 0:
                b_mask_pred = None
            else:
                b_mask_pred = torch.cat(b_mask_pred, 0)
            ins_pred_list.append(b_mask_pred)

        ins_ind_labels = [
            torch.cat([ins_ind_labels_level_img.flatten()
                       for ins_ind_labels_level_img in ins_ind_labels_level])
            for ins_ind_labels_level in zip(*ins_ind_label_list)
        ]
        flatten_ins_ind_labels = torch.cat(ins_ind_labels)

        num_ins = flatten_ins_ind_labels.sum()

        # dice loss
        loss_ins = []
        for input, target in zip(ins_pred_list, ins_labels):
            if input is None:
                continue
            input = torch.sigmoid(input)
            loss_ins.append(dice_loss(input, target))

        loss_ins_mean = torch.cat(loss_ins).mean()
        loss_ins = loss_ins_mean * self.ins_loss_weight

        # cate
        cate_labels = [
            torch.cat([cate_labels_level_img.flatten()
                       for cate_labels_level_img in cate_labels_level])
            for cate_labels_level in zip(*cate_label_list)
        ]
        flatten_cate_labels = torch.cat(cate_labels)

        cate_preds = [
            cate_pred.permute(0, 2, 3, 1).reshape(-1, self.num_classes)
            for cate_pred in cate_preds
        ]
        flatten_cate_preds = torch.cat(cate_preds)

        # prepare one_hot
        pos_inds = torch.nonzero(
            flatten_cate_labels != self.num_classes).squeeze(1)

        flatten_cate_labels_oh = torch.zeros_like(flatten_cate_preds)
        flatten_cate_labels_oh[pos_inds, flatten_cate_labels[pos_inds]] = 1

        loss_cate = self.focal_loss_weight * sigmoid_focal_loss_jit(flatten_cate_preds, flatten_cate_labels_oh,
                                                                    gamma=self.focal_loss_gamma,
                                                                    alpha=self.focal_loss_alpha,
                                                                    reduction="sum") / (num_ins + 1)
        return {'loss_ins': loss_ins,
                'loss_cate': loss_cate}

    @staticmethod
    def split_feats(feats):
        return (F.interpolate(feats[0], scale_factor=0.5, mode='bilinear'),
                feats[1],
                feats[2],
                feats[3],
                F.interpolate(feats[4], size=feats[3].shape[-2:], mode='bilinear'))

    def inference(self, pred_cates, pred_kernels, pred_masks, cur_sizes, images):
        assert len(pred_cates) == len(pred_kernels)

        results = []
        num_ins_levels = len(pred_cates)
        for img_idx in range(len(images)):
            # image size.
            ori_img = images[img_idx]
            height, width = ori_img["height"], ori_img["width"]
            ori_size = (height, width)

            # prediction.
            pred_cate = [pred_cates[i][img_idx].view(-1, self.num_classes).detach()
                         for i in range(num_ins_levels)]
            pred_kernel = [pred_kernels[i][img_idx].permute(1, 2, 0).view(-1, self.num_kernels).detach()
                           for i in range(num_ins_levels)]
            pred_mask = pred_masks[img_idx, ...].unsqueeze(0)

            pred_cate = torch.cat(pred_cate, dim=0)
            pred_kernel = torch.cat(pred_kernel, dim=0)

            # inference for single image.
            result = self.inference_single_image(pred_cate, pred_kernel, pred_mask,
                                                 cur_sizes[img_idx], ori_size)
            results.append({"instances": result})
        return results

    def inference_onnx(self, pred_cates, pred_kernels, pred_masks, cur_sizes, images):
        assert len(pred_cates) == len(pred_kernels)
        cur_size = images.shape[2:]
        results = []
        num_ins_levels = len(pred_cates)
        # for img_idx in range(len(images)):
        #     # image size.
        #     ori_img = images[img_idx]
        #     height, width = ori_img["height"], ori_img["width"]
        #     ori_size = (height, width)

        #     # prediction.
        #     pred_cate = [pred_cates[i][img_idx].view(-1, self.num_classes).detach()
        #                   for i in range(num_ins_levels)]
        #     pred_kernel = [pred_kernels[i][img_idx].permute(1, 2, 0).view(-1, self.num_kernels).detach()
        #                     for i in range(num_ins_levels)]
        #     pred_mask = pred_masks[img_idx, ...].unsqueeze(0)

        #     pred_cate = torch.cat(pred_cate, dim=0)
        #     pred_kernel = torch.cat(pred_kernel, dim=0)

        #     # inference for single image.
        #     result = self.inference_single_image_onnx(pred_cate, pred_kernel, pred_mask,
        #                                          cur_sizes[img_idx], ori_size)
        #     results.append({"instances": result})
        # return results
        # image size.
        # ori_img = images[0]
        # height, width = ori_img["height"], ori_img["width"]
        # ori_size = (height, width)

        # prediction.
        img_idx = 0
        print('pred_masks', pred_masks, pred_masks.shape)
        print(torch.isnan(pred_masks[0]).any())
        pred_cate = [pred_cates[i][img_idx].view(-1, self.num_classes).detach()
                     for i in range(num_ins_levels)]
        pred_kernel = [pred_kernels[i][img_idx].permute(1, 2, 0).view(-1, self.num_kernels).detach()
                       for i in range(num_ins_levels)]
        pred_mask = pred_masks[img_idx, ...].unsqueeze(0)
        print(torch.isnan(pred_mask).any())

        # assert torch.isnan(pred_mask).any(), 'pred_mask contains nan'

        pred_cate = torch.cat(pred_cate, dim=0)  # 3872, 80
        pred_kernel = torch.cat(pred_kernel, dim=0)  # 3872, 256
        print('pred cate: ', pred_cate.shape)
        print('pred kernel: ', pred_kernel.shape)

        # inference for single image.
        results = self.inference_single_image_onnx(
            pred_cate, pred_kernel, pred_mask, cur_size=cur_size, vis=self.export_vis)
        # results.append({"instances": result})
        return results

    def inference_single_image(
            self, cate_preds, kernel_preds, seg_preds, cur_size, ori_size
    ):
        # overall info.
        h, w = cur_size
        f_h, f_w = seg_preds.size()[-2:]
        ratio = math.ceil(h/f_h)
        upsampled_size_out = (int(f_h*ratio), int(f_w*ratio))

        # process.
        inds = (cate_preds > self.score_threshold)
        cate_scores = cate_preds[inds]
        if len(cate_scores) == 0:
            results = Instances(ori_size)
            results.scores = torch.tensor([])
            results.pred_classes = torch.tensor([])
            results.pred_masks = torch.tensor([])
            results.pred_boxes = Boxes(torch.tensor([]))
            return results

        # cate_labels & kernel_preds
        inds = inds.nonzero()
        cate_labels = inds[:, 1]
        kernel_preds = kernel_preds[inds[:, 0]]

        # trans vector.
        size_trans = cate_labels.new_tensor(self.num_grids).pow(2).cumsum(0)
        strides = kernel_preds.new_ones(size_trans[-1])

        n_stage = len(self.num_grids)
        strides[:size_trans[0]] *= self.instance_strides[0]
        for ind_ in range(1, n_stage):
            strides[size_trans[ind_ - 1]:size_trans[ind_]
                    ] *= self.instance_strides[ind_]
        strides = strides[inds[:, 0]]

        # mask encoding.
        N, I = kernel_preds.shape
        kernel_preds = kernel_preds.view(N, I, 1, 1)

        seg_preds = F.conv2d(seg_preds, kernel_preds,
                             stride=1).squeeze(0).sigmoid()

        # mask.
        seg_masks = seg_preds > self.mask_threshold
        sum_masks = seg_masks.sum((1, 2)).float()

        # filter.
        keep = sum_masks > strides
        if keep.sum() == 0:
            results = Instances(ori_size)
            results.scores = torch.tensor([])
            results.pred_classes = torch.tensor([])
            results.pred_masks = torch.tensor([])
            results.pred_boxes = Boxes(torch.tensor([]))
            return results

        seg_masks = seg_masks[keep, ...]
        seg_preds = seg_preds[keep, ...]
        sum_masks = sum_masks[keep]
        cate_scores = cate_scores[keep]
        cate_labels = cate_labels[keep]

        # mask scoring.
        seg_scores = (seg_preds * seg_masks.float()).sum((1, 2)) / sum_masks
        cate_scores *= seg_scores

        # sort and keep top nms_pre
        sort_inds = torch.argsort(cate_scores, descending=True)
        if len(sort_inds) > self.max_before_nms:
            sort_inds = sort_inds[:self.max_before_nms]
        seg_masks = seg_masks[sort_inds, :, :]
        seg_preds = seg_preds[sort_inds, :, :]
        sum_masks = sum_masks[sort_inds]
        cate_scores = cate_scores[sort_inds]
        cate_labels = cate_labels[sort_inds]

        if self.nms_type == "matrix":
            # matrix nms & filter.
            cate_scores = matrix_nms(cate_labels, seg_masks, sum_masks, cate_scores,
                                     sigma=self.nms_sigma, kernel=self.nms_kernel)
            keep = cate_scores >= self.update_threshold
        elif self.nms_type == "mask":
            # original mask nms.
            keep = mask_nms(cate_labels, seg_masks, sum_masks, cate_scores,
                            nms_thr=self.mask_threshold)
        else:
            raise NotImplementedError

        if keep.sum() == 0:
            results = Instances(ori_size)
            results.scores = torch.tensor([])
            results.pred_classes = torch.tensor([])
            results.pred_masks = torch.tensor([])
            results.pred_boxes = Boxes(torch.tensor([]))
            return results

        seg_preds = seg_preds[keep, :, :]
        cate_scores = cate_scores[keep]
        cate_labels = cate_labels[keep]

        # sort and keep top_k
        sort_inds = torch.argsort(cate_scores, descending=True)
        if len(sort_inds) > self.max_per_img:
            sort_inds = sort_inds[:self.max_per_img]
        seg_preds = seg_preds[sort_inds, :, :]
        cate_scores = cate_scores[sort_inds]
        cate_labels = cate_labels[sort_inds]

        # reshape to original size.
        seg_preds = F.interpolate(seg_preds.unsqueeze(0),
                                  size=upsampled_size_out,
                                  mode='bilinear')[:, :, :h, :w]
        seg_masks = F.interpolate(seg_preds,
                                  size=ori_size,
                                  mode='bilinear').squeeze(0)
        seg_masks = seg_masks > self.mask_threshold

        results = Instances(ori_size)
        results.pred_classes = cate_labels
        results.scores = cate_scores
        results.pred_masks = seg_masks

        # get bbox from mask
        pred_boxes = torch.zeros(seg_masks.size(0), 4)
        for i in range(seg_masks.size(0)):
            mask = seg_masks[i].squeeze()
            ys, xs = torch.where(mask)
            pred_boxes[i] = torch.tensor(
                [xs.min(), ys.min(), xs.max(), ys.max()]).float()
        results.pred_boxes = Boxes(pred_boxes)
        return results

    def inference_single_image_onnx(
            self, cate_preds, kernel_preds, seg_preds, cur_size=None, ori_size=None, vis=False
    ):
        assert not torch.isnan(seg_preds).any(), 'seg_preds contains nan'
        ori_h, ori_w = cur_size
        upsampled_size_out = (ori_h, ori_w)

        cate_scores, cate_preds = torch.max(cate_preds, dim=-1)
        # we toke only top 100 by scores
        cate_scores, inds = torch.topk(cate_scores, k=self.fixed_output_num)
        cate_labels = cate_preds[inds]
        cate_labels = cate_labels.long()

        kernel_preds = kernel_preds[inds]

        # trans vector.
        size_trans = cate_labels.new_tensor(
            self.num_grids).pow(2).cumsum(0)  # just grids pow
        strides = kernel_preds.new_ones(size_trans[-1])

        n_stage = len(self.num_grids)
        strides[:size_trans[0]] *= self.instance_strides[0]
        for ind_ in range(1, n_stage):
            strides[size_trans[ind_ - 1]:size_trans[ind_]
                    ] *= self.instance_strides[ind_]
        strides = strides[inds]

        # mask encoding.
        N, I = kernel_preds.shape
        kernel_preds = kernel_preds.view(N, I, 1, 1)

        sh = torch.tensor(seg_preds.shape)
        seg_preds = seg_preds.view(sh[0], sh[1], -1)
        sh_kernel = torch.tensor(kernel_preds.shape)
        kernel_preds = kernel_preds.view(sh_kernel[0], sh_kernel[1])

        seg_preds = torch.matmul(kernel_preds, seg_preds)
        seg_preds = seg_preds.view(sh[0], N, sh[2], sh[3])
        seg_preds = seg_preds.squeeze(0).sigmoid()

        # mask.
        seg_masks = seg_preds > torch.tensor(self.mask_threshold).float()
        # seg_masks = seg_preds > self.mask_threshold
        seg_masks = seg_masks.long()  # convert booleans to ints
        sum_masks = seg_masks.sum((1, 2)).float() + \
            0.1  # incase sum_masks contains 0

        seg_scores = ((seg_preds * seg_masks.float()
                       ).sum((1, 2)) + 0.1) / sum_masks
        cate_scores *= seg_scores

        # sort and keep top nms_pre
        _, sort_inds = torch.sort(cate_scores, descending=True)
        if len(sort_inds) > self.max_before_nms:
            sort_inds = sort_inds[:self.max_before_nms]
        seg_masks = seg_masks[sort_inds, :, :]
        sum_masks = sum_masks[sort_inds]
        cate_scores = cate_scores[sort_inds]
        cate_labels = cate_labels[sort_inds]

        if self.nms_type == "matrix":
            # matrix nms & filter.
            cate_scores = matrix_nms(cate_labels, seg_masks, sum_masks, cate_scores,
                                     sigma=self.nms_sigma, kernel=self.nms_kernel)
            keep = cate_scores >= torch.tensor(
                self.update_threshold, dtype=torch.float)
        elif self.nms_type == "mask":
            # original mask nms.
            keep = mask_nms(cate_labels, seg_masks, sum_masks, cate_scores,
                            nms_thr=self.mask_threshold)
        else:
            raise NotImplementedError

        # keep can also return as matrix_nms output
        seg_preds = seg_preds[sort_inds, :, :]

        # sort and keep top_k
        if vis:
            # only need it for vis
            seg_preds = seg_preds[keep, :, :]
            cate_scores = cate_scores[keep]
            cate_labels = cate_labels[keep]

            sort_inds = torch.argsort(cate_scores, descending=True)
            if len(sort_inds) > self.max_per_img:
                sort_inds = sort_inds[:self.max_per_img]
            seg_preds = seg_preds[sort_inds, :, :]
            cate_scores = cate_scores[sort_inds]
            cate_labels = cate_labels[sort_inds]

            # reshape to original size.
            seg_masks = F.interpolate(seg_preds.unsqueeze(0),
                                      size=upsampled_size_out,
                                      mode='bilinear')[:, :, :ori_h, :ori_w].squeeze(0)
            # seg_masks = F.interpolate(seg_preds,
            #                           size=ori_size,
            #                           mode='bilinear').squeeze(0)
            seg_masks = seg_masks > self.mask_threshold
            logger.info(seg_masks.shape)

            results = Instances(upsampled_size_out)
            results.pred_classes = cate_labels
            results.scores = cate_scores
            results.pred_masks = seg_masks

            # get bbox from mask
            pred_boxes = torch.zeros(seg_masks.size(0), 4)
            for i in range(seg_masks.size(0)):
                mask = seg_masks[i].squeeze()
                ys, xs = torch.where(mask)
                # print(ys)
                # print(xs)
                if ys.shape[0] > 0 and xs.shape[0] > 0:
                    pred_boxes[i] = torch.tensor(
                        [xs.min(), ys.min(), xs.max(), ys.max()]).float()
                    print(pred_boxes)
            results.pred_boxes = Boxes(pred_boxes)
            return {"instances": results}
        else:
            seg_masks = F.interpolate(seg_preds.unsqueeze(0),
                                      size=(max(int(ori_h*0.6), 736),
                                            max(int(ori_w*0.6), 992)),
                                      mode='bilinear')
            seg_masks = seg_masks > torch.tensor(self.mask_threshold).float()
            seg_masks = seg_masks.float()
            keep = keep.long()
            return seg_masks, cate_scores, cate_labels, keep

