/*
 * Decompiled with CFR 0.152.
 */
package librec.rating;

import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.MatrixEntry;
import librec.data.SparseMatrix;
import librec.intf.IterativeRecommender;

public class RfRec
extends IterativeRecommender {
    private DenseVector userAverages;
    private DenseVector itemAverages;
    private DenseMatrix userRatingFrequencies;
    private DenseMatrix itemRatingFrequencies;
    private DenseVector userWeights;
    private DenseVector itemWeights;

    public RfRec(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
    }

    @Override
    protected void initModel() throws Exception {
        this.userAverages = new DenseVector(numUsers);
        this.itemAverages = new DenseVector(numItems);
        this.userWeights = new DenseVector(numUsers);
        this.itemWeights = new DenseVector(numItems);
        int u = 0;
        while (u < numUsers) {
            this.userAverages.set(u, this.trainMatrix.row(u).mean());
            this.userWeights.set(u, 0.6 + Math.random() * 0.01);
            ++u;
        }
        int j = 0;
        while (j < numItems) {
            this.itemAverages.set(j, this.trainMatrix.column(j).mean());
            this.itemWeights.set(j, 0.4 + Math.random() * 0.01);
            ++j;
        }
        this.userRatingFrequencies = new DenseMatrix(numUsers, ratingScale.size());
        this.itemRatingFrequencies = new DenseMatrix(numItems, ratingScale.size());
        for (MatrixEntry me : this.trainMatrix) {
            int u2 = me.row();
            int j2 = me.column();
            int ruj = (int)me.get();
            this.userRatingFrequencies.add(u2, ruj - 1, 1.0);
            this.itemRatingFrequencies.add(j2, ruj - 1, 1.0);
        }
        this.userWeights = new DenseVector(numUsers);
        this.itemWeights = new DenseVector(numItems);
    }

    @Override
    protected void buildModel() throws Exception {
        int i = 1;
        while (i <= numIters) {
            this.loss = 0.0;
            for (MatrixEntry me : this.trainMatrix) {
                int u = me.row();
                int j = me.column();
                double ruj = me.get();
                double pred = this.predict(u, j);
                double euj = ruj - pred;
                this.loss += euj * euj;
                double userWeight = this.userWeights.get(u) + this.lRate * (euj - (double)regU * this.userWeights.get(u));
                this.userWeights.set(u, userWeight);
                double itemWeight = this.itemWeights.get(j) + this.lRate * (euj - (double)regI * this.itemWeights.get(j));
                this.itemWeights.set(j, itemWeight);
                this.loss += (double)regU * this.userWeights.get(u) * this.userWeights.get(u) + (double)regI * this.itemWeights.get(j) * this.itemWeights.get(j);
            }
            this.loss *= 0.5;
            if (this.isConverged(i)) break;
            ++i;
        }
    }

    private int isAvgRating(double avg, int rating) {
        return Math.round(avg) == (long)rating ? 1 : 0;
    }

    @Override
    protected double predict(int u, int j) {
        double estimate = this.globalMean;
        float enumeratorUser = 0.0f;
        float denominatorUser = 0.0f;
        float enumeratorItem = 0.0f;
        float denominatorItem = 0.0f;
        if (this.userRatingFrequencies.row(u).sum() > 0.0 && this.itemRatingFrequencies.row(j).sum() > 0.0 && this.userAverages.get(u) > 0.0 && this.itemAverages.get(j) > 0.0) {
            int r = 0;
            while (r < ratingScale.size()) {
                int ratingValue = (int)Math.round((Double)ratingScale.get(r));
                int tmpUser = 0;
                double frequencyInt = this.userRatingFrequencies.get(u, ratingValue - 1);
                int frequency = (int)frequencyInt;
                tmpUser = frequency + 1 + this.isAvgRating(this.userAverages.get(u), ratingValue);
                enumeratorUser += (float)(tmpUser * ratingValue);
                denominatorUser += (float)tmpUser;
                int tmpItem = 0;
                frequency = 0;
                frequencyInt = this.itemRatingFrequencies.get(j, ratingValue - 1);
                frequency = (int)frequencyInt;
                tmpItem = frequency + 1 + this.isAvgRating(this.itemAverages.get(j), ratingValue);
                enumeratorItem += (float)(tmpItem * ratingValue);
                denominatorItem += (float)tmpItem;
                ++r;
            }
            double w_u = this.userWeights.get(u);
            double w_i = this.itemWeights.get(j);
            float pred_ui_user = enumeratorUser / denominatorUser;
            float pred_ui_item = enumeratorItem / denominatorItem;
            estimate = (float)w_u * pred_ui_user + (float)w_i * pred_ui_item;
        } else {
            if (this.userRatingFrequencies.row(u).sum() == 0.0 || this.userAverages.get(u) == 0.0) {
                double iavg = this.itemAverages.get(j);
                if (iavg != 0.0) {
                    return iavg;
                }
                return this.globalMean;
            }
            if (this.itemRatingFrequencies.row(j).sum() == 0.0 || this.itemAverages.get(j) == 0.0) {
                double uavg = this.userAverages.get(u);
                if (uavg != 0.0) {
                    return uavg;
                }
                return this.globalMean;
            }
        }
        return estimate;
    }
}

