from pathlib import Path
from typing import Any

import cv2
import numpy as np
import polars as pl

from luxonis_ml.data import DatasetIterator

from .base_parser import BaseParser, ParserOutput


class SegmentationMaskDirectoryParser(BaseParser):
    """Parses directory with SegmentationMask annotations to LDF.

    Expected format::

        dataset_dir/
        ├── train/
        │   ├── img1.jpg
        │   ├── img1_mask.png
        │   ├── ...
        │   └── _classes.csv
        ├── valid/
        └── test/

    C{_classes.csv} contains mappings between pixel value and class name.

    This is one of the formats that can be generated by
    U{Roboflow <https://roboflow.com/>}.
    """

    @staticmethod
    def validate_split(split_path: Path) -> dict[str, Any] | None:
        if not split_path.exists():
            return None
        if not (split_path / "_classes.csv").exists():
            return None
        masks = list(split_path.glob("*_mask.*"))
        for mask_path in masks:
            img_path = split_path / f"{mask_path.stem[:-5]}.jpg"
            if not img_path.exists():
                return None
        return {
            "image_dir": split_path,
            "seg_dir": split_path,
            "classes_path": split_path / "_classes.csv",
        }

    @staticmethod
    def validate(dataset_dir: Path) -> bool:
        for split in ["train", "valid", "test"]:
            split_path = dataset_dir / split
            if (
                SegmentationMaskDirectoryParser.validate_split(split_path)
                is None
            ):
                return False
        return True

    def from_dir(
        self, dataset_dir: Path
    ) -> tuple[list[Path], list[Path], list[Path]]:
        added_train_imgs = self._parse_split(
            image_dir=dataset_dir / "train",
            seg_dir=dataset_dir / "train",
            classes_path=dataset_dir / "train" / "_classes.csv",
        )
        added_val_imgs = self._parse_split(
            image_dir=dataset_dir / "valid",
            seg_dir=dataset_dir / "valid",
            classes_path=dataset_dir / "valid" / "_classes.csv",
        )
        added_test_imgs = self._parse_split(
            image_dir=dataset_dir / "test",
            seg_dir=dataset_dir / "test",
            classes_path=dataset_dir / "test" / "_classes.csv",
        )
        return added_train_imgs, added_val_imgs, added_test_imgs

    def from_split(
        self, image_dir: Path, seg_dir: Path, classes_path: Path
    ) -> ParserOutput:
        """Parses annotations with SegmentationMask format to LDF.

        Annotations include classification and segmentation.

        @type image_dir: Path
        @param image_dir: Path to directory with images
        @type seg_dir: Path
        @param seg_dir: Path to directory with segmentation mask
        @type classes_path: Path
        @param classes_path: Path to CSV file with class names
        @rtype: L{ParserOutput}
        @return: Annotation generator, list of classes names, skeleton
            dictionary for keypoints and list of added images
        """

        # NOTE: space prefix included
        idx_class = " Class"

        df = pl.read_csv(classes_path).filter(pl.col(idx_class).is_not_null())
        class_names = df[idx_class].to_list()

        def generator() -> DatasetIterator:
            for mask_path in seg_dir.glob("*_mask.*"):
                image_path = next(image_dir.glob(f"{mask_path.stem[:-5]}.*"))
                file = str(image_path.absolute().resolve())
                mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)

                ids = np.unique(mask)
                for id in ids:
                    class_name = class_names[id]

                    curr_seg_mask = np.zeros_like(mask)
                    curr_seg_mask[mask == id] = 1
                    yield {
                        "file": file,
                        "annotation": {
                            "class": class_name,
                            "segmentation": {"mask": curr_seg_mask},
                        },
                    }

        added_images = self._get_added_images(generator())
        return generator(), {}, added_images
