/*
 * Decompiled with CFR 0.152.
 */
package projects.dream2016;

import de.jstacs.classifiers.AbstractClassifier;
import de.jstacs.classifiers.AbstractScoreBasedClassifier;
import de.jstacs.classifiers.assessment.KFoldCrossValidation;
import de.jstacs.classifiers.assessment.KFoldCrossValidationAssessParameterSet;
import de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifierParameterSet;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LearningPrinciple;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.DoesNothingLogPrior;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.LogPrior;
import de.jstacs.classifiers.performanceMeasures.AbstractPerformanceMeasureParameterSet;
import de.jstacs.classifiers.performanceMeasures.NumericalPerformanceMeasureParameterSet;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.io.AbstractStringExtractor;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.FileManager;
import de.jstacs.io.StringExtractor;
import de.jstacs.io.XMLParser;
import de.jstacs.results.ListResult;
import de.jstacs.results.ResultSet;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.continuous.SingleGaussianDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.directedGraphicalModels.MarkovModelDiffSM;
import de.jstacs.utils.DefaultProgressUpdater;
import de.jstacs.utils.IntList;
import de.jstacs.utils.ProgressUpdater;
import de.jstacs.utils.Time;
import java.io.File;
import java.util.ArrayList;
import projects.dream2016.mix.NewMixtureClassifier;
import projects.dream2016.mix.SubDiffSM;

public class NewDreamTraining {
    static NumericalPerformanceMeasureParameterSet p;
    static double ess;

    static {
        try {
            p = AbstractPerformanceMeasureParameterSet.createFilledParameters();
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
        ess = 1.0;
    }

    public static void main(String[] args) throws Exception {
        AbstractScoreBasedClassifier cl = NewDreamTraining.test(args[0], args[1], 5, LearningPrinciple.getBeta(LearningPrinciple.MCL), 0, 1.0E-6, false, false, Integer.parseInt(args[2]));
        if (cl != null) {
            System.out.println(cl);
            StringBuffer xml = new StringBuffer();
            XMLParser.appendObjectWithTags(xml, cl, "classifier");
            FileManager.writeFile(String.valueOf(args[0]) + "-classifier.xml", (CharSequence)xml);
        }
    }

    static AbstractScoreBasedClassifier test(String pFile, String nFile, int k, double[] beta, int order, double eps, boolean additional, boolean finalTrain, int threads) throws Exception {
        Time t = Time.getTimeInstance(null);
        AlphabetContainer con = new AlphabetContainer(FileManager.readFile(String.valueOf(pFile) + ".alpha"));
        DataSet[] data = new DataSet[]{new DataSet(con, (AbstractStringExtractor)new StringExtractor(new File(pFile), 1000, '#'), "\t"), new DataSet(con, (AbstractStringExtractor)new StringExtractor(new File(nFile), 1000, '#'), "\t")};
        int i = 0;
        while (i < data.length) {
            System.out.println(String.valueOf(i) + ": #=" + data[i].getNumberOfElements() + ", length=" + data[i].getElementLength() + ", " + data[i].getAnnotation());
            ++i;
        }
        ArrayList<AbstractDifferentiableStatisticalModel> fun = new ArrayList<AbstractDifferentiableStatisticalModel>();
        IntList lens = new IntList();
        int j = 0;
        int l = data[0].getElementLength();
        while (j < l) {
            if (con.isDiscreteAt(j)) {
                int start = j;
                while (j < l && con.isDiscreteAt(j)) {
                    ++j;
                }
                fun.add(new MarkovModelDiffSM(con.getSubContainer(start, j - start), j - start, ess, false, order, null));
                lens.add(j - start);
                continue;
            }
            fun.add(new SingleGaussianDiffSM(con.getSubContainer(j, 1), ess, 1.0, 1.0, 1.0, false));
            lens.add(1);
            ++j;
        }
        GenDisMixClassifierParameterSet ps = new GenDisMixClassifierParameterSet(con, l, 10, eps, 1.0E-6, 1.0, false, OptimizableFunction.KindOfParameter.ZEROS, true, threads);
        DoesNothingLogPrior prior = DoesNothingLogPrior.defaultInstance;
        AbstractClassifier[] cl = new AbstractScoreBasedClassifier[(additional ? 2 : 1) * fun.size()];
        int start = 0;
        DifferentiableStatisticalModel[] mix2 = (DifferentiableStatisticalModel[])ArrayHandler.createArrayOf((Cloneable)new SubDiffSM(con, l, (DifferentiableStatisticalModel)fun.get(0), 0), (int)2);
        int f = 0;
        while (f < fun.size()) {
            SubDiffSM model = new SubDiffSM(con, l, (DifferentiableStatisticalModel)fun.get(f), start);
            cl[f] = new NewMixtureClassifier.OptimizableMSPClassifier(ps, (LogPrior)prior, model, model);
            ((NewMixtureClassifier.OptimizableMSPClassifier)cl[f]).setOutputStream(null);
            start += lens.get(f);
            if (additional) {
                AbstractScoreBasedClassifier[] comp = (AbstractScoreBasedClassifier[])ArrayHandler.createArrayOf((Cloneable)cl[f], (int)mix2.length);
                cl[fun.size() + f] = new NewMixtureClassifier(threads, NewMixtureClassifier.Training.COMBINED, 3, mix2, comp, NewMixtureClassifier.Vote.VOC, ((LogPrior)prior).getNewInstance());
            }
            ++f;
        }
        KFoldCrossValidationAssessParameterSet params = new KFoldCrossValidationAssessParameterSet(DataSet.PartitionMethod.PARTITION_BY_NUMBER_OF_ELEMENTS, l, true, k);
        KFoldCrossValidation cross = new KFoldCrossValidation(cl);
        ListResult lr = cross.assess(p, params, (ProgressUpdater)new DefaultProgressUpdater(), data);
        System.out.println(lr);
        ResultSet[] r = lr.getValue();
        int idx = -1;
        double auc = -1.0;
        int i2 = 0;
        while (i2 < r.length) {
            double a = (Double)r[i2].getResultForName("AUC-PR (Integral)").getValue();
            if (a > auc) {
                auc = a;
                idx = i2;
            }
            ++i2;
        }
        System.out.println("time: " + t.getElapsedTime());
        System.out.println("classifier index: " + idx);
        System.out.println("AUC-PR: " + auc);
        if (finalTrain) {
            cl[idx].train(data);
            return cl[idx];
        }
        return null;
    }
}

