#!/usr/bin/env python3

import numpy as np
from dftpy.interface import ConfigParser
from dftpy.td.real_time_runner import RealTimeRunner
import time
from dftpy.time_data import TimeData
from dftpy.mpi import mp, sprint
from dftpy import __version__
from dftpy.cui.main import GetConf
from dftpy.functional import Functional
from dftpy.functional.external_potential import ExternalPotential
from dftpy.inverter import Inverter


def runner(config, rho, functionals):
    if config['KEDF2']['kedf'] == 'KS':
        inv = Inverter()
        ext = inv(rho, functionals)
    else:
        vt0 = functionals.funcDict['KineticEnergyFunctional'](rho=rho.field, calcType = {'V'}).potential
        kedf0 = Functional(type='KEDF', name=config['KEDF2']['kedf'], **config['KEDF2'])
        vn0t0 = kedf0(rho=rho.field, calcType={'V'}).potential
        ext = ExternalPotential(v=vn0t0-vt0)
    functionals.UpdateFunctional(newFuncDict={'ext': ext})

    realtimerunner = RealTimeRunner(rho, config, functionals)
    realtimerunner()


def RunJob(args):

    sprint("DFTpy {} Begin on : {}".format(__version__, time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())))
    if mp.is_mpi:
        info = 'Parallel version (MPI) on {0:>8d} processors'.format(mp.comm.size)
    else:
        info = 'Serial version on {0:>8d} processors'.format(mp.comm.size)
    sprint(info)
    if len(args.confs) == 0:
        args.confs.append(args.input)
    for fname in args.confs:
        config, others = ConfigParser(fname, mp=mp)
        sprint("#" * 80)
        TimeData.Begin("TOTAL")

        runner(config, others["field"], others["E_v_Evaluator"])

        TimeData.End("TOTAL")
        TimeData.output(config)
        sprint("-" * 80)
    sprint("#" * 80)
    sprint("DFTpy {} Finished on : {}".format(__version__, time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())))


def main():
    import sys
    from dftpy.mpi.utils import sprint

    if sys.version_info < (3, 6):
        sprint('Please upgrade your python to version 3.6 or higher.')
        sys.exit()
    args = GetConf()
    RunJob(args)


if __name__ == '__main__':
    main()