/*
 * Decompiled with CFR 0.152.
 */
package librec.ranking;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import librec.data.Configuration;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.SparseMatrix;
import librec.data.SparseVector;
import librec.data.VectorEntry;
import librec.intf.IterativeRecommender;
import librec.util.Randoms;
import librec.util.Strings;

@Configuration(value="binThold, rho, alpha, factors, lRate, maxLRate, regI, regB, iters")
public class FISMauc
extends IterativeRecommender {
    private int rho;
    private float alpha;

    public FISMauc(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        isRankingPred = true;
    }

    @Override
    protected void initModel() throws Exception {
        this.P = new DenseMatrix(numItems, numFactors);
        this.Q = new DenseMatrix(numItems, numFactors);
        this.P.init(smallValue);
        this.Q.init(smallValue);
        this.itemBias = new DenseVector(numItems);
        this.itemBias.init(smallValue);
        algoOptions = cf.getParamOptions("FISM");
        this.rho = algoOptions.getInt("-rho");
        this.alpha = algoOptions.getFloat("-alpha");
        this.userItemsCache = this.trainMatrix.rowColumnsCache(cacheSpec);
    }

    @Override
    protected void buildModel() throws Exception {
        int iter = 1;
        while (iter <= numIters) {
            this.loss = 0.0;
            for (int u : this.trainMatrix.rows()) {
                SparseVector Ru = this.trainMatrix.row(u);
                int[] ratedItems = Ru.getIndex();
                for (VectorEntry ve : Ru) {
                    int j;
                    int i = ve.index();
                    double rui = ve.get();
                    ArrayList<Integer> js = new ArrayList<Integer>();
                    int len = 0;
                    while (len < this.rho) {
                        int j2 = Randoms.uniform(numItems);
                        if (Ru.contains(j2) || js.contains(j2)) continue;
                        js.add(j2);
                        ++len;
                    }
                    double wu = Ru.getCount() - 1 > 0 ? Math.pow(Ru.getCount() - 1, -this.alpha) : 0.0;
                    double[] x = new double[numFactors];
                    Iterator iterator = js.iterator();
                    while (iterator.hasNext()) {
                        j = (Integer)iterator.next();
                        double sum_i = 0.0;
                        double sum_j = 0.0;
                        int[] nArray = ratedItems;
                        int n = ratedItems.length;
                        int n2 = 0;
                        while (n2 < n) {
                            int k = nArray[n2];
                            if (i != k) {
                                sum_i += DenseMatrix.rowMult(this.P, k, this.Q, i);
                            }
                            sum_j += DenseMatrix.rowMult(this.P, k, this.Q, j);
                            ++n2;
                        }
                        double bi = this.itemBias.get(i);
                        double bj = this.itemBias.get(j);
                        double pui = bi + wu * sum_i;
                        double puj = bj + Math.pow(Ru.getCount(), -this.alpha) * sum_j;
                        double ruj = 0.0;
                        double eij = rui - ruj - (pui - puj);
                        this.loss += eij * eij;
                        this.itemBias.add(i, this.lRate * (eij - (double)regB * bi));
                        this.itemBias.add(j, -this.lRate * (eij - (double)regB * bj));
                        this.loss += (double)regB * bi * bi - (double)regB * bj * bj;
                        int f = 0;
                        while (f < numFactors) {
                            double qif = this.Q.get(i, f);
                            double qjf = this.Q.get(j, f);
                            double sum_k = 0.0;
                            int[] nArray2 = ratedItems;
                            int n3 = ratedItems.length;
                            int n4 = 0;
                            while (n4 < n3) {
                                int k = nArray2[n4];
                                if (k != i) {
                                    sum_k += this.P.get(k, f);
                                }
                                ++n4;
                            }
                            double delta_i = eij * wu * sum_k - (double)regI * qif;
                            this.Q.add(i, f, this.lRate * delta_i);
                            double delta_j = eij * wu * sum_k - (double)regI * qjf;
                            this.Q.add(j, f, -this.lRate * delta_j);
                            int n5 = f++;
                            x[n5] = x[n5] + eij * (qif - qjf);
                            this.loss += (double)regI * qif * qif - (double)regI * qjf * qjf;
                        }
                    }
                    int[] nArray = ratedItems;
                    int n = ratedItems.length;
                    int n6 = 0;
                    while (n6 < n) {
                        j = nArray[n6];
                        if (j != i) {
                            int f = 0;
                            while (f < numFactors) {
                                double pjf = this.P.get(j, f);
                                double delta = wu * x[f] / (double)this.rho - (double)regI * pjf;
                                this.P.add(j, f, this.lRate * delta);
                                this.loss += (double)regI * pjf * pjf;
                                ++f;
                            }
                        }
                        ++n6;
                    }
                }
            }
            this.loss *= 0.5;
            if (this.isConverged(iter)) break;
            ++iter;
        }
    }

    @Override
    protected double predict(int u, int i) throws Exception {
        double sum = 0.0;
        int count = 0;
        List items = (List)this.userItemsCache.get(u);
        Iterator iterator = items.iterator();
        while (iterator.hasNext()) {
            int j = (Integer)iterator.next();
            if (i == j) continue;
            sum += DenseMatrix.rowMult(this.P, j, this.Q, i);
            ++count;
        }
        double wu = count > 0 ? Math.pow(count, -this.alpha) : 0.0;
        return this.itemBias.get(i) + wu * sum;
    }

    @Override
    public String toString() {
        return Strings.toString(new Object[]{Float.valueOf(binThold), this.rho, Float.valueOf(this.alpha), numFactors, Float.valueOf(initLRate), Float.valueOf(maxLRate), Float.valueOf(regI), Float.valueOf(regB), numIters});
    }
}

