/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.algo.kNN;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import si.ijs.kt.clus.Clus;
import si.ijs.kt.clus.algo.ClusInductionAlgorithm;
import si.ijs.kt.clus.algo.ClusInductionAlgorithmType;
import si.ijs.kt.clus.algo.kNN.KnnModel;
import si.ijs.kt.clus.algo.tdidt.ClusNode;
import si.ijs.kt.clus.data.ClusSchema;
import si.ijs.kt.clus.data.rows.RowData;
import si.ijs.kt.clus.data.type.ClusAttrType;
import si.ijs.kt.clus.main.ClusRun;
import si.ijs.kt.clus.main.settings.Settings;
import si.ijs.kt.clus.main.settings.section.SettingsKNN;
import si.ijs.kt.clus.model.ClusModel;
import si.ijs.kt.clus.model.ClusModelInfo;
import si.ijs.kt.clus.util.exception.ClusException;
import si.ijs.kt.clus.util.jeans.util.cmdline.CMDLineArgs;

public class KnnClassifier
extends ClusInductionAlgorithmType {
    public static String DEFAULT_MODEL_NAME_WITH_CONSTANT_WEIGHTS = "Default 1-nn model with " + SettingsKNN.DistanceWeights.Constant.toString() + " weights";
    private final String modelNameTemplate = "Original %s-nn model with %s weighting";

    public KnnClassifier(Clus clus) {
        super(clus);
    }

    @Override
    public ClusInductionAlgorithm createInduce(ClusSchema schema, Settings sett, CMDLineArgs cargs) throws ClusException, IOException {
        ClusInductionAlgorithmImpl induce = new ClusInductionAlgorithmImpl(schema, sett);
        return induce;
    }

    @Override
    public void pruneAll(ClusRun cr) throws ClusException, IOException {
    }

    @Override
    public ClusModel pruneSingle(ClusModel model, ClusRun cr) throws ClusException, IOException {
        return model;
    }

    public static int getMaxK(int[] ks) {
        int maxK = 1;
        for (int k : ks) {
            maxK = Math.max(maxK, k);
        }
        return maxK;
    }

    public static ClusAttrType[] getNecessaryDescriptiveAttributes(RowData data) {
        ClusAttrType[] necessaryDescriptiveAttributes = data.isSparse() ? data.getSchema().getNominalAttrUse(ClusAttrType.AttributeUseType.Descriptive) : data.getSchema().getAllAttrUse(ClusAttrType.AttributeUseType.Descriptive);
        return necessaryDescriptiveAttributes;
    }

    public static ClusModel induceDefaultModel(ClusRun cr) throws ClusException, InterruptedException {
        ClusNode node = new ClusNode();
        RowData data = (RowData)cr.getTrainingSet();
        node.initTargetStat(cr.getStatManager(), data);
        node.computePrediction();
        node.makeLeaf();
        return node;
    }

    @Override
    public void postProcess(ClusRun cr) throws ClusException, IOException {
    }

    private class ClusInductionAlgorithmImpl
    extends ClusInductionAlgorithm {
        public ClusInductionAlgorithmImpl(ClusSchema schema, Settings sett) throws ClusException, IOException {
            super(schema, sett);
        }

        @Override
        public ClusModel induceSingleUnpruned(ClusRun cr) throws ClusException, IOException, InterruptedException {
            int[] ks = this.getSettings().getKNN().getKNNk();
            List<SettingsKNN.DistanceWeights> distWeight = this.getSettings().getKNN().getKNNDistanceWeights();
            RowData trainData = cr.getDataSet(0);
            boolean isSparse = trainData.isSparse();
            int maxK = KnnClassifier.getMaxK(ks);
            Arrays.sort(ks);
            ClusAttrType[] necessaryDescriptiveAttributes = KnnClassifier.getNecessaryDescriptiveAttributes(trainData);
            String model_name = DEFAULT_MODEL_NAME_WITH_CONSTANT_WEIGHTS;
            KnnModel model = new KnnModel(cr, 1, SettingsKNN.DistanceWeights.Constant, maxK, isSparse, necessaryDescriptiveAttributes);
            model.tryInitializeMLC(ks, trainData, this.getSettings().getKNN().getMlcCountSmoother());
            ClusModelInfo model_info = cr.addModelInfo(1, model_name);
            model_info.setModel(model);
            model_info.setName(model_name);
            ClusModel defModel = KnnClassifier.induceDefaultModel(cr);
            ClusModelInfo defModelInfo = cr.addModelInfo(0);
            defModelInfo.setModel(defModel);
            defModelInfo.setName("Default");
            int modelCnt = 2;
            for (int k : ks) {
                for (SettingsKNN.DistanceWeights w : distWeight) {
                    if (k == 1 && w.equals((Object)SettingsKNN.DistanceWeights.Constant)) continue;
                    KnnModel tmpmodel = new KnnModel(cr, k, w, model);
                    model_name = String.format("Original %s-nn model with %s weighting", k, w.toString());
                    ClusModelInfo tmpmodel_info = cr.addModelInfo(modelCnt++, model_name);
                    tmpmodel_info.setModel(tmpmodel);
                    tmpmodel_info.setName(model_name);
                }
            }
            return model;
        }
    }
}

