from typing_extensions import Literal
import numpy as np
import matplotlib.pyplot as plt

from .graph_fitting import *
from .pseudotime import *
from .root import root
from .. import logging as logg
from .. import settings


def cellrank_to_tree(
    adata,
    time,
    Nodes: int,
    method: Literal["ppt", "epg"] = "ppt",
    ppt_lambda: int = 100,
    auto_root: bool = False,
    root_params: dict = {},
    reassign_pseudotime: bool = False,
    copy=False,
    **kwargs
):
    """\
    Converts CellRank [Lange21]_ fate probabilities into a principal tree that can be analysed by scFates.

    It combines the projection generated by `cr.pl.circular_projection`
    with any measure of differentiation (CytoTRACE, latent time). A tree is fitted onto this new embedding.

    Parameters
    ----------
    adata
        Annotated data matrix.
    time
        time key to use for the additional dimension used in combination with `cr.pl.circular_projection`.
    Nodes
        Number of nodes that compose the principal graph.
    method
        If ppt, uses simpleppt approach, `ppt_lambda` and `ppt_sigma` are the
        parameters controlling the algorithm. If `epg`, uses ComputeElasticPrincipalTree
        function from elpigraph python package, `epg_lambda` `epg_mu` and `epg_trimmingradius`
        are the parameters controlling the algorithm.
    ppt_lambda
        Parameter for simpleppt, penalty for the tree length [Mao15]_. Usually works well at default for the conversion.
    auto_root
        Automatically select the root tip using the time key.
    min_val
        min_val parameter from :func:`scFates.tl.root`
    reassign_pseudotime
        whether use the time key to replace the distances comptued from the tree.
    copy
        Return a copy instead of writing to adata.
    kwargs
        arguments to pass to function :func:`scFates.tl.tree`.

    Returns
    -------
    adata : anndata.AnnData
        if `copy=True` it returns or else add fields to `adata`:

        `.obsm['X_fates']`
            representation generated by combining the time key with projection generated by :func:`cellrank.pl.circular_projection`.
        `.uns['ppt']`
            dictionnary containing information from simpelppt tree if method='ppt'
        `.uns['epg']`
            dictionnary containing information from elastic principal tree if method='epg'
        `.uns['graph']['B']`
            adjacency matrix of the principal points
        `.uns['graph']['R']`
            soft assignment of cells to principal point in representation space
        `.uns['graph']['F']`
            coordinates of principal points in representation spac


    """
    try:
        import cellrank as cr
    except Exception as e:
        raise Exception(
            'cellrank installation is necessary for conversion. \
            \nPlease use "pip3 install cellrank" to install it'
        )

    logg.info(
        "Converting CellRank results to a principal tree",
        end=" " if settings.verbosity > 2 else "\n",
    )
    if reassign_pseudotime:
        auto_root = True

    adata = adata.copy() if copy else adata

    n_states = adata.obsm["to_terminal_states"].shape[1]

    if n_states == 2:
        adata.obsm["X_fates"] = np.vstack(
            [
                np.array(adata.obsm["to_terminal_states"][:, 0].flatten()),
                adata.obs[time],
            ]
        ).T
        logg.hint(
            "with .obsm['X_fates'], created by combining:\n"
            "    .obsm['to_terminal_states'][:,0] and adata.obs['" + time + "']\n"
        )
    else:
        logg.hint(
            "with .obsm['X_fates'], created by combining:\n"
            "    .obsm['X_fate_simplex_fwd'] (from cr.pl.circular_projection) and adata.obs['"
            + time
            + "']\n"
        )
        cr.pl.circular_projection(adata, keys=["kl_divergence"])
        plt.close()

        adata.obsm["X_fates"] = np.concatenate(
            [adata.obsm["X_fate_simplex_fwd"], adata.obs[time].values.reshape(-1, 1)],
            axis=1,
        )

    tree(
        adata,
        Nodes=Nodes,
        use_rep="X_fates",
        method=method,
        ppt_lambda=ppt_lambda,
        **kwargs
    )

    if auto_root:
        logg.info("\nauto selecting a tip as a root using " + time + ".\n")
        root(adata, time, **root_params)
        pseudotime(adata)

    newt = ""
    if reassign_pseudotime:
        adata.obs["t"] = adata.obs[time]

        adata.uns["pseudotime_list"]["0"]["t"] = adata.obs[time]
        for n in range(adata.obsm["X_R"].shape[1]):
            adata.uns["graph"]["pp_info"].loc[n, "time"] = np.average(
                adata.obs.t, weights=adata.obsm["X_R"][:, n]
            )

        for n in adata.uns["graph"]["pp_seg"].index:
            adata.uns["graph"]["pp_seg"].loc[n, "d"] = np.diff(
                adata.uns["graph"]["pp_info"]
                .loc[adata.uns["graph"]["pp_seg"].loc[n, ["from", "to"]].values, "time"]
                .values
            )[0]

        newt = "    .obs['t'] has been replaced by .obs['" + time + "']\n"

    logg.info("\nfinished", time=True, end="\n" if reassign_pseudotime else " ")
    logg.info(
        "    .obsm['X_fates'] representation used for fitting the tree.\n"
        + newt
        + "    .uns['graph']['pp_info'].time has been updated with "
        + time
        + "\n"
        "    .uns['graph']['pp_seg'].d has been updated with " + time
    )

    return adata if copy else None
