/*
 * Decompiled with CFR 0.152.
 */
package carskit.alg.baseline.cf;

import carskit.alg.baseline.cf.BiasedMF;
import carskit.data.structure.SparseMatrix;
import java.util.Iterator;
import java.util.List;
import librec.data.DenseMatrix;
import librec.data.MatrixEntry;

public class SVDPlusPlus
extends BiasedMF {
    protected DenseMatrix Y;

    public SVDPlusPlus(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        this.algoName = "SVD++";
    }

    @Override
    protected void initModel() throws Exception {
        super.initModel();
        this.Y = new DenseMatrix(this.numItems, numFactors);
        this.Y.init(initMean, initStd);
        this.userItemsCache = this.train.rowColumnsCache(cacheSpec);
    }

    @Override
    protected void buildModel() throws Exception {
        for (int iter = 1; iter <= numIters; ++iter) {
            this.loss = 0.0;
            for (MatrixEntry me : this.train) {
                int f;
                int u = me.row();
                int j = me.column();
                double ruj = me.get();
                double pred = this.predict(u, j);
                double euj = ruj - pred;
                this.loss += euj * euj;
                List items = (List)this.userItemsCache.get(u);
                double w = Math.sqrt(items.size());
                double bu = this.userBias.get(u);
                double sgd = euj - (double)regB * bu;
                this.userBias.add(u, this.lRate * sgd);
                this.loss += (double)regB * bu * bu;
                double bj = this.itemBias.get(j);
                sgd = euj - (double)regB * bj;
                this.itemBias.add(j, this.lRate * sgd);
                this.loss += (double)regB * bj * bj;
                double[] sum_ys = new double[numFactors];
                for (f = 0; f < numFactors; ++f) {
                    double sum_f = 0.0;
                    Iterator i$ = items.iterator();
                    while (i$.hasNext()) {
                        int k = (Integer)i$.next();
                        sum_f += this.Y.get(k, f);
                    }
                    sum_ys[f] = w > 0.0 ? sum_f / w : sum_f;
                }
                for (f = 0; f < numFactors; ++f) {
                    double puf = this.P.get(u, f);
                    double qjf = this.Q.get(j, f);
                    double sgd_u = euj * qjf - (double)regU * puf;
                    double sgd_j = euj * (puf + sum_ys[f]) - (double)regI * qjf;
                    this.P.add(u, f, this.lRate * sgd_u);
                    this.Q.add(j, f, this.lRate * sgd_j);
                    this.loss += (double)regU * puf * puf + (double)regI * qjf * qjf;
                    Iterator i$ = items.iterator();
                    while (i$.hasNext()) {
                        int k = (Integer)i$.next();
                        double ykf = this.Y.get(k, f);
                        double delta_y = euj * qjf / w - (double)regU * ykf;
                        this.Y.add(k, f, this.lRate * delta_y);
                        this.loss += (double)regU * ykf * ykf;
                    }
                }
            }
            this.loss *= 0.5;
            if (this.isConverged(iter)) break;
        }
    }

    @Override
    protected double predict(int u, int j, int c) throws Exception {
        if (this.isUserSplitting) {
            int n = u = this.userIdMapper.contains(u, c) ? (Integer)this.userIdMapper.get(u, c) : u;
        }
        if (this.isItemSplitting) {
            j = this.itemIdMapper.contains(j, c) ? (Integer)this.itemIdMapper.get(j, c) : j;
        }
        return this.predict(u, j);
    }

    @Override
    protected double predict(int u, int j) throws Exception {
        double pred = this.globalMean + this.userBias.get(u) + this.itemBias.get(j) + DenseMatrix.rowMult(this.P, u, this.Q, j);
        List items = (List)this.userItemsCache.get(u);
        double w = Math.sqrt(items.size());
        Iterator i$ = items.iterator();
        while (i$.hasNext()) {
            int k = (Integer)i$.next();
            pred += DenseMatrix.rowMult(this.Y, k, this.Q, j) / w;
        }
        return pred;
    }
}

