/*
 * Decompiled with CFR 0.152.
 */
package carskit.alg.baseline.cf;

import carskit.data.setting.Configuration;
import carskit.data.structure.SparseMatrix;
import carskit.generic.IterativeRecommender;
import happy.coding.io.Strings;
import happy.coding.math.Randoms;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.MatrixEntry;
import librec.data.SparseVector;

@Configuration(value="factors, iters")
public class BPMF
extends IterativeRecommender {
    public BPMF(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        this.lRate = -1.0;
        this.algoName = "BPMF";
    }

    @Override
    protected void buildModel() throws Exception {
        int beta = 2;
        DenseVector mu_u = new DenseVector(numFactors);
        DenseVector mu_m = new DenseVector(numFactors);
        DenseMatrix WI_u = DenseMatrix.eye(numFactors);
        int b0_u = 2;
        int df_u = numFactors;
        DenseVector mu0_u = new DenseVector(numFactors);
        DenseMatrix WI_m = DenseMatrix.eye(numFactors);
        int b0_m = 2;
        int df_m = numFactors;
        DenseVector mu0_m = new DenseVector(numFactors);
        this.P = new DenseMatrix(this.numUsers, numFactors);
        this.Q = new DenseMatrix(this.numItems, numFactors);
        this.P.init(0.0, 1.0);
        this.Q.init(0.0, 1.0);
        for (int f = 0; f < numFactors; ++f) {
            mu_u.set(f, this.P.columnMean(f));
            mu_m.set(f, this.Q.columnMean(f));
        }
        DenseMatrix alpha_u = this.P.cov().inv();
        DenseMatrix alpha_m = this.Q.cov().inv();
        DenseVector x_bar = new DenseVector(numFactors);
        DenseVector normalRdn = new DenseVector(numFactors);
        int M = this.numUsers;
        int N = this.numItems;
        for (int iter = 1; iter <= numIters; ++iter) {
            double df_mpost;
            int f;
            double df_upost;
            for (int f2 = 0; f2 < numFactors; ++f2) {
                x_bar.set(f2, this.P.columnMean(f2));
            }
            DenseMatrix S_bar = this.P.cov();
            DenseVector mu0_u_x_bar = mu0_u.minus(x_bar);
            DenseMatrix e1e2 = mu0_u_x_bar.outer(mu0_u_x_bar).scale((double)(M * b0_u) / ((double)(b0_u + M) + 0.0));
            DenseMatrix WI_post = WI_u.inv().add(S_bar.scale(M)).add(e1e2);
            WI_post = WI_post.inv();
            DenseMatrix wishrnd_u = this.wishart(WI_post = WI_post.add(WI_post.transpose()).scale(0.5), df_upost = (double)(df_u + M));
            if (wishrnd_u != null) {
                alpha_u = wishrnd_u;
            }
            DenseVector mu_temp = mu0_u.scale(b0_u).add(x_bar.scale(M)).scale(1.0 / ((double)(b0_u + M) + 0.0));
            DenseMatrix lam = alpha_u.scale(b0_u + M).inv().cholesky();
            if (lam != null) {
                lam = lam.transpose();
                for (f = 0; f < numFactors; ++f) {
                    normalRdn.set(f, Randoms.gaussian(0.0, 1.0));
                }
                mu_u = lam.mult(normalRdn).add(mu_temp);
            }
            for (f = 0; f < numFactors; ++f) {
                x_bar.set(f, this.Q.columnMean(f));
            }
            S_bar = this.Q.cov();
            DenseVector mu0_m_x_bar = mu0_m.minus(x_bar);
            DenseMatrix e3e4 = mu0_m_x_bar.outer(mu0_m_x_bar).scale((double)(N * b0_m) / ((double)(b0_m + N) + 0.0));
            WI_post = WI_m.inv().add(S_bar.scale(N)).add(e3e4);
            WI_post = WI_post.inv();
            DenseMatrix wishrnd_m = this.wishart(WI_post = WI_post.add(WI_post.transpose()).scale(0.5), df_mpost = (double)(df_m + N));
            if (wishrnd_m != null) {
                alpha_m = wishrnd_m;
            }
            mu_temp = mu0_m.scale(b0_m).add(x_bar.scale(N)).scale(1.0 / ((double)(b0_m + N) + 0.0));
            lam = alpha_m.scale(b0_m + N).inv().cholesky();
            if (lam != null) {
                lam = lam.transpose();
                for (int f3 = 0; f3 < numFactors; ++f3) {
                    normalRdn.set(f3, Randoms.gaussian(0.0, 1.0));
                }
                mu_m = lam.mult(normalRdn).add(mu_temp);
            }
            for (int gibbs = 0; gibbs < 2; ++gibbs) {
                int f4;
                DenseVector b;
                DenseMatrix covar;
                int idx;
                DenseVector rr;
                DenseMatrix MM;
                int count;
                for (int u = 0; u < this.numUsers; ++u) {
                    int f5;
                    SparseVector rv = this.train.row(u);
                    count = rv.getCount();
                    if (count == 0) continue;
                    MM = new DenseMatrix(count, numFactors);
                    rr = new DenseVector(count);
                    idx = 0;
                    for (int j : rv.getIndex()) {
                        rr.set(idx, rv.get(j) - this.globalMean);
                        for (f5 = 0; f5 < numFactors; ++f5) {
                            MM.set(idx, f5, this.Q.get(j, f5));
                        }
                        ++idx;
                    }
                    covar = alpha_u.add(MM.transpose().mult(MM).scale(beta)).inv();
                    DenseVector a = MM.transpose().mult(rr).scale(beta);
                    b = alpha_u.mult(mu_u);
                    DenseVector mean_u = covar.mult(a.add(b));
                    lam = covar.cholesky();
                    if (lam == null) continue;
                    lam = lam.transpose();
                    for (f5 = 0; f5 < numFactors; ++f5) {
                        normalRdn.set(f5, Randoms.gaussian(0.0, 1.0));
                    }
                    DenseVector w1_P1_u = lam.mult(normalRdn).add(mean_u);
                    for (f4 = 0; f4 < numFactors; ++f4) {
                        this.P.set(u, f4, w1_P1_u.get(f4));
                    }
                }
                for (int j = 0; j < this.numItems; ++j) {
                    SparseVector jv = this.train.column(j);
                    count = jv.getCount();
                    if (count == 0) continue;
                    MM = new DenseMatrix(count, numFactors);
                    rr = new DenseVector(count);
                    idx = 0;
                    for (int u : jv.getIndex()) {
                        rr.set(idx, jv.get(u) - this.globalMean);
                        for (int f6 = 0; f6 < numFactors; ++f6) {
                            MM.set(idx, f6, this.P.get(u, f6));
                        }
                        ++idx;
                    }
                    covar = alpha_m.add(MM.transpose().mult(MM).scale(beta)).inv();
                    DenseVector a = MM.transpose().mult(rr).scale(beta);
                    b = alpha_m.mult(mu_m);
                    DenseVector mean_m = covar.mult(a.add(b));
                    lam = covar.cholesky();
                    if (lam == null) continue;
                    lam = lam.transpose();
                    for (int f7 = 0; f7 < numFactors; ++f7) {
                        normalRdn.set(f7, Randoms.gaussian(0.0, 1.0));
                    }
                    DenseVector w1_M1_j = lam.mult(normalRdn).add(mean_m);
                    for (f4 = 0; f4 < numFactors; ++f4) {
                        this.Q.set(j, f4, w1_M1_j.get(f4));
                    }
                }
            }
            this.loss = 0.0;
            for (MatrixEntry me : this.train) {
                int u = me.row();
                int j = me.column();
                double ruj = me.get();
                double pred = this.predict(u, j);
                double euj = ruj - pred;
                this.loss += euj * euj;
            }
            this.loss *= 0.5;
            if (this.isConverged(iter)) break;
        }
    }

    protected DenseMatrix wishart(DenseMatrix scale, double df) {
        int j;
        DenseMatrix A = scale.cholesky();
        if (A == null) {
            return null;
        }
        int p = scale.numRows();
        DenseMatrix z = new DenseMatrix(p, p);
        for (int i = 0; i < p; ++i) {
            for (int j2 = 0; j2 < p; ++j2) {
                z.set(i, j2, Randoms.gaussian(0.0, 1.0));
            }
        }
        SparseVector y = new SparseVector(p);
        for (int i = 0; i < p; ++i) {
            y.set(i, Randoms.gamma((df - (double)(i + 1)) / 2.0, 2.0));
        }
        DenseMatrix B = new DenseMatrix(p, p);
        B.set(0, 0, y.get(0));
        if (p > 1) {
            for (j = 1; j < p; ++j) {
                SparseVector zz = new SparseVector(j);
                for (int k = 0; k < j; ++k) {
                    zz.set(k, z.get(k, j));
                }
                B.set(j, j, y.get(j) + zz.inner(zz));
            }
            for (j = 1; j < p; ++j) {
                B.set(0, j, z.get(0, j) * Math.sqrt(y.get(0)));
                B.set(j, 0, B.get(0, j));
            }
        }
        if (p > 2) {
            for (j = 2; j < p; ++j) {
                for (int i = 1; i <= j - 1; ++i) {
                    SparseVector zki = new SparseVector(i);
                    SparseVector zkj = new SparseVector(i);
                    for (int k = 0; k <= i - 1; ++k) {
                        zki.set(k, z.get(k, i));
                        zkj.set(k, z.get(k, j));
                    }
                    B.set(i, j, z.get(i, j) * Math.sqrt(y.get(i)) + zki.inner(zkj));
                    B.set(j, i, B.get(i, j));
                }
            }
        }
        return A.transpose().mult(B).mult(A);
    }

    @Override
    protected double predict(int u, int j, int c) throws Exception {
        if (this.isUserSplitting) {
            int n = u = this.userIdMapper.contains(u, c) ? (Integer)this.userIdMapper.get(u, c) : u;
        }
        if (this.isItemSplitting) {
            j = this.itemIdMapper.contains(j, c) ? (Integer)this.itemIdMapper.get(j, c) : j;
        }
        return this.predict(u, j);
    }

    @Override
    protected double predict(int u, int j) throws Exception {
        return this.globalMean + DenseMatrix.rowMult(this.P, u, this.Q, j);
    }

    @Override
    public String toString() {
        return Strings.toString(new Object[]{numFactors, numIters});
    }
}

