/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator.clustering;

import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.dmg.pmml.Array;
import org.dmg.pmml.ComparisonMeasure;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Distance;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Measure;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Similarity;
import org.dmg.pmml.Target;
import org.dmg.pmml.Targets;
import org.dmg.pmml.clustering.CenterFields;
import org.dmg.pmml.clustering.Cluster;
import org.dmg.pmml.clustering.ClusteringField;
import org.dmg.pmml.clustering.ClusteringModel;
import org.dmg.pmml.clustering.MissingValueWeights;
import org.dmg.pmml.clustering.PMMLAttributes;
import org.dmg.pmml.clustering.PMMLElements;
import org.jpmml.evaluator.ArrayUtil;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.EntityUtil;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.FieldValueUtil;
import org.jpmml.evaluator.HasEntityRegistry;
import org.jpmml.evaluator.InvalidElementException;
import org.jpmml.evaluator.MeasureUtil;
import org.jpmml.evaluator.MisplacedElementException;
import org.jpmml.evaluator.MissingAttributeException;
import org.jpmml.evaluator.MissingElementException;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.PMMLUtil;
import org.jpmml.evaluator.TypeInfos;
import org.jpmml.evaluator.UnsupportedAttributeException;
import org.jpmml.evaluator.UnsupportedElementException;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.ValueFactory;
import org.jpmml.evaluator.ValueMap;
import org.jpmml.evaluator.clustering.ClusterAffinityDistribution;

public class ClusteringModelEvaluator
extends ModelEvaluator<ClusteringModel>
implements HasEntityRegistry<Cluster> {
    private BiMap<String, Cluster> entityRegistry = ImmutableBiMap.of();
    private Map<Cluster, ?> clusterCentroids = Collections.emptyMap();

    private ClusteringModelEvaluator() {
    }

    public ClusteringModelEvaluator(PMML pmml) {
        this(pmml, PMMLUtil.findModel(pmml, ClusteringModel.class));
    }

    public ClusteringModelEvaluator(PMML pmml, ClusteringModel clusteringModel) {
        super(pmml, clusteringModel);
        Targets targets = clusteringModel.getTargets();
        if (targets != null) {
            throw new MisplacedElementException((PMMLObject)targets);
        }
        ComparisonMeasure comparisonMeasure = clusteringModel.getComparisonMeasure();
        if (comparisonMeasure == null) {
            throw new MissingElementException((PMMLObject)clusteringModel, PMMLElements.CLUSTERINGMODEL_COMPARISONMEASURE);
        }
        ClusteringModel.ModelClass modelClass = clusteringModel.getModelClass();
        switch (modelClass) {
            case CENTER_BASED: {
                break;
            }
            default: {
                throw new UnsupportedAttributeException((PMMLObject)clusteringModel, (Enum<?>)modelClass);
            }
        }
        CenterFields centerFields = clusteringModel.getCenterFields();
        if (centerFields != null) {
            throw new UnsupportedElementException((PMMLObject)centerFields);
        }
        if (!clusteringModel.hasClusteringFields()) {
            throw new MissingElementException((PMMLObject)clusteringModel, PMMLElements.CLUSTERINGMODEL_CLUSTERINGFIELDS);
        }
        if (!clusteringModel.hasClusters()) {
            throw new MissingElementException((PMMLObject)clusteringModel, PMMLElements.CLUSTERINGMODEL_CLUSTERS);
        }
        List clusters = clusteringModel.getClusters();
        this.entityRegistry = ImmutableBiMap.copyOf(EntityUtil.buildBiMap(clusters));
        Map<Cluster, List<FieldValue>> clusterValues = ClusteringModelEvaluator.parseClusterValues(clusters);
        Measure measure = MeasureUtil.ensureMeasure(comparisonMeasure);
        if (measure instanceof Distance) {
            clusterValues.replaceAll((key, value) -> ImmutableList.copyOf((Collection)value));
            this.clusterCentroids = ImmutableMap.copyOf(clusterValues);
        } else if (measure instanceof Similarity) {
            Map<Cluster, BitSet> clusterFlags = clusterValues.entrySet().stream().collect(Collectors.toMap(entry -> (Cluster)entry.getKey(), entry -> MeasureUtil.toBitSet((List)entry.getValue())));
            this.clusterCentroids = ImmutableMap.copyOf(clusterFlags);
        } else {
            throw new UnsupportedElementException((PMMLObject)measure);
        }
    }

    @Override
    public String getSummary() {
        return "Clustering model";
    }

    @Override
    public Target getTarget(FieldName name) {
        return null;
    }

    @Override
    public BiMap<String, Cluster> getEntityRegistry() {
        return this.entityRegistry;
    }

    @Override
    protected <V extends Number> Map<FieldName, ClusterAffinityDistribution<V>> evaluateClustering(ValueFactory<V> valueFactory, EvaluationContext context) {
        ClusterAffinityDistribution<V> result;
        ClusteringModel clusteringModel = (ClusteringModel)this.getModel();
        ComparisonMeasure comparisonMeasure = clusteringModel.getComparisonMeasure();
        List clusteringFields = clusteringModel.getClusteringFields();
        ArrayList<FieldValue> values = new ArrayList<FieldValue>(clusteringFields.size());
        int max = clusteringFields.size();
        block4: for (int i = 0; i < max; ++i) {
            ClusteringField clusteringField = (ClusteringField)clusteringFields.get(i);
            FieldName name = clusteringField.getField();
            if (name == null) {
                throw new MissingAttributeException((PMMLObject)clusteringField, PMMLAttributes.CLUSTERINGFIELD_FIELD);
            }
            ClusteringField.CenterField centerField = clusteringField.getCenterField();
            switch (centerField) {
                case TRUE: {
                    break;
                }
                case FALSE: {
                    continue block4;
                }
                default: {
                    throw new UnsupportedAttributeException((PMMLObject)clusteringField, (Enum<?>)centerField);
                }
            }
            FieldValue value = context.evaluate(name);
            values.add(value);
        }
        Measure measure = MeasureUtil.ensureMeasure(comparisonMeasure);
        if (measure instanceof Similarity) {
            result = this.evaluateSimilarity(valueFactory, comparisonMeasure, clusteringFields, values);
        } else if (measure instanceof Distance) {
            result = this.evaluateDistance(valueFactory, comparisonMeasure, clusteringFields, values);
        } else {
            throw new UnsupportedElementException((PMMLObject)measure);
        }
        result.computeResult(DataType.STRING);
        return Collections.singletonMap(this.getTargetName(), result);
    }

    private <V extends Number> ClusterAffinityDistribution<V> evaluateSimilarity(ValueFactory<V> valueFactory, ComparisonMeasure comparisonMeasure, List<ClusteringField> clusteringFields, List<FieldValue> values) {
        ClusteringModel clusteringModel = (ClusteringModel)this.getModel();
        List clusters = clusteringModel.getClusters();
        ClusterAffinityDistribution<V> result = this.createClusterAffinityDistribution(Classification.Type.SIMILARITY, clusters);
        BitSet flags = MeasureUtil.toBitSet(values);
        for (Cluster cluster : clusters) {
            BitSet clusterFlags = (BitSet)this.getClusterCentroid(cluster);
            if (flags.size() != clusterFlags.size()) {
                throw new InvalidElementException((PMMLObject)cluster);
            }
            Value<V> similarity = MeasureUtil.evaluateSimilarity(valueFactory, comparisonMeasure, clusteringFields, flags, clusterFlags);
            result.put(cluster, similarity);
        }
        return result;
    }

    private <V extends Number> ClusterAffinityDistribution<V> evaluateDistance(ValueFactory<V> valueFactory, ComparisonMeasure comparisonMeasure, List<ClusteringField> clusteringFields, List<FieldValue> values) {
        Value<V> adjustment;
        ClusteringModel clusteringModel = (ClusteringModel)this.getModel();
        List clusters = clusteringModel.getClusters();
        MissingValueWeights missingValueWeights = clusteringModel.getMissingValueWeights();
        if (missingValueWeights != null) {
            Array array = missingValueWeights.getArray();
            List<? extends Number> adjustmentValues = ArrayUtil.asNumberList(array);
            if (values.size() != adjustmentValues.size()) {
                throw new InvalidElementException((PMMLObject)missingValueWeights);
            }
            adjustment = MeasureUtil.calculateAdjustment(valueFactory, values, adjustmentValues);
        } else {
            adjustment = MeasureUtil.calculateAdjustment(valueFactory, values);
        }
        ClusterAffinityDistribution<V> result = this.createClusterAffinityDistribution(Classification.Type.DISTANCE, clusters);
        for (Cluster cluster : clusters) {
            List clusterValues = (List)this.getClusterCentroid(cluster);
            if (values.size() != clusterValues.size()) {
                throw new InvalidElementException((PMMLObject)cluster);
            }
            Value<V> distance = MeasureUtil.evaluateDistance(valueFactory, comparisonMeasure, clusteringFields, values, clusterValues, adjustment);
            result.put(cluster, distance);
        }
        return result;
    }

    private <V extends Number> ClusterAffinityDistribution<V> createClusterAffinityDistribution(Classification.Type type, List<Cluster> clusters) {
        ClusterAffinityDistribution result = new ClusterAffinityDistribution<V>(type, new ValueMap(2 * clusters.size())){

            @Override
            public BiMap<String, Cluster> getEntityRegistry() {
                return ClusteringModelEvaluator.this.getEntityRegistry();
            }
        };
        return result;
    }

    private Object getClusterCentroid(Cluster cluster) {
        return this.clusterCentroids.get(cluster);
    }

    private static Map<Cluster, List<FieldValue>> parseClusterValues(List<Cluster> clusters) {
        HashMap<Cluster, List<FieldValue>> result = new HashMap<Cluster, List<FieldValue>>();
        for (Cluster cluster : clusters) {
            Array array = cluster.getArray();
            if (array == null) {
                throw new MissingElementException((PMMLObject)cluster, PMMLElements.CLUSTER_ARRAY);
            }
            List<? extends Number> values = ArrayUtil.asNumberList(array);
            result.put(cluster, new ArrayList(Lists.transform(values, value -> FieldValueUtil.create(TypeInfos.CONTINUOUS_DOUBLE, value))));
        }
        return result;
    }
}

