"""SQLAlchemy definition of the pvsite database schema."""

from __future__ import annotations

# This means we can use Typing of objects that have jet to be defined
import uuid
from datetime import datetime
from typing import List

import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import declarative_base, relationship
from sqlalchemy.schema import UniqueConstraint

Base = declarative_base()


class CreatedMixin:
    """Mixin to add created datetime to model."""

    created_utc = sa.Column(sa.DateTime, default=lambda: datetime.utcnow())


class SiteSQL(Base, CreatedMixin):
    """Class representing the sites table.

    Each site row specifies a single panel or cluster of panels
    found on a residential house or commercial building. Their
    data is provided by a client.

    *Approximate size: *
    4 clients * ~1000 sites each = ~4000 rows
    """

    __tablename__ = "sites"

    site_uuid = sa.Column(UUID(as_uuid=True), default=uuid.uuid4, primary_key=True)
    client_uuid = sa.Column(
        UUID(as_uuid=True),
        sa.ForeignKey("clients.client_uuid"),
        nullable=False,
        comment="The internal ID of the client providing the site data",
    )
    client_site_id = sa.Column(
        sa.Integer, index=True, comment="The ID of the site as given by the providing client"
    )
    client_site_name = sa.Column(
        sa.String(255), index=True, comment="The ID of the site as given by the providing client"
    )

    region = sa.Column(sa.String(255), comment="The region in the UK in which the site is located")
    dno = sa.Column(sa.String(255), comment="The Distribution Node Operator that owns the site")
    gsp = sa.Column(sa.String(255), comment="The Grid Supply Point in which the site is located")

    # For metadata `NULL` means "we don't know".
    orientation = sa.Column(
        sa.Float, comment="The rotation of the panel in degrees. 180° points south"
    )
    tilt = sa.Column(
        sa.Float, comment="The tile of the panel in degrees. 90° indicates the panel is vertical"
    )
    latitude = sa.Column(sa.Float)
    longitude = sa.Column(sa.Float)
    capacity_kw = sa.Column(
        sa.Float, comment="The physical limit on the production capacity of the site"
    )

    ml_id = sa.Column(
        sa.Integer,
        autoincrement=True,
        nullable=False,
        unique=True,
        comment="Auto-incrementing integer ID of the site for use in ML training",
    )

    __table_args__ = (UniqueConstraint("client_site_id", client_uuid, name="idx_client"),)

    forecasts: List["ForecastSQL"] = relationship("ForecastSQL", back_populates="site")
    generation: List["GenerationSQL"] = relationship("GenerationSQL")
    client: ClientSQL = relationship("ClientSQL", back_populates="sites")


class GenerationSQL(Base, CreatedMixin):
    """Class representing the generation table.

    Each generation row specifies a generated power output over a
    given time range for a site.

    *Approximate size: *
    Generation populated every 5 minutes per site * 4000 sites = ~1,125,000 rows per day
    """

    __tablename__ = "generation"
    __table_args__ = (
        UniqueConstraint("site_uuid", "start_utc", "end_utc", name="uniq_cons_site_start_end"),
    )

    generation_uuid = sa.Column(UUID(as_uuid=True), default=uuid.uuid4, primary_key=True)
    site_uuid = sa.Column(
        UUID(as_uuid=True),
        sa.ForeignKey("sites.site_uuid"),
        nullable=False,
        index=True,
        comment="The site for which this geenration yield belongs to",
    )
    generation_power_kw = sa.Column(
        sa.Float,
        nullable=False,
        comment="The actual generated power in kW at this site for this datetime interval",
    )

    start_utc = sa.Column(
        sa.DateTime,
        nullable=False,
        index=True,
        comment="The start of the time interval over which this generated power value applies",
    )
    end_utc = sa.Column(
        sa.DateTime,
        nullable=False,
        comment="The end of the time interval over which this generated power value applies",
    )

    site: SiteSQL = relationship("SiteSQL", back_populates="generation")


class ForecastSQL(Base, CreatedMixin):
    """Class representing the forecasts table.

    Each forecast row refers to a sequence of predicted solar generation values
    over a set of target times for one site.

    *Approximate size: *
    One forecast per site every 5 minutes = ~1,125,000 rows per day
    """

    __tablename__ = "forecasts"

    forecast_uuid = sa.Column(UUID(as_uuid=True), default=uuid.uuid4, primary_key=True)
    site_uuid = sa.Column(
        UUID(as_uuid=True),
        sa.ForeignKey("sites.site_uuid"),
        nullable=False,
        index=True,
        comment="The site for which the forecast sequence was generated",
    )

    # The timestamp at which we are making the forecast. This is often referred as "now" in the
    # modelling code.
    # Note that this could be very different from the `created_utc` time, for instance if we
    # run the model for a given "now" timestamp in the past.
    timestamp_utc = sa.Column(
        sa.DateTime,
        nullable=False,
        index=True,
        comment="The creation time of the forecast sequence",
    )

    forecast_version = sa.Column(
        sa.String(32),
        nullable=False,
        comment="The semantic version of the model used to generate the forecast",
    )

    # one (forecasts) to many (forecast_values)
    forecast_values: List["ForecastValueSQL"] = relationship("ForecastValueSQL")
    site = relationship("SiteSQL", back_populates="forecasts")


class ForecastValueSQL(Base, CreatedMixin):
    """Class representing the forecast_values table.

    Each forecast_value row is a prediction for the power output
    of a site over a target datetime interval. Many predictions
    are made for each site at each target interval.

    *Approximate size: *
    One forecast value every 5 minutes per site per forecast.
    Each forecast's prediction sequence covers 24 hours of target
    intervals = ~324,000,000 rows per day
    """

    __tablename__ = "forecast_values"

    forecast_value_uuid = sa.Column(UUID(as_uuid=True), default=uuid.uuid4, primary_key=True)

    start_utc = sa.Column(
        sa.DateTime,
        nullable=False,
        index=True,
        comment="The start of the time interval over which this predicted power value applies",
    )
    end_utc = sa.Column(
        sa.DateTime,
        nullable=False,
        comment="The end of the time interval over which this predicted power value applies",
    )
    forecast_power_kw = sa.Column(
        sa.Float,
        nullable=False,
        comment="The predicted power generation of this site for the given time interval",
    )

    # This is the different between `start_utc` and the `forecast`'s `timestamp_utc`, in minutes.
    # It's useful to have it in its own column to efficiently query forecasts for a given horizon.
    # TODO Set to nullable=False
    horizon_minutes = sa.Column(
        sa.Integer,
        nullable=True,
        index=True,
        comment="The time difference between the creation time of the forecast value "
        "and the start of the time interval it applies for",
    )

    forecast_uuid = sa.Column(
        UUID(as_uuid=True),
        sa.ForeignKey("forecasts.forecast_uuid"),
        nullable=False,
        index=True,
        comment="The forecast sequence this forcast value belongs to",
    )

    forecast: ForecastSQL = relationship("ForecastSQL", back_populates="forecast_values")


class ClientSQL(Base, CreatedMixin):
    """Class representing the clients table.

    Each client row defines a provider of site data

    *Approximate size: *
    One row per client = ~4 rows
    """

    __tablename__ = "clients"

    client_uuid = sa.Column(UUID(as_uuid=True), default=uuid.uuid4, primary_key=True)
    client_name = sa.Column(sa.String(255), nullable=False, comment="The name of the client")

    sites: List[SiteSQL] = relationship("SiteSQL")


class StatusSQL(Base, CreatedMixin):
    """Class representing the status table.

    Each status row defines a message reporting on the status of the
    services within the nowcasting domain

    *Approximate size: *
    ~1 row per day
    """

    __tablename__ = "status"

    status_uuid = sa.Column(UUID(as_uuid=True), default=uuid.uuid4, primary_key=True)
    status = sa.Column(sa.String(255))
    message = sa.Column(sa.String(255))
