/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.deeplearning;

import hex.genmodel.algos.deeplearning.ActivationUtils;
import hex.genmodel.algos.deeplearning.DeeplearningMojoModel;
import java.util.Arrays;
import java.util.List;

public class NeuralNetwork {
    public String _activation;
    double _drop_out_ratio;
    public DeeplearningMojoModel.StoreWeightsBias _weightsAndBias;
    public double[] _inputs;
    public double[] _outputs;
    public int _outSize;
    public int _inSize;
    public int _maxK = 1;
    List<String> _validActivation = Arrays.asList("Linear", "Softmax", "ExpRectifierWithDropout", "ExpRectifier", "Rectifier", "RectifierWithDropout", "MaxoutWithDropout", "Maxout", "TanhWithDropout", "Tanh");

    public NeuralNetwork(String activation, double drop_out_ratio, DeeplearningMojoModel.StoreWeightsBias weightsAndBias, double[] inputs, int outSize) {
        this.validateInputs(activation, drop_out_ratio, weightsAndBias._wValues.length, weightsAndBias._bValues.length, inputs.length, outSize);
        this._activation = activation;
        this._drop_out_ratio = drop_out_ratio;
        this._weightsAndBias = weightsAndBias;
        this._inputs = inputs;
        this._outSize = outSize;
        this._inSize = this._inputs.length;
        this._outputs = new double[this._outSize];
        if ("Maxout".equals(this._activation) || "MaxoutWithDropout".equals(this._activation)) {
            this._maxK = weightsAndBias._bValues.length / outSize;
        }
    }

    public double[] fprop1Layer() {
        double[] input2ActFun = this._maxK == 1 ? this.formNNInputs() : this.formNNInputsMaxOut();
        ActivationUtils.ActivationFunctions createActivations = this.createActFuns(this._activation);
        return createActivations.eval(input2ActFun, this._drop_out_ratio, this._maxK);
    }

    public double[] formNNInputs() {
        double[] input2ActFun = new double[this._outSize];
        int cols = this._inputs.length;
        int rows = input2ActFun.length;
        int extra = cols - cols % 8;
        int multiple = cols / 8 * 8 - 1;
        int idx = 0;
        for (int row = 0; row < rows; ++row) {
            int col;
            double psum0 = 0.0;
            double psum1 = 0.0;
            double psum2 = 0.0;
            double psum3 = 0.0;
            double psum4 = 0.0;
            double psum5 = 0.0;
            double psum6 = 0.0;
            double psum7 = 0.0;
            for (col = 0; col < multiple; col += 8) {
                int off = idx + col;
                psum0 += (double)this._weightsAndBias._wValues[off] * this._inputs[col];
                psum1 += (double)this._weightsAndBias._wValues[off + 1] * this._inputs[col + 1];
                psum2 += (double)this._weightsAndBias._wValues[off + 2] * this._inputs[col + 2];
                psum3 += (double)this._weightsAndBias._wValues[off + 3] * this._inputs[col + 3];
                psum4 += (double)this._weightsAndBias._wValues[off + 4] * this._inputs[col + 4];
                psum5 += (double)this._weightsAndBias._wValues[off + 5] * this._inputs[col + 5];
                psum6 += (double)this._weightsAndBias._wValues[off + 6] * this._inputs[col + 6];
                psum7 += (double)this._weightsAndBias._wValues[off + 7] * this._inputs[col + 7];
            }
            int n2 = row;
            input2ActFun[n2] = input2ActFun[n2] + (psum0 + psum1 + psum2 + psum3);
            int n3 = row;
            input2ActFun[n3] = input2ActFun[n3] + (psum4 + psum5 + psum6 + psum7);
            for (col = extra; col < cols; ++col) {
                int n4 = row;
                input2ActFun[n4] = input2ActFun[n4] + (double)this._weightsAndBias._wValues[idx + col] * this._inputs[col];
            }
            int n5 = row;
            input2ActFun[n5] = input2ActFun[n5] + this._weightsAndBias._bValues[row];
            idx += cols;
        }
        return input2ActFun;
    }

    public double[] formNNInputsMaxOut() {
        double[] input2ActFun = new double[this._outSize * this._maxK];
        for (int k2 = 0; k2 < this._maxK; ++k2) {
            for (int row = 0; row < this._outSize; ++row) {
                int countInd = this._maxK * row + k2;
                for (int col = 0; col < this._inSize; ++col) {
                    int n2 = countInd;
                    input2ActFun[n2] = input2ActFun[n2] + this._inputs[col] * (double)this._weightsAndBias._wValues[this._maxK * (row * this._inSize + col) + k2];
                }
                int n3 = countInd;
                input2ActFun[n3] = input2ActFun[n3] + this._weightsAndBias._bValues[countInd];
            }
        }
        return input2ActFun;
    }

    public void validateInputs(String activation, double drop_out_ratio, int weightLen, int biasLen, int inSize, int outSize) {
        assert (this._validActivation.contains(activation)) : "activation must be one of \"Linear\", \"Softmax\", \"ExpRectifierWithDropout\", \"ExpRectifier\", \"Rectifier\", \"RectifierWithDropout\", \"MaxoutWithDropout\", \"Maxout\", \"TanhWithDropout\", \"Tanh\"";
        assert (weightLen % (inSize * outSize) == 0) : "Your neural network layer number of input * number of outputs should equal length of your weight vector";
        assert (biasLen % outSize == 0) : "Number of bias should equal number of nodes in your nerual network layer.";
        assert (drop_out_ratio >= 0.0 && drop_out_ratio < 1.0) : "drop_out_ratio must be >=0 and < 1.";
        assert (outSize > 0) : "number of nodes in neural network must exceed 0.";
    }

    public ActivationUtils.ActivationFunctions createActFuns(String activation) {
        switch (activation) {
            case "Linear": {
                return new ActivationUtils.LinearOut();
            }
            case "Softmax": {
                return new ActivationUtils.SoftmaxOut();
            }
            case "ExpRectifierWithDropout": {
                return new ActivationUtils.ExpRectifierDropoutOut();
            }
            case "ExpRectifier": {
                return new ActivationUtils.ExpRectifierOut();
            }
            case "Rectifier": {
                return new ActivationUtils.RectifierOut();
            }
            case "RectifierWithDropout": {
                return new ActivationUtils.RectifierDropoutOut();
            }
            case "MaxoutWithDropout": {
                return new ActivationUtils.MaxoutDropoutOut();
            }
            case "Maxout": {
                return new ActivationUtils.MaxoutOut();
            }
            case "TanhWithDropout": {
                return new ActivationUtils.TanhDropoutOut();
            }
            case "Tanh": {
                return new ActivationUtils.TanhOut();
            }
        }
        throw new UnsupportedOperationException("Unexpected activation function: " + activation);
    }
}

