from abc import ABC
from typing import Callable, Optional, Union, Tuple, Any, Iterable

from sakkara.model.base import ModelComponent
from sakkara.model.composable.base import T
from sakkara.model.fixed.base import FixedValueComponent
from sakkara.model.composable.hierarchical.base import HierarchicalComponent


class DistributionComponent(HierarchicalComponent[T], ABC):
    """
    Class for components whose variable is generated by a PyMC distribution


    :param generator: PyMC callable for the distribution to use
    :type generator: Callable

    :param name: Name of the corresponding variable to register in PyMC, defaults to None
    :type name: str

    :param group: Group of which the component is defined for, defaults to None
    :type group: Union[str, Tuple[str, ...]]

    :param members: Subset of members of column the component is defined for, defaults to None
    :type members: Iterable[Any]
    
    :param \**subcomponents: Underlying components/objects passed as parameters to PyMC distribution, should correspond to keyword of `generator`

    **Example**

    .. highlight:: python
    .. code-block:: python

        import pymc as pm
        from sakkara.model import DistributionComponent as DC
        sigma_comp = DC(pm.HalfNormal)
        n = DC(pm.Normal, sigma=sigma_comp)

    """
    def __init__(self, generator: Callable, name: Optional[str] = None, group: Union[str, Tuple[str, ...]] = None,
                 members: Optional[Iterable[Any]] = None, **subcomponents: Any):
        super().__init__(name, group, members,
                         subcomponents={k: v if isinstance(v, ModelComponent) else FixedValueComponent(v) for k, v in
                                        subcomponents.items()})
        self.generator = generator

    def build_variable(self) -> None:
        self.variable = self.generator(self.name, **self.get_built_components(), shape=self.node.get_members().shape,
                                       dims=self.dims())
