/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.util.tools.optimization.gd;

import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.LinkedList;
import java.util.ListIterator;
import si.ijs.kt.clus.algo.rules.ClusRuleSet;
import si.ijs.kt.clus.ext.ensemble.ros.ClusROSModelInfo;
import si.ijs.kt.clus.main.ClusStatManager;
import si.ijs.kt.clus.main.settings.section.SettingsEnsemble;
import si.ijs.kt.clus.main.settings.section.SettingsRules;
import si.ijs.kt.clus.util.ClusLogger;
import si.ijs.kt.clus.util.format.ClusFormat;
import si.ijs.kt.clus.util.format.ClusNumberFormat;
import si.ijs.kt.clus.util.tools.optimization.OptimizationProblem;

public class GDProblem
extends OptimizationProblem {
    public static boolean m_printGDDebugInformation = false;
    protected double[][] m_covariances;
    protected boolean[] m_isCovComputed;
    protected boolean[] m_isWeightNonZero;
    protected int m_nbOfNonZeroRules;
    protected double[] m_predCovWithTrue;
    protected double[] m_gradients;
    protected int[] m_bannedWeights;
    protected double m_stepSize;
    protected OptimizationProblem m_earlyStopProbl;
    protected ClusRuleSet m_ruleSet = null;
    double m_dynStepLowerBound = 0.0;
    protected double m_minFitness;
    protected ArrayList<Double> m_minFitWeights;

    public GDProblem(ClusStatManager stat_mgr, OptimizationProblem.OptimizationParameter optInfo, ClusRuleSet rset) {
        super(stat_mgr, optInfo);
        this.m_ruleSet = rset;
        this.preparePredictionsForNormalization();
        if (this.getSettings().getRules().getOptGDEarlyStopAmount() > 0.0) {
            int nbDataTest = (int)Math.ceil((double)this.getNbOfInstances() * this.getSettings().getRules().getOptGDEarlyStopAmount());
            OptimizationProblem.OptimizationParameter dataEarlyStop = new OptimizationProblem.OptimizationParameter(optInfo.m_rulePredictions.length, optInfo.m_baseFuncPredictions.length, nbDataTest, this.getNbOfTargets(), optInfo.m_implicitLinearTerms);
            OptimizationProblem.OptimizationParameter trainingSet = new OptimizationProblem.OptimizationParameter(optInfo.m_rulePredictions.length, optInfo.m_baseFuncPredictions.length, this.getNbOfInstances() - nbDataTest, this.getNbOfTargets(), optInfo.m_implicitLinearTerms);
            GDProblem.splitDataIntoValAndTrainSet(stat_mgr, optInfo, dataEarlyStop, trainingSet);
            this.changeData(trainingSet);
            this.m_earlyStopProbl = new OptimizationProblem(stat_mgr, dataEarlyStop);
        }
        int nbWeights = this.getNumVar();
        this.m_covariances = new double[nbWeights][nbWeights];
        for (int i = 0; i < nbWeights; ++i) {
            Arrays.setAll(this.m_covariances[i], idx -> Double.NaN);
        }
        this.m_isCovComputed = new boolean[nbWeights];
        this.initPredictorVsTrueValuesCovariances();
        if (this.getSettings().getRules().isOptGDIsDynStepsize()) {
            this.computeDynStepSize();
        }
    }

    public void initGDForNewRunWithSamePredictions() {
        int nbWeights = this.getNumVar();
        if (this.getSettings().getRules().getOptGDEarlyStopAmount() > 0.0) {
            this.m_minFitness = Double.POSITIVE_INFINITY;
            this.m_minFitWeights = new ArrayList(this.getNumVar());
            for (int iWeight = 0; iWeight < this.getNumVar(); ++iWeight) {
                this.m_minFitWeights.add(new Double(0.0));
            }
        }
        this.m_isWeightNonZero = new boolean[nbWeights];
        this.m_bannedWeights = (int[])(this.getSettings().getRules().getOptGDMTGradientCombine().equals((Object)SettingsRules.OptimizationGDMTCombineGradient.MaxLoss) ? new int[nbWeights] : null);
        this.m_gradients = new double[nbWeights];
        this.m_nbOfNonZeroRules = 0;
        this.m_stepSize = this.getSettings().getRules().getOptGDStepSize();
        if (this.getSettings().getRules().isOptGDIsDynStepsize()) {
            this.m_stepSize = this.m_dynStepLowerBound;
        }
    }

    private void computeDynStepSize() {
        for (int dimension = 0; dimension < this.getNumVar(); ++dimension) {
            this.m_covariances[dimension][dimension] = this.computeCovFor2Preds(dimension, dimension);
        }
        double sum = 0.0;
        for (int dimension = 0; dimension < this.getNumVar(); ++dimension) {
            sum += this.getWeightCov(dimension, dimension);
        }
        this.m_dynStepLowerBound = 1.0 / sum;
        if (m_printGDDebugInformation) {
            ClusLogger.info("DEBUG: DynStepSize lower bound is " + this.m_dynStepLowerBound);
        }
    }

    protected ArrayList<Double> getInitialWeightVector() {
        ArrayList<Double> result = new ArrayList<Double>(this.getNumVar());
        for (int i = 0; i < this.getNumVar(); ++i) {
            result.add(new Double(0.0));
        }
        return result;
    }

    protected final double getCovForPrediction(int iPred) {
        return this.m_predCovWithTrue[iPred];
    }

    protected void initPredictorVsTrueValuesCovariances() {
        this.m_predCovWithTrue = new double[this.getNumVar()];
        for (int iPred = 0; iPred < this.getNumVar(); ++iPred) {
            this.m_predCovWithTrue[iPred] = this.computePredVsTrueValueCov(iPred);
        }
    }

    private double computePredVsTrueValueCov(int iPred) {
        double[] covs = new double[this.getNbOfTargets()];
        int nbOfTargets = this.getNbOfTargets();
        BitSet b = new BitSet(nbOfTargets);
        for (int iTarget = 0; iTarget < nbOfTargets; ++iTarget) {
            for (int iInstance = 0; iInstance < this.getNbOfInstances(); ++iInstance) {
                double pred;
                double trueVal = this.getTrueValue(iInstance, iTarget);
                if (!this.isValidValue(trueVal) || Double.isNaN(pred = this.predictWithRule(iPred, iInstance, iTarget))) continue;
                int n = iTarget;
                covs[n] = covs[n] + trueVal * pred;
                b.set(iTarget);
            }
            int n = iTarget;
            covs[n] = covs[n] / (double)this.getNbOfInstances();
            if (!this.getSettings().getRules().isOptNormalization()) continue;
            int n2 = iTarget;
            covs[n2] = covs[n2] / this.getNormFactor(iTarget);
        }
        double avgCov = 0.0;
        for (int iTarget = 0; iTarget < nbOfTargets; ++iTarget) {
            avgCov += covs[iTarget] / (double)b.cardinality();
        }
        return avgCov;
    }

    protected final double getWeightCov(int iFirst, int iSecond) {
        int max;
        int min = Math.min(iFirst, iSecond);
        if (Double.isNaN(this.m_covariances[min][max = Math.max(iFirst, iSecond)])) {
            throw new Error("Asked covariance not yet computed. Something wrong in the covariances in GDProbl.");
        }
        return this.m_covariances[min][max];
    }

    private void computeWeightCov(int dimension) {
        for (int iMin = 0; iMin < dimension; ++iMin) {
            if (this.m_isCovComputed[iMin]) continue;
            this.m_covariances[iMin][dimension] = this.computeCovFor2Preds(iMin, dimension);
        }
        this.m_covariances[dimension][dimension] = this.computeCovFor2Preds(dimension, dimension);
        for (int iMax = dimension + 1; iMax < this.getNumVar(); ++iMax) {
            if (this.m_isCovComputed[iMax]) continue;
            this.m_covariances[dimension][iMax] = this.computeCovFor2Preds(dimension, iMax);
        }
    }

    private double computeCovFor2Preds(int iPrevious, int iLatter) {
        if (this.isRuleTerm(iLatter)) {
            return this.computeCovFor2Rules(iPrevious, iLatter);
        }
        if (this.isRuleTerm(iPrevious)) {
            return this.computeCovForRuleAndLin(iPrevious, iLatter);
        }
        return this.computeCovFor2Lin(iPrevious, iLatter);
    }

    private double computeCovFor2Lin(int iPrevious, int iLatter) {
        int nbOfInstances = this.getNbOfInstances();
        int nbOfTargets = this.getNbOfTargets();
        int iTarget = this.getLinTargetDim(iPrevious);
        if (iTarget != this.getLinTargetDim(iLatter)) {
            return 0.0;
        }
        double avgCov = 0.0;
        for (int iInstance = 0; iInstance < nbOfInstances; ++iInstance) {
            avgCov += this.predictWithRule(iPrevious, iInstance, iTarget) * this.predictWithRule(iLatter, iInstance, iTarget);
        }
        avgCov /= (double)(nbOfTargets * this.getNbOfInstances());
        if (this.getSettings().getRules().isOptNormalization()) {
            avgCov /= this.getNormFactor(iTarget);
        }
        return avgCov;
    }

    private double computeCovForRuleAndLin(int iRule, int iLinear) {
        int iTarget = this.getLinTargetDim(iLinear);
        double predRule = this.getPredictionsWhenCovered(iRule, 0, iTarget);
        if (Double.isNaN(predRule)) {
            return 0.0;
        }
        int nbOfTargets = this.getRuleEnabledTargets(iRule);
        double avgCov = 0.0;
        int iInstance = this.getRuleNextCovered(iRule, 0);
        while (iInstance >= 0) {
            avgCov += this.getPredictionsWhenCovered(iLinear, iInstance, iTarget);
            iInstance = this.getRuleNextCovered(iRule, iInstance + 1);
        }
        avgCov *= predRule;
        avgCov /= (double)(nbOfTargets * this.getNbOfInstances());
        if (this.getSettings().getRules().isOptNormalization()) {
            avgCov /= this.getNormFactor(iTarget);
        }
        if (Double.isNaN(avgCov)) {
            System.err.println("TEST");
        }
        return avgCov;
    }

    private int getRuleEnabledTargets(int iRule) {
        ClusROSModelInfo info;
        if (this.getSettings().getEnsemble().isEnsembleROSEnabled() && this.getSettings().getEnsemble().getEnsembleROSVotingType().equals((Object)SettingsEnsemble.EnsembleROSVotingType.SubspaceAveraging) && (info = this.m_ruleSet.getRule(iRule).getROSModelInfo()) != null) {
            return info.getTargets().size();
        }
        return this.getNbOfTargets();
    }

    private double computeCovFor2Rules(int iPrevious, int iLatter) {
        BitSet prev = (BitSet)this.getRuleCovers(iPrevious).clone();
        BitSet latter = this.getRuleCovers(iLatter);
        prev.and(latter);
        int nbOfTargets = this.getNbOfTargets();
        double avgCov = 0.0;
        double omitted = 0.0;
        for (int iTarget = 0; iTarget < nbOfTargets; ++iTarget) {
            double cov = 0.0;
            double p1 = this.getPredictionsWhenCovered(iPrevious, 0, iTarget);
            double p2 = this.getPredictionsWhenCovered(iLatter, 0, iTarget);
            if (!Double.isNaN(p1) && !Double.isNaN(p2)) {
                cov += p1 * p2;
                if (this.getSettings().getRules().isOptNormalization()) {
                    cov /= this.getNormFactor(iTarget);
                }
                avgCov += cov;
                continue;
            }
            omitted += 1.0;
        }
        if (omitted == (double)nbOfTargets) {
            return 0.0;
        }
        avgCov /= (double)nbOfTargets - omitted;
        return avgCov *= (double)prev.cardinality() / (double)this.getNbOfInstances();
    }

    protected final double predictWithRule(int iRule, int iInstance, int iTarget) {
        return this.isCovered(iRule, iInstance) ? this.getPredictionsWhenCovered(iRule, iInstance, iTarget) : Double.NaN;
    }

    public void fullGradientComputation(ArrayList<Double> weights) {
        for (int iWeight = 0; iWeight < weights.size(); ++iWeight) {
            this.m_gradients[iWeight] = this.getGradient(iWeight, weights);
        }
    }

    protected double getGradient(int iWeightDim, ArrayList<Double> weights) {
        double gradient = 0.0;
        switch (this.getSettings().getRules().getOptGDLossFunction()) {
            case ZeroOneError: 
            case Huber: 
            case RRMSE: {
                try {
                    throw new Exception("0/1 or Huber loss function not yet implemented for Gradient descent.\nUsing squared loss.\n");
                }
                catch (Exception s) {
                    s.printStackTrace();
                }
            }
        }
        gradient = this.gradientSquared(iWeightDim, weights);
        return gradient;
    }

    private double gradientSquared(int iGradWeightDim, ArrayList<Double> weights) {
        double gradient = this.getCovForPrediction(iGradWeightDim);
        for (int iWeight = 0; iWeight < this.getNumVar(); ++iWeight) {
            if (!this.m_isWeightNonZero[iWeight]) continue;
            gradient -= weights.get(iWeight) * this.getWeightCov(iWeight, iGradWeightDim);
        }
        return gradient;
    }

    protected final void modifyGradients(int[] changedWeightIndex, ArrayList<Double> weights) {
        this.modifyGradientSquared(changedWeightIndex);
    }

    public void modifyGradientSquared(int[] iChangedWeights) {
        double[] oldGradsOfChanged = new double[iChangedWeights.length];
        for (int iCopy = 0; iCopy < iChangedWeights.length; ++iCopy) {
            oldGradsOfChanged[iCopy] = this.m_gradients[iChangedWeights[iCopy]];
        }
        boolean firstLinearTermReached = false;
        int nbOfTargs = this.getNbOfTargets();
        int nbOfChanged = iChangedWeights.length;
        int nbOfGrads = this.m_gradients.length;
        for (int iiAffecting = 0; iiAffecting < nbOfChanged; ++iiAffecting) {
            if (!firstLinearTermReached && !this.isRuleTerm(iChangedWeights[iiAffecting])) {
                firstLinearTermReached = true;
            }
            boolean secondLinearTermReached = false;
            double stepAmount = this.m_stepSize * oldGradsOfChanged[iiAffecting];
            for (int iWeightChange = 0; iWeightChange < nbOfGrads; ++iWeightChange) {
                int n = iWeightChange;
                this.m_gradients[n] = this.m_gradients[n] - this.getWeightCov(iChangedWeights[iiAffecting], iWeightChange) * stepAmount;
                if (!firstLinearTermReached) continue;
                if (secondLinearTermReached) {
                    iWeightChange += nbOfTargs - 1;
                    continue;
                }
                if (this.isRuleTerm(iWeightChange)) continue;
                iWeightChange += (this.getLinTargetDim(iChangedWeights[iiAffecting]) + nbOfTargs - 1) % nbOfTargs;
                secondLinearTermReached = true;
            }
        }
    }

    public int[] getMaxGradients(int nbOfIterations) {
        int maxElements = this.getSettings().getRules().getOptGDMaxNbWeights();
        boolean maxNbOfWeightReached = false;
        if (maxElements > 0 && this.m_nbOfNonZeroRules >= maxElements) {
            maxNbOfWeightReached = true;
        }
        double maxGrad = 0.0;
        for (int iGrad = 0; iGrad < this.m_gradients.length; ++iGrad) {
            if (this.m_bannedWeights != null && this.m_bannedWeights[iGrad] > nbOfIterations || !(Math.abs(this.m_gradients[iGrad]) > maxGrad) || maxNbOfWeightReached && !this.m_isWeightNonZero[iGrad] && iGrad != 0) continue;
            maxGrad = Math.abs(this.m_gradients[iGrad]);
        }
        ArrayList<Integer> iMaxGradients = new ArrayList<Integer>();
        double minAllowed = this.getSettings().getRules().getOptGDGradTreshold() * maxGrad;
        for (int iCopy = 0; iCopy < this.m_gradients.length; ++iCopy) {
            if (this.m_bannedWeights != null && this.m_bannedWeights[iCopy] > nbOfIterations || (!(Math.abs(this.m_gradients[iCopy]) >= minAllowed) || maxNbOfWeightReached && !this.m_isWeightNonZero[iCopy]) && iCopy != 0) continue;
            iMaxGradients.add(iCopy);
            if (this.getSettings().getRules().getOptGDGradTreshold() == 1.0 && iCopy != 0) break;
        }
        if (maxElements > 0 && !maxNbOfWeightReached && this.getSettings().getRules().getOptGDGradTreshold() < 1.0) {
            int nbOfOldGrads = 0;
            for (int iGrad = 0; iGrad < iMaxGradients.size(); ++iGrad) {
                if (!this.m_isWeightNonZero[(Integer)iMaxGradients.get(iGrad)] && (Integer)iMaxGradients.get(iGrad) != 0) continue;
                ++nbOfOldGrads;
            }
            int nbOfAllowedNewGradients = maxElements - this.m_nbOfNonZeroRules;
            if (nbOfAllowedNewGradients < iMaxGradients.size() - nbOfOldGrads) {
                LinkedList iAllowedNewMaxGradients = new LinkedList();
                for (int iGrad = 0; iGrad < iMaxGradients.size(); ++iGrad) {
                    if (this.m_isWeightNonZero[(Integer)iMaxGradients.get(iGrad)] || (Integer)iMaxGradients.get(iGrad) == 0) continue;
                    ListIterator<Integer> iAllowed = iAllowedNewMaxGradients.listIterator();
                    while (iAllowed.hasNext()) {
                        if (!(Math.abs(this.m_gradients[(Integer)iAllowed.next()]) < Math.abs(this.m_gradients[(Integer)iMaxGradients.get(iGrad)]))) continue;
                        iAllowed.previous();
                        break;
                    }
                    iAllowed.add((Integer)iMaxGradients.get(iGrad));
                    iMaxGradients.remove(iGrad);
                    --iGrad;
                    if (iAllowedNewMaxGradients.size() <= nbOfAllowedNewGradients) continue;
                    iAllowedNewMaxGradients.removeLast();
                }
                ListIterator iList = iAllowedNewMaxGradients.listIterator();
                for (int addedElements = 0; addedElements < nbOfAllowedNewGradients; ++addedElements) {
                    iMaxGradients.add((Integer)iList.next());
                }
            }
        }
        int[] iMaxGradientsArray = new int[iMaxGradients.size()];
        for (int iCopy = 0; iCopy < iMaxGradients.size(); ++iCopy) {
            iMaxGradientsArray[iCopy] = (Integer)iMaxGradients.get(iCopy);
        }
        return iMaxGradientsArray;
    }

    public final double howMuchWeightChanges(int iTargetWeight) {
        return this.m_stepSize * this.m_gradients[iTargetWeight];
    }

    public void computeCovariancesIfNeeded(int iWeight) {
        if (!this.m_isCovComputed[iWeight]) {
            this.computeWeightCov(iWeight);
            this.m_isCovComputed[iWeight] = true;
        }
        if (!this.m_isWeightNonZero[iWeight]) {
            this.m_isWeightNonZero[iWeight] = true;
            if (iWeight != 0) {
                ++this.m_nbOfNonZeroRules;
            }
        }
    }

    public final void dropStepSize(double amount) {
        if (amount >= 1.0) {
            System.err.println("Something wrong with dropStepSize. Argument >= 1.");
        }
        this.m_stepSize *= amount;
    }

    public double getBestFitness() {
        return this.m_minFitness;
    }

    public boolean isEarlyStop(ArrayList<Double> weights) {
        double newFitness = this.m_earlyStopProbl.calcFitness(weights, this.m_ruleSet);
        if (newFitness < this.m_minFitness) {
            this.m_minFitness = newFitness;
            for (int iWeight = 0; iWeight < weights.size(); ++iWeight) {
                this.m_minFitWeights.set(iWeight, (double)weights.get(iWeight));
            }
        }
        boolean stop = false;
        if (newFitness > this.getSettings().getRules().getOptGDEarlyStopTreshold() * this.m_minFitness) {
            stop = true;
            ClusLogger.fine("GD: Independent test set error increase detected - overfitting.");
        }
        return stop;
    }

    public void restoreBestWeight(ArrayList<Double> targetWeights) {
        for (int iWeight = 0; iWeight < targetWeights.size(); ++iWeight) {
            targetWeights.set(iWeight, (double)this.m_minFitWeights.get(iWeight));
        }
    }

    public static int randDepthWighExponentialDistribution(double unifRand, int avgDepth) {
        int maxDepths = 0;
        if (unifRand == 0.0) {
            maxDepths = -1;
        } else {
            int avgNbLeaves = (int)Math.pow(2.0, avgDepth);
            double terminalNodes = 2.0 + (double)(2 - avgNbLeaves) / Math.log(avgNbLeaves - 2) * Math.log(unifRand);
            maxDepths = (int)Math.ceil(Math.log(terminalNodes) / Math.log(2.0));
        }
        return maxDepths;
    }

    public void printGradientsToFile(int iterNro, PrintWriter wrt) {
        if (!m_printGDDebugInformation) {
            return;
        }
        ClusNumberFormat fr = ClusFormat.SIX_AFTER_DOT;
        wrt.print("Iteration " + iterNro + ":");
        for (int i = 0; i < this.m_gradients.length; ++i) {
            wrt.print(fr.format(this.m_gradients[i]) + "\t");
        }
        wrt.print("\n");
    }
}

