#!/usr/bin/env python

import ulysses as uls


if __name__=="__main__":
    import optparse, os, sys
    op = optparse.OptionParser(usage=__doc__)
    op.add_option("-o", "--output",    dest="OUTPUT",      default="nestout", type=str, help="Prefix for outputs (default: %default)")
    op.add_option("-v", "--debug",     dest="DEBUG",       default=False, action="store_true", help="Turn on some debug messages")
    op.add_option("-q", "--quiet",     dest="QUIET",       default=False, action="store_true", help="Turn off messages")
    op.add_option("-m", "--model",     dest="MODEL",       default="1DME", help="Selection of of model (default: %default)")
    op.add_option("--zrange",          dest="ZRANGE",      default="0.1,100,1000", help="Ranges and steps of the evolution variable (default: %default)")
    op.add_option("--inv",             dest="INVORDERING", default=False, action='store_true', help="Use inverted mass ordering (default: %default)")
    op.add_option("--loop",            dest="LOOP",        default=False, action='store_true', help="Use loop-corrected Yukawa (default: %default)")
    op.add_option("--sigma",           dest="SIGMA",       default=1, type=int,              help="Switch to inflate the error in likelihood (default: %default)")
    op.add_option("--mn-seed",         dest="SEED",        default=-1, type=int,              help="Multinest seed (default: %default)")
    op.add_option("--mn-resume",       dest="RESUME",      default=False, action='store_true', help="Resume on previous run.")
    op.add_option("--mn-multi-mod",    dest="MULTIMODE",   default=False, action='store_true', help="Set multimodal to true.")
    op.add_option("--mn-update",       dest="UPDATE",      default=1000, type=int, help="Update inteval (default: %default iterations)")
    op.add_option("--mn-tol",          dest="TOL",         default=0.5, type=float, help="Evidence tolerance (default: %default)")
    op.add_option("--mn-eff",          dest="EFF",         default=0.8, type=float, help="Sampling efficiency (default: %default)")
    op.add_option("--mn-points",       dest="POINTS",      default=400, type=int,              help="Number of live points in PyMultinest (default: %default)")
    op.add_option("--mn-imax",         dest="ITMAX",       default=0, type=int, help="Max number of iterations PyMultinest, 0 is infinite (default: %default)")
    op.add_option("--mn-multimodal",   dest="MULTIMODAL",  default=False, action='store_true', help="Run in multimodal mode.")
    op.add_option("--mn-no-importance",dest="NOIMPSAMP",   default=False, action='store_true', help="Turn off importance sampling.")
    opts, args = op.parse_args()

    # Disect the zrange string
    zmin, zmax, zsteps = opts.ZRANGE.split(",")
    zmin=float(zmin)
    zmax=float(zmax)
    zsteps=int(zsteps)

    assert(zmin<zmax)
    assert(zsteps>0)

    pfile, gdict = uls.tools.parseArgs(args)

    LEPTO = uls.selectModel(opts.MODEL,
            zmin=zmin, zmax=zmax, zsteps=zsteps,
            ordering=int(opts.INVORDERING),
            loop=opts.LOOP,
            debug=opts.DEBUG,
            **gdict
            )

    # Read parameter card and very explicit checks on parameter names
    RNG, FIX, isCasIb = uls.readConfig(pfile)
    if isCasIb:
        if len(FIX) + len(RNG) != len(LEPTO.pnames):
            print("Error, the number of parameters needs to be {}, user supplied {}, exiting".format(len(LEPTO.pnames), len(FIX)+len(RNG)))
            sys.exit(1)

            for p in FIX.keys():
                if not p in LEPTO.pnames:
                    print("Parameter {} in input file {} not recognised, exiting".format(p, args[0]))
                    sys.exit(1)

            for p in RNG.keys():
                if not p in LEPTO.pnames:
                    print("Parameter {} in input file {} not recognised, exiting".format(p, args[0]))
                    sys.exit(1)
    else:
        print("beware!")
        LEPTO.isCasasIbarrra = False

    # Number of dimensions our problem has
    NP = len(RNG.keys())

    if NP==0:
        print("No floating parameters specified in {}, exiting.".format(args[0]))
        sys.exit(1)


    # Output directory
    rank=0
    try:
        from mpi4py import MPI
        comm = MPI.COMM_WORLD
        size = comm.Get_size()
        rank = comm.Get_rank()
        if opts.DEBUG:
            print("[%i]/[%i] reporting for duty"%(rank, size))
    except Exception as e:
        print("MPI (mpi4py) is not available (try e.g. pip install mpi4py), proceeding serially:", e)
        size = 1

    try:
        os.makedirs(opts.OUTPUT)
    except:
        pass

    if rank == 0: print("[%i] Writing output to"%rank, opts.OUTPUT)
    # Paramter names
    PNAMES = list(RNG.keys())

    # Scaling info
    PMIN   = [x[0] for x in list(RNG.values())]
    PMAX   = [x[1] for x in list(RNG.values())]
    PLEN   = [PMAX[i] - PMIN[i] for i in range(len(PMAX))]

    def scaleParam(p, idx):
        return PMIN[idx] + p * PLEN[idx]

    def myprior(cube, ndim, nparams):
        for i in range(ndim):
            cube[i] = scaleParam(cube[i], i)

    def loglike(cube, ndim, nparams):
        PP=[cube[j] for j in range(ndim)]
        pdict=FIX.copy()
        pdict["ordering"]=int(opts.INVORDERING)

        for num, p in enumerate(PNAMES): pdict[p]=cube[num]

        if opts.DEBUG: print( "Testing point", PP)

        LEPTO.setParams(pdict)
        pert = LEPTO.isPerturbative
        if not pert:
            if opts.DEBUG: print( "Vetoing point", PP)
            return -1e101

        ymodel = LEPTO(pdict)

        if opts.DEBUG: print("response", ymodel)
        import sys
        sys.stdout.flush()

        ydata = 6.10e-10
        yerr  = 0.04e-10*opts.SIGMA

        loglikelihood = -0.5 * ((ymodel - ydata) / yerr)**2
        return loglikelihood

    # Run MultiNest
    try:
        import pymultinest
    except ImportError:
        msg="""
        Could not load the multinest module for python.
        MultiNest requires libmultinest.so to be present in the
        LD_LIBRARY_PATH.
        The python package is easiest installed with pip or conda. Detailed
        installation instructions can be found here:
            https://johannesbuchner.github.io/PyMultiNest/ """
        print(msg)
        import sys
        sys.exit(1)


    pymultinest.run(loglike, myprior, NP,
            importance_nested_sampling = not opts.NOIMPSAMP,
            verbose = False if opts.QUIET else True,
            multimodal=opts.MULTIMODAL,
            resume=opts.RESUME,
            n_iter_before_update=opts.UPDATE,
            evidence_tolerance=opts.TOL,
            sampling_efficiency = opts.EFF,
            init_MPI=False,
            n_live_points = opts.POINTS,
            max_iter=opts.ITMAX,
            seed=opts.SEED,
            outputfiles_basename='%s/ULSNEST'%opts.OUTPUT)

    try:
        from mpi4py import MPI
        comm = MPI.COMM_WORLD
        size = comm.Get_size()
        rank = comm.Get_rank()
        if rank==0: print("Rank:", rank, "Size:", size)
    except:
        pass

    if rank==0:
        print()
        print("Now analyzing output from {}/ULSNEST.txt".format(opts.OUTPUT))
        a = pymultinest.Analyzer(n_params = NP, outputfiles_basename='%s/ULSNEST'%opts.OUTPUT)
        a.get_data()
        try:
            s = a.get_stats()
        except:
            print("There was an error accumulating statistics. Try increasing the number of iterations, e.g. --mn-iterations -1")
            sys.exit(1)

        from collections import OrderedDict
        resraw = a.get_best_fit()["parameters"]
        PP=OrderedDict.fromkeys(PNAMES)
        for num, pname in enumerate(PNAMES): PP[pname] = resraw[num]
        for k in FIX:                            PP[k] = FIX[k]

        LEPTO.setParams(PP)
        BESTVAL=LEPTO.EtaB

        out="# Model: {}".format(opts.MODEL)
        if opts.LOOP: out+="# with loop correction to neutrino tree-level mass\n"
        out+="\n"
        if opts.INVORDERING: out+="# Mass hierarchy: inverted\n"
        else:                out+="# Mass hierarchy: normal\n"

        out+="# Best fit value: %e\n"%BESTVAL
        out+="# Best fit point:\n"
        for k in PP:
            if not k in FIX: out+= "%s %.16f\n"%(k,PP[k])

        out+="# Fixed parameters were:\n"
        # The try-except block deals with complex parameters
        for k in FIX:
            try:
                out+= "%s %f\n"%(k, FIX[k])
            except:
                out+= "%s %f %f i\n"%(k, FIX[k].real, FIX[k].imag)
        with open("%sconfig.best"%a.outputfiles_basename, "w") as f: f.write(out)

        import json
        # store name of parameters, always useful
        with open('%sparams.json' % a.outputfiles_basename, 'w') as f: json.dump(PNAMES, f, indent=4)
        with open('%sparams.info' % a.outputfiles_basename, 'w') as f:
            for p in PNAMES: f.write("%s\n"%p)
        # store derived stats
        with open('%sstats.json' % a.outputfiles_basename, mode='w') as f: json.dump(s, f, indent=4)
        print()
        print("-" * 30, 'ANALYSIS', "-" * 30)
        print("Global Evidence:\n\t%.15e +- %.15e" % ( s['nested sampling global log-evidence'], s['nested sampling global log-evidence error'] ))

        print("Done!")
        import sys
        sys.exit(0)
