/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator.python;

import com.google.common.collect.Maps;
import java.io.IOException;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import net.razorvine.pickle.Pickler;
import net.razorvine.pickle.Unpickler;
import net.razorvine.pickle.objects.ClassDict;
import numpy.core.Scalar;
import numpy.core.ScalarUtil;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.EvaluatorFunction;
import org.jpmml.evaluator.EvaluatorUtil;
import org.jpmml.evaluator.ResultTableCollector;
import org.jpmml.evaluator.Table;
import org.jpmml.evaluator.TableCollector;
import org.jpmml.python.PickleUtil;

public class PythonEvaluatorUtil {
    private PythonEvaluatorUtil() {
    }

    public static byte[] evaluate(Evaluator evaluator, byte[] dictBytes, String[] dropColumns) throws IOException {
        Map arguments = (Map)PythonEvaluatorUtil.unpickle(dictBytes);
        Map<String, ?> results = PythonEvaluatorUtil.evaluate(evaluator, arguments, dropColumns != null ? PythonEvaluatorUtil.toSet(dropColumns) : null);
        return PythonEvaluatorUtil.pickle(results);
    }

    public static Map<String, ?> evaluate(Evaluator evaluator, final Map<String, ?> arguments, Set<String> dropColumns) {
        AbstractMap<String, Object> pmmlArguments = new AbstractMap<String, Object>(){

            @Override
            public Object get(Object key) {
                Object value = arguments.get(key);
                return PythonEvaluatorUtil.toJavaPrimitive(value);
            }

            @Override
            public Set<Map.Entry<String, Object>> entrySet() {
                Maps.EntryTransformer<String, Object, Object> entryTransformer = new Maps.EntryTransformer<String, Object, Object>(){

                    public Object transformEntry(String key, Object value) {
                        return PythonEvaluatorUtil.toJavaPrimitive(value);
                    }
                };
                Map javaArguments = Maps.transformEntries((Map)arguments, (Maps.EntryTransformer)entryTransformer);
                return javaArguments.entrySet();
            }
        };
        AbstractMap<String, Object> pmmlResults = evaluator != null ? evaluator.evaluate((Map)pmmlArguments) : pmmlArguments;
        Map results = EvaluatorUtil.decodeAll((Map)pmmlResults);
        if (dropColumns != null) {
            results.keySet().removeAll(dropColumns);
        }
        return results;
    }

    public static byte[] evaluateAll(Evaluator evaluator, byte[] dictBytes, String[] dropColumns, int parallelism) throws IOException {
        Map argumentsDict = (Map)PythonEvaluatorUtil.unpickle(dictBytes);
        Map<String, ?> resultsDict = PythonEvaluatorUtil.evaluateAll(evaluator, argumentsDict, dropColumns != null ? PythonEvaluatorUtil.toSet(dropColumns) : null, parallelism);
        return PythonEvaluatorUtil.pickle(resultsDict);
    }

    public static Map<String, ?> evaluateAll(Evaluator evaluator, Map<String, ?> argumentsDict, final Set<String> dropColumns, int parallelism) {
        Table resultsTable;
        TableCollector tableCollector;
        EvaluatorFunction function;
        Table argumentsTable = PythonEvaluatorUtil.parseDict(argumentsDict);
        if (evaluator != null) {
            function = new EvaluatorFunction(evaluator);
            List resultFields = Stream.concat(evaluator.getTargetFields().stream(), evaluator.getOutputFields().stream()).filter(resultField -> {
                String name = resultField.getName();
                return dropColumns == null || !dropColumns.contains(name);
            }).collect(Collectors.toList());
            tableCollector = new ResultTableCollector(resultFields, true);
        } else {
            function = arguments -> arguments;
            tableCollector = new TableCollector(){

                protected Table createFinisherTable(int initialSize) {
                    return super.createFinisherTable(initialSize);
                }

                protected Table.Row createFinisherRow(Table table) {
                    Table table2 = table;
                    Objects.requireNonNull(table2);
                    Table.Row result = new Table.Row(table2, 0, -1){
                        {
                            Table table = x0;
                            Objects.requireNonNull(table);
                            super(table, arg0, arg1);
                        }

                        public Object put(String key, Object value) {
                            if (dropColumns != null && dropColumns.contains(key)) {
                                return null;
                            }
                            return super.put(key, value);
                        }
                    };
                    return result;
                }
            };
        }
        if (parallelism == -1) {
            resultsTable = (Table)argumentsTable.parallelStream().map(function).collect(tableCollector);
        } else if (parallelism == 1) {
            resultsTable = (Table)argumentsTable.stream().map(function).collect(tableCollector);
        } else {
            ForkJoinPool forkJoinPool = new ForkJoinPool(parallelism);
            ForkJoinTask<Table> forkJoinTask = ForkJoinTask.adapt(() -> PythonEvaluatorUtil.lambda$evaluateAll$2(argumentsTable, (Function)function, tableCollector));
            resultsTable = forkJoinPool.invoke(forkJoinTask);
            forkJoinPool.shutdown();
        }
        return PythonEvaluatorUtil.formatDict(resultsTable);
    }

    public static Object toJavaPrimitive(Object value) {
        if (value == null) {
            return value;
        }
        if (value instanceof String) {
            return value;
        }
        if (value instanceof Boolean) {
            return value;
        }
        if (value instanceof Number) {
            return value;
        }
        if (value instanceof Scalar) {
            Scalar scalar = (Scalar)value;
            return ScalarUtil.decode((Object)scalar);
        }
        if (value instanceof ClassDict) {
            ClassDict classDict = (ClassDict)value;
            throw new IllegalArgumentException("Python type " + classDict.getClassName() + " is not supported");
        }
        throw new IllegalArgumentException("Java type " + value.getClass().getName() + " is not supported");
    }

    private static <E> Set<E> toSet(E[] values) {
        if (values.length == 0) {
            return Collections.emptySet();
        }
        if (values.length == 1) {
            return Collections.singleton(values[0]);
        }
        return new HashSet<E>(Arrays.asList(values));
    }

    private static Table parseDict(Map<String, ?> dict) {
        List columns = (List)dict.get("columns");
        List data = (List)dict.get("data");
        if (columns.size() != data.size()) {
            throw new IllegalArgumentException();
        }
        Table result = new Table(columns, 256);
        for (int i = 0; i < columns.size(); ++i) {
            String column = (String)columns.get(i);
            List values = (List)data.get(i);
            result.setValues(column, values);
        }
        return result;
    }

    public static Map<String, ?> formatDict(Table table) {
        List columns = table.getColumns();
        ArrayList<List> data = new ArrayList<List>();
        for (int i = 0; i < columns.size(); ++i) {
            String column = (String)columns.get(i);
            List values = table.getValues(column);
            data.add(values);
        }
        List errors = null;
        if (table.hasExceptions()) {
            List exceptions = table.getExceptions();
            errors = exceptions.stream().map(exception -> exception != null ? exception.toString() : null).collect(Collectors.toList());
        }
        HashMap result = new HashMap();
        result.put("columns", columns);
        result.put("data", data);
        result.put("errors", errors);
        return result;
    }

    private static Object unpickle(byte[] bytes) throws IOException {
        Unpickler unpickler = new Unpickler();
        return unpickler.loads(bytes);
    }

    private static byte[] pickle(Object object) throws IOException {
        Pickler pickler = new Pickler();
        return pickler.dumps(object);
    }

    private static /* synthetic */ Table lambda$evaluateAll$2(Table argumentsTable, Function function, TableCollector tableCollector) throws Exception {
        return (Table)argumentsTable.parallelStream().map(function).collect(tableCollector);
    }

    static {
        ClassLoader clazzLoader = PythonEvaluatorUtil.class.getClassLoader();
        PickleUtil.init((ClassLoader)clazzLoader, (String)"python2pmml.properties");
    }
}

