/*
 * Decompiled with CFR 0.152.
 */
package carskit.alg.baseline.cf;

import carskit.data.setting.Configuration;
import carskit.generic.IterativeRecommender;
import happy.coding.io.Strings;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.MatrixEntry;
import librec.data.SparseMatrix;
import librec.data.SparseVector;

@Configuration(value="factors, numIters")
public class NMF
extends IterativeRecommender {
    protected DenseMatrix W;
    protected DenseMatrix H;
    protected SparseMatrix V;

    public NMF(carskit.data.structure.SparseMatrix trainMatrix, carskit.data.structure.SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        this.lRate = -1.0;
        this.algoName = "NMF";
    }

    @Override
    protected void initModel() throws Exception {
        super.initModel();
        this.W = new DenseMatrix(this.numUsers, numFactors);
        this.H = new DenseMatrix(numFactors, this.numItems);
        this.W.init(0.01);
        this.H.init(0.01);
        this.V = this.train;
    }

    @Override
    protected void buildModel() throws Exception {
        for (int iter = 1; iter <= numIters; ++iter) {
            for (int u = 0; u < this.W.numRows(); ++u) {
                SparseVector uv = this.V.row(u);
                if (uv.getCount() <= 0) continue;
                SparseVector euv = new SparseVector(this.V.numColumns());
                for (int j : uv.getIndex()) {
                    euv.set(j, this.predict(u, j));
                }
                for (int f = 0; f < this.W.numColumns(); ++f) {
                    DenseVector fv = this.H.row(f, false);
                    double real = fv.inner(uv);
                    double estm = fv.inner(euv) + 1.0E-9;
                    this.W.set(u, f, this.W.get(u, f) * (real / estm));
                }
            }
            DenseMatrix trW = this.W.transpose();
            for (int j = 0; j < this.H.numColumns(); ++j) {
                SparseVector jv = this.V.column(j);
                if (jv.getCount() <= 0) continue;
                SparseVector ejv = new SparseVector(this.V.numRows());
                for (int u : jv.getIndex()) {
                    ejv.set(u, this.predict(u, j));
                }
                for (int f = 0; f < this.H.numRows(); ++f) {
                    DenseVector fv = trW.row(f, false);
                    double real = fv.inner(jv);
                    double estm = fv.inner(ejv) + 1.0E-9;
                    this.H.set(f, j, this.H.get(f, j) * (real / estm));
                }
            }
            this.loss = 0.0;
            for (MatrixEntry me : this.V) {
                int u = me.row();
                int j = me.column();
                double ruj = me.get();
                if (!(ruj > 0.0)) continue;
                double euj = this.predict(u, j) - ruj;
                this.loss += euj * euj;
            }
            this.loss *= 0.5;
            if (this.isConverged(iter)) break;
        }
    }

    @Override
    protected double predict(int u, int j, int c) throws Exception {
        if (this.isUserSplitting) {
            int n = u = this.userIdMapper.contains(u, c) ? (Integer)this.userIdMapper.get(u, c) : u;
        }
        if (this.isItemSplitting) {
            j = this.itemIdMapper.contains(j, c) ? (Integer)this.itemIdMapper.get(j, c) : j;
        }
        return this.predict(u, j);
    }

    @Override
    protected double predict(int u, int j) throws Exception {
        return DenseMatrix.product(this.W, u, this.H, j);
    }

    @Override
    public String toString() {
        return Strings.toString(new Object[]{numFactors, numIters});
    }
}

