class ALD:  
    
    def __init__(self, 
                 outdir = './'
                 ):

        '''Give path to output folder, all the files should be as same as files generated by ALD code, 
        all the returning lists will be in the order of harmonic_properties.dat, unless it is mentioned, like in case of cummulative lists'''
        
        from ase import Atoms
        from ase.io.trajectory import Trajectory
        import numpy as np
        import re
        
        self.outdir = outdir
        self.primcell = self.load_data(f'{self.outdir}/cell_primitive.dat',head=1)
        self.harmprop = self.load_data(f"{self.outdir}/harmonic_properties.dat", head=1)
        self.eigdata = self.load_data(f'{self.outdir}/eigenVector.dat', head=1)
        self.natoms = len(self.get_scaled_positions())
        self.nqpts = int((len(self.harmprop))/int(3*self.natoms))
        self.nlength = float(self.get_parameters('nlength'))
   
    def get_parameter(self, parameter):
        
        f = open(f'{self.outdir}/parameters.dat', 'r')
        lines = f.readlines()
        f.close()
            
        for i in lines:
            if parameter.lower() in i.lower():
                par = i.split(':')[-1][:-1]
                break
        try:
            return par
        except:
            print(f'{parameter} is not specified in parameters.dat file')
            return 0

    def load_data(self, filename, head=0):
        import numpy as np
        f = open(filename, 'r')
        lines = [i.split() for i in f.readlines()[head:]]
        f.close()

        for n, i in enumerate(lines):
            for o, j in enumerate(i):
                try:
                    lines[n][o] = float(j)
                except:
                    lines[n][o] = j

        length = max(map(len, lines))
        y=np.array([xi+[None]*(length-len(xi)) for xi in lines])

        return y

    
    def flatten(self, list_of_lists):
            
        if len(list_of_lists) == 0:
            return list_of_lists
        if isinstance(list_of_lists[0], list):
            return self.flatten(list_of_lists[0]) + self.flatten(list_of_lists[1:])
        return list_of_lists[:1] + self.flatten(list_of_lists[1:])

    def get_cell(self):
        '''
        Returns cell parameters as per primitive.dat '''
        return self.primcell[0:3, 0:3]*self.nlength

    def get_direct_qpts(self):

        '''
        Returns directs qpts from harmonic_properties.dat, in the same order as in file '''
        import numpy as np
        qpts = []
        for i in self.harmprop:
            if i[0] != len(qpts)-1 or len(qpts) == 0:
                qpts.append(i[2:5])

        return np.array(qpts, dtype=float)
    
    def get_cartesian_qpts(self):
        '''
        Returns crtesian qpts from harmonic_properties.dat, in the same order as in file '''
        import numpy as np
        qpts = []
        for i in self.harmprop:
            if i[0] != len(qpts)-1 or len(qpts) == 0:
                qpts.append(i[5:8])

        return np.array(qpts, dtype=float)

    def get_elements(self):
        '''
        Returns list of elements from primitive.dat'''
        return self.primcell[3:,1]

    def get_scaled_positions(self):
        return self.primcell[3:, 2:5]
    
    def get_masses(self):
        return self.primcell[3:, 8]
    
    def get_cartesian_positions(self):
        return self.primcell[3:, 5:8]
    
    def get_eigvecs(self):
        
        import re
        import numpy as np
        
        eigvecs = []
        eigvecdata1 = self.load_data(f'{self.outdir}/eigenVector.dat', head=1)        
        eigvecdata2 = np.array([self.flatten([re.findall(r"[-+]?\d*\.\d+|\d+", str(s)) for s in s1]) for s1 in eigvecdata1], dtype=float)
        
        real = np.array([[[0]*self.natoms*3]*self.natoms*3]*self.nqpts, dtype=float)
        imaginary = np.array([[[0]*self.natoms*3]*self.natoms*3]*self.nqpts, dtype=float)
        
        for i in eigvecdata2:
            qid = int(float(i[0]))
            modeid = int(float(i[1]))
            r = [i[j] for j in range(3, len(i)) if j%2 == 1]
            im = [i[j] for j in range(3, len(i)) if j%2 == 0]
            real[qid][modeid] = r
            imaginary[qid][modeid] = im    
        
        return real, imaginary
    
    def get_frequencies(self):
        return self.harmprop[:,8]
    
    def get_heat_capacities(self):
        sps = self.harmprop[:,9]
        import numpy as np
        sps[np.isnan(sps)]=0
        return sps
    
    def get_group_velocities(self):
        import numpy as np
        vgs = self.harmprop[:,10:13]
        vgs[np.isnan(vgs)] = 0
        return vgs
    
    def get_gruneissen_parameters(self):
        return self.harmprop[:,14:]
    
    def get_mean_free_paths(self):
        import numpy as np
        return np.array([np.sqrt(np.sum(i**2)) for i in self.get_group_velocities()])*self.get_life_times()
                             
    def get_life_times(self, index=-1):
        import numpy as np
        rta = self.load_data(f'{self.outdir}/phonon_RTA.dat', head=1)
        rta[np.isinf(rta)]=0
        return rta[:,index]   
    
    def get_band_structure(self):
        '''
        Returns dists, bands, velocities, GPs
        '''
        import numpy as np
        data = self.load_data(f'{self.outdir}/directional_harmonic_properties.dat', head=1)
        distances = np.array([[0]*int(len(data)/3/self.natoms)]*3*self.natoms, dtype=float)
        bands =np.array([[0]*int(len(data)/3/self.natoms)]*3*self.natoms, dtype=float)
        gvs =np.array([[[0,0,0]]*int(len(data)/3/self.natoms)]*3*self.natoms, dtype=float)
        gps =np.array([[[0,0,0]]*int(len(data)/3/self.natoms)]*3*self.natoms, dtype=float)
        
        for i in data:
            bands[int(float(i[1]))][int(float(i[0]))] = i[9]
            distances[int(float(i[1]))][int(float(i[0]))] = i[8]
            gvs[int(float(i[1]))][int(float(i[0]))] = i[10:13]
            try:
              gps[int(float(i[1]))][int(float(i[0]))] = i[14:-1]
            except:
              pass
        return distances, bands, gvs, gps
    
    def plot_bands(self, npoints=100, filename=None):
        import matplotlib.pyplot as plt
        import numpy as np
        fig ,ax = plt.subplots()
        dists, bands, qqw, qqr = self.get_band_structure()
        ps = len(bands[0])//npoints
        for i in range(len(bands)):
            ax.plot(dists[i], bands[i], 'k-') 
            ax.set_xlim([0,max(dists[0])]) 
            ax.set_ylim([0, max(bands[-1])+1])
        ax.vlines(x=[dists[0][j*npoints+j] for j in range(1,ps)],ymin=0,ymax= max(bands[-1])+1, linestyle='dotted')
        ax.set_xticks([])
        if filename is not None:
            fig.savefig(filename, dpi=600)
        return fig, ax  

    def get_voronoi_volumes(self):
        
        try:
            from freud.box import Box
            from freud.locality import Voronoi
    
            c = self.get_cell()
            b = Box.from_matrix(c)
            v = Voronoi()
            vc = v.compute((b,self.get_cartesian_positions()))
            
            return vc.volumes

        except:
            print("Please install freud python package")

    def write_modes(self, qid=0, mode=0, nframes=100, repeat=(1,1,1), factor=1):

        '''
        Writes *.traj file for given phonon mode which can be visualized with ase-gui
        '''
        from ase import Atoms
        from ase.io.trajectory import Trajectory
        import numpy as np
        import re
        eigvecs = np.array(self.flatten([re.findall(r"[-+]?\d*\.\d+|\d+", str(s)) for s in self.eigdata[(qid)*3*self.natoms+mode]]), dtype=float)
        r = [eigvecs[j] for j in range(3, len(eigvecs)) if j%2 == 1]
        im = [eigvecs[j] for j in range(3, len(eigvecs)) if j%2 == 0]
        cmplxeigvecs = np.array([0]*3*self.natoms, dtype=complex)
        
        for i in range(len(r)):
            cmplxeigvecs[i] = np.complex(r[i], im[i])        
        cmplxeigvecs = cmplxeigvecs.reshape(self.natoms,-1)
        
        cell = self.get_cell()
        elements = self.get_elements()
        positions = self.get_cartesian_positions()
        
        atoms = Atoms(symbols=elements, positions=positions, cell=cell, pbc=True)
        traj = Trajectory(f'{self.outdir}/q{qid}_m{mode}.traj', 'w')
        
        for x in np.linspace(0, 2*np.pi, nframes, endpoint=False):
            disps = ((np.exp(1.j * x)*cmplxeigvecs).real)*factor
            new_pos = atoms.get_positions() + disps
            atoms.set_positions(new_pos)
            atoms.set_velocities(disps/2)
            atoms *= repeat
            traj.write(atoms)
            atoms = Atoms(symbols=elements, positions=positions, cell=cell, pbc=True)
        traj.close()
        
        return disps
        
    def get_pdf(self, x, w, bandwidth=0.05, xmin=0, xmax=None, npts=1000):
        
        from scipy.stats import gaussian_kde
        import numpy as np

        gk = gaussian_kde(x, weights=w)
        gk.covariance_factor = lambda : bandwidth
        gk._compute_covariance()

        if xmax == None:
            xmax = np.max(x)

        f = np.linspace(xmin, xmax, npts)
        g = gk(f)
        g[-1] = 0

        return f, gk(f)
    
    def get_pdos(self, bandwidth=0.05, xmin=0, xmax=None, npts=1000):
        
        import numpy as np
    
        natoms = self.natoms
        freqs = self.get_frequencies()
        spdos = [[] for i in range(natoms)]      

        pdos = self.load_data(f'{self.outdir}/pdos.dat', head=1)[:,1:natoms+1].T

        for i in range(len(spdos)):
            sfreqs, spdos[i] = self.get_pdf(freqs, pdos[i], xmin=xmin, xmax=xmax, npts=npts, bandwidth=bandwidth) 
            
        tdos = []
        for i in range(len(spdos[0])):
            a = 0
            for j in spdos:
                a += j[i]
            tdos.append(a)

        return np.array(sfreqs) , np.array(spdos), np.array(tdos)

    def get_scattering_phase_space3(self):

        phase3 = self.load_data(f"{self.outdir}/scattering_phase_space_3.dat", head=0)
        return phase3[:,2]

    def get_scattering_phase_space4(self):
        
        phase4 = self.load_data(f"{self.outdir}/scattering_phase_space_4.dat", head=0)
        return phase4[:,2]

    def get_kappas(self, indices=[1,1]):
        
        vels = self.get_group_velocities()
        times = self.get_life_times()
        heats = self.get_heat_capacities()

        kappas = heats*times*vels[:,(indices[0])]*vels[:,(indices[1])]
        return kappas


    def get_cummulative_conductivity_with_frequency(self, indices=[1,1]):
        
        import numpy as np

        freqs = self.get_frequencies()
        kk = self.get_kappas(indices=indices)
        sorted_freqs = freqs[freqs.argsort()]
        sorted_kk = kk[freqs.argsort()]
        cum_k = [0]
        for i in sorted_kk:
            cum_k.append(cum_k[-1]+i)

        return sorted_freqs, np.array(cum_k[1:])
    
    def get_cummulative_conductivity_with_mean_free_paths(self, indices=[1,1]):

        import numpy as np

        mpaths = self.get_mean_free_paths()
        kk = self.get_kappas(indices=indices)
        sorted_mpaths = mpaths[mpaths.argsort()]
        sorted_kk = kk[mpaths.argsort()]
        cum_k = [0]
        for i in sorted_kk:
            cum_k.append(cum_k[-1]+i)

        return sorted_mpaths, np.array(cum_k[1:])

    def get_cummulative_conductivity_with_life_times(self, indices=[1,1]):

        import numpy as np

        ltimes = self.get_life_times()
        kk = self.get_kappas(indices=indices)
        sorted_ltimes = ltimes[ltimes.argsort()]
        sorted_kk = kk[ltimes.argsort()]
        cum_k = [0]
        for i in sorted_kk:
            cum_k.append(cum_k[-1]+i)

        return sorted_ltimes, np.array(cum_k[1:])

    def get_force_constants_3(self):
        fcs = self.load_data(f"{self.outdir}/red_3.dat", head=1)
        return fcs[:,-1]

    def get_force_constants_4(self):
        fcs = self.load_data(f"{self.outdir}/red_4.dat", head=1)
        return fcs[:,-1]

    def get_thermal_conductivity(self, indices=[1,1]):
        import numpy as np
        kappas = self.get_kappas(indices=indices)
        return np.sum(kappas)

    def write_json(self, filename='data'):
        import json 

        f = open(f"{filename}.json", 'w')
        data = {}
        data['group_velocities'] = self.get_group_velocities().tolist()
        data['specific_heats'] = self.get_heat_capacities().tolist()
        data['life_times'] = self.get_life_times().tolist()
        data['phase_space3'] = self.get_scattering_phase_space3().tolist()
        data['phase_space4'] = self.get_scattering_phase_space4().tolist()
        data['force_constants3'] = self.get_force_constants_3().tolist()
        data['force_constants4'] = self.get_force_constants_4().tolist()
        data['eigen_vectors_real'] = self.get_eigvecs()[0].tolist()
        data['eigen_vectors_img'] = self.get_eigvecs()[1].tolist()
        
        try:
            data['band_dists'] = self.get_band_structure()[0].tolist()
            data['band_freqs'] = self.get_band_structure()[1].tolist()
        except:
            pass
        
        json.dump(data, f)
        f.close()

    def get_volume(self):
        import numpy as np
        cell = self.get_cell()
        return np.dot(cell[0], np.cross(cell[1].tolist(), cell[2].tolist()))

    def get_reciprocal_lattice_vectors(self):
        import numpy as np
        cell = self.get_cell()
        v = self.get_volume()
        b1 = 2*np.pi/v*(np.cross(cell[1].tolist(), cell[2].tolist()))
        b2 = 2*np.pi/v*(np.cross(cell[2].tolist(), cell[0].tolist()))
        b3 = 2*np.pi/v*(np.cross(cell[0].tolist(), cell[1].tolist()))
        return np.array([b1, b2, b3])

    
    def get_heat_capacity_weighted_vgs(self, direction=0):
        import numpy as np
        cps = self.get_heat_capacities()
        vgs = self.get_group_velocities()
        
        if direction is None:
          return np.sqrt((np.sum(cps*vgs**2)/np.sum(cps)))
        else:
          return np.sqrt((np.sum(cps*vgs[:,direction]**2)/np.sum(cps)))
    
    def get_heat_capacity_weighted_gps(self, direction=0):
        import numpy as np
        cps = self.get_heat_capacities()
        gps = self.get_gruneissen_parameters()
        
        if direction is None:
          return np.sum(cps*np.abs(gps))/np.sum(cps)
        else:
          return np.sum(cps*np.abs(gps[:,direction]))/np.sum(cps)

    def get_debay_temperature(self):
        '''
        From Reference 29 in https://doi.org/10.1063/1.4893185
        '''
        import numpy as np
        hbar = 1.0545718e-34
        kb = 1.38e-23
        try:
          ggggggg
          bnds = self.get_band_structure()[2]
          vg0 = np.max(np.abs(bnds[0][:,0]))
          vg1 = np.max(np.abs(bnds[1][:,0]))
          vg2 = np.max(np.abs(bnds[2][:,0]))
        except:
          vgs = self.get_group_velocities()
          vg0 = np.max(np.abs(vgs[0::3*self.natoms]))
          vg1 = np.max(np.abs(vgs[1::3*self.natoms]))
          vg2 = np.max(np.abs(vgs[2::3*self.natoms]))
        vd = (3/((1/vg0**3)+(1/vg1**3)+(1/vg2**3)))**0.3333333
        print("Debay velocity : ", vd, " m/s.\n")
        
        dt = (hbar*vd/kb)*(6*np.pi*np.pi*self.natoms/self.get_volume()*1e30)**0.3333333
        return dt

    def get_slack_factor(self):
        
        '''
        From References 30, 29 in https://doi.org/10.1063/1.4893185
        considered amu = 1 and Anstrom = 1 
        '''
        import numpy as np
        mbar = np.mean(self.get_masses())
        a = self.get_volume()**0.3333333
        gamma = np.mean([self.get_heat_capacity_weighted_gps(0),
                         self.get_heat_capacity_weighted_gps(1),
                         self.get_heat_capacity_weighted_gps(2)])
        dt = self.get_debay_temperature()
        T = float(self.get_parameter('temperature'))
        
        return (mbar*a*dt**3)/(T*gamma**2)



        
#########################################################################################################

def load_data(filename, head=0):
        import numpy as np
        f = open(filename, 'r')
        lines = [i.split() for i in f.readlines()[head:]]
        f.close()

        for n, i in enumerate(lines):
            for o, j in enumerate(i):
                try:
                    lines[n][o] = float(j)
                except:
                    lines[n][o] = j

        length = max(map(len, lines))
        y=np.array([xi+[None]*(length-len(xi)) for xi in lines])

        return y

def get_pdf(x, w, bandwidth=0.05, xmin=0, xmax=None, npts=1000):
        
        from scipy.stats import gaussian_kde
        import numpy as np

        gk = gaussian_kde(x, weights=w)
        gk.covariance_factor = lambda : bandwidth
        gk._compute_covariance()

        if xmax == None:
            xmax = np.max(x)

        f = np.linspace(xmin, xmax, npts)
        g = gk(f)
        g[-1] = 0

        return f, gk(f)
