/*
 * Decompiled with CFR 0.152.
 */
package librec.ranking;

import com.google.common.collect.HashBasedTable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import librec.data.AddConfiguration;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.RatingContext;
import librec.data.SparseMatrix;
import librec.intf.GraphicRecommender;
import librec.util.Gamma;
import librec.util.Strings;

@AddConfiguration(before="factors, alpha, beta")
public class ItemBigram
extends GraphicRecommender {
    private Map<Integer, List<Integer>> userItemsMap;
    private int[][][] Nkji;
    private DenseMatrix Nkj;
    private double[][][] Pkji;
    private double[][][] PkjiSum;
    private DenseMatrix beta;

    public ItemBigram(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        isRankingPred = true;
    }

    @Override
    protected void initModel() throws Exception {
        this.userItemsMap = new HashMap<Integer, List<Integer>>();
        int u = 0;
        while (u < numUsers) {
            List<Integer> unsortedItems = this.trainMatrix.getColumns(u);
            int size = unsortedItems.size();
            ArrayList<RatingContext> rcs = new ArrayList<RatingContext>(size);
            for (Integer i : unsortedItems) {
                rcs.add(new RatingContext(u, i, (long)timeMatrix.get(u, i)));
            }
            Collections.sort(rcs);
            ArrayList<Integer> sortedItems = new ArrayList<Integer>(size);
            for (RatingContext rc : rcs) {
                sortedItems.add(rc.getItem());
            }
            this.userItemsMap.put(u, sortedItems);
            ++u;
        }
        this.Nuk = new DenseMatrix(numUsers, numFactors);
        this.Nu = new DenseVector(numUsers);
        this.Nkji = new int[numFactors][numItems + 1][numItems];
        this.Nkj = new DenseMatrix(numFactors, numItems + 1);
        this.PukSum = new DenseMatrix(numUsers, numFactors);
        this.PkjiSum = new double[numFactors][numItems + 1][numItems];
        this.Pkji = new double[numFactors][numItems + 1][numItems];
        this.alpha = new DenseVector(numFactors);
        this.alpha.setAll(initAlpha);
        this.beta = new DenseMatrix(numFactors, numItems + 1);
        this.beta.setAll(initBeta);
        this.z = HashBasedTable.create();
        for (Map.Entry<Integer, List<Integer>> en : this.userItemsMap.entrySet()) {
            int u2 = en.getKey();
            List<Integer> items = en.getValue();
            int m = 0;
            while (m < items.size()) {
                int i = items.get(m);
                int k = (int)(Math.random() * (double)numFactors);
                this.z.put(u2, i, k);
                this.Nuk.add(u2, k, 1.0);
                this.Nu.add(u2, 1.0);
                int j = m > 0 ? items.get(m - 1) : numItems;
                int[] nArray = this.Nkji[k][j];
                int n = i;
                nArray[n] = nArray[n] + 1;
                this.Nkj.add(k, j, 1.0);
                ++m;
            }
        }
    }

    @Override
    protected void eStep() {
        double sumAlpha = this.alpha.sum();
        for (Map.Entry<Integer, List<Integer>> en : this.userItemsMap.entrySet()) {
            int u = en.getKey();
            List<Integer> items = en.getValue();
            int m = 0;
            while (m < items.size()) {
                int i = items.get(m);
                int k = (Integer)this.z.get(u, i);
                this.Nuk.add(u, k, -1.0);
                this.Nu.add(u, -1.0);
                int j = m > 0 ? items.get(m - 1) : numItems;
                int[] nArray = this.Nkji[k][j];
                int n = i;
                nArray[n] = nArray[n] - 1;
                this.Nkj.add(k, j, -1.0);
                double[] Pk = new double[numFactors];
                int t = 0;
                while (t < numFactors) {
                    double v1 = (this.Nuk.get(u, t) + this.alpha.get(t)) / (this.Nu.get(u) + sumAlpha);
                    double v2 = ((double)this.Nkji[t][j][i] + this.beta.get(t, j)) / (this.Nkj.get(t, j) + this.beta.sumOfRow(t));
                    Pk[t] = v1 * v2;
                    ++t;
                }
                t = 1;
                while (t < numFactors) {
                    int n2 = t;
                    Pk[n2] = Pk[n2] + Pk[t - 1];
                    ++t;
                }
                double rand = Math.random() * Pk[numFactors - 1];
                k = 0;
                while (k < numFactors) {
                    if (rand < Pk[k]) break;
                    ++k;
                }
                this.z.put(u, i, k);
                this.Nuk.add(u, k, 1.0);
                this.Nu.add(u, 1.0);
                int[] nArray2 = this.Nkji[k][j];
                int n3 = i;
                nArray2[n3] = nArray2[n3] + 1;
                this.Nkj.add(k, j, 1.0);
                ++m;
            }
        }
    }

    @Override
    protected void mStep() {
        double sumAlpha = this.alpha.sum();
        int k = 0;
        while (k < numFactors) {
            double ak = this.alpha.get(k);
            double numerator = 0.0;
            double denominator = 0.0;
            int u = 0;
            while (u < numUsers) {
                numerator += Gamma.digamma(this.Nuk.get(u, k) + ak) - Gamma.digamma(ak);
                denominator += Gamma.digamma(this.Nu.get(u) + sumAlpha) - Gamma.digamma(sumAlpha);
                ++u;
            }
            if (numerator != 0.0) {
                this.alpha.set(k, ak * (numerator / denominator));
            }
            ++k;
        }
        k = 0;
        while (k < numFactors) {
            double bk = this.beta.sumOfRow(k);
            int j = 0;
            while (j < numItems + 1) {
                double bkj = this.beta.get(k, j);
                double numerator = 0.0;
                double denominator = 0.0;
                int i = 0;
                while (i < numItems) {
                    numerator += Gamma.digamma((double)this.Nkji[k][j][i] + bkj) - Gamma.digamma(bkj);
                    denominator += Gamma.digamma(this.Nkj.get(k, j) + bk) - Gamma.digamma(bk);
                    ++i;
                }
                if (numerator != 0.0) {
                    this.beta.set(k, j, bkj * (numerator / denominator));
                }
                ++j;
            }
            ++k;
        }
    }

    @Override
    protected void readoutParams() {
        double val = 0.0;
        double sumAlpha = this.alpha.sum();
        int u = 0;
        while (u < numFactors) {
            int k = 0;
            while (k < numFactors) {
                val = (this.Nuk.get(u, k) + this.alpha.get(k)) / (this.Nu.get(u) + sumAlpha);
                this.PukSum.add(u, k, val);
                ++k;
            }
            ++u;
        }
        int k = 0;
        while (k < numFactors) {
            double bk = this.beta.sumOfRow(k);
            int j = 0;
            while (j < numItems + 1) {
                int i = 0;
                while (i < numItems) {
                    val = ((double)this.Nkji[k][j][i] + this.beta.get(k, j)) / (this.Nkj.get(k, j) + bk);
                    double[] dArray = this.PkjiSum[k][j];
                    int n = i++;
                    dArray[n] = dArray[n] + val;
                }
                ++j;
            }
            ++k;
        }
        ++this.numStats;
    }

    @Override
    protected void estimateParams() {
        this.Puk = this.PukSum.scale(1.0 / (double)this.numStats);
        int k = 0;
        while (k < numFactors) {
            int j = 0;
            while (j < numItems + 1) {
                int i = 0;
                while (i < numItems) {
                    this.Pkji[k][j][i] = this.PkjiSum[k][j][i] / (double)this.numStats;
                    ++i;
                }
                ++j;
            }
            ++k;
        }
    }

    @Override
    protected double ranking(int u, int i) throws Exception {
        List<Integer> items = this.userItemsMap.get(u);
        int j = items.size() < 1 ? numItems : items.get(items.size() - 1);
        double rank = 0.0;
        int k = 0;
        while (k < numFactors) {
            rank += this.Puk.get(u, k) * this.Pkji[k][j][i];
            ++k;
        }
        return rank;
    }

    @Override
    public String toString() {
        return String.valueOf(Strings.toString(new Object[]{numFactors, Float.valueOf(initAlpha), Float.valueOf(initBeta)})) + ", " + super.toString();
    }
}

