from evalml.pipelines.components.transformers import Transformer
from evalml.pipelines.components.transformers.imputers.simple_imputer import (
    SimpleImputer
)
from evalml.utils.gen_utils import (
    _convert_to_woodwork_structure,
    _convert_woodwork_types_wrapper
)


class PerColumnImputer(Transformer):
    """Imputes missing data according to a specified imputation strategy per column"""
    name = 'Per Column Imputer'
    hyperparameter_ranges = {}

    def __init__(self, impute_strategies=None, default_impute_strategy="most_frequent", random_state=0, **kwargs):
        """Initializes a transformer that imputes missing data according to the specified imputation strategy per column."

        Arguments:
            impute_strategies (dict): Column and {"impute_strategy": strategy, "fill_value":value} pairings.
                Valid values for impute strategy include "mean", "median", "most_frequent", "constant" for numerical data,
                and "most_frequent", "constant" for object data types. Defaults to "most_frequent" for all columns.

                When impute_strategy == "constant", fill_value is used to replace missing data.
                Defaults to 0 when imputing numerical data and "missing_value" for strings or object data types.

            default_impute_strategy (str): Impute strategy to fall back on when none is provided for a certain column.
                Valid values include "mean", "median", "most_frequent", "constant" for numerical data,
                and "most_frequent", "constant" for object data types. Defaults to "most_frequent"
        """
        parameters = {"impute_strategies": impute_strategies,
                      "default_impute_strategy": default_impute_strategy}
        self.imputers = None
        self.default_impute_strategy = default_impute_strategy
        self.impute_strategies = impute_strategies or dict()

        if not isinstance(self.impute_strategies, dict):
            raise ValueError("`impute_strategies` is not a dictionary. Please provide in Column and {`impute_strategy`: strategy, `fill_value`:value} pairs. ")

        super().__init__(parameters=parameters,
                         component_obj=None,
                         random_state=random_state)

    def fit(self, X, y=None):
        """Fits imputers on input data

        Arguments:
            X (ww.DataTable, pd.DataFrame or np.ndarray): The input training data of shape [n_samples, n_features] to fit.
            y (ww.DataColumn, pd.Series, optional): The target training data of length [n_samples]. Ignored.

        Returns:
            self
        """
        X = _convert_to_woodwork_structure(X)
        X = _convert_woodwork_types_wrapper(X.to_dataframe())
        self.imputers = dict()
        for column in X.columns:
            strategy_dict = self.impute_strategies.get(column, dict())
            strategy = strategy_dict.get('impute_strategy', self.default_impute_strategy)
            fill_value = strategy_dict.get('fill_value', None)
            self.imputers[column] = SimpleImputer(impute_strategy=strategy, fill_value=fill_value)

        for column, imputer in self.imputers.items():
            imputer.fit(X[[column]])

        return self

    def transform(self, X, y=None):
        """Transforms input data by imputing missing values.

        Arguments:
            X (ww.DataTable, pd.DataFrame or np.ndarray): The input training data of shape [n_samples, n_features] to transform.
            y (ww.DataColumn, pd.Series, optional): The target training data of length [n_samples]. Ignored.

        Returns:
            pd.DataFrame: Transformed X
        """
        X = _convert_to_woodwork_structure(X)
        X = _convert_woodwork_types_wrapper(X.to_dataframe())
        X_t = X.copy()
        cols_to_drop = []
        for column, imputer in self.imputers.items():
            transformed = imputer.transform(X[[column]])
            if transformed.empty:
                cols_to_drop.append(column)
            else:
                X_t[column] = transformed
        X_t = X_t.drop(cols_to_drop, axis=1)
        return X_t

    def fit_transform(self, X, y=None):
        """Fits imputer and imputes missing values in input data.

        Arguments:
            X (ww.DataTable, pd.DataFrame or np.ndarray): The input training data of shape [n_samples, n_features] to transform.
            y (ww.DataColumn, pd.Series, optional): The target training data of length [n_samples]. Ignored.

        Returns:
            pd.DataFrame: Transformed X
        """

        self.fit(X, y)
        return self.transform(X, y)
