"""
IO module for CREST.

Developed by Dan Burrill and Michael Taylor
"""

# Imports
import ase.io as ase_io
import shutil
import copy
import subprocess as sub
import numpy as np
from architector import io_ptable

import architector.io_obabel as io_obabel 
import architector.io_molecule as io_molecule
import architector.arch_context_manage as arch_context_manage
import architector.io_xtb_calc as io_xtb_calc
from ase.optimize import BFGSLineSearch

def isint(str):
    out = False
    try:
        out = int(str)
        out = True
    except:
        out = False
    return out

# Functions
def read_conformers(fileName):
    '''
    Read conformers from file.

    Parameters
    ----------
    fileName : str
        File path to read in conformers generated by crest
    
    Returns
    ----------
    molList : list (ase.atoms.Atoms)
        Conformers read from crest as ASE atoms
    xtb_energies : list (float)
        Sorted xtb energies from the output
    '''
    # Read conformers to ASE Atoms objects
    molList = ase_io.read(fileName,index=":")

    # Center all conformer COMs at the origin
    for mol in molList:
        mol.translate(-mol.get_center_of_mass())

    with open(fileName,'r') as file1:
        lines = file1.readlines()
    xtb_energies = []
    read = False
    for line in lines:
        sline = line.split()
        if len(sline) == 1:
            is_int = isint(sline[0])
            if is_int:
                if int(sline[0]) == len(mol):
                    read=True
            elif read:
                xtb_energies.append(float(sline[0])*27.2114) # Hartee to eV
                read = False
            else:
                print('Warning - messed up!')
    return molList,xtb_energies

def crest_conformers(structure, charge=None, uhf=None, method='GFN2/GFNFF',
                    solvent='none', read_charge_spin=False):
    '''
    Find conformers of a given structure with CREST.

    Parameters
    ----------
    structure : ase.atoms.Atoms/str, optional
        structure passsed, default None
    charge : int, optional
        charge of the species, default None
    uhf : int, optional
        number of unpaired electrons in the system, default None
    method : str, optional
        which gfn family method to use, default GFN2/GFNFF
    solvent : str, optional
        whether to use a solvent for conformer evalulation, default 'none'
    read_charge_spin : bool, optional
        whether to read in the charge/spin from the structure (mol2), default False

    Returns
    ----------
    conformerList : list (ase.atoms.Atoms)
        Conformers generated via crest as ASE atoms
    xtb_energies : list (float)
        List of xtb energies output from CREST
    '''

    # Convert smiles to xyz string

    crestPath = shutil.which('crest')

    mol = io_molecule.convert_io_molecule(structure)
    if charge is not None:
        mol.charge = charge
    elif mol.charge is None:
        mol.detect_charge_spin()
    mol_charge = mol.charge

    even_odd_electrons = (np.sum([atom.number for atom in mol.ase_atoms])-mol_charge) % 2
    if (uhf is not None):
        if (even_odd_electrons == 1) and (uhf == 0):
            uhf = 1
        elif (even_odd_electrons == 1) and (uhf < 7) and (uhf % 2 == 0):
            uhf += 1
        elif (even_odd_electrons == 1) and (uhf >= 7) and (uhf % 2 == 0):
            uhf -= 1
        if (even_odd_electrons == 0) and (uhf % 2 == 1):
            uhf = uhf - 1 
        elif (even_odd_electrons == 1) and (uhf % 2 == 0):
            uhf = uhf + 1
    elif mol.xtb_uhf is not None:
        uhf = mol.xtb_uhf
    else:
        uhf = 0 # Set spin to LS by default
        if (even_odd_electrons == 1) and (uhf == 0):
            uhf = 1
        elif (even_odd_electrons == 1) and (uhf < 7) and (uhf % 2 == 0):
            uhf += 1
        elif (even_odd_electrons == 1) and (uhf >= 7) and (uhf % 2 == 0):
            uhf -= 1
        if (even_odd_electrons == 0) and (uhf % 2 == 1):
            uhf = uhf - 1 
        elif (even_odd_electrons == 1) and (uhf % 2 == 0):
            uhf = uhf + 1

    xyzstr = io_molecule.convert_ase_xyz(mol.ase_atoms)

    with arch_context_manage.make_temp_directory() as _:
        # Write xyz file
        with open("structure.xyz",'w') as outFile:
            outFile.write(xyzstr)

        if method == 'GFN2/GFNFF':
            method_string = '--gfn2//gfnff'
        elif method == 'GFN2-xTB':
            method_string = '--gnf2'
        elif method == 'GFN-FF':
            method_string = '--gfnff'
        else:
            raise ValueError("This crest method is not recognized {}.".format(method))

        # Run CREST
        if (uhf == 0) and (solvent == 'none'):
            execStr = "{} structure.xyz {} --chrg {} --notopo --quick > output.crest".format(
                                                                            crestPath,method_string,mol_charge)
        elif (uhf == 0) and (solvent != 'none'):
            execStr = "{} structure.xyz {} --chrg {} --alpb {} --notopo --quick > output.crest".format(
                                                                            crestPath,method_string,mol_charge,solvent)
        elif (solvent == 'none'):
            execStr = "{} structure.xyz {} --chrg {} --uhf {} --notopo --quick > output.crest".format(
                                                                            crestPath,method_string,mol_charge,uhf)
        else:
            execStr = "{} structure.xyz {} --chrg {} --uhf {} --alpb {} --notopo --quick > output.crest".format(
                                                                            crestPath,method_string,mol_charge,uhf,solvent)

        sub.run(execStr,shell=True,check=True)

        # Read conformers from file
        conformerList, xtb_energies = read_conformers("crest_conformers.xyz")

    return conformerList, xtb_energies

def crest_conformers_lig(smiles,ase_atoms=None,charge=None,uhf=0,method='GFN2/GFNFF',
                    solvent='none',neutralize=False,functionalizations=None):
    '''
    Find conformers of a given smiles with CREST.

    Parameters
    ----------
    smiles : str
        smiles string for ligand
    ase_atoms : ase.atoms.Atoms, optional
        structure passsed, default None
    charge : int, optional
        charge of the species, default to initial charges set on smiles
    uhf : int, optional
        number of unpaired electrons in the system, default to 0
    method : str, optional
        which gfn family method to use, default GFN2/GFNFF
    solvent : str, optional
        whether to use a solvent for conformer evalulation, default 'none'
    neutralize: bool, optional
        neutralize ligand before evaluating conformers? default False.
    functionalizations: bool, optional
        add functionalizations to the ligand, default False

    Returns
    ----------
    conformerList : list (ase.atoms.Atoms)
        Conformers generated via crest as ASE atoms
    xtb_energies : list (float)
        List of xtb energies output from CREST
    '''

    # Convert smiles to xyz string

    crestPath = shutil.which('crest')

    if ase_atoms is None:
        struct = io_obabel.get_obmol_smiles(
            smiles,
            neutralize=neutralize,
            functionalizations=functionalizations
            )
        if charge:
            mol_charge = charge
        else:
            mol_charge = struct.GetTotalCharge()
        ase_atoms = io_obabel.convert_obmol_ase(struct)
    else: 
        mol_charge = np.sum(ase_atoms.get_initial_charges())

    even_odd_electrons = (np.sum([atom.number for atom in ase_atoms])-mol_charge) % 2
    if (uhf is not None):
        if (even_odd_electrons == 1) and (uhf == 0):
            uhf = 1
        elif (even_odd_electrons == 1) and (uhf < 7) and (uhf % 2 == 0):
            uhf += 1
        elif (even_odd_electrons == 1) and (uhf >= 7) and (uhf % 2 == 0):
            uhf -= 1
        if (even_odd_electrons == 0) and (uhf % 2 == 1):
            uhf = uhf - 1 
        elif (even_odd_electrons == 1) and (uhf % 2 == 0):
            uhf = uhf + 1
    else:
        uhf = 0 # Set spin to LS by default
        if (even_odd_electrons == 1) and (uhf == 0):
            uhf = 1
        elif (even_odd_electrons == 1) and (uhf < 7) and (uhf % 2 == 0):
            uhf += 1
        elif (even_odd_electrons == 1) and (uhf >= 7) and (uhf % 2 == 0):
            uhf -= 1
        if (even_odd_electrons == 0) and (uhf % 2 == 1):
            uhf = uhf - 1 
        elif (even_odd_electrons == 1) and (uhf % 2 == 0):
            uhf = uhf + 1

    xyzstr = io_molecule.convert_ase_xyz(ase_atoms)

    with arch_context_manage.make_temp_directory() as _:
        # Write xyz file
        with open("structure.xyz",'w') as outFile:
            outFile.write(xyzstr)

        if method == 'GFN2/GFNFF':
            method_string = '--gfn2//gfnff'
        elif method == 'GFN2-xTB':
            method_string = '--gnf2'
        elif method == 'GFN-FF':
            method_string = '--gfnff'
        else:
            raise ValueError("This crest method is not recognized {}.".format(method))

        # Run CREST

        if (uhf == 0) and (solvent == 'none'):
            execStr = "{} structure.xyz {} --chrg {} --notopo --quick > output.crest".format(
                                                                            crestPath,method_string,mol_charge)
        elif (uhf == 0) and (solvent != 'none'):
            execStr = "{} structure.xyz {} --chrg {} --alpb {} --notopo --quick > output.crest".format(
                                                                            crestPath,method_string,mol_charge,solvent)
        elif (solvent == 'none'):
            execStr = "{} structure.xyz {} --chrg {} --uhf {} --notopo --quick > output.crest".format(
                                                                            crestPath,method_string,mol_charge,uhf)
        else:
            execStr = "{} structure.xyz {} --chrg {} --uhf {} --alpb {} --notopo --quick > output.crest".format(
                                                                            crestPath,method_string,mol_charge,uhf,solvent)

        sub.run(execStr,shell=True,check=True)

        # Read conformers from file
        conformerList, xtb_energies = read_conformers("crest_conformers.xyz")

    return conformerList, xtb_energies


def add_explicit_solvents(complex_mol2string,n_solvents=6,solvent='water'):
    """
    THIS IS HIGHLY EXPERIMENTAL:
    To use this xtb_IFF is needed: https://github.com/grimme-lab/xtbiff/releases/tag/v1.1 

    We were attempting this type of solvent addition protocol:
    https://xtb-docs.readthedocs.io/en/latest/crestqcg.html to add solvents!
    """
    
    if solvent in io_ptable.solvents_dict:
        solvent_smiles = io_ptable.solvents_dict[solvent]
    elif solvent in io_ptable.solvents_dict.values():
        solvent_smiles = copy.deepcopy(solvent)
        for key, value in io_ptable.solvents_dict.items():
            if value == solvent_smiles:
                solvent = key
    else:
        raise ValueError('Crest/XTB does not know this solvent!')

    mol = io_molecule.Molecule()
    charge,spin = mol.read_mol2(complex_mol2string,readstring=True,read_charge_spin=True)
    
    info_dict = dict()
    if charge is None:
        _,_,info_dict = io_obabel.obmol_lig_split(complex_mol2string,return_info=True)
        charge = int(io_ptable.metal_charge_dict[info_dict['metal']] + np.sum(info_dict['lig_charges']))
    if spin is None:
        even_odd_electrons = (np.sum([atom.number for atom in mol.ase_atoms])-charge) % 2
        if len(info_dict) > 1: # Check for info_dict
            uhf = io_ptable.metal_spin_dict[info_dict['metal']]
        else: # Otherwise calc the metal id!
             _,_,info_dict = io_obabel.obmol_lig_split(complex_mol2string,return_info=True)
             uhf = io_ptable.metal_spin_dict[info_dict['metal']]
        if (even_odd_electrons == 1) and (uhf == 0):
            uhf = 1
        elif (even_odd_electrons == 1) and (uhf < 7):
            uhf += 1
        elif (even_odd_electrons == 1) and (uhf >= 7):
            uhf -= 1
        if (even_odd_electrons == 0) and (uhf % 2 == 1):
            uhf = uhf - 1 
        elif (even_odd_electrons == 1) and (uhf % 2 == 0):
            uhf = uhf + 1
        spin = int(uhf)

    charge = int(charge)
    spin = int(spin)
    execStr = "crest solute.xyz --qcg solvent.xyz --chrg {} --uhf {} --nsolv {} --T 12 --ensemble --alpb {} --mdtime 50 --mddump 200 > crest.out".format(
        charge,spin,n_solvents,solvent)

    solvent_xyz_str = io_obabel.smiles2xyz(solvent_smiles)
    solute_xyz_str = mol.write_xyz('cool.xyz',writestring=True)

    metal_ind = [i for i,x in enumerate(mol.ase_atoms.get_chemical_symbols()) if x in io_ptable.all_metals][0]
    coord_atoms = np.nonzero(np.ravel(mol.graph[metal_ind]))[0]
    freeze_string = ','.join([str(x) for x in sorted(coord_atoms.tolist() + [metal_ind])])

    molcontrol_str = """$constrain
  atoms: {}
$end""".format(freeze_string)

    # os.mkdir('test') # For development
    # os.chdir('test')

    with arch_context_manage.make_temp_directory() as _:
        with open('solute.xyz','w') as file1:
            file1.write(solute_xyz_str)
        with open('solvent.xyz','w') as file1:
            file1.write(solvent_xyz_str)
        with open('.xcontrol','w') as file1:
            file1.write(molcontrol_str)
    
    sub.run(execStr,shell=True,check=True)

    conformerList, xtb_energies = read_conformers("final_ensemble.xyz")
        
    return conformerList, xtb_energies
        

def obmol_xtb_conformers(smiles,charge=None,uhf=0,method='GFN2-xTB',total_confs=3000,
                    solvent='none',neutralize=False,functionalizations=None):
    '''
    Relax ligand structure with xTB. (Sampling done with openbabel)

    Parameters
    ----------
    smiles : str
        smiles string for ligand
    charge : int, optional
        charge of the species, default to initial charges set on smiles
    uhf : int, optional
        number of unpaired electrons in the system, default to 0
    method : str, optional
        which gfn family method to use, default GFN2/GFNFF
    solvent : str, optional
        whether to use a solvent for conformer evalulation, default 'none'
    neutralize: bool, optional
        neutralize ligand before evaluating conformers? default False.
    functionalizations : None, optional
        functionalizations option for smiles

    Returns
    ----------
    conformerList : list (ase.atoms.Atoms)
        Conformers generated via openbabel as ASE atoms
    xtb_energies : list (float)
        List of xtb energies output 
    '''

    struct = io_obabel.get_obmol_smiles(
        smiles,
        neutralize=neutralize
        )
    conf_list = io_obabel.generate_obmol_conformers(smiles, neutralize=neutralize,
        functionalizations=functionalizations,conf_cutoff=total_confs)
    if charge:
        mol_charge = charge
    else:
        mol_charge = struct.GetTotalCharge()

    conformerList = []
    xtb_energies = []

    with arch_context_manage.make_temp_directory() as _:
        for conf in conf_list:
            mol = io_molecule.convert_io_molecule(conf)
            atoms = mol.ase_atoms
            io_xtb_calc.set_XTB_calc_lig(atoms,charge=mol_charge,method=method,uhf=uhf,
                                         solvent=solvent)
            fail = True
            try:
                dyn = BFGSLineSearch(atoms)
                dyn.run()
                fail = False
            except:
                fail = True 
            if not fail:
                conformerList.append(atoms)
                xtb_energies.append(atoms.get_total_energy())

    return conformerList, xtb_energies


# Main (Unit Tests)
if (__name__ == '__main__'):
    # Variables
    smiles = "n1ccccc1-c2ccccn2"

    crestPath=shutil.which('crest')
    
    # Check conformers
    with arch_context_manage.make_temp_directory() as _:
        conformerList,energies = obmol_xtb_conformers(smiles,method='GFN2/GFNFF')

    print(conformerList)

    for idx,mol in enumerate(conformerList):
        mol.write("con_{}.xyz".format(energies[idx]))