#!/usr/bin/env python3

import sys
import json
import click
import delegator
import pandas as pd
from loguru import logger
from pathlib import Path
from snpScore import sample_and_group_for_web
from snpScore import async_batch_sh_jobs
from snpScore import snpScoreBox, qtlSeqr, CHR_SIZE, CHR_WINDOW
from snpScore import outdir_suffix_from_params
from snpScore import merge_split_file, score_plot, wrap_param_arg
from snpScore import circos_suffix, circos_cfg, circos_plot
from snpScore import split_qtlseqr_results, snp_density_stats
from snpScore import (
    add_default_params,
    params_cfg,
    is_new_cmd,
    cp_if_not_exist,
    make_chr_window,
)
from snpScore import VarScoreOutDirName
from snpScore import VarScoreDocName

BY_CHR_BIN = "snpScore-bychr"


@click.command()
@click.option(
    "-p",
    "--parameters",
    help="snpScore parameters json string.",
    required=True,
    type=click.STRING,
)
@click.option(
    "--vcf_dir",
    help=(
        "vcf table directory, can be more than one " "[--vcf_dir dir1 --vcf_dir dir2]."
    ),
    type=click.Path(exists=True, file_okay=False),
    multiple=True,
)
@click.option(
    "--snpeff_cfg",
    help="snpEff config file.",
    type=click.Path(exists=True, dir_okay=False),
)
@click.option("--snpeff_db", help="snpEff database name.", type=click.STRING)
@click.option(
    "--chr_size",
    help="chr size file, default is chr size of bread wheat.",
    type=click.Path(exists=True, dir_okay=False),
    default=CHR_SIZE,
)
@click.option(
    "-o", "--outdir", help="results directory.", required=True, type=click.Path()
)
@click.option("-t", "--thread", help="paralle number", default=4, type=click.INT)
@click.option("--circos", help="generate circos plot", is_flag=True)
@click.option("--plant", help="for general plant", is_flag=True)
def main(
    parameters, vcf_dir, snpeff_cfg, snpeff_db, chr_size, outdir, thread, circos, plant
):
    chr_size_df = pd.read_csv(chr_size, sep="\t", header=None, index_col=0)
    cmd_list = []

    parameters_obj = json.loads(parameters)
    sample_list, group_list = sample_and_group_for_web(parameters_obj)

    input_params = add_default_params(parameters_obj)

    outdir = Path(outdir)
    outdir = outdir.resolve()
    if plant:
        full_outdir = outdir
    else:
        outdir_suffix = outdir_suffix_from_params(parameters_obj)
        full_outdir = outdir / outdir_suffix
    results_dir = full_outdir / "analysis"
    cmd_history_dir = results_dir / "history"
    download_dir = full_outdir / "varBscore-results"
    if not is_new_cmd(input_params, cmd_history_dir):
        print(full_outdir)
        return

    for chr_i in chr_size_df.index:
        param_wrap_arg = wrap_param_arg(sys.argv[1:])
        arg_line = " ".join(param_wrap_arg)
        arg_line = arg_line.replace("--circos", "")
        cmd_i = f"{BY_CHR_BIN} {arg_line} --chrom {chr_i}"
        cmd_list.append(cmd_i)

    logger.info("Calculating score ...")
    async_batch_sh_jobs(cmd_list, thread=thread)
    logger.info("Ploting score ...")
    snp_score_methods = ["var"]

    snpscore_obj = snpScoreBox(
        alt_freq_df=None,
        snpEff_cfg=snpeff_cfg,
        snpEff_db=snpeff_db,
        grp_list=group_list,
        method_list=snp_score_methods,
        outdir=results_dir,
        chr_size=chr_size,
        min_depth=input_params.get("min_depth"),
        snp_number_window=input_params.get("snp_number_window"),
        snp_number_step=input_params.get("snp_number_step"),
        ref_freq=input_params.get("ref_freq"),
        p_ref_freq=input_params.get("p_ref_freq"),
        background_ref_freq=input_params.get("background_ref_freq"),
        mutant_alt_exp=input_params.get("mutant_alt_exp"),
        wild_alt_exp=input_params.get("wild_alt_exp"),
        filter_method=input_params.get("filter_method", "nonsymmetrical"),
    )

    plot_cmd = []

    # varBScore
    varBScore_dir = download_dir / VarScoreOutDirName.var_score.value
    snp_score_file_pattern = f"{snpscore_obj.score_prefix}.var.score.csv"
    snp_score_file = merge_split_file(
        results_dir, snp_score_file_pattern, out_dir=varBScore_dir
    )
    plot_cmd.append(
        score_plot(snp_score_file, "var", snp_score_file.stem, chr_size, "web")
    )
    snp_score_ann_file_pattern = f"{snpscore_obj.score_prefix}.var.score.ann.csv"
    snp_score_ann_file = merge_split_file(
        results_dir, snp_score_ann_file_pattern, out_dir=varBScore_dir
    )

    # SNP DENSITY
    snp_freq_dir = download_dir / VarScoreOutDirName.snp_density.value
    snp_freq_file = merge_split_file(
        results_dir, f"{snpscore_obj.group_label}.snp.freq.csv", out_dir=snp_freq_dir
    )

    # SNP DENSITY bed
    snp_freq_bed = merge_split_file(
        results_dir,
        f"{snpscore_obj.group_label}.snp.plot.bed",
        input_header=None,
        input_sep="\t",
        out_header=None,
        out_sep="\t",
    )

    snp_freq_stats = snp_freq_file.with_suffix(".stats.csv")
    if plant:
        snp_density_window = input_params.get("snp_density_window")
        snp_density_step = input_params.get("snp_density_step")
        chr_window = make_chr_window(
            chr_size, snp_density_window, snp_density_step, results_dir
        )
    else:
        chr_window = CHR_WINDOW

    snp_density_stats(chr_window, snp_freq_bed, snp_freq_stats)
    plot_cmd.append(
        score_plot(snp_freq_stats, "density-new", snp_freq_stats.stem, chr_size, "web")
    )

    # QTLSEQR
    qtlseqr_obj = qtlSeqr(
        input_table="test",
        window=input_params.get("qtlseqr_window", 1e7),
        ref_freq=input_params.get("qtlseqr_ref_freq"),
        pop_stru=input_params.get("pop_stru"),
        min_sample_dp=input_params.get("qtlseqr_min_depth"),
        out_dir=results_dir,
        run_qtlseqr=input_params.get("qtlseqr"),
        run_ed=input_params.get("ed"),
        web=True,
    )
    qtlseqr_file = merge_split_file(results_dir, qtlseqr_obj.filePath.name)
    qtlSeqrDir = download_dir / VarScoreOutDirName.qtlseqr.value
    qtlSeqrDir.mkdir(parents=True, exist_ok=True)
    edDir = download_dir / VarScoreOutDirName.ed.value
    edDir.mkdir(parents=True, exist_ok=True)

    qtlSeqrDirFile = qtlSeqrDir / qtlseqr_obj.qtlseqrFileName
    edFile = edDir / qtlseqr_obj.edFileName

    split_qtlseqr_results(qtlseqr_file, qtlSeqrDirFile, edFile)
    plot_cmd.append(
        score_plot(qtlSeqrDirFile, "Gprime", qtlSeqrDirFile.stem, chr_size, "web")
    )
    plot_cmd.append(score_plot(edFile, "ED", edFile.stem, chr_size, "web"))
    plot_cmd.append(
        score_plot(qtlSeqrDirFile, "snpIndex", qtlSeqrDirFile.stem, chr_size, "web")
    )
    if circos:
        circos_name = circos_suffix(
            snpscore_obj.score_prefix, qtlseqr_obj.filePath.stem
        )
        circos_dir = results_dir / "circos_data" / circos_name
        circos_outdir = download_dir / VarScoreOutDirName.circos.value
        circos_outdir.mkdir(parents=True, exist_ok=True)
        cp_if_not_exist(VarScoreDocName.circos.value, circos_outdir)
        circos_plot_path = circos_cfg(circos_dir, circos_path=circos_outdir)
        if not circos_plot_path.is_file():
            plot_cmd.append(
                circos_plot(snp_score_file, qtlseqr_file, snp_freq_file, circos_dir)
            )
    plot_cmd = list(filter(None, plot_cmd))
    async_batch_sh_jobs(plot_cmd, thread=thread)

    # cp readme
    cp_if_not_exist(VarScoreDocName.var_score.value, varBScore_dir)
    cp_if_not_exist(VarScoreDocName.snp_density.value, snp_freq_dir)
    cp_if_not_exist(VarScoreDocName.qtlseqr.value, qtlSeqrDir)
    cp_if_not_exist(VarScoreDocName.ed.value, edDir)

    # write params config
    params_cfg_file = download_dir / "parameters.txt"
    params_cfg(params_cfg_file, input_params, cmd_history_dir)

    # compress results
    if plant:
        zip_cmd = f"cd {outdir}; zip -r {download_dir.name}.zip {download_dir.name}"
        delegator.run(zip_cmd)
    logger.info("Fin.")
    print(full_outdir)


if __name__ == "__main__":
    main()
