/*
 * Decompiled with CFR 0.152.
 */
package librec.rating;

import librec.data.Configuration;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.MatrixEntry;
import librec.data.SparseMatrix;
import librec.data.SparseVector;
import librec.intf.IterativeRecommender;
import librec.util.Randoms;
import librec.util.Strings;

@Configuration(value="factors, iters")
public class BPMF
extends IterativeRecommender {
    public BPMF(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        this.lRate = -1.0;
    }

    @Override
    protected void buildModel() throws Exception {
        int beta = 2;
        DenseVector mu_u = new DenseVector(numFactors);
        DenseVector mu_m = new DenseVector(numFactors);
        DenseMatrix WI_u = DenseMatrix.eye(numFactors);
        int b0_u = 2;
        int df_u = numFactors;
        DenseVector mu0_u = new DenseVector(numFactors);
        DenseMatrix WI_m = DenseMatrix.eye(numFactors);
        int b0_m = 2;
        int df_m = numFactors;
        DenseVector mu0_m = new DenseVector(numFactors);
        this.P = new DenseMatrix(numUsers, numFactors);
        this.Q = new DenseMatrix(numItems, numFactors);
        this.P.init(0.0, 1.0);
        this.Q.init(0.0, 1.0);
        int f = 0;
        while (f < numFactors) {
            mu_u.set(f, this.P.columnMean(f));
            mu_m.set(f, this.Q.columnMean(f));
            ++f;
        }
        DenseMatrix alpha_u = this.P.cov().inv();
        DenseMatrix alpha_m = this.Q.cov().inv();
        DenseVector x_bar = new DenseVector(numFactors);
        DenseVector normalRdn = new DenseVector(numFactors);
        int M = numUsers;
        int N = numItems;
        int iter = 1;
        while (iter <= numIters) {
            double df_mpost;
            int f2;
            double df_upost;
            int f3 = 0;
            while (f3 < numFactors) {
                x_bar.set(f3, this.P.columnMean(f3));
                ++f3;
            }
            DenseMatrix S_bar = this.P.cov();
            DenseVector mu0_u_x_bar = mu0_u.minus(x_bar);
            DenseMatrix e1e2 = mu0_u_x_bar.outer(mu0_u_x_bar).scale((double)(M * b0_u) / ((double)(b0_u + M) + 0.0));
            DenseMatrix WI_post = WI_u.inv().add(S_bar.scale(M)).add(e1e2);
            WI_post = WI_post.inv();
            DenseMatrix wishrnd_u = this.wishart(WI_post = WI_post.add(WI_post.transpose()).scale(0.5), df_upost = (double)(df_u + M));
            if (wishrnd_u != null) {
                alpha_u = wishrnd_u;
            }
            DenseVector mu_temp = mu0_u.scale(b0_u).add(x_bar.scale(M)).scale(1.0 / ((double)(b0_u + M) + 0.0));
            DenseMatrix lam = alpha_u.scale(b0_u + M).inv().cholesky();
            if (lam != null) {
                lam = lam.transpose();
                f2 = 0;
                while (f2 < numFactors) {
                    normalRdn.set(f2, Randoms.gaussian(0.0, 1.0));
                    ++f2;
                }
                mu_u = lam.mult(normalRdn).add(mu_temp);
            }
            f2 = 0;
            while (f2 < numFactors) {
                x_bar.set(f2, this.Q.columnMean(f2));
                ++f2;
            }
            S_bar = this.Q.cov();
            DenseVector mu0_m_x_bar = mu0_m.minus(x_bar);
            DenseMatrix e3e4 = mu0_m_x_bar.outer(mu0_m_x_bar).scale((double)(N * b0_m) / ((double)(b0_m + N) + 0.0));
            WI_post = WI_m.inv().add(S_bar.scale(N)).add(e3e4);
            WI_post = WI_post.inv();
            DenseMatrix wishrnd_m = this.wishart(WI_post = WI_post.add(WI_post.transpose()).scale(0.5), df_mpost = (double)(df_m + N));
            if (wishrnd_m != null) {
                alpha_m = wishrnd_m;
            }
            mu_temp = mu0_m.scale(b0_m).add(x_bar.scale(N)).scale(1.0 / ((double)(b0_m + N) + 0.0));
            lam = alpha_m.scale(b0_m + N).inv().cholesky();
            if (lam != null) {
                lam = lam.transpose();
                int f4 = 0;
                while (f4 < numFactors) {
                    normalRdn.set(f4, Randoms.gaussian(0.0, 1.0));
                    ++f4;
                }
                mu_m = lam.mult(normalRdn).add(mu_temp);
            }
            int gibbs = 0;
            while (gibbs < 2) {
                int f5;
                Object mean_u;
                DenseVector b22;
                DenseMatrix covar;
                int idx;
                DenseVector rr;
                DenseMatrix MM;
                int count;
                int u = 0;
                while (u < numUsers) {
                    SparseVector rv = this.trainMatrix.row(u);
                    count = rv.getCount();
                    if (count != 0) {
                        int f6;
                        MM = new DenseMatrix(count, numFactors);
                        rr = new DenseVector(count);
                        idx = 0;
                        int[] nArray = rv.getIndex();
                        int n = nArray.length;
                        int n2 = 0;
                        while (n2 < n) {
                            int j = nArray[n2];
                            rr.set(idx, rv.get(j) - this.globalMean);
                            f6 = 0;
                            while (f6 < numFactors) {
                                MM.set(idx, f6, this.Q.get(j, f6));
                                ++f6;
                            }
                            ++idx;
                            ++n2;
                        }
                        covar = alpha_u.add(MM.transpose().mult(MM).scale(beta)).inv();
                        DenseVector a = MM.transpose().mult(rr).scale(beta);
                        b22 = alpha_u.mult(mu_u);
                        mean_u = covar.mult(a.add(b22));
                        lam = covar.cholesky();
                        if (lam != null) {
                            lam = lam.transpose();
                            f6 = 0;
                            while (f6 < numFactors) {
                                normalRdn.set(f6, Randoms.gaussian(0.0, 1.0));
                                ++f6;
                            }
                            DenseVector w1_P1_u = lam.mult(normalRdn).add((DenseVector)mean_u);
                            f5 = 0;
                            while (f5 < numFactors) {
                                this.P.set(u, f5, w1_P1_u.get(f5));
                                ++f5;
                            }
                        }
                    }
                    ++u;
                }
                int j = 0;
                while (j < numItems) {
                    SparseVector jv = this.trainMatrix.column(j);
                    count = jv.getCount();
                    if (count != 0) {
                        MM = new DenseMatrix(count, numFactors);
                        rr = new DenseVector(count);
                        idx = 0;
                        mean_u = jv.getIndex();
                        int b22 = ((int[])mean_u).length;
                        int a = 0;
                        while (a < b22) {
                            int u2 = mean_u[a];
                            rr.set(idx, jv.get(u2) - this.globalMean);
                            int f7 = 0;
                            while (f7 < numFactors) {
                                MM.set(idx, f7, this.P.get(u2, f7));
                                ++f7;
                            }
                            ++idx;
                            ++a;
                        }
                        covar = alpha_m.add(MM.transpose().mult(MM).scale(beta)).inv();
                        DenseVector a2 = MM.transpose().mult(rr).scale(beta);
                        b22 = alpha_m.mult(mu_m);
                        DenseVector mean_m = covar.mult(a2.add(b22));
                        lam = covar.cholesky();
                        if (lam != null) {
                            lam = lam.transpose();
                            int f8 = 0;
                            while (f8 < numFactors) {
                                normalRdn.set(f8, Randoms.gaussian(0.0, 1.0));
                                ++f8;
                            }
                            DenseVector w1_M1_j = lam.mult(normalRdn).add(mean_m);
                            f5 = 0;
                            while (f5 < numFactors) {
                                this.Q.set(j, f5, w1_M1_j.get(f5));
                                ++f5;
                            }
                        }
                    }
                    ++j;
                }
                ++gibbs;
            }
            this.loss = 0.0;
            for (MatrixEntry me : this.trainMatrix) {
                int u = me.row();
                int j = me.column();
                double ruj = me.get();
                double pred = this.predict(u, j);
                double euj = ruj - pred;
                this.loss += euj * euj;
            }
            this.loss *= 0.5;
            if (this.isConverged(iter)) break;
            ++iter;
        }
    }

    protected DenseMatrix wishart(DenseMatrix scale, double df) {
        int j;
        DenseMatrix A = scale.cholesky();
        if (A == null) {
            return null;
        }
        int p = scale.numRows();
        DenseMatrix z = new DenseMatrix(p, p);
        int i = 0;
        while (i < p) {
            int j2 = 0;
            while (j2 < p) {
                z.set(i, j2, Randoms.gaussian(0.0, 1.0));
                ++j2;
            }
            ++i;
        }
        SparseVector y = new SparseVector(p);
        int i2 = 0;
        while (i2 < p) {
            y.set(i2, Randoms.gamma((df - (double)(i2 + 1)) / 2.0, 2.0));
            ++i2;
        }
        DenseMatrix B = new DenseMatrix(p, p);
        B.set(0, 0, y.get(0));
        if (p > 1) {
            j = 1;
            while (j < p) {
                SparseVector zz = new SparseVector(j);
                int k = 0;
                while (k < j) {
                    zz.set(k, z.get(k, j));
                    ++k;
                }
                B.set(j, j, y.get(j) + zz.inner(zz));
                ++j;
            }
            j = 1;
            while (j < p) {
                B.set(0, j, z.get(0, j) * Math.sqrt(y.get(0)));
                B.set(j, 0, B.get(0, j));
                ++j;
            }
        }
        if (p > 2) {
            j = 2;
            while (j < p) {
                int i3 = 1;
                while (i3 <= j - 1) {
                    SparseVector zki = new SparseVector(i3);
                    SparseVector zkj = new SparseVector(i3);
                    int k = 0;
                    while (k <= i3 - 1) {
                        zki.set(k, z.get(k, i3));
                        zkj.set(k, z.get(k, j));
                        ++k;
                    }
                    B.set(i3, j, z.get(i3, j) * Math.sqrt(y.get(i3)) + zki.inner(zkj));
                    B.set(j, i3, B.get(i3, j));
                    ++i3;
                }
                ++j;
            }
        }
        return A.transpose().mult(B).mult(A);
    }

    @Override
    protected double predict(int u, int j) {
        return this.globalMean + DenseMatrix.rowMult(this.P, u, this.Q, j);
    }

    @Override
    public String toString() {
        return Strings.toString(new Object[]{numFactors, numIters});
    }
}

