/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.differentiable.continuous;

import cern.colt.Arrays;
import de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifier;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifierParameterSet;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LearningPrinciple;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.DoesNothingLogPrior;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.LogPrior;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.alphabets.Alphabet;
import de.jstacs.data.alphabets.ContinuousAlphabet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.AbstractStringExtractor;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.StringExtractor;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.continuous.ConstantDiffSM;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.random.RandomNumberGenerator;
import java.io.File;
import java.text.NumberFormat;

public class GaussianNetwork
extends AbstractDifferentiableStatisticalModel {
    private double[] mu;
    private double[] lambda;
    private double[][] bij;
    private int[][] structure;
    private int[] boff;

    public GaussianNetwork(int[][] structure) throws IllegalArgumentException, CloneNotSupportedException {
        this(new AlphabetContainer((Alphabet)new ContinuousAlphabet()), structure);
    }

    public GaussianNetwork(AlphabetContainer con, int[][] structure) throws IllegalArgumentException, CloneNotSupportedException {
        super(con, structure.length);
        this.structure = (int[][])ArrayHandler.clone((Cloneable[])structure);
        this.mu = new double[structure.length];
        this.lambda = new double[structure.length];
        this.bij = new double[structure.length][];
        int off = this.mu.length + this.lambda.length;
        this.boff = new int[this.bij.length];
        int i = 0;
        while (i < structure.length) {
            this.boff[i] = off;
            this.bij[i] = new double[structure[i].length];
            off += this.bij[i].length;
            ++i;
        }
    }

    @Override
    public GaussianNetwork clone() throws CloneNotSupportedException {
        GaussianNetwork clone = (GaussianNetwork)super.clone();
        clone.structure = (int[][])this.structure.clone();
        int i = 0;
        while (i < clone.structure.length) {
            clone.structure[i] = (int[])this.structure[i].clone();
            ++i;
        }
        clone.mu = (double[])this.mu.clone();
        clone.lambda = (double[])this.lambda.clone();
        clone.bij = (double[][])this.bij.clone();
        i = 0;
        while (i < clone.bij.length) {
            clone.bij[i] = (double[])this.bij[i].clone();
            ++i;
        }
        clone.boff = (int[])this.boff.clone();
        return clone;
    }

    public GaussianNetwork(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int index) {
        return 0;
    }

    @Override
    public double getLogNormalizationConstant() {
        return 0.0;
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex) throws Exception {
        return Double.NEGATIVE_INFINITY;
    }

    @Override
    public double getLogPriorTerm() {
        return 0.0;
    }

    @Override
    public void addGradientOfLogPriorTerm(double[] grad, int start) throws Exception {
    }

    @Override
    public double getESS() {
        return 0.0;
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, DataSet[] data, double[][] weights) throws Exception {
        this.initializeFunctionRandomly(freeParams);
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) throws Exception {
        RandomNumberGenerator rng = new RandomNumberGenerator();
        int i = 0;
        while (i < this.mu.length) {
            this.mu[i] = rng.nextGaussian();
            this.lambda[i] = Math.log(0.1);
            int j = 0;
            while (j < this.bij[i].length) {
                this.bij[i][j] = rng.nextGaussian();
                ++j;
            }
            ++i;
        }
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, IntList indices, DoubleList partialDer) {
        double val = 0.0;
        int i = 0;
        while (i < this.structure.length) {
            double mymu = this.mu[i];
            int j = 0;
            while (j < this.structure[i].length) {
                mymu += this.bij[i][j] * (seq.continuousVal(start + this.structure[i][j]) - this.mu[this.structure[i][j]]);
                ++j;
            }
            double temp = seq.continuousVal(start + i) - mymu;
            double expl = Math.exp(this.lambda[i]);
            val += 0.5 * this.lambda[i] - 0.5 * Math.log(Math.PI * 2) - expl / 2.0 * temp * temp;
            indices.add(this.mu.length + i);
            partialDer.add(0.5 - expl / 2.0 * temp * temp);
            indices.add(i);
            partialDer.add(expl * temp);
            int j2 = 0;
            while (j2 < this.structure[i].length) {
                indices.add(this.structure[i][j2]);
                partialDer.add(-expl * temp * this.bij[i][j2]);
                ++j2;
            }
            j2 = 0;
            while (j2 < this.structure[i].length) {
                indices.add(this.boff[i] + j2);
                partialDer.add(expl * temp * (seq.continuousVal(start + this.structure[i][j2]) - this.mu[this.structure[i][j2]]));
                ++j2;
            }
            ++i;
        }
        return val;
    }

    @Override
    public int getNumberOfParameters() {
        return this.boff[this.boff.length - 1] + this.bij[this.bij.length - 1].length;
    }

    @Override
    public double[] getCurrentParameterValues() throws Exception {
        double[] res = new double[this.getNumberOfParameters()];
        int off = 0;
        System.arraycopy(this.mu, 0, res, 0, this.mu.length);
        System.arraycopy(this.lambda, 0, res, off += this.mu.length, this.lambda.length);
        off += this.lambda.length;
        int i = 0;
        while (i < this.bij.length) {
            System.arraycopy(this.bij[i], 0, res, off, this.bij[i].length);
            off += this.bij[i].length;
            ++i;
        }
        return res;
    }

    @Override
    public void setParameters(double[] params, int start) {
        System.arraycopy(params, start, this.mu, 0, this.mu.length);
        System.arraycopy(params, start += this.mu.length, this.lambda, 0, this.lambda.length);
        start += this.lambda.length;
        int i = 0;
        while (i < this.bij.length) {
            System.arraycopy(params, start, this.bij[i], 0, this.bij[i].length);
            start += this.bij[i].length;
            ++i;
        }
    }

    @Override
    public String getInstanceName() {
        return "GaussianNetwork";
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start) {
        double val = 0.0;
        int i = 0;
        while (i < this.structure.length) {
            double mymu = this.mu[i];
            int j = 0;
            while (j < this.structure[i].length) {
                mymu += this.bij[i][j] * (seq.continuousVal(start + this.structure[i][j]) - this.mu[this.structure[i][j]]);
                ++j;
            }
            double temp = seq.continuousVal(start + i) - mymu;
            val += 0.5 * this.lambda[i] - 0.5 * Math.log(Math.PI * 2) - Math.exp(this.lambda[i]) / 2.0 * temp * temp;
            ++i;
        }
        return val;
    }

    @Override
    public boolean isInitialized() {
        return true;
    }

    @Override
    public String toString(NumberFormat nf) {
        StringBuffer sb = new StringBuffer();
        sb.append("mus: " + Arrays.toString(this.mu) + "\n");
        sb.append("lambdas: " + Arrays.toString(this.lambda) + "\n");
        int i = 0;
        while (i < this.bij.length) {
            sb.append("b_" + i + ": " + Arrays.toString(this.bij[i]) + "\n");
            ++i;
        }
        return sb.toString();
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer xml = new StringBuffer();
        XMLParser.appendObjectWithTags(xml, this.bij, "bij");
        XMLParser.appendObjectWithTags(xml, this.boff, "boff");
        XMLParser.appendObjectWithTags(xml, this.lambda, "lambda");
        XMLParser.appendObjectWithTags(xml, this.mu, "mu");
        XMLParser.appendObjectWithTags(xml, this.structure, "structure");
        XMLParser.addTags(xml, "GaussNet");
        return xml;
    }

    @Override
    protected void fromXML(StringBuffer xml) throws NonParsableException {
        xml = XMLParser.extractForTag(xml, "GaussNet");
        this.bij = (double[][])XMLParser.extractObjectForTags(xml, "bij");
        this.boff = (int[])XMLParser.extractObjectForTags(xml, "boff");
        this.lambda = (double[])XMLParser.extractObjectForTags(xml, "lambda");
        this.mu = (double[])XMLParser.extractObjectForTags(xml, "mu");
        this.structure = (int[][])XMLParser.extractObjectForTags(xml, "structure");
        this.alphabets = new AlphabetContainer((Alphabet)new ContinuousAlphabet());
        this.length = this.mu.length;
    }

    public static void main(String[] args) throws Exception {
        AlphabetContainer ca = new AlphabetContainer((Alphabet)new ContinuousAlphabet());
        DataSet data = new DataSet(ca, (AbstractStringExtractor)new StringExtractor(new File(args[0]), 1000, '#'), "\t");
        GaussianNetwork gn = new GaussianNetwork(new int[][]{{1, 2}, {2}, new int[0]});
        GenDisMixClassifierParameterSet params = new GenDisMixClassifierParameterSet(data.getAlphabetContainer(), data.getElementLength(), 20, 1.0E-6, 1.0E-6, 1.0E-4, false, OptimizableFunction.KindOfParameter.LAST, true, 1);
        GenDisMixClassifier cl = new GenDisMixClassifier(params, (LogPrior)DoesNothingLogPrior.defaultInstance, LearningPrinciple.ML, gn, new ConstantDiffSM(data.getElementLength()));
        cl.train(data, data);
        System.out.println(cl);
    }
}

