/*
 * Decompiled with CFR 0.152.
 */
package librec.ranking;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.SparseMatrix;
import librec.intf.SocialRecommender;
import librec.util.Randoms;
import librec.util.Strings;

public class GBPR
extends SocialRecommender {
    private float rho;
    private int gLen;

    public GBPR(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        isRankingPred = true;
        this.initByNorm = false;
    }

    @Override
    protected void initModel() throws Exception {
        super.initModel();
        this.itemBias = new DenseVector(numItems);
        this.itemBias.init();
        this.rho = algoOptions.getFloat("-rho");
        this.gLen = algoOptions.getInt("-gSize");
        this.userItemsCache = this.trainMatrix.rowColumnsCache(cacheSpec);
        this.itemUsersCache = this.trainMatrix.columnRowsCache(cacheSpec);
    }

    @Override
    protected void buildModel() throws Exception {
        int iter = 1;
        while (iter <= numIters) {
            this.loss = 0.0;
            DenseMatrix PS = new DenseMatrix(numUsers, numFactors);
            DenseMatrix QS = new DenseMatrix(numItems, numFactors);
            int s = 0;
            int smax = numUsers * 100;
            while (s < smax) {
                int u = 0;
                int i = 0;
                int j = 0;
                List ratedItems = null;
                while ((ratedItems = (List)this.userItemsCache.get(u = Randoms.uniform(this.trainMatrix.numRows()))).size() == 0) {
                }
                i = (Integer)Randoms.random(ratedItems);
                List ws = (List)this.itemUsersCache.get(i);
                ArrayList<Integer> g = new ArrayList<Integer>();
                if (ws.size() <= this.gLen) {
                    g.addAll(ws);
                } else {
                    g.add(u);
                    while (g.size() < this.gLen) {
                        Integer w = (Integer)Randoms.random(ws);
                        if (g.contains(w)) continue;
                        g.add(w);
                    }
                }
                double pgui = this.predict(u, i, g);
                while (ratedItems.contains(j = Randoms.uniform(numItems))) {
                }
                double puj = this.predict(u, j);
                double pgij = pgui - puj;
                double vals = -Math.log(this.g(pgij));
                this.loss += vals;
                double cmg = this.g(-pgij);
                double bi = this.itemBias.get(i);
                this.itemBias.add(i, this.lRate * (cmg - (double)regB * bi));
                this.loss += (double)regB * bi * bi;
                double bj = this.itemBias.get(j);
                this.itemBias.add(j, this.lRate * (-cmg - (double)regB * bj));
                this.loss += (double)regB * bj * bj;
                double n = 1.0 / (double)g.size();
                double[] sum_w = new double[numFactors];
                Iterator iterator = g.iterator();
                while (iterator.hasNext()) {
                    int w = (Integer)iterator.next();
                    double delta = w == u ? 1 : 0;
                    int f = 0;
                    while (f < numFactors) {
                        double pwf = this.P.get(w, f);
                        double qif = this.Q.get(i, f);
                        double qjf = this.Q.get(j, f);
                        double delta_pwf = (double)this.rho * n * qif + (double)(1.0f - this.rho) * delta * qif - delta * qjf;
                        PS.add(w, f, this.lRate * (cmg * delta_pwf - (double)regU * pwf));
                        this.loss += (double)regU * pwf * pwf;
                        int n2 = f++;
                        sum_w[n2] = sum_w[n2] + pwf;
                    }
                }
                int f = 0;
                while (f < numFactors) {
                    double puf = this.P.get(u, f);
                    double qif = this.Q.get(i, f);
                    double qjf = this.Q.get(j, f);
                    double delta_qif = (double)this.rho * n * sum_w[f] + (double)(1.0f - this.rho) * puf;
                    QS.add(i, f, this.lRate * (cmg * delta_qif - (double)regI * qif));
                    this.loss += (double)regI * qif * qif;
                    double delta_qjf = -puf;
                    QS.add(j, f, this.lRate * (cmg * delta_qjf - (double)regI * qjf));
                    this.loss += (double)regI * qjf * qjf;
                    ++f;
                }
                ++s;
            }
            this.P = this.P.add(PS);
            this.Q = this.Q.add(QS);
            if (this.isConverged(iter)) break;
            ++iter;
        }
    }

    @Override
    protected double predict(int u, int j) {
        return this.itemBias.get(j) + DenseMatrix.rowMult(this.P, u, this.Q, j);
    }

    protected double predict(int u, int j, List<Integer> g) {
        double ruj = this.predict(u, j);
        double sum = 0.0;
        for (int w : g) {
            sum += DenseMatrix.rowMult(this.P, w, this.Q, j);
        }
        double rgj = sum / (double)g.size() + this.itemBias.get(j);
        return (double)this.rho * rgj + (double)(1.0f - this.rho) * ruj;
    }

    @Override
    public String toString() {
        return Strings.toString(new Object[]{Float.valueOf(binThold), Float.valueOf(this.rho), this.gLen, numFactors, Float.valueOf(initLRate), Float.valueOf(maxLRate), Float.valueOf(regU), Float.valueOf(regI), Float.valueOf(regB), numIters});
    }
}

