/*
 * Decompiled with CFR 0.152.
 */
package librec.ranking;

import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import librec.data.Configuration;
import librec.data.DenseMatrix;
import librec.data.SparseMatrix;
import librec.data.SparseVector;
import librec.data.SymmMatrix;
import librec.data.VectorEntry;
import librec.intf.IterativeRecommender;
import librec.util.Lists;
import librec.util.Logs;
import librec.util.Strings;

@Configuration(value="binThold, knn, regL2, regL1, similarity, iters")
public class SLIM
extends IterativeRecommender {
    private DenseMatrix W;
    private Multimap<Integer, Integer> itemNNs;
    private List<Integer> allItems;
    private float regL1;
    private float regL2;

    public SLIM(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        isRankingPred = true;
        this.regL1 = algoOptions.getFloat("-l1");
        this.regL2 = algoOptions.getFloat("-l2");
    }

    @Override
    protected void initModel() throws Exception {
        this.W = new DenseMatrix(numItems, numItems);
        this.W.init();
        this.userCache = this.trainMatrix.rowCache(cacheSpec);
        if (knn > 0) {
            SymmMatrix itemCorrs = this.buildCorrs(false);
            this.itemNNs = HashMultimap.create();
            int j = 0;
            while (j < numItems) {
                this.W.set(j, j, 0.0);
                Map<Integer, Double> nns = itemCorrs.row(j).toMap();
                if (knn > 0 && knn < nns.size()) {
                    List<Map.Entry<Integer, Double>> sorted = Lists.sortMap(nns, true);
                    List<Map.Entry<Integer, Double>> subset = sorted.subList(0, knn);
                    nns.clear();
                    for (Map.Entry<Integer, Double> kv : subset) {
                        nns.put(kv.getKey(), kv.getValue());
                    }
                }
                for (Map.Entry<Integer, Double> en : nns.entrySet()) {
                    this.itemNNs.put(j, en.getKey());
                }
                ++j;
            }
        } else {
            this.allItems = this.trainMatrix.columns();
            int j = 0;
            while (j < numItems) {
                this.W.set(j, j, 0.0);
                ++j;
            }
        }
    }

    @Override
    protected void buildModel() throws Exception {
        this.last_loss = 0.0;
        int iter = 1;
        while (iter <= numIters) {
            this.loss = 0.0;
            int j = 0;
            while (j < numItems) {
                List<Integer> nns = knn > 0 ? this.itemNNs.get(j) : this.allItems;
                for (Integer i : nns) {
                    double gradSum = 0.0;
                    double rateSum = 0.0;
                    double errs = 0.0;
                    SparseVector Ri = this.trainMatrix.column(i);
                    int N = Ri.getCount();
                    for (VectorEntry ve : Ri) {
                        int u = ve.index();
                        double rui = ve.get();
                        double ruj = this.trainMatrix.get(u, j);
                        double euj = ruj - this.predict(u, j, i);
                        gradSum += rui * euj;
                        rateSum += rui * rui;
                        errs += euj * euj;
                    }
                    gradSum /= (double)N;
                    rateSum /= (double)N;
                    double wij = this.W.get(i, j);
                    this.loss += (errs /= (double)N) + 0.5 * (double)this.regL2 * wij * wij + (double)this.regL1 * wij;
                    if ((double)this.regL1 < Math.abs(gradSum)) {
                        if (gradSum > 0.0) {
                            double update = (gradSum - (double)this.regL1) / ((double)this.regL2 + rateSum);
                            this.W.set(i, j, update);
                            continue;
                        }
                        double update = (gradSum + (double)this.regL1) / ((double)this.regL2 + rateSum);
                        this.W.set(i, j, update);
                        continue;
                    }
                    this.W.set(i, j, 0.0);
                }
                ++j;
            }
            if (this.isConverged(iter)) break;
            ++iter;
        }
    }

    protected double predict(int u, int j, int excluded_item) throws Exception {
        List<Integer> nns = knn > 0 ? this.itemNNs.get(j) : this.allItems;
        SparseVector Ru = (SparseVector)this.userCache.get(u);
        double pred = 0.0;
        Iterator iterator = nns.iterator();
        while (iterator.hasNext()) {
            int k = (Integer)iterator.next();
            if (!Ru.contains(k) || k == excluded_item) continue;
            double ruk = Ru.get(k);
            pred += ruk * this.W.get(k, j);
        }
        return pred;
    }

    @Override
    protected double predict(int u, int j) throws Exception {
        return this.predict(u, j, -1);
    }

    @Override
    protected boolean isConverged(int iter) {
        double delta_loss = this.last_loss - this.loss;
        this.last_loss = this.loss;
        if (verbose) {
            Logs.debug("{}{} iter {}: loss = {}, delta_loss = {}", this.algoName, this.foldInfo, iter, this.loss, delta_loss);
        }
        return iter > 1 ? delta_loss < 1.0E-5 : false;
    }

    @Override
    public String toString() {
        return Strings.toString(new Object[]{Float.valueOf(binThold), knn, Float.valueOf(this.regL2), Float.valueOf(this.regL1), similarityMeasure, numIters});
    }
}

