/*
 * Decompiled with CFR 0.152.
 */
package carskit.alg.cars.transformation.hybridfiltering;

import carskit.alg.cars.transformation.hybridfiltering.Particle_BPSO;
import carskit.data.structure.SparseMatrix;
import carskit.generic.IterativeRecommender;
import happy.coding.io.Lists;
import happy.coding.io.Logs;
import happy.coding.io.Strings;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import librec.data.DenseVector;
import librec.data.MatrixEntry;
import librec.data.SparseVector;
import librec.data.SymmMatrix;

public class DCR
extends IterativeRecommender {
    private SymmMatrix userCorrs;
    private DenseVector userMeans;
    private int p;
    private double lp;
    private double lg;
    private double wt;
    private double wd;
    private double w;
    private int num_dim;
    private int num_component = 3;
    private DenseVector pos_gbest;
    private double fitness_gbest;
    private Particle_BPSO[] swarm;
    private int len;
    private int start = -1;
    private String sol = "";

    public DCR(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        this.algoName = "DCR";
        this.lp = algoOptions.getDouble("-lp");
        this.lg = algoOptions.getDouble("-lg");
        this.wt = algoOptions.getDouble("-wt");
        this.wd = algoOptions.getDouble("-wd");
        this.p = algoOptions.getInt("-p");
        this.sol = algoOptions.getString("-sol", "");
    }

    @Override
    protected void initModel() throws Exception {
        super.initModel();
        this.userCorrs = this.buildCorrs(true);
        this.userMeans = new DenseVector(this.numUsers);
        for (int u = 0; u < this.numUsers; ++u) {
            SparseVector uv = this.train.row(u);
            this.userMeans.set(u, uv.getCount() > 0 ? uv.mean() : this.globalMean);
        }
        this.num_dim = rateDao.numContextDims();
        this.fitness_gbest = Double.MAX_VALUE;
        this.len = this.num_dim * this.num_component;
        this.pos_gbest = new DenseVector(this.len);
        this.swarm = new Particle_BPSO[this.p];
        for (int i = 0; i < this.p; ++i) {
            this.swarm[i] = new Particle_BPSO(this.len);
        }
    }

    @Override
    protected void buildModel() throws Exception {
        if (this.sol.equals("")) {
            for (int i = 0; i < this.p; ++i) {
                Particle_BPSO bp = this.swarm[i];
                for (int iter = 1; iter <= numIters; ++iter) {
                    double loss = 0.0;
                    for (MatrixEntry me : this.trainMatrix) {
                        int ui = me.row();
                        int u = rateDao.getUserIdFromUI(ui);
                        int j = rateDao.getItemIdFromUI(ui);
                        int ctx = me.column();
                        double rujc = me.get();
                        double predication = this.predict(u, j, ctx, bp.pos);
                        loss += Math.pow(rujc - predication, 2.0);
                    }
                    if (this.start == -1) {
                        this.start = 0;
                        if (loss < this.fitness_gbest) {
                            this.fitness_gbest = loss;
                            this.pos_gbest = bp.pos.clone();
                        }
                    }
                    DenseVector tmp = bp.pos.clone();
                    if (loss < bp.fitness_best) {
                        bp.fitness_best = loss;
                        bp.pos_best = bp.pos.clone();
                        this.w = this.wd + (this.wt - this.wd) * (double)(numIters - iter) / (double)numIters;
                        for (int j = 0; j < this.len; ++j) {
                            double d11 = 0.0;
                            double d01 = 0.0;
                            double d12 = 0.0;
                            double d02 = 0.0;
                            double r1 = Math.random();
                            if (bp.pos.get(j) == 1.0) {
                                d11 = this.lp * r1;
                                d01 = 0.0 - d11;
                            } else {
                                d01 = this.lp * r1;
                                d11 = 0.0 - d01;
                            }
                            double r2 = Math.random();
                            if (this.pos_gbest.get(j) == 1.0) {
                                d12 = this.lg * r2;
                                d02 = 0.0 - d12;
                            } else {
                                d02 = this.lg * r2;
                                d12 = 0.0 - d02;
                            }
                            bp.volocity_1.set(j, this.w * bp.volocity_1.get(j) + d11 + d12);
                            bp.volocity_0.set(j, this.w * bp.volocity_0.get(j) + d01 + d02);
                            double v = 0.0;
                            v = bp.pos.get(j) == 0.0 ? bp.volocity_1.get(j) : bp.volocity_0.get(j);
                            double sv = 1.0 / (1.0 + Math.exp(0.0 - v));
                            if (!(Math.random() < sv)) continue;
                            if (bp.pos.get(j) == 1.0) {
                                bp.pos.set(j, 0.0);
                                continue;
                            }
                            bp.pos.set(j, 1.0);
                        }
                    }
                    if (loss < this.fitness_gbest) {
                        this.fitness_gbest = loss;
                        this.pos_gbest = tmp;
                    }
                    Logs.info("Fold[" + this.fold + "]: current particle: " + (i + 1) + ", current iteration: " + iter + ", current loss: " + loss + ", lowest loss: " + this.fitness_gbest);
                }
            }
        } else {
            String[] strs = this.sol.split(";", -1);
            if (strs.length != this.len) {
                Logs.error("Error: the length of your solution should be " + this.len);
                return;
            }
            for (int i = 0; i < this.len; ++i) {
                int bit = Integer.parseInt(strs[i].trim());
                this.pos_gbest.set(i, bit);
            }
            Logs.info("You solution has been successfully loaded.");
        }
    }

    protected double predict(int a, int t, int c, DenseVector position) throws Exception {
        double pred = 0.0;
        double[] pos = position.getData();
        double[] pos_1 = new double[this.num_dim];
        double[] pos_2 = new double[this.num_dim];
        double[] pos_3 = new double[this.num_dim];
        for (int i = 0; i < pos.length; ++i) {
            if (i < this.num_dim) {
                pos_1[i] = pos[i];
                continue;
            }
            if (i < 2 * this.num_dim) {
                pos_2[i - this.num_dim] = pos[i];
                continue;
            }
            pos_3[i - 2 * this.num_dim] = pos[i];
        }
        double part3 = 0.0;
        double part3_count = 0.0;
        HashMap<Integer, Double> part22 = new HashMap<Integer, Double>();
        HashMap<Integer, Double> part22_count = new HashMap<Integer, Double>();
        HashMap<Integer, Double> part21 = new HashMap<Integer, Double>();
        HashMap<Integer, Double> nns = new HashMap<Integer, Double>();
        for (MatrixEntry me : this.trainMatrix) {
            int[] cs;
            double rate;
            SparseVector sv;
            int ui = me.row();
            int u = rateDao.getUserIdFromUI(ui);
            int ctx = me.column();
            double rujc = me.get();
            if (u == a) {
                if (!this.ContextRelaxation(c, ctx, pos_3)) continue;
                part3 += rujc;
                part3_count += 1.0;
                continue;
            }
            double sim = this.userCorrs.get(a, u);
            if (!(sim > 0.0)) continue;
            int j = rateDao.getItemIdFromUI(ui);
            int newui = rateDao.getUserItemId(u + "," + t);
            if (newui == -1 || (sv = this.trainMatrix.row(newui)) == null || (rate = this.ContextRelaxation(c, cs = sv.getIndex(), pos_1, sv)) == -1.0) continue;
            nns.put(u, sim);
            rate = this.ContextRelaxation(c, cs, pos_2, sv);
            if (rate == -1.0) {
                rate = this.train.get(u, t);
            }
            part21.put(u, rate);
        }
        part3 = part3_count == 0.0 ? this.userMeans.get(a) : (part3 /= part3_count);
        pred += part3;
        List sorted = Lists.sortMap(nns, true);
        int k = nns.size();
        if (k != 0) {
            k = k > knn ? knn : k;
            List subset = sorted.subList(0, k);
            nns.clear();
            for (Map.Entry kv : subset) {
                nns.put((Integer)kv.getKey(), (Double)kv.getValue());
            }
            List<Integer> uiids = this.trainMatrix.rows();
            for (int uiid : uiids) {
                int[] cs;
                int u = rateDao.getUserIdFromUI(uiid);
                if (!nns.containsKey(u)) continue;
                SparseVector sv = this.trainMatrix.row(uiid);
                for (int ctx : cs = sv.getIndex()) {
                    if (!this.ContextRelaxation(c, ctx, pos_2)) continue;
                    double r = sv.get(ctx);
                    if (part22.containsKey(u)) {
                        part22.put(u, (Double)part22.get(u) + r);
                        part22_count.put(u, (Double)part22_count.get(u) + 1.0);
                        continue;
                    }
                    part22.put(u, r);
                    part22_count.put(u, 1.0);
                }
            }
            double sum1 = 0.0;
            double sum2 = 0.0;
            for (Map.Entry en : nns.entrySet()) {
                int ngbr = (Integer)en.getKey();
                sum2 += ((Double)en.getValue()).doubleValue();
                double tmp = 0.0;
                tmp = part22.containsKey(ngbr) ? (Double)part22.get(ngbr) / (Double)part22_count.get(ngbr) : this.userMeans.get(ngbr);
                sum1 += (Double)en.getValue() * ((Double)part21.get(ngbr) - tmp);
            }
            pred += sum1 / sum2;
        }
        return pred;
    }

    protected double ContextRelaxation(int c, int[] cs, double[] pos, SparseVector sv) {
        double rate = -1.0;
        int index = -1;
        for (int ctx : cs) {
            if (!this.ContextRelaxation(c, ctx, pos)) continue;
            index = ctx;
            break;
        }
        if (index != -1) {
            rate = sv.get(index);
        }
        return rate;
    }

    protected boolean ContextRelaxation(int c, int ctx, double[] pos) {
        boolean mt = true;
        ArrayList<Integer> conds1 = rateDao.getContextConditionsList().get(c);
        ArrayList<Integer> conds2 = rateDao.getContextConditionsList().get(ctx);
        for (int i = 0; i < pos.length; ++i) {
            if (pos[i] != 1.0 || conds1.get(i) == conds2.get(i)) continue;
            mt = false;
            break;
        }
        return mt;
    }

    @Override
    protected double predict(int u, int j, int c) throws Exception {
        return this.predict(u, j, c, this.pos_gbest);
    }

    @Override
    public String toString() {
        return Strings.toString(new Object[]{"p: " + this.p, "lp: " + this.lp, "lg: " + this.lg, "wt: " + this.wt, "wd: " + this.wd, "sol: " + this.pos_gbest.toString()});
    }
}

