/*
 * Decompiled with CFR 0.152.
 */
package carskit.alg.baseline.ranking;

import carskit.data.structure.SparseMatrix;
import carskit.generic.IterativeRecommender;
import happy.coding.io.Lists;
import happy.coding.io.Strings;
import happy.coding.math.Randoms;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import librec.data.SparseVector;
import librec.data.VectorEntry;

public class RankSGD
extends IterativeRecommender {
    protected List<Map.Entry<Integer, Double>> itemProbs;

    public RankSGD(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        isRankingPred = true;
        this.algoName = "RankSGD";
        this.checkBinary();
    }

    @Override
    protected void initModel() throws Exception {
        super.initModel();
        HashMap<Integer, Double> itemProbsMap = new HashMap<Integer, Double>();
        for (int j = 0; j < this.numItems; ++j) {
            int users = this.train.columnSize(j);
            double prob = ((double)users + 0.0) / (double)this.numRates;
            if (!(prob > 0.0)) continue;
            itemProbsMap.put(j, prob);
        }
        this.itemProbs = Lists.sortMap(itemProbsMap);
    }

    @Override
    protected void buildModel() throws Exception {
        for (int iter = 1; iter <= numIters; ++iter) {
            this.loss = 0.0;
            for (int u : this.train.rows()) {
                SparseVector Ru = this.train.row(u);
                for (VectorEntry ve : Ru) {
                    int i = ve.index();
                    double rui = ve.get();
                    int j = -1;
                    block3: do {
                        double sum = 0.0;
                        double rand = Randoms.random();
                        for (Map.Entry<Integer, Double> en : this.itemProbs) {
                            int k = en.getKey();
                            double prob = en.getValue();
                            if (!((sum += prob) >= rand)) continue;
                            j = k;
                            continue block3;
                        }
                    } while (Ru.contains(j));
                    double ruj = 0.0;
                    double pui = this.predict(u, i);
                    double puj = this.predict(u, j);
                    double e = pui - puj - (rui - ruj);
                    this.loss += e * e;
                    double ye = this.lRate * e;
                    for (int f = 0; f < numFactors; ++f) {
                        double puf = this.P.get(u, f);
                        double qif = this.Q.get(i, f);
                        double qjf = this.Q.get(j, f);
                        this.P.add(u, f, -ye * (qif - qjf));
                        this.Q.add(i, f, -ye * puf);
                        this.Q.add(j, f, ye * puf);
                    }
                }
            }
            this.loss *= 0.5;
            if (this.isConverged(iter)) break;
        }
    }

    @Override
    public String toString() {
        return Strings.toString(new Object[]{Float.valueOf(binThold), Float.valueOf(initLRate), numIters}, ",");
    }
}

