/*
 * Decompiled with CFR 0.152.
 */
package hex.glm;

import hex.DataInfo;
import hex.ModelMetrics;
import hex.glm.GLMModel;
import java.util.Arrays;
import water.Job;
import water.MRTask;
import water.MemoryManager;
import water.fvec.Chunk;
import water.fvec.NewChunk;
import water.util.ArrayUtils;
import water.util.FrameUtils;

public class GLMScore
extends MRTask<GLMScore> {
    final GLMModel _m;
    final Job _j;
    ModelMetrics.MetricBuilder _mb;
    final DataInfo _dinfo;
    final boolean _sparse;
    final String[] _domain;
    final boolean _computeMetrics;
    final boolean _generatePredictions;
    transient double[][] _vcov;
    transient double[] _tmp;
    transient double[] _eta;
    final int _nclasses;
    private final double[] _beta;
    private final double[][] _beta_multinomial;
    private final double _defaultThreshold;

    public GLMScore(Job j2, GLMModel m4, DataInfo dinfo, String[] domain, boolean computeMetrics, boolean generatePredictions) {
        this._j = j2;
        this._m = m4;
        this._computeMetrics = computeMetrics;
        this._sparse = FrameUtils.sparseRatio(dinfo._adaptedFrame) < 0.5;
        this._domain = domain;
        this._generatePredictions = generatePredictions;
        this._m._parms = m4._parms;
        this._nclasses = ((GLMModel.GLMOutput)m4._output).nclasses();
        if (((GLMModel.GLMParameters)this._m._parms)._family == GLMModel.GLMParameters.Family.multinomial || ((GLMModel.GLMParameters)this._m._parms)._family == GLMModel.GLMParameters.Family.ordinal) {
            this._beta = null;
            this._beta_multinomial = ((GLMModel.GLMOutput)m4._output)._global_beta_multinomial;
        } else {
            double[] beta = m4.beta();
            int[] ids = new int[beta.length - 1];
            int k2 = 0;
            for (int i2 = 0; i2 < beta.length - 1; ++i2) {
                if (beta[i2] == 0.0) continue;
                ids[k2++] = i2;
            }
            if (k2 < beta.length - 1) {
                ids = Arrays.copyOf(ids, k2);
                dinfo = dinfo.filterExpandedColumns(ids);
                double[] beta2 = MemoryManager.malloc8d(ids.length + 1);
                int l2 = 0;
                for (int x2 : ids) {
                    beta2[l2++] = beta[x2];
                }
                beta2[l2] = beta[beta.length - 1];
                beta = beta2;
            }
            this._beta_multinomial = null;
            this._beta = beta;
        }
        this._dinfo = dinfo;
        this._dinfo._valid = true;
        this._defaultThreshold = m4.defaultThreshold();
    }

    public double[] scoreRow(DataInfo.Row r2, double o2, double[] preds) {
        int lastClass = this._nclasses - 1;
        if (((GLMModel.GLMParameters)this._m._parms)._family == GLMModel.GLMParameters.Family.ordinal) {
            double[][] bm = this._beta_multinomial;
            Arrays.fill(preds, 0.0);
            double previousCDF = 0.0;
            for (int cInd = 0; cInd < lastClass; ++cInd) {
                double eta = r2.innerProduct(bm[cInd]) + o2;
                double currCDF = 1.0 / (1.0 + Math.exp(-eta));
                preds[cInd + 1] = currCDF - previousCDF;
                previousCDF = currCDF;
            }
            preds[this._nclasses] = 1.0 - previousCDF;
            preds[0] = ArrayUtils.maxIndex(preds) - 1;
        } else if (((GLMModel.GLMParameters)this._m._parms)._family == GLMModel.GLMParameters.Family.multinomial) {
            int c2;
            double[] eta = this._eta;
            double[][] bm = this._beta_multinomial;
            double sumExp = 0.0;
            double maxRow = 0.0;
            for (c2 = 0; c2 < bm.length; ++c2) {
                eta[c2] = r2.innerProduct(bm[c2]) + o2;
                if (!(eta[c2] > maxRow)) continue;
                maxRow = eta[c2];
            }
            for (c2 = 0; c2 < bm.length; ++c2) {
                eta[c2] = Math.exp(eta[c2] - maxRow);
                sumExp += eta[c2];
            }
            sumExp = 1.0 / sumExp;
            for (c2 = 0; c2 < bm.length; ++c2) {
                preds[c2 + 1] = eta[c2] * sumExp;
            }
            preds[0] = ArrayUtils.maxIndex(eta);
        } else {
            double mu = ((GLMModel.GLMParameters)this._m._parms).linkInv(r2.innerProduct(this._beta) + o2);
            if (((GLMModel.GLMParameters)this._m._parms)._family == GLMModel.GLMParameters.Family.binomial || ((GLMModel.GLMParameters)this._m._parms)._family == GLMModel.GLMParameters.Family.quasibinomial || ((GLMModel.GLMParameters)this._m._parms)._family == GLMModel.GLMParameters.Family.fractionalbinomial) {
                preds[0] = mu >= this._defaultThreshold ? 1.0 : 0.0;
                preds[1] = 1.0 - mu;
                preds[2] = mu;
            } else {
                preds[0] = mu;
            }
        }
        return preds;
    }

    private void processRow(DataInfo.Row r2, float[] res, double[] ps, NewChunk[] preds, int ncols) {
        if (this._dinfo._responses != 0) {
            res[0] = (float)r2.response[0];
        }
        if (r2.predictors_bad) {
            Arrays.fill(ps, Double.NaN);
        } else if (r2.weight == 0.0) {
            Arrays.fill(ps, 0.0);
        } else {
            this.scoreRow(r2, r2.offset, ps);
            if (this._computeMetrics && !r2.response_bad) {
                this._mb.perRow(ps, res, r2.weight, r2.offset, this._m);
            }
        }
        if (this._generatePredictions) {
            for (int c2 = 0; c2 < ncols; ++c2) {
                preds[c2].addNum(ps[c2]);
            }
            if (this._vcov != null) {
                preds[ncols].addNum(Math.sqrt(r2.innerProduct(r2.mtrxMul(this._vcov, this._tmp))));
            }
        }
    }

    @Override
    public void map(Chunk[] chks, NewChunk[] preds) {
        int ncols;
        double[] ps;
        if (this.isCancelled() || this._j != null && this._j.stop_requested()) {
            return;
        }
        if (((GLMModel.GLMParameters)this._m._parms)._family == GLMModel.GLMParameters.Family.multinomial || ((GLMModel.GLMParameters)this._m._parms)._family == GLMModel.GLMParameters.Family.ordinal) {
            this._eta = MemoryManager.malloc8d(this._nclasses);
        }
        this._vcov = ((GLMModel.GLMOutput)this._m._output)._vcov;
        if (this._generatePredictions && this._vcov != null) {
            this._tmp = MemoryManager.malloc8d(this._vcov.length);
        }
        if (this._computeMetrics) {
            this._mb = this._m.makeMetricBuilder(this._domain);
            ps = this._mb._work;
        } else {
            ps = new double[((GLMModel.GLMOutput)this._m._output)._nclasses + 1];
        }
        float[] res = new float[1];
        int nc = ((GLMModel.GLMOutput)this._m._output).nclasses();
        int n2 = ncols = nc == 1 ? 1 : nc + 1;
        if (this._sparse) {
            for (DataInfo.Row r2 : this._dinfo.extractSparseRows(chks)) {
                this.processRow(r2, res, ps, preds, ncols);
            }
        } else {
            DataInfo.Row r3 = this._dinfo.newDenseRow();
            for (int rid = 0; rid < chks[0]._len; ++rid) {
                this._dinfo.extractDenseRow(chks, rid, r3);
                this.processRow(r3, res, ps, preds, ncols);
            }
        }
        if (this._j != null) {
            this._j.update(1L);
        }
    }

    @Override
    public void reduce(GLMScore bs) {
        if (this._mb != null) {
            this._mb.reduce(bs._mb);
        }
    }

    @Override
    protected void postGlobal() {
        if (this._mb != null) {
            this._mb.postGlobal();
        }
    }
}

