/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.types.inference.strategies;

import java.util.Arrays;
import java.util.HashSet;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import org.apache.flink.annotation.Internal;
import org.apache.flink.table.expressions.TableSymbol;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.inference.ArgumentTypeStrategy;
import org.apache.flink.table.types.inference.CallContext;
import org.apache.flink.table.types.inference.Signature;
import org.apache.flink.table.types.logical.LogicalTypeRoot;

@Internal
public class SymbolArgumentTypeStrategy<T extends Enum<? extends TableSymbol>>
implements ArgumentTypeStrategy {
    private final Class<T> symbolClass;
    private final Set<T> allowedVariants;

    public SymbolArgumentTypeStrategy(Class<T> symbolClass) {
        this(symbolClass, new HashSet<Enum>(Arrays.asList((Enum[])symbolClass.getEnumConstants())));
    }

    public SymbolArgumentTypeStrategy(Class<T> symbolClass, Set<T> allowedVariants) {
        this.symbolClass = symbolClass;
        this.allowedVariants = allowedVariants;
    }

    @Override
    public Optional<DataType> inferArgumentType(CallContext callContext, int argumentPos, boolean throwOnFailure) {
        DataType argumentType = callContext.getArgumentDataTypes().get(argumentPos);
        if (argumentType.getLogicalType().getTypeRoot() != LogicalTypeRoot.SYMBOL || !callContext.isArgumentLiteral(argumentPos)) {
            return callContext.fail(throwOnFailure, "Unsupported argument type. Expected symbol type '%s' but actual type was '%s'.", this.symbolClass.getSimpleName(), argumentType);
        }
        Optional<T> val = callContext.getArgumentValue(argumentPos, this.symbolClass);
        if (!val.isPresent()) {
            return callContext.fail(throwOnFailure, "Unsupported argument symbol type. Expected symbol '%s' but actual symbol was %s.", this.symbolClass.getSimpleName(), callContext.getArgumentValue(argumentPos, Enum.class).map(e -> "'" + e.getClass().getSimpleName() + "'").orElse("invalid"));
        }
        if (!this.allowedVariants.contains(val.get())) {
            return callContext.fail(throwOnFailure, "Unsupported argument symbol variant. Expected one of the following variants %s but actual symbol was %s.", this.allowedVariants, val.get());
        }
        return Optional.of(argumentType);
    }

    @Override
    public Signature.Argument getExpectedArgument(FunctionDefinition functionDefinition, int argumentPos) {
        return Signature.Argument.ofGroup(this.symbolClass);
    }

    public boolean equals(Object other) {
        if (this == other) {
            return true;
        }
        if (other == null || this.getClass() != other.getClass()) {
            return false;
        }
        SymbolArgumentTypeStrategy that = (SymbolArgumentTypeStrategy)other;
        return Objects.equals(this.symbolClass, that.symbolClass);
    }

    public int hashCode() {
        return Objects.hash(this.symbolClass);
    }
}

