/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.xgboost;

import biz.k11i.xgboost.tree.RegTree;
import biz.k11i.xgboost.tree.RegTreeNodeStat;
import java.lang.reflect.Field;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;

public class AuxNodeWeightsHelper {
    private static final int DOUBLE_BYTES = 8;
    private static final int INTEGER_BYTES = 4;

    public static byte[] toBytes(double[][] auxNodeWeights) {
        int elements = 0;
        for (double[] weights : auxNodeWeights) {
            elements += weights.length;
        }
        int len = (1 + auxNodeWeights.length) * 4 + elements * 8;
        ByteBuffer bb = ByteBuffer.wrap(new byte[len]).order(ByteOrder.nativeOrder());
        bb.putInt(auxNodeWeights.length);
        for (double[] weights : auxNodeWeights) {
            bb.putInt(weights.length);
            for (double w2 : weights) {
                bb.putDouble(w2);
            }
        }
        return bb.array();
    }

    static double[][] fromBytes(byte[] auxNodeWeightBytes) {
        ByteBuffer bb = ByteBuffer.wrap(auxNodeWeightBytes).order(ByteOrder.nativeOrder());
        double[][] auxNodeWeights = new double[bb.getInt()][];
        for (int i2 = 0; i2 < auxNodeWeights.length; ++i2) {
            double[] weights = new double[bb.getInt()];
            for (int j2 = 0; j2 < weights.length; ++j2) {
                weights[j2] = bb.getDouble();
            }
            auxNodeWeights[i2] = weights;
        }
        return auxNodeWeights;
    }

    static void updateNodeWeights(RegTree[] trees, double[][] nodeWeights) {
        Field field;
        try {
            field = RegTreeNodeStat.class.getDeclaredField("sum_hess");
            field.setAccessible(true);
        }
        catch (NoSuchFieldException e2) {
            throw new IllegalStateException("Unable to access field 'sum_hess'.");
        }
        try {
            for (int i2 = 0; i2 < nodeWeights.length; ++i2) {
                RegTreeNodeStat[] stats = trees[i2].getStats();
                assert (stats.length == nodeWeights[i2].length);
                for (int j2 = 0; j2 < nodeWeights[i2].length; ++j2) {
                    field.setFloat(stats[j2], (float)nodeWeights[i2][j2]);
                }
            }
        }
        catch (IllegalAccessException e3) {
            throw new RuntimeException(e3);
        }
    }
}

