/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.ext.ensemble.ros;

import java.util.HashMap;
import si.ijs.kt.clus.ext.ensemble.ClusOOBWeights;
import si.ijs.kt.clus.ext.ensemble.ros.ClusROSForestInfo;
import si.ijs.kt.clus.ext.ensemble.ros.ClusROSModelInfo;
import si.ijs.kt.clus.main.settings.section.SettingsEnsemble;

public class ClusROSOOBWeights
extends ClusOOBWeights {
    private ClusROSForestInfo m_ROSForestInfo = null;
    private SettingsEnsemble.EnsembleROSVotingType m_ROSVotingType = null;

    public ClusROSOOBWeights(SettingsEnsemble.EnsembleVotingType et, SettingsEnsemble.EnsembleROSVotingType rosVT) {
        super(et);
        this.m_ROSVotingType = rosVT;
    }

    @Override
    protected void calculateAggregateWeights() {
        if (this.m_ROSVotingType.equals((Object)SettingsEnsemble.EnsembleROSVotingType.TotalAveraging)) {
            super.calculateAggregateWeights();
        } else {
            double error;
            int j;
            this.m_ComponentWeights = new HashMap();
            double[] componentSums = new double[((double[])this.m_ComponentErrors.get(0)).length];
            double[] vals = null;
            for (j = 0; j < this.m_ComponentErrors.size(); ++j) {
                ClusROSModelInfo info = this.m_ROSForestInfo.getROSModelInfo(j);
                for (int i : info.getTargets()) {
                    error = (Double)this.m_AggregatedError.get(j);
                    if (error > 0.0) {
                        int n = i;
                        componentSums[n] = componentSums[n] + 1.0 / (Double)this.m_AggregatedError.get(j);
                        continue;
                    }
                    int n = i;
                    componentSums[n] = componentSums[n] + 1000000.0;
                }
            }
            for (j = 0; j < this.m_AggregatedError.size(); ++j) {
                vals = new double[((double[])this.m_ComponentErrors.get(j)).length];
                for (int i = 0; i < vals.length; ++i) {
                    error = (Double)this.m_AggregatedError.get(j);
                    vals[i] = error > 0.0 ? 1.0 / (Double)this.m_AggregatedError.get(j) / componentSums[i] : 1000000.0 / componentSums[i];
                }
                this.m_ComponentWeights.put(j, vals);
            }
        }
    }

    @Override
    protected void calculateComponentWeights() {
        if (this.m_ROSVotingType.equals((Object)SettingsEnsemble.EnsembleROSVotingType.TotalAveraging)) {
            super.calculateComponentWeights();
        } else {
            this.m_ComponentWeights = new HashMap();
            double[] componentSums = new double[((double[])this.m_ComponentErrors.get(0)).length];
            double[] components = null;
            double[] vals = null;
            for (int model = 0; model < this.m_ComponentErrors.size(); ++model) {
                components = (double[])this.m_ComponentErrors.get(model);
                ClusROSModelInfo info = this.m_ROSForestInfo.getROSModelInfo(model);
                for (int target : info.getTargets()) {
                    if (components[target] > 0.0) {
                        int n = target;
                        componentSums[n] = componentSums[n] + 1.0 / components[target];
                        continue;
                    }
                    int n = target;
                    componentSums[n] = componentSums[n] + 1000000.0;
                }
                for (Integer target : this.m_ROSForestInfo.getTargetsNotLearned()) {
                    if (components[target] > 0.0) {
                        int n = target;
                        componentSums[n] = componentSums[n] + 1.0 / components[target];
                        continue;
                    }
                    int n = target;
                    componentSums[n] = componentSums[n] + 1000000.0;
                }
            }
            for (int j = 0; j < this.m_ComponentErrors.size(); ++j) {
                vals = new double[((double[])this.m_ComponentErrors.get(j)).length];
                components = (double[])this.m_ComponentErrors.get(j);
                for (int i = 0; i < vals.length; ++i) {
                    vals[i] = components[i] > 0.0 ? 1.0 / components[i] / componentSums[i] : 1000000.0 / componentSums[i];
                }
                this.m_ComponentWeights.put(j, vals);
            }
        }
    }

    public ClusROSOOBWeights getNew(int numberOfModels, ClusROSForestInfo info) {
        this.m_ROSForestInfo = info;
        if (numberOfModels == this.m_AggregatedError.size()) {
            if (this.m_AggregatedWeight == null || this.m_ComponentWeights == null) {
                this.calculateWeights();
            }
            return this;
        }
        ClusROSOOBWeights weights = new ClusROSOOBWeights(this.m_EnsembleVotingType, this.m_ROSVotingType);
        for (int i = 0; i < numberOfModels; ++i) {
            weights.setErrors(i, (Double)this.m_AggregatedError.get(i), (double[])this.m_ComponentErrors.get(i));
        }
        weights.m_ROSForestInfo = this.m_ROSForestInfo.getNew(numberOfModels);
        weights.calculateWeights();
        return weights;
    }
}

