import argparse
import logging
import os
import re
import sys
import traceback
from datetime import datetime
from pathlib import Path

from clip import __version__
from .bamCLIP import bamCLIP
from .countCLIP import countCLIP
from .createMatrix import MatrixConverter
from .gffCLIP import FeatureOrderException, gffCLIP
from .trimAnnotation import Trimmer


"""
--------------------------------------------------
htseq-clip main
Authors: Marko Fritz, marko.fritz@embl.de
         Thomas Schwarzl, schwarzl@embl.de
         Nadia Ashraf, nadia.ashraf@embl.de
Modified by: Sudeep Sahadevan, sudeep.sahadevan@embl.de
Institution: EMBL Heidelberg
Date: October 2015
--------------------------------------------------
"""


class bcolors:
    """
    source: https://stackoverflow.com/questions/287871/how-do-i-print-colored-text-to-the-terminal
    """

    HEADER = "\033[32;4m"
    GRAY = "\033[90m"
    GREEN = "\033[32m"
    ENDC = "\033[0m"


def _annotation(args):
    """
    Parse annotations from given GFF file
    @TODO use logging module
    """
    logging.info("Parsing annotations")
    logging.info("GFF file {}, output file {}".format(args.gff, args.output))
    gffc = gffCLIP(args)
    try:
        gffc.process(args.unsorted)
    except FeatureOrderException as se:
        if args.unsorted:
            raise (se)
        else:
            logging.warning(str(se))
            logging.warning(
                'Trying to parse {} with "--unsorted" option.'.format(args.gff)
            )
            logging.warning("This step is memory hungry")
            gffc.process(True)


def _createSlidingWindows(args):
    """
    Create sliding windows from the given annotation file
    """
    logging.info("Create sliding windows")
    logging.info("input file {}, output file {}".format(args.input, args.output))
    logging.info("Window size {} step size {}".format(args.windowSize, args.windowStep))
    gffc = gffCLIP(args)
    gffc.slidingWindow(args.input)


def _mapToId(args):
    logging.info("Creating mapping file from annotations")
    logging.info("Input file {} output file {}".format(args.annotation, args.output))
    # additional params
    mapC = countCLIP(args)
    mapC.annotationToIDs()


def _extract(args):
    """
    Extract cross-link sites
    """
    if args.choice == "s":
        logging.info("Extracting start sites")
        logging.info(
            "Bam file : {}, output file: {}, offset: {}".format(
                args.input, args.output, args.offset
            )
        )
        with bamCLIP(args) as bh:
            bh.extract_start_sites(offset=args.offset)
    elif args.choice == "i":
        logging.info("Extracting insertion sites")
        logging.info("Bam file : {}, output file: {}".format(args.input, args.output))
        with bamCLIP(args) as bh:
            bh.extract_insertion_sites()
    elif args.choice == "d":
        logging.info("Extracting deletion sites")
        logging.info("Bam file : {}, output file: {}".format(args.input, args.output))
        with bamCLIP(args) as bh:
            bh.extract_deletion_sites()
    elif args.choice == "m":
        logging.info("Extracting middle sites")
        logging.info("Bam file : {}, output file: {}".format(args.input, args.output))
        with bamCLIP(args) as bh:
            bh.extract_middle_sites()
    elif args.choice == "e":
        logging.info("Extracting end sites")
        logging.info(
            "Bam file : {}, output file: {}, offset: {}".format(
                args.input, args.output, args.offset
            )
        )
        with bamCLIP(args) as bh:
            bh.extract_end_sites(offset=args.offset)


def _count(args):
    """
    Count crosslink sites per sliding window
    """
    logging.info("Count crosslink sites")
    logging.info(
        "Annotation file {} crosslink sites file {} output file {}".format(
            args.annotation, args.input, args.output
        )
    )
    # sanity check temp dir exists
    if (args.cpTmp) and (args.tmp is not None):
        tmpAbs = Path(args.tmp).absolute()
        if not tmpAbs.exists():
            raise RuntimeError(
                "Folder {} given under '--tmp' parameter does not exists!".format(
                    str(tmpAbs)
                )
            )
        # not all necessary but for completeness
        args.tmp = str(tmpAbs)
    countC = countCLIP(args)
    stranded = True
    if args.unstranded:
        stranded = False
    countC.count(stranded)


def _countMatrix(args):
    logging.info("Generate count matrix from  input files")
    logging.info("Input folder {}, output file {}".format(args.input, args.output))
    mC = MatrixConverter(args.input, args.prefix, args.postfix, args.output)
    mC.read_samples()
    mC.write_matrix()


def _maxCountMatrix(args):
    logging.info("Generate max count matrix from input files")
    logging.info("Input folder {}, output file {}".format(args.input, args.output))
    mC = MatrixConverter(args.input, args.prefix, args.postfix, args.output)
    mC.read_samples(colNr=5)
    mC.write_matrix()


def _trimAnnotation(args):
    logging.info(f"Trimming down annotations in {args.annotation}")
    trimmer = Trimmer(
        inputMatrix=args.matrix, inputAnn=args.annotation, outputAnn=args.output
    )
    trimmer.trim_annotation(header=not args.noheader)


logger = logging.getLogger()


def main():
    prog = "htseq-clip"
    description = f"""
    {bcolors.HEADER}{prog}{bcolors.ENDC}  A flexible toolset for the analysis of iCLIP and eCLIP sequencing data

    The function (as a positional argument) should be one of:

    {bcolors.GREEN}Annotation{bcolors.ENDC}
        annotation              flattens a gff formatted annotation file
        createSlidingWindows    creates sliding windows based on given annotation file
        mapToId                 map entries in "name" column to unique ids and write in tab separated format
    
    {bcolors.GREEN}Extraction{bcolors.ENDC}
        extract                 extracts crosslink sites, insertions or deletions
    
    {bcolors.GREEN}Counting{bcolors.ENDC}
        count                   count sites in annotation
    
    {bcolors.GREEN}Helpers{bcolors.ENDC}
        createMatrix            create R friendly matrix from "count" function output files
        createMaxCountMatrix    create R friendly matrix from `crosslink_count_position_max` column in  "count" function output files
        trimAnnotation          trim annotations from "mapToId" function based on unique ids in output matrix from "createMatrix" function
    
    {bcolors.GRAY}version:{bcolors.ENDC} {__version__}
    {bcolors.GRAY}Issues/Bug reports:{bcolors.ENDC} https://github.com/EMBL-Hentze-group/htseq-clip/issues
    """
    epilog = "For command line options of each argument, use: {} <positional argument> -h".format(
        prog
    )
    parser = argparse.ArgumentParser(
        prog=prog,
        description=description,
        epilog=epilog,
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    # version
    parser.add_argument(
        "-v", "--version", action="version", version=f"{prog} {__version__}"
    )
    # log levels
    loglevels = ["debug", "info", "warn", "quiet"]
    # subparsers
    subps = parser.add_subparsers(help="Need positional arguments", dest="subparser")

    """ ____________________ [Annotation] ___________________  """
    # annotation
    ahelp = (
        "annotation: flattens (to BED format) the given annotation file (in GFF format)"
    )
    annotation = subps.add_parser(
        "annotation", description=ahelp, formatter_class=argparse.RawTextHelpFormatter
    )  # help='flatten annotation',
    annotation.add_argument(
        "-g",
        "--gff",
        metavar="annotation",
        dest="gff",
        help="GFF formatted annotation file, supports gzipped (.gz) files",
        required=True,
    )
    annotation.add_argument(
        "-o",
        "--output",
        metavar="output file",
        dest="output",
        help="output file (.bed[.gz], default: print to console)",
        default=None,
        type=str,
    )
    annotation.add_argument(
        "-u",
        "--geneid",
        metavar="gene id",
        dest="id",
        help="Gene id attribute in GFF file (default: gene_id for gencode gff files)",
        default="gene_id",
        type=str,
    )
    annotation.add_argument(
        "-n",
        "--genename",
        metavar="gene name",
        dest="name",
        help="Gene name attribute in GFF file (default: gene_name for gencode gff files)",
        default="gene_name",
        type=str,
    )
    annotation.add_argument(
        "-t",
        "--genetype",
        metavar="gene type",
        dest="type",
        help="Gene type attribute in GFF file (default: gene_type for gencode gff files)",
        default="gene_type",
        type=str,
    )
    annotation.add_argument(
        "--splitExons",
        dest="splitExons",
        help="use this flag to split exons into exonic features such as 5'UTR, CDS and 3' UTR",
        action="store_true",
    )
    annotation.add_argument(
        "--unsorted",
        dest="unsorted",
        help="use this flag if the GFF file is unsorted",
        action="store_true",
    )
    annotation.add_argument(
        "-v",
        "--verbose",
        metavar="Verbose level",
        dest="log",
        help="Allowed choices: " + ", ".join(loglevels) + " (default: info)",
        choices=loglevels,
        default="info",
    )
    # createSlidingWindows
    cshelp = "createSlidingWindows: creates sliding windows out of the flattened annotation file"
    createSlidingWindows = subps.add_parser(
        "createSlidingWindows",
        description=cshelp,
        formatter_class=argparse.RawTextHelpFormatter,
    )  # help='create sliding windows',
    createSlidingWindows.add_argument(
        "-i",
        "--input",
        metavar="input file",
        dest="input",
        help='flattend annotation file, see "{} annotation -h"'.format(prog),
        required=True,
    )
    createSlidingWindows.add_argument(
        "-o",
        "--output",
        metavar="output file",
        dest="output",
        help="annotation sliding windows file (.bed[.gz], default: print to console)",
        default=None,
        type=str,
    )
    createSlidingWindows.add_argument(
        "-w",
        "--windowSize",
        metavar="window size",
        dest="windowSize",
        help="window size (in number of base pairs) for sliding window (default: 50)",
        default=50,
        type=int,
    )
    createSlidingWindows.add_argument(
        "-s",
        "--windowStep",
        metavar="step size",
        dest="windowStep",
        help="window step size for sliding window (default: 20)",
        default=20,
        type=int,
    )
    createSlidingWindows.add_argument(
        "-v",
        "--verbose",
        metavar="Verbose level",
        dest="log",
        help="Allowed choices: " + ", ".join(loglevels) + " (default: info)",
        choices=loglevels,
        default="info",
    )
    # mapToIds
    maphelp = 'mapToId: extract "name" column from the annotation file and map the entries to unique id and print out in tab separated format'
    mapToId = subps.add_parser(
        "mapToId", description=maphelp, formatter_class=argparse.RawTextHelpFormatter
    )
    mapToId.add_argument(
        "-a",
        "--annotation",
        metavar="annotation file",
        help='flattened annotation file from "{0} annotation -h" or sliding window file from "{0} createSlidingWindows -h"'.format(
            prog
        ),
        required=True,
    )
    mapToId.add_argument(
        "-o",
        "--output",
        metavar="output file",
        dest="output",
        help="region/window annotation mapped to a unique id (.txt[.gz], default: print to console)",
        default=None,
        type=str,
    )
    mapToId.add_argument(
        "-v",
        "--verbose",
        metavar="Verbose level",
        dest="log",
        help="Allowed choices: " + ", ".join(loglevels) + " (default: info)",
        choices=loglevels,
        default="info",
    )

    """ ____________________ [Extraction] ___________________ """
    # extract
    ehelp = "extract:  extracts crosslink sites, insertions or deletions"
    echoices = ["s", "i", "d", "m", "e"]
    mates = [1, 2]
    extract = subps.add_parser(
        "extract", description=ehelp, formatter_class=argparse.RawTextHelpFormatter
    )  # ,help='extract crosslinks'
    extract.add_argument(
        "-i",
        "--input",
        metavar="input file",
        dest="input",
        help="input file (.bam, MUST be co-ordinate sorted and indexed)",
        required=True,
    )
    extract.add_argument(
        "-o",
        "--output",
        metavar="output file",
        dest="output",
        help="output file (.bed, default: print to console)",
        default=None,
        type=str,
    )
    extract.add_argument(
        "-e",
        "--mate",
        dest="mate",
        help="for paired end sequencing, select the read/mate to extract the crosslink sites from.\n Must be one of: {}".format(
            ", ".join([str(i) for i in mates])
        ),
        type=int,
        choices=mates,
        required=True,
    )  # make it required ?
    extract.add_argument(
        "-s",
        "--site",
        dest="choice",
        help="Crosslink site choices, must be one of: {0}\n s: start site \n i: insertion site \n d: deletion site \n m: middle site \n e: end site (default: e).".format(
            ", ".join(echoices)
        ),
        choices=echoices,
        default="e",
    )
    extract.add_argument(
        "-g",
        "--offset",
        metavar="offset length",
        dest="offset",
        help="Number of nucleotides to offset for crosslink sites (default: 0)",
        type=int,
        default=0,
    )
    extract.add_argument(
        "--ignore",
        dest="ignore",
        help="flag to ignore crosslink sites outside of genome",
        action="store_true",
    )
    extract.add_argument(
        "--ignore_PCR_duplicates",
        dest="pcr",
        help="flag to ignore PCR duplicates (only if bam file has PCR duplicate flag in alignment)",
        action="store_true",
    )
    extract.add_argument(
        "-q",
        "--minAlignmentQuality",
        metavar="min. alignment quality",
        dest="minAlignmentQuality",
        help="minimum alignment quality (default: 10)",
        type=int,
        default=10,
    )
    extract.add_argument(
        "-m",
        "--minReadLength",
        metavar="min. read length",
        dest="minReadLength",
        help="minimum read length (default: 0)",
        type=int,
        default=0,
    )
    extract.add_argument(
        "-x",
        "--maxReadLength",
        metavar="max. read length",
        dest="maxReadLength",
        help="maximum read length (default: 500)",
        type=int,
        default=500,
    )
    extract.add_argument(
        "-l",
        "--maxReadInterval",
        metavar="max. read interval",
        dest="maxReadIntervalLength",
        help="maximum read interval length (default: 10000)",
        type=int,
        default=10000,
    )
    extract.add_argument(
        "--primary",
        dest="primary",
        help="flag to use only primary positions of multimapping reads",
        action="store_true",
    )
    extract.add_argument(
        "-c",
        "--cores",
        dest="cores",
        metavar="cpus",
        help="Number of cores to use for alignment parsing (default: 5)",
        default=5,
        type=int,
    )
    extract.add_argument(
        "-f",
        "--chrom",
        metavar="chromosomes list",
        dest="chromFile",
        help="Extract crosslink sites only from chromosomes given in this file (one chromosome per line, default: None)",
        type=str,
        default=None,
    )
    extract.add_argument(
        "-t",
        "--tmp",
        dest="tmp",
        metavar="tmp",
        help='Path to create and store temp files (default behavior: use folder from "--output" parameter)',
        default=None,
        type=str,
    )
    extract.add_argument(
        "-v",
        "--verbose",
        metavar="Verbose level",
        dest="log",
        help="Allowed choices: " + ", ".join(loglevels) + " (default: info)",
        choices=loglevels,
        default="info",
    )

    """ ____________________ [Counting] ___________________ """
    # count
    chelp = "count: counts the number of crosslink/deletion/insertion sites"
    count = subps.add_parser(
        "count", description=chelp, formatter_class=argparse.RawTextHelpFormatter
    )  # help='count crosslinks',
    count.add_argument(
        "-i",
        "--input",
        metavar="input bed",
        dest="input",
        help='extracted crosslink, insertion or deletion sites (.bed[.gz]), see "{} extract -h"'.format(
            prog
        ),
        required=True,
    )
    count.add_argument(
        "-o",
        "--output",
        metavar="output file",
        dest="output",
        help="output count file (.txt[.gz], default: print to console)",
        default=None,
        type=str,
    )
    count.add_argument(
        "-a",
        "--ann",
        metavar="annotation",
        dest="annotation",
        help='''flattened annotation file (.bed[.gz]) 
    See "{0} annotation -h" OR sliding window annotation file (.bed[.gz]), see "{0} createSlidingWindows -h"'''.format(
            prog
        ),
        required=True,
    )
    count.add_argument(
        "--unstranded",
        dest="unstranded",
        help="""crosslink site counting is strand specific by default. 
    Use this flag for non strand specific crosslink site counting""",
        action="store_true",
    )
    count.add_argument(
        "--copy_tmp",
        dest="cpTmp",
        help="""In certain cases, gzip crashes on while running "htseq-clip count" with a combination of Slurm and Snakemake.
    Copying files to the local temp. folder seems to get rid of the issue. Use this flag to copy files to a tmp. folder. 
    Default: use system specific "tmp" folder, use argument "--tmp" to specify a custom one""",
        action="store_true",
    )
    count.add_argument(
        "-t",
        "--tmp",
        metavar="temp. directory",
        dest="tmp",
        help="temp. directory path to copy files (default: None, use system tmp directory)",
        default=None,
        type=str,
    )
    count.add_argument(
        "-v",
        "--verbose",
        metavar="Verbose level",
        dest="log",
        help="Allowed choices: " + ", ".join(loglevels) + " (default: info)",
        choices=loglevels,
        default="info",
    )

    """ ____________________ [Helpers] ___________________ """
    # createMatrix
    cmhelp = "createMatrix: create R friendly output matrix file from count function output files"
    createMatrix = subps.add_parser(
        "createMatrix",
        description=cmhelp,
        formatter_class=argparse.RawTextHelpFormatter,
    )
    createMatrix.add_argument(
        "-i",
        "--inputFolder",
        dest="input",
        metavar="input folder",
        help='Folder name with output files from count function, see "{} count -h ", supports .gz (gzipped files)'.format(
            prog
        ),
        required=True,
    )
    createMatrix.add_argument(
        "-b",
        "--prefix",
        dest="prefix",
        metavar="file name prefix",
        help="Use files only with this given file name prefix (default: None)",
        default="",
        type=str,
    )
    createMatrix.add_argument(
        "-e",
        "--postfix",
        dest="postfix",
        metavar="file name postfix",
        help='Use files only with this given file name postfix (default: None). WARNING! either "--prefix" or "--postfix" argument must be given!',
        default="",
        type=str,
    )
    createMatrix.add_argument(
        "-o",
        "--output",
        metavar="output file",
        dest="output",
        help="output junction file (.txt[.gz], default: print to console)",
        default=None,
        type=str,
    )
    createMatrix.add_argument(
        "-v",
        "--verbose",
        metavar="Verbose level",
        dest="log",
        help="Allowed choices: " + ", ".join(loglevels) + " (default: info)",
        choices=loglevels,
        default="info",
    )

    # createMaxCountMatrix
    maxcounthelp = 'createMaxCountMatrix: create R friendly matrix from `crosslink_count_position_max` column in  "count" function output files'
    createMaxCountMatrix = subps.add_parser(
        "createMaxCountMatrix",
        description=maxcounthelp,
        formatter_class=argparse.RawTextHelpFormatter,
    )
    createMaxCountMatrix.add_argument(
        "-i",
        "--inputFolder",
        dest="input",
        metavar="input folder",
        help='Folder name with output files from count function, see "{} count -h ", supports .gz (gzipped files)'.format(
            prog
        ),
        required=True,
    )
    createMaxCountMatrix.add_argument(
        "-b",
        "--prefix",
        dest="prefix",
        metavar="file name prefix",
        help="Use files only with this given file name prefix (default: None)",
        default="",
        type=str,
    )
    createMaxCountMatrix.add_argument(
        "-e",
        "--postfix",
        dest="postfix",
        metavar="file name postfix",
        help='Use files only with this given file name postfix (default: None). WARNING! either "--prefix" or "--postfix" argument must be given!',
        default="",
        type=str,
    )
    createMaxCountMatrix.add_argument(
        "-o",
        "--output",
        metavar="output file",
        dest="output",
        help="output junction file (.txt[.gz], default: print to console)",
        default=None,
        type=str,
    )
    createMaxCountMatrix.add_argument(
        "-v",
        "--verbose",
        metavar="Verbose level",
        dest="log",
        help="Allowed choices: " + ", ".join(loglevels) + " (default: info)",
        choices=loglevels,
        default="info",
    )

    # trimAnnotation
    trimAnnhelp = "trimAnnotation: trim down large annotation file based on output from 'createMatrix' "
    trimAnnotation = subps.add_parser(
        "trimAnnotation",
        description=trimAnnhelp,
        formatter_class=argparse.RawTextHelpFormatter,
    )
    trimAnnotation.add_argument(
        "-i",
        "--matrix",
        dest="matrix",
        metavar="crosslink matrix",
        help="Crosslink count matrix, output from the function 'createMatrix'",
        required=True,
    )
    trimAnnotation.add_argument(
        "-a",
        "--annotation",
        dest="annotation",
        metavar="annotation",
        help="Annotation file, output from the function 'mapToId'",
        required=True,
    )
    trimAnnotation.add_argument(
        "-o",
        "--output",
        metavar="output file",
        dest="output",
        help="output trimmed annotations to file (.txt[.gz], default: print to console)",
        default=None,
        type=str,
    )
    header_help = "First row in the annotation file is assumed to be header by default. Use this flag if the first row in annotation file is not a header"
    trimAnnotation.add_argument(
        "--no_header", dest="noheader", help=header_help, action="store_true"
    )
    trimAnnotation.add_argument(
        "-v",
        "--verbose",
        metavar="Verbose level",
        dest="log",
        help="Allowed choices: " + ", ".join(loglevels) + " (default: info)",
        choices=loglevels,
        default="info",
    )

    # Now read in arguments and process
    try:
        args = parser.parse_args()
        if args.subparser is None:
            parser.print_help(sys.stderr)
            sys.exit(1)
        # set logging level and handler
        if args.log == "quiet":
            logger.addHandler(logging.NullHandler())
        else:
            logger.setLevel(logging.getLevelName(args.log.upper()))
            if len(logger.handlers) >= 1:
                # ugly fix for multiple logging handlers
                logger.handlers = []
            consHandle = logging.StreamHandler(sys.stderr)
            consHandle.setLevel(logging.getLevelName(args.log.upper()))
            consHandle.setFormatter(logging.Formatter(" [%(levelname)s]  %(message)s"))
            logger.addHandler(consHandle)
        logging.info(
            "run started at {}".format(datetime.now().strftime("%Y-%m-%d %H:%M"))
        )
        # check subparsers
        if args.subparser == "annotation":
            # parse annotations
            _annotation(args)
        elif args.subparser == "createSlidingWindows":
            # create sliding windows
            _createSlidingWindows(args)
        elif args.subparser == "mapToId":
            _mapToId(args)
        elif args.subparser == "extract":
            # extract crosslink sites based on annotations
            _extract(args)
        elif args.subparser == "count":
            # count extracted crosslink sites
            _count(args)
        elif args.subparser == "createMatrix":
            # collect output files from count function and generate an R friendly matrix
            if args.prefix == "" and args.postfix == "":
                createMatrix.print_help()
                raise argparse.ArgumentTypeError(
                    'Input values for both arguments "--prefix" and "--postfix" cannot be empty! Either one of the values MUST be given'
                )
            _countMatrix(args)
        elif args.subparser == "createMaxCountMatrix":
            # collect output files from count function and generate an R friendly matrix
            if args.prefix == "" and args.postfix == "":
                createMatrix.print_help()
                raise argparse.ArgumentTypeError(
                    'Input values for both arguments "--prefix" and "--postfix" cannot be empty! Either one of the values MUST be given'
                )
            _maxCountMatrix(args)
        elif args.subparser == "trimAnnotation":
            _trimAnnotation(args)
        else:
            raise NotImplementedError(f"Function {args.subparser} not implemented!")
        logging.info(
            "run completed at {}".format(datetime.now().strftime("%Y-%m-%d %H:%M"))
        )
    except KeyboardInterrupt:
        sys.stderr.write("Keyboard interrupt... good bye\n")
        sys.exit(1)
    except Exception:
        traceback.print_exc(file=sys.stdout)
        sys.exit(1)
    sys.exit(0)


if __name__ == "__main__":
    main()
