/*
 * Decompiled with CFR 0.152.
 */
package carskit.alg.cars.adaptation.dependent.dev;

import carskit.alg.cars.adaptation.dependent.CAMF;
import carskit.data.structure.SparseMatrix;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.MatrixEntry;

public class CAMF_CU
extends CAMF {
    public CAMF_CU(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        this.algoName = "CAMF_CU";
    }

    @Override
    protected void initModel() throws Exception {
        super.initModel();
        this.itemBias = new DenseVector(this.numItems);
        this.itemBias.init(initMean, initStd);
        this.ucBias = new DenseMatrix(this.numUsers, numConditions);
        this.ucBias.init();
    }

    @Override
    protected double predict(int u, int j, int c) throws Exception {
        double pred = this.globalMean + this.itemBias.get(j) + DenseMatrix.rowMult(this.P, u, this.Q, j);
        for (int cond : this.getConditions(c)) {
            pred += this.ucBias.get(u, cond);
        }
        return pred;
    }

    @Override
    protected void buildModel() throws Exception {
        for (int iter = 1; iter <= numIters; ++iter) {
            this.loss = 0.0;
            for (MatrixEntry me : this.trainMatrix) {
                int ui = me.row();
                int u = rateDao.getUserIdFromUI(ui);
                int j = rateDao.getItemIdFromUI(ui);
                int ctx = me.column();
                double rujc = me.get();
                double pred = this.predict(u, j, ctx, false);
                double euj = rujc - pred;
                this.loss += euj * euj;
                double bj = this.itemBias.get(j);
                double sgd = euj - (double)regB * bj;
                this.itemBias.add(j, this.lRate * sgd);
                this.loss += (double)regB * bj * bj;
                double Buc_sum = 0.0;
                for (int cond : this.getConditions(ctx)) {
                    double Buc = this.ucBias.get(u, cond);
                    Buc_sum += Math.pow(Buc, 2.0);
                    sgd = euj - (double)regC * Buc;
                    this.ucBias.set(u, cond, Buc + this.lRate * sgd);
                }
                this.loss += (double)regC * Buc_sum;
                for (int f = 0; f < numFactors; ++f) {
                    double puf = this.P.get(u, f);
                    double qjf = this.Q.get(j, f);
                    double delta_u = euj * qjf - (double)regU * puf;
                    double delta_j = euj * puf - (double)regI * qjf;
                    this.P.add(u, f, this.lRate * delta_u);
                    this.Q.add(j, f, this.lRate * delta_j);
                    this.loss += (double)regU * puf * puf + (double)regI * qjf * qjf;
                }
            }
            this.loss *= 0.5;
            if (this.isConverged(iter)) break;
        }
    }
}

