import woodwork as ww


def _downcast_nullable_X(X, handle_boolean_nullable=True, handle_integer_nullable=True):
    """Removes Pandas nullable integer and nullable boolean dtypes from data by transforming
        to other dtypes via Woodwork logical type transformations.

    Args:
        X (pd.DataFrame): Input data of shape [n_samples, n_features] whose nullable types will be changed.
        handle_boolean_nullable (bool, optional): Whether or not to downcast data with BooleanNullable logical types.
        handle_integer_nullable (bool, optional): Whether or not to downcast data with IntegerNullable or AgeNullable logical types.


    Returns:
        X with any incompatible nullable types downcasted to compatible equivalents.
    """
    if X.ww.schema is None:
        X.ww.init()

    incompatible_logical_types = _get_incompatible_nullable_types(
        handle_boolean_nullable,
        handle_integer_nullable,
    )

    data_to_downcast = X.ww.select(incompatible_logical_types)
    # If no incompatible types are present, no downcasting is needed
    if not len(data_to_downcast.columns):
        return X

    new_ltypes = {
        col: _determine_downcast_type(data_to_downcast.ww[col])
        for col in data_to_downcast.columns
    }

    X.ww.set_types(logical_types=new_ltypes)
    return X


def _downcast_nullable_y(y, handle_boolean_nullable=True, handle_integer_nullable=True):
    """Removes Pandas nullable integer and nullable boolean dtypes from data by transforming
        to other dtypes via Woodwork logical type transformations.

    Args:
        y (pd.Series): Target data of shape [n_samples] whose nullable types will be changed.
        handle_boolean_nullable (bool, optional): Whether or not to downcast data with BooleanNullable logical types.
        handle_integer_nullable (bool, optional): Whether or not to downcast data with IntegerNullable or AgeNullable logical types.


    Returns:
        y with any incompatible nullable types downcasted to compatible equivalents.
    """
    if y.ww.schema is None:
        y = ww.init_series(y)

    incompatible_logical_types = _get_incompatible_nullable_types(
        handle_boolean_nullable,
        handle_integer_nullable,
    )

    if isinstance(y.ww.logical_type, tuple(incompatible_logical_types)):
        new_ltype = _determine_downcast_type(y)
        return y.ww.set_logical_type(new_ltype)

    return y


def _get_incompatible_nullable_types(handle_boolean_nullable, handle_integer_nullable):
    """Determines which Woodwork logical types are incompatible.

    Args:
        handle_boolean_nullable (bool): Whether boolean nullable logical types are incompatible.
        handle_integer_nullable (bool): Whether integer nullable logical types are incompatible.

    Returns:
        list[ww.LogicalType] of logical types that are incompatible.
    """
    nullable_types_to_handle = []
    if handle_boolean_nullable:
        nullable_types_to_handle.append(ww.logical_types.BooleanNullable)
    if handle_integer_nullable:
        nullable_types_to_handle.append(ww.logical_types.IntegerNullable)
        nullable_types_to_handle.append(ww.logical_types.AgeNullable)

    return nullable_types_to_handle


def _determine_downcast_type(col):
    """Determines what logical type to downcast to based on whether nans were present or not.
        - BooleanNullable becomes Boolean if nans are not present and Categorical if they are
        - IntegerNullable becomes Integer if nans are not present and Double if they are.
        - AgeNullable becomes Age if nans are not present and AgeFractional if they are.

    Args:
        col (Woodwork Series): The data whose downcast logical type we are determining by inspecting
            its current logical type and whether nans are present.

    Returns:
        LogicalType string to be used to downcast incompatible nullable logical types.
    """
    downcast_matches = {
        "BooleanNullable": ("Boolean", "Categorical"),
        "IntegerNullable": ("Integer", "Double"),
        "AgeNullable": ("Age", "AgeFractional"),
    }

    no_nans_ltype, has_nans_ltype = downcast_matches[str(col.ww.logical_type)]
    if col.isnull().any():
        return has_nans_ltype

    return no_nans_ltype
