/*
 * Decompiled with CFR 0.152.
 */
package carskit.alg.baseline.ranking;

import carskit.data.setting.Configuration;
import carskit.data.structure.SparseMatrix;
import carskit.generic.IterativeRecommender;
import happy.coding.io.Strings;
import java.util.List;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.MatrixEntry;

@Configuration(value="binThold, numFactors, initLRate, maxLRate, regU, regI, numIters")
public class LRMF
extends IterativeRecommender {
    public DenseVector userExp;

    public LRMF(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        isRankingPred = true;
        this.initByNorm = false;
        this.algoName = "LRMF";
    }

    @Override
    protected void initModel() throws Exception {
        super.initModel();
        this.userExp = new DenseVector(this.numUsers);
        for (MatrixEntry me : this.train) {
            int u = me.row();
            double ruj = me.get();
            this.userExp.add(u, Math.exp(ruj));
        }
    }

    @Override
    protected void buildModel() throws Exception {
        for (int iter = 1; iter <= numIters; ++iter) {
            this.loss = 0.0;
            for (MatrixEntry me : this.train) {
                int u = me.row();
                int j = me.column();
                double ruj = me.get();
                double pred = DenseMatrix.rowMult(this.P, u, this.Q, j);
                double uexp = 0.0;
                List<Integer> items = this.train.getColumns(u);
                for (int i : items) {
                    uexp += Math.exp(DenseMatrix.rowMult(this.P, u, this.Q, i));
                }
                this.loss -= Math.exp(ruj) / this.userExp.get(u) * Math.log(Math.exp(pred) / uexp);
                for (int f = 0; f < numFactors; ++f) {
                    double puf = this.P.get(u, f);
                    double qjf = this.Q.get(j, f);
                    double delta_u = (Math.exp(ruj) / this.userExp.get(u) - Math.exp(pred) / uexp) * this.gd(pred) * qjf - (double)regU * puf;
                    double delta_j = (Math.exp(ruj) / this.userExp.get(u) - Math.exp(pred) / uexp) * this.gd(pred) * puf - (double)regI * qjf;
                    this.P.add(u, f, this.lRate * delta_u);
                    this.Q.add(j, f, this.lRate * delta_j);
                    this.loss += 0.5 * (double)regU * puf * puf + 0.5 * (double)regI * qjf * qjf;
                }
            }
            if (this.isConverged(iter)) break;
        }
    }

    @Override
    public String toString() {
        return Strings.toString(new Object[]{Float.valueOf(binThold), numFactors, Float.valueOf(initLRate), Float.valueOf(maxLRate), Float.valueOf(regU), Float.valueOf(regI), numIters}, ",");
    }
}

