#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
import argparse
import logging
import pandas as pd
import bean as be
from bean.plotting.allele_stats import plot_n_alleles_per_guide, plot_n_guides_per_edit
import matplotlib.pyplot as plt

plt.style.use("default")
logging.basicConfig(
    level=logging.INFO,
    format="%(levelname)-5s @ %(asctime)s:\n\t %(message)s \n",
    datefmt="%a, %d %b %Y %H:%M:%S",
    stream=sys.stderr,
    filemode="w",
)
error = logging.critical
warn = logging.warning
debug = logging.debug
info = logging.info


def parse_args():
    """Get the input arguments"""
    print(
        r"""
    _ _         
  /  \ '\       __ _ _ _           
  |   \  \     / _(_) | |_ ___ _ _ 
   \   \  |   |  _| | |  _/ -_) '_|
    `.__|/    |_| |_|_|\__\___|_|  
    """
    )
    print("bean-filter: filter alleles")
    parser = argparse.ArgumentParser(
        prog="allele_filter",
        description="Filter alleles based on edit position in spacer and frequency across samples.",
    )
    parser.add_argument(
        "bdata_path",
        type=str,
        help="Input ReporterScreen file of which allele will be filtered out.",
    )
    parser.add_argument(
        "--output-prefix",
        "-o",
        type=str,
        default=None,
        help="Output prefix for log and ReporterScreen file with allele assignment",
    )
    parser.add_argument(
        "--plasmid-path",
        "-p",
        type=str,
        default=None,
        help="Plasmid ReporterScreen object path. If provided, alleles are filtered based on if a nucleotide edit is more significantly enriched in sample compared to the plasmid data. Negative control data where no edit is expected can be fed in instead of plasmid library.",
    )
    parser.add_argument(
        "--edit-start-pos",
        "-s",
        type=int,
        default=2,
        help="0-based start posiiton (inclusive) of edit relative to the start of guide spacer.",
    )
    parser.add_argument(
        "--edit-end-pos",
        "-e",
        type=int,
        default=7,
        help="0-based end position (exclusive) of edit relative to the start of guide spacer.",
    )
    parser.add_argument(
        "--jaccard-threshold",
        "-j",
        type=float,
        help="Jaccard Index threshold when the alleles are mapped to the most similar alleles. In each filtering step, allele counts of filtered out alleles will be mapped to the most similar allele only if they have Jaccard Index of shared edit higher than this threshold.",
        default=0.3,
    )
    parser.add_argument(
        "--filter-window",
        "-w",
        help="Only consider edit within window provided by (edit-start-pos, edit-end-pos). If this flag is not provided, `--edit-start-pos` and `--edit-end-pos` flags are ignored.",
        action="store_true",
    )
    parser.add_argument(
        "--filter-target-basechange",
        "-b",
        help="Only consider target edit (stored in bdata.uns['target_base_change'])",
        action="store_true",
    )
    parser.add_argument(
        "--translate", "-t", help="Translate alleles", action="store_true"
    )
    parser.add_argument(
        "--translate-fasta",
        "-f",
        type=str,
        help="fasta file path with exon positions. If not provided, LDLR hg19 coordinates will be used.",
        default=None,
    )
    parser.add_argument(
        "--translate-fastas-csv",
        "-fs",
        type=str,
        help=".csv with two columns with gene IDs and FASTA file path corresponding to each gene.",
        default=None,
    )
    parser.add_argument(
        "--filter-allele-proportion",
        "-ap",
        type=float,
        default=None,
        help="If provided, alleles that exceed `filter_allele_proportion` in `filter-sample-proportion` will be retained.",
    )
    parser.add_argument(
        "--filter-allele-count",
        "-ac",
        type=int,
        default=5,
        help="If provided, alleles that exceed `filter_allele_proportion` AND `filter_allele_count` in `filter-sample-proportion` will be retained.",
    )
    parser.add_argument(
        "--filter-sample-proportion",
        "-sp",
        type=float,
        default=0.2,
        help="If `filter_allele_proportion` is provided, alleles that exceed `filter_allele_proportion` in `filter-sample-proportion` will be retained.",
    )
    return parser.parse_args()


def check_args(args):
    if args.output_prefix is None:
        args.output_prefix = args.bdata_path.rsplit(".h5ad", 1)[0] + "_alleleFiltered"
    info(f"Saving results to {args.output_prefix}")
    if args.filter_window:
        if args.edit_start_pos is None and args.edit_end_pos is None:
            raise ValueError(
                "Invalid arguments: --filter-window option set but none of --edit-start-pos and --edit-end-pos specified."
            )
        if args.edit_start_pos is None:
            warn(
                "--filter-window option set but none of --edit-start-pos not provided. Using 0 as its value."
            )
            args.edit_start_pos = 0
        if args.edit_end_pos is None:
            warn(
                "--filter-window option set but none of --edit-end-pos not provided. Using 20 as its value."
            )
            args.edit_end_pos = 20
    if args.filter_allele_proportion is not None and (
        args.filter_allele_proportion < 0 or args.filter_allele_proportion > 1
    ):
        raise ValueError(
            "Invalid arguments: filter-allele-proportion should be in range [0, 1]."
        )
    if args.filter_sample_proportion < 0 or args.filter_sample_proportion > 1:
        raise ValueError(
            "Invalid arguments: filter-sample-proportion should be in range [0, 1]."
        )
    if args.translate_fasta and args.translate_fastas_csv:
        raise ValueError(
            "Invalid arguments: You can only specify one of --translate-fasta (single gene) or --translate-fastas-csv (multiple genes)."
        )
    if args.translate_fastas_csv:
        tbl = pd.read_csv(
            args.translate_fastas_csv,
            header=None,
        )
        if len(tbl) == 0 or len(tbl.columns != 2):
            raise ValueError(
                "Invalid arguments: Table should have two columns and more than 0 entry"
            )
        for path in tbl.iloc[:, 2].tolist():
            if not os.path.isfile(path):
                raise FileNotFoundError(
                    f"Invalid input file: {path} does not exist. Check your input in {args.translate_fastas_csv}"
                )


if __name__ == "__main__":
    args = parse_args()
    check_args(args)
    bdata = be.read_h5ad(args.bdata_path)
    allele_df_keys = ["allele_counts"]
    info(
        f"Starting from .uns['allele_counts'] with {len(bdata.uns['allele_counts'])} alleles."
    )

    if args.plasmid_path is not None:
        info(
            "Filtering significantly more edited nucleotide per guide compared to plasmid library..."
        )
        plasmid_adata = be.read_h5ad(args.plasmid_path)
        plasmid_adata.uns[allele_df_keys[-1]] = plasmid_adata.uns[
            allele_df_keys[-1]
        ].loc[plasmid_adata.uns[allele_df_keys[-1]].allele.map(str) != "", :]

        (
            q_val_each,
            sig_allele_df,
        ) = be.an.filter_alleles.filter_alleles(
            bdata, plasmid_adata, filter_each_sample=True, run_parallel=True
        )
        bdata.uns["sig_allele_counts"] = sig_allele_df.reset_index(drop=True)
        allele_df_keys.append("sig_allele_counts")
        info(f"Filtered down to {len(bdata.uns['sig_allele_counts'])} alleles.")

    print(len(bdata.uns[allele_df_keys[-1]]))
    if len(bdata.uns[allele_df_keys[-1]]) >= 1:
        info("Filtering out edits outside spacer position...")
        bdata.uns[f"{allele_df_keys[-1]}_spacer"] = bdata.filter_allele_counts_by_pos(
            rel_pos_start=0,
            rel_pos_end=20,
            rel_pos_is_reporter=False,
            map_to_filtered=True,
            allele_uns_key=allele_df_keys[-1],
            jaccard_threshold=0.2,
        ).reset_index(drop=True)
        info(
            f"Filtered down to {len(bdata.uns[f'{allele_df_keys[-1]}_spacer'])} alleles."
        )
        allele_df_keys.append(f"{allele_df_keys[-1]}_spacer")

    if len(bdata.uns[allele_df_keys[-1]]) > 0 and args.filter_window:
        info(
            f"Filtering out edits based on relatvie position in spacer: 0-based [{args.edit_start_pos},{args.edit_end_pos})..."
        )
        filtered_key = f"{allele_df_keys[-1]}_{args.edit_start_pos}_{args.edit_end_pos}"
        bdata.uns[filtered_key] = bdata.filter_allele_counts_by_pos(
            rel_pos_start=args.edit_start_pos,
            rel_pos_end=args.edit_end_pos,
            rel_pos_is_reporter=False,
            map_to_filtered=True,
            allele_uns_key=allele_df_keys[-1],
            jaccard_threshold=args.jaccard_threshold,
        ).reset_index(drop=True)
        allele_df_keys.append(filtered_key)
        info(f"Filtered down to {len(bdata.uns[filtered_key])} alleles.")

    if len(bdata.uns[allele_df_keys[-1]]) > 0 and args.filter_target_basechange:
        filtered_key = (
            f"{allele_df_keys[-1]}_{bdata.base_edited_from}.{bdata.base_edited_to}"
        )
        info(f"Filtering out non-{bdata.uns['target_base_change']} edits...")
        bdata.uns[filtered_key] = bdata.filter_allele_counts_by_base(
            bdata.base_edited_from,
            bdata.base_edited_to,
            map_to_filtered=False,
            allele_uns_key=allele_df_keys[-1],
        ).reset_index(drop=True)
        info(f"Filtered down to {len(bdata.uns[filtered_key])} alleles.")
        allele_df_keys.append(filtered_key)

    if len(bdata.uns[allele_df_keys[-1]]) > 0 and args.translate:
        if args.translate_fastas_csv:
            fasta_df = pd.read_csv(
                args.translate_fastas_csv,
                header=None,
            )
            fasta_dict = {row[0]: row[1] for i, row in fasta_df.iterrows()}
        else:
            fasta_dict = None
        info(
            "Translating alleles..."
        )  # TODO: Check & document custom fasta file for translation
        filtered_key = f"{allele_df_keys[-1]}_translated"
        bdata.uns[filtered_key] = be.translate_allele_df(
            bdata.uns[allele_df_keys[-1]],
            fasta_file=args.translate_fasta,
            fasta_file_dict=fasta_dict,
        ).rename(columns={"allele": "aa_allele"})
        allele_df_keys.append(filtered_key)
        info(f"Filtered down to {len(bdata.uns[filtered_key])} alleles.")

    if (
        len(bdata.uns[allele_df_keys[-1]]) > 0
        and args.filter_allele_proportion is not None
    ):
        info(
            f"Filtering alleles for those have allele fraction {args.filter_allele_proportion} in at least {args.filter_sample_proportion*100}% of samples..."
        )
        filtered_key = f"{allele_df_keys[-1]}_prop{args.filter_allele_proportion}_{args.filter_sample_proportion}"
        bdata.uns[filtered_key] = be.an.filter_alleles.filter_allele_prop(
            bdata,
            allele_df_keys[-1],
            allele_prop_thres=args.filter_allele_proportion,
            allele_count_thres=args.filter_allele_count,
            sample_prop_thres=args.filter_sample_proportion,
            map_to_filtered=True,
            retain_max=True,
            allele_col=bdata.uns[allele_df_keys[-1]].columns[1],
            distribute=True,
            jaccard_threshold=args.jaccard_threshold,
        )
        allele_df_keys.append(filtered_key)
        info(f"Filtered down to {len(bdata.uns[filtered_key])} alleles.")
        info("Done filtering!")
    info(f"Saving ReporterScreen with filtered alleles at {args.output_prefix}.h5ad...")
    bdata.write(f"{args.output_prefix}.h5ad")

    info("Plotting allele stats for each filtering step...")
    fig, ax = plt.subplots(len(allele_df_keys), 2, figsize=(6, 3 * len(allele_df_keys)))
    for i, key in enumerate(allele_df_keys):
        if len(bdata.uns[key]) > 0:
            plot_n_alleles_per_guide(bdata, key, bdata.uns[key].columns[1], ax[i, 0])
            plot_n_guides_per_edit(bdata, key, bdata.uns[key].columns[1], ax[i, 1])
    plt.tight_layout()
    plt.savefig(f"{args.output_prefix}.filtered_allele_stats.pdf", bbox_inches="tight")
    info(
        f"Saving plotting result and log at {args.output_prefix}.[filtered_allele_stats.pdf, filter_log.txt]."
    )
    with open(f"{args.output_prefix}.filter_log.txt", "w") as out_log:
        for key in allele_df_keys:
            out_log.write(f"{key}\t{len(bdata.uns[key])}\n")
