/*
 * Decompiled with CFR 0.152.
 */
package carskit.alg.cars.transformation.prefiltering;

import carskit.data.structure.SparseMatrix;
import carskit.eval.Measures;
import carskit.generic.Recommender;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Table;
import happy.coding.io.FileIO;
import happy.coding.io.Lists;
import happy.coding.io.Logs;
import happy.coding.io.Strings;
import happy.coding.math.Stats;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import librec.data.DenseVector;
import librec.data.SparseVector;
import librec.data.SymmMatrix;

public class ExactFiltering
extends Recommender {
    private SymmMatrix userCorrs;
    private DenseVector userMeans;
    private SparseMatrix sm;

    public ExactFiltering(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        this.algoName = "ExactFiltering";
    }

    @Override
    protected void initModel() throws Exception {
        super.initModel();
        if (!isRankingPred) {
            this.userCorrs = this.buildCorrs(true);
            this.userMeans = new DenseVector(this.numUsers);
            for (int u = 0; u < this.numUsers; ++u) {
                SparseVector uv = this.train.row(u);
                this.userMeans.set(u, uv.getCount() > 0 ? uv.mean() : this.globalMean);
            }
        }
    }

    @Override
    protected double predict(int a, int t, int c) throws Exception {
        if (isRankingPred) {
            double pred = 0.0;
            HashMap<Integer, Double> nns = new HashMap<Integer, Double>();
            HashMap<Integer, Double> nns_ratings = new HashMap<Integer, Double>();
            SparseVector sv = this.userCorrs.row(a);
            for (int u : sv.getIndex()) {
                double rate;
                if (nns.size() >= knn) break;
                double sim = sv.get(u);
                if (!(sim > 0.0) || !((rate = this.sm.get(u, t)) > 0.0)) continue;
                nns.put(u, sim);
                nns_ratings.put(u, rate);
            }
            List sorted = Lists.sortMap(nns, true);
            int k = nns.size();
            if (k != 0) {
                k = k > knn ? knn : k;
                List subset = sorted.subList(0, k);
                nns.clear();
                for (Map.Entry kv : subset) {
                    nns.put((Integer)kv.getKey(), (Double)kv.getValue());
                }
                double sum1 = 0.0;
                double sum2 = 0.0;
                for (Map.Entry en : nns.entrySet()) {
                    int ngbr = (Integer)en.getKey();
                    sum2 += ((Double)en.getValue()).doubleValue();
                    sum1 += ((Double)nns_ratings.get(ngbr) - this.userMeans.get(ngbr)) * (Double)en.getValue();
                }
                pred = this.userMeans.get(a) + sum1 / sum2;
            }
            return pred > 0.0 ? pred : this.userMeans.get(a);
        }
        double pred = 0.0;
        double part3 = 0.0;
        double part3_count = 0.0;
        HashMap part22 = new HashMap();
        HashMap part22_count = new HashMap();
        HashMap part21 = new HashMap();
        Set<Integer> users = rateDao.getUserList(this.trainMatrix);
        HashMap<Integer, Double> nns = new HashMap<Integer, Double>();
        HashMap<Integer, Double> nns_ratings = new HashMap<Integer, Double>();
        for (int u : users) {
            double sim;
            double rate;
            int ui;
            if (u == a || (ui = rateDao.getUserItemId(u + "," + t)) == -1 || !((rate = this.trainMatrix.get(ui, c)) > 0.0) || !((sim = this.userCorrs.get(a, u)) > 0.0)) continue;
            nns.put(u, sim);
            nns_ratings.put(u, rate);
        }
        List sorted = Lists.sortMap(nns, true);
        int k = nns.size();
        if (k != 0) {
            k = k > knn ? knn : k;
            List subset = sorted.subList(0, k);
            nns.clear();
            for (Map.Entry kv : subset) {
                nns.put((Integer)kv.getKey(), (Double)kv.getValue());
            }
            double sum1 = 0.0;
            double sum2 = 0.0;
            for (Map.Entry en : nns.entrySet()) {
                int ngbr = (Integer)en.getKey();
                sum2 += ((Double)en.getValue()).doubleValue();
                sum1 += ((Double)nns_ratings.get(ngbr) - this.userMeans.get(ngbr)) * (Double)en.getValue();
            }
            pred = this.userMeans.get(a) + sum1 / sum2;
        }
        return pred > 0.0 ? pred : this.userMeans.get(a);
    }

    protected SparseMatrix getUIMatrix(int ctx) {
        HashBasedTable<Integer, Integer, Double> dataTable_ui = HashBasedTable.create();
        HashMultimap<Integer, Integer> colMap = HashMultimap.create();
        SparseVector sv = this.trainMatrix.column(ctx);
        for (int ui : sv.getIndex()) {
            int u = rateDao.getUserIdFromUI(ui);
            int j = rateDao.getItemIdFromUI(ui);
            dataTable_ui.put(u, j, sv.get(ui));
            colMap.put(j, u);
        }
        return new SparseMatrix(this.numUsers, this.numItems, (Table<Integer, Integer, Double>)dataTable_ui, (Multimap<Integer, Integer>)colMap);
    }

    @Override
    protected Map<Recommender.Measure, Double> evalRankings() throws Exception {
        int numTopNRanks;
        HashMap<Integer, HashMultimap<Integer, Integer>> cuiList = rateDao.getCtxUserList(this.testMatrix, binThold);
        HashMap<Integer, HashMultimap<Integer, Integer>> cuiList_train = rateDao.getCtxUserList(this.trainMatrix);
        int capacity = cuiList.keySet().size();
        ArrayList<Double> ds5 = new ArrayList<Double>(isDiverseUsed ? capacity : 0);
        ArrayList<Double> ds10 = new ArrayList<Double>(isDiverseUsed ? capacity : 0);
        ArrayList<Double> dsN = new ArrayList<Double>(isDiverseUsed ? capacity : 0);
        ArrayList<Double> precs5 = new ArrayList<Double>(capacity);
        ArrayList<Double> precs10 = new ArrayList<Double>(capacity);
        ArrayList<Double> precsN = new ArrayList<Double>(capacity);
        ArrayList<Double> recalls5 = new ArrayList<Double>(capacity);
        ArrayList<Double> recalls10 = new ArrayList<Double>(capacity);
        ArrayList<Double> recallsN = new ArrayList<Double>(capacity);
        ArrayList<Double> aps5 = new ArrayList<Double>(capacity);
        ArrayList<Double> aps10 = new ArrayList<Double>(capacity);
        ArrayList<Double> apsN = new ArrayList<Double>(capacity);
        ArrayList<Double> rrs5 = new ArrayList<Double>(capacity);
        ArrayList<Double> rrs10 = new ArrayList<Double>(capacity);
        ArrayList<Double> rrsN = new ArrayList<Double>(capacity);
        ArrayList<Double> aucs5 = new ArrayList<Double>(capacity);
        ArrayList<Double> aucs10 = new ArrayList<Double>(capacity);
        ArrayList<Double> aucsN = new ArrayList<Double>(capacity);
        ArrayList<Double> ndcgs5 = new ArrayList<Double>(capacity);
        ArrayList<Double> ndcgs10 = new ArrayList<Double>(capacity);
        ArrayList<Double> ndcgsN = new ArrayList<Double>(capacity);
        Set<Integer> candItems = rateDao.getItemList(this.trainMatrix);
        ArrayList<String> preds = null;
        String toFile = null;
        int n = numTopNRanks = numRecs < 0 ? 10 : numRecs;
        if (this.isResultsOut) {
            preds = new ArrayList<String>(1500);
            preds.add("# userId: recommendations in (itemId, ranking score) pairs, where a correct recommendation is denoted by symbol *.");
            toFile = workingPath + String.format("%s-top-%d-items%s.txt", this.algoName, numTopNRanks, this.foldInfo);
            FileIO.deleteFile(toFile);
        }
        if (verbose) {
            Logs.debug("{}{} has candidate items: {}", this.algoName, this.foldInfo, candItems.size());
        }
        if (numIgnore > 0) {
            ArrayList itemDegs = new ArrayList();
            for (Integer j : candItems) {
                itemDegs.add(new AbstractMap.SimpleImmutableEntry<Integer, Integer>(j, rateDao.getRatingCountByItem(this.trainMatrix, j)));
            }
            Lists.sortList(itemDegs, true);
            int k = 0;
            for (Map.Entry entry : itemDegs) {
                candItems.remove(entry.getKey());
                if (++k < numIgnore) continue;
                break;
            }
        }
        for (int ctx : cuiList.keySet()) {
            SparseMatrix UIM;
            Multimap uis = cuiList.get(ctx);
            int n2 = uis.keySet().size();
            ArrayList<Double> c_ds5 = new ArrayList<Double>(isDiverseUsed ? n2 : 0);
            ArrayList<Double> c_ds10 = new ArrayList<Double>(isDiverseUsed ? n2 : 0);
            ArrayList<Double> c_dsN = new ArrayList<Double>(isDiverseUsed ? n2 : 0);
            ArrayList<Double> c_precs5 = new ArrayList<Double>(n2);
            ArrayList<Double> c_precs10 = new ArrayList<Double>(n2);
            ArrayList<Double> c_precsN = new ArrayList<Double>(n2);
            ArrayList<Double> c_recalls5 = new ArrayList<Double>(n2);
            ArrayList<Double> c_recalls10 = new ArrayList<Double>(n2);
            ArrayList<Double> c_recallsN = new ArrayList<Double>(n2);
            ArrayList<Double> c_aps5 = new ArrayList<Double>(n2);
            ArrayList<Double> c_aps10 = new ArrayList<Double>(n2);
            ArrayList<Double> c_apsN = new ArrayList<Double>(n2);
            ArrayList<Double> c_rrs5 = new ArrayList<Double>(n2);
            ArrayList<Double> c_rrs10 = new ArrayList<Double>(n2);
            ArrayList<Double> c_rrsN = new ArrayList<Double>(n2);
            ArrayList<Double> c_aucs5 = new ArrayList<Double>(n2);
            ArrayList<Double> c_aucs10 = new ArrayList<Double>(n2);
            ArrayList<Double> c_aucsN = new ArrayList<Double>(n2);
            ArrayList<Double> c_ndcgs5 = new ArrayList<Double>(n2);
            ArrayList<Double> c_ndcgs10 = new ArrayList<Double>(n2);
            ArrayList<Double> c_ndcgsN = new ArrayList<Double>(n2);
            HashMultimap uList_train = cuiList_train.containsKey(ctx) ? cuiList_train.get(ctx) : HashMultimap.create();
            this.sm = null;
            this.userCorrs = null;
            this.userMeans = null;
            this.sm = UIM = this.getUIMatrix(ctx);
            this.userCorrs = this.buildCorrs(true, UIM);
            this.userMeans = new DenseVector(this.numUsers);
            for (int u = 0; u < this.numUsers; ++u) {
                SparseVector uv = UIM.row(u);
                this.userMeans.set(u, uv.getCount() > 0 ? uv.mean() : this.globalMean);
            }
            Iterator i$ = uis.keySet().iterator();
            while (i$.hasNext()) {
                int u = (Integer)i$.next();
                if (verbose && (u + 1) % 100 == 0) {
                    Logs.debug("{}{} evaluates progress: {} / {}", this.algoName, this.foldInfo, u + 1, capacity);
                }
                int numCands = candItems.size();
                Collection posItems = uis.get(u);
                ArrayList<Integer> correctItems = new ArrayList<Integer>();
                for (Integer j : posItems) {
                    if (!candItems.contains(j)) continue;
                    correctItems.add(j);
                }
                if (correctItems.size() == 0) continue;
                Set ratedItems = uList_train.containsKey(u) ? uList_train.get((Object)u) : new HashSet();
                ArrayList itemScores = new ArrayList(Lists.initSize(candItems));
                for (Integer j : candItems) {
                    if (!ratedItems.contains(j)) {
                        double rank = this.ranking(u, j, ctx);
                        if (Double.isNaN(rank) || !(rank > (double)binThold)) continue;
                        itemScores.add(new AbstractMap.SimpleImmutableEntry<Integer, Double>(j, rank));
                        continue;
                    }
                    --numCands;
                }
                if (itemScores.size() == 0) continue;
                Lists.sortList(itemScores, true);
                ArrayList recomd = numRecs <= 0 || itemScores.size() <= numRecs ? itemScores : itemScores.subList(0, numRecs);
                ArrayList<Integer> rankedItems = new ArrayList<Integer>();
                StringBuilder sb = new StringBuilder();
                int count = 0;
                for (Map.Entry entry : recomd) {
                    Integer item = (Integer)entry.getKey();
                    rankedItems.add(item);
                    if (!this.isResultsOut || count >= numTopNRanks) continue;
                    sb.append("(").append(rateDao.getItemId(item));
                    if (posItems.contains(item)) {
                        sb.append("*");
                    }
                    sb.append(", ").append(((Double)entry.getValue()).floatValue()).append(")");
                    if (++count >= numTopNRanks) break;
                    if (count >= numTopNRanks) continue;
                    sb.append(", ");
                }
                int numDropped = numCands - rankedItems.size();
                List<Integer> list = Arrays.asList(5, 10, numRecs);
                Map<Integer, Double> precs = Measures.PrecAt(rankedItems, correctItems, list);
                Map<Integer, Double> recalls = Measures.RecallAt(rankedItems, correctItems, list);
                Map<Integer, Double> aucs = Measures.AUCAt(rankedItems, correctItems, numDropped, list);
                Map<Integer, Double> aps = Measures.APAt(rankedItems, correctItems, list);
                Map<Integer, Double> ndcgs = Measures.nDCGAt(rankedItems, correctItems, list);
                Map<Integer, Double> rrs = Measures.RRAt(rankedItems, correctItems, list);
                c_precs5.add(precs.get(5));
                c_precs10.add(precs.get(10));
                c_precsN.add(precs.get(numRecs));
                c_recalls5.add(recalls.get(5));
                c_recalls10.add(recalls.get(10));
                c_recallsN.add(recalls.get(numRecs));
                c_aucs5.add(aucs.get(5));
                c_aps5.add(aps.get(5));
                c_rrs5.add(rrs.get(5));
                c_ndcgs5.add(ndcgs.get(5));
                c_aucs10.add(aucs.get(10));
                c_aps10.add(aps.get(10));
                c_rrs10.add(rrs.get(10));
                c_ndcgs10.add(ndcgs.get(10));
                c_aucsN.add(aucs.get(numRecs));
                c_apsN.add(aps.get(numRecs));
                c_rrsN.add(rrs.get(numRecs));
                c_ndcgsN.add(ndcgs.get(numRecs));
                if (isDiverseUsed) {
                    double d5 = this.diverseAt(rankedItems, 5);
                    double d10 = this.diverseAt(rankedItems, 10);
                    double dN = this.diverseAt(rankedItems, numRecs);
                    c_ds5.add(d5);
                    c_ds10.add(d10);
                    c_dsN.add(dN);
                }
                if (!this.isResultsOut) continue;
                preds.add(rateDao.getUserId(u) + ", " + rateDao.getContextSituationFromInnerId(ctx) + ": " + sb.toString());
                if (preds.size() < 1000) continue;
                FileIO.writeList(toFile, preds, true);
                preds.clear();
            }
            ds5.add(isDiverseUsed ? Stats.mean(c_ds5) : 0.0);
            ds10.add(isDiverseUsed ? Stats.mean(c_ds10) : 0.0);
            dsN.add(isDiverseUsed ? Stats.mean(c_dsN) : 0.0);
            precs5.add(Stats.mean(c_precs5));
            precs10.add(Stats.mean(c_precs10));
            precsN.add(Stats.mean(c_precsN));
            recalls5.add(Stats.mean(c_recalls5));
            recalls10.add(Stats.mean(c_recalls10));
            recallsN.add(Stats.mean(c_recallsN));
            aucs5.add(Stats.mean(c_aucs5));
            ndcgs5.add(Stats.mean(c_ndcgs5));
            aps5.add(Stats.mean(c_aps5));
            rrs5.add(Stats.mean(c_rrs5));
            aucs10.add(Stats.mean(c_aucs10));
            ndcgs10.add(Stats.mean(c_ndcgs10));
            aps10.add(Stats.mean(c_aps10));
            rrs10.add(Stats.mean(c_rrs10));
            aucsN.add(Stats.mean(c_aucsN));
            ndcgsN.add(Stats.mean(c_ndcgsN));
            apsN.add(Stats.mean(c_apsN));
            rrsN.add(Stats.mean(c_rrsN));
        }
        if (this.isResultsOut && preds.size() > 0) {
            FileIO.writeList(toFile, preds, true);
            Logs.debug("{}{} has writeen item recommendations to {}", this.algoName, this.foldInfo, toFile);
        }
        HashMap<Recommender.Measure, Double> measures = new HashMap<Recommender.Measure, Double>();
        measures.put(Recommender.Measure.D5, isDiverseUsed ? Stats.mean(ds5) : 0.0);
        measures.put(Recommender.Measure.D10, isDiverseUsed ? Stats.mean(ds10) : 0.0);
        measures.put(Recommender.Measure.DN, isDiverseUsed ? Stats.mean(dsN) : 0.0);
        measures.put(Recommender.Measure.Pre5, Stats.mean(precs5));
        measures.put(Recommender.Measure.Pre10, Stats.mean(precs10));
        measures.put(Recommender.Measure.PreN, Stats.mean(precsN));
        measures.put(Recommender.Measure.Rec5, Stats.mean(recalls5));
        measures.put(Recommender.Measure.Rec10, Stats.mean(recalls10));
        measures.put(Recommender.Measure.RecN, Stats.mean(recallsN));
        measures.put(Recommender.Measure.AUC5, Stats.mean(aucs5));
        measures.put(Recommender.Measure.NDCG5, Stats.mean(ndcgs5));
        measures.put(Recommender.Measure.MAP5, Stats.mean(aps5));
        measures.put(Recommender.Measure.MRR5, Stats.mean(rrs5));
        measures.put(Recommender.Measure.AUC10, Stats.mean(aucs10));
        measures.put(Recommender.Measure.NDCG10, Stats.mean(ndcgs10));
        measures.put(Recommender.Measure.MAP10, Stats.mean(aps10));
        measures.put(Recommender.Measure.MRR10, Stats.mean(rrs10));
        measures.put(Recommender.Measure.AUCN, Stats.mean(aucsN));
        measures.put(Recommender.Measure.NDCGN, Stats.mean(ndcgsN));
        measures.put(Recommender.Measure.MAPN, Stats.mean(apsN));
        measures.put(Recommender.Measure.MRRN, Stats.mean(rrsN));
        return measures;
    }

    public String toString() {
        return Strings.toString(new Object[]{knn, similarityMeasure, similarityShrinkage});
    }
}

