import copy
import os
import sys
import traceback
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import Manager
from threading import Thread
from typing import Any

import Levenshtein

# from func_timeout import FunctionTimedOut, func_timeout
from pylatexenc.latex2text import LatexNodes2Text
from tqdm import tqdm

from .cdm_metric import CDM
from .table_metric import TEDS
from .utils.data_preprocess import clean_string, normalized_table
from .utils.extract import md_tex_filter
from .utils.match import match_gt2pred_no_split, match_gt2pred_simple
from .utils.match_quick import match_gt2pred_quick


def get_blocks(label_data: dict) -> list[dict]:
    blocks: dict[int, dict] = {}
    for idx, block in enumerate(label_data["blocks"]):
        if block["type"] in ("header", "footer", "page_number", "page_footnote"):
            block["order"] = None
        else:
            block["order"] = (idx + 1) # idx==0 has issue.
        blocks[idx] = block

    list_items: dict[int, set[int]] = {}
    truncated_sets: list[set[int]] = []
    for relation in label_data["relations"]:
        if relation["relation"] == "item_of":
            item_idx, list_idx = relation["from"], relation["to"]
            list_items.setdefault(list_idx, set()).add(item_idx)
        elif relation["relation"] == "truncated":
            from_idx, to_idx = relation["from"], relation["to"]
            found_sets = [s for s in truncated_sets if from_idx in s or to_idx in s]
            if not found_sets:
                truncated_sets.append(set([from_idx, to_idx]))
            elif len(found_sets) == 1:
                found_sets[0].update([from_idx, to_idx])
            else:
                raise RuntimeError("More than one truncated sets found, unexpected case.")

    for list_idx, item_indices in list_items.items():
        item_blocks = [blocks.pop(idx) for idx in sorted(item_indices)]
        joint_block = copy.deepcopy(blocks.pop(list_idx))
        joint_block["type"] = item_blocks[0]["type"]
        joint_block["content"] = "\n".join(b["content"] for b in item_blocks)
        blocks[list_idx] = joint_block

    for truncated_set in truncated_sets:
        truncated_blocks = [blocks.pop(idx) for idx in sorted(truncated_set)]
        joint_block = copy.deepcopy(truncated_blocks[0])
        joint_block["content"] = "".join(b["content"] for b in truncated_blocks)
        blocks[min(truncated_set)] = joint_block

    return [b for _, b in sorted([(i, b) for i, b in blocks.items()])]


def filter_blocks(blocks: list[dict], type_list: list[str]) -> list[dict]:
    return [block for block in blocks if block["type"] in type_list]


# 从元素列表 items 中过滤掉 gt_category_type 在 ignore_category_list 中的元素。
def filtered_out_ignore(items, ignore_category_list):
    filted_items = []
    for item in items:
        if item["gt_category_type"] not in ignore_category_list:
            filted_items.append(item)
    return filted_items


# 计算预测结果和地面真值的阅读顺序之间的编辑距离，并返回包含相关信息的字典。
def get_order_paired(order_match_s, img_name):
    matched = [
        (item["gt_position"], item["pred_position"])
        for item in order_match_s
        if (item["gt_position"] != [""] and item["pred_position"] != "")
    ]
    gt_idx_all = [item["gt_position"] for item in order_match_s if (item["gt_position"] != [""])]
    read_order_pred = [i[0] for i in sorted(matched, key=lambda x: x[1])]  # Sort by pred idx to get Pred ordered GT_idx
    read_order_gt = sum(gt_idx_all, [])  # Convert to one-dimensional list
    read_order_gt = [
        x for x in read_order_gt if x
    ]  # For truncated merges, some discarded classes may be merged in, remove them when calculating edit distance
    gt = sorted(read_order_gt)  # Sort by all GT idx to get GT ordered GT_idx
    pred = sum(read_order_pred, [])
    pred = [x for x in pred if x]
    if len(pred) > 0 or len(gt) > 0:
        edit = Levenshtein.distance(gt, pred) / max(len(pred), len(gt))
        return {"gt": gt, "pred": pred, "img_id": img_name, "edit": edit}
    else:
        return {}  # If both GT and pred are empty for the page, return empty


# 为公式匹配结果添加 img_id 信息。
def formula_format(formula_matches, img_name):
    # formated_list = []
    for i, item in enumerate(formula_matches):
        item["img_id"] = img_name + "_" + str(i)
    return formula_matches


# 0403 提取gt的table跟pred的table进行匹配 -> 未匹配上的pred_table 去掉html格式然后丢进去混合匹配
def process_get_matched_elements(match_method: str, label_data: dict, pred_content: str, img_name: str):
    if match_method == "simple_match":  # add match choice
        match_gt2pred = match_gt2pred_simple
    elif match_method == "quick_match":
        match_gt2pred = match_gt2pred_quick
    elif match_method == "no_split":
        match_gt2pred = match_gt2pred_no_split
    else:
        print("Invalid match method name. The quick_match will be used.")
        match_gt2pred = match_gt2pred_quick

    pred_dataset = md_tex_filter(pred_content)
    all_blocks = get_blocks(label_data)

    gt_mix, pred_dataset_mix = [], []
    for category in pred_dataset:
        if category not in ["html_table", "latex_table", "md2html_table"]:
            pred_dataset_mix.extend(pred_dataset[category])

    gt_mix = filter_blocks(
        all_blocks,
        [
            "text",
            "title",
            "code",
            "code_caption",
            "ref_text",
            "equation_caption",
            "image_caption",
            "image_footnote",
            "table_caption",
            "table_footnote",
            "code_algorithm",
            "code_algorithm_caption",
            "header",
            "footer",
            "page_footnote",
            "page_number",
            "equation",
        ],
    )

    latex_table_match_s = []
    html_table_match_s = []

    table_blocks = filter_blocks(all_blocks, ["table"])
    if table_blocks:
        latex_table_len = len(pred_dataset["latex_table"]) if pred_dataset["latex_table"] else 0
        html_table_len = len(pred_dataset["html_table"]) if pred_dataset["html_table"] else 0
        if latex_table_len == html_table_len and latex_table_len == 0:
            html_table_match_s, unmatch_table_pred = match_gt2pred_simple(
                table_blocks, [], "html_table", img_name
            )  # Don't consider truncated merging for tables
            html_table_match_s = [x for x in html_table_match_s if x["gt_idx"] != [""]]  # Remove extra preds
        elif latex_table_len > html_table_len:
            latex_table_match_s, unmatch_table_pred = match_gt2pred_simple(
                table_blocks, pred_dataset["latex_table"], "latex_table", img_name
            )  # Don't consider truncated merging for tables
            latex_table_match_s = [x for x in latex_table_match_s if x["gt_idx"] != [""]]  # Remove extra preds
        else:
            html_table_match_s, unmatch_table_pred = match_gt2pred_simple(
                table_blocks, pred_dataset["html_table"], "html_table", img_name
            )  # Don't consider truncated merging for tables
            html_table_match_s = [x for x in html_table_match_s if x["gt_idx"] != [""]]  # Remove extra preds

        if unmatch_table_pred:
            pred_dataset_mix.extend(unmatch_table_pred)

    try:
        match = match_gt2pred(gt_mix, pred_dataset_mix, "text_all", img_name)
        # match = func_timeout(300, match_gt2pred, args=(gt_mix, pred_dataset_mix, "text_all", img_name))
    # except FunctionTimedOut as e1:
    #     raise
    #     match, _ = match_gt2pred_simple(gt_mix, pred_dataset_mix, "text_all", img_name)
    except Exception as e:
        print(traceback.format_exc())
        sys.exit()

    plain_text_match_s = []
    display_formula_match_s = []
    for item in match:
        gt_category = item.get("gt_category_type", None)
        if gt_category in [
            "text",
            "title",
            "code",
            "code_caption",
            "ref_text",
            "equation_caption",
            "image_caption",
            "image_footnote",
            "table_caption",
            "table_footnote",
            "code_algorithm",
            "code_algorithm_caption",
            "header",
            "footer",
            "page_footnote",
            "page_number",
        ]:
            plain_text_match_s.append(item)
        elif gt_category == "equation":
            display_formula_match_s.append(item)

    display_formula_match_s = [x for x in display_formula_match_s if x["gt_idx"] != [""]]

    plain_text_match_clean = []
    if plain_text_match_s:
        plain_text_match_clean = filtered_out_ignore(
            plain_text_match_s,
            [
                "image_caption",
                "image_footnote",
                "table_caption",
                "table_footnote",
                "code_algorithm",
                "code_algorithm_caption",
                "header",
                "footer",
                "page_footnote",
                "page_number",
                "equation_caption",
            ],
        )

    order_match_single = []
    # print(plain_text_match_clean)
    if plain_text_match_clean:
        order_match_single = get_order_paired(plain_text_match_clean, img_name)

    return [plain_text_match_clean, display_formula_match_s, latex_table_match_s, html_table_match_s, order_match_single]


###################### Metrics #######################


def calc_section_edit_dist(section: dict):
    pred = section["norm_pred"] if section.get("norm_pred") else section["pred"]
    gt = section["norm_gt"] if section.get("norm_gt") else section["gt"]
    max_len = max(len(pred), len(gt))
    if max_len == 0:
        return {"max_len": 0, "edit_len": 0, "edit_ratio": 0.0}
    if len(pred) == 0 or len(gt) == 0:
        return {"max_len": max_len, "edit_len": max_len, "edit_ratio": 1.0}
    edit_len = Levenshtein.distance(pred, gt)
    edit_ratio = edit_len / max_len
    return {"max_len": max_len, "edit_len": edit_len, "edit_ratio": edit_ratio}


def calc_page_edit_dist(sections: list[dict]):
    if not sections:
        return {
            "max_len": None,
            "edit_len": None,
            "edit_ratio": None,
            "sections": [],
        }
    section_results = [calc_section_edit_dist(s) for s in sections]
    max_len, edit_len = 0, 0
    for section_result in section_results:
        max_len += section_result["max_len"]
        edit_len += section_result["edit_len"]
    if max_len == 0:
        return {
            "max_len": max_len,
            "edit_len": edit_len,
            "edit_ratio": 0.0,
            "sections": section_results,
        }
    else:
        return {
            "max_len": max_len,
            "edit_len": edit_len,
            "edit_ratio": edit_len / max_len,
            "sections": section_results,
        }


def calc_section_cdm(section: dict, output_dir: str):
    # pred: str = section["norm_pred"] if section.get("norm_pred") else section["pred"]
    # gt: str = section["norm_gt"] if section.get("norm_gt") else section["gt"]
    pred: str = section["pred"]
    gt: str = section["gt"]

    pred = pred.split("```latex")[-1].split("```")[0]
    pred = pred.lstrip("$$").rstrip("$$").strip()
    pred = pred.lstrip("$").rstrip("$").strip()

    gt = gt.lstrip("$$").rstrip("$$").strip()
    gt = gt.lstrip("$").rstrip("$").strip()

    cdm_evaluator = CDM()
    try:
        cdm = cdm_evaluator.evaluate(gt, pred, output_dir)["F1_score"]
        return {"cdm": cdm}
    except Exception as e:
        print("Error calculating CDM:", str(e))
        return {"cdm": 0.0}


def _calc_section_cdm_wrapper(args):
    return calc_section_cdm(*args)


def calc_page_cdm(sections: list[dict], output_dir: str, executor: ProcessPoolExecutor | None = None):
    args_list = [(s, os.path.join(output_dir, str(idx))) for idx, s in enumerate(sections)]
    if executor:
        section_results = list(executor.map(_calc_section_cdm_wrapper, args_list))
    else:
        section_results = [calc_section_cdm(*args) for args in args_list]

    cdm_sum = 0
    for section_result in section_results:
        cdm_sum += section_result["cdm"]

    if len(section_results) == 0:
        return {
            "cdm": None,
            "sections": section_results,
        }
    else:
        return {
            "cdm": cdm_sum / len(section_results),
            "sections": section_results,
        }


def calc_section_teds(section: dict):
    pred = section["norm_pred"] if section.get("norm_pred") else section["pred"]
    gt = section["norm_gt"] if section.get("norm_gt") else section["gt"]

    try:
        teds_evaluator = TEDS(structure_only=False)
        teds = teds_evaluator.evaluate(pred, gt)
    except Exception as e:
        print("Error calculating TEDS:", str(e))
        teds = 0.0

    try:
        s_teds_evaluator = TEDS(structure_only=True)
        s_teds = s_teds_evaluator.evaluate(pred, gt)
    except Exception as e:
        print("Error calculating S-TEDS:", str(e))
        s_teds = 0.0

    return {"teds": teds, "s_teds": s_teds}


def calc_page_teds(sections: list[dict]):
    section_results = [calc_section_teds(s) for s in sections]
    teds_sum, s_teds_sum = 0, 0
    for section_result in section_results:
        teds_sum += section_result["teds"]
        s_teds_sum += section_result["s_teds"]
    if len(section_results) == 0:
        return {
            "teds": None,
            "s_teds": None,
            "sections": section_results,
        }
    else:
        return {
            "teds": teds_sum / len(section_results),
            "s_teds": s_teds_sum / len(section_results),
            "sections": section_results,
        }


#############################################


def _merge_list(list1: list[dict], list2: list[dict]):
    if len(list1) != len(list2):
        raise RuntimeError("Cannot merge lists of different lengths")
    merged = []
    for item1, item2 in zip(list1, list2):
        merged.append(_merge_dict(item1, item2))
    return merged


def _merge_dict(dict1: dict[str, Any], dict2: dict[str, Any]):
    merged = dict1.copy()
    for key, val2 in dict2.items():
        val1 = merged.get(key)
        if isinstance(val1, dict) and isinstance(val2, dict):
            merged[key] = _merge_dict(val1, val2)
        elif isinstance(val1, list) and isinstance(val2, list):
            merged[key] = _merge_list(val1, val2)
        else:  # None, other types
            merged[key] = val2
    return merged


def evaluate_page(
    predicted_markdown: str,
    ground_truth: dict,
    cdm_dir: str,
    cdm_executor: ProcessPoolExecutor | None = None,
    match_method: str = "quick_match",
    disable_cdm: bool = False,
    disable_teds: bool = False,
) -> dict:
    plain_text_match = []
    display_formula_match = []
    html_table_match = []
    latex_table_match = []
    order_match = []

    # 对单个样本匹配，根据不同的元素类型（如文本块、显示公式、表格等），使用指定的匹配方法将gt与预测结果进行匹配，并返回匹配结果
    (
        plain_text_match_clean,
        formated_display_formula,
        latex_table_match_s,
        html_table_match_s,
        order_match_single,
    ) = process_get_matched_elements(
        match_method, ground_truth, predicted_markdown, "FAKE-IMAGE-NAME"
    )  # Don't use timeout logic

    if order_match_single:
        order_match.append(order_match_single)

    # print(order_match)

    if plain_text_match_clean:
        plain_text_match.extend(plain_text_match_clean)

    if formated_display_formula:
        display_formula_match.extend(formated_display_formula)

    if latex_table_match_s:
        latex_table_match.extend(latex_table_match_s)

    if html_table_match_s:
        html_table_match.extend(html_table_match_s)

    display_formula_match_clean, display_formula_match_others = [], []
    for item in display_formula_match:
        pred_category_type = item.get("pred_category_type", None)
        if pred_category_type not in ["equation_inline", "equation", ""]:
            gt = item.get("gt", None)
            norm_gt = item.get("norm_gt", None)
            ## latex2unicode
            try:
                item["gt"] = LatexNodes2Text().latex_to_text(gt)
            except:
                item["gt"] = gt
            # item['norm_gt'] = LatexNodes2Text().latex_to_text(norm_gt)  # 错了，这里的norm gt是跑的normalized_formula函数，所以再跑latex2unicode会报错
            # 这里的norm_gt应该是跑文本的nrom了
            item["norm_gt"] = clean_string(item["gt"])
            display_formula_match_others.append(item)
        else:
            display_formula_match_clean.append(item)
    display_formula_match = display_formula_match_clean
    if display_formula_match_others and plain_text_match:
        plain_text_match.extend(display_formula_match_others)

    #  将latex合并到html 全量428
    if latex_table_match:
        latex_to_html = []
        for latex_table in latex_table_match:
            for k, v in latex_table.items():
                if "pred" in k:
                    latex_table[k] = ""
            latex_table["edit"] = 1
            latex_to_html.append(latex_table)
        html_table_match.extend(latex_to_html)

    # Assume model won't randomly output both latex and html, but will choose one
    if len(latex_table_match) > len(html_table_match):
        table_match = latex_table_match
        table_format = "latex"
    else:
        table_match = html_table_match
        table_format = "html"

    for table in table_match:
        table["norm_pred"] = normalized_table(table["pred"], table_format)
        table["norm_gt"] = normalized_table(table["gt"])

    result = {
        "text": calc_page_edit_dist(plain_text_match),
        "equation": _merge_dict(
            calc_page_edit_dist(display_formula_match),
            calc_page_cdm(display_formula_match, cdm_dir, cdm_executor) if not disable_cdm else {},
        ),
        "table": _merge_dict(
            calc_page_edit_dist(table_match),
            calc_page_teds(table_match) if not disable_teds else {},
        ),
        "reading_order": calc_page_edit_dist(order_match),
    }

    return result


class _EvaluatePageProgressHelper:
    def __init__(self, progress_queue):
        self.progress_queue = progress_queue

    def evaluate_page(self, kwargs):
        result = evaluate_page(**kwargs)
        self.progress_queue.put(1)
        return result


def evaluate_pages(
    predicted_markdowns_and_ground_truths: list[tuple[str, dict]],
    cdm_root_dir: str,
    match_method: str = "quick_match",
    disable_cdm: bool = False,
    disable_teds: bool = False,
    max_workers: int = 0,
) -> list[dict]:
    kwargs_list = [
        {
            "predicted_markdown": predicted_markdown,
            "ground_truth": ground_truth,
            "cdm_dir": f"{cdm_root_dir}/{idx}",
            "match_method": match_method,
            "disable_cdm": disable_cdm,
            "disable_teds": disable_teds,
        }
        for idx, (
            predicted_markdown,
            ground_truth,
        ) in enumerate(predicted_markdowns_and_ground_truths)
    ]

    if max_workers <= 0:
        return [
            evaluate_page(**kwargs)
            for kwargs in tqdm(
                kwargs_list,
                total=len(kwargs_list),
            )
        ]

    with tqdm(total=len(kwargs_list)) as progress_bar:
        manager = Manager()
        progress_queue = manager.Queue()
        helper = _EvaluatePageProgressHelper(progress_queue)

        def queue_listener():
            while progress_queue.get():
                progress_bar.update()

        listener_thread = Thread(target=queue_listener, daemon=True)
        listener_thread.start()

        with ProcessPoolExecutor(max_workers=max_workers) as executor:
            results = list(executor.map(helper.evaluate_page, kwargs_list))

        progress_queue.put(None)
        listener_thread.join()
        manager.shutdown()

        return results
