/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.parser.nndep;

import edu.stanford.nlp.parser.nndep.Config;
import edu.stanford.nlp.parser.nndep.Dataset;
import edu.stanford.nlp.parser.nndep.Example;
import edu.stanford.nlp.parser.nndep.Util;
import edu.stanford.nlp.util.CollectionUtils;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.concurrent.MulticoreWrapper;
import edu.stanford.nlp.util.concurrent.ThreadsafeProcessor;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.IntStream;

public class Classifier {
    private static final Redwood.RedwoodChannels log = Redwood.channels(Classifier.class);
    private final double[][] W1;
    private final double[][] W2;
    private final double[][] E;
    private final double[] b1;
    private double[][] gradSaved;
    private double[][] eg2W1;
    private double[][] eg2W2;
    private double[][] eg2E;
    private double[] eg2b1;
    private double[][] saved;
    private final Map<Integer, Integer> preMap;
    private boolean isTraining;
    private final Dataset dataset;
    private final MulticoreWrapper<Pair<Collection<Example>, FeedforwardParams>, Cost> jobHandler;
    private final Config config;
    private final int numLabels;

    public Classifier(Config config, double[][] E, double[][] W1, double[] b1, double[][] W2, List<Integer> preComputed) {
        this(config, null, E, W1, b1, W2, preComputed);
    }

    public Classifier(Config config, Dataset dataset, double[][] E, double[][] W1, double[] b1, double[][] W2, List<Integer> preComputed) {
        this.config = config;
        this.dataset = dataset;
        this.E = E;
        this.W1 = W1;
        this.b1 = b1;
        this.W2 = W2;
        this.initGradientHistories();
        this.numLabels = W2.length;
        this.preMap = new HashMap<Integer, Integer>();
        for (int i = 0; i < preComputed.size() && i < config.numPreComputed; ++i) {
            this.preMap.put(preComputed.get(i), i);
        }
        this.isTraining = dataset != null;
        this.jobHandler = this.isTraining ? new MulticoreWrapper<Pair<Collection<Example>, FeedforwardParams>, Cost>(config.trainingThreads, new CostFunction(), false) : null;
    }

    private Set<Integer> getToPreCompute(List<Example> examples) {
        HashSet<Integer> featureIDs = new HashSet<Integer>();
        block0: for (Example ex : examples) {
            List<Integer> feature = ex.getFeature();
            int j = 0;
            while (true) {
                if (j >= 48) continue block0;
                int tok = feature.get(j);
                int index = tok * 48 + j;
                if (this.preMap.containsKey(index)) {
                    featureIDs.add(index);
                }
                ++j;
            }
        }
        double percentagePreComputed = (float)featureIDs.size() / (float)this.config.numPreComputed;
        log.info(String.format("Percent actually necessary to pre-compute: %f%%%n", percentagePreComputed * 100.0));
        return featureIDs;
    }

    public Cost computeCostFunction(int batchSize, double regParameter, double dropOutProb) {
        this.validateTraining();
        List<Example> examples = Util.getRandomSubList(this.dataset.examples, batchSize);
        Set<Integer> toPreCompute = this.getToPreCompute(examples);
        this.preCompute(toPreCompute);
        FeedforwardParams params = new FeedforwardParams(batchSize, dropOutProb);
        this.gradSaved = new double[this.preMap.size()][this.config.hiddenSize];
        int numChunks = this.config.trainingThreads;
        List<List<Example>> chunks = CollectionUtils.partitionIntoFolds(examples, numChunks);
        for (Collection collection : chunks) {
            this.jobHandler.put(new Pair<Collection, FeedforwardParams>(collection, params));
        }
        this.jobHandler.join(false);
        Cost cost = null;
        while (this.jobHandler.peek()) {
            Cost cost2 = this.jobHandler.poll();
            if (cost == null) {
                cost = cost2;
                continue;
            }
            cost.merge(cost2);
        }
        if (cost == null) {
            return null;
        }
        cost.backpropSaved(toPreCompute);
        cost.addL2Regularization(regParameter);
        return cost;
    }

    public void takeAdaGradientStep(Cost cost, double adaAlpha, double adaEps) {
        int j;
        int i;
        this.validateTraining();
        double[][] gradW1 = cost.getGradW1();
        double[][] gradW2 = cost.getGradW2();
        double[][] gradE = cost.getGradE();
        double[] gradb1 = cost.getGradb1();
        for (i = 0; i < this.W1.length; ++i) {
            for (j = 0; j < this.W1[i].length; ++j) {
                double[] dArray = this.eg2W1[i];
                int n = j;
                dArray[n] = dArray[n] + gradW1[i][j] * gradW1[i][j];
                double[] dArray2 = this.W1[i];
                int n2 = j;
                dArray2[n2] = dArray2[n2] - adaAlpha * gradW1[i][j] / Math.sqrt(this.eg2W1[i][j] + adaEps);
            }
        }
        for (i = 0; i < this.b1.length; ++i) {
            int n = i;
            this.eg2b1[n] = this.eg2b1[n] + gradb1[i] * gradb1[i];
            int n3 = i;
            this.b1[n3] = this.b1[n3] - adaAlpha * gradb1[i] / Math.sqrt(this.eg2b1[i] + adaEps);
        }
        for (i = 0; i < this.W2.length; ++i) {
            for (j = 0; j < this.W2[i].length; ++j) {
                double[] dArray = this.eg2W2[i];
                int n = j;
                dArray[n] = dArray[n] + gradW2[i][j] * gradW2[i][j];
                double[] dArray3 = this.W2[i];
                int n4 = j;
                dArray3[n4] = dArray3[n4] - adaAlpha * gradW2[i][j] / Math.sqrt(this.eg2W2[i][j] + adaEps);
            }
        }
        if (this.config.doWordEmbeddingGradUpdate) {
            for (i = 0; i < this.E.length; ++i) {
                for (j = 0; j < this.E[i].length; ++j) {
                    double[] dArray = this.eg2E[i];
                    int n = j;
                    dArray[n] = dArray[n] + gradE[i][j] * gradE[i][j];
                    double[] dArray4 = this.E[i];
                    int n5 = j;
                    dArray4[n5] = dArray4[n5] - adaAlpha * gradE[i][j] / Math.sqrt(this.eg2E[i][j] + adaEps);
                }
            }
        }
    }

    private void initGradientHistories() {
        this.eg2E = new double[this.E.length][this.E[0].length];
        this.eg2W1 = new double[this.W1.length][this.W1[0].length];
        this.eg2b1 = new double[this.b1.length];
        this.eg2W2 = new double[this.W2.length][this.W2[0].length];
    }

    public void clearGradientHistories() {
        this.validateTraining();
        this.initGradientHistories();
    }

    private void validateTraining() {
        if (!this.isTraining) {
            throw new IllegalStateException("Not training, or training was already finalized");
        }
    }

    public void finalizeTraining() {
        this.validateTraining();
        this.jobHandler.join(true);
        this.isTraining = false;
    }

    public void preCompute() {
        this.preCompute(this.preMap.keySet());
    }

    public void preCompute(Set<Integer> toPreCompute) {
        long startTime = System.currentTimeMillis();
        this.saved = new double[this.preMap.size()][this.config.hiddenSize];
        for (int x : toPreCompute) {
            int mapX = this.preMap.get(x);
            int tok = x / 48;
            int pos = x % 48;
            for (int j = 0; j < this.config.hiddenSize; ++j) {
                for (int k = 0; k < this.config.embeddingSize; ++k) {
                    double[] dArray = this.saved[mapX];
                    int n = j;
                    dArray[n] = dArray[n] + this.W1[j][pos * this.config.embeddingSize + k] * this.E[tok][k];
                }
            }
        }
        log.info("PreComputed " + toPreCompute.size() + ", Elapsed Time: " + (double)(System.currentTimeMillis() - startTime) / 1000.0 + " (s)");
    }

    double[] computeScores(int[] feature) {
        return this.computeScores(feature, this.preMap);
    }

    private double[] computeScores(int[] feature, Map<Integer, Integer> preMap) {
        double[] hidden = new double[this.config.hiddenSize];
        int offset = 0;
        for (int j = 0; j < feature.length; ++j) {
            int tok = feature[j];
            int index = tok * 48 + j;
            if (preMap.containsKey(index)) {
                int id = preMap.get(index);
                for (int i = 0; i < this.config.hiddenSize; ++i) {
                    int n = i;
                    hidden[n] = hidden[n] + this.saved[id][i];
                }
            } else {
                for (int i = 0; i < this.config.hiddenSize; ++i) {
                    for (int k = 0; k < this.config.embeddingSize; ++k) {
                        int n = i;
                        hidden[n] = hidden[n] + this.W1[i][offset + k] * this.E[tok][k];
                    }
                }
            }
            offset += this.config.embeddingSize;
        }
        for (int i = 0; i < this.config.hiddenSize; ++i) {
            int n = i;
            hidden[n] = hidden[n] + this.b1[i];
            hidden[i] = hidden[i] * hidden[i] * hidden[i];
        }
        double[] scores = new double[this.numLabels];
        for (int i = 0; i < this.numLabels; ++i) {
            for (int j = 0; j < this.config.hiddenSize; ++j) {
                int n = i;
                scores[n] = scores[n] + this.W2[i][j] * hidden[j];
            }
        }
        return scores;
    }

    public double[][] getW1() {
        return this.W1;
    }

    public double[] getb1() {
        return this.b1;
    }

    public double[][] getW2() {
        return this.W2;
    }

    public double[][] getE() {
        return this.E;
    }

    private static void addInPlace(double[][] m1, double[][] m2) {
        for (int i = 0; i < m1.length; ++i) {
            for (int j = 0; j < m1[0].length; ++j) {
                double[] dArray = m1[i];
                int n = j;
                dArray[n] = dArray[n] + m2[i][j];
            }
        }
    }

    private static void addInPlace(double[] a1, double[] a2) {
        for (int i = 0; i < a1.length; ++i) {
            int n = i;
            a1[n] = a1[n] + a2[i];
        }
    }

    public class Cost {
        private double cost;
        private double percentCorrect;
        private final double[][] gradW1;
        private final double[] gradb1;
        private final double[][] gradW2;
        private final double[][] gradE;

        private Cost(double cost, double percentCorrect, double[][] gradW1, double[] gradb1, double[][] gradW2, double[][] gradE) {
            this.cost = cost;
            this.percentCorrect = percentCorrect;
            this.gradW1 = gradW1;
            this.gradb1 = gradb1;
            this.gradW2 = gradW2;
            this.gradE = gradE;
        }

        public void merge(Cost otherCost) {
            this.cost += otherCost.getCost();
            this.percentCorrect += otherCost.getPercentCorrect();
            Classifier.addInPlace(this.gradW1, otherCost.getGradW1());
            Classifier.addInPlace(this.gradb1, otherCost.getGradb1());
            Classifier.addInPlace(this.gradW2, otherCost.getGradW2());
            Classifier.addInPlace(this.gradE, otherCost.getGradE());
        }

        private void backpropSaved(Set<Integer> featuresSeen) {
            for (int x : featuresSeen) {
                int mapX = (Integer)Classifier.this.preMap.get(x);
                Classifier.this.config;
                int tok = x / 48;
                Classifier.this.config;
                int offset = x % 48 * ((Classifier)Classifier.this).config.embeddingSize;
                for (int j = 0; j < ((Classifier)Classifier.this).config.hiddenSize; ++j) {
                    double delta = Classifier.this.gradSaved[mapX][j];
                    for (int k = 0; k < ((Classifier)Classifier.this).config.embeddingSize; ++k) {
                        double[] dArray = this.gradW1[j];
                        int n = offset + k;
                        dArray[n] = dArray[n] + delta * Classifier.this.E[tok][k];
                        double[] dArray2 = this.gradE[tok];
                        int n2 = k;
                        dArray2[n2] = dArray2[n2] + delta * Classifier.this.W1[j][offset + k];
                    }
                }
            }
        }

        private void addL2Regularization(double regularizationWeight) {
            int j;
            int i;
            for (i = 0; i < Classifier.this.W1.length; ++i) {
                for (j = 0; j < Classifier.this.W1[i].length; ++j) {
                    this.cost += regularizationWeight * Classifier.this.W1[i][j] * Classifier.this.W1[i][j] / 2.0;
                    double[] dArray = this.gradW1[i];
                    int n = j;
                    dArray[n] = dArray[n] + regularizationWeight * Classifier.this.W1[i][j];
                }
            }
            for (i = 0; i < Classifier.this.b1.length; ++i) {
                this.cost += regularizationWeight * Classifier.this.b1[i] * Classifier.this.b1[i] / 2.0;
                int n = i;
                this.gradb1[n] = this.gradb1[n] + regularizationWeight * Classifier.this.b1[i];
            }
            for (i = 0; i < Classifier.this.W2.length; ++i) {
                for (j = 0; j < Classifier.this.W2[i].length; ++j) {
                    this.cost += regularizationWeight * Classifier.this.W2[i][j] * Classifier.this.W2[i][j] / 2.0;
                    double[] dArray = this.gradW2[i];
                    int n = j;
                    dArray[n] = dArray[n] + regularizationWeight * Classifier.this.W2[i][j];
                }
            }
            for (i = 0; i < Classifier.this.E.length; ++i) {
                for (j = 0; j < Classifier.this.E[i].length; ++j) {
                    this.cost += regularizationWeight * Classifier.this.E[i][j] * Classifier.this.E[i][j] / 2.0;
                    double[] dArray = this.gradE[i];
                    int n = j;
                    dArray[n] = dArray[n] + regularizationWeight * Classifier.this.E[i][j];
                }
            }
        }

        public double getCost() {
            return this.cost;
        }

        public double getPercentCorrect() {
            return this.percentCorrect;
        }

        public double[][] getGradW1() {
            return this.gradW1;
        }

        public double[] getGradb1() {
            return this.gradb1;
        }

        public double[][] getGradW2() {
            return this.gradW2;
        }

        public double[][] getGradE() {
            return this.gradE;
        }
    }

    private static class FeedforwardParams {
        private final int batchSize;
        private final double dropOutProb;

        private FeedforwardParams(int batchSize, double dropOutProb) {
            this.batchSize = batchSize;
            this.dropOutProb = dropOutProb;
        }

        public int getBatchSize() {
            return this.batchSize;
        }

        public double getDropOutProb() {
            return this.dropOutProb;
        }
    }

    private class CostFunction
    implements ThreadsafeProcessor<Pair<Collection<Example>, FeedforwardParams>, Cost> {
        private double[][] gradW1;
        private double[] gradb1;
        private double[][] gradW2;
        private double[][] gradE;

        private CostFunction() {
        }

        @Override
        public Cost process(Pair<Collection<Example>, FeedforwardParams> input) {
            Collection<Example> examples = input.first();
            FeedforwardParams params = input.second();
            ThreadLocalRandom random = ThreadLocalRandom.current();
            this.gradW1 = new double[Classifier.this.W1.length][Classifier.this.W1[0].length];
            this.gradb1 = new double[Classifier.this.b1.length];
            this.gradW2 = new double[Classifier.this.W2.length][Classifier.this.W2[0].length];
            this.gradE = new double[Classifier.this.E.length][Classifier.this.E[0].length];
            double cost = 0.0;
            double correct = 0.0;
            block0: for (Example ex : examples) {
                int nodeIndex;
                int tok;
                List<Integer> feature = ex.getFeature();
                List<Integer> label = ex.getLabel();
                double[] scores = new double[Classifier.this.numLabels];
                double[] hidden = new double[((Classifier)Classifier.this).config.hiddenSize];
                double[] hidden3 = new double[((Classifier)Classifier.this).config.hiddenSize];
                int[] ls = IntStream.range(0, ((Classifier)Classifier.this).config.hiddenSize).filter(n -> random.nextDouble() > params.getDropOutProb()).toArray();
                int offset = 0;
                int j22 = 0;
                while (true) {
                    Classifier.this.config;
                    if (j22 >= 48) break;
                    tok = feature.get(j22);
                    Classifier.this.config;
                    int index = tok * 48 + j22;
                    if (Classifier.this.preMap.containsKey(index)) {
                        int id = (Integer)Classifier.this.preMap.get(index);
                        int[] nArray = ls;
                        int n2 = nArray.length;
                        for (int i = 0; i < n2; ++i) {
                            int nodeIndex2;
                            int n3 = nodeIndex2 = nArray[i];
                            hidden[n3] = hidden[n3] + Classifier.this.saved[id][nodeIndex2];
                        }
                    } else {
                        for (int nodeIndex3 : ls) {
                            for (int k = 0; k < ((Classifier)Classifier.this).config.embeddingSize; ++k) {
                                int n4 = nodeIndex3;
                                hidden[n4] = hidden[n4] + Classifier.this.W1[nodeIndex3][offset + k] * Classifier.this.E[tok][k];
                            }
                        }
                    }
                    offset += ((Classifier)Classifier.this).config.embeddingSize;
                    ++j22;
                }
                int[] j22 = ls;
                tok = j22.length;
                for (int i = 0; i < tok; ++i) {
                    int n5 = nodeIndex = j22[i];
                    hidden[n5] = hidden[n5] + Classifier.this.b1[nodeIndex];
                    hidden3[nodeIndex] = Math.pow(hidden[nodeIndex], 3.0);
                }
                int optLabel = -1;
                for (int i = 0; i < Classifier.this.numLabels; ++i) {
                    if (label.get(i) < 0) continue;
                    int[] nArray = ls;
                    nodeIndex = nArray.length;
                    for (int j = 0; j < nodeIndex; ++j) {
                        int nodeIndex4 = nArray[j];
                        int n6 = i;
                        scores[n6] = scores[n6] + Classifier.this.W2[i][nodeIndex4] * hidden3[nodeIndex4];
                    }
                    if (optLabel >= 0 && !(scores[i] > scores[optLabel])) continue;
                    optLabel = i;
                }
                double sum1 = 0.0;
                double sum2 = 0.0;
                double maxScore = scores[optLabel];
                for (int i = 0; i < Classifier.this.numLabels; ++i) {
                    if (label.get(i) < 0) continue;
                    scores[i] = Math.exp(scores[i] - maxScore);
                    if (label.get(i) == 1) {
                        sum1 += scores[i];
                    }
                    sum2 += scores[i];
                }
                cost += (Math.log(sum2) - Math.log(sum1)) / (double)params.getBatchSize();
                if (label.get(optLabel) == 1) {
                    correct += 1.0 / (double)params.getBatchSize();
                }
                double[] gradHidden3 = new double[((Classifier)Classifier.this).config.hiddenSize];
                for (int i = 0; i < Classifier.this.numLabels; ++i) {
                    if (label.get(i) < 0) continue;
                    double delta = -((double)label.get(i).intValue() - scores[i] / sum2) / (double)params.getBatchSize();
                    for (int nodeIndex5 : ls) {
                        double[] dArray = this.gradW2[i];
                        int n7 = nodeIndex5;
                        dArray[n7] = dArray[n7] + delta * hidden3[nodeIndex5];
                        int n8 = nodeIndex5;
                        gradHidden3[n8] = gradHidden3[n8] + delta * Classifier.this.W2[i][nodeIndex5];
                    }
                }
                double[] gradHidden = new double[((Classifier)Classifier.this).config.hiddenSize];
                for (int nodeIndex6 : ls) {
                    gradHidden[nodeIndex6] = gradHidden3[nodeIndex6] * 3.0 * hidden[nodeIndex6] * hidden[nodeIndex6];
                    int n9 = nodeIndex6;
                    this.gradb1[n9] = this.gradb1[n9] + gradHidden[nodeIndex6];
                }
                offset = 0;
                int j = 0;
                while (true) {
                    Classifier.this.config;
                    if (j >= 48) continue block0;
                    int tok2 = feature.get(j);
                    Classifier.this.config;
                    int index = tok2 * 48 + j;
                    if (Classifier.this.preMap.containsKey(index)) {
                        int id = (Integer)Classifier.this.preMap.get(index);
                        int[] nArray = ls;
                        int n10 = nArray.length;
                        for (int i = 0; i < n10; ++i) {
                            int nodeIndex7 = nArray[i];
                            double[] dArray = Classifier.this.gradSaved[id];
                            int n11 = nodeIndex7;
                            dArray[n11] = dArray[n11] + gradHidden[nodeIndex7];
                        }
                    } else {
                        for (int nodeIndex8 : ls) {
                            for (int k = 0; k < ((Classifier)Classifier.this).config.embeddingSize; ++k) {
                                double[] dArray = this.gradW1[nodeIndex8];
                                int n12 = offset + k;
                                dArray[n12] = dArray[n12] + gradHidden[nodeIndex8] * Classifier.this.E[tok2][k];
                                double[] dArray2 = this.gradE[tok2];
                                int n13 = k;
                                dArray2[n13] = dArray2[n13] + gradHidden[nodeIndex8] * Classifier.this.W1[nodeIndex8][offset + k];
                            }
                        }
                    }
                    offset += ((Classifier)Classifier.this).config.embeddingSize;
                    ++j;
                }
            }
            return new Cost(cost, correct, this.gradW1, this.gradb1, this.gradW2, this.gradE);
        }

        @Override
        public ThreadsafeProcessor<Pair<Collection<Example>, FeedforwardParams>, Cost> newInstance() {
            return new CostFunction();
        }
    }
}

