# Copyright 2025 Louder Digital Pty Ltd.
# All Rights Reserved.
"""A serializable key / value pair that handles primitive BigQuery types."""

from __future__ import annotations

import datetime
import decimal
from types import UnionType
from typing import (
    TYPE_CHECKING,
    Annotated,
    Any,
    Self,
    Union,
    get_args,
    get_origin,
)

import pydantic
from ldr.modelling import bigquery
from ldr.modelling.bigquery import types
from pydantic.alias_generators import to_camel

if TYPE_CHECKING:
    from collections.abc import Iterable


def _unwrap_type(t: type[Any], /) -> type[Any]:
    """
    Unwrap a type wrapped in `typing.Annotated` or other Union types.

    Params
    ------
    t: The type to unwrap.

    Returns
    -------
    The unwrapped type, or the original type if it is not wrapped.

    Raises
    ------
    ldr.modelling.bigquery.UnsupportedTypeError: If the inner type cannot be parsed.

    """
    origin = get_origin(t)

    if origin is Annotated:
        return _unwrap_type(get_args(t)[0])

    if origin is Union or origin is UnionType:
        args = get_args(t)
        non_none_args = [arg for arg in args if arg is not type(None)]
        if len(non_none_args) == 1:
            return _unwrap_type(next(iter(non_none_args)))

        raise bigquery.UnsupportedTypeError(
            f"Cannot determine single inner type from {t}",
        )

    return t


class _PydanticProperty(pydantic.BaseModel, alias_generator=to_camel):
    """A property on a pydantic model."""

    any_of: list[Self] | None = None
    title: str | None = None
    type: str | None = None
    format: str | None = None
    items: Self | None = None


class Param(pydantic.BaseModel):
    """A key value pair serializable to BigQuery."""

    key: str
    string_value: str | None = None
    bool_value: bool | None = None
    int_value: int | None = None
    float_value: float | None = None
    numeric_value: Annotated[types.Numeric, types.NumericValidator] | None = None
    bignumeric_value: (
        Annotated[
            types.BigNumeric,
            types.BigNumericValidator,
        ]
        | None
    ) = None
    date_value: datetime.date | None = None
    datetime_value: datetime.datetime | None = None
    timestamp_value: Annotated[types.Timestamp, types.TimestampValidator] | None = None
    geography_value: Annotated[types.Geography, types.GeographyValidator] | None = None

    @classmethod
    def from_model(
        cls,
        model: pydantic.BaseModel,
        *,
        by_alias: bool = False,
    ) -> list[Self]:
        """
        Generate a list of `Param`s from a pydantic model instance.

        Returns
        -------
        The encoded model.

        Raises
        ------
        ldr.modelling.bigquery.MissingAnnotationError: If the field has no annotation.
        ldr.modelling.bigquery.MissingAliasError: If using 'by_alias'
                                                  but no alias exists.
        ValueError: If a model has an alias set but cannot find the respective field.

        """
        model_cls = type(model)
        schema = model.model_json_schema(by_alias=by_alias).get("properties", {})
        dump = model.model_dump(by_alias=by_alias)
        params: list[Self] = []

        for name, info in model_cls.model_fields.items():
            if not info.annotation:  # pragma: nocover
                # pydantic model validation makes this unreachable
                raise bigquery.MissingAnnotationError(
                    f"No annotation set on {model_cls}.{name}",
                )

            if by_alias and not info.alias:
                raise bigquery.MissingAliasError(
                    f"'by_alias' used but no alias provided for {model_cls}.{name}",
                )
            key = info.alias if by_alias and info.alias else name

            if key not in schema:  # pragma: nocover
                # pydantic model validation makes this unreachable
                raise ValueError(
                    f"Field '{key}' does not exist on class {model_cls}.\n"
                    f"{model_cls} schema = {schema}",
                )

            params.append(
                cls._from_key_info(
                    key,
                    field_type=info.annotation,
                    info=_PydanticProperty.model_validate(schema[key]),
                    value=dump.get(key),
                ),
            )

        return params

    @classmethod
    def _from_key_info[
        T
    ](cls, key: str, field_type: type[T], info: _PydanticProperty, value: T) -> Self:
        if value is None:
            return cls(key=key)

        if info.any_of is not None:
            # Take first index as unions should be annotated with None as
            # the second value. Unions of T and None are supported, any other
            # union is unsupported.
            return cls._from_key_info(
                key,
                field_type=field_type,
                info=info.any_of[0],
                value=value,
            )

        unwrapped_type = _unwrap_type(field_type)

        attr: str = {
            types.Numeric: "numeric",
            types.BigNumeric: "bignumeric",
            decimal.Decimal: "numeric",
            types.Geography: "geography",
            types.Timestamp: "timestamp",
            datetime.datetime: "datetime",
            datetime.date: "date",
            bool: "bool",
            int: "int",
            float: "float",
            str: "string",
        }.get(unwrapped_type, "") + "_value"

        if attr == "_value":
            raise bigquery.UnsupportedTypeError(f"Unsupported type {unwrapped_type}")

        param = cls(key=key)
        setattr(param, attr, value)
        return param

    def as_kv(self) -> dict[str, Any]:
        """
        Return the param as a `{key: value}` pair with the set value (or None).

        Returns
        -------
        The param serialized as `{key: value}`.

        Raises
        ------
        ValueError: If more than one value is set per key.

        """
        values = [
            value
            for key, value in self.model_dump().items()
            if key != "key"
            if value is not None
        ]

        if len(values) <= 1:
            return {self.key: next(iter(values), None)}

        raise ValueError(f"Multiple values set for {self.key}: {values}")

    @staticmethod
    def parse_into[T: pydantic.BaseModel](model: type[T], params: Iterable[Param]) -> T:
        """
        Construct the provided Pydantic model from some `Param`s.

        The provided `Param`s must have keys matching the model's properties, and are
        provided to the provided pydantic model's `model_validate` method.

        Returns
        -------
        A new instance of the model, if validation was successful.

        """
        return model.model_validate(
            {key: value for param in params for key, value in param.as_kv().items()},
        )
