/*
 * Decompiled with CFR 0.152.
 */
package carskit.alg.cars.adaptation.dependent.dev;

import carskit.data.structure.SparseMatrix;
import carskit.generic.ContextRecommender;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import happy.coding.math.Randoms;
import librec.data.DenseMatrix;
import librec.data.MatrixEntry;

public class CAMF_CUCI
extends ContextRecommender {
    protected Table<Integer, Integer, Double> icBias;
    protected Table<Integer, Integer, Double> ucBias;

    public CAMF_CUCI(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        this.algoName = "CAMF_CUCI";
    }

    @Override
    protected void initModel() throws Exception {
        int c;
        super.initModel();
        this.ucBias = HashBasedTable.create();
        this.icBias = HashBasedTable.create();
        for (int u = 0; u < this.numUsers; ++u) {
            for (c = 0; c < numConditions; ++c) {
                this.ucBias.put(u, c, Randoms.gaussian(initMean, initStd));
            }
        }
        for (int i = 0; i < this.numItems; ++i) {
            for (c = 0; c < numConditions; ++c) {
                this.icBias.put(i, c, Randoms.gaussian(initMean, initStd));
            }
        }
    }

    @Override
    protected double predict(int u, int j, int c) throws Exception {
        double pred = this.globalMean + DenseMatrix.rowMult(this.P, u, this.Q, j);
        for (int cond : this.getConditions(c)) {
            pred += this.icBias.get(j, cond) + this.ucBias.get(u, cond);
        }
        return pred;
    }

    @Override
    protected void buildModel() throws Exception {
        for (int iter = 1; iter <= numIters; ++iter) {
            this.loss = 0.0;
            for (MatrixEntry me : this.trainMatrix) {
                int ui = me.row();
                int u = rateDao.getUserIdFromUI(ui);
                int j = rateDao.getItemIdFromUI(ui);
                int ctx = me.column();
                double rujc = me.get();
                double pred = this.predict(u, j, ctx, false);
                double euj = rujc - pred;
                this.loss += euj * euj;
                double Buc_sum = 0.0;
                double Bic_sum = 0.0;
                for (int cond : this.getConditions(ctx)) {
                    double Buc = this.ucBias.get(u, cond);
                    double Bic = this.icBias.get(j, cond);
                    Buc_sum += Math.pow(Buc, 2.0);
                    Bic_sum += Math.pow(Bic, 2.0);
                    double sgdu = euj - (double)regC * Buc;
                    double sgdj = euj - (double)regC * Bic;
                    this.ucBias.put(u, cond, Buc + this.lRate * sgdu);
                    this.icBias.put(j, cond, Bic + this.lRate * sgdj);
                }
                this.loss += (double)regC * Bic_sum + (double)regC * Buc_sum;
                for (int f = 0; f < numFactors; ++f) {
                    double puf = this.P.get(u, f);
                    double qjf = this.Q.get(j, f);
                    double delta_u = euj * qjf - (double)regU * puf;
                    double delta_j = euj * puf - (double)regI * qjf;
                    this.P.add(u, f, this.lRate * delta_u);
                    this.Q.add(j, f, this.lRate * delta_j);
                    this.loss += (double)regU * puf * puf + (double)regI * qjf * qjf;
                }
            }
            this.loss *= 0.5;
            if (this.isConverged(iter)) break;
        }
    }
}

