/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.ext.imputation;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import si.ijs.kt.clus.algo.kNN.KnnClassifier;
import si.ijs.kt.clus.algo.kNN.KnnModel;
import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.data.rows.RowData;
import si.ijs.kt.clus.data.type.ClusAttrType;
import si.ijs.kt.clus.main.ClusRun;
import si.ijs.kt.clus.main.settings.section.SettingsKNN;
import si.ijs.kt.clus.statistic.ClusStatistic;
import si.ijs.kt.clus.util.ClusLogger;
import si.ijs.kt.clus.util.exception.ClusException;

public class MissingTargetImputation {
    public static void impute(ClusRun cr) {
        MissingTargetImputation.impute(cr, null);
    }

    public static void impute(ClusRun cr, HashMap<Integer, ArrayList<Integer>> missing) {
        SettingsKNN settings = cr.getStatManager().getSettings().getKNN();
        ClusAttrType[] targets = cr.getStatManager().getSchema().getAllAttrUse(ClusAttrType.AttributeUseType.Target);
        boolean allNominal = true;
        boolean allNumeric = true;
        boolean allClasses = true;
        for (ClusAttrType target : targets) {
            if (!target.isNominal()) {
                allNominal = false;
            }
            if (!target.isNumeric()) {
                allNumeric = false;
            }
            if (target.isClasses()) continue;
            allClasses = false;
        }
        if (!(allNominal || allNumeric || allClasses)) {
            throw new RuntimeException("Targets should be all numeric or all nominal or all classes.");
        }
        RowData data = null;
        try {
            data = (RowData)cr.getTrainingSet();
        }
        catch (InterruptedException e) {
            e.printStackTrace();
        }
        if (missing == null) {
            missing = data.getMissingTargets();
        }
        boolean isSparse = data.isSparse();
        int[] neededNeighbours = new int[missing.keySet().size()];
        int i = 0;
        for (int j : missing.keySet()) {
            neededNeighbours[i++] = j;
        }
        Arrays.sort(neededNeighbours);
        boolean singular = neededNeighbours.length == 1;
        ClusLogger.info(String.format("%d example%s need%s imputation.", neededNeighbours.length, singular ? "" : "s", singular ? "s" : ""));
        int maxK = KnnClassifier.getMaxK(settings.getKNNk());
        ClusAttrType[] necessaryDescriptiveAttributes = KnnClassifier.getNecessaryDescriptiveAttributes(data);
        SettingsKNN.DistanceWeights distWeight = settings.getKNNDistanceWeights().get(0);
        KnnModel knn = null;
        try {
            knn = new KnnModel(cr, maxK, distWeight, maxK, isSparse, necessaryDescriptiveAttributes, neededNeighbours);
        }
        catch (IOException | InterruptedException | ClusException e) {
            e.printStackTrace();
        }
        int iterations = 0;
        ArrayList<Integer> toProcess = new ArrayList<Integer>();
        for (int example : neededNeighbours) {
            toProcess.add(example);
        }
        while (toProcess.size() > 0) {
            ArrayList<Integer> toProcessNext = new ArrayList<Integer>();
            double[] predictedNum = null;
            int[] predictedNom = null;
            for (int example : missing.keySet()) {
                ClusStatistic prediction = null;
                DataTuple tuple = data.getTuple(example);
                try {
                    prediction = knn.predictWeighted(tuple, missing.get(example));
                    if (prediction == null) {
                        toProcessNext.add(example);
                        continue;
                    }
                }
                catch (ClusException e) {
                    e.printStackTrace();
                }
                if (allNumeric) {
                    predictedNum = prediction.getNumericPred();
                } else if (allNominal) {
                    predictedNom = prediction.getNominalPred();
                }
                for (int targetIndex : missing.get(example)) {
                    if (allNumeric) {
                        prediction.predictTupleOneComponent(tuple, targetIndex, predictedNum[targetIndex]);
                        continue;
                    }
                    if (allNominal) {
                        prediction.predictTupleOneComponent(tuple, targetIndex, predictedNom[targetIndex]);
                        continue;
                    }
                    prediction.predictTuple(tuple);
                }
            }
            ++iterations;
            if (toProcess.size() <= toProcessNext.size()) break;
            toProcess = toProcessNext;
        }
        if (toProcess.size() > 0) {
            System.err.println("Cannot impute the values in a finite number of steps. Number of examples with missing values: " + toProcess.size());
        }
        ClusLogger.info(String.format("Values imputed in %d iteration(s).", iterations));
    }
}

