/*
 * Decompiled with CFR 0.152.
 */
package water.rapids;

import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.ModelMetricsRegression;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.stream.IntStream;
import water.Keyed;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.MRUtils;
import water.util.TwoDimTable;
import water.util.VecUtils;

public class PermutationVarImp {
    private final Model _model;
    private final Frame _inputFrame;

    public PermutationVarImp(Model model, Frame fr) {
        if (fr.numRows() < 2L) {
            throw new IllegalArgumentException("Frame must contain more than 1 rows to be used in permutation variable importance!");
        }
        if (!ArrayUtils.contains(fr.names(), ((Model.Parameters)model._parms)._response_column)) {
            throw new IllegalArgumentException("Frame must contain the response column for the use in permutation variable importance!");
        }
        this._model = model;
        this._inputFrame = fr;
    }

    private static double getMetric(ModelMetrics mm4, String metric) {
        assert (mm4 != null);
        double metricValue = ModelMetrics.getMetricFromModelMetric(mm4, metric);
        if (Double.isNaN(metricValue)) {
            throw new IllegalArgumentException("Model doesn't support the metric following metric " + metric);
        }
        return metricValue;
    }

    private String inferAndValidateMetric(String metric) {
        Set<String> allowed_metrics = ModelMetrics.getAllowedMetrics(this._model._key);
        if ((metric = metric.toLowerCase()).equals("auto")) {
            if (((Model.Output)this._model._output)._training_metrics instanceof ModelMetricsBinomial) {
                metric = "auc";
            } else if (((Model.Output)this._model._output)._training_metrics instanceof ModelMetricsRegression) {
                metric = "rmse";
            } else if (((Model.Output)this._model._output)._training_metrics instanceof ModelMetricsMultinomial) {
                metric = "logloss";
            } else {
                throw new IllegalArgumentException("Unable to infer metric. Please specify metric for permutation variable importance.");
            }
        }
        if (!allowed_metrics.contains(metric)) {
            throw new IllegalArgumentException("Permutation Variable Importance doesn't support " + metric + " for model " + this._model._key);
        }
        return metric;
    }

    private Future<Vec> precomputeShuffledVec(ExecutorService executor, Frame fr, HashSet<String> featuresToCompute, String[] variables, int currentFeature, long seed) {
        for (int f2 = currentFeature + 1; f2 < fr.numCols(); ++f2) {
            if (!featuresToCompute.contains(variables[f2])) continue;
            int finalF = f2;
            return executor.submit(() -> VecUtils.shuffleVec(fr.vec(finalF), seed));
        }
        return null;
    }

    Map<String, Double> calculatePermutationVarImp(String metric, long n_samples, String[] features, long seed) {
        if (-1L == seed) {
            seed = new Random().nextLong();
        }
        if (n_samples == 1L) {
            throw new IllegalArgumentException("Unable to permute one row. Please set n_samples to higher value or to -1 to use the whole dataset.");
        }
        String[] variables = this._inputFrame.names();
        HashSet<String> featuresToCompute = new HashSet<String>(Arrays.asList(null != features && features.length > 0 ? features : variables));
        featuresToCompute.removeAll(Arrays.asList(((Model.Parameters)this._model._parms).getNonPredictors()));
        if (((Model.Parameters)this._model._parms)._ignored_columns != null) {
            featuresToCompute.removeAll(Arrays.asList(((Model.Parameters)this._model._parms)._ignored_columns));
        }
        Frame fr = n_samples > 1L ? (n_samples > 1000L || ((Model.Parameters)this._model._parms)._weights_column != null ? MRUtils.sampleFrame(this._inputFrame, n_samples, ((Model.Parameters)this._model._parms)._weights_column, seed) : MRUtils.sampleFrameSmall(this._inputFrame, (int)n_samples, seed)) : this._inputFrame;
        this._model.score(fr).remove();
        double origMetric = PermutationVarImp.getMetric(ModelMetrics.getFromDKV(this._model, fr), metric);
        ExecutorService executor = Executors.newSingleThreadExecutor();
        Keyed shuffledFeature = null;
        Future<Vec> shuffledFeatureFuture = this.precomputeShuffledVec(executor, fr, featuresToCompute, variables, -1, seed);
        HashMap<String, Double> result = new HashMap<String, Double>();
        try {
            for (int f2 = 0; f2 < fr.numCols(); ++f2) {
                if (!featuresToCompute.contains(variables[f2])) continue;
                assert (shuffledFeatureFuture != null);
                shuffledFeature = shuffledFeatureFuture.get();
                shuffledFeatureFuture = this.precomputeShuffledVec(executor, fr, featuresToCompute, variables, f2, seed);
                Vec origFeature = fr.replace(f2, (Vec)shuffledFeature);
                this._model.score(fr).remove();
                result.put(variables[f2], Math.abs(PermutationVarImp.getMetric(ModelMetrics.getFromDKV(this._model, fr), metric) - origMetric));
                fr.replace(f2, origFeature);
                shuffledFeature.remove();
                shuffledFeature = null;
            }
        }
        catch (InterruptedException | ExecutionException e2) {
            throw new RuntimeException("Unable to calculate the permutation variable importance.", e2);
        }
        finally {
            if (null != fr && fr != this._inputFrame) {
                fr.remove();
            }
            if (null != shuffledFeature) {
                shuffledFeature.remove();
            }
            if (null != shuffledFeatureFuture) {
                shuffledFeatureFuture.cancel(true);
            }
            executor.shutdownNow();
        }
        return result;
    }

    public TwoDimTable getPermutationVarImp(String metric, long n_samples, String[] features, long seed) {
        metric = this.inferAndValidateMetric(metric);
        Map<String, Double> varImps = this.calculatePermutationVarImp(metric, n_samples, features, seed);
        String[] names = new String[varImps.size()];
        double[] importance = new double[varImps.size()];
        int i2 = 0;
        for (Map.Entry<String, Double> entry : varImps.entrySet()) {
            names[i2] = entry.getKey();
            importance[i2++] = entry.getValue();
        }
        return ModelMetrics.calcVarImp(importance, names);
    }

    public TwoDimTable getRepeatedPermutationVarImp(String metric, long n_samples, int n_repeats, String[] features, long seed) {
        metric = this.inferAndValidateMetric(metric);
        HashMap[] varImps = new HashMap[n_repeats];
        for (int i2 = 0; i2 < n_repeats; ++i2) {
            varImps[i2] = this.calculatePermutationVarImp(metric, n_samples, features, seed == -1L ? -1L : seed + (long)i2);
        }
        String[] names = new String[varImps[0].size()];
        double[][] importance = new double[varImps[0].size()][n_repeats];
        ArrayList sortedFeatures = new ArrayList(varImps[0].entrySet());
        sortedFeatures.sort(Map.Entry.comparingByValue(Collections.reverseOrder()));
        int i3 = 0;
        for (Map.Entry entry : sortedFeatures) {
            names[i3] = (String)entry.getKey();
            for (int j2 = 0; j2 < n_repeats; ++j2) {
                importance[i3][j2] = (Double)varImps[j2].get(entry.getKey());
            }
            ++i3;
        }
        return new TwoDimTable("Repeated Permutation Variable Importance", null, names, (String[])IntStream.range(0, n_repeats).mapToObj(run -> "Run " + (run + 1)).toArray(String[]::new), (String[])IntStream.range(0, n_repeats).mapToObj(run -> "double").toArray(String[]::new), null, "Variable", new String[names.length][], importance);
    }

    public TwoDimTable getPermutationVarImp(String metric) {
        return this.getPermutationVarImp(metric, -1L, null, -1L);
    }
}

