import json
import cv2 as cv
import os
import os.path as osp
from tqdm import tqdm
from PIL import Image
import base64
import numpy as np
from labelme.label_file import LabelFile, PY2, QT4, utils, io
import shutil as sh

__version__ = "0.1.0"

def encodeImage(imarr, ext=".jpg"):
    """@See Also LabelFile.load_image_file

    Arguments:
        imarr(np.ndarray | PIL.Image | str)

    Return:
        (str) base64 encoded image
    """
    imtype = type(imarr)
    if isinstance(imarr, Image.Image):
        pass
    elif imtype == np.ndarray:
        imarr = Image.fromarray(imarr)
    elif imtype == str:
        imarr = PIL.Image.open(imarr)
    else:
        raise ValueError("Unsupported Type: ", imtype)
    # apply orientation to image according to exif
    image_pil = utils.apply_exif_orientation(imarr)

    with io.BytesIO() as f:
        # ext = osp.splitext(filename)[1].lower()
        if PY2 and QT4:
            format = "PNG"
        elif ext in [".jpg", ".jpeg"]:
            format = "JPEG"
        else:
            format = "PNG"
        image_pil.save(f, format=format)
        f.seek(0)
        # return f.read()
        imageData = f.read()
    return base64.b64encode(imageData).decode("utf-8")


def calcY(line, x):
    return line[0] * x + line[1]

import copy
from PIL import ImageEnhance

class LabelmeModifier(object):
    """
    用于旋转、缩放 Labelme 已经标记好的图片

    Arguments:
        root(str): root directory that contains the raw image and json file
        generated by `labelme`
        dst_dir(str)
    """

    def __init__(self, root, dst_dir, image_suffix=".jpg", debug=False, **kwargs):
        self.root = root
        self.dst_dir = dst_dir
        self.debug = debug
        self.image_suffix = image_suffix
        files = [f for f in sorted(os.listdir(root))]
        self.images = [file for file in files if image_suffix in file]
        jsonfiles = [file for file in files if ".json" in file ]
        self.names = [img[: img.rfind(image_suffix)] for img in self.images]

        print(f"Found files at {root} images ({image_suffix}): {len(self.images)}, jsons: {len(jsonfiles)}")
        
        os.makedirs(self.dst_dir, exist_ok=True)


    def readJsonFile(self, file):
        loc = osp.join(self.root, file)
        if not osp.exists(loc):
            return None
        try:
            with open(loc, "r") as f:
                content = json.loads(f.read())
        except Exception as e:
            print(e)
            raise ValueError(f"Exception for {file}: ")
        return content


    def _resize(self, name, size, jsoncontent):
        """
        Note: **imageData in jsoncontent is not modified, it should be modified by the last step.**

        Arguments:
            name(str): filename stem without path location
            size(tuple): new size that resized image will be
            jsoncontent: correspoinding labelme json file content

        Return:
            tuple: (newImage, jsoncontent)
        """
        # resize images only
        image = Image.open(osp.join(self.root, f"{name}{self.image_suffix}"))
        iw, ih = image.size
        wscale = size[0] / iw
        hscale = size[1] / ih
        newImage = image.resize(size)

        if jsoncontent is None:
            if self.debug:
                print(f"Not Found json file for {name}")
            return newImage, None
        jsoncontent["imageWidth"]  = size[0]
        jsoncontent["imageHeight"] = size[1]
        if self.debug:
            newImage.show(f"Resized Image {name} {wscale} {hscale}")
        shapes = jsoncontent["shapes"]
        for shape in shapes:
            for point in shape["points"]:
                point[0] *= wscale
                point[1] *= hscale
        
        if self.debug:
            print("Name: ", name)
        return newImage, jsoncontent

    def resize(self, size, saveImageData=False):
        """
        Collabrate with crop, If crop is need before resize, using cropResize instead

        Arguemengs:
            size(tuple): W X H
        """
        count = 0
        for name in tqdm(self.names):
            jsoncontent = self.readJsonFile(f"{name}.json")
            # print("JsonContent", jsoncontent)
            newImage, jsoncontent = self._resize(name, size, jsoncontent)
            savedLoc = osp.join(self.dst_dir, f"{name}{self.image_suffix}")
            newImage.save(savedLoc)
            if jsoncontent is not None:
                imageData = LabelFile.load_image_file(savedLoc)
                imageData = base64.b64encode(imageData).decode("utf-8")
                jsoncontent["imageData"] = imageData if saveImageData else None

                with open(osp.join(self.dst_dir, f"{name}.json"), "w") as f:
                    f.write(json.dumps(jsoncontent, indent=True))
            count += 1
            if self.debug:
                break

        print(f"Done for {count}")

    def _rotate(self, name, jsoncontent, degree, center, boarderValue, scale=1.0):
        """
        Arguments: 
            name(str): Image Name
            degree(float): degrees
            center(tuple): X,Y

        See torchutils.rotateImage
        """
        image = Image.open(osp.join(self.root, f"{name}{self.image_suffix}"))
        iw, ih = image.size    #    
        if center is None:
            center = (iw // 2, ih // 2 )
        imarr = np.asarray(image)
        
        M = cv.getRotationMatrix2D(center, degree, scale)
        rotated = cv.warpAffine(imarr, M, (iw, ih), borderValue=boarderValue)
        # if self.debug:
        #     cv.imshow("default", rotated)
        #     cv.waitKey(0)

        if jsoncontent is None:
            return imarr

        shapes = jsoncontent["shapes"]
        for shape in shapes:
            points = []
            # rotatepoints = cv.warpAffine(points, M, (iw, ih), borderValue=boarderValue)
            # print(rotatepoints.shape, points.shape)
            for point in shape["points"]:
                newPoint = np.dot(M[:,:2], np.asarray(point)) + M[:,2]
                # print("\n=======", newPoint, M[:, 2])
                # check point pos
                if newPoint[0] > iw: newPoint[0] = iw
                if newPoint[1] > ih: newPoint[0] = ih
                points.append([newPoint[0], newPoint[1]])
                # if self.debug: break
            # if self.debug: break
            shape["points"] = points
        return Image.fromarray(rotated)
        

    def rotate(self, degs, center=None, boarderValue = (0, 0, 0), scale=1.0, **kwargs):
        """
        Arguments:
            degs: list or iterable (float)
            center(tuple | None): if the all the image size are the same, it could be applied when rotating,
                if image size differes from each other, leave it unchanged to let program choose the center 
                of the image
        """
        count = 0
        if kwargs.get("start"):
            print("[Warning] api changed, using `degs` instead")
        for name in tqdm(self.names):
            orijsoncontent = self.readJsonFile(f"{name}.json")
            if orijsoncontent is None:
                continue
            # print("JsonContent", jsoncontent)
            rc = 0
            for deg in degs:
                jsoncontent = copy.deepcopy(orijsoncontent) 
                newImage = self._rotate(name, jsoncontent, deg, center, boarderValue, scale=scale)
                newName = f"{name}-r{deg}-s{scale}-{rc:04}"
                savedLoc = osp.join(self.dst_dir, f"{newName}{self.image_suffix}")
                rc += 1
                newImage.save(savedLoc)
                if jsoncontent is not None:
                    imageData = LabelFile.load_image_file(savedLoc)
                    imageData = base64.b64encode(imageData).decode("utf-8")
                    jsoncontent["imageData"] = imageData

                    with open(osp.join(self.dst_dir, f"{newName}.json"), "w") as f:
                        f.write(json.dumps(jsoncontent, indent=True))
                count += 1

                if self.debug:
                    break
            if self.debug:
                break

        print(f"Done Rotation for {count}")

    def _shift(self, name, jsoncontent, xy, boarderValue):
        """
        Arguments: 
            name(str): Image Name
            degree(float): degrees
            center(tuple): X,Y

        See torchutils.rotateImage
        """
        image = Image.open(osp.join(self.root, f"{name}{self.image_suffix}"))
        iw, ih = image.size    #    

        if jsoncontent is None:
            return image

        image = image.rotate(0, translate=xy)
        shapes = jsoncontent["shapes"]
        for shape in shapes:
            points = []
            for point in shape["points"]:
                newPoint = [point[0] + xy[0], point[1] + xy[1]]
                # print("\n=======", newPoint, M[:, 2])
                # check point pos
                if newPoint[0] > iw: newPoint[0] = iw
                if newPoint[1] > ih: newPoint[0] = ih
                points.append([newPoint[0], newPoint[1]])
                # if self.debug: break
            # if self.debug: break
            shape["points"] = points
        return image
        
    def shift(self, maxpixel=10, step = 10, direction = (1, 1), boarderValue = (0, 0, 0)) :
        """
        pixel shift will be 10, 5 if max is 10 and step is 5
        direction is the vector that shift will be, for example (1, 1) 
        will shift the image to the up-right corner for pixel
        """
        count = 0
        assert direction[0] != 0 or direction[1] != 0
        for name in tqdm(self.names):
            orijsoncontent = self.readJsonFile(f"{name}.json")
            if orijsoncontent is None:
                continue
            # print("JsonContent", jsoncontent)
            rc = 0
            for pixel in range(maxpixel, 0, - step):
                if pixel <= 0: break
                xy = (direction[0] * pixel, direction[1] * pixel)
                jsoncontent = copy.deepcopy(orijsoncontent) 
                newImage = self._shift(name, jsoncontent, xy, boarderValue)
                newName = f"{name}-t{direction[0]}-{direction[1]}-{pixel}-{rc:04}"
                savedLoc = osp.join(self.dst_dir, f"{newName}{self.image_suffix}")
                rc += 1
                newImage.save(savedLoc)
                if jsoncontent is not None:
                    imageData = LabelFile.load_image_file(savedLoc)
                    imageData = base64.b64encode(imageData).decode("utf-8")
                    jsoncontent["imageData"] = imageData

                    with open(osp.join(self.dst_dir, f"{newName}.json"), "w") as f:
                        f.write(json.dumps(jsoncontent, indent=True))
                count += 1

                if self.debug:
                    break
            if self.debug:
                break

        print(f"Done Shift for {count}")

    def cropResize(self, leftTop, cropSize, newSize = None):
        """
        Arguements:
            leftTop(tuple): (X, Y)

            cropSize(tuple): W X H, croped size.  
                (cropSize + leftTop) should be less than original_size at W & H, 
                if not, (cropSize + leftTop) will reach the original images rightBottom

            newSize(tuple): W X H, rescaled size after cropped, if it's None, then it 
                won't be resized
        """
        raise NotImplementedError("Not Impletmented Yet")
        
    def copy(self):
        for name in tqdm(self.names):
            # imane
            imname = f"{name}{self.image_suffix}"
            sh.copy(osp.join(self.root, imname), osp.join(self.dst_dir, imname))
            jsonname = f"{name}.json"
            sh.copy(osp.join(self.root, jsonname), osp.join(self.dst_dir, jsonname))

    def output(self, destdir):
        """
        Arguments:
            destdir(str): dest directory
        """
        pass

        
    def polyfitSlope(self):
        for jsonfile in tqdm(jsonfiles):
        #     print(jsonfile)
            with open(src/jsonfile, "r") as f:
                content = json.loads(f.read())
            # print(content)
            shapes = content['shapes']
            imagePath = content['imagePath']
            occipitalSlope = None
            sphenoidbone = None
            newShapes = []
            for i, shape in enumerate(shapes):
                if shape['label'] == 'SphenoidBone':
                    sphenoidbone = shape
                elif shape['label'] == labelType:
                    occipitalSlope = shape
                    continue
                newShapes.append(shape)
                
            shapes = newShapes
                    
        #     if occipitalSlope is not None:
        #         continue
            imdata = cv.imread(str(src/ imagePath), 0)
        #     flatcontour = np.asarray(sphenoidbone["points"], dtype=int)
        #     contourLen = len(flatcontour)
            # print(flatcontour)
        #     maxy = (np.argmax(flatcontour[:, 1]))
        #     maxy = (maxy + 1) % len(flatcontour)
        #     if (maxy + 5) < len(flatcontour):
        #         contour = flatcontour[maxy:maxy+5]
        #     else:
        #         contour = flatcontour[maxy:]
        #         con = contour.shape[0]
        # #         print("length", con)
        #         contour = np.concatenate((contour, flatcontour[: 5-con]))
        #     line = np.polyfit(contour[:,0].reshape(-1), contour[:,1].reshape(-1), 1)
        #     x1 = float(contour[0, 0])
        #     x2 = float(contour[-1, 0] + 20)
        #     y1 = float(calcY(line, x1))
        #     y2 = float(calcY(line, x2))
        #     print([[x1, y1], [x2, y2]])
            # cv.line(imdata, (x1, y1), (x2, y2),(0,255,0), 2) 
            # plt.figure(figsize=(10, 8 ), dpi=100)
            # plt.imshow(imdata, cmap='gray', vmin=0, vmax=255)
        #     occipitalSlope = {"label": labelType, "points": [[x1, y1], [x2, y2], [x2, y2+1], [x1, y1+1]], "gorup_id": None, "shape_type": "polygon", "flags": {}}
        #     shapes.append(occipitalSlope)
        #     print([s["label"] for s in shapes], len(shapes))
        #     break
            content["shapes"] = shapes
            try:
                data = json.dumps(content, indent=True)
                
                with open(src/jsonfile, "w") as f:
                    f.write(data)

            except Exception as e:
                    print(e)

class LabelmeLabelModifier(LabelmeModifier):
    """专用于修改 Label 名称或调整 label 数量的
    """
    def removeLabels(self, labels):
        """
        Remove Labels if exits
        Arugments:
            labels(list(str))
        """
        assert type(labels) is list
        r_labels = {}
        for label in labels:
            r_labels[label] = { "num": 0 }
        self._renameShapes(r_labels)

   
    def _renameShapes(self, labelShapes):
        """Rename one shape at one time,

        Arguments:
            labelShapes(dict): 
            ```
            {
                "old_label_name": {
                    "label": "new_label_name", // empty for not change
                    "num": 1000, // default 1000,
                    "shape_idx": False, // default False
                }
            }
            ```
        """
        count = 0

        copyImage = (self.root != self.dst_dir)

        for name in tqdm(self.names):
            orijsoncontent = self.readJsonFile(f"{name}.json")
            if orijsoncontent is None:
                continue
            new_shapes = []
            reserve_count = 0
            # iterate over original json content
            for shape in orijsoncontent["shapes"]:
                # shape is a dict that contains label
                # print("====\n", shape, "====\n")
                oldName = shape["label"] 
                if oldName not in labelShapes:
                    # 如果找不到需要修改的label，默认保留旧 label
                    new_shapes.append(shape)
                    continue
                prop = labelShapes[oldName]
                # 默认使用旧名称
                new_label_name = prop.get("label", oldName)
                shape_idx = prop.get("shape_idx", False)
                num_reserved = prop.get("num", 1000)
                # if count not meet, reserve it
                if reserve_count < num_reserved:
                    nshape = copy.deepcopy(shape)
                    # 避免重复排序
                    try: 
                        d = new_label_name[new_label_name.rfind("-")+1:]
                        d = int(d)
                        shape_idx = False
                    except:
                        pass
                    if shape_idx:
                        nshape["label"] = new_label_name + f"-{reserve_count}"
                    else:
                        nshape["label"] = new_label_name
                    new_shapes.append(nshape)
                    reserve_count += 1
            # print(new_shapes)
            orijsoncontent["shapes"] = list(sorted(new_shapes, key=lambda x: x['label']))
            with open(osp.join(self.dst_dir, f"{name}.json"), "w") as f:
                f.write(json.dumps(orijsoncontent, indent=True))
            if copyImage:
                sh.copy(osp.join(self.root, f"{name}{self.image_suffix}"), self.dst_dir)
            count += 1
        print(f"Done rename shapes: {count}")

    def renameShapes(self, shapename, new_shapename=None, num_reserved=None, shape_idxes=None):
        """
        Arguments:
            shapename (list(str) ) | dict 
            new_shapename (list(str))
            num_reserved  list(Number) max_number of instance reserved, default 1000
            shape_idxes list(Boolean) whether to append idxes for different instance in one Picture
            
        Example:
        
        use one dict:

        ```py
        lm = LabelmeModifier(src_dir, dst_dir, image_suffix=".png")
        labels = {
            "Person": {
                "label": "Human",
                "num": 100, 
                "shape_idx": True,
            }
        }
        lm.renameShapes(labels)
        ```

        or use multiple arrays instead:

        ```py
        lm = LabelmeModifier(src_dir, dst_dir, image_suffix=".png")
        ori_label = ["Person"]
        new_label = ["Human"]
        nums = [100]
        lm.renameShapes(ori_label, new_label, nums)
        ```

        """
        stype = type(shapename)
        if stype == str:
            raise ValueError("String is not supported anymore. Please Use List instead.")
        if stype == dict:

            self._renameShapes(shapename)
        else:
            assert stype is list, " shapename must be either List or Dict"

            if new_shapename is None:
                new_shapename = shapename
            slen = len(shapename)
            if shape_idxes is None:
                shape_idxes = [False for _ in range(slen)]
            if num_reserved is None:
                num_reserved = [1000 for _ in range(slen)]
            assert len(new_shapename) == len(shapename) == len(shape_idxes) == len(num_reserved)
            shapes = {}
            for idx, name in enumerate(shapename):
                shapes[name] = {
                    "label": new_shapename[idx],
                    "num": num_reserved[idx],
                    "shape_idx": shape_idxes[idx]
                }
            self._renameShapes(shapes)
        


def moveImages(src, dst):
    """move Images from src dir with json files not moved
    """
    jsons = []
    imgs = []
    for root, dirs, files in os.walk(src):
        for file in files:
            if file.endswith(".json"):
                jsons.append(file[:file.rfind(".")])
            else:
                imgs.append(file)
    
    dst = str(dst)
    for file in imgs:
        filename = file[:file.rfind(".")]
        if filename in jsons:
            continue
        # print(filename)
        sh.move(str(src / file), dst)

from .labelme2voc import Labelme2Vocor
from .labelme2coco import Labelme2Cocoer
