#!/usr/bin/env python3

import sys
import fire
import numpy as np
import pandas as pd
from pathlib import PurePath, Path


def alt_allele_num(reads_inf, allele_num=2):
    reads_inf_str_list = str(reads_inf).split(",")
    if len(reads_inf_str_list) <= allele_num:
        reads_inf_int_list = [int(read_i) for read_i in reads_inf_str_list]
        if len(reads_inf_int_list) == 1:
            return 0
        else:
            return reads_inf_int_list[1]
    else:
        return np.nan


def ref_allele_num(reads_inf, allele_num=2):
    reads_inf_str_list = str(reads_inf).split(",")
    if len(reads_inf_str_list) <= allele_num:
        reads_inf_int_list = [int(read_i) for read_i in reads_inf_str_list]
        if len(reads_inf_int_list) == 1:
            return 0
        else:
            return reads_inf_int_list[0]
    else:
        return np.nan


def allele_depth(reads_inf, allele_num=2):
    reads_inf_str_list = str(reads_inf).split(",")
    if len(reads_inf_str_list) <= allele_num:
        reads_inf_int_list = [int(read_i) for read_i in reads_inf_str_list]
        return sum(reads_inf_int_list)
    else:
        return np.nan


def ad_col_index(info_val):
    return info_val.split(":").index("AD")


def load_table_from_vcf(input_vcf, sample_id):
    table_df = pd.read_csv(input_vcf, sep='\t', comment='#', header=None)
    out_df = table_df.loc[:, [0, 1, 3, 4, 9]]
    out_df.columns = ['Chr', 'Pos', 'Ref', 'Alt', sample_id]
    out_df.loc[:, 'Chr'] = out_df.Chr.astype('str')
    ad_col = ad_col_index(table_df.loc[0, 8])
    out_df.loc[:, sample_id] = [snp_i.split(":")[ad_col] for snp_i in out_df[sample_id]]
    return out_df


def reformat_table(table_df, sample_name):
    out_table_df = table_df.copy()
    out_table_df.loc[:, "ref_count"] = table_df.loc[:, sample_name].map(ref_allele_num)
    out_table_df.loc[:, "alt_count"] = table_df.loc[:, sample_name].map(alt_allele_num)
    out_table_df.drop([sample_name], axis=1, inplace=True)
    out_table_df.dropna(inplace=True)
    out_table_df.loc[:, "ref_count"] = out_table_df.ref_count.astype("int")
    out_table_df.loc[:, "alt_count"] = out_table_df.alt_count.astype("int")
    return out_table_df


def df2pkl(df, csv_dir):
    csv_dir = Path(csv_dir)
    csv_dir.mkdir(parents=True, exist_ok=True)
    for grp, df_i in df.groupby("Chr"):
        pkl_file = csv_dir / f"{grp}.csv"
        df_i.to_csv(pkl_file, index=False)


def table2pkl(table_file, csv_dir, sep="\t"):

    table_df = pd.read_csv(table_file, sep=sep)
    sample_name = table_df.columns[-1]
    out_table_df = reformat_table(table_df, sample_name)
    df2pkl(out_table_df, csv_dir)


def table2pkl_stdin(sample_id, csv_dir):
    if not sys.stdin:
        sys.exit("VCF stdin input is needed!")
    table_df = load_table_from_vcf(sys.stdin, sample_id)
    sample_name = table_df.columns[-1]

    out_table_df = reformat_table(table_df, sample_name)
    df2pkl(out_table_df, csv_dir)


if __name__ == "__main__":
    fire.Fire({"from_file": table2pkl, "from_stdin": table2pkl_stdin})
