# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/20-alg.ipynb.

# %% auto 0
__all__ = ['avg_confidence', 'find_likely_label', 'rank_suspicious', 'suspect']

# %% ../nbs/20-alg.ipynb 2
import numpy as np
import pandas as pd
from lapros.util import set_logger

logger = set_logger();

# %% ../nbs/20-alg.ipynb 15
def avg_confidence(
    *,
    ps: np.ndarray,
    ys: np.ndarray,
):
    """Compute the average model confidence for samples in each class. If there is no sample in some specific class, we take the avarage over all samples."""
    m = ps.shape[1]
    t = np.zeros(m)
    for j in range(m):
        xi = ys == j
        p_xi = ps[xi, j] if xi.any() else ps[:, j]
        t[j] = p_xi.mean()
    logger.debug(t)
    return t

# %% ../nbs/20-alg.ipynb 18
def find_likely_label(
    *,
    t2p: np.ndarray,
    mask_negative: bool = False,
):
    """Find the most likely ys for each sample.

    ## Params:

    - mask_negative: For some specific sample, if the normalized probabilities are all negative then we use 0 to mark that there is no likely class label for the sample."""
    ll = t2p.argmax(axis=1)
    if mask_negative:
        ll = np.where((t2p >= 0).any(axis=1), ll, -1)
        # ll = ifelle.(any(t2p .≥ 0, dims=2)[:], likely_ys, 0)
    logger.debug(ll)
    return ll

# %% ../nbs/20-alg.ipynb 20
def rank_suspicious(
    *,
    t2p: np.ndarray,
    ll: np.ndarray,
    ys: np.ndarray,
):
    from scipy.sparse import lil_matrix

    n, m = t2p.shape
    e = lil_matrix((n, 1))
    ids = (ll != ys) & (ll > -1)
    # logger.debug(ids)
    # TODO: make the loop below faster
    for k in np.arange(n)[np.where(ids)]:
        e[k, 0] = t2p[k, ll[k]] - t2p[k, ys[k]]
    # logger.debug(e)
    coo = e.tocoo()
    id = coo.row
    err = coo.data
    # logger.debug(id)
    # logger.debug(err)
    err_df = pd.DataFrame(dict(id=id, err=err)).set_index("id")
    logger.debug(err_df)
    return err_df

# %% ../nbs/20-alg.ipynb 23
def suspect(
    *,
    ps: np.ndarray,
    ys: np.ndarray,
) -> np.ndarray:
    """The internal method to rank the suspicious ys given ps from a classifier."""
    # logger.debug(f"Shape of ys and ps: {ys.shape} vs {ps.shape}")
    assert len(ys) == len(ps)
    assert (ps >= 0).all() and (ps <= 1).all()  # ps should be between 0 and 1

    t = avg_confidence(ps=ps, ys=ys)
    t2p = ps - t
    ll = find_likely_label(t2p=t2p)
    err_df = rank_suspicious(t2p=t2p, ll=ll, ys=ys)
    # logger.debug(f"err_df {err_df}")

    return err_df
