/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.types.inference;

import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.flink.annotation.Internal;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.functions.FunctionKind;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.inference.ArgumentCount;
import org.apache.flink.table.types.inference.CallContext;
import org.apache.flink.table.types.inference.InputTypeValidator;
import org.apache.flink.table.types.inference.Signature;
import org.apache.flink.table.types.inference.TypeInference;
import org.apache.flink.table.types.inference.TypeStrategy;
import org.apache.flink.table.types.logical.utils.LogicalTypeCasts;

@Internal
public final class TypeInferenceUtil {
    public static Result runTypeInference(TypeInference typeInference, CallContext callContext) {
        try {
            return TypeInferenceUtil.runTypeInferenceInternal(typeInference, callContext);
        }
        catch (ValidationException e) {
            throw new ValidationException(String.format("Invalid call to function '%s'. Given arguments: %s", callContext.getName(), callContext.getArgumentDataTypes().stream().map(DataType::toString).collect(Collectors.joining(", "))), e);
        }
        catch (Throwable t) {
            throw new TableException(String.format("Unexpected error in type inference logic of function '%s'. This is a bug.", callContext.getName()), t);
        }
    }

    private static Result runTypeInferenceInternal(TypeInference typeInference, CallContext callContext) {
        List<DataType> argumentTypes = callContext.getArgumentDataTypes();
        try {
            TypeInferenceUtil.validateArgumentCount(typeInference.getInputTypeValidator().getArgumentCount(), argumentTypes.size());
        }
        catch (ValidationException e) {
            throw TypeInferenceUtil.getInvalidInputException(typeInference.getInputTypeValidator(), callContext);
        }
        List<DataType> expectedTypes2 = typeInference.getArgumentTypes().orElse(argumentTypes);
        AdaptedCallContext adaptedCallContext = TypeInferenceUtil.adaptArguments(callContext, expectedTypes2);
        try {
            TypeInferenceUtil.validateInputTypes(typeInference.getInputTypeValidator(), adaptedCallContext);
        }
        catch (ValidationException e) {
            throw TypeInferenceUtil.getInvalidInputException(typeInference.getInputTypeValidator(), adaptedCallContext);
        }
        return TypeInferenceUtil.inferTypes(adaptedCallContext, typeInference.getAccumulatorTypeStrategy().orElse(null), typeInference.getOutputTypeStrategy());
    }

    private static ValidationException getInvalidInputException(InputTypeValidator validator, CallContext callContext) {
        String expectedSignatures = validator.getExpectedSignatures(callContext.getFunctionDefinition()).stream().map(s -> TypeInferenceUtil.formatSignature(callContext.getName(), s)).collect(Collectors.joining("\n"));
        return new ValidationException(String.format("Invalid input arguments. Expected signatures are:\n%s", expectedSignatures));
    }

    private static String formatSignature(String name, Signature s) {
        String arguments = s.getArguments().stream().map(TypeInferenceUtil::formatArgument).collect(Collectors.joining(", "));
        return String.format("%s(%s)", name, arguments);
    }

    private static String formatArgument(Signature.Argument arg) {
        StringBuilder stringBuilder = new StringBuilder();
        arg.getName().ifPresent(n -> stringBuilder.append((String)n).append(" => "));
        stringBuilder.append(arg.getType());
        return stringBuilder.toString();
    }

    private static void validateArgumentCount(ArgumentCount argumentCount, int actualCount) {
        argumentCount.getMinCount().ifPresent(min -> {
            if (actualCount < min) {
                throw new ValidationException(String.format("Invalid number of arguments. At least %d arguments expected but %d passed.", min, actualCount));
            }
        });
        argumentCount.getMaxCount().ifPresent(max -> {
            if (actualCount > max) {
                throw new ValidationException(String.format("Invalid number of arguments. At most %d arguments expected but %d passed.", max, actualCount));
            }
        });
        if (!argumentCount.isValidCount(actualCount)) {
            throw new ValidationException(String.format("Invalid number of arguments. %d arguments passed.", actualCount));
        }
    }

    private static void validateInputTypes(InputTypeValidator inputTypeValidator, CallContext callContext) {
        if (!inputTypeValidator.validate(callContext, true)) {
            throw new ValidationException("Invalid input arguments.");
        }
    }

    private static AdaptedCallContext adaptArguments(CallContext callContext, List<DataType> expectedTypes2) {
        List<DataType> actualTypes = callContext.getArgumentDataTypes();
        for (int pos = 0; pos < actualTypes.size(); ++pos) {
            DataType expectedType = expectedTypes2.get(pos);
            DataType actualType = actualTypes.get(pos);
            if (actualType.equals(expectedType) || TypeInferenceUtil.canCast(actualType, expectedType)) continue;
            throw new ValidationException(String.format("Invalid argument type at position %d. Data type %s expected but %s passed.", pos, expectedType, actualType));
        }
        return new AdaptedCallContext(callContext, expectedTypes2);
    }

    private static boolean canCast(DataType sourceDataType, DataType targetDataType) {
        return LogicalTypeCasts.supportsImplicitCast(sourceDataType.getLogicalType(), targetDataType.getLogicalType());
    }

    private static Result inferTypes(AdaptedCallContext adaptedCallContext, @Nullable TypeStrategy accumulatorTypeStrategy, TypeStrategy outputTypeStrategy) {
        Optional<DataType> potentialOutputType = outputTypeStrategy.inferType(adaptedCallContext);
        if (!potentialOutputType.isPresent()) {
            throw new ValidationException("Could not infer an output type for the given arguments.");
        }
        DataType outputType = potentialOutputType.get();
        if (adaptedCallContext.getFunctionDefinition().getKind() == FunctionKind.TABLE_AGGREGATE || adaptedCallContext.getFunctionDefinition().getKind() == FunctionKind.AGGREGATE) {
            if (accumulatorTypeStrategy == null) {
                return new Result(adaptedCallContext.expectedArguments, outputType, outputType);
            }
            Optional<DataType> potentialAccumulatorType = accumulatorTypeStrategy.inferType(adaptedCallContext);
            if (!potentialAccumulatorType.isPresent()) {
                throw new ValidationException("Could not infer an accumulator type for the given arguments.");
            }
            return new Result(adaptedCallContext.expectedArguments, potentialAccumulatorType.get(), outputType);
        }
        return new Result(adaptedCallContext.expectedArguments, null, outputType);
    }

    private TypeInferenceUtil() {
    }

    private static class AdaptedCallContext
    implements CallContext {
        private final CallContext originalContext;
        private final List<DataType> expectedArguments;

        public AdaptedCallContext(CallContext originalContext, List<DataType> castedArguments) {
            this.originalContext = originalContext;
            this.expectedArguments = castedArguments;
        }

        @Override
        public List<DataType> getArgumentDataTypes() {
            return this.expectedArguments;
        }

        @Override
        public FunctionDefinition getFunctionDefinition() {
            return this.originalContext.getFunctionDefinition();
        }

        @Override
        public boolean isArgumentLiteral(int pos) {
            if (this.isCasted(pos)) {
                return false;
            }
            return this.originalContext.isArgumentLiteral(pos);
        }

        @Override
        public boolean isArgumentNull(int pos) {
            return this.originalContext.isArgumentNull(pos);
        }

        @Override
        public <T> Optional<T> getArgumentValue(int pos, Class<T> clazz) {
            if (this.isCasted(pos)) {
                return Optional.empty();
            }
            return this.originalContext.getArgumentValue(pos, clazz);
        }

        @Override
        public String getName() {
            return this.originalContext.getName();
        }

        private boolean isCasted(int pos) {
            return !this.originalContext.getArgumentDataTypes().get(pos).equals(this.expectedArguments.get(pos));
        }
    }

    public static final class Result {
        private final List<DataType> expectedArgumentTypes;
        @Nullable
        private final DataType accumulatorDataType;
        private final DataType outputDataType;

        public Result(List<DataType> expectedArgumentTypes, @Nullable DataType accumulatorDataType, DataType outputDataType) {
            this.expectedArgumentTypes = expectedArgumentTypes;
            this.accumulatorDataType = accumulatorDataType;
            this.outputDataType = outputDataType;
        }

        public List<DataType> getExpectedArgumentTypes() {
            return this.expectedArgumentTypes;
        }

        public Optional<DataType> getAccumulatorDataType() {
            return Optional.ofNullable(this.accumulatorDataType);
        }

        public DataType getOutputDataType() {
            return this.outputDataType;
        }
    }
}

