/*
 * Decompiled with CFR 0.152.
 */
package carskit.alg.cars.adaptation.dependent.sim;

import carskit.alg.cars.adaptation.dependent.CAMF;
import carskit.data.structure.SparseMatrix;
import com.google.common.collect.HashBasedTable;
import java.util.Iterator;
import java.util.List;
import librec.data.DenseMatrix;
import librec.data.MatrixEntry;
import librec.data.SymmMatrix;

public class CAMF_ICS
extends CAMF {
    public CAMF_ICS(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        this.algoName = "CAMF_ICS";
        isRankingPred = true;
    }

    @Override
    protected void initModel() throws Exception {
        super.initModel();
        if (!isRankingPred) {
            this.P.init(1.0, 0.1);
            this.Q.init(1.0, 0.1);
        } else {
            this.P.init();
            this.Q.init();
        }
        this.ccMatrix_ICS = new SymmMatrix(numConditions);
        for (int i = 0; i < numConditions; ++i) {
            for (int j = 0; j < numConditions; ++j) {
                this.ccMatrix_ICS.set(i, j, 1.0);
            }
        }
    }

    @Override
    protected double predict(int u, int j, int c) throws Exception {
        double pred = DenseMatrix.rowMult(this.P, u, this.Q, j);
        List<Integer> conditions = this.getConditions(c);
        for (int i = 0; i < conditions.size(); ++i) {
            pred *= this.ccMatrix_ICS.get(conditions.get(i), (Integer)EmptyContextConditions.get(i));
        }
        return pred;
    }

    @Override
    protected void buildModel() throws Exception {
        for (int iter = 1; iter <= numIters; ++iter) {
            this.loss = 0.0;
            for (MatrixEntry me : this.trainMatrix) {
                double dotRating;
                int ui = me.row();
                int u = rateDao.getUserIdFromUI(ui);
                int j = rateDao.getItemIdFromUI(ui);
                int ctx = me.column();
                double rujc = me.get();
                HashBasedTable toBeUpdated = HashBasedTable.create();
                double simc = 1.0;
                double pred = dotRating = DenseMatrix.rowMult(this.P, u, this.Q, j);
                List<Integer> conditions = this.getConditions(ctx);
                for (int i = 0; i < conditions.size(); ++i) {
                    int index1 = conditions.get(i);
                    int index2 = (Integer)EmptyContextConditions.get(i);
                    double sim = 1.0;
                    if (index1 != index2) {
                        sim = this.ccMatrix_ICS.get(index1, index2);
                        toBeUpdated.put((Object)index1, (Object)index2, (Object)sim);
                        simc *= sim;
                    }
                    this.loss += (double)regC * sim * sim;
                    pred *= sim;
                }
                double euj = rujc - pred;
                this.loss += euj * euj;
                if (toBeUpdated.size() > 0) {
                    Iterator i$ = toBeUpdated.rowKeySet().iterator();
                    while (i$.hasNext()) {
                        int index1 = (Integer)i$.next();
                        Iterator i$2 = toBeUpdated.row((Object)index1).keySet().iterator();
                        while (i$2.hasNext()) {
                            int index2 = (Integer)i$2.next();
                            double update = (Double)toBeUpdated.get(index1, index2);
                            update += this.lRate * (euj * dotRating * simc / update - (double)regC * update);
                            this.ccMatrix_ICS.set(index1, index2, update);
                        }
                    }
                }
                for (int f = 0; f < numFactors; ++f) {
                    double puf = this.P.get(u, f);
                    double qjf = this.Q.get(j, f);
                    double delta_u = euj * qjf * simc - (double)regU * puf;
                    double delta_j = euj * puf * simc - (double)regI * qjf;
                    this.P.add(u, f, this.lRate * delta_u);
                    this.Q.add(j, f, this.lRate * delta_j);
                    this.loss += (double)regU * puf * puf + (double)regI * qjf * qjf;
                }
            }
            this.loss *= 0.5;
            if (this.isConverged(iter)) break;
        }
    }
}

