/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator.naive_bayes;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.commons.math3.util.Precision;
import org.dmg.pmml.ContinuousDistribution;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Discretize;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Extension;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.GaussianDistribution;
import org.dmg.pmml.HasType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.PoissonDistribution;
import org.dmg.pmml.naive_bayes.BayesInput;
import org.dmg.pmml.naive_bayes.BayesInputs;
import org.dmg.pmml.naive_bayes.BayesOutput;
import org.dmg.pmml.naive_bayes.NaiveBayesModel;
import org.dmg.pmml.naive_bayes.PMMLAttributes;
import org.dmg.pmml.naive_bayes.PMMLElements;
import org.dmg.pmml.naive_bayes.PairCounts;
import org.dmg.pmml.naive_bayes.TargetValueCount;
import org.dmg.pmml.naive_bayes.TargetValueCounts;
import org.dmg.pmml.naive_bayes.TargetValueStat;
import org.dmg.pmml.naive_bayes.TargetValueStats;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.DiscretizationUtil;
import org.jpmml.evaluator.DistributionUtil;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.ExpressionUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.FieldValueUtil;
import org.jpmml.evaluator.FieldValues;
import org.jpmml.evaluator.Functions;
import org.jpmml.evaluator.InvalidAttributeException;
import org.jpmml.evaluator.MapHolder;
import org.jpmml.evaluator.MisplacedElementException;
import org.jpmml.evaluator.MissingAttributeException;
import org.jpmml.evaluator.MissingElementException;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.NumberUtil;
import org.jpmml.evaluator.PMMLUtil;
import org.jpmml.evaluator.ProbabilityDistribution;
import org.jpmml.evaluator.TargetField;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.ValueFactory;
import org.jpmml.evaluator.ValueUtil;
import org.jpmml.evaluator.VerificationUtil;
import org.jpmml.evaluator.naive_bayes.ProbabilityMap;
import org.jpmml.model.XPathUtil;

public class NaiveBayesModelEvaluator
extends ModelEvaluator<NaiveBayesModel> {
    private List<BayesInput> bayesInputs = Collections.emptyList();
    private Map<FieldName, Map<Object, Number>> fieldCountSums = Collections.emptyMap();

    private NaiveBayesModelEvaluator() {
    }

    public NaiveBayesModelEvaluator(PMML pmml) {
        this(pmml, PMMLUtil.findModel(pmml, NaiveBayesModel.class));
    }

    public NaiveBayesModelEvaluator(PMML pmml, NaiveBayesModel naiveBayesModel) {
        super(pmml, naiveBayesModel);
        BayesInputs bayesInputs = naiveBayesModel.getBayesInputs();
        if (bayesInputs == null) {
            throw new MissingElementException((PMMLObject)naiveBayesModel, PMMLElements.NAIVEBAYESMODEL_BAYESINPUTS);
        }
        if (!bayesInputs.hasBayesInputs() && !bayesInputs.hasExtensions()) {
            throw new MissingElementException((PMMLObject)bayesInputs, PMMLElements.BAYESINPUTS_BAYESINPUTS);
        }
        this.bayesInputs = ImmutableList.copyOf(NaiveBayesModelEvaluator.parseBayesInputs(bayesInputs));
        this.fieldCountSums = ImmutableMap.copyOf(NaiveBayesModelEvaluator.toImmutableMapMap(NaiveBayesModelEvaluator.calculateFieldCountSums(this.bayesInputs)));
        BayesOutput bayesOutput = naiveBayesModel.getBayesOutput();
        if (bayesOutput == null) {
            throw new MissingElementException((PMMLObject)naiveBayesModel, PMMLElements.NAIVEBAYESMODEL_BAYESOUTPUT);
        }
        TargetValueCounts targetValueCounts = bayesOutput.getTargetValueCounts();
        if (targetValueCounts == null) {
            throw new MissingElementException((PMMLObject)bayesOutput, PMMLElements.BAYESOUTPUT_TARGETVALUECOUNTS);
        }
        if (!targetValueCounts.hasTargetValueCounts()) {
            throw new MissingElementException((PMMLObject)targetValueCounts, PMMLElements.TARGETVALUECOUNTS_TARGETVALUECOUNTS);
        }
    }

    @Override
    public String getSummary() {
        return "Naive Bayes model";
    }

    @Override
    protected <V extends Number> Map<FieldName, ? extends Classification<?, V>> evaluateClassification(final ValueFactory<V> valueFactory, EvaluationContext context) {
        NaiveBayesModel naiveBayesModel = (NaiveBayesModel)this.getModel();
        BayesOutput bayesOutput = naiveBayesModel.getBayesOutput();
        TargetField targetField = this.getTargetField();
        FieldName targetName = bayesOutput.getField();
        if (targetName == null) {
            throw new MissingAttributeException((PMMLObject)bayesOutput, PMMLAttributes.BAYESOUTPUT_FIELD);
        }
        if (targetName != null && !Objects.equals(targetField.getFieldName(), targetName)) {
            throw new InvalidAttributeException((PMMLObject)bayesOutput, PMMLAttributes.BAYESOUTPUT_FIELD, targetName);
        }
        ProbabilityMap probabilities = new ProbabilityMap<Object, V>(){

            @Override
            public ValueFactory<V> getValueFactory() {
                return valueFactory;
            }

            @Override
            public void multiply(Object key, Number probability) {
                ValueFactory valueFactory2 = this.getValueFactory();
                Value value = this.ensureValue(key);
                Value probabilityValue = valueFactory2.newValue(probability).ln();
                value.add(probabilityValue);
            }
        };
        TargetValueCounts targetValueCounts = NaiveBayesModelEvaluator.getTargetValueCounts(bayesOutput);
        this.calculatePriorProbabilities(probabilities, targetValueCounts);
        Number threshold = naiveBayesModel.getThreshold();
        if (threshold == null) {
            throw new MissingAttributeException((PMMLObject)naiveBayesModel, PMMLAttributes.NAIVEBAYESMODEL_THRESHOLD);
        }
        Map<FieldName, Map<Object, Number>> fieldCountSums = this.getFieldCountSums();
        List<BayesInput> bayesInputs = this.getBayesInputs();
        for (BayesInput bayesInput : bayesInputs) {
            FieldName name = bayesInput.getField();
            if (name == null) {
                throw new MissingAttributeException((PMMLObject)bayesInput, PMMLAttributes.BAYESINPUT_FIELD);
            }
            FieldValue value = context.evaluate(name);
            if (FieldValueUtil.isMissing(value)) continue;
            TargetValueStats targetValueStats = NaiveBayesModelEvaluator.getTargetValueStats(bayesInput);
            if (targetValueStats != null) {
                this.calculateContinuousProbabilities(probabilities, targetValueStats, threshold, value);
                continue;
            }
            DerivedField derivedField = bayesInput.getDerivedField();
            if (derivedField != null && FieldValueUtil.isMissing(value = this.discretize(derivedField, value))) continue;
            Map<Object, Number> countSums = fieldCountSums.get(name);
            TargetValueCounts targetValueCounts2 = NaiveBayesModelEvaluator.getTargetValueCounts(bayesInput, value);
            if (targetValueCounts2 == null) continue;
            this.calculateDiscreteProbabilities(probabilities, targetValueCounts2, threshold, countSums);
        }
        ValueUtil.normalizeSoftMax(probabilities);
        ProbabilityDistribution result = new ProbabilityDistribution(probabilities);
        return TargetUtil.evaluateClassification(targetField, result);
    }

    private FieldValue discretize(DerivedField derivedField, FieldValue value) {
        Expression expression = ExpressionUtil.ensureExpression(derivedField);
        if (expression instanceof Discretize) {
            Discretize discretize = (Discretize)expression;
            if (FieldValueUtil.isMissing(value = DiscretizationUtil.discretize(discretize, value))) {
                return FieldValues.MISSING_VALUE;
            }
            return value.cast((HasType<?>)derivedField);
        }
        throw new MisplacedElementException((PMMLObject)expression);
    }

    private void calculateContinuousProbabilities(ProbabilityMap<Object, ?> probabilities, TargetValueStats targetValueStats, Number threshold, FieldValue value) {
        Number x = value.asNumber();
        for (TargetValueStat targetValueStat : targetValueStats) {
            Object targetCategory = targetValueStat.getValue();
            if (targetCategory == null) {
                throw new MissingAttributeException((PMMLObject)targetValueStat, PMMLAttributes.TARGETVALUESTAT_VALUE);
            }
            ContinuousDistribution distribution = targetValueStat.getContinuousDistribution();
            if (distribution == null) {
                throw new MissingElementException(MissingElementException.formatMessage(XPathUtil.formatElement(targetValueStat.getClass()) + "/<ContinuousDistribution>"), (PMMLObject)targetValueStat);
            }
            if (!(distribution instanceof GaussianDistribution) && !(distribution instanceof PoissonDistribution)) {
                throw new MisplacedElementException((PMMLObject)distribution);
            }
            Number probability = DistributionUtil.probability(distribution, x);
            if (NumberUtil.compare(probability, threshold) < 0) {
                probability = threshold;
            }
            probabilities.multiply(targetCategory, probability);
        }
    }

    private void calculateDiscreteProbabilities(ProbabilityMap<Object, ?> probabilities, TargetValueCounts targetValueCounts, Number threshold, Map<?, Number> countSums) {
        for (TargetValueCount targetValueCount : targetValueCounts) {
            Number probability;
            Object targetCategory = targetValueCount.getValue();
            if (targetCategory == null) {
                throw new MissingAttributeException((PMMLObject)targetValueCount, PMMLAttributes.TARGETVALUECOUNT_VALUE);
            }
            Number count = targetValueCount.getCount();
            if (count == null) {
                throw new MissingAttributeException((PMMLObject)targetValueCount, PMMLAttributes.TARGETVALUECOUNT_COUNT);
            }
            if (VerificationUtil.isZero(count, Precision.EPSILON)) {
                probability = threshold;
            } else {
                Number countSum = countSums.get(targetCategory);
                probability = Functions.DIVIDE.evaluate(count, NumberUtil.asDouble(countSum));
            }
            probabilities.multiply(targetCategory, probability);
        }
    }

    private void calculatePriorProbabilities(ProbabilityMap<Object, ?> probabilities, TargetValueCounts targetValueCounts) {
        for (TargetValueCount targetValueCount : targetValueCounts) {
            Object targetCategory = targetValueCount.getValue();
            if (targetCategory == null) {
                throw new MissingAttributeException((PMMLObject)targetValueCount, PMMLAttributes.TARGETVALUECOUNT_VALUE);
            }
            Number count = targetValueCount.getCount();
            if (count == null) {
                throw new MissingAttributeException((PMMLObject)targetValueCount, PMMLAttributes.TARGETVALUECOUNT_COUNT);
            }
            Number probability = count;
            probabilities.multiply(targetCategory, probability);
        }
    }

    protected List<BayesInput> getBayesInputs() {
        return this.bayesInputs;
    }

    protected Map<FieldName, Map<Object, Number>> getFieldCountSums() {
        return this.fieldCountSums;
    }

    private static Map<FieldName, Map<Object, Number>> calculateFieldCountSums(List<BayesInput> bayesInputs) {
        LinkedHashMap<FieldName, Map<Object, Number>> result = new LinkedHashMap<FieldName, Map<Object, Number>>();
        for (BayesInput bayesInput : bayesInputs) {
            FieldName name = bayesInput.getField();
            LinkedHashMap<Object, Number> countSums = new LinkedHashMap<Object, Number>();
            List pairCounts = bayesInput.getPairCounts();
            for (PairCounts pairCount : pairCounts) {
                TargetValueCounts targetValueCounts = pairCount.getTargetValueCounts();
                for (TargetValueCount targetValueCount : targetValueCounts) {
                    Object targetCategory = targetValueCount.getValue();
                    if (targetCategory == null) {
                        throw new MissingAttributeException((PMMLObject)targetValueCount, PMMLAttributes.TARGETVALUECOUNT_VALUE);
                    }
                    Number count = targetValueCount.getCount();
                    if (count == null) {
                        throw new MissingAttributeException((PMMLObject)targetValueCount, PMMLAttributes.TARGETVALUECOUNT_COUNT);
                    }
                    Number countSum = (Number)countSums.get(targetCategory);
                    countSum = countSum == null ? (Number)count : (Number)Functions.ADD.evaluate(countSum, count);
                    countSums.put(targetCategory, countSum);
                }
            }
            result.put(name, countSums);
        }
        return result;
    }

    private static List<BayesInput> parseBayesInputs(BayesInputs bayesInputs) {
        if (!bayesInputs.hasExtensions()) {
            return bayesInputs.getBayesInputs();
        }
        ArrayList<BayesInput> result = new ArrayList<BayesInput>(bayesInputs.getBayesInputs());
        List extensions = bayesInputs.getExtensions();
        for (Extension extension : extensions) {
            List objects = extension.getContent();
            for (Object object : objects) {
                if (!(object instanceof BayesInput)) continue;
                BayesInput bayesInput = (BayesInput)object;
                result.add(bayesInput);
            }
        }
        return result;
    }

    private static TargetValueStats getTargetValueStats(BayesInput bayesInput) {
        return bayesInput.getTargetValueStats();
    }

    private static TargetValueCounts getTargetValueCounts(BayesInput bayesInput, FieldValue value) {
        if (bayesInput instanceof MapHolder) {
            MapHolder mapHolder = (MapHolder)bayesInput;
            return (TargetValueCounts)mapHolder.get(value.getDataType(), value.getValue());
        }
        List pairCounts = bayesInput.getPairCounts();
        for (PairCounts pairCount : pairCounts) {
            Object category = pairCount.getValue();
            if (category == null) {
                throw new MissingAttributeException((PMMLObject)pairCount, PMMLAttributes.PAIRCOUNTS_VALUE);
            }
            if (!value.equalsValue(category)) continue;
            TargetValueCounts targetValueCounts = pairCount.getTargetValueCounts();
            if (targetValueCounts == null) {
                throw new MissingElementException((PMMLObject)pairCount, PMMLElements.PAIRCOUNTS_TARGETVALUECOUNTS);
            }
            return targetValueCounts;
        }
        return null;
    }

    private static TargetValueCounts getTargetValueCounts(BayesOutput bayesOutput) {
        return bayesOutput.getTargetValueCounts();
    }
}

