/*
 * Decompiled with CFR 0.152.
 */
package librec.ranking;

import java.util.HashMap;
import java.util.List;
import librec.data.Configuration;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.SparseMatrix;
import librec.data.SparseVector;
import librec.data.VectorEntry;
import librec.intf.IterativeRecommender;
import librec.util.Logs;
import librec.util.Strings;

@Configuration(value="binThold, factors, isSupportWeight, numIters")
public class RankALS
extends IterativeRecommender {
    private boolean isSupportWeight;
    private DenseVector s;
    private double sum_s;

    public RankALS(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        isRankingPred = true;
        this.checkBinary();
    }

    @Override
    protected void initModel() throws Exception {
        super.initModel();
        this.isSupportWeight = algoOptions.isOn("-sw");
        this.s = new DenseVector(numItems);
        this.sum_s = 0.0;
        int i = 0;
        while (i < numItems) {
            double si = this.isSupportWeight ? this.trainMatrix.columnSize(i) : 1;
            this.s.set(i, si);
            this.sum_s += si;
            ++i;
        }
    }

    @Override
    protected void buildModel() throws Exception {
        int iter = 1;
        while (iter < numIters) {
            Object pu;
            SparseVector Ru;
            Object sum_sqr;
            if (verbose) {
                Logs.debug("{}{} runs at iter = {}/{}", this.algoName, this.foldInfo, iter, numIters);
            }
            DenseVector sum_sq = new DenseVector(numFactors);
            DenseMatrix sum_sqq = new DenseMatrix(numFactors, numFactors);
            int j = 0;
            while (j < numItems) {
                DenseVector qj = this.Q.row(j);
                double sj = this.s.get(j);
                sum_sq = sum_sq.add(qj.scale(sj));
                sum_sqq = sum_sqq.add(qj.outer(qj).scale(sj));
                ++j;
            }
            List<Integer> cus = this.trainMatrix.rows();
            for (int u : cus) {
                DenseMatrix sum_cqq = new DenseMatrix(numFactors, numFactors);
                DenseVector sum_cq = new DenseVector(numFactors);
                DenseVector sum_cqr = new DenseVector(numFactors);
                sum_sqr = new DenseVector(numFactors);
                Ru = this.trainMatrix.row(u);
                double sum_c = Ru.getCount();
                double sum_sr = 0.0;
                double sum_cr = 0.0;
                for (VectorEntry ve : Ru) {
                    int i = ve.index();
                    double rui = ve.get();
                    DenseVector qi = this.Q.row(i);
                    sum_cqq = sum_cqq.add(qi.outer(qi));
                    sum_cq = sum_cq.add(qi);
                    sum_cqr = sum_cqr.add(qi.scale(rui));
                    double si = this.s.get(i);
                    sum_sr += si * rui;
                    sum_cr += rui;
                    sum_sqr = ((DenseVector)sum_sqr).add(qi.scale(si * rui));
                }
                DenseMatrix M = sum_cqq.scale(this.sum_s).minus(sum_cq.outer(sum_sq)).minus(sum_sq.outer(sum_cq)).add(sum_sqq.scale(sum_c));
                DenseVector y = sum_cqr.scale(this.sum_s).minus(sum_cq.scale(sum_sr)).minus(sum_sq.scale(sum_cr)).add(((DenseVector)sum_sqr).scale(sum_c));
                pu = M.inv().mult(y);
                this.P.setRow(u, (DenseVector)pu);
            }
            HashMap<Integer, Double> m_sum_sr = new HashMap<Integer, Double>();
            HashMap<Integer, Double> m_sum_cr = new HashMap<Integer, Double>();
            HashMap<Integer, Double> m_sum_c = new HashMap<Integer, Double>();
            HashMap<Integer, DenseVector> m_sum_cq = new HashMap<Integer, DenseVector>();
            sum_sqr = cus.iterator();
            while (sum_sqr.hasNext()) {
                int u = sum_sqr.next();
                Ru = this.trainMatrix.row(u);
                double sum_sr = 0.0;
                double sum_cr = 0.0;
                double sum_c = Ru.getCount();
                DenseVector sum_cq = new DenseVector(numFactors);
                for (VectorEntry ve : Ru) {
                    int j2 = ve.index();
                    double ruj = ve.get();
                    double sj = this.s.get(j2);
                    sum_sr += sj * ruj;
                    sum_cr += ruj;
                    sum_cq = sum_cq.add(this.Q.row(j2));
                }
                m_sum_sr.put(u, sum_sr);
                m_sum_cr.put(u, sum_cr);
                m_sum_c.put(u, sum_c);
                m_sum_cq.put(u, sum_cq);
            }
            int i = 0;
            while (i < numItems) {
                DenseMatrix sum_cpp = new DenseMatrix(numFactors, numFactors);
                DenseMatrix sum_p_p_c = new DenseMatrix(numFactors, numFactors);
                DenseVector sum_p_p_cq = new DenseVector(numFactors);
                DenseVector sum_cpr = new DenseVector(numFactors);
                DenseVector sum_c_sr_p = new DenseVector(numFactors);
                DenseVector sum_cr_p = new DenseVector(numFactors);
                DenseVector sum_p_r_c = new DenseVector(numFactors);
                double si = this.s.get(i);
                pu = cus.iterator();
                while (pu.hasNext()) {
                    int u = (Integer)pu.next();
                    DenseVector pu2 = this.P.row(u);
                    double rui = this.trainMatrix.get(u, i);
                    DenseMatrix pp = pu2.outer(pu2);
                    sum_cpp = sum_cpp.add(pp);
                    sum_p_p_cq = sum_p_p_cq.add(pp.mult((DenseVector)m_sum_cq.get(u)));
                    sum_p_p_c = sum_p_p_c.add(pp.scale((Double)m_sum_c.get(u)));
                    sum_cr_p = sum_cr_p.add(pu2.scale((Double)m_sum_cr.get(u)));
                    if (!(rui > 0.0)) continue;
                    sum_cpr = sum_cpr.add(pu2.scale(rui));
                    sum_c_sr_p = sum_c_sr_p.add(pu2.scale((Double)m_sum_sr.get(u)));
                    sum_p_r_c = sum_p_r_c.add(pu2.scale(rui * (Double)m_sum_c.get(u)));
                }
                DenseMatrix M = sum_cpp.scale(this.sum_s).add(sum_p_p_c.scale(si));
                DenseVector y = sum_cpp.mult(sum_sq).add(sum_cpr.scale(this.sum_s)).minus(sum_c_sr_p).add(sum_p_p_cq.scale(si)).minus(sum_cr_p.scale(si)).add(sum_p_r_c.scale(si));
                DenseVector qi = M.inv().mult(y);
                this.Q.setRow(i, qi);
                ++i;
            }
            ++iter;
        }
    }

    @Override
    public String toString() {
        return Strings.toString(new Object[]{Float.valueOf(binThold), numFactors, this.isSupportWeight, numIters});
    }
}

