/*
 * Decompiled with CFR 0.152.
 */
package librec.rating;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import librec.data.AddConfiguration;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.MatrixEntry;
import librec.data.SparseMatrix;
import librec.data.SparseVector;
import librec.data.VectorEntry;
import librec.intf.GraphicRecommender;
import librec.util.Gaussian;
import librec.util.Logs;
import librec.util.Randoms;
import librec.util.Stats;
import librec.util.Strings;

@AddConfiguration(before="factors, q, b")
public class GPLSA
extends GraphicRecommender {
    private Table<Integer, Integer, Map<Integer, Double>> Q;
    private DenseMatrix Mu;
    private DenseMatrix Sigma;
    private DenseVector mu;
    private DenseVector sigma;
    private float q;
    private float b;
    private double preRMSE;

    public GPLSA(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
    }

    @Override
    protected void initModel() throws Exception {
        double sum;
        this.Puk = new DenseMatrix(numUsers, numFactors);
        int u = 0;
        while (u < numUsers) {
            double[] probs = Randoms.randProbs(numFactors);
            int k = 0;
            while (k < numFactors) {
                this.Puk.set(u, k, probs[k]);
                ++k;
            }
            ++u;
        }
        double mean = this.globalMean;
        double sd = Stats.sd(this.trainMatrix.getData(), mean);
        this.q = algoOptions.getFloat("-q");
        this.b = algoOptions.getFloat("-b", 1.0f);
        this.mu = new DenseVector(numUsers);
        this.sigma = new DenseVector(numUsers);
        int u2 = 0;
        while (u2 < numUsers) {
            SparseVector ru = this.trainMatrix.row(u2);
            int Nu = ru.size();
            if (Nu >= 1) {
                double mu_u = (ru.sum() + (double)this.q * mean) / (double)((float)Nu + this.q);
                this.mu.set(u2, mu_u);
                sum = 0.0;
                Iterator<VectorEntry> iterator = ru.iterator();
                while (iterator.hasNext()) {
                    VectorEntry ve = iterator.next();
                    sum += Math.pow(ve.get() - mu_u, 2.0);
                }
                double sigma_u = Math.sqrt((sum += (double)this.q * Math.pow(sd, 2.0)) / (double)((float)Nu + this.q));
                this.sigma.set(u2, sigma_u);
            }
            ++u2;
        }
        this.Q = HashBasedTable.create();
        for (MatrixEntry me : this.trainMatrix) {
            int u3 = me.row();
            int i = me.column();
            double rate = me.get();
            double r = (rate - this.mu.get(u3)) / this.sigma.get(u3);
            me.set(r);
            this.Q.put(u3, i, new HashMap());
        }
        this.Mu = new DenseMatrix(numItems, numFactors);
        this.Sigma = new DenseMatrix(numItems, numFactors);
        int i = 0;
        while (i < numItems) {
            SparseVector ci = this.trainMatrix.column(i);
            int Ni = ci.size();
            if (Ni >= 1) {
                double mu_i = ci.mean();
                sum = 0.0;
                for (VectorEntry ve : ci) {
                    sum += Math.pow(ve.get() - mu_i, 2.0);
                }
                double sd_i = Math.sqrt(sum / (double)Ni);
                int z = 0;
                while (z < numFactors) {
                    this.Mu.set(i, z, mu_i + smallValue * Math.random());
                    this.Sigma.set(i, z, sd_i + smallValue * Math.random());
                    ++z;
                }
            }
            ++i;
        }
    }

    @Override
    protected void eStep() {
        for (MatrixEntry me : this.trainMatrix) {
            int u = me.row();
            int i = me.column();
            double r = me.get();
            double denominator = 0.0;
            double[] numerator = new double[numFactors];
            int z = 0;
            while (z < numFactors) {
                double val;
                double pdf = Gaussian.pdf(r, this.Mu.get(i, z), this.Sigma.get(i, z));
                numerator[z] = val = Math.pow(this.Puk.get(u, z) * pdf, this.b);
                denominator += val;
                ++z;
            }
            Map<Integer, Double> factorProbs = this.Q.get(u, i);
            int z2 = 0;
            while (z2 < numFactors) {
                double prob = denominator > 0.0 ? numerator[z2] / denominator : 0.0;
                factorProbs.put(z2, prob);
                ++z2;
            }
        }
    }

    @Override
    protected void mStep() {
        int u = 0;
        while (u < numUsers) {
            List<Integer> items = this.trainMatrix.getColumns(u);
            if (items.size() >= 1) {
                double[] numerator = new double[numFactors];
                double denominator = 0.0;
                int z = 0;
                while (z < numFactors) {
                    for (int i : items) {
                        int n = z;
                        numerator[n] = numerator[n] + this.Q.get(u, i).get(z);
                    }
                    denominator += numerator[z];
                    ++z;
                }
                z = 0;
                while (z < numFactors) {
                    this.Puk.set(u, z, numerator[z] / denominator);
                    ++z;
                }
            }
            ++u;
        }
        int i = 0;
        while (i < numItems) {
            List<Integer> users = this.trainMatrix.getRows(i);
            if (users.size() >= 1) {
                int z = 0;
                while (z < numFactors) {
                    double numerator = 0.0;
                    double denominator = 0.0;
                    for (int u2 : users) {
                        double r = this.trainMatrix.get(u2, i);
                        double prob = this.Q.get(u2, i).get(z);
                        numerator += r * prob;
                        denominator += prob;
                    }
                    double mu = denominator > 0.0 ? numerator / denominator : 0.0;
                    this.Mu.set(i, z, mu);
                    numerator = 0.0;
                    denominator = 0.0;
                    for (int u3 : users) {
                        double r = this.trainMatrix.get(u3, i);
                        double prob = this.Q.get(u3, i).get(z);
                        numerator += Math.pow(r - mu, 2.0) * prob;
                        denominator += prob;
                    }
                    double sigma = denominator > 0.0 ? Math.sqrt(numerator / denominator) : 0.0;
                    this.Sigma.set(i, z, sigma);
                    ++z;
                }
            }
            ++i;
        }
    }

    @Override
    protected double predict(int u, int i) throws Exception {
        double sum = 0.0;
        int z = 0;
        while (z < numFactors) {
            sum += this.Puk.get(u, z) * this.Mu.get(i, z);
            ++z;
        }
        return this.mu.get(u) + this.sigma.get(u) * sum;
    }

    @Override
    protected boolean isConverged(int iter) throws Exception {
        if (this.validationMatrix == null) {
            return false;
        }
        int numCount = 0;
        double sum = 0.0;
        for (MatrixEntry me : this.validationMatrix) {
            int j;
            double rate = me.get();
            int u = me.row();
            double pred = this.predict(u, j = me.column(), true);
            if (Double.isNaN(pred)) continue;
            double err = rate - pred;
            sum += err * err;
            ++numCount;
        }
        double RMSE = Math.sqrt(sum / (double)numCount);
        double delta = RMSE - this.preRMSE;
        if (verbose) {
            Logs.debug("{}{} iter {} achieves RMSE = {}, delta_RMSE = {}", this.algoName, this.foldInfo, iter, Float.valueOf((float)RMSE), Float.valueOf((float)delta));
        }
        if (this.numStats > 1 && delta > 0.0) {
            return true;
        }
        this.preRMSE = RMSE;
        ++this.numStats;
        return false;
    }

    @Override
    public String toString() {
        return String.valueOf(Strings.toString(new Object[]{numFactors, Float.valueOf(this.q), Float.valueOf(this.b)})) + ", " + super.toString();
    }
}

