/*
 * Decompiled with CFR 0.152.
 */
package librec.rating;

import librec.data.DenseMatrix;
import librec.data.MatrixEntry;
import librec.data.SparseMatrix;
import librec.data.SparseVector;
import librec.intf.SocialRecommender;

public class SocialMF
extends SocialRecommender {
    public SocialMF(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        this.initByNorm = false;
    }

    @Override
    protected void buildModel() throws Exception {
        int iter = 1;
        while (iter <= numIters) {
            this.loss = 0.0;
            DenseMatrix PS = new DenseMatrix(numUsers, numFactors);
            DenseMatrix QS = new DenseMatrix(numItems, numFactors);
            for (MatrixEntry me : this.trainMatrix) {
                int u = me.row();
                int j = me.column();
                double ruj = me.get();
                double pred = this.predict(u, j, false);
                double euj = this.g(pred) - this.normalize(ruj);
                this.loss += euj * euj;
                double csgd = this.gd(pred) * euj;
                int f = 0;
                while (f < numFactors) {
                    PS.add(u, f, csgd * this.Q.get(j, f) + (double)regU * this.P.get(u, f));
                    QS.add(j, f, csgd * this.P.get(u, f) + (double)regI * this.Q.get(j, f));
                    this.loss += (double)regU * this.P.get(u, f) * this.P.get(u, f);
                    this.loss += (double)regI * this.Q.get(j, f) * this.Q.get(j, f);
                    ++f;
                }
            }
            int u = 0;
            while (u < numUsers) {
                SparseVector uv = socialMatrix.row(u);
                int numConns = uv.getCount();
                if (numConns != 0) {
                    double[] sumNNs = new double[numFactors];
                    int[] nArray = uv.getIndex();
                    int pred = nArray.length;
                    int n = 0;
                    while (n < pred) {
                        int v = nArray[n];
                        int f = 0;
                        while (f < numFactors) {
                            int n2 = f;
                            sumNNs[n2] = sumNNs[n2] + socialMatrix.get(u, v) * this.P.get(v, f);
                            ++f;
                        }
                        ++n;
                    }
                    int f = 0;
                    while (f < numFactors) {
                        double diff = this.P.get(u, f) - sumNNs[f] / (double)numConns;
                        PS.add(u, f, (double)regS * diff);
                        this.loss += (double)regS * diff * diff;
                        ++f;
                    }
                    SparseVector iuv = socialMatrix.column(u);
                    int numVs = iuv.getCount();
                    int[] nArray2 = iuv.getIndex();
                    int n3 = nArray2.length;
                    int n4 = 0;
                    while (n4 < n3) {
                        int v = nArray2[n4];
                        double tvu = socialMatrix.get(v, u);
                        SparseVector vv = socialMatrix.row(v);
                        double[] sumDiffs = new double[numFactors];
                        int[] nArray3 = vv.getIndex();
                        int n5 = nArray3.length;
                        int n6 = 0;
                        while (n6 < n5) {
                            int w = nArray3[n6];
                            int f2 = 0;
                            while (f2 < numFactors) {
                                int n7 = f2;
                                sumDiffs[n7] = sumDiffs[n7] + socialMatrix.get(v, w) * this.P.get(w, f2);
                                ++f2;
                            }
                            ++n6;
                        }
                        numConns = vv.getCount();
                        if (numConns > 0) {
                            int f3 = 0;
                            while (f3 < numFactors) {
                                PS.add(u, f3, (double)(-regS) * (tvu / (double)numVs) * (this.P.get(v, f3) - sumDiffs[f3] / (double)numConns));
                                ++f3;
                            }
                        }
                        ++n4;
                    }
                }
                ++u;
            }
            this.P = this.P.add(PS.scale(-this.lRate));
            this.Q = this.Q.add(QS.scale(-this.lRate));
            this.loss *= 0.5;
            if (this.isConverged(iter)) break;
            ++iter;
        }
    }

    @Override
    protected double predict(int u, int j, boolean bounded) {
        double pred = DenseMatrix.rowMult(this.P, u, this.Q, j);
        if (bounded) {
            return this.denormalize(this.g(pred));
        }
        return pred;
    }
}

