/*
 * Decompiled with CFR 0.152.
 */
package librec.rating;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import librec.data.DenseMatrix;
import librec.data.MatrixEntry;
import librec.data.SparseMatrix;
import librec.data.SparseVector;
import librec.intf.SocialRecommender;

public class SoReg
extends SocialRecommender {
    private Table<Integer, Integer, Double> userCorrs;
    private float beta;

    public SoReg(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        this.initByNorm = false;
    }

    @Override
    protected void initModel() throws Exception {
        super.initModel();
        this.userCorrs = HashBasedTable.create();
        this.beta = algoOptions.getFloat("-beta");
    }

    protected double similarity(Integer u, Integer v) {
        SparseVector vv;
        SparseVector uv;
        if (this.userCorrs.contains(u, v)) {
            return this.userCorrs.get(u, v);
        }
        if (this.userCorrs.contains(v, u)) {
            return this.userCorrs.get(v, u);
        }
        double sim = Double.NaN;
        if (u < this.trainMatrix.numRows() && v < this.trainMatrix.numRows() && (uv = this.trainMatrix.row(u)).getCount() > 0 && !Double.isNaN(sim = this.correlation(uv, vv = this.trainMatrix.row(v), "pcc"))) {
            sim = (1.0 + sim) / 2.0;
        }
        this.userCorrs.put(u, v, sim);
        return sim;
    }

    @Override
    protected void buildModel() throws Exception {
        int iter = 1;
        while (iter <= numIters) {
            int j;
            this.loss = 0.0;
            DenseMatrix PS = new DenseMatrix(numUsers, numFactors);
            DenseMatrix QS = new DenseMatrix(numItems, numFactors);
            for (MatrixEntry me : this.trainMatrix) {
                int u = me.row();
                j = me.column();
                double ruj = me.get();
                double pred = this.predict(u, j);
                double euj = pred - ruj;
                this.loss += euj * euj;
                int f = 0;
                while (f < numFactors) {
                    double puf = this.P.get(u, f);
                    double qjf = this.Q.get(j, f);
                    PS.add(u, f, euj * qjf + (double)regU * puf);
                    QS.add(j, f, euj * puf + (double)regI * qjf);
                    this.loss += (double)regU * puf * puf + (double)regI * qjf * qjf;
                    ++f;
                }
            }
            int u = 0;
            while (u < numUsers) {
                SparseVector uos = socialMatrix.row(u);
                int[] nArray = uos.getIndex();
                int n = nArray.length;
                j = 0;
                while (j < n) {
                    int k = nArray[j];
                    double suk = this.similarity(u, k);
                    if (!Double.isNaN(suk)) {
                        int f = 0;
                        while (f < numFactors) {
                            double euk = this.P.get(u, f) - this.P.get(k, f);
                            PS.add(u, f, (double)this.beta * suk * euk);
                            this.loss += (double)this.beta * suk * euk * euk;
                            ++f;
                        }
                    }
                    ++j;
                }
                SparseVector uis = socialMatrix.column(u);
                int[] nArray2 = uis.getIndex();
                int n2 = nArray2.length;
                n = 0;
                while (n < n2) {
                    int g = nArray2[n];
                    double sug = this.similarity(u, g);
                    if (!Double.isNaN(sug)) {
                        int f = 0;
                        while (f < numFactors) {
                            double eug = this.P.get(u, f) - this.P.get(g, f);
                            PS.add(u, f, (double)this.beta * sug * eug);
                            ++f;
                        }
                    }
                    ++n;
                }
                ++u;
            }
            this.P = this.P.add(PS.scale(-this.lRate));
            this.Q = this.Q.add(QS.scale(-this.lRate));
            this.loss *= 0.5;
            if (this.isConverged(iter)) break;
            ++iter;
        }
    }
}

