/*
 * Decompiled with CFR 0.152.
 */
package projects.dream2016;

import de.jstacs.classifiers.AbstractScoreBasedClassifier;
import de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifier;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifierParameterSet;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LearningPrinciple;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.DoesNothingLogPrior;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.LogPrior;
import de.jstacs.classifiers.performanceMeasures.AucPR;
import de.jstacs.classifiers.performanceMeasures.AucROC;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.alphabets.DoubleSymbolException;
import de.jstacs.data.sequences.ArbitrarySequence;
import de.jstacs.io.AbstractStringExtractor;
import de.jstacs.io.FileManager;
import de.jstacs.io.StringExtractor;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.IndependentProductDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.continuous.GaussianNetwork;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.Time;
import de.jstacs.utils.ToolBox;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.zip.GZIPInputStream;
import projects.dream2016.Aggregation_multi2;
import projects.dream2016.DataParser;

public class ExtractGenomeWideScan {
    private static Time time;
    private static AlphabetContainer con;

    private static void fill(String fName, boolean fg, HashMap<String, ArrayList<Region>> region) throws IOException {
        String line;
        BufferedReader r = new BufferedReader(new FileReader(fName));
        while ((line = r.readLine()) != null) {
            Region reg = new Region(line, fg);
            ArrayList<Region> c = region.get(reg.chr);
            if (c == null) {
                c = new ArrayList();
                region.put(reg.chr, c);
            }
            c.add(reg);
        }
        r.close();
    }

    private static int sort(HashMap<String, ArrayList<Region>> region) {
        int windows = -1;
        for (String chr : region.keySet()) {
            ArrayList<Region> list = region.get(chr);
            if (windows < 0) {
                Region r = list.get(0);
                windows = (int)Math.round((Double.parseDouble(r.split[2]) - Double.parseDouble(r.split[1])) / 50.0);
            }
            Collections.sort(list);
        }
        return windows;
    }

    private static DataParser getDataParser(int offset, String[] args, String confOut) throws IllegalArgumentException, IOException, DoubleSymbolException {
        int col = 0;
        LinkedList<String> conf = new LinkedList<String>();
        int f = offset;
        while (f < args.length) {
            if (args[f].endsWith("hg19.fa_other.txt.gz")) {
                conf.add("Percent\t" + col + "," + (col + 1) + "\tEach\tMeanCenter3");
                conf.add("Entropy\t" + ((col += 2) + 2) + "\tEach\tMaxCenter3");
                col += 3;
            } else if (args[f].endsWith("hg19.fa_tracts.txt.gz")) {
                conf.add("Length\t" + col + "," + (col + 1) + "," + (col + 2) + "," + (col + 3) + "\tEach\tMaxCenter3");
                col += 4;
            } else if (args[f].endsWith("gencode.v19.types.txt.gz")) {
                conf.add("Region\t" + col + "\tEach\tMin");
                ++col;
            } else if (args[f].endsWith(".bigwig-interval.txt.gz")) {
                conf.add("Coverage\t" + col + "\tEach\tMinCenter3");
                conf.add("Coverage\t" + (col + 2) + "," + (col + 3) + "\tEach\tMin");
                conf.add("Coverage\t" + (col + 1) + "\tEach\tEach");
                col += 4;
            } else if (args[f].endsWith(".xml_winscores2.txt.gz")) {
                conf.add("Score\t" + col + "\tEach\tMaxCenter3");
                conf.add("Score\t" + (col + 1) + "\tEach\tCenter");
                conf.add("Score\t" + col + "\tEach\tLogSum");
                col += 2;
                col += 2;
                ++col;
            } else if (args[f].endsWith("_winscores2.txt.gz")) {
                conf.add("Score\t" + col + "\tEach\tMax");
                col += 2;
                col += 2;
                ++col;
            } else if (args[f].endsWith("hg19.genome.fa.seqs.gz")) {
                conf.add("Seq\t" + col + "\tEach\tCenter3");
                ++col;
            } else if (args[f].endsWith("expression-interval.txt.gz")) {
                conf.add("Coverage\t" + col + "\tEach\tCenter");
                ++col;
            } else {
                throw new RuntimeException("Unknown file type");
            }
            ++f;
        }
        conf.addFirst("5\t" + col);
        if (confOut != null) {
            PrintWriter confw = new PrintWriter(String.valueOf(confOut) + ".conf");
            confw.println("0\t" + col);
            int c = 1;
            while (c < conf.size()) {
                confw.println((String)conf.get(c));
                ++c;
            }
            confw.close();
        }
        return new DataParser(conf.toArray(new String[0]));
    }

    public static void main(String[] args) throws Exception {
        time = Time.getTimeInstance(null);
        HashMap<String, ArrayList<Region>> region = new HashMap<String, ArrayList<Region>>();
        ExtractGenomeWideScan.fill(args[0], true, region);
        ExtractGenomeWideScan.fill(args[1], false, region);
        int windows = ExtractGenomeWideScan.sort(region);
        System.out.println(windows);
        DataParser pars = ExtractGenomeWideScan.getDataParser(3, args, args[2]);
        ExtractGenomeWideScan.extract(pars, args, region, windows, null, null);
    }

    private static void extract(DataParser pars, String[] args, HashMap<String, ArrayList<Region>> region, int windows, AbstractScoreBasedClassifier cl, ArrayList<String> chrs) throws Exception {
        int f;
        time.reset();
        String outpath = args[2];
        int offset = 3;
        int l = args.length - offset;
        BufferedReader[] reader = new BufferedReader[l];
        int f2 = 0;
        while (f2 < l) {
            GZIPInputStream stream = new GZIPInputStream(new FileInputStream(args[f2 + offset]));
            reader[f2] = new BufferedReader(new InputStreamReader(stream));
            ++f2;
        }
        int c = 0;
        int start = 0;
        int last = -1;
        int a = -1;
        int p = -1000;
        String chr = null;
        Object[] line = new String[l];
        String[] li = new String[l];
        int[] remove = null;
        Object[] split = null;
        double[] scores = new double[2];
        double ls = 0.0;
        boolean[] read = new boolean[l];
        Arrays.fill(read, true);
        BufferedWriter[] w = cl == null ? new BufferedWriter[]{new BufferedWriter(new FileWriter(String.valueOf(outpath) + "_positives.txt")), new BufferedWriter(new FileWriter(String.valueOf(outpath) + "_negatives.txt")), new BufferedWriter(new FileWriter(String.valueOf(outpath) + "_positives.txt.weights")), new BufferedWriter(new FileWriter(String.valueOf(outpath) + "_negatives.txt.weights"))} : null;
        BufferedWriter gws = cl != null ? new BufferedWriter(new FileWriter(String.valueOf(outpath) + "_positives.txt.gws")) : null;
        ArrayList<Region> reg = null;
        int r = 0;
        boolean use = false;
        block1: while (true) {
            int f3;
            boolean any = false;
            f = 0;
            while (f < l) {
                if (read[f]) {
                    line[f] = reader[f].readLine();
                    boolean bl = read[f] = line[f] != null && ((String)line[f]).charAt(0) != '[';
                }
                if (!read[f]) {
                    if (args[f + offset].contains("DNASE") || args[f + offset].contains("tracts")) {
                        li[f] = "0\t0\t0\t0";
                    } else if (args[f + offset].contains("winscores2")) {
                        li[f] = "Inf\tInf\t0\t0\t..0";
                    } else if (args[f + offset].contains("types")) {
                        li[f] = "chr\tpos\t----------";
                    } else if (args[f + offset].contains("_other")) {
                        li[f] = "0.5\t0.125\t0.0\t0.0\t0.0";
                    } else if (args[f + offset].endsWith("expression-interval.txt.gz")) {
                        li[f] = "0.0";
                    } else if (args[f + offset].endsWith("hg19.genome.fa.seqs.gz")) {
                        li[f] = "CTTAGCGGAAATAGGAGAAACTGTACTAGACGTCCTTGATCGTTATTCGG";
                    } else {
                        System.out.println("WARNING: no default values for " + args[f + offset]);
                    }
                } else {
                    li[f] = line[f];
                }
                any |= read[f];
                ++f;
            }
            if (!any) {
                if (use) {
                    System.out.println("last window\t" + chr + "\t" + p + "\t" + (p + 50) + "\t" + time.getElapsedTime());
                }
                System.out.println(Arrays.toString(line));
                boolean changed = false;
                int f4 = 0;
                while (f4 < l) {
                    int n = 0;
                    while (line[f4] != null && (((String)line[f4]).charAt(0) != '[' || ((String)line[f4]).indexOf(95) >= 0)) {
                        line[f4] = reader[f4].readLine();
                        ++n;
                    }
                    changed |= n > 0;
                    if (line[f4] == null) break block1;
                    if (!((String)line[f4]).equals(line[0])) {
                        throw new Exception("Mismatch:\nFile 0: " + args[offset] + "\nLine: " + (String)line[0] + "\nFile " + f4 + ": " + args[f4 + offset] + "\nLine: " + (String)line[f4]);
                    }
                    read[f4] = true;
                    ++f4;
                }
                if (changed) {
                    System.out.println("=> " + Arrays.toString(line));
                }
                chr = ((String)line[0]).substring(1, ((String)line[0]).length() - 1);
                reg = region == null ? null : region.get(chr);
                r = 0;
                boolean bl = use = chrs == null || chrs.contains(chr);
                if (!use) {
                    System.out.println("skip " + chr);
                }
                c = 0;
                start = 0;
                p = -100000;
                continue;
            }
            if (!use) continue;
            if (remove == null) {
                remove = new int[l];
                f = 0;
                while (f < l) {
                    remove[f] = 0;
                    if (li[f].startsWith(String.valueOf(chr) + "\t")) {
                        int pos = li[f].indexOf(9) + 1;
                        li[f] = li[f].substring(pos);
                        int n = f;
                        remove[n] = remove[n] + 1;
                        if (li[f].startsWith(String.valueOf(start) + "\t")) {
                            pos = li[f].indexOf(9) + 1;
                            li[f] = li[f].substring(pos);
                            int n2 = f;
                            remove[n2] = remove[n2] + 1;
                        }
                    }
                    ++f;
                }
            } else {
                f = 0;
                while (f < l) {
                    int re = 0;
                    while (re < remove[f]) {
                        int pos = li[f].indexOf(9) + 1;
                        li[f] = li[f].substring(pos);
                        ++re;
                    }
                    ++f;
                }
            }
            ++c;
            p = (start += 50) - (windows + 1) / 2 * 50;
            boolean out = false;
            int index = -1;
            Region current = null;
            if (reg != null && r < reg.size() && start == reg.get((int)r).pos) {
                current = reg.get(r);
                index = current.fg ? 0 : 1;
                out = true;
            }
            if (split == null) {
                a = 0;
                f3 = 0;
                while (f3 < li.length) {
                    a += li[f3].split("\t").length;
                    ++f3;
                }
                split = new String[5 + windows * a];
                last = 5 + a * (windows - 1);
            } else {
                System.arraycopy(split, 5 + a, split, 5, last - 5);
            }
            f3 = 0;
            int b = last;
            while (f3 < l) {
                String[] help = li[f3].split("\t");
                System.arraycopy(help, 0, split, b, help.length);
                b += help.length;
                ++f3;
            }
            if (!out && cl == null || c < windows) continue;
            if (current != null) {
                System.arraycopy(current.split, 0, split, 0, 5);
            }
            ArbitrarySequence s = pars.parse(con, (String[])split);
            if (con == null) {
                con = s.getAlphabetContainer();
                FileManager.writeFile(String.valueOf(outpath) + "_positives.txt.alpha", (CharSequence)con.toXML());
            }
            Arrays.fill(split, 0, 5, null);
            if (cl == null) {
                do {
                    current = reg.get(r);
                    w[index].append(s.toString("\t", 0, s.getLength()));
                    w[index].newLine();
                    w[index + 2].append(current.split[4]);
                    w[index + 2].newLine();
                } while (++r < reg.size() && reg.get((int)r).fg == current.fg && reg.get((int)r).pos == current.pos);
                continue;
            }
            gws.append(String.valueOf(chr) + "\t" + p);
            scores[0] = cl.getScore(s, 0);
            scores[1] = cl.getScore(s, 1);
            ls = Normalisation.getLogSum(scores);
            gws.append("\t" + Math.exp(scores[0] - ls));
            gws.newLine();
        }
        System.out.println(String.valueOf(cl == null ? "extract" : "gws") + " - elapsed time: " + time.getElapsedTime());
        if (w != null) {
            int i = 0;
            while (i < w.length) {
                w[i].close();
                ++i;
            }
        }
        if (gws != null) {
            gws.close();
        }
        f = 0;
        while (f < l) {
            reader[f].close();
            ++f;
        }
    }

    private static AbstractScoreBasedClassifier create(int threads) throws Exception {
        int l = con.getPossibleLength();
        GenDisMixClassifierParameterSet ps = new GenDisMixClassifierParameterSet(con, l, 10, 1.0E-6, 1.0E-7, 1.0, false, OptimizableFunction.KindOfParameter.ZEROS, true, threads);
        DoesNothingLogPrior prior = DoesNothingLogPrior.defaultInstance;
        int[][] struc = new int[1][0];
        ArrayList<GaussianNetwork> funs = new ArrayList<GaussianNetwork>();
        int i = 0;
        while (i < con.getPossibleLength()) {
            funs.add(new GaussianNetwork(con.getSubContainer(i, 1), struc));
            ++i;
        }
        IndependentProductDiffSM model = new IndependentProductDiffSM(1.0, true, funs.toArray(new DifferentiableStatisticalModel[0]));
        GenDisMixClassifier cl = new GenDisMixClassifier(ps, (LogPrior)prior, LearningPrinciple.MCL, model, model);
        cl.setOutputStream(null);
        return cl;
    }

    private static void train(String pFile, String nFile, int idx, AbstractScoreBasedClassifier cl) throws Exception {
        time.reset();
        DataSet[] data = new DataSet[]{new DataSet(con, (AbstractStringExtractor)new StringExtractor(new File(pFile), 1000, '#'), " "), new DataSet(con, (AbstractStringExtractor)new StringExtractor(new File(nFile), 1000, '#'), " ")};
        double[][] weights = new double[][]{DataParser.getWeights(String.valueOf(pFile) + ".weights", DataParser.Weighting.ONE), DataParser.getWeights(String.valueOf(nFile) + ".weights", DataParser.Weighting.DIRECT)};
        int i = 0;
        while (i < data.length) {
            System.out.println(String.valueOf(i) + ": #=" + data[i].getNumberOfElements() + ", length=" + data[i].getElementLength() + ", " + data[i].getAnnotation());
            ++i;
        }
        cl.train(data, weights);
        System.out.println("train - elapsed time: " + time.getElapsedTime());
        if (idx >= 0) {
            StringBuffer xml = new StringBuffer();
            XMLParser.appendObjectWithTags(xml, cl, "classifier");
            FileManager.writeFile(String.valueOf(pFile) + "-classifier-" + idx + ".xml", (CharSequence)xml);
        }
    }

    private static void evaluate(HashMap<String, ArrayList<int[]>> hash, ArrayList<String> chrs, String gwsFileName, String truth, int col, int wi, int anz, HashMap<String, ArrayList<Region>> region, int windows) throws IOException {
        double th = ExtractGenomeWideScan.agg(hash, chrs, gwsFileName, truth, col, wi, Double.NaN, anz, null, 0);
        System.out.println("threshold: " + th);
        ExtractGenomeWideScan.agg(hash, chrs, gwsFileName, truth, col, wi, th, 0, region, windows);
    }

    private static double agg(HashMap<String, ArrayList<int[]>> hash, ArrayList<String> chrs, String gwsFileName, String truth, int col, int wi, double th, int anz, HashMap<String, ArrayList<Region>> region, int bins) throws NumberFormatException, IOException {
        String line;
        BufferedReader reader = new BufferedReader(new FileReader(gwsFileName));
        int windows = 2 * wi;
        int offset = (wi - 2) * 50;
        String chr = null;
        ArrayList<int[]> intervals = null;
        int[] inter = null;
        HashMap<String, ArrayList> predictions = new HashMap<String, ArrayList>();
        ArrayList currentPred = null;
        double[] p = null;
        double[] values = null;
        int idx = -1;
        int h = -1;
        while ((line = reader.readLine()) != null) {
            int start;
            String[] split = line.split("\t");
            if (chr == null || !chr.equals(split[0])) {
                chr = split[0];
                intervals = hash.get(chr);
                if (intervals != null) {
                    currentPred = new ArrayList();
                    predictions.put(chr, currentPred);
                    idx = 0;
                    inter = intervals.get(idx);
                    p = new double[(inter[1] - inter[0]) / 50 + 1 + 2 * (wi - 2)];
                    currentPred.add(p);
                    h = 0;
                } else {
                    inter = null;
                    idx = -1;
                    p = null;
                }
            }
            if (intervals == null || (start = Integer.parseInt(split[1])) < inter[0] - offset) continue;
            if (start <= inter[1] + offset) {
                if (values == null) {
                    values = new double[split.length - 2];
                }
                int i = 2;
                while (i < split.length) {
                    values[i - 2] = Double.parseDouble(split[i]);
                    ++i;
                }
                p[h] = values[0];
                ++h;
                continue;
            }
            if (idx + 1 >= intervals.size()) continue;
            if (++idx < intervals.size()) {
                inter = intervals.get(idx);
                p = new double[(inter[1] - inter[0]) / 50 + 1 + 2 * (wi - 2)];
                currentPred.add(p);
                h = 0;
                continue;
            }
            h = -1;
        }
        reader.close();
        Aggregation_multi2.Aggregate a = Aggregation_multi2.Aggregate.Prod;
        DoubleList pos = new DoubleList();
        DoubleList neg = new DoubleList();
        GZIPInputStream stream = new GZIPInputStream(new FileInputStream(truth));
        BufferedReader r = new BufferedReader(new InputStreamReader(stream));
        String first = r.readLine();
        if (Double.isNaN(th)) {
            System.out.println("Compute performance for cell type: " + first.split("\t")[col]);
        }
        int x = 50 * (bins - 1) / 2;
        int num = 0;
        int c = 0;
        while (c < chrs.size()) {
            chr = chrs.get(c);
            currentPred = (ArrayList)predictions.get(chr);
            if (currentPred != null) {
                intervals = hash.get(chr);
                int i = 0;
                while (i < currentPred.size()) {
                    inter = intervals.get(i);
                    p = (double[])currentPred.get(i);
                    int start = inter[0];
                    int j = 0;
                    while (j < p.length - windows) {
                        double ag = -1.0;
                        switch (a) {
                            case Max: {
                                ag = ToolBox.max(j, j + windows, p);
                                break;
                            }
                            case Mean: {
                                ag = ToolBox.mean(j, j + windows, p);
                                break;
                            }
                            case Prod: {
                                ag = 1.0;
                                int k = j;
                                while (k < j + windows) {
                                    ag *= 1.0 - p[k];
                                    ++k;
                                }
                                ag = 1.0 - ag;
                                break;
                            }
                            case Median: {
                                ag = ToolBox.median(j, j + windows, p);
                                break;
                            }
                            default: {
                                throw new IllegalArgumentException("not implemented: " + (Object)((Object)a));
                            }
                        }
                        if (r != null) {
                            String s = r.readLine();
                            String[] split = s.split("\t");
                            if (!(j != 0 || chr.equals(split[0]) && split[1].equals("" + start))) {
                                System.out.println();
                                System.out.println(s);
                                System.out.println(Arrays.toString(inter));
                                System.out.println(String.valueOf(chr) + "\t" + start + "\t" + (start + 200));
                                System.exit(1);
                            }
                            if (region == null) {
                                switch (split[col].charAt(0)) {
                                    case 'B': 
                                    case 'b': {
                                        pos.add(ag);
                                        break;
                                    }
                                    case 'U': 
                                    case 'u': {
                                        neg.add(ag);
                                    }
                                }
                            } else if (ag >= th && (split[col].charAt(0) == 'U' || split[col].charAt(0) == 'u')) {
                                ArrayList<Region> list = region.get(chr);
                                list.add(new Region(String.valueOf(chr) + "\t" + (start - x) + "\t" + (start + x + 50) + "\tn\t1", false));
                                ++num;
                            }
                        }
                        start += 50;
                        ++j;
                    }
                    ++i;
                }
            } else {
                System.out.println("WARNING: Did not find predictions for " + chr);
            }
            ++c;
        }
        r.close();
        if (Double.isNaN(th)) {
            double[] po = pos.toArray();
            double[] ne = neg.toArray();
            Arrays.sort(po);
            Arrays.sort(ne);
            System.out.println("#positives: " + po.length + "\t" + po[0] + " .. " + po[po.length - 1]);
            System.out.println("#negatives: " + ne.length + "\t" + ne[0] + " .. " + ne[ne.length - 1]);
            System.out.println("random: " + (double)po.length / (double)(po.length + ne.length));
            System.out.println();
            System.out.println(String.valueOf(windows) + "\t" + (Object)((Object)a));
            System.out.println(new AucROC().compute(po, ne));
            System.out.println(new AucPR().compute(po, ne));
            System.out.println(String.valueOf(ne.length) + " vs " + anz);
            return Math.max(ne[Math.max(ne.length - anz, 0)], po[0]);
        }
        System.out.println("add " + num + " negatives");
        return Double.NaN;
    }

    static class Region
    implements Comparable<Region> {
        int pos;
        String chr;
        String line;
        boolean fg;
        String[] split;

        Region(String line, boolean fg) {
            this.line = line;
            this.split = line.split("\t");
            this.chr = this.split[0];
            this.pos = (int)Math.round(Double.parseDouble(this.split[2]));
            this.fg = fg;
        }

        @Override
        public int compareTo(Region o) {
            return Integer.compare(this.pos, o.pos);
        }
    }
}

