
import numpy as np

# Package fns.
from epsproc.basicPlotters import molPlot
from epsproc.util.conversion import (conv_ev_atm, conv_ev_nm)

#*************** Summary & display functions

# Print some jobInfo stuff & plot molecular structure
def jobSummary(jobInfo = None, molInfo = None, tolConv = 1e-2):
    """
    Print some jobInfo stuff & plot molecular structure. (Currently very basic.)

    Parameters
    ------------
    jobInfo : dict, default = None
        Dictionary of job data, as generated by :py:function:`epsproc.IO.headerFileParse()` from source ePS output file.

    molInfo : dict, default = None
        Dictionary of molecule data, as generated by :py:func:`epsproc.IO.molInfoParse()` from source ePS output file.

    tolConv : float, default = 1e-2
        Used to check for convergence in ExpOrb outputs, which defines single-center expansion of orbitals.

    Returns
    -------
    JobInfo : list

    orbInfo : dict
        Properties of ionizing orbital, as determined from (jobInfo, molInfo).


    History
    -------
    20/09/20    v2  Added orbInfo dict, and use this to hold all orbital related outputs for return. May break old codes (pre v1.2.6-dev).
                    Moved orbInfo to a separate function.

    """

    # orbInfo = {}

    # Pull job summary info
    if jobInfo is not None:
        print('\n*** Job summary data')
        [print(line.strip('#')) for line in jobInfo['comments'][0:4]]
        print(f"\nElectronic structure input: {jobInfo['Convert'][0].split()[1].strip()}")
        print(f"Initial state occ:\t\t {jobInfo['OrbOccInit']}")
        print(f"Final state occ:\t\t {jobInfo['OrbOcc']}")
        print(f"IPot (input vertical IP, eV):\t\t {jobInfo['IPot']}")

        # Log orbOcc
        # orbInfo['OrbOccInit'] = jobInfo['OrbOccInit']
        # orbInfo['OrbOccFinal'] = jobInfo['OrbOcc']

        # Additional orb info
        print("\n*** Additional orbital info (SymProd)")
        # orbInfo['iList'] = jobInfo['OrbOccInit'] - jobInfo['OrbOcc']
        iList = jobInfo['OrbOccInit'] - jobInfo['OrbOcc']
        print(f"Ionizing orb:\t\t\t {iList}")

        if molInfo is not None:
            # Get orbGrp for event
            # orbInfo['iOrbGrp'] = np.flatnonzero(orbInfo['iList']) + 1
            orbX, orbInfo = getOrbInfo(jobInfo, molInfo)

            # Find entries in orbTableX
            # orbInfo['orbSym'] = np.unique(molInfo['orbTable'].where(molInfo['orbTable'].coords['OrbGrp'] == orbInfo['iOrbGrp'], drop = True).coords['Sym'])
            print(f"Ionizing orb sym:\t\t {orbInfo['orbSym']}")

            # orbInfo['orbIP'] = np.unique(molInfo['orbTable'].where(molInfo['orbTable'].coords['OrbGrp'] == orbInfo['iOrbGrp'], drop = True).sel(props = 'E'))
            # orbInfo['orbIPH'] = np.unique(molInfo['orbTable'].where(molInfo['orbTable'].coords['OrbGrp'] == orbInfo['iOrbGrp'], drop = True).sel(props = 'EH'))
            print(f"Orb energy (eV):\t\t {orbInfo['orbIP']}")
            print(f"Orb energy (H):\t\t\t {orbInfo['orbIPH']}")

            # orbInfo['orbIPnm'] = conv_ev_nm(orbInfo['orbIP'])
            # orbInfo['orbIPcm'] = 1/orbInfo['orbIPnm']*1e7
            # orbInfo['threshold'] = np.abs(orbInfo['orbIPnm'])[0]
            print(f"Orb energy (cm^-1):\t\t {orbInfo['orbIPcm']}")
            print(f"Threshold wavelength (nm):\t {orbInfo['threshold']}")

            # Check ExpOrb outputs...
            ind = (molInfo['orbTable'][:,8].values < 1-tolConv) + (molInfo['orbTable'][:,8].values > 1+tolConv)
            if ind.any():
                print(f"\n*** Warning: some orbital convergences outside single-center expansion convergence tolerance ({tolConv}):")
                print(molInfo['orbTable'][ind, [0, 8]].values)



    # Display structure
    if molInfo is not None:
        print('\n*** Molecular structure\n')
        molPlot(molInfo)

    # return jobInfo  # Why bother? - but left for back-compatibility for now... had plans to do more with this?
    # 20/09/20 Added orbInfo return, to allow further use of values determined above
    return jobInfo, orbInfo


# Pull orbital-specific properties for job.
def getOrbInfo(jobInfo, molInfo):
    """
    Pull orbital information for job from (jobInfo, molInfo) structures and return as Xarray.

    20/09/20 v1 Adapted from routines in jobSummary().

    """
    # Init dict.
    orbInfo = {}

    # Find which orb is ionized using ePS orb numbering.
    orbInfo['OrbOccInit'] = jobInfo['OrbOccInit']
    orbInfo['OrbOccFinal'] = jobInfo['OrbOcc']

    orbInfo['iList'] = jobInfo['OrbOccInit'] - jobInfo['OrbOcc']

    # Get orbGrp for event
    orbInfo['iOrbGrp'] = np.flatnonzero(orbInfo['iList']) + 1

    # Find entries in orbTableX
    orbX = molInfo['orbTable'].where(molInfo['orbTable'].coords['OrbGrp'] == orbInfo['iOrbGrp'], drop = True)

    # Set orb properties to structure - mainly just taking relevant entries from orbTable, but will make for easier reference later.
    orbInfo['orbN'] = orbX.orb.data[0]
    orbInfo['orbSym'] = np.unique(molInfo['orbTable'].where(molInfo['orbTable'].coords['OrbGrp'] == orbInfo['iOrbGrp'], drop = True).coords['Sym'])

    orbInfo['orbIP'] = np.unique(molInfo['orbTable'].where(molInfo['orbTable'].coords['OrbGrp'] == orbInfo['iOrbGrp'], drop = True).sel(props = 'E'))
    orbInfo['orbIPH'] = np.unique(molInfo['orbTable'].where(molInfo['orbTable'].coords['OrbGrp'] == orbInfo['iOrbGrp'], drop = True).sel(props = 'EH'))

    orbInfo['orbIPnm'] = conv_ev_nm(orbInfo['orbIP'])
    orbInfo['orbIPcm'] = 1/orbInfo['orbIPnm']*1e7
    orbInfo['threshold'] = np.abs(orbInfo['orbIPnm'])[0]

    return orbX, orbInfo


# Print (LM) and symmetry sets with Pandas tables
def lmSymSummary(data):
    """Display summary info data tables.

    Works nicely in a notebook cell, with Pandas formatted table... but not from function?

    For a more sophisticated Pandas conversion, see :py:func:`epsproc.util.conversion.multiDimXrToPD`

    """

    print('\n*** Index summary\n')

    test = data.Sym.to_pandas()
    print('Symmetry sets')
    print(test.unstack())

    test = data.LM.to_pandas()
    print('\n(L,M) sets')
    print(test.unstack().T)
