/*
 * Decompiled with CFR 0.152.
 */
package hex.gam.MatrixFrameUtils;

import hex.Model;
import hex.gam.GAM;
import hex.gam.GAMModel;
import hex.gam.GamSplines.ThinPlateRegressionUtils;
import hex.gam.MatrixFrameUtils.GAMModelUtils;
import hex.glm.GLMModel;
import hex.quantile.Quantile;
import hex.quantile.QuantileModel;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import water.DKV;
import water.Key;
import water.MemoryManager;
import water.Scope;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.TwoDimTable;

public class GamUtils {
    public static double[][][] allocate3DArrayCS(int num2DArrays, GAMModel.GAMParameters parms, AllocateType fileMode) {
        double[][][] array3D = new double[num2DArrays][][];
        int gamColCount = 0;
        for (int frameIdx = 0; frameIdx < num2DArrays; ++frameIdx) {
            if (parms._gam_columns_sorted[frameIdx].length != 1) continue;
            int numKnots = parms._num_knots_sorted[frameIdx];
            array3D[gamColCount++] = GamUtils.allocate2DArray(fileMode, numKnots);
        }
        return array3D;
    }

    public static double[][][] allocate3DArray(int num2DArrays, GAMModel.GAMParameters parms, AllocateType fileMode) {
        double[][][] array3D = new double[num2DArrays][][];
        for (int frameIdx = 0; frameIdx < num2DArrays; ++frameIdx) {
            if (parms._bs_sorted[frameIdx] != 2) {
                array3D[frameIdx] = GamUtils.allocate2DArray(fileMode, parms._num_knots_sorted[frameIdx]);
                continue;
            }
            int totBasis = parms._num_knots_sorted[frameIdx] + parms._spline_orders_sorted[frameIdx] - 2;
            array3D[frameIdx] = GamUtils.allocate2DArray(fileMode, totBasis);
        }
        return array3D;
    }

    public static void removeCenteringIS(double[][][] penaltyMatCenter, GAMModel.GAMParameters parms) {
        int numGamCol = parms._bs_sorted.length;
        for (int index = 0; index < numGamCol; ++index) {
            if (parms._bs_sorted[index] != 2) continue;
            int numBasis = parms._num_knots_sorted[index] + parms._spline_orders_sorted[index] - 2;
            penaltyMatCenter[index] = GamUtils.allocate2DArray(AllocateType.sameOrig, numBasis);
        }
    }

    public static double[][][] allocate3DArrayTP(int num2DArrays, GAMModel.GAMParameters parms, int[] secondDim, int[] thirdDim) {
        double[][][] array3D = new double[num2DArrays][][];
        int gamColCount = 0;
        int numGamCols = parms._gam_columns.length;
        for (int frameIdx = 0; frameIdx < numGamCols; ++frameIdx) {
            if (parms._bs_sorted[frameIdx] != 1) continue;
            array3D[gamColCount] = MemoryManager.malloc8d(secondDim[gamColCount], thirdDim[gamColCount]);
            ++gamColCount;
        }
        return array3D;
    }

    public static double[][] allocate2DArray(AllocateType fileMode, int numKnots) {
        double[][] array2D;
        switch (fileMode) {
            case firstOneLess: {
                array2D = MemoryManager.malloc8d(numKnots - 1, numKnots);
                break;
            }
            case sameOrig: {
                array2D = MemoryManager.malloc8d(numKnots, numKnots);
                break;
            }
            case bothOneLess: {
                array2D = MemoryManager.malloc8d(numKnots - 1, numKnots - 1);
                break;
            }
            case firstTwoLess: {
                array2D = MemoryManager.malloc8d(numKnots - 2, numKnots);
                break;
            }
            default: {
                throw new IllegalArgumentException("fileMode can only be firstOneLess, sameOrig, bothOneLess or firstTwoLess.");
            }
        }
        return array2D;
    }

    public static Integer[] sortCoeffMags(int arrayLength, final double[] coeffMags) {
        Integer[] indices = new Integer[arrayLength];
        for (int i2 = 0; i2 < indices.length; ++i2) {
            indices[i2] = i2;
        }
        Arrays.sort(indices, new Comparator<Integer>(){

            @Override
            public int compare(Integer o1, Integer o2) {
                if (coeffMags[o1] < coeffMags[o2]) {
                    return 1;
                }
                if (coeffMags[o1] > coeffMags[o2]) {
                    return -1;
                }
                return 0;
            }
        });
        return indices;
    }

    public static boolean equalColNames(String[] name1, String[] standardN, String response_column) {
        boolean equalNames;
        boolean name1ContainsResp = ArrayUtils.contains(name1, response_column);
        boolean standarNContainsResp = ArrayUtils.contains(standardN, response_column);
        boolean bl = equalNames = name1.length == standardN.length;
        if (name1ContainsResp && !standarNContainsResp) {
            equalNames = name1.length == standardN.length + 1;
        } else if (!name1ContainsResp && standarNContainsResp) {
            boolean bl2 = equalNames = name1.length + 1 == standardN.length;
        }
        if (equalNames) {
            for (String name : name1) {
                if (name == response_column || ArrayUtils.contains(standardN, name)) continue;
                return false;
            }
            return true;
        }
        return equalNames;
    }

    public static void copy2DArray(double[][] src_array, double[][] dest_array) {
        int numRows = src_array.length;
        for (int colIdx = 0; colIdx < numRows; ++colIdx) {
            System.arraycopy(src_array[colIdx], 0, dest_array[colIdx], 0, src_array[colIdx].length);
        }
    }

    public static void copy2DArray(int[][] src_array, int[][] dest_array) {
        int numRows = src_array.length;
        for (int colIdx = 0; colIdx < numRows; ++colIdx) {
            System.arraycopy(src_array[colIdx], 0, dest_array[colIdx], 0, src_array[colIdx].length);
        }
    }

    public static void copyCVGLMtoGAMModel(GAMModel model, GLMModel glmModel, GAMModel.GAMParameters parms, String foldColumn) {
        ((GAMModel.GAMModelOutput)model._output)._cross_validation_metrics = ((GLMModel.GLMOutput)glmModel._output)._cross_validation_metrics;
        ((GAMModel.GAMModelOutput)model._output)._cross_validation_metrics_summary = GAMModelUtils.copyTwoDimTable(((GLMModel.GLMOutput)glmModel._output)._cross_validation_metrics_summary, "GLM cross-validation metrics summary");
        int nFolds = ((GLMModel.GLMOutput)glmModel._output)._cv_scoring_history.length;
        ((GAMModel.GAMModelOutput)model._output)._glm_cv_scoring_history = new TwoDimTable[nFolds];
        if (parms._keep_cross_validation_predictions) {
            ((GAMModel.GAMModelOutput)model._output)._cross_validation_predictions = new Key[nFolds];
        }
        for (int fInd = 0; fInd < nFolds; ++fInd) {
            ((GAMModel.GAMModelOutput)model._output)._glm_cv_scoring_history[fInd] = GAMModelUtils.copyTwoDimTable(((GLMModel.GLMOutput)glmModel._output)._cv_scoring_history[fInd], ((GLMModel.GLMOutput)glmModel._output)._cv_scoring_history[fInd].getTableHeader());
            if (!parms._keep_cross_validation_predictions) continue;
            Frame pred = (Frame)DKV.getGet(((GLMModel.GLMOutput)glmModel._output)._cross_validation_predictions[fInd]);
            Frame newPred = pred.deepCopy(Key.make().toString());
            DKV.put(newPred);
            ((GAMModel.GAMModelOutput)model._output)._cross_validation_predictions[fInd] = newPred.getKey();
        }
        if (parms._keep_cross_validation_models) {
            ((GAMModel.GAMModelOutput)model._output)._cross_validation_models = GamUtils.buildCVGamModels(model, glmModel, parms, foldColumn);
        }
        if (parms._keep_cross_validation_predictions) {
            Frame cvPred = (Frame)DKV.getGet(((GLMModel.GLMOutput)glmModel._output)._cross_validation_holdout_predictions_frame_id);
            Frame newPred = cvPred.deepCopy(Key.make().toString());
            DKV.put(newPred);
            ((GAMModel.GAMModelOutput)model._output)._cross_validation_holdout_predictions_frame_id = newPred.getKey();
        }
        if (parms._keep_cross_validation_fold_assignment) {
            Frame foldAssignment = (Frame)DKV.getGet(((GLMModel.GLMOutput)glmModel._output)._cross_validation_fold_assignment_frame_id);
            Frame newFold = foldAssignment.deepCopy(Key.make().toString());
            DKV.put(newFold);
            ((GAMModel.GAMModelOutput)model._output)._cross_validation_fold_assignment_frame_id = newFold.getKey();
        }
    }

    public static Key[] buildCVGamModels(GAMModel model, GLMModel glmModel, GAMModel.GAMParameters parms, String foldColumn) {
        int nFolds = ((GLMModel.GLMOutput)glmModel._output)._cross_validation_models.length;
        Key[] cvModelKeys = new Key[nFolds];
        for (int fInd = 0; fInd < nFolds; ++fInd) {
            GLMModel cvModel = (GLMModel)DKV.getGet(((GLMModel.GLMOutput)glmModel._output)._cross_validation_models[fInd]);
            GAMModel.GAMParameters gamParams = GamUtils.makeGAMParameters(parms);
            if (foldColumn != null) {
                if (gamParams._ignored_columns != null) {
                    ArrayList<String> ignoredCols = new ArrayList<String>(Arrays.asList(gamParams._ignored_columns));
                    ignoredCols.add(foldColumn);
                    gamParams._ignored_columns = ignoredCols.toArray(new String[0]);
                } else {
                    gamParams._ignored_columns = new String[]{foldColumn};
                }
            }
            int maxIterations = gamParams._max_iterations;
            gamParams._max_iterations = 1;
            GAMModel gamModel = (GAMModel)new GAM(gamParams).trainModel().get();
            gamParams._max_iterations = maxIterations;
            GAMModelUtils.copyGLMCoeffs(cvModel, gamModel, gamParams, model._nclass);
            GAMModelUtils.copyGLMtoGAMModel(gamModel, cvModel, parms, true);
            cvModelKeys[fInd] = gamModel.getKey();
            DKV.put(gamModel);
        }
        return cvModelKeys;
    }

    public static GAMModel.GAMParameters makeGAMParameters(GAMModel.GAMParameters parms) {
        GAMModel.GAMParameters gamParams = new GAMModel.GAMParameters();
        Field[] field1 = GAMModel.GAMParameters.class.getDeclaredFields();
        Field[] field2 = Model.Parameters.class.getDeclaredFields();
        GamUtils.setParamField(parms, gamParams, false, field1, Collections.emptyList());
        GamUtils.setParamField(parms, gamParams, true, field2, Collections.emptyList());
        gamParams._nfolds = 0;
        gamParams._keep_cross_validation_predictions = false;
        gamParams._keep_cross_validation_fold_assignment = false;
        gamParams._keep_cross_validation_models = false;
        gamParams._train = parms._train;
        return gamParams;
    }

    public static void setParamField(Model.Parameters parms, Model.Parameters glmParam, boolean superClassParams, Field[] gamFields, List<String> excludeList) {
        boolean emptyExcludeList = excludeList == null || excludeList.size() == 0;
        for (Field oneField : gamFields) {
            try {
                if (!emptyExcludeList && excludeList.contains(oneField.getName())) continue;
                Field glmField = superClassParams ? glmParam.getClass().getSuperclass().getDeclaredField(oneField.getName()) : glmParam.getClass().getDeclaredField(oneField.getName());
                glmField.set(glmParam, oneField.get(parms));
            }
            catch (IllegalAccessException | NoSuchFieldException reflectiveOperationException) {
                // empty catch block
            }
        }
    }

    public static void keepFrameKeys(List<Key> keep, Key<Frame> ... keyNames) {
        for (Key<Frame> keyName : keyNames) {
            Frame loadingFrm = (Frame)DKV.getGet(keyName);
            if (loadingFrm == null) continue;
            for (Vec vec : loadingFrm.vecs()) {
                keep.add(vec._key);
            }
        }
    }

    public static void setDefaultBSType(GAMModel.GAMParameters parms) {
        parms._bs = new int[parms._gam_columns.length];
        for (int index = 0; index < parms._bs.length; ++index) {
            parms._bs[index] = parms._gam_columns[index].length > 1 ? 1 : 0;
        }
    }

    public static void setThinPlateParameters(GAMModel.GAMParameters parms, int thinPlateNum) {
        int numGamCols = parms._gam_columns.length;
        parms._m = MemoryManager.malloc4(thinPlateNum);
        parms._M = MemoryManager.malloc4(thinPlateNum);
        int countThinPlate = 0;
        for (int index = 0; index < numGamCols; ++index) {
            if (parms._bs[index] != 1) continue;
            int d2 = parms._gam_columns[index].length;
            parms._m[countThinPlate] = ThinPlateRegressionUtils.calculatem(d2);
            parms._M[countThinPlate] = ThinPlateRegressionUtils.calculateM(d2, parms._m[countThinPlate]);
            ++countThinPlate;
        }
    }

    public static void setGamPredSize(GAMModel.GAMParameters parms, int csOffset) {
        int numGamCols = parms._gam_columns.length;
        int tpCount = csOffset;
        int csCount = 0;
        parms._gamPredSize = MemoryManager.malloc4(numGamCols);
        for (int index = 0; index < numGamCols; ++index) {
            if (parms._gam_columns[index].length == 1) {
                parms._gamPredSize[csCount++] = 1;
                continue;
            }
            parms._gamPredSize[tpCount++] = parms._gam_columns[index].length;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static double[] generateKnotsOneColumn(Frame gamFrame, int knotNum) {
        double[] knots = MemoryManager.malloc8d(knotNum);
        try {
            Scope.enter();
            Frame tempFrame = new Frame(gamFrame);
            DKV.put(tempFrame);
            double[] prob = MemoryManager.malloc8d(knotNum);
            assert (knotNum > 1);
            double stepProb = 1.0 / (double)(knotNum - 1);
            for (int knotInd = 0; knotInd < knotNum; ++knotInd) {
                prob[knotInd] = (double)knotInd * stepProb;
            }
            QuantileModel.QuantileParameters parms = new QuantileModel.QuantileParameters();
            parms._train = tempFrame._key;
            parms._probs = prob;
            QuantileModel qModel = (QuantileModel)new Quantile(parms).trainModel().get();
            DKV.remove(tempFrame._key);
            Scope.track_generic(qModel);
            System.arraycopy(((QuantileModel.QuantileOutput)qModel._output)._quantiles[0], 0, knots, 0, knotNum);
        }
        finally {
            Scope.exit(new Key[0]);
        }
        return knots;
    }

    public static Frame prepareGamVec(int gam_column_index, GAMModel.GAMParameters parms, Frame fr) {
        Vec weights_column = parms._weights_column == null ? Scope.track(Vec.makeOne(fr.numRows())) : fr.vec(parms._weights_column);
        Frame predictVec = new Frame(new Vec[0]);
        int numPredictors = parms._gam_columns_sorted[gam_column_index].length;
        for (int colInd = 0; colInd < numPredictors; ++colInd) {
            predictVec.add(parms._gam_columns_sorted[gam_column_index][colInd], fr.vec(parms._gam_columns_sorted[gam_column_index][colInd]));
        }
        predictVec.add("weights_column", weights_column);
        return predictVec;
    }

    public static String[] generateGamColNames(int gamColIndex, GAMModel.GAMParameters parms) {
        String[] newColNames = null;
        newColNames = parms._bs_sorted[gamColIndex] == 0 ? new String[parms._num_knots_sorted[gamColIndex]] : new String[parms._num_knots_sorted[gamColIndex] + parms._spline_orders_sorted[gamColIndex] - 2];
        String stubName = parms._gam_columns_sorted[gamColIndex][0] + "_";
        stubName = parms._bs_sorted[gamColIndex] == 0 ? stubName + "cr_" : (parms._bs_sorted[gamColIndex] == 2 ? stubName + "is_" : stubName + "tp_");
        for (int knotIndex = 0; knotIndex < newColNames.length; ++knotIndex) {
            newColNames[knotIndex] = stubName + knotIndex;
        }
        return newColNames;
    }

    public static String[] generateGamColNamesThinPlateKnots(int gamColIndex, GAMModel.GAMParameters parms, int[][] polyBasisDegree, String nameStub) {
        int index;
        int num_knots = parms._num_knots_sorted[gamColIndex];
        int polyBasisSize = polyBasisDegree.length;
        String[] gamColNames = new String[num_knots + polyBasisSize];
        for (index = 0; index < num_knots; ++index) {
            gamColNames[index] = nameStub + index;
        }
        for (index = 0; index < polyBasisSize; ++index) {
            gamColNames[index + num_knots] = GamUtils.genPolyBasisNames(parms._gam_columns_sorted[gamColIndex], polyBasisDegree[index]);
        }
        return gamColNames;
    }

    public static String genPolyBasisNames(String[] gam_columns, int[] oneBasis) {
        StringBuffer polyBasisName = new StringBuffer();
        int numGamCols = gam_columns.length;
        int beforeLastIndex = numGamCols - 1;
        for (int index = 0; index < numGamCols; ++index) {
            polyBasisName.append(gam_columns[index]);
            polyBasisName.append("_");
            polyBasisName.append(oneBasis[index]);
            if (index >= beforeLastIndex) continue;
            polyBasisName.append("_");
        }
        return polyBasisName.toString();
    }

    public static Frame buildGamFrame(GAMModel.GAMParameters parms, Frame train, Key<Frame>[] gamFrameKeysCenter, String foldColumn) {
        Vec responseVec = train.remove(parms._response_column);
        List<Object> ignored_cols = parms._ignored_columns == null ? new ArrayList() : Arrays.asList(parms._ignored_columns);
        Vec weightsVec = null;
        Vec offsetVec = null;
        Vec foldVec = null;
        if (parms._offset_column != null) {
            offsetVec = train.remove(parms._offset_column);
        }
        if (parms._weights_column != null) {
            weightsVec = train.remove(parms._weights_column);
        }
        if (foldColumn != null) {
            foldVec = train.remove(foldColumn);
        }
        for (int colIdx = 0; colIdx < parms._gam_columns_sorted.length; ++colIdx) {
            Frame gamFrame = Scope.track(gamFrameKeysCenter[colIdx].get());
            train.add(gamFrame.names(), gamFrame.removeAll());
            if (!ignored_cols.contains(parms._gam_columns_sorted[colIdx])) continue;
            train.remove(parms._gam_columns_sorted[colIdx]);
        }
        if (foldColumn != null) {
            train.add(foldColumn, foldVec);
        }
        if (weightsVec != null) {
            train.add(parms._weights_column, weightsVec);
        }
        if (offsetVec != null) {
            train.add(parms._offset_column, offsetVec);
        }
        if (responseVec != null) {
            train.add(parms._response_column, responseVec);
        }
        return train;
    }

    public static Frame concateGamVecs(Key<Frame>[] gamFrameKeysCenter) {
        Frame gamVecs = new Frame(Key.make());
        for (int index = 0; index < gamFrameKeysCenter.length; ++index) {
            Frame tempCols = Scope.track(gamFrameKeysCenter[index].get());
            gamVecs.add(tempCols.names(), tempCols.removeAll());
        }
        return gamVecs;
    }

    public static void sortGAMParameters(GAMModel.GAMParameters parms, int csGamCol, int isGamCol) {
        int gamColNum = parms._gam_columns.length;
        int csIndex = 0;
        int isIndex = csGamCol;
        int tpIndex = csGamCol + isGamCol;
        parms._gam_columns_sorted = new String[gamColNum][];
        parms._num_knots_sorted = MemoryManager.malloc4(gamColNum);
        parms._scale_sorted = MemoryManager.malloc8d(gamColNum);
        parms._bs_sorted = MemoryManager.malloc4(gamColNum);
        parms._gamPredSize = MemoryManager.malloc4(gamColNum);
        parms._spline_orders_sorted = MemoryManager.malloc4(gamColNum);
        for (int index = 0; index < gamColNum; ++index) {
            if (parms._bs[index] == 0) {
                GamUtils.setGamParameters(parms, index, csIndex++);
                continue;
            }
            if (parms._bs[index] == 2) {
                GamUtils.setGamParameters(parms, index, isIndex);
                parms._spline_orders_sorted[isIndex++] = parms._spline_orders[index];
                continue;
            }
            GamUtils.setGamParameters(parms, index, tpIndex++);
        }
    }

    public static void setGamParameters(GAMModel.GAMParameters parms, int gamIndex, int splineIndex) {
        parms._gam_columns_sorted[splineIndex] = (String[])parms._gam_columns[gamIndex].clone();
        parms._num_knots_sorted[splineIndex] = parms._num_knots[gamIndex];
        parms._scale_sorted[splineIndex] = parms._scale[gamIndex];
        parms._gamPredSize[splineIndex] = parms._gam_columns_sorted[splineIndex].length;
        parms._bs_sorted[splineIndex] = parms._bs[gamIndex];
    }

    public static void setDefaultScale(GAMModel.GAMParameters parms) {
        int numGamCol = parms._gam_columns.length;
        parms._scale = new double[numGamCol];
        for (int index = 0; index < numGamCol; ++index) {
            parms._scale[index] = 1.0;
        }
    }

    public static enum AllocateType {
        firstOneLess,
        sameOrig,
        bothOneLess,
        firstTwoLess;

    }
}

