/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.gam;

import hex.genmodel.ConverterFactoryProvidingModel;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.gam.GamRowToRawDataConverter;
import hex.genmodel.algos.gam.GamUtilsCubicRegression;
import hex.genmodel.algos.gam.GamUtilsThinPlateRegression;
import hex.genmodel.algos.gam.ISplines;
import hex.genmodel.easy.CategoricalEncoder;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.RowToRawDataConverter;
import hex.genmodel.utils.ArrayUtils;
import hex.genmodel.utils.DistributionFamily;
import hex.genmodel.utils.LinkFunctionType;
import java.util.Map;

public abstract class GamMojoModelBase
extends MojoModel
implements ConverterFactoryProvidingModel,
Cloneable {
    public LinkFunctionType _link_function;
    boolean _useAllFactorLevels;
    int _cats;
    int[] _catNAFills;
    int[] _catOffsets;
    int _nums;
    int _numsCenter;
    double[] _numNAFillsCenter;
    boolean _meanImputation;
    double[] _beta_no_center;
    double[] _beta_center;
    double[][] _beta_multinomial;
    double[][] _beta_multinomial_no_center;
    double[][] _beta_multinomial_center;
    int[] _spline_orders;
    int[] _spline_orders_sorted;
    DistributionFamily _family;
    String[][] _gam_columns;
    String[][] _gam_columns_sorted;
    int[] _d;
    int[] _m;
    int[] _M;
    int[] _gamPredSize;
    int _num_gam_columns;
    int[] _bs;
    int[] _bs_sorted;
    int[] _num_knots;
    int[] _num_knots_sorted;
    int[] _num_knots_sorted_minus1;
    int[] _numBasisSize;
    int[] _num_knots_TP;
    double[][][] _knots;
    double[][][] _binvD;
    double[][][] _zTranspose;
    double[][][] _zTransposeCS;
    String[][] _gamColNames;
    String[][] _gamColNamesCenter;
    String[] _names_no_centering;
    int _totFeatureSize;
    int _betaSizePerClass;
    int _betaCenterSizePerClass;
    double _tweedieLinkPower;
    double[][] _hj;
    int _numExpandedGamCols;
    int _numExpandedGamColsCenter;
    int _lastClass;
    int[][][] _allPolyBasisList;
    int _numTPCol;
    int _numCSCol;
    int _numISCol;
    int[] _tpDistzCSSize;
    boolean[] _dEven;
    double[] _constantTerms;
    double[][] _gamColMeansRaw;
    double[][] _oneOGamColStd;
    boolean _standardize;
    ISplines[] _iSplineBasis;

    GamMojoModelBase(String[] columns, String[][] domains, String responseColumn) {
        super(columns, domains, responseColumn);
    }

    @Override
    public double[] score0(double[] row, double[] preds) {
        if (this._meanImputation) {
            this.imputeMissingWithMeans(row);
        }
        return this.gamScore0(row, preds);
    }

    void init() {
        int absIndex;
        int ind;
        int index;
        this._num_knots_sorted_minus1 = new int[this._num_knots_sorted.length];
        for (index = 0; index < this._num_knots_sorted.length; ++index) {
            this._num_knots_sorted_minus1[index] = this._num_knots_sorted[index] - 1;
        }
        if (this._numCSCol > 0) {
            this._hj = new double[this._numCSCol][];
            for (ind = 0; ind < this._numCSCol; ++ind) {
                this._hj[ind] = ArrayUtils.eleDiff(this._knots[ind][0]);
            }
        }
        if (this._numISCol > 0) {
            this._numBasisSize = new int[this._numISCol];
            this._iSplineBasis = new ISplines[this._numISCol];
            for (ind = 0; ind < this._numISCol; ++ind) {
                absIndex = ind + this._numCSCol;
                this._numBasisSize[ind] = this._num_knots_sorted[absIndex] + this._spline_orders_sorted[absIndex] - 2;
                this._iSplineBasis[ind] = new ISplines(this._spline_orders_sorted[absIndex], this._knots[absIndex][0]);
            }
        }
        if (this._numTPCol > 0) {
            this._tpDistzCSSize = new int[this._numTPCol];
            this._dEven = new boolean[this._numTPCol];
            this._constantTerms = new double[this._numTPCol];
            for (index = 0; index < this._numTPCol; ++index) {
                absIndex = index + this._numCSCol + this._numISCol;
                this._tpDistzCSSize[index] = this._num_knots_sorted[absIndex] - this._M[index];
                this._dEven[index] = this._d[absIndex] % 2 == 0;
                this._constantTerms[index] = GamUtilsThinPlateRegression.calTPConstantTerm(this._m[index], this._d[absIndex], this._dEven[index]);
            }
        }
        this._lastClass = this._nclasses - 1;
    }

    @Override
    public GenModel internal_threadSafeInstance() {
        try {
            GamMojoModelBase clonedMojo = (GamMojoModelBase)this.clone();
            clonedMojo.init();
            return clonedMojo;
        }
        catch (CloneNotSupportedException e2) {
            throw new RuntimeException(e2);
        }
    }

    abstract double[] gamScore0(double[] var1, double[] var2);

    private void imputeMissingWithMeans(double[] data) {
        int ind;
        for (ind = 0; ind < this._cats; ++ind) {
            if (!Double.isNaN(data[ind])) continue;
            data[ind] = this._catNAFills[ind];
        }
        for (ind = 0; ind < this._numsCenter; ++ind) {
            if (!Double.isNaN(data[ind + this._cats])) continue;
            data[ind + this._cats] = this._numNAFillsCenter[ind];
        }
    }

    double evalLink(double val) {
        switch (this._link_function) {
            case identity: {
                return GenModel.GLM_identityInv(val);
            }
            case logit: {
                return GenModel.GLM_logitInv(val);
            }
            case log: {
                return GenModel.GLM_logInv(val);
            }
            case inverse: {
                return GenModel.GLM_inverseInv(val);
            }
            case tweedie: {
                return GenModel.GLM_tweedieInv(val, this._tweedieLinkPower);
            }
        }
        throw new UnsupportedOperationException("Unexpected link function " + (Object)((Object)this._link_function));
    }

    int readCatVal(double data, int dataIndex) {
        int ival;
        int n2 = ival = this._useAllFactorLevels ? (int)data : (int)data - 1;
        if (ival < 0) {
            return -1;
        }
        return ival += this._catOffsets[dataIndex];
    }

    double generateEta(double[] beta, double[] data) {
        double eta = 0.0;
        int catOffsetLength = this._catOffsets.length - 1;
        for (int i2 = 0; i2 < catOffsetLength; ++i2) {
            int ival = this.readCatVal(data[i2], i2);
            if (ival >= this._catOffsets[i2 + 1] || ival < 0) continue;
            eta += beta[ival];
        }
        int noff = this._catOffsets[this._cats] - this._cats;
        int numColLen = beta.length - 1 - noff;
        for (int i3 = this._cats; i3 < numColLen; ++i3) {
            eta += beta[noff + i3] * data[i3];
        }
        return eta += beta[beta.length - 1];
    }

    private boolean gamificationNeeded(double[] rawData, int gamColStart) {
        for (int cind = gamColStart; cind < rawData.length; ++cind) {
            if (Double.isNaN(rawData[cind])) continue;
            return false;
        }
        return true;
    }

    int addCSGamification(RowData rowData, int cind, int dataIndEnd, double[] dataWithGamifiedColumns) {
        Object dataObject = rowData.get(this._gam_columns_sorted[cind][0]);
        double gamColData = Double.NaN;
        if (dataObject == null) {
            return dataIndEnd;
        }
        gamColData = dataObject instanceof String ? Double.parseDouble((String)dataObject) : (Double)dataObject;
        double[] basisVals = new double[this._num_knots_sorted[cind]];
        double[] basisValsCenter = new double[this._num_knots_sorted_minus1[cind]];
        GamUtilsCubicRegression.expandOneGamCol(gamColData, this._binvD[cind], basisVals, this._hj[cind], this._knots[cind][0]);
        ArrayUtils.multArray(basisVals, this._zTranspose[cind], basisValsCenter);
        System.arraycopy(basisValsCenter, 0, dataWithGamifiedColumns, dataIndEnd, this._num_knots_sorted_minus1[cind]);
        return dataIndEnd;
    }

    int addISGamification(RowData rowData, int cind, int csCounter, int dataIndEnd, double[] dataWithGamifiedColumns) {
        Object dataObject = rowData.get(this._gam_columns_sorted[cind][0]);
        double gamColData = Double.NaN;
        if (dataObject == null) {
            return dataIndEnd;
        }
        gamColData = dataObject instanceof String ? Double.parseDouble((String)dataObject) : (Double)dataObject;
        double[] basisVals = new double[this._numBasisSize[csCounter]];
        this._iSplineBasis[csCounter].gamifyVal(basisVals, gamColData);
        System.arraycopy(basisVals, 0, dataWithGamifiedColumns, dataIndEnd, this._numBasisSize[csCounter]);
        return dataIndEnd;
    }

    double[] addExpandGamCols(double[] rawData, RowData rowData) {
        int dataIndEnd = this._nfeatures - this._numExpandedGamColsCenter;
        if (!this.gamificationNeeded(rawData, dataIndEnd)) {
            return rawData;
        }
        double[] dataWithGamifiedColumns = ArrayUtils.nanArray(this._nfeatures);
        System.arraycopy(rawData, 0, dataWithGamifiedColumns, 0, dataIndEnd);
        int tpCounter = 0;
        int isCounter = 0;
        for (int cind = 0; cind < this._num_gam_columns; ++cind) {
            if (this._bs_sorted[cind] == 0) {
                dataIndEnd = this.addCSGamification(rowData, cind, dataIndEnd, dataWithGamifiedColumns);
            } else if (this._bs_sorted[cind] == 1) {
                this.addTPGamification(rowData, cind, tpCounter, dataIndEnd, dataWithGamifiedColumns);
                ++tpCounter;
            } else if (this._bs_sorted[cind] == 2) {
                this.addISGamification(rowData, cind, isCounter, dataIndEnd, dataWithGamifiedColumns);
                ++isCounter;
            } else {
                throw new IllegalArgumentException("spline type not implemented!");
            }
            dataIndEnd += this._num_knots_sorted_minus1[cind];
        }
        return dataWithGamifiedColumns;
    }

    int addTPGamification(RowData rowData, int cind, int tpCounter, int dataIndEnd, double[] dataWithGamifiedColumns) {
        String[] gamCols = this._gam_columns_sorted[cind];
        double[] gamPred = this.grabPredictorVals(gamCols, rowData);
        if (gamPred == null) {
            return dataIndEnd;
        }
        double[] tpDistance = new double[this._num_knots_sorted[cind]];
        GamUtilsThinPlateRegression.calculateDistance(tpDistance, gamPred, this._num_knots_sorted[cind], this._knots[cind], this._d[cind], this._m[tpCounter], this._dEven[tpCounter], this._constantTerms[tpCounter], this._oneOGamColStd[tpCounter], this._standardize);
        double[] tpDistzCS = new double[this._tpDistzCSSize[tpCounter]];
        ArrayUtils.multArray(tpDistance, this._zTransposeCS[tpCounter], tpDistzCS);
        double[] tpPoly = new double[this._M[tpCounter]];
        GamUtilsThinPlateRegression.calculatePolynomialBasis(tpPoly, gamPred, this._d[cind], this._M[tpCounter], this._allPolyBasisList[tpCounter], this._gamColMeansRaw[tpCounter], this._oneOGamColStd[tpCounter], this._standardize);
        double[] tpDistzCSPoly = new double[this._num_knots_sorted[cind]];
        double[] tpDistzCSPolyzT = new double[this._num_knots_sorted_minus1[cind]];
        System.arraycopy(tpDistzCS, 0, tpDistzCSPoly, 0, tpDistzCS.length);
        System.arraycopy(tpPoly, 0, tpDistzCSPoly, tpDistzCS.length, this._M[tpCounter]);
        ArrayUtils.multArray(tpDistzCSPoly, this._zTranspose[cind], tpDistzCSPolyzT);
        System.arraycopy(tpDistzCSPolyzT, 0, dataWithGamifiedColumns, dataIndEnd, tpDistzCSPolyzT.length);
        return dataIndEnd;
    }

    double[] grabPredictorVals(String[] gamCols, RowData rowData) {
        int numCol = gamCols.length;
        double[] predVals = new double[numCol];
        for (int index = 0; index < numCol; ++index) {
            Object data = rowData.get(gamCols[index]);
            if (data == null) {
                return null;
            }
            predVals[index] = data instanceof String ? Double.parseDouble((String)data) : (Double)data;
        }
        return predVals;
    }

    @Override
    public RowToRawDataConverter makeConverterFactory(Map<String, Integer> modelColumnNameToIndexMap, Map<Integer, CategoricalEncoder> domainMap, EasyPredictModelWrapper.ErrorConsumer errorConsumer, EasyPredictModelWrapper.Config config) {
        return new GamRowToRawDataConverter(this, modelColumnNameToIndexMap, domainMap, errorConsumer, config);
    }
}

