/*
 * Decompiled with CFR 0.152.
 */
package projects.dream2016;

import de.jstacs.algorithms.optimization.termination.SmallDifferenceOfFunctionEvaluationsCondition;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.alphabets.Alphabet;
import de.jstacs.data.alphabets.ContinuousAlphabet;
import de.jstacs.data.sequences.ArbitrarySequence;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.AbstractStringExtractor;
import de.jstacs.io.StringExtractor;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.DifferentiableHigherOrderHMM;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.DifferentiableEmission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.SilentEmission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.continuous.GaussianEmission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.ViterbiParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.elements.TransitionElement;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Pair;
import de.jstacs.utils.ToolBox;
import java.io.File;
import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.LinkedList;

public class DNaseHMMFewStates {
    private static DataSet medianSmoothing(DataSet data, int window) throws Exception {
        LinkedList<Sequence> seqs = new LinkedList<Sequence>();
        int i = 0;
        while (i < data.getNumberOfElements()) {
            ArbitrarySequence seq = (ArbitrarySequence)data.getElementAt(i);
            double[] temp = new double[window];
            double[] res = new double[seq.getLength() / window];
            int j = 0;
            while (j < res.length) {
                int k = 0;
                while (k < temp.length) {
                    temp[k] = seq.continuousVal(j * window + k);
                    ++k;
                }
                res[j] = ToolBox.median(temp);
                ++j;
            }
            double med = ToolBox.median(temp);
            if (med == 0.0) {
                med = 1.0;
            }
            int j2 = 0;
            while (j2 < res.length) {
                int n = j2++;
                res[n] = res[n] / med;
            }
            seqs.add(new ArbitrarySequence(seq.getAlphabetContainer(), res));
            ++i;
        }
        return new DataSet("", seqs);
    }

    private static double[] getMean(DataSet data) {
        double[] mean = new double[data.getElementLength()];
        int i = 0;
        while (i < data.getNumberOfElements()) {
            Sequence seq = data.getElementAt(i);
            int j = 0;
            while (j < seq.getLength()) {
                int n = j;
                mean[n] = mean[n] + seq.continuousVal(j);
                ++j;
            }
            ++i;
        }
        double n = data.getNumberOfElements();
        int i2 = 0;
        while (i2 < mean.length) {
            int n2 = i2++;
            mean[n2] = mean[n2] / n;
        }
        return mean;
    }

    public static void main(String[] args) throws Exception {
        AlphabetContainer ca = new AlphabetContainer((Alphabet)new ContinuousAlphabet());
        DataSet data = new DataSet(ca, (AbstractStringExtractor)new StringExtractor(new File(args[0]), 1000, '#'), "\t");
        data = DNaseHMMFewStates.medianSmoothing(data, 5);
        double[] mean = DNaseHMMFewStates.getMean(data);
        double ma = ToolBox.max(mean);
        System.out.println(Arrays.toString(mean));
        System.out.println("num: " + data.getNumberOfElements());
        System.out.println("len: " + data.getElementLength());
        boolean useLikelihood = false;
        int seqLen = data.getElementLength();
        double ess = 4.0;
        double locEss = ess / 5.0 * (double)seqLen;
        DifferentiableEmission[] emissions = new DifferentiableEmission[]{new GaussianEmission(locEss, ca, 0.0, 1.0, 1.0, true), new GaussianEmission(locEss, ca, 3.0, 1.0, 1.0, true), new GaussianEmission(locEss, ca, 8.0, 1.0, 1.0, true), new GaussianEmission(locEss, ca, 1.0, 1.0, 1.0, true), new GaussianEmission(locEss, ca, 2.0, 1.0, 1.0, true), new SilentEmission()};
        String[] stateNames = new String[]{"S", "I11", "C1", "I12", "I21", "C2", "I22", "E", "F"};
        int[] nArray = new int[9];
        nArray[1] = 1;
        nArray[2] = 2;
        nArray[3] = 1;
        nArray[4] = 3;
        nArray[5] = 4;
        nArray[6] = 3;
        nArray[8] = 5;
        int[] emissionIdx = nArray;
        LinkedList<TransitionElement> el = new LinkedList<TransitionElement>();
        el.add(new TransitionElement(new int[0], new int[1], new double[]{ess}));
        int[] nArray2 = new int[3];
        nArray2[1] = 1;
        nArray2[2] = 4;
        el.add(new TransitionElement(new int[1], nArray2, new double[]{locEss, ess, ess}));
        el.add(new TransitionElement(new int[]{1}, new int[]{1, 2}, new double[]{locEss, ess}));
        el.add(new TransitionElement(new int[]{2}, new int[]{2, 3}, new double[]{locEss, ess}));
        el.add(new TransitionElement(new int[]{3}, new int[]{3, 7}, new double[]{locEss, ess}));
        el.add(new TransitionElement(new int[]{4}, new int[]{4, 5}, new double[]{locEss, ess}));
        el.add(new TransitionElement(new int[]{5}, new int[]{5, 6}, new double[]{locEss, ess}));
        el.add(new TransitionElement(new int[]{6}, new int[]{6, 7}, new double[]{locEss, ess}));
        el.add(new TransitionElement(new int[]{7}, new int[]{7, 8}, new double[]{locEss, ess}));
        boolean[] forward = new boolean[stateNames.length];
        Arrays.fill(forward, true);
        ViterbiParameterSet trainingParameterSet = new ViterbiParameterSet(10, new SmallDifferenceOfFunctionEvaluationsCondition(1.0E-6), 4);
        DifferentiableHigherOrderHMM hmm = new DifferentiableHigherOrderHMM(trainingParameterSet, stateNames, emissionIdx, forward, emissions, useLikelihood, ess, el.toArray(new TransitionElement[0]));
        System.out.println(hmm.getGraphvizRepresentation(null));
        hmm.train(data);
        System.out.println(hmm);
        DecimalFormat nf = new DecimalFormat("0.000");
        System.out.println(hmm.getGraphvizRepresentation(nf));
        int i = 0;
        while (i < 100) {
            Pair<IntList, Double> pair = hmm.getViterbiPathFor(data.getElementAt(i));
            System.out.println(pair.getFirstElement());
            ++i;
        }
        double[] stats = new double[stateNames.length];
        double[] ns = new double[stateNames.length];
        int i2 = 0;
        while (i2 < data.getNumberOfElements()) {
            Sequence seq = data.getElementAt(i2);
            IntList vit = hmm.getViterbiPathFor(seq).getFirstElement();
            int j = 0;
            while (j < seq.getLength()) {
                int n = vit.get(j);
                stats[n] = stats[n] + seq.continuousVal(j);
                int n2 = vit.get(j);
                ns[n2] = ns[n2] + 1.0;
                ++j;
            }
            ++i2;
        }
        int j = 0;
        while (j < stats.length) {
            int n = j;
            stats[n] = stats[n] / ns[j];
            ++j;
        }
        System.out.println(Arrays.toString(stats));
    }
}

