/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.algo.kNN;

import com.google.gson.JsonObject;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import si.ijs.kt.clus.algo.kNN.KnnClassifier;
import si.ijs.kt.clus.algo.kNN.distance.attributeWeighting.AttributeWeighting;
import si.ijs.kt.clus.algo.kNN.distance.attributeWeighting.NoWeighting;
import si.ijs.kt.clus.algo.kNN.distance.attributeWeighting.RandomForestWeighting;
import si.ijs.kt.clus.algo.kNN.distance.attributeWeighting.UserDefinedWeighting;
import si.ijs.kt.clus.algo.kNN.distance.distanceWeighting.DistanceWeighting;
import si.ijs.kt.clus.algo.kNN.distance.distanceWeighting.WeightConstant;
import si.ijs.kt.clus.algo.kNN.distance.distanceWeighting.WeightMinus;
import si.ijs.kt.clus.algo.kNN.distance.distanceWeighting.WeightOver;
import si.ijs.kt.clus.algo.kNN.methods.SearchAlgorithm;
import si.ijs.kt.clus.algo.kNN.methods.bfMethod.BruteForce;
import si.ijs.kt.clus.algo.kNN.methods.bfMethod.OracleBruteForce;
import si.ijs.kt.clus.algo.kNN.methods.kdTree.KDTree;
import si.ijs.kt.clus.algo.kNN.methods.vpTree.VPTree;
import si.ijs.kt.clus.data.ClusSchema;
import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.data.rows.RowData;
import si.ijs.kt.clus.data.type.ClusAttrType;
import si.ijs.kt.clus.data.type.primitive.NumericAttrType;
import si.ijs.kt.clus.distance.ClusDistance;
import si.ijs.kt.clus.distance.primitive.ChebyshevDistance;
import si.ijs.kt.clus.distance.primitive.EuclideanDistance;
import si.ijs.kt.clus.distance.primitive.ManhattanDistance;
import si.ijs.kt.clus.distance.primitive.SearchDistance;
import si.ijs.kt.clus.ext.timeseries.TimeSeriesStat;
import si.ijs.kt.clus.main.ClusRun;
import si.ijs.kt.clus.main.settings.Settings;
import si.ijs.kt.clus.main.settings.section.SettingsKNN;
import si.ijs.kt.clus.model.ClusModel;
import si.ijs.kt.clus.statistic.ClusStatistic;
import si.ijs.kt.clus.statistic.KnnMlcStat;
import si.ijs.kt.clus.statistic.StatisticPrintInfo;
import si.ijs.kt.clus.util.ClusLogger;
import si.ijs.kt.clus.util.exception.ClusException;
import si.ijs.kt.clus.util.jeans.util.MyArray;

public class KnnModel
implements ClusModel,
Serializable {
    private static final long serialVersionUID = 1L;
    private SearchAlgorithm m_Search;
    private SettingsKNN.DistanceWeights m_WeightingOption;
    private ClusRun m_ClusRun;
    protected ClusStatistic m_StatTemplate;
    private int m_K = 1;
    private int m_MaxK = 1;
    private DataTuple m_CurrentTuple;
    private LinkedList<DataTuple> m_CurrentNeighbours;
    private KnnModel m_Master = null;
    private boolean m_IsMlcKnn = false;

    public KnnModel(ClusRun cr, int k, SettingsKNN.DistanceWeights weighting, KnnModel master) {
        this.m_ClusRun = cr;
        this.m_K = k;
        master.m_MaxK = this.m_MaxK = Math.max(this.m_K, master.m_MaxK);
        this.m_WeightingOption = weighting;
        this.m_Search = master.m_Search;
        this.m_StatTemplate = master.m_StatTemplate;
        this.m_Master = master;
    }

    public KnnModel(ClusRun cr, int k, SettingsKNN.DistanceWeights weighting, int maxK, boolean isSparse, ClusAttrType[] necessaryDescriptiveAttributes) throws ClusException, IOException, InterruptedException {
        this(cr, k, weighting, maxK, isSparse, necessaryDescriptiveAttributes, null);
    }

    public KnnModel(ClusRun cr, int k, SettingsKNN.DistanceWeights weighting, int maxK, boolean isSparse, ClusAttrType[] necessaryDescriptiveAttributes, int[] trainingExamplesWithMissing) throws ClusException, IOException, InterruptedException {
        this.m_ClusRun = cr;
        this.m_K = k;
        this.m_MaxK = Math.max(Math.max(this.m_K, this.m_MaxK), maxK);
        this.m_WeightingOption = weighting;
        Settings sett = this.m_ClusRun.getStatManager().getSettings();
        String fName = sett.getGeneric().getAppName();
        AttributeWeighting attrWe = new NoWeighting();
        String attrWeighting = sett.getKNN().getKNNAttrWeight();
        boolean loadedWeighting = false;
        if (attrWeighting.toLowerCase().compareTo("none") != 0) {
            String[] wS;
            if (attrWeighting.startsWith("RF")) {
                try {
                    wS = attrWeighting.split(",");
                    int nbBags = 100;
                    if (wS.length == 2) {
                        nbBags = Integer.parseInt(wS[1]);
                    } else {
                        sett.getKNN().setKNNAttrWeight(attrWeighting + "," + nbBags);
                    }
                    attrWe = new RandomForestWeighting(this.m_ClusRun, nbBags);
                }
                catch (Exception e) {
                    throw new ClusException("Error at reading attributeWeighting value. RF value detected, but error accured while reading number of bags.");
                }
            } else if (attrWeighting.startsWith("[") && attrWeighting.endsWith("]")) {
                try {
                    wS = attrWeighting.substring(1, attrWeighting.length() - 1).split(",");
                    double[] we = new double[wS.length];
                    for (int i = 0; i < we.length; ++i) {
                        we[i] = Double.parseDouble(wS[i]);
                    }
                    attrWe = new UserDefinedWeighting(we);
                }
                catch (Exception e) {
                    throw new ClusException("Error at reading attributeWeighting value. User defined entry detected, but value cannot be read.");
                }
            } else {
                attrWe = AttributeWeighting.loadFromFile(fName + ".weight");
                ClusLogger.info(attrWe.toString());
                if (attrWe != null) {
                    loadedWeighting = true;
                } else {
                    throw new ClusException("Unrecognized attributeWeighting value (" + attrWeighting + ")");
                }
            }
        }
        if (!(attrWe instanceof NoWeighting) && !loadedWeighting) {
            AttributeWeighting.saveToFile(attrWe, fName + ".weight");
        }
        SettingsKNN.Distance dist = sett.getKNN().getDistance();
        SearchDistance searchDistance = KnnModel.createSearchDistance(cr, sett.getKNN(), attrWe);
        SettingsKNN.SearchMethod searchMethod = sett.getKNN().getSearchMethod();
        switch (searchMethod) {
            case VPTree: {
                this.m_Search = new VPTree(this.m_ClusRun, searchDistance);
                break;
            }
            case KDTree: {
                this.m_Search = new KDTree(this.m_ClusRun, searchDistance);
                break;
            }
            case BruteForce: {
                this.m_Search = new BruteForce(this.m_ClusRun, searchDistance);
                break;
            }
            case Oracle: {
                this.m_Search = new OracleBruteForce(this.m_ClusRun, searchDistance);
                break;
            }
            default: {
                throw new RuntimeException("Wrong search method: " + searchMethod.toString());
            }
        }
        if (sett.getGeneral().getVerbose() >= 1) {
            ClusLogger.info("Search method: " + this.m_Search.getClass());
            ClusLogger.info("Search distance: " + searchDistance.getBasicDistance().getClass());
            ClusLogger.info("Number of neighbours: " + this.m_K);
            ClusLogger.info("Distance weights: " + sett.getKNN().getKNNDistanceWeights());
        }
        if (trainingExamplesWithMissing != null) {
            this.m_Search.buildForMissingTargetImputation(maxK, trainingExamplesWithMissing, sett.getKNN());
        } else {
            this.m_Search.build(this.m_MaxK);
        }
        RowData train = this.m_ClusRun.getDataSet(0);
        RowData test = this.m_ClusRun.getDataSet(1);
        if (searchMethod == SettingsKNN.SearchMethod.Oracle && trainingExamplesWithMissing == null) {
            if (sett.getKNN().mustNotComputeTrainingError(train.getNbRows())) {
                sett.getOutput().setOutTrainError(false);
                System.err.println("Training error will not be computed, since we do not know the neighbours for each training instance.");
            }
            if (test != null && sett.getKNN().mustNotComputeTestError(test.getNbRows())) {
                sett.getOutput().setOutTestError(false);
                System.err.println("Testing error will not be computed, since we do not know the neighbours for each testing instance.");
            }
        }
        this.m_IsMlcKnn = sett.getKNN().isMlcKnn();
        this.m_StatTemplate = sett.getKNN().isMlcKnn() ? new KnnMlcStat(sett, cr.getStatManager().getSchema().getNominalAttrUse(ClusAttrType.AttributeUseType.Target)) : cr.getStatManager().getStatistic(ClusAttrType.AttributeUseType.Target);
        ClusStatistic trainingStat = cr.getStatManager().getTrainSetStat(ClusAttrType.AttributeUseType.Target);
        this.m_StatTemplate.setTrainingStat(trainingStat);
    }

    public static SearchDistance createSearchDistance(ClusRun cr, SettingsKNN sett, AttributeWeighting attrWe) {
        ClusDistance distance;
        SettingsKNN.Distance dist = sett.getDistance();
        SearchDistance searchDistance = new SearchDistance();
        RowData data = null;
        try {
            data = cr.getDataSet(0);
        }
        catch (IOException | InterruptedException | ClusException e) {
            e.printStackTrace();
        }
        ClusAttrType[] necessaryDescriptiveAttributes = KnnClassifier.getNecessaryDescriptiveAttributes(data);
        necessaryDescriptiveAttributes = ClusDistance.attributesWithNonZeroWeight(necessaryDescriptiveAttributes, attrWe);
        boolean isSparse = data.isSparse();
        switch (dist) {
            case Euclidean: {
                distance = new EuclideanDistance(searchDistance, isSparse, necessaryDescriptiveAttributes);
                break;
            }
            case Chebyshev: {
                distance = new ChebyshevDistance(searchDistance, isSparse, necessaryDescriptiveAttributes);
                break;
            }
            case Manhattan: {
                distance = new ManhattanDistance(searchDistance, isSparse, necessaryDescriptiveAttributes);
                break;
            }
            default: {
                throw new RuntimeException("Wrong distance.");
            }
        }
        int[] data_types = new int[]{0, 1, 2};
        double[] mins = null;
        double[] maxs = null;
        int nb_attrs = -1;
        for (int type = 0; type < data_types.length; ++type) {
            try {
                data = cr.getDataSet(type);
            }
            catch (IOException | InterruptedException | ClusException e) {
                e.printStackTrace();
            }
            ClusSchema schema = cr.getStatManager().getSchema();
            if (data == null) continue;
            if (mins == null) {
                nb_attrs = schema.getNbAttributes();
                mins = new double[nb_attrs];
                Arrays.fill(mins, Double.POSITIVE_INFINITY);
                maxs = new double[nb_attrs];
                Arrays.fill(maxs, Double.NEGATIVE_INFINITY);
            }
            for (int tuple_ind = 0; tuple_ind < data.getNbRows(); ++tuple_ind) {
                for (int i = 0; i < nb_attrs; ++i) {
                    ClusAttrType attr_type = schema.getAttrType(i);
                    if (attr_type.isDisabled() || !(attr_type instanceof NumericAttrType)) continue;
                    double t = attr_type.getNumeric(data.getTuple(tuple_ind));
                    if (t < mins[i] && t != Double.POSITIVE_INFINITY) {
                        mins[i] = t;
                    }
                    if (!(t > maxs[i]) || t == Double.POSITIVE_INFINITY) continue;
                    maxs[i] = t;
                }
            }
        }
        searchDistance.setDistance(distance);
        distance.setWeighting(attrWe);
        searchDistance.setNormalizationWeights(mins, maxs);
        return searchDistance;
    }

    @Override
    public ClusStatistic predictWeighted(DataTuple tuple) throws ClusException {
        return this.predictWeighted(tuple, null);
    }

    public ClusStatistic predictWeighted(DataTuple tuple, ArrayList<Integer> targetsNeeded) throws ClusException {
        DistanceWeighting weighting;
        int neighbour;
        LinkedList<DataTuple> nearest = new LinkedList<DataTuple>();
        if (this.m_Master == null) {
            this.m_CurrentNeighbours = this.m_Search.returnNNs(tuple, this.m_MaxK);
            this.m_CurrentTuple = tuple;
            for (neighbour = 0; neighbour < this.m_K; ++neighbour) {
                nearest.add(this.m_CurrentNeighbours.get(neighbour));
            }
        } else {
            if (this.m_Master.m_CurrentTuple != tuple) {
                throw new RuntimeException("The neighbours were computed for tuple\n" + this.m_Master.m_CurrentTuple.toString() + "\nbut now, we are dealing with tuple\n" + tuple.toString());
            }
            for (neighbour = 0; neighbour < this.m_K; ++neighbour) {
                nearest.add(this.m_Master.m_CurrentNeighbours.get(neighbour));
            }
        }
        switch (this.m_WeightingOption) {
            case OneOverD: {
                weighting = new WeightOver(nearest, this.m_Search, tuple);
                break;
            }
            case OneMinusD: {
                weighting = new WeightMinus(nearest, this.m_Search, tuple);
                break;
            }
            case Constant: {
                weighting = new WeightConstant(nearest, this.m_Search, tuple);
                break;
            }
            default: {
                throw new RuntimeException("DistanceWeights unknown!");
            }
        }
        ClusStatistic stat = this.m_StatTemplate.cloneStat();
        if (stat instanceof TimeSeriesStat) {
            for (DataTuple dt : nearest) {
                ClusStatistic dtStat = this.m_StatTemplate.cloneStat();
                dtStat.setSDataSize(1);
                dtStat.updateWeighted(dt, 0);
                dtStat.computePrediction();
                stat.addPrediction(dtStat, weighting.weight(dt));
            }
            stat.computePrediction();
            return stat;
        }
        for (DataTuple dt : nearest) {
            stat.updateWeighted(dt, weighting.weight(dt));
        }
        if (targetsNeeded != null) {
            Iterator<Object> iterator = targetsNeeded.iterator();
            while (iterator.hasNext()) {
                int target = (Integer)iterator.next();
                if (stat.isAnyLabeled(target)) continue;
                return null;
            }
        }
        stat.calcMean();
        return stat;
    }

    public void tryInitializeMLC(int[] ks, RowData trainData, double smoothing) throws ClusException {
        if (this.m_IsMlcKnn) {
            ((KnnMlcStat)this.m_StatTemplate).tryInitializeMLC(ks, trainData, this, smoothing);
        }
    }

    public int getMaxK() {
        return this.m_MaxK;
    }

    public SearchAlgorithm getSearch() {
        return this.m_Search;
    }

    @Override
    public void applyModelProcessors(DataTuple tuple, MyArray mproc) throws IOException {
    }

    @Override
    public int getModelSize() {
        ClusLogger.info("No specific model size for kNN model.");
        return -1;
    }

    @Override
    public String getModelInfo() {
        return "kNN model weighted with " + this.m_WeightingOption.toString() + " and " + this.m_K + " neighbors.";
    }

    @Override
    public void printModel(PrintWriter wrt) {
        wrt.println("No specific kNN model to write!");
    }

    @Override
    public void printModel(PrintWriter wrt, StatisticPrintInfo info) {
        wrt.println("No specific kNN model to write!");
        wrt.print(info.toString());
    }

    @Override
    public void printModelAndExamples(PrintWriter wrt, StatisticPrintInfo info, RowData examples) {
        throw new UnsupportedOperationException(this.getClass().getName() + ":printModelAndExamples() - Not supported yet for kNN.");
    }

    @Override
    public void printModelToQuery(PrintWriter wrt, ClusRun cr, int starttree, int startitem, boolean exhaustive) {
        throw new UnsupportedOperationException(this.getClass().getName() + ":printModelToQuery() - Not supported yet for kNN.");
    }

    @Override
    public void printModelToPythonScript(PrintWriter wrt, HashMap<String, Integer> indices) {
        throw new UnsupportedOperationException(this.getClass().getName() + ":printModelToPythonScript() - Not supported yet for kNN.");
    }

    @Override
    public void printModelToPythonScript(PrintWriter wrt, HashMap<String, Integer> indices, String modelIdentifier) {
        throw new UnsupportedOperationException(this.getClass().getName() + ":printModelToPythonScript() - Not supported yet for kNN.");
    }

    @Override
    public JsonObject getModelJSON() {
        return null;
    }

    @Override
    public JsonObject getModelJSON(StatisticPrintInfo info) {
        return null;
    }

    @Override
    public JsonObject getModelJSON(StatisticPrintInfo info, RowData examples) {
        return null;
    }

    @Override
    public void attachModel(HashMap table) throws ClusException {
        throw new UnsupportedOperationException(this.getClass().getName() + ":attachModel - Not supported yet for kNN.");
    }

    public void retrieveStatistics(ArrayList list) {
        throw new UnsupportedOperationException(this.getClass().getName() + ":retrieveStatistics - Not supported yet for kNN.");
    }

    @Override
    public ClusModel prune(int prunetype) {
        throw new UnsupportedOperationException(this.getClass().getName() + ":prune - Not supported yet for kNN.");
    }

    @Override
    public int getID() {
        throw new UnsupportedOperationException(this.getClass().getName() + ":getID - Not supported yet for kNN.");
    }
}

