/*
 * Decompiled with CFR 0.152.
 */
package carskit.alg.cars.adaptation.dependent.sim;

import carskit.alg.cars.adaptation.dependent.CAMF;
import carskit.data.structure.SparseMatrix;
import com.google.common.collect.HashBasedTable;
import java.util.Iterator;
import java.util.List;
import librec.data.DenseMatrix;
import librec.data.MatrixEntry;

public class CAMF_LCS
extends CAMF {
    private int numF;

    public CAMF_LCS(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        this.algoName = "CAMF_LCS";
        isRankingPred = true;
    }

    @Override
    protected void initModel() throws Exception {
        super.initModel();
        this.numF = algoOptions.getInt("-f", 10);
        this.cfMatrix_LCS = new DenseMatrix(numConditions, this.numF);
        this.cfMatrix_LCS.init();
    }

    @Override
    protected double predict(int u, int j, int c) throws Exception {
        double pred = DenseMatrix.rowMult(this.P, u, this.Q, j);
        List<Integer> conditions = this.getConditions(c);
        for (int i = 0; i < conditions.size(); ++i) {
            double[] dv1 = this.cfMatrix_LCS.row(conditions.get(i)).getData();
            double[] dv2 = this.cfMatrix_LCS.row((Integer)EmptyContextConditions.get(i)).getData();
            double sum1 = 0.0;
            double sum2 = 0.0;
            for (int h = 0; h < dv1.length; ++h) {
                sum1 += dv1[h] * dv1[h];
                sum2 += dv2[h] * dv2[h];
            }
            sum1 = Math.sqrt(sum1);
            sum2 = Math.sqrt(sum2);
            pred *= DenseMatrix.rowMult(this.cfMatrix_LCS, conditions.get(i), this.cfMatrix_LCS, (Integer)EmptyContextConditions.get(i));
        }
        return pred;
    }

    @Override
    protected void buildModel() throws Exception {
        for (int iter = 1; iter <= numIters; ++iter) {
            this.loss = 0.0;
            for (MatrixEntry me : this.trainMatrix) {
                double dotRating;
                int ui = me.row();
                int u = rateDao.getUserIdFromUI(ui);
                int j = rateDao.getItemIdFromUI(ui);
                int ctx = me.column();
                double rujc = me.get();
                HashBasedTable toBeUpdated = HashBasedTable.create();
                double simc = 1.0;
                double pred = dotRating = DenseMatrix.rowMult(this.P, u, this.Q, j);
                List<Integer> conditions = this.getConditions(ctx);
                for (int i = 0; i < conditions.size(); ++i) {
                    int index1 = conditions.get(i);
                    int index2 = (Integer)EmptyContextConditions.get(i);
                    double sim = 1.0;
                    if (index1 != index2) {
                        sim = DenseMatrix.rowMult(this.cfMatrix_LCS, index1, this.cfMatrix_LCS, index2);
                        toBeUpdated.put((Object)index1, (Object)index2, (Object)sim);
                        simc *= sim;
                    }
                    pred *= sim;
                }
                double euj = rujc - pred;
                this.loss += euj * euj;
                if (toBeUpdated.size() > 0) {
                    Iterator i$ = toBeUpdated.rowKeySet().iterator();
                    while (i$.hasNext()) {
                        int index1 = (Integer)i$.next();
                        Iterator i$2 = toBeUpdated.row((Object)index1).keySet().iterator();
                        while (i$2.hasNext()) {
                            int index2 = (Integer)i$2.next();
                            for (int f = 0; f < this.numF; ++f) {
                                double c1f = this.cfMatrix_LCS.get(index1, f);
                                double c2f = this.cfMatrix_LCS.get(index2, f);
                                double sim = (Double)toBeUpdated.get(index1, index2);
                                double delta_c1 = euj * dotRating * simc * c2f / sim - (double)regC * c1f;
                                double delta_c2 = euj * dotRating * simc * c1f / sim - (double)regC * c2f;
                                this.cfMatrix_LCS.add(index1, f, this.lRate * delta_c1);
                                this.cfMatrix_LCS.add(index2, f, this.lRate * delta_c2);
                                this.loss += (double)regC * c1f * c1f + (double)regC * c2f * c2f;
                            }
                        }
                    }
                }
                for (int f = 0; f < numFactors; ++f) {
                    double puf = this.P.get(u, f);
                    double qjf = this.Q.get(j, f);
                    double delta_u = euj * qjf * simc - (double)regU * puf;
                    double delta_j = euj * puf * simc - (double)regI * qjf;
                    this.P.add(u, f, this.lRate * delta_u);
                    this.Q.add(j, f, this.lRate * delta_j);
                    this.loss += (double)regU * puf * puf + (double)regI * qjf * qjf;
                }
            }
            this.loss *= 0.5;
            if (this.isConverged(iter)) break;
        }
    }
}

