/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.util.tools.optimization.gd;

import java.util.ArrayList;
import si.ijs.kt.clus.algo.rules.ClusRuleSet;
import si.ijs.kt.clus.data.type.ClusAttrType;
import si.ijs.kt.clus.main.ClusStatManager;
import si.ijs.kt.clus.main.settings.section.SettingsRules;
import si.ijs.kt.clus.util.tools.optimization.OptimizationProblem;

public class CallExternGD {
    native double[] externalOptim(String var1, double[] var2, double[] var3, boolean[] var4);

    public static ArrayList<Double> main(ClusStatManager clusStatManager, OptimizationProblem.OptimizationParameter optInfo, ClusRuleSet rset) {
        OptimizationProblem.OptimizationParameter targetData;
        int iMaxInst;
        int iInst;
        int iIndex;
        int nbOfWeights;
        int nbOfRules = nbOfWeights = optInfo.m_rulePredictions.length;
        SettingsRules set = clusStatManager.getSettings().getRules();
        int nbTargs = clusStatManager.getStatistic(ClusAttrType.AttributeUseType.Target).getNbAttributes();
        int nbDescrForDataMatrix = 0;
        int nbRows = optInfo.m_trueValues.length;
        if (set.getOptAddLinearTerms().equals((Object)SettingsRules.OptimizationGDAddLinearTerms.YesSaveMemory)) {
            nbDescrForDataMatrix = clusStatManager.getSchema().getNumericAttrUse(ClusAttrType.AttributeUseType.Descriptive).length;
            nbOfWeights += nbDescrForDataMatrix * nbTargs;
        }
        double[] normFactors = OptimizationProblem.initNormFactors(nbTargs, clusStatManager.getSettings());
        double[] targetAvg = OptimizationProblem.initMeans(nbTargs);
        for (int iTarg = 0; iTarg < nbTargs; ++iTarg) {
            if (optInfo.m_rulePredictions[0].m_prediction[iTarg][0] != targetAvg[iTarg]) {
                System.err.println("Error: Difference in main for target nb " + iTarg + ". The values are " + optInfo.m_rulePredictions[0].m_prediction[iTarg][0] + " and " + targetAvg[iTarg]);
                System.exit(1);
            }
            optInfo.m_rulePredictions[0].m_prediction[iTarg][0] = Math.sqrt(normFactors[iTarg]);
        }
        OptimizationProblem.OptimizationParameter trainingSet = optInfo;
        OptimizationProblem.OptimizationParameter validationSet = null;
        if (set.getOptGDEarlyStopAmount() > 0.0) {
            int nbDataTest = (int)Math.ceil((double)nbRows * set.getOptGDEarlyStopAmount());
            validationSet = new OptimizationProblem.OptimizationParameter(optInfo.m_rulePredictions.length, optInfo.m_baseFuncPredictions.length, nbDataTest, nbTargs, optInfo.m_implicitLinearTerms);
            trainingSet = new OptimizationProblem.OptimizationParameter(optInfo.m_rulePredictions.length, optInfo.m_baseFuncPredictions.length, nbRows - nbDataTest, nbTargs, optInfo.m_implicitLinearTerms);
            OptimizationProblem.splitDataIntoValAndTrainSet(clusStatManager, optInfo, validationSet, trainingSet);
        }
        double[] weights = new double[nbOfWeights];
        double[] rulePreds = new double[nbOfRules * nbTargs];
        for (int iRule = 0; iRule < nbOfRules; ++iRule) {
            for (int iTarg = 0; iTarg < nbTargs; ++iTarg) {
                rulePreds[iRule * nbTargs + iTarg] = trainingSet.m_rulePredictions[iRule].m_prediction[iTarg][0] / Math.sqrt(normFactors[iTarg]);
            }
        }
        boolean[] ruleCovers = new boolean[nbOfRules * nbRows];
        int nbInstTrain = trainingSet.m_rulePredictions[0].m_cover.length();
        int nbInstVal = validationSet.m_rulePredictions[0].m_cover.length();
        for (int iRule = 0; iRule < nbOfRules; ++iRule) {
            iIndex = 0;
            iInst = 0;
            iMaxInst = nbInstTrain;
            targetData = trainingSet;
            while (iInst < iMaxInst) {
                ruleCovers[iRule * nbRows + iIndex] = targetData.m_rulePredictions[iRule].m_cover.get(iInst);
                if (iInst == nbInstTrain - 1) {
                    iInst = -1;
                    targetData = validationSet;
                    iMaxInst = nbInstVal;
                }
                ++iInst;
                ++iIndex;
            }
        }
        double[] binData = new double[nbRows * (nbTargs + nbDescrForDataMatrix)];
        iIndex = 0;
        iInst = 0;
        iMaxInst = nbInstTrain;
        targetData = trainingSet;
        while (iInst < iMaxInst) {
            for (int iDescrDim = 0; iDescrDim < nbDescrForDataMatrix; ++iDescrDim) {
                binData[iIndex * (nbTargs + nbDescrForDataMatrix) + iDescrDim] = targetData.m_implicitLinearTerms.predict(iDescrDim * nbTargs, targetData.m_trueValues[iInst].m_dataExample, 0, nbTargs) / Math.sqrt(normFactors[0]);
            }
            for (int iTarDim = 0; iTarDim < nbTargs; ++iTarDim) {
                binData[iIndex * (nbTargs + nbDescrForDataMatrix) + nbDescrForDataMatrix + iTarDim] = (targetData.m_trueValues[iInst].m_targets[iTarDim] - targetAvg[iTarDim]) / Math.sqrt(normFactors[iTarDim]);
            }
            if (iInst == nbInstTrain - 1) {
                iInst = -1;
                targetData = validationSet;
                iMaxInst = nbInstVal;
            }
            ++iInst;
            ++iIndex;
        }
        CallExternGD mappedFile = new CallExternGD();
        String settings = "";
        if (set.getOptAddLinearTerms().equals((Object)SettingsRules.OptimizationGDAddLinearTerms.YesSaveMemory)) {
            settings = settings + "linTermsUsed 1\n";
        }
        settings = settings + "nbOfTargs " + nbTargs + "\nnbTrainData " + nbInstTrain + "\nnbValData " + nbInstVal + "\nnbOfRules " + nbOfRules + "\nnbOfDescrAttr " + nbDescrForDataMatrix;
        settings = settings + "\nnbOfIterations " + set.getOptGDMaxIter() + "\nminTVal " + set.getOptGDGradTreshold();
        settings = settings + "\nnbOfDiffTVal " + set.getOptGDNbOfTParameterTry() + "\nnbNonZeroWeights " + set.getOptGDMaxNbWeights() + "\n";
        weights = mappedFile.externalOptim(settings, binData, rulePreds, ruleCovers);
        weights[0] = CallExternGD.undoNormalization(weights[0], rulePreds, nbTargs, targetAvg, rset, normFactors);
        ArrayList<Double> result = new ArrayList<Double>(nbOfWeights);
        for (int i = 0; i < nbOfWeights; ++i) {
            result.add(weights[i]);
        }
        return result;
    }

    private static final double undoNormalization(double defaultRuleWeights, double[] defaultRulePreds, int nbTargs, double[] targMeans, ClusRuleSet rset, double[] normFactors) {
        double[] newDefault = new double[nbTargs];
        for (int iTarg = 0; iTarg < nbTargs; ++iTarg) {
            newDefault[iTarg] = defaultRulePreds[iTarg] * Math.sqrt(normFactors[iTarg]) * defaultRuleWeights + targMeans[iTarg];
        }
        rset.getRule(0).setNumericPrediction(newDefault);
        return 1.0;
    }

    static {
        System.loadLibrary("GDInterface");
    }
}

