#!/usr/bin/env python

import threading
import time
import sys
import os
import argparse
import csv
import math
from os import path, system
from typing import Dict, List, Tuple

try:
    import pysam
except ImportError:
    sys.stderr.write("ERROR: Module pysam is curently not installed")
    sys.exit()

try:
    import numpy as np
except ImportError:
    sys.stderr.write("ERROR: Module numpy is curently not installed")
    sys.exit()

try:
    from scipy import stats
except ImportError:
    sys.stderr.write("ERROR: Module scipy is curently not installed")
    sys.exit()


class Spinner:
    busy = False
    delay = 0.1

    def __init__(self, delay=None):
        self.spinner_generator = self.spinning_cursor()
        if delay and float(delay): self.delay = delay

    @staticmethod
    def spinning_cursor():
        while 1:
            for cursor in '|/-\\': yield cursor

    def spinner_task(self):
        while self.busy:
            sys.stdout.write("Work in progress " + next(self.spinner_generator))
            sys.stdout.flush()
            time.sleep(self.delay)
            sys.stdout.write('\b'*18)
            sys.stdout.flush()

    def __enter__(self):
        self.busy = True
        threading.Thread(target=self.spinner_task).start()

    def __exit__(self, exception, value, tb):
        self.busy = False
        time.sleep(self.delay)
        if exception is not None:
            return False

    @staticmethod
    def remove_spinner_leftovers():
        sys.stdout.write('\b')
        sys.stdout.flush()
        time.sleep(0.01)


class FormatterClass(argparse.ArgumentDefaultsHelpFormatter,
                     argparse.RawDescriptionHelpFormatter):
    pass



class Arguments:
    def __init__(self):
        self.counts = False
        self.control = ""
        self.treated = ""
        self.output = ""
        self.ids = None
        #self.max_y_filename = ""
        self.coverage = 0
        self.pvalue = 1.0
        self.flag = True
        self.flag_m = True
        self.mutation_rate = False
        self.tr_size = 20
        self.constrains = False

    def parse(self) -> None:
        """
        Parse arguments from the command line using argparse.
        """

        parser = argparse.ArgumentParser(prog='probNORM', description='probNORM is a method of signal calculation that eliminate read distribution bias and prevent underestimation of reactivity.', formatter_class=FormatterClass)
        subparsers = parser.add_subparsers(help='Choose one of the built-in modes for loading input files.\n\n', metavar="", dest='option')

        parser_counts = subparsers.add_parser('counts', add_help=False, help='The input file with RT polymerase stops counts for both treated and control samples')
        req_counts = parser_counts.add_argument_group("Required arguments")
        req_counts.add_argument('-i', '--input', help='The input file listing RT polymerase stops counts for both treated and control samples. The file structure is as follows:\n1.  Transcript ID\n2.  Position\n3.  Stops count in the control sample\n4.  Stops count in the treated sample', required=True, type=str, dest='counts')
        req_counts.add_argument('-o', '--output', help='The name for the probNORM output file', required=True, type=str, dest='output')
        opt_counts = parser_counts.add_argument_group("Optional arguments")
        opt_counts.add_argument('-p', '--pvalue', help='P-value is the probability that a nucleotide belongs to the background distribution and is not statistically significant. Range [0-1]. All positions with a p-value higher than the provided one are rejected from the result. Default: 1 -> showing all positions.', required=False, dest='pvalue', type=float)
        opt_counts.add_argument('-m', '--mutation-rate', help=argparse.SUPPRESS, required=False, dest='mrate', action='store_true') #'Mutation rates as counts. Result positions without shifting'
        opt_counts.add_argument('-s', '--transcript-size', help="Set the percentage of transcript covered with reactive nucleotides. Those positions are further use to set normalization parameters. Default: 20 percent of provided transcript.", required=False, dest='tr_size', type=int)
        opt_counts.add_argument('-f', '--constrain-files', action='store_true', help="The constrain files with normalized reactivities for RNAfold [ViennaRNA] and Fold [RNAStructure] will be prepared for each normalized transcript. Files will be saved in <output-file-name>_constrains", required=False, dest='constrains')
        opt_counts.add_argument('-h', '--help', action='help', default=argparse.SUPPRESS, help='Show this help message and exit.')

        parser_bam = subparsers.add_parser('bam', add_help=False, help='Reads aligned to transcriptome in sorted BAM file format for the control and treated sample.')
        req_bam = parser_bam.add_argument_group("required arguments")
        req_bam.add_argument('-t', '--treated', help='Sorted BAM file for treated sample', required=True, type=str, dest='treated')
        req_bam.add_argument('-c', '--control', help='Sorted BAM file for control sample', required=True, type=str, dest='control')
        req_bam.add_argument('-o', '--output', help='The name for the probNORM output file', required=True, type=str, dest='output')
        opt_bam = parser_bam.add_argument_group("Optional arguments")
        opt_bam.add_argument('-id', help='probNORM can normalize only selected transcript. Quote transcripts ids in the comma separated list e.g. "RDN18,RDN25" . Default: all transcripts are normalized.', required=False, type=str, dest='ids')
        opt_bam.add_argument('-v', '--coverage', help='Coverage filtering. Only positions with coverage higher than specified are normalized.', required=False, dest='coverage', type=int)
        opt_bam.add_argument('-p', '--pvalue', help='P-value is the probability that a nucleotide belongs to the background distribution and is not statistically significant. Range [0-1]. All positions with a p-value higher than the provided one are rejected from the result. Default: 1 -> showing all positions.', required=False, dest='pvalue', type=float)
        opt_bam.add_argument('-s', '--transcript-size', help="Set the percentage of transcript covered with reactive nucleotides. Those positions are further use to set normalization parameters. Default: 20 percent of provided transcript.", required=False, dest='tr_size', type=int)
        opt_bam.add_argument('-f', '--constrain-files', action='store_true', help="The constrain files with normalized reactivities for RNAfold [ViennaRNA] and Fold [RNAStructure] will be prepared for each normalized transcript. Files will be saved in <output-file-name>_constrains", required=False, dest='constrains')
        opt_bam.add_argument('-h', '--help', action='help', default=argparse.SUPPRESS, help='Show this help message and exit.')

        args = parser.parse_args()
        if len(sys.argv[1:]) <= 1:
            parser.print_help()
            parser.exit()

        if args.option == 'counts':
            #sys.stdout.write("[INFO] Counts mode ON.\n")
            self.counts = self.check_input_counts(args.counts)
            self.mutation_rate = args.mrate
        else:
            #sys.stdout.write("[INFO] BAM mode ON.\n")
            self.control, self.treated = self.check_bam_file(args.control, args.treated)
            if args.ids:
                self.ids = args.ids.split(",")
            if args.coverage:
                self.coverage = args.coverage
            self.mutation_rate = False

        self.constrains = args.constrains

        if args.pvalue:
            if 0.0 <= args.pvalue <= 1.0:
                self.pvalue = args.pvalue
            else:
                sys.stderr.write("[ERROR] P-value parameter '{}' should be in the range [0-1].\n".format(args.pvalue))
                sys.exit()

        if args.tr_size != None:
            self.tr_size = int(args.tr_size)

        self.output = self.check_output_files(args.output)
        #self.max_y_filename = self.check_output_files(".".join(self.output.split(".")[:-1])+"_max_y.txt")

    def check_output_files(self, filename: str) -> str:
        """
        Checks whether the resulting files with the given name exist.
        If so, clean their content.

        Arguments & Returns:
            filename {str} -- counts filename
        """
        if path.exists(filename):
            with open(filename, 'w'):
                pass
        return filename

    @staticmethod
    def create_bam_index_files(bam_input_file):
        pysam.sort("-o", bam_input_file[:-4]+".sorted.bam", bam_input_file)
        pysam.index(bam_input_file[:-4]+".sorted.bam")
        bam_input_file = bam_input_file[:-4]+".sorted.bam"
        return bam_input_file

    @staticmethod
    def check_bam_file_integrity(bam_input_file):
        if time.ctime(path.getmtime(bam_input_file+".bai")) < time.ctime(path.getmtime(bam_input_file)):
            sys.stderr.write("[ERROR] The index file is older than the data file: \nUse 'samtools index' to create new index file\n")
            sys.exit()
            bam_input_file = Arguments.create_bam_index_files(bam_input_file)
            pass
        try:
            pysam.AlignmentFile(bam_input_file, "rb")
        except FileNotFoundError:
            sys.stderr.write("[ERROR] Could not open alignment file {}: No such file\n".format(bam_input_file))
            sys.exit()
        except ValueError:
            sys.stderr.write("[ERROR] Could not parse alignment file {}: Check if the file is correct\n".format(bam_input_file))
            sys.exit()
        return bam_input_file

    @staticmethod
    def quick_check_bam(bam_file):
        try:
            pysam.quickcheck(bam_file)
        except pysam.utils.SamtoolsError:
            sys.stderr.write("[ERROR] File '{}' was not identified as sequence data or is truncated.\n".format(bam_file))
            sys.exit()

    def check_bam_file(self, control_file: str, treated_file: str) -> str:
        """
        Validates the given bam file.

        Arguments & Returns:
            control_file {str} -- control bam filename
            treated_file {str} -- treated bam filename
        """
        for file in [control_file, treated_file]:
            Arguments.quick_check_bam(file)
        if path.exists(control_file+".bai") and path.exists(treated_file+".bai"):
            #sys.stdout.write("[INFO] Indexed bam files (.bai) detected.\n")
            pass
        if not path.exists(control_file+".bai"):
            control_file = self.create_bam_index_files(control_file)
            sys.stdout.write("[INFO] Indexed control bam file (.bai) created.\n")
        if not path.exists(treated_file+".bai"):
            treated_file = self.create_bam_index_files(treated_file)
            sys.stdout.write("[INFO] Indexed treated bam file (.bai) created.\n")

        control_file = self.check_bam_file_integrity(control_file)
        treated_file = self.check_bam_file_integrity(treated_file)
        return control_file, treated_file

    def check_input_counts(self, counts_file: str) -> str:
        """
        Validates the given counts file.

        Arguments & Returns:
            counts_file {str} -- counts filename
        """
        if path.exists(counts_file):
            return counts_file
        else:
            sys.stderr.write("[ERROR] Could not open counts file {}: No such file or directory\n".format(counts_file))
            sys.exit()


class Transcript:
    def __init__(self):
        self.stops_control = []
        self.stops_treated = []
        self.stops_control_norm = []
        self.norm_c = 0.0
        self.FC = []
        self.reactivity = []
        self.trans_id = ""
        self.FC_lim = 0
        self.length = 0
        self.trtd_coverage = []
        self.con_coverage = []

    def norm_mean(self) -> None:
        """
        Calculation of fold-change values ​​for the control and treated sample.
        Skips positions for which:
            control and treated sample stops values are equal to zero
            any of the values ​​is NaN
            if given, number of stops is lower than coverage
        """
        self.FC = []
        for i in range(0, len(self.stops_treated)):
            if self.stops_control[i] != 0 and self.stops_treated[i] != 0:
                if not math.isnan(self.stops_control[i]) and not math.isnan(self.stops_treated[i]):
                    if  self.con_coverage[i] >= arg.coverage and self.trtd_coverage[i] >= arg.coverage:
                        self.FC.append(math.log(float(self.stops_treated[i]) / self.stops_control[i], 2))


    def norm_kernel_densities(self, tr_name: str) -> None:
        """
        Calculation of FC density distribution and determination of max y value.
        Normalization of the number of stops for the control sample.
        Arguments:
            tr_name {str} -- name of analyzed transcript
        """
        density = stats.gaussian_kde(self.FC)
        den = density(self.FC)
        index = np.argmax(den)
        self.norm_c = math.pow(2, self.FC[index])
        #max_y = den[index]
        """
        with open(arg.max_y_filename, 'a') as max_file:
            if arg.flag_m:
                max_file.write("transcript_id\tmaximum\n")
                arg.flag_m = False
            max_file.write("{}\t{}\n".format(tr_name, max_y))
        """
        for i in range(0, len(self.stops_control)):
            self.stops_control_norm.append(self.stops_control[i] * self.norm_c)

    def reactivity_c(self) -> None:
        """
        Calculation of the reactivity based on the number of stops for the treated sample and the standardized control.
        Formula for reactivity for one position: stops_treated - stops_control
        If the subtraction value is less than zero, the reactivity will be zero.
        """
        self.reactivity = []
        for i in range(0, len(self.stops_treated)):
            self.reactivity.append(max(self.stops_treated[i] - self.stops_control_norm[i], 0))

    def get_shifted_FC(self) -> Tuple[List[float], List[float]]:
        """
        Calculation of fold-change values ​​for the normalized control and treated sample.
        Both for further distribution analysis and output creation.

        Returns:
            Tuple[List[float], List[float]] -- FC values for distribution analysis and output creation
        """
        FC_norm_filtered = []
        FC_normalized = []
        sample_tuple = []
        for i in range(0, len(self.stops_treated)):
            if self.stops_control_norm[i] != 0 and self.stops_treated[i] != 0:
                if not math.isnan(self.stops_control_norm[i]) or not math.isnan(self.stops_treated[i]):
                    if  self.con_coverage[i] >= arg.coverage and self.trtd_coverage[i] >= arg.coverage:
                        FC_norm_filtered.append(math.log(float(self.stops_treated[i]) / self.stops_control_norm[i], 2))
                    FC_normalized.append(math.log(float(self.stops_treated[i]) / self.stops_control_norm[i], 2))
                    sample_tuple.append((self.stops_control_norm[i], self.stops_treated[i]))
                else:
                    FC_normalized.append(0)
                    sample_tuple.append((self.stops_control_norm[i], self.stops_treated[i]))
            else:
                FC_normalized.append(0)
                sample_tuple.append((self.stops_control_norm[i], self.stops_treated[i]))
        return FC_norm_filtered, FC_normalized, sample_tuple

    def fit_norm_distribution(self, FC_norm: List[float]) -> Tuple[float, float]:
        """
        Gets FC values ​​less than or equal to zero and generates their opposite values.
        Matches the obtained data to the normal distribution.

        Arguments:
            FC_norm {List[float]} -- list of shifted FC values

        Returns:
            Tuple[float, float] -- mean and std values of the fitted distribution
        """
        left_side_mapped = []
        for i in FC_norm:
            if i < 0:
                left_side_mapped.extend([i, abs(i)])
            elif i == 0:
                left_side_mapped.append(0)
            else:
                pass
        mean, std = stats.norm.fit(left_side_mapped)
        return mean, std


    def norm2_8(self) -> None:
        """
        Calculates p-value for FC values ​​relative to normal background distribution,
            average reactivity by the 2/8 method and records the result.
        """
        FC_distribution, FC_norm, sample_tuple = self.get_shifted_FC()
        mean, std = self.fit_norm_distribution(FC_distribution)

        """
        #save info about fc and normal distribution
        Spinner.remove_spinner_leftovers()
        sys.stdout.write("Saving info about distribution\n")
        #x = np.linspace(min(FC_norm), abs(min(FC_norm)), 200)
        x = np.linspace(-5, 5, 200)
        y = stats.norm.pdf(x, mean, std)
        #with open(self.trans_id+"_fc.txt", 'w') as out:
        with open("datasets/benchmark/final_data/datasets/FC_norm_vivo.txt", 'a') as out:
            #out.write("fc\tfc_norm\n")
            for e,i in enumerate(FC_norm): ## powinno byc distribution
                out.write("{}\t{}\t{}\t{}\n".format(self.trans_id, i, sample_tuple[e][0], sample_tuple[e][1]))
        #with open(self.trans_id+"_dist.txt", 'w') as out:

        with open("datasets/benchmark/final_data/distribution/distributions.txt", 'a') as out:
            #out.write("x\ty\n")
            for i in range(0, len(x)):
                out.write("{}\t{}\t{}\n".format(self.trans_id, x[i], y[i]))
        """

        avg_8p_range, max_val = self.get_2_8_params()

        if arg.constrains:
            constrain_folder = self.create_constrain_dictionary()
        if arg.constrains:
            constrain_out = open(constrain_folder+"/"+self.trans_id+".txt", 'w')

        with open(arg.output, 'a') as out:
            if arg.flag:
                out.write("transcript_id\tposition\tstops_treated\tstops_control\tstops_norm_control\treactivity\tfold_change\tp_value\tpassed_quality_filter\n")
                arg.flag = False
            if arg.mutation_rate or not arg.counts:
                for key in range(0, len(self.reactivity)):
                    FC_probability = 1-stats.norm.cdf(FC_norm[key], mean, std)
                    if self.reactivity[key] > max_val:
                        self.reactivity[key] = max_val / avg_8p_range
                    else:
                        self.reactivity[key] = self.reactivity[key] / avg_8p_range
                    if FC_probability <= arg.pvalue:
                        if (self.stops_treated[key] != 0 and self.stops_control[key] != 0) and self.con_coverage[key] >= arg.coverage and self.trtd_coverage[key] >= arg.coverage:
                            output_line = "{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\tY\n".format(self.trans_id, str(key+1), str(self.stops_treated[key]),  str(self.stops_control[key]),
                                            str(self.stops_control_norm[key]), str(self.reactivity[key]), str(FC_norm[key]), str(FC_probability))
                            if arg.constrains: constrain_out.write("{} {}\n".format(str(key+1), str(self.reactivity[key])))
                        else:
                            output_line = "{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\tN\n".format(self.trans_id, str(key+1), str(self.stops_treated[key]),  str(self.stops_control[key]),
                                            str(self.stops_control_norm[key]), str(self.reactivity[key]), str(FC_norm[key]), str(FC_probability))
                            if arg.constrains: constrain_out.write("{} -999\n".format(str(key+1)))
                        out.write(output_line)
                    else:
                        if arg.constrains: constrain_out.write("{} -999\n".format(str(key+1)))
            else:
                for key in range(1, len(self.reactivity)):
                    FC_probability = 1-stats.norm.cdf(FC_norm[key], mean, std)
                    if self.reactivity[key] > max_val:
                        self.reactivity[key] = max_val / avg_8p_range
                    else:
                        self.reactivity[key] = self.reactivity[key] / avg_8p_range
                    if FC_probability <= arg.pvalue:
                        if (self.stops_treated[key] != 0 and self.stops_control[key] != 0) and self.con_coverage[key] >= arg.coverage and self.trtd_coverage[key] >= arg.coverage:
                            output_line = "{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\tY\n".format(self.trans_id, str(key), str(self.stops_treated[key]),  str(self.stops_control[key]),
                                            str(self.stops_control_norm[key]), str(self.reactivity[key]), str(FC_norm[key]), str(FC_probability))
                            if arg.constrains: constrain_out.write("{} {}\n".format(str(key), str(self.reactivity[key])))
                        else:
                            output_line = "{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\tN\n".format(self.trans_id, str(key), str(self.stops_treated[key]),  str(self.stops_control[key]),
                                            str(self.stops_control_norm[key]), str(self.reactivity[key]), str(FC_norm[key]), str(FC_probability))
                            if arg.constrains: constrain_out.write("{} -999\n".format(str(key)))
                        out.write(output_line)
                    else:
                        if arg.constrains: constrain_out.write("{} -999\n".format(str(key)))

    def create_constrain_dictionary(self):
        if not path.isdir(arg.output+"_constrains"):
            system("mkdir {}".format(arg.output+"_constrains"))
        return arg.output+"_constrains"

    def get_2_8_params(self) -> Tuple[float, float]:
        """
        Calculates normalization parameters by the 2/8 method.

        Returns:
            Tuple[float, float] -- mean of 8 percent of data (90-98), value 98 percent of data
        """

        sort_react = sorted(self.reactivity)
        #print(sort_react)
        ten_percent = int(math.floor(len(sort_react) * 0.9))
        two_percent = int(math.floor(len(sort_react) * 0.98))
        avg_8p_range = sum(sort_react[ten_percent:two_percent]) / (two_percent - ten_percent)
        max_val = sort_react[two_percent]
        #print(ten_percent, two_percent, avg_8p_range, max_val, sort_react[two_percent], sort_react[ten_percent])
        return avg_8p_range, max_val


class Input:

    def get_stops_bam(self, bam_file: str, transcript_id: str) -> Tuple[List[float], List[float]]:
        """
        Counts occurrences of RT stops at given transcript positions.
        Stop refers to one nucleotide before the read start position.
        Additionaly count coverage for each position.

        Arguments:
            bam_file {str} -- bam fila name
            transcript_id {str} -- name of analyzed transcript

        Returns:
            List[float] -- number of identified stops for a given transcript position
            List[float] -- coverage for a given transcript position
        """
        stops = {}
        coverage = {}
        samfile = pysam.AlignmentFile(bam_file, "rb")
        tr_reads = samfile.fetch(transcript_id, until_eof=True)
        reference_len = samfile.get_reference_length(transcript_id)
        for read in tr_reads:
            if not read.is_reverse:
                position = read.reference_start # 0-based left start
                try:
                    stops[position] += 1
                except KeyError:
                    stops[position] = 1
                for nt in range(read.reference_start, read.reference_start+read.query_length):
                    try:
                        coverage[nt] += 1
                    except KeyError:
                        coverage[nt] = 1
        stops_list = []
        cov_list = []
        for i in range(0, reference_len):
            try:
                stops[i]
                stops_list.append(stops[i])
            except KeyError:
                stops_list.append(0)
        for i in range(0, reference_len):
            try:
                coverage[i]
                cov_list.append(coverage[i])
            except KeyError:
                cov_list.append(0)
        return stops_list, cov_list

    def add_zero_coverage(self, stops_control, stops_treated, t_class):
        """
        fill coverage lists with zeros when counts mode is on

        Arguments:
            stops_control {list[float]} -- number of stops in treated sample
            stops_treated {list[float]} -- number of stops in control sample
            t_class {Transcript} -- Transcript class object
        """
        t_class.con_coverage = [0 for i in stops_control]
        t_class.trtd_coverage = [0 for i in stops_treated]


    def run_counts_normalization(self, d_stops_treated: Dict[int, float], d_stops_control: Dict[int, float], transcript_name: str) -> int:
        """
        A function that manages the normalization of a file with counts.

        Arguments:
            d_stops_treated {Dict[int, float]} -- number of stops in treated sample
            d_stops_control {Dict[int, float]} -- number of stops in control sample
            transcript_name {str} -- transcript name
        Return
            int -- number of omitted transcripts
        """
        t = Transcript()
        omitted = 0
        self.add_missing_points(d_stops_treated, d_stops_control, t)
        self.add_zero_coverage(t.stops_control, t.stops_treated, t)
        t.trans_id = transcript_name
        t.length = len(t.stops_control)
        t.norm_mean()
        FC_count = len([x for x in t.FC if x != 0])
        #if len(t.FC) >= 0.2 * t.length: # at least 20% of provided positions should've fold-change != 0
        if FC_count >= (arg.tr_size/100) * t.length: # at least 20% of provided positions should've fold-change != 0
            t.norm_kernel_densities(transcript_name)
            t.reactivity_c()
            t.norm2_8()
        else:
            omitted = 1
        return omitted

    def add_missing_points(self, d_stops_treated, d_stops_control, t: Transcript):
        """
        Adds missing entries in the data.

        Arguments:
            d_stops_treated {[type]} -- number of stops in treated sample
            d_stops_control {[type]} -- number of stops in control sample
            t {Transcript} -- Transcript class object
        """
        for i in range(1, max(d_stops_treated.keys())+1):
            if i in d_stops_treated.keys():
                t.stops_treated.append(d_stops_treated[i])
                t.stops_control.append(d_stops_control[i])
            else:
                t.stops_treated.append(0)
                t.stops_control.append(0)

    def input_file_mode(self) -> None:
        """
        Checks the input file mode used and starts the appropriate parsing mode.
        """
        if arg.counts:
            self.counts_input_parser()
        elif arg.control and arg.treated:
            self.bam_input_parser()

    def get_ids(self, samfile:pysam.AlignmentFile) -> Dict[str, int]:
        """Get transcript ids and lengths from SAM header.

        Arguments:
            samfile {pysam.AlignmentFile} -- file handler from pysam

        Returns:
            Dict[str, int] -- names and length of transcripts
        """
        ids = {}
        for i in samfile.header["SQ"]:
            if not arg.ids:
                ids[i['SN']] = i['LN']
            else:
                if i['SN'] in arg.ids:
                    ids[i['SN']] = i['LN']
        return ids

    def get_transcripts_names(self, control_file: str, treated_file:str) -> Dict[str, int]:
        """Retrieves transcript names from given bam files. Checks if bam files match and transcript ids are present.

        Arguments:
            control_file {str} -- control bam file name
            treated_file {str} -- treated bam file name

        Returns:
            Dict[str, int] -- names and length of transcripts
        """
        cfile = pysam.AlignmentFile(control_file, "rb")
        tfile = pysam.AlignmentFile(treated_file, "rb")
        control_ids = self.get_ids(cfile)
        treated_ids = self.get_ids(tfile)
        if control_ids != treated_ids:
            sys.stderr.write("[ERROR] Provided BAM files do not match. Make sure BAM files are correct.\n")
            sys.exit()
        if not control_ids and arg.ids:
            sys.stderr.write("[ERROR] No selected transcripts found: {}\n        Make sure the names are correct and separated by a comma.\n".format(",".join(arg.ids)))
            sys.exit()
        elif arg.ids and len(control_ids) != len(arg.ids):
            missing = list(set(arg.ids) - set(control_ids.keys()))
            if len(missing) == 1:
                sys.stderr.write("[WARNING] Transcript '{}' was not found in BAM files. Make sure the names are correct and separated by a comma.\n".format(",".join(missing)))
            else:
                sys.stderr.write("[WARNING] Transcripts '{}' were not found in BAM files. Make sure the names are correct and separated by a comma.\n".format(",".join(missing)))
        elif not control_ids:
            sys.stderr.write("[ERROR] Could not find any transcripts. Make sure BAM files are correct.\n")
            sys.exit()
        return  control_ids

    def run_bam_normalization(self, t_class: Transcript) -> int:
        """
        A function that manages the normalization of bam files.

        Arguments:
            t_class {Transcript} -- Transcript class object

        Returns:
            int -- number of omitted transcripts due to insufficient number of treated nucleotides
        """
        omitted_count = 0
        t_class.norm_mean()
        #print(t_class.trans_id, str(round((len(t_class.FC)*100)/len(t_class.stops_treated),2))+"%", str(len(t_class.FC))+"/"+str(len(t_class.stops_treated)))
        if len(t_class.FC) >= (arg.tr_size/100) * len(t_class.stops_treated):
            t_class.norm_kernel_densities(t_class.trans_id)
            t_class.reactivity_c()
            t_class.norm2_8()
        else:
            omitted_count += 1
        return omitted_count

    def check_counts_position_integer(self, line, line_nmbr, stops_treated, stops_control):
        """
        check if position value is integer or float with .0

        Arguments:
            line {str} -- line from counts input file
            line_nmbr {int} -- number of line
            stops_treated {list[float]} -- treated stops list
            stops_control {list[float]} -- control stops list

        Returns:
            int -- position value
        """
        try:
            position = float(line[1])
        except ValueError:
            Spinner.remove_spinner_leftovers()
            sys.stderr.write("[ERROR] Position must be an integer. Please check the input format. Line: '{}', value: '{}'\n".format(line_nmbr, line[1]))
            sys.exit()
        if position.is_integer():
            if int(position) in stops_treated or int(position) in stops_control:
                Spinner.remove_spinner_leftovers()
                sys.stderr.write("[ERROR] Positions in '{}' transcript are duplicated. Position: '{}'\n".format(line[0], position))
                sys.exit()
        else:
            Spinner.remove_spinner_leftovers()
            sys.stderr.write("[ERROR] Position must be an integer. Please check the input format. Line: '{}', value: '{}'\n".format(line_nmbr, position))
            sys.exit()
        return position

    def parse_counts_line(self, line, stops_treated, stops_control, line_nmbr):
        """
        check the correctness of input counts file line and gets the data

        Arguments:
            line {str} -- line from counts input file
            stops_treated {list[float]} -- treated stops list
            stops_control {list[float]} -- control stops list
            line_nmbr {int} -- number of line

        Returns:
            stops_treated {list[float]} -- treated stops list
            stops_control {list[float]} -- control stops list
        """
        position = self.check_counts_position_integer(line, line_nmbr, stops_treated, stops_control)
        try:
            stops_treated[int(position)] = float(line[3])
        except ValueError:
            Spinner.remove_spinner_leftovers()
            sys.stderr.write("[ERROR] Illegal character in a treated column. Please check the input format. Line: '{}', value: '{}'\n".format(line_nmbr, line[3]))
            sys.exit()
        try:
            stops_control[int(position)] = float(line[2])
        except ValueError:
            Spinner.remove_spinner_leftovers()
            sys.stderr.write("[ERROR] Illegal character in a control column. Please check the input format. Line: '{}', value: '{}'\n".format(line_nmbr, line[2]))
            sys.exit()
        return stops_treated, stops_control

    def check_column_number_counts(self, line, line_nmbr):
        """
        check if counts file has at least 4 columns

        Arguments:
            line {str} -- line from counts input file
            line_nmbr {int} -- number of line
        """
        if len(line) < 4:
            Spinner.remove_spinner_leftovers()
            sys.stderr.write("[ERROR] Insufficient number of columns in '{}'. Line: '{}', columns: '{}', required: 4.\n        Check the input file format in <man file>\n".format(arg.counts, line_nmbr, len(line)))
            sys.exit()

    def check_counts_points_limit(self, stops):
        """
        check if transcript has at least 20 positions defined

        Arguments:
            stops {list[float]} -- control stops list

        Returns:
            bool -- True if transcript has more than 20 positions
        """
        if len(stops) < 20:
            return False
        else:
            return True

    def counts_input_parser(self) -> None:
        """
        Reads and parses the input file with the stops number for the control and treated sample.
        Starts normalization and calculation the reactivity of nucleotides together with the FC filter.
        Works for files with one or more transcripts.
        """
        transcript_id = ""
        stops_treated = {}
        stops_control = {}
        omitted_transcripts = 0
        trans_in_file = 0

        with open(arg.counts) as inp:
            trans_ids = []
            has_header = csv.Sniffer().has_header(inp.read(2048))
            inp.seek(0)
            add = 1
            if has_header:
                sys.stdout.write("[INFO] Header was detected in {}\n".format(arg.counts))
                add += 1
                inp.readline()
            first_line = inp.readline().split()
            self.check_column_number_counts(first_line, add)
            transcript_id = first_line[0]
            trans_ids.append(transcript_id)
            stops_treated, stops_control = self.parse_counts_line(first_line, stops_treated, stops_control, add)
            trans_in_file += 1
            with Spinner():
                for index, line in enumerate(inp):
                    tab = line.split()
                    self.check_column_number_counts(tab, index+add+1)
                    if transcript_id == tab[0]:
                        stops_treated, stops_control = self.parse_counts_line(tab, stops_treated, stops_control, index+add+1)
                    else:
                        # run analysis
                        if self.check_counts_points_limit(stops_control):
                            omitted_transcripts += self.run_counts_normalization(stops_treated, stops_control, transcript_id)
                        else:
                            Spinner.remove_spinner_leftovers()
                            #sys.stderr.write("[WARNING] Transcript '{}' has less than 20 positions ({}). Transcript was discarded from further normalization.\n".format(transcript_id, len(stops_control)))
                            omitted_transcripts += 1
                        transcript_id = tab[0]
                        trans_ids.append(transcript_id)
                        stops_treated, stops_control = self.parse_counts_line(tab, {}, {}, index+add+1)
                        trans_in_file += 1
                # run analysis
                if self.check_counts_points_limit(stops_control):
                    omitted_transcripts += self.run_counts_normalization(stops_treated, stops_control, transcript_id)
                else:
                    Spinner.remove_spinner_leftovers()
                    #sys.stderr.write("[WARNING] Transcript '{}' has less than 20 positions ({}). Transcript was discarded from further normalization.\n".format(transcript_id, len(stops_control)))
                    omitted_transcripts += 1
            if omitted_transcripts < trans_in_file:
                if trans_in_file != 1:
                    Spinner.remove_spinner_leftovers()
                    self.prepare_normalization_summary(trans_ids, omitted_transcripts)
                    #sys.stdout.write("[INFO] {} transcripts, from initial {}, met the normalization criteria.\n       Results are in the file: {}\n".format(trans_in_file - omitted_transcripts, trans_in_file, arg.output))
                else:
                    Spinner.remove_spinner_leftovers()
                    self.prepare_normalization_summary(trans_ids, omitted_transcripts)
                    #sys.stdout.write("[INFO] {} transcript, from initial {}, met the normalization criteria.\n      Results are in the file: {}\n".format(trans_in_file - omitted_transcripts, trans_in_file, arg.output))
            else:
                Spinner.remove_spinner_leftovers()
                sys.stderr.write("[WARNING] The transcripts did not meet the normalization criteria.\n          The resulting file was not created\n")

    def bam_input_parser(self):
        """
        Reads and parses bam input files for the control and treated sample.
        Starts normalization and calculation the reactivity of nucleotides together with the FC filter.
        Works for files with multiple transcripts and for selected transcripts.
        """
        omitted = 0
        trans_ids = self.get_transcripts_names(arg.control, arg.treated)
        with Spinner():
            for tid in trans_ids.keys():
                t = Transcript()
                t.trans_id = tid
                t.length = trans_ids[tid]
                t.stops_treated, t.trtd_coverage = self.get_stops_bam(arg.treated, t.trans_id)
                t.stops_control, t.con_coverage = self.get_stops_bam(arg.control, t.trans_id)
                #self.write_count_file(t) # for testing only
                omitted += self.run_bam_normalization(t)
        if omitted < len(trans_ids):
            if len(trans_ids) != 1:
                self.prepare_normalization_summary(trans_ids, omitted)
                #sys.stdout.write("[INFO] {} transcripts, from initial {}, met the normalization criteria.\nResults are in the file: {}\n".format(len(trans_ids) - omitted, len(trans_ids), arg.output))
            else:
                self.prepare_normalization_summary(trans_ids, omitted)
                #sys.stdout.write("***** SUMMARY *****\n\n")
                #sys.stdout.write("[INFO] {} transcript, from initial {}, met the normalization criteria.\nResults are in the file: {}\n".format(len(trans_ids) - omitted, len(trans_ids), arg.output))
        else:
            Spinner.remove_spinner_leftovers()
            sys.stderr.write("WARNING! The transcripts did not meet the normalization criteria.\nThe resulting file was not created\n")

    def write_count_file(self, T):
        count_name = "".join(arg.control.split(".")[:-1])+"_counts.txt"
        #print(T.trans_id, len(T.stops_control), len(T.stops_treated))
        with open(count_name, 'a') as out:
            for i in range(0, len(T.stops_control)):
                out.write("{}\t{}\t{}\t{}\n".format(T.trans_id, i+1, T.stops_control[i], T.stops_treated[i]))

    def prepare_normalization_summary(self, trans_ids, omitted):
        Spinner.remove_spinner_leftovers()
        sys.stdout.write(" "*20 + "\n***** SUMMARY *****\n\n")
        sys.stdout.write("    input mode: {}\n".format("COUNTS" if arg.counts else "BAM"))
        sys.stdout.write("    input file/s: {}\n".format(arg.counts if arg.counts else "control: {} treated: {}".format(arg.control, arg.treated)))
        sys.stdout.write("    output file: {}\n".format(arg.output))
        if  not arg.counts:  sys.stdout.write("    min coverage: {}\n".format(arg.coverage))
        sys.stdout.write("    max p-value: {}\n".format(arg.pvalue))
        sys.stdout.write("    min reactive positions per transcript: {}%\n".format(arg.tr_size))
        if  not arg.counts:  sys.stdout.write("    selected transcripts:  {}\n".format(",".join(arg.ids) if arg.ids else "all"))
        sys.stdout.write("    total number of input transcripts: {}\n".format(len(trans_ids)))
        sys.stdout.write("    transcripts omitted due to low reactivity: {}\n".format(omitted))
        sys.stdout.write("    transcripts normalized: {}\n\n".format(len(trans_ids) - omitted))
        sys.stdout.write("*******************\n")

start = time.time()
arg = Arguments()
arg.parse()
Input().input_file_mode()
stop = time.time()
#print("Time: ", stop - start)
