import colorsys
import hashlib
import math
from collections import defaultdict
from collections.abc import Generator, Hashable, Iterator, Mapping

import cv2
import matplotlib.colors
import numpy as np
from bidict import bidict

from luxonis_ml.data.utils import get_task_name, task_type_iterator
from luxonis_ml.typing import HSV, RGB, Color, Labels

FONT = cv2.FONT_HERSHEY_SIMPLEX


class ColorMap(Mapping[Hashable, RGB]):
    """A mapping that assigns distinct RGB colors to hashable labels.

    The C{ColorMap} class generates and stores distinct colors for any
    hashable labels. Colors are lazily assigned upon request using a
    L{distinct_color_generator}.
    """

    def __init__(self):
        self._generator = distinct_color_generator()
        self._color_dict: dict[Hashable, RGB] = {}

    def __getitem__(self, label: Hashable) -> RGB:
        if label not in self._color_dict:
            self._color_dict[label] = next(self._generator)
        return self._color_dict[label]

    def __iter__(self) -> Iterator[Hashable]:
        return iter(self._color_dict)

    def __len__(self) -> int:
        return len(self._color_dict)


def distinct_color_generator(stop: int = -1) -> Generator[RGB, None, None]:
    """Generate distinct RGB colors using the golden ratio.

    This generator produces a sequence of distinct colors in RGB format.
    The colors are generated by incrementing the hue by the golden ratio
    and keeping saturation and value fixed. This ensures a wide
    distribution of visually distinct colors.

    @param stop: Optional. The maximum number of colors to generate. If
        set to -1 (default), the generator will continue indefinitely.
    @type stop: int
    @yield: A tuple representing an RGB color, where each component (R,
        G, B) is an integer in the range [0, 255].
    @rtype: Generator[tuple[int, int, int], None, None]
    """
    golden_ratio = 0.618033988749895
    hue = 0.0
    i = 0
    while True:
        if i == stop:
            break
        hue = (hue + golden_ratio) % 1
        saturation = 0.8
        value = 0.95
        r, g, b = colorsys.hsv_to_rgb(hue, saturation, value)
        yield int(r * 255), int(g * 255), int(b * 255)
        i += 1


def resolve_color(color: Color) -> RGB:
    """Resolves a color to an RGB tuple.

    @type color: Color
    @param color: The color to resolve. Can be a string, an integer or a
        tuple.
    @rtype: Tuple[int, int, int]
    @return: The RGB tuple.
    """

    def _check_range(val: int) -> None:
        if val < 0 or val > 255:
            raise ValueError(f"Color value {val} is out of range [0, 255]")

    if isinstance(color, str):
        return matplotlib.colors.to_rgb(color)  # type: ignore
    if isinstance(color, int):
        _check_range(color)
        return color, color, color
    for c in color:
        _check_range(c)
    return color


def rgb_to_hsv(color: Color) -> HSV:
    """Converts an RGB color to HSV.

    @type color: Color
    @param color: The color to convert.
    @rtype: Tuple[float, float, float]
    @return: The HSV tuple.
    """
    r, g, b = resolve_color(color)
    h, s, br = colorsys.rgb_to_hsv(r / 255, g / 255, b / 255)
    return h * 360, s, br


def hsv_to_rgb(color: HSV) -> RGB:
    """Converts an HSV color to RGB.

    @type color: Tuple[float, float, float]
    @param color: The color to convert as an HSV tuple.
    @rtype: Tuple[int, int, int]
    @return: The RGB tuple.
    """
    h, s, b = color
    r, g, b = colorsys.hsv_to_rgb(h / 360, s, b)
    return int(r * 255), int(g * 255), int(b * 255)


def get_contrast_color(color: Color) -> RGB:
    """Returns a contrasting color for the given RGB color.

    @type color: Color
    @param color: The color to contrast.
    @rtype: Tuple[int, int, int]
    @return: The contrasting color.
    """

    h, s, v = rgb_to_hsv(resolve_color(color))
    h = (h + 180) % 360
    return hsv_to_rgb((h, s, v))


def str_to_rgb(string: str) -> RGB:
    """Converts a string to its unique RGB color.

    @type string: str
    @param string: The string to convert.
    @rtype: Tuple[int, int, int]
    @return: The RGB tuple.
    """
    h = int(hashlib.md5(string.encode()).hexdigest(), 16)  # noqa: S324
    r = (h & 0xFF0000) >> 16
    g = (h & 0x00FF00) >> 8
    b = h & 0x0000FF

    return r, g, b


def draw_dashed_rectangle(
    image: np.ndarray,
    pt1: tuple[int, int],
    pt2: tuple[int, int],
    color: Color,
    thickness: int = 1,
    dash_length: int = 10,
) -> None:
    """Draws a dashed rectangle on the image.

    @type image: np.ndarray
    @param image: The image to draw on.
    @type pt1: Tuple[int, int]
    @param pt1: The top-left corner of the rectangle.
    @type pt2: Tuple[int, int]
    @param pt2: The bottom-right corner of the rectangle.
    @type color: Color
    @param color: The color of the rectangle.
    @type thickness: int
    @param thickness: The thickness of the rectangle. Default is 1.
    @type dash_length: int
    @param dash_length: The length of the dashes. Default is 10.
    """
    x1, y1 = pt1
    x2, y2 = pt2

    def draw_dashed_line(p1: tuple[int, int], p2: tuple[int, int]) -> None:
        line_length = int(np.hypot(p2[0] - p1[0], p2[1] - p1[1]))
        dashes = [
            (i, i + dash_length)
            for i in range(0, line_length, 2 * dash_length)
        ]
        for start, end in dashes:
            end = min(end, line_length)
            start_point = (
                int(p1[0] + (p2[0] - p1[0]) * start / line_length),
                int(p1[1] + (p2[1] - p1[1]) * start / line_length),
            )
            end_point = (
                int(p1[0] + (p2[0] - p1[0]) * end / line_length),
                int(p1[1] + (p2[1] - p1[1]) * end / line_length),
            )
            cv2.line(
                image, start_point, end_point, resolve_color(color), thickness
            )

    draw_dashed_line((x1, y1), (x2, y1))
    draw_dashed_line((x2, y1), (x2, y2))
    draw_dashed_line((x2, y2), (x1, y2))
    draw_dashed_line((x1, y2), (x1, y1))


def draw_cross(
    img: np.ndarray,
    center: tuple[int, int],
    size: int = 5,
    color: Color = 0,
    thickness: int = 1,
) -> None:
    """Draws a cross on the image.

    @type img: np.ndarray
    @param img: The image to draw on.
    @type center: Tuple[int, int]
    @param center: The center of the cross.
    @type size: int
    @param size: The size of the cross. Default is 5.
    @type color: Color
    @param color: The color of the cross. Default is black.
    @type thickness: int
    @param thickness: The thickness of the cross. Default is 1.
    """
    x, y = center
    color = resolve_color(color)
    cv2.line(img, (x - size, y), (x + size, y), color, thickness)
    cv2.line(img, (x, y - size), (x, y + size), color, thickness)


def create_text_image(
    text: str,
    width: int,
    height: int,
    font_size: float = 0.7,
    bg_color: Color = 255,
    text_color: Color = 0,
) -> np.ndarray:
    """Creates an image with the given text centered in the image.

    @type text: str
    @param text: The text to display.
    @type width: int
    @param width: The width of the image.
    @type height: int
    @param height: The height of the image.
    @type font_size: float
    @param font_size: The font size of the text. Default is 0.7.
    @type bg_color: Tuple[int, int, int]
    @param bg_color: The background color of the image. Default is
        white.
    @type text_color: Tuple[int, int, int]
    @param text_color: The color of the text. Default is black.
    """
    img = np.full((height, width, 3), resolve_color(bg_color), dtype=np.uint8)

    text_size = cv2.getTextSize(text, FONT, font_size, 1)[0]

    text_x = (width - text_size[0]) // 2
    text_y = (height + text_size[1]) // 2

    cv2.putText(
        img,
        text,
        (text_x, text_y),
        fontFace=FONT,
        fontScale=font_size,
        color=resolve_color(text_color),
        thickness=1,
        lineType=cv2.LINE_AA,
    )

    return img


def concat_images(
    image_dict: dict[str, np.ndarray],
    padding: int = 10,
    label_height: int = 30,
) -> np.ndarray:
    """Concatenates images into a single image with labels.

    It will attempt to create a square grid of images.

    @type image_dict: Dict[str, np.ndarray]
    @param image_dict: A dictionary mapping image names to images.
    @type padding: int
    @param padding: The padding between images. Default is 10.
    @type label_height: int
    @param label_height: The height of the label. Default
    @rtype: np.ndarray
    @return: The concatenated image.
    """
    n_images = len(image_dict)
    n_cols = math.ceil(math.sqrt(n_images))
    n_rows = math.ceil(n_images / n_cols)

    max_h = max(img.shape[0] for img in image_dict.values())
    max_w = max(img.shape[1] for img in image_dict.values())

    cell_height = max_h + 2 * padding + label_height
    cell_width = max_w + 2 * padding

    output = np.full(
        (cell_height * n_rows, cell_width * n_cols, 3), 255, dtype=np.uint8
    )

    for idx, (name, img) in enumerate(image_dict.items()):
        i = idx // n_cols
        j = idx % n_cols

        y_start = i * cell_height
        x_start = j * cell_width

        label = create_text_image(name, cell_width, label_height)
        output[
            y_start : y_start + label_height, x_start : x_start + cell_width
        ] = label

        h, w = img.shape[:2]
        y_img = y_start + label_height + padding
        x_img = x_start + padding
        output[y_img : y_img + h, x_img : x_img + w] = img

    return output


def draw_bbox_label(
    image: np.ndarray,
    class_name: str,
    box: np.ndarray,
    color: tuple[int, int, int],
    font_scale: float,
) -> None:
    """Draws the classname label at the top-left corner of the bounding
    box.

    @type image: np.ndarray
    @param image: The image to draw on.
    @type class_name: str
    @param class_name: The name of the class.
    @type box: np.ndarray
    @param box: The bounding box coordinates. The format is [class_id,
        x1, y1, x2, y2], where (x1, y1) is the top-left corner and (x2,
        y2) is the bottom-right corner.
    @type color: Tuple[int, int, int]
    @param color: The color of the label.
    @type font_scale: float
    @param font_scale: The scale of the font.
    """
    text = class_name

    text_size = cv2.getTextSize(text, FONT, font_scale, 1)[0]
    text_x = box[1]
    text_y = max(box[2] - 5, text_size[1])

    cv2.rectangle(
        image,
        (text_x - 2, text_y - text_size[1] - 2),
        (text_x + text_size[0] + 2, text_y + 2),
        color,
        -1,
    )
    cv2.putText(
        image,
        text,
        (text_x, text_y),
        FONT,
        font_scale,
        (255, 255, 255),
        1,
        cv2.LINE_AA,
    )


def draw_keypoint_label(
    image: np.ndarray,
    text: str,
    point: tuple[int, int],
    size: int,
    color: tuple[int, int, int],
    font_scale: float,
) -> None:
    """Draws a text label next to a keypoint on the image.

    @type image: np.ndarray
    @param image: The image to draw on.
    @type text: str
    @param text: The text to draw.
    @type point: Tuple[int, int]
    @param point: The coordinates of the keypoint.
    @type size: int
    @param size: The size of the keypoint.
    @type color: Tuple[int, int, int]
    @param color: The color of the text.
    @type font_scale: float
    """
    text_size = cv2.getTextSize(text, FONT, font_scale, 1)[0]
    text_x = point[0] + size + 2
    text_y = point[1] + text_size[1] // 2

    cv2.putText(
        image,
        text,
        (text_x, text_y),
        FONT,
        font_scale,
        color,
        1,
        cv2.LINE_AA,
    )


def visualize(
    image: np.ndarray,
    source_name: str,
    labels: Labels,
    classes: dict[str, dict[str, int]],
    blend_all: bool = False,
) -> np.ndarray:
    """Visualizes the labels on the image.

    @type image: np.ndarray
    @param image: The image to visualize.
    @type source_name: str
    @param source_name: The name of the source of the image.
    @type labels: Labels
    @param labels: The labels to visualize.
    @type class_names: Dict[str, List[str]]
    @param class_names: A dictionary mapping task names to a list of
        class names.
    @type blend_all: bool
    @param blend_all: Whether to blend all labels (apart from semantic
        segmentations) into a single image. This means mixing labels
        belonging to different tasks. Default is False.
    @rtype: np.ndarray
    @return: The visualized image.
    """
    h, w, _ = image.shape
    images = {source_name: image}
    mappings = {task: bidict(c) for task, c in classes.items()}

    min_dimension = min(h, w)
    font_scale = max(0.25, min(1.1, 0.4 * min_dimension / 500))
    small_kp_size = max(2, round(0.002 * min_dimension))
    large_kp_size = max(4, round(0.004 * min_dimension))

    def create_mask(
        image: np.ndarray, arr: np.ndarray, task_name: str, is_instance: bool
    ) -> np.ndarray:
        mask_viz = np.zeros((h, w, 3)).astype(np.uint8)
        for i, mask in enumerate(arr):
            mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
            if is_instance:
                task_classes = mappings[task_name]
                if len(bbox_classes[task_name]) > i:
                    class_id = bbox_classes[task_name][i]
                    color = str_to_rgb(task_classes.inverse[class_id])
                else:
                    color = (255, 0, 0)
                mask_viz[mask > 0] = color
            else:
                mask_viz[mask == 1] = (
                    str_to_rgb(mappings[task_name].inverse[i])
                    if (i != 0 or len(arr) == 1)
                    else (0, 0, 0)
                )

        binary_mask = (mask_viz > 0).astype(np.uint8)

        return np.where(
            binary_mask > 0,
            cv2.addWeighted(image, 0.4, mask_viz, 0.6, 0),
            image,
        )

    bbox_classes = defaultdict(list)

    for task, arr in task_type_iterator(labels, "segmentation"):
        task_name = get_task_name(task)
        image_name = task_name if task_name and not blend_all else "labels"
        images[image_name] = create_mask(
            image, arr, task_name, is_instance=False
        )

    for task, arr in task_type_iterator(labels, "boundingbox"):
        task_name = get_task_name(task)
        image_name = task_name if task_name and not blend_all else "labels"
        curr_image = images.get(image_name, image.copy())

        draw_function = cv2.rectangle

        is_sublabel = len(task.split("/")) > 2

        if is_sublabel:
            draw_function = draw_dashed_rectangle

        arr[:, [1, 3]] *= w
        arr[:, [2, 4]] *= h
        arr[:, 3] += arr[:, 1]
        arr[:, 4] += arr[:, 2]
        arr = arr.astype(int)

        for box in arr:
            class_id = int(box[0])
            bbox_classes[task_name].append(class_id)
            class_name = mappings[task_name].inverse[class_id]
            color = str_to_rgb(class_name)
            draw_function(
                curr_image,
                (box[1], box[2]),
                (box[3], box[4]),
                color,
                thickness=2,
            )
            draw_bbox_label(curr_image, class_name, box, color, font_scale)
        images[image_name] = curr_image

    for task, arr in task_type_iterator(labels, "instance_segmentation"):
        task_name = get_task_name(task)
        image_name = task_name if task_name and not blend_all else "labels"
        curr_image = images.get(image_name, image.copy())
        images[image_name] = create_mask(
            curr_image, arr, task_name, is_instance=True
        )

    for task, arr in task_type_iterator(labels, "keypoints"):
        task_name = get_task_name(task)
        image_name = task_name if task_name and not blend_all else "labels"
        curr_image = images.get(image_name, image.copy())

        task_classes = mappings[task_name]

        for i, kp in enumerate(arr):
            kp = kp.reshape(-1, 3)
            if len(bbox_classes[task_name]) > i:
                class_id = bbox_classes[task_name][i]
                color = get_contrast_color(
                    str_to_rgb(task_classes.inverse[class_id])
                )
            else:
                color = (255, 0, 0)
            for j, k in enumerate(kp):
                visibility = int(k[-1])
                if visibility == 0:
                    continue

                if visibility == 2:
                    draw_function = cv2.circle
                    size = small_kp_size
                else:
                    draw_function = draw_cross
                    size = large_kp_size

                point = (int(k[0] * w), int(k[1] * h))
                draw_function(
                    curr_image,
                    point,
                    size,
                    color=color,
                    thickness=2,
                )

                text = str(j)
                draw_keypoint_label(
                    curr_image, text, point, size, color, font_scale
                )

        images[image_name] = curr_image

    return concat_images(images)
