/*
 * Decompiled with CFR 0.152.
 */
package carskit.alg.baseline.cf;

import carskit.data.structure.SparseMatrix;
import carskit.generic.IterativeRecommender;
import librec.data.DenseMatrix;
import librec.data.MatrixEntry;

public class PMF
extends IterativeRecommender {
    public PMF(SparseMatrix rm, SparseMatrix tm, int fold) {
        super(rm, tm, fold);
        this.algoName = "PMF";
    }

    @Override
    protected void buildModel() throws Exception {
        for (int iter = 1; iter <= numIters; ++iter) {
            this.loss = 0.0;
            for (MatrixEntry me : this.train) {
                int u = me.row();
                int j = me.column();
                double ruj = me.get();
                double puj = this.predict(u, j, -1, false);
                double euj = ruj - puj;
                this.loss += euj * euj;
                for (int f = 0; f < numFactors; ++f) {
                    double puf = this.P.get(u, f);
                    double qjf = this.Q.get(j, f);
                    this.P.add(u, f, this.lRate * (euj * qjf - (double)regU * puf));
                    this.Q.add(j, f, this.lRate * (euj * puf - (double)regI * qjf));
                    this.loss += (double)regU * puf * puf + (double)regI * qjf * qjf;
                }
            }
            this.loss *= 0.5;
            if (this.isConverged(iter)) break;
        }
    }

    @Override
    protected double predict(int u, int j, int c) throws Exception {
        if (this.isUserSplitting) {
            int n = u = this.userIdMapper.contains(u, c) ? (Integer)this.userIdMapper.get(u, c) : u;
        }
        if (this.isItemSplitting) {
            j = this.itemIdMapper.contains(j, c) ? (Integer)this.itemIdMapper.get(j, c) : j;
        }
        return this.predict(u, j);
    }

    @Override
    protected double predict(int u, int j) throws Exception {
        return DenseMatrix.rowMult(this.P, u, this.Q, j);
    }
}

