# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/10_api.ipynb.

# %% auto 0
__all__ = ['suspect']

# %% ../nbs/10_api.ipynb 2
import numpy as np
import pandas as pd

from plum import dispatch
from loguru import logger

# %% ../nbs/10_api.ipynb 13
@dispatch
def suspect(
    probas: np.ndarray,
    *,
    labels: np.ndarray,
) -> pd.DataFrame:
    """Rank the suspicious labels given probas from a classifier.
    Accept Numpy arrays, Pandas dataframes and series, and normal Python lists.
    We can use interger, string or even float labels, given that
    the probability matrix's columns are indexed by the same label set.

    #### Args

    - probas (n x m matrix): probabilites for possible classes.

    #### KwArgs

    - labels (n x 1 vector): observed class labels

    #### Returns

    a Pandas DataFrame including 1 index and 2 columns:

    - id (int): the index which is the same to the original data row index
    - err (float): the magnitude of suspiciousness, valued between [0, 1]
    - suspected (bool):  whether the data row is suspected as having a label error.
    """
    if len(labels) != len(probas):
        logger.debug(
            f"""
            Trying to reshape probas"""
        )
        try:
            probas = probas.reshape((len(labels), -1))
        except Exception as e:
            logger.error(
                f"Labels and probas MUST have same length, BUT {len(labels)} != {len(probas)}"
            )
            raise e

    logger.debug(f"Shape of labels and probas: {labels.shape} vs {probas.shape}")

    from lapros.alg import suspect

    ranks = suspect(probas=probas, labels=labels)

    return ranks


# show_doc(suspect)

# %% ../nbs/10_api.ipynb 23
def validate_labels_cols_matching(
    *,
    labels: pd.Series,
    cols: pd.Index,
) -> bool:
    """Verify that the given labels and cols match together."""
    u_labels = labels.unique()
    logger.debug(f"Unique lables {u_labels}")
    diff_labs = cols.symmetric_difference(u_labels)
    logger.debug(f"Cols {cols}")
    logger.debug(f"diff_labs {diff_labs}")
    return diff_labs.empty

# %% ../nbs/10_api.ipynb 26
def to_numpy(
    *,
    labels: pd.Series,
    cols: pd.Index,
) -> np.ndarray:
    """Converting a Pandas series of string labels to a Numpy integer array,
    given the index of the unique labels.
    """
    if not validate_labels_cols_matching(labels=labels, cols=cols):
        raise f"Labels and columns not matching"

    if cols.is_integer():
        logger.debug("Cols are integer. The labels should already be, too.")
        labels = labels.astype(int)
        return labels.to_numpy()

    lab2int = {col: i for i, col in enumerate(cols)}
    logger.debug(f"lab2int mapping {lab2int}")
    try:
        # logger.debug(labels.values)
        int_labels = [lab2int[lab] for lab in labels.values]
        labels = np.array(int_labels)
        logger.debug(f"labels converted to integers {labels}")
    except Exception as err:
        logger.error("Can not convert labels from strings to integers")
        raise err
    return labels

# %% ../nbs/10_api.ipynb 27
@logger.catch(reraise=True)
@dispatch
def suspect(
    probas: pd.DataFrame,
    *,
    labels: pd.Series,
) -> pd.DataFrame:
    logger.debug(
        f"Pandas series labels and dataframe probas must have same length {len(labels)} vs {len(probas)}"
    )
    labels = to_numpy(labels=labels, cols=probas.columns)
    ranks = suspect(
        probas.to_numpy(),
        labels=labels,
    )
    return ranks

# %% ../nbs/10_api.ipynb 40
@dispatch
def suspect(
    probas: list,
    *,
    labels: list,
) -> pd.DataFrame:
    logger.debug("Normal Python lists")
    ranks = suspect(
        np.array(probas),
        labels=np.array(labels),
    )
    return ranks
