import numpy as np
import pandas as pd
import warnings

from sklearn.base import is_regressor, is_classifier

from scipy.stats import norm

from statsmodels.stats.multitest import multipletests

from abc import ABC, abstractmethod

from .double_ml_data import DoubleMLData
from .double_ml_resampling import DoubleMLResampling
from ._helper import _check_is_partition, _check_all_smpls


class DoubleML(ABC):
    """
    Double Machine Learning
    """

    def __init__(self,
                 obj_dml_data,
                 n_folds,
                 n_rep,
                 score,
                 dml_procedure,
                 draw_sample_splitting,
                 apply_cross_fitting):
        # check and pick up obj_dml_data
        if not isinstance(obj_dml_data, DoubleMLData):
            raise TypeError('The data must be of DoubleMLData type. '
                            f'{str(obj_dml_data)} of type {str(type(obj_dml_data))} was passed.')
        self._check_data(obj_dml_data)
        self._dml_data = obj_dml_data

        # initialize learners and parameters which are set model specific
        self._learner = None
        self._params = None

        # check resampling specifications
        if not isinstance(n_folds, int):
            raise TypeError('The number of folds must be of int type. '
                            f'{str(n_folds)} of type {str(type(n_folds))} was passed.')
        if n_folds < 1:
            raise ValueError('The number of folds must be positive. '
                             f'{str(n_folds)} was passed.')

        if not isinstance(n_rep, int):
            raise TypeError('The number of repetitions for the sample splitting must be of int type. '
                            f'{str(n_rep)} of type {str(type(n_rep))} was passed.')
        if n_rep < 1:
            raise ValueError('The number of repetitions for the sample splitting must be positive. '
                             f'{str(n_rep)} was passed.')

        if not isinstance(apply_cross_fitting, bool):
            raise TypeError('apply_cross_fitting must be True or False. '
                            f'got {str(apply_cross_fitting)}')
        if not isinstance(draw_sample_splitting, bool):
            raise TypeError('draw_sample_splitting must be True or False. '
                            f'got {str(draw_sample_splitting)}')

        # set resampling specifications
        self._n_folds = n_folds
        self._n_rep = n_rep
        self._apply_cross_fitting = apply_cross_fitting

        # check and set dml_procedure and score
        if (not isinstance(dml_procedure, str)) | (dml_procedure not in ['dml1', 'dml2']):
            raise ValueError('dml_procedure must be "dml1" or "dml2" '
                             f' got {str(dml_procedure)}')
        self._dml_procedure = dml_procedure
        self._score = self._check_score(score)

        if (self.n_folds == 1) & self.apply_cross_fitting:
            warnings.warn('apply_cross_fitting is set to False. Cross-fitting is not supported for n_folds = 1.')
            self._apply_cross_fitting = False

        if not self.apply_cross_fitting:
            assert self.n_folds <= 2, 'Estimation without cross-fitting not supported for n_folds > 2.'
            if self.dml_procedure == 'dml2':
                # redirect to dml1 which works out-of-the-box; dml_procedure is of no relevance without cross-fitting
                self._dml_procedure = 'dml1'

        # perform sample splitting
        self._smpls = None
        if draw_sample_splitting:
            self.draw_sample_splitting()

        # initialize arrays according to obj_dml_data and the resampling settings
        self._psi, self._psi_a, self._psi_b,\
            self._coef, self._se, self._all_coef, self._all_se, self._all_dml1_coef = self._initialize_arrays()

        # also initialize bootstrap arrays with the default number of bootstrap replications
        self._n_rep_boot, self._boot_coef, self._boot_t_stat = self._initialize_boot_arrays(n_rep_boot=500)

        # initialize instance attributes which are later used for iterating
        self._i_rep = None
        self._i_treat = None

    def __str__(self):
        class_name = self.__class__.__name__
        header = f'================== {class_name} Object ==================\n'
        data_info = f'Outcome variable: {self._dml_data.y_col}\n' \
                    f'Treatment variable(s): {self._dml_data.d_cols}\n' \
                    f'Covariates: {self._dml_data.x_cols}\n' \
                    f'Instrument variable(s): {self._dml_data.z_cols}\n' \
                    f'No. Observations: {self._dml_data.n_obs}\n'
        score_info = f'Score function: {str(self.score)}\n' \
                     f'DML algorithm: {self.dml_procedure}\n'
        learner_info = ''
        for key, value in self.learner.items():
            learner_info += f'Learner {key}: {str(value)}\n'
        resampling_info = f'No. folds: {self.n_folds}\n' \
                          f'No. repeated sample splits: {self.n_rep}\n' \
                          f'Apply cross-fitting: {self.apply_cross_fitting}\n'
        fit_summary = str(self.summary)
        res = header + \
            '\n------------------ Data summary      ------------------\n' + data_info + \
            '\n------------------ Score & algorithm ------------------\n' + score_info + \
            '\n------------------ Machine learner   ------------------\n' + learner_info + \
            '\n------------------ Resampling        ------------------\n' + resampling_info + \
            '\n------------------ Fit summary       ------------------\n' + fit_summary
        return res

    @property
    def n_folds(self):
        """
        Number of folds.
        """
        return self._n_folds

    @property
    def n_rep(self):
        """
        Number of repetitions for the sample splitting.
        """
        return self._n_rep

    @property
    def apply_cross_fitting(self):
        """
        Indicates whether cross-fitting should be applied.
        """
        return self._apply_cross_fitting

    @property
    def dml_procedure(self):
        """
        The double machine learning algorithm.
        """
        return self._dml_procedure

    @property
    def n_rep_boot(self):
        """
        The number of bootstrap replications.
        """
        return self._n_rep_boot

    @property
    def score(self):
        """
        The score function.
        """
        return self._score

    @property
    def learner(self):
        """
        The machine learners for the nuisance functions.
        """
        return self._learner

    @property
    def learner_names(self):
        """
        The names of the learners.
        """
        return list(self._learner.keys())

    @property
    def params(self):
        """
        The hyperparameters of the learners.
        """
        return self._params

    @property
    def params_names(self):
        """
        The names of the nuisance models with hyperparameters.
        """
        return list(self._params.keys())

    def get_params(self, learner):
        """
        Get hyperparameters for the nuisance model of DoubleML models.

        Parameters
        ----------
        learner : str
            The nuisance model / learner (see attribute ``params_names``).

        Returns
        -------
        params : dict
            Parameters for the nuisance model / learner.
        """
        valid_learner = self.params_names
        if (not isinstance(learner, str)) | (learner not in valid_learner):
            raise ValueError('invalid nuisance learner ' + str(learner) +
                             '\n valid nuisance learner ' + ' or '.join(valid_learner))
        return self._params[learner]

    # The private function _get_params delivers the single treatment, single (cross-fitting) sample subselection.
    # The slicing is based on the two properties self._i_treat, the index of the treatment variable, and
    # self._i_rep, the index of the cross-fitting sample.

    def _get_params(self, learner):
        return self._params[learner][self._dml_data.d_cols[self._i_treat]][self._i_rep]

    @property
    def smpls(self):
        """
        The partition used for cross-fitting.
        """
        if self._smpls is None:
            raise ValueError('sample splitting not specified\nEither draw samples via .draw_sample splitting() ' +
                             'or set external samples via .set_sample_splitting().')
        return self._smpls

    @property
    def psi(self):
        """
        Values of the score function :math:`\\psi(W; \\theta, \\eta) = \\psi_a(W; \\eta) \\theta + \\psi_b(W; \\eta)`
        after calling :meth:`fit`.
        """
        return self._psi

    @property
    def psi_a(self):
        """
        Values of the score function component :math:`\\psi_a(W; \\eta)` after calling :meth:`fit`.
        """
        return self._psi_a

    @property
    def psi_b(self):
        """
        Values of the score function component :math:`\\psi_b(W; \\eta)` after calling :meth:`fit`.
        """
        return self._psi_b

    @property
    def coef(self):
        """
        Estimates for the causal parameter(s) after calling :meth:`fit`.
        """
        return self._coef

    @coef.setter
    def coef(self, value):
        self._coef = value

    @property
    def se(self):
        """
        Standard errors for the causal parameter(s) after calling :meth:`fit`.
        """
        return self._se

    @se.setter
    def se(self, value):
        self._se = value

    @property
    def t_stat(self):
        """
        t-statistics for the causal parameter(s) after calling :meth:`fit`.
        """
        t_stat = self.coef / self.se
        return t_stat

    @property
    def pval(self):
        """
        p-values for the causal parameter(s) after calling :meth:`fit`.
        """
        pval = 2 * norm.cdf(-np.abs(self.t_stat))
        return pval

    @property
    def boot_coef(self):
        """
        Bootstrapped coefficients for the causal parameter(s) after calling :meth:`fit` and :meth:`bootstrap`.
        """
        return self._boot_coef

    @property
    def boot_t_stat(self):
        """
        Bootstrapped t-statistics for the causal parameter(s) after calling :meth:`fit` and :meth:`bootstrap`.
        """
        return self._boot_t_stat

    @property
    def all_coef(self):
        """
        Estimates of the causal parameter(s) for the ``n_rep`` different sample splits after calling :meth:`fit`.
        """
        return self._all_coef

    @property
    def all_se(self):
        """
        Standard errors of the causal parameter(s) for the ``n_rep`` different sample splits after calling :meth:`fit`.
        """
        return self._all_se

    @property
    def all_dml1_coef(self):
        """
        Estimates of the causal parameter(s) for the ``n_rep`` x ``n_folds`` different folds after calling :meth:`fit` with ``dml_procedure='dml1'``.
        """
        return self._all_dml1_coef

    @property
    def summary(self):
        """
        A summary for the estimated causal effect after calling :meth:`fit`.
        """
        col_names = ['coef', 'std err', 't', 'P>|t|']
        if self._dml_data.d_cols is None:
            df_summary = pd.DataFrame(columns=col_names)
        else:
            summary_stats = np.transpose(np.vstack(
                [self.coef, self.se,
                 self.t_stat, self.pval]))
            df_summary = pd.DataFrame(summary_stats,
                                      columns=col_names,
                                      index=self._dml_data.d_cols)
            ci = self.confint()
            df_summary = df_summary.join(ci)
        return df_summary

    # The private properties with __ always deliver the single treatment, single (cross-fitting) sample subselection.
    # The slicing is based on the two properties self._i_treat, the index of the treatment variable, and
    # self._i_rep, the index of the cross-fitting sample.

    @property
    def __smpls(self):
        return self._smpls[self._i_rep]

    @property
    def __psi(self):
        return self._psi[:, self._i_rep, self._i_treat]

    @__psi.setter
    def __psi(self, value):
        self._psi[:, self._i_rep, self._i_treat] = value

    @property
    def __psi_a(self):
        return self._psi_a[:, self._i_rep, self._i_treat]

    @__psi_a.setter
    def __psi_a(self, value):
        self._psi_a[:, self._i_rep, self._i_treat] = value

    @property
    def __psi_b(self):
        return self._psi_b[:, self._i_rep, self._i_treat]

    @__psi_b.setter
    def __psi_b(self, value):
        self._psi_b[:, self._i_rep, self._i_treat] = value

    @property
    def __boot_coef(self):
        ind_start = self._i_rep * self.n_rep_boot
        ind_end = (self._i_rep + 1) * self.n_rep_boot
        return self._boot_coef[self._i_treat, ind_start:ind_end]

    @__boot_coef.setter
    def __boot_coef(self, value):
        ind_start = self._i_rep * self.n_rep_boot
        ind_end = (self._i_rep + 1) * self.n_rep_boot
        self._boot_coef[self._i_treat, ind_start:ind_end] = value

    @property
    def __boot_t_stat(self):
        ind_start = self._i_rep * self.n_rep_boot
        ind_end = (self._i_rep + 1) * self.n_rep_boot
        return self._boot_t_stat[self._i_treat, ind_start:ind_end]

    @__boot_t_stat.setter
    def __boot_t_stat(self, value):
        ind_start = self._i_rep * self.n_rep_boot
        ind_end = (self._i_rep + 1) * self.n_rep_boot
        self._boot_t_stat[self._i_treat, ind_start:ind_end] = value

    @property
    def __all_coef(self):
        return self._all_coef[self._i_treat, self._i_rep]

    @__all_coef.setter
    def __all_coef(self, value):
        self._all_coef[self._i_treat, self._i_rep] = value

    @property
    def __all_se(self):
        return self._all_se[self._i_treat, self._i_rep]

    @__all_se.setter
    def __all_se(self, value):
        self._all_se[self._i_treat, self._i_rep] = value

    @property
    def __all_dml1_coef(self):
        assert self.dml_procedure == 'dml1', 'only available for dml_procedure `dml1`'
        return self._all_dml1_coef[self._i_treat, self._i_rep, :]

    @__all_dml1_coef.setter
    def __all_dml1_coef(self, value):
        assert self.dml_procedure == 'dml1', 'only available for dml_procedure `dml1`'
        self._all_dml1_coef[self._i_treat, self._i_rep, :] = value

    def fit(self, n_jobs_cv=None, keep_scores=True):
        """
        Estimate DoubleML models.

        Parameters
        ----------
        n_jobs_cv : None or int
            The number of CPUs to use to fit the learners. ``None`` means ``1``.
            Default is ``None``.

        keep_scores : bool
            Indicates whether the score function evaluations should be stored in ``psi``, ``psi_a`` and ``psi_b``.
            Default is ``True``.

        Returns
        -------
        self : object
        """

        if n_jobs_cv is not None:
            if not isinstance(n_jobs_cv, int):
                raise TypeError('The number of CPUs used to fit the learners must be of int type. '
                                f'{str(n_jobs_cv)} of type {str(type(n_jobs_cv))} was passed.')

        if not isinstance(keep_scores, bool):
            raise TypeError('keep_scores must be True or False. '
                            f'got {str(keep_scores)}')

        for i_rep in range(self.n_rep):
            self._i_rep = i_rep
            for i_d in range(self._dml_data.n_treat):
                self._i_treat = i_d

                # this step could be skipped for the single treatment variable case
                if self._dml_data.n_treat > 1:
                    self._dml_data.set_x_d(self._dml_data.d_cols[i_d])

                # ml estimation of nuisance models and computation of score elements
                self.__psi_a, self.__psi_b = self._ml_nuisance_and_score_elements(self.__smpls, n_jobs_cv)

                # estimate the causal parameter
                self.__all_coef = self._est_causal_pars()

                # compute score (depends on estimated causal parameter)
                self._compute_score()

                # compute standard errors for causal parameter
                self.__all_se = self._se_causal_pars()

        # aggregated parameter estimates and standard errors from repeated cross-fitting
        self._agg_cross_fit()

        if not keep_scores:
            self._clean_scores()

        return self

    def bootstrap(self, method='normal', n_rep_boot=500):
        """
        Multiplier bootstrap for DoubleML models.

        Parameters
        ----------
        method : str
            A str (``'Bayes'``, ``'normal'`` or ``'wild'``) specifying the multiplier bootstrap method.
            Default is ``'normal'``

        n_rep_boot : int
            The number of bootstrap replications.

        Returns
        -------
        self : object
        """
        if np.isnan(self.coef).all():
            raise ValueError('apply fit() before bootstrap()')

        if (not isinstance(method, str)) | (method not in ['Bayes', 'normal', 'wild']):
            raise ValueError('method must be "Bayes", "normal" or "wild" '
                             f' got {str(method)}')

        if not isinstance(n_rep_boot, int):
            raise TypeError('The number of bootstrap replications must be of int type. '
                            f'{str(n_rep_boot)} of type {str(type(n_rep_boot))} was passed.')
        if n_rep_boot < 1:
            raise ValueError('The number of bootstrap replications must be positive. '
                             f'{str(n_rep_boot)} was passed.')

        self._n_rep_boot, self._boot_coef, self._boot_t_stat = self._initialize_boot_arrays(n_rep_boot)

        for i_rep in range(self.n_rep):
            self._i_rep = i_rep
            for i_d in range(self._dml_data.n_treat):
                self._i_treat = i_d

                self.__boot_coef, self.__boot_t_stat = self._compute_bootstrap(method)

        return self

    def confint(self, joint=False, level=0.95):
        """
        Confidence intervals for DoubleML models.

        Parameters
        ----------
        joint : bool
            Indicates whether joint confidence intervals are computed.
            Default is ``False``

        level : float
            The confidence level.
            Default is ``0.95``.

        Returns
        -------
        df_ci : pd.DataFrame
            A data frame with the confidence interval(s).
        """

        if not isinstance(joint, bool):
            raise TypeError('joint must be True or False. '
                            f'got {str(joint)}')

        if not isinstance(level, float):
            raise TypeError('The confidence level must be of float type. '
                            f'{str(level)} of type {str(type(level))} was passed.')
        if (level <= 0) | (level >= 1):
            raise ValueError('The confidence level must be in (0,1). '
                             f'{str(level)} was passed.')

        a = (1 - level)
        ab = np.array([a / 2, 1. - a / 2])
        if joint:
            if np.isnan(self.boot_coef).all():
                raise ValueError(f'apply fit() & bootstrap() before confint(joint=True)')
            sim = np.amax(np.abs(self.boot_t_stat), 0)
            hatc = np.quantile(sim, 1 - a)
            ci = np.vstack((self.coef - self.se * hatc, self.coef + self.se * hatc)).T
        else:
            fac = norm.ppf(ab)
            ci = np.vstack((self.coef + self.se * fac[0], self.coef + self.se * fac[1])).T

        df_ci = pd.DataFrame(ci,
                             columns=['{:.1f} %'.format(i * 100) for i in ab],
                             index=self._dml_data.d_cols)
        return df_ci

    def p_adjust(self, method='romano-wolf'):
        """
        Multiple testing adjustment for DoubleML models.

        Parameters
        ----------
        method : str
            A str (``'romano-wolf''``, ``'bonferroni'``, ``'holm'``, etc) specifying the adjustment method.
            In addition to ``'romano-wolf''``, all methods implemented in
            :py:func:`statsmodels.stats.multitest.multipletests` can be applied.
            Default is ``'romano-wolf'``.

        Returns
        -------
        p_val : np.array
            An array of adjusted p-values.
        """
        if np.isnan(self.coef).all():
            raise ValueError('apply fit() before p_adjust()')

        if not isinstance(method, str):
            raise TypeError('The p_adjust method must be of str type. '
                            f'{str(method)} of type {str(type(method))} was passed.')

        if method.lower() in ['rw', 'romano-wolf']:
            if np.isnan(self.boot_coef).all():
                raise ValueError(f'apply fit() & bootstrap() before p_adjust("{method}")')

            pinit = np.full_like(self.pval, np.nan)
            p_val_corrected = np.full_like(self.pval, np.nan)

            boot_t_stats = self.boot_t_stat
            t_stat = self.t_stat
            stepdown_ind = np.argsort(t_stat)[::-1]
            ro = np.argsort(stepdown_ind)

            for i_d in range(self._dml_data.n_treat):
                if i_d == 0:
                    sim = np.max(boot_t_stats, axis=0)
                    pinit[i_d] = np.minimum(1, np.mean(sim >= np.abs(t_stat[stepdown_ind][i_d])))
                else:
                    sim = np.max(np.delete(boot_t_stats, stepdown_ind[:i_d], axis=0),
                                 axis=0)
                    pinit[i_d] = np.minimum(1, np.mean(sim >= np.abs(t_stat[stepdown_ind][i_d])))

            for i_d in range(self._dml_data.n_treat):
                if i_d == 0:
                    p_val_corrected[i_d] = pinit[i_d]
                else:
                    p_val_corrected[i_d] = np.maximum(pinit[i_d], p_val_corrected[i_d - 1])

            p_val = p_val_corrected[ro]
        else:
            _, p_val, _, _ = multipletests(self.pval, method=method)

        p_val = pd.DataFrame(np.vstack((self.coef, p_val)).T,
                             columns=['coef', 'pval'],
                             index=self._dml_data.d_cols)

        return p_val

    def tune(self,
             param_grids,
             tune_on_folds=False,
             scoring_methods=None,  # if None the estimator's score method is used
             n_folds_tune=5,
             search_mode='grid_search',
             n_iter_randomized_search=100,
             n_jobs_cv=None,
             set_as_params=True,
             return_tune_res=False):
        """
        Hyperparameter-tuning for DoubleML models.

        The hyperparameter-tuning is performed using either an exhaustive search over specified parameter values
        implemented in :class:`sklearn.model_selection.GridSearchCV` or via a randomized search implemented in
        :class:`sklearn.model_selection.RandomizedSearchCV`.

        Parameters
        ----------
        param_grids : dict
            A dict with a parameter grid for each nuisance model / learner (see attribute ``learner_names``).

        tune_on_folds : bool
            Indicates whether the tuning should be done fold-specific or globally.
            Default is ``False``.

        scoring_methods : None or dict
            The scoring method used to evaluate the predictions. The scoring method must be set per nuisance model via
            a dict (see attribute ``learner_names`` for the keys).
            If None, the estimator’s score method is used.
            Default is ``None``.

        n_folds_tune : int
            Number of folds used for tuning.
            Default is ``5``.

        search_mode : str
            A str (``'grid_search'`` or ``'randomized_search'``) specifying whether hyperparameters are optimized via
            :class:`sklearn.model_selection.GridSearchCV` or :class:`sklearn.model_selection.RandomizedSearchCV`.
            Default is ``'grid_search'``.

        n_iter_randomized_search : int
            If ``search_mode == 'randomized_search'``. The number of parameter settings that are sampled.
            Default is ``100``.

        n_jobs_cv : None or int
            The number of CPUs to use to tune the learners. ``None`` means ``1``.
            Default is ``None``.

        set_as_params : bool
            Indicates whether the hyperparameters should be set in order to be used when :meth:`fit` is called.
            Default is ``True``.

        return_tune_res : bool
            Indicates whether detailed tuning results should be returned.
            Default is ``False``.

        Returns
        -------
        self : object
            Returned if ``return_tune_res`` is ``False``.

        tune_res: list
            A list containing detailed tuning results and the proposed hyperparameters.
            Returned if ``return_tune_res`` is ``False``.
        """

        if (not isinstance(param_grids, dict)) | (not all(k in param_grids for k in self.learner_names)):
            raise ValueError('invalid param_grids ' + str(param_grids) +
                             '\n param_grids must be a dictionary with keys ' + ' and '.join(self.learner_names))

        if scoring_methods is not None:
            if (not isinstance(scoring_methods, dict)) | (not all(k in self.learner_names for k in scoring_methods)):
                raise ValueError('invalid scoring_methods ' + str(scoring_methods) +
                                 '\n scoring_methods must be a dictionary.' +
                                 '\n Valid keys are ' + ' and '.join(self.learner_names))
            if not all(k in scoring_methods for k in self.learner_names):
                # if there are learners for which no scoring_method was set, we fall back to None, i.e., default scoring
                for learner in self.learner_names:
                    if learner not in scoring_methods:
                        scoring_methods[learner] = None

        if not isinstance(tune_on_folds, bool):
            raise TypeError('tune_on_folds must be True or False. '
                            f'got {str(tune_on_folds)}')

        if not isinstance(n_folds_tune, int):
            raise TypeError('The number of folds used for tuning must be of int type. '
                            f'{str(n_folds_tune)} of type {str(type(n_folds_tune))} was passed.')
        if n_folds_tune < 2:
            raise ValueError('The number of folds used for tuning must be at least two. '
                             f'{str(n_folds_tune)} was passed.')

        if (not isinstance(search_mode, str)) | (search_mode not in ['grid_search', 'randomized_search']):
            raise ValueError('search_mode must be "grid_search" or "randomized_search" '
                             f' got {str(search_mode)}')

        if not isinstance(n_iter_randomized_search, int):
            raise TypeError('The number of parameter settings sampled for the randomized search must be of int type. '
                            f'{str(n_iter_randomized_search)} of type {str(type(n_iter_randomized_search))} was passed.')
        if n_iter_randomized_search < 2:
            raise ValueError('The number of parameter settings sampled for the randomized search must be at least two. '
                             f'{str(n_iter_randomized_search)} was passed.')

        if n_jobs_cv is not None:
            if not isinstance(n_jobs_cv, int):
                raise TypeError('The number of CPUs used to fit the learners must be of int type. '
                                f'{str(n_jobs_cv)} of type {str(type(n_jobs_cv))} was passed.')

        if not isinstance(set_as_params, bool):
            raise TypeError('set_as_params must be True or False. '
                            f'got {str(set_as_params)}')

        if not isinstance(return_tune_res, bool):
            raise TypeError('return_tune_res must be True or False. '
                            f'got {str(return_tune_res)}')

        if tune_on_folds:
            tuning_res = [[None] * self.n_rep] * self._dml_data.n_treat
        else:
            tuning_res = [None] * self._dml_data.n_treat

        for i_d in range(self._dml_data.n_treat):
            self._i_treat = i_d
            # this step could be skipped for the single treatment variable case
            if self._dml_data.n_treat > 1:
                self._dml_data.set_x_d(self._dml_data.d_cols[i_d])

            if tune_on_folds:
                nuisance_params = list()
                for i_rep in range(self.n_rep):
                    self._i_rep = i_rep

                    # tune hyperparameters
                    res = self._ml_nuisance_tuning(self.__smpls,
                                                   param_grids, scoring_methods,
                                                   n_folds_tune,
                                                   n_jobs_cv,
                                                   search_mode, n_iter_randomized_search)

                    tuning_res[i_rep][i_d] = res
                    nuisance_params.append(res['params'])

                if set_as_params:
                    for nuisance_model in nuisance_params[0].keys():
                        params = [x[nuisance_model] for x in nuisance_params]
                        self.set_ml_nuisance_params(nuisance_model, self._dml_data.d_cols[i_d], params)

            else:
                smpls = [(np.arange(self._dml_data.n_obs), np.arange(self._dml_data.n_obs))]
                # tune hyperparameters
                res = self._ml_nuisance_tuning(smpls,
                                               param_grids, scoring_methods,
                                               n_folds_tune,
                                               n_jobs_cv,
                                               search_mode, n_iter_randomized_search)
                tuning_res[i_d] = res

                if set_as_params:
                    for nuisance_model in res['params'].keys():
                        params = res['params'][nuisance_model]
                        self.set_ml_nuisance_params(nuisance_model, self._dml_data.d_cols[i_d], params[0])

        if return_tune_res:
            return tuning_res
        else:
            return self

    def set_ml_nuisance_params(self, learner, treat_var, params):
        """
        Set hyperparameters for the nuisance models of DoubleML models.

        Parameters
        ----------
        learner : str
            The nuisance model / learner (see attribute ``params_names``).

        treat_var : str
            The treatment variable (hyperparameters can be set treatment-variable specific).

        params : dict or list
            A dict with estimator parameters (used for all folds) or a nested list with fold specific parameters. The
            outer list needs to be of length ``n_rep`` and the inner list of length ``n_folds``.

        Returns
        -------
        self : object
        """
        valid_learner = self.params_names
        if learner not in valid_learner:
            raise ValueError('invalid nuisance learner ' + learner +
                             '\n valid nuisance learner ' + ' or '.join(valid_learner))

        if treat_var not in self._dml_data.d_cols:
            raise ValueError('invalid treatment variable' + treat_var +
                             '\n valid treatment variable ' + ' or '.join(self._dml_data.d_cols))

        if isinstance(params, dict):
            if self.apply_cross_fitting:
                all_params = [[params] * self.n_folds] * self.n_rep
            else:
                all_params = [[params] * 1] * self.n_rep
        else:
            assert len(params) == self.n_rep
            if self.apply_cross_fitting:
                assert np.all(np.array([len(x) for x in params]) == self.n_folds)
            else:
                assert np.all(np.array([len(x) for x in params]) == 1)
            all_params = params

        self._params[learner][treat_var] = all_params

        return self

    @abstractmethod
    def _initialize_ml_nuisance_params(self):
        pass

    @abstractmethod
    def _check_score(self, score):
        pass

    @abstractmethod
    def _check_data(self, obj_dml_data):
        pass

    @abstractmethod
    def _ml_nuisance_and_score_elements(self, smpls, n_jobs_cv):
        pass

    @abstractmethod
    def _ml_nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv,
                            search_mode, n_iter_randomized_search):
        pass

    @staticmethod
    def _check_learner(learner, learner_name, classifier=False):
        err_msg_prefix = f'invalid learner provided for {learner_name}: '
        warn_msg_prefix = f'learner provided for {learner_name} is probably invalid: '

        if isinstance(learner, type):
            raise TypeError(err_msg_prefix + f'provide an instance of a learner instead of a class')

        if not hasattr(learner, 'fit'):
            raise TypeError(err_msg_prefix + f'{str(learner)} has no method .fit()')
        if not hasattr(learner, 'set_params'):
            raise TypeError(err_msg_prefix + f'{str(learner)} has no method .set_params()')
        if not hasattr(learner, 'get_params'):
            raise TypeError(err_msg_prefix + f'{str(learner)} has no method .get_params()')

        if classifier:
            if not hasattr(learner, 'predict_proba'):
                raise TypeError(err_msg_prefix + f'{str(learner)} has no method .predict_proba()')
            if not is_classifier(learner):
                warnings.warn(warn_msg_prefix + f'{str(learner)} is (probably) no classifier')
        else:
            if not hasattr(learner, 'predict'):
                raise TypeError(err_msg_prefix + f'{str(learner)} has no method .predict()')
            if not is_regressor(learner):
                warnings.warn(warn_msg_prefix + f'{str(learner)} is (probably) no regressor')

        return learner

    def _initialize_arrays(self):
        psi = np.full((self._dml_data.n_obs, self.n_rep, self._dml_data.n_treat), np.nan)
        psi_a = np.full((self._dml_data.n_obs, self.n_rep, self._dml_data.n_treat), np.nan)
        psi_b = np.full((self._dml_data.n_obs, self.n_rep, self._dml_data.n_treat), np.nan)

        coef = np.full(self._dml_data.n_treat, np.nan)
        se = np.full(self._dml_data.n_treat, np.nan)

        all_coef = np.full((self._dml_data.n_treat, self.n_rep), np.nan)
        all_se = np.full((self._dml_data.n_treat, self.n_rep), np.nan)

        if self.dml_procedure == 'dml1':
            if self.apply_cross_fitting:
                all_dml1_coef = np.full((self._dml_data.n_treat, self.n_rep, self.n_folds), np.nan)
            else:
                all_dml1_coef = np.full((self._dml_data.n_treat, self.n_rep, 1), np.nan)
        else:
            all_dml1_coef = None

        return psi, psi_a, psi_b, coef, se, all_coef, all_se, all_dml1_coef

    def _initialize_boot_arrays(self, n_rep_boot):
        boot_coef = np.full((self._dml_data.n_treat, n_rep_boot * self.n_rep), np.nan)
        boot_t_stat = np.full((self._dml_data.n_treat, n_rep_boot * self.n_rep), np.nan)
        return n_rep_boot, boot_coef, boot_t_stat

    def draw_sample_splitting(self):
        """
        Draw sample splitting for DoubleML models.

        The samples are drawn according to the attributes
        ``n_folds``, ``n_rep`` and ``apply_cross_fitting``.

        Returns
        -------
        self : object
        """
        obj_dml_resampling = DoubleMLResampling(n_folds=self.n_folds,
                                                n_rep=self.n_rep,
                                                n_obs=self._dml_data.n_obs,
                                                apply_cross_fitting=self.apply_cross_fitting)
        self._smpls = obj_dml_resampling.split_samples()

        return self

    def set_sample_splitting(self, all_smpls):
        """
        Set the sample splitting for DoubleML models.

        The  attributes ``n_folds`` and ``n_rep`` are derived from the provided partition.

        Parameters
        ----------
        all_smpls : list or tuple
            If nested list of lists of tuples:
                The outer list needs to provide an entry per repeated sample splitting (length of list is set as
                ``n_rep``).
                The inner list needs to provide a tuple (train_ind, test_ind) per fold (length of list is set as
                ``n_folds``). If tuples for more than one fold are provided, it must form a partition and
                ``apply_cross_fitting`` is set to True. Otherwise ``apply_cross_fitting`` is set to False and
                ``n_folds=2``.
            If list of tuples:
                The list needs to provide a tuple (train_ind, test_ind) per fold (length of list is set as
                ``n_folds``). If tuples for more than one fold are provided, it must form a partition and
                ``apply_cross_fitting`` is set to True. Otherwise ``apply_cross_fitting`` is set to False and
                ``n_folds=2``.
                ``n_rep=1`` is always set.
            If tuple:
                Must be a tuple with two elements train_ind and test_ind. No sample splitting is achieved if train_ind
                and test_ind are range(n_rep). Otherwise ``n_folds=2``.
                ``apply_cross_fitting=False`` and ``n_rep=1`` is always set.

        Returns
        -------
        self : object

        Examples
        --------
        >>> import numpy as np
        >>> import doubleml as dml
        >>> from doubleml.datasets import make_plr_CCDDHNR2018
        >>> from sklearn.ensemble import RandomForestRegressor
        >>> from sklearn.base import clone
        >>> np.random.seed(3141)
        >>> learner = RandomForestRegressor(max_depth=2, n_estimators=10)
        >>> ml_g = learner
        >>> ml_m = learner
        >>> obj_dml_data = make_plr_CCDDHNR2018(n_obs=10, alpha=0.5)
        >>> dml_plr_obj = dml.DoubleMLPLR(obj_dml_data, ml_g, ml_m)
        >>> # simple sample splitting with two folds and without cross-fitting
        >>> smpls = ([0, 1, 2, 3, 4], [5, 6, 7, 8, 9])
        >>> dml_plr_obj.set_sample_splitting(smpls)
        >>> # sample splitting with two folds and cross-fitting
        >>> smpls = [([0, 1, 2, 3, 4], [5, 6, 7, 8, 9]),
        >>>          ([5, 6, 7, 8, 9], [0, 1, 2, 3, 4])]
        >>> dml_plr_obj.set_sample_splitting(smpls)
        >>> # sample splitting with two folds and repeated cross-fitting with n_rep = 2
        >>> smpls = [[([0, 1, 2, 3, 4], [5, 6, 7, 8, 9]),
        >>>           ([5, 6, 7, 8, 9], [0, 1, 2, 3, 4])],
        >>>          [([0, 2, 4, 6, 8], [1, 3, 5, 7, 9]),
        >>>           ([1, 3, 5, 7, 9], [0, 2, 4, 6, 8])]]
        >>> dml_plr_obj.set_sample_splitting(smpls)
        """
        if isinstance(all_smpls, tuple):
            if not len(all_smpls) == 2:
                raise ValueError('Invalid partition provided. '
                                 'Tuple for train_ind and test_ind must consist of exactly two elements.')
            if (_check_is_partition([all_smpls], self._dml_data.n_obs) &
                    _check_is_partition([(all_smpls[1], all_smpls[0])], self._dml_data.n_obs)):
                self._n_rep = 1
                self._n_folds = 1
                self._apply_cross_fitting = False
                self._smpls = [[all_smpls]]
            else:
                self._n_rep = 1
                self._n_folds = 2
                self._apply_cross_fitting = False
                self._smpls = _check_all_smpls([[all_smpls]], self._dml_data.n_obs)
        else:
            if not isinstance(all_smpls, list):
                raise TypeError('all_smpls must be of list or tuple type. '
                                f'{str(all_smpls)} of type {str(type(all_smpls))} was passed.')
            all_tuple = all([isinstance(tpl, tuple) for tpl in all_smpls])
            if all_tuple:
                if not all([len(tpl) == 2 for tpl in all_smpls]):
                    raise ValueError('Invalid partition provided. '
                                     'All tuples for train_ind and test_ind must consist of exactly two elements.')
                self._n_rep = 1
                if _check_is_partition(all_smpls, self._dml_data.n_obs):
                    self._n_folds = len(all_smpls)
                    self._apply_cross_fitting = True
                    self._smpls = _check_all_smpls([all_smpls], self._dml_data.n_obs)
                else:
                    if not len(all_smpls) == 1:
                        raise ValueError('Invalid partition provided. '
                                         'Tuples for more than one fold provided that don\'t form a partition.')
                    self._n_folds = 2
                    self._apply_cross_fitting = False
                    self._smpls = _check_all_smpls([all_smpls], self._dml_data.n_obs)
            else:
                all_list = all([isinstance(smpl, list) for smpl in all_smpls])
                if not all_list:
                    raise ValueError('Invalid partition provided. '
                                     'all_smpls is a list where neither all elements are tuples '
                                     'nor all elements are lists.')
                all_tuple = all([all([isinstance(tpl, tuple) for tpl in smpl]) for smpl in all_smpls])
                if not all_tuple:
                    raise TypeError('For repeated sample splitting all_smpls must be list of lists of tuples.')
                all_pairs = all([all([len(tpl) == 2 for tpl in smpl]) for smpl in all_smpls])
                if not all_pairs:
                    raise ValueError('Invalid partition provided. '
                                     'All tuples for train_ind and test_ind must consist of exactly two elements.')
                n_folds_each_smpl = np.array([len(smpl) for smpl in all_smpls])
                if not np.all(n_folds_each_smpl == n_folds_each_smpl[0]):
                    raise ValueError('Invalid partition provided. '
                                     'Different number of folds for repeated sample splitting.')
                smpls_are_partitions = [_check_is_partition(smpl, self._dml_data.n_obs) for smpl in all_smpls]

                if all(smpls_are_partitions):
                    self._n_rep = len(all_smpls)
                    self._n_folds = n_folds_each_smpl[0]
                    self._apply_cross_fitting = True
                    self._smpls = _check_all_smpls(all_smpls, self._dml_data.n_obs)
                else:
                    if not n_folds_each_smpl[0] == 1:
                        raise ValueError('Invalid partition provided. '
                                         'Tuples for more than one fold provided '
                                         'but at least one does not form a partition.')
                    self._n_rep = len(all_smpls)
                    self._n_folds = 2
                    self._apply_cross_fitting = False
                    self._smpls = _check_all_smpls(all_smpls, self._dml_data.n_obs)

        self._psi, self._psi_a, self._psi_b, \
            self._coef, self._se, self._all_coef, self._all_se, self._all_dml1_coef = self._initialize_arrays()
        self._initialize_ml_nuisance_params()

        return self

    def _est_causal_pars(self):
        dml_procedure = self.dml_procedure
        smpls = self.__smpls

        if dml_procedure == 'dml1':
            # Note that len(smpls) is only not equal to self.n_folds if self.apply_cross_fitting = False
            thetas = np.zeros(len(smpls))
            for idx, (train_index, test_index) in enumerate(smpls):
                thetas[idx] = self._orth_est(test_index)
            theta_hat = np.mean(thetas)
            coef = theta_hat

            self.__all_dml1_coef = thetas

        elif dml_procedure == 'dml2':
            theta_hat = self._orth_est()
            coef = theta_hat

        else:
            raise ValueError('invalid dml_procedure')

        return coef

    def _se_causal_pars(self):
        if self.apply_cross_fitting:
            se = np.sqrt(self._var_est())
        else:
            # In case of no-cross-fitting, the score function was only evaluated on the test data set
            smpls = self.__smpls
            test_index = smpls[0][1]
            se = np.sqrt(self._var_est(test_index))

        return se

    def _agg_cross_fit(self):
        # aggregate parameters from the repeated cross-fitting
        # don't use the getter (always for one treatment variable and one sample), but the private variable
        self.coef = np.median(self._all_coef, 1)

        # TODO: In the documentation of standard errors we need to cleary state what we return here, i.e.,
        # the asymptotic variance sigma_hat/N and not sigma_hat (which sometimes is also called the asympt var)!
        if self.apply_cross_fitting:
            n_obs = self._dml_data.n_obs
        else:
            # be prepared for the case of test sets of different size in repeated no-cross-fitting
            smpls = self.__smpls
            test_index = smpls[0][1]
            n_obs = len(test_index)
        xx = np.tile(self.coef.reshape(-1, 1), self.n_rep)
        self.se = np.sqrt(np.divide(np.median(np.multiply(np.power(self._all_se, 2), n_obs) +
                                              np.power(self._all_coef - xx, 2), 1), n_obs))

    def _compute_bootstrap(self, method):
        dml_procedure = self.dml_procedure
        smpls = self.__smpls
        if self.apply_cross_fitting:
            n_obs = self._dml_data.n_obs
        else:
            # be prepared for the case of test sets of different size in repeated no-cross-fitting
            test_index = smpls[0][1]
            n_obs = len(test_index)

        if method == 'Bayes':
            weights = np.random.exponential(scale=1.0, size=(self.n_rep_boot, n_obs)) - 1.
        elif method == 'normal':
            weights = np.random.normal(loc=0.0, scale=1.0, size=(self.n_rep_boot, n_obs))
        elif method == 'wild':
            xx = np.random.normal(loc=0.0, scale=1.0, size=(self.n_rep_boot, n_obs))
            yy = np.random.normal(loc=0.0, scale=1.0, size=(self.n_rep_boot, n_obs))
            weights = xx / np.sqrt(2) + (np.power(yy, 2) - 1) / 2
        else:
            raise ValueError('invalid boot method')

        if self.apply_cross_fitting:
            if dml_procedure == 'dml1':
                boot_coefs = np.full((self.n_rep_boot, self.n_folds), np.nan)
                boot_t_stats = np.full((self.n_rep_boot, self.n_folds), np.nan)
                for idx, (_, test_index) in enumerate(smpls):
                    J = np.mean(self.__psi_a[test_index])
                    boot_coefs[:, idx] = np.matmul(weights[:, test_index], self.__psi[test_index]) / (
                            len(test_index) * J)
                    boot_t_stats[:, idx] = np.matmul(weights[:, test_index], self.__psi[test_index]) / (
                            len(test_index) * self.__all_se * J)
                boot_coef = np.mean(boot_coefs, axis=1)
                boot_t_stat = np.mean(boot_t_stats, axis=1)

            elif dml_procedure == 'dml2':
                J = np.mean(self.__psi_a)
                boot_coef = np.matmul(weights, self.__psi) / (self._dml_data.n_obs * J)
                boot_t_stat = np.matmul(weights, self.__psi) / (self._dml_data.n_obs * self.__all_se * J)

            else:
                raise ValueError('invalid dml_procedure')
        else:
            J = np.mean(self.__psi_a[test_index])
            boot_coef = np.matmul(weights, self.__psi[test_index]) / (len(test_index) * J)
            boot_t_stat = np.matmul(weights, self.__psi[test_index]) / (len(test_index) * self.__all_se * J)

        return boot_coef, boot_t_stat

    def _var_est(self, inds=None):
        """
        Estimate the standard errors of the structural parameter
        """
        psi_a = self.__psi_a
        psi = self.__psi

        if inds is not None:
            assert not self.apply_cross_fitting
            psi_a = psi_a[inds]
            psi = psi[inds]
            n_obs = len(inds)
        else:
            assert self.apply_cross_fitting
            n_obs = self._dml_data.n_obs

        # TODO: In the documentation of standard errors we need to cleary state what we return here, i.e.,
        # the asymptotic variance sigma_hat/N and not sigma_hat (which sometimes is also called the asympt var)!
        J = np.mean(psi_a)
        sigma2_hat = 1 / n_obs * np.mean(np.power(psi, 2)) / np.power(J, 2)

        return sigma2_hat

    def _orth_est(self, inds=None):
        """
        Estimate the structural parameter
        """
        psi_a = self.__psi_a
        psi_b = self.__psi_b

        if inds is not None:
            psi_a = psi_a[inds]
            psi_b = psi_b[inds]

        theta = -np.mean(psi_b) / np.mean(psi_a)

        return theta

    def _compute_score(self):
        self.__psi = self.__psi_a * self.__all_coef + self.__psi_b

    def _clean_scores(self):
        del self._psi
        del self._psi_a
        del self._psi_b
