#!/usr/bin/env python

from __future__ import print_function

from sys import argv
import sys
import operator
from collections import defaultdict

import numpy as np
from scipy.stats import poisson, rankdata, spearmanr, pearsonr

from epic2.main import _main

from epic2.version import __version__
from epic2.src.find_islands import differential_count_reads_on_islands

from collections import OrderedDict
from natsort import natsorted
import argparse
import os

import pyranges as pr
import pandas as pd

required_args = False if ("--version" in sys.argv or "-v" in sys.argv
                          or "-ex" in sys.argv
                          or "--example" in sys.argv) else True

parser = argparse.ArgumentParser(
    description="""epic2-df, version: {}
(Visit github.com/endrebak/epic2 for examples and help. Run epic2-example for a simple example command.)
    """.format(__version__),
    prog=os.path.basename(__file__))

parser.add_argument(
    '--treatment-knockout',
    '-tk',
    required=required_args,
    type=str,
    nargs='+',
    help=
    '''Treatment (pull-down) file(s) for knockout in one of these formats: bed, bedpe, bed.gz, bedpe.gz or (single-end) bam, sam. Mixing file formats is allowed.'''
)

parser.add_argument(
    '--control-knockout',
    '-ck',
    required=False,
    type=str,
    nargs='+',
    help=
    '''Control (input) file(s) for knockout in one of these formats: bed, bedpe, bed.gz, bedpe.gz or (single-end) bam, sam. Mixing file formats is allowed.'''
)

parser.add_argument(
    '--treatment-wildtype',
    '-tw',
    required=required_args,
    type=str,
    nargs='+',
    help=
    '''Treatment (pull-down) file(s) for wildtype in one of these formats: bed, bedpe, bed.gz, bedpe.gz or (single-end) bam, sam. Mixing file formats is allowed.'''
)

parser.add_argument(
    '--control-wildtype',
    '-cw',
    required=False,
    type=str,
    nargs='+',
    help=
    '''Control (input) file(s) for wildtype in one of these formats: bed, bedpe, bed.gz, bedpe.gz or (single-end) bam, sam. Mixing file formats is allowed.'''
)

parser.add_argument(
    '--genome',
    '-gn',
    required=False,
    default="hg19",
    type=str,
    help=
    '''Which genome to analyze. Default: hg19. If --chromsizes and --egf flag is given, --genome is not required.'''
)

parser.add_argument(
    '--keep-duplicates',
    '-kd',
    required=False,
    default=False,
    action='store_true',
    help=
    '''Keep reads mapping to the same position on the same strand within a library. Default: False.
                   ''')

parser.add_argument(
    '--original-algorithm',
    '-oa',
    required=False,
    default=False,
    action='store_true',
    help=
    '''Use the original SICER algorithm, without the epic2 fix. This will use all reads in your files to compute the p-values, including those falling outside the genome boundaries.'''
)

parser.add_argument(
    '--bin-size',
    '-bin',
    required=False,
    default=200,
    type=int,
    help=
    '''Size of the windows to scan the genome. BIN-SIZE is the smallest possible island. Default 200.
                   ''')

parser.add_argument(
    '--gaps-allowed',
    '-g',
    required=False,
    default=3,
    type=int,
    help=
    '''This number is multiplied by the window size to determine the number of gaps
                   (ineligible windows) allowed between two eligible windows.
                   Must be an integer. Default: 3. ''')

parser.add_argument(
    '--fragment-size',
    '-fs',
    required=False,
    default=150,
    type=int,
    help=
    '''(Single end reads only) Size of the sequenced fragment. Each read is extended half the fragment size from the 5' end. Default 150 (i.e. extend by 75).'''
)

parser.add_argument(
    '--false-discovery-rate-cutoff',
    '-fdr',
    required=False,
    default=0.05,
    type=float,
    help='''Remove all islands with an FDR above cutoff. Default 0.05.
                   ''')

parser.add_argument(
    '--false-discovery-rate-comparison',
    '-fdrc',
    required=False,
    default=0.05,
    type=float,
    help=
    '''Remove all merged islands with an FDR above this value for both up- and downregulation. Default 0.05.'''
)

parser.add_argument(
    '--effective-genome-fraction',
    '-egf',
    required=False,
    type=float,
    help=
    '''Use a different effective genome fraction than the one included in epic2. The default value depends on the genome and readlength, but is a number between 0 and 1.'''
)

parser.add_argument(
    '--chromsizes',
    '-cs',
    required=False,
    type=str,
    help=
    '''Set the chromosome lengths yourself in a file with two columns: chromosome names and sizes. Useful to analyze custom genomes, assemblies or simulated data. Only chromosomes included in the file will be analyzed.'''
)

parser.add_argument(
    '--e-value',
    '-e',
    required=False,
    default=1000,
    type=int,
    help=
    '''The E-value controls the genome-wide error rate of identified islands under the random background assumption. Should be used when not using a control library. Default: 1000.'''
)

parser.add_argument(
    '--required-flag',
    '-f',
    required=False,
    default=0,
    type=int,
    help=
    '''(bam only.) Keep reads with these bits set in flag. Same as `samtools view -f`. Default 0
                   ''')

parser.add_argument(
    '--filter-flag',
    '-F',
    required=False,
    default=1540,
    type=int,
    help=
    '''(bam only.) Discard reads with these bits set in flag. Same as `samtools view -F`. Default 1540 (hex: 0x604).
See https://broadinstitute.github.io/picard/explain-flags.html for more info.
                   ''')

parser.add_argument(
    '--mapq',
    '-m',
    required=False,
    default=5,
    type=int,
    help=
    '''(bam only.) Discard reads with mapping quality lower than this. Default 5.
                   ''')

parser.add_argument(
    '--autodetect-chroms',
    '-a',
    required=False,
    default=False,
    action="store_true",
    help=
    '''(bam only.) Autodetect chromosomes from bam file. Use with --discard-chromosomes flag to avoid non-canonical chromosomes.'''
)

parser.add_argument(
    '--discard-chromosomes-pattern',
    '-d',
    required=False,
    default="_",
    type=str,
    help=
    '''(bam only.) Discard reads from chromosomes matching this pattern. Default
                   '_'. Note that if you are not interested in the results from
                   non-canonical chromosomes, you should ensure they are
                   removed with this flag, otherwise they will make the
                   statistical analysis too stringent.''')

parser.add_argument(
    '--original-statistics',
    required=False,
    action='store_true',
    help=
    '''Use the original SICER way of computing the statistics. Like SICER itself, this method raises an error on large datasets. Only included for debugging-purposes.'''
)

parser.add_argument(
    '--output-knockout',
    '-ok',
    required=required_args,
    type=str,
    help='''File to write knockout results to.''')

parser.add_argument(
    '--output-wildtype',
    '-ow',
    required=required_args,
    type=str,
    help='''File to write wildtype results to.''')

parser.add_argument(
    '--quiet',
    '-q',
    required=False,
    default=False,
    action="store_true",
    help='''Do not write output messages to stderr.''')

parser.add_argument(
    '--example',
    '-ex',
    required=False,
    default=False,
    action="store_true",
    help='''Show the paths of the example data and an example command.''')

parser.add_argument('--version', "-v", action='version', version=__version__)

# def _pvalue(chip_read_count, control_read_count, scaling_factor):
#     if control_read_count > 0:
#         average = control_read_count * scaling_factor
#     else:
#         average = 1 * scaling_factor

#     if chip_read_count > average:
#         pvalue = poisson.sf(chip_read_count, average)
#     else:
#         pvalue = 1

#     return pvalue


def _pvalue(chip, control, scaling_factor):

    average = pd.Series(np.zeros(len(control)))
    average.loc[control > 0] = control[control > 0] * scaling_factor
    average.loc[control == 0] = scaling_factor

    pvalues = pd.Series(np.zeros(len(control)))

    chip_gt_avg = chip > average
    pvalues.loc[chip_gt_avg] = poisson.sf(chip[chip_gt_avg],
                                          average[chip_gt_avg])
    pvalues.loc[~chip_gt_avg] = 1

    return pvalues


if __name__ == "__main__":

    args = vars(parser.parse_args())
    args["df"] = True

    if args["example"]:

        import pkg_resources
        treatment = pkg_resources.resource_filename("epic2",
                                                    "examples/test.bed.gz")
        control = pkg_resources.resource_filename("epic2",
                                                  "examples/control.bed.gz")
        print("Knockout: " + treatment)
        print("Wildtype: " + control)
        print(
            "Example command: epic2-df -fdrc 1 -tk {} -tw {} -ok deleteme_ko.txt -ow deleteme_wt.txt > deleteme.txt"
            .format(treatment, control))
        sys.exit(0)

    args["treatment"] = args["treatment_knockout"]
    args["control"] = args["control_knockout"]
    args["output"] = args["output_knockout"]

    print("Running epic2 on KO.", file=sys.stderr)
    bins_counts_ko = _main(args)
    # print(bins_counts_ko)

    args["treatment"] = args["treatment_wildtype"]
    args["control"] = args["control_wildtype"]
    args["output"] = args["output_wildtype"]

    print("Running epic2 on WT.", file=sys.stderr)
    bins_counts_wt = _main(args)

    print("Comparing islands.", file=sys.stderr)

    counts_wildtype = sum(sum(counts) for _, counts in bins_counts_wt.values())
    counts_knockout = sum(sum(counts) for _, counts in bins_counts_ko.values())

    names = "Chromosome	Start	End	ChIPCount	Score".split()
    dtypes = {"Chromosome": "category", "Start": int, "End": int}
    cols = [0, 1, 2, 3, 4]
    ko = pd.read_csv(
        args["output_knockout"],
        sep="\t",
        names=names,
        usecols=cols,
        skiprows=1,
        dtype=dtypes,
        header=None)
    wt = pd.read_csv(
        args["output_wildtype"],
        sep="\t",
        usecols=cols,
        names=names,
        skiprows=1,
        dtype=dtypes,
        header=None)

    ko = pr.PyRanges(ko)
    wt = pr.PyRanges(wt)

    union = ko.set_union(wt, strandedness=False)

    ko_counts = differential_count_reads_on_islands(union.dfs, bins_counts_ko)
    wt_counts = differential_count_reads_on_islands(union.dfs, bins_counts_wt)

    ko_counts = pd.concat(
        [pd.Series(arr) for _, arr in natsorted(ko_counts.items())])
    wt_counts = pd.concat(
        [pd.Series(arr) for _, arr in natsorted(wt_counts.items())])

    union.KO = ko_counts
    union.WT = wt_counts

    df = union.df.reset_index(drop=True)

    # do wt/ko stats on counts
    scaling_factor = counts_knockout / counts_wildtype
    fc_ko = ((df.KO + 1) / (df.WT + 1)) / scaling_factor
    fc_ko.name = "FC_KO"
    fc_wt = ((df.WT + 1) / (df.KO + 1)) * scaling_factor
    fc_wt.name = "FC_WT"

    p_ko = _pvalue(df.KO, df.WT, scaling_factor)
    p_ko.name = "P_KO"
    p_wt = _pvalue(df.WT, df.KO, scaling_factor)
    p_wt.name = "P_WT"

    fdr_ko = p_ko * len(p_ko) / rankdata(p_ko)
    fdr_wt = p_wt * len(p_wt) / rankdata(p_wt)

    fdr_ko[fdr_ko > 1] = 1
    fdr_ko.name = "FDR_KO"
    fdr_wt[fdr_wt > 1] = 1
    fdr_wt.name = "FDR_WT"

    for c in [fc_ko, fc_wt, p_ko, p_wt, fdr_ko, fdr_wt]:
        df.insert(df.shape[1], c.name, c)

    p = pearsonr(df.KO, df.WT)
    s = spearmanr(df.KO, df.WT)

    diff_fdr = args["false_discovery_rate_comparison"]
    df = df[(df.FDR_KO <= diff_fdr) | (df.FDR_WT <= diff_fdr)]

    print(df.to_csv(sep="\t", index=False), end="")

    print(
        "Pearson's correlation is:",
        p[0],
        "and the p-value is",
        p[1],
        file=sys.stderr)
    print(
        "Spearman's correlation is:",
        s[0],
        "and the p-value is",
        s[1],
        file=sys.stderr)
