/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.klime;

import hex.ModelCategory;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.glm.GlmMojoModel;
import java.util.EnumSet;

public class KLimeMojoModel
extends MojoModel {
    MojoModel _clusteringModel;
    MojoModel _globalRegressionModel;
    MojoModel[] _clusterRegressionModels;
    int[] _rowSubsetMap;

    @Override
    public EnumSet<ModelCategory> getModelCategories() {
        return EnumSet.of(ModelCategory.Regression, ModelCategory.KLime);
    }

    KLimeMojoModel(String[] columns, String[][] domains, String responseColumn) {
        super(columns, domains, responseColumn);
    }

    @Override
    public double[] score0(double[] row, double[] preds) {
        assert (preds.length == row.length + 2);
        double[] predsSubset = new double[this._clusteringModel.nfeatures() + 2];
        double[] rowSubset = new double[this._clusteringModel.nfeatures()];
        for (int j2 = 0; j2 < this._clusteringModel._names.length; ++j2) {
            rowSubset[j2] = row[this._rowSubsetMap[j2]];
        }
        this._clusteringModel.score0(rowSubset, predsSubset);
        int cluster = (int)predsSubset[0];
        GlmMojoModel regressionModel = this.getRegressionModel(cluster);
        regressionModel.score0(row, preds);
        preds[1] = cluster;
        for (int i2 = 2; i2 < preds.length; ++i2) {
            preds[i2] = Double.NaN;
        }
        regressionModel.applyCoefficients(row, preds, 2);
        return preds;
    }

    public GlmMojoModel getRegressionModel(int cluster) {
        return (GlmMojoModel)(this._clusterRegressionModels[cluster] != null ? this._clusterRegressionModels[cluster] : this._globalRegressionModel);
    }

    @Override
    public int getPredsSize() {
        return this.nfeatures() + 2;
    }
}

