#!/usr/bin/env python3

from random import sample
import fire
import delegator
import pandas as pd
from pathlib import Path
from typing import List, Optional
from functools import reduce

from snpScore import VAR_DENSITY_PLOT

CHROM = "Chr"
POS = "Pos"


class VaCountTable:
    CHROM = "chrom"
    START = "start"
    END = "end"
    COUNT = "variant_count"


def chrom_bin_snp_number_df(
    start: int, window: int, chr_len: int, chrom: str, label: str, df: pd.DataFrame
) -> pd.DataFrame:
    cut_range = range(start, chr_len + window, window)
    range_count_df = pd.DataFrame(
        pd.cut(df[POS], cut_range).value_counts().sort_index()
    )
    range_count_df.columns = [label]
    range_count_df.loc[:, VaCountTable.CHROM] = chrom
    range_count_df.loc[:, VaCountTable.START] = [
        each.left for each in range_count_df.index
    ]
    range_count_df.loc[:, VaCountTable.END] = [
        each.right for each in range_count_df.index
    ]
    return range_count_df[
        [VaCountTable.CHROM, VaCountTable.START, VaCountTable.END, label]
    ].reset_index(drop=True)


def var_density_stats(
    chrom: str,
    label: str,
    chr_size: int,
    var_df: pd.DataFrame,
    window: int = 1000 * 1000,
    step: Optional[int] = None,
) -> pd.DataFrame:
    stats_df_list = []
    if step is None:
        step = window
    for start in range(0, window, step):
        stats_df_i = chrom_bin_snp_number_df(
            start=start,
            chr_len=chr_size,
            window=window,
            chrom=str(chrom),
            df=var_df,
            label=label,
        )
        stats_df_list.append(stats_df_i)
    return pd.concat(stats_df_list)


def stats_plot(stats_file: Path) -> None:
    stats_plot = stats_file.with_suffix(".plot")
    plot_cmd = (
        f"Rscript {VAR_DENSITY_PLOT} "
        f"--var_density_file {stats_file} "
        f"--out_prefix {stats_plot}"
    )
    delegator.run(plot_cmd)


def va_density_compare(
    samples: str,
    va_dir: str,
    chr_size: str,
    outdir: str,
    min_depth: int = 1,
    min_alt_freq: float = 0.1,
    window: int = 1000_000,
    step: Optional[int] = None,
    plant=False,
):
    out_path = Path(outdir)
    if out_path.is_dir():
        return
    chr_df = pd.read_csv(chr_size, sep="\t", index_col=0, names=["chr_len"])
    va_pathes = [Path(dir_i) for dir_i in va_dir.split(",")]
    all_stats_list = []
    for sample_i in samples.split(","):
        sample_df_list = []
        for chr_i in chr_df.index:
            chr_i_size = chr_df.loc[chr_i, "chr_len"]
            for va_path in va_pathes:
                sample_chr_va_file = va_path / sample_i / f"{chr_i}.csv"
                if sample_chr_va_file.is_file():
                    sample_chr_va_df = pd.read_csv(sample_chr_va_file)
                    sample_chr_va_df.loc[:, "depth"] = (
                        sample_chr_va_df.ref_count + sample_chr_va_df.alt_count
                    )
                    sample_chr_va_df = sample_chr_va_df[
                        sample_chr_va_df.depth >= min_depth
                    ]
                    sample_chr_va_df.loc[:, "alt_freq"] = (
                        sample_chr_va_df.alt_count / sample_chr_va_df.depth
                    )
                    sample_chr_va_df = sample_chr_va_df[
                        sample_chr_va_df.alt_freq >= min_alt_freq
                    ]
                    if plant:
                        sample_label = "-".join(sample_i.split("-")[1:])
                    else:
                        sample_label = sample_i
                    sample_df_list.append(
                        var_density_stats(
                            chrom=str(chr_i),
                            chr_size=chr_i_size,
                            var_df=sample_chr_va_df,
                            window=window,
                            step=step,
                            label=sample_label,
                        )
                    )
        all_stats_list.append(pd.concat(sample_df_list))
    stats_df = reduce(
        lambda x, y: pd.merge(
            x, y, on=[VaCountTable.CHROM, VaCountTable.START, VaCountTable.END]
        ),
        all_stats_list,
    )
    out_path.mkdir(parents=True, exist_ok=True)
    stats_file = out_path / "variant-density.csv"
    stats_df.to_csv(stats_file, index=False)
    stats_plot(stats_file)
    zip_cmd = f"cd {out_path}; zip -r varDensity.zip ./*"
    delegator.run(zip_cmd)


if __name__ == "__main__":
    fire.Fire(va_density_compare)
