#! /usr/bin/env python
"""
block2-gaopt wrapper.

Author: Huanchen Zhai (Jan 22, 2021)
"""

from block2 import FCIDUMP, OrbitalOrdering, VectorDouble, VectorUInt16, Random
import numpy as np
import sys

try:
    from pyblock2.driver.parser import parse_gaopt, read_integral
except ImportError:
    from parser import parse_gaopt, read_integral

DEBUG = True

if len(sys.argv) > 1:
    arg_dic = {}
    for i in range(1, len(sys.argv)):
        if sys.argv[i] == '-s' or sys.argv[i] == '-fiedler':
            arg_dic[sys.argv[i][1:]] = ''
        elif sys.argv[i].startswith('-'):
            arg_dic[sys.argv[i][1:]] = sys.argv[i + 1]
else:
    raise ValueError("""
        Usage:
            (A) python gaopt -integral FCIDUMP
            (B) python gaopt -config gaopt.conf -integral FCIDUMP
            (C) python gaopt -config gaopt.conf -integral ints.h5
            (D) python gaopt -s -config gaopt.conf -integral kmat (read kmat)
            (E) python gaopt ... -w kmat (write kmat)
            (F) python gaopt ... -wint FCIDUMP.NEW (write reordered FCIDUMP)
            (G) python gaopt -ord RestartReorder.dat -wint FCIDUMP.NEW
                (read from StackBlock reorder file, write reordered FCIDUMP)
            (H) python gaopt -fiedler -wint FCIDUMP.NEW
                (write fiedler reordered FCIDUMP)
    """)

# MPI
from mpi4py import MPI as MPI
comm = MPI.COMM_WORLD
mrank = MPI.COMM_WORLD.Get_rank()
msize = MPI.COMM_WORLD.Get_size()


def _print(*args, **kwargs):
    if mrank == 0:
        print(*args, **kwargs)

# read integrals/kmat
fints = arg_dic["integral"]
fcidump = None
if 's' not in arg_dic:
    if open(fints, 'rb').read(4) != b'\x89HDF':
        fcidump = FCIDUMP()
        fcidump.read(fints)
    else:
        fcidump = read_integral(fints, 0, 0)
    n_sites = fcidump.n_sites
    hmat = fcidump.abs_h1e_matrix()
    xmat = fcidump.abs_exchange_matrix()
    kmat = VectorDouble(np.array(hmat) * 1E-7 + np.array(xmat))
    if 'w' in arg_dic and mrank == 0:
        with open(arg_dic['w'], 'w') as kfin:
            kfin.write("%d\n" % n_sites)
            for i in range(n_sites):
                for j in range(n_sites):
                    kfin.write(" %24.16e" % kmat[i * n_sites + j])
                kfin.write("\n")
else:
    if len(fints.split(':')) == 2:
        fa, fb = fints.split(':')
        if open(fb, 'rb').read(4) != b'\x89HDF':
            fcidump = FCIDUMP()
            fcidump.read(fb)
        else:
            fcidump = read_integral(fb, 0, 0)
        fints = fa
    with open(fints, 'r') as kfin:
        klines = kfin.readlines()
    n_sites = int(klines[0].strip())
    kmat = VectorDouble([float(x) for kmat in klines[1:n_sites + 1]
                         for x in kmat.strip().split() if x != ''])
    assert len(kmat) == n_sites * n_sites

# options
dic = parse_gaopt(arg_dic["config"]) if "config" in arg_dic else {}
assert dic.get("method", "gauss") == "gauss"
assert float(dic.get("scale", 1.0)) == 1.0
n_tasks = int(dic.get("maxcomm", 32))
opts = dict(
    n_generations=int(dic.get("maxgen", 10000)),
    n_configs=int(dic.get("maxcell", n_sites * 2)),
    n_elite=int(dic.get("elite", 8)),
    clone_rate=1.0 - float(dic.get("cloning", 0.9)),
    mutate_rate=float(dic.get("mutation", 0.1))
)

# original cost
if mrank == 0:
    origf = OrbitalOrdering.evaluate(n_sites, kmat, VectorUInt16(range(0, n_sites)))
    print("Default order : f = %20.12f" % origf, flush=True)

# run
midx, mf = None, None

if "ord" in arg_dic:
    print("read reordering from", arg_dic["ord"])
    idx = [int(x) for x in open(arg_dic["ord"], "r").readline().split()]
    f = OrbitalOrdering.evaluate(n_sites, kmat, VectorUInt16(idx))
    idx = np.array(idx) + 1
    midx, mf = idx, f
    n_tasks = 0
elif "fiedler" in arg_dic:
    print("use fiedler reorder")
    idx = OrbitalOrdering.fiedler(n_sites, kmat)
    f = OrbitalOrdering.evaluate(n_sites, kmat, idx)
    idx = np.array(idx) + 1
    midx, mf = idx, f
    n_tasks = 0

for i_task in range(0, n_tasks):
    Random.rand_seed(1234 + i_task)
    if i_task * msize // n_tasks == mrank:
        idx = OrbitalOrdering.ga_opt(n_sites, kmat, **opts)
        f = OrbitalOrdering.evaluate(n_sites, kmat, idx)
        idx = np.array(idx) + 1
        print("Order # %4d : %r / f = %20.12f" % (i_task, idx, f), flush=True)
        if mf is None or f < mf:
            midx, mf = idx, f

if msize != 0:
    comm.barrier()

# collect
if mrank != 0:
    comm.send((midx, mf), dest=0, tag=11)
else:
    for irank in range(1, msize):
        idx, f = comm.recv(source=irank, tag=11)
        if f is not None and (mf is None or f < mf):
            midx, mf = idx, f
    print('MINIMUM / f = %20.12f' % mf)
    print('DMRG REORDER FORMAT')
    print(','.join(map(str, midx)))

# write FCIDUMP
if 'wint' in arg_dic and mrank == 0:
    assert fcidump is not None
    mx = VectorUInt16(midx - 1)
    fcidump.reorder(mx)
    fcidump.write(arg_dic['wint'])
