/*
 * Decompiled with CFR 0.152.
 */
package librec.rating;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import librec.data.AddConfiguration;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.MatrixEntry;
import librec.data.SparseMatrix;
import librec.intf.GraphicRecommender;
import librec.util.Logs;
import librec.util.Strings;

@AddConfiguration(before="Ku, Kv, au, av, beta")
public class LDCC
extends GraphicRecommender {
    private Table<Integer, Integer, Integer> Zu;
    private Table<Integer, Integer, Integer> Zv;
    private DenseMatrix Nui;
    private DenseMatrix Nvj;
    private DenseVector Nv;
    private int[][][] Nijl;
    private DenseMatrix Nij;
    private int Ku;
    private int Kv;
    private float au;
    private float av;
    private float bl;
    private DenseMatrix PIu;
    private DenseMatrix PIv;
    private DenseMatrix PIuSum;
    private DenseMatrix PIvSum;
    private double[][][] Pijl;
    private double[][][] PijlSum;

    public LDCC(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
    }

    @Override
    protected void initModel() throws Exception {
        this.Ku = algoOptions.getInt("-ku", numFactors);
        this.Kv = algoOptions.getInt("-kv", numFactors);
        this.Nui = new DenseMatrix(numUsers, this.Ku);
        this.Nu = new DenseVector(numUsers);
        this.Nvj = new DenseMatrix(numItems, this.Kv);
        this.Nv = new DenseVector(numItems);
        this.Nijl = new int[this.Ku][this.Kv][numLevels];
        this.Nij = new DenseMatrix(this.Ku, this.Kv);
        this.au = algoOptions.getFloat("-au", 1.0f / (float)this.Ku);
        this.av = algoOptions.getFloat("-av", 1.0f / (float)this.Kv);
        this.bl = algoOptions.getFloat("-beta", 1.0f / (float)numLevels);
        this.Zu = HashBasedTable.create();
        this.Zv = HashBasedTable.create();
        for (MatrixEntry me : this.trainMatrix) {
            int u = me.row();
            int v = me.column();
            double rate = me.get();
            int l = ratingScale.indexOf(rate);
            int i = (int)((double)this.Ku * Math.random());
            int j = (int)((double)this.Kv * Math.random());
            this.Nui.add(u, i, 1.0);
            this.Nu.add(u, 1.0);
            this.Nvj.add(v, j, 1.0);
            this.Nv.add(v, 1.0);
            int[] nArray = this.Nijl[i][j];
            int n = l;
            nArray[n] = nArray[n] + 1;
            this.Nij.add(i, j, 1.0);
            this.Zu.put(u, v, i);
            this.Zv.put(u, v, j);
        }
        this.PIuSum = new DenseMatrix(numUsers, this.Ku);
        this.PIvSum = new DenseMatrix(numItems, this.Kv);
        this.Pijl = new double[this.Ku][this.Kv][numLevels];
        this.PijlSum = new double[this.Ku][this.Kv][numLevels];
    }

    @Override
    protected void eStep() {
        for (MatrixEntry me : this.trainMatrix) {
            int u = me.row();
            int v = me.column();
            double rate = me.get();
            int l = ratingScale.indexOf(rate);
            int i = this.Zu.get(u, v);
            int j = this.Zv.get(u, v);
            this.Nui.add(u, i, -1.0);
            this.Nu.add(u, -1.0);
            this.Nvj.add(v, j, -1.0);
            this.Nv.add(v, -1.0);
            int[] nArray = this.Nijl[i][j];
            int n = l;
            nArray[n] = nArray[n] - 1;
            this.Nij.add(i, j, -1.0);
            DenseMatrix probs = new DenseMatrix(this.Ku, this.Kv);
            double sum = 0.0;
            int m = 0;
            while (m < this.Ku) {
                int n2 = 0;
                while (n2 < this.Kv) {
                    double v1 = (this.Nui.get(u, m) + (double)this.au) / (this.Nu.get(u) + (double)((float)this.Ku * this.au));
                    double v2 = (this.Nvj.get(v, n2) + (double)this.av) / (this.Nv.get(v) + (double)((float)this.Kv * this.av));
                    double v3 = (double)((float)this.Nijl[m][n2][l] + this.bl) / (this.Nij.get(m, n2) + (double)((float)numLevels * this.bl));
                    double prob = v1 * v2 * v3;
                    probs.set(m, n2, prob);
                    sum += prob;
                    ++n2;
                }
                ++m;
            }
            probs = probs.scale(1.0 / sum);
            double[] Pu = new double[this.Ku];
            int m2 = 0;
            while (m2 < this.Ku) {
                Pu[m2] = probs.sumOfRow(m2);
                ++m2;
            }
            m2 = 1;
            while (m2 < this.Ku) {
                int n3 = m2;
                Pu[n3] = Pu[n3] + Pu[m2 - 1];
                ++m2;
            }
            double rand = Math.random();
            i = 0;
            while (i < this.Ku) {
                if (rand < Pu[i]) break;
                ++i;
            }
            double[] Pv = new double[this.Kv];
            int n4 = 0;
            while (n4 < this.Kv) {
                Pv[n4] = probs.sumOfColumn(n4);
                ++n4;
            }
            n4 = 1;
            while (n4 < this.Kv) {
                int n5 = n4;
                Pv[n5] = Pv[n5] + Pv[n4 - 1];
                ++n4;
            }
            rand = Math.random();
            j = 0;
            while (j < this.Kv) {
                if (rand < Pv[j]) break;
                ++j;
            }
            this.Nui.add(u, i, 1.0);
            this.Nu.add(u, 1.0);
            this.Nvj.add(v, j, 1.0);
            this.Nv.add(v, 1.0);
            int[] nArray2 = this.Nijl[i][j];
            int n6 = l;
            nArray2[n6] = nArray2[n6] + 1;
            this.Nij.add(i, j, 1.0);
            this.Zu.put(u, v, i);
            this.Zv.put(u, v, j);
        }
    }

    @Override
    protected void readoutParams() {
        int j;
        int u = 0;
        while (u < numUsers) {
            int i = 0;
            while (i < this.Ku) {
                this.PIuSum.add(u, i, (this.Nui.get(u, i) + (double)this.au) / (this.Nu.get(u) + (double)((float)this.Ku * this.au)));
                ++i;
            }
            ++u;
        }
        int v = 0;
        while (v < numItems) {
            j = 0;
            while (j < this.Kv) {
                this.PIvSum.add(v, j, (this.Nvj.get(v, j) + (double)this.av) / (this.Nv.get(v) + (double)((float)this.Kv * this.av)));
                ++j;
            }
            ++v;
        }
        int i = 0;
        while (i < this.Ku) {
            j = 0;
            while (j < this.Kv) {
                int l = 0;
                while (l < numLevels) {
                    double[] dArray = this.PijlSum[i][j];
                    int n = l;
                    dArray[n] = dArray[n] + (double)((float)this.Nijl[i][j][l] + this.bl) / (this.Nij.get(i, j) + (double)((float)numLevels * this.bl));
                    ++l;
                }
                ++j;
            }
            ++i;
        }
        ++this.numStats;
    }

    @Override
    protected void estimateParams() {
        this.PIu = this.PIuSum.scale(1.0 / (double)this.numStats);
        this.PIv = this.PIvSum.scale(1.0 / (double)this.numStats);
        int i = 0;
        while (i < this.Ku) {
            int j = 0;
            while (j < this.Kv) {
                int l = 0;
                while (l < numLevels) {
                    this.Pijl[i][j][l] = this.PijlSum[i][j][l] / (double)this.numStats;
                    ++l;
                }
                ++j;
            }
            ++i;
        }
    }

    @Override
    protected boolean isConverged(int iter) throws Exception {
        this.estimateParams();
        int N = 0;
        double sum = 0.0;
        for (MatrixEntry me : this.trainMatrix) {
            int u = me.row();
            int v = me.column();
            double ruv = me.get();
            sum += this.perplexity(u, v, ruv);
            ++N;
        }
        double perp = Math.exp(sum / (double)N);
        double delta = perp - this.loss;
        Logs.debug("{}{} iter {} achieves perplexity = {}, delta_perp = {}", this.algoName, this.foldInfo, iter, perp, delta);
        if (this.numStats > 1 && delta > 0.0) {
            return true;
        }
        this.loss = perp;
        return false;
    }

    @Override
    protected double perplexity(int u, int v, double pred) throws Exception {
        int l = (int)(pred / minRate - 1.0);
        double prob = 0.0;
        int i = 0;
        while (i < this.Ku) {
            int j = 0;
            while (j < this.Kv) {
                prob += this.Pijl[i][j][l] * this.PIu.get(u, i) * this.PIv.get(v, j);
                ++j;
            }
            ++i;
        }
        return -Math.log(prob);
    }

    @Override
    protected double predict(int u, int v) throws Exception {
        double pred = 0.0;
        int l = 0;
        while (l < numLevels) {
            double rate = (Double)ratingScale.get(l);
            double prob = 0.0;
            int i = 0;
            while (i < this.Ku) {
                int j = 0;
                while (j < this.Kv) {
                    prob += this.Pijl[i][j][l] * this.PIu.get(u, i) * this.PIv.get(v, j);
                    ++j;
                }
                ++i;
            }
            pred += rate * prob;
            ++l;
        }
        return pred;
    }

    @Override
    public String toString() {
        return String.valueOf(Strings.toString(new Object[]{this.Ku, this.Kv, Float.valueOf(this.au), Float.valueOf(this.av), Float.valueOf(this.bl)})) + ", " + super.toString();
    }
}

