#!/usr/bin/env python

import argparse
import torch
import pandas as pd
import yaml
import torch

import os
import sys
import rootpath
import subprocess

module_path = rootpath.detect()
if module_path not in sys.path:
    sys.path.append(module_path)

module_path = os.path.join(rootpath.detect(), "astir")
if module_path not in sys.path:
    sys.path.append(module_path)


from astir import Astir
from astir.data import from_csv_yaml


parser = argparse.ArgumentParser(description='Run astir')
# parser.add_argument("expr_csv", help="Path to CSV expression matrix with cells as rows and proteins as columns. First column is cell ID")

subparsers = parser.add_subparsers(title="Criterion", help="specify if the cells are to be classified by types or states or if the user wants convert a SingleCellExperiment data.")
type_parser = subparsers.add_parser("type")
state_parser = subparsers.add_parser("state")

type_parser.add_argument("--prob_max", "-pm", help="For cell type only. Classifies celltypes according to the max probability values", default=False, action="store_true")


def parser_setup(parser):
    parser.add_argument("expr_csv", help="Path to CSV expression matrix with cells as rows and proteins as columns. First column is cell ID")
    parser.add_argument("marker_yaml", help="Path to YAML file of cell markers")
    parser.add_argument("output_csv", help="Output CSV of cell assignments")
    parser.add_argument("--design", "-d",
                        help="Path to design matrix CSV, default to None",
                        type=str,
                        default="None")
    parser.add_argument("--random_seed", "-s",
                        help="random seed for variable initialization, default to 1234",
                        type=int,
                        default=1234)
    parser.add_argument("--dtype", "-t",
                        help="dtype of torch tensors, valid input include torch.float32 or torch.float64, default to torch.float32",
                        type=str,
                        default="torch.float64")
    parser.add_argument("--max_epochs", "-m",
                        help="Number of training epochs, default to 50",
                        type=int,
                        default=50)
    parser.add_argument("--learning_rate", "-r",
                        help="learning rate, default to 1e-2",
                        type=float,
                        default=1e-2)
    parser.add_argument("--batch_size", "-b",
                        help="batch size, default to 1024",
                        type=int,
                        default=1024)
    parser.add_argument("--delta_loss", "-l",
                        type=float,
                        help="the model will stop if delta loss between epochs is reached, default to 1e-3",
                        default=1e-3)
    parser.add_argument("--n_init", "-n",
                        help="number of models to initialize before choosing the best one to finish the training, default to 3",
                        type=int,
                        default=3)
    parser.add_argument("--n_init_epochs", "-i",
                        help="number of initial epochs before the actual training, default to 5",
                        type=int,
                        default=5)

parser_setup(type_parser)

parser_setup(state_parser)
state_parser.add_argument("--delta_loss_batch",
    help="the batch size  to consider delta loss, default to 10",
    type=int, default=10)
state_parser.add_argument("--const", "-c",
    help="constant, default to 2",
    type=int, default=2)
state_parser.add_argument("--dropout_rate",
    help="dropout rate, default to 0",
    type=float, default=0.0)
state_parser.add_argument("--batch_norm",
    help="apply batch normalization if set to True, default to False",
    type=bool, default=False)


def run_type(args):
    if args.design == "None":
        args.design = None

    if args.dtype == "torch.float64":
        args.dtype = torch.float64
    elif args.dtype == "torch.float32":
        args.dtype = torch.float32
    
    a = from_csv_yaml(args.expr_csv, args.marker_yaml, design_csv=args.design, random_seed=args.random_seed, dtype=args.dtype)
    a.fit_type(max_epochs = args.max_epochs,
        learning_rate = args.learning_rate,
        batch_size = args.batch_size,
        delta_loss = args.delta_loss,
        n_init = args.n_init,
        n_init_epochs = args.n_init_epochs)

    if args.prob_max:
        a.type_to_csv(args.output_csv, assignment_type="max")
    else:
        a.type_to_csv(args.output_csv, assignment_type="threshold")

    prob_fn = os.path.splitext(args.output_csv)[0] + ".probabilities.csv"
    a.get_celltype_probabilities().to_csv(prob_fn)


def run_state(args):
    if args.design == "None":
        args.design = None

    if args.dtype == "torch.float64":
        args.dtype = torch.float64
    elif args.dtype == "torch.float32":
        args.dtype = torch.float32
    
    a = from_csv_yaml(args.expr_csv, args.marker_yaml, design_csv=args.design, random_seed=args.random_seed, dtype=args.dtype)
    a.fit_state(max_epochs = args.max_epochs,
        learning_rate = args.learning_rate,
        batch_size = args.batch_size,
        delta_loss = args.delta_loss,
        n_init = args.n_init,
        n_init_epochs = args.n_init_epochs,
        delta_loss_batch = args.delta_loss_batch,
        const=args.const,
        dropout_rate=args.dropout_rate,
        batch_norm=args.batch_norm)
    a.state_to_csv(args.output_csv)

type_parser.set_defaults(func=run_type)
state_parser.set_defaults(func=run_state)

convert_parser = subparsers.add_parser("convert")
convert_parser.add_argument("in_rds",
                                help="input rds file",
                                type=str)
convert_parser.add_argument("out_csv",
                                help="output csv file",
                                type=str
                            )
convert_parser.add_argument("--assay", "-a",
                                help="the assay to convert, default to logcounts",
                                type=str,
                                default="logcounts"
                            )
convert_parser.add_argument("--design_col",
                                help="design column, default to None",
                                type=str,
                                default=""
                            )
convert_parser.add_argument("--design_csv",
                                help="output design csv, default to None",
                                type=str,
                                default=""
                            )
convert_parser.add_argument("--winsorize", "-w",
                                help="the winsorize limit will be (<w>, 1-<w>), default to 0.0",
                                type=float,
                                default=0.0
                            )

def convert_rds(args): 
    path = os.path.dirname(os.path.realpath(__file__))
    cmd = ["Rscript", path + "/../astir/data/rds_reader.R", args.in_rds, 
        args.out_csv, "--assay", args.assay, "--design_col", args.design_col, 
        "--design_csv", args.design_csv, "--winsorize", str(args.winsorize)]
    subprocess.call(cmd, cwd=os.getcwd())
convert_parser.set_defaults(func=convert_rds)

args = parser.parse_args()

if hasattr(args, 'func'):
    args.func(args)
else:
    parser.print_help()

# astir type ./basel_subset_expression.csv ./jackson-2020-markers.yml ./output.csv

