import inspect
import json
import os
from datetime import datetime
from enum import Enum
from typing import (
    Any,
    Callable,
    Dict,
    List,
    Literal,
    Optional,
    Sequence,
    Union,
    get_args,
    get_origin,
)

from pydantic import BaseModel, field_validator, model_validator
from pyspark.sql.dataframe import DataFrame

CURRENT_FILE_DIR = os.path.dirname(os.path.abspath(__file__))


DATE_FORMATTER = {
    "year": lambda x: str(x.year),
    "month": lambda x: str(x.month).zfill(2),
    "day": lambda x: str(x.day).zfill(2),
}


class ReadPreviousDayException(Exception):
    """
    Exception raised when the dataframe from the previous day is not found.
    """

    def __init__(self, message: str = "Dataframe from the previous day not found"):
        self.message = message
        super().__init__(self.message)


class EmptyDataFrameException(Exception):
    """
    Exception raised when the dataframe is empty.
    """

    def __init__(self, message: str = "Dataframe is empty"):
        self.message = message
        super().__init__(self.message)


class FileType(str, Enum):
    """
    Represents the type of file used in data pipelines.

    Attributes:
        CSV (str): Represents a CSV file.
        JSON (str): Represents a JSON file.
        PARQUET (str): Represents a Parquet file.
    """

    CSV = "csv"
    JSON = "json"
    PARQUET = "parquet"


class Layer(str, Enum):
    """
    Enum class representing different layers in a data pipeline.

    The available layers are:
    - LANDING: The landing layer where raw data is initially ingested.
    - RAW: The raw layer where the ingested data is stored without any transformations.
    - TRUSTED: The trusted layer where the data is cleaned and validated.
    - ANALYTICS: The analytics layer where the data is transformed and aggregated for analysis.
    """

    LANDING = "landing"
    RAW = "raw"
    TRUSTED = "trusted"
    ANALYTICS = "analytics"


class Env(str, Enum):
    """
    Enumeration representing different environments.

    Attributes:
        DEV (str): Development environment.
        PRD (str): Production environment.
    """

    DEV = "dev"
    PRD = "prd"


class DatePartition(str, Enum):
    """
    Enum representing different types of date partitions.

    Attributes:
        YEAR (str): Represents a partition by year, with name "year".
        MONTH (str): Represents a partition by month, with name "month".
        DAY (str): Represents a partition by day, with name "day".
    """

    YEAR = "year"
    MONTH = "month"
    DAY = "day"


class FnKind(str, Enum):
    """
    Enumeration representing the kind of function.

    Attributes:
        SINGLE (str): Represents a single function.
        BATCH (str): Represents a batch function.
    """

    SINGLE = "single"
    BATCH = "batch"


class BaseModelJsonDumps(BaseModel):
    """A base model class that extends Pydantic's BaseModel by overriding the __str__ to pretty print the model as JSON."""

    def __str__(self):
        return json.dumps(self.model_dump(), indent=2, default=str)


class AWSCredentials(BaseModelJsonDumps):
    """
    Represents AWS credentials required for authentication.

    Attributes:
        aws_access_key_id (str): The AWS access key ID.
        aws_secret_access_key (str): The AWS secret access key.
    """

    aws_access_key_id: str
    aws_secret_access_key: str


class S3BucketParams(BaseModelJsonDumps):
    """
    Represents the parameters for an S3 bucket, most specially the bucket name.

    Attributes:
        env (Env): The environment. Must be one of "Env.DEV" or "Env.PRD".
        layer (Optional[Layer]): The layer (optional).
        ng_prefix (str): The NG prefix. Defaults to "ng".
        bucket_name (Optional[str]): The bucket name (optional). If not provided, it will be generated based on the layer, environment, and NG prefix.
    """

    env: Env = Env(os.environ.get("ENV"))
    layer: Optional[Layer] = None
    ng_prefix: str = "ng"
    bucket_name: Optional[str] = None

    @model_validator(mode="before")
    def set_env_and_bucket_name_if_not_provided(cls, data):
        env = data.get("env")
        if env is None:
            env = os.environ.get("ENV")

            if env is None:
                raise ValueError(
                    "If 'env' is not provided, the environment variable 'ENV' must be set"
                )

            data["env"] = Env(env)

        layer = data.get("layer").value if data.get("layer") is not None else None
        env = data.get("env").value
        ng_prefix = data.get("ng_prefix", "ng")
        bucket_name = data.get("bucket_name")

        if bucket_name is None:
            if layer is None:
                raise ValueError(
                    "'layer' must be provided if 'bucket_name' is not provided"
                )
            data["bucket_name"] = f"{ng_prefix}-datalake-{layer}-{env}"

        return data


class S3ReadSchemaParams(BaseModelJsonDumps):
    """
    Represents the parameters for reading a schema from an S3 bucket.

    Attributes:
        bucket_params (S3BucketParams): The parameters for the S3 bucket.
        path (str): The path to the schema file in the S3 bucket.

    Methods:
        strip_slashes(v: str) -> str: A class method that strips leading and trailing slashes from the path.
        ensure_path_to_file_has_json_extension(v: str) -> str: A class method that ensures the path has a '.json' extension.
    """

    bucket_params: S3BucketParams
    path: str

    @field_validator("path")
    @classmethod
    def strip_slashes(cls, v: str) -> str:
        return v.strip("/")

    @field_validator("path")
    @classmethod
    def ensure_path_to_file_has_json_extension(cls, v: str) -> str:
        if not v.endswith(".json"):
            raise ValueError(
                f"For S3 schema, 'path' must have a '.json' extension. Received path: '{v}'"
            )

        return v


class DataFrameBaseParams(BaseModelJsonDumps):
    """
    Represents the base parameters for a DataFrame.

    Attributes:
        dataframe_bucket_params (S3BucketParams): The bucket parameters for the DataFrame.
        dataframe_specific_paths (Optional[Union[List[str], str]]): The specific path for the DataFrame. Defaults to None.
        dataframe_base_path (Optional[str]): The base path for the DataFrame. Defaults to None.
        dataframe_file_type (FileType): The file type of the DataFrame.
    """

    dataframe_bucket_params: S3BucketParams
    dataframe_base_path: Optional[str] = None
    dataframe_file_type: FileType

    @field_validator("dataframe_base_path")
    def strip_and_ensure_dataframe_base_path_parent_folder(cls, v: str) -> str:
        if v is not None:
            v = v.strip("/")

        if "/" not in v:
            raise ValueError(
                f"'dataframe_base_path' should have at least one parent folder inside the bucket. Ensure there is at least one slash ('/') in the path. Received path: '{v}'"
            )

        return v


class ReadDateParams(BaseModelJsonDumps):
    """
    Represents the parameters for reading dates in a data pipeline.
    Attributes:
        read_dates (Union[List[datetime], datetime, Literal["{{processing_date}}"], Literal["{{processing_date_previous}}"]]):
            The dates to be read. It can be a list of datetimes, a single datetime, or special placeholders for processing dates.
            The default value is "{{processing_date}}".
        processing_date_offset_days (int):
            The offset to be applied to the processing date. It represents the number of days to add or subtract from the processing date.
            The default value is 0.
        date_partitions (dict[DatePartition, str]):
            A dictionary that maps DatePartition enum values to their corresponding string representations.
            The default value is:
            {
                DatePartition.YEAR: "year",
                DatePartition.MONTH: "month",
                DatePartition.DAY: "day",
            }
        add_all_date_columns (bool):
            Whether to add all date columns to the DataFrame. If set to True, the DataFrame will have columns for year, month, and day.
            The default value is True.
    """

    read_dates: Union[
        Sequence[Union[Union[str, datetime], List[Union[str, datetime]]]],
        datetime,
        Literal["{{processing_date}}"],
    ] = "{{processing_date}}"
    processing_date_offset_days: Optional[int] = None
    was_offset_applied: bool = False
    date_partitions: dict[DatePartition, str] = {
        DatePartition.YEAR: "year",
        DatePartition.MONTH: "month",
        DatePartition.DAY: "day",
    }
    add_all_date_columns: bool = True

    @model_validator(mode="before")
    def validate_offset_and_read_dates(cls, data):
        read_dates = data.get("read_dates")
        processing_date_offset_days = data.get("processing_date_offset_days")

        if processing_date_offset_days is not None:
            if read_dates != "{{processing_date}}":
                raise ValueError(
                    "If 'processing_date_offset_days' is provided, 'read_dates' must be '{{processing_date}}'"
                )

        return data

    @field_validator("date_partitions")
    def ensure_all_date_partitions(
        cls, v: dict[DatePartition, str]
    ) -> dict[DatePartition, str]:
        """
        Ensures that all date partitions are present in the input dictionary.

        Args:
            v (dict[DatePartition, str]): The input dictionary.

        Returns:
            dict[DatePartition, str]: The input dictionary.

        Raises:
            ValueError: If any date partition is missing.
        """
        date_partitions = [DatePartition.YEAR, DatePartition.MONTH, DatePartition.DAY]

        for date_partition in date_partitions:
            if date_partition not in v:
                raise ValueError(
                    f"Missing date partition: '{date_partition}'. The dictionary must contain all date partitions: {date_partitions}"
                )

        return v


class InputDataFrameParams(DataFrameBaseParams):
    """
    Parameters for input data frames.

    Attributes:
        pyspark_schema_struct (Optional[Dict[str, Any]]): The schema of the input data frame in PySpark StructType format.
        s3_schema_path_params (Optional[S3ReadSchemaParams]): The parameters for reading the schema from an S3 path.
        read_date_params (Optional[ReadDateParams]): The parameters for reading the date from the input data frame.

    Methods:
        check_schema_mode(cls, data): Check the schema mode and validate the input data.

    Raises:
        ValueError: If 'pyspark_schema_struct' and 's3_schema_path_params' are passed together.
    """

    dataframe_specific_paths: Optional[Union[List[str], str]] = None
    pyspark_schema_struct: Optional[Dict[str, Any]] = None
    s3_schema_path_params: Optional[S3ReadSchemaParams] = None
    read_date_params: Optional[ReadDateParams] = None

    @model_validator(mode="before")
    def check_schema_mode(cls, data):
        """
        Check the schema mode and validate the input data.

        Args:
            cls: The class object.
            data: The input data dictionary.

        Returns:
            The validated input data dictionary.

        Raises:
            ValueError: If 'pyspark_schema_struct' and 's3_schema_path_params' are passed together.
        """
        pyspark_schema_struct = data.get("pyspark_schema_struct")
        s3_schema_path_params = data.get("s3_schema_path_params")

        if pyspark_schema_struct is not None and s3_schema_path_params is not None:
            raise ValueError(
                "'pyspark_schema_struct' and 's3_schema_path_params' cannot be passed together"
            )

        return data

    @model_validator(mode="before")
    def xor_specific_path_and_dataframe_base_path(cls, data):
        dataframe_specific_paths = data.get("dataframe_specific_paths")
        dataframe_base_path = data.get("dataframe_base_path")

        if dataframe_specific_paths is not None and dataframe_base_path is not None:
            raise ValueError(
                "'dataframe_specific_paths' and 'dataframe_base_path' cannot be passed together"
            )

        if dataframe_specific_paths is None and dataframe_base_path is None:
            raise ValueError(
                "Either 'dataframe_specific_paths' or 'dataframe_base_path' should be passed"
            )

        return data

    @model_validator(mode="before")
    def xor_specific_path_and_read_date_params(cls, data):
        dataframe_specific_paths = data.get("dataframe_specific_paths")
        read_date_params = data.get("read_date_params")

        if dataframe_specific_paths is not None and read_date_params is not None:
            raise ValueError(
                "'dataframe_specific_paths' and 'read_date_params' cannot be passed together"
            )

        return data

    @field_validator("dataframe_specific_paths")
    def strip_dataframe_specific_paths(cls, v: str) -> str:
        if v is not None:
            if isinstance(v, list):
                return [x.strip("/") for x in v]  # type: ignore
            else:
                return v.strip("/")


class SingleWriteDateParams(BaseModelJsonDumps):
    """
    Represents the parameters for a single write date.

    Attributes:
        single_write_date (Union[datetime, Literal["{{processing_date}}"]]):
            The single write date. It can be either a datetime object or the string "{{processing_date}}".
        single_write_date_partitions (List[str]):
            The list of partitions to be used for the single write date. Defaults to ["year", "month", "day"].
    """

    single_write_date: Union[datetime, Literal["{{processing_date}}"]] = (
        "{{processing_date}}"
    )
    single_write_date_partitions: dict[DatePartition, str] = {
        DatePartition.YEAR: "year",
        DatePartition.MONTH: "month",
        DatePartition.DAY: "day",
    }

    @field_validator("single_write_date_partitions")
    def ensure_all_date_partitions(
        cls, v: dict[DatePartition, str]
    ) -> dict[DatePartition, str]:
        """
        Ensures that all date partitions are present in the input dictionary.

        Args:
            v (dict[DatePartition, str]): The input dictionary.

        Returns:
            dict[DatePartition, str]: The input dictionary.

        Raises:
            ValueError: If any date partition is missing.
        """
        date_partitions = [DatePartition.YEAR, DatePartition.MONTH, DatePartition.DAY]

        for date_partition in date_partitions:
            if date_partition not in v:
                raise ValueError(
                    f"Missing date partition: '{date_partition}'. The dictionary must contain all date partitions: {date_partitions}"
                )

        return v


class OutputDataFrameParams(DataFrameBaseParams):
    """
    Parameters for writing output dataframes.

    Attributes:
        write_schema_on_s3 (bool): Whether to write the schema on S3.
        overwrite (bool): Whether to overwrite existing data.
        single_write_date (Optional[SingleWriteDateParams]): Parameters for single write date.
        partition_by (List[str]): List of columns to partition the data by.
    """

    dataframe_specific_path: Optional[str] = None
    write_schema_on_s3: bool = False
    overwrite: bool = False
    dataframe_file_type: FileType = FileType.PARQUET
    single_write_date: Optional[SingleWriteDateParams] = None
    partition_by: Optional[List[str]] = None

    @model_validator(mode="before")
    def xor_specific_path_and_dataframe_base_path(cls, data):
        dataframe_specific_path = data.get("dataframe_specific_path")
        dataframe_base_path = data.get("dataframe_base_path")

        if dataframe_specific_path is not None and dataframe_base_path is not None:
            raise ValueError(
                "'dataframe_specific_path' and 'dataframe_base_path' cannot be passed together"
            )

        if dataframe_specific_path is None and dataframe_base_path is None:
            raise ValueError(
                "Either 'dataframe_specific_path' or 'dataframe_base_path' should be passed"
            )

        return data


class FnIndirect(BaseModelJsonDumps):
    fn_name: str
    fn_path: str


class TransformParams(BaseModelJsonDumps):
    """
    Represents the parameters for a data transformation.

    Attributes:
        transform_label (str): The label for the transformation.
        transform_function (Callable): The transformation function. Must have a (pyspark) DataFrame type hint as first parameter and return type hint, or a dict of [str, DataFrame] as first parameter hint and return type hint.
        fn_kwargs (Optional[dict]): Additional keyword arguments for the transformation function.
        apply_only_on (Optional[List[str]]): A list of target dataframes to apply the transformation on.
    """

    transform_label: str
    transform_function: Callable
    fn_kwargs: Optional[dict] = None
    apply_only_on: Optional[List[str]] = None
    fn_indirect: Optional[FnIndirect] = None
    fn_kind: Optional[FnKind] = None

    @model_validator(mode="before")
    def validate_transform_function_indirect_and_apply_only_on(cls, data):
        fn_indirect = data.get("fn_indirect")
        transform_function = data.get("transform_function")
        apply_only_on = data.get("apply_only_on")

        def get_annotation_kind(annotation):
            origin = get_origin(annotation)
            args = get_args(annotation)

            if origin is dict and args[0] is str and issubclass(args[1], DataFrame):
                return FnKind.BATCH
            elif issubclass(annotation, DataFrame):
                return FnKind.SINGLE
            return None

        if fn_indirect is not None and transform_function is not None:
            raise ValueError(
                "'fn_indirect' and 'transform_function' cannot be passed together"
            )

        if fn_indirect is None and transform_function is None:
            raise ValueError(
                "Either 'fn_indirect' or 'transform_function' should be passed"
            )

        if fn_indirect:
            raise NotImplementedError("fn_indirect is not implemented yet")

        signature = inspect.signature(transform_function)
        parameters = signature.parameters
        first_param_annotation = (
            list(parameters.values())[0].annotation if parameters else None
        )
        return_annotation = signature.return_annotation

        first_param_kind = get_annotation_kind(first_param_annotation)
        return_kind = get_annotation_kind(return_annotation)

        if (first_param_kind is None or return_kind is None) or (
            first_param_kind != return_kind
        ):
            raise ValueError(
                "Function must have a DataFrame type hint as first parameter and return type hint, or a dict of [str, DataFrame] as first parameter hint and return type hint"
            )

        if first_param_kind != FnKind.SINGLE and apply_only_on is not None:
            raise ValueError(
                "'apply_only_on' is only accepted when 'transform_function' is of type 'single', that is, when the first parameter is a DataFrame type hint and the return type hint is also a DataFrame type hint."
            )

        # If the function has a DataFrame as first parameter, then it is a single function. If it has a dict of [str, DataFrame] as first parameter, then it is a batch function
        # TODO: This should be a private attribute
        data["fn_kind"] = first_param_kind

        return data


DataFrameDict = Dict[str, DataFrame]
InputDataFrameParamsDict = Dict[str, InputDataFrameParams]
OutputDataFrameParamsDict = Dict[str, OutputDataFrameParams]
TransformParamsDict = Dict[str, TransformParams]


class StepParams(BaseModelJsonDumps):
    """
    Represents the parameters for a step in a data pipeline.

    Attributes:
        input_dataframes_params (InputDataFrameParamsDict): The parameters for input dataframes.
        transform_params (Optional[TransformParamsDict]): The parameters for the transformation step. Defaults to None.
        output_dataframes_params (Optional[OutputDataFrameParamsDict]): The parameters for output dataframes. Defaults to None.
    """

    input_dataframes_params: InputDataFrameParamsDict
    transform_params: Optional[TransformParamsDict] = None
    output_dataframes_params: Optional[OutputDataFrameParamsDict] = None


StepParamsDict = Dict[str, StepParams]
