import matplotlib.pyplot as plt
import numpy as np
def getEfSCF(scfFile):
    """ Extracts Fermi level from SCF file
    

    Parameters
    ----------
    scfFile : str
        DESCRIPTION. Filename (or path to file and filename)

    Returns
    -------
    list
        DESCRIPTION. [Ef, highest occupied], The first element is the Fermi level. 
        The second element is a boolean describing whether the file said it was 
          the highest occupied orbital (True) or truly the Fermi level (False) 

    """
    highestOccupied=True
    Ef=0.0
    with open(scfFile, "r") as f:
        for l in f.readlines():
            lineStr=l.strip()
            if lineStr[0:3]=='hig':
                L=lineStr.split()
                Ef=float(L[4])
            elif lineStr[0:9]=="the Fermi":
                L=lineStr.split()
                Ef=float(L[4])
                highestOccupied=False
                
    return [Ef, highestOccupied]
def importBands(file):
    """    import bands structure in a mannner ready for plotting
    Parameters
    ----------
    file : str
        DESCRIPTION. band.x output file $[prefix]_bandx.out file generated by running bands.x
    Kcoord : boolean, optional
        
    Returns
    -------
    k : array of float64
        DESCRIPTION. x-coordinates associated with k-points to use for plot. This is a 1-d array (vector)
    Energies : array of float64
        DESCRIPTION. array such that each column represents a band
    x2 : list
        DESCRIPTION. x-coordinates associated with special k points
    hsp : list
        DESCRIPTION. Each element in the list is a set of coordinates [kx, ky, kz] 
        describing the high-symmetry points. 
    """
    bands=[]
    firstSet=True
    numBands=0
    blockX=[]
    Energies=[]
    #with bands.x, already in eV
    with open(file, "r") as f:
        blockY=[]
        for l in f.readlines():
            lineStr=l.strip()
            
            if len(lineStr)==0:
                firstSet=False
                Energies.append(blockY)
                blockY=[]
            else:
                if firstSet==True:
                    line=lineStr.split()
                    #read in both x and y for the first one
                    blockX.append(float(line[0]))
                    blockY.append(float(line[1]))                
                    #bands
                    #convert Rydberg to eV

                else:
                    line=lineStr.split()
                    blockY.append(float(line[1]))
    arrY=np.array(Energies).transpose()
    k=np.array(blockX)
    return [k, arrY]

def buildBands(file2):
    """Extracts energy vs k for plotting into 1D array k and numpy array; 
    also extracts special points for plotting
    
    Parameters
    ----------
    file2 : str
        DESCRIPTION. band.x output file $[prefix]_bandx.out file generated by running bands.x
    Kcoord : boolean, optional
        
    Returns
    -------
    k : array of float64
        DESCRIPTION. x-coordinates associated with k-points to use for plot. This is a 1-d array (vector)
    Energies : array of float64
        DESCRIPTION. array such that each column represents a band
    x2 : list
        DESCRIPTION. x-coordinates associated with special k points
    hsp : list
        DESCRIPTION. Each element in the list is a set of coordinates [kx, ky, kz] 
        describing the high-symmetry points. 
    
    """
    #import bands info            
    #[x, Energies]=importBands(file)
    #extract high-symmetry points
    x2=[]
    hsp=[]
    fname=[]
    with open(file2, "r") as f:
        for l in f.readlines():
            lineStr=l.strip()
            if lineStr[0:3]=='hig':
                L=lineStr.split()
                x2.append(float(L[7]))
                hsp.append([float(L[2]), float(L[3]), float(L[4])])
            if lineStr[0:3]=='Plo':
                L=lineStr.split()
                lnth=len(L)
                fname=L[6:lnth][0]
    [k, Energies]=importBands(fname)
    return [k, Energies, x2, hsp]
        
    

def plotBands(file, Ef=0, Kcoord=True, lbls=[], figsize=plt.rcParams.get('figure.figsize'), pad=1.08):
    """ Return a plot of the data from file. 
    plotBands(file)
    plotBands(file, getEfSCF('si_scf.out')[0], False, ['L', r'$\Gamma$', 'X'])
    plotBands(file, 0, False, ['L', r'$\Gamma$', 'X'])

    Parameters
    ----------
    file : str
        DESCRIPTION. band.x output file $[prefix]_bandx.out file generated by running bands.x
    Ef : float, optional
        DESCRIPTION. Fermi level. Can be extracted using 
    Kcoord : boolean, optional
        DESCRIPTION. The default is True. For the k-axis, use the special point coordinates (kx, ky, kz) for labels. 
        If false, requires lbls to include a list of labels of the appropriate length eg:['L', r'$\Gamma$', X]
    lbls : list
        DESCRIPTION. The default is []. List of length number of high-symmetry points, optional only if Kcoord=True. 
    figsize: (float, float), optional
        DESCRIPTION. The default is plt.rcParams.get('figure.figsize'). Figure width and height in inches
    pad: float, optional
        DESCRIPTION. The default is 1.08. Padding around figure for tight_layout
    Returns
    -------
    None.

    """
    [k, Energies, xSpec, hsp]=buildBands(file)    

    [r,c]=Energies.shape
    #plot bands info
    plt.figure(figsize=figsize)
    plt.ylabel('Energy (eV)')
    for kp in range(c):
        plt.plot(k, Energies[:,kp]-Ef, color='C0')
        
    plt.axhline(y=0, color='k', linestyle='--', linewidth=1, alpha=0.5, label=None)
    #change k axis to special points
    plt.xlabel('k-point path')
    plt.xticks(ticks=xSpec, labels=hsp)
    for kp in range(len(xSpec)):
        plt.axvline(x=xSpec[kp], color='k', linestyle='--', linewidth=1, alpha=0.5, label=None)
    plt.xlim((min(xSpec), max(xSpec)))
    if Kcoord:
        plt.xticks(ticks=xSpec, labels=hsp)
    elif len(lbls)==len(xSpec):
        plt.xticks(ticks=xSpec, labels=lbls)
    else:
        plt.xticks(ticks=xSpec, labels=hsp)
        print('Ensure that list lbls contains the proper number of special point labels')
        
    plt.tight_layout(pad=pad)
    plt.show()