/*
 * Decompiled with CFR 0.152.
 */
package librec.rating;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.MatrixEntry;
import librec.data.SparseMatrix;
import librec.intf.IterativeRecommender;
import librec.util.Randoms;
import librec.util.Strings;

public class TimeSVD
extends IterativeRecommender {
    private static int numDays;
    private DenseVector userMeanDate;
    private float beta;
    private int numBins;
    private DenseMatrix Y;
    private DenseMatrix Bit;
    private Table<Integer, Integer, Double> But;
    private DenseVector Alpha;
    private DenseMatrix Auk;
    private Map<Integer, Table<Integer, Integer, Double>> Pukt;
    private DenseVector Cu;
    private DenseMatrix Cut;

    public TimeSVD(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        this.setAlgoName("timeSVD++");
        this.beta = algoOptions.getFloat("-beta");
        this.numBins = algoOptions.getInt("-bins");
    }

    @Override
    protected void initModel() throws Exception {
        super.initModel();
        numDays = TimeSVD.days(maxTimestamp, minTimestamp) + 1;
        this.userBias = new DenseVector(numUsers);
        this.userBias.init();
        this.itemBias = new DenseVector(numItems);
        this.itemBias.init();
        this.Alpha = new DenseVector(numUsers);
        this.Alpha.init();
        this.Bit = new DenseMatrix(numItems, this.numBins);
        this.Bit.init();
        this.Y = new DenseMatrix(numItems, numFactors);
        this.Y.init();
        this.Auk = new DenseMatrix(numUsers, numFactors);
        this.Auk.init();
        this.But = HashBasedTable.create();
        this.Pukt = new HashMap<Integer, Table<Integer, Integer, Double>>();
        this.Cu = new DenseVector(numUsers);
        this.Cu.init();
        this.Cut = new DenseMatrix(numUsers, numDays);
        this.Cut.init();
        this.userItemsCache = this.trainMatrix.rowColumnsCache(cacheSpec);
        double sum = 0.0;
        int cnt = 0;
        for (MatrixEntry me : this.trainMatrix) {
            int u = me.row();
            int i = me.column();
            double rui = me.get();
            if (rui <= 0.0) continue;
            sum += (double)TimeSVD.days((long)timeMatrix.get(u, i), minTimestamp);
            ++cnt;
        }
        double globalMeanDate = sum / (double)cnt;
        this.userMeanDate = new DenseVector(numUsers);
        List Ru = null;
        int u = 0;
        while (u < numUsers) {
            sum = 0.0;
            Ru = (List)this.userItemsCache.get(u);
            Iterator iterator = Ru.iterator();
            while (iterator.hasNext()) {
                int i = (Integer)iterator.next();
                sum += (double)TimeSVD.days((long)timeMatrix.get(u, i), minTimestamp);
            }
            double mean = Ru.size() > 0 ? (sum + 0.0) / (double)Ru.size() : globalMeanDate;
            this.userMeanDate.set(u, mean);
            ++u;
        }
    }

    @Override
    protected void buildModel() throws Exception {
        int iter = 1;
        while (iter <= numIters) {
            this.loss = 0.0;
            for (MatrixEntry me : this.trainMatrix) {
                int u = me.row();
                int i = me.column();
                double rui = me.get();
                long timestamp = (long)timeMatrix.get(u, i);
                int t = TimeSVD.days(timestamp, minTimestamp);
                int bin = this.bin(t);
                double dev_ut = this.dev(u, t);
                double bi = this.itemBias.get(i);
                double bit = this.Bit.get(i, bin);
                double bu = this.userBias.get(u);
                double cu = this.Cu.get(u);
                double cut = this.Cut.get(u, t);
                if (!this.But.contains(u, t)) {
                    this.But.put(u, t, Randoms.random());
                }
                double but = this.But.get(u, t);
                double au = this.Alpha.get(u);
                double pui = this.globalMean + (bi + bit) * (cu + cut);
                pui += bu + au * dev_ut + but;
                List Ru = (List)this.userItemsCache.get(u);
                double sum_y = 0.0;
                Iterator iterator = Ru.iterator();
                while (iterator.hasNext()) {
                    int j = (Integer)iterator.next();
                    sum_y += DenseMatrix.rowMult(this.Y, j, this.Q, i);
                }
                double wi = Ru.size() > 0 ? Math.pow(Ru.size(), -0.5) : 0.0;
                pui += sum_y * wi;
                if (!this.Pukt.containsKey(u)) {
                    HashBasedTable data = HashBasedTable.create();
                    this.Pukt.put(u, data);
                }
                Table<Integer, Integer, Double> Pkt = this.Pukt.get(u);
                int k = 0;
                while (k < numFactors) {
                    double qik = this.Q.get(i, k);
                    if (!Pkt.contains(k, t)) {
                        Pkt.put(k, t, Randoms.random());
                    }
                    double puk = this.P.get(u, k) + this.Auk.get(u, k) * dev_ut + Pkt.get(k, t);
                    pui += puk * qik;
                    ++k;
                }
                double eui = pui - rui;
                this.loss += eui * eui;
                double sgd = eui * (cu + cut) + (double)regB * bi;
                this.itemBias.add(i, -this.lRate * sgd);
                this.loss += (double)regB * bi * bi;
                sgd = eui * (cu + cut) + (double)regB * bit;
                this.Bit.add(i, bin, -this.lRate * sgd);
                this.loss += (double)regB * bit * bit;
                sgd = eui * (bi + bit) + (double)regB * cu;
                this.Cu.add(u, -this.lRate * sgd);
                this.loss += (double)regB * cu * cu;
                sgd = eui * (bi + bit) + (double)regB * cut;
                this.Cut.add(u, t, -this.lRate * sgd);
                this.loss += (double)regB * cut * cut;
                sgd = eui + (double)regB * bu;
                this.userBias.add(u, -this.lRate * sgd);
                this.loss += (double)regB * bu * bu;
                sgd = eui * dev_ut + (double)regB * au;
                this.Alpha.add(u, -this.lRate * sgd);
                this.loss += (double)regB * au * au;
                sgd = eui + (double)regB * but;
                double delta = but - this.lRate * sgd;
                this.But.put(u, t, delta);
                this.loss += (double)regB * but * but;
                int k2 = 0;
                while (k2 < numFactors) {
                    int j;
                    double qik = this.Q.get(i, k2);
                    double puk = this.P.get(u, k2);
                    double auk = this.Auk.get(u, k2);
                    double pkt = Pkt.get(k2, t);
                    double pukt = puk + auk * dev_ut + pkt;
                    double sum_yk = 0.0;
                    Iterator iterator2 = Ru.iterator();
                    while (iterator2.hasNext()) {
                        j = (Integer)iterator2.next();
                        sum_yk += this.Y.get(j, k2);
                    }
                    sgd = eui * (pukt + wi * sum_yk) + (double)regI * qik;
                    this.Q.add(i, k2, -this.lRate * sgd);
                    this.loss += (double)regI * qik * qik;
                    sgd = eui * qik + (double)regU * puk;
                    this.P.add(u, k2, -this.lRate * sgd);
                    this.loss += (double)regU * puk * puk;
                    sgd = eui * qik * dev_ut + (double)regU * auk;
                    this.Auk.add(u, k2, -this.lRate * sgd);
                    this.loss += (double)regU * auk * auk;
                    sgd = eui * qik + (double)regU * pkt;
                    delta = pkt - this.lRate * sgd;
                    Pkt.put(k2, t, delta);
                    this.loss += (double)regU * pkt * pkt;
                    iterator2 = Ru.iterator();
                    while (iterator2.hasNext()) {
                        j = (Integer)iterator2.next();
                        double yjk = this.Y.get(j, k2);
                        sgd = eui * wi * qik + (double)regI * yjk;
                        this.Y.add(j, k2, -this.lRate * sgd);
                        this.loss += (double)regI * yjk * yjk;
                    }
                    ++k2;
                }
            }
            this.loss *= 0.5;
            if (this.isConverged(iter)) break;
            ++iter;
        }
    }

    @Override
    protected double predict(int u, int i) throws Exception {
        long timestamp = (long)testTimeMatrix.get(u, i);
        int t = TimeSVD.days(timestamp, minTimestamp);
        int bin = this.bin(t);
        double dev_ut = this.dev(u, t);
        double pred = this.globalMean;
        pred += (this.itemBias.get(i) + this.Bit.get(i, bin)) * (this.Cu.get(u) + this.Cut.get(u, t));
        double but = this.But.contains(u, t) ? this.But.get(u, t) : 0.0;
        pred += this.userBias.get(u) + this.Alpha.get(u) * dev_ut + but;
        List Ru = (List)this.userItemsCache.get(u);
        double sum_y = 0.0;
        Iterator iterator = Ru.iterator();
        while (iterator.hasNext()) {
            int j = (Integer)iterator.next();
            sum_y += DenseMatrix.rowMult(this.Y, j, this.Q, i);
        }
        double wi = Ru.size() > 0 ? Math.pow(Ru.size(), -0.5) : 0.0;
        pred += sum_y * wi;
        int k = 0;
        while (k < numFactors) {
            Table<Integer, Integer, Double> pkt;
            double qik = this.Q.get(i, k);
            double puk = this.P.get(u, k) + this.Auk.get(u, k) * dev_ut;
            if (this.Pukt.containsKey(u) && (pkt = this.Pukt.get(u)) != null) {
                puk += pkt.contains(k, t) ? pkt.get(k, t) : 0.0;
            }
            pred += puk * qik;
            ++k;
        }
        return pred;
    }

    @Override
    public String toString() {
        return String.valueOf(super.toString()) + "," + Strings.toString(new Object[]{Float.valueOf(this.beta), this.numBins});
    }

    protected double dev(int u, int t) {
        double tu = this.userMeanDate.get(u);
        double diff = (double)t - tu;
        return Math.signum(diff) * Math.pow(Math.abs(diff), this.beta);
    }

    protected int bin(int day) {
        return (int)((double)day / ((double)numDays + 0.0) * (double)this.numBins);
    }

    protected static int days(long diff) {
        return (int)TimeUnit.MILLISECONDS.toDays(diff);
    }

    protected static int days(long t1, long t2) {
        return TimeSVD.days(Math.abs(t1 - t2));
    }
}

