/*
 * Decompiled with CFR 0.152.
 */
package carskit.alg.baseline.cf;

import carskit.data.structure.SparseMatrix;
import carskit.generic.Recommender;
import happy.coding.io.Lists;
import happy.coding.io.Strings;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import librec.data.DenseVector;
import librec.data.SparseVector;
import librec.data.SymmMatrix;

public class UserKNN
extends Recommender {
    private SymmMatrix userCorrs;
    private DenseVector userMeans;

    public UserKNN(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        this.algoName = "UserKNN";
    }

    @Override
    protected void initModel() throws Exception {
        super.initModel();
        this.userCorrs = this.buildCorrs(true);
        this.userMeans = new DenseVector(this.numUsers);
        for (int u = 0; u < this.numUsers; ++u) {
            SparseVector uv = this.train.row(u);
            this.userMeans.set(u, uv.getCount() > 0 ? uv.mean() : this.globalMean);
        }
    }

    @Override
    protected double predict(int u, int j, int c) throws Exception {
        if (this.isUserSplitting) {
            int n = u = this.userIdMapper.contains(u, c) ? (Integer)this.userIdMapper.get(u, c) : u;
        }
        if (this.isItemSplitting) {
            j = this.itemIdMapper.contains(j, c) ? (Integer)this.itemIdMapper.get(j, c) : j;
        }
        return this.predict(u, j);
    }

    @Override
    protected double predict(int u, int j) throws Exception {
        HashMap<Integer, Double> nns = new HashMap<Integer, Double>();
        SparseVector dv = this.userCorrs.row(u);
        for (int v : dv.getIndex()) {
            double sim = dv.get(v);
            double rate = this.train.get(v, j);
            if (isRankingPred && rate > 0.0) {
                nns.put(v, sim);
                continue;
            }
            if (!(sim > 0.0) || !(rate > 0.0)) continue;
            nns.put(v, sim);
        }
        if (knn > 0 && knn < nns.size()) {
            List sorted = Lists.sortMap(nns, true);
            List subset = sorted.subList(0, knn);
            nns.clear();
            for (Map.Entry kv : subset) {
                nns.put((Integer)kv.getKey(), (Double)kv.getValue());
            }
        }
        if (nns.size() == 0) {
            return this.globalMean;
        }
        double sum = 0.0;
        double ws = 0.0;
        for (Map.Entry en : nns.entrySet()) {
            int v = (Integer)en.getKey();
            double sim = (Double)en.getValue();
            double rate = this.train.get(v, j);
            sum += sim * (rate - this.userMeans.get(v));
            ws += Math.abs(sim);
        }
        return ws > 0.0 ? this.userMeans.get(u) + sum / ws : this.globalMean;
    }

    public String toString() {
        return Strings.toString(new Object[]{knn, similarityMeasure, similarityShrinkage});
    }
}

