/*
 * Decompiled with CFR 0.152.
 */
package carskit.alg.cars.transformation.hybridfiltering;

import carskit.alg.cars.transformation.hybridfiltering.Particle_CFPSO;
import carskit.data.structure.SparseMatrix;
import carskit.generic.IterativeRecommender;
import happy.coding.io.Lists;
import happy.coding.io.Logs;
import happy.coding.io.Strings;
import happy.coding.math.Stats;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import librec.data.DenseVector;
import librec.data.MatrixEntry;
import librec.data.SparseVector;
import librec.data.SymmMatrix;

public class DCW
extends IterativeRecommender {
    private SymmMatrix userCorrs;
    private DenseVector userMeans;
    private int p;
    private double lp;
    private double lg;
    private double wt;
    private double wd;
    private double w;
    private double th;
    private int num_dim;
    private int num_component = 3;
    double p1 = 3.0;
    double p2 = 4.0;
    private DenseVector pos_gbest;
    private double fitness_gbest;
    private Particle_CFPSO[] swarm;
    private int len;
    private int start = -1;
    private String sol = "";

    public DCW(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        this.algoName = "DCW";
        this.lp = algoOptions.getDouble("-lp");
        this.lg = algoOptions.getDouble("-lg");
        this.wt = algoOptions.getDouble("-wt");
        this.wd = algoOptions.getDouble("-wd");
        this.p = algoOptions.getInt("-p");
        this.th = algoOptions.getDouble("-th");
        this.sol = algoOptions.getString("-sol", "");
    }

    @Override
    protected void initModel() throws Exception {
        super.initModel();
        this.userCorrs = this.buildCorrs(true);
        this.userMeans = new DenseVector(this.numUsers);
        for (int u = 0; u < this.numUsers; ++u) {
            SparseVector uv = this.train.row(u);
            this.userMeans.set(u, uv.getCount() > 0 ? uv.mean() : this.globalMean);
        }
        this.num_dim = rateDao.numContextDims();
        this.fitness_gbest = Double.MAX_VALUE;
        this.len = this.num_dim * this.num_component;
        this.pos_gbest = new DenseVector(this.len);
        this.swarm = new Particle_CFPSO[this.p];
        for (int i = 0; i < this.p; ++i) {
            this.swarm[i] = new Particle_CFPSO(this.len);
        }
    }

    @Override
    protected void buildModel() throws Exception {
        if (this.sol.equals("")) {
            for (int i = 0; i < this.p; ++i) {
                Particle_CFPSO bp = this.swarm[i];
                for (int iter = 1; iter <= numIters; ++iter) {
                    double loss = 0.0;
                    for (MatrixEntry me : this.trainMatrix) {
                        int ui = me.row();
                        int u = rateDao.getUserIdFromUI(ui);
                        int j = rateDao.getItemIdFromUI(ui);
                        int ctx = me.column();
                        double rujc = me.get();
                        double predication = this.predict(u, j, ctx, bp.pos);
                        loss += Math.pow(rujc - predication, 2.0);
                    }
                    if (this.start == -1) {
                        this.start = 0;
                        if (loss < this.fitness_gbest) {
                            this.fitness_gbest = loss;
                            this.pos_gbest = bp.pos.clone();
                        }
                    }
                    DenseVector tmp = bp.pos.clone();
                    if (loss < bp.fitness_best) {
                        bp.fitness_best = loss;
                        bp.pos_best = bp.pos.clone();
                        this.w = this.wd + (this.wt - this.wd) * (double)(numIters - iter) / (double)numIters;
                        double x = 2.0 * Math.random() / Math.abs(2.0 - this.lp - this.lg - Math.sqrt((this.lp + this.lg) * (this.lp + this.lg - 4.0)));
                        for (int j = 0; j < this.len; ++j) {
                            bp.volocity.set(j, bp.volocity.get(j) + this.lp * (this.p1 - bp.pos.get(j) + this.lg * (this.p2 - bp.pos.get(j))));
                            bp.pos.set(j, x * bp.volocity.get(j) + x * bp.pos.get(j) + (1.0 - x) * (this.lp * this.p1 + this.lg * this.p2) / (this.lp + this.lg));
                        }
                    }
                    if (loss < this.fitness_gbest) {
                        this.fitness_gbest = loss;
                        this.pos_gbest = tmp;
                    }
                    Logs.info("Fold[" + this.fold + "]: current particle: " + (i + 1) + ", current iteration: " + iter + ", current loss: " + loss + ", lowest loss: " + this.fitness_gbest);
                }
            }
        } else {
            String[] strs = this.sol.split(";", -1);
            if (strs.length != this.len) {
                Logs.error("Error: the length of your solution should be " + this.len);
                return;
            }
            for (int i = 0; i < this.len; ++i) {
                double bit = Double.parseDouble(strs[i].trim());
                this.pos_gbest.set(i, bit);
            }
            Logs.info("You solution has been successfully loaded.");
        }
    }

    protected double predict(int a, int t, int c, DenseVector position) throws Exception {
        double pred = 0.0;
        double[] pos = position.getData();
        double[] pos_1 = new double[this.num_dim];
        double[] pos_2 = new double[this.num_dim];
        double[] pos_3 = new double[this.num_dim];
        for (int i = 0; i < pos.length; ++i) {
            if (i < this.num_dim) {
                pos_1[i] = pos[i];
                continue;
            }
            if (i < 2 * this.num_dim) {
                pos_2[i - this.num_dim] = pos[i];
                continue;
            }
            pos_3[i - 2 * this.num_dim] = pos[i];
        }
        double part3 = 0.0;
        double part3_count = 0.0;
        HashMap<Integer, Double> part22 = new HashMap<Integer, Double>();
        HashMap<Integer, Double> part22_count = new HashMap<Integer, Double>();
        HashMap<Integer, Double> part21 = new HashMap<Integer, Double>();
        HashMap<Integer, Double> nns = new HashMap<Integer, Double>();
        for (MatrixEntry me : this.trainMatrix) {
            int[] cs;
            SparseVector sv;
            int ui = me.row();
            int u = rateDao.getUserIdFromUI(ui);
            int ctx = me.column();
            double rujc = me.get();
            if (u == a) {
                double sim = this.ContextSimilarity(c, ctx, pos_3);
                if (!(sim >= this.th)) continue;
                part3 += sim * rujc;
                part3_count += sim;
                continue;
            }
            double simu = this.userCorrs.get(a, u);
            if (!(simu > 0.0)) continue;
            int j = rateDao.getItemIdFromUI(ui);
            int newui = rateDao.getUserItemId(u + "," + t);
            if (newui == -1 || (sv = this.trainMatrix.row(newui)) == null || !this.ContextMatch(c, cs = sv.getIndex(), pos_1)) continue;
            nns.put(u, simu);
            double rate = this.ContextWeight(c, cs, pos_2, sv);
            if (rate == -1.0) {
                rate = this.train.get(u, t);
            }
            part21.put(u, rate);
        }
        part3 = part3_count == 0.0 ? this.userMeans.get(a) : (part3 /= part3_count);
        pred += part3;
        List sorted = Lists.sortMap(nns, true);
        int k = nns.size();
        if (k != 0) {
            k = k > knn ? knn : k;
            List subset = sorted.subList(0, k);
            nns.clear();
            for (Map.Entry kv : subset) {
                nns.put((Integer)kv.getKey(), (Double)kv.getValue());
            }
            List<Integer> uiids = this.trainMatrix.rows();
            for (int uiid : uiids) {
                int[] cs;
                int u = rateDao.getUserIdFromUI(uiid);
                if (!nns.containsKey(u)) continue;
                SparseVector sv = this.trainMatrix.row(uiid);
                for (int ctx : cs = sv.getIndex()) {
                    double sim = this.ContextSimilarity(c, ctx, pos_2);
                    double r = sv.get(ctx);
                    if (part22.containsKey(u)) {
                        part22.put(u, (Double)part22.get(u) + sim * r);
                        part22_count.put(u, (Double)part22_count.get(u) + sim);
                        continue;
                    }
                    part22.put(u, sim * r);
                    part22_count.put(u, sim);
                }
            }
            double sum1 = 0.0;
            double sum2 = 0.0;
            for (Map.Entry en : nns.entrySet()) {
                int ngbr = (Integer)en.getKey();
                sum2 += ((Double)en.getValue()).doubleValue();
                double tmp = 0.0;
                tmp = part22.containsKey(ngbr) ? (Double)part22.get(ngbr) / (Double)part22_count.get(ngbr) : this.userMeans.get(ngbr);
                sum1 += (Double)en.getValue() * ((Double)part21.get(ngbr) - tmp);
            }
            pred += sum1 / sum2;
        }
        return pred;
    }

    protected double ContextWeight(int c, int[] cs, double[] pos, SparseVector sv) {
        double rate = 0.0;
        double count = 0.0;
        for (int ctx : cs) {
            double sim = this.ContextSimilarity(c, ctx, pos);
            if (!(sim >= this.th)) continue;
            rate += sv.get(ctx);
            count += 1.0;
        }
        if (count == 0.0) {
            return -1.0;
        }
        return rate / count;
    }

    protected boolean ContextMatch(int c, int[] cs, double[] pos) {
        boolean okay = false;
        for (int ctx : cs) {
            double sim = this.ContextSimilarity(c, ctx, pos);
            if (!(sim >= this.th)) continue;
            okay = true;
            break;
        }
        return okay;
    }

    protected double ContextSimilarity(int c, int ctx, double[] pos) {
        double sim = 0.0;
        ArrayList<Integer> conds1 = rateDao.getContextConditionsList().get(c);
        ArrayList<Integer> conds2 = rateDao.getContextConditionsList().get(ctx);
        for (int i = 0; i < pos.length; ++i) {
            if (conds1.get(i) != conds2.get(i)) continue;
            sim += pos[i];
        }
        return sim / Stats.sum(pos);
    }

    @Override
    protected double predict(int u, int j, int c) throws Exception {
        return this.predict(u, j, c, this.pos_gbest);
    }

    @Override
    public String toString() {
        return Strings.toString(new Object[]{"p: " + this.p, "lp: " + this.lp, "lg: " + this.lg, "wt: " + this.wt, "wd: " + this.wd, "th: " + this.th, "sol: " + this.pos_gbest.toString()});
    }
}

