/*
 * Decompiled with CFR 0.152.
 */
package librec.rating;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import librec.data.Configuration;
import librec.data.DenseVector;
import librec.data.SparseMatrix;
import librec.data.SparseVector;
import librec.data.SymmMatrix;
import librec.intf.Recommender;
import librec.util.Lists;
import librec.util.Stats;
import librec.util.Strings;

@Configuration(value="knn, similarity, shrinkage")
public class UserKNN
extends Recommender {
    private SymmMatrix userCorrs;
    private DenseVector userMeans;

    public UserKNN(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
    }

    @Override
    protected void initModel() throws Exception {
        this.userCorrs = this.buildCorrs(true);
        this.userMeans = new DenseVector(numUsers);
        int u = 0;
        while (u < numUsers) {
            SparseVector uv = this.trainMatrix.row(u);
            this.userMeans.set(u, uv.getCount() > 0 ? uv.mean() : this.globalMean);
            ++u;
        }
    }

    @Override
    protected double predict(int u, int j) {
        HashMap<Integer, Double> nns = new HashMap<Integer, Double>();
        SparseVector dv = this.userCorrs.row(u);
        Object object = dv.getIndex();
        int n = ((int[])object).length;
        int n2 = 0;
        while (n2 < n) {
            int v = object[n2];
            double sim = dv.get(v);
            double rate = this.trainMatrix.get(v, j);
            if (isRankingPred && rate > 0.0) {
                nns.put(v, sim);
            } else if (sim > 0.0 && rate > 0.0) {
                nns.put(v, sim);
            }
            ++n2;
        }
        if (knn > 0 && knn < nns.size()) {
            List sorted = Lists.sortMap(nns, true);
            List subset = sorted.subList(0, knn);
            nns.clear();
            object = subset.iterator();
            while (object.hasNext()) {
                Map.Entry kv = (Map.Entry)object.next();
                nns.put((Integer)kv.getKey(), (Double)kv.getValue());
            }
        }
        if (nns.size() == 0) {
            return isRankingPred ? 0.0 : this.globalMean;
        }
        if (isRankingPred) {
            return Stats.sum(nns.values());
        }
        double sum = 0.0;
        double ws = 0.0;
        for (Map.Entry en : nns.entrySet()) {
            int v = (Integer)en.getKey();
            double sim = (Double)en.getValue();
            double rate = this.trainMatrix.get(v, j);
            sum += sim * (rate - this.userMeans.get(v));
            ws += Math.abs(sim);
        }
        return ws > 0.0 ? this.userMeans.get(u) + sum / ws : this.globalMean;
    }

    @Override
    public String toString() {
        return Strings.toString(new Object[]{knn, similarityMeasure, similarityShrinkage});
    }
}

