/*
 * Decompiled with CFR 0.152.
 */
package carskit.alg.baseline.ranking;

import carskit.data.structure.SparseMatrix;
import carskit.generic.IterativeRecommender;
import happy.coding.io.Strings;
import happy.coding.math.Randoms;
import librec.data.SparseVector;

public class BPR
extends IterativeRecommender {
    public BPR(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        isRankingPred = true;
        this.initByNorm = false;
        this.algoName = "BPR";
    }

    @Override
    protected void initModel() throws Exception {
        super.initModel();
        this.userCache = this.train.rowCache(cacheSpec);
    }

    @Override
    protected void buildModel() throws Exception {
        for (int iter = 1; iter <= numIters; ++iter) {
            this.loss = 0.0;
            int smax = this.numUsers * 100;
            for (int s = 0; s < smax; ++s) {
                SparseVector pu;
                int u = 0;
                int i = 0;
                int j = 0;
                while ((pu = (SparseVector)this.userCache.get(u = Randoms.uniform(this.numUsers))).getCount() == 0) {
                }
                int[] is = pu.getIndex();
                i = is[Randoms.uniform(is.length)];
                while (pu.contains(j = Randoms.uniform(this.numItems))) {
                }
                double xui = this.predict(u, i);
                double xuj = this.predict(u, j);
                double xuij = xui - xuj;
                double vals = -Math.log(this.g(xuij));
                this.loss += vals;
                double cmg = this.g(-xuij);
                for (int f = 0; f < numFactors; ++f) {
                    double puf = this.P.get(u, f);
                    double qif = this.Q.get(i, f);
                    double qjf = this.Q.get(j, f);
                    this.P.add(u, f, this.lRate * (cmg * (qif - qjf) - (double)regU * puf));
                    this.Q.add(i, f, this.lRate * (cmg * puf - (double)regI * qif));
                    this.Q.add(j, f, this.lRate * (cmg * -puf - (double)regI * qjf));
                    this.loss += (double)regU * puf * puf + (double)regI * qif * qif + (double)regI * qjf * qjf;
                }
            }
            if (this.isConverged(iter)) break;
        }
    }

    @Override
    public String toString() {
        return Strings.toString(new Object[]{Float.valueOf(binThold), numFactors, Float.valueOf(initLRate), Float.valueOf(maxLRate), Float.valueOf(regU), Float.valueOf(regI), numIters}, ",");
    }
}

