/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.ext.semisupervised;

import java.io.IOException;
import si.ijs.kt.clus.algo.ClusInductionAlgorithm;
import si.ijs.kt.clus.data.ClusSchema;
import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.data.rows.RowData;
import si.ijs.kt.clus.data.type.ClusAttrType;
import si.ijs.kt.clus.data.type.primitive.NominalAttrType;
import si.ijs.kt.clus.data.type.primitive.NumericAttrType;
import si.ijs.kt.clus.error.Accuracy;
import si.ijs.kt.clus.error.RMSError;
import si.ijs.kt.clus.error.common.ClusError;
import si.ijs.kt.clus.error.common.ClusErrorList;
import si.ijs.kt.clus.error.hmlc.HierErrorMeasures;
import si.ijs.kt.clus.ext.ensemble.ClusForest;
import si.ijs.kt.clus.main.ClusRun;
import si.ijs.kt.clus.main.ClusStatManager;
import si.ijs.kt.clus.main.settings.Settings;
import si.ijs.kt.clus.main.settings.section.SettingsHMLC;
import si.ijs.kt.clus.model.ClusModel;
import si.ijs.kt.clus.selection.RandomSelection;
import si.ijs.kt.clus.statistic.ClusStatistic;
import si.ijs.kt.clus.util.ClusLogger;
import si.ijs.kt.clus.util.exception.ClusException;

public abstract class ClusSemiSupervisedInduce
extends ClusInductionAlgorithm {
    RowData m_UnlabeledData;
    RowData m_TrainingSet;
    double m_PercentageLabeled;
    ClusModel m_Model;

    public ClusSemiSupervisedInduce(ClusSchema schema, Settings sett) throws ClusException, IOException {
        super(schema, sett);
    }

    public ClusSemiSupervisedInduce(ClusInductionAlgorithm other) {
        super(other);
    }

    public void partitionData(ClusRun cr) throws IOException, ClusException, InterruptedException {
        this.m_UnlabeledData = cr.getUnlabeledSet();
        this.m_TrainingSet = new RowData(cr.getStatManager().getSchema());
        if (this.m_UnlabeledData == null) {
            this.m_UnlabeledData = new RowData(cr.getStatManager().getSchema());
            RowData tempTrainingSet = (RowData)cr.getTrainingSet();
            if (tempTrainingSet.getNbUnlabeled() > 0) {
                for (int i = 0; i < tempTrainingSet.getNbRows(); ++i) {
                    if (tempTrainingSet.getTuple(i).isUnlabeled()) {
                        this.m_UnlabeledData.add(tempTrainingSet.getTuple(i).deepCloneTuple());
                        continue;
                    }
                    this.m_TrainingSet.add(tempTrainingSet.getTuple(i).deepCloneTuple());
                }
            } else {
                ClusLogger.info("UnlabeledData not set. Unlabeled examples will be selected from training set (Percentage labeled = " + this.m_PercentageLabeled + ")");
                RandomSelection randomSelection = new RandomSelection(tempTrainingSet.getNbRows(), this.m_PercentageLabeled, cr.getStatManager().getSettings().getGeneral().getRandomSeed());
                for (int i = 0; i < tempTrainingSet.getNbRows(); ++i) {
                    if (!randomSelection.isSelected(i)) {
                        this.m_UnlabeledData.add(tempTrainingSet.getTuple(i).deepCloneTuple());
                        continue;
                    }
                    this.m_TrainingSet.add(tempTrainingSet.getTuple(i).deepCloneTuple());
                }
            }
            cr.setTrainingSet(this.m_TrainingSet);
        } else {
            this.m_TrainingSet = (RowData)cr.getTrainingSet();
        }
        this.setTestSet(cr);
    }

    public void setTestSet(ClusRun cr) throws IOException, ClusException {
        RowData testSet = cr.getTestSet();
        if (testSet == null) {
            ClusLogger.info("Testing data not set. Semi-supervised learning will be evaluated on unlabeled data.");
            testSet = new RowData(cr.getStatManager().getSchema());
            for (int i = 0; i < this.m_UnlabeledData.getNbRows(); ++i) {
                testSet.add(this.m_UnlabeledData.getTuple(i).deepCloneTuple());
            }
            cr.setTestSet(testSet.getIterator());
        }
    }

    public ClusError calculateError(ClusModel model, RowData testSet, int maxInstanceIndex) throws ClusException, InterruptedException {
        ClusError error = null;
        ClusErrorList ErrorList = new ClusErrorList();
        if (this.getStatManager().getTargetMode() == ClusStatManager.Mode.HIERARCHICAL) {
            error = new HierErrorMeasures(ErrorList, this.m_StatManager.getHier(), this.m_StatManager.getSettings().getHMLC().getRecallValues().getDoubleVector(), SettingsHMLC.HierarchyMeasures.PooledAUPRC, this.m_StatManager.getSettings().getOutput().isWriteCurves(), this.getSettings().getOutput().isGzipOutput());
        } else {
            NumericAttrType[] num = this.m_Schema.getNumericAttrUse(ClusAttrType.AttributeUseType.Target);
            NominalAttrType[] nom = this.m_Schema.getNominalAttrUse(ClusAttrType.AttributeUseType.Target);
            if (nom.length != 0) {
                error = new Accuracy(ErrorList, nom);
            } else if (num.length != 0) {
                error = new RMSError(ErrorList, num);
            }
        }
        ErrorList.addError(error);
        for (int t = 0; t < maxInstanceIndex; ++t) {
            DataTuple tuple = testSet.getTuple(t);
            ClusStatistic pred = model.predictWeighted(tuple);
            ErrorList.addExample(tuple, pred);
        }
        return error;
    }

    public ClusError calculateError(RowData testSet) throws ClusException, InterruptedException {
        return this.calculateError(this.m_Model, testSet, testSet.getNbRows());
    }

    public ClusError getOOBError(RowData all_data, int maxInstanceIndex) throws ClusException, InterruptedException {
        ClusError error = null;
        ClusErrorList OOBErrorList = new ClusErrorList();
        if (!this.getSettings().getEnsemble().shouldEstimateOOB()) {
            return new Accuracy(OOBErrorList, this.getSchema().getNominalAttrUse(ClusAttrType.AttributeUseType.Target));
        }
        if (this.getStatManager().getTargetMode() == ClusStatManager.Mode.HIERARCHICAL) {
            error = new HierErrorMeasures(OOBErrorList, this.m_StatManager.getHier(), this.m_StatManager.getSettings().getHMLC().getRecallValues().getDoubleVector(), SettingsHMLC.HierarchyMeasures.PooledAUPRC, this.m_StatManager.getSettings().getOutput().isWriteCurves(), this.getSettings().getOutput().isGzipOutput());
        }
        if (this.getStatManager().getTargetMode() == ClusStatManager.Mode.REGRESSION) {
            error = new RMSError(OOBErrorList, this.getSchema().getNumericAttrUse(ClusAttrType.AttributeUseType.Target));
        }
        if (this.getStatManager().getTargetMode() == ClusStatManager.Mode.CLASSIFY) {
            error = new Accuracy(OOBErrorList, this.getSchema().getNominalAttrUse(ClusAttrType.AttributeUseType.Target));
        }
        OOBErrorList.addError(error);
        for (int t = 0; t < maxInstanceIndex; ++t) {
            DataTuple tuple = all_data.getTuple(t);
            if (!((ClusForest)this.m_Model).containsOOBForTuple(tuple)) continue;
            ClusStatistic pred = ((ClusForest)this.m_Model).predictWeightedOOB(tuple);
            OOBErrorList.addExample(tuple, pred);
        }
        return error;
    }
}

