/*
 * Decompiled with CFR 0.152.
 */
package librec.ext;

import java.util.HashMap;
import java.util.Map;
import librec.data.DenseVector;
import librec.data.SparseMatrix;
import librec.data.SparseVector;
import librec.data.SymmMatrix;
import librec.data.VectorEntry;
import librec.ranking.RankSGD;
import librec.util.Lists;
import librec.util.Randoms;

public class PRankD
extends RankSGD {
    private DenseVector s;
    private SymmMatrix itemCorrs;
    private float alpha;

    public PRankD(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        isRankingPred = true;
    }

    @Override
    protected void initModel() throws Exception {
        super.initModel();
        HashMap<Integer, Double> itemProbsMap = new HashMap<Integer, Double>();
        double maxUsers = 0.0;
        this.s = new DenseVector(numItems);
        int j = 0;
        while (j < numItems) {
            int users = this.trainMatrix.columnSize(j);
            if (maxUsers < (double)users) {
                maxUsers = users;
            }
            this.s.set(j, users);
            double prob = ((double)users + 0.0) / (double)numRates;
            if (prob > 0.0) {
                itemProbsMap.put(j, prob);
            }
            ++j;
        }
        this.itemProbs = Lists.sortMap(itemProbsMap);
        j = 0;
        while (j < numItems) {
            this.s.set(j, this.s.get(j) / maxUsers);
            ++j;
        }
        this.alpha = algoOptions.getFloat("-alpha");
        this.itemCorrs = this.buildCorrs(false);
    }

    @Override
    protected double correlation(SparseVector iv, SparseVector jv) {
        double sim = this.correlation(iv, jv, "cos-binary");
        if (Double.isNaN(sim)) {
            sim = 0.0;
        }
        return Math.tanh((double)this.alpha * sim);
    }

    @Override
    protected void buildModel() throws Exception {
        int iter = 1;
        while (iter <= numIters) {
            this.loss = 0.0;
            for (int u : this.trainMatrix.rows()) {
                SparseVector Ru = this.trainMatrix.row(u);
                for (VectorEntry ve : Ru) {
                    int i = ve.index();
                    double rui = ve.get();
                    int j = -1;
                    block3: do {
                        double sum = 0.0;
                        double rand = Randoms.random();
                        for (Map.Entry en : this.itemProbs) {
                            int k = (Integer)en.getKey();
                            double prob = (Double)en.getValue();
                            if (!((sum += prob) >= rand)) continue;
                            j = k;
                            continue block3;
                        }
                    } while (Ru.contains(j));
                    double ruj = 0.0;
                    double pui = this.predict(u, i);
                    double puj = this.predict(u, j);
                    double dij = Math.sqrt(1.0 - this.itemCorrs.get(i, j));
                    double sj = this.s.get(j);
                    double e = sj * (pui - puj - dij * (rui - ruj));
                    this.loss += e * e;
                    double ye = this.lRate * e;
                    int f = 0;
                    while (f < numFactors) {
                        double puf = this.P.get(u, f);
                        double qif = this.Q.get(i, f);
                        double qjf = this.Q.get(j, f);
                        this.P.add(u, f, -ye * (qif - qjf));
                        this.Q.add(i, f, -ye * puf);
                        this.Q.add(j, f, ye * puf);
                        ++f;
                    }
                }
            }
            this.loss *= 0.5;
            if (this.isConverged(iter)) break;
            ++iter;
        }
    }

    @Override
    public String toString() {
        return String.valueOf(super.toString()) + "," + this.alpha;
    }
}

