from __future__ import annotations

import logging
from datetime import datetime
from typing import TYPE_CHECKING, Any, Generator, Iterable, overload
from uuid import UUID

from orca_sdk._utils.common import UNSET

from ._generated_api_client.api import (
    drop_feedback_category_with_data,
    get_prediction,
    list_feedback_categories,
    list_predictions,
    record_prediction_feedback,
    update_prediction,
)
from ._generated_api_client.client import get_client
from ._generated_api_client.errors import get_error_for_response
from ._generated_api_client.models import (
    FeedbackType,
    LabelPredictionWithMemoriesAndFeedback,
    ListPredictionsRequest,
    PredictionFeedbackCategory,
    PredictionFeedbackRequest,
    UpdatePredictionRequest,
)
from ._generated_api_client.types import UNSET as CLIENT_UNSET
from ._utils.prediction_result_ui import inspect_prediction_result
from .memoryset import LabeledMemoryLookup, LabeledMemoryset

if TYPE_CHECKING:
    from .classification_model import ClassificationModel


def _parse_feedback(feedback: dict[str, Any]) -> PredictionFeedbackRequest:
    category = feedback.get("category", None)
    if category is None:
        raise ValueError("`category` must be specified")
    prediction_id = feedback.get("prediction_id", None)
    if prediction_id is None:
        raise ValueError("`prediction_id` must be specified")
    return PredictionFeedbackRequest(
        prediction_id=prediction_id,
        category_name=category,
        value=feedback.get("value", CLIENT_UNSET),
        comment=feedback.get("comment", CLIENT_UNSET),
    )


class FeedbackCategory:
    """
    A category of feedback for predictions.

    Categories are created automatically, the first time feedback with a new name is recorded.
    The value type of the category is inferred from the first recorded value. Subsequent feedback
    for the same category must be of the same type. Categories are not model specific.

    Attributes:
        id: Unique identifier for the category.
        name: Name of the category.
        value_type: Type that values for this category must have.
        created_at: When the category was created.
    """

    id: str
    name: str
    value_type: type[bool] | type[float]
    created_at: datetime

    def __init__(self, category: PredictionFeedbackCategory):
        # for internal use only, do not document
        self.id = category.id
        self.name = category.name
        self.value_type = bool if category.type == FeedbackType.BINARY else float
        self.created_at = category.created_at

    @classmethod
    def all(cls) -> list[FeedbackCategory]:
        """
        Get a list of all existing feedback categories.

        Returns:
            List with information about all existing feedback categories.
        """
        return [FeedbackCategory(category) for category in list_feedback_categories()]

    @classmethod
    def drop(cls, name: str) -> None:
        """
        Drop all feedback for this category and drop the category itself, allowing it to be
        recreated with a different value type.

        Warning:
            This will delete all feedback in this category across all models.

        Params:
            name: Name of the category to drop.

        Raises:
            LookupError: If the category is not found.
        """
        drop_feedback_category_with_data(name)
        logging.info(f"Deleted feedback category {name} with all associated feedback")

    def __repr__(self):
        return "FeedbackCategory({" + f"name: {self.name}, " + f"value_type: {self.value_type}" + "})"


class LabelPrediction:
    """
    A prediction made by a model

    Attributes:
        prediction_id: Unique identifier for the prediction
        label: Predicted label for the input value
        label_name: Name of the predicted label
        confidence: Confidence of the prediction
        anomaly_score: The score for how anomalous the input is relative to the memories
        memory_lookups: List of memories used to ground the prediction
        input_value: Input value that this prediction was for
        model: Model that was used to make the prediction
        memoryset: Memoryset that was used to lookup memories to ground the prediction
        expected_label: Optional expected label that was set for the prediction
        expected_label_name: Name of the expected label
        tags: tags that were set for the prediction
        feedback: Feedback recorded, mapping from category name to value
        explanation: Explanation why the model made the prediction generated by a reasoning agent
    """

    prediction_id: str
    label: int
    label_name: str | None
    confidence: float
    anomaly_score: float | None
    memoryset: LabeledMemoryset
    model: ClassificationModel

    def __init__(
        self,
        prediction_id: str,
        *,
        label: int,
        label_name: str | None,
        confidence: float,
        anomaly_score: float | None,
        memoryset: LabeledMemoryset | str,
        model: ClassificationModel | str,
        telemetry: LabelPredictionWithMemoriesAndFeedback | None = None,
    ):
        # for internal use only, do not document
        from .classification_model import ClassificationModel

        self.prediction_id = prediction_id
        self.label = label
        self.label_name = label_name
        self.confidence = confidence
        self.anomaly_score = anomaly_score
        self.memoryset = LabeledMemoryset.open(memoryset) if isinstance(memoryset, str) else memoryset
        self.model = ClassificationModel.open(model) if isinstance(model, str) else model
        self.__telemetry = telemetry if telemetry else None

    def __repr__(self):
        return (
            "LabelPrediction({"
            + f"label: <{self.label_name}: {self.label}>, "
            + f"confidence: {self.confidence:.2f}, "
            + f"anomaly_score: {self.anomaly_score:.2f}, "
            if self.anomaly_score is not None
            else ""
            + f"input_value: '{str(self.input_value)[:100] + '...' if len(str(self.input_value)) > 100 else self.input_value}'"
            + "})"
        )

    @property
    def _telemetry(self) -> LabelPredictionWithMemoriesAndFeedback:
        # for internal use only, do not document
        if self.__telemetry is None:
            self.__telemetry = get_prediction(prediction_id=UUID(self.prediction_id))
        return self.__telemetry

    @property
    def memory_lookups(self) -> list[LabeledMemoryLookup]:
        return [LabeledMemoryLookup(self.memoryset.id, lookup) for lookup in self._telemetry.memories]

    @property
    def input_value(self) -> str | list[list[float]] | None:
        return self._telemetry.input_value

    @property
    def feedback(self) -> dict[str, bool | float]:
        return {
            f.category_name: (
                f.value if f.category_type == FeedbackType.CONTINUOUS else True if f.value == 1 else False
            )
            for f in self._telemetry.feedbacks
        }

    @property
    def expected_label(self) -> int | None:
        return self._telemetry.expected_label

    @property
    def expected_label_name(self) -> str | None:
        return self._telemetry.expected_label_name

    @property
    def tags(self) -> set[str]:
        return set(self._telemetry.tags)

    def _explanation_stream(self, refresh: bool = False) -> Generator[str, None, None]:
        httpx_client = get_client().get_httpx_client()
        url = f"/telemetry/prediction/{self.prediction_id}/explanation?refresh={refresh}"
        with httpx_client.stream("GET", url) as res:
            if res.status_code != 200:
                # Read the response body before raising the error
                res.read()
                raise get_error_for_response(res)
            for chunk in res.iter_text():
                yield chunk

    @property
    def explanation(self) -> str:
        if self._telemetry.explanation is None:
            self._telemetry.explanation = "".join(self._explanation_stream())
        return self._telemetry.explanation

    def explain(self, refresh: bool = False) -> None:
        """
        Print an explanation of the prediction as a stream of text.

        Params:
            refresh: Force the explanation agent to re-run even if an explanation already exists.
        """
        if not refresh and self._telemetry.explanation is not None:
            print(self._telemetry.explanation)
        else:
            for chunk in self._explanation_stream(refresh):
                print(chunk, end="")
            print()  # final newline

    @overload
    @classmethod
    def get(cls, prediction_id: str) -> LabelPrediction:  # type: ignore -- this takes precedence
        pass

    @overload
    @classmethod
    def get(cls, prediction_id: Iterable[str]) -> list[LabelPrediction]:
        pass

    @classmethod
    def get(cls, prediction_id: str | Iterable[str]) -> LabelPrediction | list[LabelPrediction]:
        """
        Fetch a prediction or predictions

        Params:
            prediction_id: Unique identifier of the prediction or predictions to fetch

        Returns:
            Prediction or list of predictions

        Raises:
            LookupError: If no prediction with the given id is found

        Examples:
            Fetch a single prediction:
            >>> LabelPrediction.get("0195019a-5bc7-7afb-b902-5945ee1fb766")
            LabelPrediction({
                label: <positive: 1>,
                confidence: 0.95,
                anomaly_score: 0.1,
                input_value: "I am happy",
                memoryset: "my_memoryset",
                model: "my_model"
            })

            Fetch multiple predictions:
            >>> LabelPrediction.get([
            ...     "0195019a-5bc7-7afb-b902-5945ee1fb766",
            ...     "019501a1-ea08-76b2-9f62-95e4800b4841",
            ... ])
            [
                LabelPrediction({
                    label: <positive: 1>,
                    confidence: 0.95,
                    anomaly_score: 0.1,
                    input_value: "I am happy",
                    memoryset: "my_memoryset",
                    model: "my_model"
                }),
                LabelPrediction({
                    label: <negative: 0>,
                    confidence: 0.05,
                    anomaly_score: 0.2,
                    input_value: "I am sad",
                    memoryset: "my_memoryset", model: "my_model"
                }),
            ]
        """
        if isinstance(prediction_id, str):
            prediction = get_prediction(prediction_id=UUID(prediction_id))
            return cls(
                prediction_id=prediction.prediction_id,
                label=prediction.label,
                label_name=prediction.label_name,
                confidence=prediction.confidence,
                anomaly_score=prediction.anomaly_score,
                memoryset=prediction.memoryset_id,
                model=prediction.model_id,
                telemetry=prediction,
            )
        else:
            return [
                cls(
                    prediction_id=prediction.prediction_id,
                    label=prediction.label,
                    label_name=prediction.label_name,
                    confidence=prediction.confidence,
                    anomaly_score=prediction.anomaly_score,
                    memoryset=prediction.memoryset_id,
                    model=prediction.model_id,
                    telemetry=prediction,
                )
                for prediction in list_predictions(body=ListPredictionsRequest(prediction_ids=list(prediction_id)))
            ]

    def refresh(self):
        """Refresh the prediction data from the OrcaCloud"""
        self.__dict__.update(LabelPrediction.get(self.prediction_id).__dict__)

    def inspect(self):
        """Open a UI to inspect the memories used by this prediction"""
        inspect_prediction_result(self)

    def update(self, *, expected_label: int | None = UNSET, tags: set[str] | None = UNSET) -> None:
        """
        Update editable prediction properties.

        Params:
            expected_label: Value to set for the expected label, defaults to `[UNSET]` if not provided.
            tags: Value to replace existing tags with, defaults to `[UNSET]` if not provided.

        Examples:
            Update the expected label:
            >>> prediction.update(expected_label=1)

            Add a new tag:
            >>> prediction.update(tags=prediction.tags | {"new_tag"})

            Remove expected label and tags:
            >>> prediction.update(expected_label=None, tags=None)
        """
        update_prediction(
            prediction_id=self.prediction_id,
            body=UpdatePredictionRequest(
                expected_label=expected_label if expected_label is not UNSET else CLIENT_UNSET,
                tags=[] if tags is None else list(tags) if tags is not UNSET else CLIENT_UNSET,
            ),
        )
        self.refresh()

    def add_tag(self, tag: str) -> None:
        """
        Add a tag to the prediction

        Params:
            tag: Tag to add to the prediction
        """
        self.update(tags=self.tags | {tag})

    def remove_tag(self, tag: str) -> None:
        """
        Remove a tag from the prediction

        Params:
            tag: Tag to remove from the prediction
        """
        self.update(tags=self.tags - {tag})

    def record_feedback(
        self,
        category: str,
        value: bool | float,
        *,
        comment: str | None = None,
    ):
        """
        Record feedback for the prediction.

        We support recording feedback in several categories for each prediction. A
        [`FeedbackCategory`][orca_sdk.telemetry.FeedbackCategory] is created automatically,
        the first time feedback with a new name is recorded. Categories are global across models.
        The value type of the category is inferred from the first recorded value. Subsequent
        feedback for the same category must be of the same type.

        Params:
            category: Name of the category under which to record the feedback.
            value: Feedback value to record, should be `True` for positive feedback and `False` for
                negative feedback or a [`float`][float] between `-1.0` and `+1.0` where negative
                values indicate negative feedback and positive values indicate positive feedback.
            comment: Optional comment to record with the feedback.

        Examples:
            Record whether a suggestion was accepted or rejected:
            >>> prediction.record_feedback("accepted", True)

            Record star rating as normalized continuous score between `-1.0` and `+1.0`:
            >>> prediction.record_feedback("rating", -0.5, comment="2 stars")

        Raises:
            ValueError: If the value does not match previous value types for the category, or is a
                [`float`][float] that is not between `-1.0` and `+1.0`.
        """
        record_prediction_feedback(
            body=[
                _parse_feedback(
                    {"prediction_id": self.prediction_id, "category": category, "value": value, "comment": comment}
                )
            ]
        )
        self.refresh()

    def delete_feedback(self, category: str) -> None:
        """
        Delete prediction feedback for a specific category.

        Params:
            category: Name of the category of the feedback to delete.

        Raises:
            ValueError: If the category is not found.
        """
        record_prediction_feedback(
            body=[PredictionFeedbackRequest(prediction_id=self.prediction_id, category_name=category, value=None)]
        )
        self.refresh()
