diff --git a/build.gradle.kts b/build.gradle.kts index b1525593..2618aa79 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -339,12 +339,12 @@ dependencies { annotationProcessor(libs.nullaway) api(libs.jspecify) api(libs.protobuf.java) - implementation(enforcedPlatform(libs.cel)) - implementation(libs.cel.core) + implementation(libs.cel) buf("build.buf:buf:${libs.versions.buf.get()}:${osdetector.classifier}@exe") testImplementation(libs.assertj) + testImplementation(libs.grpc.protobuf) testImplementation(platform(libs.junit.bom)) testImplementation("org.junit.jupiter:junit-jupiter") testRuntimeOnly("org.junit.platform:junit-platform-launcher") diff --git a/conformance/expected-failures.yaml b/conformance/expected-failures.yaml index e69de29b..c69ddf5a 100644 --- a/conformance/expected-failures.yaml +++ b/conformance/expected-failures.yaml @@ -0,0 +1,9 @@ +standard_rules/bytes: + - pattern/invalid/not_utf8 +# input: [ type.googleapis.com/buf.validate.conformance.cases.BytesPattern ]:{val:"\x99"} +# want: runtime error: value must be valid UTF-8 to apply regexp +# got: validation error (1 violation) +# 1. rule_id: "bytes.pattern" +# message: "value must match regex pattern `^[\\x00-\\x7F]+$`" +# field: "val" elements:{field_number:1 field_name:"val" field_type:TYPE_BYTES} +# rule: "bytes.pattern" elements:{field_number:15 field_name:"bytes" field_type:TYPE_MESSAGE} elements:{field_number:4 field_name:"pattern" field_type:TYPE_STRING} \ No newline at end of file diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 3d58bb4b..a717bbe4 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -1,7 +1,7 @@ [versions] assertj = "3.27.3" buf = "1.54.0" -cel = "0.5.3" +cel = "0.9.1" error-prone = "2.38.0" junit = "5.13.0" maven-publish = "0.32.0" @@ -10,10 +10,10 @@ protobuf = "4.31.1" [libraries] assertj = { module = "org.assertj:assertj-core", version.ref = "assertj" } buf = { module = "build.buf:buf", version.ref = "buf" } -cel = { module = "org.projectnessie.cel:cel-bom", version.ref = "cel" } -cel-core = { module = "org.projectnessie.cel:cel-core" } +cel = { module = "dev.cel:cel", version.ref = "cel" } errorprone-annotations = { module = "com.google.errorprone:error_prone_annotations", version.ref = "error-prone" } errorprone-core = { module = "com.google.errorprone:error_prone_core", version.ref = "error-prone" } +grpc-protobuf = { module = "io.grpc:grpc-protobuf", version = "1.73.0" } jspecify = { module ="org.jspecify:jspecify", version = "1.0.0" } junit-bom = { module = "org.junit:junit-bom", version.ref = "junit" } maven-plugin = { module = "com.vanniktech:gradle-maven-publish-plugin", version.ref = "maven-publish" } diff --git a/src/main/java/build/buf/protovalidate/AstExpression.java b/src/main/java/build/buf/protovalidate/AstExpression.java index 24cce494..81841d9b 100644 --- a/src/main/java/build/buf/protovalidate/AstExpression.java +++ b/src/main/java/build/buf/protovalidate/AstExpression.java @@ -15,20 +15,22 @@ package build.buf.protovalidate; import build.buf.protovalidate.exceptions.CompilationException; -import com.google.api.expr.v1alpha1.Type; -import org.projectnessie.cel.Ast; -import org.projectnessie.cel.Env; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelValidationException; +import dev.cel.common.CelValidationResult; +import dev.cel.common.types.CelKind; +import dev.cel.compiler.CelCompiler; -/** {@link AstExpression} is a compiled CEL {@link Ast}. */ +/** {@link AstExpression} is a compiled CEL {@link CelAbstractSyntaxTree}. */ class AstExpression { /** The compiled CEL AST. */ - public final Ast ast; + public final CelAbstractSyntaxTree ast; /** Contains the original expression from the proto file. */ public final Expression source; /** Constructs a new {@link AstExpression}. */ - private AstExpression(Ast ast, Expression source) { + private AstExpression(CelAbstractSyntaxTree ast, Expression source) { this.ast = ast; this.source = source; } @@ -36,25 +38,32 @@ private AstExpression(Ast ast, Expression source) { /** * Compiles the given expression to a {@link AstExpression}. * - * @param env The CEL environment. + * @param cel The CEL compiler. * @param expr The expression to compile. * @return The compiled {@link AstExpression}. * @throws CompilationException if the expression compilation fails. */ - public static AstExpression newAstExpression(Env env, Expression expr) + public static AstExpression newAstExpression(CelCompiler cel, Expression expr) throws CompilationException { - Env.AstIssuesTuple astIssuesTuple = env.compile(expr.expression); - if (astIssuesTuple.hasIssues()) { + CelValidationResult compileResult = cel.compile(expr.expression); + if (!compileResult.getAllIssues().isEmpty()) { throw new CompilationException( - "Failed to compile expression " + expr.id + ":\n" + astIssuesTuple.getIssues()); + "Failed to compile expression " + expr.id + ":\n" + compileResult.getIssueString()); } - Ast ast = astIssuesTuple.getAst(); - Type outType = ast.getResultType(); - if (outType.getPrimitive() != Type.PrimitiveType.BOOL - && outType.getPrimitive() != Type.PrimitiveType.STRING) { + CelAbstractSyntaxTree ast; + try { + ast = compileResult.getAst(); + } catch (CelValidationException e) { + // This will not happen as we checked for issues, and it only throws when + // it has at least one issue of error severity. + throw new CompilationException( + "Failed to compile expression " + expr.id + ":\n" + compileResult.getIssueString()); + } + CelKind outKind = ast.getResultType().kind(); + if (outKind != CelKind.BOOL && outKind != CelKind.STRING) { throw new CompilationException( String.format( - "Expression outputs, wanted either bool or string: %s %s", expr.id, outType)); + "Expression outputs, wanted either bool or string: %s %s", expr.id, outKind)); } return new AstExpression(ast, expr); } diff --git a/src/main/java/build/buf/protovalidate/CelPrograms.java b/src/main/java/build/buf/protovalidate/CelPrograms.java index 74ab7abf..0921642f 100644 --- a/src/main/java/build/buf/protovalidate/CelPrograms.java +++ b/src/main/java/build/buf/protovalidate/CelPrograms.java @@ -15,6 +15,7 @@ package build.buf.protovalidate; import build.buf.protovalidate.exceptions.ExecutionException; +import dev.cel.runtime.CelVariableResolver; import java.util.ArrayList; import java.util.List; import org.jspecify.annotations.Nullable; @@ -44,10 +45,10 @@ public boolean tautology() { @Override public List evaluate(Value val, boolean failFast) throws ExecutionException { - Variable activation = Variable.newThisVariable(val.value(Object.class)); + CelVariableResolver bindings = Variable.newThisVariable(val.value(Object.class)); List violations = new ArrayList<>(); for (CompiledProgram program : programs) { - RuleViolation.Builder violation = program.eval(val, activation); + RuleViolation.Builder violation = program.eval(val, bindings); if (violation != null) { violations.add(violation); if (failFast) { diff --git a/src/main/java/build/buf/protovalidate/CompiledProgram.java b/src/main/java/build/buf/protovalidate/CompiledProgram.java index 3fac0d33..27ca0a11 100644 --- a/src/main/java/build/buf/protovalidate/CompiledProgram.java +++ b/src/main/java/build/buf/protovalidate/CompiledProgram.java @@ -16,10 +16,10 @@ import build.buf.protovalidate.exceptions.ExecutionException; import build.buf.validate.FieldPath; +import dev.cel.runtime.CelEvaluationException; +import dev.cel.runtime.CelRuntime.Program; +import dev.cel.runtime.CelVariableResolver; import org.jspecify.annotations.Nullable; -import org.projectnessie.cel.Program; -import org.projectnessie.cel.common.types.Err; -import org.projectnessie.cel.common.types.ref.Val; /** * {@link CompiledProgram} is a parsed and type-checked {@link Program} along with the source {@link @@ -38,6 +38,12 @@ class CompiledProgram { /** The rule value. */ @Nullable private final Value ruleValue; + /** + * Global variables to pass to the evaluation step. Program/CelRuntime doesn't have a concept of + * global variables. + */ + @Nullable private final CelVariableResolver globals; + /** * Constructs a new {@link CompiledProgram}. * @@ -47,30 +53,38 @@ class CompiledProgram { * @param ruleValue The rule value. */ public CompiledProgram( - Program program, Expression source, @Nullable FieldPath rulePath, @Nullable Value ruleValue) { + Program program, + Expression source, + @Nullable FieldPath rulePath, + @Nullable Value ruleValue, + @Nullable CelVariableResolver globals) { this.program = program; this.source = source; this.rulePath = rulePath; this.ruleValue = ruleValue; + this.globals = globals; } /** - * Evaluate the compiled program with a given set of {@link Variable} bindings. + * Evaluate the compiled program with a given set of {@link Variable} variables. * - * @param bindings Variable bindings used for the evaluation. + * @param variables Variables used for the evaluation. * @param fieldValue Field value to return in violations. * @return The {@link build.buf.validate.Violation} from the evaluation, or null if there are no * violations. * @throws ExecutionException If the evaluation of the CEL program fails with an error. */ - public RuleViolation.@Nullable Builder eval(Value fieldValue, Variable bindings) + public RuleViolation.@Nullable Builder eval(Value fieldValue, CelVariableResolver variables) throws ExecutionException { - Program.EvalResult evalResult = program.eval(bindings); - Val val = evalResult.getVal(); - if (val instanceof Err) { - throw new ExecutionException(String.format("error evaluating %s: %s", source.id, val)); + Object value; + try { + if (this.globals != null) { + variables = CelVariableResolver.hierarchicalVariableResolver(variables, this.globals); + } + value = program.eval(variables); + } catch (CelEvaluationException e) { + throw new ExecutionException(String.format("error evaluating %s: %s", source.id, e)); } - Object value = val.value(); if (value instanceof String) { if ("".equals(value)) { return null; @@ -88,7 +102,7 @@ public CompiledProgram( } return builder; } else if (value instanceof Boolean) { - if (val.booleanValue()) { + if (Boolean.TRUE.equals(value)) { return null; } RuleViolation.Builder builder = @@ -101,7 +115,7 @@ public CompiledProgram( } return builder; } else { - throw new ExecutionException(String.format("resolved to an unexpected type %s", val)); + throw new ExecutionException(String.format("resolved to an unexpected type %s", value)); } } } diff --git a/src/main/java/build/buf/protovalidate/CustomDeclarations.java b/src/main/java/build/buf/protovalidate/CustomDeclarations.java index d46d07a7..25f57845 100644 --- a/src/main/java/build/buf/protovalidate/CustomDeclarations.java +++ b/src/main/java/build/buf/protovalidate/CustomDeclarations.java @@ -14,15 +14,20 @@ package build.buf.protovalidate; -import com.google.api.expr.v1alpha1.Decl; +import static dev.cel.common.CelFunctionDecl.newFunctionDeclaration; +import static dev.cel.common.CelOverloadDecl.newGlobalOverload; +import static dev.cel.common.CelOverloadDecl.newMemberOverload; + +import dev.cel.common.CelFunctionDecl; +import dev.cel.common.CelOverloadDecl; +import dev.cel.common.types.CelType; +import dev.cel.common.types.ListType; +import dev.cel.common.types.SimpleType; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Locale; -import org.projectnessie.cel.checker.Decls; -import org.projectnessie.cel.checker.Types; -import org.projectnessie.cel.common.types.TimestampT; /** Defines custom declaration functions. */ final class CustomDeclarations { @@ -32,141 +37,153 @@ final class CustomDeclarations { * * @return the list of function declarations. */ - static List create() { - List decls = new ArrayList<>(); - - // Add 'now' variable declaration - decls.add(Decls.newVar("now", Decls.newObjectType(TimestampT.TimestampType.typeName()))); + static List create() { + List decls = new ArrayList<>(); // Add 'getField' function declaration decls.add( - Decls.newFunction( + newFunctionDeclaration( "getField", - Decls.newOverload( - "get_field_any_string", Arrays.asList(Decls.Any, Decls.String), Decls.Any))); - + newGlobalOverload( + "get_field_any_string", + SimpleType.DYN, + Arrays.asList(SimpleType.ANY, SimpleType.STRING)))); // Add 'isIp' function declaration decls.add( - Decls.newFunction( + newFunctionDeclaration( "isIp", - Decls.newInstanceOverload("is_ip", Arrays.asList(Decls.String, Decls.Int), Decls.Bool), - Decls.newInstanceOverload( - "is_ip_unary", Collections.singletonList(Decls.String), Decls.Bool))); + newMemberOverload( + "is_ip", SimpleType.BOOL, Arrays.asList(SimpleType.STRING, SimpleType.INT)), + newMemberOverload( + "is_ip_unary", SimpleType.BOOL, Collections.singletonList(SimpleType.STRING)))); // Add 'isIpPrefix' function declaration decls.add( - Decls.newFunction( + newFunctionDeclaration( "isIpPrefix", - Decls.newInstanceOverload( + newMemberOverload( "is_ip_prefix_int_bool", - Arrays.asList(Decls.String, Decls.Int, Decls.Bool), - Decls.Bool), - Decls.newInstanceOverload( - "is_ip_prefix_int", Arrays.asList(Decls.String, Decls.Int), Decls.Bool), - Decls.newInstanceOverload( - "is_ip_prefix_bool", Arrays.asList(Decls.String, Decls.Bool), Decls.Bool), - Decls.newInstanceOverload( - "is_ip_prefix", Collections.singletonList(Decls.String), Decls.Bool))); + SimpleType.BOOL, + Arrays.asList(SimpleType.STRING, SimpleType.INT, SimpleType.BOOL)), + newMemberOverload( + "is_ip_prefix_int", + SimpleType.BOOL, + Arrays.asList(SimpleType.STRING, SimpleType.INT)), + newMemberOverload( + "is_ip_prefix_bool", + SimpleType.BOOL, + Arrays.asList(SimpleType.STRING, SimpleType.BOOL)), + newMemberOverload( + "is_ip_prefix", SimpleType.BOOL, Collections.singletonList(SimpleType.STRING)))); // Add 'isUriRef' function declaration decls.add( - Decls.newFunction( + newFunctionDeclaration( "isUriRef", - Decls.newInstanceOverload( - "is_uri_ref", Collections.singletonList(Decls.String), Decls.Bool))); + newMemberOverload( + "is_uri_ref", SimpleType.BOOL, Collections.singletonList(SimpleType.STRING)))); // Add 'isUri' function declaration decls.add( - Decls.newFunction( + newFunctionDeclaration( "isUri", - Decls.newInstanceOverload( - "is_uri", Collections.singletonList(Decls.String), Decls.Bool))); + newMemberOverload( + "is_uri", SimpleType.BOOL, Collections.singletonList(SimpleType.STRING)))); // Add 'isEmail' function declaration decls.add( - Decls.newFunction( + newFunctionDeclaration( "isEmail", - Decls.newInstanceOverload( - "is_email", Collections.singletonList(Decls.String), Decls.Bool))); + newMemberOverload( + "is_email", SimpleType.BOOL, Collections.singletonList(SimpleType.STRING)))); // Add 'isHostname' function declaration decls.add( - Decls.newFunction( + newFunctionDeclaration( "isHostname", - Decls.newInstanceOverload( - "is_hostname", Collections.singletonList(Decls.String), Decls.Bool))); + newMemberOverload( + "is_hostname", SimpleType.BOOL, Collections.singletonList(SimpleType.STRING)))); decls.add( - Decls.newFunction( + newFunctionDeclaration( "isHostAndPort", - Decls.newInstanceOverload( + newMemberOverload( "string_bool_is_host_and_port_bool", - Arrays.asList(Decls.String, Decls.Bool), - Decls.Bool))); + SimpleType.BOOL, + Arrays.asList(SimpleType.STRING, SimpleType.BOOL)))); // Add 'startsWith' function declaration decls.add( - Decls.newFunction( + newFunctionDeclaration( "startsWith", - Decls.newInstanceOverload( - "starts_with_bytes", Arrays.asList(Decls.Bytes, Decls.Bytes), Decls.Bool))); + newMemberOverload( + "starts_with_bytes", + SimpleType.BOOL, + Arrays.asList(SimpleType.BYTES, SimpleType.BYTES)))); // Add 'endsWith' function declaration decls.add( - Decls.newFunction( + newFunctionDeclaration( "endsWith", - Decls.newInstanceOverload( - "ends_with_bytes", Arrays.asList(Decls.Bytes, Decls.Bytes), Decls.Bool))); + newMemberOverload( + "ends_with_bytes", + SimpleType.BOOL, + Arrays.asList(SimpleType.BYTES, SimpleType.BYTES)))); // Add 'contains' function declaration decls.add( - Decls.newFunction( + newFunctionDeclaration( "contains", - Decls.newInstanceOverload( - "contains_bytes", Arrays.asList(Decls.Bytes, Decls.Bytes), Decls.Bool))); + newMemberOverload( + "contains_bytes", + SimpleType.BOOL, + Arrays.asList(SimpleType.BYTES, SimpleType.BYTES)))); // Add 'isNan' function declaration decls.add( - Decls.newFunction( + newFunctionDeclaration( "isNan", - Decls.newInstanceOverload( - "is_nan", Collections.singletonList(Decls.Double), Decls.Bool))); + newMemberOverload( + "is_nan", SimpleType.BOOL, Collections.singletonList(SimpleType.DOUBLE)))); // Add 'isInf' function declaration decls.add( - Decls.newFunction( + newFunctionDeclaration( "isInf", - Decls.newInstanceOverload( - "is_inf_unary", Collections.singletonList(Decls.Double), Decls.Bool), - Decls.newInstanceOverload( - "is_inf_binary", Arrays.asList(Decls.Double, Decls.Int), Decls.Bool))); + newMemberOverload( + "is_inf_unary", SimpleType.BOOL, Collections.singletonList(SimpleType.DOUBLE)), + newMemberOverload( + "is_inf_binary", + SimpleType.BOOL, + Arrays.asList(SimpleType.DOUBLE, SimpleType.INT)))); // Add 'unique' function declaration - List uniqueOverloads = new ArrayList<>(); - for (com.google.api.expr.v1alpha1.Type type : - Arrays.asList(Decls.String, Decls.Int, Decls.Uint, Decls.Double, Decls.Bytes, Decls.Bool)) { - uniqueOverloads.add( - Decls.newInstanceOverload( - String.format("unique_%s", Types.formatCheckedType(type).toLowerCase(Locale.US)), - Collections.singletonList(type), - Decls.Bool)); + List uniqueOverloads = new ArrayList<>(); + for (CelType type : + Arrays.asList( + SimpleType.STRING, + SimpleType.INT, + SimpleType.UINT, + SimpleType.DOUBLE, + SimpleType.BYTES, + SimpleType.BOOL)) { uniqueOverloads.add( - Decls.newInstanceOverload( - String.format("unique_list_%s", Types.formatCheckedType(type).toLowerCase(Locale.US)), - Collections.singletonList(Decls.newListType(type)), - Decls.Bool)); + newMemberOverload( + String.format("unique_list_%s", type.name().toLowerCase(Locale.US)), + SimpleType.BOOL, + Collections.singletonList(ListType.create(type)))); } - decls.add(Decls.newFunction("unique", uniqueOverloads)); + decls.add(newFunctionDeclaration("unique", uniqueOverloads)); // Add 'format' function declaration - List formatOverloads = new ArrayList<>(); + List formatOverloads = new ArrayList<>(); formatOverloads.add( - Decls.newInstanceOverload( + newMemberOverload( "format_list_dyn", - Arrays.asList(Decls.String, Decls.newListType(Decls.Dyn)), - Decls.String)); + SimpleType.STRING, + Arrays.asList(SimpleType.STRING, ListType.create(SimpleType.DYN)))); - decls.add(Decls.newFunction("format", formatOverloads)); + decls.add(newFunctionDeclaration("format", formatOverloads)); return Collections.unmodifiableList(decls); } diff --git a/src/main/java/build/buf/protovalidate/CustomOverload.java b/src/main/java/build/buf/protovalidate/CustomOverload.java index 8383e38c..2950d89b 100644 --- a/src/main/java/build/buf/protovalidate/CustomOverload.java +++ b/src/main/java/build/buf/protovalidate/CustomOverload.java @@ -14,42 +14,25 @@ package build.buf.protovalidate; +import com.google.protobuf.ByteString; import com.google.protobuf.Descriptors; import com.google.protobuf.Message; +import dev.cel.common.types.CelType; +import dev.cel.common.types.SimpleType; +import dev.cel.runtime.CelEvaluationException; +import dev.cel.runtime.CelRuntime.CelFunctionBinding; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; +import java.util.List; +import java.util.Locale; import java.util.Set; import java.util.regex.Pattern; -import org.projectnessie.cel.common.types.BoolT; -import org.projectnessie.cel.common.types.Err; -import org.projectnessie.cel.common.types.IntT; -import org.projectnessie.cel.common.types.ListT; -import org.projectnessie.cel.common.types.StringT; -import org.projectnessie.cel.common.types.Types; -import org.projectnessie.cel.common.types.pb.DefaultTypeAdapter; -import org.projectnessie.cel.common.types.ref.TypeEnum; -import org.projectnessie.cel.common.types.ref.Val; -import org.projectnessie.cel.common.types.traits.Lister; -import org.projectnessie.cel.interpreter.functions.Overload; /** Defines custom function overloads (the implementation). */ final class CustomOverload { - private static final String OVERLOAD_GET_FIELD = "getField"; - private static final String OVERLOAD_FORMAT = "format"; - private static final String OVERLOAD_UNIQUE = "unique"; - private static final String OVERLOAD_STARTS_WITH = "startsWith"; - private static final String OVERLOAD_ENDS_WITH = "endsWith"; - private static final String OVERLOAD_CONTAINS = "contains"; - private static final String OVERLOAD_IS_HOSTNAME = "isHostname"; - private static final String OVERLOAD_IS_EMAIL = "isEmail"; - private static final String OVERLOAD_IS_IP = "isIp"; - private static final String OVERLOAD_IS_IP_PREFIX = "isIpPrefix"; - private static final String OVERLOAD_IS_URI = "isUri"; - private static final String OVERLOAD_IS_URI_REF = "isUriRef"; - private static final String OVERLOAD_IS_NAN = "isNan"; - private static final String OVERLOAD_IS_INF = "isInf"; - private static final String OVERLOAD_IS_HOST_AND_PORT = "isHostAndPort"; - // See https://html.spec.whatwg.org/multipage/input.html#valid-e-mail-address private static final Pattern EMAIL_REGEX = Pattern.compile( @@ -58,184 +41,146 @@ final class CustomOverload { /** * Create custom function overload list. * - * @return an array of overloaded functions. + * @return a list of overloaded functions. */ - static Overload[] create() { - return new Overload[] { - celGetField(), - celFormat(), - celUnique(), - celStartsWith(), - celEndsWith(), - celContains(), - celIsHostname(), - celIsEmail(), - celIsIp(), - celIsIpPrefix(), - celIsUri(), - celIsUriRef(), - celIsNan(), - celIsInf(), - celIsHostAndPort(), - }; + static List create() { + ArrayList bindings = new ArrayList<>(); + bindings.addAll( + Arrays.asList( + celGetField(), + celFormat(), + celStartsWithBytes(), + celEndsWithBytes(), + celContainsBytes(), + celIsHostname(), + celIsEmail(), + celIsIpUnary(), + celIsIp(), + celIsIpPrefix(), + celIsIpPrefixInt(), + celIsIpPrefixBool(), + celIsIpPrefixIntBool(), + celIsUri(), + celIsUriRef(), + celIsNan(), + celIsInfUnary(), + celIsInfBinary(), + celIsHostAndPort())); + bindings.addAll(celUnique()); + return Collections.unmodifiableList(bindings); } /** * Creates a custom function overload for the "getField" operation. * - * @return The {@link Overload} instance for the "getField" operation. + * @return The {@link CelFunctionBinding} instance for the "getField" operation. */ - private static Overload celGetField() { - return Overload.binary( - OVERLOAD_GET_FIELD, - (msgarg, namearg) -> { - if (msgarg.type().typeEnum() != TypeEnum.Object - || namearg.type().typeEnum() != TypeEnum.String) { - return Err.newErr("no such overload"); - } - Message message = msgarg.convertToNative(Message.class); - String fieldName = namearg.convertToNative(String.class); + private static CelFunctionBinding celGetField() { + return CelFunctionBinding.from( + "get_field_any_string", + Message.class, + String.class, + (message, fieldName) -> { Descriptors.FieldDescriptor field = message.getDescriptorForType().findFieldByName(fieldName); if (field == null) { - return Err.newErr("no such field: " + fieldName); + throw new CelEvaluationException("no such field: " + fieldName); } - return DefaultTypeAdapter.Instance.nativeToValue(message.getField(field)); + return ProtoAdapter.toCel(field, message.getField(field)); }); } /** * Creates a custom binary function overload for the "format" operation. * - * @return The {@link Overload} instance for the "format" operation. + * @return The {@link CelFunctionBinding} instance for the "format" operation. */ - private static Overload celFormat() { - return Overload.binary( - OVERLOAD_FORMAT, - (lhs, rhs) -> { - if (lhs.type().typeEnum() != TypeEnum.String || rhs.type().typeEnum() != TypeEnum.List) { - return Err.noSuchOverload(lhs, OVERLOAD_FORMAT, rhs); - } - ListT list = (ListT) rhs.convertToType(ListT.ListType); - String formatString = (String) lhs.value(); - try { - return StringT.stringOf(Format.format(formatString, list)); - } catch (Err.ErrException e) { - return e.getErr(); - } - }); + private static CelFunctionBinding celFormat() { + return CelFunctionBinding.from("format_list_dyn", String.class, List.class, Format::format); } /** * Creates a custom unary function overload for the "unique" operation. * - * @return The {@link Overload} instance for the "unique" operation. + * @return The {@link CelFunctionBinding} instance for the "unique" operation. */ - private static Overload celUnique() { - return Overload.unary( - OVERLOAD_UNIQUE, - (val) -> { - if (val.type().typeEnum() != TypeEnum.List) { - return Err.noSuchOverload(val, OVERLOAD_UNIQUE, null); - } - return uniqueList((Lister) val); - }); + private static List celUnique() { + List uniqueOverloads = new ArrayList<>(); + for (CelType type : + Arrays.asList( + SimpleType.STRING, + SimpleType.INT, + SimpleType.UINT, + SimpleType.DOUBLE, + SimpleType.BYTES, + SimpleType.BOOL)) { + uniqueOverloads.add( + CelFunctionBinding.from( + String.format("unique_list_%s", type.name().toLowerCase(Locale.US)), + List.class, + CustomOverload::uniqueList)); + } + return Collections.unmodifiableList(uniqueOverloads); } /** * Creates a custom binary function overload for the "startsWith" operation. * - * @return The {@link Overload} instance for the "startsWith" operation. + * @return The {@link CelFunctionBinding} instance for the "startsWith" operation. */ - private static Overload celStartsWith() { - return Overload.binary( - OVERLOAD_STARTS_WITH, - (lhs, rhs) -> { - TypeEnum lhsType = lhs.type().typeEnum(); - if (lhsType != rhs.type().typeEnum()) { - return Err.noSuchOverload(lhs, OVERLOAD_STARTS_WITH, rhs); + private static CelFunctionBinding celStartsWithBytes() { + return CelFunctionBinding.from( + "starts_with_bytes", + ByteString.class, + ByteString.class, + (receiver, param) -> { + if (receiver.size() < param.size()) { + return false; } - if (lhsType == TypeEnum.String) { - String receiver = lhs.value().toString(); - String param = rhs.value().toString(); - return Types.boolOf(receiver.startsWith(param)); - } - if (lhsType == TypeEnum.Bytes) { - byte[] receiver = (byte[]) lhs.value(); - byte[] param = (byte[]) rhs.value(); - if (receiver.length < param.length) { - return BoolT.False; - } - for (int i = 0; i < param.length; i++) { - if (param[i] != receiver[i]) { - return BoolT.False; - } + for (int i = 0; i < param.size(); i++) { + if (param.byteAt(i) != receiver.byteAt(i)) { + return false; } - return BoolT.True; } - return Err.noSuchOverload(lhs, OVERLOAD_STARTS_WITH, rhs); + return true; }); } /** * Creates a custom binary function overload for the "endsWith" operation. * - * @return The {@link Overload} instance for the "endsWith" operation. + * @return The {@link CelFunctionBinding} instance for the "endsWith" operation. */ - private static Overload celEndsWith() { - return Overload.binary( - OVERLOAD_ENDS_WITH, - (lhs, rhs) -> { - TypeEnum lhsType = lhs.type().typeEnum(); - if (lhsType != rhs.type().typeEnum()) { - return Err.noSuchOverload(lhs, OVERLOAD_ENDS_WITH, rhs); + private static CelFunctionBinding celEndsWithBytes() { + return CelFunctionBinding.from( + "ends_with_bytes", + ByteString.class, + ByteString.class, + (receiver, param) -> { + if (receiver.size() < param.size()) { + return false; } - if (lhsType == TypeEnum.String) { - String receiver = (String) lhs.value(); - String param = (String) rhs.value(); - return Types.boolOf(receiver.endsWith(param)); - } - if (lhsType == TypeEnum.Bytes) { - byte[] receiver = (byte[]) lhs.value(); - byte[] param = (byte[]) rhs.value(); - if (receiver.length < param.length) { - return BoolT.False; - } - for (int i = 0; i < param.length; i++) { - if (param[param.length - i - 1] != receiver[receiver.length - i - 1]) { - return BoolT.False; - } + for (int i = 0; i < param.size(); i++) { + if (param.byteAt(param.size() - i - 1) != receiver.byteAt(receiver.size() - i - 1)) { + return false; } - return BoolT.True; } - return Err.noSuchOverload(lhs, OVERLOAD_ENDS_WITH, rhs); + return true; }); } /** * Creates a custom binary function overload for the "contains" operation. * - * @return The {@link Overload} instance for the "contains" operation. + * @return The {@link CelFunctionBinding} instance for the "contains" operation. */ - private static Overload celContains() { - return Overload.binary( - OVERLOAD_CONTAINS, - (lhs, rhs) -> { - TypeEnum lhsType = lhs.type().typeEnum(); - if (lhsType != rhs.type().typeEnum()) { - return Err.noSuchOverload(lhs, OVERLOAD_CONTAINS, rhs); - } - if (lhsType == TypeEnum.String) { - String receiver = lhs.value().toString(); - String param = rhs.value().toString(); - return Types.boolOf(receiver.contains(param)); - } - if (lhsType == TypeEnum.Bytes) { - byte[] receiver = (byte[]) lhs.value(); - byte[] param = (byte[]) rhs.value(); - return Types.boolOf(bytesContains(receiver, param)); - } - return Err.noSuchOverload(lhs, OVERLOAD_CONTAINS, rhs); + private static CelFunctionBinding celContainsBytes() { + return CelFunctionBinding.from( + "contains_bytes", + ByteString.class, + ByteString.class, + (receiver, param) -> { + return bytesContains(receiver.toByteArray(), param.toByteArray()); }); } @@ -264,200 +209,153 @@ static boolean bytesContains(byte[] arr, byte[] subArr) { /** * Creates a custom binary function overload for the "isHostname" operation. * - * @return The {@link Overload} instance for the "isHostname" operation. + * @return The {@link CelFunctionBinding} instance for the "isHostname" operation. */ - private static Overload celIsHostname() { - return Overload.unary( - OVERLOAD_IS_HOSTNAME, - value -> { - if (value.type().typeEnum() != TypeEnum.String) { - return Err.noSuchOverload(value, OVERLOAD_IS_HOSTNAME, null); - } - String host = (String) value.value(); - return Types.boolOf(isHostname(host)); - }); + private static CelFunctionBinding celIsHostname() { + return CelFunctionBinding.from("is_hostname", String.class, CustomOverload::isHostname); } /** * Creates a custom unary function overload for the "isEmail" operation. * - * @return The {@link Overload} instance for the "isEmail" operation. + * @return The {@link CelFunctionBinding} instance for the "isEmail" operation. */ - private static Overload celIsEmail() { - return Overload.unary( - OVERLOAD_IS_EMAIL, - value -> { - if (value.type().typeEnum() != TypeEnum.String) { - return Err.noSuchOverload(value, OVERLOAD_IS_EMAIL, null); - } - String addr = (String) value.value(); - return Types.boolOf(isEmail(addr)); - }); + private static CelFunctionBinding celIsEmail() { + return CelFunctionBinding.from("is_email", String.class, CustomOverload::isEmail); } /** * Creates a custom function overload for the "isIp" operation. * - * @return The {@link Overload} instance for the "isIp" operation. + * @return The {@link CelFunctionBinding} instance for the "isIp" operation. */ - private static Overload celIsIp() { - return Overload.overload( - OVERLOAD_IS_IP, - null, - value -> { - if (value.type().typeEnum() != TypeEnum.String) { - return Err.noSuchOverload(value, OVERLOAD_IS_IP, null); - } - String addr = (String) value.value(); - return Types.boolOf(isIp(addr, 0L)); - }, - (lhs, rhs) -> { - if (lhs.type().typeEnum() != TypeEnum.String || rhs.type().typeEnum() != TypeEnum.Int) { - return Err.noSuchOverload(lhs, OVERLOAD_IS_IP, rhs); - } - String address = (String) lhs.value(); - return Types.boolOf(isIp(address, rhs.intValue())); - }, - null); + private static CelFunctionBinding celIsIpUnary() { + return CelFunctionBinding.from("is_ip_unary", String.class, value -> isIp(value, 0L)); + } + + /** + * Creates a custom function overload for the "isIp" operation that also accepts a port. + * + * @return The {@link CelFunctionBinding} instance for the "isIp" operation. + */ + private static CelFunctionBinding celIsIp() { + return CelFunctionBinding.from("is_ip", String.class, Long.class, CustomOverload::isIp); } /** * Creates a custom function overload for the "isIpPrefix" operation. * - * @return The {@link Overload} instance for the "isIpPrefix" operation. + * @return The {@link CelFunctionBinding} instance for the "isIpPrefix" operation. */ - private static Overload celIsIpPrefix() { - return Overload.overload( - OVERLOAD_IS_IP_PREFIX, - null, - value -> { - if (value.type().typeEnum() != TypeEnum.String - && value.type().typeEnum() != TypeEnum.Bool) { - return Err.noSuchOverload(value, OVERLOAD_IS_IP_PREFIX, null); - } - String prefix = (String) value.value(); - return Types.boolOf(isIpPrefix(prefix, 0L, false)); - }, - (lhs, rhs) -> { - if (lhs.type().typeEnum() != TypeEnum.String - || (rhs.type().typeEnum() != TypeEnum.Int - && rhs.type().typeEnum() != TypeEnum.Bool)) { - return Err.noSuchOverload(lhs, OVERLOAD_IS_IP_PREFIX, rhs); - } - String prefix = (String) lhs.value(); - if (rhs.type().typeEnum() == TypeEnum.Int) { - return Types.boolOf(isIpPrefix(prefix, rhs.intValue(), false)); - } - return Types.boolOf(isIpPrefix(prefix, 0L, rhs.booleanValue())); - }, - (values) -> { - if (values.length != 3 - || values[0].type().typeEnum() != TypeEnum.String - || values[1].type().typeEnum() != TypeEnum.Int - || values[2].type().typeEnum() != TypeEnum.Bool) { - return Err.noSuchOverload(values[0], OVERLOAD_IS_IP_PREFIX, "", values); - } - String prefix = (String) values[0].value(); - return Types.boolOf(isIpPrefix(prefix, values[1].intValue(), values[2].booleanValue())); + private static CelFunctionBinding celIsIpPrefix() { + return CelFunctionBinding.from( + "is_ip_prefix", String.class, prefix -> isIpPrefix(prefix, 0L, false)); + } + + /** + * Creates a custom function overload for the "isIpPrefix" operation that accepts a version. + * + * @return The {@link CelFunctionBinding} instance for the "isIpPrefix" operation. + */ + private static CelFunctionBinding celIsIpPrefixInt() { + return CelFunctionBinding.from( + "is_ip_prefix_int", + String.class, + Long.class, + (prefix, version) -> { + return isIpPrefix(prefix, version, false); }); } /** - * Creates a custom unary function overload for the "isUri" operation. + * Creates a custom function overload for the "isIpPrefix" operation that accepts a strict flag. * - * @return The {@link Overload} instance for the "isUri" operation. + * @return The {@link CelFunctionBinding} instance for the "isIpPrefix" operation. */ - private static Overload celIsUri() { - return Overload.unary( - OVERLOAD_IS_URI, - value -> { - if (value.type().typeEnum() != TypeEnum.String) { - return Err.noSuchOverload(value, OVERLOAD_IS_URI, null); - } - String addr = (String) value.value(); - return Types.boolOf(isUri(addr)); + private static CelFunctionBinding celIsIpPrefixBool() { + return CelFunctionBinding.from( + "is_ip_prefix_bool", + String.class, + Boolean.class, + (prefix, strict) -> { + return isIpPrefix(prefix, 0L, strict); }); } + /** + * Creates a custom function overload for the "isIpPrefix" operation that accepts both version and + * strict flag. + * + * @return The {@link CelFunctionBinding} instance for the "isIpPrefix" operation. + */ + private static CelFunctionBinding celIsIpPrefixIntBool() { + return CelFunctionBinding.from( + "is_ip_prefix_int_bool", + Arrays.asList(String.class, Long.class, Boolean.class), + (args) -> isIpPrefix((String) args[0], (Long) args[1], (Boolean) args[2])); + } + + /** + * Creates a custom unary function overload for the "isUri" operation. + * + * @return The {@link CelFunctionBinding} instance for the "isUri" operation. + */ + private static CelFunctionBinding celIsUri() { + return CelFunctionBinding.from("is_uri", String.class, CustomOverload::isUri); + } + /** * Creates a custom unary function overload for the "isUriRef" operation. * - * @return The {@link Overload} instance for the "isUriRef" operation. + * @return The {@link CelFunctionBinding} instance for the "isUriRef" operation. */ - private static Overload celIsUriRef() { - return Overload.unary( - OVERLOAD_IS_URI_REF, - value -> { - if (value.type().typeEnum() != TypeEnum.String) { - return Err.noSuchOverload(value, OVERLOAD_IS_URI_REF, null); - } - String addr = (String) value.value(); - return Types.boolOf(isUriRef(addr)); - }); + private static CelFunctionBinding celIsUriRef() { + return CelFunctionBinding.from("is_uri_ref", String.class, CustomOverload::isUriRef); } /** * Creates a custom unary function overload for the "isNan" operation. * - * @return The {@link Overload} instance for the "isNan" operation. + * @return The {@link CelFunctionBinding} instance for the "isNan" operation. */ - private static Overload celIsNan() { - return Overload.unary( - OVERLOAD_IS_NAN, - value -> { - if (value.type().typeEnum() != TypeEnum.Double) { - return Err.noSuchOverload(value, OVERLOAD_IS_NAN, null); - } - Double doubleVal = (Double) value.value(); - return Types.boolOf(doubleVal.isNaN()); - }); + private static CelFunctionBinding celIsNan() { + return CelFunctionBinding.from("is_nan", Double.class, value -> Double.isNaN(value)); } /** * Creates a custom unary function overload for the "isInf" operation. * - * @return The {@link Overload} instance for the "isInf" operation. + * @return The {@link CelFunctionBinding} instance for the "isInf" operation. */ - private static Overload celIsInf() { - return Overload.overload( - OVERLOAD_IS_INF, - null, - value -> { - if (value.type().typeEnum() != TypeEnum.Double) { - return Err.noSuchOverload(value, OVERLOAD_IS_INF, null); - } - Double doubleVal = (Double) value.value(); - return Types.boolOf(doubleVal.isInfinite()); - }, - (lhs, rhs) -> { - if (lhs.type().typeEnum() != TypeEnum.Double || rhs.type().typeEnum() != TypeEnum.Int) { - return Err.noSuchOverload(lhs, OVERLOAD_IS_INF, rhs); - } - Double value = (Double) lhs.value(); - long sign = rhs.intValue(); + private static CelFunctionBinding celIsInfUnary() { + return CelFunctionBinding.from("is_inf_unary", Double.class, value -> value.isInfinite()); + } + + /** + * Creates a custom unary function overload for the "isInf" operation with sign option. + * + * @return The {@link CelFunctionBinding} instance for the "isInf" operation. + */ + private static CelFunctionBinding celIsInfBinary() { + return CelFunctionBinding.from( + "is_inf_binary", + Double.class, + Long.class, + (value, sign) -> { if (sign == 0) { - return Types.boolOf(value.isInfinite()); + return value.isInfinite(); } double expectedValue = (sign > 0) ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY; - return Types.boolOf(value == expectedValue); - }, - null); + return value == expectedValue; + }); } - private static Overload celIsHostAndPort() { - return Overload.overload( - OVERLOAD_IS_HOST_AND_PORT, - null, - null, - (lhs, rhs) -> { - if (lhs.type().typeEnum() != TypeEnum.String || rhs.type().typeEnum() != TypeEnum.Bool) { - return Err.noSuchOverload(lhs, OVERLOAD_IS_HOST_AND_PORT, rhs); - } - String value = (String) lhs.value(); - boolean portRequired = rhs.booleanValue(); - return Types.boolOf(isHostAndPort(value, portRequired)); - }, - null); + private static CelFunctionBinding celIsHostAndPort() { + return CelFunctionBinding.from( + "string_bool_is_host_and_port_bool", + String.class, + Boolean.class, + CustomOverload::isHostAndPort); } /** @@ -534,38 +432,23 @@ private static boolean isPort(String str) { /** * Determines if the input list contains unique values. If the list contains duplicate values, it - * returns {@link BoolT#False}. If the list contains unique values, it returns {@link BoolT#True}. + * returns {@code false}. If the list contains unique values, it returns {@code true}. * * @param list The input list to check for uniqueness. - * @return {@link BoolT#True} if the list contains unique scalar values, {@link BoolT#False} - * otherwise. + * @return {@code true} if the list contains unique scalar values, {@code false} otherwise. */ - private static Val uniqueList(Lister list) { - long size = list.size().intValue(); + private static boolean uniqueList(List list) throws CelEvaluationException { + long size = list.size(); if (size == 0) { - return BoolT.True; - } - Set exist = new HashSet<>((int) size); - Val firstVal = list.get(IntT.intOf(0)); - switch (firstVal.type().typeEnum()) { - case Bool: - case Int: - case Uint: - case Double: - case String: - case Bytes: - break; - default: - return Err.noSuchOverload(list, OVERLOAD_UNIQUE, null); + return true; } - exist.add(firstVal); - for (int i = 1; i < size; i++) { - Val val = list.get(IntT.intOf(i)); + Set exist = new HashSet<>((int) size); + for (Object val : list) { if (!exist.add(val)) { - return BoolT.False; + return false; } } - return BoolT.True; + return true; } /** @@ -616,7 +499,8 @@ private static boolean isHostname(String val) { for (String part : parts) { allDigits = true; - // if part is empty, longer than 63 chars, or starts/ends with '-', it is invalid + // if part is empty, longer than 63 chars, or starts/ends with '-', it is + // invalid int len = part.length(); if (len == 0 || len > 63 || part.startsWith("-") || part.endsWith("-")) { return false; diff --git a/src/main/java/build/buf/protovalidate/DescriptorMappings.java b/src/main/java/build/buf/protovalidate/DescriptorMappings.java index 24345811..dbbe42a6 100644 --- a/src/main/java/build/buf/protovalidate/DescriptorMappings.java +++ b/src/main/java/build/buf/protovalidate/DescriptorMappings.java @@ -15,14 +15,19 @@ package build.buf.protovalidate; import build.buf.validate.FieldRules; -import com.google.api.expr.v1alpha1.Type; import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.FieldDescriptor; import com.google.protobuf.Descriptors.OneofDescriptor; +import dev.cel.common.types.CelType; +import dev.cel.common.types.CelTypes; +import dev.cel.common.types.ListType; +import dev.cel.common.types.MapType; +import dev.cel.common.types.SimpleType; +import dev.cel.common.types.StructTypeReference; +import dev.cel.common.types.UnspecifiedType; import java.util.HashMap; import java.util.Map; import org.jspecify.annotations.Nullable; -import org.projectnessie.cel.checker.Decls; /** * DescriptorMappings provides mappings between protocol buffer descriptors and CEL declarations. @@ -131,11 +136,11 @@ public static FieldDescriptor expectedWrapperRules(String fqn) { * @param kind The protobuf field type. * @return The corresponding CEL type for the protobuf field. */ - public static Type protoKindToCELType(FieldDescriptor.Type kind) { + public static CelType protoKindToCELType(FieldDescriptor.Type kind) { switch (kind) { case FLOAT: case DOUBLE: - return Decls.newPrimitiveType(Type.PrimitiveType.DOUBLE); + return SimpleType.DOUBLE; case INT32: case INT64: case SINT32: @@ -143,25 +148,23 @@ public static Type protoKindToCELType(FieldDescriptor.Type kind) { case SFIXED32: case SFIXED64: case ENUM: - return Decls.newPrimitiveType(Type.PrimitiveType.INT64); + return SimpleType.INT; case UINT32: case UINT64: case FIXED32: case FIXED64: - return Decls.newPrimitiveType(Type.PrimitiveType.UINT64); + return SimpleType.UINT; case BOOL: - return Decls.newPrimitiveType(Type.PrimitiveType.BOOL); + return SimpleType.BOOL; case STRING: - return Decls.newPrimitiveType(Type.PrimitiveType.STRING); + return SimpleType.STRING; case BYTES: - return Decls.newPrimitiveType(Type.PrimitiveType.BYTES); + return SimpleType.BYTES; case MESSAGE: case GROUP: - return Type.newBuilder().setMessageType(kind.getJavaType().name()).build(); + return StructTypeReference.create(kind.getJavaType().name()); default: - return Type.newBuilder() - .setPrimitive(Type.PrimitiveType.PRIMITIVE_TYPE_UNSPECIFIED) - .build(); + return UnspecifiedType.create(); } } @@ -189,42 +192,20 @@ static FieldDescriptor getExpectedRuleDescriptor( * Resolves the CEL value type for the provided {@link FieldDescriptor}. If forItems is true, the * type for the repeated list items is returned instead of the list type itself. */ - static Type getCELType(FieldDescriptor fieldDescriptor, boolean forItems) { + static CelType getCELType(FieldDescriptor fieldDescriptor, boolean forItems) { if (!forItems) { if (fieldDescriptor.isMapField()) { - return Decls.newMapType( + return MapType.create( getCELType(fieldDescriptor.getMessageType().findFieldByNumber(1), true), getCELType(fieldDescriptor.getMessageType().findFieldByNumber(2), true)); } else if (fieldDescriptor.isRepeated()) { - return Decls.newListType(getCELType(fieldDescriptor, true)); + return ListType.create(getCELType(fieldDescriptor, true)); } } if (fieldDescriptor.getJavaType() == FieldDescriptor.JavaType.MESSAGE) { String fqn = fieldDescriptor.getMessageType().getFullName(); - switch (fqn) { - case "google.protobuf.Any": - return Decls.newWellKnownType(Type.WellKnownType.ANY); - case "google.protobuf.Duration": - return Decls.newWellKnownType(Type.WellKnownType.DURATION); - case "google.protobuf.Timestamp": - return Decls.newWellKnownType(Type.WellKnownType.TIMESTAMP); - case "google.protobuf.BytesValue": - return Decls.newWrapperType(Decls.Bytes); - case "google.protobuf.DoubleValue": - case "google.protobuf.FloatValue": - return Decls.newWrapperType(Decls.Double); - case "google.protobuf.Int32Value": - case "google.protobuf.Int64Value": - return Decls.newWrapperType(Decls.Int); - case "google.protobuf.StringValue": - return Decls.newWrapperType(Decls.String); - case "google.protobuf.UInt32Value": - case "google.protobuf.UInt64Value": - return Decls.newWrapperType(Decls.Uint); - default: - return Decls.newObjectType(fqn); - } + return CelTypes.getWellKnownCelType(fqn).orElse(StructTypeReference.create(fqn)); } return DescriptorMappings.protoKindToCELType(fieldDescriptor.getType()); } diff --git a/src/main/java/build/buf/protovalidate/EnumEvaluator.java b/src/main/java/build/buf/protovalidate/EnumEvaluator.java index e2484983..e9a0c506 100644 --- a/src/main/java/build/buf/protovalidate/EnumEvaluator.java +++ b/src/main/java/build/buf/protovalidate/EnumEvaluator.java @@ -32,7 +32,7 @@ class EnumEvaluator implements Evaluator { private final RuleViolationHelper helper; /** Captures all the defined values for this enum */ - private final Set values; + private final Set values; private static final Descriptors.FieldDescriptor DEFINED_ONLY_DESCRIPTOR = EnumRules.getDescriptor().findFieldByNumber(EnumRules.DEFINED_ONLY_FIELD_NUMBER); @@ -57,9 +57,7 @@ class EnumEvaluator implements Evaluator { this.values = Collections.emptySet(); } else { this.values = - valueDescriptors.stream() - .map(Descriptors.EnumValueDescriptor::getNumber) - .collect(Collectors.toSet()); + valueDescriptors.stream().map(it -> (long) it.getNumber()).collect(Collectors.toSet()); } } @@ -79,11 +77,11 @@ public boolean tautology() { @Override public List evaluate(Value val, boolean failFast) throws ExecutionException { - Descriptors.EnumValueDescriptor enumValue = val.value(Descriptors.EnumValueDescriptor.class); + Object enumValue = val.value(Object.class); if (enumValue == null) { return RuleViolation.NO_VIOLATIONS; } - if (!values.contains(enumValue.getNumber())) { + if (!values.contains(enumValue)) { return Collections.singletonList( RuleViolation.newBuilder() .addAllRulePathElements(helper.getRulePrefixElements()) diff --git a/src/main/java/build/buf/protovalidate/EvaluatorBuilder.java b/src/main/java/build/buf/protovalidate/EvaluatorBuilder.java index f44d0603..96213ecf 100644 --- a/src/main/java/build/buf/protovalidate/EvaluatorBuilder.java +++ b/src/main/java/build/buf/protovalidate/EvaluatorBuilder.java @@ -22,7 +22,6 @@ import build.buf.validate.MessageRules; import build.buf.validate.OneofRules; import build.buf.validate.Rule; -import com.google.api.expr.v1alpha1.Decl; import com.google.protobuf.ByteString; import com.google.protobuf.Descriptors; import com.google.protobuf.Descriptors.Descriptor; @@ -30,18 +29,17 @@ import com.google.protobuf.DynamicMessage; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; +import dev.cel.bundle.Cel; +import dev.cel.bundle.CelBuilder; +import dev.cel.common.types.StructTypeReference; +import dev.cel.runtime.CelEvaluationException; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Set; import org.jspecify.annotations.Nullable; -import org.projectnessie.cel.Env; -import org.projectnessie.cel.EnvOption; -import org.projectnessie.cel.checker.Decls; /** A build-through cache of message evaluators keyed off the provided descriptor. */ class EvaluatorBuilder { @@ -51,34 +49,34 @@ class EvaluatorBuilder { private volatile Map evaluatorCache = Collections.emptyMap(); - private final Env env; + private final Cel cel; private final boolean disableLazy; private final RuleCache rules; /** * Constructs a new {@link EvaluatorBuilder}. * - * @param env The CEL environment for evaluation. + * @param cel The CEL environment for evaluation. * @param config The configuration to use for the evaluation. */ - EvaluatorBuilder(Env env, Config config) { - this.env = env; + EvaluatorBuilder(Cel cel, Config config) { + this.cel = cel; this.disableLazy = false; - this.rules = new RuleCache(env, config); + this.rules = new RuleCache(cel, config); } /** * Constructs a new {@link EvaluatorBuilder}. * - * @param env The CEL environment for evaluation. + * @param cel The CEL environment for evaluation. * @param config The configuration to use for the evaluation. */ - EvaluatorBuilder(Env env, Config config, List descriptors, boolean disableLazy) + EvaluatorBuilder(Cel cel, Config config, List descriptors, boolean disableLazy) throws CompilationException { Objects.requireNonNull(descriptors, "descriptors must not be null"); - this.env = env; + this.cel = cel; this.disableLazy = disableLazy; - this.rules = new RuleCache(env, config); + this.rules = new RuleCache(cel, config); for (Descriptor descriptor : descriptors) { this.build(descriptor); @@ -119,7 +117,7 @@ private Evaluator build(Descriptor desc) throws CompilationException { } // Rebuild cache with this descriptor (and any of its dependencies). Map updatedCache = - new DescriptorCacheBuilder(env, rules, evaluatorCache).build(desc); + new DescriptorCacheBuilder(cel, rules, evaluatorCache).build(desc); evaluatorCache = updatedCache; eval = updatedCache.get(desc); if (eval == null) { @@ -132,13 +130,13 @@ private Evaluator build(Descriptor desc) throws CompilationException { private static class DescriptorCacheBuilder { private final RuleResolver resolver = new RuleResolver(); - private final Env env; + private final Cel cel; private final RuleCache ruleCache; private final HashMap cache; private DescriptorCacheBuilder( - Env env, RuleCache ruleCache, Map previousCache) { - this.env = Objects.requireNonNull(env, "env"); + Cel cel, RuleCache ruleCache, Map previousCache) { + this.cel = Objects.requireNonNull(cel, "cel"); this.ruleCache = Objects.requireNonNull(ruleCache, "ruleCache"); this.cache = new HashMap<>(previousCache); } @@ -186,31 +184,6 @@ private void buildMessage(Descriptor desc, MessageEvaluator msgEval) } } - private void collectDependencies(Set dependencyTypes, Descriptor desc) { - dependencyTypes.add(desc); - for (FieldDescriptor field : desc.getFields()) { - if (field.getJavaType() != FieldDescriptor.JavaType.MESSAGE) { - continue; - } - Descriptor submessageDesc = field.getMessageType(); - if (dependencyTypes.contains(submessageDesc)) { - continue; - } - collectDependencies(dependencyTypes, submessageDesc); - } - } - - private Message[] getTypesForMessage(Message message) { - Set dependencyTypes = new HashSet<>(); - collectDependencies(dependencyTypes, message.getDescriptorForType()); - Message[] dependencyTypeMessages = new Message[dependencyTypes.size()]; - int i = 0; - for (Descriptor dependencyType : dependencyTypes) { - dependencyTypeMessages[i++] = DynamicMessage.newBuilder(dependencyType).buildPartial(); - } - return dependencyTypeMessages; - } - private void processMessageExpressions( Descriptor desc, MessageRules msgRules, MessageEvaluator msgEval, DynamicMessage message) throws CompilationException { @@ -218,12 +191,12 @@ private void processMessageExpressions( if (celList.isEmpty()) { return; } - Env finalEnv = - env.extend( - EnvOption.types((Object[]) getTypesForMessage(message)), - EnvOption.declarations( - Decls.newVar(Variable.THIS_NAME, Decls.newObjectType(desc.getFullName())))); - List compiledPrograms = compileRules(celList, finalEnv, false); + Cel finalCel = + cel.toCelBuilder() + .addMessageTypes(message.getDescriptorForType()) + .addVar(Variable.THIS_NAME, StructTypeReference.create(desc.getFullName())) + .build(); + List compiledPrograms = compileRules(celList, finalCel, false); if (compiledPrograms.isEmpty()) { throw new CompilationException("compile returned null"); } @@ -312,14 +285,10 @@ private Object zeroValue(FieldDescriptor fieldDescriptor, boolean forItems) if (forItems && fieldDescriptor.isRepeated()) { switch (fieldDescriptor.getType().getJavaType()) { case INT: - zero = 0; - break; case LONG: zero = 0L; break; case FLOAT: - zero = 0F; - break; case DOUBLE: zero = 0D; break; @@ -333,7 +302,7 @@ private Object zeroValue(FieldDescriptor fieldDescriptor, boolean forItems) zero = ByteString.EMPTY; break; case ENUM: - zero = fieldDescriptor.getEnumType().getValues().get(0); + zero = (long) fieldDescriptor.getEnumType().getValues().get(0).getNumber(); break; case MESSAGE: zero = createMessageForType(fieldDescriptor.getMessageType()); @@ -346,7 +315,8 @@ private Object zeroValue(FieldDescriptor fieldDescriptor, boolean forItems) && !fieldDescriptor.isRepeated()) { zero = createMessageForType(fieldDescriptor.getMessageType()); } else { - zero = fieldDescriptor.getDefaultValue(); + zero = + ProtoAdapter.scalarToCel(fieldDescriptor.getType(), fieldDescriptor.getDefaultValue()); } return zero; } @@ -366,24 +336,17 @@ private void processFieldExpressions( if (rulesCelList.isEmpty()) { return; } - Decl celType = - Decls.newVar( + CelBuilder builder = cel.toCelBuilder(); + builder = + builder.addVar( Variable.THIS_NAME, DescriptorMappings.getCELType(fieldDescriptor, valueEvaluatorEval.hasNestedRule())); - List opts = new ArrayList(); - opts.add(EnvOption.declarations(celType)); if (fieldDescriptor.getJavaType() == FieldDescriptor.JavaType.MESSAGE) { - try { - DynamicMessage defaultInstance = - DynamicMessage.parseFrom(fieldDescriptor.getMessageType(), new byte[0]); - opts.add(EnvOption.types((Object[]) getTypesForMessage(defaultInstance))); - } catch (InvalidProtocolBufferException e) { - throw new CompilationException("field descriptor type is invalid " + e.getMessage(), e); - } + builder = builder.addMessageTypes(fieldDescriptor.getMessageType()); } - Env finalEnv = env.extend(opts.toArray(new EnvOption[opts.size()])); - List compiledPrograms = compileRules(rulesCelList, finalEnv, true); + Cel finalCel = builder.build(); + List compiledPrograms = compileRules(rulesCelList, finalCel, true); if (!compiledPrograms.isEmpty()) { valueEvaluatorEval.append(new CelPrograms(valueEvaluatorEval, compiledPrograms)); } @@ -526,13 +489,13 @@ private void processRepeatedRules( valueEvaluatorEval.append(listEval); } - private static List compileRules(List rules, Env env, boolean isField) + private static List compileRules(List rules, Cel cel, boolean isField) throws CompilationException { List expressions = Expression.fromRules(rules); List compiledPrograms = new ArrayList<>(); for (int i = 0; i < expressions.size(); i++) { Expression expression = expressions.get(i); - AstExpression astExpression = AstExpression.newAstExpression(env, expression); + AstExpression astExpression = AstExpression.newAstExpression(cel, expression); @Nullable FieldPath rulePath = null; if (isField) { rulePath = @@ -540,12 +503,17 @@ private static List compileRules(List rules, Env env, boo .addElements(CEL_FIELD_PATH_ELEMENT.toBuilder().setIndex(i)) .build(); } - compiledPrograms.add( - new CompiledProgram( - env.program(astExpression.ast), - astExpression.source, - rulePath, - new MessageValue(rules.get(i)))); + try { + compiledPrograms.add( + new CompiledProgram( + cel.createProgram(astExpression.ast), + astExpression.source, + rulePath, + new MessageValue(rules.get(i)), + null)); + } catch (CelEvaluationException e) { + throw new CompilationException("failed to evaluate rule " + rules.get(i).getId(), e); + } } return compiledPrograms; } diff --git a/src/main/java/build/buf/protovalidate/FieldEvaluator.java b/src/main/java/build/buf/protovalidate/FieldEvaluator.java index 5721da57..a8366447 100644 --- a/src/main/java/build/buf/protovalidate/FieldEvaluator.java +++ b/src/main/java/build/buf/protovalidate/FieldEvaluator.java @@ -135,7 +135,8 @@ public List evaluate(Value val, boolean failFast) return RuleViolation.NO_VIOLATIONS; } Object fieldValue = message.getField(descriptor); - if (this.shouldIgnoreDefault() && Objects.equals(zero, fieldValue)) { + if (this.shouldIgnoreDefault() + && Objects.equals(zero, ProtoAdapter.toCel(descriptor, fieldValue))) { return RuleViolation.NO_VIOLATIONS; } return valueEvaluator.evaluate(new ObjectValue(descriptor, fieldValue), failFast); diff --git a/src/main/java/build/buf/protovalidate/Format.java b/src/main/java/build/buf/protovalidate/Format.java index 9b1a8211..f3dd494f 100644 --- a/src/main/java/build/buf/protovalidate/Format.java +++ b/src/main/java/build/buf/protovalidate/Format.java @@ -16,38 +16,38 @@ import static java.time.format.DateTimeFormatter.ISO_INSTANT; +import com.google.common.primitives.UnsignedLong; +import com.google.protobuf.ByteString; import com.google.protobuf.Duration; +import com.google.protobuf.NullValue; import com.google.protobuf.Timestamp; +import dev.cel.common.types.TypeType; +import dev.cel.runtime.CelEvaluationException; import java.math.BigInteger; import java.nio.charset.StandardCharsets; import java.text.DecimalFormat; import java.time.Instant; +import java.util.Iterator; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Map.Entry; import java.util.Optional; import java.util.SortedMap; import java.util.TreeMap; import java.util.stream.Collectors; -import org.projectnessie.cel.common.types.Err.ErrException; -import org.projectnessie.cel.common.types.IntT; -import org.projectnessie.cel.common.types.IteratorT; -import org.projectnessie.cel.common.types.ListT; -import org.projectnessie.cel.common.types.MapT; -import org.projectnessie.cel.common.types.ref.TypeEnum; -import org.projectnessie.cel.common.types.ref.Val; /** String formatter for CEL evaluation. */ final class Format { /** - * Format the string with a {@link ListT}. + * Format the string with a {@link List}. * * @param fmtString the string to format. * @param list the arguments. * @return the formatted string. - * @throws ErrException If an error occurs formatting the string. + * @throws CelEvaluationException If an error occurs formatting the string. */ - static String format(String fmtString, ListT list) { + static String format(String fmtString, List list) throws CelEvaluationException { // StringBuilder to accumulate the formatted string StringBuilder builder = new StringBuilder(); int index = 0; @@ -67,7 +67,7 @@ static String format(String fmtString, ListT list) { continue; } if (index >= fmtString.length()) { - throw new ErrException("format: expected format specifier"); + throw new CelEvaluationException("format: expected format specifier"); } if (fmtString.charAt(index) == '%') { // Escaped '%', append '%' and move to the next character @@ -75,10 +75,10 @@ static String format(String fmtString, ListT list) { index++; continue; } - if (argIndex >= list.size().intValue()) { - throw new ErrException("index " + argIndex + " out of range"); + if (argIndex >= list.size()) { + throw new CelEvaluationException("index " + argIndex + " out of range"); } - Val arg = list.get(IntT.intOf(argIndex++)); + Object arg = list.get(argIndex++); c = fmtString.charAt(index++); int precision = 6; if (c == '.') { @@ -90,7 +90,7 @@ static String format(String fmtString, ListT list) { precision = precision * 10 + (fmtString.charAt(index++) - '0'); } if (index >= fmtString.length()) { - throw new ErrException("format: expected format specifier"); + throw new CelEvaluationException("format: expected format specifier"); } c = fmtString.charAt(index++); } @@ -122,7 +122,7 @@ static String format(String fmtString, ListT list) { builder.append(formatOctal(arg)); break; default: - throw new ErrException( + throw new CelEvaluationException( "could not parse formatting clause: unrecognized formatting clause \"" + c + "\""); } } @@ -134,46 +134,43 @@ static String format(String fmtString, ListT list) { * * @param val the value to format. */ - private static String formatString(Val val) { - TypeEnum type = val.type().typeEnum(); - switch (type) { - case Type: - case String: - return val.value().toString(); - case Bool: - return Boolean.toString(val.booleanValue()); - case Int: - case Uint: - Optional str = validateNumber(val); - if (str.isPresent()) { - return str.get(); - } - return val.value().toString(); - case Bytes: - String byteStr = new String((byte[]) val.value(), StandardCharsets.UTF_8); - // Collapse any contiguous placeholders into one - return byteStr.replaceAll("\\ufffd+", "\ufffd"); - case Double: - Optional result = validateNumber(val); - if (result.isPresent()) { - return result.get(); - } - return formatDecimal(val); - case Duration: - return formatDuration(val); - case Timestamp: - return formatTimestamp(val); - case List: - return formatList((ListT) val); - case Map: - return formatMap((MapT) val); - case Null: - return "null"; - default: - throw new ErrException( - "error during formatting: string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps, was given " - + val.type()); + private static String formatString(Object val) throws CelEvaluationException { + if (val instanceof String) { + return (String) val; + } else if (val instanceof TypeType) { + return ((TypeType) val).containingTypeName(); + } else if (val instanceof Boolean) { + return Boolean.toString((Boolean) val); + } else if (val instanceof Long || val instanceof UnsignedLong) { + Optional str = validateNumber(val); + if (str.isPresent()) { + return str.get(); + } + return val.toString(); + } else if (val instanceof ByteString) { + String byteStr = ((ByteString) val).toStringUtf8(); + // Collapse any contiguous placeholders into one + return byteStr.replaceAll("\\ufffd+", "\ufffd"); + } else if (val instanceof Double) { + Optional result = validateNumber(val); + if (result.isPresent()) { + return result.get(); + } + return formatDecimal(val); + } else if (val instanceof Duration) { + return formatDuration((Duration) val); + } else if (val instanceof Timestamp) { + return formatTimestamp((Timestamp) val); + } else if (val instanceof List) { + return formatList((List) val); + } else if (val instanceof Map) { + return formatMap((Map) val); + } else if (val == null || val instanceof NullValue) { + return "null"; } + throw new CelEvaluationException( + "error during formatting: string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps, was given " + + val.getClass()); } /** @@ -181,15 +178,15 @@ private static String formatString(Val val) { * * @param val the value to format. */ - private static String formatList(ListT val) { + private static String formatList(List val) throws CelEvaluationException { StringBuilder builder = new StringBuilder(); builder.append('['); - IteratorT iter = val.iterator(); - while (iter.hasNext().booleanValue()) { - Val v = iter.next(); + Iterator iter = val.iterator(); + while (iter.hasNext()) { + Object v = iter.next(); builder.append(formatString(v)); - if (iter.hasNext().booleanValue()) { + if (iter.hasNext()) { builder.append(", "); } } @@ -197,18 +194,14 @@ private static String formatList(ListT val) { return builder.toString(); } - private static String formatMap(MapT val) { + private static String formatMap(Map val) throws CelEvaluationException { StringBuilder builder = new StringBuilder(); builder.append('{'); SortedMap sorted = new TreeMap<>(); - IteratorT iter = val.iterator(); - while (iter.hasNext().booleanValue()) { - Val key = iter.next(); - String mapKey = formatString(key); - String mapVal = formatString(val.find(key)); - sorted.put(mapKey, mapVal); + for (Entry entry : val.entrySet()) { + sorted.put(formatString(entry.getKey()), formatString(entry.getValue())); } String result = @@ -224,22 +217,19 @@ private static String formatMap(MapT val) { /** * Formats a timestamp value. * - * @param val the value to format. + * @param timestamp the value to format. */ - private static String formatTimestamp(Val val) { - Timestamp timestamp = val.convertToNative(Timestamp.class); - Instant instant = Instant.ofEpochSecond(timestamp.getSeconds(), timestamp.getNanos()); - return ISO_INSTANT.format(instant); + private static String formatTimestamp(Timestamp timestamp) { + return ISO_INSTANT.format(Instant.ofEpochSecond(timestamp.getSeconds(), timestamp.getNanos())); } /** * Formats a duration value. * - * @param val the value to format. + * @param duration the value to format. */ - private static String formatDuration(Val val) { + private static String formatDuration(Duration duration) { StringBuilder builder = new StringBuilder(); - Duration duration = val.convertToNative(Duration.class); double totalSeconds = duration.getSeconds() + (duration.getNanos() / 1_000_000_000.0); @@ -255,23 +245,24 @@ private static String formatDuration(Val val) { * * @param val the value to format. */ - private static String formatHex(Val val) { - TypeEnum type = val.type().typeEnum(); - if (type == TypeEnum.Int || type == TypeEnum.Uint) { - return Long.toHexString(val.intValue()); - } else if (type == TypeEnum.Bytes) { + private static String formatHex(Object val) throws CelEvaluationException { + if (val instanceof Long) { + return Long.toHexString((Long) val); + } else if (val instanceof UnsignedLong) { + return Long.toHexString(((UnsignedLong) val).longValue()); + } else if (val instanceof ByteString) { StringBuilder hexString = new StringBuilder(); - for (byte b : (byte[]) val.value()) { + for (byte b : (ByteString) val) { hexString.append(String.format("%02x", b)); } return hexString.toString(); - } else if (type == TypeEnum.String) { - String arg = val.value().toString(); + } else if (val instanceof String) { + String arg = (String) val; return String.format("%x", new BigInteger(1, arg.getBytes(StandardCharsets.UTF_8))); } else { - throw new ErrException( + throw new CelEvaluationException( "error during formatting: only integers, byte buffers, and strings can be formatted as hex, was given " - + val.type()); + + val.getClass()); } } @@ -280,65 +271,64 @@ private static String formatHex(Val val) { * * @param val the value to format. */ - private static String formatDecimal(Val val) { - TypeEnum type = val.type().typeEnum(); - if (type == TypeEnum.Int || type == TypeEnum.Uint || type == TypeEnum.Double) { + private static String formatDecimal(Object val) throws CelEvaluationException { + if (val instanceof Long || val instanceof UnsignedLong || val instanceof Double) { Optional str = validateNumber(val); if (str.isPresent()) { return str.get(); } DecimalFormat formatter = new DecimalFormat("0.#########"); - return formatter.format(val.value()); + return formatter.format(val); } else { - throw new ErrException( + throw new CelEvaluationException( "error during formatting: decimal clause can only be used on integers, was given " - + val.type()); + + val.getClass()); } } - private static String formatOctal(Val val) { - TypeEnum type = val.type().typeEnum(); - if (type == TypeEnum.Int || type == TypeEnum.Uint) { - return Long.toOctalString(Long.valueOf(val.intValue())); + private static String formatOctal(Object val) throws CelEvaluationException { + if (val instanceof Long) { + return Long.toOctalString((Long) val); + } else if (val instanceof UnsignedLong) { + return Long.toOctalString(((UnsignedLong) val).longValue()); } else { - throw new ErrException( + throw new CelEvaluationException( "error during formatting: octal clause can only be used on integers, was given " - + val.type()); + + val.getClass()); } } - private static String formatBinary(Val val) { - TypeEnum type = val.type().typeEnum(); - if (type == TypeEnum.Int || type == TypeEnum.Uint) { - return Long.toBinaryString(Long.valueOf(val.intValue())); - } else if (type == TypeEnum.Bool) { - return val.booleanValue() ? "1" : "0"; + private static String formatBinary(Object val) throws CelEvaluationException { + if (val instanceof Long) { + return Long.toBinaryString((Long) val); + } else if (val instanceof UnsignedLong) { + return Long.toBinaryString(((UnsignedLong) val).longValue()); + } else if (val instanceof Boolean) { + return Boolean.TRUE.equals(val) ? "1" : "0"; } else { - throw new ErrException( + throw new CelEvaluationException( "error during formatting: only integers and bools can be formatted as binary, was given " - + val.type()); + + val.getClass()); } } - private static String formatExponential(Val val, int precision) { - TypeEnum type = val.type().typeEnum(); - if (type == TypeEnum.Double) { + private static String formatExponential(Object val, int precision) throws CelEvaluationException { + if (val instanceof Double) { Optional str = validateNumber(val); if (str.isPresent()) { return str.get(); } String pattern = "%." + precision + "e"; - return String.format(pattern, val.doubleValue()); + return String.format(pattern, (Double) val); } else { - throw new ErrException( + throw new CelEvaluationException( "error during formatting: scientific clause can only be used on doubles, was given " - + val.type()); + + val.getClass()); } } - private static String formatFloat(Val val, int precision) { - TypeEnum type = val.type().typeEnum(); - if (type == TypeEnum.Double) { + private static String formatFloat(Object val, int precision) throws CelEvaluationException { + if (val instanceof Double) { Optional str = validateNumber(val); if (str.isPresent()) { return str.get(); @@ -352,19 +342,21 @@ private static String formatFloat(Val val, int precision) { pattern.append("########"); } DecimalFormat formatter = new DecimalFormat(pattern.toString()); - return formatter.format(val.value()); + return formatter.format(val); } else { - throw new ErrException( + throw new CelEvaluationException( "error during formatting: fixed-point clause can only be used on doubles, was given " - + val.type()); + + val.getClass()); } } - private static Optional validateNumber(Val val) { - if (val.doubleValue() == Double.POSITIVE_INFINITY) { - return Optional.of("Infinity"); - } else if (val.doubleValue() == Double.NEGATIVE_INFINITY) { - return Optional.of("-Infinity"); + private static Optional validateNumber(Object val) { + if (val instanceof Double) { + if ((Double) val == Double.POSITIVE_INFINITY) { + return Optional.of("Infinity"); + } else if ((Double) val == Double.NEGATIVE_INFINITY) { + return Optional.of("-Infinity"); + } } return Optional.empty(); } diff --git a/src/main/java/build/buf/protovalidate/ListElementValue.java b/src/main/java/build/buf/protovalidate/ListElementValue.java new file mode 100644 index 00000000..d76754ed --- /dev/null +++ b/src/main/java/build/buf/protovalidate/ListElementValue.java @@ -0,0 +1,73 @@ +// Copyright 2023-2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package build.buf.protovalidate; + +import com.google.protobuf.Descriptors; +import com.google.protobuf.Message; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.jspecify.annotations.Nullable; + +/** + * The {@link Value} type that contains a field descriptor for repeated field and the value of an + * element. + */ +final class ListElementValue implements Value { + /** Object type since the object type is inferred from the field descriptor. */ + private final Object value; + + /** + * {@link com.google.protobuf.Descriptors.FieldDescriptor} is the field descriptor for the value. + */ + private final Descriptors.FieldDescriptor fieldDescriptor; + + ListElementValue(Descriptors.FieldDescriptor fieldDescriptor, Object value) { + this.value = value; + this.fieldDescriptor = fieldDescriptor; + } + + @Override + public Descriptors.@Nullable FieldDescriptor fieldDescriptor() { + return fieldDescriptor; + } + + @Override + public @Nullable Message messageValue() { + if (fieldDescriptor.getJavaType() == Descriptors.FieldDescriptor.JavaType.MESSAGE) { + return (Message) value; + } + return null; + } + + @Override + public T value(Class clazz) { + Descriptors.FieldDescriptor.Type type = fieldDescriptor.getType(); + if (type == Descriptors.FieldDescriptor.Type.MESSAGE) { + return clazz.cast(value); + } + return clazz.cast(ProtoAdapter.scalarToCel(type, value)); + } + + @Override + public List repeatedValue() { + return Collections.emptyList(); + } + + @Override + public Map mapValue() { + return Collections.emptyMap(); + } +} diff --git a/src/main/java/build/buf/protovalidate/NowVariable.java b/src/main/java/build/buf/protovalidate/NowVariable.java index 97953460..8330d0c5 100644 --- a/src/main/java/build/buf/protovalidate/NowVariable.java +++ b/src/main/java/build/buf/protovalidate/NowVariable.java @@ -14,41 +14,39 @@ package build.buf.protovalidate; +import com.google.protobuf.Timestamp; +import dev.cel.runtime.CelVariableResolver; import java.time.Instant; +import java.util.Optional; import org.jspecify.annotations.Nullable; -import org.projectnessie.cel.common.types.TimestampT; -import org.projectnessie.cel.interpreter.Activation; -import org.projectnessie.cel.interpreter.ResolvedValue; /** - * {@link NowVariable} implements {@link Activation}, providing a lazily produced timestamp for - * accessing the variable `now` that's constant within an evaluation. + * {@link NowVariable} implements {@link CelVariableResolver}, providing a lazily produced timestamp + * for accessing the variable `now` that's constant within an evaluation. */ -class NowVariable implements Activation { +class NowVariable implements CelVariableResolver { /** The name of the 'now' variable. */ - private static final String NOW_NAME = "now"; + public static final String NOW_NAME = "now"; /** The resolved value of the 'now' variable. */ - @Nullable private ResolvedValue resolvedValue; + @Nullable private Timestamp now; /** Creates an instance of a "now" variable. */ public NowVariable() {} @Override - public ResolvedValue resolveName(String name) { + public Optional find(String name) { if (!name.equals(NOW_NAME)) { - return ResolvedValue.ABSENT; - } else if (resolvedValue != null) { - return resolvedValue; + return Optional.empty(); } - Instant instant = Instant.now(); // UTC. - TimestampT value = TimestampT.timestampOf(instant); - resolvedValue = ResolvedValue.resolvedValue(value); - return resolvedValue; - } - - @Override - public Activation parent() { - return Activation.emptyActivation(); + if (this.now == null) { + Instant nowInstant = Instant.now(); + now = + Timestamp.newBuilder() + .setSeconds(nowInstant.getEpochSecond()) + .setNanos(nowInstant.getNano()) + .build(); + } + return Optional.of(this.now); } } diff --git a/src/main/java/build/buf/protovalidate/ObjectValue.java b/src/main/java/build/buf/protovalidate/ObjectValue.java index af37251f..07829d34 100644 --- a/src/main/java/build/buf/protovalidate/ObjectValue.java +++ b/src/main/java/build/buf/protovalidate/ObjectValue.java @@ -23,7 +23,6 @@ import java.util.List; import java.util.Map; import org.jspecify.annotations.Nullable; -import org.projectnessie.cel.common.ULong; /** The {@link Value} type that contains a field descriptor and its value. */ final class ObjectValue implements Value { @@ -63,25 +62,7 @@ public Message messageValue() { @Override public T value(Class clazz) { - Descriptors.FieldDescriptor.Type type = fieldDescriptor.getType(); - if (fieldDescriptor.isMapField()) { - return clazz.cast(mapValueAsObject()); - } - if (!fieldDescriptor.isRepeated() - && (type == Descriptors.FieldDescriptor.Type.UINT32 - || type == Descriptors.FieldDescriptor.Type.UINT64 - || type == Descriptors.FieldDescriptor.Type.FIXED32 - || type == Descriptors.FieldDescriptor.Type.FIXED64)) { - /* - * Java does not have native support for unsigned int/long or uint32/uint64 types. - * To work with CEL's uint type in Java, special handling is required. - * - * When using uint32/uint64 in your protobuf objects or CEL expressions in Java, - * wrap them with the org.projectnessie.cel.common.ULong type. - */ - return clazz.cast(ULong.valueOf(((Number) value).longValue())); - } - return clazz.cast(value); + return clazz.cast(ProtoAdapter.toCel(fieldDescriptor, value)); } @Override @@ -90,44 +71,12 @@ public List repeatedValue() { if (fieldDescriptor.isRepeated()) { List list = (List) value; for (Object o : list) { - out.add(new ObjectValue(fieldDescriptor, o)); + out.add(new ListElementValue(fieldDescriptor, o)); } } return out; } - // TODO - This should be refactored at some point. - // - // This is essentially the same functionality as `mapValue` except that it - // returns a Map of Objects rather than a Map of protovalidate-java Values. - // It is used for binding to a CEL variable (i.e. `this`). - // Trying to bind a Map of Values to a CEL variable does not work because - // CEL-Java doesn't know how to interpret that proprietary Value object. - // - // Ideally, we should be using CEL-Java's org.projectnessie.cel.common.types.ref.Val - // type instead of our own custom Value abstraction. However, since we are evaluating - // Java CEL implementations, we should probably wait until that decision is made before - // making such a large refactor. This should suffice as a stopgap until then. - private Map mapValueAsObject() { - List input = - value instanceof List - ? (List) value - : Collections.singletonList((AbstractMessage) value); - - Descriptors.FieldDescriptor keyDesc = fieldDescriptor.getMessageType().findFieldByNumber(1); - Descriptors.FieldDescriptor valDesc = fieldDescriptor.getMessageType().findFieldByNumber(2); - Map out = new HashMap<>(input.size()); - - for (AbstractMessage entry : input) { - Object keyValue = entry.getField(keyDesc); - Object valValue = entry.getField(valDesc); - - out.put(keyValue, valValue); - } - - return out; - } - @Override public Map mapValue() { List input = diff --git a/src/main/java/build/buf/protovalidate/ProtoAdapter.java b/src/main/java/build/buf/protovalidate/ProtoAdapter.java new file mode 100644 index 00000000..0c20ce79 --- /dev/null +++ b/src/main/java/build/buf/protovalidate/ProtoAdapter.java @@ -0,0 +1,90 @@ +// Copyright 2023-2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package build.buf.protovalidate; + +import com.google.common.primitives.UnsignedLong; +import com.google.protobuf.AbstractMessage; +import com.google.protobuf.Descriptors; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * CEL supports protobuf natively but when we pass it field values (like scalars, repeated, and + * maps) it has no way to treat them like a proto message field. This class has methods to convert + * to a cel values. + */ +final class ProtoAdapter { + /** Converts a protobuf field value to CEL compatible value. */ + public static Object toCel(Descriptors.FieldDescriptor fieldDescriptor, Object value) { + Descriptors.FieldDescriptor.Type type = fieldDescriptor.getType(); + if (fieldDescriptor.isMapField()) { + List input = + value instanceof List + ? (List) value + : Collections.singletonList((AbstractMessage) value); + Descriptors.FieldDescriptor keyDesc = fieldDescriptor.getMessageType().findFieldByNumber(1); + Descriptors.FieldDescriptor valDesc = fieldDescriptor.getMessageType().findFieldByNumber(2); + Map out = new HashMap<>(input.size()); + + for (AbstractMessage entry : input) { + Object keyValue = entry.getField(keyDesc); + Object valValue = entry.getField(valDesc); + out.put(toCel(keyDesc, keyValue), toCel(valDesc, valValue)); + } + return out; + } + // Cel understands protobuf message so we return as is (even if it is repeated). + if (type == Descriptors.FieldDescriptor.Type.MESSAGE) { + return value; + } + if (fieldDescriptor.isRepeated()) { + List out = new ArrayList<>(); + List list = (List) value; + for (Object element : list) { + out.add(scalarToCel(type, element)); + } + return out; + } + return scalarToCel(type, value); + } + + /** Converts a scalar type to cel value. */ + public static Object scalarToCel(Descriptors.FieldDescriptor.Type type, Object value) { + switch (type) { + case ENUM: + if (value instanceof Descriptors.EnumValueDescriptor) { + return (long) ((Descriptors.EnumValueDescriptor) value).getNumber(); + } + return value; + case FLOAT: + return Double.valueOf((Float) value); + case INT32: + case SINT32: + case SFIXED32: + return Long.valueOf((Integer) value); + case FIXED32: + case UINT32: + return UnsignedLong.fromLongBits(Long.valueOf((Integer) value)); + case UINT64: + case FIXED64: + return UnsignedLong.fromLongBits((Long) value); + default: + return value; + } + } +} diff --git a/src/main/java/build/buf/protovalidate/RuleCache.java b/src/main/java/build/buf/protovalidate/RuleCache.java index 5ea1af99..9618b473 100644 --- a/src/main/java/build/buf/protovalidate/RuleCache.java +++ b/src/main/java/build/buf/protovalidate/RuleCache.java @@ -27,21 +27,15 @@ import com.google.protobuf.Message; import com.google.protobuf.MessageLite; import com.google.protobuf.TypeRegistry; +import dev.cel.bundle.Cel; +import dev.cel.common.types.StructTypeReference; +import dev.cel.runtime.CelEvaluationException; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import org.jspecify.annotations.Nullable; -import org.projectnessie.cel.Ast; -import org.projectnessie.cel.Env; -import org.projectnessie.cel.EnvOption; -import org.projectnessie.cel.EvalOption; -import org.projectnessie.cel.Program; -import org.projectnessie.cel.ProgramOption; -import org.projectnessie.cel.checker.Decls; -import org.projectnessie.cel.common.types.ref.Val; -import org.projectnessie.cel.interpreter.Activation; /** A build-through cache for computed standard rules. */ class RuleCache { @@ -63,14 +57,6 @@ public CelRule(AstExpression astExpression, FieldDescriptor field, FieldPath rul EXTENSION_REGISTRY.add(ValidateProto.predefined); } - /** Partial eval options for evaluating the rule's expression. */ - private static final ProgramOption PARTIAL_EVAL_OPTIONS = - ProgramOption.evalOptions( - EvalOption.OptTrackState, - EvalOption.OptExhaustiveEval, - EvalOption.OptOptimize, - EvalOption.OptPartialEval); - /** * Concurrent map for caching {@link FieldDescriptor} and their associated List of {@link * AstExpression}. @@ -79,7 +65,7 @@ public CelRule(AstExpression astExpression, FieldDescriptor field, FieldPath rul new ConcurrentHashMap<>(); /** The environment to use for evaluation. */ - private final Env env; + private final Cel cel; /** Registry used to resolve dynamic messages. */ private final TypeRegistry typeRegistry; @@ -94,11 +80,11 @@ public CelRule(AstExpression astExpression, FieldDescriptor field, FieldPath rul * Constructs a new build-through cache for the standard rules, with a provided registry to * resolve dynamic extensions. * - * @param env The CEL environment for evaluation. + * @param cel The CEL environment for evaluation. * @param config The configuration to use for the rule cache. */ - public RuleCache(Env env, Config config) { - this.env = env; + public RuleCache(Cel cel, Config config) { + this.cel = cel; this.typeRegistry = config.getTypeRegistry(); this.extensionRegistry = config.getExtensionRegistry(); this.allowUnknownFields = config.isAllowingUnknownFields(); @@ -133,37 +119,19 @@ public List compile( } List programs = new ArrayList<>(); for (CelRule rule : completeProgramList) { - Env ruleEnv = getRuleEnv(fieldDescriptor, message, rule.field, forItems); - Variable ruleVar = Variable.newRuleVariable(message, message.getField(rule.field)); - ProgramOption globals = ProgramOption.globals(ruleVar); - Value ruleValue = new ObjectValue(rule.field, message.getField(rule.field)); + Cel ruleCel = getRuleCel(fieldDescriptor, message, rule.field, forItems); try { - Program program = ruleEnv.program(rule.astExpression.ast, globals, PARTIAL_EVAL_OPTIONS); - Program.EvalResult evalResult = program.eval(Activation.emptyActivation()); - Val value = evalResult.getVal(); - if (value != null) { - Object val = value.value(); - if (val instanceof Boolean && value.booleanValue()) { - continue; - } - if (val instanceof String && val.equals("")) { - continue; - } - } - Ast residual = ruleEnv.residualAst(rule.astExpression.ast, evalResult.getEvalDetails()); - programs.add( - new CompiledProgram( - ruleEnv.program(residual, globals), - rule.astExpression.source, - rule.rulePath, - ruleValue)); - } catch (Exception e) { programs.add( new CompiledProgram( - ruleEnv.program(rule.astExpression.ast, globals), + ruleCel.createProgram(rule.astExpression.ast), rule.astExpression.source, rule.rulePath, - ruleValue)); + new ObjectValue(rule.field, message.getField(rule.field)), + Variable.newRuleVariable( + message, ProtoAdapter.toCel(rule.field, message.getField(rule.field))))); + } catch (CelEvaluationException e) { + throw new CompilationException( + "failed to evaluate rule " + rule.astExpression.source.id, e); } } return Collections.unmodifiableList(programs); @@ -184,7 +152,7 @@ public List compile( if (rules == null) return null; List expressions = Expression.fromRules(rules.getCelList()); celRules = new ArrayList<>(expressions.size()); - Env ruleEnv = getRuleEnv(fieldDescriptor, message, ruleFieldDesc, forItems); + Cel ruleCel = getRuleCel(fieldDescriptor, message, ruleFieldDesc, forItems); for (Expression expression : expressions) { FieldPath rulePath = FieldPath.newBuilder() @@ -193,7 +161,7 @@ public List compile( .build(); celRules.add( new CelRule( - AstExpression.newAstExpression(ruleEnv, expression), ruleFieldDesc, rulePath)); + AstExpression.newAstExpression(ruleCel, expression), ruleFieldDesc, rulePath)); } descriptorMap.put(ruleFieldDesc, celRules); return celRules; @@ -243,20 +211,19 @@ public List compile( * @param forItems Whether the field is a list type or not. * @return An environment with requisite declarations and types added. */ - private Env getRuleEnv( + private Cel getRuleCel( FieldDescriptor fieldDescriptor, Message ruleMessage, FieldDescriptor ruleFieldDesc, boolean forItems) { - return env.extend( - EnvOption.types(ruleMessage.getDefaultInstanceForType()), - EnvOption.declarations( - Decls.newVar( - Variable.THIS_NAME, DescriptorMappings.getCELType(fieldDescriptor, forItems)), - Decls.newVar( - Variable.RULES_NAME, - Decls.newObjectType(ruleMessage.getDescriptorForType().getFullName())), - Decls.newVar(Variable.RULE_NAME, DescriptorMappings.getCELType(ruleFieldDesc, false)))); + return cel.toCelBuilder() + .addMessageTypes(ruleMessage.getDescriptorForType()) + .addVar(Variable.THIS_NAME, DescriptorMappings.getCELType(fieldDescriptor, forItems)) + .addVar( + Variable.RULES_NAME, + StructTypeReference.create(ruleMessage.getDescriptorForType().getFullName())) + .addVar(Variable.RULE_NAME, DescriptorMappings.getCELType(ruleFieldDesc, false)) + .build(); } private static class ResolvedRule { diff --git a/src/main/java/build/buf/protovalidate/ValidateLibrary.java b/src/main/java/build/buf/protovalidate/ValidateLibrary.java index f28fe152..71e66706 100644 --- a/src/main/java/build/buf/protovalidate/ValidateLibrary.java +++ b/src/main/java/build/buf/protovalidate/ValidateLibrary.java @@ -14,43 +14,46 @@ package build.buf.protovalidate; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import org.projectnessie.cel.EnvOption; -import org.projectnessie.cel.EvalOption; -import org.projectnessie.cel.Library; -import org.projectnessie.cel.ProgramOption; +import dev.cel.checker.CelCheckerBuilder; +import dev.cel.common.CelVarDecl; +import dev.cel.common.types.SimpleType; +import dev.cel.compiler.CelCompilerLibrary; +import dev.cel.parser.CelParserBuilder; +import dev.cel.parser.CelStandardMacro; +import dev.cel.runtime.CelRuntimeBuilder; +import dev.cel.runtime.CelRuntimeLibrary; /** - * Custom {@link Library} for CEL. Provides all the custom extension function definitions and - * overloads. + * Custom {@link CelCompilerLibrary} and {@link CelRuntimeLibrary}. Provides all the custom + * extension function definitions and overloads. */ -class ValidateLibrary implements Library { +class ValidateLibrary implements CelCompilerLibrary, CelRuntimeLibrary { /** Creates a ValidateLibrary with all custom declarations and overloads. */ public ValidateLibrary() {} - /** - * Returns the compile options for the CEL environment. - * - * @return the compile options. - */ @Override - public List getCompileOptions() { - return Collections.singletonList(EnvOption.declarations(CustomDeclarations.create())); + public void setParserOptions(CelParserBuilder parserBuilder) { + parserBuilder.setStandardMacros( + CelStandardMacro.ALL, + CelStandardMacro.EXISTS, + CelStandardMacro.EXISTS_ONE, + CelStandardMacro.FILTER, + CelStandardMacro.HAS, + CelStandardMacro.MAP, + CelStandardMacro.MAP_FILTER); } - /** - * Returns the program options for the CEL program. - * - * @return the program options. - */ @Override - public List getProgramOptions() { - return Arrays.asList( - ProgramOption.evalOptions(EvalOption.OptOptimize), - ProgramOption.globals(new NowVariable()), - ProgramOption.functions(CustomOverload.create())); + public void setCheckerOptions(CelCheckerBuilder checkerBuilder) { + checkerBuilder + .addVarDeclarations( + CelVarDecl.newVarDeclaration(NowVariable.NOW_NAME, SimpleType.TIMESTAMP)) + .addFunctionDeclarations(CustomDeclarations.create()); + } + + @Override + public void setRuntimeOptions(CelRuntimeBuilder runtimeBuilder) { + runtimeBuilder.addFunctionBindings(CustomOverload.create()); } } diff --git a/src/main/java/build/buf/protovalidate/ValidatorImpl.java b/src/main/java/build/buf/protovalidate/ValidatorImpl.java index d40304a9..6ed84ad1 100644 --- a/src/main/java/build/buf/protovalidate/ValidatorImpl.java +++ b/src/main/java/build/buf/protovalidate/ValidatorImpl.java @@ -18,10 +18,10 @@ import build.buf.protovalidate.exceptions.ValidationException; import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Message; +import dev.cel.bundle.Cel; +import dev.cel.bundle.CelFactory; import java.util.ArrayList; import java.util.List; -import org.projectnessie.cel.Env; -import org.projectnessie.cel.Library; class ValidatorImpl implements Validator { /** evaluatorBuilder is the builder used to construct the evaluator for a given message. */ @@ -34,15 +34,25 @@ class ValidatorImpl implements Validator { private final boolean failFast; ValidatorImpl(Config config) { - Env env = Env.newEnv(Library.Lib(new ValidateLibrary())); - this.evaluatorBuilder = new EvaluatorBuilder(env, config); + ValidateLibrary validateLibrary = new ValidateLibrary(); + Cel cel = + CelFactory.standardCelBuilder() + .addCompilerLibraries(validateLibrary) + .addRuntimeLibraries(validateLibrary) + .build(); + this.evaluatorBuilder = new EvaluatorBuilder(cel, config); this.failFast = config.isFailFast(); } ValidatorImpl(Config config, List descriptors, boolean disableLazy) throws CompilationException { - Env env = Env.newEnv(Library.Lib(new ValidateLibrary())); - this.evaluatorBuilder = new EvaluatorBuilder(env, config, descriptors, disableLazy); + ValidateLibrary validateLibrary = new ValidateLibrary(); + Cel cel = + CelFactory.standardCelBuilder() + .addCompilerLibraries(validateLibrary) + .addRuntimeLibraries(validateLibrary) + .build(); + this.evaluatorBuilder = new EvaluatorBuilder(cel, config, descriptors, disableLazy); this.failFast = config.isFailFast(); } diff --git a/src/main/java/build/buf/protovalidate/Variable.java b/src/main/java/build/buf/protovalidate/Variable.java index 7f9830e3..8e71a46a 100644 --- a/src/main/java/build/buf/protovalidate/Variable.java +++ b/src/main/java/build/buf/protovalidate/Variable.java @@ -14,15 +14,15 @@ package build.buf.protovalidate; +import dev.cel.runtime.CelVariableResolver; +import java.util.Optional; import org.jspecify.annotations.Nullable; -import org.projectnessie.cel.interpreter.Activation; -import org.projectnessie.cel.interpreter.ResolvedValue; /** - * {@link Variable} implements {@link org.projectnessie.cel.interpreter.Activation}, providing a - * lightweight named variable to cel.Program executions. + * {@link Variable} implements {@link CelVariableResolver}, providing a lightweight named variable + * to cel.Program executions. */ -class Variable implements Activation { +class Variable implements CelVariableResolver { /** The {@value} variable in CEL. */ public static final String THIS_NAME = "this"; @@ -32,8 +32,8 @@ class Variable implements Activation { /** The {@value} variable in CEL. */ public static final String RULE_NAME = "rule"; - /** The parent activation */ - private final Activation next; + /** The {@value} variable in CEL. */ + public static final String NOW_NAME = "now"; /** The variable's name */ private final String name; @@ -42,8 +42,7 @@ class Variable implements Activation { @Nullable private final Object val; /** Creates a variable with the given name and value. */ - private Variable(Activation activation, String name, @Nullable Object val) { - this.next = activation; + private Variable(String name, @Nullable Object val) { this.name = name; this.val = val; } @@ -54,8 +53,9 @@ private Variable(Activation activation, String name, @Nullable Object val) { * @param val the value. * @return {@link Variable}. */ - public static Variable newThisVariable(@Nullable Object val) { - return new Variable(new NowVariable(), THIS_NAME, val); + public static CelVariableResolver newThisVariable(@Nullable Object val) { + return CelVariableResolver.hierarchicalVariableResolver( + new NowVariable(), new Variable(THIS_NAME, val)); } /** @@ -64,8 +64,8 @@ public static Variable newThisVariable(@Nullable Object val) { * @param val the value. * @return {@link Variable}. */ - public static Variable newRulesVariable(Object val) { - return new Variable(Activation.emptyActivation(), RULES_NAME, val); + public static CelVariableResolver newRulesVariable(Object val) { + return new Variable(RULES_NAME, val); } /** @@ -75,22 +75,16 @@ public static Variable newRulesVariable(Object val) { * @param val the value of the "rule" variable. * @return {@link Variable}. */ - public static Variable newRuleVariable(Object rules, Object val) { - return new Variable(newRulesVariable(rules), RULE_NAME, val); + public static CelVariableResolver newRuleVariable(Object rules, Object val) { + return CelVariableResolver.hierarchicalVariableResolver( + newRulesVariable(rules), new Variable(RULE_NAME, val)); } @Override - public ResolvedValue resolveName(String name) { - if (this.name.equals(name)) { - return ResolvedValue.resolvedValue(val); - } else if (next != null) { - return next.resolveName(name); + public Optional find(String name) { + if (!this.name.equals(name) || val == null) { + return Optional.empty(); } - return ResolvedValue.ABSENT; - } - - @Override - public Activation parent() { - return next; + return Optional.of(val); } } diff --git a/src/test/java/build/buf/protovalidate/CustomOverloadTest.java b/src/test/java/build/buf/protovalidate/CustomOverloadTest.java index a4c7a4d1..602d5303 100644 --- a/src/test/java/build/buf/protovalidate/CustomOverloadTest.java +++ b/src/test/java/build/buf/protovalidate/CustomOverloadTest.java @@ -17,23 +17,28 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import dev.cel.bundle.Cel; +import dev.cel.bundle.CelFactory; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelValidationException; +import dev.cel.common.CelValidationResult; +import dev.cel.runtime.CelEvaluationException; import java.util.Arrays; import java.util.Collections; +import java.util.Map; import org.junit.jupiter.api.Test; -import org.projectnessie.cel.Ast; -import org.projectnessie.cel.Env; -import org.projectnessie.cel.Library; -import org.projectnessie.cel.Program; -import org.projectnessie.cel.common.types.Err; -import org.projectnessie.cel.common.types.ref.Val; -import org.projectnessie.cel.interpreter.Activation; public class CustomOverloadTest { - private final Env env = Env.newEnv(Library.Lib(new ValidateLibrary())); + private final ValidateLibrary validateLibrary = new ValidateLibrary(); + private final Cel cel = + CelFactory.standardCelBuilder() + .addCompilerLibraries(validateLibrary) + .addRuntimeLibraries(validateLibrary) + .build(); @Test - public void testIsInf() { + public void testIsInf() throws Exception { assertThat(evalToBool("0.0.isInf()")).isFalse(); assertThat(evalToBool("(1.0/0.0).isInf()")).isTrue(); assertThat(evalToBool("(1.0/0.0).isInf(0)")).isTrue(); @@ -48,15 +53,12 @@ public void testIsInf() { @Test public void testIsInfUnsupported() { for (String testCase : Arrays.asList("'abc'.isInf()", "0.0.isInf('abc')")) { - Val val = eval(testCase).getVal(); - assertThat(Err.isError(val)).isTrue(); - assertThatThrownBy(() -> val.convertToNative(Exception.class)) - .isInstanceOf(UnsupportedOperationException.class); + assertThatThrownBy(() -> evalToBool(testCase)).isInstanceOf(CelValidationException.class); } } @Test - public void testIsNan() { + public void testIsNan() throws Exception { assertThat(evalToBool("0.0.isNan()")).isFalse(); assertThat(evalToBool("(0.0/0.0).isNan()")).isTrue(); assertThat(evalToBool("(1.0/0.0).isNan()")).isFalse(); @@ -65,15 +67,12 @@ public void testIsNan() { @Test public void testIsNanUnsupported() { for (String testCase : Collections.singletonList("'foo'.isNan()")) { - Val val = eval(testCase).getVal(); - assertThat(Err.isError(val)).isTrue(); - assertThatThrownBy(() -> val.convertToNative(Exception.class)) - .isInstanceOf(UnsupportedOperationException.class); + assertThatThrownBy(() -> evalToBool(testCase)).isInstanceOf(CelValidationException.class); } } @Test - public void testUnique() { + public void testUnique() throws Exception { assertThat(evalToBool("[].unique()")).isTrue(); assertThat(evalToBool("[true].unique()")).isTrue(); assertThat(evalToBool("[true, false].unique()")).isTrue(); @@ -97,16 +96,12 @@ public void testUnique() { @Test public void testUniqueUnsupported() { for (String testCase : Collections.singletonList("1.unique()")) { - Program.EvalResult result = eval(testCase); - Val val = result.getVal(); - assertThat(Err.isError(val)).isTrue(); - assertThatThrownBy(() -> val.convertToNative(Exception.class)) - .isInstanceOf(UnsupportedOperationException.class); + assertThatThrownBy(() -> evalToBool(testCase)).isInstanceOf(CelValidationException.class); } } @Test - public void testIsIpPrefix() { + public void testIsIpPrefix() throws Exception { assertThat(evalToBool("'1.2.3.0/24'.isIpPrefix()")).isTrue(); assertThat(evalToBool("'1.2.3.4/24'.isIpPrefix()")).isTrue(); assertThat(evalToBool("'1.2.3.0/24'.isIpPrefix(true)")).isTrue(); @@ -141,22 +136,18 @@ public void testIsIpPrefixUnsupported() { "'1.2.3.0/24'.isIpPrefix('foo')", "'1.2.3.0/24'.isIpPrefix(4,'foo')", "'1.2.3.0/24'.isIpPrefix('foo',true)")) { - Program.EvalResult result = eval(testCase); - Val val = result.getVal(); - assertThat(Err.isError(val)).isTrue(); - assertThatThrownBy(() -> val.convertToNative(Exception.class)) - .isInstanceOf(UnsupportedOperationException.class); + assertThatThrownBy(() -> eval(testCase)).isInstanceOf(CelValidationException.class); } } @Test - public void testIsHostname() { + public void testIsHostname() throws Exception { assertThat(evalToBool("'example.com'.isHostname()")).isTrue(); assertThat(evalToBool("'example.123'.isHostname()")).isFalse(); } @Test - public void testIsEmail() { + public void testIsEmail() throws Exception { assertThat(evalToBool("'foo@example.com'.isEmail()")).isTrue(); assertThat(evalToBool("''.isEmail()")).isFalse(); assertThat(evalToBool("' foo@example.com'.isEmail()")).isFalse(); @@ -164,7 +155,7 @@ public void testIsEmail() { } @Test - public void testBytesContains() { + public void testBytesContains() throws Exception { assertThat(evalToBool("bytes('12345').contains(bytes(''))")).isTrue(); assertThat(evalToBool("bytes('12345').contains(bytes('1'))")).isTrue(); assertThat(evalToBool("bytes('12345').contains(bytes('5'))")).isTrue(); @@ -179,19 +170,18 @@ public void testBytesContains() { assertThat(evalToBool("bytes('12345').contains(bytes('123456'))")).isFalse(); } - private Program.EvalResult eval(String source) { - return eval(source, Activation.emptyActivation()); + private Object eval(String source) throws Exception { + return eval(source, Collections.emptyMap()); } - private Program.EvalResult eval(String source, Object vars) { - Env.AstIssuesTuple parsed = env.parse(source); - assertThat(parsed.hasIssues()).isFalse(); - Ast ast = parsed.getAst(); - return env.program(ast).eval(vars); + private Object eval(String source, Map vars) + throws CelEvaluationException, CelValidationException { + CelValidationResult parsed = cel.compile(source); + CelAbstractSyntaxTree ast = parsed.getAst(); + return cel.createProgram(ast).eval(vars); } - private boolean evalToBool(String source) { - Program.EvalResult result = eval(source); - return result.getVal().booleanValue(); + private boolean evalToBool(String source) throws Exception { + return (Boolean) eval(source); } } diff --git a/src/test/java/build/buf/protovalidate/FormatTest.java b/src/test/java/build/buf/protovalidate/FormatTest.java index 5a6591bd..7438ccf7 100644 --- a/src/test/java/build/buf/protovalidate/FormatTest.java +++ b/src/test/java/build/buf/protovalidate/FormatTest.java @@ -14,8 +14,7 @@ package build.buf.protovalidate; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.fail; +import static org.assertj.core.api.Assertions.*; import cel.expr.conformance.proto3.TestAllTypes; import com.cel.expr.Decl; @@ -25,6 +24,14 @@ import com.cel.expr.conformance.test.SimpleTestFile; import com.cel.expr.conformance.test.SimpleTestSection; import com.google.protobuf.TextFormat; +import dev.cel.bundle.Cel; +import dev.cel.bundle.CelBuilder; +import dev.cel.bundle.CelFactory; +import dev.cel.common.CelValidationException; +import dev.cel.common.CelValidationResult; +import dev.cel.common.types.SimpleType; +import dev.cel.runtime.CelEvaluationException; +import dev.cel.runtime.CelRuntime.Program; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Paths; @@ -39,22 +46,13 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -import org.projectnessie.cel.Env; -import org.projectnessie.cel.EnvOption; -import org.projectnessie.cel.EvalOption; -import org.projectnessie.cel.Library; -import org.projectnessie.cel.Program; -import org.projectnessie.cel.ProgramOption; -import org.projectnessie.cel.checker.Decls; -import org.projectnessie.cel.common.types.ref.TypeEnum; -import org.projectnessie.cel.interpreter.Activation; class FormatTest { // Version of the cel-spec that this implementation is conformant with // This should be kept in sync with the version in gradle.properties private static String CEL_SPEC_VERSION = "v0.24.0"; - private static Env env; + private static Cel cel; private static List formatTests; private static List formatErrorTests; @@ -88,23 +86,26 @@ public static void setUp() throws Exception { .flatMap(s -> s.getTestList().stream()) .collect(Collectors.toList()); - env = Env.newEnv(Library.Lib(new ValidateLibrary())); + ValidateLibrary validateLibrary = new ValidateLibrary(); + cel = + CelFactory.standardCelBuilder() + .addCompilerLibraries(validateLibrary) + .addRuntimeLibraries(validateLibrary) + .build(); } @ParameterizedTest() @MethodSource("getFormatTests") - void testFormatSuccess(SimpleTest test) { - Program.EvalResult result = evaluate(test); - assertThat(result.getVal().value()).isEqualTo(getExpectedResult(test)); - assertThat(result.getVal().type().typeEnum()).isEqualTo(TypeEnum.String); + void testFormatSuccess(SimpleTest test) throws CelValidationException, CelEvaluationException { + Object result = evaluate(test); + assertThat(result).isEqualTo(getExpectedResult(test)); + assertThat(result).isInstanceOf(String.class); } @ParameterizedTest() @MethodSource("getFormatErrorTests") void testFormatError(SimpleTest test) { - Program.EvalResult result = evaluate(test); - assertThat(result.getVal().value()).isEqualTo(getExpectedResult(test)); - assertThat(result.getVal().type().typeEnum()).isEqualTo(TypeEnum.Err); + assertThatThrownBy(() -> evaluate(test)).isInstanceOf(CelEvaluationException.class); } // Loads test data from the given text format file @@ -120,22 +121,19 @@ private static List loadTestData(String fileName) throws Exce // Runs a test by extending the cel environment with the specified // types, variables and declarations, then evaluating it with the cel runtime. - private static Program.EvalResult evaluate(SimpleTest test) { - List decls = buildDecls(test); + private static Object evaluate(SimpleTest test) + throws CelValidationException, CelEvaluationException { - TestAllTypes msg = TestAllTypes.newBuilder().getDefaultInstanceForType(); - Env newEnv = env.extend(EnvOption.types(msg), EnvOption.declarations(decls)); + CelBuilder builder = cel.toCelBuilder().addMessageTypes(TestAllTypes.getDescriptor()); + addDecls(builder, test); + Cel newCel = builder.build(); - Env.AstIssuesTuple ast = newEnv.compile(test.getExpr()); - if (ast.hasIssues()) { - fail("error building AST for evaluation: " + ast.getIssues().toString()); + CelValidationResult validationResult = newCel.compile(test.getExpr()); + if (!validationResult.getAllIssues().isEmpty()) { + fail("error building AST for evaluation: " + validationResult.getIssueString()); } - Map vars = buildVariables(test.getBindingsMap()); - ProgramOption globals = ProgramOption.globals(vars); - Program program = - newEnv.program(ast.getAst(), globals, ProgramOption.evalOptions(EvalOption.OptTrackState)); - - return program.eval(Activation.emptyActivation()); + Program program = newCel.createProgram(validationResult.getAst()); + return program.eval(buildVariables(test.getBindingsMap())); } private static Stream getTestStream(List tests) { @@ -186,19 +184,17 @@ private static String getExpectedResult(SimpleTest test) { } // Builds the declarations for a given test - private static List buildDecls(SimpleTest test) { - List decls = new ArrayList<>(); + private static void addDecls(CelBuilder builder, SimpleTest test) { for (Decl decl : test.getTypeEnvList()) { if (decl.hasIdent()) { Decl.IdentDecl ident = decl.getIdent(); com.cel.expr.Type type = ident.getType(); if (type.hasPrimitive()) { if (type.getPrimitive() == com.cel.expr.Type.PrimitiveType.STRING) { - decls.add(Decls.newVar(decl.getName(), Decls.String)); + builder.addVar(decl.getName(), SimpleType.STRING); } } } } - return decls; } } diff --git a/src/test/java/build/buf/protovalidate/ValidatorConstructionTest.java b/src/test/java/build/buf/protovalidate/ValidatorConstructionTest.java index 85177646..470c00f4 100644 --- a/src/test/java/build/buf/protovalidate/ValidatorConstructionTest.java +++ b/src/test/java/build/buf/protovalidate/ValidatorConstructionTest.java @@ -108,8 +108,7 @@ public void testSeedDescriptorsImmutable() { FieldExpressionMapInt32 msg = FieldExpressionMapInt32.newBuilder().putAllVal(testMap).build(); List seedDescriptors = new ArrayList(); - FieldExpressionMapInt32 reg = FieldExpressionMapInt32.newBuilder().build(); - seedDescriptors.add(reg.getDescriptorForType()); + seedDescriptors.add(msg.getDescriptorForType()); Config cfg = Config.newBuilder().setFailFast(true).build(); try {