/*
 * Decompiled with CFR 0.152.
 */
package carskit.generic;

import carskit.data.structure.SparseMatrix;
import carskit.generic.Recommender;
import happy.coding.io.FileIO;
import happy.coding.io.LineConfiger;
import happy.coding.io.Logs;
import happy.coding.io.Strings;
import librec.data.Configuration;
import librec.data.DenseMatrix;
import librec.data.DenseVector;

@Configuration(value="factors, lRate, maxLRate, regB, regU, regI, iters, boldDriver")
public abstract class IterativeRecommender
extends Recommender {
    protected static float initLRate;
    protected static float maxLRate;
    protected static float momentum;
    protected static LineConfiger regOptions;
    protected static float regU;
    protected static float regI;
    protected static float regB;
    protected static float reg;
    protected static float regC;
    protected static int numFactors;
    protected static int numIters;
    protected static boolean isBoldDriver;
    protected static float decay;
    public static boolean resetStatics;
    protected DenseMatrix P;
    protected DenseMatrix Q;
    protected DenseVector userBias;
    protected DenseVector itemBias;
    protected double lRate;
    protected double loss;
    protected double last_loss = 0.0;
    protected double measure;
    protected double last_measure = 0.0;
    protected boolean initByNorm;

    public IterativeRecommender(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
        super(trainMatrix, testMatrix, fold);
        if (resetStatics) {
            resetStatics = false;
            LineConfiger lc = cf.getParamOptions("learn.rate");
            if (lc != null) {
                initLRate = Float.parseFloat(lc.getMainParam());
                maxLRate = lc.getFloat("-max", -1.0f);
                isBoldDriver = lc.contains("-bold-driver");
                decay = lc.getFloat("-decay", -1.0f);
                momentum = lc.getFloat("-momentum", 50.0f);
            }
            if ((regOptions = cf.getParamOptions("reg.lambda")) != null) {
                reg = Float.parseFloat(regOptions.getMainParam());
                regU = regOptions.getFloat("-u", reg);
                regI = regOptions.getFloat("-i", reg);
                regB = regOptions.getFloat("-b", reg);
                regC = regOptions.getFloat("-c", reg);
            }
            numFactors = cf.getInt("num.factors", 10);
            numIters = cf.getInt("num.max.iter", 100);
        }
        this.lRate = initLRate;
        this.initByNorm = true;
    }

    @Override
    protected double predict(int u, int j, int c) throws Exception {
        if (this.isUserSplitting) {
            int n = u = this.userIdMapper.contains(u, c) ? (Integer)this.userIdMapper.get(u, c) : u;
        }
        if (this.isItemSplitting) {
            j = this.itemIdMapper.contains(j, c) ? (Integer)this.itemIdMapper.get(j, c) : j;
        }
        return this.predict(u, j);
    }

    @Override
    protected double predict(int u, int j) throws Exception {
        return DenseMatrix.rowMult(this.P, u, this.Q, j);
    }

    protected boolean isConverged(int iter) throws Exception {
        boolean converged;
        float delta_loss = (float)(this.last_loss - this.loss);
        if (earlyStopMeasure != null) {
            switch (earlyStopMeasure) {
                case Loss: {
                    this.measure = this.loss;
                    this.last_measure = this.last_loss;
                    break;
                }
                default: {
                    boolean flag = this.isResultsOut;
                    this.isResultsOut = false;
                    this.measure = this.evalRatings().get((Object)earlyStopMeasure);
                    this.isResultsOut = flag;
                }
            }
        }
        float delta_measure = (float)(this.last_measure - this.measure);
        if (verbose) {
            String learnRate = this.lRate > 0.0 ? ", learn_rate = " + (float)this.lRate : "";
            String earlyStop = "";
            if (earlyStopMeasure != null && earlyStopMeasure != Recommender.Measure.Loss) {
                earlyStop = String.format(", %s = %.6f, delta_%s = %.6f", new Object[]{earlyStopMeasure, Float.valueOf((float)this.measure), earlyStopMeasure, Float.valueOf(delta_measure)});
            }
            Logs.debug("{}{} iter {}: loss = {}, delta_loss = {}{}{}", this.algoName, this.foldInfo, iter, Float.valueOf((float)this.loss), Float.valueOf(delta_loss), earlyStop, learnRate);
        }
        if (Double.isNaN(this.loss) || Double.isInfinite(this.loss)) {
            Logs.error("Loss = NaN or Infinity: current settings does not fit the recommender! Change the settings and try again!");
            System.exit(-1);
        }
        boolean cond1 = Math.abs(this.loss) < 1.0E-5;
        boolean cond2 = delta_measure > 0.0f && (double)delta_measure < 1.0E-5;
        boolean bl = converged = cond1 || cond2;
        if (!converged) {
            this.updateLRate(iter);
        }
        this.last_loss = this.loss;
        this.last_measure = this.measure;
        return converged;
    }

    protected void updateLRate(int iter) {
        if (this.lRate <= 0.0) {
            return;
        }
        if (isBoldDriver && iter > 1) {
            this.lRate = Math.abs(this.last_loss) > Math.abs(this.loss) ? this.lRate * 1.05 : this.lRate * 0.5;
        } else if (decay > 0.0f && decay < 1.0f) {
            this.lRate *= (double)decay;
        }
        if (maxLRate > 0.0f && this.lRate > (double)maxLRate) {
            this.lRate = maxLRate;
        }
    }

    @Override
    protected void initModel() throws Exception {
        super.initModel();
        this.P = new DenseMatrix(this.numUsers, numFactors);
        this.Q = new DenseMatrix(this.numItems, numFactors);
        if (this.initByNorm) {
            this.P.init(initMean, initStd);
            this.Q.init(initMean, initStd);
        } else {
            this.P.init();
            this.Q.init();
        }
    }

    @Override
    protected void saveModel() throws Exception {
        String dirPath = FileIO.makeDirectory(workingPath, this.algoName);
        String suffix = this.foldInfo + ".bin";
        FileIO.serialize(this.trainMatrix, dirPath + "trainMatrix" + suffix);
        FileIO.serialize(this.testMatrix, dirPath + "testMatrix" + suffix);
        FileIO.serialize(this.P, dirPath + "userFactors" + suffix);
        FileIO.serialize(this.Q, dirPath + "itemFactors" + suffix);
        if (this.userBias != null) {
            FileIO.serialize(this.userBias, dirPath + "userBiases" + suffix);
        }
        if (this.itemBias != null) {
            FileIO.serialize(this.itemBias, dirPath + "itemBiases" + suffix);
        }
        Logs.debug("Learned models are saved to folder \"{}\"", (Object)dirPath);
    }

    @Override
    protected void loadModel() throws Exception {
        String dirPath = FileIO.makeDirectory(workingPath, this.algoName);
        Logs.debug("A recommender model is loaded from {}", (Object)dirPath);
        String suffix = this.foldInfo + ".bin";
        this.trainMatrix = (SparseMatrix)FileIO.deserialize(dirPath + "trainMatrix" + suffix);
        this.testMatrix = (SparseMatrix)FileIO.deserialize(dirPath + "testMatrix" + suffix);
        this.P = (DenseMatrix)FileIO.deserialize(dirPath + "userFactors" + suffix);
        this.Q = (DenseMatrix)FileIO.deserialize(dirPath + "itemFactors" + suffix);
        this.userBias = (DenseVector)FileIO.deserialize(dirPath + "userBiases" + suffix);
        this.itemBias = (DenseVector)FileIO.deserialize(dirPath + "itemBiases" + suffix);
    }

    public String toString() {
        return Strings.toString(new Object[]{"numFactors: " + numFactors, "numIter: " + numIters, "lrate: " + initLRate, "maxlrate: " + maxLRate, "regB: " + regB, "regU: " + regU, "regI: " + regI, "regC: " + regC, "isBoldDriver: " + isBoldDriver});
    }

    static {
        resetStatics = true;
    }
}

