/*
 * Decompiled with CFR 0.152.
 */
package projects.dream2016;

import de.jstacs.algorithms.optimization.termination.SmallDifferenceOfFunctionEvaluationsCondition;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.alphabets.Alphabet;
import de.jstacs.data.alphabets.ContinuousAlphabet;
import de.jstacs.data.sequences.ArbitrarySequence;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.AbstractStringExtractor;
import de.jstacs.io.StringExtractor;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.DifferentiableHigherOrderHMM;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.DifferentiableEmission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.SilentEmission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.continuous.GaussianEmission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.ViterbiParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.elements.TransitionElement;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Pair;
import de.jstacs.utils.ToolBox;
import java.io.File;
import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.LinkedList;

public class DNaseHMM {
    private static DataSet medianSmoothing(DataSet data, int window) throws Exception {
        LinkedList<Sequence> seqs = new LinkedList<Sequence>();
        int i = 0;
        while (i < data.getNumberOfElements()) {
            ArbitrarySequence seq = (ArbitrarySequence)data.getElementAt(i);
            double[] temp = new double[window];
            double[] res = new double[seq.getLength() / window];
            int j = 0;
            while (j < res.length) {
                int k = 0;
                while (k < temp.length) {
                    temp[k] = seq.continuousVal(j * window + k);
                    ++k;
                }
                res[j] = ToolBox.median(temp);
                ++j;
            }
            double med = ToolBox.median(temp);
            if (med == 0.0) {
                med = 1.0;
            }
            int j2 = 0;
            while (j2 < res.length) {
                int n = j2++;
                res[n] = res[n] / med;
            }
            seqs.add(new ArbitrarySequence(seq.getAlphabetContainer(), res));
            ++i;
        }
        return new DataSet("", seqs);
    }

    private static double[] getMean(DataSet data) {
        double[] mean = new double[data.getElementLength()];
        int i = 0;
        while (i < data.getNumberOfElements()) {
            Sequence seq = data.getElementAt(i);
            int j = 0;
            while (j < seq.getLength()) {
                int n = j;
                mean[n] = mean[n] + seq.continuousVal(j);
                ++j;
            }
            ++i;
        }
        double n = data.getNumberOfElements();
        int i2 = 0;
        while (i2 < mean.length) {
            int n2 = i2++;
            mean[n2] = mean[n2] / n;
        }
        return mean;
    }

    public static void main(String[] args) throws Exception {
        AlphabetContainer ca = new AlphabetContainer((Alphabet)new ContinuousAlphabet());
        DataSet data = new DataSet(ca, (AbstractStringExtractor)new StringExtractor(new File(args[0]), 1000, '#'), "\t");
        data = DNaseHMM.medianSmoothing(data, 5);
        double[] mean = DNaseHMM.getMean(data);
        double ma = ToolBox.max(mean);
        System.out.println(Arrays.toString(mean));
        System.out.println("num: " + data.getNumberOfElements());
        System.out.println("len: " + data.getElementLength());
        int nMotif = data.getElementLength() / 3;
        boolean useLikelihood = false;
        int seqLen = data.getElementLength();
        double ess = 4.0;
        double loopEss = (double)(seqLen - 2 * nMotif) * ess / 2.0;
        DifferentiableEmission[] emissions = new DifferentiableEmission[nMotif + 3];
        int i = 0;
        while (i < emissions.length - 1) {
            double locEss = ess * 2.0;
            if (i == 0) {
                locEss = loopEss * 2.0;
            } else if (i == emissions.length - 2) {
                locEss = ess * 2.0;
            }
            double priorMu = (double)i / ((double)emissions.length - 1.0) * ma;
            System.out.println(String.valueOf(i) + " " + priorMu);
            double expPrec = 1.0;
            double sdPrec = 1.0;
            emissions[i] = new GaussianEmission(locEss, ca, priorMu, expPrec, sdPrec, true);
            ++i;
        }
        emissions[emissions.length - 1] = new SilentEmission();
        String[] stateNames = new String[2 * nMotif + 4];
        int[] emissionIdx = new int[stateNames.length];
        stateNames[0] = "S";
        emissionIdx[0] = 0;
        stateNames[nMotif + 1] = "I";
        emissionIdx[nMotif + 1] = emissions.length - 2;
        stateNames[2 * nMotif + 2] = "E";
        emissionIdx[2 * nMotif + 2] = 0;
        int i2 = 0;
        while (i2 < nMotif) {
            stateNames[i2 + 1] = "F" + i2;
            stateNames[2 * nMotif + 2 - i2 - 1] = "B" + i2;
            emissionIdx[i2 + 1] = i2 + 1;
            emissionIdx[2 * nMotif + 2 - i2 - 1] = i2 + 1;
            ++i2;
        }
        stateNames[2 * nMotif + 3] = "End";
        emissionIdx[2 * nMotif + 3] = emissions.length - 1;
        boolean[] forward = new boolean[stateNames.length];
        Arrays.fill(forward, true);
        LinkedList<TransitionElement> tel = new LinkedList<TransitionElement>();
        tel.add(new TransitionElement(new int[0], new int[1], new double[]{ess}));
        int[] nArray = new int[2];
        nArray[1] = 1;
        tel.add(new TransitionElement(new int[1], nArray, new double[]{loopEss, ess}));
        tel.add(new TransitionElement(new int[]{nMotif + 1}, new int[]{nMotif + 2}, new double[]{ess}));
        tel.add(new TransitionElement(new int[]{2 * nMotif + 2}, new int[]{2 * nMotif + 2, 2 * nMotif + 3}, new double[]{loopEss, ess}));
        int i3 = 0;
        while (i3 < nMotif) {
            tel.add(new TransitionElement(new int[]{i3 + 1}, new int[]{i3 + 2}, new double[]{ess}));
            tel.add(new TransitionElement(new int[]{nMotif + 2 + i3}, new int[]{nMotif + 2 + i3 + 1}, new double[]{ess}));
            ++i3;
        }
        ViterbiParameterSet trainingParameterSet = new ViterbiParameterSet(10, new SmallDifferenceOfFunctionEvaluationsCondition(1.0E-6), 4);
        DifferentiableHigherOrderHMM hmm = new DifferentiableHigherOrderHMM(trainingParameterSet, stateNames, emissionIdx, forward, emissions, useLikelihood, ess, tel.toArray(new TransitionElement[0]));
        System.out.println(hmm.getGraphvizRepresentation(null));
        hmm.train(data);
        System.out.println(hmm);
        DecimalFormat nf = new DecimalFormat("0.000");
        System.out.println(hmm.getGraphvizRepresentation(nf));
        int i4 = 0;
        while (i4 < 100) {
            Pair<IntList, Double> pair = hmm.getViterbiPathFor(data.getElementAt(i4));
            System.out.println(pair.getFirstElement());
            ++i4;
        }
        double[] stats = new double[stateNames.length];
        double[] ns = new double[stateNames.length];
        int i5 = 0;
        while (i5 < data.getNumberOfElements()) {
            Sequence seq = data.getElementAt(i5);
            IntList vit = hmm.getViterbiPathFor(seq).getFirstElement();
            int j = 0;
            while (j < seq.getLength()) {
                int n = vit.get(j);
                stats[n] = stats[n] + seq.continuousVal(j);
                int n2 = vit.get(j);
                ns[n2] = ns[n2] + 1.0;
                ++j;
            }
            ++i5;
        }
        int j = 0;
        while (j < stats.length) {
            int n = j;
            stats[n] = stats[n] / ns[j];
            ++j;
        }
        System.out.println(Arrays.toString(stats));
    }
}

