/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.ext.ensemble;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import si.ijs.kt.clus.Clus;
import si.ijs.kt.clus.data.ClusSchema;
import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.data.rows.RowData;
import si.ijs.kt.clus.error.common.ClusErrorList;
import si.ijs.kt.clus.ext.ensemble.ClusEnsembleInduceOptimization;
import si.ijs.kt.clus.ext.ensemble.ClusReadWriteLock;
import si.ijs.kt.clus.main.ClusOutput;
import si.ijs.kt.clus.main.ClusRun;
import si.ijs.kt.clus.main.ClusStatManager;
import si.ijs.kt.clus.main.settings.Settings;
import si.ijs.kt.clus.main.settings.section.SettingsEnsemble;
import si.ijs.kt.clus.main.settings.section.SettingsSSL;
import si.ijs.kt.clus.model.ClusModel;
import si.ijs.kt.clus.model.ClusModelInfo;
import si.ijs.kt.clus.model.processor.ModelProcessorCollection;
import si.ijs.kt.clus.selection.OOBSelection;
import si.ijs.kt.clus.statistic.ClassificationStat;
import si.ijs.kt.clus.statistic.ClusStatistic;
import si.ijs.kt.clus.statistic.RegressionStat;
import si.ijs.kt.clus.statistic.WHTDStatistic;
import si.ijs.kt.clus.util.exception.ClusException;

public class ClusOOBErrorEstimate {
    static HashMap<Integer, Object> m_OOBPredictions;
    static HashMap<Integer, Integer> m_OOBUsage;
    static boolean m_OOBCalculation;
    ClusStatManager.Mode m_Mode;
    Settings m_Settings;
    static HashMap m_OOBVotes;
    static HashMap OOBMapping;
    static ClusReadWriteLock m_LockPredictions;
    static ClusReadWriteLock m_LockUsage;
    static ClusReadWriteLock m_LockCalculation;

    public ClusOOBErrorEstimate(ClusStatManager.Mode mode, Settings sett) {
        m_OOBPredictions = new HashMap();
        m_OOBUsage = new HashMap();
        m_OOBCalculation = false;
        this.m_Mode = mode;
        this.m_Settings = sett;
        m_OOBVotes = new HashMap();
        OOBMapping = new HashMap();
    }

    private Settings getSettings() {
        return this.m_Settings;
    }

    public static boolean containsPredictionForTuple(DataTuple tuple) throws InterruptedException {
        m_LockPredictions.readingLock();
        boolean contains = m_OOBPredictions.containsKey(tuple.hashCode());
        m_LockPredictions.readingUnlock();
        return contains;
    }

    public static double[] getPredictionForRegressionHMCTuple(DataTuple tuple) throws InterruptedException {
        m_LockPredictions.readingLock();
        double[] pred = (double[])m_OOBPredictions.get(tuple.hashCode());
        double[] predictions = Arrays.copyOf(pred, pred.length);
        m_LockPredictions.readingUnlock();
        return predictions;
    }

    public static double[][] getPredictionForClassificationTuple(DataTuple tuple) throws InterruptedException {
        m_LockPredictions.readingLock();
        double[][] pred = (double[][])m_OOBPredictions.get(tuple.hashCode());
        double[][] predictions = new double[pred.length][];
        for (int i = 0; i < pred.length; ++i) {
            predictions[i] = Arrays.copyOf(pred[i], pred[i].length);
        }
        m_LockPredictions.readingUnlock();
        return predictions;
    }

    public static ArrayList getVotesForTuple(DataTuple tuple) {
        return (ArrayList)m_OOBVotes.get(tuple.hashCode());
    }

    public synchronized void postProcessForestForOOBEstimate(ClusRun cr, OOBSelection oob_total, RowData all_data, Clus cl, String addname) throws ClusException, IOException, InterruptedException {
        Settings sett = cr.getStatManager().getSettings();
        ClusSchema schema = all_data.getSchema();
        ClusOutput output = new ClusOutput(sett.getGeneric().getAppName() + addname + ".oob", schema, sett);
        this.setOOBCalculation(true);
        if (sett.getOutput().isWriteOOBFile() || sett.getEnsemble().shouldEstimateOOB()) {
            this.calcOOBError(oob_total, all_data, 0, cr);
            cl.calcExtraTrainingSetErrors(cr);
            output.writeHeader();
            output.writeOutput(cr, true, cl.getSettings().getOutput().isOutTrainError());
            output.close();
        }
        this.setOOBCalculation(false);
    }

    public synchronized void updateOOBTuples(OOBSelection oob_sel, RowData train_data, ClusModel model, int modelNo) throws IOException, ClusException, InterruptedException {
        for (int i = 0; i < train_data.getNbRows(); ++i) {
            if (!oob_sel.isSelected(i)) continue;
            DataTuple tuple = train_data.getTuple(i);
            if (this.existsOOBtuple(tuple)) {
                this.updateOOBTuple(tuple, model);
            } else {
                this.addOOBTuple(tuple, model);
            }
            this.updateOOBMapping(tuple, modelNo);
        }
    }

    public synchronized void updateOOBMapping(DataTuple tuple, int treeNumber) {
        if (OOBMapping.containsKey(treeNumber)) {
            ((ArrayList)OOBMapping.get(treeNumber)).add(tuple.hashCode());
        } else {
            ArrayList<Integer> hashCodes = new ArrayList<Integer>();
            hashCodes.add(tuple.hashCode());
            OOBMapping.put(treeNumber, hashCodes);
        }
    }

    public boolean existsOOBtuple(DataTuple tuple) throws InterruptedException {
        boolean exists = false;
        boolean existsInUsage = this.existsInOOBUsage(tuple);
        boolean existsInPred = this.existsInOOBPredictions(tuple);
        if (existsInUsage && existsInPred) {
            exists = true;
        }
        if (!existsInUsage && existsInPred) {
            System.err.println(this.getClass().getName() + ":existsOOBtuple(DataTuple) OOB tuples mismatch-> Usage = False, Predictions = True");
        }
        if (existsInUsage && !existsInPred) {
            System.err.println(this.getClass().getName() + ":existsOOBtuple(DataTuple) OOB tuples mismatch-> Usage = True, Predictions = False");
        }
        return exists;
    }

    public void addOOBTuple(DataTuple tuple, ClusModel model) throws ClusException, InterruptedException {
        this.putToOOBUsage(tuple, 1);
        ClusStatistic stat = model.predictWeighted(tuple);
        switch (this.m_Mode) {
            case HIERARCHICAL: {
                this.put1DArrayToOOBPredictions(tuple, ((WHTDStatistic)stat).getNumericPred());
                break;
            }
            case REGRESSION: {
                this.put1DArrayToOOBPredictions(tuple, ((RegressionStat)stat).getNumericPred());
                break;
            }
            case CLASSIFY: {
                if (this.getSettings().getEnsemble().getEnsembleVotingType().equals((Object)SettingsEnsemble.EnsembleVotingType.ProbabilityDistribution)) {
                    this.put2DArrayToOOBPredictions(tuple, ClusEnsembleInduceOptimization.transformToProbabilityDistribution(((ClassificationStat)stat).m_ClassCounts));
                    break;
                }
                this.put2DArrayToOOBPredictions(tuple, ClusEnsembleInduceOptimization.transformToMajority(((ClassificationStat)stat).m_ClassCounts));
            }
        }
        if (Arrays.asList(SettingsSSL.SSLUnlabeledCriteria.AutomaticOOB, SettingsSSL.SSLUnlabeledCriteria.AutomaticOOBInitial).contains((Object)this.getSettings().getSSL().getUnlabeledCriteria())) {
            ArrayList<ClusStatistic> votes = new ArrayList<ClusStatistic>();
            votes.add(stat);
            m_OOBVotes.put(tuple.hashCode(), votes);
        }
    }

    public void updateOOBTuple(DataTuple tuple, ClusModel model) throws ClusException, InterruptedException {
        Integer used = this.getFromOOBUsage(tuple);
        used = used + 1;
        this.putToOOBUsage(tuple, used);
        ClusStatistic stat = model.predictWeighted(tuple);
        switch (this.m_Mode) {
            case HIERARCHICAL: {
                double[] predictions = ((WHTDStatistic)stat).getNumericPred();
                double[] avg_predictions = this.get1DArrayFromOOBPredictions(tuple);
                avg_predictions = ClusEnsembleInduceOptimization.incrementPredictions(avg_predictions, predictions, used.doubleValue());
                this.put1DArrayToOOBPredictions(tuple, avg_predictions);
                break;
            }
            case REGRESSION: {
                double[] predictions = ((RegressionStat)stat).getNumericPred();
                double[] avg_predictions = this.get1DArrayFromOOBPredictions(tuple);
                avg_predictions = ClusEnsembleInduceOptimization.incrementPredictions(avg_predictions, predictions, used.doubleValue());
                this.put1DArrayToOOBPredictions(tuple, avg_predictions);
                break;
            }
            case CLASSIFY: {
                ClassificationStat statc = (ClassificationStat)stat;
                double[][] preds = (double[][])statc.m_ClassCounts.clone();
                preds = this.getSettings().getEnsemble().getEnsembleVotingType().equals((Object)SettingsEnsemble.EnsembleVotingType.ProbabilityDistribution) ? ClusEnsembleInduceOptimization.transformToProbabilityDistribution(preds) : ClusEnsembleInduceOptimization.transformToMajority(preds);
                double[][] sum_predictions = this.get2DArrayFromOOBPredictions(tuple);
                sum_predictions = ClusEnsembleInduceOptimization.incrementPredictions(sum_predictions, preds);
                this.put2DArrayToOOBPredictions(tuple, sum_predictions);
            }
        }
        if (Arrays.asList(SettingsSSL.SSLUnlabeledCriteria.AutomaticOOB, SettingsSSL.SSLUnlabeledCriteria.AutomaticOOBInitial).contains((Object)this.getSettings().getSSL().getUnlabeledCriteria())) {
            ((ArrayList)m_OOBVotes.get(tuple.hashCode())).add(stat);
        }
    }

    public final void calcOOBError(OOBSelection oob_tot, RowData all_data, int type, ClusRun cr) throws IOException, ClusException, InterruptedException {
        ClusSchema mschema = all_data.getSchema();
        cr.initModelProcessors(type, mschema);
        ModelProcessorCollection allcoll = cr.getAllModelsMI().getAddModelProcessors(type);
        for (int t = 0; t < all_data.getNbRows(); ++t) {
            if (!oob_tot.isSelected(t)) continue;
            DataTuple tuple = all_data.getTuple(t);
            allcoll.exampleUpdate(tuple);
            for (int i = 0; i < cr.getNbModels(); ++i) {
                ModelProcessorCollection coll;
                ClusModelInfo mi = cr.getModelInfo(i);
                ClusModel model = mi.getModel();
                if (model == null) continue;
                ClusStatistic pred = model.predictWeighted(tuple);
                ClusErrorList err = mi.getError(type);
                if (err != null) {
                    err.addExample(tuple, pred);
                }
                if ((coll = mi.getModelProcessors(type)) == null) continue;
                if (coll.needsModelUpdate()) {
                    model.applyModelProcessors(tuple, coll);
                    coll.modelDone();
                }
                coll.exampleUpdate(tuple, pred);
            }
            allcoll.exampleDone();
        }
        cr.termModelProcessors(type);
    }

    public static boolean isOOBForTree(DataTuple tuple, int treeNumber) throws InterruptedException {
        m_LockUsage.readingLock();
        if (!OOBMapping.containsKey(treeNumber)) {
            boolean isOOB = false;
            m_LockUsage.readingUnlock();
            return isOOB;
        }
        boolean isOOB = ((ArrayList)OOBMapping.get(treeNumber)).contains(tuple.hashCode());
        m_LockUsage.readingUnlock();
        return isOOB;
    }

    public static boolean isOOBCalculation() throws InterruptedException {
        m_LockCalculation.readingLock();
        boolean isCalc = m_OOBCalculation;
        m_LockCalculation.readingUnlock();
        return isCalc;
    }

    public void setOOBCalculation(boolean value) throws InterruptedException {
        m_LockCalculation.writingLock();
        m_OOBCalculation = value;
        m_LockCalculation.writingUnlock();
    }

    private boolean existsInOOBPredictions(DataTuple tuple) throws InterruptedException {
        m_LockPredictions.readingLock();
        boolean exists = m_OOBPredictions.containsKey(tuple.hashCode());
        m_LockPredictions.readingUnlock();
        return exists;
    }

    public void put1DArrayToOOBPredictions(DataTuple tuple, double[] value) throws InterruptedException {
        m_LockPredictions.writingLock();
        m_OOBPredictions.put(tuple.hashCode(), value);
        m_LockPredictions.writingUnlock();
    }

    public void put2DArrayToOOBPredictions(DataTuple tuple, double[][] value) throws InterruptedException {
        m_LockPredictions.writingLock();
        m_OOBPredictions.put(tuple.hashCode(), value);
        m_LockPredictions.writingUnlock();
    }

    private double[] get1DArrayFromOOBPredictions(DataTuple tuple) throws InterruptedException {
        m_LockPredictions.readingLock();
        double[] pred = (double[])m_OOBPredictions.get(tuple.hashCode());
        double[] predictions = Arrays.copyOf(pred, pred.length);
        m_LockPredictions.readingUnlock();
        return predictions;
    }

    private double[][] get2DArrayFromOOBPredictions(DataTuple tuple) throws InterruptedException {
        m_LockPredictions.readingLock();
        double[][] pred = (double[][])m_OOBPredictions.get(tuple.hashCode());
        double[][] predictions = new double[pred.length][];
        for (int i = 0; i < pred.length; ++i) {
            predictions[i] = Arrays.copyOf(pred[i], pred[i].length);
        }
        m_LockPredictions.readingUnlock();
        return predictions;
    }

    private boolean existsInOOBUsage(DataTuple tuple) throws InterruptedException {
        m_LockUsage.readingLock();
        boolean exists = m_OOBUsage.containsKey(tuple.hashCode());
        m_LockUsage.readingUnlock();
        return exists;
    }

    private void putToOOBUsage(DataTuple tuple, int i) throws InterruptedException {
        m_LockUsage.writingLock();
        m_OOBUsage.put(tuple.hashCode(), i);
        m_LockUsage.writingUnlock();
    }

    public Integer getFromOOBUsage(DataTuple tuple) throws InterruptedException {
        m_LockUsage.readingLock();
        Integer i = m_OOBUsage.get(tuple.hashCode());
        m_LockUsage.readingUnlock();
        return i;
    }

    static {
        m_LockPredictions = new ClusReadWriteLock();
        m_LockUsage = new ClusReadWriteLock();
        m_LockCalculation = new ClusReadWriteLock();
    }
}

