/*
 * Decompiled with CFR 0.152.
 */
package librec.baseline;

import java.math.BigDecimal;
import java.math.RoundingMode;
import librec.data.Configuration;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.SparseMatrix;
import librec.data.SparseVector;
import librec.data.VectorEntry;
import librec.intf.GraphicRecommender;
import librec.util.Logs;
import librec.util.Randoms;

@Configuration(value="factors, max.iters")
public class ItemCluster
extends GraphicRecommender {
    private DenseMatrix Pkr;
    private DenseVector Pi;
    private DenseMatrix Gamma;
    private DenseMatrix Nir;
    private DenseVector Ni;

    public ItemCluster(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
    }

    @Override
    protected void initModel() throws Exception {
        this.Pkr = new DenseMatrix(numFactors, numLevels);
        int k = 0;
        while (k < numFactors) {
            double[] probs = Randoms.randProbs(numLevels);
            int r = 0;
            while (r < numLevels) {
                this.Pkr.set(k, r, probs[r]);
                ++r;
            }
            ++k;
        }
        this.Pi = new DenseVector(Randoms.randProbs(numFactors));
        this.Gamma = new DenseMatrix(numItems, numFactors);
        this.Nir = new DenseMatrix(numItems, numLevels);
        this.Ni = new DenseVector(numItems);
        int i = 0;
        while (i < numItems) {
            SparseVector ri = this.trainMatrix.column(i);
            for (VectorEntry ve : ri) {
                double rui = ve.get();
                int r = ratingScale.indexOf(rui);
                this.Nir.add(i, r, 1.0);
            }
            this.Ni.set(i, ri.size());
            ++i;
        }
    }

    @Override
    protected void buildModel() throws Exception {
        int iter = 1;
        while (iter <= numIters) {
            int i = 0;
            while (i < numItems) {
                BigDecimal sum_i = BigDecimal.ZERO;
                SparseVector ri = this.trainMatrix.column(i);
                BigDecimal[] sum_ik = new BigDecimal[numFactors];
                int k = 0;
                while (k < numFactors) {
                    BigDecimal pik = new BigDecimal(this.Pi.get(k));
                    for (VectorEntry ve : ri) {
                        double rui = ve.get();
                        int r = ratingScale.indexOf(rui);
                        BigDecimal pkr = new BigDecimal(this.Pkr.get(k, r));
                        pik = pik.multiply(pkr);
                    }
                    sum_ik[k] = pik;
                    sum_i = sum_i.add(pik);
                    ++k;
                }
                k = 0;
                while (k < numFactors) {
                    double zik = sum_ik[k].divide(sum_i, 6, RoundingMode.HALF_UP).doubleValue();
                    this.Gamma.set(i, k, zik);
                    ++k;
                }
                ++i;
            }
            double[] sum_ik = new double[numFactors];
            double sum = 0.0;
            int k = 0;
            while (k < numFactors) {
                int r = 0;
                while (r < numLevels) {
                    double numerator = 0.0;
                    double denorminator = 0.0;
                    int i2 = 0;
                    while (i2 < numItems) {
                        double ruk = this.Gamma.get(i2, k);
                        numerator += ruk * this.Nir.get(i2, r);
                        denorminator += ruk * this.Ni.get(i2);
                        ++i2;
                    }
                    this.Pkr.set(k, r, numerator / denorminator);
                    ++r;
                }
                double sum_i = 0.0;
                int i3 = 0;
                while (i3 < numItems) {
                    double rik = this.Gamma.get(i3, k);
                    sum_i += rik;
                    ++i3;
                }
                sum_ik[k] = sum_i;
                sum += sum_i;
                ++k;
            }
            k = 0;
            while (k < numFactors) {
                this.Pi.set(k, sum_ik[k] / sum);
                ++k;
            }
            this.loss = 0.0;
            int i4 = 0;
            while (i4 < numItems) {
                int k2 = 0;
                while (k2 < numFactors) {
                    double rik = this.Gamma.get(i4, k2);
                    double pi_k = this.Pi.get(k2);
                    double sum_nl = 0.0;
                    int r = 0;
                    while (r < numLevels) {
                        double nur = this.Nir.get(i4, r);
                        double pkr = this.Pkr.get(k2, r);
                        sum_nl += nur * Math.log(pkr);
                        ++r;
                    }
                    this.loss += rik * (Math.log(pi_k) + sum_nl);
                    ++k2;
                }
                ++i4;
            }
            this.loss = -this.loss;
            if (this.isConverged(iter)) break;
            ++iter;
        }
    }

    @Override
    protected boolean isConverged(int iter) throws Exception {
        float deltaLoss = (float)(this.loss - this.lastLoss);
        Logs.debug("{}{} iter {} achives loss = {}, delta_loss = {}", this.algoName, this.foldInfo, iter, Float.valueOf((float)this.loss), Float.valueOf(deltaLoss));
        if (iter > 1 && (deltaLoss > 0.0f || Double.isNaN(deltaLoss))) {
            Logs.debug("{}{} converges at iter {}", this.algoName, this.foldInfo, iter);
            return true;
        }
        this.lastLoss = this.loss;
        return false;
    }

    @Override
    protected double predict(int u, int j, boolean bound) throws Exception {
        double pred = 0.0;
        int k = 0;
        while (k < numFactors) {
            double pj_k = this.Gamma.get(j, k);
            double pred_k = 0.0;
            int r = 0;
            while (r < numLevels) {
                double rui = (Double)ratingScale.get(r);
                double pkr = this.Pkr.get(k, r);
                pred_k += rui * pkr;
                ++r;
            }
            pred += pj_k * pred_k;
            ++k;
        }
        return pred;
    }

    @Override
    public String toString() {
        return String.valueOf(numFactors) + "," + numIters;
    }
}

