/*
 * Decompiled with CFR 0.152.
 */
package carskit.alg.cars.adaptation.dependent.sim;

import carskit.alg.cars.adaptation.dependent.CAMF;
import carskit.data.structure.SparseMatrix;
import com.google.common.collect.HashBasedTable;
import java.util.Iterator;
import java.util.List;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.MatrixEntry;

public class CAMF_MCS
extends CAMF {
    private double upbound;
    private double lowbound;

    public CAMF_MCS(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        this.algoName = "CAMF_MCS";
        isRankingPred = true;
    }

    @Override
    protected void initModel() throws Exception {
        super.initModel();
        this.upbound = 1.0 / Math.sqrt(rateDao.numContextDims());
        this.lowbound = 1.0 / Math.pow(10.0, 100.0);
        this.cVector_MCS = new DenseVector(numConditions);
        this.cVector_MCS.init(this.upbound);
    }

    @Override
    protected double predict(int u, int j, int c) throws Exception {
        double pred = DenseMatrix.rowMult(this.P, u, this.Q, j);
        List<Integer> conditions = this.getConditions(c);
        double dist = 0.0;
        for (int i = 0; i < conditions.size(); ++i) {
            int index1 = conditions.get(i);
            int index2 = (Integer)EmptyContextConditions.get(i);
            dist += Math.pow(this.cVector_MCS.get(index1) - this.cVector_MCS.get(index2), 2.0);
        }
        dist = Math.sqrt(dist);
        double sim = 1.0 - dist;
        return pred *= sim;
    }

    @Override
    protected void buildModel() throws Exception {
        for (int iter = 1; iter <= numIters; ++iter) {
            this.loss = 0.0;
            for (MatrixEntry me : this.trainMatrix) {
                double dotRating;
                int ui = me.row();
                int u = rateDao.getUserIdFromUI(ui);
                int j = rateDao.getItemIdFromUI(ui);
                int ctx = me.column();
                double rujc = me.get();
                HashBasedTable toBeUpdated = HashBasedTable.create();
                double simc = 1.0;
                double pred = dotRating = DenseMatrix.rowMult(this.P, u, this.Q, j);
                List<Integer> conditions = this.getConditions(ctx);
                double dist = 0.0;
                for (int i = 0; i < conditions.size(); ++i) {
                    int index1 = conditions.get(i);
                    int index2 = (Integer)EmptyContextConditions.get(i);
                    double pos1 = this.cVector_MCS.get(index1);
                    double pos2 = this.cVector_MCS.get(index2);
                    double diff = pos1 - pos2;
                    dist += Math.pow(diff, 2.0);
                    if (index1 != index2) {
                        toBeUpdated.put((Object)index1, (Object)index2, (Object)diff);
                    }
                    this.loss += (double)regC * pos1 * pos1 + (double)regC * pos2 * pos2;
                }
                dist = Math.sqrt(dist);
                double sim = 1.0 - dist;
                double euj = rujc - (pred *= sim);
                this.loss += euj * euj;
                if (toBeUpdated.size() > 0) {
                    Iterator i$ = toBeUpdated.rowKeySet().iterator();
                    while (i$.hasNext()) {
                        int index1 = (Integer)i$.next();
                        Iterator i$2 = toBeUpdated.row((Object)index1).keySet().iterator();
                        while (i$2.hasNext()) {
                            int index2 = (Integer)i$2.next();
                            double pos1 = this.cVector_MCS.get(index1);
                            double pos2 = this.cVector_MCS.get(index2);
                            if (dist == 0.0) {
                                dist = this.lowbound;
                            }
                            double pos1_update = pos1 + this.lRate * (euj * dotRating * (Double)toBeUpdated.get(index1, index2) / dist - (double)regC * pos1);
                            double pos2_update = pos2 - this.lRate * (euj * dotRating * (Double)toBeUpdated.get(index1, index2) / dist + (double)regC * pos2);
                            pos1_update = pos1_update < 0.0 ? this.lowbound : pos1_update;
                            pos1_update = pos1_update > this.upbound ? this.upbound - this.lowbound : pos1_update;
                            pos2_update = pos2_update < 0.0 ? this.lowbound : pos2_update;
                            pos2_update = pos2_update > this.upbound ? this.upbound - this.lowbound : pos2_update;
                            this.cVector_MCS.set(index1, pos1_update);
                            this.cVector_MCS.set(index2, pos2_update);
                        }
                    }
                }
                for (int f = 0; f < numFactors; ++f) {
                    double puf = this.P.get(u, f);
                    double qjf = this.Q.get(j, f);
                    double delta_u = euj * qjf * (1.0 - dist) - (double)regU * puf;
                    double delta_j = euj * puf * (1.0 - dist) - (double)regI * qjf;
                    this.P.add(u, f, this.lRate * delta_u);
                    this.Q.add(j, f, this.lRate * delta_j);
                    this.loss += (double)regU * puf * puf + (double)regI * qjf * qjf;
                }
            }
            this.loss *= 0.05;
            if (this.isConverged(iter)) break;
        }
    }
}

