Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ public Type getType() {

public abstract Expression.AggregationInvocation invocation();

/**
* Validates that variadic arguments satisfy the parameter consistency requirement. When
* CONSISTENT, all variadic arguments must have the same type (ignoring nullability). When
* INCONSISTENT, arguments can have different types.
*/
@Value.Check
protected void check() {
VariadicParameterConsistencyValidator.validate(declaration(), arguments());
}

public static ImmutableAggregateFunctionInvocation.Builder builder() {
return ImmutableAggregateFunctionInvocation.builder();
}
Expand Down
20 changes: 20 additions & 0 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,16 @@ public Type getType() {
return outputType();
}

/**
* Validates that variadic arguments satisfy the parameter consistency requirement. When
* CONSISTENT, all variadic arguments must have the same type (ignoring nullability). When
* INCONSISTENT, arguments can have different types.
*/
@Value.Check
protected void check() {
VariadicParameterConsistencyValidator.validate(declaration(), arguments());
}

public static ImmutableExpression.ScalarFunctionInvocation.Builder builder() {
return ImmutableExpression.ScalarFunctionInvocation.builder();
}
Expand Down Expand Up @@ -840,6 +850,16 @@ public Type getType() {

public abstract AggregationInvocation invocation();

/**
* Validates that variadic arguments satisfy the parameter consistency requirement. When
* CONSISTENT, all variadic arguments must have the same type (ignoring nullability). When
* INCONSISTENT, arguments can have different types.
*/
@Value.Check
protected void check() {
VariadicParameterConsistencyValidator.validate(declaration(), arguments());
}

public static ImmutableExpression.WindowFunctionInvocation.Builder builder() {
return ImmutableExpression.WindowFunctionInvocation.builder();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package io.substrait.expression;

import io.substrait.extension.SimpleExtension;
import io.substrait.type.Type;
import java.util.List;

/**
* Helper class for validating variadic parameter consistency in function invocations. Validates that
* when parameterConsistency is CONSISTENT, all variadic arguments have the same type (ignoring
* nullability).
*/
public class VariadicParameterConsistencyValidator {

/**
* Validates that variadic arguments satisfy the parameter consistency requirement. When
* CONSISTENT, all variadic arguments must have the same type (ignoring nullability). When
* INCONSISTENT, arguments can have different types.
*
* @param func the function declaration
* @param arguments the function arguments to validate
* @throws AssertionError if validation fails
*/
public static void validate(
SimpleExtension.Function func, List<FunctionArg> arguments) {
java.util.Optional<SimpleExtension.VariadicBehavior> variadic = func.variadic();
if (!variadic.isPresent()) {
return;
}

SimpleExtension.VariadicBehavior variadicBehavior = variadic.get();
if (variadicBehavior.parameterConsistency()
!= SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) {
// INCONSISTENT allows different types, so validation passes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So unfortunately, this is going to be a bit more complicated than it initially seemed. From the docs:

  // Each argument can be any possible concrete type afforded by the bounds
  // of any parameter defined in the arguments specification.

Let me show you a (made-up) example:

urn: "urn:example:extension"
scalar_functions:
  - name: "inconsistent_sum"
    impls:
      - args:
          - value: "decimal<P,S>"
            variadic:
              min: 1
              parameterConsistency: INCONSISTENT
        return: "decimal<38,S>"

This means that the following are valid:

  • inconsistent_sum(decimal<10, 2>, decimal<10, 2>, decimal<10, 2>)
  • inconsistent_sum(decimal<10, 2>, decimal<10, 2>)
  • inconsistent_sum(decimal<15, 8>, decimal<11, 8>, decimal<119, 8>)

On the other hand, the following are all invalid:

  • inconsistent_sum(decimal<10, 2>, decimal<10, 3>, decimal<10, 4>)
  • inconsistent_sum(decimal<10, 2>, i32)

The above example is showing that there is the implicit constraint across the variadic parameters that the scale S must be the same as the output (and thus must all be the same across the variadic parameter). On the other hand, the precision P is free to be anything. But all of the parameters do have to be decimal.

Not sure if that was instructive or not 😅

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For readability, it may make sense to have two private methods that implement each behavior, then have the public method just be a switch on the parameter consistency options. Just an idea!

return;
}

// Extract types from arguments (only Expression and Type have types, EnumArg doesn't)
List<Type> argumentTypes =
arguments.stream()
.filter(arg -> arg instanceof Expression || arg instanceof Type)
.map(
arg -> {
if (arg instanceof Expression) {
return ((Expression) arg).getType();
} else {
return (Type) arg;
}
})
.collect(java.util.stream.Collectors.toList());

int fixedArgCount = func.args().size();
if (argumentTypes.size() <= fixedArgCount) {
// No variadic arguments, validation passes
return;
}

// For CONSISTENT, all variadic arguments must have the same type (ignoring nullability)
// Compare all variadic arguments to the first one for more informative error messages
// Variadic arguments start after the fixed arguments
int firstVariadicArgIdx = fixedArgCount + Math.max(variadicBehavior.getMin() - 1, 0);
Type firstVariadicType = argumentTypes.get(firstVariadicArgIdx);
for (int i = firstVariadicArgIdx + 1; i < argumentTypes.size(); i++) {
Type currentType = argumentTypes.get(i);
if (!firstVariadicType.equalsIgnoringNullability(currentType)) {
throw new AssertionError(
String.format(
"Variadic arguments must have consistent types when parameterConsistency is CONSISTENT. "
+ "Argument at index %d has type %s but argument at index %d has type %s",
firstVariadicArgIdx, firstVariadicType, i, currentType));
}
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ enum ParameterConsistency {
INCONSISTENT
}

@Value.Default
default ParameterConsistency parameterConsistency() {
return ParameterConsistency.CONSISTENT;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a reasonble default, because I think it's what most people expect in practice and it's also the first value in the enumeration.

We can formalize this in the spec more concretely.

}
Expand Down
10 changes: 10 additions & 0 deletions core/src/main/java/io/substrait/type/Type.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ default <R, C extends VisitationContext, E extends Throwable> R accept(
return fnArgVisitor.visitType(fnDef, argIdx, this, context);
}

/**
* Compares this type with another type, ignoring nullability differences.
*
* @param other the type to compare with
* @return true if the types are equal when both are treated as nullable
*/
default boolean equalsIgnoringNullability(Type other) {
return TypeCreator.asNullable(this).equals(TypeCreator.asNullable(other));
}

@Value.Immutable
abstract class Bool implements Type {
public static ImmutableType.Bool.Builder builder() {
Expand Down
Loading
Loading