/*
 * Decompiled with CFR 0.152.
 */
package librec.ranking;

import com.google.common.collect.Table;
import java.util.Iterator;
import java.util.List;
import librec.data.Configuration;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.SparseMatrix;
import librec.data.SparseVector;
import librec.data.VectorEntry;
import librec.intf.IterativeRecommender;
import librec.util.Randoms;
import librec.util.Strings;

@Configuration(value="binThold, rho, alpha, factors, lRate, maxLRate, regI, regB, iters")
public class FISMrmse
extends IterativeRecommender {
    private float rho;
    private float alpha;
    private int nnz;

    public FISMrmse(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        isRankingPred = true;
    }

    @Override
    protected void initModel() throws Exception {
        this.P = new DenseMatrix(numItems, numFactors);
        this.Q = new DenseMatrix(numItems, numFactors);
        this.P.init(0.01);
        this.Q.init(0.01);
        this.userBias = new DenseVector(numUsers);
        this.itemBias = new DenseVector(numItems);
        this.userBias.init(0.01);
        this.itemBias.init(0.01);
        this.nnz = this.trainMatrix.size();
        algoOptions = cf.getParamOptions("FISM");
        this.rho = algoOptions.getFloat("-rho");
        this.alpha = algoOptions.getFloat("-alpha");
        this.userItemsCache = this.trainMatrix.rowColumnsCache(cacheSpec);
    }

    @Override
    protected void buildModel() throws Exception {
        int sampleSize = (int)(this.rho * (float)this.nnz);
        int totalSize = numUsers * numItems;
        int iter = 1;
        while (iter <= numIters) {
            this.loss = 0.0;
            DenseMatrix PS = new DenseMatrix(numItems, numFactors);
            DenseMatrix QS = new DenseMatrix(numItems, numFactors);
            Table<Integer, Integer, Double> R = this.trainMatrix.getDataTable();
            List<Integer> indices = Randoms.randInts(sampleSize, 0, totalSize - this.nnz);
            int index = 0;
            int count = 0;
            boolean isDone = false;
            int u = 0;
            while (u < numUsers) {
                int j = 0;
                while (j < numItems) {
                    double ruj = this.trainMatrix.get(u, j);
                    if (ruj == 0.0 && count++ == indices.get(index)) {
                        R.put(u, j, 0.0);
                        if (++index >= indices.size()) {
                            isDone = true;
                            break;
                        }
                    }
                    ++j;
                }
                if (isDone) break;
                ++u;
            }
            for (Table.Cell<Integer, Integer, Double> cell : R.cellSet()) {
                int u2 = cell.getRowKey();
                int j = cell.getColumnKey();
                double ruj = cell.getValue();
                SparseVector Ru = this.trainMatrix.row(u2);
                double bu = this.userBias.get(u2);
                double bj = this.itemBias.get(j);
                double sum_ij = 0.0;
                int cnt = 0;
                for (VectorEntry ve : Ru) {
                    int i = ve.index();
                    if (i == j) continue;
                    sum_ij += DenseMatrix.rowMult(this.P, i, this.Q, j);
                    ++cnt;
                }
                double wu = cnt > 0 ? Math.pow(cnt, -this.alpha) : 0.0;
                double puj = bu + bj + wu * sum_ij;
                double euj = puj - ruj;
                this.loss += euj * euj;
                this.userBias.add(u2, -this.lRate * (euj + (double)regB * bu));
                this.itemBias.add(j, -this.lRate * (euj + (double)regB * bj));
                this.loss += (double)regB * bu * bu + (double)regB * bj * bj;
                int f = 0;
                while (f < numFactors) {
                    double qjf = this.Q.get(j, f);
                    double sum_i = 0.0;
                    for (VectorEntry ve : Ru) {
                        int i = ve.index();
                        if (i == j) continue;
                        sum_i += this.P.get(i, f);
                    }
                    double delta = euj * wu * sum_i + (double)regI * qjf;
                    QS.add(j, f, -this.lRate * delta);
                    this.loss += (double)regI * qjf * qjf;
                    ++f;
                }
                for (VectorEntry ve : Ru) {
                    int i = ve.index();
                    if (i == j) continue;
                    int f2 = 0;
                    while (f2 < numFactors) {
                        double pif = this.P.get(i, f2);
                        double delta = euj * wu * this.Q.get(j, f2) + (double)regI * pif;
                        PS.add(i, f2, -this.lRate * delta);
                        this.loss += (double)regI * pif * pif;
                        ++f2;
                    }
                }
            }
            this.P = this.P.add(PS);
            this.Q = this.Q.add(QS);
            this.loss *= 0.5;
            if (this.isConverged(iter)) break;
            ++iter;
        }
    }

    @Override
    protected double predict(int u, int j) throws Exception {
        double pred = this.userBias.get(u) + this.itemBias.get(j);
        double sum = 0.0;
        int count = 0;
        List items = (List)this.userItemsCache.get(u);
        Iterator iterator = items.iterator();
        while (iterator.hasNext()) {
            int i = (Integer)iterator.next();
            if (i == j) continue;
            sum += DenseMatrix.rowMult(this.P, i, this.Q, j);
            ++count;
        }
        double wu = count > 0 ? Math.pow(count, -this.alpha) : 0.0;
        return pred + wu * sum;
    }

    @Override
    public String toString() {
        return Strings.toString(new Object[]{Float.valueOf(binThold), Float.valueOf(this.rho), Float.valueOf(this.alpha), numFactors, Float.valueOf(initLRate), Float.valueOf(maxLRate), Float.valueOf(regI), Float.valueOf(regB), numIters}, ",");
    }
}

