"""Build neurite diameters from a pre-generated model. TO REVIEW!."""
import logging
from collections import deque
from functools import partial
from copy import copy

import numpy as np
from numpy.polynomial import polynomial

from morphio import SectionType, IterType

import diameter_synthesis.morph_functions as morph_funcs
from diameter_synthesis import utils
from diameter_synthesis.distribution_fitting import sample_distribution
from diameter_synthesis.exception import DiameterSynthesisError

TRUNK_FRAC_DECREASE = 0.1
N_TRIES_BEFORE_REDUC = 5
L = logging.getLogger(__name__)

STR_TO_TYPES = {
    "basal": SectionType.basal_dendrite,
    "apical": SectionType.apical_dendrite,
    "axon": SectionType.axon,
}

TYPES_TO_STR = {
    SectionType.basal_dendrite: "basal",
    SectionType.apical_dendrite: "apical",
    SectionType.axon: "axon",
}


def _reset_caches():
    """Reset the cached functions."""
    morph_funcs.sec_length.cache_clear()
    morph_funcs.partition_asymmetry_length.cache_clear()
    morph_funcs.lengths_from_origin.cache_clear()
    morph_funcs.n_children_downstream.cache_clear()


def _get_neurites(neuron, neurite_type):
    """Get a list of neurites to diametrize.

    Args:
        neuron (morphio.mu.Morphology): a neuron
        neurite_type (morphio.SectionType): the neurite type to consider

    Returns:
        list: list of neurites to consider
    """
    return [
        list(neurite.iter()) for neurite in neuron.root_sections if neurite.type == neurite_type
    ]


def _sample_sibling_ratio(
    params, neurite_type, apply_asymmetry=False, mode="generic", rng=np.random
):
    """Sample a sibling ratio from distribution.

    Args:
        params (dict): model parameters
        neurite_type (str): the neurite type to consider
        apply_asymmetry (bool): asymmetry of current branching point
        mode (str): to use or not the asymmetry_threshold

    Returns:
        float: sibling ratio
    """
    if mode == "generic":
        return sample_distribution(params["sibling_ratios"][neurite_type], rng=rng)
    if mode == "threshold":
        if apply_asymmetry:
            return 0.0
        return sample_distribution(params["sibling_ratios"][neurite_type], rng=rng)
    # This case should never happen since the mode is already checked in `_select_model`
    raise DiameterSynthesisError("mode not understood {}".format(mode))


def _sample_diameter_power_relation(
    params, neurite_type, apply_asymmetry=False, mode="generic", rng=np.random
):
    """Sample a diameter power relation from distribution.

    Args:
        params (dict): model parameters
        neurite_type (str): the neurite type to consider
        apply_asymmetry (bool): asymmetry of current branching point
        mode (str): to use or not the asymmetry_threshold

    Returns:
        float: diameter power relation
    """
    if mode == "generic":
        return sample_distribution(params["diameter_power_relation"][neurite_type], rng=rng)
    if mode == "threshold":
        if apply_asymmetry:
            return 1.0
        return sample_distribution(params["diameter_power_relation"][neurite_type], rng=rng)
    if mode == "exact":
        # This case should never happen since this mode is not known by `_select_model`
        return 1.0
    # This case should never happen since the mode is already checked in `_select_model`
    raise DiameterSynthesisError("mode not understood {}".format(mode))


def _sample_trunk_diameter(params, neurite_type, rng=np.random):
    """Sample a trunk diameter from distribution.

    Args:
        params (dict): model parameters
        neurite_type (str): the neurite type to consider

    Returns:
        float: trunk diameter
    """
    return sample_distribution(params["trunk_diameters"][neurite_type], rng=rng)


def _sample_terminal_diameter(params, neurite_type, rng=np.random):
    """Sample a terminal diameter.

    Args:
        params (dict): model parameters
        neurite_type (str): the neurite type to consider

    Returns:
        float: terminal diameter
    """
    return sample_distribution(params["terminal_diameters"][neurite_type], rng=rng)


def _sample_taper(params, neurite_type, rng=np.random):
    """Sample a taper rate from distributions.

    Args:
        params (dict): model parameters
        neurite_type (str): the neurite type to consider

    Returns:
        float: taper rate
    """
    return sample_distribution(params["tapers"][neurite_type], rng=rng)


def _sample_daughter_diameters(section, params, params_tree, rng=np.random):
    """Compute the daughter diameters of the current section.

    Args:
        section (morphio  section): section to consider
        params (dict): model parameters
        params_tree (dict): specific parameters of the current tree

    Returns:
       list: list of daughter diameters
    """
    # pylint: disable=too-many-locals
    major_sections = params_tree["major_sections"]

    apply_asymmetry = section.id in major_sections

    reduction_factor = params_tree["reduction_factor_max"] + 1.0
    # try until we get a reduction of diameter in the branching
    while reduction_factor > params_tree["reduction_factor_max"]:

        sibling_ratio = _sample_sibling_ratio(
            params,
            params_tree["neurite_type"],
            apply_asymmetry=apply_asymmetry,
            mode=params_tree["mode_sibling"],
            rng=rng,
        )

        diameter_power_relation = _sample_diameter_power_relation(
            params,
            params_tree["neurite_type"],
            apply_asymmetry=apply_asymmetry,
            mode=params_tree["mode_diameter_power_relation"],
            rng=rng,
        )

        reduction_factor = morph_funcs.diameter_power_relation_factor(
            diameter_power_relation, sibling_ratio
        )

    diam_0 = section.diameters[-1]
    terminal_diam = min(diam_0, params_tree["terminal_diam"])

    diam_1 = reduction_factor * diam_0
    diam_2 = sibling_ratio * diam_1

    diam_1 = max(diam_1, terminal_diam)
    diam_2 = max(diam_2, terminal_diam)

    diams = [diam_1] + (len(section.children) - 1) * [diam_2]

    if params_tree.get("with_asymmetry", False):
        # This case should always happen since the `with_asymmetry` attribute is always set to True
        # in `_select_model`

        # returns child diameters sorted by child length (major/secondary for apical tree)
        child_not_in_major = [child.id not in major_sections for child in section.children]
        if False in child_not_in_major:
            child_sort = np.argsort(child_not_in_major)
            return list(np.array(diams)[child_sort])

    # At the moment we don't have enough information to do better than a random choice in this case
    rng.shuffle(diams)
    return diams


def _diametrize_section(section, initial_diam, taper, min_diam=0.07, max_diam=10.0):
    """Diameterize a section.

    Args:
        section (morphio section): current section
        initial_diam (float): initial diameter
        taper (float): taper rate
        min_diam (flaot): minimum diameter
        max_diam (float): maximum diameter
    """
    diams = polynomial.polyval(morph_funcs.lengths_from_origin(section), [initial_diam, taper])
    section.diameters = np.clip(diams, min_diam, max_diam)


def _diametrize_tree(neurite, params, params_tree, rng=np.random):
    """Diametrize a tree, or neurite.

    Args:
        neurite (morphio neurite): current neurite
        params (dict): model parameters
        params_tree (dict): specific parameters of the current tree

    Returns:
        bool: True is all terminal diameters are small enough, False otherwise
    """
    params_tree["tot_length"] = morph_funcs.get_total_length(neurite)
    max_diam = params["terminal_diameters"][params_tree["neurite_type"]]["params"]["max"]
    wrong_tips = False
    active = deque([neurite[0]])
    while active:
        section = active.popleft()

        if section.is_root:
            init_diam = params_tree["trunk_diam"]
        else:
            init_diam = section.diameters[0]

        taper = _sample_taper(params, params_tree["neurite_type"], rng=rng)

        params_tree["terminal_diam"] = min(
            init_diam, _sample_terminal_diameter(params, params_tree["neurite_type"], rng=rng)
        )

        _diametrize_section(
            section,
            init_diam,
            taper=taper,
            min_diam=params_tree["terminal_diam"],
            max_diam=params_tree["trunk_diam"],
        )

        if len(section.children) > 0:
            diams = _sample_daughter_diameters(section, params, params_tree, rng=rng)

            for i, child in enumerate(section.children):
                utils.redefine_diameter_section(child, 0, diams[i])
                active.append(child)

        # if we are at a tip, check if tip diameters are small enough
        elif section.diameters[-1] > max_diam:
            wrong_tips = True

    return wrong_tips


def _diametrize_neuron(params_tree, neuron, params, neurite_types, config, rng=np.random):
    """Diametrize a neuron.

    Args:
        params_tree (dict): specific parameters of the current tree
        neuron (morphio.mut.Morphology): neuron to diametrize
        params (dict): model parameters
        neurite_type (str or morphio.SectionType): the neurite type to consider
        config (dict): general configuration parameters
    """
    # pylint: disable=too-many-locals
    major_sections = set()
    if params_tree["with_asymmetry"]:
        # Get sections on the major branch
        for apical_section in params.get("apical_point_sec_ids", []):
            for sec in neuron.sections[apical_section].iter(IterType.upstream):
                major_sections.add(sec.id)
    params_tree["major_sections"] = major_sections

    for neurite_type in neurite_types:
        if isinstance(neurite_type, str):
            morphio_neurite_type = STR_TO_TYPES[neurite_type]
        else:
            morphio_neurite_type = neurite_type
            neurite_type = TYPES_TO_STR[neurite_type]

        params_tree["neurite_type"] = neurite_type

        for neurite in _get_neurites(neuron, morphio_neurite_type):
            wrong_tips = True
            n_tries = 0
            trunk_diam_frac = 1.0
            n_tries_step = 1
            while wrong_tips:
                trunk_diam = trunk_diam_frac * _sample_trunk_diameter(params, neurite_type, rng=rng)

                if trunk_diam < 0.01:
                    trunk_diam = 1.0
                    L.warning("sampled trunk diameter < 0.01, so use 1 instead")

                params_tree["trunk_diam"] = trunk_diam
                wrong_tips = _diametrize_tree(neurite, params, params_tree, rng=rng)

                # if we can't get a good model, reduce the trunk diameter progressively
                if n_tries > N_TRIES_BEFORE_REDUC * n_tries_step:
                    trunk_diam_frac -= TRUNK_FRAC_DECREASE
                    n_tries_step += 1
                if n_tries > config["trunk_max_tries"]:
                    L.warning("max tries attained with %s", neurite_type)
                    wrong_tips = False
                n_tries += 1


def _select_model(model):
    """Select a diametrized model to use.

    Args:
        model (str): model name

    Returns:
        function: diamtrizer with specific params_tree
    """
    if model == "generic":
        params_tree = {}
        params_tree["mode_sibling"] = "threshold"
        params_tree["mode_diameter_power_relation"] = "threshold"
        params_tree["with_asymmetry"] = True
        params_tree["reduction_factor_max"] = 1.0
    elif model == "astrocyte":
        params_tree = {}
        params_tree["mode_sibling"] = "generic"
        params_tree["mode_diameter_power_relation"] = "generic"
        params_tree["with_asymmetry"] = True
        params_tree["reduction_factor_max"] = 3.0
    else:
        raise DiameterSynthesisError("Unknown diameter model: {}".format(model))

    return partial(_diametrize_neuron, params_tree)


def build(neuron, model_params, neurite_types, config, rng=np.random):
    """Builder function for generating diameters of a neuron from the a diameter models.

    Args:
        neuron (morphio.mut.Morphology): neuron to diametrize
        model_params (dict): model parameters
        neurite_type (str): the neurite type to consider
        config (dict): general configuration parameters
    """
    if "seed" in config:
        np.random.seed(config["seed"])
    _reset_caches()

    if len(config["models"]) > 1:
        L.warning("Several models provided, we will only use the first")
    diameter_generator = _select_model(config["models"][0])

    diameter_generator(neuron, model_params, neurite_types, config, rng=rng)
    if config["n_samples"] > 1:
        diameters = utils.get_all_diameters(neuron)
        for _ in range(config["n_samples"] - 1):
            diameter_generator(neuron, model_params, neurite_types, config, rng=rng)
            for i, new_diams in enumerate(utils.get_all_diameters(neuron)):
                diameters[i] += new_diams
        for i, _ in enumerate(diameters):
            diameters[i] /= config["n_samples"]
        utils.set_all_diameters(neuron, diameters)


def _save_first_diams(morphology, length):
    """Save diameters in a dict up to length."""
    diams = {}
    for root_section in morphology.root_sections:
        if root_section.type == SectionType.axon:
            dist = 0
            prev_point = root_section.points[0]
            for section in root_section.iter():
                _diams = copy(section.diameters)
                for point in section.points:
                    dist += np.linalg.norm(point - prev_point)
                    prev_point = copy(point)
                    if dist >= length:
                        break
                diams[section.id] = _diams
                if dist >= length:
                    break

    return diams


def _set_first_diams(morphology, diams, length):
    """Set diameters from a dict up to length."""
    for root_section in morphology.root_sections:
        if root_section.type == SectionType.axon:
            dist = 0
            prev_point = root_section.points[0]
            for section in root_section.iter():
                current_diams = copy(section.diameters)
                old_diams = diams[section.id]
                for i, point in enumerate(section.points):
                    dist += np.linalg.norm(point - prev_point)
                    prev_point = copy(point)
                    current_diams[i] = old_diams[i]
                    if dist >= length:
                        break
                section.diameters = current_diams
                if dist >= length:
                    break


def diametrize_axon(
    morphology,
    main_diameter=1.0,
    colateral_diameter=0.1,
    main_taper=-0.0005,
    axon_point_isec=None,
    ais_length=60,
    rng=np.random,
):
    """Diametrize axon in place without learning from reconstructed axons.

    The main axon branch (from soma to axon point) will have a tapered diameter, and the
    colaterals a constant diameter.

    If an axon point is not provided, and main_diameter > colateral_diameter, the diameters will
    decrease with taper and bifurcations, with hardcoded parameters sibling_ratio = 0.5 and
    diameter_power_relation = 0.5.

    Args:
        morphology (morphio.mut.Morphology): morpholoty to diametrize
        main_diameter (float): diameter of main axon branch (from soma to axon_point_isec)
        colateral_diameter (float): diameter of colateral branches
        main_taper (float): taper rate of main branch (set to 0 for no taper, should be negative)
        axon_point_isec (int): morphio section id of axon point (see morph_tool.axon_point module)
        ais_length (float): length of ais for which we keep original diameters
    """
    model_params = {
        "trunk_diameters": {
            "axon": {
                "distribution": "constant",
                "params": {"value": main_diameter},
            }
        },
        "terminal_diameters": {
            "axon": {
                "distribution": "constant",
                "params": {"value": colateral_diameter, "max": main_diameter},
            }
        },
        "tapers": {
            "axon": {
                "distribution": "constant",
                "params": {"value": main_taper},
            }
        },
        "sibling_ratios": {
            "axon": {
                "distribution": "constant",
                "params": {"value": 0.5},
            }
        },
        "diameter_power_relation": {
            "axon": {
                "distribution": "constant",
                "params": {"value": 5.0},
            }
        },
    }
    if axon_point_isec:
        model_params["apical_point_sec_ids"] = [axon_point_isec]

    config = {"models": ["generic"], "trunk_max_tries": 1, "n_samples": 1}

    diams = _save_first_diams(morphology, ais_length)
    build(morphology, model_params, neurite_types=["axon"], config=config, rng=rng)
    _set_first_diams(morphology, diams, ais_length)
