from probatus.utils import assure_pandas_df, shap_calc, calculate_shap_importance, \
    NotFittedError, assure_pandas_series
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import warnings
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV, check_cv
from sklearn.base import clone, is_classifier
from sklearn.metrics import check_scoring
from joblib import Parallel, delayed

class ShapRFECV:
    """
    This class performs Backwards Recursive Feature Elimination, using SHAP feature importance. At each round, for a
        given feature set, starting from all available features, the following steps are applied:

    1. (Optional) Tune the hyperparameters of the model using [GridSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LassoCV.html)
        or [RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html?highlight=randomized#sklearn.model_selection.RandomizedSearchCV),
    2. Apply Cross-validation (CV) to estimate the SHAP feature importance on the provided dataset. In each CV
        iteration, the model is fitted on the train folds, and applied on the validation fold to estimate
        SHAP feature importance.
    3. Remove `step` lowest SHAP importance features from the dataset.

    At the end of the process, the user can plot the performance of the model for each iteration, and select the
        optimal number of features and the features set.

    The functionality is similar to [RFECV](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.RFECV.html).
        The main difference is removing the lowest importance features based on SHAP features importance. It also
        supports the use of [GridSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html)
        and [RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html)
        passed as the `clf`, thanks to which` you can perform hyperparameter optimization at each step of the search.
        hyperparameters of the model at each round, to tune the model for each features set. Lastly, it supports
        categorical features (object and category dtype) and missing values in the data, as long as the model supports
        them.

    We recommend using [LGBMClassifier](https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMClassifier.html),
        because by default it handles missing values and categorical features. In case of other models, make sure to
        handle these issues for your dataset and consider impact it might have on features importance.


    Example:
    ```python
    from probatus.feature_elimination import ShapRFECV
    from sklearn.datasets import make_classification
    from sklearn.model_selection import train_test_split
    import numpy as np
    import pandas as pd
    import lightgbm
    from sklearn.model_selection import RandomizedSearchCV

    feature_names = ['f1_categorical', 'f2_missing', 'f3_static', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9', 'f10', 'f11', 'f12', 'f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f20']

    # Prepare two samples
    X, y = make_classification(n_samples=1000, class_sep=0.05, n_informative=6, n_features=20,
                               random_state=0, n_redundant=10, n_clusters_per_class=1)
    X = pd.DataFrame(X, columns=feature_names)
    X['f1_categorical'] = X['f1_categorical'].apply(lambda x: str(np.round(x*10)))
    X['f2_missing'] = X['f2_missing'].apply(lambda x: x if np.random.rand()<0.8 else np.nan)
    X['f3_static'] = 0

    # Prepare model and parameter search space
    clf = lightgbm.LGBMClassifier(max_depth=5, class_weight='balanced')

    param_grid = {
        'n_estimators': [5, 7, 10],
        'num_leaves': [3, 5, 7, 10],
    }
    search = RandomizedSearchCV(clf, param_grid)


    # Run feature elimination
    shap_elimination = ShapRFECV(
        clf=search, step=0.2, cv=10, scoring='roc_auc', n_jobs=3)
    report = shap_elimination.fit_compute(X, y)

    # Make plots
    performance_plot = shap_elimination.plot()

    # Get final feature set
    final_features_set = shap_elimination.get_reduced_features_set(num_features=3)
    ```

    """

    def __init__(self, clf, step=1, min_features_to_select=1, cv=None, scoring=None, n_jobs=-1, random_state=None):
        """
        This method initializes the class:

        Args:
            clf (binary classifier, GridSearchCV or RandomizedSearchCV):
                A model that will be optimized and trained at each round of features elimination. The recommended model
                is [LGBMClassifier](https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMClassifier.html),
                because it by default handles the missing values and categorical variables. This parameter also supports
                [GridSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html)
                and [RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html).

            step (Optional, int or float):
                Number of lowest importance features removed each round. If it is an int, then each round such number of
                features is discarded. If float, such percentage of remaining features (rounded down) is removed each
                iteration. It is recommended to use float, since it is faster for a large number of features, and slows
                down and becomes more precise towards less features. Note: the last round may remove fewer features in
                order to reach min_features_to_select.

            min_features_to_select (Optional, int):
                Minimum number of features to be kept. This is a stopping criterion of the feature elimination. By
                default the process stops when one feature is left.

            cv (Optional, int, cross-validation generator or an iterable):
                Determines the cross-validation splitting strategy. Compatible with sklearn
                [cv parameter](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.RFECV.html).
                If None, then cv of 5 is used.

            scoring (Optional, string, callable or None):
                A string (see sklearn [model scoring](https://scikit-learn.org/stable/modules/model_evaluation.html)) or
                a scorer callable object, function with the signature `scorer(estimator, X, y)`.

            n_jobs (Optional, int):
                Number of cores to run in parallel while fitting across folds. None means 1 unless in a
                `joblib.parallel_backend` context. -1 means using all processors.

            random_state (Optional, int):
                Random state set at each round of feature elimination. If it is None, the results will not be
                reproducible and in random search at each iteration a different hyperparameters might be tested. For
                reproducible results set it to integer.
        """
        self.clf = clf

        if isinstance(self.clf, RandomizedSearchCV) or isinstance(self.clf, GridSearchCV):
            self.search_clf = True
        else:
            self.search_clf=False

        if (isinstance(step, int) or isinstance(step, float)) and \
                step > 0:
            self.step = step
        else:
            raise (ValueError(f"The current value of step = {step} is not allowed. "
                              f"It needs to be a positive integer or positive float."))

        if isinstance(min_features_to_select, int) and min_features_to_select>0:
            self.min_features_to_select=min_features_to_select
        else:
            raise (ValueError(f"The current value of min_features_to_select = {min_features_to_select} is not allowed. "
                              f"It needs to be a positive integer."))

        self.cv = cv
        self.scorer = check_scoring(self.clf, scoring=scoring)
        self.random_state = random_state
        self.n_jobs = n_jobs
        self.report_df = pd.DataFrame([])
        self.fitted = False


    def _check_if_fitted(self):
        """
        Checks if object has been fitted. If not, NotFittedError is raised.
        """
        if self.fitted is False:
            raise(NotFittedError('The object has not been fitted. Please run fit() method first'))


    @staticmethod
    def _preprocess_data(X):
        """
        Does basic preprocessing of the data: Removal of static features, Warns which features have missing variables,
        and transform object dtype features to category type, such that LightGBM handles them by default.

        Args:
            X (pd.DataFrame):
                Provided dataset.

        Returns:
            (pd.DataFrame):
                Preprocessed dataset.
        """
        # Make sure that X is a pd.DataFrame
        X = assure_pandas_df(X)

        # Remove static features, those that have only one value for all samples
        static_features = [i for i in X.columns if len(X[i].unique()) == 1]
        if len(static_features)>0:
            warnings.warn(f'Removing static features {static_features}.')
            X = X.drop(columns=static_features)

        # Warn if missing
        columns_with_missing = [column for column in X.columns if X[column].isnull().values.any()]
        if len(columns_with_missing) > 0:
            warnings.warn(f'The following variables contain missing values {columns_with_missing}. Make sure to impute'
                          f'missing or apply a model that handles them automatically.')

        # Transform Categorical variables into category dtype
        indices_obj_dtype_features = [column[0] for column in enumerate(X.dtypes) if column[1] == 'O']
        obj_dtype_features = list(X.columns[indices_obj_dtype_features])

        # Set categorical features type to category
        if len(obj_dtype_features) > 0:
            warnings.warn(f'Changing dtype of {obj_dtype_features} from "object" to "category". Treating it as '
                          f'categorical variable. Make sure that the model handles categorical variables, or encode '
                          f'them first.')
            for obj_dtype_feature in obj_dtype_features:
                X[obj_dtype_feature] = X[obj_dtype_feature].astype('category')
        return X


    def _get_current_features_to_remove(self, shap_importance_df):
        """
        Implements the logic used to determine which features to remove. If step is a positive integer,
            at each round step lowest SHAP importance features are selected. If it is a float, such percentage
            of remaining features (rounded up) is removed each iteration. It is recommended to use float, since it is
            faster for a large set of features, and slows down and becomes more precise towards less features.

        Args:
            shap_importance_df (pd.DataFrame):
                DataFrame presenting SHAP importance of remaining features.

        Returns:
            (list):
                List of features to be removed at a given round.
        """

        # If the step is an int remove n features.
        if isinstance(self.step, int):
            num_features_to_remove = self._calculate_number_of_features_to_remove(
                current_num_of_features=shap_importance_df.shape[0],
                num_features_to_remove=self.step,
                min_num_features_to_keep=self.min_features_to_select
            )
        # If the step is a float remove n * number features that are left, rounded down
        elif isinstance(self.step, float):
            current_step = int(np.floor(shap_importance_df.shape[0] * self.step))
            # The step after rounding down should be at least 1
            if current_step < 1:
                current_step = 1

            num_features_to_remove = self._calculate_number_of_features_to_remove(
                current_num_of_features=shap_importance_df.shape[0],
                num_features_to_remove=current_step,
                min_num_features_to_keep=self.min_features_to_select
            )

        if num_features_to_remove == 0:
            return []
        else:

            return shap_importance_df.iloc[-num_features_to_remove:].index.tolist()


    @staticmethod
    def _calculate_number_of_features_to_remove(current_num_of_features, num_features_to_remove,
                                                min_num_features_to_keep):
        """
        Calculates the number of features to be removed, and makes sure that after removal at least
            min_num_features_to_keep are kept

         Args:
            current_num_of_features (int):
                Current number of features in the data.

            num_features_to_remove (int):
                Number of features to be removed at this stage.

            min_num_features_to_keep (int):
                Minimum number of features to be left after removal.

        Returns:
            (int):
                Number of features to be removed.
        """
        num_features_after_removal = current_num_of_features - num_features_to_remove
        if num_features_after_removal >= min_num_features_to_keep:
            num_to_remove = num_features_to_remove
        else:
            # take all available features minus number of them that should stay
            num_to_remove = current_num_of_features - min_num_features_to_keep
        return num_to_remove


    def _report_current_results(self, round_number, current_features_set, features_to_remove, train_metric_mean,
                                train_metric_std, val_metric_mean, val_metric_std):
        """
        This function adds the results from a current iteration to the report.

        Args:
            round_number (int):
                Current number of the round.

            current_features_set (list of str):
                Current list of features.

            features_to_remove (list of str):
                List of features to be removed at the end of this iteration.

            train_metric_mean (float or int):
                Mean scoring metric measured on train set during CV.

            train_metric_std (float or int):
                Std scoring metric measured on train set during CV.

            val_metric_mean (float or int):
                Mean scoring metric measured on validation set during CV.

            val_metric_std (float or int):
                Std scoring metric measured on validation set during CV.
        """

        current_results = {
            'num_features': len(current_features_set),
            'features_set': None,
            'eliminated_features':  None,
            'train_metric_mean': train_metric_mean,
            'train_metric_std': train_metric_std,
            'val_metric_mean': val_metric_mean,
            'val_metric_std': val_metric_std,
        }

        current_row = pd.DataFrame(current_results, index=[round_number])
        current_row['features_set'] = [current_features_set]
        current_row['eliminated_features'] = [features_to_remove]

        self.report_df = pd.concat([self.report_df, current_row], axis=0)


    @staticmethod
    def _get_feature_shap_values_per_fold(X, y, clf, train_index, val_index, scorer):
        """
        This function calculates the shap values on validation set, and Train and Val score.

        Args:
            X (pd.DataFrame):
                Dataset used in CV.

            y (pd.Series):
                Binary labels for X.

            clf (binary classifier):
                Model to be fitted on the train folds.

            train_index (np.array):
                Positions of train folds samples.

            val_index (np.array):
                Positions of validation fold samples.

            scorer (Optional, string, callable or None):
                A string (see sklearn [model scoring](https://scikit-learn.org/stable/modules/model_evaluation.html)) or
                a scorer callable object, function with the signature `scorer(estimator, X, y)`.

        Returns:
            (np.array, float, float):
                Tuple with the results: Shap Values on validation fold, train score, validation score.
        """
        X_train, X_val = X.iloc[train_index, :], X.iloc[val_index, :]
        y_train, y_val = y.iloc[train_index], y.iloc[val_index]

        # Fit model with train folds
        clf = clf.fit(X_train, y_train)

        # Score the model
        score_train = scorer(clf, X_train, y_train)
        score_val = scorer(clf, X_val, y_val)

        # Compute SHAP values
        shap_values = shap_calc(clf, X_val, suppress_warnings=True)
        return shap_values, score_train, score_val


    def fit(self, X, y):
        """
        Fits the object with the provided data. The algorithm starts with the entire dataset, and then sequentially
             eliminates features. If [GridSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html)
             or [RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html)
             object assigned as clf, the hyperparameter optimization is applied first. Then, the SHAP feature importance
             is calculated using Cross-Validation, and `step` lowest importance features are removed.

        Args:
            X (pd.DataFrame):
                Provided dataset.

            y (pd.Series):
                Binary labels for X.
        """
        # Set seed for results reproducibility
        if self.random_state is not None:
            np.random.seed(self.random_state)

        self.X = self._preprocess_data(X)
        self.y = assure_pandas_series(y, index=self.X.index)
        self.cv = check_cv(self.cv, self.y, classifier=is_classifier(self.clf))

        remaining_features = current_features_set = self.X.columns.tolist()
        round_number = 0

        while len(current_features_set) > self.min_features_to_select:
            round_number += 1

            # Get current dataset info
            current_features_set = remaining_features
            current_X = self.X[current_features_set]

            # Set seed for results reproducibility
            if self.random_state is not None:
                np.random.seed(self.random_state)

            # Optimize parameters
            if self.search_clf:
                current_search_clf = clone(self.clf).fit(current_X, self.y)
                current_clf = current_search_clf.estimator.set_params(**current_search_clf.best_params_)
            else:
                current_clf = clone(self.clf)

            # Perform CV to estimate feature importance with SHAP
            results_per_fold = Parallel(n_jobs=self.n_jobs)(delayed(self._get_feature_shap_values_per_fold)(
                X=current_X, y=self.y, clf=current_clf, train_index=train_index, val_index=val_index, scorer=self.scorer
            ) for train_index, val_index in self.cv.split(current_X, self.y))

            shap_values = np.vstack([current_result[0] for current_result in results_per_fold])
            scores_train = [current_result[1] for current_result in results_per_fold]
            scores_val = [current_result[2] for current_result in results_per_fold]

            shap_importance_df = calculate_shap_importance(shap_values, remaining_features)

            # Get features to remove
            features_to_remove = self._get_current_features_to_remove(shap_importance_df)
            remaining_features = list(set(current_features_set) - set(features_to_remove))

            # Report results
            self._report_current_results(round_number=round_number, current_features_set=current_features_set,
                                         features_to_remove=features_to_remove,
                                         train_metric_mean = np.round(np.mean(scores_train), 3),
                                         train_metric_std = np.round(np.std(scores_train), 3),
                                         val_metric_mean = np.round(np.mean(scores_val), 3),
                                         val_metric_std = np.round(np.std(scores_val), 3))

            print(f'Round: {round_number}, Current number of features: {len(current_features_set)}, '
                  f'Current performance: Train {self.report_df.loc[round_number]["train_metric_mean"]} '
                  f'+/- {self.report_df.loc[round_number]["train_metric_std"]}, CV Validation '
                  f'{self.report_df.loc[round_number]["val_metric_mean"]} '
                  f'+/- {self.report_df.loc[round_number]["val_metric_std"]}. \n'
                  f'Num of features left: {len(remaining_features)}. '
                  f'Removed features at the end of the round: {features_to_remove}')
        self.fitted = True


    def compute(self):
        """
        Checks if fit() method has been run and computes the DataFrame with results of feature elimintation for each
         round.

        Returns:
            (pd.DataFrame):
                DataFrame with results of feature elimination for each round.
        """
        self._check_if_fitted()

        return self.report_df


    def fit_compute(self, X, y):
        """
        Fits the object with the provided data. The algorithm starts with the entire dataset, and then sequentially
             eliminates features. If [GridSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html)
             or [RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html)
             object assigned as clf, the hyperparameter optimization is applied first. Then, the SHAP feature importance
             is calculated using Cross-Validation, and `step` lowest importance features are removed. At the end, the
             report containing results from each iteration is computed and returned to the user.

        Args:
            X (pd.DataFrame):
                Provided dataset.

            y (pd.Series):
                Binary labels for X.

        Returns:
            (pd.DataFrame):
                DataFrame containing results of feature elimination from each iteration.
        """

        self.fit(X, y)
        return self.compute()


    def get_reduced_features_set(self, num_features):
        """
        Gets the features set after the feature elimination process, for a given number of features.

        Args:
            num_features (int):
                Number of features in the reduced features set.

        Returns:
            (list of str):
                Reduced features set.
        """
        self._check_if_fitted()

        if num_features not in self.report_df.num_features.tolist():
            raise(ValueError(f'The provided number of features has not been achieved at any stage of the process. '
                             f'You can select one of the following: {self.report_df.num_features.tolist()}'))
        else:
            return self.report_df[self.report_df.num_features == num_features]['features_set'].values[0]


    def plot(self, show=True, **figure_kwargs):
        """
        Generates plot of the model performance for each iteration of feature elimination.

        Args:
            show (Optional, bool):
                If True, the plots are showed to the user, otherwise they are not shown.

            **figure_kwargs:
                Keyword arguments that are passed to the plt.figure, at its initialization.

        Returns:
            (plt.axis):
                Axis containing the performance plot.
        """
        x_ticks = list(reversed(self.report_df['num_features'].tolist()))

        plt.figure(**figure_kwargs)

        plt.plot(self.report_df['num_features'], self.report_df['train_metric_mean'], label='Train Score')
        plt.fill_between(pd.to_numeric(self.report_df.num_features, errors='coerce'),
                         self.report_df['train_metric_mean'] - self.report_df['train_metric_std'],
                         self.report_df['train_metric_mean'] + self.report_df['train_metric_std'], alpha=.3)

        plt.plot(self.report_df['num_features'], self.report_df['val_metric_mean'], label='Validation Score')
        plt.fill_between(pd.to_numeric(self.report_df.num_features, errors='coerce'),
                         self.report_df['val_metric_mean'] - self.report_df['val_metric_std'],
                         self.report_df['val_metric_mean'] + self.report_df['val_metric_std'], alpha=.3)

        plt.xlabel('Number of features')
        plt.ylabel('Performance')
        plt.title('Backwards Feature Elimination using SHAP & CV')
        plt.legend(loc="lower left")
        ax = plt.gca()
        ax.invert_xaxis()
        ax.set_xticks(x_ticks)
        if show:
            plt.show()
        else:
            plt.close()
        return ax

