# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
from typing import Iterable, List, Optional

import cv2
import numpy as np

from ydocr.cls import TextClassifier
from ydocr.det import TextDetector
from ydocr.rec import TextRecognizer
from copy import deepcopy


logger = logging


def get_rotate_crop_image(img, points):
    '''
    img_height, img_width = img.shape[0:2]
    left = int(np.min(points[:, 0]))
    right = int(np.max(points[:, 0]))
    top = int(np.min(points[:, 1]))
    bottom = int(np.max(points[:, 1]))
    img_crop = img[top:bottom, left:right, :].copy()
    points[:, 0] = points[:, 0] - left
    points[:, 1] = points[:, 1] - top
    '''
    img_crop_width = int(
        max(
            np.linalg.norm(points[0] - points[1]),
            np.linalg.norm(points[2] - points[3])))
    img_crop_height = int(
        max(
            np.linalg.norm(points[0] - points[3]),
            np.linalg.norm(points[1] - points[2])))
    pts_std = np.float32([[0, 0], [img_crop_width, 0],
                          [img_crop_width, img_crop_height],
                          [0, img_crop_height]])
    M = cv2.getPerspectiveTransform(points, pts_std)
    dst_img = cv2.warpPerspective(
        img,
        M, (img_crop_width, img_crop_height),
        borderMode=cv2.BORDER_REPLICATE,
        flags=cv2.INTER_CUBIC)
    dst_img_height, dst_img_width = dst_img.shape[0:2]
    if dst_img_height * 1.0 / dst_img_width >= 1.5:
        dst_img = np.rot90(dst_img)
    return dst_img


class TextSystem(object):
    def __init__(self, use_angle_cls=True, box_thresh=0.5, unclip_ratio=2.0, rec_model_path=None, det_model_path=None,
                 ort_providers=None):
        self.text_detector = TextDetector(box_thresh=box_thresh, unclip_ratio=unclip_ratio,
                                          det_model_path=det_model_path, ort_providers=ort_providers)
        self.text_recognizer = TextRecognizer(rec_model_path=rec_model_path, ort_providers=ort_providers)
        self.use_angle_cls = use_angle_cls
        if self.use_angle_cls:
            self.text_classifier = TextClassifier(ort_providers=ort_providers)

    def set_char_whitelist(self, chars: Optional[Iterable[str]]):
        self.text_recognizer.set_char_whitelist(chars)

    def ocr_single_line(self, img):
        res = self.ocr_lines([img])
        if res:
            return res[0]

    def ocr_lines(self, img_list: List[np.ndarray]):
        tmp_img_list = []
        for img in img_list:
            img_height, img_width = img.shape[0:2]
            if img_height * 1.0 / img_width >= 1.5:
                img = np.rot90(img)
            tmp_img_list.append(img)
        rec_res, elapse = self.text_recognizer(tmp_img_list)
        return rec_res

    
    def trans_poly_to_bbox(self, poly):
        x1 = int(np.min([p[0] for p in poly]))
        x2 = int(np.max([p[0] for p in poly]))
        y1 = int(np.min([p[1] for p in poly]))
        y2 = int(np.max([p[1] for p in poly]))
        return [x1, y1, x2, y2]

    def detect_and_ocr(self, img: np.ndarray, drop_score=0.5, unclip_ratio=None, box_thresh=None):
        ori_im = img.copy()
        dt_boxes, elapse = self.text_detector(img, unclip_ratio, box_thresh)
        logger.debug("dt_boxes num : {}, elapse : {}".format(len(dt_boxes), elapse))
        if dt_boxes is None:
            return []
        img_crop_list = []

        dt_boxes = sorted_boxes(dt_boxes)

        for bno in range(len(dt_boxes)):
            tmp_box = copy.deepcopy(dt_boxes[bno])
            img_crop = get_rotate_crop_image(ori_im, tmp_box)
            img_crop_list.append(img_crop)
        if self.use_angle_cls:
            img_crop_list, angle_list, elapse = self.text_classifier(img_crop_list)
            logger.debug("cls num  : {}, elapse : {}".format(len(img_crop_list), elapse))

        rec_res, elapse = self.text_recognizer(img_crop_list)
        logger.debug("rec_res num  : {}, elapse : {}".format(len(rec_res), elapse))
        res = []
        for box, rec_reuslt, img_crop in zip(dt_boxes, rec_res, img_crop_list):
            text, score = rec_reuslt
            if score >= drop_score:
                res.append(BoxedResult(self.trans_poly_to_bbox(box), img_crop, text, score))
        return res


class BoxedResult(object):
    box: List[int]
    text_img: np.ndarray
    ocr_text: str
    score: float

    def __init__(self, box, text_img, ocr_text, score):
        self.box = box
        self.text_img = text_img
        self.ocr_text = ocr_text
        self.score = score

    def __str__(self):
        return 'BoxedResult[%s, %s %s ]' % (self.ocr_text, self.score,self.box)

    def __repr__(self):
        return self.__str__()



def sorted_boxes(dt_boxes):
    """
    Sort text boxes in order from top to bottom, left to right
    args:
        dt_boxes(array):detected text boxes with shape [4, 2]
    return:
        sorted boxes(array) with shape [4, 2]
    """
    num_boxes = dt_boxes.shape[0]
    sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
    _boxes = list(sorted_boxes)

    for i in range(num_boxes - 1):
        if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
                (_boxes[i + 1][0][0] < _boxes[i][0][0]):
            tmp = _boxes[i]
            _boxes[i] = _boxes[i + 1]
            _boxes[i + 1] = tmp
    return _boxes

def  order_onrow(ocr_info):
    dt_boxes = []
    for res in ocr_info:
        dt_boxes.append(res.box)
    # 平均字间距
    avg_word_height = int(np.mean([dt_boxes[i][3] - dt_boxes[i][1] for i in range(len(dt_boxes))]) / 2  )

    text = ''
    res = sorted(ocr_info, key=lambda r: (r.box[1], r.box[0]))
    res.append(res[-1])
    for i in range(len(res) - 1):
        for j in range(i, 0, -1):
            if abs(res[j + 1].box[1] - res[j].box[1]) < avg_word_height and \
                    (res[j + 1].box[0] < res[j].box[0]):
                tmp = deepcopy(res[j])
                res[j] = deepcopy(res[j + 1])
                res[j + 1] = deepcopy(tmp)
        if abs(res[i + 1].box[1] - res[i].box[1]) < avg_word_height :
            text += res[i].ocr_text + '\\r'
        else:
            text += res[i].ocr_text + '\\n'
    
    # # 判断最后一位
    # if abs(res[-1].box[1] - res[-2].box[1]) < 20 :  
    #     text += res[-1].ocr_text 
    # else:
    #     text += res[-1].ocr_text     
    
    return text

