from pathlib import Path
from typing import Dict, Optional, Union

import anndata as ad
import lamindb as ln
from lnschema_core.types import FieldAttr
from pandas.core.api import DataFrame as DataFrame

from ._register import register_artifact, register_labels
from ._validate import validate_anndata
from ._validator import Validator


class AnnDataValidator(Validator):
    """Lamin AnnData validator.

    Args:
        adata: The AnnData object to validate.
        var_field (str): The registry field to validate variables index against.
        obs_fields (dict): A dictionary mapping obs_column to registry_field.
            For example:
            {"cell_type_ontology_id": bt.CellType.ontology_id, "donor_id": ln.ULabel.name}
        using (str): The reference instance containing registries to validate against.
    """

    def __init__(
        self,
        adata: Union[ad.AnnData, str, Path],
        var_field: FieldAttr,
        obs_fields: Dict[str, FieldAttr],
        using: str = "default",
        verbosity: str = "hint",
    ) -> None:
        self._adata = ad.read_h5ad(adata) if isinstance(adata, (str, Path)) else adata
        super().__init__(
            df=self._adata.obs, fields=obs_fields, using=using, verbosity=verbosity
        )
        self._obs_fields = obs_fields
        self._var_field = var_field

    @property
    def var_field(self) -> FieldAttr:
        """Return the registry field to validate variables index against."""
        return self._var_field

    @property
    def obs_fields(self) -> Dict:
        """Return the obs fields to validate against."""
        return self._obs_fields

    def register_variables(self, organism: Optional[str] = None, **kwargs):
        """Register variable records.

        Args:
            organism: The name of the organism.
            **kwargs: Additional keyword arguments to pass to the registry model.
        """
        self._add_kwargs(
            organism=organism or self._kwargs.get("organism"),
            **kwargs,
        )
        register_labels(
            values=self._adata.var_names,
            field=self.var_field,
            using=self._using,
            kwargs=self._kwargs,
        )

    def validate(
        self,
        organism: Optional[str] = None,
        **kwargs,
    ) -> bool:
        """Validate variables and categorical observations.

        Args:
            organism: name of the organism
            **kwargs: object level metadata

        Returns:
            whether the AnnData object is validated
        """
        self._add_kwargs(
            organism=organism or self._kwargs.get("organism"),
            **kwargs,
        )
        self._validated = validate_anndata(
            self._adata,
            var_field=self.var_field,
            obs_fields=self.obs_fields,
            verbosity=self._verbosity,
            **self._kwargs,
        )

        return self._validated

    def register_artifact(
        self,
        description: str,
        **kwargs,
    ) -> ln.Artifact:
        """Register the validated AnnData and metadata.

        Args:
            description: description of the AnnData object
            **kwargs: object level metadata

        Returns:
            a registered artifact record
        """
        self._add_kwargs(**kwargs)
        if not self._validated:
            raise ValueError("please run `validate()` first!")

        self._artifact = register_artifact(
            self._adata,
            description=description,
            var_field=self.var_field,
            fields=self.obs_fields,
            **self._kwargs,
        )

        return self._artifact
