/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.ext.ensemble;

import java.io.IOException;
import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.data.rows.TupleIterator;
import si.ijs.kt.clus.ext.ensemble.ClusEnsembleInduceOptimization;
import si.ijs.kt.clus.ext.ensemble.ros.ClusEnsembleROSInfo;
import si.ijs.kt.clus.main.settings.Settings;
import si.ijs.kt.clus.model.ClusModel;
import si.ijs.kt.clus.statistic.ClassificationStat;
import si.ijs.kt.clus.statistic.ClusStatistic;
import si.ijs.kt.clus.util.ClusUtil;
import si.ijs.kt.clus.util.exception.ClusException;

public class ClusEnsembleInduceOptClassification
extends ClusEnsembleInduceOptimization {
    private static final long serialVersionUID = 1L;
    private double[][][] m_AvgPredictions;

    public ClusEnsembleInduceOptClassification(TupleIterator train, TupleIterator test, Settings sett) throws IOException, ClusException {
        super(train, test, sett);
    }

    @Override
    public void initPredictions(ClusStatistic stat, ClusEnsembleROSInfo ensembleROSInfo) {
        ClassificationStat nstat = (ClassificationStat)stat;
        this.m_AvgPredictions = new double[this.m_TuplePositions.size()][nstat.getNbAttributes()][];
        for (int tuple = 0; tuple < this.m_AvgPredictions.length; ++tuple) {
            for (int i = 0; i < nstat.getNbAttributes(); ++i) {
                this.m_AvgPredictions[tuple][i] = new double[nstat.getNbClasses(i)];
            }
        }
        this.m_EnsembleROSInfo = ensembleROSInfo;
    }

    @Override
    public synchronized void updatePredictionsForTuples(ClusModel model, TupleIterator train, TupleIterator test) throws IOException, ClusException, InterruptedException {
        this.m_NbUpdatesLock.writingLock();
        this.m_AvgPredictionsLock.writingLock();
        ++this.m_NbUpdates;
        if (this.getSettings().getEnsemble().isEnsembleROSEnabled()) {
            int[] enabledTargets = this.m_EnsembleROSInfo.getOnlyTargets(this.m_EnsembleROSInfo.getModelSubspace(this.m_NbUpdates - 1));
            this.m_EnsembleROSInfo.incrementCoverageOpt(enabledTargets);
        }
        this.updateTuplesWithModel(train, model);
        this.updateTuplesWithModel(test, model);
        this.m_AvgPredictionsLock.writingUnlock();
        this.m_NbUpdatesLock.writingUnlock();
    }

    private void updateTuplesWithModel(TupleIterator iterator, ClusModel model) throws IOException, ClusException, InterruptedException {
        if (iterator != null) {
            iterator.init();
            DataTuple tuple = iterator.readTuple();
            double[][] tmp = (double[][])((ClassificationStat)model.predictWeighted(tuple)).getClassCounts().clone();
            double[][] zeros = new double[tmp.length][tmp[0].length];
            for (int i = 0; i < zeros.length; ++i) {
                for (int j = 0; j < zeros[i].length; ++j) {
                    zeros[i][j] = 0.0;
                }
            }
            while (tuple != null) {
                int position = this.locateTuple(tuple);
                if (this.m_NbUpdates == 1) {
                    this.m_AvgPredictions[position] = (double[][])zeros.clone();
                }
                ClassificationStat stat = (ClassificationStat)model.predictWeighted(tuple);
                double[][] counts = (double[][])stat.getClassCounts().clone();
                switch (this.getSettings().getEnsemble().getEnsembleVotingType()) {
                    case Majority: {
                        counts = ClusEnsembleInduceOptClassification.transformToMajority(counts);
                        break;
                    }
                    default: {
                        counts = ClusEnsembleInduceOptClassification.transformToProbabilityDistribution(counts);
                    }
                }
                this.m_AvgPredictions[position] = this.incrementPredictions(this.m_AvgPredictions[position], counts, this.m_NbUpdates);
                tuple = iterator.readTuple();
            }
            iterator.init();
        }
    }

    @Override
    public int getPredictionLength(int tuple) {
        return this.m_AvgPredictions[tuple].length;
    }

    public double[] getPredictionValueClassification(int tuple, int attribute) {
        return this.m_AvgPredictions[tuple][attribute];
    }

    @Override
    public void roundPredictions() {
        for (int i = 0; i < this.m_AvgPredictions.length; ++i) {
            for (int j = 0; j < this.m_AvgPredictions[i].length; ++j) {
                for (int k = 0; k < this.m_AvgPredictions[i][j].length; ++k) {
                    this.m_AvgPredictions[i][j][k] = ClusUtil.roundToSignificantFigures(this.m_AvgPredictions[i][j][k], 4);
                }
            }
        }
    }
}

