Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,16 @@ public O visit(Expression.StructLiteral expr, C context) throws E {
return visitFallback(expr, context);
}

@Override
public O visit(Expression.UserDefinedAnyLiteral expr, C context) throws E {
return visitFallback(expr, context);
}

@Override
public O visit(Expression.UserDefinedStructLiteral expr, C context) throws E {
return visitFallback(expr, context);
}

@Override
public O visit(Expression.Switch expr, C context) throws E {
return visitFallback(expr, context);
Expand Down
87 changes: 81 additions & 6 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -662,21 +662,96 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
}
}

/**
* Base interface for user-defined literals.
*
* <p>User-defined literals can be encoded in one of two ways as per the Substrait spec:
*
* <ul>
* <li>As {@code google.protobuf.Any} - see {@link UserDefinedAnyLiteral}
* <li>As {@code Literal.Struct} - see {@link UserDefinedStructLiteral}
* </ul>
*
* @see UserDefinedAnyLiteral
* @see UserDefinedStructLiteral
*/
interface UserDefinedLiteral extends Literal {
String urn();

String name();

List<io.substrait.type.Type.Parameter> typeParameters();
}

/**
* User-defined literal with value encoded as {@code google.protobuf.Any}.
*
* <p>This encoding allows for arbitrary binary data to be stored in the literal value.
*/
@Value.Immutable
abstract class UserDefinedLiteral implements Literal {
public abstract ByteString value();
abstract class UserDefinedAnyLiteral implements UserDefinedLiteral {
@Override
public abstract String urn();

@Override
public abstract String name();

@Override
public abstract List<io.substrait.type.Type.Parameter> typeParameters();

public abstract com.google.protobuf.Any value();

@Override
public Type.UserDefined getType() {
return Type.UserDefined.builder()
.nullable(nullable())
.urn(urn())
.name(name())
.typeParameters(typeParameters())
.build();
}

public static ImmutableExpression.UserDefinedAnyLiteral.Builder builder() {
return ImmutableExpression.UserDefinedAnyLiteral.builder();
}

@Override
public <R, C extends VisitationContext, E extends Throwable> R accept(
ExpressionVisitor<R, C, E> visitor, C context) throws E {
return visitor.visit(this, context);
}
}

/**
* User-defined literal with value encoded as {@code Literal.Struct}.
*
* <p>This encoding uses a structured list of fields to represent the literal value.
*/
@Value.Immutable
abstract class UserDefinedStructLiteral implements UserDefinedLiteral {
@Override
public abstract String urn();

@Override
public abstract String name();

@Override
public Type getType() {
return Type.withNullability(nullable()).userDefined(urn(), name());
public abstract List<io.substrait.type.Type.Parameter> typeParameters();

public abstract List<Literal> fields();

@Override
public Type.UserDefined getType() {
return Type.UserDefined.builder()
.nullable(nullable())
.urn(urn())
.name(name())
.typeParameters(typeParameters())
.build();
}

public static ImmutableExpression.UserDefinedLiteral.Builder builder() {
return ImmutableExpression.UserDefinedLiteral.builder();
public static ImmutableExpression.UserDefinedStructLiteral.Builder builder() {
return ImmutableExpression.UserDefinedStructLiteral.builder();
}

@Override
Expand Down
46 changes: 42 additions & 4 deletions core/src/main/java/io/substrait/expression/ExpressionCreator.java
Original file line number Diff line number Diff line change
Expand Up @@ -286,13 +286,51 @@ public static Expression.StructLiteral struct(
return Expression.StructLiteral.builder().nullable(nullable).addAllFields(values).build();
}

public static Expression.UserDefinedLiteral userDefinedLiteral(
boolean nullable, String urn, String name, Any value) {
return Expression.UserDefinedLiteral.builder()
/**
* Create a UserDefinedAnyLiteral with google.protobuf.Any representation.
*
* @param nullable whether the literal is nullable
* @param urn the URN of the user-defined type
* @param name the name of the user-defined type
* @param typeParameters the type parameters for the user-defined type (can be empty list)
* @param value the value, encoded as google.protobuf.Any
*/
public static Expression.UserDefinedAnyLiteral userDefinedLiteralAny(
boolean nullable,
String urn,
String name,
java.util.List<io.substrait.type.Type.Parameter> typeParameters,
Any value) {
return Expression.UserDefinedAnyLiteral.builder()
.nullable(nullable)
.urn(urn)
.name(name)
.addAllTypeParameters(typeParameters)
.value(value)
.build();
}

/**
* Create a UserDefinedStructLiteral with Struct representation.
*
* @param nullable whether the literal is nullable
* @param urn the URN of the user-defined type
* @param name the name of the user-defined type
* @param typeParameters the type parameters for the user-defined type (can be empty list)
* @param fields the fields, as a list of Literal values
*/
public static Expression.UserDefinedStructLiteral userDefinedLiteralStruct(
boolean nullable,
String urn,
String name,
java.util.List<io.substrait.type.Type.Parameter> typeParameters,
java.util.List<Expression.Literal> fields) {
return Expression.UserDefinedStructLiteral.builder()
.nullable(nullable)
.urn(urn)
.name(name)
.value(value.toByteString())
.addAllTypeParameters(typeParameters)
.addAllFields(fields)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ public interface ExpressionVisitor<R, C extends VisitationContext, E extends Thr

R visit(Expression.StructLiteral expr, C context) throws E;

R visit(Expression.UserDefinedLiteral expr, C context) throws E;
R visit(Expression.UserDefinedAnyLiteral expr, C context) throws E;

R visit(Expression.UserDefinedStructLiteral expr, C context) throws E;

R visit(Expression.Switch expr, C context) throws E;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package io.substrait.expression.proto;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import io.substrait.expression.ExpressionVisitor;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
Expand Down Expand Up @@ -359,21 +357,51 @@ public Expression visit(

@Override
public Expression visit(
io.substrait.expression.Expression.UserDefinedLiteral expr, EmptyVisitationContext context) {
io.substrait.expression.Expression.UserDefinedAnyLiteral expr,
EmptyVisitationContext context) {
int typeReference =
extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name()));
return lit(
bldr -> {
try {
bldr.setNullable(expr.nullable())
.setUserDefined(
Expression.Literal.UserDefined.newBuilder()
.setTypeReference(typeReference)
.setValue(Any.parseFrom(expr.value())))
.build();
} catch (InvalidProtocolBufferException e) {
throw new IllegalStateException(e);
}
Expression.Literal.UserDefined.Builder userDefinedBuilder =
Expression.Literal.UserDefined.newBuilder()
.setTypeReference(typeReference)
.addAllTypeParameters(
expr.typeParameters().stream()
.map(typeProtoConverter::toProto)
.collect(java.util.stream.Collectors.toList()))
.setValue(expr.value());

bldr.setNullable(expr.nullable()).setUserDefined(userDefinedBuilder).build();
});
}

@Override
public Expression visit(
io.substrait.expression.Expression.UserDefinedStructLiteral expr,
EmptyVisitationContext context) {
int typeReference =
extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name()));
return lit(
bldr -> {
Expression.Literal.Struct structLiteral =
Expression.Literal.Struct.newBuilder()
.addAllFields(
expr.fields().stream()
.map(this::toLiteral)
.collect(java.util.stream.Collectors.toList()))
.build();

Expression.Literal.UserDefined.Builder userDefinedBuilder =
Expression.Literal.UserDefined.newBuilder()
.setTypeReference(typeReference)
.addAllTypeParameters(
expr.typeParameters().stream()
.map(typeProtoConverter::toProto)
.collect(java.util.stream.Collectors.toList()))
.setStruct(structLiteral);

bldr.setNullable(expr.nullable()).setUserDefined(userDefinedBuilder).build();
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,10 +492,40 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) {
{
io.substrait.proto.Expression.Literal.UserDefined userDefinedLiteral =
literal.getUserDefined();

SimpleExtension.Type type =
lookup.getType(userDefinedLiteral.getTypeReference(), extensions);
return ExpressionCreator.userDefinedLiteral(
literal.getNullable(), type.urn(), type.name(), userDefinedLiteral.getValue());
String urn = type.urn();
String name = type.name();

switch (userDefinedLiteral.getValCase()) {
case VALUE:
return ExpressionCreator.userDefinedLiteralAny(
literal.getNullable(),
urn,
name,
userDefinedLiteral.getTypeParametersList().stream()
.map(protoTypeConverter::from)
.collect(Collectors.toList()),
userDefinedLiteral.getValue());
case STRUCT:
return ExpressionCreator.userDefinedLiteralStruct(
literal.getNullable(),
urn,
name,
userDefinedLiteral.getTypeParametersList().stream()
.map(protoTypeConverter::from)
.collect(Collectors.toList()),
userDefinedLiteral.getStruct().getFieldsList().stream()
.map(this::from)
.collect(Collectors.toList()));
case VAL_NOT_SET:
throw new IllegalStateException(
"UserDefined literal has no value (neither 'value' nor 'struct' is set)");
default:
throw new IllegalStateException(
"Unknown UserDefined literal value case: " + userDefinedLiteral.getValCase());
}
}
default:
throw new IllegalStateException("Unexpected value: " + literal.getLiteralTypeCase());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public class DefaultExtensionCatalog {
"extension:io.substrait:functions_rounding_decimal";
public static final String FUNCTIONS_SET = "extension:io.substrait:functions_set";
public static final String FUNCTIONS_STRING = "extension:io.substrait:functions_string";
public static final String EXTENSION_TYPES = "extension:io.substrait:extension_types";

public static final SimpleExtension.ExtensionCollection DEFAULT_COLLECTION =
loadDefaultCollection();
Expand All @@ -44,6 +45,8 @@ private static SimpleExtension.ExtensionCollection loadDefaultCollection() {
.map(c -> String.format("/functions_%s.yaml", c))
.collect(Collectors.toList());

defaultFiles.add("/extension_types.yaml");

return SimpleExtension.load(defaultFiles);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,13 @@ public Optional<Expression> visit(Expression.StructLiteral expr, EmptyVisitation

@Override
public Optional<Expression> visit(
Expression.UserDefinedLiteral expr, EmptyVisitationContext context) throws E {
Expression.UserDefinedAnyLiteral expr, EmptyVisitationContext context) throws E {
return visitLiteral(expr);
}

@Override
public Optional<Expression> visit(
Expression.UserDefinedStructLiteral expr, EmptyVisitationContext context) throws E {
return visitLiteral(expr);
}

Expand Down
Loading
Loading