/*
 * Decompiled with CFR 0.152.
 */
package hex.svd;

import hex.CustomMetric;
import hex.DataInfo;
import hex.Model;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ModelMetricsUnsupervised;
import hex.svd.SVD;
import java.util.ArrayList;
import water.AutoBuffer;
import water.DKV;
import water.Futures;
import water.Job;
import water.Key;
import water.Keyed;
import water.MRTask;
import water.codegen.CodeGeneratorPipeline;
import water.exceptions.JCodeSB;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.udf.CFuncRef;
import water.util.JCodeGen;
import water.util.SBPrintStream;

public class SVDModel
extends Model<SVDModel, SVDParameters, SVDOutput> {
    public SVDModel(Key<SVDModel> selfKey, SVDParameters parms, SVDOutput output) {
        super(selfKey, parms, output);
    }

    @Override
    protected Futures remove_impl(Futures fs, boolean cascade) {
        Keyed.remove(((SVDOutput)this._output)._u_key, fs, true);
        Keyed.remove(((SVDOutput)this._output)._v_key, fs, true);
        return super.remove_impl(fs, cascade);
    }

    @Override
    protected AutoBuffer writeAll_impl(AutoBuffer ab) {
        ab.putKey(((SVDOutput)this._output)._u_key);
        ab.putKey(((SVDOutput)this._output)._v_key);
        return super.writeAll_impl(ab);
    }

    @Override
    protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
        ab.getKey(((SVDOutput)this._output)._u_key, fs);
        ab.getKey(((SVDOutput)this._output)._v_key, fs);
        return super.readAll_impl(ab, fs);
    }

    @Override
    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
        return new ModelMetricsSVD.SVDModelMetrics(((SVDParameters)this._parms)._nv);
    }

    @Override
    protected Model.PredictScoreResult predictScoreImpl(Frame orig, Frame adaptedFr, String destination_key, final Job j2, boolean computeMetrics, CFuncRef customMetricFunc) {
        Frame adaptFrm = new Frame(adaptedFr);
        for (int i2 = 0; i2 < ((SVDParameters)this._parms)._nv; ++i2) {
            adaptFrm.add("PC" + String.valueOf(i2 + 1), adaptFrm.anyVec().makeZero());
        }
        new MRTask(){

            @Override
            public void map(Chunk[] chks) {
                if (this.isCancelled() || j2 != null && j2.stop_requested()) {
                    return;
                }
                double[] tmp = new double[((SVDOutput)SVDModel.this._output)._names.length];
                double[] preds = new double[((SVDParameters)SVDModel.this._parms)._nv];
                for (int row = 0; row < chks[0]._len; ++row) {
                    double[] p2 = SVDModel.this.score0(chks, row, tmp, preds);
                    for (int c2 = 0; c2 < preds.length; ++c2) {
                        chks[((SVDOutput)SVDModel.this._output)._names.length + c2].set(row, p2[c2]);
                    }
                }
                if (j2 != null) {
                    j2.update(1L);
                }
            }
        }.doAll(adaptFrm);
        int x2 = ((SVDOutput)this._output)._names.length;
        int y2 = adaptFrm.numCols();
        Frame f2 = adaptFrm.extractFrame(x2, y2);
        f2 = new Frame(Key.make(destination_key), f2.names(), f2.vecs());
        DKV.put(f2);
        ModelMetrics.MetricBuilder mb = this.makeMetricBuilder(null);
        return new Model.PredictScoreResult(this, mb, f2, f2);
    }

    @Override
    protected double[] score0(double[] data, double[] preds) {
        int numStart = ((SVDOutput)this._output)._catOffsets[((SVDOutput)this._output)._catOffsets.length - 1];
        assert (data.length == ((SVDOutput)this._output)._permutation.length);
        for (int i2 = 0; i2 < ((SVDParameters)this._parms)._nv; ++i2) {
            preds[i2] = 0.0;
            for (int j2 = 0; j2 < ((SVDOutput)this._output)._ncats; ++j2) {
                int level;
                double tmp = data[((SVDOutput)this._output)._permutation[j2]];
                int last_cat = ((SVDOutput)this._output)._catOffsets[j2 + 1] - ((SVDOutput)this._output)._catOffsets[j2] - 1;
                int n2 = Double.isNaN(tmp) ? last_cat : (level = (int)tmp - (((SVDParameters)this._parms)._use_all_factor_levels ? 0 : 1));
                if (level < 0 || level > last_cat) continue;
                int n3 = i2;
                preds[n3] = preds[n3] + ((SVDOutput)this._output)._v[((SVDOutput)this._output)._catOffsets[j2] + level][i2];
            }
            int dcol = ((SVDOutput)this._output)._ncats;
            int vcol = numStart;
            for (int j3 = 0; j3 < ((SVDOutput)this._output)._nnums; ++j3) {
                int n4 = i2;
                preds[n4] = preds[n4] + (data[((SVDOutput)this._output)._permutation[dcol]] - ((SVDOutput)this._output)._normSub[j3]) * ((SVDOutput)this._output)._normMul[j3] * ((SVDOutput)this._output)._v[vcol][i2];
                ++dcol;
                ++vcol;
            }
        }
        return preds;
    }

    @Override
    protected SBPrintStream toJavaInit(SBPrintStream sb, CodeGeneratorPipeline fileCtx) {
        sb = super.toJavaInit(sb, fileCtx);
        sb.ip("public boolean isSupervised() { return " + this.isSupervised() + "; }").nl();
        sb.ip("public int nfeatures() { return " + ((SVDOutput)this._output).nfeatures() + "; }").nl();
        sb.ip("public int nclasses() { return " + ((SVDParameters)this._parms)._nv + "; }").nl();
        if (((SVDOutput)this._output)._nnums > 0) {
            JCodeGen.toStaticVar((JCodeSB)sb, "NORMMUL", ((SVDOutput)this._output)._normMul, "Standardization/Normalization scaling factor for numerical variables.");
            JCodeGen.toStaticVar((JCodeSB)sb, "NORMSUB", ((SVDOutput)this._output)._normSub, "Standardization/Normalization offset for numerical variables.");
        }
        JCodeGen.toStaticVar((JCodeSB)sb, "CATOFFS", ((SVDOutput)this._output)._catOffsets, "Categorical column offsets.");
        JCodeGen.toStaticVar((JCodeSB)sb, "PERMUTE", ((SVDOutput)this._output)._permutation, "Permutation index vector.");
        JCodeGen.toStaticVar((JCodeSB)sb, "EIGVECS", ((SVDOutput)this._output)._v, "Eigenvector matrix.");
        return sb;
    }

    @Override
    protected void toJavaPredictBody(SBPrintStream bodySb, CodeGeneratorPipeline classCtx, CodeGeneratorPipeline fileCtx, boolean verboseCode) {
        bodySb.i().p("java.util.Arrays.fill(preds,0);").nl();
        int cats = ((SVDOutput)this._output)._ncats;
        int nums = ((SVDOutput)this._output)._nnums;
        bodySb.i().p("final int nstart = CATOFFS[CATOFFS.length-1];").nl();
        bodySb.i().p("for(int i = 0; i < ").p(((SVDParameters)this._parms)._nv).p("; i++) {").nl();
        bodySb.i(1).p("for(int j = 0; j < ").p(cats).p("; j++) {").nl();
        bodySb.i(2).p("double d = data[PERMUTE[j]];").nl();
        bodySb.i(2).p("int last = CATOFFS[j+1]-CATOFFS[j]-1;").nl();
        bodySb.i(2).p("int c = Double.isNaN(d) ? last : (int)d").p(((SVDParameters)this._parms)._use_all_factor_levels ? ";" : "-1;").nl();
        bodySb.i(2).p("if(c < 0 || c > last) continue;").nl();
        bodySb.i(2).p("preds[i] += EIGVECS[CATOFFS[j]+c][i];").nl();
        bodySb.i(1).p("}").nl();
        if (((SVDOutput)this._output)._nnums > 0) {
            bodySb.i(1).p("for(int j = 0; j < ").p(nums).p("; j++) {").nl();
            bodySb.i(2).p("preds[i] += (data[PERMUTE[j" + (cats > 0 ? "+" + cats : "") + "]]-NORMSUB[j])*NORMMUL[j]*EIGVECS[j" + (cats > 0 ? "+ nstart" : "") + "][i];").nl();
            bodySb.i(1).p("}").nl();
        }
        bodySb.i().p("}").nl();
    }

    public static class ModelMetricsSVD
    extends ModelMetricsUnsupervised {
        public ModelMetricsSVD(Model model, Frame frame, CustomMetric customMetric) {
            super(model, frame, 0L, Double.NaN, customMetric);
        }

        public static class SVDModelMetrics
        extends ModelMetricsUnsupervised.MetricBuilderUnsupervised<SVDModelMetrics> {
            public SVDModelMetrics(int dims) {
                this._work = new double[dims];
            }

            @Override
            public double[] perRow(double[] preds, float[] dataRow, Model m4) {
                return preds;
            }

            @Override
            public ModelMetrics makeModelMetrics(Model m4, Frame f2) {
                return m4.addModelMetrics(new ModelMetricsSVD(m4, f2, this._customMetric));
            }
        }
    }

    public static class SVDOutput
    extends Model.Output {
        public int _iterations;
        public double[][] _v;
        public Key<Frame> _v_key;
        public double[] _d;
        public Key<Frame> _u_key;
        public int _ncats;
        public int _nnums;
        public long _nobs;
        public double _total_variance;
        public int[] _catOffsets;
        public double[] _normSub;
        public double[] _normMul;
        public int[] _permutation;
        public String[] _names_expanded;
        public ArrayList<Double> _history_average_SEE = new ArrayList();
        public ArrayList<Double> _history_err = new ArrayList();
        public ArrayList<Double> _history_eigenVectorIndex = new ArrayList();
        public ArrayList<Long> _training_time_ms = new ArrayList();

        public SVDOutput(SVD b2) {
            super(b2);
        }

        @Override
        public ModelCategory getModelCategory() {
            return ModelCategory.DimReduction;
        }
    }

    public static class SVDParameters
    extends Model.Parameters {
        public DataInfo.TransformType _transform = DataInfo.TransformType.NONE;
        public Method _svd_method = Method.GramSVD;
        public int _nv = 1;
        public int _max_iterations = 1000;
        public String _u_name;
        public String _v_name;
        public boolean _keep_u = true;
        public boolean _save_v_frame = true;
        public boolean _only_v = false;
        public boolean _use_all_factor_levels = true;
        public boolean _impute_missing = false;

        @Override
        public String algoName() {
            return "SVD";
        }

        @Override
        public String fullName() {
            return "Singular Value Decomposition";
        }

        @Override
        public String javaName() {
            return SVDModel.class.getName();
        }

        @Override
        public long progressUnits() {
            switch (this._svd_method) {
                case GramSVD: {
                    return 2L;
                }
                case Power: {
                    return 1 + this._nv;
                }
                case Randomized: {
                    return 5 + this._max_iterations;
                }
            }
            return this._nv;
        }

        public static enum Method {
            GramSVD,
            Power,
            Randomized;

        }
    }
}

