#!/usr/bin/env python
import os
import sys
import logging
from copy import deepcopy
from functools import partial
import numpy as np
import pandas as pd

import torch
import pyro
import pyro.infer
import pyro.optim
import pickle as pkl

import bean.model.model as m
from bean.model.readwrite import write_result_table
from bean.preprocessing.data_class import (
    VariantSortingScreenData,
    VariantSortingReporterScreenData,
    TilingSortingReporterScreenData,
)
from bean.preprocessing.utils import (
    _obtain_effective_edit_rate,
    _obtain_n_guides_alleles_per_variant,
)

import bean as be
from bean.model.utils import (
    run_inference,
    _get_guide_target_info,
    parse_args,
    check_args,
)

logging.basicConfig(
    level=logging.INFO,
    format="%(levelname)-5s @ %(asctime)s:\n\t %(message)s \n",
    datefmt="%a, %d %b %Y %H:%M:%S",
    stream=sys.stderr,
    filemode="w",
)
error = logging.critical
warn = logging.warning
debug = logging.debug
info = logging.info
pyro.set_rng_seed(101)

DATACLASS_DICT = {
    "Normal": VariantSortingReporterScreenData,
    "MixtureNormal": VariantSortingReporterScreenData,
    "_MixtureNormal+Acc": VariantSortingReporterScreenData,  # TODO: old
    "MixtureNormal+Acc": VariantSortingReporterScreenData,
    "MixtureNormalConstPi": VariantSortingScreenData,
    "MultiMixtureNormal": TilingSortingReporterScreenData,
    "MultiMixtureNormal+Acc": TilingSortingReporterScreenData,
}


def identify_model_guide(args):
    if args.mode == "tiling":
        info("Using Mixture Normal model...")
        return (
            f"MultiMixtureNormal{'+Acc' if args.scale_by_acc else ''}",
            partial(
                m.MultiMixtureNormalModel,
                scale_by_accessibility=args.scale_by_acc,
                use_bcmatch=~args.ignore_bcmatch,
            ),
            partial(
                m.MultiMixtureNormalGuide,
                scale_by_accessibility=args.scale_by_acc,
                fit_noise=~args.dont_fit_noise,
            ),
        )
    if args.uniform_edit:
        if args.guide_activity_col is not None:
            raise ValueError(
                "Can't use the guide activity column while constraining uniform edit."
            )
        info("Using Normal model...")
        return (
            "Normal",
            partial(m.NormalModel, use_bcmatch=~args.ignore_bcmatch),
            m.NormalGuide,
        )
    elif args.const_pi:
        if args.guide_activity_col is not None:
            raise ValueError(
                "--guide-activity-col to be used as constant pi is not provided."
            )
        info("Using Mixture Normal model with constant weight ...")
        return (
            "MixtureNormalConstPi",
            partial(m.MixtureNormalConstPiModel, use_bcmatch=~args.ignore_bcmatch),
            m.MixtureNormalGuide,
        )
    else:
        info(
            f"Using Mixture Normal model {'with accessibility normalization' if args.scale_by_acc else ''}..."
        )
        return (
            f"{'_' if args.dont_fit_noise else ''}MixtureNormal{'+Acc' if args.scale_by_acc else ''}",
            partial(
                m.MixtureNormalModel,
                scale_by_accessibility=args.scale_by_acc,
                use_bcmatch=~args.ignore_bcmatch,
            ),
            partial(
                m.MixtureNormalGuide,
                scale_by_accessibility=args.scale_by_acc,
                fit_noise=~args.dont_fit_noise,
            ),
        )


def main(args, bdata):
    if args.cuda:
        os.environ["CUDA_VISIBLE_DEVICES"] = "1"
        torch.set_default_tensor_type(torch.cuda.FloatTensor)
    else:
        torch.set_default_tensor_type(torch.FloatTensor)
    prefix = (
        args.outdir
        + "/bean_run_result."
        + os.path.basename(args.bdata_path).rsplit(".", 1)[0]
    )
    os.makedirs(prefix, exist_ok=True)
    model_label, model, guide = identify_model_guide(args)
    guide_index = bdata.guides.index
    info("Done loading data. Preprocessing...")
    bdata.samples["rep"] = bdata.samples["rep"].astype("category")
    bdata.guides = bdata.guides.loc[:, ~bdata.guides.columns.duplicated()].copy()
    if args.mode == "variant":
        if bdata.guides[args.target_col].isnull().any():
            raise ValueError(
                f"Some target column (bdata.guides[{args.target_col}]) value is null. Check your input file."
            )
        bdata = bdata[bdata.guides[args.target_col].argsort(), :]
    ndata = DATACLASS_DICT[model_label](
        screen=bdata,
        device=args.device,
        repguide_mask=args.repguide_mask,
        sample_mask_column=args.sample_mask_col,
        accessibility_col=args.acc_col,
        accessibility_bw_path=args.acc_bw_path,
        use_const_pi=args.const_pi,
        condition_column=args.condition_col,
        allele_df_key=args.allele_df_key,
        control_guide_tag=args.control_guide_tag,
        target_col=args.target_col,
        shrink_alpha=args.shrink_alpha,
        replicate_col=args.replicate_col,
    )
    adj_negctrl_idx = None
    if args.mode == "variant":
        if "edit_rate" not in bdata.guides.columns:
            bdata.get_edit_from_allele()
            bdata.get_edit_mat_from_uns(rel_pos_is_reporter=True)
            bdata.get_guide_edit_rate()
        target_info_df = _get_guide_target_info(
            ndata.screen, args, cols_include=[args.negctrl_col]
        )
        print(target_info_df.columns)
        if args.adjust_confidence_by_negative_control:
            adj_negctrl_idx = np.where(
                target_info_df[args.negctrl_col].map(lambda s: s.lower())
                == args.negctrl_col_value.lower()
            )[0]
    else:
        if args.splice_site_path is not None:
            splice_site = pd.read_csv(args.splice_site_path).pos
        target_info_df = be.an.translate_allele.annotate_edit(
            pd.DataFrame(pd.Series(ndata.edit_index))
            .reset_index()
            .rename(columns={"index": "edit"}),
            splice_sites=None if args.splice_site_path is None else splice_site,
        )
        target_info_df["effective_edit_rate"] = _obtain_effective_edit_rate(ndata).cpu()
        target_info_df["n_guides"] = _obtain_n_guides_alleles_per_variant(ndata).cpu()
        if args.adjust_confidence_by_negative_control:
            adj_negctrl_idx = np.where(target_info_df.ref == target_info_df.alt)[0]

    guide_info_df = ndata.screen.guides
    del bdata

    info(f"Running inference for {model_label}...")

    if args.load_existing:
        with open(f"{prefix}/{model_label}.result.pkl", "rb") as handle:
            param_history_dict = pkl.load(handle)
    else:
        param_history_dict = deepcopy(run_inference(model, guide, ndata))
        if args.fit_negctrl:
            negctrl_model = m.ControlNormalModel
            negctrl_guide = m.ControlNormalGuide
            negctrl_idx = np.where(
                guide_info_df[args.negctrl_col].map(lambda s: s.lower())
                == args.negctrl_col_value.lower()
            )[0]
            print(len(negctrl_idx))
            print(negctrl_idx.shape)
            ndata_negctrl = ndata[negctrl_idx]
            param_history_dict["negctrl"] = run_inference(
                negctrl_model, negctrl_guide, ndata_negctrl
            )

    outfile_path = (
        f"{prefix}/bean_element[sgRNA]_result.{model_label}{args.result_suffix}.csv"
    )
    info(f"Done running inference. Writing result at {outfile_path}...")
    if not os.path.exists(prefix):
        os.makedirs(prefix)
    with open(f"{prefix}/{model_label}.result{args.result_suffix}.pkl", "wb") as handle:
        try:
            pkl.dump(param_history_dict, handle)
        except TypeError as exc:
            print(exc.message)
            # print(param_history_dict)
    write_result_table(
        target_info_df,
        param_history_dict,
        model_label=model_label,
        prefix=f"{prefix}/",
        suffix=args.result_suffix,
        guide_index=guide_index,
        guide_acc=ndata.guide_accessibility.cpu().numpy()
        if hasattr(ndata, "guide_accessibility")
        and ndata.guide_accessibility is not None
        else None,
        adjust_confidence_by_negative_control=args.adjust_confidence_by_negative_control,
        adjust_confidence_negatives=adj_negctrl_idx,
    )
    info("Done!")


if __name__ == "__main__":
    args = parse_args()
    bdata = be.read_h5ad(args.bdata_path)
    args, bdata = check_args(args, bdata)
    main(args, bdata)
