from typing import Final, Optional, Union, List, Dict, Any
from pathlib import Path

from layoutparser.models.detectron2.layoutmodel import (
    is_detectron2_available,
    Detectron2LayoutModel,
)
from layoutparser.models.model_config import LayoutModelConfig
from PIL import Image
from huggingface_hub import hf_hub_download

from unstructured_inference.logger import logger
from unstructured_inference.models.unstructuredmodel import UnstructuredModel
from unstructured_inference.utils import LazyDict, LazyEvaluateInfo


DETECTRON_CONFIG: Final = "lp://PubLayNet/faster_rcnn_R_50_FPN_3x/config"
DEFAULT_LABEL_MAP: Final[Dict[int, str]] = {
    0: "Text",
    1: "Title",
    2: "List",
    3: "Table",
    4: "Figure",
}
DEFAULT_EXTRA_CONFIG: Final[List[Any]] = ["MODEL.ROI_HEADS.SCORE_THRESH_TEST", 0.8]


# NOTE(alan): Entries are implemented as LazyDicts so that models aren't downloaded until they are
# needed.
MODEL_TYPES = {
    None: LazyDict(
        model_path=LazyEvaluateInfo(
            hf_hub_download,
            "layoutparser/detectron2",
            "PubLayNet/faster_rcnn_R_50_FPN_3x/model_final.pth",
        ),
        config_path=LazyEvaluateInfo(
            hf_hub_download,
            "layoutparser/detectron2",
            "PubLayNet/faster_rcnn_R_50_FPN_3x/config.yml",
        ),
        label_map=DEFAULT_LABEL_MAP,
        extra_config=DEFAULT_EXTRA_CONFIG,
    ),
    "checkbox": LazyDict(
        model_path=LazyEvaluateInfo(
            hf_hub_download, "unstructuredio/oer-checkbox", "detectron2_finetuned_oer_checkbox.pth"
        ),
        config_path=LazyEvaluateInfo(
            hf_hub_download, "unstructuredio/oer-checkbox", "detectron2_oer_checkbox.json"
        ),
        label_map={0: "Unchecked", 1: "Checked"},
        extra_config=None,
    ),
}


class UnstructuredDetectronModel(UnstructuredModel):
    """Unstructured model wrapper for Detectron2LayoutModel."""

    def predict(self, x: Image):
        super().predict(x)
        return self.model.detect(x)

    def initialize(
        self,
        config_path: Union[str, Path, LayoutModelConfig],
        model_path: Optional[Union[str, Path]] = None,
        label_map: Optional[Dict[int, str]] = None,
        extra_config: Optional[list] = None,
        device: Optional[str] = None,
    ):
        """Loads the detectron2 model using the specified parameters"""

        if not is_detectron2_available():
            raise ImportError(
                "Failed to load the Detectron2 model. Ensure that the Detectron2 "
                "module is correctly installed."
            )

        config_path_str = str(config_path)
        model_path_str: Optional[str] = None if model_path is None else str(model_path)
        logger.info("Loading the Detectron2 layout model ...")
        self.model = Detectron2LayoutModel(
            config_path_str,
            model_path=model_path_str,
            label_map=label_map,
            extra_config=extra_config,
            device=device,
        )
