/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator.mining;

import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.mining.Segmentation;
import org.jpmml.evaluator.EvaluatorUtil;
import org.jpmml.evaluator.HasProbability;
import org.jpmml.evaluator.ProbabilityAggregator;
import org.jpmml.evaluator.TypeCheckException;
import org.jpmml.evaluator.TypeUtil;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.ValueAggregator;
import org.jpmml.evaluator.ValueFactory;
import org.jpmml.evaluator.ValueMap;
import org.jpmml.evaluator.VoteAggregator;
import org.jpmml.evaluator.mining.SegmentResult;

public class MiningModelUtil {
    private MiningModelUtil() {
    }

    public static SegmentResult asSegmentResult(Segmentation.MultipleModelMethod multipleModelMethod, Map<FieldName, ?> predictions) {
        switch (multipleModelMethod) {
            case SELECT_FIRST: 
            case SELECT_ALL: 
            case MODEL_CHAIN: {
                if (!(predictions instanceof SegmentResult)) break;
                SegmentResult segmentResult = (SegmentResult)((Object)predictions);
                return segmentResult;
            }
        }
        return null;
    }

    public static <V extends Number> Value<V> aggregateValues(ValueFactory<V> valueFactory, Segmentation.MultipleModelMethod multipleModelMethod, Segmentation.MissingPredictionTreatment missingPredictionTreatment, Number missingThreshold, List<SegmentResult> segmentResults) {
        ValueAggregator aggregator;
        switch (multipleModelMethod) {
            case AVERAGE: 
            case SUM: {
                aggregator = new ValueAggregator.UnivariateStatistic<V>(valueFactory);
                break;
            }
            case MEDIAN: {
                aggregator = new ValueAggregator.Median<V>(valueFactory, segmentResults.size());
                break;
            }
            case WEIGHTED_AVERAGE: 
            case WEIGHTED_SUM: {
                aggregator = new ValueAggregator.WeightedUnivariateStatistic<V>(valueFactory);
                break;
            }
            case WEIGHTED_MEDIAN: {
                aggregator = new ValueAggregator.WeightedMedian<V>(valueFactory, segmentResults.size());
                break;
            }
            default: {
                throw new IllegalArgumentException();
            }
        }
        Fraction missingFraction = null;
        block25: for (SegmentResult segmentResult : segmentResults) {
            Number value;
            Object targetValue = EvaluatorUtil.decode(segmentResult.getTargetValue());
            if (targetValue == null) {
                switch (missingPredictionTreatment) {
                    case RETURN_MISSING: {
                        return null;
                    }
                    case SKIP_SEGMENT: {
                        if (missingFraction == null) {
                            missingFraction = new Fraction(valueFactory, segmentResults);
                        }
                        if (!missingFraction.update(segmentResult, missingThreshold)) continue block25;
                        return null;
                    }
                    case CONTINUE: {
                        return null;
                    }
                }
                throw new IllegalArgumentException();
            }
            try {
                value = targetValue instanceof Number ? (Number)((Number)targetValue) : (Number)((Number)TypeUtil.cast(DataType.DOUBLE, targetValue));
            }
            catch (TypeCheckException tce) {
                throw tce.ensureContext((PMMLObject)segmentResult.getSegment());
            }
            switch (multipleModelMethod) {
                case AVERAGE: 
                case SUM: 
                case MEDIAN: {
                    aggregator.add(value);
                    continue block25;
                }
                case WEIGHTED_AVERAGE: 
                case WEIGHTED_SUM: 
                case WEIGHTED_MEDIAN: {
                    Number weight = segmentResult.getWeight();
                    aggregator.add(value, weight);
                    continue block25;
                }
            }
            throw new IllegalArgumentException();
        }
        switch (multipleModelMethod) {
            case AVERAGE: {
                return aggregator.average();
            }
            case WEIGHTED_AVERAGE: {
                return aggregator.weightedAverage();
            }
            case SUM: {
                return aggregator.sum();
            }
            case WEIGHTED_SUM: {
                return aggregator.weightedSum();
            }
            case MEDIAN: {
                return aggregator.median();
            }
            case WEIGHTED_MEDIAN: {
                return aggregator.weightedMedian();
            }
        }
        throw new IllegalArgumentException();
    }

    public static <V extends Number> ValueMap<Object, V> aggregateVotes(ValueFactory<V> valueFactory, Segmentation.MultipleModelMethod multipleModelMethod, Segmentation.MissingPredictionTreatment missingPredictionTreatment, Number missingThreshold, List<SegmentResult> segmentResults) {
        VoteAggregator<Object, V> aggregator = new VoteAggregator<Object, V>(valueFactory);
        Fraction missingFraction = null;
        block15: for (SegmentResult segmentResult : segmentResults) {
            Object targetValue = EvaluatorUtil.decode(segmentResult.getTargetValue());
            if (targetValue == null) {
                switch (missingPredictionTreatment) {
                    case RETURN_MISSING: {
                        return null;
                    }
                    case SKIP_SEGMENT: 
                    case CONTINUE: {
                        if (missingFraction == null) {
                            missingFraction = new Fraction(valueFactory, segmentResults);
                        }
                        if (!missingFraction.update(segmentResult, missingThreshold)) break;
                        return null;
                    }
                    default: {
                        throw new IllegalArgumentException();
                    }
                }
                switch (missingPredictionTreatment) {
                    case SKIP_SEGMENT: {
                        continue block15;
                    }
                    case CONTINUE: {
                        break;
                    }
                    default: {
                        throw new IllegalArgumentException();
                    }
                }
            }
            switch (multipleModelMethod) {
                case MAJORITY_VOTE: {
                    aggregator.add(targetValue);
                    continue block15;
                }
                case WEIGHTED_MAJORITY_VOTE: {
                    Number weight = segmentResult.getWeight();
                    aggregator.add(targetValue, weight);
                    continue block15;
                }
            }
            throw new IllegalArgumentException();
        }
        ValueMap result = aggregator.sumMap();
        switch (missingPredictionTreatment) {
            case CONTINUE: {
                Collection voteSums;
                Value missingVoteSum = (Value)result.remove(null);
                if (missingVoteSum == null || (voteSums = result.values()).isEmpty() || missingVoteSum.compareTo(Collections.max(voteSums)) <= 0) break;
                return null;
            }
        }
        return result;
    }

    public static <V extends Number> ValueMap<Object, V> aggregateProbabilities(ValueFactory<V> valueFactory, Segmentation.MultipleModelMethod multipleModelMethod, Segmentation.MissingPredictionTreatment missingPredictionTreatment, Number missingThreshold, List<?> categories, List<SegmentResult> segmentResults) {
        ProbabilityAggregator aggregator;
        switch (multipleModelMethod) {
            case AVERAGE: {
                aggregator = new ProbabilityAggregator.Average<V>(valueFactory);
                break;
            }
            case WEIGHTED_AVERAGE: {
                aggregator = new ProbabilityAggregator.WeightedAverage<V>(valueFactory);
                break;
            }
            case MEDIAN: {
                aggregator = new ProbabilityAggregator.Median<V>(valueFactory, segmentResults.size());
                break;
            }
            case MAX: {
                aggregator = new ProbabilityAggregator.Max<V>(valueFactory, segmentResults.size());
                break;
            }
            default: {
                throw new IllegalArgumentException();
            }
        }
        Fraction missingFraction = null;
        block23: for (SegmentResult segmentResult : segmentResults) {
            HasProbability hasProbability;
            Object targetValue = segmentResult.getTargetValue();
            if (targetValue == null) {
                switch (missingPredictionTreatment) {
                    case RETURN_MISSING: {
                        return null;
                    }
                    case SKIP_SEGMENT: {
                        if (missingFraction == null) {
                            missingFraction = new Fraction(valueFactory, segmentResults);
                        }
                        if (!missingFraction.update(segmentResult, missingThreshold)) continue block23;
                        return null;
                    }
                    case CONTINUE: {
                        return null;
                    }
                }
                throw new IllegalArgumentException();
            }
            try {
                hasProbability = TypeUtil.cast(HasProbability.class, targetValue);
            }
            catch (TypeCheckException tce) {
                throw tce.ensureContext((PMMLObject)segmentResult.getSegment());
            }
            switch (multipleModelMethod) {
                case AVERAGE: 
                case MEDIAN: 
                case MAX: {
                    aggregator.add(hasProbability);
                    continue block23;
                }
                case WEIGHTED_AVERAGE: {
                    Number weight = segmentResult.getWeight();
                    aggregator.add(hasProbability, weight);
                    continue block23;
                }
            }
            throw new IllegalArgumentException();
        }
        switch (multipleModelMethod) {
            case AVERAGE: {
                return aggregator.averageMap();
            }
            case WEIGHTED_AVERAGE: {
                return aggregator.weightedAverageMap();
            }
            case MEDIAN: {
                return aggregator.medianMap(categories);
            }
            case MAX: {
                return aggregator.maxMap(categories);
            }
        }
        throw new IllegalArgumentException();
    }

    private static class Fraction<V extends Number> {
        private Value<V> weightSum = null;
        private Value<V> missingWeightSum = null;

        private Fraction(ValueFactory<V> valueFactory, List<SegmentResult> segmentResults) {
            this.weightSum = valueFactory.newValue();
            this.missingWeightSum = valueFactory.newValue();
            int max = segmentResults.size();
            for (int i = 0; i < max; ++i) {
                SegmentResult segmentResult = segmentResults.get(i);
                this.weightSum.add(segmentResult.getWeight());
            }
        }

        public boolean update(SegmentResult segmentResult, Number missingThreshold) {
            this.missingWeightSum.add(segmentResult.getWeight());
            return this.missingWeightSum.doubleValue() / this.weightSum.doubleValue() > missingThreshold.doubleValue();
        }
    }
}

