import json
import cv2 as cv
import os
import os.path as osp
from tqdm import tqdm
from PIL import Image
import base64
import numpy as np
from labelme.label_file import LabelFile, PY2, QT4, utils, io
import shutil as sh

__version__ = "0.1.0"

def encodeImage(imarr, ext=".jpg"):
    """@See Also LabelFile.load_image_file

    Arguments:
        imarr(np.ndarray | PIL.Image | str)

    Return:
        (str) base64 encoded image
    """
    imtype = type(imarr)
    if isinstance(imarr, Image.Image):
        pass
    elif imtype == np.ndarray:
        imarr = Image.fromarray(imarr)
    elif imtype == str:
        imarr = PIL.Image.open(imarr)
    else:
        raise ValueError("Unsupported Type: ", imtype)
    # apply orientation to image according to exif
    image_pil = utils.apply_exif_orientation(imarr)

    with io.BytesIO() as f:
        # ext = osp.splitext(filename)[1].lower()
        if PY2 and QT4:
            format = "PNG"
        elif ext in [".jpg", ".jpeg"]:
            format = "JPEG"
        else:
            format = "PNG"
        image_pil.save(f, format=format)
        f.seek(0)
        # return f.read()
        imageData = f.read()
    return base64.b64encode(imageData).decode("utf-8")


def calcY(line, x):
    return line[0] * x + line[1]

import copy
from PIL import ImageEnhance

class LabelmeModifier(object):
    """
    用于旋转、缩放 Labelme 已经标记好的图片

    Arguments:
        root(str): root directory that contains the raw image and json file
        generated by `labelme`
        dst_dir(str)
    """

    def __init__(self, root, dst_dir, image_suffix=".jpg", debug=False, **kwargs):
        self.root = root
        self.dst_dir = dst_dir
        self.debug = debug
        self.image_suffix = image_suffix
        files = [f for f in sorted(os.listdir(root))]
        self.images = [file for file in files if image_suffix in file]
        jsonfiles = [file for file in files if ".json" in file ]
        self.names = [img[: img.rfind(image_suffix)] for img in self.images]

        print(f"Found files at {root} images ({image_suffix}): {len(self.images)}, jsons: {len(jsonfiles)}")
        
        os.makedirs(self.dst_dir, exist_ok=True)


    def readJsonFile(self, file):
        loc = osp.join(self.root, file)
        if not osp.exists(loc):
            return None
        try:
            with open(loc, "r") as f:
                content = json.loads(f.read())
        except Exception as e:
            print(e)
            raise ValueError(f"Exception for {file}: ")
        return content


    def _resize(self, name, size, jsoncontent):
        """
        Note: **imageData in jsoncontent is not modified, it should be modified by the last step.**

        Arguments:
            name(str): filename stem without path location
            size(tuple): new size that resized image will be
            jsoncontent: correspoinding labelme json file content

        Return:
            tuple: (newImage, jsoncontent)
        """
        # resize images only
        image = Image.open(osp.join(self.root, f"{name}{self.image_suffix}"))
        iw, ih = image.size
        wscale = size[0] / iw
        hscale = size[1] / ih
        newImage = image.resize(size)

        if jsoncontent is None:
            if self.debug:
                print(f"Not Found json file for {name}")
            return newImage, None
        jsoncontent["imageWidth"]  = size[0]
        jsoncontent["imageHeight"] = size[1]
        if self.debug:
            newImage.show(f"Resized Image {name} {wscale} {hscale}")
        shapes = jsoncontent["shapes"]
        for shape in shapes:
            for point in shape["points"]:
                point[0] *= wscale
                point[1] *= hscale
        
        if self.debug:
            print("Name: ", name)
        return newImage, jsoncontent

    def resize(self, size, saveImageData=False):
        """
        Collabrate with crop, If crop is need before resize, using cropResize instead

        Arguemengs:
            size(tuple): W X H
        """
        count = 0
        for name in tqdm(self.names):
            jsoncontent = self.readJsonFile(f"{name}.json")
            # print("JsonContent", jsoncontent)
            newImage, jsoncontent = self._resize(name, size, jsoncontent)
            savedLoc = osp.join(self.dst_dir, f"{name}{self.image_suffix}")
            newImage.save(savedLoc)
            if jsoncontent is not None:
                imageData = LabelFile.load_image_file(savedLoc)
                imageData = base64.b64encode(imageData).decode("utf-8")
                jsoncontent["imageData"] = imageData if saveImageData else None

                with open(osp.join(self.dst_dir, f"{name}.json"), "w") as f:
                    f.write(json.dumps(jsoncontent, indent=True))
            count += 1
            if self.debug:
                break

        print(f"Done for {count}")

    def _rotate(self, name, jsoncontent, degree, center, boarderValue, scale=1.0):
        """
        Arguments: 
            name(str): Image Name
            degree(float): degrees
            center(tuple): X,Y

        See torchutils.rotateImage
        """
        image = Image.open(osp.join(self.root, f"{name}{self.image_suffix}"))
        iw, ih = image.size    #    
        if center is None:
            center = (iw // 2, ih // 2 )
        imarr = np.asarray(image)
        
        M = cv.getRotationMatrix2D(center, degree, scale)
        rotated = cv.warpAffine(imarr, M, (iw, ih), borderValue=boarderValue)
        # if self.debug:
        #     cv.imshow("default", rotated)
        #     cv.waitKey(0)

        if jsoncontent is None:
            return imarr

        shapes = jsoncontent["shapes"]
        for shape in shapes:
            points = []
            # rotatepoints = cv.warpAffine(points, M, (iw, ih), borderValue=boarderValue)
            # print(rotatepoints.shape, points.shape)
            for point in shape["points"]:
                newPoint = np.dot(M[:,:2], np.asarray(point)) + M[:,2]
                # print("\n=======", newPoint, M[:, 2])
                # check point pos
                if newPoint[0] > iw: newPoint[0] = iw
                if newPoint[1] > ih: newPoint[0] = ih
                points.append([newPoint[0], newPoint[1]])
                # if self.debug: break
            # if self.debug: break
            shape["points"] = points
        return Image.fromarray(rotated)
        

    def rotate(self, degs, center=None, boarderValue = (0, 0, 0), scale=1.0, **kwargs):
        """
        Arguments:
            degs: list or iterable (float)
            center(tuple | None): if the all the image size are the same, it could be applied when rotating,
                if image size differes from each other, leave it unchanged to let program choose the center 
                of the image
        """
        count = 0
        if kwargs.get("start"):
            print("[Warning] api changed, using `degs` instead")
        for name in tqdm(self.names):
            orijsoncontent = self.readJsonFile(f"{name}.json")
            if orijsoncontent is None:
                continue
            # print("JsonContent", jsoncontent)
            rc = 0
            for deg in degs:
                jsoncontent = copy.deepcopy(orijsoncontent) 
                newImage = self._rotate(name, jsoncontent, deg, center, boarderValue, scale=scale)
                newName = f"{name}-r{deg}-s{scale}-{rc:04}"
                savedLoc = osp.join(self.dst_dir, f"{newName}{self.image_suffix}")
                rc += 1
                newImage.save(savedLoc)
                if jsoncontent is not None:
                    imageData = LabelFile.load_image_file(savedLoc)
                    imageData = base64.b64encode(imageData).decode("utf-8")
                    jsoncontent["imageData"] = imageData

                    with open(osp.join(self.dst_dir, f"{newName}.json"), "w") as f:
                        f.write(json.dumps(jsoncontent, indent=True))
                count += 1

                if self.debug:
                    break
            if self.debug:
                break

        print(f"Done Rotation for {count}")

    def _shift(self, name, jsoncontent, xy, boarderValue):
        """
        Arguments: 
            name(str): Image Name
            degree(float): degrees
            center(tuple): X,Y

        See torchutils.rotateImage
        """
        image = Image.open(osp.join(self.root, f"{name}{self.image_suffix}"))
        iw, ih = image.size    #    

        if jsoncontent is None:
            return image

        image = image.rotate(0, translate=xy)
        shapes = jsoncontent["shapes"]
        for shape in shapes:
            points = []
            for point in shape["points"]:
                newPoint = [point[0] + xy[0], point[1] + xy[1]]
                # print("\n=======", newPoint, M[:, 2])
                # check point pos
                if newPoint[0] > iw: newPoint[0] = iw
                if newPoint[1] > ih: newPoint[0] = ih
                points.append([newPoint[0], newPoint[1]])
                # if self.debug: break
            # if self.debug: break
            shape["points"] = points
        return image
        
    def shift(self, maxpixel=10, step = 10, direction = (1, 1), boarderValue = (0, 0, 0)) :
        """
        pixel shift will be 10, 5 if max is 10 and step is 5
        direction is the vector that shift will be, for example (1, 1) 
        will shift the image to the up-right corner for pixel
        """
        count = 0
        assert direction[0] != 0 or direction[1] != 0
        for name in tqdm(self.names):
            orijsoncontent = self.readJsonFile(f"{name}.json")
            if orijsoncontent is None:
                continue
            # print("JsonContent", jsoncontent)
            rc = 0
            for pixel in range(maxpixel, 0, - step):
                if pixel <= 0: break
                xy = (direction[0] * pixel, direction[1] * pixel)
                jsoncontent = copy.deepcopy(orijsoncontent) 
                newImage = self._shift(name, jsoncontent, xy, boarderValue)
                newName = f"{name}-t{direction[0]}-{direction[1]}-{pixel}-{rc:04}"
                savedLoc = osp.join(self.dst_dir, f"{newName}{self.image_suffix}")
                rc += 1
                newImage.save(savedLoc)
                if jsoncontent is not None:
                    imageData = LabelFile.load_image_file(savedLoc)
                    imageData = base64.b64encode(imageData).decode("utf-8")
                    jsoncontent["imageData"] = imageData

                    with open(osp.join(self.dst_dir, f"{newName}.json"), "w") as f:
                        f.write(json.dumps(jsoncontent, indent=True))
                count += 1

                if self.debug:
                    break
            if self.debug:
                break

        print(f"Done Shift for {count}")

    def cropResize(self, leftTop, cropSize, newSize = None):
        """
        Arguements:
            leftTop(tuple): (X, Y)

            cropSize(tuple): W X H, croped size.  
                (cropSize + leftTop) should be less than original_size at W & H, 
                if not, (cropSize + leftTop) will reach the original images rightBottom

            newSize(tuple): W X H, rescaled size after cropped, if it's None, then it 
                won't be resized
        """
        raise NotImplementedError("Not Impletmented Yet")
        
    def copy(self):
        for name in tqdm(self.names):
            # imane
            imname = f"{name}{self.image_suffix}"
            sh.copy(osp.join(self.root, imname), osp.join(self.dst_dir, imname))
            jsonname = f"{name}.json"
            sh.copy(osp.join(self.root, jsonname), osp.join(self.dst_dir, jsonname))

    def output(self, destdir):
        """
        Arguments:
            destdir(str): dest directory
        """
        pass

        
    def polyfitSlope(self):
        for jsonfile in tqdm(jsonfiles):
        #     print(jsonfile)
            with open(src/jsonfile, "r") as f:
                content = json.loads(f.read())
            # print(content)
            shapes = content['shapes']
            imagePath = content['imagePath']
            occipitalSlope = None
            sphenoidbone = None
            newShapes = []
            for i, shape in enumerate(shapes):
                if shape['label'] == 'SphenoidBone':
                    sphenoidbone = shape
                elif shape['label'] == labelType:
                    occipitalSlope = shape
                    continue
                newShapes.append(shape)
                
            shapes = newShapes
                    
        #     if occipitalSlope is not None:
        #         continue
            imdata = cv.imread(str(src/ imagePath), 0)
        #     flatcontour = np.asarray(sphenoidbone["points"], dtype=int)
        #     contourLen = len(flatcontour)
            # print(flatcontour)
        #     maxy = (np.argmax(flatcontour[:, 1]))
        #     maxy = (maxy + 1) % len(flatcontour)
        #     if (maxy + 5) < len(flatcontour):
        #         contour = flatcontour[maxy:maxy+5]
        #     else:
        #         contour = flatcontour[maxy:]
        #         con = contour.shape[0]
        # #         print("length", con)
        #         contour = np.concatenate((contour, flatcontour[: 5-con]))
        #     line = np.polyfit(contour[:,0].reshape(-1), contour[:,1].reshape(-1), 1)
        #     x1 = float(contour[0, 0])
        #     x2 = float(contour[-1, 0] + 20)
        #     y1 = float(calcY(line, x1))
        #     y2 = float(calcY(line, x2))
        #     print([[x1, y1], [x2, y2]])
            # cv.line(imdata, (x1, y1), (x2, y2),(0,255,0), 2) 
            # plt.figure(figsize=(10, 8 ), dpi=100)
            # plt.imshow(imdata, cmap='gray', vmin=0, vmax=255)
        #     occipitalSlope = {"label": labelType, "points": [[x1, y1], [x2, y2], [x2, y2+1], [x1, y1+1]], "gorup_id": None, "shape_type": "polygon", "flags": {}}
        #     shapes.append(occipitalSlope)
        #     print([s["label"] for s in shapes], len(shapes))
        #     break
            content["shapes"] = shapes
            try:
                data = json.dumps(content, indent=True)
                
                with open(src/jsonfile, "w") as f:
                    f.write(data)

            except Exception as e:
                    print(e)

class LabelmeLabelModifier(LabelmeModifier):
    """专用于修改 Label 名称或调整 label 数量的
    """
    def removeLabels(self, labels):
        """
        Remove Labels if exits
        Arugments:
            labels(list(str))
        """
        assert type(labels) is list
        r_labels = {}
        for label in labels:
            r_labels[label] = { "num": 0 }
        self._renameShapes(r_labels)

   
    def _renameShapes(self, labelShapes):
        """Rename one shape at one time,

        Arguments:
            labelShapes(dict): 
            ```
            {
                "old_label_name": {
                    "label": "new_label_name", // empty for not change
                    "num": 1000, // default 1000,
                    "shape_idx": False, // default False
                }
            }
            ```
        """
        count = 0

        copyImage = (self.root != self.dst_dir)

        for name in tqdm(self.names):
            orijsoncontent = self.readJsonFile(f"{name}.json")
            if orijsoncontent is None:
                continue
            new_shapes = []
            reserve_count = 0
            # iterate over original json content
            for shape in orijsoncontent["shapes"]:
                # shape is a dict that contains label
                # print("====\n", shape, "====\n")
                oldName = shape["label"] 
                if oldName not in labelShapes:
                    # 如果找不到需要修改的label，默认保留旧 label
                    new_shapes.append(shape)
                    continue
                prop = labelShapes[oldName]
                # 默认使用旧名称
                new_label_name = prop.get("label", oldName)
                shape_idx = prop.get("shape_idx", False)
                num_reserved = prop.get("num", 1000)
                # if count not meet, reserve it
                if reserve_count < num_reserved:
                    nshape = copy.deepcopy(shape)
                    # 避免重复排序
                    try: 
                        d = new_label_name[new_label_name.rfind("-")+1:]
                        d = int(d)
                        shape_idx = False
                    except:
                        pass
                    if shape_idx:
                        nshape["label"] = new_label_name + f"-{reserve_count}"
                    else:
                        nshape["label"] = new_label_name
                    new_shapes.append(nshape)
                    reserve_count += 1
            # print(new_shapes)
            orijsoncontent["shapes"] = list(sorted(new_shapes, key=lambda x: x['label']))
            with open(osp.join(self.dst_dir, f"{name}.json"), "w") as f:
                f.write(json.dumps(orijsoncontent, indent=True))
            if copyImage:
                sh.copy(osp.join(self.root, f"{name}{self.image_suffix}"), self.dst_dir)
            count += 1
        print(f"Done rename shapes: {count}")

    def renameShapes(self, shapename, new_shapename=None, num_reserved=None, shape_idxes=None):
        """
        Arguments:
            shapename (list(str) ) | dict 
            new_shapename (list(str))
            num_reserved  list(Number) max_number of instance reserved, default 1000
            shape_idxes list(Boolean) whether to append idxes for different instance in one Picture
            
        Example:
        
        use one dict:

        ```py
        lm = LabelmeModifier(src_dir, dst_dir, image_suffix=".png")
        labels = {
            "Person": {
                "label": "Human",
                "num": 100, 
                "shape_idx": True,
            }
        }
        lm.renameShapes(labels)
        ```

        or use multiple arrays instead:

        ```py
        lm = LabelmeModifier(src_dir, dst_dir, image_suffix=".png")
        ori_label = ["Person"]
        new_label = ["Human"]
        nums = [100]
        lm.renameShapes(ori_label, new_label, nums)
        ```

        """
        stype = type(shapename)
        if stype == str:
            raise ValueError("String is not supported anymore. Please Use List instead.")
        if stype == dict:

            self._renameShapes(shapename)
        else:
            assert stype is list, " shapename must be either List or Dict"

            if new_shapename is None:
                new_shapename = shapename
            slen = len(shapename)
            if shape_idxes is None:
                shape_idxes = [False for _ in range(slen)]
            if num_reserved is None:
                num_reserved = [1000 for _ in range(slen)]
            assert len(new_shapename) == len(shapename) == len(shape_idxes) == len(num_reserved)
            shapes = {}
            for idx, name in enumerate(shapename):
                shapes[name] = {
                    "label": new_shapename[idx],
                    "num": num_reserved[idx],
                    "shape_idx": shape_idxes[idx]
                }
            self._renameShapes(shapes)
        


def moveImages(src, dst):
    """move Images from src dir with json files not moved
    """
    jsons = []
    imgs = []
    for root, dirs, files in os.walk(src):
        for file in files:
            if file.endswith(".json"):
                jsons.append(file[:file.rfind(".")])
            else:
                imgs.append(file)
    
    dst = str(dst)
    for file in imgs:
        filename = file[:file.rfind(".")]
        if filename in jsons:
            continue
        # print(filename)
        sh.move(str(src / file), dst)


import argparse
import glob
import os
import os.path as osp
import sys

import imgviz
import numpy as np

import labelme

class Labelme2Vocor():

    def __init__(self, input_dir, output_dir, labels, noviz=False, debug=False):
        if osp.exists(output_dir):
            print("Output directory already exists: ", output_dir, " Deleting...")
            import shutil as sh
            sh.rmtree(output_dir)
        os.makedirs(output_dir)
        os.makedirs(osp.join(output_dir, "JPEGImages"))
        os.makedirs(osp.join(output_dir, "SegmentationClass"))
        os.makedirs(osp.join(output_dir, "SegmentationClassPNG"))

        if not noviz:
            os.makedirs(osp.join(output_dir, "SegmentationClassVisualization"))
        print("Creating dataset:", output_dir)

        self.input_dir = input_dir
        self.output_dir = output_dir
        self.labels = labels
        self.noviz = noviz
        self.debug = debug

    def getClasses(self):
        ## get class names
        class_names = []
        class_name_to_id = {}
        for i, line in enumerate(open(self.labels).readlines()):
            class_id = i - 1  # starts with -1
            class_name = line.strip()
            if len(class_name) == 0:
                continue
            class_name_to_id[class_name] = class_id
            if class_id == -1:
                assert class_name == "__ignore__"
                continue
            elif class_id == 0:
                assert class_name == "_background_"
            class_names.append(class_name)

        class_names = tuple(class_names)
        self.class_names = class_names
        self.class_name_to_id = class_name_to_id
        print("class_names:", class_names)
        out_class_names_file = osp.join(self.output_dir, "class_names.txt")

        with open(out_class_names_file, "w") as f:
            f.writelines("\n".join(class_names))
        print("Saved class_names:", out_class_names_file)

    def output(self, ordered_keys=None):
        """
        Arguments:
            ordered_keys(list(str))
        """
        for filename in tqdm(glob.glob(osp.join(self.input_dir, "*.json"))):
            # print("Generating dataset from:", filename)

            label_file = labelme.LabelFile(filename=filename)

            base = osp.splitext(osp.basename(filename))[0]
            out_img_file = osp.join(self.output_dir, "JPEGImages", base + ".jpg")
            out_lbl_file = osp.join(
                self.output_dir, "SegmentationClass", base + ".npy"
            )
            out_png_file = osp.join(
                self.output_dir, "SegmentationClassPNG", base + ".png"
            )
            if not self.noviz:
                out_viz_file = osp.join(
                    self.output_dir,
                    "SegmentationClassVisualization",
                    base + ".jpg",
                )

            with open(out_img_file, "wb") as f:
                f.write(label_file.imageData)
            img = labelme.utils.img_data_to_arr(label_file.imageData)
            if ordered_keys is not None:
                newshapes = []
                for ok in ordered_keys:
                    for shape in label_file.shapes:
                        if shape["label"] == ok:
                            newshapes.append(shape)
                label_file.shapes=newshapes

            if self.debug:
                print(label_file.shapes)
                break

            lbl, _ = labelme.utils.shapes_to_label(
                img_shape=img.shape,
                shapes=label_file.shapes,
                label_name_to_value=self.class_name_to_id,
            )
            labelme.utils.lblsave(out_png_file, lbl)

            # np.save(out_lbl_file, lbl)
                
            if not self.noviz:
                if img.shape[0] == 1: # gray img
                    img = imgviz.rgb2gray(img)
                viz = imgviz.label2rgb(
                    label=lbl,
                    # img=imgviz.rgb2gray(img),
                    img=img,
                    font_size=15,
                    label_names=self.class_names,
                    loc="rb",
                )
                imgviz.io.imsave(out_viz_file, viz)



import collections
import datetime
import glob
import uuid

import numpy as np
import labelme
import re
SubCategoryPatter = re.compile("(\w+)-(\w+|\d+)")

try:
    import pycocotools.mask as cocomask
except ImportError:
    print("Please install pycocotools:\n\n    pip install pycocotools\n")
    sys.exit(1)

class Labelme2Cocoer():
    """

    """
    def __init__(self, output_dir, cocomask=None, debug=False):
        """
        Arguments:
            output_dir(str)  annotations.json and image will store in this place
            
            cocomask(pytcocotools.mask): 
                @deprecated
        """
        self.output_dir = output_dir
        self.debug = debug
        os.makedirs(output_dir, exist_ok=True)
        os.makedirs(osp.join(output_dir, "JPEGImages"), exist_ok=True)
        now = datetime.datetime.now()

        ## coco data that will be write into annotations.json
        self._data_info = dict(
                description=None,
                url=None,
                version=None,
                year=now.year,
                contributor=None,
                date_created=now.strftime("%Y-%m-%d %H:%M:%S.%f"),
            )
        self._data_licenses= [dict(url=None, id=0, name=None,)]
        self._data_images=[
                # license, url, file_name, height, width, date_captured, id
            ]
            # supercategory, id, name
        self._data_annotations=[
                # segmentation, area, iscrowd, image_id, bbox, category_id, id
            ]
        self._data_categories= None

    def classNameToId(self, labels_file, skeletondict=None):
        """
        @deprecated see setClassId

        读取 labels 文件，将 className 转换为 id，
        对于关键点名称，需要包含 Mask 名称作为前缀，并用 - 分隔

        对于关键点来说，必须有父类才行

        Currently keypoint labelname should be `xxx-1` or `xxx-2`, because it's
        easy to be ordered.

        Argument:
            labels_file (str): labels file loc,
        """
        class_id = -1  # class id 仅会给对象分配
        # dict to map class_name (str) to id (int)
        class_name_to_id = {} # key: class_name， 仅包含对象名, value: id
        categoriedict = {}
        with open(labels_file, "r") as f:
            for i, line in enumerate(f.readlines()):
                class_name = line.strip()   
                # 形如 class-subclass1
                if len(class_name) == 0:
                    continue
                
                if class_id == -1:
                    assert class_name == "__ignore__"
                    class_id += 1
                    continue

                matched = SubCategoryPatter.match(class_name)
                class_name_to_id[class_name] = class_id
                if matched is None:
                    # 检测为对象
                    categoriedict[class_id] = { 
                        "name": class_name, 
                        "supercategory": None,
                        "type": "object", 
                        "keypoints": [],
                    }
                    class_id += 1
                else:
                    # 检测为关键点
                    kp_name = class_name
                    class_name = matched[1]

                    class_id = class_name_to_id[class_name]
                    # keypoint name 也映射到相同的 id 
                    class_name_to_id[kp_name] = class_id

                    cat = categoriedict[class_id]
                    cat["type"] = "keypoint"
                    cat["keypoints"].append(kp_name)
        if self.debug:
            print("Class to id", class_name_to_id, "============")
            # raise ValueError("Test")
        self.class_name_to_id = class_name_to_id

        if self.debug:
            print("Categories Dict", categoriedict)

        categories = []
        skeleton = [] if skeletondict is None else skeletondict
        for class_id, cat in categoriedict.items():
            if cat["type"] == "object":
                categories.append({
                    "supercategory": cat["supercategory"], 
                    "id": class_id, 
                    "name": cat["name"],
                })
            elif cat["type"] == "keypoint":
                categories.append({
                    "supercategory": cat["supercategory"], 
                    "id": class_id, 
                    "name": cat["name"],
                    "keypoints": cat["keypoints"],
                    "skeleton": skeleton
                })

        self._data_categories = list(sorted(categories, key=lambda x: x["id"]))

        if self.debug:
            print("cats", self._data_categories, "============\n\n")

    def setClassId(self, class_dict):
        """ 默认背景 classId 为 0
        Args
            class_dict
            {
                "background": True, // default
                "instances": ["L1_1", "L1_2", "L1_3],
                // L1_1 L1_2 分别有关键点, L1_3 没有
                "keypoints": {
                    "L1_1": ["L2_1", "L2_2"]
                    "L1_2": ["L2_3"]
                },
                "skeleton": {

                }
            }
        """

        background = class_dict.get("background", True)
        categories = []
        if background:
            categories.append({
                "supercategory": None,
                "id": 0,
                "name": "_background_"
            })
        class_name_to_id = {
            "_background_": 0,
        }

        instances = class_dict["instances"]
        keypoints = class_dict["keypoints"] if "keypoints" in class_dict else {}
        skeleton = class_dict["skeleton"] if "skeleton" in class_dict else {}

        for i, ins_name in enumerate(instances):
            kps = keypoints.get(ins_name, [])
            skt = skeleton.get(ins_name, [])
            categories.append({
                "supercategory": None,
                "id": i + 1,
                "name": ins_name,
                "keypoints": kps,
                "skeleton": skt
            })
            class_name_to_id[ins_name] = i + 1
        self._data_categories = categories

        self.class_name_to_id = class_name_to_id
        keymap = {}
        for obj, kps in keypoints.items():
            for kp in kps:
                keymap[kp] = obj
        self._class_dict = class_dict
        self._class_dict["keymap"] = keymap

    def _kps_from_categories(self, cls_name):
        for categories in self._data_categories:
            if categories["name"] == cls_name:
                return categories["keypoints"]

    def _extractAnno(self, label_file, imshape):
        """
        this function can only find keypoints for each segclass on one image,
        
        if multiple segmentation class is on one image, group id must be set in labelme

        one part of the keypoints of segmentations will be saved

        Returns
            masks: key 为 (label, group_id)，value 为 mask

            keypoints: key 为 object class name，value 为 list((label_name, points))
        """
        masks = {}
        segmentations = collections.defaultdict(list)  # for segmentation
        keypoints = {}

        for shape in label_file.shapes:
            shape_type = shape.get("shape_type", "polygon")

            points = shape["points"]
            label = shape["label"]
            group_id = shape.get("group_id")

            if shape_type == "point": 
                # keypoint  is special
                # matched = SubCategoryPatter.match(label)
                segclass_name = self._class_dict["keymap"][label]
                if segclass_name not in keypoints:
                    keypoints[segclass_name] = []
                # points[0] 即该关键点
                keypoints[segclass_name].append((label, points[0]))
                continue

            mask = labelme.utils.shape_to_mask(imshape, points, shape_type)
            if group_id is None:
                # 保证在一副图片上的对象都是不同的 group_id
                group_id = uuid.uuid1()

            instance = (label, group_id)

            # if some instance is the same, then their masks should be intersected
            if instance in masks:
                masks[instance] = masks[instance] | mask
            else:
                masks[instance] = mask

            if shape_type == "rectangle":
                (x1, y1), (x2, y2) = points
                x1, x2 = sorted([x1, x2])
                y1, y2 = sorted([y1, y2])
                # clock wise
                points = [x1, y1, x2, y1, x2, y2, x1, y2]
            else:
                # 这里直接展平，不考虑分割成几部分的情况
                points = np.asarray(points).flatten().tolist()

            segmentations[instance].append(points)
        return masks, segmentations, keypoints

    def _convert_kp(self, cls_name, keypoints, imshape=None):
        kplist = []
        kpnum = 0
        i = 0
        keypoint_in_category = self._kps_from_categories(cls_name)
        boundary = [10000, 100000, 0, 0] # Top Left  Bottom Right
        if cls_name in keypoints:
            for kp in keypoints[cls_name]:
                # kp is the 2-element-tuple list
                name = kp[0]
                if name not in keypoint_in_category:
                    raise ValueError("KeyPoint and class not match")
                assert i == keypoint_in_category.index(name)
                i += 1
                point = kp[1]
                x = point[0]
                y = point[1]
                kplist.append(x) # x
                kplist.append(y) # y
                kplist.append(2)
                kpnum += 1

                if x < boundary[1]: boundary[1] = x
                if x > boundary[3]: boundary[3] = x
                if y < boundary[0]: boundary[0] = y
                if y > boundary[2]: boundary[2] = y
        if imshape is not None:
            boundary[0] -= 50; boundary[1] -= 50
            boundary[2] += 50; boundary[3] += 50
            if boundary[0] < 0: boundary[0] = 0
            if boundary[1] < 0: boundary[1] = 0
            if boundary[2] > imshape[0]: boundary[2] = imshape[0]
            if boundary[3] > imshape[1]: boundary[3] = imshape[1]
        return kplist, kpnum, boundary
           
    def _getAnno(self, label_file, img, image_id):
        """
        Extract all annotations from single labeled json file.

        Arguments:
            label_file: label_file object
            img: img_arr
            image_id (int): 
        """
        # masks is for area
        masks, segmentations, keypoints = self._extractAnno(label_file, img.shape[:2])
        segmentations = dict(segmentations)
        # sort keypoints to get ordered value
        for _, kp in keypoints.items():
            kp.sort(key=lambda x: x[0])

        if self.debug:
            print("Mask: ", len(masks), masks.keys())
            print("Segm", len(segmentations))
            # for segclass, items in keypoints.items():
            #     print("Keypoints: ", segclass, items)
            self.masks = masks

        annotations = self._data_annotations
        if len(masks) > 0:
            for instance, mask in masks.items():
                cls_name, group_id = instance
                if cls_name not in self.class_name_to_id:
                    continue
                cls_id = self.class_name_to_id[cls_name]

                mask = np.asfortranarray(mask.astype(np.uint8))
                mask = cocomask.encode(mask)
                bbox = cocomask.toBbox(mask).flatten().tolist()
                # it calculates mask area, not box area that torch needs
                area = float(cocomask.area(mask)) 
                # following is the area that torch needs
                # area = float(bbox[2] * bbox[3])
                # bbox = [bbox[0], bbox[1], bbox[0], bbox[1]] # format for torch
                kplist, kpnum, _ = self._convert_kp(cls_name, keypoints)
                anno = dict(
                    id=len(annotations),
                    image_id=image_id,
                    category_id=cls_id,
                    segmentation=segmentations[instance],
                    area=area, 
                    bbox=bbox, 
                    iscrowd=0,
                    keypoints = kplist, # [[kp[0], kp[1], 1] for kp in kps ], # format for torch
                    num_keypoints = kpnum,
                )
                annotations.append(anno)
        else:
            # 这种情况对应只有关键点，没有 mask 的情况
            assert len(keypoints) == 1, "Unable to extract annotation from this file"
            # print("No mask found\n\n")
            # if cls_name not in keypoints:
                # print("Error in class name and keypoint list")
                
            cls_name = list(keypoints.keys())[0] # 找到 keypoint 对应的类名，这里默认只有一个类
            cls_id = self.class_name_to_id[cls_name]
   
            kplist = []
            kpnum = 0
            # print(kps, "========")
            kplist, kpnum, boundary = self._convert_kp(cls_name=cls_name, keypoints=keypoints, imshape=img.shape)
            bbox = [boundary[1], boundary[0], (boundary[3] - boundary[1]), (boundary[2] - boundary[0])]
            area = (boundary[2] - boundary[0]) * (boundary[3] - boundary[1])
            segmentation = [ 
                boundary[1], boundary[0], 
                boundary[1], boundary[2], 
                boundary[3], boundary[2], 
                boundary[3], boundary[0], 
            ]
            anno = dict(
                id=len(annotations),
                image_id=image_id,
                category_id=cls_id,
                segmentation=[segmentation],
                area=area, 
                bbox=bbox, 
                iscrowd=0,
                keypoints = kplist, # [[kp[0], kp[1], 1] for kp in kps ], # format for torch
                num_keypoints = kpnum,
            )
            annotations.append(anno)
        if self.debug:
            print(annotations)

    def generateCocoJson(self, label_files):
        """
        生成 COCO 格式的 JSON 文件

        Args
            label_files: iterateble filenames
        """
        
        out_ann_file = osp.join(self.output_dir, "annotations.json")
        for image_id, filename in tqdm(enumerate(label_files)):
            try:
                label_file = labelme.LabelFile(filename=filename)
            except:
                print("Error generating dataset from:", filename)
                continue
            # 
            base = osp.splitext(osp.basename(filename))[0]
            out_img_file = osp.join(self.output_dir, "JPEGImages", base + ".jpg")

            img = labelme.utils.img_data_to_arr(label_file.imageData)
            Image.fromarray(img).convert("RGB").save(out_img_file)
            ## 填充 images
            self._data_images.append(
                dict(
                    license=0,
                    url=None,
                    file_name=osp.relpath(out_img_file, osp.dirname(out_ann_file)),
                    height=img.shape[0],
                    width=img.shape[1],
                    date_captured=None,
                    id=image_id,
                )
            )

            self._getAnno(label_file, img, image_id)

            if self.debug:
                break


    def output(self, out_ann_file = None, dtype="instances"):
        if out_ann_file is None:
            out_ann_file = "annotations.json"
        out_ann_file = osp.join(self.output_dir, out_ann_file)
        with open(out_ann_file, "w") as f:
            data = dict(
                type=dtype,
                annotations=self._data_annotations,
                categories=self._data_categories,
                images=self._data_images,
                info=self._data_info,
                licenses=self._data_licenses
            )
            json.dump(data, f, indent=True)
