from .tools import *
####################################################################################################

class ALD:  
    
    def __init__(self, 
                 outdir = './'
                 ):

        '''Give path to output folder, all the files should be as same as files generated by ALD code, 
        all the returning lists will be in the order of harmonic_properties.dat, unless it is mentioned, like in case of cummulative lists'''
        
        from ase import Atoms
        from ase.io.trajectory import Trajectory
        import numpy as np
        import re
        
        self.outdir = outdir
        self.primcell = self.load_data(f'{self.outdir}/cell_primitive.dat',head=1)
        self.pos_data = pair_sort(self.primcell[3:][:,0], self.primcell[3:])[1]
        self.harmprop = self.load_data(f"{self.outdir}/harmonic_properties.dat", head=1)
        #self.eigdata = self.load_data(f'{self.outdir}/eigenVector.dat', head=1)
        self.natoms = len(self.get_scaled_positions())
        self.nqpts = int((len(self.harmprop))/int(3*self.natoms))
        self.nlength = float(self.get_parameter('nlength'))*1e10
        self.iso = self.get_parameter('isoscat')
        self.boundary = self.get_parameter('ibdry')
        self.coherent = self.get_parameter('coherent_bte')
        self.fqcell = [int(i) for i in  self.get_parameter('fqcell').replace('[', '').replace(']', '').split(',')]
        self.set_name()
        self.dos_done = False
        

    def move_axes(self, ax, fig, pos=[0.1,0.1,0.5,0.5]):
        
        '''
        To move an axes from one figure to another from: ax, to: fig
        '''
        old_fig = ax.figure
        ax.remove()
        ax.figure = fig
        fig.axes.append(ax)
        fig.add_axes(ax)
        dummy_ax = fig.add_subplot()
        ax.set_position(pos)
        dummy_ax.remove()
        plt.close(old_fig)

    def get_parameter(self, parameter):
        
        f = open(f'{self.outdir}/parameters.dat', 'r')
        lines = f.readlines()
        f.close()
            
        for i in lines:
            if parameter.lower() in i.lower():
                par = i.split(':')[-1][:-1].strip()
                break
        try:
            return par
        except:
            print(f'{parameter} is not specified in parameters.dat file')
            return 0

    def load_data(self, filename, head=0):
        import numpy as np
        f = open(filename, 'r')
        lines = [i.split() for i in f.readlines()[head:]]
        f.close()

        for n, i in enumerate(lines):
            for o, j in enumerate(i):
                if 'nan' in j.lower() or 'inf' in j.lower():
                    lines[n][o] = 0
                    continue
                try:
                    lines[n][o] = float(j)
                except:
                    lines[n][o] = j

        length = max(map(len, lines))
        y=np.array([xi+[None]*(length-len(xi)) for xi in lines])

        return y

    
    def flatten(self, list_of_lists):
            
        if len(list_of_lists) == 0:
            return list_of_lists
        if isinstance(list_of_lists[0], list):
            return self.flatten(list_of_lists[0]) + self.flatten(list_of_lists[1:])
        return list_of_lists[:1] + self.flatten(list_of_lists[1:])

    def get_cell(self):
        '''
        Returns cell parameters as per primitive.dat '''
        return self.primcell[0:3, 0:3]*self.nlength

    def get_direct_qpts(self):

        '''
        Returns directs qpts from harmonic_properties.dat, in the same order as in file '''
        import numpy as np
        qpts = []
        for i in self.harmprop:
            if i[0] != len(qpts)-1 or len(qpts) == 0:
                qpts.append(i[2:5])

        return np.array(qpts, dtype=float)
    
    def get_cartesian_qpts(self):
        '''
        Returns crtesian qpts from harmonic_properties.dat, in the same order as in file '''
        import numpy as np
        qpts = []
        for i in self.harmprop:
            if i[0] != len(qpts)-1 or len(qpts) == 0:
                qpts.append(i[5:8])

        return np.array(qpts, dtype=float)

    def get_elements(self):
        '''
        Returns list of elements from primitive.dat'''
        return self.pos_data[:,1]

    def get_scaled_positions(self):
        return self.pos_data[:, 2:5]
    
    def get_masses(self):
        return self.pos_data[:, 8]
    
    def get_cartesian_positions(self):
        return self.pos_data[:, 5:8]*self.nlength
    
    def get_eigvecs(self):
        
        import re
        import numpy as np
        
        eigvecs = []
        eigvecdata1 = self.load_data(f'{self.outdir}/eigenVector.dat', head=1)        
        eigvecdata2 = np.array([self.flatten([re.findall(r"[-+]?\d*\.\d+|\d+", str(s)) for s in s1]) for s1 in eigvecdata1], dtype=float)
        
        real = np.array([[[0]*self.natoms*3]*self.natoms*3]*self.nqpts, dtype=float)
        imaginary = np.array([[[0]*self.natoms*3]*self.natoms*3]*self.nqpts, dtype=float)
        
        for i in eigvecdata2:
            qid = int(float(i[0]))
            modeid = int(float(i[1]))
            r = [i[j] for j in range(3, len(i)) if j%2 == 1]
            im = [i[j] for j in range(3, len(i)) if j%2 == 0]
            real[qid][modeid] = r
            imaginary[qid][modeid] = im    
        
        self.real_eigvec = real
        self.imag_eigvec = imaginary
        return real, imaginary
    
    def get_frequencies(self):
        return self.harmprop[:,8]
    
    def get_heat_capacities(self):
        sps = self.harmprop[:,9]
        import numpy as np
        sps[np.isnan(sps)]=0
        return sps
    
    def get_group_velocities(self):
        import numpy as np
        vgs = self.harmprop[:,10:13]
        vgs[np.isnan(vgs)] = 0
        return vgs
    
    def get_gruneissen_parameters(self):
        return self.harmprop[:,14:]
    
    def get_mean_free_paths(self):
        import numpy as np
        return np.array([np.sqrt(np.sum(i**2)) for i in self.get_group_velocities()])*self.get_lifetimes()
                             
    def get_lifetimes(self, index=-1):
        import numpy as np
        rta = self.load_data(f'{self.outdir}/phonon_RTA.dat', head=1)
        rta[np.isinf(rta)]=0
        rta[np.isnan(rta)]=0
        return rta[:,index]   
   
    def get_total_lifetimes(self):
        
        if self.iso=='yes' and self.boundary=='yes':
            ti = self.get_lifetimes(-2)
            tb = self.get_lifetimes(-3)
            tp = self.get_lifetimes(-1)
            tt = 1/(1/ti + 1/tb + 1/tp)
        if self.iso=='no' and self.boundary=='yes':
            tb = self.get_lifetimes(-2)
            tp = self.get_lifetimes(-1)
            tt = 1/(1/tb + 1/tp)
        if self.iso=='yes' and self.boundary=='no':
            ti = self.get_lifetimes(-2)
            tp = self.get_lifetimes(-1)
            tt = 1/(1/ti + 1/tp)
        else:
            tt = self.get_lifetimes(-1)

        return tt

     
    def get_band_structure(self):
        '''
        Returns dists, bands, velocities, GPs, PRs
        '''
        import numpy as np
        data = self.load_data(f'{self.outdir}/directional_harmonic_properties.dat', head=1)
        distances = np.array([[0]*int(len(data)/3/self.natoms)]*3*self.natoms, dtype=float)
        bands =np.array([[0]*int(len(data)/3/self.natoms)]*3*self.natoms, dtype=float)
        gvs =np.array([[[0,0,0]]*int(len(data)/3/self.natoms)]*3*self.natoms, dtype=float)
        gps =np.array([[[0,0,0]]*int(len(data)/3/self.natoms)]*3*self.natoms, dtype=float)
        prs = np.array([[0]*int(len(data)/3/self.natoms)]*3*self.natoms, dtype=float)
        
        for i in data:
            bands[int(float(i[1]))][int(float(i[0]))] = i[9]
            distances[int(float(i[1]))][int(float(i[0]))] = i[8]
            gvs[int(float(i[1]))][int(float(i[0]))] = i[10:13]
            gps[int(float(i[1]))][int(float(i[0]))] = i[14:17]
            prs[int(float(i[1]))][int(float(i[0]))] = i[13]
          
        return distances, bands, gvs, gps, prs
    
    def plot_bands(self, npoints=100, filename=None, bandpath=None):
        import matplotlib.pyplot as plt
        import numpy as np
        fig ,ax = plt.subplots()
        dists, bands, qqw, qqr, prs = self.get_band_structure()
        ps = len(bands[0])//npoints
        for i in range(len(bands)):
            ax.plot(dists[i], bands[i], 'k-') 
            ax.set_xlim([0,max(dists[0])]) 
            ax.set_ylim([0, max(bands[-1])+1])
        ax.vlines(x=[dists[0][j*npoints+j] for j in range(1,ps)],ymin=0,ymax= max(bands[-1])+1, linestyle='dotted')
        ax.set_xticks([])
        if bandpath is not None:
            ticks = [dists[0][j*npoints+j] for j in range(1,ps)]
            ticks.insert(0,0)
            ticks.insert(-1,np.max(dists[0]))
            ax.set_xticks(ticks)
            ax.set_xticklabels(list(bandpath))
        ax.set_title(self.name)
        if filename is not None:
            fig.savefig(filename, dpi=600)
        return fig, ax  
    
    def plot_bands_with_gps(self, npoints=100, filename=None, bandpath=None):
        import matplotlib.pyplot as plt
        import numpy as np
        from matplotlib.collections import LineCollection
        from matplotlib.colors import ListedColormap, BoundaryNorm

        fig ,ax = plt.subplots()
        dists, bands, vgs, gps, prs = self.get_band_structure()
        ps = len(bands[0])//npoints
        
        for i in range(len(bands)):
            x = dists[i]; y = bands[i]; g = np.array([abs(np.mean(arr)) for arr in gps[i]])
            points = np.array([x, y]).T.reshape(-1,1,2)
            segments = np.concatenate([points[:-1], points[1:]], axis=1)
            lc = LineCollection(segments, cmap='viridis')
            lc.set_array(g)
            #lc.set_linewidth(2)
            line = ax.add_collection(lc)
        
        cbr = fig.colorbar(line, ax=ax)
        ax.set_xlim([0,max(dists[0])]) 
        ax.set_ylim([0, max(bands[-1])+1])
        ax.vlines(x=[dists[0][j*npoints+j] for j in range(1,ps)],ymin=0,ymax= max(bands[-1])+1, linestyle='dotted')
        ax.set_xticks([])
        if bandpath is not None:
            ticks = [dists[0][j*npoints+j] for j in range(1,ps)]
            ticks.insert(0,0)
            ticks.insert(-1,np.max(dists[0]))
            ax.set_xticks(ticks)
            ax.set_xticklabels(list(bandpath))
        ax.set_title(self.name)
        if filename is not None:
            fig.savefig(filename, dpi=600)
        return fig, ax  
    
    def plot_bands_with_prs(self, npoints=100, filename=None, bandpath=None):
        
        import matplotlib.pyplot as plt
        import numpy as np
        from matplotlib.collections import LineCollection
        from matplotlib.colors import ListedColormap, BoundaryNorm

        fig ,ax = plt.subplots()
        dists, bands, vgs, gps, prs = self.get_band_structure()
        ps = len(bands[0])//npoints
        
        for i in range(len(bands)):
            x = dists[i]; y = bands[i]; g = prs[i]
            points = np.array([x, y]).T.reshape(-1,1,2)
            segments = np.concatenate([points[:-1], points[1:]], axis=1)
            lc = LineCollection(segments, cmap='viridis')
            lc.set_array(g)
            #lc.set_linewidth(2)
            line = ax.add_collection(lc)
        
        cbr = fig.colorbar(line, ax=ax)
        ax.set_xlim([0,max(dists[0])]) 
        ax.set_ylim([0, max(bands[-1])+1])
        ax.vlines(x=[dists[0][j*npoints+j] for j in range(1,ps)],ymin=0,ymax= max(bands[-1])+1, linestyle='dotted')
        ax.set_xticks([])
        if bandpath is not None:
            ticks = [dists[0][j*npoints+j] for j in range(1,ps)]
            ticks.insert(0,0)
            ticks.insert(-1,np.max(dists[0]))
            ax.set_xticks(ticks)
            ax.set_xticklabels(list(bandpath))
        ax.set_title(self.name)
        if filename is not None:
            fig.savefig(filename, dpi=600)
        return fig, ax  
    
    def plot_bands_with_tdos(self, npoints=100, filename=None, bandpath=None, bandwidth=0.05):
        
        import matplotlib.pyplot as plt
        import numpy as np
        fig ,ax = plt.subplots(1,2, gridspec_kw={'width_ratios': [3, 1]})
        dists, bands, vgs, gps, prs = self.get_band_structure()
        ps = len(bands[0])//npoints
        for i in range(len(bands)):
            ax[0].plot(dists[i], bands[i], 'k-') 
            ax[0].set_xlim([0,max(dists[0])]) 
            ax[0].set_ylim([0, max(bands[-1])+1])
        ax[0].vlines(x=[dists[0][j*npoints+j] for j in range(1,ps)],ymin=0,ymax= max(bands[-1])+1, linestyle='dotted')
        ax[0].set_xticks([])
        if bandpath is not None:
            ticks = [dists[0][j*npoints+j] for j in range(1,ps)]
            ticks.insert(0,0)
            ticks.insert(-1,np.max(dists[0]))
            ax[0].set_xticks(ticks)
            ax[0].set_xticklabels(list(bandpath))
        ax[0].set_title(self.name)

        if self.dos_done:
            pfs = self.pfs
            tdos = self.tdos
        else:
            pfs, _, tdos = self.get_pdos(xmax=max(bands[-1])+1, bandwidth=bandwidth)
        
        ax[1].plot(tdos, pfs, 'k-')
        ax[1].set_xlim([0,max(tdos)+0.2]) 
        ax[1].set_ylim([0, max(bands[-1])+1])
        ax[1].set_yticks([])
        
        fig.tight_layout()
        if filename is not None:
            fig.savefig(filename, dpi=600)
        return fig, ax  

    def get_voronoi_volumes(self):
        
        try:
            from freud.box import Box
            from freud.locality import Voronoi
    
            c = self.get_cell()
            b = Box.from_matrix(c)
            v = Voronoi()
            vc = v.compute((b,self.get_cartesian_positions()))
            
            return vc.volumes

        except:
            print("Please install freud python package")

    def write_modes(self, qid=0, mode=0, nframes=100, repeat=(1,1,1), factor=1):

        '''
        Writes *.traj file for given phonon mode which can be visualized with ase-gui
        '''
        from ase import Atoms
        from ase.io.trajectory import Trajectory
        import numpy as np
        import re

        eigdata = self.load_data(f'{self.outdir}/eigenVector.dat', head=1)
        eigvecs = np.array(self.flatten([re.findall(r"[-+]?\d*\.\d+|\d+", str(s)) for s in eigdata[(qid)*3*self.natoms+mode]]), dtype=float)
        r = [eigvecs[j] for j in range(3, len(eigvecs)) if j%2 == 1]
        im = [eigvecs[j] for j in range(3, len(eigvecs)) if j%2 == 0]
        cmplxeigvecs = np.array([0]*3*self.natoms, dtype=complex)
        
        for i in range(len(r)):
            cmplxeigvecs[i] = np.complex(r[i], im[i])        
        cmplxeigvecs = cmplxeigvecs.reshape(self.natoms,-1)
        
        cell = self.get_cell()
        elements = self.get_elements()
        positions = self.get_cartesian_positions()
        
        atoms = Atoms(symbols=elements, positions=positions, cell=cell, pbc=True)
        traj = Trajectory(f'{self.outdir}/q{qid}_m{mode}.traj', 'w')
        
        for x in np.linspace(0, 2*np.pi, nframes, endpoint=False):
            disps = ((np.exp(1.j * x)*cmplxeigvecs).real)*factor
            new_pos = atoms.get_positions() + disps
            atoms.set_positions(new_pos)
            atoms.set_velocities(disps/2)
            atoms *= repeat
            traj.write(atoms)
            atoms = Atoms(symbols=elements, positions=positions, cell=cell, pbc=True)
        traj.close()
        
        return disps
        
    def get_pdf(self, x, w, bandwidth=0.05, xmin=0, xmax=None, npts=1000):
        
        from scipy.stats import gaussian_kde
        import numpy as np

        gk = gaussian_kde(x, weights=w)
        gk.covariance_factor = lambda : bandwidth
        gk._compute_covariance()

        if xmax == None:
            xmax = np.max(x)

        f = np.linspace(xmin, xmax, npts)
        g = gk(f)
        g[-1] = 0

        return f, gk(f)
    
    def get_pdos(self, bandwidth=0.05, xmin=0, xmax=None, npts=1000):
        
        import numpy as np
    
        natoms = self.natoms
        freqs = self.get_frequencies()
        spdos = [[] for i in range(natoms)]      

        pdos = self.load_data(f'{self.outdir}/pdos.dat', head=1)[:,1:natoms+1].T
        self.pdosfile = pdos

        for i in range(len(spdos)):
            sfreqs, spdos[i] = self.get_pdf(freqs, pdos[i], xmin=xmin, xmax=xmax, npts=npts, bandwidth=bandwidth) 
            
        tdos = []
        for i in range(len(spdos[0])):
            a = 0
            for j in spdos:
                a += j[i]
            tdos.append(a)

        self.dos_done = True
        self.pfs = np.array(sfreqs)
        self.pdos = np.array(spdos)
        self.tdos = np.array(tdos)

        return np.array(sfreqs) , np.array(spdos), np.array(tdos)

    def get_scattering_phase_space3(self):

        phase3 = self.load_data(f"{self.outdir}/scattering_phase_space_3.dat", head=0)
        return phase3[:,2]*2/3/self.fqcell[0]/self.fqcell[1]/self.fqcell[2]

    def get_scattering_phase_space4(self):
        
        phase4 = self.load_data(f"{self.outdir}/scattering_phase_space_4.dat", head=0)
        return phase4[:,2]*2/3/self.fqcell[0]/self.fqcell[1]/self.fqcell[2] 

    def get_kappas(self, indices=[0,0], life_id=None):
        
        vels = self.get_group_velocities()
        
        if life_id is None:
          times = self.get_total_lifetimes()
        else:
          times = self.get_lifetimes(index=life_id)
       
        t2 = times[times>0] 
        heats = self.get_heat_capacities()

        kappas = heats[times>0]*t2*vels[times>0][:,(indices[0])]*vels[times>0][:,(indices[1])]
        return kappas

    def get_coherent_kappas(self, direction='xx'):
        try:
            aa = ['a', 'a', 'xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz']
            cd = self.load_data(f"{self.outdir}/phonon_coherent.dat", head=1)
            return cd[:,aa.index(direction.lower())]
        except:
            return 0
    
    def get_cummulative_conductivity_with_frequency(self, indices=[0,0], life_id=None):
        
        import numpy as np

        freqs = self.get_frequencies()
        kk = self.get_kappas(indices=indices, life_id=life_id)
        sorted_freqs = freqs[freqs.argsort()]
        sorted_kk = kk[freqs.argsort()]
        cum_k = [0]
        for i in sorted_kk:
            cum_k.append(cum_k[-1]+i)

        return sorted_freqs, np.array(cum_k[1:])
    
    def get_cummulative_conductivity_with_mean_free_paths(self, indices=[0,0], life_id=None):

        import numpy as np

        mpaths = self.get_mean_free_paths()
        kk = self.get_kappas(indices=indices, life_id=life_id)
        sorted_mpaths = mpaths[mpaths.argsort()]
        sorted_kk = kk[mpaths.argsort()]
        cum_k = [0]
        for i in sorted_kk:
            cum_k.append(cum_k[-1]+i)

        return sorted_mpaths, np.array(cum_k[1:])

    def get_cummulative_conductivity_with_lifetimes(self, indices=[0,0], life_id=None):

        import numpy as np

        ltimes = self.get_total_lifetimes()
        kk = self.get_kappas(indices=indices, life_id=life_id)
        sorted_ltimes = ltimes[ltimes.argsort()]
        sorted_kk = kk[ltimes.argsort()]
        cum_k = [0]
        for i in sorted_kk:
            cum_k.append(cum_k[-1]+i)

        return sorted_ltimes, np.array(cum_k[1:])

    def get_force_constants_3(self):
        fcs = self.load_data(f"{self.outdir}/red_3.dat", head=1)
        return fcs[:,-1]

    def get_force_constants_4(self):
        fcs = self.load_data(f"{self.outdir}/red_4.dat", head=1)
        return fcs[:,-1]

    def get_thermal_conductivity(self, indices=[0,0]):
        import numpy as np
        kappas = self.get_kappas(indices=indices)
        return np.sum(kappas)

    def write_pickle(self, filename='data'):
        import pickle

        f = open(f"{filename}.pickle", 'wb')
        data = {}
        data['group_velocities'] = self.get_group_velocities().tolist()
        data['specific_heats'] = self.get_heat_capacities().tolist()
        data['lifetimes'] = self.get_lifetimes().tolist()
        data['phase_space3'] = self.get_scattering_phase_space3().tolist()
        data['phase_space4'] = self.get_scattering_phase_space4().tolist()
        data['force_constants3'] = self.get_force_constants_3().tolist()
        data['force_constants4'] = self.get_force_constants_4().tolist()
        data['eigen_vectors_real'] = self.get_eigvecs()[0].tolist()
        data['eigen_vectors_img'] = self.get_eigvecs()[1].tolist()
        data['frequencies'] = self.get_frequencies().tolist()
        
        try:
            data['band_dists'] = self.get_band_structure()[0].tolist()
            data['band_freqs'] = self.get_band_structure()[1].tolist()
        except:
            pass
        
        pickle.dump(data, f)
        f.close()

    def get_volume(self):
        import numpy as np
        cell = self.get_cell()
        return np.dot(cell[0], np.cross(cell[1].tolist(), cell[2].tolist()))

    def get_density(self):
        import numpy as np
        vol = self.get_volume()*1e-30
        mass = np.sum(self.get_masses())*1.66e-27
        return mass/vol

    def get_msds(self):
        import numpy as np
        try:
          logfile = self.outdir+'/../log.dat'

          f = open(logfile, 'r')
          lines = f.readlines()
          l = []
          for i in lines:
              if 'MSD' in i and 'Symbol' in i:
                  l.append(i.replace('\n', ''))
          f.close()
          
          l = [i.split(':')[-1] for i in l]
          l = [i.replace(',', '') for i in l]
          l = [i.replace(']', '') for i in l]
          l = [i.replace('[', '').split() for i in l]
          ll = []
          for i in l:
              m = []
              for j in i:
                  try:
                      m.append(float(j))
                  except:
                      continue
              ll.append(m)
          mm = np.array(ll).reshape(len(ll), 3, 3)
          return mm
        except:
          print(f'Error occured while reading {self.outdir}/../log.dat')

    def get_reciprocal_lattice_vectors(self):
        import numpy as np
        cell = self.get_cell()
        v = self.get_volume()
        b1 = 2*np.pi/v*(np.cross(cell[1].tolist(), cell[2].tolist()))
        b2 = 2*np.pi/v*(np.cross(cell[2].tolist(), cell[0].tolist()))
        b3 = 2*np.pi/v*(np.cross(cell[0].tolist(), cell[1].tolist()))
        return np.array([b1, b2, b3])

    
    def get_heat_capacity_weighted_vgs(self, direction=0):
        import numpy as np
        cps = self.get_heat_capacities()
        vgs = self.get_group_velocities()
        
        if direction is None:
          return np.sqrt((np.sum(cps*vgs**2)/np.sum(cps)))
        else:
          return np.sqrt((np.sum(cps*vgs[:,direction]**2)/np.sum(cps)))
    
    def get_heat_capacity_weighted_gps(self, direction=0):
        import numpy as np
        cps = self.get_heat_capacities()
        gps = self.get_gruneissen_parameters()
        
        if direction is None:
          return np.sum(cps*np.abs(gps))/np.sum(cps)
        else:
          return np.sum(cps*np.abs(gps[:,direction]))/np.sum(cps)

    def get_debye_temperature(self):
        '''
        From Reference 29 in https://doi.org/10.1063/1.4893185
        '''
        import numpy as np
        hbar = 1.0545718e-34
        kb = 1.38e-23
        try:
          ggggggg
          bnds = self.get_band_structure()[2]
          vg0 = np.max(np.abs(bnds[0][:,0]))
          vg1 = np.max(np.abs(bnds[1][:,0]))
          vg2 = np.max(np.abs(bnds[2][:,0]))
        except:
          vgs = self.get_group_velocities()
          vg0 = np.max(np.abs(vgs[0::3*self.natoms]))
          vg1 = np.max(np.abs(vgs[1::3*self.natoms]))
          vg2 = np.max(np.abs(vgs[2::3*self.natoms]))
        vd = (3/((1/vg0**3)+(1/vg1**3)+(1/vg2**3)))**0.3333333
        
        dt = (hbar*vd/kb)*(6*np.pi*np.pi*self.natoms/self.get_volume()*1e30)**0.3333333
        return dt

    def set_name(self, name=None):
        import numpy as np

        if name == None:
          elm = self.get_elements()
          unq = np.unique(elm)
          nm = ''
          for el in unq:
            nm+=el
            nm+=str(elm.tolist().count(el))
          self.name = nm
        else:
          self.name = name
    
    def get_name(self):
        
        return self.name

    def get_atoms(self):

        from ase import Atoms
        atoms = Atoms(symbols=self.get_elements(),
                      positions=self.get_cartesian_positions(),
                      cell=self.get_cell(),
                      pbc=True)
        return atoms
         
    def get_unique_atoms(self, symprec=1e-2):
        
        import spglib as sg
        import numpy as np

        atoms = self.get_atoms()
        c = atoms.get_cell()
        p = atoms.get_scaled_positions()
        n = atoms.get_atomic_numbers()
        cell = (c,p,n)
        eqa = sg.get_symmetry(cell, symprec=symprec)['equivalent_atoms']
        return np.unique(eqa)
    
    def get_symmetry(self, symprec=1e-2):
        
        import spglib as sg
        import numpy as np

        atoms = self.get_atoms()
        c = atoms.get_cell()
        p = atoms.get_scaled_positions()
        n = atoms.get_atomic_numbers()
        cell = (c,p,n)
        return sg.get_symmetry_dataset(cell, symprec=symprec)


    def get_slack_factor(self):
        
        '''
        From References 30, 29 in https://doi.org/10.1063/1.4893185
        considered amu = 1, Anstrom = 1 and T in K 
        '''
        import numpy as np
        mbar = np.mean(self.get_masses())
        a = (self.get_volume()*1e-30)**0.3333333
        gamma = np.mean([self.get_heat_capacity_weighted_gps(0),
                         self.get_heat_capacity_weighted_gps(1),
                         self.get_heat_capacity_weighted_gps(2)])
        dt = self.get_debye_temperature()
        T = float(self.get_parameter('temperature'))
        
        return (mbar*a*dt**3)/(T*gamma**2)

    def get_group_speeds(self):
        import numpy as np
        vgs = self.get_group_velocities()
        return np.array([np.linalg.norm(i) for i in vgs])        
    
    def get_mean_gps(self):
        import numpy as np
        gps = self.get_gruneissen_parameters()
        mgps = [np.abs(np.mean(j)) for j in gps]
        return np.array(mgps)


    def plot_bands_with_pdos(self, npoints=100, filename=None,symprec=1e-2, bandpath=None, bandwidth=0.05):
        
        import matplotlib.pyplot as plt
        import numpy as np
        fig ,ax = plt.subplots(1,2, gridspec_kw={'width_ratios': [3, 1]})
        dists, bands, vgs, gps, prs = self.get_band_structure()
        ps = len(bands[0])//npoints
        for i in range(len(bands)):
            ax[0].plot(dists[i], bands[i], 'k-') 
            ax[0].set_xlim([0,max(dists[0])]) 
            ax[0].set_ylim([0, max(bands[-1])+1])
        ax[0].vlines(x=[dists[0][j*npoints+j] for j in range(1,ps)],ymin=0,ymax= max(bands[-1])+1, linestyle='dotted')
        ax[0].set_xticks([])
        if bandpath is not None:
            ticks = [dists[0][j*npoints+j] for j in range(1,ps)]
            ticks.insert(0,0)
            ticks.insert(-1,np.max(dists[0]))
            ax[0].set_xticks(ticks)
            ax[0].set_xticklabels(list(bandpath))
        ax[0].set_title(self.name)

        if self.dos_done:
            pfs = self.pfs
            pd = self.pdos
        else:
            pfs, pd, tdos = self.get_pdos(xmax=max(bands[-1])+1, bandwidth=bandwidth)
        
        uatoms = self.get_unique_atoms(symprec=symprec)
        ealist = self.get_symmetry(symprec=symprec)['equivalent_atoms'].tolist() 
        max_pdos = 0
        for sim in uatoms :
          pdoss = self.pdos[sim]*ealist.count(sim)
          if np.max(pdoss) > max_pdos:
            max_pdos = np.max(pdoss)
          ax[1].plot(pdoss, pfs, label=f'{self.get_elements()[sim]}_{sim}')
        
        ax[1].legend()
        xl = max_pdos+0.02
        yl = np.max(bands[-1])+1
        ax[1].set_xlim([0, xl]) 
        ax[1].set_ylim([0, yl])
        ax[1].set_yticks([])
        
        fig.tight_layout()
        if filename is not None:
            fig.savefig(filename, dpi=600)
        return fig, ax 
        
    def get_harmonic_force_constants(self, flfrcpath=None):
        
        '''
        Returns; fc, positions, masses_by_positions, cell2, cart_cell
        Harmonic force constants as array with indices atom1, atom2, dir1, dir2, cell1, cell2, cell3
        '''
        if not flfrcpath is None:
          return read_flfrc(flfrcpath)
        
        else:
          try:
            f = read_flfrc(f"{self.outdir}/../save/harmonic_flfrc.dat")
            return f
          except:
            print("No valid flfrc file present in outdir/../save/")   
    
    def get_linewidth(self, i=-1):

        '''Returns linewidth (in THz) as per 10.1038/ncomms7400, by Prof. Marzari.
        i is index of required lifetimes in phonon_RTA.dat'''
        
        import numpy as np

        l = self.get_lifetimes(i)
        l2 = l[l>0]
        c = self.get_heat_capacities()
        c = c[l>0]
        
        return np.sum(c*2*np.pi/l2)/np.sum(c)*1e-12    
    
    def get_secondsound_lifetime(self, rl):
    
        """
        Returns second sound lifetime given resistive lifetimes.
        Length of resisitve lifetimes should be same as ald.heatcapacities.
        """
        
        import numpy as np
        
        hbar = 1.0545*1e-34
        kb = 1.3806*1e-23
        T = float(self.get_parameter('temperature'))
        natoms = self.natoms
        
        def repeat(vec, num):
            rep_vec = []
            for i in vec:
                for j in range(num):
                    rep_vec.append(i)
            return np.array(rep_vec)
        
        r2 = 1/rl[rl>0]
        frs = self.get_frequencies()[rl>0]
        qpts = repeat(self.get_cartesian_qpts(), 3*natoms)[rl>0]
        c = self.get_heat_capacities()[rl>0]
        v = self.get_group_velocities()[rl>0]
        vq = np.array([np.dot(v[i],qpts[i]) for i in range(len(v)) ])
        
        prefacs = c/hbar/frs*kb*(T**2)*vq
        
        return np.sum(prefacs)/np.sum(prefacs*r2)

    def get_secondsound_speed(self):
    
        import numpy as np
        
        v = self.get_group_velocities()
        vdot = np.array([np.dot(i,i) for i in v])
        c = self.get_heat_capacities()
        
        vss_2 = np.sum(c*vdot/2)/np.sum(c)
        
        return vss_2**0.5

    def get_secondsound_length(self, rl):
        
        return self.get_secondsound_speed()*self.get_secondsound_lifetime(rl)

    def is_layered(self, cutoff=2.5):
    
        """
        Returns 0 if not layered. 1 for layered in X-dir, 2 for layered in Y-dir etc.
        """

        import numpy as np
        at = self.get_atoms()
        cell = at.cell.cellpar()
        px = at.get_scaled_positions()[:,0]; py = at.positions[:,1]; pz = at.positions[:,2]
        
        if 0 in np.histogram(px, bins=int(cell[0]/cutoff))[0]:
            return 1
        elif 0 in np.histogram(py, bins=int(cell[1]/cutoff))[0]:
            return 2
        elif 0 in np.histogram(pz, bins=int(cell[2]/cutoff))[0]:
            return 3
        
        else:
            return 0
#########################################################################################################


