#!/usr/bin/env python3

import fire
import delegator
import pandas as pd
from pathlib import Path
from typing import List, Optional
from functools import reduce

from snpScore import VAR_DENSITY_PLOT

CHROM = 'Chr'
POS = 'Pos'


class VaCountTable:
    CHROM = "chrom"
    START = 'start'
    END = 'end'
    COUNT = 'variant_count'


def chrom_bin_snp_number_df(start: int, window: int, chr_len: int, chrom: str,
                            label: str, df: pd.DataFrame) -> pd.DataFrame:
    cut_range = range(start, chr_len + window, window)
    range_count_df = pd.DataFrame(
        pd.cut(df[POS], cut_range).value_counts().sort_index())
    range_count_df.columns = [label]
    range_count_df.loc[:, VaCountTable.CHROM] = chrom
    range_count_df.loc[:, VaCountTable.START] = [
        each.left for each in range_count_df.index
    ]
    range_count_df.loc[:, VaCountTable.END] = [
        each.right for each in range_count_df.index
    ]
    return range_count_df[[
        VaCountTable.CHROM, VaCountTable.START, VaCountTable.END, label
    ]].reset_index(drop=True)


def var_density_stats(
    chrom: str,
    label: str,
    chr_size: int,
    var_df: pd.DataFrame,
    window: int = 1000 * 1000,
    step: Optional[int] = None,
) -> pd.DataFrame:
    stats_df_list = []
    if step is None:
        step = window
    for start in range(0, window, step):
        stats_df_i = chrom_bin_snp_number_df(
            start=start,
            chr_len=chr_size,
            window=window,
            chrom=str(chrom),
            df=var_df,
            label=label,
        )
        stats_df_list.append(stats_df_i)
    return pd.concat(stats_df_list)


def stats_plot(stats_file: Path) -> None:
    stats_plot = stats_file.with_suffix('.plot')
    plot_cmd = (f"Rscript {VAR_DENSITY_PLOT} "
                f"--var_density_file {stats_file} "
                f"--out_prefix {stats_plot}")
    delegator.run(plot_cmd)


def va_density_compare(samples: str,
                       va_dir: str,
                       chr_size: str,
                       outdir: str,
                       min_depth: int = 1,
                       min_alt_freq: float = 0,
                       window: int = 1000_000,
                       step: Optional[int] = None):
    out_path = Path(outdir)
    if out_path.is_dir():
        return
    sample_list = samples.split(',')
    chr_df = pd.read_csv(chr_size, sep='\t', index_col=0, names=['chr_len'])
    chr_df.index = 
    va_path = Path(va_dir)
    all_stats_list = []
    for sample_i in sample_list:
        sample_i_file = va_path /  "{sample_i}.pkl"
        sample_i_df = pd.read_pickle(sample_i_file)
        sample_i_df.columns = sample_i_df.columns.droplevel(1).reset_index()
        sample_df_list = []
        for chr_i, sample_chr_va_df in sample_i_df.groupby:
            chr_i_size = chr_df.loc[chr_i, 'chr_len']
            sample_chr_va_df.loc[:,
                                    'depth'] = sample_chr_va_df.ref_count + sample_chr_va_df.alt_count
            sample_chr_va_df = sample_chr_va_df[
                sample_chr_va_df.depth >= min_depth]
            sample_chr_va_df.loc[:,
                                    'alt_freq'] = sample_chr_va_df.alt_count / sample_chr_va_df.depth
            sample_chr_va_df = sample_chr_va_df[
                sample_chr_va_df.alt_freq > min_alt_freq]
            sample_df_list.append(
                var_density_stats(chrom=str(chr_i),
                                    chr_size=chr_i_size,
                                    var_df=sample_chr_va_df,
                                    window=window,
                                    step=step,
                                    label=sample_i))
        all_stats_list.append(pd.concat(sample_df_list))
    stats_df = reduce(
        lambda x, y: pd.merge(
            x,
            y,
            on=[VaCountTable.CHROM, VaCountTable.START, VaCountTable.END]),
        all_stats_list,
    )
    out_path.mkdir(parents=True, exist_ok=True)
    stats_file = out_path / "variant-density.csv"
    stats_df.to_csv(stats_file, index=False)
    stats_plot(stats_file)


if __name__ == "__main__":
    fire.Fire(va_density_compare)
