Skip to content

Commit f3379c1

Browse files
committed
feat: add structured UDT literal support with dual encoding
Support both opaque (google.protobuf.Any) and structured (Literal.Struct) encodings for user-defined type literals per Substrait spec. - Split UserDefinedLiteral into UserDefinedAny and UserDefinedStruct - Move type parameters to interface level for parameterized types - Comprehensive test coverage including roundtrip tests - Throw exception on unhandled struct-based representation in isthmus
1 parent 22448d1 commit f3379c1

27 files changed

+722
-59
lines changed

core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,16 @@ public O visit(Expression.StructLiteral expr, C context) throws E {
151151
return visitFallback(expr, context);
152152
}
153153

154+
@Override
155+
public O visit(Expression.UserDefinedAny expr, C context) throws E {
156+
return visitFallback(expr, context);
157+
}
158+
159+
@Override
160+
public O visit(Expression.UserDefinedStruct expr, C context) throws E {
161+
return visitFallback(expr, context);
162+
}
163+
154164
@Override
155165
public O visit(Expression.Switch expr, C context) throws E {
156166
return visitFallback(expr, context);

core/src/main/java/io/substrait/expression/Expression.java

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -662,21 +662,96 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
662662
}
663663
}
664664

665+
/**
666+
* Base interface for user-defined literals.
667+
*
668+
* <p>User-defined literals can be encoded in one of two ways as per the Substrait spec:
669+
*
670+
* <ul>
671+
* <li>As {@code google.protobuf.Any} - see {@link UserDefinedAny}
672+
* <li>As {@code Literal.Struct} - see {@link UserDefinedStruct}
673+
* </ul>
674+
*
675+
* @see UserDefinedAny
676+
* @see UserDefinedStruct
677+
*/
678+
interface UserDefinedLiteral extends Literal {
679+
String urn();
680+
681+
String name();
682+
683+
List<io.substrait.proto.Type.Parameter> typeParameters();
684+
}
685+
686+
/**
687+
* User-defined literal with value encoded as {@code google.protobuf.Any}.
688+
*
689+
* <p>This encoding allows for arbitrary binary data to be stored in the literal value.
690+
*/
665691
@Value.Immutable
666-
abstract class UserDefinedLiteral implements Literal {
667-
public abstract ByteString value();
692+
abstract class UserDefinedAny implements UserDefinedLiteral {
693+
@Override
694+
public abstract String urn();
695+
696+
@Override
697+
public abstract String name();
698+
699+
@Override
700+
public abstract List<io.substrait.proto.Type.Parameter> typeParameters();
701+
702+
public abstract com.google.protobuf.Any value();
703+
704+
@Override
705+
public Type.UserDefined getType() {
706+
return Type.UserDefined.builder()
707+
.nullable(nullable())
708+
.urn(urn())
709+
.name(name())
710+
.typeParameters(typeParameters())
711+
.build();
712+
}
713+
714+
public static ImmutableExpression.UserDefinedAny.Builder builder() {
715+
return ImmutableExpression.UserDefinedAny.builder();
716+
}
668717

718+
@Override
719+
public <R, C extends VisitationContext, E extends Throwable> R accept(
720+
ExpressionVisitor<R, C, E> visitor, C context) throws E {
721+
return visitor.visit(this, context);
722+
}
723+
}
724+
725+
/**
726+
* User-defined literal with value encoded as {@code Literal.Struct}.
727+
*
728+
* <p>This encoding uses a structured list of fields to represent the literal value.
729+
*/
730+
@Value.Immutable
731+
abstract class UserDefinedStruct implements UserDefinedLiteral {
732+
@Override
669733
public abstract String urn();
670734

735+
@Override
671736
public abstract String name();
672737

673738
@Override
674-
public Type getType() {
675-
return Type.withNullability(nullable()).userDefined(urn(), name());
739+
public abstract List<io.substrait.proto.Type.Parameter> typeParameters();
740+
741+
public abstract List<Literal> fields();
742+
743+
@Override
744+
public Type.UserDefined getType() {
745+
return Type.UserDefined.builder()
746+
.nullable(nullable())
747+
.urn(urn())
748+
.name(name())
749+
.typeParameters(typeParameters())
750+
.build();
676751
}
677752

678-
public static ImmutableExpression.UserDefinedLiteral.Builder builder() {
679-
return ImmutableExpression.UserDefinedLiteral.builder();
753+
public static ImmutableExpression.UserDefinedStruct.Builder builder() {
754+
return ImmutableExpression.UserDefinedStruct.builder();
680755
}
681756

682757
@Override

core/src/main/java/io/substrait/expression/ExpressionCreator.java

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -286,13 +286,51 @@ public static Expression.StructLiteral struct(
286286
return Expression.StructLiteral.builder().nullable(nullable).addAllFields(values).build();
287287
}
288288

289-
public static Expression.UserDefinedLiteral userDefinedLiteral(
290-
boolean nullable, String urn, String name, Any value) {
291-
return Expression.UserDefinedLiteral.builder()
289+
/**
290+
* Create a UserDefinedAny with google.protobuf.Any representation.
291+
*
292+
* @param nullable whether the literal is nullable
293+
* @param urn the URN of the user-defined type
294+
* @param name the name of the user-defined type
295+
* @param typeParameters the type parameters for the user-defined type (can be empty list)
296+
* @param value the value, encoded as google.protobuf.Any
297+
*/
298+
public static Expression.UserDefinedAny userDefinedLiteralAny(
299+
boolean nullable,
300+
String urn,
301+
String name,
302+
java.util.List<io.substrait.proto.Type.Parameter> typeParameters,
303+
Any value) {
304+
return Expression.UserDefinedAny.builder()
305+
.nullable(nullable)
306+
.urn(urn)
307+
.name(name)
308+
.addAllTypeParameters(typeParameters)
309+
.value(value)
310+
.build();
311+
}
312+
313+
/**
314+
* Create a UserDefinedStruct with Struct representation.
315+
*
316+
* @param nullable whether the literal is nullable
317+
* @param urn the URN of the user-defined type
318+
* @param name the name of the user-defined type
319+
* @param typeParameters the type parameters for the user-defined type (can be empty list)
320+
* @param fields the fields, as a list of Literal values
321+
*/
322+
public static Expression.UserDefinedStruct userDefinedLiteralStruct(
323+
boolean nullable,
324+
String urn,
325+
String name,
326+
java.util.List<io.substrait.proto.Type.Parameter> typeParameters,
327+
java.util.List<Expression.Literal> fields) {
328+
return Expression.UserDefinedStruct.builder()
292329
.nullable(nullable)
293330
.urn(urn)
294331
.name(name)
295-
.value(value.toByteString())
332+
.addAllTypeParameters(typeParameters)
333+
.addAllFields(fields)
296334
.build();
297335
}
298336

core/src/main/java/io/substrait/expression/ExpressionVisitor.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ public interface ExpressionVisitor<R, C extends VisitationContext, E extends Thr
6262

6363
R visit(Expression.StructLiteral expr, C context) throws E;
6464

65-
R visit(Expression.UserDefinedLiteral expr, C context) throws E;
65+
R visit(Expression.UserDefinedAny expr, C context) throws E;
66+
67+
R visit(Expression.UserDefinedStruct expr, C context) throws E;
6668

6769
R visit(Expression.Switch expr, C context) throws E;
6870

core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
package io.substrait.expression.proto;
22

3-
import com.google.protobuf.Any;
4-
import com.google.protobuf.InvalidProtocolBufferException;
53
import io.substrait.expression.ExpressionVisitor;
64
import io.substrait.expression.FieldReference;
75
import io.substrait.expression.FunctionArg;
@@ -359,21 +357,43 @@ public Expression visit(
359357

360358
@Override
361359
public Expression visit(
362-
io.substrait.expression.Expression.UserDefinedLiteral expr, EmptyVisitationContext context) {
360+
io.substrait.expression.Expression.UserDefinedAny expr, EmptyVisitationContext context) {
363361
int typeReference =
364362
extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name()));
365363
return lit(
366364
bldr -> {
367-
try {
368-
bldr.setNullable(expr.nullable())
369-
.setUserDefined(
370-
Expression.Literal.UserDefined.newBuilder()
371-
.setTypeReference(typeReference)
372-
.setValue(Any.parseFrom(expr.value())))
373-
.build();
374-
} catch (InvalidProtocolBufferException e) {
375-
throw new IllegalStateException(e);
376-
}
365+
Expression.Literal.UserDefined.Builder userDefinedBuilder =
366+
Expression.Literal.UserDefined.newBuilder()
367+
.setTypeReference(typeReference)
368+
.addAllTypeParameters(expr.typeParameters())
369+
.setValue(expr.value());
370+
371+
bldr.setNullable(expr.nullable()).setUserDefined(userDefinedBuilder).build();
372+
});
373+
}
374+
375+
@Override
376+
public Expression visit(
377+
io.substrait.expression.Expression.UserDefinedStruct expr, EmptyVisitationContext context) {
378+
int typeReference =
379+
extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name()));
380+
return lit(
381+
bldr -> {
382+
Expression.Literal.Struct structLiteral =
383+
Expression.Literal.Struct.newBuilder()
384+
.addAllFields(
385+
expr.fields().stream()
386+
.map(this::toLiteral)
387+
.collect(java.util.stream.Collectors.toList()))
388+
.build();
389+
390+
Expression.Literal.UserDefined.Builder userDefinedBuilder =
391+
Expression.Literal.UserDefined.newBuilder()
392+
.setTypeReference(typeReference)
393+
.addAllTypeParameters(expr.typeParameters())
394+
.setStruct(structLiteral);
395+
396+
bldr.setNullable(expr.nullable()).setUserDefined(userDefinedBuilder).build();
377397
});
378398
}
379399

core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,10 +492,36 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) {
492492
{
493493
io.substrait.proto.Expression.Literal.UserDefined userDefinedLiteral =
494494
literal.getUserDefined();
495+
495496
SimpleExtension.Type type =
496497
lookup.getType(userDefinedLiteral.getTypeReference(), extensions);
497-
return ExpressionCreator.userDefinedLiteral(
498-
literal.getNullable(), type.urn(), type.name(), userDefinedLiteral.getValue());
498+
String urn = type.urn();
499+
String name = type.name();
500+
501+
switch (userDefinedLiteral.getValCase()) {
502+
case VALUE:
503+
return ExpressionCreator.userDefinedLiteralAny(
504+
literal.getNullable(),
505+
urn,
506+
name,
507+
userDefinedLiteral.getTypeParametersList(),
508+
userDefinedLiteral.getValue());
509+
case STRUCT:
510+
return ExpressionCreator.userDefinedLiteralStruct(
511+
literal.getNullable(),
512+
urn,
513+
name,
514+
userDefinedLiteral.getTypeParametersList(),
515+
userDefinedLiteral.getStruct().getFieldsList().stream()
516+
.map(this::from)
517+
.collect(Collectors.toList()));
518+
case VAL_NOT_SET:
519+
throw new IllegalStateException(
520+
"UserDefined literal has no value (neither 'value' nor 'struct' is set)");
521+
default:
522+
throw new IllegalStateException(
523+
"Unknown UserDefined literal value case: " + userDefinedLiteral.getValCase());
524+
}
499525
}
500526
default:
501527
throw new IllegalStateException("Unexpected value: " + literal.getLiteralTypeCase());

core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ public class DefaultExtensionCatalog {
2222
"extension:io.substrait:functions_rounding_decimal";
2323
public static final String FUNCTIONS_SET = "extension:io.substrait:functions_set";
2424
public static final String FUNCTIONS_STRING = "extension:io.substrait:functions_string";
25+
public static final String EXTENSION_TYPES = "extension:io.substrait:extension_types";
2526

2627
public static final SimpleExtension.ExtensionCollection DEFAULT_COLLECTION =
2728
loadDefaultCollection();
@@ -44,6 +45,8 @@ private static SimpleExtension.ExtensionCollection loadDefaultCollection() {
4445
.map(c -> String.format("/functions_%s.yaml", c))
4546
.collect(Collectors.toList());
4647

48+
defaultFiles.add("/extension_types.yaml");
49+
4750
return SimpleExtension.load(defaultFiles);
4851
}
4952
}

core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,15 @@ public Optional<Expression> visit(Expression.StructLiteral expr, EmptyVisitation
203203
return visitLiteral(expr);
204204
}
205205

206+
@Override
207+
public Optional<Expression> visit(Expression.UserDefinedAny expr, EmptyVisitationContext context)
208+
throws E {
209+
return visitLiteral(expr);
210+
}
211+
206212
@Override
207213
public Optional<Expression> visit(
208-
Expression.UserDefinedLiteral expr, EmptyVisitationContext context) throws E {
214+
Expression.UserDefinedStruct expr, EmptyVisitationContext context) throws E {
209215
return visitLiteral(expr);
210216
}
211217

core/src/main/java/io/substrait/type/Type.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,23 @@ abstract class UserDefined implements Type {
393393

394394
public abstract String name();
395395

396+
/**
397+
* Returns the type parameters for this user-defined type.
398+
*
399+
* <p>Type parameters are used to represent parameterized/generic types, such as {@code
400+
* List<i32>} or {@code Map<String, i64>}. Each parameter in the list represents a type argument
401+
* that specializes the generic user-defined type.
402+
*
403+
* <p>For example, a user-defined type {@code MyList} parameterized by {@code i32} would have
404+
* one type parameter containing the {@code i32} type definition.
405+
*
406+
* @return a list of type parameters, or an empty list if this type is not parameterized
407+
*/
408+
@Value.Default
409+
public java.util.List<io.substrait.proto.Type.Parameter> typeParameters() {
410+
return java.util.Collections.emptyList();
411+
}
412+
396413
public static ImmutableType.UserDefined.Builder builder() {
397414
return ImmutableType.UserDefined.builder();
398415
}

core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,6 @@ public final T visit(final Type.Map expr) {
165165
public final T visit(final Type.UserDefined expr) {
166166
int ref =
167167
extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name()));
168-
return typeContainer(expr).userDefined(ref);
168+
return typeContainer(expr).userDefined(ref, expr.typeParameters());
169169
}
170170
}

0 commit comments

Comments
 (0)