/*
 * Decompiled with CFR 0.152.
 */
package librec.ranking;

import com.google.common.collect.HashBasedTable;
import librec.data.AddConfiguration;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.MatrixEntry;
import librec.data.SparseMatrix;
import librec.intf.GraphicRecommender;
import librec.util.Gamma;
import librec.util.Logs;
import librec.util.Strings;

@AddConfiguration(before="factors, alpha, beta, gamma")
public class BUCM
extends GraphicRecommender {
    private float initGamma;
    private DenseVector gamma;
    private int[][][] Nkir;
    private double[][][] PkirSum;
    protected double[][][] Pkir;

    public BUCM(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
    }

    @Override
    protected void initModel() throws Exception {
        this.PukSum = new DenseMatrix(numUsers, numFactors);
        this.PkiSum = new DenseMatrix(numFactors, numItems);
        this.PkirSum = new double[numFactors][numItems][numLevels];
        this.Nuk = new DenseMatrix(numUsers, numFactors);
        this.Nu = new DenseVector(numUsers);
        this.Nki = new DenseMatrix(numFactors, numItems);
        this.Nk = new DenseVector(numFactors);
        this.Nkir = new int[numFactors][numItems][numLevels];
        this.alpha = new DenseVector(numFactors);
        this.alpha.setAll(initAlpha);
        this.beta = new DenseVector(numItems);
        this.beta.setAll(initBeta);
        this.gamma = new DenseVector(numLevels);
        this.initGamma = algoOptions.getFloat("-gamma", 1.0f / (float)numLevels);
        this.gamma.setAll(this.initGamma);
        this.z = HashBasedTable.create();
        for (MatrixEntry me : this.trainMatrix) {
            int u = me.row();
            int i = me.column();
            double rui = me.get();
            int r = ratingScale.indexOf(rui);
            int t = (int)(Math.random() * (double)numFactors);
            this.z.put(u, i, t);
            this.Nuk.add(u, t, 1.0);
            this.Nu.add(u, 1.0);
            this.Nki.add(t, i, 1.0);
            this.Nk.add(t, 1.0);
            int[] nArray = this.Nkir[t][i];
            int n = r;
            nArray[n] = nArray[n] + 1;
        }
    }

    @Override
    protected void eStep() {
        double sumAlpha = this.alpha.sum();
        double sumBeta = this.beta.sum();
        double sumGamma = this.gamma.sum();
        for (MatrixEntry me : this.trainMatrix) {
            int u = me.row();
            int i = me.column();
            double rui = me.get();
            int r = ratingScale.indexOf(rui);
            int t = (Integer)this.z.get(u, i);
            this.Nuk.add(u, t, -1.0);
            this.Nu.add(u, -1.0);
            this.Nki.add(t, i, -1.0);
            this.Nk.add(t, -1.0);
            int[] nArray = this.Nkir[t][i];
            int n = r;
            nArray[n] = nArray[n] - 1;
            double[] p = new double[numFactors];
            int k = 0;
            while (k < numFactors) {
                double v1 = (this.Nuk.get(u, k) + this.alpha.get(k)) / (this.Nu.get(u) + sumAlpha);
                double v2 = (this.Nki.get(k, i) + this.beta.get(i)) / (this.Nk.get(k) + sumBeta);
                double v3 = ((double)this.Nkir[k][i][r] + this.gamma.get(r)) / (this.Nki.get(k, i) + sumGamma);
                p[k] = v1 * v2 * v3;
                ++k;
            }
            k = 1;
            while (k < p.length) {
                int n2 = k;
                p[n2] = p[n2] + p[k - 1];
                ++k;
            }
            double rand = Math.random() * p[numFactors - 1];
            t = 0;
            while (t < p.length) {
                if (rand < p[t]) break;
                ++t;
            }
            this.z.put(u, i, t);
            this.Nuk.add(u, t, 1.0);
            this.Nu.add(u, 1.0);
            this.Nki.add(t, i, 1.0);
            this.Nk.add(t, 1.0);
            int[] nArray2 = this.Nkir[t][i];
            int n3 = r;
            nArray2[n3] = nArray2[n3] + 1;
        }
    }

    @Override
    protected void mStep() {
        double denominator;
        double numerator;
        double sumAlpha = this.alpha.sum();
        double sumBeta = this.beta.sum();
        double sumGamma = this.gamma.sum();
        int k = 0;
        while (k < numFactors) {
            double ak = this.alpha.get(k);
            numerator = 0.0;
            denominator = 0.0;
            int u = 0;
            while (u < numUsers) {
                numerator += Gamma.digamma(this.Nuk.get(u, k) + ak) - Gamma.digamma(ak);
                denominator += Gamma.digamma(this.Nu.get(u) + sumAlpha) - Gamma.digamma(sumAlpha);
                ++u;
            }
            if (numerator != 0.0) {
                this.alpha.set(k, ak * (numerator / denominator));
            }
            ++k;
        }
        int i = 0;
        while (i < numItems) {
            double bi = this.beta.get(i);
            numerator = 0.0;
            denominator = 0.0;
            int k2 = 0;
            while (k2 < numFactors) {
                numerator += Gamma.digamma(this.Nki.get(k2, i) + bi) - Gamma.digamma(bi);
                denominator += Gamma.digamma(this.Nk.get(k2) + sumBeta) - Gamma.digamma(sumBeta);
                ++k2;
            }
            if (numerator != 0.0) {
                this.beta.set(i, bi * (numerator / denominator));
            }
            ++i;
        }
        int r = 0;
        while (r < numLevels) {
            double gr = this.gamma.get(r);
            numerator = 0.0;
            denominator = 0.0;
            int i2 = 0;
            while (i2 < numItems) {
                int k3 = 0;
                while (k3 < numFactors) {
                    numerator += Gamma.digamma((double)this.Nkir[k3][i2][r] + gr) - Gamma.digamma(gr);
                    denominator += Gamma.digamma(this.Nki.get(k3, i2) + sumGamma) - Gamma.digamma(sumGamma);
                    ++k3;
                }
                ++i2;
            }
            if (numerator != 0.0) {
                this.gamma.set(r, gr * (numerator / denominator));
            }
            ++r;
        }
    }

    @Override
    protected boolean isConverged(int iter) throws Exception {
        this.loss = 0.0;
        this.estimateParams();
        int count = 0;
        for (MatrixEntry me : this.trainMatrix) {
            int u = me.row();
            int i = me.column();
            double rui = me.get();
            int r = ratingScale.indexOf(rui);
            double prob = 0.0;
            int k = 0;
            while (k < numFactors) {
                prob += this.Puk.get(u, k) * this.Pki.get(k, i) * this.Pkir[k][i][r];
                ++k;
            }
            this.loss += -Math.log(prob);
            ++count;
        }
        this.loss /= (double)count;
        float delta = (float)(this.loss - this.lastLoss);
        Logs.debug("{}{} iter {} achieves log likelihood = {}, delta_LogLLH = {}", this.algoName, this.foldInfo, iter, Float.valueOf((float)this.loss), Float.valueOf(delta));
        if (this.numStats > 1 && delta > 0.0f) {
            Logs.debug("{}{} has converged at iter {}", this.algoName, this.foldInfo, iter);
            return true;
        }
        this.lastLoss = this.loss;
        return false;
    }

    @Override
    protected void readoutParams() {
        int i;
        double val = 0.0;
        double sumAlpha = this.alpha.sum();
        double sumBeta = this.beta.sum();
        double sumGamma = this.gamma.sum();
        int u = 0;
        while (u < numUsers) {
            int k = 0;
            while (k < numFactors) {
                val = (this.Nuk.get(u, k) + this.alpha.get(k)) / (this.Nu.get(u) + sumAlpha);
                this.PukSum.add(u, k, val);
                ++k;
            }
            ++u;
        }
        int k = 0;
        while (k < numFactors) {
            i = 0;
            while (i < numItems) {
                val = (this.Nki.get(k, i) + this.beta.get(i)) / (this.Nk.get(k) + sumBeta);
                this.PkiSum.add(k, i, val);
                ++i;
            }
            ++k;
        }
        k = 0;
        while (k < numFactors) {
            i = 0;
            while (i < numItems) {
                int r = 0;
                while (r < numLevels) {
                    val = ((double)this.Nkir[k][i][r] + this.gamma.get(r)) / (this.Nki.get(k, i) + sumGamma);
                    double[] dArray = this.PkirSum[k][i];
                    int n = r++;
                    dArray[n] = dArray[n] + val;
                }
                ++i;
            }
            ++k;
        }
        ++this.numStats;
    }

    @Override
    protected void estimateParams() {
        this.Puk = this.PukSum.scale(1.0 / (double)this.numStats);
        this.Pki = this.PkiSum.scale(1.0 / (double)this.numStats);
        this.Pkir = new double[numFactors][numItems][numLevels];
        int k = 0;
        while (k < numFactors) {
            int i = 0;
            while (i < numItems) {
                int r = 0;
                while (r < numLevels) {
                    this.Pkir[k][i][r] = this.PkirSum[k][i][r] / (double)this.numStats;
                    ++r;
                }
                ++i;
            }
            ++k;
        }
    }

    @Override
    protected double perplexity(int u, int j, double ruj) throws Exception {
        int r = (int)(ruj / minRate) - 1;
        double prob = 0.0;
        int k = 0;
        while (k < numFactors) {
            prob += this.Puk.get(u, k) * this.Pki.get(k, j) * this.Pkir[k][j][r];
            ++k;
        }
        return -Math.log(prob);
    }

    @Override
    protected double predict(int u, int i) throws Exception {
        double pred = 0.0;
        double probs = 0.0;
        int r = 0;
        while (r < numLevels) {
            double rate = (Double)ratingScale.get(r);
            double prob = 0.0;
            int k = 0;
            while (k < numFactors) {
                prob += this.Puk.get(u, k) * this.Pki.get(k, i) * this.Pkir[k][i][r];
                ++k;
            }
            pred += prob * rate;
            probs += prob;
            ++r;
        }
        return pred / probs;
    }

    @Override
    protected double ranking(int u, int j) throws Exception {
        double rank = 0.0;
        int k = 0;
        while (k < numFactors) {
            double sum = 0.0;
            int r = 0;
            while (r < numLevels) {
                double rate = (Double)ratingScale.get(r);
                if (rate > this.globalMean) {
                    sum += this.Pkir[k][j][r];
                }
                ++r;
            }
            rank += this.Puk.get(u, k) * this.Pki.get(k, j) * sum;
            ++k;
        }
        return rank;
    }

    @Override
    public String toString() {
        return String.valueOf(Strings.toString(new Object[]{numFactors, Float.valueOf(initAlpha), Float.valueOf(initBeta), Float.valueOf(this.initGamma)})) + ", " + super.toString();
    }
}

