#!/usr/bin/env python

import argparse
import pysam
import sys
import threading
import time
from os import path
from typing import Dict, List, Tuple


class Spinner:
    busy = False
    delay = 0.1

    def __init__(self, delay=None):
        self.spinner_generator = self.spinning_cursor()
        if delay and float(delay): self.delay = delay

    @staticmethod
    def spinning_cursor():
        while 1:
            for cursor in '|/-\\': yield cursor

    def spinner_task(self):
        while self.busy:
            sys.stdout.write("Work in progress " + next(self.spinner_generator))
            sys.stdout.flush()
            time.sleep(self.delay)
            sys.stdout.write('\b'*18)
            sys.stdout.flush()

    def __enter__(self):
        self.busy = True
        threading.Thread(target=self.spinner_task).start()

    def __exit__(self, exception, value, tb):
        self.busy = False
        time.sleep(self.delay)
        if exception is not None:
            return False

    @staticmethod
    def remove_spinner_leftovers():
        sys.stdout.write('\b')
        sys.stdout.flush()


class FormatterClass(argparse.RawDescriptionHelpFormatter):
    pass


class Arguments:
    def __init__(self):
        self.counts = ""
        self.control = ""
        self.treated = ""
        self.output = ""
        self.ids = None
        self.all = False
        self.max_y_filename = ""
        self.coverage = 0
        self.pvalue = 1.0
        self.flag = True

    def parse(self) -> None:
        """
        Parse arguments from the command line using argparse.
        """

        parser = argparse.ArgumentParser(prog='input probNORM', description='Script for counting stops and preparation of probNORM input file. By default, it saves data that meets the quality filter.', formatter_class=FormatterClass)

        req_args = parser.add_argument_group("required arguments")
        req_args.add_argument('-t', help='sorted BAM file for treated sample', required=True, type=str, dest='treated')
        req_args.add_argument('-c', help='sorted BAM file for control sample', required=True, type=str, dest='control')
        req_args.add_argument('-o', help='output file name', required=True, type=str, dest='output')
        parser.add_argument('-id', help='transcript id list in "". Default: all transcripts are prepared', required=False, type=str, dest='ids')
        filter_group = parser.add_mutually_exclusive_group()
        filter_group.add_argument('-v', help='coverage filtering. Save data that, in addition to the quality filter, meets the coverage threshold.', required=False, dest='coverage', type=int)
        filter_group.add_argument('-a', help='save all data regardless of the quality filter.', required=False, dest='all', action='store_true')

        if len(sys.argv[1:]) < 1:
            parser.print_help()
            parser.exit()

        args = parser.parse_args()

        self.control, self.treated = self.check_bam_file(args.control, args.treated)
        if args.ids:
            self.ids = args.ids.split(",")

        if args.coverage:
            self.coverage = args.coverage
        if args.all:
            self.all = args.all

        self.output = self.check_output_files(args.output)

    def check_output_files(self, filename: str) -> str:
        """
        Checks whether the resulting files with the given name exist.
        If so, clean their content.

        Arguments & Returns:
            filename {str} -- counts filename
        """
        if path.exists(filename):
            with open(filename, 'w'):
                pass
        return filename

    def check_bam_file(self, control_file: str, treated_file: str) -> str:
        """
        Validates the given bam file.

        Arguments & Returns:
            control_file {str} -- control bam filename
            treated_file {str} -- treated bam filename
        """
        if not path.exists(control_file+".bai") or not path.exists(treated_file+".bai"):
            sys.stderr.write("[ERROR] Indexed bam files (.bai) must be located in the same folder as the input files.\n        Use 'samtools index <bam_file>' to create index files\n")
            sys.exit()
        else:
            sys.stdout.write("[INFO] Indexed bam files (.bai) detected.\n")
        for f in (control_file, treated_file):
            if time.ctime(path.getmtime(f+".bai")) < time.ctime(path.getmtime(f)):
                sys.stderr.write("[ERROR] The index file is older than the data file: {}\n        Use 'samtools index {}' to create new index file\n".format(f+".bai", f))
                sys.exit()
            try:
                pysam.AlignmentFile(f, "rb")
            except FileNotFoundError:
                sys.stderr.write("[ERROR] Could not open alignment file {}: No such file.\n".format(f))
                sys.exit()
            except ValueError:
                sys.stderr.write("[ERROR] Could not parse alignment file {}: Check if the file is correct\n".format(f))
                sys.exit()
        return control_file, treated_file


class Counter:

    def get_stops_bam(self, bam_file: str, transcript_id: str) -> Tuple[List[float], List[float]]:
        """
        Counts occurrences of RT stops at given transcript positions.
        Stop refers to one nucleotide before the read start position.
        Additionaly count coverage for each position.

        Arguments:
            bam_file {str} -- bam fila name
            transcript_id {str} -- name of analyzed transcript

        Returns:
            List[float] -- number of identified stops for a given transcript position
            List[float] -- coverage for a given transcript position
        """
        stops = {}
        coverage = {}
        samfile = pysam.AlignmentFile(bam_file, "rb")
        tr_reads = samfile.fetch(transcript_id, until_eof=True)
        reference_len = samfile.get_reference_length(transcript_id)
        for read in tr_reads:
            if not read.is_reverse:
                position = read.reference_start # 0-based left start
                #if position > 0:
                    #try:
                    #    stops[position-1] += 1.0
                    #except KeyError:
                    #    stops[position-1] = 1.0
                try:
                    stops[position] += 1
                except KeyError:
                    stops[position] = 1
                for nt in range(read.reference_start, read.reference_start+read.query_length):
                    try:
                        coverage[nt] += 1
                    except KeyError:
                        coverage[nt] = 1
        stops_list = []
        cov_list = []
        for i in range(0, reference_len):
            try:
                stops[i]
                stops_list.append(stops[i])
            except KeyError:
                stops_list.append(0)
        for i in range(0, reference_len):
            try:
                coverage[i]
                cov_list.append(coverage[i])
            except KeyError:
                cov_list.append(0)
        return stops_list, cov_list


    def get_ids(self, samfile:pysam.AlignmentFile) -> Dict[str, int]:
        """Get transcript ids and lengths from SAM header.

        Arguments:
            samfile {pysam.AlignmentFile} -- file handler from pysam

        Returns:
            Dict[str, int] -- names and length of transcripts
        """
        ids = {}
        for i in samfile.header["SQ"]:
            if not arg.ids:
                ids[i['SN']] = i['LN']
            else:
                if i['SN'] in arg.ids:
                    ids[i['SN']] = i['LN']
        return ids

    def get_transcripts_names(self, control_file: str, treated_file:str) -> Dict[str, int]:
        """Retrieves transcript names from given bam files. Checks if bam files match and transcript ids are present.

        Arguments:
            control_file {str} -- control bam file name
            treated_file {str} -- treated bam file name

        Returns:
            Dict[str, int] -- names and length of transcripts
        """
        cfile = pysam.AlignmentFile(control_file, "rb")
        tfile = pysam.AlignmentFile(treated_file, "rb")
        control_ids = self.get_ids(cfile)
        treated_ids = self.get_ids(tfile)
        if control_ids != treated_ids:
            sys.stderr.write("[ERROR] Provided BAM files do not match. Make sure BAM files are correct.\n")
            sys.exit()
        if not control_ids and arg.ids:
            sys.stderr.write("[ERROR] No selected transcripts found: {}\n        Make sure the names are correct and separated by a comma.\n".format(",".join(arg.ids)))
            sys.exit()
        elif arg.ids and len(control_ids) != len(arg.ids):
            missing = list(set(arg.ids) - set(control_ids.keys()))
            if len(missing) == 1:
                sys.stderr.write("[WARNING] Transcript '{}' was not found in BAM files. Make sure the names are correct and separated by a comma.\n".format(",".join(missing)))
            else:
                sys.stderr.write("[WARNING] Transcripts '{}' were not found in BAM files. Make sure the names are correct and separated by a comma.\n".format(",".join(missing)))
        elif not control_ids:
            sys.stderr.write("[ERROR] Could not find any transcripts. Make sure BAM files are correct.\n")
            sys.exit()
        return  control_ids


    def prepare_counts_file(self):
        """
        Reads and parses bam input files for the control and treated sample.
        Writes output files according filtering type.
        """
        trans_ids = self.get_transcripts_names(arg.control, arg.treated)
        omitted = 0
        with Spinner():
            for tid in trans_ids.keys():
                stops_treated, trtd_coverage = self.get_stops_bam(arg.treated, tid)
                stops_control, con_coverage = self.get_stops_bam(arg.control, tid)
                result = []
                if arg.all:
                    for index in range(0, len(stops_treated)):
                        #result.append(("{}\t{}\t{}\t{}\n".format(tid, index+1, stops_control[index], stops_treated[index])))
                        result.append(("{}\t{}\t{}\t{}\t{}\t{}\n".format(tid, index+1, stops_control[index], stops_treated[index], con_coverage[index], trtd_coverage[index])))
                elif arg.coverage:
                    for index in range(0, len(stops_treated)):
                        if stops_control[index] != 0 and stops_treated[index] != 0:
                            if trtd_coverage[index] >= arg.coverage and con_coverage[index] >= arg.coverage:
                                result.append(("{}\t{}\t{}\t{}\t{}\t{}\n".format(tid, index+1, stops_control[index], stops_treated[index], con_coverage[index], trtd_coverage[index])))
                else:
                    for index in range(0, len(stops_treated)):
                        if stops_control[index] != 0 and stops_treated[index] != 0:
                            result.append(("{}\t{}\t{}\t{}\t{}\t{}\n".format(tid, index+1, stops_control[index], stops_treated[index], con_coverage[index], trtd_coverage[index])))

                #if len(result) >= 0.1*len(stops_treated):
                if len(result) >= 1:
                    with open(arg.output, 'a') as out:
                        if arg.flag:
                            #out.write("transcript_id\tposition\tstops_control\tstops_treated\n")
                            out.write("transcript_id\tposition\tstops_control\tstops_treated\tcov_control\tcov_treated\n")
                            arg.flag = False
                        out.write("".join(result))
                else:
                    omitted += 1
        if omitted < len(trans_ids):
            if len(trans_ids) != 1:
                sys.stdout.write("[INFO] Data from '{}' transcripts has been prepared and exported to {}\n".format(len(trans_ids) - omitted, arg.output))
            else:
                sys.stdout.write("[INFO] Data from '{}' transcript has been prepared and exported to {}\n".format(len(trans_ids) - omitted, arg.output))
        else:
            sys.stderr.write("[WARNING] The transcripts do not have enough positions with stops for further analysis.\nThe resulting file was not created\n")



arg = Arguments()
arg.parse()
Counter().prepare_counts_file()
