#!/usr/bin/env python3

import argparse
import numpy as np
import sys
import os.path as path
import PloidPy.process_bam as pb
import PloidPy.ploidy_model as pm
import PloidPy.plot_read_data as plot
import PloidPy.nbinom as nb


def save_tsv(aicllh, pld, out, uniform_com, bin_err):
    headers = ["Ploidy", "Log_Likelihood", "AIC", "Het_Weights"]
    if uniform_com:
        headers.append("Uniform_Weight")
    if bin_err:
        headers.append("Binomial_Err_Weight")
    out.write("\t".join(headers) + "\n")
    extra_params = uniform_com + bin_err
    for i in range(len(pld)):
        w = str(aicllh[2][i])[1:-1].split()
        if extra_params == 2:
            out.write("%s\t%s\t%s\t%s\t%s\t%s\n" % (pld[i], aicllh[0][i],
                                                    aicllh[1][i],
                                                    ",".join(w[:-2]), w[-1],
                                                    w[-2]))
        elif extra_params == 1:
            out.write("%s\t%s\t%s\t%s\t%s\n" % (pld[i], aicllh[0][i],
                                                aicllh[1][i], ",".join(w[:-1]),
                                                w[-1]))
        else:
            out.write("%s\t%s\t%s\t%s\n" % (pld[i], aicllh[0][i], aicllh[1][i],
                                            ",".join(w)))


if __name__ == '__main__':
    desc = "Toolkit for ploidy evaluation"
    parser = argparse.ArgumentParser(prog="PloidPy", description=desc)
    subparsers = parser.add_subparsers(dest='subparser')

    # help descriptions
    bam_desc = "BAM file"
    count_desc = "file containing MAC and TRC values"
    bed_desc = "coordinates for areas to be assessed in BED format"
    qty_desc = "minimum Phred mapping quality (inclusive)"
    b_desc = "minimum base quality"

    process_bam = subparsers.add_parser("process_bam")
    process_bam.add_argument("--bam", required=True, help=bam_desc)
    process_bam.add_argument("--out", required=True, help=count_desc)
    process_bam.add_argument("--bed", default=False, help=bed_desc)
    process_bam.add_argument("--quality", default=15, type=int, help=qty_desc)
    process_bam.add_argument("--base_qual", default=15, type=int, help=b_desc)
    process_bam.add_argument("--max_np", type=int, default=200000000,
                             help="maximum NumPy array size")

    denoise = subparsers.add_parser("filter")
    denoise.add_argument("--count_file", required=True, help=count_desc)
    denoise.add_argument("--out", required=True, help="filtered count file")
    denoise.add_argument("--pthresh", default=1.0, type=float,
                         help="posterior probability threshold")

    jdist = subparsers.add_parser("jdist")
    jdist.add_argument("--count_file", required=True, help=count_desc)
    jdist.add_argument("--out", required=True, help="output image")

    assess = subparsers.add_parser("assess")
    assess.add_argument("--count_file", required=True, help=count_desc)
    assess.add_argument("--min_cov", type=int, default=1,
                        help="minimum TRC to be evaluated")
    assess.add_argument("--max_cov", type=int, default=0,
                        help="maximum TRC to be evaluated")
    assess.add_argument("--ploidies", nargs='+', type=int, required=True,
                        help="Ploidy models to be compared.")
    assess.add_argument("--out", default=None,
                        help="statistics table in tab-separated format")
    assess.add_argument("--uniform_weight", default=False,
                        type=lambda x: (str(x).lower() == 'true'),
                        help="[bool] True if weights are set to uniform")
    assess.add_argument("--uniform_com", default=True,
                        type=lambda x: (str(x).lower() == 'true'),
                        help="""[bool] True if the uniform component of the
                        mixture model is included""")

    args = parser.parse_args()

    if args.subparser == 'process_bam':
        print("Now processing BAM file %s..." % args.bam)
        if args.bed:
            print("\tprocessing subset found in %s" % args.bed)
        pb.get_biallelic_coverage(args.bam, args.out, args.bed, args.quality,
                                  args.base_qual, args.max_np)
        print("Success!")
    elif args.subparser == 'filter':
        print("Filtering count file %s..." % args.count_file)
        out = args.out if args.out[-3:] == ".gz" else args.out + ".gz"
        if path.exists(out):
            raise IOError("File %s exists. Please choose another output name."
                          % out)
        else:
            np.savetxt(out, pb.denoise_reads(args.count_file, args.pthresh)[1],
                       fmt='%d')
            print("Filtered data sucessfully saved in %s" % out)
    elif args.subparser == 'jdist':
        cnts = np.loadtxt(args.count_file)
        plot.plot_joint_dist(cnts, args.out + ".pdf")
        print("Files saved in %s.pdf!" % args.out)
    elif args.subparser == 'assess':
        cnts = np.loadtxt(args.count_file)
        if not args.max_cov == 0:
            cnts = cnts[np.logical_and(cnts[:, 1] >= args.min_cov,
                                       cnts[:, 1] <= args.max_cov)]
        else:
            cnts = cnts[cnts[:, 1] >= args.min_cov]
        r, p_nb = nb.fit_nbinom(cnts[:, 1])
        pld = np.array(args.ploidies)
        info_file = (args.count_file[:-3] + ".info"
                     if args.count_file[-3:] == ".gz"
                     else args.count_file + ".info")
        p_err = 0
        if path.exists(info_file):
            info = open(info_file)
            p_err = float(info.readline().strip().split("\t")[1])
            info.close()
        aicllh = pm.get_Log_Likelihood_AIC(cnts, pld, r, p_nb, p_err,
                                           args.uniform_weight,
                                           args.uniform_com)
        if args.out:
            if path.exists(args.out):
                raise IOError(
                    "File %s exists. Please choose another output name."
                    % args.out)
            f = open(args.out, "w+")
            save_tsv(aicllh, pld, f, args.uniform_com, not p_err == 0)
            f.close()
        else:
            save_tsv(aicllh, pld, sys.stdout, args.uniform_com, not p_err == 0)
        print("The most likely model is %s-ploid." %
              pm.select_model(aicllh[1], pld))
    else:
        print("error: one of the following subcommands is required: " +
              "process_bam, filter, jdist, assess")
