/*
 * Decompiled with CFR 0.152.
 */
package projects.dream2016;

import de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifier;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifierParameterSet;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LearningPrinciple;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.DoesNothingLogPrior;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.LogPrior;
import de.jstacs.classifiers.performanceMeasures.AbstractPerformanceMeasureParameterSet;
import de.jstacs.classifiers.performanceMeasures.NumericalPerformanceMeasureParameterSet;
import de.jstacs.classifiers.performanceMeasures.PerformanceMeasure;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.results.CategoricalResult;
import de.jstacs.results.ListResult;
import de.jstacs.results.MeanResultSet;
import de.jstacs.results.NumericalResultSet;
import de.jstacs.results.ResultSet;
import de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.IndependentProductDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.continuous.ConstantDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.continuous.DirichletDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.continuous.ExpGammaDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.continuous.PoissonDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.continuous.SingleGaussianDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.directedGraphicalModels.MarkovModelDiffSM;
import de.jstacs.utils.IntList;
import java.util.ArrayList;
import java.util.Arrays;
import projects.dream2016.DataParser;

public class DreamTraining {
    public static void main(String[] args) throws Exception {
        String[] conf = new String[args.length - 2];
        System.arraycopy(args, 2, conf, 0, args.length - 2);
        DreamTraining.featureSelection(args[0], args[1], 0.1, 1.0E-5, LearningPrinciple.getBeta(LearningPrinciple.MCL), 5, false, 0, conf);
    }

    public static void featureSelection(String pFile, String nFile, double perc, double eps, double[] beta, int r, boolean constant, int order, String[] conf) throws Exception {
        NumericalPerformanceMeasureParameterSet p = AbstractPerformanceMeasureParameterSet.createFilledParameters();
        ResultSet[] resTrain = new MeanResultSet[conf.length];
        ResultSet[] resTest = new MeanResultSet[conf.length];
        int i = 0;
        int max = -1;
        double ess = 1.0;
        String[] stringArray = conf;
        int n = conf.length;
        int n2 = 0;
        while (n2 < n) {
            String c = stringArray[n2];
            System.out.println(c);
            CategoricalResult cr = new CategoricalResult("conf", "", c);
            resTrain[i] = new MeanResultSet(cr);
            resTest[i] = new MeanResultSet(cr);
            DataParser parser = new DataParser(c);
            DataSet pos = parser.parseData(pFile, (double)max);
            DataSet neg = parser.parseData(nFile, (double)max);
            AlphabetContainer con = pos.getAlphabetContainer();
            System.out.println(String.valueOf(pos.getElementLength()) + ", #= " + pos.getNumberOfElements());
            System.out.println(String.valueOf(neg.getElementLength()) + ", #= " + neg.getNumberOfElements());
            int j = 0;
            int l = pos.getElementLength();
            ArrayList<AbstractDifferentiableStatisticalModel> fun = new ArrayList<AbstractDifferentiableStatisticalModel>();
            IntList lens = new IntList();
            while (j < l) {
                if (con.isDiscreteAt(j)) {
                    int start = j;
                    while (j < l && con.isDiscreteAt(j)) {
                        ++j;
                    }
                    fun.add(new MarkovModelDiffSM(con.getSubContainer(start, j - start), j - start, ess, false, order, null));
                    lens.add(j - start);
                    continue;
                }
                if (con.getAlphabetLengthAt(j) == 1.0) {
                    DirichletDiffSM dir = new DirichletDiffSM(con.getSubContainer(j, 1), 1, new double[]{ess / 2.0, ess / 2.0}, ess, 1);
                    dir.initializeFunctionRandomly(false);
                    fun.add(dir);
                } else if (con.getMin(j) == 0.0) {
                    if (con.getAlphabetLengthAt(j) == 2.147483647E9) {
                        fun.add(new PoissonDiffSM(ess));
                    } else {
                        ExpGammaDiffSM gam = new ExpGammaDiffSM(con.getSubContainer(j, 1), 1, ess, new double[]{2.0}, new double[]{2.0}, false);
                        gam.initializeFunctionRandomly(false);
                        fun.add(gam);
                    }
                } else if (con.getMin(j) == -1.7976931348623157E308) {
                    fun.add(new SingleGaussianDiffSM(con.getSubContainer(j, 1), ess, 1.0, 1.0, 1.0, false));
                }
                System.out.println(fun.get(fun.size() - 1));
                lens.add(1);
                ++j;
            }
            IndependentProductDiffSM fg = new IndependentProductDiffSM(1.0, false, fun.toArray(new DifferentiableStatisticalModel[0]), lens.toArray());
            AbstractDifferentiableSequenceScore bg = constant ? new ConstantDiffSM(fg.getLength()) : fg;
            GenDisMixClassifierParameterSet ps = new GenDisMixClassifierParameterSet(con, pos.getElementLength(), 10, eps, 1.0E-6, 1.0, false, OptimizableFunction.KindOfParameter.LAST, true, 4);
            GenDisMixClassifier cl = new GenDisMixClassifier(ps, (LogPrior)DoesNothingLogPrior.defaultInstance, beta, new DifferentiableStatisticalModel[]{fg, bg});
            int repeat = 0;
            while (repeat < r) {
                DataSet[] posSplit = pos.partition(DataSet.PartitionMethod.PARTITION_BY_NUMBER_OF_ELEMENTS, perc, 1.0 - perc);
                DataSet[] negSplit = neg.partition(DataSet.PartitionMethod.PARTITION_BY_NUMBER_OF_ELEMENTS, perc, 1.0 - perc);
                cl.train(posSplit[0], negSplit[0]);
                System.out.println(cl);
                ((MeanResultSet)resTrain[i]).addResults((NumericalResultSet)cl.evaluate((AbstractPerformanceMeasureParameterSet<? extends PerformanceMeasure>)p, true, posSplit[0], negSplit[0]));
                ((MeanResultSet)resTest[i]).addResults((NumericalResultSet)cl.evaluate((AbstractPerformanceMeasureParameterSet<? extends PerformanceMeasure>)p, true, posSplit[1], negSplit[1]));
                ++repeat;
            }
            ++i;
            ++n2;
        }
        System.out.println();
        System.out.println(String.valueOf(Arrays.toString(beta)) + "\t" + constant);
        System.out.println();
        System.out.println("Train " + perc);
        System.out.println(new ListResult("", "", null, resTrain));
        System.out.println();
        System.out.println("Test " + (1.0 - perc));
        System.out.println(new ListResult("", "", null, resTest));
    }
}

