#!/usr/bin/env python3

import click
import pandas as pd
from loguru import logger
from pathlib import Path
from snpScore import tableFromVcf, sample_and_group
from snpScore import tableFromSelectTable
from snpScore import snpTable, async_batch_sh_jobs
from snpScore import snpFilterBox, qtlSeqr, CHR_SIZE
from snpScore import score_plot
from snpScore import cp_files
from snpScore import split_qtlseqr_results, snp_density_stats
from snpScore import VarScoreDocName, VarScoreOutDirName
from snpScore import format_outfile, make_chr_window
from snpScore import window_number_format, add_snp_ann, check_output
from snpScore import var_density_stats, var_density_file_suffix


@click.command()
@click.option('--vcf_file',
              help=('vcf file path, can be more than one '
                    '[--vcf_file vcf1 --vcf_file vcf2]'),
              type=click.Path(exists=True, dir_okay=False),
              multiple=True)
@click.option('--vcf_table',
              help=('vcf table path, can be more than one '
                    '[--vcf_table vcf1 --vcf_table vcf2]'),
              type=click.Path(exists=True, dir_okay=False),
              multiple=True)
@click.option('--vcf_dir',
              help=('vcf table directory, can be more than one '
                    '[--vcf_dir dir1 --vcf_dir dir2].'),
              type=click.Path(exists=True, file_okay=False),
              multiple=True)
@click.option('-m',
              '--mutant',
              help='mutant sample ids, seperated with comma.',
              type=click.STRING,
              required=True)
@click.option('-w',
              '--wild',
              help='wild sample ids, seperated with comma.',
              type=click.STRING,
              required=True)
@click.option('-o',
              '--outdir',
              help='results directory.',
              required=True,
              type=click.Path())
@click.option('-t',
              '--thread',
              help='Max thread for this program to use.',
              default=4)
@click.option('-mp',
              '--mutant_parent',
              help='mutant parent sample ids, seperated with comma.',
              type=click.STRING,
              default='')
@click.option('-wp',
              '--wild_parent',
              help='wild parent sample ids, seperated with comma.',
              type=click.STRING,
              default='')
@click.option('--vcf_ann_file',
              help='snp annotation file.',
              type=click.Path(exists=True, dir_okay=False),
              required=True)
@click.option('--chr_size',
              help='chr size file, default is chr size of bread wheat.',
              type=click.Path(exists=True, dir_okay=False),
              default=CHR_SIZE)
@click.option('--snp_density_window',
              help='window size to calculate snp density.',
              type=click.INT,
              default=1000_000)
@click.option('--snp_density_step',
              help='window step to calculate snp density.',
              type=click.INT,
              default=None)
@click.option('--vcf_split_dir',
              help=('Directory to save snp pickle file for each sample.'
                    'Default is outdir/pickle'),
              default=None)
@click.option('--mut_freq',
              help=('filter out mutant SNPs with a '
                    'Reference Allele Frequency '
                    'less than refAlleleFreq '
                    'and greater than 1 - refAlleleFreq.'),
              default=0.4,
              type=click.FLOAT)
@click.option('--wild_freq',
              help=('filter out wild SNPs with a '
                    'Reference Allele Frequency '
                    'less than refAlleleFreq '
                    'and greater than 1 - refAlleleFreq.'),
              default=0.4,
              type=click.FLOAT)
@click.option('--p_mut_freq',
              help=('filter out mutant parent SNPs with a '
                    'Reference Allele Frequency '
                    'less than refAlleleFreq '
                    'and greater than 1 - refAlleleFreq.'),
              default=0,
              type=click.FLOAT)
@click.option('--p_wild_freq',
              help=('filter out wild parent SNPs with a '
                    'Reference Allele Frequency '
                    'less than refAlleleFreq '
                    'and greater than 1 - refAlleleFreq.'),
              default=0,
              type=click.FLOAT)
@click.option('--afd',
              help='mutant & wild afd cutoff.',
              type=click.FLOAT,
              required=True)
@click.option('--afd_deviation',
              help='mutant & wild afd deviation.',
              type=click.FLOAT,
              default=0.05)
@click.option('--p_afd',
              help='mutant parent & wild parent afd cutoff.',
              type=click.FLOAT,
              default=1)
@click.option('--p_afd_deviation',
              help='mutant parent & wild parent afd deviation.',
              type=click.FLOAT,
              default=0.05)
@click.option('--min_depth',
              help='minimal read depth for a site to include in analysis.',
              default=5,
              type=click.INT)
def main(vcf_file, vcf_table, vcf_dir, mutant, wild, outdir, thread,
         vcf_split_dir, mutant_parent, wild_parent, mut_freq, wild_freq,
         p_mut_freq, p_wild_freq, afd, afd_deviation, p_afd, p_afd_deviation,
         min_depth, vcf_ann_file, chr_size, snp_density_window,
         snp_density_step):
    if vcf_file or vcf_table:
        vcf_dir = list(vcf_dir)
        if vcf_split_dir is None:
            vcf_split_dir = Path(outdir) / 'pickle'
        for vcf_i in vcf_file:
            vcf2tb_obj = tableFromVcf(vcf=vcf_i,
                                      out_dir=vcf_split_dir,
                                      thread=thread)
            vcf2tb_obj.make_table
            vcf_dir.append(vcf_split_dir)
        for st_i in vcf_table:
            vcf2tb_obj = tableFromSelectTable(vcf=st_i,
                                              out_dir=vcf_split_dir,
                                              thread=thread)
            vcf2tb_obj.make_table
            vcf_dir.append(vcf_split_dir)

    sample_list, group_list = sample_and_group(mutant, wild, mutant_parent,
                                               wild_parent)
    filter_grp = list(set(group_list))
    snp_table_obj = snpTable(out_dir=outdir,
                             table_dirs=vcf_dir,
                             samples=sample_list,
                             sample_label=group_list,
                             min_depth=min_depth,
                             filter_dp_grp=filter_grp)

    snpFilter_obj = snpFilterBox(alt_freq_df=snp_table_obj.alt_freq_df,
                                 outdir=outdir,
                                 min_depth=min_depth,
                                 mutant_freq=mut_freq,
                                 wild_freq=wild_freq,
                                 pat_mutant_freq=p_mut_freq,
                                 pat_wild_freq=p_wild_freq,
                                 afd=afd,
                                 afd_deviation=afd_deviation,
                                 parent_afd=p_afd,
                                 parent_afd_deviation=p_afd_deviation)
    snpFilter_obj.alt_filter_freq_df

    # 整理结果
    chr_df = pd.read_csv(chr_size,
                         sep='\t',
                         header=None,
                         names=['chrom', 'size'])
    outPath = Path(outdir)
    outdir_name = outPath.name
    results_dir = outPath / f'{outdir_name}-results'
    plot_cmds = []

    snp_ann_df = pd.read_pickle(vcf_ann_file)
    snp_ann_df.loc[:, '#CHROM'] = snp_ann_df['#CHROM'].astype('str')
    # snpDensity results
    logger.info('copy snpDensity files...')
    results_dir.mkdir(parents=True, exist_ok=True)

    fmt_var_filter_file = format_outfile(snpFilter_obj.alt_filter_freq_file,
                                         results_dir,
                                         ann_df=snp_ann_df,
                                         chr_list=chr_df.chrom.astype('str'))

    fmt_var_filter_df = pd.read_csv(fmt_var_filter_file)
    var_filter_density_df = var_density_stats(
        chr_size,
        fmt_var_filter_df,
        window=snp_density_window,
        step=snp_density_step,
    )
    var_filter_density_suffix = var_density_file_suffix(
        snp_density_window, snp_density_step)
    var_filter_density_file = results_dir / f'{fmt_var_filter_file.stem}{var_filter_density_suffix}'
    var_filter_density_df.to_csv(var_filter_density_file, index=False)
    plot_cmds.append(
        score_plot(var_filter_density_file, 'density-new',
                   var_filter_density_file.stem, chr_size))
    if plot_cmds:
        logger.info('Launch plot jobs...')
        async_batch_sh_jobs(plot_cmds, thread=thread)


if __name__ == '__main__':
    main()
