/*
 * Decompiled with CFR 0.152.
 */
package projects.dream2016;

import de.jstacs.classifiers.AbstractClassifier;
import de.jstacs.classifiers.AbstractScoreBasedClassifier;
import de.jstacs.classifiers.assessment.KFoldCrossValidation;
import de.jstacs.classifiers.assessment.KFoldCrossValidationAssessParameterSet;
import de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.ScoreClassifier;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifier;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifierParameterSet;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LearningPrinciple;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.DoesNothingLogPrior;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.LogPrior;
import de.jstacs.classifiers.performanceMeasures.AbstractPerformanceMeasureParameterSet;
import de.jstacs.classifiers.performanceMeasures.NumericalPerformanceMeasureParameterSet;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.io.AbstractStringExtractor;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.FileManager;
import de.jstacs.io.StringExtractor;
import de.jstacs.io.XMLParser;
import de.jstacs.results.ListResult;
import de.jstacs.results.ResultSet;
import de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.IndependentProductDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.continuous.ConstantDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.continuous.GaussianNetwork;
import de.jstacs.sequenceScores.statisticalModels.differentiable.directedGraphicalModels.BayesianNetworkDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.directedGraphicalModels.structureLearning.measures.FixedStructure;
import de.jstacs.utils.DefaultProgressUpdater;
import de.jstacs.utils.IntList;
import de.jstacs.utils.ProgressUpdater;
import de.jstacs.utils.Time;
import java.io.File;
import java.util.ArrayList;
import projects.dream2016.DataParser;
import projects.dream2016.mix.NewMixtureClassifier;
import projects.dream2016.mix.SubDiffSM;

public class NewDreamTrainingW {
    static NumericalPerformanceMeasureParameterSet p;
    static double ess;

    static {
        try {
            p = AbstractPerformanceMeasureParameterSet.createFilledParameters();
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
        ess = 1.0;
    }

    public static void main(String[] args) throws Exception {
        AbstractScoreBasedClassifier cl = NewDreamTrainingW.test(args[0], args[1], 5, LearningPrinciple.getBeta(LearningPrinciple.MCL), 0, 1.0E-6, true, Integer.parseInt(args[2]), Integer.parseInt(args[3]), args[4], Boolean.parseBoolean(args[5]));
        if (cl != null) {
            System.out.println(cl);
            StringBuffer xml = new StringBuffer();
            XMLParser.appendObjectWithTags(xml, cl, "classifier");
            FileManager.writeFile(String.valueOf(args[0]) + "-classifier-" + args[4] + "_" + args[5] + ".xml", (CharSequence)xml);
        }
    }

    static AbstractScoreBasedClassifier test(String pFile, String nFile, int k, double[] beta, int order, double eps, boolean finalTrain, int threads, int epigram, String type, boolean useDeps) throws Exception {
        Time t = Time.getTimeInstance(null);
        AlphabetContainer con = new AlphabetContainer(FileManager.readFile(String.valueOf(pFile) + ".alpha"));
        System.out.println(con);
        DataSet[] data = new DataSet[]{new DataSet(con, (AbstractStringExtractor)new StringExtractor(new File(pFile), 1000, '#'), "\t"), new DataSet(con, (AbstractStringExtractor)new StringExtractor(new File(nFile), 1000, '#'), "\t")};
        double[][] weights = new double[][]{DataParser.getWeights(String.valueOf(pFile) + ".weights", DataParser.Weighting.SIGNAL), DataParser.getWeights(String.valueOf(nFile) + ".weights", DataParser.Weighting.DIRECT)};
        int i = 0;
        while (i < data.length) {
            System.out.println(String.valueOf(i) + ": #=" + data[i].getNumberOfElements() + ", length=" + data[i].getElementLength() + ", " + data[i].getAnnotation());
            ++i;
        }
        ArrayList<AbstractDifferentiableStatisticalModel> fun = new ArrayList<AbstractDifferentiableStatisticalModel>();
        IntList lens = new IntList();
        int d = 0;
        while (!con.isDiscreteAt(d)) {
            ++d;
        }
        Object structure = new int[d][];
        if (useDeps) {
            structure[1] = new int[0];
            structure[2] = new int[]{1};
            structure[0] = new int[]{2};
            int c = 3 + (d - 3) / 2;
            structure[c] = new int[1];
            int h = 0;
            while (c + h + 1 < d) {
                structure[c + h + 1] = new int[]{c + h};
                structure[c - h - 1] = new int[]{c - h};
                ++h;
            }
        } else {
            structure = new int[d][0];
        }
        fun.add(new GaussianNetwork(con.getSubContainer(0, d), (int[][])structure));
        lens.add(d);
        structure = new int[10][];
        if (useDeps) {
            structure[4] = new int[0];
            structure[9] = new int[]{4};
            structure[3] = new int[]{4};
            structure[8] = new int[]{9};
            structure[2] = new int[]{3};
            structure[7] = new int[]{8};
            structure[1] = new int[]{3};
            structure[6] = new int[]{8};
            structure[0] = new int[]{2};
            structure[5] = new int[]{7};
        } else {
            structure = new int[10][0];
        }
        fun.add(new BayesianNetworkDiffSM(con.getSubContainer(d, 10), 10, 4.0, true, new FixedStructure((int[][])structure)));
        lens.add(10);
        structure = new int[7][];
        if (useDeps) {
            structure[0] = new int[0];
            structure[1] = new int[1];
            structure[2] = new int[0];
            structure[5] = new int[1];
            structure[3] = new int[]{5};
            structure[6] = new int[1];
            structure[4] = new int[]{6};
        } else {
            structure = new int[7][0];
        }
        fun.add(new GaussianNetwork(con.getSubContainer(lens.get(0) + lens.get(1), 7), (int[][])structure));
        lens.add(7);
        if (epigram > 0) {
            structure = new int[epigram][0];
            fun.add(new GaussianNetwork((int[][])structure));
            lens.add(epigram);
        }
        int le = 0;
        int i2 = 0;
        while (i2 < lens.length()) {
            le += lens.get(i2);
            ++i2;
        }
        structure = new int[1][0];
        int l = con.getPossibleLength();
        while (le < l) {
            fun.add(new GaussianNetwork(con.getSubContainer(le, 1), (int[][])structure));
            lens.add(1);
            ++le;
        }
        System.out.println(String.valueOf(le) + "\t" + l);
        ArrayList<AbstractScoreBasedClassifier> cl = new ArrayList<AbstractScoreBasedClassifier>();
        GenDisMixClassifierParameterSet ps = new GenDisMixClassifierParameterSet(con, l, 10, eps, 1.0E-6, 1.0, false, OptimizableFunction.KindOfParameter.ZEROS, true, threads);
        DoesNothingLogPrior prior = DoesNothingLogPrior.defaultInstance;
        int start = 0;
        switch (type = type.toLowerCase()) {
            case "single": {
                GenDisMixClassifier gdsm;
                AbstractDifferentiableSequenceScore model;
                int f = 0;
                while (f < fun.size()) {
                    model = new SubDiffSM(con, l, (DifferentiableStatisticalModel)fun.get(f), start);
                    gdsm = new GenDisMixClassifier(ps, (LogPrior)prior, beta, new DifferentiableStatisticalModel[]{model, model});
                    gdsm.setOutputStream(null);
                    cl.add(gdsm);
                    start += lens.get(f);
                    ++f;
                }
                break;
            }
            case "ipsf": {
                IndependentProductDiffSM ipsf = new IndependentProductDiffSM(ess, true, fun.toArray(new DifferentiableStatisticalModel[0]));
                GenDisMixClassifier gdsm = new GenDisMixClassifier(ps, (LogPrior)prior, beta, ipsf, ipsf);
                cl.add(gdsm);
                break;
            }
            case "mix_each": {
                AbstractScoreBasedClassifier[] comp = new NewMixtureClassifier.OptimizableMSPClassifier[fun.size()];
                DifferentiableStatisticalModel[] mix = new DifferentiableStatisticalModel[comp.length];
                start = 0;
                int i3 = 0;
                while (i3 < comp.length) {
                    SubDiffSM s = new SubDiffSM(con, l, (DifferentiableStatisticalModel)fun.get(i3), start);
                    comp[i3] = new NewMixtureClassifier.OptimizableMSPClassifier(ps, (LogPrior)prior, s, s);
                    ((ScoreClassifier)comp[i3]).setOutputStream(null);
                    mix[i3] = new ConstantDiffSM(con, l);
                    start += lens.get(i3);
                    ++i3;
                }
                cl.add(new NewMixtureClassifier(threads, NewMixtureClassifier.Training.COMBINED, 3, mix, comp, NewMixtureClassifier.Vote.VOC, prior));
                break;
            }
            case "mix_dnase_single": 
            case "mix_dnase_block": {
                DifferentiableStatisticalModel[] mix = (DifferentiableStatisticalModel[])ArrayHandler.createArrayOf((Cloneable)(type.endsWith("single") ? new SubDiffSM(con, l, new GaussianNetwork(new int[1][0]), 0) : new SubDiffSM(con, l, (DifferentiableStatisticalModel)fun.get(0), 0)), (int)2);
                AbstractDifferentiableSequenceScore model = new IndependentProductDiffSM(ess, true, fun.toArray(new DifferentiableStatisticalModel[0]));
                NewMixtureClassifier.OptimizableMSPClassifier o = new NewMixtureClassifier.OptimizableMSPClassifier(ps, (LogPrior)prior, new DifferentiableStatisticalModel[]{model, model});
                o.setOutputStream(null);
                cl.add(new NewMixtureClassifier(threads, NewMixtureClassifier.Training.COMBINED, 3, mix, (AbstractScoreBasedClassifier[])ArrayHandler.createArrayOf((Cloneable)o, (int)mix.length), NewMixtureClassifier.Vote.VOC, ((LogPrior)prior).getNewInstance()));
                break;
            }
            default: {
                throw new IllegalArgumentException("Unkown type: " + type);
            }
        }
        int idx = -1;
        double auc = -1.0;
        if (k > 0) {
            KFoldCrossValidationAssessParameterSet params = new KFoldCrossValidationAssessParameterSet(DataSet.PartitionMethod.PARTITION_BY_NUMBER_OF_ELEMENTS, l, true, k);
            KFoldCrossValidation cross = new KFoldCrossValidation((AbstractClassifier[])cl.toArray(new AbstractScoreBasedClassifier[0]));
            ListResult lr = cross.assess(p, params, (ProgressUpdater)new DefaultProgressUpdater(), data, weights);
            System.out.println(lr);
            ResultSet[] r = lr.getValue();
            int i4 = 0;
            while (i4 < r.length) {
                double a = (Double)r[i4].getResultForName("AUC-PR (Integral)").getValue();
                if (a > auc) {
                    auc = a;
                    idx = i4;
                }
                ++i4;
            }
        } else {
            idx = 0;
        }
        System.out.println("time: " + t.getElapsedTime());
        System.out.println("classifier index: " + idx);
        System.out.println("AUC-PR: " + auc);
        if (finalTrain) {
            ((AbstractScoreBasedClassifier)cl.get(idx)).train(data, weights);
            return (AbstractScoreBasedClassifier)cl.get(idx);
        }
        return null;
    }
}

