import sys
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import List
import logging
from .logging_configuration import QueueLogger, configure_main_logger
from .editor import MultiProcessEditor
from .models import Gff, HaplotypeTable
from .modifications.annotate import (HaplotypePosition,
                                     PairwiseAlignmentAnnotation,
                                     EffectAnnotation,
                                     AddGuideFilter,
                                     ProteinPrediction,
                                     AddHaploTypeName)
from .modifications.collapse import Collapse
from .modifications.aggregate import LocusAggregation
from .modifications.discretize import Discretize
from .modifications.modification import WriteOperation
from .plotting import VariationRangePlot
from math import inf
from textwrap import dedent
from copy import deepcopy

NUCLEASE_CONFIG = {'CAS9': {'offset': 17}}


def parse_args(args):
    parser = get_arg_parser()
    parsed_args = parser.parse_args(args)
    if not parsed_args.local_gff_file and not parsed_args.disable_protein_prediction:
        raise ValueError("Please provide an annotation .gff or use --disable_protein_prediction.")
    if parsed_args.local_gff_file and parsed_args.disable_protein_prediction:
        raise ValueError(("--gene_annotation and --disable_protein_prediction "
                          "are mutually exclusive."))
    if parsed_args.disable_protein_prediction and not parsed_args.gRNAs:
        raise ValueError("--disable_protein_prediction requires --gRNAs.")
    if not parsed_args.cas_offset and parsed_args.cas_protein:
        parsed_args.cas_offset = NUCLEASE_CONFIG[parsed_args.cas_protein]['offset']
        delattr(parsed_args, "cas_protein")
    if (not parsed_args.cas_offset and parsed_args.gRNAs) or \
       (not parsed_args.gRNAs and parsed_args.cas_offset):
        raise ValueError("A gRNAs .gff needs to be specified together with "
                         "'--cas_offset' or '--cas_protein'.")
    if parsed_args.debug:
        parsed_args.logging_level = logging.DEBUG
    else:
        parsed_args.logging_level = logging.INFO
        sys.tracebacklimit = 0  # Suppress traceback information on errors.
    return parsed_args


def get_arg_parser():
    parser = ArgumentParser("effect-prediction")
    input_output_group = parser.add_argument_group(title='Input and output ' +
                                                         'information')

    input_output_group.add_argument("frequency_table",
                                    type=Path,
                                    help='Tab-delimited file containing haplotype ' +
                                         'frequencies per locus, as generated by ' +
                                         'SMAP haplotype-window.')
    input_output_group.add_argument("genome",
                                    type=Path,
                                    help='Reference genome sequence (in .fasta format) ' +
                                         'onto which the .fastq files were mapped.')
    input_output_group.add_argument("borders",
                                    type=Path,
                                    help='.gff file containing the border ' +
                                         'coordinates, must contain NAME=<> in ' +
                                         'column 9.')
    input_output_group.add_argument("-a", "--gene_annotation",
                                    dest="local_gff_file",
                                    type=str,
                                    help='.gff file containing the gene and ' +
                                         'CDS annotation. Must be provided unless you use '
                                         '--disable_protein_prediction.')
    input_output_group.add_argument("--debug", help="Enable verbose logging.", action="store_true")

    gRNA_group = parser.add_argument_group(title='gRNA information')
    gRNA_group.add_argument("-u", "--gRNAs",
                            type=Path,
                            help='.gff file containing the gRNA ' +
                                  'coordinates, must contain NAME=<> in ' +
                                  'column 9.')
    gRNA_group.add_argument("-g", "--no_gRNA_relative_naming",
                            help='.gff file containing the gRNA ' +
                                  'coordinates, must contain NAME=<> in ' +
                                  'column 9.',
                            default=True,
                            action='store_false')
    cas_threshold = gRNA_group.add_mutually_exclusive_group(required=False)
    cas_threshold.add_argument("-p", "--cas_protein",
                               type=str,
                               choices=NUCLEASE_CONFIG.keys(),
                               default=None,
                               help="Name of the nucluease used in the experiment. " +
                                    "Used to select a predefined offset.")
    cas_threshold.add_argument("-f", "--cas_offset",
                               type=int,
                               help="Cas offset")
    filtering_group = parser.add_argument_group("Filtering parameters")
    filtering_group.add_argument("-s", "--cut_site_range_upper_bound",
                                 dest="cut_site_range",
                                 type=int,
                                 default=inf,
                                 help='Upper bound for selecting variations from a range ' +
                                      'around the cut site. Defined in the direction from ' +
                                      'the cut-site towards the PAM.')
    filtering_group.add_argument("-r", "--cut_site_range_lower_bound",
                                 dest="cut_site_range_lower",
                                 type=int,
                                 default=-inf,
                                 help='Lower bound for selecting variations from the a range ' +
                                      'around the cut site. Defined in the direction from ' +
                                      'the cut-site towards the start of the gRNA binding site.')
    resources_group = parser.add_argument_group(title='System resources')
    resources_group.add_argument("-c", "--cpu",
                                 dest="cpu",
                                 help="Maximum number of allowed processes.",
                                 type=int,
                                 default=1)
    mapping_group = parser.add_argument_group("Alignment parameters",
                                              description="Define the parameters to " +
                                                          "align the haplotype sequences to " +
                                                          "the reference sequence.")
    mapping_group.add_argument("--match_score",
                               type=int,
                               default=1)
    mapping_group.add_argument("--mismatch_penalty",
                               type=int,
                               default=-100)
    mapping_group.add_argument("--gap_open_penalty",
                               type=int,
                               default=-100
                               )
    mapping_group.add_argument("--gap_extension",
                               type=int,
                               default=-10
                               )
    discrete_calls_group = parser.add_argument_group(
        title='Discrete calls options',
        description=('Use thresholds to transform '
                     'haplotype frequencies into discrete calls '
                     'using fixed intervals. '
                     'The assigned intervals are indicated '
                     'by a running integer. This is only '
                     'informative for individual samples '
                     'and not for Pool-Seq data.')
        )
    discrete_calls_group.add_argument(
        '-e', '--discrete_calls',
        choices=['dominant', 'dosage'],
        dest='discrete_calls',
        default=None,
        help=('Set to "dominant" to transform haplotype frequency values '
              'into presence(1)/absence(0) calls per allele, or "dosage" '
              'to indicate the allele copy number.'))
    discrete_calls_group.add_argument(
        '-i', '--frequency_interval_bounds',
        nargs='+',
        dest='frequency_bounds',
        help=('Frequency interval bounds for transforming haplotype '
              'frequencies into discrete calls. Custom thresholds can be '
              'defined by passing one or more space-separated values (relative '
              'frequencies in percentage). For dominant calling, one value '
              'should be specified. For dosage calling, an even total number '
              'of four or more thresholds should be specified. Default values '
              'are invoked by passing either "diploid" or "tetraploid". The '
              'default value for dominant calling (see discrete_calls '
              'argument) is 10, both for "diploid" and "tetraploid". For '
              'dosage calling, the default for diploids is "10, 10, 90, 90" '
              'and for tetraploids "12.5, 12.5, 37.5, 37.5, 62.5, 62.5, 87.5, '
              '87.5".'))
    protein_effect_group = parser.add_argument_group('Protein effect prediction')
    protein_effect_group.add_argument('--disable_protein_prediction',
                                      action='store_true',
                                      help=("Disable the estimation of the protein from "
                                            "the haplotypes sequences. All variations "
                                            "within range (-s and -r) of the cut-site will be "
                                            "considered as relevant sequence variant with an "
                                            "effect. This option requires --gRNAs."))
    protein_effect_group.add_argument('-t', '--effect_threshold',
                                      type=float,
                                      help=("Threshold to determine whether a protein is affected "
                                            "by the haplotype variant sequence or not. For each "
                                            "haplotype, a protein identity score is calculated "
                                            "compared to the reference. Haplotypes for which the "
                                            "protein identity is below the effect threshold, "
                                            "will be marked as encoding an affected protein."
                                            "For instance, a protein with 10%% identity to the "
                                            "reference, is below an effect threshold of 50%%, "
                                            "and will be marked as loss-of-function (LOF)."))
    return parser


def set_default_frequency_thresholds(parsed_args: Namespace):
    if parsed_args.discrete_calls:
        default_thresholds_options = {
            'dominant': {
                'diploid': [10],
                'tetraploid': [10]
                },
            'dosage': {
                'diploid': [10, 10, 90, 90],
                'tetraploid': [12.5, 12.5, 37.5, 37.5, 62.5, 62.5, 87.5, 87.5]
            }
        }
        if not parsed_args.frequency_bounds:
            raise ValueError('If discrete calling is enabled, please define ' +
                             'the interval bounds using the frequency_bounds ' +
                             'parameter (see --help for more information)."')
        defaults_for_type = default_thresholds_options[parsed_args.discrete_calls]
        try:
            # Keyword is used to define thresholds
            parsed_args.frequency_bounds = defaults_for_type[parsed_args.frequency_bounds[0]]
        except KeyError:
            # User has chosen to define own thresholds
            manual_threshold_conditions = {
                'dominant': ('1 threshold', lambda x: len(x) == 1),
                'dosage': ('Odd number of thresholds (at least 4)',
                           lambda x: len(x) >= 4 and len(x) % 2 == 0)
            }
            wording, condition = manual_threshold_conditions[parsed_args.discrete_calls]
            if not condition(parsed_args.frequency_bounds):
                raise ValueError('If setting the thresholds manually in ' +
                                 f'{parsed_args.discrete_calls} mode, ' +
                                 'the thresholds must adhere to the ' +
                                 f'following condition: {wording}')
        parsed_args.frequency_bounds = [float(i) for i in parsed_args.frequency_bounds]
    return parsed_args


def log_args(parsed_args, logger):
    log_string = dedent("""
    Running SMAP effect predictor using the following options:

    Input & output:
        Frequency table: {frequency_table}
        Genome: {genome}
        Borders gff: {borders}
        Annotation gff: {local_gff_file}
        gRNAs gff: {gRNAs}
        Cas offset: {cas_offset}

    Alignment parameters:
        Match score: {match_score}
        Mismatch penalty: {mismatch_penalty}
        Gap open penalty: {gap_open_penalty}
        Gap extension penalty: {gap_extension}

    Discrete calls options:
        Discrete call mode: {discrete_calls}
        Frequency bounds: {frequency_bounds}

    Filtering options:
        Cut site range upper bound: {cut_site_range}
        Cut site range lower bound: {cut_site_range_lower}

    Protein effect prediction:
        Protein effect threshold: {effect_threshold}

    System resources:
        Number of processes: {cpu}
    """)
    logger.info(log_string.format(**vars(parsed_args)))


def main(args: List[str]):
    parsed_args = parse_args(args)
    logger = configure_main_logger(parsed_args.logging_level)
    log_args(parsed_args, logger)
    with QueueLogger() as queue_logger:
        logging_queue = queue_logger.logging_queue
        logger.info('SMAP effect-predictor started.')
        parsed_args = set_default_frequency_thresholds(parsed_args)
        haplotype_table = HaplotypeTable.read_smap_output(parsed_args.frequency_table)
        editor = MultiProcessEditor(logging_queue, parsed_args.cpu)
        with parsed_args.borders.open('r') as borders, \
             parsed_args.genome.open('r') as genome, \
             Path('annotate.tsv').open('w') as annotated_table, \
             Path('aggregated.tsv').open('w') as aggregated_table, \
             Path('discretized.tsv').open('w') as discritized_table, \
             Path('collapsed.tsv').open('w') as collapsed_table:
            borders = Gff.read_file(borders)

            # Annotations
            haplo_modification = HaplotypePosition(borders, genome)
            pairwise_modification = PairwiseAlignmentAnnotation(parsed_args.match_score,
                                                                parsed_args.mismatch_penalty,
                                                                parsed_args.gap_open_penalty,
                                                                parsed_args.gap_extension)
            effect_modification = EffectAnnotation('pairwiseProteinIdentity (%)',
                                                   parsed_args.effect_threshold)
            annotated_write = WriteOperation(annotated_table)
            aggregation = LocusAggregation('Effect')
            aggregated_write = WriteOperation(aggregated_table)
            all_annotation_operations = [haplo_modification,
                                         pairwise_modification,
                                         effect_modification,
                                         annotated_write]
            if not parsed_args.disable_protein_prediction:
                local_gff = str(parsed_args.local_gff_file)
                add_protein_modification = ProteinPrediction(local_gff, genome,
                                                             parsed_args.cut_site_range_lower,
                                                             parsed_args.cut_site_range,
                                                             with_gRNAs=bool(parsed_args.gRNAs))
                all_annotation_operations.insert(2, add_protein_modification)

            if parsed_args.gRNAs:
                logger.info('gRNAs .gff file passed, adding gRNA information '
                            'and enabling filtering based on cut-site.')
                with parsed_args.gRNAs.open('r') as gRNA_file:
                    gRNAs = Gff.read_file(gRNA_file)
                    add_gRNA_modification = AddGuideFilter(gRNAs,
                                                           parsed_args.cas_offset,
                                                           parsed_args.cut_site_range_lower,
                                                           parsed_args.cut_site_range,
                                                           parsed_args.no_gRNA_relative_naming)
                    all_annotation_operations.insert(2, add_gRNA_modification)
                    range_plot = VariationRangePlot()
                    all_annotation_operations.insert(3, range_plot)
            else:
                add_gRNA_modification = AddHaploTypeName()
                all_annotation_operations.insert(2, add_gRNA_modification)

            # Aggregation
            all_aggregation_operations = (aggregation,
                                          aggregated_write)

            # Discrete calling
            all_discrete_calling_operations = []
            if parsed_args.discrete_calls:
                logger.info('Requested discrete calling with mode %s and frequency bounds %s.',
                            parsed_args.discrete_calls, parsed_args.frequency_bounds)
                discretize_modification = Discretize(parsed_args.discrete_calls,
                                                     parsed_args.frequency_bounds)
                write_op = WriteOperation(discritized_table)
                all_discrete_calling_operations = [discretize_modification, write_op]

            # Collapse
            collapse_modification = Collapse(parsed_args.cut_site_range_lower,
                                             parsed_args.cut_site_range)
            collapse_write = WriteOperation(collapsed_table)
            all_collapse_operations = [collapse_modification,
                                       collapse_write]

            # Perform annotations
            editor.queue_modification(all_annotation_operations)
            haplotype_table = editor.edit(haplotype_table)
            to_be_collapsed = deepcopy(haplotype_table)

            editor.queue_modification(all_collapse_operations)
            editor.edit(to_be_collapsed)

            for modification_group in (all_aggregation_operations,
                                       all_discrete_calling_operations):
                if modification_group:
                    editor.queue_modification(modification_group)
                    haplotype_table = editor.edit(haplotype_table)
    logger.info('Finished.')
