/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.xgboost.predict;

import biz.k11i.xgboost.util.FVec;
import hex.DataInfo;
import hex.genmodel.GenModel;

public class MutableOneHotEncoderFVec
implements FVec {
    private final DataInfo _di;
    private final boolean _treatsZeroAsNA;
    private final int[] _catMap;
    private final int[] _catValues;
    private final float[] _numValues;
    private final float _notHot;

    public MutableOneHotEncoderFVec(DataInfo di, boolean treatsZeroAsNA) {
        this._di = di;
        this._catValues = new int[this._di._cats];
        this._treatsZeroAsNA = treatsZeroAsNA;
        float f2 = this._notHot = this._treatsZeroAsNA ? Float.NaN : 0.0f;
        if (this._di._catOffsets == null) {
            this._catMap = new int[0];
        } else {
            this._catMap = new int[this._di._catOffsets[this._di._cats]];
            for (int c2 = 0; c2 < this._di._cats; ++c2) {
                for (int j2 = this._di._catOffsets[c2]; j2 < this._di._catOffsets[c2 + 1]; ++j2) {
                    this._catMap[j2] = c2;
                }
            }
        }
        this._numValues = new float[this._di._nums];
    }

    public void setInput(double[] input) {
        GenModel.setCats(input, this._catValues, this._di._cats, this._di._catOffsets, this._di._useAllFactorLevels);
        for (int i2 = 0; i2 < this._numValues.length; ++i2) {
            float val = (float)input[this._di._cats + i2];
            this._numValues[i2] = this._treatsZeroAsNA && val == 0.0f ? Float.NaN : val;
        }
    }

    @Override
    public final float fvalue(int index) {
        if (index >= this._catMap.length) {
            return this._numValues[index - this._catMap.length];
        }
        boolean isHot = this._catValues[this._catMap[index]] == index;
        return isHot ? 1.0f : this._notHot;
    }

    public void decodeAggregate(float[] encoded, float[] output) {
        for (int c2 = 0; c2 < this._di._cats; ++c2) {
            float sum = 0.0f;
            for (int i2 = this._di._catOffsets[c2]; i2 < this._di._catOffsets[c2 + 1]; ++i2) {
                sum += encoded[i2];
            }
            output[c2] = sum;
        }
        int numStart = this._di._catOffsets[this._di._cats];
        if (this._di._nums >= 0) {
            System.arraycopy(encoded, numStart, output, this._di._cats, this._di._nums);
        }
    }
}

