/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.ext.ensemble;

import java.util.HashMap;
import si.ijs.kt.clus.error.common.ClusError;
import si.ijs.kt.clus.main.settings.section.SettingsEnsemble;

public class ClusOOBWeights {
    protected final double SMALL_BUT_MORE_THAN_ZERO = 1.0E-6;
    protected HashMap<Integer, Double> m_AggregatedError;
    protected HashMap<Integer, double[]> m_ComponentErrors;
    protected HashMap<Integer, Double> m_AggregatedWeight;
    protected HashMap<Integer, double[]> m_ComponentWeights;
    protected SettingsEnsemble.EnsembleVotingType m_EnsembleVotingType = null;
    protected double m_Alpha = 1.0;

    public ClusOOBWeights(SettingsEnsemble.EnsembleVotingType et) {
        this.m_EnsembleVotingType = et;
        this.m_AggregatedError = new HashMap();
        this.m_ComponentErrors = new HashMap();
    }

    public ClusOOBWeights getNew(int numberOfModels) {
        if (numberOfModels == this.m_AggregatedError.size()) {
            if (this.m_AggregatedWeight == null || this.m_ComponentWeights == null) {
                this.calculateWeights();
            }
            return this;
        }
        ClusOOBWeights weights = new ClusOOBWeights(this.m_EnsembleVotingType);
        for (int i = 0; i < numberOfModels; ++i) {
            weights.setErrors(i, this.m_AggregatedError.get(i), this.m_ComponentErrors.get(i));
        }
        weights.calculateWeights();
        return weights;
    }

    public void setErrors(int baseModelNumber, double aggregatedError, double[] componentErrors) {
        this.m_AggregatedError.put(baseModelNumber, aggregatedError);
        this.m_ComponentErrors.put(baseModelNumber, componentErrors);
    }

    public void setErrors(int baseModelNumber, ClusError error) {
        double[] components = new double[error.getDimension()];
        for (int i = 0; i < components.length; ++i) {
            components[i] = error.getModelErrorComponent(i);
        }
        this.setErrors(baseModelNumber, error.getModelError(), components);
    }

    public void calculateWeights() {
        switch (this.m_EnsembleVotingType) {
            case OOBModelWeighted: {
                this.calculateAggregateWeights();
                break;
            }
            case OOBTargetWeighted: {
                this.calculateComponentWeights();
                break;
            }
            default: {
                throw new RuntimeException("Selected voting scheme is not OOB-based.");
            }
        }
    }

    protected void calculateAggregateWeights() {
        this.m_AggregatedWeight = new HashMap();
        double sum = this.m_AggregatedError.values().parallelStream().mapToDouble(d -> 1.0 / d).sum();
        for (int i = 0; i < this.m_AggregatedError.size(); ++i) {
            this.m_AggregatedWeight.put(i, 1.0 / this.m_AggregatedError.get(i) / sum);
        }
    }

    protected void calculateComponentWeights() {
        int i;
        int j;
        this.m_ComponentWeights = new HashMap();
        double[] componentSums = new double[this.m_ComponentErrors.get(0).length];
        double[] components = null;
        double[] vals = null;
        for (j = 0; j < this.m_ComponentErrors.size(); ++j) {
            components = this.m_ComponentErrors.get(j);
            for (i = 0; i < components.length; ++i) {
                if (components[i] > 0.0) {
                    int n = i;
                    componentSums[n] = componentSums[n] + 1.0 / components[i];
                    continue;
                }
                int n = i;
                componentSums[n] = componentSums[n] + 1000000.0;
            }
        }
        for (j = 0; j < this.m_ComponentErrors.size(); ++j) {
            vals = new double[this.m_ComponentErrors.get(j).length];
            components = this.m_ComponentErrors.get(j);
            for (i = 0; i < vals.length; ++i) {
                vals[i] = components[i] > 0.0 ? 1.0 / components[i] / componentSums[i] : 1000000.0 / componentSums[i];
            }
            this.m_ComponentWeights.put(j, vals);
        }
    }

    public double getModelWeight(int baseModelNumber) {
        return this.m_AggregatedWeight.get(baseModelNumber);
    }

    public double getComponentWeight(int baseModelNumber, int component) {
        return this.m_ComponentWeights.get(baseModelNumber)[component];
    }
}

