/*
 * Decompiled with CFR 0.152.
 */
package carskit.alg.cars.adaptation.dependent.sim;

import carskit.alg.cars.adaptation.dependent.CSLIM;
import carskit.data.setting.Configuration;
import carskit.data.structure.SparseMatrix;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import happy.coding.io.Lists;
import happy.coding.io.Logs;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.MatrixEntry;
import librec.data.SparseVector;
import librec.data.SymmMatrix;

@Configuration(value="binThold, knn, regLw2, regLw1, similarity, iters, rc")
public class GCSLIM_MCS
extends CSLIM {
    private DenseMatrix W;
    private Multimap<Integer, Integer> itemNNs;
    private List<Integer> allItems;
    private double upbound;
    private double lowbound;

    public GCSLIM_MCS(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        isRankingPred = true;
        this.isCARSRecommender = false;
        this.algoName = "GCSLIM_MCS";
        this.regLw1 = algoOptions.getFloat("-lw1");
        this.regLw2 = algoOptions.getFloat("-lw2");
        knn = algoOptions.getInt("-k");
    }

    @Override
    protected void initModel() throws Exception {
        super.initModel();
        this.upbound = 1.0 / Math.sqrt(rateDao.numContextDims());
        this.lowbound = 1.0 / Math.pow(10.0, 100.0);
        this.cVector_MCS = new DenseVector(numConditions);
        this.cVector_MCS.init(this.upbound);
        this.W = new DenseMatrix(this.numItems, this.numItems);
        this.W.init();
        this.userCache = this.train.rowCache(cacheSpec);
        if (knn > 0) {
            SymmMatrix itemCorrs = this.buildCorrs(false);
            this.itemNNs = HashMultimap.create();
            for (int j = 0; j < this.numItems; ++j) {
                this.W.set(j, j, 0.0);
                Map<Integer, Double> nns = itemCorrs.row(j).toMap();
                if (knn > 0 && knn < nns.size()) {
                    List<Map.Entry<Integer, Double>> sorted = Lists.sortMap(nns, true);
                    List<Map.Entry<Integer, Double>> subset = sorted.subList(0, knn);
                    nns.clear();
                    for (Map.Entry<Integer, Double> kv : subset) {
                        nns.put(kv.getKey(), kv.getValue());
                    }
                }
                for (Map.Entry<Integer, Double> en : nns.entrySet()) {
                    this.itemNNs.put(j, en.getKey());
                }
            }
        } else {
            this.allItems = this.train.columns();
            for (int j = 0; j < this.numItems; ++j) {
                this.W.set(j, j, 0.0);
            }
        }
    }

    @Override
    protected void buildModel() throws Exception {
        for (int iter = 1; iter <= numIters; ++iter) {
            this.loss = 0.0;
            for (MatrixEntry me : this.trainMatrix) {
                Iterator i$;
                int index1;
                Iterator i$2;
                int ui = me.row();
                int u = rateDao.getUserIdFromUI(ui);
                int j = rateDao.getItemIdFromUI(ui);
                int c = me.column();
                double rujc = me.get();
                List<Integer> conditions = this.getConditions(c);
                HashBasedTable toBeUpdated_sim_factor = HashBasedTable.create();
                HashBasedTable toBeUpdated_w_factor = HashBasedTable.create();
                List<Integer> nns = knn > 0 ? this.itemNNs.get(j) : this.allItems;
                SparseVector Ru = (SparseVector)this.userCache.get(u);
                double pred = 0.0;
                Iterator i$3 = nns.iterator();
                while (i$3.hasNext()) {
                    int k = (Integer)i$3.next();
                    if (!Ru.contains(k) || k == j) continue;
                    String key = u + "," + k;
                    int uiid = rateDao.getUserItemId(key);
                    List<Integer> ctxid = this.trainMatrix.getColumns(uiid);
                    Random r = new Random();
                    int index = r.nextInt(ctxid.size());
                    int ctx = ctxid.get(index);
                    List<Integer> conditions_from = this.getConditions(ctx);
                    double ruk = this.trainMatrix.get(uiid, ctx);
                    double rating = ruk * this.W.get(k, j);
                    double dist = 0.0;
                    for (int i = 0; i < conditions.size(); ++i) {
                        int index12 = conditions.get(i);
                        int index2 = conditions_from.get(i);
                        double pos1 = this.cVector_MCS.get(index12);
                        double pos2 = this.cVector_MCS.get(index2);
                        double diff = pos1 - pos2;
                        dist += Math.pow(diff, 2.0);
                        if (index12 == index2) continue;
                        double update = rating * diff;
                        if (toBeUpdated_sim_factor.contains(index12, index2)) {
                            update += ((Double)toBeUpdated_sim_factor.get(index12, index2)).doubleValue();
                        }
                        toBeUpdated_sim_factor.put((Object)index12, (Object)index2, (Object)update);
                    }
                    if ((dist = Math.sqrt(dist)) == 0.0) {
                        dist = this.lowbound;
                    }
                    Iterator i$4 = toBeUpdated_sim_factor.rowKeySet().iterator();
                    while (i$4.hasNext()) {
                        int row = (Integer)i$4.next();
                        Iterator i$5 = toBeUpdated_sim_factor.row((Object)row).keySet().iterator();
                        while (i$5.hasNext()) {
                            int col = (Integer)i$5.next();
                            toBeUpdated_sim_factor.put((Object)row, (Object)col, (Object)((Double)toBeUpdated_sim_factor.get(row, col) / dist));
                        }
                    }
                    double sim = 1.0 - dist;
                    pred += rating * sim;
                    toBeUpdated_w_factor.put((Object)k, (Object)j, (Object)(ruk * sim));
                }
                double eujc = rujc - pred;
                this.loss += eujc * eujc;
                if (toBeUpdated_sim_factor.size() > 0) {
                    i$2 = toBeUpdated_sim_factor.rowKeySet().iterator();
                    while (i$2.hasNext()) {
                        index1 = (Integer)i$2.next();
                        i$ = toBeUpdated_sim_factor.row((Object)index1).keySet().iterator();
                        while (i$.hasNext()) {
                            int index2 = (Integer)i$.next();
                            double pos1 = this.cVector_MCS.get(index1);
                            double pos2 = this.cVector_MCS.get(index2);
                            double pos1_update = pos1 + this.lRate * (eujc * (Double)toBeUpdated_sim_factor.get(index1, index2) - (double)regC * pos1);
                            double pos2_update = pos2 - this.lRate * (eujc * (Double)toBeUpdated_sim_factor.get(index1, index2) + (double)regC * pos2);
                            pos1_update = pos1_update < 0.0 ? this.lowbound : pos1_update;
                            pos1_update = pos1_update > this.upbound ? this.upbound - this.lowbound : pos1_update;
                            pos2_update = pos2_update < 0.0 ? this.lowbound : pos2_update;
                            pos2_update = pos2_update > this.upbound ? this.upbound - this.lowbound : pos2_update;
                            this.cVector_MCS.set(index1, pos1_update);
                            this.cVector_MCS.set(index2, pos2_update);
                        }
                    }
                }
                if (toBeUpdated_w_factor.size() <= 0) continue;
                i$2 = toBeUpdated_w_factor.rowKeySet().iterator();
                while (i$2.hasNext()) {
                    index1 = (Integer)i$2.next();
                    i$ = toBeUpdated_w_factor.row((Object)index1).keySet().iterator();
                    while (i$.hasNext()) {
                        int index2 = (Integer)i$.next();
                        double update = this.W.get(index1, index2);
                        this.loss += (double)this.regLw2 * update * update + (double)this.regLw1 * update;
                        double delta_w = eujc * (Double)toBeUpdated_w_factor.get(index1, index2) - (double)this.regLw2 * update - (double)this.regLw1;
                        this.W.set(index1, index2, update += this.lRate * delta_w);
                    }
                }
            }
        }
    }

    protected double predict(int u, int j, int c, boolean exclude, int excluded_item) throws Exception {
        List<Integer> nns = knn > 0 ? this.itemNNs.get(j) : this.allItems;
        SparseVector Ru = (SparseVector)this.userCache.get(u);
        List<Integer> conditions = this.getConditions(c);
        double pred = 0.0;
        Iterator i$ = nns.iterator();
        while (i$.hasNext()) {
            int k = (Integer)i$.next();
            if (!Ru.contains(k) || exclude && k == excluded_item) continue;
            String key = u + "," + k;
            int uiid = rateDao.getUserItemId(key);
            List<Integer> ctxid = this.trainMatrix.getColumns(uiid);
            Random r = new Random();
            int index = r.nextInt(ctxid.size());
            int ctx = ctxid.get(index);
            List<Integer> conditions_from = this.getConditions(ctx);
            double ruk = this.trainMatrix.get(uiid, ctx);
            double dist = 0.0;
            for (int i = 0; i < conditions.size(); ++i) {
                int index1 = conditions.get(i);
                int index2 = conditions_from.get(i);
                dist += Math.pow(this.cVector_MCS.get(index1) - this.cVector_MCS.get(index2), 2.0);
            }
            dist = Math.sqrt(dist);
            double sim = 1.0 - dist;
            pred += ruk * this.W.get(k, j) * sim;
        }
        return pred;
    }

    @Override
    protected double predict(int u, int j, int c) throws Exception {
        return this.predict(u, j, c, true, j);
    }

    @Override
    protected boolean isConverged(int iter) {
        double delta_loss = this.last_loss - this.loss;
        this.last_loss = this.loss;
        if (verbose) {
            Logs.debug("{}{} iter {}: loss = {}, delta_loss = {}", this.algoName, this.foldInfo, iter, this.loss, delta_loss);
        }
        return iter > 1 ? delta_loss < 1.0E-5 : false;
    }
}

