/*
 * Decompiled with CFR 0.152.
 */
package librec.rating;

import java.util.HashMap;
import java.util.Map;
import librec.data.DenseMatrix;
import librec.data.MatrixEntry;
import librec.data.SparseMatrix;
import librec.intf.SocialRecommender;

public class SoRec
extends SocialRecommender {
    private DenseMatrix Z;
    private float regC;
    private float regZ;
    private Map<Integer, Integer> inDegrees;
    private Map<Integer, Integer> outDegrees;

    public SoRec(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        this.initByNorm = false;
    }

    @Override
    protected void initModel() throws Exception {
        super.initModel();
        this.Z = new DenseMatrix(numUsers, numFactors);
        this.Z.init();
        this.regC = algoOptions.getFloat("-c");
        this.regZ = algoOptions.getFloat("-z");
        this.inDegrees = new HashMap<Integer, Integer>();
        this.outDegrees = new HashMap<Integer, Integer>();
        int u = 0;
        while (u < numUsers) {
            int in = socialMatrix.columnSize(u);
            int out = socialMatrix.rowSize(u);
            this.inDegrees.put(u, in);
            this.outDegrees.put(u, out);
            ++u;
        }
    }

    @Override
    protected void buildModel() throws Exception {
        int iter = 1;
        while (iter <= numIters) {
            double pred;
            int u;
            this.loss = 0.0;
            DenseMatrix PS = new DenseMatrix(numUsers, numFactors);
            DenseMatrix QS = new DenseMatrix(numItems, numFactors);
            DenseMatrix ZS = new DenseMatrix(numUsers, numFactors);
            for (MatrixEntry me : this.trainMatrix) {
                u = me.row();
                int j = me.column();
                double ruj = me.get();
                pred = this.predict(u, j, false);
                double euj = this.g(pred) - this.normalize(ruj);
                this.loss += euj * euj;
                int f = 0;
                while (f < numFactors) {
                    double puf = this.P.get(u, f);
                    double qjf = this.Q.get(j, f);
                    PS.add(u, f, this.gd(pred) * euj * qjf + (double)regU * puf);
                    QS.add(j, f, this.gd(pred) * euj * puf + (double)regI * qjf);
                    this.loss += (double)regU * puf * puf + (double)regI * qjf * qjf;
                    ++f;
                }
            }
            for (MatrixEntry me : socialMatrix) {
                u = me.row();
                int v = me.column();
                double tuv = me.get();
                if (tuv <= 0.0) continue;
                pred = DenseMatrix.rowMult(this.P, u, this.Z, v);
                int vminus = this.inDegrees.get(v);
                int uplus = this.outDegrees.get(u);
                double weight = Math.sqrt((double)vminus / ((double)(uplus + vminus) + 0.0));
                double euv = this.g(pred) - weight * tuv;
                this.loss += (double)this.regC * euv * euv;
                int f = 0;
                while (f < numFactors) {
                    double puf = this.P.get(u, f);
                    double zvf = this.Z.get(v, f);
                    PS.add(u, f, (double)this.regC * this.gd(pred) * euv * zvf);
                    ZS.add(v, f, (double)this.regC * this.gd(pred) * euv * puf + (double)this.regZ * zvf);
                    this.loss += (double)this.regZ * zvf * zvf;
                    ++f;
                }
            }
            this.P = this.P.add(PS.scale(-this.lRate));
            this.Q = this.Q.add(QS.scale(-this.lRate));
            this.Z = this.Z.add(ZS.scale(-this.lRate));
            this.loss *= 0.5;
            if (this.isConverged(iter)) break;
            ++iter;
        }
    }

    @Override
    protected double predict(int u, int j, boolean bounded) {
        double pred = DenseMatrix.rowMult(this.P, u, this.Q, j);
        if (bounded) {
            return this.denormalize(this.g(pred));
        }
        return pred;
    }
}

