import pandas as pd
import numpy as np
import tqdm

import pyro as pyro
import numpy as np
import math
from pyro.infer.autoguide import AutoDelta
import pyro.poutine as poutine
import torch
import mobster
from mobster.utils_mobster import *
from mobster.stopping_criteria import *
import mobster.model_selection_mobster as ms
from mobster.calculate_posteriors import *
from pyro.util import ignore_jit_warnings


def fit_mobster(data, K, tail=1, truncated_pareto=True, subclonal_prior="Moyal", multi_tail=False, purity=0.96,
                number_of_trials_clonal_mean=500., number_of_trials_k=300.,
                number_of_trials_subclonal=500,
                alpha_precision_concentration=5, alpha_precision_rate=0.1,
                prior_lims_clonal=[0.1, 100000.], prior_lims_k=[0.1, 100000.],max_min_subclonal_ccf = [0.05,0.95], k_means_init = True, min_vaf_scale_tail = 0.1,stopping=all_stopping_criteria, lr=0.05,
                max_it=5000, e=0.001, compile=False, CUDA=False, seed=3, lrd_gamma=0.1):
    
    pyro.set_rng_seed(seed)

    if CUDA:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        data = {k:v.cuda() for k,v in data.items()}
    else:
        torch.set_default_tensor_type('torch.FloatTensor')

    if compile:
        loss = pyro.infer.JitTrace_ELBO
        
    else:
        loss = pyro.infer.Trace_ELBO
        

    model = mobster.model
    guide = mobster.guide

    lrd = lrd_gamma ** (1 / max_it)

    svi = pyro.infer.SVI(model=model,
                         guide=guide,
                         optim=pyro.optim.ClippedAdam({"lr": lr, "lrd": lrd, "betas" : (0.95, 0.999)
  }),
                         loss=loss())

    print('Running MOBSTER on {} karyotypes with {} subclones.'.format(len(data), K), flush=True)
    if tail == 1:
        print("Fitting a model with tail", flush=True)
    else:
        print("Fitting a model without tail", flush=True)
    params = {
        'K': K,
        'tail': tail,
        'truncated_pareto': truncated_pareto,
        'purity': purity,
        "subclonal_prior": subclonal_prior,
        'multi_tail': multi_tail,
        'alpha_precision_concentration': alpha_precision_concentration,
        'alpha_precision_rate': alpha_precision_rate,
        'number_of_trials_clonal_mean': number_of_trials_clonal_mean,
        'number_of_trials_k': number_of_trials_k,
        'prior_lims_clonal': prior_lims_clonal,
        'prior_lims_k': prior_lims_k,
        'number_of_trials_subclonal' : number_of_trials_subclonal,
        "max_min_subclonal_ccf" : max_min_subclonal_ccf,
        "k_means_init" : k_means_init,
        "min_vaf_scale_tail" : min_vaf_scale_tail
    }
    loss = run(data, params, svi, stopping, max_it, e)

    params_dict = ms.retrieve_params(CUDA)
    # params_dict = include_ccf(data, params_dict_noccf, K,purity)
    print("", flush=True, end="")
    print("Computing cluster assignements.", flush=True)
    params_dict, lk = retrieve_posterior_probs(data, truncated_pareto, params_dict, tail, purity, K, subclonal_prior,
                                               multi_tail, min_vaf_scale_tail)

    ### Caclculate information criteria
    print("Computing information criteria.", flush=True)
    likelihood = ms.likelihood(lk)
    AIC = ms.AIC(likelihood, params_dict)
    BIC = ms.BIC(likelihood, data, params_dict)
    ICL = ms.ICL(likelihood, data, params_dict, tail, params_dict)

    params_dict = format_parameters_for_export(data, params_dict, tail, K, purity, truncated_pareto, subclonal_prior,
                                               multi_tail)

    information_dict = {"likelihood": likelihood.detach().numpy(),
                        "AIC": AIC.detach().numpy(),
                        "BIC": BIC.detach().numpy(),
                        "ICL": ICL.detach().numpy()}

    final_dict = {
        "information_criteria": information_dict,
        "model_parameters": params_dict,
        "run_parameters": params,
        "loss": np.array(loss)
    }
    print("Done!\n", flush=True)

    return final_dict


def run(data, params, svi, stopping, max_it, e):
    N = ms.number_of_samples(data)
    data_dict = params.copy()
    data_dict["data"] = data
    pyro.clear_param_store()
    loss = new = svi.step(**data_dict)

    new_w = retrieve_params()

    losses = []
    t = trange(max_it, desc='Bar desc', leave=True)
    for i in t:

        t.set_description('ELBO: {:.9f}  '.format(loss / N))
        t.refresh()

        loss = svi.step(**data_dict)
        losses.append(loss)

        old_w, new_w = new_w, retrieve_params()

        if stopping(old_w, new_w, e):
            break

    return losses
