/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.glrm;

import hex.ModelCategory;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.glrm.GlrmInitialization;
import hex.genmodel.algos.glrm.GlrmLoss;
import hex.genmodel.algos.glrm.GlrmRegularizer;
import java.util.EnumSet;
import java.util.Random;

public class GlrmMojoModel
extends MojoModel {
    public int _ncolA;
    public int _ncolX;
    public int _ncolY;
    public int _nrowY;
    public double[][] _archetypes;
    public double[][] _archetypes_raw;
    public int[] _numLevels;
    public int[] _catOffsets;
    public int[] _permutation;
    public GlrmLoss[] _losses;
    public GlrmRegularizer _regx;
    public double _gammax;
    public GlrmInitialization _init;
    public int _ncats;
    public int _nnums;
    public double[] _normSub;
    public double[] _normMul;
    public long _seed;
    public boolean _transposed;
    public boolean _reverse_transform;
    public double _accuracyEps = 1.0E-10;
    public int _iterNumber = 100;
    private static final double DOWN_FACTOR = 0.5;
    private static final double UP_FACTOR = Math.pow(2.0, 0.25);
    public long _rcnt = 0L;
    public int _numAlphaFactors = 10;
    public double[] _allAlphas;
    private static EnumSet<ModelCategory> CATEGORIES;

    @Override
    public EnumSet<ModelCategory> getModelCategories() {
        return CATEGORIES;
    }

    public GlrmMojoModel(String[] columns, String[][] domains, String responseColumn) {
        super(columns, domains, responseColumn);
    }

    @Override
    public int getPredsSize(ModelCategory mc) {
        return this._ncolX;
    }

    public static double[] initializeAlphas(int numAlpha) {
        double[] alphas = new double[numAlpha];
        double alpha = 1.0;
        for (int index = 0; index < numAlpha; ++index) {
            alphas[index] = alpha *= 0.5;
        }
        return alphas;
    }

    public double[] score0(double[] row, double[] preds, long seedValue) {
        double obj;
        assert (row.length == this._ncolA);
        assert (preds.length == this._ncolX);
        assert (this._nrowY == this._ncolX);
        assert (this._archetypes.length == this._nrowY);
        assert (this._archetypes[0].length == this._ncolY);
        double[] a2 = this.getRowData(row);
        double[] x2 = new double[this._ncolX];
        double[] u2 = new double[this._ncolX];
        Random random = new Random(seedValue);
        for (int i2 = 0; i2 < this._ncolX; ++i2) {
            x2[i2] = random.nextGaussian();
        }
        x2 = this._regx.project(x2, random);
        double oldObj = obj = this.objective(x2, a2);
        boolean done = false;
        int iters = 0;
        while (!done && iters++ < this._iterNumber) {
            double[] grad = this.gradientL(x2, a2);
            obj = this.applyBestAlpha(u2, x2, grad, a2, oldObj, random);
            double obj_improvement = 1.0 - obj / oldObj;
            if (obj_improvement < 0.0 || obj_improvement < this._accuracyEps) {
                done = true;
            }
            oldObj = obj;
        }
        System.arraycopy(x2, 0, preds, 0, this._ncolX);
        return preds;
    }

    public double[] getRowData(double[] row) {
        int i2;
        double[] a2 = new double[this._ncolA];
        for (i2 = 0; i2 < this._ncats; ++i2) {
            double temp = row[this._permutation[i2]];
            a2[i2] = temp >= (double)this._numLevels[i2] ? Double.NaN : temp;
        }
        for (i2 = this._ncats; i2 < this._ncolA; ++i2) {
            a2[i2] = row[this._permutation[i2]];
        }
        return a2;
    }

    public double applyBestAlpha(double[] u2, double[] x2, double[] grad, double[] a2, double oldObj, Random random) {
        double[] bestX = new double[x2.length];
        double lowestObj = Double.MAX_VALUE;
        if (oldObj == 0.0) {
            return 0.0;
        }
        double alphaScale = oldObj > 10.0 ? 1.0 / oldObj : 1.0;
        for (int index = 0; index < this._numAlphaFactors; ++index) {
            double alpha = this._allAlphas[index] * alphaScale;
            for (int k2 = 0; k2 < this._ncolX; ++k2) {
                u2[k2] = x2[k2] - alpha * grad[k2];
            }
            double[] xnew = this._regx.rproxgrad(u2, alpha * this._gammax, random);
            double newobj = this.objective(xnew, a2);
            if (lowestObj > newobj) {
                System.arraycopy(xnew, 0, bestX, 0, xnew.length);
                lowestObj = newobj;
            }
            if (newobj == 0.0) break;
        }
        if (lowestObj < oldObj) {
            System.arraycopy(bestX, 0, x2, 0, x2.length);
        }
        return lowestObj;
    }

    @Override
    public double[] score0(double[] row, double[] preds) {
        return this.score0(row, preds, this._seed + this._rcnt++);
    }

    public static double[] impute_data(double[] xfactor, double[] preds, int nnums, int ncats, int[] permutation, boolean reverse_transform, double[] normMul, double[] normSub, GlrmLoss[] losses, boolean transposed, double[][] archetypes_raw, int[] catOffsets, int[] numLevels) {
        int d2;
        assert (preds.length == nnums + ncats);
        for (d2 = 0; d2 < ncats; ++d2) {
            double[] xyblock = GlrmMojoModel.lmulCatBlock(xfactor, d2, numLevels, transposed, archetypes_raw, catOffsets);
            preds[permutation[d2]] = losses[d2].mimpute(xyblock);
        }
        for (d2 = ncats; d2 < preds.length; ++d2) {
            int ds = d2 - ncats;
            double xy = GlrmMojoModel.lmulNumCol(xfactor, ds, transposed, archetypes_raw, catOffsets);
            preds[permutation[d2]] = losses[d2].impute(xy);
            if (!reverse_transform) continue;
            preds[permutation[d2]] = preds[permutation[d2]] / normMul[ds] + normSub[ds];
        }
        return preds;
    }

    public static int getNumCidx(int j2, int[] catOffsets) {
        return catOffsets[catOffsets.length - 1] + j2;
    }

    public static double lmulNumCol(double[] x2, int j2, boolean transposed, double[][] archetypes_raw, int[] catOffsets) {
        assert (x2 != null && x2.length == GlrmMojoModel.rank(transposed, archetypes_raw)) : "x must be of length " + GlrmMojoModel.rank(transposed, archetypes_raw);
        int cidx = GlrmMojoModel.getNumCidx(j2, catOffsets);
        double prod = 0.0;
        if (transposed) {
            for (int k2 = 0; k2 < GlrmMojoModel.rank(transposed, archetypes_raw); ++k2) {
                prod += x2[k2] * archetypes_raw[cidx][k2];
            }
        } else {
            for (int k3 = 0; k3 < GlrmMojoModel.rank(transposed, archetypes_raw); ++k3) {
                prod += x2[k3] * archetypes_raw[k3][cidx];
            }
        }
        return prod;
    }

    public static int getCatCidx(int j2, int level, int[] numLevels, int[] catOffsets) {
        int catColJLevel = numLevels[j2];
        assert (catColJLevel != 0) : "Number of levels in categorical column cannot be zero";
        assert (!Double.isNaN(level) && level >= 0 && level < catColJLevel) : "Got level = " + level + " when expected integer in [0," + catColJLevel + ")";
        return catOffsets[j2] + level;
    }

    public static double[] lmulCatBlock(double[] x2, int j2, int[] numLevels, boolean transposed, double[][] archetypes_raw, int[] catOffsets) {
        int catColJLevel = numLevels[j2];
        assert (catColJLevel != 0) : "Number of levels in categorical column cannot be zero";
        assert (x2 != null && x2.length == GlrmMojoModel.rank(transposed, archetypes_raw)) : "x must be of length " + GlrmMojoModel.rank(transposed, archetypes_raw);
        double[] prod = new double[catColJLevel];
        if (transposed) {
            for (int level = 0; level < catColJLevel; ++level) {
                int cidx = GlrmMojoModel.getCatCidx(j2, level, numLevels, catOffsets);
                for (int k2 = 0; k2 < GlrmMojoModel.rank(transposed, archetypes_raw); ++k2) {
                    int n2 = level;
                    prod[n2] = prod[n2] + x2[k2] * archetypes_raw[cidx][k2];
                }
            }
        } else {
            for (int level = 0; level < catColJLevel; ++level) {
                int cidx = GlrmMojoModel.getCatCidx(j2, level, numLevels, catOffsets);
                for (int k3 = 0; k3 < GlrmMojoModel.rank(transposed, archetypes_raw); ++k3) {
                    int n3 = level;
                    prod[n3] = prod[n3] + x2[k3] * archetypes_raw[k3][cidx];
                }
            }
        }
        return prod;
    }

    public static int rank(boolean transposed, double[][] archetypes_raw) {
        return transposed ? archetypes_raw[0].length : archetypes_raw.length;
    }

    private double[] gradientL(double[] x2, double[] a2) {
        int j2;
        double[] grad = new double[this._ncolX];
        int cat_offset = 0;
        for (j2 = 0; j2 < this._ncats; ++j2) {
            int k2;
            if (Double.isNaN(a2[j2])) continue;
            int n_levels = this._numLevels[j2];
            double[] xy = new double[n_levels];
            for (int level = 0; level < n_levels; ++level) {
                for (k2 = 0; k2 < this._ncolX; ++k2) {
                    int n2 = level;
                    xy[n2] = xy[n2] + x2[k2] * this._archetypes[k2][level + cat_offset];
                }
            }
            double[] gradL = this._losses[j2].mlgrad(xy, (int)a2[j2]);
            for (k2 = 0; k2 < this._ncolX; ++k2) {
                for (int c2 = 0; c2 < n_levels; ++c2) {
                    int n3 = k2;
                    grad[n3] = grad[n3] + gradL[c2] * this._archetypes[k2][c2 + cat_offset];
                }
            }
            cat_offset += n_levels;
        }
        for (j2 = this._ncats; j2 < this._ncolA; ++j2) {
            int js = j2 - this._ncats;
            if (Double.isNaN(a2[j2])) continue;
            double xy = 0.0;
            for (int k3 = 0; k3 < this._ncolX; ++k3) {
                xy += x2[k3] * this._archetypes[k3][js + cat_offset];
            }
            double gradL = this._losses[j2].lgrad(xy, (a2[j2] - this._normSub[js]) * this._normMul[js]);
            for (int k4 = 0; k4 < this._ncolX; ++k4) {
                int n4 = k4;
                grad[n4] = grad[n4] + gradL * this._archetypes[k4][js + cat_offset];
            }
        }
        return grad;
    }

    private double objective(double[] x2, double[] a2) {
        int k2;
        int j2;
        double res = 0.0;
        int cat_offset = 0;
        for (j2 = 0; j2 < this._ncats; ++j2) {
            if (Double.isNaN(a2[j2])) continue;
            int n_levels = this._numLevels[j2];
            double[] xy = new double[n_levels];
            for (int level = 0; level < n_levels; ++level) {
                for (k2 = 0; k2 < this._ncolX; ++k2) {
                    int n2 = level;
                    xy[n2] = xy[n2] + x2[k2] * this._archetypes[k2][level + cat_offset];
                }
            }
            res += this._losses[j2].mloss(xy, (int)a2[j2]);
            cat_offset += n_levels;
        }
        for (j2 = this._ncats; j2 < this._ncolA; ++j2) {
            int js = j2 - this._ncats;
            if (Double.isNaN(a2[j2])) continue;
            double xy = 0.0;
            for (k2 = 0; k2 < this._ncolX; ++k2) {
                xy += x2[k2] * this._archetypes[k2][js + cat_offset];
            }
            res += this._losses[j2].loss(xy, (a2[j2] - this._normSub[js]) * this._normMul[js]);
        }
        return res += this._gammax * this._regx.regularize(x2);
    }

    @Override
    public String[] getOutputNames() {
        String[] names = new String[this._ncolX];
        for (int i2 = 0; i2 < names.length; ++i2) {
            names[i2] = "Arch" + (i2 + 1);
        }
        return names;
    }

    static {
        assert (UP_FACTOR > 1.0);
        CATEGORIES = EnumSet.of(ModelCategory.AutoEncoder, ModelCategory.DimReduction);
    }
}

