/*
 * Decompiled with CFR 0.152.
 */
package model.tree;

import data.catalog.Catalog;
import data.feature.SimpleFeature;
import data.instance.Instance;
import data.instance.Instances;
import data.parameter.AffineShift;
import data.parameter.NumericShiftFunction;
import data.value.Value;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Random;
import java.util.stream.Collectors;
import model.Model;
import model.ModelOptions;
import model.NodeSplit;
import model.criterion.stop.LeafTooSmall;
import model.criterion.stop.PureLeaf;
import model.criterion.stop.StopCriterion;
import model.distribution.ClassDistribution;
import model.distribution.Distribution;
import model.distribution.NumericDistribution;
import model.inference.Inference;
import model.inference.InferenceRF;
import model.inference.hc.AggregateBase;
import model.tree.AggregatePrototype;
import model.tree.InternalNode;
import model.tree.Leaf;
import model.tree.Node;
import org.jdom2.Content;
import org.jdom2.Element;
import util.Couple;
import util.GlobalRandom;
import util.Logging;

public class DecisionTree
extends Model {
    protected Node root;
    protected Inference inference;
    protected ArrayList<StopCriterion> stops;
    protected Distribution distribution;

    public DecisionTree(ModelOptions opt) {
        this.opts = opt;
        this.shift = new AffineShift();
        this.shift.setOutput();
        if (this.opts.mode.equals("classification")) {
            this.distribution = new ClassDistribution();
        } else if (this.opts.mode.equals("regression")) {
            this.distribution = new NumericDistribution();
        }
        this.stops = new ArrayList();
        this.stops.add(new PureLeaf());
        this.stops.add(new LeafTooSmall(this.opts.instPerLeaf));
        this.inference = new InferenceRF(this.opts, GlobalRandom.instance());
    }

    public DecisionTree(ModelOptions opt, Random gr) {
        this.opts = opt;
        if (this.opts.mode.equals("classification")) {
            this.distribution = new ClassDistribution();
        } else if (this.opts.mode.equals("regression")) {
            this.distribution = new NumericDistribution();
        }
        this.stops = new ArrayList();
        this.stops.add(new PureLeaf());
        this.stops.add(new LeafTooSmall(this.opts.instPerLeaf));
        this.inference = new InferenceRF(this.opts, gr);
    }

    public DecisionTree(ModelOptions opt, Element rootEl, Catalog cat) {
        this(opt);
        this.root = Node.fromXML((Element)rootEl.getChildren().get(0), cat, 1);
    }

    @Override
    public void build(Instances ids, Catalog cat) {
        Logging.model.info((Object)("Learning of model " + this.name() + " starts."));
        this.root = this.buildNode(ids, cat, 1);
        this.initShifts(ids, cat);
        Logging.model.info((Object)("Learning of model " + this.name() + " finishes."));
    }

    private Node buildNode(Instances insts, Catalog cat, int d) {
        Distribution distr = this.distribution.clone();
        distr.init(insts);
        Logging.model.info((Object)(String.valueOf(this.name()) + " - Node distribution : " + distr.toString()));
        boolean stopNow = insts.isEmpty();
        for (StopCriterion stop : this.stops) {
            boolean bl = stopNow = stopNow || stop.stopCriterion(insts);
        }
        if (stopNow) {
            Logging.model.info((Object)(String.valueOf(this.name()) + " - Stop : " + distr.toString()));
            return new Leaf(distr, d);
        }
        long t1 = System.currentTimeMillis();
        NodeSplit spl = this.inference.infer(insts, cat);
        long t2 = System.currentTimeMillis();
        if (spl != null) {
            Logging.model.info((Object)(String.valueOf(this.name()) + " - Best split found in " + (double)(t2 - t1) / 1000.0 + "s : " + spl.toString()));
            Node left = this.buildNode(new Instances(spl.getLeft()), cat, d + 1);
            Node right = this.buildNode(new Instances(spl.getRight()), cat, d + 1);
            Node noValue = this.buildNode(new Instances(spl.getNoValue()), cat, d + 1);
            InternalNode res = new InternalNode(spl, left, right, noValue, distr, d);
            res.createFiliation();
            return res;
        }
        Logging.model.info((Object)(String.valueOf(this.name()) + " - No split found in " + (double)(t2 - t1) / 1000.0 + "s " + distr.toString()));
        return new Leaf(distr, d);
    }

    @Override
    public Value classify(Instance inst, Catalog cat) {
        Value res = this.root.classify(inst, cat);
        if (this.opts.mode.equals("regression")) {
            res = this.shift.apply(res);
        }
        return res;
    }

    @Override
    public String toString() {
        return this.root.toString();
    }

    @Override
    public Model clone() {
        ModelOptions optsCl = this.opts.clone();
        return new DecisionTree(optsCl);
    }

    @Override
    public String name() {
        return this.opts.name;
    }

    public HashSet<Value> classifyMajority(Instance inst, Catalog cat) {
        return this.root.classifyMajority(inst, cat);
    }

    public HashMap<Value, Double> classifyProbabilities(Instance inst, Catalog cat) {
        return this.root.classifyProbabilities(inst, cat);
    }

    @Override
    public HashSet<NumericShiftFunction> getParameters() {
        return null;
    }

    public HashMap<AggregatePrototype, ArrayList<Double>> getAllFeatures(ArrayList<SimpleFeature> sfs) {
        return this.root.getAllFeatures(sfs);
    }

    public HashMap<AggregateBase, HashMap<HashSet<SimpleFeature>, Couple<Double, Long>>> getBias() {
        return this.root.getBias();
    }

    public HashMap<AggregateBase, HashSet<SimpleFeature>> getComplexAggregates() {
        return this.root.getComplexAggregates();
    }

    public HashSet<SimpleFeature> getMainFeatures() {
        return this.root.getMainFeatures();
    }

    @Override
    public HashMap<NumericShiftFunction, double[]> getShifts(Instances insts, Catalog cat) {
        HashMap<NumericShiftFunction, double[]> shifts = new HashMap<NumericShiftFunction, double[]>();
        shifts.putAll(this.root.getShifts(insts, cat));
        if (this.opts.mode.equals("regression")) {
            shifts.put(this.shift, ((SimpleFeature)cat.getClassFeature()).range(new ArrayList<Value>(insts.stream().map(inst -> inst.getId()).collect(Collectors.toList())), cat));
        }
        return shifts;
    }

    @Override
    public Element toXMLElement() {
        Element dt = new Element("tree");
        dt.setAttribute("name", this.name());
        dt.addContent((Content)this.root.toXMLElement());
        return dt;
    }

    @Override
    public void initShifts(Instances insts, Catalog cat) {
        cat.initShifts(insts, cat);
        this.root.initShifts(insts, cat);
        if (this.opts.mode.equals("regression")) {
            this.shift.setParamsOrig(((SimpleFeature)cat.getClassFeature()).range(new ArrayList<Value>(insts.stream().map(inst -> inst.getId()).collect(Collectors.toList())), cat));
        }
    }

    @Override
    public void deployShifts(Instances insts, Catalog cat) {
        this.root.deployShifts(insts, cat);
        if (this.opts.mode.equals("regression")) {
            this.shift.setParamsDeploy(((SimpleFeature)cat.getClassFeature()).range(new ArrayList<Value>(insts.stream().map(inst -> inst.getId()).collect(Collectors.toList())), cat));
        }
    }

    @Override
    public HashSet<NumericShiftFunction> getShifts() {
        HashSet<NumericShiftFunction> res = this.root.getShifts();
        if (this.opts.mode.equals("regression")) {
            res.add(this.getOutputShift());
        }
        return res;
    }
}

