/*
 * Decompiled with CFR 0.152.
 */
package carskit.alg.cars.adaptation.dependent;

import carskit.data.structure.DenseMatrix;
import carskit.data.structure.DenseVector;
import carskit.data.structure.SparseMatrix;
import carskit.generic.ContextRecommender;
import com.google.common.collect.HashBasedTable;
import java.util.Iterator;
import librec.data.MatrixEntry;

public class FM
extends ContextRecommender {
    private double w0;
    private int p;
    private int k;
    private int size;
    private DenseVector w;
    private DenseMatrix V;
    private DenseMatrix Q;
    private float regLw;
    private float regLf;

    public FM(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        this.algoName = "FM";
        this.regLw = algoOptions.getFloat("-lw");
        this.regLf = algoOptions.getFloat("-lf");
    }

    @Override
    protected void initModel() throws Exception {
        super.initModel();
        this.k = numFactors;
        this.p = this.numUsers + this.numItems + numConditions;
        this.w0 = 0.0;
        this.size = this.trainMatrix.size();
        this.w = new DenseVector(this.p);
        this.w.init();
        this.V = new DenseMatrix(this.p, this.k);
        this.V.init(initMean, initStd);
        this.Q = new DenseMatrix(this.size, this.k);
    }

    private DenseVector getFeatureVector(int u, int j, int c) {
        DenseVector fs = new DenseVector(this.p);
        int indexu = u;
        int indexj = this.numUsers + j;
        int indexc = this.numUsers + this.numItems + c;
        for (int i = 0; i < this.p; ++i) {
            if (i == indexu || i == indexj) {
                fs.set(i, 1.0);
                continue;
            }
            if (i == indexc) {
                fs.set(i, 1.0 / (double)rateDao.numContextDims());
                continue;
            }
            fs.set(i, 0.0);
        }
        return fs;
    }

    @Override
    protected double predict(int u, int j, int c) throws Exception {
        DenseVector fs = this.getFeatureVector(u, j, c);
        double pred = this.w0;
        for (j = 0; j < this.p; ++j) {
            pred += this.w.get(j) * fs.get(j);
        }
        double sum = 0.0;
        for (int f = 0; f < this.k; ++f) {
            double sum1 = 0.0;
            double sum2 = 0.0;
            for (j = 0; j < this.p; ++j) {
                double dot = this.V.get(j, f) * fs.get(j);
                sum1 += this.V.get(j, f) * fs.get(j);
                sum2 += Math.pow(dot, 2.0);
            }
            sum += Math.pow(sum1, 2.0) - sum2;
        }
        return pred += 0.5 * sum;
    }

    @Override
    protected void buildModel() throws Exception {
        DenseVector errors = new DenseVector(this.size);
        Iterator<MatrixEntry> itor = this.trainMatrix.iterator();
        int counter = -1;
        HashBasedTable fvalues = HashBasedTable.create();
        while (itor.hasNext()) {
            MatrixEntry me = itor.next();
            ++counter;
            int ui = me.row();
            int u = rateDao.getUserIdFromUI(ui);
            int j = rateDao.getItemIdFromUI(ui);
            int c = me.column();
            double rujc = me.get();
            DenseVector fs = this.getFeatureVector(u, j, c);
            double pred = this.predict(u, j, c);
            double euj = rujc - pred;
            errors.set(counter, euj);
            for (int f = 0; f < this.k; ++f) {
                double value = 0.0;
                for (int i = 0; i < this.p; ++i) {
                    value += this.V.get(i, f) * fs.get(i);
                    fvalues.put((Object)counter, (Object)i, (Object)fs.get(i));
                }
                this.Q.set(counter, f, value);
            }
        }
        for (int iter = 1; iter <= numIters; ++iter) {
            int i;
            this.loss = 0.0;
            double update_w0 = 0.0;
            for (i = 0; i < this.size; ++i) {
                double err = errors.get(i);
                update_w0 += err - this.w0;
                this.loss += err * err;
            }
            update_w0 /= (double)((float)this.size + this.regLw);
            update_w0 = 0.0 - update_w0;
            for (i = 0; i < this.size; ++i) {
                errors.set(i, errors.get(i) + update_w0 - this.w0);
            }
            this.loss += (double)this.regLw * this.w0 * this.w0;
            this.w0 = update_w0;
            for (int l = 0; l < this.p; ++l) {
                int i2;
                double update_wl = 0.0;
                double sum = 0.0;
                for (i2 = 0; i2 < this.size; ++i2) {
                    double fl = (Double)fvalues.get(i2, l);
                    update_wl += (errors.get(i2) - this.w.get(l) * fl) * fl;
                    sum += Math.pow(fl, 2.0) + (double)this.regLw;
                }
                update_wl = 0.0 - update_wl / sum;
                for (i2 = 0; i2 < this.size; ++i2) {
                    errors.set(i2, errors.get(i2) + (update_wl - this.w.get(l)) * (Double)fvalues.get(i2, l));
                }
                this.loss += (double)this.regLw * this.w.get(l) * this.w.get(l);
                this.w.set(l, update_wl);
            }
            for (int f = 0; f < this.k; ++f) {
                for (int l = 0; l < this.p; ++l) {
                    int i3;
                    double update_Vlf = 0.0;
                    double sum = 0.0;
                    for (i3 = 0; i3 < this.size; ++i3) {
                        double fl = (Double)fvalues.get(i3, l);
                        double hlf = fl * this.Q.get(i3, f) - Math.pow(fl, 2.0) * this.V.get(l, f);
                        update_Vlf += (errors.get(i3) - this.V.get(l, f) * hlf) * hlf;
                        sum += Math.pow(hlf, 2.0) + (double)this.regLf;
                        this.loss += (double)this.regLf * Math.pow(this.Q.get(i3, f), 2.0);
                    }
                    update_Vlf = 0.0 - update_Vlf / sum;
                    for (i3 = 0; i3 < this.size; ++i3) {
                        errors.set(i3, errors.get(i3) + (update_Vlf - this.V.get(l, f)) * (Double)fvalues.get(i3, l));
                        this.Q.set(i3, f, this.Q.get(i3, f) + (update_Vlf - this.V.get(l, f)) * (Double)fvalues.get(i3, l));
                    }
                    this.V.set(l, f, update_Vlf);
                }
            }
            this.loss *= 0.05;
        }
    }
}

