"""Personal Identifiable Information Anonymizer."""

import hashlib
import importlib
import inspect
import warnings
from collections.abc import Iterable
from copy import deepcopy
from operator import attrgetter

import faker
import numpy as np
import pandas as pd

from rdt.errors import TransformerInputError, TransformerProcessingError
from rdt.transformers.base import BaseTransformer
from rdt.transformers.categorical import LabelEncoder


class AnonymizedFaker(BaseTransformer):
    """Personal Identifiable Information Anonymizer using Faker.

    This transformer will drop a column and regenerate it with the previously specified
    ``Faker`` provider and ``function``.

    Args:
        provider_name (str):
            The name of the provider in ``Faker``. If ``None`` the ``BaseProvider`` is used.
            Defaults to ``None``.
        function_name (str):
            The name of the function to use within the ``faker.provider``. Defaults to
            ``lexify``.
        function_kwargs (dict):
            Keyword args to pass into the ``function_name`` when being called.
        locales (list):
            List of localized providers to use instead of the global provider.
        cardinality_rule (str):
            If ``'unique'`` enforce that every created value is unique.
            If ``'match'`` match the cardinality of the data seen during fit.
            If ``None`` do not consider cardinality.
            Defaults to ``None``.
        enforce_uniqueness (bool):
            **DEPRECATED** Whether or not to ensure that the new anonymized data is all unique.
            If it isn't possible to create the requested number of rows, then an error will be
            raised.
            Defaults to ``False``.
        missing_value_generation (str or None):
            The way missing values are being handled. There are two strategies:

                * ``random``: Randomly generates missing values based on the percentage of
                  missing values.
                * ``None``: Don't learn anything during fit. Then during reverse transform,
                  don't create any missing values.

    """

    # pylint: disable=too-many-instance-attributes

    IS_GENERATOR = True
    INPUT_SDTYPE = 'pii'

    @staticmethod
    def check_provider_function(provider_name, function_name):
        """Check that the provider and the function exist.

        Attempt to get the provider from ``faker.providers`` and then get the ``function``
        from the provider object. If one of them fails, it will raise an ``AttributeError``.

        Raises:
            ``AttributeError`` if the provider or the function is not found.
        """
        try:
            module = attrgetter(provider_name)(faker.providers)
            if provider_name.lower() == 'baseprovider':
                getattr(module, function_name)

            else:
                provider = getattr(module, 'Provider')
                getattr(provider, function_name)

        except AttributeError as exception:
            raise TransformerProcessingError(
                f"The '{provider_name}' module does not contain a function named "
                f"'{function_name}'.\nRefer to the Faker docs to find the correct function: "
                'https://faker.readthedocs.io/en/master/providers.html'
            ) from exception

    def _check_locales(self):
        """Check if the locales exist for the provided provider."""
        locales = self.locales if isinstance(self.locales, list) else [self.locales]
        missed_locales = []
        for locale in locales:
            provider_name = self.provider_name
            if self.provider_name.endswith(f'.{locale}'):
                provider_name = self.provider_name.replace(f'.{locale}', '')

            spec = importlib.util.find_spec(f'faker.providers.{provider_name}.{locale}')
            if spec is None and locale != 'en_US':
                missed_locales.append(locale)

        if missed_locales:
            warnings.warn(
                f"Locales {missed_locales} do not support provider '{self.provider_name}' "
                f"and function '{self.function_name}'.\nIn place of these locales, 'en_US' will "
                'be used instead. Please refer to the localized provider docs for more '
                'information: https://faker.readthedocs.io/en/master/locales.html'
            )

    def __init__(
        self,
        provider_name=None,
        function_name=None,
        function_kwargs=None,
        locales=None,
        cardinality_rule=None,
        enforce_uniqueness=False,
        missing_value_generation='random',
    ):
        super().__init__()
        self._data_cardinality = None
        self.data_length = None
        self.enforce_uniqueness = enforce_uniqueness
        self.cardinality_rule = cardinality_rule.lower() if cardinality_rule else None
        if enforce_uniqueness:
            warnings.warn(
                "The 'enforce_uniqueness' parameter is no longer supported. "
                "Please use the 'cardinality_rule' parameter instead.",
                FutureWarning,
            )
            if not self.cardinality_rule:
                self.cardinality_rule = 'unique'

        self.provider_name = provider_name if provider_name else 'BaseProvider'
        if self.provider_name != 'BaseProvider' and function_name is None:
            raise TransformerInputError(
                f"Please specify the function name to use from the '{self.provider_name}' provider."
            )

        self.function_name = function_name if function_name else 'lexify'
        self.function_kwargs = deepcopy(function_kwargs) if function_kwargs else {}
        self.check_provider_function(self.provider_name, self.function_name)
        self.output_properties = {None: {'next_transformer': None}}

        self._faker_random_seed = None
        self.locales = locales
        self.faker = faker.Faker(self.locales)
        if self.provider_name != 'BaseProvider' and self.locales:
            self._check_locales()

        if missing_value_generation not in ['random', None]:
            raise TransformerInputError(
                f"Missing value generation '{missing_value_generation}' is not supported "
                "for AnonymizedFaker. Please use either 'random' or None."
            )

        self.missing_value_generation = missing_value_generation
        self._nan_frequency = 0.0

    @classmethod
    def get_supported_sdtypes(cls):
        """Return the supported sdtypes by the transformer.

        Returns:
            list:
                Accepted input sdtypes of the transformer.
        """
        unsupported_sdtypes = {
            'numerical',
            'datetime',
            'categorical',
            'boolean',
            None,
        }
        all_sdtypes = {cls.INPUT_SDTYPE}
        for transformer in BaseTransformer.get_subclasses():
            if not issubclass(transformer, cls):
                all_sdtypes.update(transformer.get_supported_sdtypes())

        supported_sdtypes = all_sdtypes - unsupported_sdtypes
        return list(supported_sdtypes)

    def reset_randomization(self):
        """Create a new ``Faker`` instance."""
        super().reset_randomization()
        self.faker = faker.Faker(self.locales)
        self.faker.seed_instance(self._faker_random_seed)

    def _function(self):
        """Return the result of calling the ``faker`` function."""
        try:
            if self.cardinality_rule in {'unique', 'match'}:
                faker_attr = self.faker.unique
            else:
                faker_attr = self.faker
        except AttributeError:
            faker_attr = self.faker.unique if self.enforce_uniqueness else self.faker

        result = getattr(faker_attr, self.function_name)(**self.function_kwargs)

        if isinstance(result, Iterable) and not isinstance(result, str):
            result = ', '.join(map(str, result))

        return result

    def _set_faker_seed(self, data):
        hash_value = self.get_input_column()
        for value in data.head(5):
            hash_value += str(value)

        hash_value = int(hashlib.sha256(hash_value.encode('utf-8')).hexdigest(), 16)
        self._faker_random_seed = hash_value % ((2**32) - 1)  # maximum value for a seed
        self.faker.seed_instance(self._faker_random_seed)

    def _fit(self, data):
        """Fit the transformer to the data.

        Args:
            data (pandas.Series):
                Data to fit to.
        """
        self._set_faker_seed(data)
        self.data_length = len(data)
        if self.missing_value_generation == 'random':
            self._nan_frequency = data.isna().sum() / len(data)

        if self.cardinality_rule == 'match':
            # remove nans from data
            self._data_cardinality = len(data.dropna().unique())

    def _transform(self, _data):
        """Drop the input column by returning ``None``."""
        return None

    def _get_unique_categories(self, samples):
        return np.array([self._function() for _ in range(samples)], dtype=object)

    def _reverse_transform_cardinality_rule_match(self, sample_size):
        """Reverse transform the data when the cardinality rule is 'match'."""
        reverse_transformed = np.array([], dtype=object)
        if self.missing_value_generation == 'random':
            num_nans = int(self._nan_frequency * sample_size)
            reverse_transformed = np.concatenate([
                reverse_transformed,
                np.full(num_nans, np.nan),
            ])
        else:
            num_nans = 0

        if sample_size <= num_nans:
            return reverse_transformed

        if sample_size < num_nans + self._data_cardinality:
            unique_categories = self._get_unique_categories(sample_size - num_nans)
            reverse_transformed = np.concatenate([
                reverse_transformed,
                unique_categories,
            ])
        else:
            unique_categories = self._get_unique_categories(self._data_cardinality)
            num_copies = sample_size - self._data_cardinality - num_nans
            copies = np.random.choice(unique_categories, num_copies)
            reverse_transformed = np.concatenate([
                reverse_transformed,
                unique_categories,
                copies,
            ])

        np.random.shuffle(reverse_transformed)

        return reverse_transformed

    def _reverse_transform(self, data):
        """Generate new anonymized data using a ``faker.provider.function``.

        Args:
            data (pd.Series or numpy.ndarray):
                Data to transform.

        Returns:
            np.array
        """
        if data is not None and len(data):
            sample_size = len(data)
        else:
            sample_size = self.data_length

        try:
            if hasattr(self, 'cardinality_rule') and self.cardinality_rule == 'match':
                reverse_transformed = self._reverse_transform_cardinality_rule_match(sample_size)
            else:
                reverse_transformed = np.array(
                    [self._function() for _ in range(sample_size)],
                    dtype=object,
                )

        except faker.exceptions.UniquenessException as exception:
            raise TransformerProcessingError(
                f'The Faker function you specified is not able to generate {sample_size} unique '
                'values. Please use a different Faker function for column '
                f"('{self.get_input_column()}')."
            ) from exception

        if self.missing_value_generation == 'random' and not pd.isna(reverse_transformed).any():
            num_nans = int(self._nan_frequency * sample_size)
            nan_indices = np.random.choice(sample_size, num_nans, replace=False)
            reverse_transformed[nan_indices] = np.nan

        return reverse_transformed

    def _set_fitted_parameters(self, column_name, nan_frequency=0.0, cardinality=None):
        """Manually set the parameters on the transformer to get it into a fitted state.

        Args:
            column_name (str):
                The name of the column to use for the transformer.
            nan_frequency (float):
                The fraction of values that should be replaced with nan values
                if self.missing_value_generation is 'random'.
            cardinality (int or None):
                The number of unique values to generate if cardinality rule is set to
                'match'.
        """
        self.reset_randomization()
        self.columns = [column_name]
        self.output_columns = [column_name]
        if self.cardinality_rule == 'match':
            if not cardinality:
                raise TransformerInputError(
                    'Cardinality "match" rule must specify a cardinality value.'
                )
        self._data_cardinality = cardinality
        self._nan_frequency = nan_frequency

    def __repr__(self):
        """Represent initialization of transformer as text.

        Returns:
            str:
                The name of the transformer followed by any non-default parameters.
        """
        class_name = self.__class__.get_name()
        custom_args = []
        args = inspect.getfullargspec(self.__init__)
        keys = args.args[1:]
        defaults = dict(zip(keys, args.defaults))
        keys.remove('enforce_uniqueness')
        instanced = {key: getattr(self, key) for key in keys}

        defaults['function_name'] = None
        for arg, value in instanced.items():
            if value and defaults[arg] != value and value != 'BaseProvider':
                value = f"'{value}'" if isinstance(value, str) else value
                custom_args.append(f'{arg}={value}')

        args_string = ', '.join(custom_args)
        return f'{class_name}({args_string})'


class PseudoAnonymizedFaker(AnonymizedFaker):
    """Pseudo-anonymization Transformer using Faker.

    This transformer anonymizes values that can be traced back to the original input by using
    a mapping. The transformer will generate a mapping with the previously specified
    ``Faker`` provider and ``function``.

    Args:
        provider_name (str):
            The name of the provider in ``Faker``. If ``None`` the ``BaseProvider`` is used.
            Defaults to ``None``.
        function_name (str):
            The name of the function to use within the ``faker.provider``. Defaults to
            ``lexify``.
        function_kwargs (dict):
            Keyword args to pass into the ``function_name`` when being called.
        locales (list):
            List of localized providers to use instead of the global provider.
    """

    def __getstate__(self):
        """Return a dictionary representation of the instance and warn the user when pickling."""
        warnings.warn(
            (
                'You are saving the mapping information, which includes the original data. '
                'Sharing this object with others will also give them access to the original data '
                'used with this transformer.'
            )
        )

        return self.__dict__

    def __init__(
        self,
        provider_name=None,
        function_name=None,
        function_kwargs=None,
        locales=None,
    ):
        super().__init__(
            provider_name=provider_name,
            function_name=function_name,
            function_kwargs=function_kwargs,
            locales=locales,
            cardinality_rule='unique',
        )
        self._mapping_dict = {}
        self._reverse_mapping_dict = {}
        self.output_properties = {
            None: {
                'sdtype': 'categorical',
                'next_transformer': LabelEncoder(add_noise=True),
            }
        }

    def get_mapping(self):
        """Return the mapping dictionary."""
        return deepcopy(self._mapping_dict)

    def _fit(self, columns_data):
        """Fit the transformer to the data.

        Generate a ``_mapping_dict`` and a ``_reverse_mapping_dict`` for each
        value in the provided ``columns_data`` using the ``Faker`` provider and
        ``function``.

        Args:
            data (pandas.Series):
                Data to fit the transformer to.
        """
        self._set_faker_seed(columns_data)
        unique_values = columns_data[columns_data.notna()].unique()
        unique_data_length = len(unique_values)
        try:
            generated_values = [self._function() for _ in range(unique_data_length)]
        except faker.exceptions.UniquenessException as exception:
            raise TransformerProcessingError(
                'The Faker function you specified is not able to generate '
                f'{unique_data_length} unique values. Please use a different '
                'Faker function for this column.'
            ) from exception

        generated_values = list(set(generated_values))
        self._mapping_dict = dict(zip(unique_values, generated_values))
        self._reverse_mapping_dict = dict(zip(generated_values, unique_values))

    def _transform(self, columns_data):
        """Replace each category with a numerical representation.

        Map the input ``columns_data`` using the previously generated values for each one.
        If the  ``columns_data`` contain unknown values, an error will be raised with the
        unknown categories.

        Args:
            data (pandas.Series):
                Data to transform.

        Returns:
            pd.Series
        """
        unique_values = columns_data[columns_data.notna()].unique()
        new_values = list(set(unique_values) - set(self._mapping_dict))
        if new_values:
            new_values = [str(value) for value in new_values]
            if len(new_values) < 5:
                new_values = ', '.join(new_values)
                error_msg = (
                    'The data you are transforming has new, unexpected values '
                    f'({new_values}). Please fit the transformer again using this '
                    'new data.'
                )
            else:
                diff = len(new_values) - 5
                new_values = ', '.join(new_values[:5])
                error_msg = (
                    'The data you are transforming has new, unexpected values '
                    f'({new_values} and {diff} more). Please fit the transformer again '
                    'using this new data.'
                )

            raise TransformerProcessingError(error_msg)

        mapped_data = columns_data.map(self._mapping_dict)
        return mapped_data

    def _reverse_transform(self, columns_data):
        """Return the input data.

        Args:
            data (pd.Series or numpy.ndarray):
                Data to revert.

        Returns:
            pandas.Series
        """
        return columns_data
