/*
 * Decompiled with CFR 0.152.
 */
package librec.baseline;

import java.math.BigDecimal;
import java.math.RoundingMode;
import librec.data.Configuration;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.SparseMatrix;
import librec.data.SparseVector;
import librec.data.VectorEntry;
import librec.intf.GraphicRecommender;
import librec.util.Logs;
import librec.util.Randoms;

@Configuration(value="factors, max.iters")
public class UserCluster
extends GraphicRecommender {
    private DenseMatrix Pkr;
    private DenseVector Pi;
    private DenseMatrix Gamma;
    private DenseMatrix Nur;
    private DenseVector Nu;

    public UserCluster(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
    }

    @Override
    protected void initModel() throws Exception {
        this.Pkr = new DenseMatrix(numFactors, numLevels);
        int k = 0;
        while (k < numFactors) {
            double[] probs = Randoms.randProbs(numLevels);
            int r = 0;
            while (r < numLevels) {
                this.Pkr.set(k, r, probs[r]);
                ++r;
            }
            ++k;
        }
        this.Pi = new DenseVector(Randoms.randProbs(numFactors));
        this.Gamma = new DenseMatrix(numUsers, numFactors);
        this.Nur = new DenseMatrix(numUsers, numLevels);
        this.Nu = new DenseVector(numUsers);
        int u = 0;
        while (u < numUsers) {
            SparseVector ru = this.trainMatrix.row(u);
            for (VectorEntry ve : ru) {
                double rui = ve.get();
                int r = ratingScale.indexOf(rui);
                this.Nur.add(u, r, 1.0);
            }
            this.Nu.set(u, ru.size());
            ++u;
        }
    }

    @Override
    protected void buildModel() throws Exception {
        int iter = 1;
        while (iter <= numIters) {
            int u = 0;
            while (u < numUsers) {
                BigDecimal sum_u = BigDecimal.ZERO;
                SparseVector ru = this.trainMatrix.row(u);
                BigDecimal[] sum_uk = new BigDecimal[numFactors];
                int k = 0;
                while (k < numFactors) {
                    BigDecimal puk = new BigDecimal(this.Pi.get(k));
                    for (VectorEntry ve : ru) {
                        double rui = ve.get();
                        int r = ratingScale.indexOf(rui);
                        BigDecimal pkr = new BigDecimal(this.Pkr.get(k, r));
                        puk = puk.multiply(pkr);
                    }
                    sum_uk[k] = puk;
                    sum_u = sum_u.add(puk);
                    ++k;
                }
                k = 0;
                while (k < numFactors) {
                    double zuk = sum_uk[k].divide(sum_u, 6, RoundingMode.HALF_UP).doubleValue();
                    this.Gamma.set(u, k, zuk);
                    ++k;
                }
                ++u;
            }
            double[] sum_uk = new double[numFactors];
            double sum = 0.0;
            int k = 0;
            while (k < numFactors) {
                int r = 0;
                while (r < numLevels) {
                    double numerator = 0.0;
                    double denorminator = 0.0;
                    int u2 = 0;
                    while (u2 < numUsers) {
                        double ruk = this.Gamma.get(u2, k);
                        numerator += ruk * this.Nur.get(u2, r);
                        denorminator += ruk * this.Nu.get(u2);
                        ++u2;
                    }
                    this.Pkr.set(k, r, numerator / denorminator);
                    ++r;
                }
                double sum_u = 0.0;
                int u3 = 0;
                while (u3 < numUsers) {
                    double ruk = this.Gamma.get(u3, k);
                    sum_u += ruk;
                    ++u3;
                }
                sum_uk[k] = sum_u;
                sum += sum_u;
                ++k;
            }
            k = 0;
            while (k < numFactors) {
                this.Pi.set(k, sum_uk[k] / sum);
                ++k;
            }
            this.loss = 0.0;
            int u4 = 0;
            while (u4 < numUsers) {
                int k2 = 0;
                while (k2 < numFactors) {
                    double ruk = this.Gamma.get(u4, k2);
                    double pi_k = this.Pi.get(k2);
                    double sum_nl = 0.0;
                    int r = 0;
                    while (r < numLevels) {
                        double nur = this.Nur.get(u4, r);
                        double pkr = this.Pkr.get(k2, r);
                        sum_nl += nur * Math.log(pkr);
                        ++r;
                    }
                    this.loss += ruk * (Math.log(pi_k) + sum_nl);
                    ++k2;
                }
                ++u4;
            }
            this.loss = -this.loss;
            if (this.isConverged(iter)) break;
            ++iter;
        }
    }

    @Override
    protected boolean isConverged(int iter) throws Exception {
        float deltaLoss = (float)(this.loss - this.lastLoss);
        Logs.debug("{}{} iter {} achives loss = {}, delta_loss = {}", this.algoName, this.foldInfo, iter, Float.valueOf((float)this.loss), Float.valueOf(deltaLoss));
        if (iter > 1 && (deltaLoss > 0.0f || Double.isNaN(deltaLoss))) {
            Logs.debug("{}{} converges at iter {}", this.algoName, this.foldInfo, iter);
            return true;
        }
        this.lastLoss = this.loss;
        return false;
    }

    @Override
    protected double predict(int u, int j, boolean bound) throws Exception {
        double pred = 0.0;
        int k = 0;
        while (k < numFactors) {
            double pu_k = this.Gamma.get(u, k);
            double pred_k = 0.0;
            int r = 0;
            while (r < numLevels) {
                double rui = (Double)ratingScale.get(r);
                double pkr = this.Pkr.get(k, r);
                pred_k += rui * pkr;
                ++r;
            }
            pred += pu_k * pred_k;
            ++k;
        }
        return pred;
    }

    @Override
    public String toString() {
        return String.valueOf(numFactors) + "," + numIters;
    }
}

