#!/usr/bin/env python

#=========================================
#
# Analyze results from Spirit
# Results of spin configurations
# mainly stored in ovf format
#
# Shunhong Zhang
#=========================================

import numpy as np
from asd.utility.spirit_tool import *
from asd.utility.asd_arguments import *
import os
import glob
import pickle
import matplotlib.pyplot as plt

import_spirit_err='''
\nNotes from {0}
fail to import modules from Spirit
{0} cannot be used
other scirpts are not affected
Please install Spirit properly
if you want ot use this script\n'''.format(__file__.split('/')[-1])

try:    from spirit import state,system,geometry,parameters
except: exit(import_spirit_err)


def collect_confs_from_ovfs(outdir,prefix,parse_ovf_method='pyasd'):
    confs=[]
    fils = glob.glob('{}/{}*Spins_*.ovf'.format(outdir,prefix))
    if len(fils)==0: exit('Cannot find ovf files in the directory\n{}'.format(outdir))
    indices = sorted([int(fil.split('_')[-1].rstrip('.ovf')) for fil in fils])
    nn = len(fils[-1].rstrip('.ovf').split('_')[-1])+1
    for idx in indices:
        found=False
        for n in range(nn):
            fil_key = '{}/*Spins_{}.ovf'.format(outdir,str(idx).zfill(n))
            fils= glob.glob(fil_key)
            if len(fils)>0: 
                fil_ovf = fils[0]
                found=True
                break
        if found: 
            if parse_ovf_method=='ovf':   conf = parse_ovf_1(fil_ovf)[1] 
            if parse_ovf_method=='pyasd': conf = parse_ovf(fil_ovf)[1]
            confs.append( conf )
        else: exit('\nCannot find ovf file with prefix {} under directory\n{}'.format(prefix,outdir))
    return np.array(confs)
 

def get_GNEB(outdir):
    fil_en = glob.glob('{}/*Chain_Energies-interpolated-final.txt'.format(outdir))[0]
    lines = open(fil_en).readlines()[3:]
    data = [line.split() for line in lines]
    Rx = np.array([d[4] for d in data],float)
    Etot = np.array([d[6] for d in data],float)
    Etot -= Etot[0]
    return Rx,Etot


def plot_GNEB(outdir,show=False):
    Rx,Etot = get_GNEB(outdir)
    fig,ax=plt.subplots(1,1)
    ax.plot(Rx,Etot)
    ax.scatter(Rx[::10],Etot[::10],c='r')
    ax.scatter(Rx[-1],Etot[-1],c='r')
    emin = np.min(Etot)
    emax = np.max(Etot)
    erange = emax - emin
    ax.set_ylim(emin-erange*0.1,emax+erange*0.1)
    if erange<1e-4: ax.set_ylim(-0.1,0.1)
    ax.set_ylabel('E (meV)')
    ax.set_xlabel('Reaction coord')
    fig.tight_layout()
    if show: plt.show()
    return fig,ax


def get_params_from_cfg(fil_cfg,nx=0,ny=0,nz=0,lat_type=1,dt=0):
    lattice_constant = 1
    try:
        line = os.popen('grep lattice_constant {}'.format(fil_cfg)).readline()
        lattice_constant = float(line.rstrip('\n').split()[-1])
    except:
        pass
    with state.State(fil_cfg,quiet=True) as p_state:
        if fil_cfg=='': geometry.set_bravais_lattice_type(p_state, lat_type)
        if nx*ny*nz!=0: geometry.set_n_cells(p_state,[nx,ny,nz])
        pos=geometry.get_positions(p_state)
        nx,ny,nz=geometry.get_n_cells(p_state)
        nat=geometry.get_n_cell_atoms(p_state)
        latt=geometry.get_bravais_vectors(p_state)
        if dt==0: dt = parameters.llg.get_timestep(p_state)
        latt = lattice_constant*np.array(latt)
        np.savetxt('Positions.dat',pos,fmt='%10.5f')
        return latt,pos,nx,ny,nz,nat,dt


def analyze_GNEB_results(scatter_size=10,parse_ovf_method='pyasd'):
    Rx, Etot = plot_GNEB(outdir)

    fil_ovf_init = glob.glob('{}/*Spins-initial.ovf'.format(outdir))[0]
    fil_ovf_finl = glob.glob('{}/*Spins-final.ovf'.format(outdir))[0]

    print ('\nExtracting configurations for first and last images from these files')
    print (fil_ovf_init)
    print (fil_ovf_finl)
    if parse_ovf_method=='ovf':
        params,spins_init = parse_ovf_1(fil_ovf_init,parse_params=True)
        params,spins_final = parse_ovf_1(fil_ovf_finl)
    if parse_ovf_method=='pyasd':
        params,spins_init = parse_ovf(fil_ovf_init,parse_params=True)
        params,spins_final = parse_ovf(fil_ovf_finl)


    nx = params['xnodes']
    ny = params['ynodes']
    nz = params['znodes']

    latt,pos,nx,ny,nz,nat,dt = get_params_from_cfg(fil_cfg,nx,ny,nz)

    print (spins_init.shape,spins_final.shape,scatter_size)
    plot_spin_2d(pos,spins_init, scatter_size=scatter_size,title='initial')
    plot_spin_2d(pos,spins_final,scatter_size=scatter_size,title='final'  )
    plt.show()

    fil = sorted(glob.glob('{}/*Chain*final.ovf'.format(outdir)))[0]
    print ('confs from file {}'.format(fil))
    Iter = np.array([line.split()[-1] for line in os.popen('grep Iteration {}'.format(fil)).readlines()],int)
    titles = ['Iter = {:10d}'.format(tt) for tt in Iter]
    params,confs = parse_ovf_ovf(fil,parse_params=False)
    make_ani(pos,confs,scatter_size=scatter_size,titles=titles,interval=5e3)


def get_energy_from_txt(start_conf=0,outdir='.',prefix='temp'):
    fil = glob.glob('{}/{}*Energy-archive.txt'.format(outdir,prefix))
    assert len(fil)>0, 'Energy file out found!'
    lines = open(fil[0]).readlines()[3:][start_conf:]
    iters = np.array([line.split()[0] for line in lines],float)
    ens = np.array([line.split()[2] for line in lines],float)
    return iters,ens


def analyze_LLG_results(args,lat_type=1,
    quiver_kws=None):

    spin_plot_kwargs = get_spin_plot_kwargs(args)
    spin_anim_kwargs = get_spin_anim_kwargs(args)
    spin_plot_kwargs.update(quiver_kws=quiver_kws)
    spin_anim_kwargs.update(quiver_kws=quiver_kws)


    outdir = args.outdir
    if outdir=='.': outdir=os.getcwd()
    print ('spirit input file: {}\n'.format(args.spirit_input_file))
    latt,pos,nx,ny,nz,nat,dt = get_params_from_cfg(args.spirit_input_file,args.nx,args.ny,args.nz,dt=args.dt)
    print ('\nGet timestep = {:8.4f} ps from {}'.format(dt,args.spirit_input_file))
    print ('If this is incorrect, you can use the option like')
    print ('"--dt=0.01" (in ps) to overwrite the time step.\n')

    def display_snapshot(pos,latt,outdir,prefix,status='initial',show=False):
        superlatt = np.dot(np.diag([nx,ny]),latt)
        fils=glob.glob('{}/{}*Spins-{}.ovf'.format(outdir,prefix,status))
        if len(fils)>0: 
            fil_ovf = fils[0]
            print ('\nDisplay config of file {}'.format(fil_ovf))
            if args.parse_ovf_method=='ovf':
                spins = parse_ovf_1(fil_ovf)[1]
            elif args.parse_ovf_method=='pyasd':
                params,spins = parse_ovf(fil_ovf)
            conf = spins.reshape(ny,nx,nat,3)
            conf = np.swapaxes(conf,0,1)
            spin_plot_kwargs.update(title=status,superlatt=superlatt,
            show=show,save=True,figname='{}_spin_conf'.format(status))
            plot_spin_2d(pos,conf, **spin_plot_kwargs)
        else:
            print ('\nFile for display not found at\n{}'.format(outdir))


    def plot_energy(iters,dt,ens,show=True):
        assert dt>0, 'Time step should be positive!'
 
        diff = iters[1:] - iters[:-1]
        if np.min(diff)<=0:
            print ('Iters are non-monotonic.')
            print ('This is usually because you have multiple runs')
            print ('And store the energies in the same file')
            seg_idx = np.where(diff<0)[0]
            seg_idx = np.append(0,seg_idx)
            seg_idx = np.append(seg_idx,len(iters))

            fig,ax=plt.subplots(len(seg_idx)-1,1,sharex=True)
            for ii in range(len(seg_idx)-1):
                segment = np.arange(seg_idx[ii]+1,seg_idx[ii+1])
                ax[ii].plot(iters[segment]*dt,ens[segment])
                ax[ii].set_ylabel('E (meV/site)')
            ax[-1].set_xlabel('t (ps)')
        else:
            fig,ax=plt.subplots(1,1)
            ax.plot(iters*dt,ens)
            ax.set_xlabel('t (ps)')
            ax.set_ylabel('E (meV/site)')
        fig.tight_layout()
        fig.savefig('LLG_energies_profile',dpi=300)
        if show: plt.show()

    try: 
        iters,ens = get_energy_from_txt(0,outdir,args.prefix)
        plot_energy(iters,dt,ens)
    except:
        print ('Fail to read energy from txt file, skip plotting energy')

    latt = np.array(latt)[:2,:2].T
    pos = np.loadtxt('Positions.dat')
    if nx*ny*nz!=0:
        pos=pos.reshape(ny,nx,-1,3)
        pos = np.swapaxes(pos,0,1)
        display_snapshot(pos,latt,outdir,args.prefix,'initial')
        display_snapshot(pos,latt,outdir,args.prefix,'final',show=True)
    os.remove('Positions.dat')

    if args.make_ani:
        titles = None
        if args.pick_confs: 
            if not os.path.isfile(args.confs_pickle):
                print ('You set pick_confs=True, but {} not found'.format(args.confs_pickle))
                print ('Use --dump_confs to dump the configurations to the pickle file')
                print ('Next time you can pick up them directly')
                exit ('See you!')
            confs = pickle.load(open(args.confs_pickle,'rb'))
            fil = glob.glob('{}/{}*Spins-archive.ovf'.format(outdir,args.prefix))
            if len(fil)==1:
                fil = fil[0]
                Iter = np.array([line.split()[-1] for line in os.popen('grep Iteration {}'.format(fil)).readlines()],int)
                titles = ['t = {:10.4f} ps'.format(tt) for tt in Iter*dt]
 
        else:
            fil = glob.glob('{}/{}*Spins-archive.ovf'.format(outdir,args.prefix))
            if len(fil)==1:
                fil = fil[0]
                print ('\nSpin configs from achive file {}'.format(fil))
                Iter = np.array([line.split()[-1] for line in os.popen('grep Iteration {}'.format(fil)).readlines()],int)
                if args.parse_ovf_method=='ovf':
                    confs = parse_ovf_1(fil,parse_params=False)[1]
                elif args.parse_ovf_method=='pyasd':
                    confs = parse_ovf(fil,parse_params=False)[1]
                titles = ['t = {:10.4f} ps'.format(tt) for tt in Iter*dt]
            else:
                confs = collect_confs_from_ovfs(outdir,args.prefix)
            confs = confs.reshape(-1,ny,nx,nat,3)
            if args.dump_confs: 
                print ('Dump configurations to {}'.format(args.confs_pickle))
                pickle.dump(confs,open(args.confs_pickle,'wb'))
        if args.topo_chg:
            from asd.core.topological_charge import calc_topo_chg
            print ('Set topo_chg=T, calculate the topological charge')
            print ('This might take several minutes')
            Qs = [calc_topo_chg(conf,pos[...,:2]) for conf in confs]
            titles = ['{}, Q = {:6.3f}'.format(tl,Q) for (tl,Q) in zip(titles,Qs)]
        confs = np.swapaxes(confs,1,2)
        spin_anim_kwargs.update(latt=latt)
        make_ani(pos, confs, titles=titles, **spin_anim_kwargs)
    
    return 1


def gen_args():
    import argparse
    prog='analyze_Spirit_results.py'
    description = 'post-processing of Spirit LLG simulations'
    parser = argparse.ArgumentParser(prog=prog,description=description)
    add_switch_arguments(parser)
    add_llg_arguments(parser)
    add_spirit_arguments(parser)
    add_quiver_arguments(parser)
    add_spin_plot_arguments(parser)
    add_common_arguments(parser)
    args = parser.parse_args()
    return args


if __name__=='__main__':
    from asd.utility.head_figlet import pkg_info
    code_info = pkg_info()
    code_info.verbose_head()

    args = gen_args()
    quiver_kws = dict([(k.split('_')[1],v) for (k,v) in vars(args).items() if k.startswith('quiver')])
    quiver_kws.update(pivot='mid',units='x')
    if args.verbose_qv_kws:
        print ('keyword arguments for quivers\n')
        for key in quiver_kws.keys(): print ('{:>15s} = {}'.format(key,quiver_kws[key]))

    print ('\ntask = {}\n'.format(args.job))
    if args.job=='llg':     analyze_LLG_results(args,quiver_kws=quiver_kws)
    elif args.job=='gneb':  analyze_GNEB_results(parse_ovf_method=args.parse_ovf_method)
