#! /usr/bin/env python
"""
block2 wrapper.

Author:
    Huanchen Zhai
    Zhi-Hao Cui
"""

from block2 import SZ, SU2, SZK, SU2K, SGF, DoubleFPCodec as FPCodec
from block2 import Global, OpNamesSet, NoiseTypes, DecompositionTypes, Threading, ThreadingTypes
from block2 import init_memory, release_memory, set_mkl_num_threads, read_occ, TruncationTypes
from block2 import VectorUInt8, VectorUBond, VectorVectorUBond, VectorDouble, PointGroup
from block2 import Random, FCIDUMP, QCTypes, SeqTypes, TETypes, OpNames, VectorInt, VectorUInt16
from block2 import MatrixFunctions, KuhnMunkres, Matrix, DyallFCIDUMP, FinkFCIDUMP, ConvergenceTypes
from block2 import HubbardKSpaceFCIDUMP, HubbardFCIDUMP, HeisenbergFCIDUMP, ExpectationAlgorithmTypes
from block2 import SpinOrbitalFCIDUMP, MRCISFCIDUMP, VectorVectorInt
import numpy as np
import time
import os
import sys

VectorFP = VectorDouble

try:
    from pyblock2.driver.parser import parse, orbital_reorder, read_integral, format_schedule
except ImportError:
    from parser import parse, orbital_reorder, read_integral, format_schedule

DEBUG = True

if len(sys.argv) > 1:
    fin = sys.argv[1]
    if len(sys.argv) == 2 and fin == '-v':
        print('Block 2.0')
        quit()
    elif len(sys.argv) > 2 and sys.argv[2] in ["pre", "para-pre"]:
        pre_run = True
        para_pre_run = sys.argv[2] == "para-pre"
    else:
        pre_run = para_pre_run = False
    if len(sys.argv) > 2 and sys.argv[2] in ["run", "para-run"]:
        no_pre_run = True
        para_no_pre_run = sys.argv[2] == "para-run"
    else:
        no_pre_run = para_no_pre_run = False
else:
    raise ValueError("""
        Usage: any of:
            (A) python block2main dmrg.conf
            (B) reduced memory mode (save/load serial mpo):
                Step 1: python block2main dmrg.conf pre
                Step 2: python block2main dmrg.conf run
            (C) extra reduced memory mode (save/load parallel mpo):
                Step 1: python block2main dmrg.conf para-pre
                Step 2: python block2main dmrg.conf para-run
            (D) python block2main FCIDUMP
    """)

dic = parse(fin)
if "single_prec" in dic and "use_complex" not in dic:
    assert "k_symmetry" not in dic
    from block2 import FloatFPCodec as FPCodec, init_memory_float
    from block2 import VectorFloat
    VectorFP = VectorFloat
    init_memory = init_memory_float
    from block2.sp import FCIDUMP, SpinOrbitalFCIDUMP, MRCISFCIDUMP
    if "nonspinadapted" in dic and "use_general_spin" not in dic:
        from block2 import VectorSZ as VectorSL
        from block2.sz import MultiMPSInfo, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo, CG
        from block2.sp.sz import MultiMPS, CSRSparseMatrix, FusedMPO, CSROperatorFunctions
        from block2.sp.sz import HamiltonianQC, MPS, ParallelRuleQC
        from block2.sp.sz import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule
        from block2.sp.sz import Expect, DMRG, MovingEnvironment, OperatorFunctions, TensorFunctions, MPO
        from block2.sp.sz import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.sp.sz import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.sp.sz import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, Expect as ComplexExpect
        from block2.sz import trans_state_info_to_su2 as trans_si
        from block2.su2 import MPSInfo as TrMPSInfo
        from block2.su2 import trans_mps_info_to_sz as trans_mi, VectorStateInfo as TrVectorStateInfo
        SX = SZ
        TrSX = SU2
    elif "nonspinadapted" not in dic and "use_general_spin" not in dic:
        from block2 import VectorSU2 as VectorSL
        from block2.su2 import MultiMPSInfo, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo, CG
        from block2.sp.su2 import MultiMPS, CSRSparseMatrix, FusedMPO, CSROperatorFunctions
        from block2.sp.su2 import HamiltonianQC, MPS, ParallelRuleQC
        from block2.sp.su2 import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule
        from block2.sp.su2 import Expect, DMRG, MovingEnvironment, OperatorFunctions, TensorFunctions, MPO
        from block2.sp.su2 import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.sp.su2 import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.sp.su2 import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, Expect as ComplexExpect
        from block2.su2 import trans_state_info_to_sz as trans_si, trans_unfused_mps_to_sz as trans_mps
        from block2.sz import MPSInfo as TrMPSInfo
        from block2.sz import trans_mps_info_to_su2 as trans_mi, VectorStateInfo as TrVectorStateInfo
        SX = SU2
        TrSX = SZ
    elif "use_general_spin" in dic:
        from block2 import VectorSGF as VectorSL
        from block2.sgf import MultiMPSInfo, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo, CG
        from block2.sp.sgf import MultiMPS, CSRSparseMatrix, FusedMPO, CSROperatorFunctions
        from block2.sp.sgf import HamiltonianQC, MPS, ParallelRuleQC
        from block2.sp.sgf import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule
        from block2.sp.sgf import Expect, DMRG, MovingEnvironment, OperatorFunctions, TensorFunctions, MPO
        from block2.sp.sgf import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.sp.sgf import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.sp.sgf import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, Expect as ComplexExpect
        SX = SGF
    else:
        assert False

    try:
        if "nonspinadapted" in dic:
            from block2.sz import HamiltonianQCBigSite, SimplifiedBigSite, ParallelBigSite
            from block2.sz import DMRGBigSite, LinearBigSite
            from block2.sz import SCIFockBigSite
        else:
            from block2.su2 import HamiltonianQCBigSite, SimplifiedBigSite, ParallelBigSite
            from block2.su2 import DMRGBigSite, LinearBigSite
            from block2.su2 import CSFBigSite
    except ImportError:
        pass
elif "single_prec" in dic and "use_complex" in dic:
    assert "k_symmetry" not in dic
    from block2 import FloatFPCodec as FPCodec, init_memory_float
    from block2 import VectorFloat
    VectorFP = VectorFloat
    init_memory = init_memory_float
    from block2.sp.cpx import FCIDUMP, SpinOrbitalFCIDUMP, MRCISFCIDUMP
    if "nonspinadapted" in dic and "use_general_spin" not in dic:
        from block2 import VectorSZ as VectorSL
        from block2.sz import MultiMPSInfo, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo, CG
        from block2.sp.cpx.sz import MultiMPS, CSRSparseMatrix, FusedMPO, CSROperatorFunctions
        from block2.sp.cpx.sz import HamiltonianQC, MPS, ParallelRuleQC
        from block2.sp.cpx.sz import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule
        from block2.sp.cpx.sz import Expect, DMRG, MovingEnvironment, OperatorFunctions, TensorFunctions, MPO
        from block2.sp.cpx.sz import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.sp.cpx.sz import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.sp.cpx.sz import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, Expect as ComplexExpect
        from block2.sz import trans_state_info_to_su2 as trans_si
        from block2.su2 import MPSInfo as TrMPSInfo
        from block2.su2 import trans_mps_info_to_sz as trans_mi, VectorStateInfo as TrVectorStateInfo
        SX = SZ
        TrSX = SU2
    elif "nonspinadapted" not in dic and "use_general_spin" not in dic:
        from block2 import VectorSU2 as VectorSL
        from block2.su2 import MultiMPSInfo, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo, CG
        from block2.sp.cpx.su2 import MultiMPS, CSRSparseMatrix, FusedMPO, CSROperatorFunctions
        from block2.sp.cpx.su2 import HamiltonianQC, MPS, ParallelRuleQC
        from block2.sp.cpx.su2 import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule
        from block2.sp.cpx.su2 import Expect, DMRG, MovingEnvironment, OperatorFunctions, TensorFunctions, MPO
        from block2.sp.cpx.su2 import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.sp.cpx.su2 import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.sp.cpx.su2 import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, Expect as ComplexExpect
        from block2.su2 import trans_state_info_to_sz as trans_si, trans_unfused_mps_to_sz as trans_mps
        from block2.sz import MPSInfo as TrMPSInfo
        from block2.sz import trans_mps_info_to_su2 as trans_mi, VectorStateInfo as TrVectorStateInfo
        SX = SU2
        TrSX = SZ
    elif "use_general_spin" in dic:
        from block2 import VectorSGF as VectorSL
        from block2.sgf import MultiMPSInfo, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo, CG
        from block2.sp.cpx.sgf import MultiMPS, CSRSparseMatrix, FusedMPO, CSROperatorFunctions
        from block2.sp.cpx.sgf import HamiltonianQC, MPS, ParallelRuleQC
        from block2.sp.cpx.sgf import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule
        from block2.sp.cpx.sgf import Expect, DMRG, MovingEnvironment, OperatorFunctions, TensorFunctions, MPO
        from block2.sp.cpx.sgf import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.sp.cpx.sgf import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.sp.cpx.sgf import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, Expect as ComplexExpect
        SX = SGF
    else:
        assert False

    try:
        if "nonspinadapted" in dic:
            from block2.sz import HamiltonianQCBigSite, SimplifiedBigSite, ParallelBigSite
            from block2.sz import DMRGBigSite, LinearBigSite
            from block2.sz import SCIFockBigSite
        else:
            from block2.su2 import HamiltonianQCBigSite, SimplifiedBigSite, ParallelBigSite
            from block2.su2 import DMRGBigSite, LinearBigSite
            from block2.su2 import CSFBigSite
    except ImportError:
        pass
elif "use_complex" not in dic:
    if "nonspinadapted" in dic and "k_symmetry" not in dic and "use_general_spin" not in dic:
        from block2 import VectorSZ as VectorSL
        from block2.sz import MultiMPS, MultiMPSInfo, CSRSparseMatrix, FusedMPO, CSROperatorFunctions
        from block2.sz import HamiltonianQC, MPS, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo, ParallelRuleQC
        from block2.sz import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule
        from block2.sz import Expect, DMRG, MovingEnvironment, OperatorFunctions, CG, TensorFunctions, MPO
        from block2.sz import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.sz import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.sz import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, ComplexExpect
        from block2.sz import trans_state_info_to_su2 as trans_si
        from block2.su2 import MPSInfo as TrMPSInfo
        from block2.su2 import trans_mps_info_to_sz as trans_mi, VectorStateInfo as TrVectorStateInfo
        SX = SZ
        TrSX = SU2
    elif "nonspinadapted" not in dic and "k_symmetry" not in dic and "use_general_spin" not in dic:
        from block2 import VectorSU2 as VectorSL
        from block2.su2 import MultiMPS, MultiMPSInfo, CSRSparseMatrix, FusedMPO, CSROperatorFunctions
        from block2.su2 import HamiltonianQC, MPS, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo, ParallelRuleQC
        from block2.su2 import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule
        from block2.su2 import Expect, DMRG, MovingEnvironment, OperatorFunctions, CG, TensorFunctions, MPO
        from block2.su2 import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.su2 import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.su2 import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, ComplexExpect
        from block2.su2 import trans_state_info_to_sz as trans_si, trans_unfused_mps_to_sz as trans_mps
        from block2.sz import MPSInfo as TrMPSInfo
        from block2.sz import trans_mps_info_to_su2 as trans_mi, VectorStateInfo as TrVectorStateInfo
        SX = SU2
        TrSX = SZ
    elif "nonspinadapted" in dic and "k_symmetry" in dic:
        from block2 import VectorSZK as VectorSL
        from block2.szk import MultiMPS, MultiMPSInfo, CSRSparseMatrix, FusedMPO, CSROperatorFunctions
        from block2.szk import HamiltonianQC, MPS, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo, ParallelRuleQC
        from block2.szk import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule
        from block2.szk import Expect, DMRG, MovingEnvironment, OperatorFunctions, CG, TensorFunctions, MPO
        from block2.szk import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.szk import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.szk import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, ComplexExpect
        from block2.szk import trans_state_info_to_su2k as trans_si
        from block2.su2k import MPSInfo as TrMPSInfo
        from block2.su2k import trans_mps_info_to_szk as trans_mi, VectorStateInfo as TrVectorStateInfo
        SX = SZK
        TrSX = SU2K
    elif "nonspinadapted" not in dic and "k_symmetry" in dic:
        from block2 import VectorSU2K as VectorSL
        from block2.su2k import MultiMPS, MultiMPSInfo, CSRSparseMatrix, FusedMPO, CSROperatorFunctions
        from block2.su2k import HamiltonianQC, MPS, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo, ParallelRuleQC
        from block2.su2k import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule
        from block2.su2k import Expect, DMRG, MovingEnvironment, OperatorFunctions, CG, TensorFunctions, MPO
        from block2.su2k import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.su2k import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.su2k import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, ComplexExpect
        from block2.su2k import trans_state_info_to_szk as trans_si, trans_unfused_mps_to_szk as trans_mps
        from block2.szk import MPSInfo as TrMPSInfo
        from block2.szk import trans_mps_info_to_su2k as trans_mi, VectorStateInfo as TrVectorStateInfo
        SX = SU2K
        TrSX = SZK
    elif "use_general_spin" in dic:
        from block2 import VectorSGF as VectorSL
        from block2.sgf import MultiMPS, MultiMPSInfo, CSRSparseMatrix, FusedMPO, CSROperatorFunctions
        from block2.sgf import HamiltonianQC, MPS, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo, ParallelRuleQC
        from block2.sgf import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule
        from block2.sgf import Expect, DMRG, MovingEnvironment, OperatorFunctions, CG, TensorFunctions, MPO
        from block2.sgf import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.sgf import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.sgf import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, ComplexExpect
        SX = SGF
    else:
        assert False

    try:
        if "nonspinadapted" in dic:
            from block2.sz import HamiltonianQCBigSite, SimplifiedBigSite, ParallelBigSite
            from block2.sz import DMRGBigSite, LinearBigSite
            from block2.sz import SCIFockBigSite
        else:
            from block2.su2 import HamiltonianQCBigSite, SimplifiedBigSite, ParallelBigSite
            from block2.su2 import DMRGBigSite, LinearBigSite
            from block2.su2 import CSFBigSite
    except ImportError:
        pass
else:
    from block2.cpx import FCIDUMP, SpinOrbitalFCIDUMP, MRCISFCIDUMP
    if "nonspinadapted" in dic and "k_symmetry" not in dic and "use_general_spin" not in dic:
        from block2 import VectorSZ as VectorSL
        from block2.sz import MultiMPSInfo, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo, CG
        from block2.cpx.sz import MultiMPS, CSRSparseMatrix, FusedMPO, CSROperatorFunctions
        from block2.cpx.sz import HamiltonianQC, MPS, ParallelRuleQC
        from block2.cpx.sz import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule
        from block2.cpx.sz import Expect, DMRG, MovingEnvironment, OperatorFunctions, TensorFunctions, MPO
        from block2.cpx.sz import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.cpx.sz import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.cpx.sz import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, Expect as ComplexExpect
        from block2.sz import trans_state_info_to_su2 as trans_si
        from block2.su2 import MPSInfo as TrMPSInfo
        from block2.su2 import trans_mps_info_to_sz as trans_mi, VectorStateInfo as TrVectorStateInfo
        SX = SZ
        TrSX = SU2
    elif "nonspinadapted" not in dic and "k_symmetry" not in dic and "use_general_spin" not in dic:
        from block2 import VectorSU2 as VectorSL
        from block2.su2 import MultiMPSInfo, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo, CG
        from block2.cpx.su2 import MultiMPS, CSRSparseMatrix, FusedMPO, CSROperatorFunctions
        from block2.cpx.su2 import HamiltonianQC, MPS, ParallelRuleQC
        from block2.cpx.su2 import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule
        from block2.cpx.su2 import Expect, DMRG, MovingEnvironment, OperatorFunctions, TensorFunctions, MPO
        from block2.cpx.su2 import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.cpx.su2 import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.cpx.su2 import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, Expect as ComplexExpect
        from block2.su2 import trans_state_info_to_sz as trans_si, trans_unfused_mps_to_sz as trans_mps
        from block2.sz import MPSInfo as TrMPSInfo
        from block2.sz import trans_mps_info_to_su2 as trans_mi, VectorStateInfo as TrVectorStateInfo
        SX = SU2
        TrSX = SZ
    elif "nonspinadapted" in dic and "k_symmetry" in dic:
        from block2 import VectorSZK as VectorSL
        from block2.szk import MultiMPSInfo, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo, CG
        from block2.cpx.szk import MultiMPS, CSRSparseMatrix, FusedMPO, CSROperatorFunctions
        from block2.cpx.szk import HamiltonianQC, MPS, ParallelRuleQC
        from block2.cpx.szk import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule
        from block2.cpx.szk import Expect, DMRG, MovingEnvironment, OperatorFunctions, TensorFunctions, MPO
        from block2.cpx.szk import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.cpx.szk import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.cpx.szk import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, Expect as ComplexExpect
        from block2.szk import trans_state_info_to_su2k as trans_si
        from block2.su2k import MPSInfo as TrMPSInfo
        from block2.su2k import trans_mps_info_to_szk as trans_mi, VectorStateInfo as TrVectorStateInfo
        SX = SZK
        TrSX = SU2K
    elif "nonspinadapted" not in dic and "k_symmetry" in dic:
        from block2 import VectorSU2K as VectorSL
        from block2.su2k import MultiMPSInfo, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo, CG
        from block2.cpx.su2k import MultiMPS, CSRSparseMatrix, FusedMPO, CSROperatorFunctions
        from block2.cpx.su2k import HamiltonianQC, MPS, ParallelRuleQC
        from block2.cpx.su2k import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule
        from block2.cpx.su2k import Expect, DMRG, MovingEnvironment, OperatorFunctions, TensorFunctions, MPO
        from block2.cpx.su2k import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.cpx.su2k import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.cpx.su2k import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, Expect as ComplexExpect
        from block2.su2k import trans_state_info_to_szk as trans_si, trans_unfused_mps_to_szk as trans_mps
        from block2.szk import MPSInfo as TrMPSInfo
        from block2.szk import trans_mps_info_to_su2k as trans_mi, VectorStateInfo as TrVectorStateInfo
        SX = SU2K
        TrSX = SZK
    elif "use_general_spin" in dic:
        from block2 import VectorSGF as VectorSL
        from block2.sgf import MultiMPSInfo, MPSInfo, CASCIMPSInfo, MRCIMPSInfo, NEVPTMPSInfo, CG
        from block2.cpx.sgf import MultiMPS, CSRSparseMatrix, FusedMPO, CSROperatorFunctions
        from block2.cpx.sgf import HamiltonianQC, MPS, ParallelRuleQC
        from block2.cpx.sgf import PDM1MPOQC, NPC1MPOQC, SimplifiedMPO, Rule, RuleQC, MPOQC, NoTransposeRule
        from block2.cpx.sgf import Expect, DMRG, MovingEnvironment, OperatorFunctions, TensorFunctions, MPO
        from block2.cpx.sgf import ParallelRuleQC, ParallelMPO, ParallelMPS, IdentityMPO, VectorMPS, PDM2MPOQC
        from block2.cpx.sgf import ParallelRulePDM1QC, ParallelRulePDM2QC, ParallelRuleIdentity, ParallelRuleOneBodyQC
        from block2.cpx.sgf import AntiHermitianRuleQC, TimeEvolution, Linear, DeterminantTRIE, UnfusedMPS, Expect as ComplexExpect
        SX = SGF
    else:
        assert False

    try:
        if "nonspinadapted" in dic:
            from block2.sz import HamiltonianQCBigSite, SimplifiedBigSite, ParallelBigSite
            from block2.sz import DMRGBigSite, LinearBigSite
            from block2.sz import SCIFockBigSite
        else:
            from block2.su2 import HamiltonianQCBigSite, SimplifiedBigSite, ParallelBigSite
            from block2.su2 import DMRGBigSite, LinearBigSite
            from block2.su2 import CSFBigSite
    except ImportError:
        pass

try:
    if "nonspinadapted" in dic and "k_symmetry" not in dic and "use_general_spin" not in dic:
        from block2.sz import MPICommunicator
    elif "nonspinadapted" not in dic and "k_symmetry" not in dic and "use_general_spin" not in dic:
        from block2.su2 import MPICommunicator
    elif "nonspinadapted" in dic and "k_symmetry" in dic:
        from block2.szk import MPICommunicator
    elif "nonspinadapted" not in dic and "k_symmetry" in dic:
        from block2.su2k import MPICommunicator
    elif "use_general_spin" in dic:
        from block2.sgf import MPICommunicator
    MPI = MPICommunicator()
    # from mpi4py import MPI as PYMPI
    # comm = PYMPI.COMM_WORLD

    def _print(*args, **kwargs):
        if MPI.rank == 0 and outputlevel > -1:
            kwargs["flush"] = True
            print(*args, **kwargs)
except ImportError:
    MPI = None
    _print = print


tx = time.perf_counter()

# input parameters
Random.rand_seed(1234)
outputlevel = int(dic.get("outputlevel", 2))
if DEBUG:
    _print("\n" + "*" * 34 + " INPUT START " + "*" * 34)
    for key, val in dic.items():
        if key == "schedule":
            pval = format_schedule(val)
            for ipv, pv in enumerate(pval):
                _print("%-25s %40s" % (key if ipv == 0 else "", pv))
        else:
            _print("%-25s %40s" % (key, val))
    _print("*" * 34 + " INPUT END   " + "*" * 34 + "\n")
    if "use_general_spin" in dic:
        _print("GENERAL SPIN - ", end='')
    else:
        _print("SPIN ADAPTED - " if "nonspinadapted" not in dic else "NON SPIN ADAPTED - ", end='')
    _print("REAL DOMAIN - " if "use_complex" not in dic else "COMPLEX DOMAIN - ", end='')
    _print("SINGLE PREC" if "single_prec" in dic else "DOUBLE PREC")

scratch = dic.get("prefix", "./nodex/")
restart_dir = dic.get("restart_dir", None)
stackblock_compat = dic.get("hf_occ", None) == "integral"
restart_dir_per_sweep = dic.get("restart_dir_per_sweep", None)
if stackblock_compat:
    n_threads = int(dic.get("num_thrds", 1))
else:
    n_threads = int(dic.get("num_thrds", 28))
mkl_threads = int(dic.get("mkl_thrds", 1))
bond_dims, dav_thrds, noises, site_dependent_bdims = dic["schedule"]
if max([len(x) for x in site_dependent_bdims]) == 0:
    site_dependent_bdims = []
site_dependent_bdims = VectorVectorUBond(
    [VectorUBond(x) for x in site_dependent_bdims])
store_wfn_spectra = "store_wfn_spectra" in dic
sweep_tol = float(dic.get("sweep_tol", 1e-6))
cached_contraction = int(dic.get("cached_contraction", 1)) == 1
singlet_embedding = "singlet_embedding" in dic
integral_rescale = dic.get("integral_rescale", "auto")
siv = dic.get("symmetrize_ints", 1E-10)
init_mps_center = int(dic.get("init_mps_center", 0))
symmetrize_ints_tol = 1E-10 if siv == "" else float(siv)
dynamic_corr_method = None
sub_spaces =['ijrs', 'ij', 'rs', 'ijr', 'rsi', 'ir', 'i', 'r']
for dyn_key in ["dmrgfci", "mrci", "mrcis", "mrcisd", "mrcisdt",
                "casci", "nevpt2", "nevpt2s", "nevpt2sd",
                "mrrept2", "mrrept2s", "mrrept2sd",
                *["nevpt2-" + x for x in sub_spaces],
                *["mrrept2-" + x for x in sub_spaces]]:
    if dyn_key in dic:
        dynamic_corr_method = [dyn_key, [int(x) for x in dic[dyn_key].split()]]
        if dynamic_corr_method[0] == 'mrci':
            dynamic_corr_method[0] = 'mrcisd'
        elif dynamic_corr_method[0] == 'nevpt2':
            dynamic_corr_method[0] = 'nevpt2sd'
        elif dynamic_corr_method[0] == 'mrrept2':
            dynamic_corr_method[0] = 'mrrept2sd'
        break
big_site_method = dic.get("big_site", None)
n_cas = 0
if dic.get("qc_mpo_type", "auto") != "auto":
    qctstr = dic.get("qc_mpo_type", "auto")
    if qctstr == "conventional":
        qc_type = QCTypes.Conventional
    elif qctstr == "nc":
        qc_type = QCTypes.NC
    elif qctstr == "cn":
        qc_type = QCTypes.CN
    else:
        raise RuntimeError("invalid qc_mpo_type: %s" % qctstr)
elif dynamic_corr_method is None or dynamic_corr_method[0] == "dmrgfci":
    qc_type = QCTypes.Conventional
else:
    qc_type = QCTypes.NC

_print('qc mpo type = ', qc_type)

if dic.get("trunc_type", "physical") == "physical":
    trunc_type = TruncationTypes.Physical
elif dic.get("trunc_type", "physical").startswith("keep "):
    trunc_type = TruncationTypes.KeepOne * \
        int(dic["trunc_type"][len("keep "):].strip())
else:
    trunc_type = TruncationTypes.Reduced
if "real_density_matrix" in dic:
    trunc_type = trunc_type | TruncationTypes.RealDensityMatrix
if dic.get("decomp_type", "density_matrix") == "density_matrix":
    decomp_type = DecompositionTypes.DensityMatrix
else:
    decomp_type = DecompositionTypes.SVD
if dic.get("te_type", "rk4") == "rk4":
    te_type = TETypes.RK4
else:
    te_type = TETypes.TangentSpace
if dic.get("expt_algo_type", "auto") == "auto":
    algo_type = ExpectationAlgorithmTypes.Automatic
elif dic.get("expt_algo_type", "auto") == "fast":
    algo_type = ExpectationAlgorithmTypes.Fast
else:
    algo_type = ExpectationAlgorithmTypes.Normal

has_tran = "restart_tran_onepdm" in dic or "tran_onepdm" in dic \
    or "restart_tran_twopdm" in dic or "tran_twopdm" in dic \
    or "restart_tran_oh" in dic or "tran_oh" in dic or "compression" in dic
has_2pdm = "restart_tran_twopdm" in dic or "tran_twopdm" in dic \
    or "restart_twopdm" in dic or "twopdm" in dic
has_1npc = "restart_correlation" in dic or "correlation" in dic \
    or "restart_diag_twopdm" in dic or "diag_twopdm" in dic
anti_herm = "orbital_rotation" in dic
one_body_only = "orbital_rotation" in dic or "one_body_parallel_rule" in dic
complex_mps = "complex_mps" in dic
full_integral = "full_integral" in dic
XExpect = ComplexExpect if complex_mps else Expect

simpl_rule = RuleQC()
if has_tran:
    simpl_rule = NoTransposeRule(simpl_rule)
if anti_herm:
    simpl_rule = AntiHermitianRuleQC(simpl_rule)

if "use_hybrid_complex" in dic:
    assert "use_complex" not in dic
    assert "single_prec" not in dic # todo
    _print("USE HYBRID COMPLEX MPO")
    cached_contraction = False
    from block2.cpx import FCIDUMP as FCIDUMPCPX
    if "nonspinadapted" in dic and "k_symmetry" not in dic and "use_general_spin" not in dic:
        from block2.cpx.sz import RuleQC as RuleQCCPX, NoTransposeRule as NoTransposeRuleCPX
        from block2.cpx.sz import AntiHermitianRuleQC as AntiHermitianRuleQCCPX, MovingEnvironmentX
        from block2.cpx.sz import HamiltonianQC as HamiltonianQCCPX, ParallelMPO as ParallelMPOCPX
        from block2.cpx.sz import MPOQC as MPOQCCPX, MPO as MPOCPX, SimplifiedMPO as SimplifiedMPOCPX
        from block2.cpx.sz import ParallelRuleQC as ParallelRuleQCCPX, ParallelRuleOneBodyQC as ParallelRuleOneBodyQCCPX
    elif "nonspinadapted" not in dic and "k_symmetry" not in dic and "use_general_spin" not in dic:
        from block2.cpx.su2 import RuleQC as RuleQCCPX, NoTransposeRule as NoTransposeRuleCPX
        from block2.cpx.su2 import AntiHermitianRuleQC as AntiHermitianRuleQCCPX, MovingEnvironmentX
        from block2.cpx.su2 import HamiltonianQC as HamiltonianQCCPX, ParallelMPO as ParallelMPOCPX
        from block2.cpx.su2 import MPOQC as MPOQCCPX, MPO as MPOCPX, SimplifiedMPO as SimplifiedMPOCPX
        from block2.cpx.su2 import ParallelRuleQC as ParallelRuleQCCPX, ParallelRuleOneBodyQC as ParallelRuleOneBodyQCCPX
    elif "nonspinadapted" in dic and "k_symmetry" in dic:
        from block2.cpx.szk import RuleQC as RuleQCCPX, NoTransposeRule as NoTransposeRuleCPX
        from block2.cpx.szk import AntiHermitianRuleQC as AntiHermitianRuleQCCPX, MovingEnvironmentX
        from block2.cpx.szk import HamiltonianQC as HamiltonianQCCPX, ParallelMPO as ParallelMPOCPX
        from block2.cpx.szk import MPOQC as MPOQCCPX, MPO as MPOCPX, SimplifiedMPO as SimplifiedMPOCPX
        from block2.cpx.szk import ParallelRuleQC as ParallelRuleQCCPX, ParallelRuleOneBodyQC as ParallelRuleOneBodyQCCPX
    elif "nonspinadapted" not in dic and "k_symmetry" in dic:
        from block2.cpx.su2k import RuleQC as RuleQCCPX, NoTransposeRule as NoTransposeRuleCPX
        from block2.cpx.su2k import AntiHermitianRuleQC as AntiHermitianRuleQCCPX, MovingEnvironmentX
        from block2.cpx.su2k import HamiltonianQC as HamiltonianQCCPX, ParallelMPO as ParallelMPOCPX
        from block2.cpx.su2k import MPOQC as MPOQCCPX, MPO as MPOCPX, SimplifiedMPO as SimplifiedMPOCPX
        from block2.cpx.su2k import ParallelRuleQC as ParallelRuleQCCPX, ParallelRuleOneBodyQC as ParallelRuleOneBodyQCCPX
    elif "use_general_spin" in dic:
        from block2.cpx.sgf import RuleQC as RuleQCCPX, NoTransposeRule as NoTransposeRuleCPX
        from block2.cpx.sgf import AntiHermitianRuleQC as AntiHermitianRuleQCCPX, MovingEnvironmentX
        from block2.cpx.sgf import HamiltonianQC as HamiltonianQCCPX, ParallelMPO as ParallelMPOCPX
        from block2.cpx.sgf import MPOQC as MPOQCCPX, MPO as MPOCPX, SimplifiedMPO as SimplifiedMPOCPX
        from block2.cpx.sgf import ParallelRuleQC as ParallelRuleQCCPX, ParallelRuleOneBodyQC as ParallelRuleOneBodyQCCPX
    simpl_rule_cpx_h1e = RuleQCCPX()
    if has_tran:
        simpl_rule_cpx_h1e = NoTransposeRuleCPX(simpl_rule_cpx_h1e)
    if anti_herm:
        simpl_rule_cpx_h1e = AntiHermitianRuleQCCPX(simpl_rule_cpx_h1e)

if MPI is None or MPI.rank == 0:
    if not os.path.isdir(scratch):
        os.mkdir(scratch)
    if stackblock_compat:
        scratch = scratch + "/node0"
    if not os.path.isdir(scratch):
        os.mkdir(scratch)
    if restart_dir is not None and not os.path.isdir(restart_dir):
        os.mkdir(restart_dir)
    os.environ['TMPDIR'] = scratch
else:
    if stackblock_compat:
        scratch = scratch + "/node0"
    os.environ['TMPDIR'] = scratch
if MPI is not None:
    MPI.barrier()

# global settings
memory = int(int(dic.get("mem", "2").split()[0].lower().replace('g', '')) * 1e9)
mem_ratio = float(dic.get("mem_ratio", 0.4))
min_mpo_mem = dic.get("min_mpo_mem", "auto")
fp_cps_cutoff = float(dic.get("fp_cps_cutoff", 1E-16))
if "intmem" in dic:
    intmemory = int(int(dic["intmem"].split()[0]) * 1e9)
    init_memory(isize=int(intmemory), dsize=int(memory),
                save_dir=scratch, dmain_ratio=mem_ratio)
else:
    init_memory(isize=int(memory * 0.1),
                dsize=int(memory * 0.9), save_dir=scratch,
                dmain_ratio=mem_ratio)
# ZHC NOTE nglobal_threads, nop_threads, MKL_NUM_THREADS
Global.threading = Threading(
    ThreadingTypes.OperatorBatchedGEMM | ThreadingTypes.Global,
    n_threads * mkl_threads, n_threads, mkl_threads)
Global.threading.seq_type = SeqTypes.Tasked if big_site_method is None else SeqTypes.Nothing
gframe = Global.frame if "single_prec" not in dic else Global.frame_float
gframe.fp_codec = FPCodec(fp_cps_cutoff, 1024)
gframe.load_buffering = False
gframe.save_buffering = False
gframe.use_main_stack = False
gframe.minimal_disk_usage = True
if restart_dir is not None:
    gframe.restart_dir = restart_dir
if restart_dir_per_sweep is not None:
    gframe.restart_dir_per_sweep = restart_dir_per_sweep
_print(gframe)
_print(Global.threading)

if MPI is not None:
    prule = ParallelRuleQC(MPI)
    prule_one_body = ParallelRuleOneBodyQC(MPI)
    prule_pdm1 = ParallelRulePDM1QC(MPI)
    prule_pdm2 = ParallelRulePDM2QC(MPI)
    prule_ident = ParallelRuleIdentity(MPI)

if "use_hybrid_complex" in dic:
    if MPI is not None:
        prule_cpx_h1e = ParallelRuleQCCPX(MPI)
        prule_one_body_cpx_h1e = ParallelRuleOneBodyQCCPX(MPI)

# prepare hamiltonian
if pre_run or not no_pre_run:
    nelec = [int(x) for x in dic["nelec"].split()]
    spin = [int(x) for x in dic.get("spin", "0").split()]
    isym = [int(x) for x in dic.get("irrep", "1").split()]
    iksym = [int(x) for x in dic.get("k_irrep", "0").split()]
    if "orbital_rotation" in dic:
        orb_sym = np.load(scratch + "/nat_orb_sym.npy")
        if "k_symmetry" in dic:
            orb_sym = VectorUInt32(orb_sym)
        else:
            orb_sym = VectorUInt8(orb_sym)
        kappa = np.load(scratch + "/nat_kappa.npy")
        kappa = kappa.flatten()
        n_sites = len(orb_sym)
        fcidump = FCIDUMP()
        fcidump.initialize_h1e(n_sites, nelec[0], spin[0], isym[0], 0.0, kappa)
        assert "nofiedler" in dic or "noreorder" in dic
        if "target_t" not in dic:
            dic["target_t"] = "1"
    elif "model" in dic:  # model hamiltonians
        fmods = dic["model"].split()
        if fmods[0] in ["hubbard", "hubbard_periodic", "hubbard_kspace", "hubbard_rspace"]:
            assert len(fmods) in [4, 5]
            n_sites, const_t, const_u = int(
                fmods[1]), float(fmods[2]), float(fmods[3])
            if len(fmods) == 5 and fmods[4] == "per-site":
                const_t /= n_sites
                const_u /= n_sites
            _print("1D %s model : L = %d T = %.5f U = %.5f" %
                   (fmods[0], n_sites, const_t, const_u))
            if fmods[0] == "hubbard_kspace":
                fcidump = HubbardKSpaceFCIDUMP(n_sites, const_t, const_u)
            else:
                fcidump = HubbardFCIDUMP(n_sites, const_t, const_u,
                                         fmods[0] in ["hubbard_periodic", "hubbard_rspace"])
            orb_sym = None
        else:
            raise RuntimeError("Model %d not supported!" % fmods[0])
    else:
        orb_sym = None
        fints = dic["orbitals"]
        if open(fints, 'rb').read(4) != b'\x89HDF':
            # separate fcidump into real and complex parts
            if "use_hybrid_complex" in dic:
                fd_cpx = FCIDUMPCPX()
                fd_cpx.read(fints)
                fh1e = np.array(fd_cpx.h1e_matrix())
                rg2e = np.array(fd_cpx.g2e_1fold())
                assert np.abs(np.linalg.norm(np.imag(rg2e))) < 1E-20
                assert np.abs(np.imag(fd_cpx.const_e)) < 1E-20
                rg2e = np.real(rg2e).copy() # make contig
                rh1e = np.real(fh1e).copy()
                ch1e = fh1e.copy()
                rh1e[np.abs(np.imag(fh1e)) >= 1E-20] = 0.0
                ch1e[np.abs(np.imag(fh1e)) < 1E-20] = 0.0
                assert not fd_cpx.uhf
                fcidump = FCIDUMP()
                if fd_cpx.uhf:
                    fcidump.initialize_sz(
                        fd_cpx.n_sites, fd_cpx.n_elec, fd_cpx.twos, fd_cpx.isym, np.real(fd_cpx.const_e), rh1e, rg2e)
                else:
                    fcidump.initialize_su2(
                        fd_cpx.n_sites, fd_cpx.n_elec, fd_cpx.twos, fd_cpx.isym, np.real(fd_cpx.const_e), rh1e, rg2e)
                fcidump.orb_sym = fd_cpx.orb_sym
                fd_cpx_h1e = FCIDUMPCPX()
                fd_cpx_h1e.initialize_h1e(fd_cpx.n_sites, fd_cpx.n_elec, fd_cpx.twos, fd_cpx.isym, np.imag(fd_cpx.const_e), ch1e)
                fd_cpx_h1e.orb_sym = fd_cpx.orb_sym
            else:
                fcidump = FCIDUMP()
                fcidump.read(fints)
            integral_tol = float(dic.get("integral_tol", 0.0))
            if integral_tol != 0.0:
                int_tc_error = fcidump.truncate_small(integral_tol)
                _print("integral truncation error = ", int_tc_error)
            fcidump.params["nelec"] = str(nelec[0])
            fcidump.params["ms2"] = str(spin[0])
            fcidump.params["isym"] = str(isym[0])
        else:
            integral_tol = float(dic.get("integral_tol", 1E-12))
            fcidump = read_integral(fints, nelec[0], spin[0], isym=isym[0],
                                    tol=integral_tol)
        if integral_rescale == "auto" and "single_prec" in dic:
            _print("original integral const = %20.10f" % fcidump.e())
            fcidump.rescale(0)
            _print("rescaled integral const = %20.10f" % fcidump.e())
        elif integral_rescale != "none" and integral_rescale != "auto":
            _print("original integral const = %20.10f" % fcidump.e())
            fcidump.rescale(float(integral_rescale))
            _print("rescaled integral const = %20.10f" % fcidump.e())
    n_orbs = fcidump.n_sites
    if "trans_integral_to_spin_orbital" in dic:
        n_orbs = n_orbs * 2
    if dynamic_corr_method is not None:
        if len(dynamic_corr_method[1]) == 2:
            assert len(nelec) == 1
            n_cas, n_elec_cas = dynamic_corr_method[1]
            assert (nelec[0] - n_elec_cas) % 2 == 0
            n_inactive = (nelec[0] - n_elec_cas) // 2
            n_external = n_orbs - n_inactive - n_cas
        else:
            n_inactive, n_cas, n_external = dynamic_corr_method[1]
            assert n_orbs == n_inactive + n_cas + n_external
        _print("dynamic correlation space : inactive = %d, cas = %d, external = %d"
               % (n_inactive, n_cas, n_external))
    if "fullrestart" in dic and os.path.isfile(scratch + '/orbital_reorder.npy'):
        orb_idx = np.load(scratch + '/orbital_reorder.npy')
        _print("loading reorder for restarting = ", orb_idx)
        fcidump.reorder(VectorUInt16(orb_idx))
    elif "nofiedler" in dic or "noreorder" in dic:
        orb_idx = None
        np.save(scratch + '/orbital_reorder.npy',
                np.arange(0, fcidump.n_sites, dtype=int))
    else:
        if "gaopt" in dic:
            orb_idx = orbital_reorder(fcidump, method='gaopt ' + dic["gaopt"])
            _print("using gaopt reorder = ", orb_idx)
        elif "reorder" in dic:
            orb_idx = orbital_reorder(
                fcidump, method='manual ' + dic["reorder"])
            _print("using manual reorder = ", orb_idx)
        elif "irrep_reorder" in dic:
            orb_idx = orbital_reorder(
                fcidump, method='irrep ' + dic.get("sym", "d2h"))
            _print("using irrep reorder = ", orb_idx)
            _print("reordered irrep = ", fcidump.orb_sym)
        else:
            orb_idx = orbital_reorder(fcidump, method='fiedler')
            _print("using fiedler reorder = ", orb_idx)
        if dynamic_corr_method is not None:
            orb_idx = np.concatenate((orb_idx[orb_idx < n_inactive],
                orb_idx[(orb_idx >= n_inactive) & (orb_idx < n_cas + n_inactive)],
                orb_idx[orb_idx >= n_cas + n_inactive]), axis=0)
            _print("reorder indices adjusted for dynamic correlation = ", orb_idx)
        fcidump.reorder(VectorUInt16(orb_idx))
        np.save(scratch + '/orbital_reorder.npy', orb_idx)
    if "use_hybrid_complex" in dic and orb_idx is not None:
        fd_cpx_h1e.reorder(VectorUInt16(orb_idx))
    if "full_integral" not in dic and dynamic_corr_method is not None and \
        dynamic_corr_method[0] in ["nevpt2s", "mrcis", "mrrept2s",
            "nevpt2-i", "nevpt2-r", "mrrept2-i", "mrrept2-r"]:
        _print("use mrcis ficdump")
        fcidump = MRCISFCIDUMP(fcidump, n_inactive, n_external)
    if "trans_integral_to_spin_orbital" in dic:
        fcidump = SpinOrbitalFCIDUMP(fcidump)
    if "heisenberg" in dic:
        fcidump = HeisenbergFCIDUMP(fcidump)

    swap_pg = getattr(PointGroup, "swap_" + dic.get("sym", "d2h"))

    _print("read integral finished", time.perf_counter() - tx)

    vacuum = SX(0)
    if "k_symmetry" in dic:
        if "k_mod" in dic:
            fcidump.k_mod = int(dic["k_mod"])
            if fcidump.k_mod != 0:
                fcidump.k_sym = VectorInt(
                    [x % fcidump.k_mod for x in fcidump.k_sym])
            fcidump.k_isym = fcidump.k_isym % fcidump.k_mod
            iksym = [x % fcidump.k_mod for x in iksym]
        target = SX(fcidump.n_elec, fcidump.twos,
                    SX.pg_combine(swap_pg(fcidump.isym), fcidump.k_isym, fcidump.k_mod))
    else:
        target = SX(fcidump.n_elec, fcidump.twos, swap_pg(fcidump.isym))
    targets = []
    for inelec in nelec:
        for ispin in spin:
            for iisym in isym:
                if "k_symmetry" in dic:
                    for iiksym in iksym:
                        targets.append(SX(inelec, ispin,
                                          SX.pg_combine(swap_pg(iisym), iiksym, fcidump.k_mod)))
                else:
                    targets.append(SX(inelec, ispin, swap_pg(iisym)))
    targets = VectorSL(targets)
    if len(targets) == 0:
        targets = VectorSL([target])
    if singlet_embedding and SX == SU2:
        for it, targ in enumerate(targets):
            if targ.twos != 0:
                targets[it] = SX(targ.n + targ.twos, 0, targ.pg)
        assert len(spin) == 1
        singlet_embedding_spin = spin[0]
    if len(targets) == 1:
        target = targets[0]
    n_sites = n_orbs
    if orb_sym is None:
        orb_sym = VectorUInt8(map(swap_pg, fcidump.orb_sym))
    sym_error = fcidump.symmetrize(orb_sym)
    _print("integral sym error = %12.4g" % sym_error)
    if "k_symmetry" in dic:
        pure_pg_sym = orb_sym
        k_sym = fcidump.k_sym
        k_mod = fcidump.k_mod
        k_sym_error = fcidump.symmetrize(k_sym, k_mod)
        _print("integral k sym error = %12.4g" % k_sym_error)
        sym_error += k_sym_error
        orb_sym = HamiltonianQC.combine_orb_sym(orb_sym, k_sym, k_mod)
    if sym_error > symmetrize_ints_tol:
        raise RuntimeError(("Integral symmetrization error larger than %10.5g, "
                            + "please check point group symmetry and FCIDUMP or set"
                            + " a higher tolerance for the keyword '%s'") % (
            symmetrize_ints_tol, "symmetrize_ints"))
    hamil_np = None
    hamil_cpx_h1e = None
    if big_site_method is None:
        hamil = HamiltonianQC(vacuum, n_sites, orb_sym, fcidump)
        if "use_hybrid_complex" in dic:
            hamil_cpx_h1e = HamiltonianQCCPX(vacuum, n_sites, orb_sym, fd_cpx_h1e)
    elif big_site_method == "folding":
        hamil = HamiltonianQC(vacuum, n_sites, orb_sym, fcidump)
        mpo_fold = MPOQC(hamil, qc_type)
        assert dynamic_corr_method is not None
        if dynamic_corr_method[0] in ["casci"]:
            mps_info_fold = CASCIMPSInfo(
                n_orbs, vacuum, target, hamil.basis, n_inactive, n_cas, n_external)
        elif dynamic_corr_method[0] in ["mrcis", "mrcisd", "mrcisdt"]:
            ci_order = len(dynamic_corr_method[0]) - 4
            mps_info_fold = MRCIMPSInfo(
                n_orbs, n_inactive, n_external, ci_order, vacuum, target, hamil.basis)
        elif dynamic_corr_method[0] in ["nevpt2sd", "nevpt2s"]:
            ci_order = len(dynamic_corr_method[0]) - 6
            mps_info_fold = MRCIMPSInfo(
                n_orbs, n_inactive, n_external, ci_order, vacuum, target, hamil.basis)
        elif dynamic_corr_method[0] in ["mrrept2sd", "mrrept2s"]:
            ci_order = len(dynamic_corr_method[0]) - 7
            mps_info_fold = MRCIMPSInfo(
                n_orbs, n_inactive, n_external, ci_order, vacuum, target, hamil.basis)
        elif dynamic_corr_method[0] in ["nevpt2-ijrs", "nevpt2-ij", "nevpt2-rs", "nevpt2-ijr", "nevpt2-rsi",
                                        "nevpt2-ir", "nevpt2-i", "nevpt2-r"]:
            sub_space = dynamic_corr_method[0][7:]
            n_ex_inactive = sub_space.count('i') + sub_space.count('j')
            n_ex_external = sub_space.count('r') + sub_space.count('s')
            mps_info_fold = NEVPTMPSInfo(
                n_orbs, n_inactive, n_external, n_ex_inactive, n_ex_external, vacuum, target, hamil.basis)
        elif dynamic_corr_method[0] in ["mrrept2-ijrs", "mrrept2-ij", "mrrept2-rs", "mrrept2-ijr", "mrrept2-rsi",
                                        "mrrept2-ir", "mrrept2-i", "mrrept2-r"]:
            sub_space = dynamic_corr_method[0][8:]
            n_ex_inactive = sub_space.count('i') + sub_space.count('j')
            n_ex_external = sub_space.count('r') + sub_space.count('s')
            mps_info_fold = NEVPTMPSInfo(
                n_orbs, n_inactive, n_external, n_ex_inactive, n_ex_external, vacuum, target, hamil.basis)
        for i in range(n_external - 1):
            _print("fold right %d / %d" % (i, n_external))
            mpo_fold = FusedMPO(mpo_fold, hamil.basis, mpo_fold.n_sites - 2,
                                mpo_fold.n_sites - 1, mps_info_fold.right_dims_fci[mpo_fold.n_sites - 2])
            hamil.basis = mpo_fold.basis
            hamil.n_sites = mpo_fold.n_sites
        for i in range(n_inactive - 1):
            _print("fold left %d / %d" % (i, n_inactive))
            mpo_fold = FusedMPO(mpo_fold, hamil.basis, 0, 1,
                                mps_info_fold.left_dims_fci[i + 2])
            hamil.basis = mpo_fold.basis
            hamil.n_sites = mpo_fold.n_sites
        for k, op in mpo_fold.tensors[0].ops.items():
            smat = CSRSparseMatrix()
            if op.sparsity() > 0.75:
                smat.from_dense(op)
                op.deallocate()
            else:
                smat.wrap_dense(op)
            mpo_fold.tensors[0].ops[k] = smat
        mpo_fold.sparse_form = 'S' + mpo_fold.sparse_form[1:]
        mpo_fold.tf = TensorFunctions(CSROperatorFunctions(hamil.opf.cg))
        for k, op in mpo_fold.tensors[-1].ops.items():
            smat = CSRSparseMatrix()
            if op.sparsity() > 0.75:
                smat.from_dense(op)
                op.deallocate()
            else:
                smat.wrap_dense(op)
            mpo_fold.tensors[-1].ops[k] = smat
        mpo_fold.sparse_form = mpo_fold.sparse_form[:-1] + 'S'
        mpo_fold.tf = TensorFunctions(CSROperatorFunctions(hamil.opf.cg))
    else:
        assert dynamic_corr_method is not None
        if dynamic_corr_method[0] in ['mrcisdt']:
            xl = -3, -3, -3
            xr = 3, 3, 3
        elif dynamic_corr_method[0] in ['mrcisd', 'nevpt2sd', 'mrrept2sd']:
            xl = -2, -2, -2
            xr = 2, 2, 2
        elif dynamic_corr_method[0] in ['mrcis', 'nevpt2s', 'mrrept2s']:
            xl = -1, -1, -1
            xr = 1, 1, 1
        elif dynamic_corr_method[0] in ["nevpt2-ijrs", "nevpt2-ij", "nevpt2-rs", "nevpt2-ijr",
                                        "nevpt2-rsi", "nevpt2-ir", "nevpt2-i", "nevpt2-r",
                                        "mrrept2-ijrs", "mrrept2-ij", "mrrept2-rs", "mrrept2-ijr",
                                        "mrrept2-rsi", "mrrept2-ir", "mrrept2-i", "mrrept2-r"]:
            # this is not correct yet
            if dynamic_corr_method[0].startswith('nevpt2-'):
                sub_space = dynamic_corr_method[0][7:]
            else:
                sub_space = dynamic_corr_method[0][8:]
            n_ex_inactive = sub_space.count('i') + sub_space.count('j')
            n_ex_external = sub_space.count('r') + sub_space.count('s')
            xl = -n_ex_inactive, -n_ex_inactive, -n_ex_inactive
            xr = n_ex_external, n_ex_external, n_ex_external
        elif dynamic_corr_method[0] in ['casci']:
            xl = 0, 0, 0
            xr = 0, 0, 0
        else:
            assert len(nelec) == 1 and len(spin) == 1
            assert (nelec[0] + spin[0]) % 2 == 0
            n_alpha = (nelec[0] + spin[0]) // 2
            n_beta = (nelec[0] - spin[0]) // 2
            xl = -min(n_inactive, n_alpha), -min(n_inactive, n_beta), -min(n_inactive, nelec[0])
            xr = min(n_external, n_alpha), min(n_external, n_beta), min(n_external, nelec[0])
        if big_site_method == "fock":
            assert "nonspinadapted" in dic
            # special treatment for all aplpha/beta spatial orbtial
            if len(nelec) == 1 and len(spin) == 1 and nelec[0] == abs(spin[0]):
                if nelec[0] == spin[0]:
                    ref = VectorInt([i + i for i in range(n_inactive)])
                    xl = abs(xl[0]), 0, abs(xl[2])
                else:
                    ref = VectorInt([i + i + 1 for i in range(n_inactive)])
                    xl = 0, abs(xl[1]), abs(xl[2])
                poccl = SCIFockBigSite.ras_space(False, n_inactive, *xl, ref)
            else:
                poccl = SCIFockBigSite.ras_space(False, n_inactive, *[abs(x) for x in xl], VectorInt([]))
            poccr = SCIFockBigSite.ras_space(True, n_external, *xr, VectorInt([]))
            # need to include casci ref state for the first step even for nevpt2-r
            big_left_orig = SCIFockBigSite(n_orbs, n_inactive, False, fcidump, orb_sym, poccl, True)
            big_right_orig = SCIFockBigSite(n_orbs, n_external, True, fcidump, orb_sym, poccr, True)
        elif big_site_method == "csf":
            assert "nonspinadapted" not in dic
            big_left_orig = CSFBigSite(n_inactive, abs(xl[-1]), False, fcidump, orb_sym[:n_inactive])
            big_right_orig = CSFBigSite(n_external, abs(xr[-1]), True, fcidump, orb_sym[-n_external:])
        else:
            raise NotImplementedError
        big_left = SimplifiedBigSite(big_left_orig, simpl_rule)
        big_right = SimplifiedBigSite(big_right_orig, simpl_rule)
        if MPI is not None:
            big_left_np = big_left
            big_right_np = big_right
            if one_body_only:
                big_left = ParallelBigSite(big_left, prule_one_body)
                big_right = ParallelBigSite(big_right, prule_one_body)
            else:
                big_left = ParallelBigSite(big_left, prule)
                big_right = ParallelBigSite(big_right, prule)
            hamil_np = HamiltonianQCBigSite(vacuum, n_orbs, orb_sym, fcidump,
                    None if n_inactive == 0 else big_left_np, None if n_external == 0 else big_right_np)
        hamil = HamiltonianQCBigSite(vacuum, n_orbs, orb_sym, fcidump,
                    None if n_inactive == 0 else big_left, None if n_external == 0 else big_right)
    n_sites = hamil.n_sites

else:
    orb_idx = np.load(scratch + '/orbital_reorder.npy')
    n_orbs = len(orb_idx)
    if "nofiedler" in dic or "noreorder" in dic:
        orb_idx = None
    orb_sym = None
    fcidump = None
    spin = [int(x) for x in dic.get("spin", "0").split()]
    if singlet_embedding and SX == SU2:
        assert len(spin) == 1
        singlet_embedding_spin = spin[0]

if min_mpo_mem == "auto":
    gframe.minimal_memory_usage = n_orbs >= 120
else:
    gframe.minimal_memory_usage = min_mpo_mem.lower()[0] in ['1', 't']

_print('MinMPOMemUsage = ', gframe.minimal_memory_usage)

if no_pre_run:
    impo = MPO(0)
    impo.load_data(scratch + '/mpo-ident.bin', minimal=True)
    n_sites = impo.n_sites
    if "k_symmetry" in dic:
        k_mod = 0
        for b in impo.basis:
            for bb in b.quanta:
                k_mod = k_mod | bb.pg_k_mod
    nelec = [int(x) for x in dic["nelec"].split()]
    spin = [int(x) for x in dic.get("spin", "0").split()]
    isym = [int(x) for x in dic.get("irrep", "1").split()]
    iksym = [int(x) for x in dic.get("k_irrep", "0").split()]
    targets = []
    swap_pg = getattr(PointGroup, "swap_" + dic.get("sym", "d2h"))
    for inelec in nelec:
        for ispin in spin:
            for iisym in isym:
                if "k_symmetry" in dic:
                    for iiksym in iksym:
                        targets.append(SX(inelec, ispin,
                                          SX.pg_combine(swap_pg(iisym), iiksym, k_mod)))
                else:
                    targets.append(SX(inelec, ispin, swap_pg(iisym)))
    assert len(targets) != 0
    if singlet_embedding and SX == SU2:
        for it, targ in enumerate(targets):
            if targ.twos != 0:
                targets[it] = SX(targ.n + targ.twos, 0, targ.pg)
    if len(targets) == 1:
        target = targets[0]

# parallelization over sites
# use keyword: conn_centers auto 5      (5 is number of procs)
#          or  conn_centers 10 20 30 40 (list of connection site indices)
if "conn_centers" in dic:
    assert MPI is not None
    cc = dic["conn_centers"].split()
    if cc[0] == "auto":
        ncc = int(cc[1])
        conn_centers = list(
            np.arange(0, n_sites * ncc, n_sites, dtype=int) // ncc)[1:]
        assert len(conn_centers) == ncc - 1
    else:
        conn_centers = [int(xcc) for xcc in cc]
    _print("using connection sites: ", conn_centers)
    assert MPI.size % (len(conn_centers) + 1) == 0
    mps_prule = prule
    prule = prule.split(MPI.size // (len(conn_centers) + 1))
else:
    conn_centers = None

if dic.get("warmup", None) == "occ":
    _print("using occ init")
    assert "occ" in dic
    if len(dic["occ"].split()) == 1:
        with open(dic["occ"], 'r') as ofin:
            dic["occ"] = ofin.readlines()[0]
    occs = VectorDouble([float(occ)
                         for occ in dic["occ"].split() if len(occ) != 0])
    if orb_idx is not None:
        occs = FCIDUMP.array_reorder(occs, VectorUInt16(orb_idx))
        _print("using reordered occ init")
    assert len(occs) == n_sites or len(occs) == n_sites * \
        2 or len(occs) == n_sites * 4
    cbias = float(dic.get("cbias", 0.0))
    if cbias != 0.0:
        if len(occs) == n_sites:
            if "use_general_spin" in dic:
                occs = VectorDouble(
                    [c - cbias if c >= 0.5 else c + cbias for c in occs])
            else:
                occs = VectorDouble(
                    [c - cbias if c >= 1 else c + cbias for c in occs])
        elif len(occs) == 2 * n_sites:
            occs = VectorDouble(
                [c - cbias if c >= 0.5 else c + cbias for c in occs])
        elif len(occs) == 4 * n_sites:
            moccs = np.array(occs).reshape((n_sites, 4))
            f = (1 - cbias) / moccs.sum(axis=1)[:, None]
            moccs = moccs * f + cbias / 4
            occs = VectorDouble(moccs.flatten())
        else:
            assert False
    bias = float(dic.get("bias", 1.0))
else:
    occs = None

dot = 1 if "onedot" in dic or "zerodot" in dic else 2
nroots = int(dic.get("nroots", 1))
mps_tags = dic.get("mps_tags", "KET").split()
read_tags = dic.get("read_mps_tags", "KET").split()
proj_tags = dic.get("proj_mps_tags", "").split()
soc = "soc" in dic
overlap = "overlap" in dic


def fmt_size(i, suffix='B'):
    if i < 1000:
        return "%d %s" % (i, suffix)
    else:
        a = 1024
        for pf in "KMGTPEZY":
            p = 2
            for k in [10, 100, 1000]:
                if i < k * a:
                    return "%%.%df %%s%%s" % p % (i / a, pf, suffix)
                p -= 1
            a *= 1024
    return "??? " + suffix


if "compression" in dic or "stopt_compression" in dic or "delta_t" in dic:
    if mps_tags == read_tags:
        raise RuntimeError("""For compression and time evolution, the input MPS
            tags "read_mps_tags" and the output MPS tags "mps_tags" cannot
            be the same!""")

# prepare mps
if len(mps_tags) > 1 or ("compression" in dic and "random_mps_init" not in dic) \
   or "stopt_sampling" in dic:
    nroots = len(mps_tags)
    mps = None
    mps_info = None
    forward = False
elif "fullrestart" in dic:
    _print("full restart")
    mps_info = MPSInfo(0) if nroots == 1 and len(
        targets) == 1 and not complex_mps and "use_hybrid_complex" not in dic else MultiMPSInfo(0)
    if len(mps_tags) == 1 and os.path.isfile(scratch + "/%s-mps_info.bin" % mps_tags[0]):
        mps_info.load_data(scratch + "/%s-mps_info.bin" % mps_tags[0])
    else:
        mps_info.load_data(scratch + "/mps_info.bin")
    mps_info.tag = mps_tags[0]
    mps_info.load_mutable()
    max_bdim = max([x.n_states_total for x in mps_info.left_dims])
    if mps_info.bond_dim < max_bdim:
        mps_info.bond_dim = max_bdim
    max_bdim = max([x.n_states_total for x in mps_info.right_dims])
    if mps_info.bond_dim < max_bdim:
        mps_info.bond_dim = max_bdim
    mps = MPS(mps_info) if nroots == 1 and len(
        targets) == 1 and not complex_mps and "use_hybrid_complex" not in dic else MultiMPS(mps_info)
    mps.load_data()
    if mps.dot != dot:
        if MPI is not None:
            MPI.barrier()
        mps.dot = dot
        mps.save_data()
        if MPI is not None:
            MPI.barrier()
    if "use_hybrid_complex" in dic:
        mps.nroots = nroots * 2
        mps.wfns = mps.wfns[:nroots * 2]
        mps.weights = mps.weights[:nroots * 2]
    elif nroots != 1 and not complex_mps:
        mps.nroots = nroots
        mps.wfns = mps.wfns[:nroots]
        mps.weights = mps.weights[:nroots]
    weights = dic.get("weights", None)
    if weights is not None:
        mps.weights = VectorFP([float(x) for x in weights.split()])
    mps.load_mutable()
    forward = mps.center == 0
    if mps.canonical_form[mps.center] == 'L' and mps.center != mps.n_sites - mps.dot:
        mps.center += 1
        forward = True
        if mps.canonical_form[mps.center] in "ST" and mps.dot == 2:
            if MPI is not None:
                MPI.barrier()
            cg = CG(200)
            cg.initialize()
            mps.flip_fused_form(
                mps.center, cg, prule if MPI is not None else None)
            mps.save_data()
            if MPI is not None:
                MPI.barrier()
            mps.load_mutable()
            mps.info.load_mutable()
            if MPI is not None:
                MPI.barrier()
    elif mps.canonical_form[mps.center] in "CMKJST" and mps.center != 0:
        if mps.canonical_form[mps.center] in "KJ" and mps.dot == 2:
            if MPI is not None:
                MPI.barrier()
            cg = CG(200)
            cg.initialize()
            mps.flip_fused_form(
                mps.center, cg, prule if MPI is not None else None)
            mps.save_data()
            if MPI is not None:
                MPI.barrier()
            mps.load_mutable()
            mps.info.load_mutable()
            if MPI is not None:
                MPI.barrier()
        if not mps.canonical_form[mps.center:mps.center + 2] == "CC" and mps.dot == 2:
            mps.center -= 1
        forward = False
    elif mps.center == mps.n_sites - 1 and mps.dot == 2:
        if MPI is not None:
            MPI.barrier()
        if mps.canonical_form[mps.center] in "KJ":
            cg = CG(200)
            cg.initialize()
            mps.flip_fused_form(
                mps.center, cg, prule if MPI is not None else None)
        mps.center = mps.n_sites - 2
        mps.save_data()
        forward = False
        if MPI is not None:
            MPI.barrier()
        mps.load_mutable()
        mps.info.load_mutable()
        if MPI is not None:
            MPI.barrier()
    elif mps.center == 0 and mps.dot == 2:
        if MPI is not None:
            MPI.barrier()
        if mps.canonical_form[mps.center] in "ST":
            cg = CG(200)
            cg.initialize()
            mps.flip_fused_form(
                mps.center, cg, prule if MPI is not None else None)
        mps.save_data()
        forward = True
        if MPI is not None:
            MPI.barrier()
        mps.load_mutable()
        mps.info.load_mutable()
        if MPI is not None:
            MPI.barrier()
elif pre_run or not no_pre_run:
    if "trans_mps_info" in dic:
        assert nroots == 1 and len(targets) == 1
        tr_vacuum = TrSX(vacuum.n, abs(vacuum.twos), vacuum.pg)
        tr_target = TrSX(target.n, abs(target.twos), target.pg)
        tr_basis = TrVectorStateInfo([trans_si(b) for b in hamil.basis])
        tr_mps_info = TrMPSInfo(n_sites, tr_vacuum, tr_target, tr_basis)
        assert "full_fci_space" not in dic
        tr_mps_info.tag = mps_tags[0]
        if occs is None:
            tr_mps_info.set_bond_dimension(bond_dims[0])
        else:
            tr_mps_info.set_bond_dimension_using_occ(
                bond_dims[0], occs, bias=bias)
        mps_info = trans_mi(tr_mps_info, target)
    else:
        if nroots == 1 and len(targets) == 1 and "use_hybrid_complex" not in dic:
            if dynamic_corr_method is not None:
                if dynamic_corr_method[0] in ["casci", "nevpt2s", "nevpt2sd", "nevpt2-ijrs", "nevpt2-ij",
                        "nevpt2-rs", "nevpt2-ijr", "nevpt2-rsi", "nevpt2-ir", "nevpt2-i", "nevpt2-r",
                        "mrrept2s", "mrrept2sd", "mrrept2-ijrs", "mrrept2-ij",
                        "mrrept2-rs", "mrrept2-ijr", "mrrept2-rsi", "mrrept2-ir", "mrrept2-i", "mrrept2-r"]:
                    if big_site_method is not None:
                        mps_info = CASCIMPSInfo(n_sites, vacuum, target, hamil.basis,
                                                1 if n_inactive != 0 else 0, n_cas, 1 if n_external != 0 else 0)
                    else:
                        mps_info = CASCIMPSInfo(
                            n_sites, vacuum, target, hamil.basis, n_inactive, n_cas, n_external)
                elif dynamic_corr_method[0] in ["mrcis", "mrcisd", "mrcisdt"]:
                    if big_site_method is not None:
                        mps_info = MPSInfo(
                            n_sites, vacuum, target, hamil.basis)
                    else:
                        ci_order = len(dynamic_corr_method[0]) - 4
                        mps_info = MRCIMPSInfo(
                            n_sites, n_inactive, n_external, ci_order, vacuum, target, hamil.basis)
                else:
                    mps_info = MPSInfo(n_sites, vacuum, target, hamil.basis)
            else:
                mps_info = MPSInfo(n_sites, vacuum, target, hamil.basis)
        else:
            assert dynamic_corr_method is None
            _print('TARGETS = ', list(targets))
            mps_info = MultiMPSInfo(n_sites, vacuum, targets, hamil.basis)
        if singlet_embedding and SX == SU2:
            left_vacuum = SX(singlet_embedding_spin, singlet_embedding_spin, 0)
            right_vacuum = vacuum
            if "full_fci_space" in dic:
                mps_info.set_bond_dimension_full_fci(left_vacuum, right_vacuum)
            else:
                mps_info.set_bond_dimension_fci(left_vacuum, right_vacuum)
        else:
            if "full_fci_space" in dic:
                mps_info.set_bond_dimension_full_fci()
        mps_info.tag = mps_tags[0]
        if occs is None:
            mps_info.set_bond_dimension(bond_dims[0])
        else:
            mps_info.set_bond_dimension_using_occ(
                bond_dims[0], occs, bias=bias)
    if "skip_inact_ext_sites" in dic:
        assert dynamic_corr_method is not None
        mps_info.set_bond_dimension_inact_ext_fci(bond_dims[0], n_inactive, n_external)
    if MPI is None or MPI.rank == 0:
        mps_info.save_data(scratch + '/mps_info.bin')
        mps_info.save_data(scratch + '/%s-mps_info.bin' % mps_tags[0])

    if conn_centers is not None:
        assert nroots == 1
        mps = ParallelMPS(mps_info.n_sites, init_mps_center, dot, mps_prule)
    elif nroots != 1 or len(targets) != 1 or "use_hybrid_complex" in dic:
        if "use_hybrid_complex" in dic:
            mps = MultiMPS(n_sites, init_mps_center, dot, nroots * 2)
        else:
            mps = MultiMPS(n_sites, init_mps_center, dot, nroots)
        weights = dic.get("weights", None)
        if weights is not None:
            mps.weights = VectorFP([float(x) for x in weights.split()])
    else:
        mps = MPS(n_sites, init_mps_center, dot)
    mps.initialize(mps_info)
    mps.random_canonicalize()
    if nroots == 1 and "use_hybrid_complex" not in dic:
        mps.tensors[mps.center].normalize()
    else:
        for xwfn in mps.wfns:
            xwfn.normalize()
    if "skip_inact_ext_sites" in dic:
        mps.set_inact_ext_identity(n_inactive, n_external)
    forward = mps.center == 0
else:
    mps_info = MPSInfo(0) if nroots == 1 and len(
        targets) == 1 and "use_hybrid_complex" not in dic else MultiMPSInfo(0)
    if len(mps_tags) == 1 and os.path.isfile(scratch + "/%s-mps_info.bin" % mps_tags[0]):
        mps_info.load_data(scratch + "/%s-mps_info.bin" % mps_tags[0])
    else:
        mps_info.load_data(scratch + "/mps_info.bin")
    mps_info.tag = mps_tags[0]
    if occs is None:
        mps_info.set_bond_dimension(bond_dims[0])
    else:
        mps_info.set_bond_dimension_using_occ(
            bond_dims[0], occs, bias=bias)
    if "skip_inact_ext_sites" in dic:
        assert dynamic_corr_method is not None
        mps_info.set_bond_dimension_inact_ext_fci(bond_dims[0], n_inactive, n_external)

    if conn_centers is not None:
        assert nroots == 1
        mps = ParallelMPS(mps_info.n_sites, init_mps_center, dot, mps_prule)
    elif nroots != 1 or len(targets) != 1 or "use_hybrid_complex" in dic:
        if "use_hybrid_complex" in dic:
            mps = MultiMPS(n_sites, init_mps_center, dot, nroots * 2)
        else:
            mps = MultiMPS(n_sites, init_mps_center, dot, nroots)
        weights = dic.get("weights", None)
        if weights is not None:
            mps.weights = VectorFP([float(x) for x in weights.split()])
    else:
        mps = MPS(mps_info.n_sites, init_mps_center, dot)
    mps.initialize(mps_info)
    mps.random_canonicalize()
    if nroots == 1 and "use_hybrid_complex" not in dic:
        mps.tensors[mps.center].normalize()
    else:
        for xwfn in mps.wfns:
            xwfn.normalize()
    if "skip_inact_ext_sites" in dic:
        mps.set_inact_ext_identity(n_inactive, n_external)
    forward = mps.center == 0

if mps is not None:
    _print("MPS = ", mps.canonical_form, mps.center, mps.dot, mps.info.target)
    _print("GS INIT MPS BOND DIMS = ", ''.join(
        ["%6d" % x.n_states_total for x in mps_info.left_dims]))

if conn_centers is not None and "fullrestart" in dic:
    assert mps.dot == 2
    mps = ParallelMPS(mps, mps_prule)
    if mps.canonical_form[0] == 'C' and mps.canonical_form[1] == 'R':
        mps.canonical_form = 'K' + mps.canonical_form[1:]
    elif mps.canonical_form[-1] == 'C' and mps.canonical_form[-2] == 'L':
        mps.canonical_form = mps.canonical_form[:-1] + 'S'
        mps.center = mps.n_sites - 1

try:
    import psutil
    mem = psutil.Process(os.getpid()).memory_info().rss
    _print("pre-mpo memory usage = ", fmt_size(mem))
except ImportError:
    pass

# prepare mpo
if pre_run or not no_pre_run:
    # mpo for dmrg
    _print("build mpo start ...")
    txx = time.perf_counter()
    if big_site_method == "folding":
        mpo = mpo_fold
    else:
        mpo = MPOQC(hamil, qc_type)
    _print("build mpo finished ... Tread = %.3f Twrite = %.3f T = %.3f" % (mpo.tread, mpo.twrite, time.perf_counter() - txx))
    _print("simpl mpo start ...")
    txx = time.perf_counter()
    mpo = SimplifiedMPO(mpo, simpl_rule, True, True,
                        OpNamesSet((OpNames.R, OpNames.RD)))
    _print("simpl mpo finished ... Tread = %.3f Twrite = %.3f T = %.3f" % (mpo.tread, mpo.twrite, time.perf_counter() - txx))

    mpo_bdims = [None] * len(mpo.left_operator_names)
    for ix in range(len(mpo.left_operator_names)):
        mpo.load_left_operators(ix)
        x = mpo.left_operator_names[ix]
        mpo_bdims[ix] = x.m * x.n
        mpo.unload_left_operators(ix)
    _print('GS MPO BOND DIMS = ', ''.join(["%6d" % x for x in mpo_bdims]))

    if MPI is None or MPI.rank == 0:
        mpo.save_data(scratch + '/mpo.bin')

    if "use_hybrid_complex" in dic:
        txx = time.perf_counter()
        mpo_cpx_h1e = MPOQCCPX(hamil_cpx_h1e, qc_type)
        mpo_cpx_h1e = SimplifiedMPOCPX(mpo_cpx_h1e, simpl_rule_cpx_h1e, True, True,
                        OpNamesSet((OpNames.R, OpNames.RD)))
        _print("cpx h1e mpo finished ... Tread = %.3f Twrite = %.3f T = %.3f"
            % (mpo_cpx_h1e.tread, mpo_cpx_h1e.twrite, time.perf_counter() - txx))

        if MPI is None or MPI.rank == 0:
            mpo_cpx_h1e.save_data(scratch + '/mpo_cpx_h1e.bin')

    # mpo for 1pdm
    _print("build 1pdm mpo", time.perf_counter() - tx)
    pmpo = PDM1MPOQC(hamil_np or hamil, 1 if soc else 0)
    pmpo = SimplifiedMPO(pmpo,
                         NoTransposeRule(RuleQC()) if has_tran else RuleQC(),
                         True, True, OpNamesSet((OpNames.R, OpNames.RD)))

    if MPI is None or MPI.rank == 0:
        pmpo.save_data(scratch + '/mpo-1pdm.bin')

    if has_2pdm:
        # mpo for 2pdm
        _print("build 2pdm mpo", time.perf_counter() - tx)
        p2mpo = PDM2MPOQC(hamil_np or hamil)
        p2mpo = SimplifiedMPO(p2mpo,
                              NoTransposeRule(
                                  RuleQC()) if has_tran else RuleQC(),
                              True, True, OpNamesSet((OpNames.R, OpNames.RD)))

        if MPI is None or MPI.rank == 0:
            p2mpo.save_data(scratch + '/mpo-2pdm.bin')

    if has_1npc:
        # mpo for particle number correlation
        _print("build 1npc mpo", time.perf_counter() - tx)
        nmpo = NPC1MPOQC(hamil_np or hamil)
        nmpo = SimplifiedMPO(nmpo, RuleQC(), True, True,
                             OpNamesSet((OpNames.R, OpNames.RD)))

        if MPI is None or MPI.rank == 0:
            nmpo.save_data(scratch + '/mpo-1npc.bin')

    # mpo for identity operator
    _print("build identity mpo", time.perf_counter() - tx)
    impo = IdentityMPO(hamil_np or hamil)
    impo = SimplifiedMPO(impo,
                         NoTransposeRule(RuleQC()) if has_tran else RuleQC(),
                         True, True, OpNamesSet((OpNames.R, OpNames.RD)))

    if MPI is None or MPI.rank == 0:
        impo.save_data(scratch + '/mpo-ident.bin')

    if para_pre_run:
        if MPI is not None:
            if one_body_only:
                mpo = ParallelMPO(mpo, prule_one_body)
            else:
                mpo = ParallelMPO(mpo, prule)
            pmpo = ParallelMPO(pmpo, prule_pdm1)
            if has_2pdm:
                p2mpo = ParallelMPO(p2mpo, prule_pdm2)
            if has_1npc:
                nmpo = ParallelMPO(nmpo, prule_pdm1)
            impo = ParallelMPO(impo, prule_ident)

        _print("para mpo finished", time.perf_counter() - tx)
        try:
            import psutil
            mem = psutil.Process(os.getpid()).memory_info().rss
            _print("memory usage = ", fmt_size(mem))
        except ImportError:
            pass

        mrank = MPI.rank if MPI is not None else 0
        mpo.reduce_data()
        mpo.save_data(scratch + '/mpo.bin.%d' % mrank)
        pmpo.reduce_data()
        pmpo.save_data(scratch + '/mpo-1pdm.bin.%d' % mrank)
        if has_2pdm:
            p2mpo.reduce_data()
            p2mpo.save_data(scratch + '/mpo-2pdm.bin.%d' % mrank)
        if has_1npc:
            nmpo.reduce_data()
            nmpo.save_data(scratch + '/mpo-1npc.bin.%d' % mrank)
        impo.reduce_data()
        impo.save_data(scratch + '/mpo-ident.bin.%d' % mrank)

else:

    if not para_no_pre_run:

        mpo = MPO(0)
        mpo.load_data(scratch + '/mpo.bin')

        _print('GS MPO BOND DIMS = ', ''.join(
            ["%6d" % (x.m * x.n) for x in mpo.left_operator_names]))

        if "use_hybrid_complex" in dic:
            mpo_cpx_h1e = MPOCPX(0)
            mpo_cpx_h1e.load_data(scratch + '/mpo_cpx_h1e.bin')

        pmpo = MPO(0)
        pmpo.load_data(scratch + '/mpo-1pdm.bin')

        _print('1PDM MPO BOND DIMS = ', ''.join(
            ["%6d" % (x.m * x.n) for x in pmpo.left_operator_names]))

        if has_2pdm:
            p2mpo = MPO(0)
            p2mpo.load_data(scratch + '/mpo-2pdm.bin')

            _print('2PDM MPO BOND DIMS = ', ''.join(
                ["%6d" % (x.m * x.n) for x in p2mpo.left_operator_names]))

        if has_1npc:
            nmpo = MPO(0)
            nmpo.load_data(scratch + '/mpo-1npc.bin')

            _print('1NPC MPO BOND DIMS = ', ''.join(
                ["%6d" % (x.m * x.n) for x in nmpo.left_operator_names]))

        impo = MPO(0)
        impo.load_data(scratch + '/mpo-ident.bin')

        _print('IDENT MPO BOND DIMS = ', ''.join(
            ["%6d" % (x.m * x.n) for x in impo.left_operator_names]))

    else:

        if MPI is not None:
            if one_body_only:
                mpo = ParallelMPO(0, prule_one_body)
            else:
                mpo = ParallelMPO(0, prule)
            pmpo = ParallelMPO(0, prule_pdm1)
            if has_2pdm:
                p2mpo = ParallelMPO(0, prule_pdm2)
            if has_1npc:
                nmpo = ParallelMPO(0, prule_pdm1)
            impo = ParallelMPO(0, prule_ident)
        else:
            mpo = MPO(0)
            pmpo = MPO(0)
            if has_2pdm:
                p2mpo = MPO(0)
            if has_1npc:
                nmpo = MPO(0)
            impo = MPO(0)

        mrank = MPI.rank if MPI is not None else 0
        mpo.load_data(scratch + '/mpo.bin.%d' % mrank, minimal=True)
        pmpo.load_data(scratch + '/mpo-1pdm.bin.%d' % mrank, minimal=True)
        if has_2pdm:
            p2mpo.load_data(scratch + '/mpo-2pdm.bin.%d' % mrank, minimal=True)
        if has_1npc:
            nmpo.load_data(scratch + '/mpo-1npc.bin.%d' % mrank, minimal=True)
        impo.load_data(scratch + '/mpo-ident.bin.%d' % mrank, minimal=True)

        if "use_hybrid_complex" in dic:
            if MPI is not None:
                if one_body_only:
                    mpo_cpx_h1e = ParallelMPOCPX(0, prule_one_body_cpx_h1e)
                else:
                    mpo_cpx_h1e = ParallelMPOCPX(0, prule_cpx_h1e)
            else:
                mpo_cpx_h1e = MPO(0)
            mpo_cpx_h1e.load_data(scratch + '/mpo_cpx_h1e.bin.%d' % mrank, minimal=True)

try:
    import psutil
    mem = psutil.Process(os.getpid()).memory_info().rss
    _print("memory usage = ", fmt_size(mem))
except ImportError:
    pass


def split_mps(iroot, mps, mps_info, mpi=MPI):
    mps.load_data()  # this will avoid memory sharing
    mps_info.load_mutable()
    mps.load_mutable()

    # break up a MultiMPS to single MPSs
    if len(mps_info.targets) != 1:
        smps_info = MultiMPSInfo(mps_info.n_sites, mps_info.vacuum,
                                 mps_info.targets, mps_info.basis)
        if singlet_embedding and SX == SU2:
            left_vacuum = SX(singlet_embedding_spin, singlet_embedding_spin, 0)
            right_vacuum = vacuum
            if "full_fci_space" in dic:
                smps_info.set_bond_dimension_full_fci(
                    left_vacuum, right_vacuum)
            else:
                smps_info.set_bond_dimension_fci(left_vacuum, right_vacuum)
        else:
            if "full_fci_space" in dic:
                smps_info.set_bond_dimension_full_fci()
        smps_info.tag = mps_info.tag + "-%d" % iroot
        smps_info.bond_dim = mps_info.bond_dim
        for i in range(0, smps_info.n_sites + 1):
            smps_info.left_dims[i] = mps_info.left_dims[i]
            smps_info.right_dims[i] = mps_info.right_dims[i]
        smps_info.save_mutable()
        smps = MultiMPS(smps_info)
        smps.n_sites = mps.n_sites
        smps.center = mps.center
        smps.dot = mps.dot
        smps.canonical_form = '' + mps.canonical_form
        smps.tensors = mps.tensors[:]
        smps.wfns = mps.wfns[iroot:iroot + 1]
        smps.weights = mps.weights[iroot:iroot + 1]
        smps.weights[0] = 1
        smps.nroots = 1
        smps.save_mutable()
    else:
        smps_info = MPSInfo(mps_info.n_sites, mps_info.vacuum,
                            mps_info.targets[0], mps_info.basis)
        if singlet_embedding and SX == SU2:
            left_vacuum = SX(singlet_embedding_spin, singlet_embedding_spin, 0)
            right_vacuum = vacuum
            if "full_fci_space" in dic:
                smps_info.set_bond_dimension_full_fci(
                    left_vacuum, right_vacuum)
            else:
                smps_info.set_bond_dimension_fci(left_vacuum, right_vacuum)
        else:
            if "full_fci_space" in dic:
                smps_info.set_bond_dimension_full_fci()
        smps_info.tag = mps_info.tag + "-%d" % iroot
        smps_info.bond_dim = mps_info.bond_dim
        for i in range(0, smps_info.n_sites + 1):
            smps_info.left_dims[i] = mps_info.left_dims[i]
            smps_info.right_dims[i] = mps_info.right_dims[i]
        smps_info.save_mutable()
        smps = MPS(smps_info)
        smps.n_sites = mps.n_sites
        smps.center = mps.center
        smps.dot = mps.dot
        smps.canonical_form = '' + mps.canonical_form
        smps.canonical_form = smps.canonical_form.replace(
            'T', 'S').replace('J', 'K')
        smps.tensors = mps.tensors[:]
        if smps.tensors[smps.center] is None:
            smps.tensors[smps.center] = mps.wfns[iroot][0]
        else:
            assert smps.center + 1 < smps.n_sites
            assert smps.tensors[smps.center + 1] is None
            smps.tensors[smps.center + 1] = mps.wfns[iroot][0]
        smps.save_mutable()

    if smps.center == 0 and dot == 2:
        if mpi is not None:
            mpi.barrier()
        if smps.canonical_form[smps.center] in "ST":
            cg = CG(200)
            cg.initialize()
            smps.flip_fused_form(
                smps.center, cg, prule if mpi is not None else None)
        smps.save_data()
        forward = True
        if mpi is not None:
            mpi.barrier()
        smps.load_mutable()
        smps.info.load_mutable()
        if mpi is not None:
            mpi.barrier()

    smps.dot = dot
    forward = smps.center == 0
    if smps.canonical_form[smps.center] == 'L' and smps.center != smps.n_sites - smps.dot:
        smps.center += 1
        forward = True
    elif (smps.canonical_form[smps.center] == 'C' or smps.canonical_form[smps.center] == 'M') and smps.center != 0:
        smps.center -= 1
        forward = False
    if smps.canonical_form[smps.center] == 'M' and not isinstance(smps, MultiMPS):
        smps.canonical_form = smps.canonical_form[:smps.center] + \
            'C' + smps.canonical_form[smps.center + 1:]
    if smps.canonical_form[-1] == 'M' and not isinstance(smps, MultiMPS):
        smps.canonical_form = smps.canonical_form[:-1] + 'C'
    if dot == 1:
        if smps.canonical_form[0] == 'C' and smps.canonical_form[1] == 'R':
            smps.canonical_form = 'K' + smps.canonical_form[1:]
        elif smps.canonical_form[-1] == 'C' and smps.canonical_form[-2] == 'L':
            smps.canonical_form = smps.canonical_form[:-1] + 'S'
            smps.center = smps.n_sites - 1
        if smps.canonical_form[0] == 'M' and smps.canonical_form[1] == 'R':
            smps.canonical_form = 'J' + smps.canonical_form[1:]
        elif smps.canonical_form[-1] == 'M' and smps.canonical_form[-2] == 'L':
            smps.canonical_form = smps.canonical_form[:-1] + 'T'
            smps.center = smps.n_sites - 1

    mps.deallocate()
    mps_info.deallocate_mutable()
    smps.save_data()
    return smps, smps_info, forward


def get_mps_from_tags(iroot, proj_mps=False, ref_center=0):
    if proj_mps:
        _print('----- proj = %3d tag = %s -----' % (iroot, proj_tags[iroot]))
        tag = proj_tags[iroot]
    elif iroot >= 0:
        _print('----- root = %3d tag = %s -----' % (iroot, mps_tags[iroot]))
        tag = mps_tags[iroot]
    else:
        _print('----- cps/te init tag = %s -----' % read_tags[0])
        tag = read_tags[0]
    smps_info = MPSInfo(0) if not complex_mps else MultiMPSInfo(0)
    smps_info.load_data(scratch + "/%s-mps_info.bin" % tag)
    if MPI is not None:
        MPI.barrier()
    if not complex_mps:
        smps = MPS(smps_info).deep_copy(smps_info.tag + "-%d" % iroot)
    else:
        smps = MultiMPS(smps_info).deep_copy(smps_info.tag + "-%d" % iroot)
    if MPI is not None:
        MPI.barrier()
    smps_info = smps.info
    smps_info.load_mutable()
    max_bdim = max([x.n_states_total for x in smps_info.left_dims])
    if smps_info.bond_dim < max_bdim:
        smps_info.bond_dim = max_bdim
    max_bdim = max([x.n_states_total for x in smps_info.right_dims])
    if smps_info.bond_dim < max_bdim:
        smps_info.bond_dim = max_bdim
    smps.load_data()
    if smps.dot == 2 and smps.center == smps.n_sites - 1 and dot == 1:
        smps.dot = 1
    elif smps.dot == 2 and smps.center == 0 and dot == 1:
        smps.dot = 1
    elif smps.dot == 2 and smps.center == smps.n_sites - 1:
        if complex_mps:
            _print('change canonical form ...')
            cf = str(smps.canonical_form)
            smps.dot = 1
            ime = MovingEnvironment(impo, smps, smps, "IEX")
            ime.delayed_contraction = OpNamesSet.normal_ops()
            ime.cached_contraction = cached_contraction
            ime.init_environments(False)
            expect = ComplexExpect(ime, smps.info.bond_dim, smps.info.bond_dim)
            expect.iprint = max(min(outputlevel, 3), 0)
            expect.solve(True, smps.center == 0)
            if MPI is not None:
                MPI.barrier()
            smps.dot = 2
            smps.save_data()
            if MPI is not None:
                MPI.barrier()
            if smps.canonical_form[smps.center] in "ST":
                cg = CG(200)
                cg.initialize()
                smps.flip_fused_form(
                    smps.center, cg, prule if MPI is not None else None)
            smps.save_data()
            if MPI is not None:
                MPI.barrier()
            _print(cf + ' -> ' + smps.canonical_form)
    if smps.dot == 1 and dot == 2:
        if smps.center == 0 and smps.canonical_form[0] == 'S':
            cg = CG(200)
            cg.initialize()
            smps.move_right(cg, prule if MPI is not None else None)
            smps.center = 0
        elif smps.center == smps.n_sites - 1 and smps.canonical_form[smps.center] == 'K':
            cg = CG(200)
            cg.initialize()
            smps.move_left(cg, prule if MPI is not None else None)
            smps.center = smps.n_sites - 2
        smps.dot = dot
        if MPI is not None:
            MPI.barrier()
        smps.save_data()
        if MPI is not None:
            MPI.barrier()
    if (smps.center == 0) != (ref_center == 0):
        _print('change canonical form ...')
        cf = str(smps.canonical_form)
        ime = MovingEnvironment(impo, smps, smps, "IEX")
        ime.delayed_contraction = OpNamesSet.normal_ops()
        ime.cached_contraction = cached_contraction
        ime.init_environments(False)
        if not complex_mps:
            expect = Expect(ime, smps.info.bond_dim, smps.info.bond_dim)
        else:
            expect = ComplexExpect(ime, smps.info.bond_dim, smps.info.bond_dim)
        expect.iprint = max(min(outputlevel, 3), 0)
        expect.solve(True, smps.center == 0)
        if MPI is not None:
            MPI.barrier()
        smps.save_data()
        if MPI is not None:
            MPI.barrier()
        _print(cf + ' -> ' + smps.canonical_form)
    forward = smps.center == 0
    return smps, smps.info, forward


def get_state_specific_mps(iroot, mps_info):
    smps_info = MPSInfo(0)
    smps_info.load_data(scratch + "/mps_info-ss-%d.bin" % iroot)
    if MPI is not None:
        MPI.barrier()
    smps = MPS(smps_info).deep_copy(mps_info.tag + "-%d" % iroot)
    if MPI is not None:
        MPI.barrier()
    smps_info = smps.info
    smps_info.load_mutable()
    max_bdim = max([x.n_states_total for x in smps_info.left_dims])
    if smps_info.bond_dim < max_bdim:
        smps_info.bond_dim = max_bdim
    max_bdim = max([x.n_states_total for x in smps_info.right_dims])
    if smps_info.bond_dim < max_bdim:
        smps_info.bond_dim = max_bdim
    smps.load_data()
    if smps.dot == 2 and smps.center == smps.n_sites - 1 and dot == 1:
        smps.dot = 1
    elif smps.dot == 2 and smps.center == 0 and dot == 1:
        smps.dot = 1
    elif (smps.dot == 1 or smps.center == smps.n_sites - 1) and dot == 2:
        if smps.center == 0 and smps.canonical_form[0] == 'S':
            cg = CG(200)
            cg.initialize()
            smps.move_right(cg, prule)
            smps.center = 0
        elif smps.center == smps.n_sites - 1 and smps.canonical_form[smps.center] == 'K':
            cg = CG(200)
            cg.initialize()
            smps.move_left(cg, prule)
            smps.center = smps.n_sites - 2
        elif smps.center == smps.n_sites - 1 and smps.canonical_form[smps.center] == 'S':
            smps.center = smps.n_sites - 2
        smps.dot = dot
        if MPI is not None:
            MPI.barrier()
        smps.save_data()
        if MPI is not None:
            MPI.barrier()
    forward = smps.center == 0
    _print('ss-mps', smps.center, smps.dot, dot, smps.canonical_form)
    return smps, smps_info, forward


if not pre_run:

    if not para_no_pre_run:

        if MPI is not None:
            if one_body_only:
                mpo = ParallelMPO(mpo, prule_one_body)
            else:
                mpo = ParallelMPO(mpo, prule)
            pmpo = ParallelMPO(pmpo, prule_pdm1)
            if has_2pdm:
                p2mpo = ParallelMPO(p2mpo, prule_pdm2)
            if has_1npc:
                nmpo = ParallelMPO(nmpo, prule_pdm1)
            impo = ParallelMPO(impo, prule_ident)

            if "use_hybrid_complex" in dic:
                if one_body_only:
                    mpo_cpx_h1e = ParallelMPOCPX(mpo_cpx_h1e, prule_one_body_cpx_h1e)
                else:
                    mpo_cpx_h1e = ParallelMPOCPX(mpo_cpx_h1e, prule_cpx_h1e)

        _print("para mpo finished", time.perf_counter() - tx)

    if mps is not None:
        mps.save_mutable()
        mps.deallocate()
        mps_info.save_mutable()
        mps_info.deallocate_mutable()

    if conn_centers is not None:
        mps.conn_centers = VectorInt(conn_centers)

    # state-specific DMRG
    if "statespecific" in dic and "restart_onepdm" not in dic \
            and "restart_correlation" not in dic and "restart_tran_twopdm" not in dic \
            and "restart_oh" not in dic and "restart_twopdm" not in dic \
            and "restart_tran_onepdm" not in dic and "restart_tran_oh" not in dic \
            and "restart_copy_mps" not in dic and "restart_sample" not in dic:
        assert isinstance(mps, MultiMPS)
        assert nroots != 1

        ext_mpss = []
        for iroot in range(nroots):
            tx = time.perf_counter()
            _print('----- root = %3d / %3d -----' % (iroot, nroots))
            ext_mpss.append(mps.extract(iroot, mps.info.tag + "-%d" % iroot)
                               .make_single(mps.info.tag + "-S%d" % iroot))
            for iex, ext_mps in enumerate(ext_mpss):
                _print(iex, ext_mpss[iex].canonical_form, ext_mpss[iex].center)
                if (ext_mps.dot == 1 or ext_mps.center == ext_mps.n_sites - 1) and dot == 2:
                    if ext_mps.center == 0 and ext_mps.canonical_form[0] == 'S':
                        cg = CG(200)
                        cg.initialize()
                        ext_mps.move_right(cg, prule)
                        ext_mps.center = 0
                    elif ext_mps.center == ext_mps.n_sites - 1 and ext_mps.canonical_form[ext_mps.center] == 'K':
                        cg = CG(200)
                        cg.initialize()
                        ext_mps.move_left(cg, prule)
                        ext_mps.center = ext_mps.n_sites - 2
                    elif ext_mps.center == ext_mps.n_sites - 1 and ext_mps.canonical_form[ext_mps.center] == 'S':
                        ext_mps.center = ext_mps.n_sites - 2
                    ext_mps.dot = dot
                    ext_mps.save_data()
                _print(iex, ext_mpss[iex].canonical_form, ext_mpss[iex].center)
            if ext_mpss[0].center != ext_mpss[iroot].center:
                _print('change canonical form ...')
                cf = str(ext_mpss[iroot].canonical_form)
                ime = MovingEnvironment(
                    impo, ext_mpss[iroot], ext_mpss[iroot], "IEX")
                ime.delayed_contraction = OpNamesSet.normal_ops()
                ime.cached_contraction = cached_contraction
                ime.init_environments(False)
                expect = Expect(
                    ime, ext_mpss[iroot].info.bond_dim, ext_mpss[iroot].info.bond_dim)
                expect.iprint = max(min(outputlevel, 3), 0)
                expect.solve(True, ext_mpss[iroot].center == 0)
                ext_mpss[iroot].save_data()
                _print(cf + ' -> ' + ext_mpss[iroot].canonical_form)

            me = MovingEnvironment(
                mpo, ext_mpss[iroot], ext_mpss[iroot], "DMRG")
            me.delayed_contraction = OpNamesSet.normal_ops()
            me.cached_contraction = cached_contraction
            me.save_partition_info = True
            me.init_environments(outputlevel >= 2)

            _print("env init finished", time.perf_counter() - tx)

            dmrg = DMRG(me, VectorUBond(bond_dims), VectorFP(noises))
            dmrg.ext_mpss = VectorMPS(ext_mpss[:iroot])
            dmrg.state_specific = True
            proj_weights = dic.get("proj_weights", None)
            if proj_weights is not None:
                dmrg.projection_weights = VectorFP([float(x) for x in proj_weights.split()][:iroot])
            dmrg.iprint = max(min(outputlevel, 3), 0)
            for ext_mps in dmrg.ext_mpss:
                ext_me = MovingEnvironment(
                    impo, ext_mpss[iroot], ext_mps, "EX" + ext_mps.info.tag)
                ext_me.delayed_contraction = OpNamesSet.normal_ops()
                ext_me.init_environments(outputlevel >= 2)
                dmrg.ext_mes.append(ext_me)
            if "lowmem_noise" in dic:
                dmrg.noise_type = NoiseTypes.ReducedPerturbativeCollectedLowMem
            if "dm_noise" in dic:
                dmrg.noise_type = NoiseTypes.DensityMatrix
            elif decomp_type != DecompositionTypes.SVD:
                dmrg.noise_type = NoiseTypes.ReducedPerturbativeCollected
            else:
                dmrg.noise_type = NoiseTypes.ReducedPerturbative
            dmrg.cutoff = float(dic.get("cutoff", 1E-14))
            dmrg.decomp_type = decomp_type
            dmrg.trunc_type = trunc_type
            dmrg.davidson_conv_thrds = VectorFP(dav_thrds)
            dmrg.davidson_max_iter = int(dic.get("davidson_max_iter", 5000))
            dmrg.davidson_soft_max_iter = int(
                dic.get("davidson_soft_max_iter", -1))
            dmrg.store_wfn_spectra = store_wfn_spectra
            dmrg.site_dependent_bond_dims = site_dependent_bdims

            sweep_energies = []
            discarded_weights = []
            if "twodot_to_onedot" not in dic:
                E_dmrg = dmrg.solve(len(bond_dims), forward, sweep_tol)
            else:
                tto = int(dic["twodot_to_onedot"])
                assert len(bond_dims) > tto
                dmrg.solve(tto, forward, 0)
                # save the twodot part energies and discarded weights
                sweep_energies.append(np.array(dmrg.energies))
                discarded_weights.append(np.array(dmrg.discarded_weights))
                dmrg.me.dot = 1
                for ext_me in dmrg.ext_mes:
                    ext_me.dot = 1
                dmrg.bond_dims = VectorUBond(bond_dims[tto:])
                dmrg.noises = VectorFP(noises[tto:])
                dmrg.davidson_conv_thrds = VectorFP(dav_thrds[tto:])
                E_dmrg = dmrg.solve(len(bond_dims) - tto,
                                    ext_mpss[iroot].center == 0, sweep_tol)
                ext_mpss[iroot].dot = 1

            if MPI is None or MPI.rank == 0:
                for ir in range(iroot + 1):
                    ext_mpss[ir].save_data()

            if conn_centers is not None:
                me.finalize_environments()

            sweep_energies.append(np.array(dmrg.energies))
            discarded_weights.append(np.array(dmrg.discarded_weights))
            sweep_energies = np.vstack(sweep_energies)
            discarded_weights = np.hstack(discarded_weights)

            if MPI is None or MPI.rank == 0:
                np.save(scratch + "/E_dmrg-%d.npy" % iroot, E_dmrg)
                np.save(scratch + "/bond_dims-%d.npy" %
                        iroot, bond_dims[:len(discarded_weights)])
                np.save(scratch + "/sweep_energies-%d.npy" %
                        iroot, sweep_energies)
                np.save(scratch + "/discarded_weights-%d.npy" %
                        iroot, discarded_weights)
            _print("DMRG Energy for root %4d = %20.15f" % (iroot, E_dmrg))

            if MPI is None or MPI.rank == 0:
                ext_mpss[iroot].info.save_data(
                    scratch + '/mps_info-ss-%d.bin' % iroot)
                ext_mpss[iroot].info.save_data(
                    scratch + '/%s-mps_info-ss-%d.bin' % (mps_tags[0], iroot))
        
        if "twodot_to_onedot" in dic:
            dot = 1

    # GS DMRG
    if "restart_onepdm" not in dic and "restart_twopdm" not in dic \
            and "restart_correlation" not in dic and "restart_tran_twopdm" not in dic \
            and "restart_oh" not in dic and "statespecific" not in dic \
            and "restart_tran_onepdm" not in dic and "restart_tran_oh" not in dic \
            and "restart_copy_mps" not in dic and "restart_sample" not in dic \
            and "delta_t" not in dic and "compression" not in dic and "stopt_sampling" not in dic:

        me = MovingEnvironment(mpo, mps, mps, "DMRG")
        if "use_hybrid_complex" not in dic:
            me.delayed_contraction = OpNamesSet.normal_ops()
            me.cached_contraction = cached_contraction
        me.save_partition_info = True
        me.init_environments(outputlevel >= 2)

        if "use_hybrid_complex" in dic:
            cpx_me = MovingEnvironmentX(mpo_cpx_h1e, mps, mps, "DMRG-CPX")
            cpx_me.cached_contraction = False
            cpx_me.init_environments(outputlevel >= 2)

        if conn_centers is not None:
            forward = mps.center == 0

        _print("env init finished", time.perf_counter() - tx)

        if big_site_method is not None and n_cas != 0:
            dmrg = DMRGBigSite(me, VectorUBond(bond_dims), VectorFP(noises))
            dmrg.last_site_svd = True
            dmrg.last_site_1site = dot == 2
            dmrg.decomp_last_site = False
        else:
            dmrg = DMRG(me, VectorUBond(bond_dims), VectorFP(noises))
        dmrg.iprint = max(min(outputlevel, 3), 0)

        if "skip_inact_ext_sites" in dic:
            dmrg.sweep_start_site = n_inactive
            dmrg.sweep_end_site = me.n_sites - n_external

        if "use_hybrid_complex" in dic:
            dmrg.cpx_me = cpx_me

        # projection
        if len(proj_tags) != 0:
            proj_weights = dic.get("proj_weights", None)
            assert proj_weights is not None
            proj_weights = VectorFP([float(x) for x in proj_weights.split()])
            assert len(proj_weights) == len(proj_tags)
            ext_mpss = []
            for ipj in range(len(proj_weights)):
                xmps, xmps_info, _ = get_mps_from_tags(ipj, True, mps.center)
                ext_mpss.append(xmps)
            dmrg.projection_weights = proj_weights
            dmrg.ext_mpss = VectorMPS(ext_mpss)
            for ext_mps in dmrg.ext_mpss:
                ext_me = MovingEnvironment(impo, mps, ext_mps, "PJ" + ext_mps.info.tag)
                ext_me.delayed_contraction = OpNamesSet.normal_ops()
                ext_me.init_environments(outputlevel >= 2)
                dmrg.ext_mes.append(ext_me)

        if "lowmem_noise" in dic:
            dmrg.noise_type = NoiseTypes.ReducedPerturbativeCollectedLowMem
        elif decomp_type != DecompositionTypes.SVD:
            dmrg.noise_type = NoiseTypes.ReducedPerturbativeCollected
        else:
            dmrg.noise_type = NoiseTypes.ReducedPerturbative
        dmrg.cutoff = float(dic.get("cutoff", 1E-14))
        dmrg.davidson_max_iter = int(dic.get("davidson_max_iter", 5000))
        dmrg.davidson_soft_max_iter = int(
            dic.get("davidson_soft_max_iter", -1))
        dmrg.decomp_type = decomp_type
        dmrg.trunc_type = trunc_type
        dmrg.davidson_conv_thrds = VectorFP(dav_thrds)
        dmrg.store_wfn_spectra = store_wfn_spectra
        dmrg.site_dependent_bond_dims = site_dependent_bdims
        sweep_energies = []
        discarded_weights = []
        if big_site_method is not None and n_cas == 0:
            if "twodot_to_onedot" in dic:
                _print("WARNING: twodot_to_onedot is ignored for n_cas = 0")
            E_dmrg = dmrg.solve(len(bond_dims), forward, sweep_tol)
        elif "twodot_to_onedot" not in dic:
            E_dmrg = dmrg.solve(len(bond_dims), forward, sweep_tol)
        else:
            tto = int(dic["twodot_to_onedot"])
            assert len(bond_dims) > tto
            dmrg.solve(tto, forward, 0)
            # save the twodot part energies and discarded weights
            for x in dmrg.energies:
                while len(x) < nroots:
                    x.append(0.0)
            sweep_energies.append(np.array(dmrg.energies))
            discarded_weights.append(np.array(dmrg.discarded_weights))
            if big_site_method is not None and n_cas != 0:
                dmrg.last_site_1site = False
                dmrg.me.center = mps.center
            dmrg.bond_dims = VectorUBond(bond_dims[tto:])
            dmrg.noises = VectorFP(noises[tto:])
            dmrg.davidson_conv_thrds = VectorFP(dav_thrds[tto:])
            dmrg.site_dependent_bond_dims = site_dependent_bdims[tto:]
            dmrg.me.dot = 1
            for ext_me in dmrg.ext_mes:
                ext_me.dot = 1
            E_dmrg = dmrg.solve(len(bond_dims) - tto,
                                mps.center == dmrg.sweep_start_site, sweep_tol)
            mps.dot = 1
            dot = 1
            if MPI is None or MPI.rank == 0:
                mps.save_data()

        if conn_centers is not None:
            me.finalize_environments()

        _print("Final canonical form = ", mps.canonical_form, mps.center)
        for x in dmrg.energies:
            while len(x) < nroots:
                x.append(0.0)
        sweep_energies.append(np.array(dmrg.energies))
        discarded_weights.append(np.array(dmrg.discarded_weights))
        sweep_energies = np.vstack(sweep_energies)
        discarded_weights = np.hstack(discarded_weights)

        if MPI is None or MPI.rank == 0:
            bdims = bond_dims[:len(discarded_weights)]
            if len(bdims) < len(discarded_weights):
                bdims = bdims + bdims[-1:] * \
                    (len(discarded_weights) - len(bdims))
            np.save(scratch + "/E_dmrg.npy", E_dmrg)
            if stackblock_compat:
                dmrg_energies = [E_dmrg] if nroots == 1 else list(sweep_energies[-1])
                with open(os.path.join(scratch + "/dmrg.e"), "wb") as f:
                    import struct
                    f.write(struct.pack('d' * nroots, *dmrg_energies))
            np.save(scratch + "/bond_dims.npy", bdims)
            np.save(scratch + "/sweep_energies.npy", sweep_energies)
            np.save(scratch + "/discarded_weights.npy", discarded_weights)
            if store_wfn_spectra:
                np.save(scratch + "/sweep_wfn_spectra.npy",
                        np.array([np.array(x) for x in dmrg.sweep_wfn_spectra], dtype=object))
            if "extrapolation" in dic:
                ext_eners = []
                ext_dws = []
                ext_bdims = []
                if "twodot_to_onedot" not in dic:
                    llsw = len(sweep_energies)
                else:
                    llsw = tto
                for iext in range(llsw):
                    if bdims[iext] not in ext_bdims:
                        ext_bdims.append(bdims[iext])
                        ext_dws.append(discarded_weights[iext])
                        ext_eners.append(sweep_energies[iext, 0])
                    else:
                        ii = ext_bdims.index(bdims[iext])
                        ext_dws[ii] = discarded_weights[iext]
                        ext_eners[ii] = sweep_energies[iext, 0]
                ext_eners = np.array(ext_eners)
                ext_dws = np.array(ext_dws)
                ext_bdims = np.array(ext_bdims)
                _print('EXTRAP discarded weights = ', ext_dws)
                _print('EXTRAP oh energies (au) = ', ext_eners)
                _print('EXTRAP bond dimensions = ', ext_bdims)
                import scipy.stats
                reg = scipy.stats.linregress(ext_dws, ext_eners)
                _print('EXTRAP Energy = %20.15f (+/-) %20.15f' % (reg.intercept,
                                                                  np.min(np.abs(reg.intercept - ext_eners)) / 5))
                _print('EXTRAP R^2 = %20.15f' % (reg.rvalue ** 2))
                emin, emax = min(ext_eners), max(ext_eners)
                de = emax - emin
                xmin, xmax = min(ext_dws), max(ext_dws)
                ddw = xmax - xmin
                import matplotlib.pyplot as plt
                x_reg = np.array([0, xmax + ddw / 12])
                plt.plot(x_reg, reg.intercept + reg.slope * x_reg,
                         '--', linewidth=1, color='#5FA8AB')
                plt.plot(ext_dws, ext_eners, 'o', color='#38686A',
                         markerfacecolor='white', markersize=5)
                plt.text(ddw / 12, emax, "$E(M=\\infty) = %.6f \pm %.6f \\mathrm{\\ Hartree}$" %
                         (reg.intercept, abs(reg.intercept - emin) / 5), color='#38686A', fontsize=12)
                plt.text(ddw / 12, emax - de / 12, "$R^2 = %.6f$" % (reg.rvalue ** 2),
                         color='#38686A', fontsize=12)
                plt.xlim((0, xmax + ddw / 12))
                plt.ylim((emin - de / 12, emax + de / 12))
                plt.xlabel("Largest Discarded Weight")
                plt.ylabel("Sweep Energy (Hartree)")
                plt.subplots_adjust(left=0.16, bottom=0.1,
                                    right=0.95, top=0.95)
                plt.savefig(scratch + "/extrapolation.png", dpi=600)

        if dynamic_corr_method is not None:
            if dynamic_corr_method[0] in ['casci', 'nevpt2s', 'nevpt2sd', "nevpt2-ijrs", "nevpt2-ij",
                        "nevpt2-rs", "nevpt2-ijr", "nevpt2-rsi", "nevpt2-ir", "nevpt2-i", "nevpt2-r",
                        "mrrept2s", "mrrept2sd", "mrrept2-ijrs", "mrrept2-ij",
                        "mrrept2-rs", "mrrept2-ijr", "mrrept2-rsi", "mrrept2-ir", "mrrept2-i", "mrrept2-r"]:
                _print("DMRG-CASCI Energy = %20.15f" % E_dmrg)
            elif dynamic_corr_method[0] in ['dmrgfci']:
                _print("DMRG-FCI Energy = %20.15f" % E_dmrg)
            elif dynamic_corr_method[0] in ['mrcis', 'mrcisd', 'mrcisdt']:
                _print("DMRG-%s Energy = %20.15f" %
                       (dynamic_corr_method[0].upper(), E_dmrg))
        else:
            _print("DMRG Energy = %20.15f" % E_dmrg)

        if MPI is None or MPI.rank == 0:
            mps_info.save_data(scratch + '/mps_info.bin')
            mps_info.save_data(scratch + '/%s-mps_info.bin' % mps_tags[0])

    # Compression
    if "compression" in dic:
        lmps, lmps_info, _ = get_mps_from_tags(-1)
        if "random_mps_init" not in dic:
            mps = lmps.deep_copy(mps_tags[0])
            mps_info = mps.info

        if "stopt_compression" in dic:
            E_dmrg = float(np.load(scratch + "/E_dmrg.npy"))
            mpo.const_e -= E_dmrg

        me = MovingEnvironment(impo if overlap else mpo, mps, lmps, "CPS")
        me.delayed_contraction = OpNamesSet.normal_ops()
        me.cached_contraction = cached_contraction
        me.save_partition_info = True
        me.init_environments(outputlevel >= 2)

        cps = Linear(me, VectorUBond(bond_dims),
                     VectorUBond([lmps.info.bond_dim]))
        cps.iprint = max(min(outputlevel, 3), 0)
        cps.cutoff = float(dic.get("cutoff", 1E-14))
        cps.decomp_type = decomp_type
        cps.trunc_type = trunc_type
        if "stopt_compression" in dic:
            cps.conv_type = ConvergenceTypes.LastMaximal
        if "twodot_to_onedot" not in dic:
            ovl = cps.solve(len(bond_dims), mps.center == 0, sweep_tol)
        else:
            tto = int(dic["twodot_to_onedot"])
            assert len(bond_dims) > tto
            cps.solve(tto, mps.center == 0, 0)
            cps.bra_bond_dims = VectorUBond(bond_dims[tto:])
            cps.rme.dot = 1
            ovl = cps.solve(len(bond_dims) - tto, mps.center == 0, sweep_tol)
            mps.dot = 1
            lmps.dot = 1
            if MPI is None or MPI.rank == 0:
                mps.save_data()
                lmps.save_data()
        _print("Final canonical form = ", mps.canonical_form, mps.center)
        _print("Compression overlap = %20.15f" % ovl)

        if MPI is None or MPI.rank == 0:
            np.save(scratch + "/cps_overlap.npy", ovl)
            mps_info.save_data(scratch + '/mps_info.bin')
            mps_info.save_data(scratch + '/%s-mps_info.bin' % mps_tags[0])

        if "stopt_compression" in dic:
            mpo.const_e += E_dmrg

    # Time Evolution
    if "delta_t" in dic:

        if len(read_tags) == 0:
            _print("Time Evolution START FROM RANDOM MPS !!!")
        else:
            mps, mps_info, _ = get_mps_from_tags(-1)

        dt = complex(dic["delta_t"].replace(' ', '').replace('i', 'j'))
        tt = complex(dic["target_t"].replace(' ', '').replace('i', 'j'))
        n_steps = int(abs(tt) / abs(dt) + 0.1)
        assert np.abs(abs(n_steps * dt) - abs(tt)) < 1E-10
        is_imag_te = abs(np.imag(dt)) < 1E-10
        if is_imag_te:
            dt = np.real(dt)
            tt = np.real(tt)
            _print("Time Evolution  DELTA T = %15.8f" % dt)
            _print("Time Evolution TARGET T = %15.8f" % tt)
        else:
            _print("Time Evolution  DELTA T = RE %15.8f + IM %15.8f" %
                   (np.real(dt), np.imag(dt)))
            _print("Time Evolution TARGET T = RE %15.8f + IM %15.8f" %
                   (np.real(tt), np.imag(tt)))

        if isinstance(mps, MultiMPS):
            assert len(mps.wfns) == 2
            assert mps.info.tag != mps_tags[0]
            assert complex_mps
            mps = mps.deep_copy(mps_tags[0])
            mps_info = mps.info
        else:
            assert not complex_mps
            assert mps.info.tag != mps_tags[0]
            mps = mps.deep_copy(mps_tags[0])
            mps_info = mps.info

        _print("Time Evolution   NSTEPS = %d" % n_steps)
        _print("    with %s wavefunction" %
               ("complex" if complex_mps else "real"))
        _print("    with %s step (%s TE)" % (
            "real" if is_imag_te else "complex", "imag" if is_imag_te else "real"))
        _print("    init canonical form = %s" % mps.canonical_form)

        me = MovingEnvironment(mpo, mps, mps, "DMRG")
        me.delayed_contraction = OpNamesSet.normal_ops()
        me.cached_contraction = cached_contraction
        me.save_partition_info = True
        me.init_environments(outputlevel >= 2)

        te = TimeEvolution(me, VectorUBond(bond_dims), te_type)
        te.hermitian = not anti_herm
        te.iprint = max(min(outputlevel, 3), 0)
        te.n_sub_sweeps = 1
        if te.mode != TETypes.TangentSpace:
            te.n_sub_sweeps = int(dic.get("n_sub_sweeps", 2))
        te.normalize_mps = "normalize_mps" in dic
        te_times = []
        te_energies = []
        te_normsqs = []
        te_discarded_weights = []
        for i in range(n_steps):
            if te.mode == TETypes.TangentSpace:
                te.solve(2, dt / 2, mps.center == 0)
            else:
                te.solve(1, dt, mps.center == 0)
            if is_imag_te:
                _print("T = %10.5f <E> = %20.15f <Norm^2> = %20.15f" %
                       ((i + 1) * dt, te.energies[-1], te.normsqs[-1]))
            else:
                _print("T = RE %10.5f + IM %10.5f <E> = %20.15f <Norm^2> = %20.15f" %
                       ((i + 1) * np.real(dt), (i + 1) * np.imag(dt), te.energies[-1], te.normsqs[-1]))
            te_times.append((i + 1) * dt)
            te_energies.append(te.energies[-1])
            te_normsqs.append(te.normsqs[-1])
            te_discarded_weights.append(te.discarded_weights[-1])
        _print("Max Discarded Weight = %9.5g" % max(te_discarded_weights))

        _print("   mps final tag = %s" % mps_tags[0])
        _print("   mps final canonical form = %s" % mps.canonical_form)

        np.save(scratch + "/te_times.npy", np.array(te_times))
        np.save(scratch + "/te_energies.npy", np.array(te_energies))
        np.save(scratch + "/te_normsqs.npy", np.array(te_normsqs))
        np.save(scratch + "/te_discarded_weights.npy",
                np.array(te_discarded_weights))

        if MPI is None or MPI.rank == 0:
            mps_info.save_data(scratch + '/mps_info.bin')
            mps_info.save_data(scratch + '/%s-mps_info.bin' % mps_tags[0])

    def do_onepdm(bmps, kmps):
        me = MovingEnvironment(pmpo, bmps, kmps, "1PDM")
        # currently delayed_contraction is not compatible to
        # ExpectationAlgorithmTypes.Fast
        if algo_type == ExpectationAlgorithmTypes.Normal:
            me.delayed_contraction = OpNamesSet.normal_ops()
        me.cached_contraction = cached_contraction
        me.save_partition_info = True
        me.init_environments(outputlevel >= 2)

        _print("env init finished", time.perf_counter() - tx)

        expect = XExpect(me, bmps.info.bond_dim, kmps.info.bond_dim)
        expect.zero_dot_algo = "zerodot" in dic
        expect.algo_type = algo_type
        expect.iprint = max(min(outputlevel, 3), 0)
        expect.solve(True, kmps.center == 0)
        _print("Final canonical form = ", bmps.canonical_form, bmps.center)

        if MPI is None or MPI.rank == 0:
            dmr = expect.get_1pdm(n_orbs)
            dm = np.array(dmr).copy()
            dmr.deallocate()
            if "use_general_spin" in dic:
                dm = dm[None]
            else:
                dm = dm.reshape((dm.shape[0] // 2, 2, dm.shape[1] // 2, 2))
                dm = np.transpose(dm, (0, 2, 1, 3))
                dm = np.concatenate(
                    [dm[None, :, :, 0, 0], dm[None, :, :, 1, 1]], axis=0)
            if orb_idx is not None:
                rev_idx = np.argsort(orb_idx)
                if "trans_integral_to_spin_orbital" in dic:
                    rev_idx = np.array(list(zip(rev_idx*2, rev_idx*2+1))).ravel()
                assert dm.shape[-1] == len(rev_idx)
                dm[:, :, :] = dm[:, rev_idx, :][:, :, rev_idx]
            return dm
        else:
            return None

    # ONEPDM
    if "restart_onepdm" in dic or "onepdm" in dic:

        if nroots == 1:

            if "skip_inact_ext_sites" in dic and mps.center != 0 and mps.center != mps.n_sites - mps.dot:
                if MPI is not None:
                    MPI.barrier()
                _print('change canonical form ...')
                _print('original cf = ', mps.canonical_form)
                ime = MovingEnvironment(impo, mps, mps, "IEX")
                ime.delayed_contraction = OpNamesSet.normal_ops()
                ime.cached_contraction = cached_contraction
                ime.init_environments(False)
                expect = XExpect(ime, mps.info.bond_dim, mps.info.bond_dim)
                expect.iprint = max(min(outputlevel, 3), 0)
                expect.solve(True, mps.center != n_inactive)
                _print('final cf = ', mps.canonical_form)
                if MPI is not None:
                    MPI.barrier()

            dm = do_onepdm(mps, mps)
            if MPI is None or MPI.rank == 0:
                dmocc = np.diag(np.sum(dm, axis=0))
                if dmocc.dtype == np.complex128:
                    _print("DMRG OCC = ", "".join(["(%6.3f,%6.3f)" % (np.real(x), np.imag(x)) for x in dmocc]))
                else:
                    _print("DMRG OCC = ", "".join(["%6.3f" % x for x in dmocc]))
                if big_site_method != "folding":
                    np.save(scratch + "/1pdm.npy", dm)
                mps_info.save_data(scratch + '/mps_info.bin')
                mps_info.save_data(scratch + '/%s-mps_info.bin' % mps_tags[0])

                # natural orbital generation
                if "nat_orbs" in dic:
                    spdm = np.sum(dm, axis=0)
                    # need pdm after orbital rotation
                    if orb_idx is not None:
                        spdm[:, :] = spdm[orb_idx, :][:, orb_idx]
                    xpdm = spdm.copy()
                    _print("REORDERED OCC = ", "".join(
                        ["%6.3f" % x for x in np.diag(spdm)]))
                    spdm = spdm.flatten()
                    nat_occs = np.zeros((n_sites, ))
                    if orb_sym is None:
                        raise ValueError(
                            "Need FCIDUMP construction (namely, not a pre run) for 'nat_orbs'!")
                    MatrixFunctions.block_eigs(spdm, nat_occs, orb_sym)
                    _print("NAT OCC = ", "".join(
                        ["%9.6f" % x for x in nat_occs]))
                    # (old, new)
                    rot = np.array(spdm.reshape(
                        (n_sites, n_sites)).T, copy=True)
                    np.save(scratch + "/nat_orb_sym.npy", np.array(orb_sym))
                    for isym in set(orb_sym):
                        mask = np.array(orb_sym) == isym
                        if "nat_km_reorder" in dic:
                            kmidx = np.argsort(KuhnMunkres(
                                1 - rot[mask, :][:, mask] ** 2).solve()[1])
                            _print("init = ", np.sum(
                                np.diag(rot[mask, :][:, mask]) ** 2))
                            rot[:, mask] = rot[:, mask][:, kmidx]
                            nat_occs[mask] = nat_occs[mask][kmidx]
                            _print("final = ", np.sum(
                                np.diag(rot[mask, :][:, mask]) ** 2))
                        if "nat_positive_def" in dic:
                            for j in range(len(nat_occs[mask])):
                                mrot = rot[mask, :][:j + 1,
                                                    :][:, mask][:, :j + 1]
                                mrot_det = np.linalg.det(mrot)
                                _print("ISYM = %d J = %d MDET = %15.10f" %
                                       (isym, j, mrot_det))
                                if mrot_det < 0:
                                    mask0 = np.arange(len(mask), dtype=int)[
                                        mask][j]
                                    rot[:, mask0] = -rot[:, mask0]
                        else:
                            mrot = rot[mask, :][:, mask]
                            mrot_det = np.linalg.det(mrot)
                            _print("ISYM = %d MDET = %15.10f" %
                                   (isym, mrot_det))
                            if mrot_det < 0:
                                mask0 = np.arange(len(mask), dtype=int)[
                                    mask][0]
                                rot[:, mask0] = -rot[:, mask0]
                    if "nat_km_reorder" in dic:
                        _print("REORDERED NAT OCC = ", "".join(
                            ["%9.6f" % x for x in nat_occs]))
                    assert np.linalg.norm(rot @ np.diag(
                        nat_occs) @ rot.T - xpdm) < 1E-10
                    np.save(scratch + "/nat_occs.npy", nat_occs)
                    rot_det = np.linalg.det(rot)
                    _print("DET = %15.10f" % rot_det)
                    assert rot_det > 0
                    np.save(scratch + "/nat_rotation.npy", rot)

                    def my_logm(mrot):
                        rs = mrot + mrot.T
                        rl, rv = np.linalg.eigh(rs)
                        assert np.linalg.norm(
                            rs - rv @ np.diag(rl) @ rv.T) < 1E-10
                        rd = rv.T @ mrot @ rv
                        ra, rdet = 1, rd[0, 0]
                        for i in range(1, len(rd)):
                            ra, rdet = rdet, rd[i, i] * rdet - \
                                rd[i - 1, i] * rd[i, i - 1] * ra
                        assert rdet > 0
                        ld = np.zeros_like(rd)
                        for i in range(0, len(rd) // 2 * 2, 2):
                            xcos = (rd[i, i] + rd[i + 1, i + 1]) / 2
                            xsin = (rd[i, i + 1] - rd[i + 1, i]) / 2
                            theta = np.arctan2(xsin, xcos)
                            ld[i, i + 1] = theta
                            ld[i + 1, i] = -theta
                        return rv @ ld @ rv.T

                    import scipy.linalg
                    # kappa = scipy.linalg.logm(rot)
                    kappa = np.zeros_like(rot)
                    for isym in set(orb_sym):
                        mask = np.array(orb_sym) == isym
                        mrot = rot[mask, :][:, mask]
                        mkappa = my_logm(mrot)
                        # mkappa = scipy.linalg.logm(mrot)
                        # assert mkappa.dtype == float
                        gkappa = np.zeros((kappa.shape[0], mkappa.shape[1]))
                        gkappa[mask, :] = mkappa
                        kappa[:, mask] = gkappa
                    assert np.linalg.norm(
                        scipy.linalg.expm(kappa) - rot) < 1E-10
                    assert np.linalg.norm(kappa + kappa.T) < 1E-10

                    # rot is (old, new) => kappa should be minus
                    np.save(scratch + "/nat_kappa.npy", kappa)

                    # integral rotation
                    nat_fname = dic["nat_orbs"].strip()
                    if len(nat_fname) > 0:
                        if fcidump is None:
                            raise ValueError(
                                "Need FCIDUMP construction (namely, not a pre run) for 'nat_orbs'!")
                        # the following code will not check values inside fcidump
                        # since all MPOs are already constructed
                        _print("rotating integrals to natural orbitals ...")
                        # (old, new)
                        fcidump.rotate(VectorFP(rot.flatten()))
                        _print("finished.")
                        rot_sym_error = fcidump.symmetrize(orb_sym)
                        _print("rotated integral sym error = %12.4g" %
                               rot_sym_error)
                        if rot_sym_error > symmetrize_ints_tol:
                            raise RuntimeError(("Integral symmetrization error larger than %10.5g, "
                                                + "please check point group symmetry and FCIDUMP or set"
                                                + " a higher tolerance for the keyword '%s'") % (
                                symmetrize_ints_tol, "symmetrize_ints"))
                        _print("writing natural orbital integrals ...")
                        fcidump.write(nat_fname)
                        _print("finished.")

        else:
            for iroot in range(nroots):
                _print('----- root = %3d / %3d -----' % (iroot, nroots))
                if len(mps_tags) > 1:
                    smps, smps_info, forward = get_mps_from_tags(iroot)
                elif "statespecific" in dic:
                    smps, smps_info, forward = get_state_specific_mps(
                        iroot, mps_info)
                else:
                    smps, smps_info, forward = split_mps(iroot, mps, mps_info)
                dm = do_onepdm(smps, smps)
                if MPI is None or MPI.rank == 0:
                    dmocc = np.diag(np.sum(dm, axis=0))
                    if dmocc.dtype == np.complex128:
                        _print("DMRG OCC = ", "".join(["(%6.3f,%6.3f)" % (np.real(x), np.imag(x)) for x in dmocc]))
                    else:
                        _print("DMRG OCC = ", "".join(["%6.3f" % x for x in dmocc]))
                    np.save(scratch + "/1pdm-%d-%d.npy" % (iroot, iroot), dm)
                    smps_info.save_data(scratch + '/mps_info-%d.bin' % iroot)

    # Transition ONEPDM
    # note that there can be a undetermined +1/-1 factor due to the relative phase in two MPSs
    if "restart_tran_onepdm" in dic or "tran_onepdm" in dic:

        assert nroots != 1
        brar, ketr = range(nroots), range(nroots)
        if "tran_bra_range" in dic:
            tbr = [int(x) for x in dic["tran_bra_range"].split()]
            brar = range(*tbr)
        if "tran_ket_range" in dic:
            tkr = [int(x) for x in dic["tran_ket_range"].split()]
            ketr = range(*tkr)
        for iroot in brar:
            for jroot in ketr:
                _print('----- root = %3d -> %3d / %3d -----' %
                       (jroot, iroot, nroots))
                if "tran_triangular" in dic:
                    if iroot < jroot:
                        continue
                tx = time.perf_counter()
                if len(mps_tags) > 1:
                    simps, simps_info, _ = get_mps_from_tags(iroot)
                    sjmps, sjmps_info, _ = get_mps_from_tags(jroot)
                elif "statespecific" in dic:
                    simps, simps_info, _ = get_state_specific_mps(
                        iroot, mps_info)
                    sjmps, sjmps_info, _ = get_state_specific_mps(
                        jroot, mps_info)
                else:
                    simps, simps_info, _ = split_mps(iroot, mps, mps_info)
                    sjmps, sjmps_info, _ = split_mps(jroot, mps, mps_info)
                if jroot == iroot:
                    sjmps = simps
                dm = do_onepdm(simps, sjmps)
                if soc:
                    if SX == SU2:
                        if hasattr(simps.info, "targets"):
                            qsbra = simps.info.targets[0].twos
                        else:
                            qsbra = simps.info.target.twos
                        # fix different Wigner–Eckart theorem convention
                        dm *= np.sqrt(qsbra + 1)
                    dm = dm / np.sqrt(2)
                if MPI is None or MPI.rank == 0:
                    np.save(scratch + "/1pdm-%d-%d.npy" % (iroot, jroot), dm)
                if (MPI is None or MPI.rank == 0) and iroot == jroot:
                    _print("DMRG OCC (state %4d) = " % iroot, "".join(
                        ["%6.3f" % x for x in np.diag(dm[0]) + np.diag(dm[1])]))
                    simps_info.save_data(scratch + '/mps_info-%d.bin' % iroot)
                _print("tran 1pdm finished", time.perf_counter() - tx)

    # Particle Number Correlation
    if "restart_correlation" in dic or "correlation" in dic:
        assert nroots == 1
        me = MovingEnvironment(nmpo, mps, mps, "1NPC")
        if algo_type == ExpectationAlgorithmTypes.Normal:
            me.delayed_contraction = OpNamesSet.normal_ops()
        me.cached_contraction = cached_contraction
        me.save_partition_info = True
        me.init_environments(outputlevel >= 2)

        _print("env init finished", time.perf_counter() - tx)

        expect = XExpect(me, mps.info.bond_dim, mps.info.bond_dim)
        expect.algo_type = algo_type
        expect.iprint = max(min(outputlevel, 3), 0)
        expect.solve(True, mps.center == 0)
        _print("Final canonical form = ", mps.canonical_form, mps.center)

        if MPI is None or MPI.rank == 0:

            dmr = expect.get_1npc_spatial(0, n_orbs)
            dm_pure = np.array(dmr).copy()
            dmr.deallocate()
            dmr = expect.get_1npc_spatial(1, n_orbs)
            dm_mix = np.array(dmr).copy()
            dmr.deallocate()
            dm = np.concatenate(
                [dm_pure[None, :, :], dm_mix[None, :, :]], axis=0)
            if orb_idx is not None:
                rev_idx = np.argsort(orb_idx)
                dm[:, :, :] = dm[:, rev_idx, :][:, :, rev_idx]

            np.save(scratch + "/1npc.npy", dm)
            mps_info.save_data(scratch + '/mps_info.bin')
            mps_info.save_data(scratch + '/%s-mps_info.bin' % mps_tags[0])

    # diag twopdm
    if "restart_diag_twopdm" in dic or "diag_twopdm" in dic:
        if MPI is None or MPI.rank == 0:
            assert nroots == 1
            dm_npc = np.load(scratch + "/1npc.npy")
            dm_pdm = np.load(scratch + "/1pdm.npy").sum(axis=0)
            dm_e_pqqp = dm_npc[0] - np.diag(np.diag(dm_pdm))
            dm_e_pqpq = -dm_npc[1] + 2 * np.diag(np.diag(dm_pdm))
            np.save(scratch + "/e_pqqp.npy", dm_e_pqqp)
            np.save(scratch + "/e_pqpq.npy", dm_e_pqpq)

    def do_twopdm(bmps, kmps):
        me = MovingEnvironment(p2mpo, bmps, kmps, "2PDM")
        if algo_type == ExpectationAlgorithmTypes.Normal:
            me.delayed_contraction = OpNamesSet.normal_ops()
        me.cached_contraction = cached_contraction
        me.save_partition_info = True
        me.init_environments(outputlevel >= 2)

        _print("env init finished", time.perf_counter() - tx)

        expect = XExpect(me, bmps.info.bond_dim, kmps.info.bond_dim)
        expect.algo_type = algo_type
        expect.zero_dot_algo = "zerodot" in dic
        expect.iprint = max(min(outputlevel, 3), 0)
        expect.solve(True, kmps.center == 0)
        _print("Final canonical form = ", bmps.canonical_form, bmps.center)

        if MPI is None or MPI.rank == 0:
            dmr = expect.get_2pdm(n_orbs)
            dm = np.array(dmr, copy=True)
            dm = dm.reshape((n_orbs, 2, n_orbs, 2,
                             n_orbs, 2, n_orbs, 2))
            dm = np.transpose(dm, (0, 2, 4, 6, 1, 3, 5, 7))
            dm = np.concatenate([dm[None, :, :, :, :, 0, 0, 0, 0], dm[None, :, :, :, :, 0, 1, 1, 0],
                                 dm[None, :, :, :, :, 1, 1, 1, 1]], axis=0)
            if orb_idx is not None:
                rev_idx = np.argsort(orb_idx)
                dm[:, :, :, :, :] = dm[:, rev_idx, :, :, :][:, :, rev_idx,
                                                            :, :][:, :, :, rev_idx, :][:, :, :, :, rev_idx]
            return dm
        else:
            return None
    
    def save_twopdm_stackblock_format(dm, fn):
        xdm = (dm[0] + dm[2] + 2 * dm[1]) / 2
        xn = xdm.shape[0]
        with open(scratch + "/" + fn, 'w') as f:
            f.write("%d\n" % xn)
            for i in range(0, xn):
                for j in range(0, xn):
                    for k in range(0, xn):
                        for l in range(0, xn):
                            f.write(("%d " * 4 + "%20.14f\n")
                                % (i, j, k, l, xdm[i, j, k, l]))

    # TWOPDM
    if "restart_twopdm" in dic or "twopdm" in dic:

        if nroots == 1:
            dm = do_twopdm(mps, mps)
            if MPI is None or MPI.rank == 0:
                np.save(scratch + "/2pdm.npy", dm)
                if stackblock_compat:
                    save_twopdm_stackblock_format(dm, "spatial_twopdm.0.0.txt")
                mps_info.save_data(scratch + '/mps_info.bin')
                mps_info.save_data(scratch + '/%s-mps_info.bin' % mps_tags[0])
        else:
            for iroot in range(nroots):
                _print('----- root = %3d / %3d -----' % (iroot, nroots))
                if len(mps_tags) > 1:
                    smps, smps_info, _ = get_mps_from_tags(iroot)
                elif "statespecific" in dic:
                    smps, smps_info, forward = get_state_specific_mps(
                        iroot, mps_info)
                else:
                    smps, smps_info, forward = split_mps(iroot, mps, mps_info)
                dm = do_twopdm(smps, smps)
                if MPI is None or MPI.rank == 0:
                    np.save(scratch + "/2pdm-%d-%d.npy" % (iroot, iroot), dm)
                    if stackblock_compat:
                        save_twopdm_stackblock_format(dm, "spatial_twopdm.%d.%d.txt" % (iroot, iroot))
                    smps_info.save_data(scratch + '/mps_info-%d.bin' % iroot)

    # Transition TWOPDM
    # note that there can be a undetermined +1/-1 factor due to the relative phase in two MPSs
    if "restart_tran_twopdm" in dic or "tran_twopdm" in dic:

        assert nroots != 1
        brar, ketr = range(nroots), range(nroots)
        if "tran_bra_range" in dic:
            tbr = [int(x) for x in dic["tran_bra_range"].split()]
            brar = range(*tbr)
        if "tran_ket_range" in dic:
            tkr = [int(x) for x in dic["tran_ket_range"].split()]
            ketr = range(*tkr)
        for iroot in brar:
            for jroot in ketr:
                _print('----- root = %3d -> %3d / %3d -----' %
                       (jroot, iroot, nroots))
                if "tran_triangular" in dic:
                    if iroot < jroot:
                        continue
                tx = time.perf_counter()
                if len(mps_tags) > 1:
                    simps, simps_info, _ = get_mps_from_tags(iroot)
                    sjmps, sjmps_info, _ = get_mps_from_tags(jroot)
                elif "statespecific" in dic:
                    simps, simps_info, _ = get_state_specific_mps(
                        iroot, mps_info)
                    sjmps, sjmps_info, _ = get_state_specific_mps(
                        jroot, mps_info)
                else:
                    simps, simps_info, _ = split_mps(iroot, mps, mps_info)
                    sjmps, sjmps_info, _ = split_mps(jroot, mps, mps_info)
                if jroot == iroot:
                    sjmps = simps
                dm = do_twopdm(simps, sjmps)
                if MPI is None or MPI.rank == 0:
                    np.save(scratch + "/2pdm-%d-%d.npy" % (iroot, jroot), dm)
                _print("tran 2pdm finished", time.perf_counter() - tx)
            if MPI is None or MPI.rank == 0:
                simps_info.save_data(scratch + '/mps_info-%d.bin' % iroot)

    def do_oh(bmps, kmps):
        me = MovingEnvironment(impo if overlap else mpo, bmps, kmps, "OH")
        me.delayed_contraction = OpNamesSet.normal_ops()
        me.cached_contraction = cached_contraction
        me.save_partition_info = True
        me.init_environments(outputlevel >= 2)

        _print("env init finished", time.perf_counter() - tx)

        expect = XExpect(me, bmps.info.bond_dim, kmps.info.bond_dim)
        expect.iprint = max(min(outputlevel, 3), 0)
        E_oh = expect.solve(False, kmps.center == 0)

        if MPI is None or MPI.rank == 0:
            return E_oh
        else:
            return None

    # OH (Hamiltonian expectation on MPS)
    if "restart_oh" in dic or "oh" in dic:

        if nroots == 1:
            E_oh = do_oh(mps, mps)
            if MPI is None or MPI.rank == 0:
                np.save(scratch + "/E_oh.npy", E_oh)
                _print("OH Energy = %20.15f" % E_oh)
                mps_info.save_data(scratch + '/mps_info.bin')
                mps_info.save_data(scratch + '/%s-mps_info.bin' % mps_tags[0])
        else:
            mat_oh = np.zeros((nroots, ))
            for iroot in range(nroots):
                _print('----- root = %3d / %3d -----' % (iroot, nroots))
                if len(mps_tags) > 1:
                    smps, smps_info, _ = get_mps_from_tags(iroot)
                elif "statespecific" in dic:
                    smps, smps_info, forward = get_state_specific_mps(
                        iroot, mps_info)
                else:
                    smps, smps_info, forward = split_mps(iroot, mps, mps_info)
                E_oh = do_oh(smps, smps)
                if MPI is None or MPI.rank == 0:
                    mat_oh[iroot] = E_oh
                    print("OH Energy %4d - %4d = %20.15f" %
                          (iroot, iroot, E_oh))
                    smps_info.save_data(scratch + '/mps_info-%d.bin' % iroot)
            if MPI is None or MPI.rank == 0:
                np.save(scratch + "/E_oh.npy", mat_oh)

    # Transition OH (OH between different MPS roots)
    # note that there can be a undetermined +1/-1 factor due to the relative phase in two MPSs
    # only mat_oh[i, j] with i >= j are filled
    if "restart_tran_oh" in dic or "tran_oh" in dic:

        assert nroots != 1
        mat_oh = np.zeros((nroots, nroots),
                          dtype=np.complex128 if complex_mps else float)
        for iroot in range(nroots):
            for jroot in range(iroot + 1):
                _print('----- root = %3d -> %3d / %3d -----' %
                       (jroot, iroot, nroots))
                if len(mps_tags) > 1:
                    simps, simps_info, _ = get_mps_from_tags(iroot)
                    sjmps, sjmps_info, _ = get_mps_from_tags(jroot)
                elif "statespecific" in dic:
                    simps, simps_info, _ = get_state_specific_mps(
                        iroot, mps_info)
                    sjmps, sjmps_info, _ = get_state_specific_mps(
                        jroot, mps_info)
                else:
                    simps, simps_info, _ = split_mps(iroot, mps, mps_info)
                    sjmps, sjmps_info, _ = split_mps(jroot, mps, mps_info)
                E_oh = do_oh(simps, sjmps)
                if MPI is None or MPI.rank == 0:
                    mat_oh[iroot, jroot] = E_oh
                    if complex_mps:
                        print("OH Energy %4d - %4d = RE %20.15f + IM %20.15f" %
                              (iroot, jroot, np.real(E_oh), np.imag(E_oh)))
                    else:
                        print("OH Energy %4d - %4d = %20.15f" %
                              (iroot, jroot, E_oh))
            if MPI is None or MPI.rank == 0:
                simps_info.save_data(scratch + '/mps_info-%d.bin' % iroot)
        if MPI is None or MPI.rank == 0:
            np.save(scratch + "/E_oh.npy", mat_oh)

    # NEVPT2 / MRREPT2
    if dynamic_corr_method is not None and dynamic_corr_method[0] in ["nevpt2s", "nevpt2sd",
            "nevpt2-ijrs", "nevpt2-ij", "nevpt2-rs", "nevpt2-ijr", "nevpt2-rsi", "nevpt2-ir",
            "nevpt2-i", "nevpt2-r", "mrrept2s", "mrrept2sd", "mrrept2-ijrs", "mrrept2-ij",
            "mrrept2-rs", "mrrept2-ijr", "mrrept2-rsi", "mrrept2-ir", "mrrept2-i", "mrrept2-r"]:

        if fcidump is None:
            raise ValueError(
                "Need FCIDUMP construction (namely, not a pre run) for 'nevpt2/mrrept2'!")
        if MPI is not None:
            MPI.barrier()
        if dynamic_corr_method[0].startswith("nevpt2"):
            fd_dyall = DyallFCIDUMP(fcidump, n_inactive, n_external)
            dm = np.load(scratch + "/1pdm.npy")
            if fcidump.uhf:
                dmx = np.zeros((dm.shape[0] * 2, dm.shape[1] * 2))
                dmx[0::2, 0::2] = dm[0]
                dmx[1::2, 1::2] = dm[1]
                fd_dyall.initialize_from_1pdm_sz(dmx)
            else:
                fd_dyall.initialize_from_1pdm_su2(dm[0] + dm[1])
            fd_zero = fd_dyall
        else:
            fd_fink = FinkFCIDUMP(fcidump, n_inactive, n_external)
            fd_zero = fd_fink
        e_casci = float(np.load(scratch + "/E_dmrg.npy"))

        sym_error = fd_zero.symmetrize(orb_sym)
        _print("integral sym error = %12.4g" % sym_error)
        if sym_error > symmetrize_ints_tol:
            raise RuntimeError(("Integral symmetrization error larger than %10.5g, "
                                + "please check point group symmetry and FCIDUMP or set"
                                + " a higher tolerance for the keyword '%s'") % (
                symmetrize_ints_tol, "symmetrize_ints"))

        if big_site_method == "folding":
            pass
        elif big_site_method is not None:
            big_left = SimplifiedBigSite(
                big_left_orig, NoTransposeRule(simpl_rule))
            big_right = SimplifiedBigSite(
                big_right_orig, NoTransposeRule(simpl_rule))
            if MPI is not None:
                if one_body_only:
                    big_left = ParallelBigSite(big_left, prule_one_body)
                    big_right = ParallelBigSite(big_right, prule_one_body)
                else:
                    big_left = ParallelBigSite(big_left, prule)
                    big_right = ParallelBigSite(big_right, prule)
            hamil = HamiltonianQCBigSite(vacuum, n_orbs, orb_sym, fcidump,
                                         None if n_inactive == 0 else big_left, None if n_external == 0 else big_right)
        if big_site_method is None:
            hm_zero = HamiltonianQC(vacuum, n_sites, orb_sym, fd_zero)
        elif big_site_method == "folding":
            assert dynamic_corr_method[0] in ['nevpt2sd', 'nevpt2s', "nevpt2-ijrs", "nevpt2-ij",
                "nevpt2-rs", "nevpt2-ijr", "nevpt2-rsi", "nevpt2-ir", "nevpt2-i", "nevpt2-r",
                "mrrept2s", "mrrept2sd", "mrrept2-ijrs", "mrrept2-ij",
                "mrrept2-rs", "mrrept2-ijr", "mrrept2-rsi", "mrrept2-ir", "mrrept2-i", "mrrept2-r"]
            hm_zero = HamiltonianQC(vacuum, n_orbs, orb_sym, fd_zero)
            lmpo_fold = MPOQC(hm_zero, qc_type)
            if dynamic_corr_method[0] in ['nevpt2sd', 'nevpt2s']:
                ci_order = len(dynamic_corr_method[0]) - 6
                lmps_info_fold = MRCIMPSInfo(
                    n_orbs, n_inactive, n_external, ci_order, vacuum, target, hm_zero.basis)
            elif dynamic_corr_method[0] in ['mrrept2sd', 'mrrept2s']:
                ci_order = len(dynamic_corr_method[0]) - 7
                lmps_info_fold = MRCIMPSInfo(
                    n_orbs, n_inactive, n_external, ci_order, vacuum, target, hm_zero.basis)
            else:
                if dynamic_corr_method[0].startswith('nevpt2-'):
                    sub_space = dynamic_corr_method[0][7:]
                else:
                    sub_space = dynamic_corr_method[0][8:]
                n_ex_inactive = sub_space.count('i') + sub_space.count('j')
                n_ex_external = sub_space.count('r') + sub_space.count('s')
                lmps_info_fold = NEVPTMPSInfo(n_orbs, n_inactive, n_external, n_ex_inactive, n_ex_external, vacuum,
                                              target, hm_zero.basis)
            for i in range(n_external - 1):
                _print("fold right %d / %d" % (i, n_external))
                lmpo_fold = FusedMPO(lmpo_fold, hm_zero.basis, lmpo_fold.n_sites - 2,
                                     lmpo_fold.n_sites - 1, lmps_info_fold.right_dims_fci[lmpo_fold.n_sites - 2])
                hm_zero.basis = lmpo_fold.basis
                hm_zero.n_sites = lmpo_fold.n_sites
            for i in range(n_inactive - 1):
                _print("fold left %d / %d" % (i, n_inactive))
                lmpo_fold = FusedMPO(
                    lmpo_fold, hm_zero.basis, 0, 1, lmps_info_fold.left_dims_fci[i + 2])
                hm_zero.basis = lmpo_fold.basis
                hm_zero.n_sites = lmpo_fold.n_sites
            for k, op in lmpo_fold.tensors[0].ops.items():
                smat = CSRSparseMatrix()
                if op.sparsity() > 0.75:
                    smat.from_dense(op)
                    op.deallocate()
                else:
                    smat.wrap_dense(op)
                lmpo_fold.tensors[0].ops[k] = smat
            for k, op in lmpo_fold.tensors[-1].ops.items():
                smat = CSRSparseMatrix()
                if op.sparsity() > 0.75:
                    smat.from_dense(op)
                    op.deallocate()
                else:
                    smat.wrap_dense(op)
                lmpo_fold.tensors[-1].ops[k] = smat
            lmpo_fold.sparse_form = 'S' + lmpo_fold.sparse_form[1:-1] + 'S'
            lmpo_fold.tf = TensorFunctions(
                CSROperatorFunctions(hm_zero.opf.cg))
        else:
            assert dynamic_corr_method is not None
            if dynamic_corr_method[0] in ['nevpt2sd', 'mrrept2sd']:
                xl = -2, -2, -2
                xr = 2, 2, 2
            elif dynamic_corr_method[0] in ['nevpt2s', 'mrrept2s']:
                xl = -1, -1, -1
                xr = 1, 1, 1
            elif dynamic_corr_method[0] in ["nevpt2-ijrs", "nevpt2-ij", "nevpt2-rs", "nevpt2-ijr",
                                            "nevpt2-rsi", "nevpt2-ir", "nevpt2-i", "nevpt2-r",
                                            "mrrept2-ijrs", "mrrept2-ij", "mrrept2-rs", "mrrept2-ijr",
                                            "mrrept2-rsi", "mrrept2-ir", "mrrept2-i", "mrrept2-r"]:
                # this is not correct yet
                if dynamic_corr_method[0].startswith('nevpt2-'):
                    sub_space = dynamic_corr_method[0][7:]
                else:
                    sub_space = dynamic_corr_method[0][8:]
                n_ex_inactive = sub_space.count('i') + sub_space.count('j')
                n_ex_external = sub_space.count('r') + sub_space.count('s')
                xl = -n_ex_inactive, -n_ex_inactive, -n_ex_inactive
                xr = n_ex_external, n_ex_external, n_ex_external
            else:
                assert False
            if big_site_method == "fock":
                assert "nonspinadapted" in dic
                poccl = SCIFockBigSite.ras_space(False, n_inactive, *[abs(x) for x in xl], VectorInt([]))
                poccr = SCIFockBigSite.ras_space(True, n_external, *xr, VectorInt([]))
                if '-' in dynamic_corr_method[0]:
                    maxl = max([len(x) for x in poccl])
                    poccl = VectorVectorInt([x for x in poccl if len(x) == maxl - abs(xl[0])])
                    poccr = VectorVectorInt([x for x in poccr if len(x) == abs(xr[0])])
                big_left_orig = SCIFockBigSite(n_orbs, n_inactive, False, fd_zero, orb_sym, poccl, True)
                big_right_orig = SCIFockBigSite(n_orbs, n_external, True, fd_zero, orb_sym, poccr, True)
            elif big_site_method == "csf":
                assert "nonspinadapted" not in dic
                big_left_orig = CSFBigSite(n_inactive, abs(
                    xl[-1]), False, fd_zero, orb_sym[:n_inactive])
                big_right_orig = CSFBigSite(n_external, abs(
                    xr[-1]), True, fd_zero, orb_sym[-n_external:])
            else:
                raise NotImplementedError
            big_left = SimplifiedBigSite(big_left_orig, simpl_rule)
            big_right = SimplifiedBigSite(big_right_orig, simpl_rule)
            if MPI is not None:
                if one_body_only:
                    big_left = ParallelBigSite(big_left, prule_one_body)
                    big_right = ParallelBigSite(big_right, prule_one_body)
                else:
                    big_left = ParallelBigSite(big_left, prule)
                    big_right = ParallelBigSite(big_right, prule)
            hm_zero = HamiltonianQCBigSite(vacuum, n_orbs, orb_sym, fd_zero,
                                            None if n_inactive == 0 else big_left, None if n_external == 0 else big_right)

        # left mpo
        _print("build left mpo", time.perf_counter() - tx)
        if big_site_method == "folding":
            lmpo = lmpo_fold
        else:
            lmpo = MPOQC(hm_zero, qc_type, "LQC")
        _print("simpl left mpo", time.perf_counter() - tx)
        lmpo = SimplifiedMPO(lmpo, simpl_rule, True, True,
                             OpNamesSet((OpNames.R, OpNames.RD)))
        _print("simpl left mpo finished", time.perf_counter() - tx)
        lmpo.const_e -= e_casci
        lmpo = lmpo * -1

        mpo_bdims = [None] * len(lmpo.left_operator_names)
        for ix in range(len(lmpo.left_operator_names)):
            lmpo.load_left_operators(ix)
            x = lmpo.left_operator_names[ix]
            mpo_bdims[ix] = x.m * x.n
            lmpo.unload_left_operators(ix)
        _print('LEFT MPO BOND DIMS = ', ''.join(["%6d" % x for x in mpo_bdims]))

        if MPI is None or MPI.rank == 0:
            lmpo.save_data(scratch + '/lmpo.bin')

        # right mpo
        _print("build right mpo", time.perf_counter() - tx)
        if big_site_method == "folding":
            rmpo = mpo_fold
        else:
            rmpo = MPOQC(hamil, qc_type, "RQC")
        _print("simpl right mpo", time.perf_counter() - tx)
        rmpo = SimplifiedMPO(rmpo, NoTransposeRule(simpl_rule),
                             True, True, OpNamesSet((OpNames.R, OpNames.RD)))
        _print("simpl right mpo finished", time.perf_counter() - tx)
        rmpo.const_e -= e_casci

        mpo_bdims = [None] * len(rmpo.left_operator_names)
        for ix in range(len(rmpo.left_operator_names)):
            rmpo.load_left_operators(ix)
            x = rmpo.left_operator_names[ix]
            mpo_bdims[ix] = x.m * x.n
            rmpo.unload_left_operators(ix)
        _print('RIGHT MPO BOND DIMS = ', ''.join(["%6d" % x for x in mpo_bdims]))

        if MPI is None or MPI.rank == 0:
            rmpo.save_data(scratch + '/rmpo.bin')

        if not para_no_pre_run:

            if MPI is not None:
                if one_body_only:
                    lmpo = ParallelMPO(lmpo, prule_one_body)
                    rmpo = ParallelMPO(rmpo, prule_one_body)
                else:
                    lmpo = ParallelMPO(lmpo, prule)
                    rmpo = ParallelMPO(rmpo, prule)

            _print("para left/right mpo finished", time.perf_counter() - tx)

        mps.dot = dot
        if mps.center == mps.n_sites - 1 and mps.dot == 2:
            mps.center = mps.n_sites - 2

        # bra mps
        if big_site_method is not None:
            bra_info = MPSInfo(n_sites, vacuum, target, hm_zero.basis)
        else:
            assert dynamic_corr_method[0] in ['nevpt2sd', 'nevpt2s', "nevpt2-ijrs", "nevpt2-ij",
                "nevpt2-rs", "nevpt2-ijr", "nevpt2-rsi", "nevpt2-ir", "nevpt2-i", "nevpt2-r",
                "mrrept2sd", "mrrept2s", "mrrept2-ijrs", "mrrept2-ij", "mrrept2-rs", "mrrept2-ijr",
                "mrrept2-rsi", "mrrept2-ir", "mrrept2-i", "mrrept2-r"]
            if dynamic_corr_method[0] in ['nevpt2sd', 'nevpt2s']:
                ci_order = len(dynamic_corr_method[0]) - 6
                bra_info = MRCIMPSInfo(
                    n_sites, n_inactive, n_external, ci_order, vacuum, target, hamil.basis)
            elif dynamic_corr_method[0] in ["mrrept2sd", "mrrept2s"]:
                ci_order = len(dynamic_corr_method[0]) - 7
                bra_info = MRCIMPSInfo(
                    n_sites, n_inactive, n_external, ci_order, vacuum, target, hamil.basis)
            else:
                if dynamic_corr_method[0].startswith('nevpt2-'):
                    sub_space = dynamic_corr_method[0][7:]
                else:
                    sub_space = dynamic_corr_method[0][8:]
                n_ex_inactive = sub_space.count('i') + sub_space.count('j')
                n_ex_external = sub_space.count('r') + sub_space.count('s')
                bra_info = NEVPTMPSInfo(n_orbs, n_inactive, n_external, n_ex_inactive, n_ex_external, vacuum,
                                        target, hamil.basis)

        if singlet_embedding and SX == SU2:
            left_vacuum = SX(singlet_embedding_spin, singlet_embedding_spin, 0)
            right_vacuum = vacuum
            if "full_fci_space" in dic:
                bra_info.set_bond_dimension_full_fci(left_vacuum, right_vacuum)
            else:
                bra_info.set_bond_dimension_fci(left_vacuum, right_vacuum)
        else:
            if "full_fci_space" in dic:
                bra_info.set_bond_dimension_full_fci()
        bra_info.tag = 'BRA'

        while bra_info.tag == mps_info.tag:
            bra_info.tag += 'X'
        bra_info.set_bond_dimension(bond_dims[0])
        if "skip_inact_ext_sites" in dic:
            bra_info.set_bond_dimension_inact_ext_fci(bond_dims[0], n_inactive, n_external)
        if dynamic_corr_method[0].startswith('nevpt2'):
            mrpt_method_name = 'nevpt2'
        else:
            mrpt_method_name = 'mrrept2'
        if MPI is None or MPI.rank == 0:
            bra_info.save_data(scratch + '/%s_mps_info.bin' % mrpt_method_name)
            bra_info.save_data(scratch + '/%s-mps_info.bin' % bra_info.tag)
        bra = MPS(n_sites, mps.center, dot)
        bra.initialize(bra_info)
        bra.random_canonicalize()
        bra.tensors[bra.center].normalize()
        if "skip_inact_ext_sites" in dic:
            bra.set_inact_ext_identity(n_inactive, n_external)
        bra.save_mutable()
        bra.deallocate()
        bra_info.save_mutable()
        bra_info.deallocate_mutable()

        if MPI is not None:
            MPI.barrier()
        if bra.center == 0 and bra.dot == 2:
            bra.move_left(hamil.opf.cg, prule if MPI is not None else None)
        elif bra.center == bra.n_sites - 2 and bra.dot == 2:
            bra.move_right(hamil.opf.cg, prule if MPI is not None else None)
        bra.center = mps.center
        if MPI is not None:
            MPI.barrier()

        _print("BRA MPS = ", bra.canonical_form,
               bra.center, bra.dot, bra.info.target)
        _print("BRA INIT MPS BOND DIMS = ", ''.join(
            ["%6d" % x.n_states_total for x in bra_info.left_dims]))

        tx = time.perf_counter()

        lme = MovingEnvironment(lmpo, bra, bra, "LME")
        lme.init_environments(outputlevel >= 2)
        lme.delayed_contraction = OpNamesSet.normal_ops()
        lme.cached_contraction = False
        rme = MovingEnvironment(rmpo, bra, mps, "RME")
        rme.init_environments(outputlevel >= 2)

        _print("env init finished", time.perf_counter() - tx)

        right_bdims = VectorUBond([mps_info.bond_dim + 400])
        if big_site_method is not None and n_cas != 0:
            linear = LinearBigSite(lme, rme, None, VectorUBond(
                bond_dims), right_bdims, VectorFP(noises))
            linear.last_site_svd = True
            linear.last_site_1site = dot == 2
            linear.decomp_last_site = False
        else:
            linear = Linear(lme, rme, None, VectorUBond(
                bond_dims), right_bdims, VectorFP(noises))
        linear.iprint = max(min(outputlevel, 3), 0)

        if "skip_inact_ext_sites" in dic:
            linear.sweep_start_site = n_inactive
            linear.sweep_end_site = lme.n_sites - n_external

        if "lowmem_noise" in dic:
            linear.noise_type = NoiseTypes.ReducedPerturbativeCollectedLowMem
        elif decomp_type != DecompositionTypes.SVD:
            linear.noise_type = NoiseTypes.ReducedPerturbativeCollected
        else:
            linear.noise_type = NoiseTypes.ReducedPerturbative
        linear.cutoff = float(dic.get("cutoff", 1E-14))
        linear.decomp_type = decomp_type
        linear.trunc_type = trunc_type
        linear.linear_soft_max_iter = int(
                dic.get("linear_soft_max_iter", -1))
        linear.linear_conv_thrds = VectorFP([x / 50 for x in dav_thrds])

        e_corr = linear.solve(len(bond_dims),
            mps.center == linear.sweep_start_site, sweep_tol)
        nevpt_sweep_energies = np.array(linear.targets)
        nevpt_discarded_weights = np.array(linear.discarded_weights)

        if MPI is None or MPI.rank == 0:
            bdims = bond_dims[:len(nevpt_discarded_weights)]
            if len(bdims) < len(nevpt_discarded_weights):
                bdims = bdims + bdims[-1:] * \
                    (len(nevpt_discarded_weights) - len(bdims))
            np.save(scratch + "/E_%s.npy" % mrpt_method_name, e_casci + e_corr)
            np.save(scratch + "/%s_e_corr.npy" % mrpt_method_name, e_corr)
            np.save(scratch + "/%s_bond_dims.npy" % mrpt_method_name, bdims)
            np.save(scratch + "/%s_sweep_energies.npy" % mrpt_method_name,
                    nevpt_sweep_energies)
            np.save(scratch + "/%s_discarded_weights.npy" % mrpt_method_name,
                    nevpt_discarded_weights)

        _print("DMRG-CASCI  Energy     = %20.15f" % e_casci)
        _print("DMRG-%s Correction = %20.15f" % (mrpt_method_name.upper(), e_corr))
        _print("DMRG-%s Energy     = %20.15f" % (mrpt_method_name.upper(), e_corr + e_casci))

    # MPS split
    if ("restart_copy_mps" in dic or "copy_mps" in dic) and (
            "split_states" in dic or "trans_mps_to_complex" in dic):

        if "restart_copy_mps" in dic:
            copy_tag = dic["restart_copy_mps"]
        else:
            copy_tag = dic["copy_mps"]

        if MPI is None or MPI.rank == 0:
            if "split_states" in dic:
                assert nroots != 1

                for iroot in range(nroots):
                    _print('----- root = %3d / %3d -----' % (iroot, nroots))
                    simps, simps_info, _ = split_mps(iroot, mps, mps_info, mpi=None)
                    if "trans_mps_to_complex" in dic:
                        simps = MultiMPS.make_complex(
                            simps, "%s-CPX-%d" % (mps_info.tag, iroot))
                        simps_info = simps.info
                    if copy_tag != '' and "%s-%d" % (copy_tag, iroot) != simps_info.tag:
                        final_tag = "%s-%d" % (copy_tag, iroot)
                        simps = simps.deep_copy(final_tag)
                        simps_info = simps.info
                    else:
                        final_tag = simps_info.tag
                    _print("   final tag = %s" % final_tag)
                    _print("   final canonical form = %s" %
                           simps.canonical_form)
                    simps_info.save_data(
                        scratch + '/%s-mps_info.bin' % final_tag)

            elif "trans_mps_to_complex" in dic:
                if copy_tag == '':
                    copy_tag = "%s-CPX" % mps_info.tag
                assert copy_tag != mps_info.tag
                simps = MultiMPS.make_complex(mps, copy_tag)
                simps_info = simps.info
                _print("   final tag = %s" % copy_tag)
                _print("   final canonical form = %s" % simps.canonical_form)
                simps_info.save_data(scratch + '/%s-mps_info.bin' % copy_tag)

    # MPS copy/transform
    elif "restart_copy_mps" in dic or "copy_mps" in dic:

        if "restart_copy_mps" in dic:
            copy_tag = dic["restart_copy_mps"]
        else:
            copy_tag = dic["copy_mps"]

        mps.dot = dot
        if MPI is not None:
            MPI.barrier()
        if MPI is None or MPI.rank == 0:
            if (mps.center == 0 and mps.canonical_form[0] == 'S') or \
                    (mps.center == mps.n_sites - 1 and mps.canonical_form[-1] == 'K'):
                cg = CG(200)
                cg.initialize()
                mps.flip_fused_form(mps.center, cg, None)
                mps.save_data()
        if MPI is not None:
            MPI.barrier()
        if mps.center == mps.n_sites - 1 and mps.dot == 2:
            mps.center = mps.n_sites - 2
        _print("Copy-init canonical form = ", mps.canonical_form, mps.center)

        if copy_tag == '':
            raise ValueError(
                "A tag name must be given for the keyword copy_mps/restart_copy_mps!")
        if "trans_mps_to_sz" in dic:
            assert "nonspinadapted" not in dic
            mps.info.load_mutable()
            mps.load_mutable()
            cp_mps = trans_mps(UnfusedMPS(mps), copy_tag, mpo.tf.opf.cg)
            if "resolve_twosz" in dic:
                res_twosz = int(dic["resolve_twosz"])
                cp_mps.resolve_singlet_embedding(res_twosz)
            cp_mps = cp_mps.finalize()
            dot_bk, center_bk = cp_mps.dot, cp_mps.center
            if MPI is not None:
                MPI.barrier()
            last_flipped = None
            if dot == 2:
                if cp_mps.center == 0 and cp_mps.canonical_form[cp_mps.center + 1] == 'R':
                    cp_mps.dot = 1
                elif cp_mps.center == cp_mps.n_sites - 2 and cp_mps.canonical_form[cp_mps.center] == 'L':
                    cp_mps.center = cp_mps.n_sites - 1
                    cp_mps.dot = 1
                    if cp_mps.canonical_form[cp_mps.center] in ['C', 'S']:
                        last_flipped = cp_mps.canonical_form[cp_mps.center]
                        from block2.sz import MPICommunicator as MPICommunicatorX
                        from block2.sz import CG as CGX, ParallelRuleQC as ParallelRuleQCX
                        prulex = ParallelRuleQCX(MPICommunicatorX())
                        cgx = CGX(200)
                        cgx.initialize()
                        cp_mps.canonical_form = cp_mps.canonical_form[:-1] + 'S'
                        cp_mps.flip_fused_form(cp_mps.center, cgx, prulex)
            cp_mps.info.load_mutable()
            cp_mps.load_mutable()
            cp_mps.dynamic_canonicalize()
            if "normalize_mps" in dic:
                cp_mps.tensors[cp_mps.center].normalize()
            cp_mps.dot, cp_mps.center = dot_bk, center_bk
            if MPI is not None:
                MPI.barrier()
            if MPI is None or MPI.rank == 0:
                cp_mps.info.save_mutable()
                cp_mps.save_mutable()
                cp_mps.save_data()
            if MPI is not None:
                MPI.barrier()
            if last_flipped is not None:
                cp_mps.flip_fused_form(cp_mps.n_sites - 1, cgx, prulex)
                cp_mps.canonical_form = cp_mps.canonical_form[: -
                                                              1] + last_flipped
                if MPI is None or MPI.rank == 0:
                    cp_mps.save_data()
                if MPI is not None:
                    MPI.barrier()
        else:
            cp_mps = mps.deep_copy(copy_tag)

        if "trans_mps_to_singlet_embedding" in dic or "trans_mps_from_singlet_embedding" in dic:
            assert "nonspinadapted" not in dic
            cp_mps = mps.deep_copy(copy_tag)
            if cp_mps.canonical_form[0] == 'C' and cp_mps.canonical_form[1] == 'R':
                cp_mps.canonical_form = 'K' + cp_mps.canonical_form[1:]
                cp_mps.center = 0
            elif cp_mps.canonical_form[-1] == 'C' and cp_mps.canonical_form[-2] == 'L':
                cp_mps.canonical_form = cp_mps.canonical_form[:-1] + 'S'
                cp_mps.center = cp_mps.n_sites - 1
            elif cp_mps.center == cp_mps.n_sites - 2 and cp_mps.canonical_form[-2] == 'L':
                cp_mps.center = cp_mps.n_sites - 1
            cg = CG(200)
            cg.initialize()
            while cp_mps.center > 0:
                cp_mps.move_left(cg, prule)
            if "trans_mps_to_singlet_embedding" in dic:
                cp_mps.to_singlet_embedding_wfn(cg, prule)
            if "trans_mps_from_singlet_embedding" in dic:
                cp_mps.from_singlet_embedding_wfn(cg, prule)
            if MPI is None or MPI.rank == 0:
                cp_mps.save_data()

        if MPI is None or MPI.rank == 0:
            cp_mps.info.save_data(scratch + '/mps_info.bin')
            cp_mps.info.save_data(scratch + '/%s-mps_info.bin' % copy_tag)

        mps = cp_mps
        _print("Copy-final canonical form = ", mps.canonical_form, mps.center)

    # stoptDMRG sampling (sampling CSF is not supported yet)
    if "stopt_sampling" in dic:

        if fcidump is None:
            raise ValueError(
                "Need FCIDUMP construction (namely, not a pre run) for 'stoptDMRG'!")
        if MPI is not None:
            MPI.barrier()

        try:
            from pyblock2.driver.stopt import SPDMRG
        except ImportError:
            from stopt import SPDMRG

        nsample = int(dic.get("stopt_sampling", 10000))
        sp_dmrg = SPDMRG(su2="nonspinadapted" not in dic,
                         scratch=scratch, fcidump=fcidump, mps_tags=mps_tags, verbose=outputlevel)
        e_corr, std_corr = sp_dmrg.kernel(nsample)

        if MPI is None or MPI.rank == 0:
            np.save(scratch + "/E_stopt.npy", sp_dmrg.Edmrg + e_corr)
            np.save(scratch + "/stopt_e_corr.npy", e_corr)
            np.save(scratch + "/stopt_std_corr.npy", std_corr)

        _print("            DMRG Energy     = %25.15f" % sp_dmrg.Edmrg)
        _print("stochastic PDMRG Correction = %25.15f (%20.15f)" %
               (e_corr, std_corr))
        _print("stochastic PDMRG Energy     = %25.15f (%20.15f)" %
               (sp_dmrg.Edmrg + e_corr, std_corr))

    # CSF/DET coefficients
    if "restart_sample" in dic or "sample" in dic:

        if "restart_sample" in dic:
            sample_cutoff = float(dic.get("restart_sample", 0))
        else:
            sample_cutoff = float(dic.get("sample", 0))

        if "trans_mps_to_sz" in dic:
            from block2.sz import DeterminantTRIE, UnfusedMPS
            su2mps = False
        else:
            su2mps = SX == SU2

        if "sample_reference" not in dic:
            max_rank = nelec[0]
            sample_ref = []
        else:
            sample_params = dic.get("sample_reference", 0).split()
            max_rank = int(sample_params[0])
            sample_ref = sample_params[1]
            assert len(sample_ref) == n_sites
            sample_ref = [ int(i) for i in sample_ref ]
            nelec_ref = 0
            for occ in sample_ref:
                if occ == 3: nelec_ref += 2 
                elif occ == 1 or occ == 2: nelec_ref += 1 
            assert nelec_ref == nelec[0] 

        tx = time.perf_counter()
        dtrie = DeterminantTRIE(n_sites, True)
        mps.info.load_mutable()
        mps.load_mutable()
        dtrie.evaluate(UnfusedMPS(mps), sample_cutoff, max_rank, VectorUInt8(sample_ref))

        if "sample_phase" in dic:
            ref_idx = dic["sample_phase"].split()
            if len(ref_idx) == 0:
                if orb_idx is None:
                    ref_idx = [] 
                else:
                    ref_idx = orb_idx
            else:
                ref_idx = [int(i) for i in ref_idx]
                assert len(ref_idx) == n_sites
                check_idx = [i for i in ref_idx if i >= 0 and i < n_sites] 
                assert len(set(check_idx)) == n_sites
            dtrie.convert_phase(VectorInt(ref_idx))
        _print("dtrie finished", time.perf_counter() - tx)

        if MPI is None or MPI.rank == 0:
            dname = "CSF" if su2mps else "DET"
            _print("Number of %s = %10d (cutoff = %9.5g, max_rank = %3d)" %
                   (dname, len(dtrie), sample_cutoff, max_rank))
            ddstr = "0+-2" if su2mps else "0ab2"
            dvals = np.array(dtrie.vals)
            gidx = np.argsort(np.abs(dvals))[::-1][:50]
            _print("Sum of weights of sampled %s = %20.15f\n" %
                   (dname, (dvals ** 2).sum()))
            for ii, idx in enumerate(gidx):
                det = ''.join([ddstr[x] for x in dtrie[idx]])
                val = dvals[idx]
                _print(dname, "%10d" % ii, det, " = %20.15f" % val)
            if len(dvals) > 50:
                _print(" ... and more ... ")
            np.save(scratch + "/sample-vals.npy", dvals)
            dets = np.zeros((len(dtrie), n_sites), dtype=np.uint8)
            for i in range(len(dtrie)):
                dets[i] = np.array(dtrie[i])
            np.save(scratch + "/sample-dets.npy", dets)

            nq = 4 if "use_general_spin" not in dic else 2
            state_occ = np.array(
                dtrie.get_state_occupation()).reshape(n_sites, nq)
            # state_occ += ((1 - state_occ.sum(axis=1)) / 4)[:, None]
            state_occ *= (1 / state_occ.sum(axis=1))[:, None]
            _print("STATE OCC = ", "".join(
                ["%8.5f" % x for x in state_occ.flatten()]))
            np.save(scratch + "/sample-stocc.npy", state_occ)
