"""
.. module:: plot_slab
   :synopsis: Plot routines for prodimo slab models.

.. moduleauthor:: A. M. Arabhavi


"""


import numpy as np
import pandas as pd
from scipy.constants import h,c,k
from adjustText import adjust_text
import matplotlib.pyplot as plt
from prodimopy.read_slab import slab_data,slab
import matplotlib.colors as mcolors
import matplotlib.cm as cm
sci_c = c*1.00

class default_figsize:
    def __init__(self):
        self.width = 7
        self.height = 5

defFigSize = default_figsize()

def set_default_figsize(w=defFigSize.width,h=defFigSize.height):
    defFigSize.width,defFigSize.height = w,h
    return
def _get_set_fig(ax=None,fig=None,figsize=None):
    if ax == None:
        if fig == None:
            fig,ax = plt.subplots(figsize=(defFigSize.width,defFigSize.height))
        else:
            ax = fig.add_subplot()
    else:
        if fig == None:
            fig = ax.get_figure()
        else:
            fig2 = ax.get_figure()
            if fig2!=fig:
                print('WARNING, passed figure and axis do not match')
                fig = ax.get_figure()

    if figsize != None:
        fig.set_size_inches(figsize[0], figsize[1], forward=True)
    return(fig,ax)

def plotBoltzmannDiagram(dat,ax=None,fig=None,figsize=None,NLTE=False,label=None,s=0.1,c='k',set_axis_limits=True):
    """
    This function plots the Boltzmann diagram (reduced flux vs upper energy level)
    """
    fig,ax = _get_set_fig(ax,fig,figsize)
    data_list = []
    label_list = []
    s_list = []
    c_list = []
    if isinstance(dat,slab):
        data_list.append(dat)
        label_list.append(label)
        s_list.append(s)
        c_list.append(c)
    elif isinstance(dat,list) or isinstance(dat,slab_data):
        if isinstance(dat,slab_data): 
            data_list = dat.models
        else:
            data_list = dat
        if isinstance(label,list):
            label_list = label
        elif isinstance(label,str):
            for i in range(len(data_list)):
                label_list.append(label)
        elif isinstance(label,type(None)):
            for i in range(len(data_list)):
                label_list.append('')
        if isinstance(s,list):
            s_list = label
        elif isinstance(s,float) or isinstance(s,int):
            for i in range(len(data_list)):
                s_list.append(s)
        if isinstance(c,list):
            c_list = c
        elif isinstance(c,str):
            for i in range(len(data_list)):
                c_list.append(c)
    else:
        raise ValueError('The data passed should be of type "slab" or a list containing "slab"s ')
    max_Eu,min_Eu = 1e-99,1e99
    max_F,min_F = 1e-99,1e99
    FTag = 'FLTE'
    if NLTE:
        FTag = 'FNLTE'
    for i,Dat in enumerate(data_list):
        F  = Dat.linedata[FTag]
        nu = Dat.linedata['GHz']
        A  = Dat.linedata['A']
        gu = Dat.linedata['gu']
        Eu = Dat.linedata['Eu']
        RedF = F/(nu*A*gu*1e9)
        max_Eu = np.amax([np.amax(Eu),          max_Eu])
        max_F  = np.amax([np.amax(RedF),        max_F])
        min_Eu = np.amin([np.amin(Eu),          min_Eu])
        min_F  = np.amin([np.amin(RedF[RedF>0]),min_F])
        ax.scatter(Eu,RedF,s=s_list[i],label=label_list[i],c=c_list[i])
    if not(label_list[0] is None):
        ax.legend()
    ax.set_yscale('log')
    if set_axis_limits:
        ax.set_xlim([min_Eu*0.9,max_Eu*1.02])
        ax.set_ylim([min_F*10**-0.5,max_F*10**0.5])
    ax.set_xlabel('Eu')
    ax.set_ylabel(r'$F/(\nu A g_u)$')
#         print([np.amin(RedF[RedF>0])*0.9,np.amax(RedF)*1.02])
#         print([np.amin(RedF[RedF>0])*10**-0.5,np.amax(RedF)*10**0.5])
    return(fig,ax)

def plotLevelDiagram(data,ax=None,figsize=(10,18),seed=None,lambda_0=None,lambda_n=None,width_nlines=False):
    """
    This function plots level diagram for a single slab model
    """
    
    if(seed==None):
        seed = np.random.randint(0,2**25-1)
    np.random.seed(seed)
    try:
        lineData = data.linedata
        levData = data.leveldata
    except:
        levData = data[0]
        lineData = data[1]
    if lambda_0!=None:
        reqLineData = lineData[c/lineData['GHz']*1e-3>lambda_0]
        lineData = reqLineData
    if lambda_n!=None:
        reqLineData = lineData[c/lineData['GHz']*1e-3<lambda_n]
        lineData = reqLineData
    levData = levData.set_index('i')
    lineData = lineData.set_index('i')
    
    selU = np.asarray(list(set(lineData['u'])))
    selL = np.asarray(list(set(lineData['l'])))
    selGU = np.asarray(list(set(lineData['global_u'])))
    selGL = np.asarray(list(set(lineData['global_l'])))
    
    levArr= []
    for i in selGU:
        selLines = lineData[lineData['global_u']==i]
        lev= np.amin(levData.loc[selLines['u']]['E'])
        levArr.append([i,lev])

    for i in selGL:
        selLines = lineData[lineData['global_l']==i]
        lev= np.amin(levData.loc[selLines['l']]['E'])
        levArr.append([i,lev])
        
    levArr = pd.DataFrame(levArr,columns=['Level','E']).drop_duplicates().sort_values(by=['E'])
    linArr = pd.DataFrame([lineData['global_u'],lineData['global_l']]).transpose().drop_duplicates()
    levArr['err'] = np.random.uniform(0.05,0.95,len(levArr))
    
    text = []
    if ax==None:
        fig,ax = plt.subplots(figsize=figsize)
    else:
        ax = plt.gca()
        fig = ax.get_figure()
    for i in range(len(levArr['E'])):
        ax.axhline(y=levArr.iloc[i]['E'],xmin=levArr.iloc[i]['err']-0.05,xmax=levArr.iloc[i]['err']+0.05,color='k')
        t = ax.text(levArr.iloc[i]['err'],levArr.iloc[i]['E'],levArr.iloc[i]['Level'])
        text.append(t)
    widths = []
    for i in range(len(linArr)):
        gu = linArr.iloc[i]['global_u']
        gl = linArr.iloc[i]['global_l']
        widths.append(np.sum(lineData[(lineData['global_u']==gu)&(lineData['global_l']==gl)]['FLTE']))
    widths=np.array(widths)/np.amax(widths)
    if width_nlines:
        widths = []
        for i in range(len(linArr)):
            gu = linArr.iloc[i]['global_u']
            gl = linArr.iloc[i]['global_l']
            widths.append(len(lineData[(lineData['global_u']==gu)&(lineData['global_l']==gl)]['FLTE']))
        widths=np.array(widths)/np.amax(widths)        
    for i in range(len(linArr)):
        gu = linArr.iloc[i]['global_u']
        gl = linArr.iloc[i]['global_l']
        yu = levArr[levArr['Level']==gu]['E']
        yl = levArr[levArr['Level']==gl]['E']
        eu = levArr[levArr['Level']==gu]['err']
        el = levArr[levArr['Level']==gl]['err']
        ax.plot([el.iloc[0],eu.iloc[0]],[yl.iloc[0],yu.iloc[0]],lw=widths[i]*2.5+0.1)
    
    ax.set_xlim([0,1])
    adjust_text(text)
    ax.set_ylabel('Energy (K)')
    ax.axes.xaxis.set_ticklabels([])
    ax.tick_params(top=False,bottom=False)
    
    print('Random seed = ',seed)
    return(fig,ax,seed)

def plot_lines(dat, normalise = False, fig=None, ax=None, overplot=False, c=None, cmap=None, colors=None, figsize=None, NLTE=False, label='', lw=1, scaling=1,offset=0):
    """
    This function plots total line fluxes (erg/s/cm2/sr)
    """
    if isinstance(dat,slab_data):
        dat = dat.models
    elif isinstance(dat,slab):
        dat = [dat]
    elif isinstance(dat,list):
        pass
    else:
        raise ValueError('Wrong input slab data')
    color_list = []
    fig_list = []
    ax_list = []
    offset_list = []
    scaling_list = []
    label_list = []
    if isinstance(label,list):
        label_list = label
    else:
        for i in range(len(dat)):
            if label=='': 
                label_list.append(i)
            else:
                label_list.append(label)
    if c is None: c='k'
                
    if isinstance(offset,float) or isinstance(offset,int):
        for i in range(len(dat)):
            offset_list.append(offset*i)
    elif isinstance(offset,list):
        offset_list = offset
    else:
        raise ValueError('offset takes only int, float or list of int or float')
    
    if isinstance(scaling,float) or isinstance(scaling,int):
        for i in range(len(dat)):
            scaling_list.append(scaling)
    elif isinstance(scaling,list):
        scaling_list = scaling
    else:
        raise ValueError('scaling takes only int, float or list of int or float')
    
    if overplot:
        fig,ax = _get_set_fig(ax,fig,figsize)
        values = range(len(dat))
        if cmap is None: cmap = 'jet'
        jet = cmm = plt.get_cmap(cmap) 
        cNorm  = mcolors.Normalize(vmin=0, vmax=values[-1])
        scalarMap = cm.ScalarMappable(norm=cNorm, cmap=jet)
        fig_list.append(fig)
        ax_list.append(ax)
        for i in range(len(dat)):
            color_list.append(scalarMap.to_rgba(i))
    else:
        for i in range(len(dat)):
            figg,aax = _get_set_fig(ax,fig,figsize)
            fig_list.append(figg)
            ax_list.append(aax)
            color_list.append(c)
    
    if overplot:
        for i,slb in enumerate(dat):
            if NLTE:
                t = 'FNLTE'
            else:
                t = 'FLTE'
            fig,ax = _basic_line_plot(sci_c/slb.linedata['GHz']*1e-3, slb.linedata[t], scaling=scaling_list[i], normalise=normalise, fig=fig_list[0], ax=ax_list[0], figsize=figsize, label=label_list[i], color=color_list[i], lw=lw,offset=offset_list[i])
        return(fig,ax)
    else:
        for i,slb in enumerate(dat):
            if NLTE:
                t = 'FNLTE'
            else:
                t = 'FLTE'
            fig_list[i],ax_list[i] = _basic_line_plot(sci_c/slb.linedata['GHz']*1e-3, slb.linedata[t], scaling=scaling_list[i], normalise=normalise, fig=fig_list[i], ax=ax_list[i], figsize=figsize, label=label_list[i], color=color_list[i], lw=lw,offset=offset_list[i])
        return([(fig,ax) for (fig,ax) in zip(fig_list,ax_list)])
        
        
def _basic_line_plot(x,y,normalise=False,fig=None, ax=None, scaling=1, figsize=None, label='', color='k', lw=1,offset = 0.0):
    fig,ax = _get_set_fig(ax,fig,figsize)
    if x is None or y is None:
        raise ValueError('x or y is None')
    if normalise:
        y = y/np.amax(y)
    ax.vlines(x, 0.0+offset, y*scaling+offset, label = label, color=color, lw=lw)
    ax.set_xlabel('Wavelength [microns]')
    ax.legend()
    return(fig,ax)


def plot_spectra(dat, normalise = False, fig=None, ax=None, overplot=False, add=False, cmap=None, colors=None, style='step', figsize=None, NLTE=False, label='', lw=1, c=None, scaling=1,sampling=1,offset=0):
    """
    This function plots convolved spectra (erg/s/cm2/sr)
    """

    if isinstance(dat,slab_data):
        dat = dat.models
    elif isinstance(dat,slab):
        dat = [dat]
    elif isinstance(dat,list):
        pass
    else:
        raise ValueError('Wrong input slab data')
    color_list = []
    fig_list = []
    ax_list = []
    offset_list = []
    scaling_list = []
    label_list = []
    if isinstance(label,list):
        label_list = label
    else:
        for i in range(len(dat)):
            if label=='': 
                label_list.append(i)
            else:
                label_list.append(label)
    if c is None: c='k'
                
    if isinstance(offset,float) or isinstance(offset,int):
        for i in range(len(dat)):
            offset_list.append(offset*i)
    elif isinstance(offset,list):
        offset_list = offset
    else:
        raise ValueError('offset takes only int, float or list of int or float')
    
    if isinstance(scaling,float) or isinstance(scaling,int):
        for i in range(len(dat)):
            scaling_list.append(scaling)
    elif isinstance(scaling,list):
        scaling_list = scaling
    else:
        raise ValueError('scaling takes only int, float or list of int or float')
    
    
    for d in dat:
        if d.convWavelength is None:
            raise ValueError('The model is not convolved, please use .convolve() method')
    add_wave = dat[0].convWavelength
    add_flux = add_wave*0.0
    if add:
        for i,d in enumerate(dat):
            if not np.array_equal(add_wave, d.convWavelength):
                raise ValueError('Convolved wavelength grid not same for molecules within a model')
        if NLTE:
            for i,d in enumerate(dat):
                add_flux+=d.convNLTEflux
        else:
            for i,d in enumerate(dat):
                add_flux+=d.convLTEflux
        fig,ax = _get_set_fig(ax,fig,figsize)
        ind = int(1e6/dat[0].convR/sampling)
        if ind>len(add_flux): ind = len(add_flux)
        if ind == 0:   ind = 1
        fig,ax = _basic_spectra_plot(add_wave[::ind], add_flux[::ind], scaling=scaling, normalise=normalise, fig=fig, ax=ax, style=style, figsize=figsize, label=label, color=c,lw=lw, where='mid')
        return(fig,ax)
    
    if overplot:
        fig,ax = _get_set_fig(ax,fig,figsize)
        values = range(len(dat))
        if cmap is None: cmap = 'jet'
        jet = cmm = plt.get_cmap(cmap) 
        cNorm  = mcolors.Normalize(vmin=0, vmax=values[-1])
        scalarMap = cm.ScalarMappable(norm=cNorm, cmap=jet)
        fig_list.append(fig)
        ax_list.append(ax)
        for i in range(len(dat)):
            color_list.append(scalarMap.to_rgba(i))
    else:
        for i in range(len(dat)):
            figg,aax = _get_set_fig(ax,fig,figsize)
            fig_list.append(figg)
            ax_list.append(aax)
            color_list.append(c)
            
    if overplot:
        for i,d in enumerate(dat):
            add_wave = d.convWavelength
            if NLTE:
                add_flux=d.convNLTEflux
            else:
                add_flux=d.convLTEflux
            ind = int(1e6/d.convR/sampling)
            if ind>len(add_flux): ind = len(add_flux)
            if ind == 0:   ind = 1
            fig_list[0],ax_list[0] = _basic_spectra_plot(add_wave[::ind], add_flux[::ind], scaling=scaling_list[i], offset=offset_list[i],  normalise=normalise, fig=fig_list[0], ax=ax_list[0], style=style, figsize=figsize, label=label_list[i], color=color_list[i],lw=lw, where='mid')
        return(fig_list[0],ax_list[0])
    else:
        for i,d in enumerate(dat):
            add_wave = d.convWavelength
            if NLTE:
                add_flux=d.convNLTEflux
            else:
                add_flux=d.convLTEflux
            ind = int(1e6/d.convR/sampling)
            if ind>len(add_flux): ind = len(add_flux)
            if ind == 0:   ind = 1
            fig_list[i],ax_list[i] = _basic_spectra_plot(add_wave[::ind], add_flux[::ind], scaling=scaling_list[i], offset=offset_list[i],  normalise=normalise, fig=fig_list[i], ax=ax_list[i], style=style, figsize=figsize, label=label_list[i], color=color_list[i],lw=lw, where='mid')
        return([(fig,ax) for (fig,ax) in zip(fig_list,ax_list)])
            
        
def _basic_spectra_plot(x, y, normalise=False, fig=None, ax=None, scaling=1, offset=0.0, figsize=None, label='', color='k', style='step', lw=1, where='post'):
    if x is None or y is None:
        raise ValueError('x or y is None, mostly the data is not convolved. Please use .convolve() method first')
    fig,ax = _get_set_fig(ax,fig,figsize)
    if normalise:
        y = y/np.amax(y)
    if style=='step':
        ax.step(x, y*scaling+offset, lw=lw, label = label, color=color, where=where)
    else:
        ax.plot(x, y*scaling+offset, lw=lw, label = label, color=color)
    ax.set_xlabel('Wavelength [microns]')
    ax.legend()
    return(fig,ax)