diff --git a/src/main/java/build/buf/protovalidate/EvaluatorBuilder.java b/src/main/java/build/buf/protovalidate/EvaluatorBuilder.java index 50923c64..6b25545d 100644 --- a/src/main/java/build/buf/protovalidate/EvaluatorBuilder.java +++ b/src/main/java/build/buf/protovalidate/EvaluatorBuilder.java @@ -270,7 +270,7 @@ private void processIgnoreEmpty( FieldConstraints fieldConstraints, ValueEvaluator valueEvaluatorEval) throws CompilationException { - if (valueEvaluatorEval.getNestedRule() != null && shouldIgnoreEmpty(fieldConstraints)) { + if (valueEvaluatorEval.hasNestedRule() && shouldIgnoreEmpty(fieldConstraints)) { valueEvaluatorEval.setIgnoreEmpty(zeroValue(fieldDescriptor, true)); } } @@ -348,7 +348,8 @@ private void processFieldExpressions( EnvOption.declarations( Decls.newVar( Variable.THIS_NAME, - Decls.newObjectType(fieldDescriptor.getMessageType().getFullName())))); + DescriptorMappings.getCELType( + fieldDescriptor, valueEvaluatorEval.hasNestedRule())))); } catch (InvalidProtocolBufferException e) { throw new CompilationException("field descriptor type is invalid " + e.getMessage(), e); } @@ -376,7 +377,7 @@ private void processEmbeddedMessage( if (fieldDescriptor.getJavaType() != FieldDescriptor.JavaType.MESSAGE || shouldSkip(fieldConstraints) || fieldDescriptor.isMapField() - || (fieldDescriptor.isRepeated() && valueEvaluatorEval.getNestedRule() == null)) { + || (fieldDescriptor.isRepeated() && !valueEvaluatorEval.hasNestedRule())) { return; } Evaluator embedEval = @@ -393,7 +394,7 @@ private void processWrapperConstraints( if (fieldDescriptor.getJavaType() != FieldDescriptor.JavaType.MESSAGE || shouldSkip(fieldConstraints) || fieldDescriptor.isMapField() - || (fieldDescriptor.isRepeated() && valueEvaluatorEval.getNestedRule() == null)) { + || (fieldDescriptor.isRepeated() && !valueEvaluatorEval.hasNestedRule())) { return; } FieldDescriptor expectedWrapperDescriptor = @@ -418,7 +419,7 @@ private void processStandardConstraints( throws CompilationException { List compile = constraintCache.compile( - fieldDescriptor, fieldConstraints, valueEvaluatorEval.getNestedRule() != null); + fieldDescriptor, fieldConstraints, valueEvaluatorEval.hasNestedRule()); if (compile.isEmpty()) { return; } @@ -429,7 +430,7 @@ private void processAnyConstraints( FieldDescriptor fieldDescriptor, FieldConstraints fieldConstraints, ValueEvaluator valueEvaluatorEval) { - if ((fieldDescriptor.isRepeated() && valueEvaluatorEval.getNestedRule() == null) + if ((fieldDescriptor.isRepeated() && !valueEvaluatorEval.hasNestedRule()) || fieldDescriptor.getJavaType() != FieldDescriptor.JavaType.MESSAGE || !fieldDescriptor.getMessageType().getFullName().equals("google.protobuf.Any")) { return; @@ -484,7 +485,7 @@ private void processRepeatedConstraints( throws CompilationException { if (fieldDescriptor.isMapField() || !fieldDescriptor.isRepeated() - || valueEvaluatorEval.getNestedRule() != null) { + || valueEvaluatorEval.hasNestedRule()) { return; } ListEvaluator listEval = new ListEvaluator(valueEvaluatorEval); diff --git a/src/main/java/build/buf/protovalidate/ValueEvaluator.java b/src/main/java/build/buf/protovalidate/ValueEvaluator.java index ea050b9a..8b0c0035 100644 --- a/src/main/java/build/buf/protovalidate/ValueEvaluator.java +++ b/src/main/java/build/buf/protovalidate/ValueEvaluator.java @@ -59,6 +59,10 @@ class ValueEvaluator implements Evaluator { return nestedRule; } + public boolean hasNestedRule() { + return this.nestedRule != null; + } + @Override public boolean tautology() { return evaluators.isEmpty(); diff --git a/src/test/java/build/buf/protovalidate/ValidatorCelExpressionTest.java b/src/test/java/build/buf/protovalidate/ValidatorCelExpressionTest.java new file mode 100644 index 00000000..e613476f --- /dev/null +++ b/src/test/java/build/buf/protovalidate/ValidatorCelExpressionTest.java @@ -0,0 +1,160 @@ +// Copyright 2023-2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package build.buf.protovalidate; + +import static org.assertj.core.api.Assertions.assertThat; + +import build.buf.validate.FieldConstraints; +import build.buf.validate.FieldPath; +import build.buf.validate.Violation; +import com.example.imports.buf.validate.RepeatedRules; +import java.util.Arrays; +import org.junit.jupiter.api.Test; + +/** + * This test verifies that custom (CEL-based) field and/or message constraints evaluate as expected. + */ +public class ValidatorCelExpressionTest { + + @Test + public void testFieldExpressionRepeatedMessage() throws Exception { + // Nested message wrapping the int 1 + com.example.imports.validationtest.FieldExpressionRepeatedMessage.Msg one = + com.example.imports.validationtest.FieldExpressionRepeatedMessage.Msg.newBuilder() + .setA(1) + .build(); + + // Nested message wrapping the int 2 + com.example.imports.validationtest.FieldExpressionRepeatedMessage.Msg two = + com.example.imports.validationtest.FieldExpressionRepeatedMessage.Msg.newBuilder() + .setA(2) + .build(); + + // Create a valid message (1, 1) + com.example.imports.validationtest.FieldExpressionRepeatedMessage validMsg = + com.example.imports.validationtest.FieldExpressionRepeatedMessage.newBuilder() + .addAllVal(Arrays.asList(one, one)) + .build(); + + // Create an invalid message (1, 2, 1) + com.example.imports.validationtest.FieldExpressionRepeatedMessage invalidMsg = + com.example.imports.validationtest.FieldExpressionRepeatedMessage.newBuilder() + .addAllVal(Arrays.asList(one, two, one)) + .build(); + + // Build a model of the expected violation + Violation expectedViolation = + Violation.newBuilder() + .setField( + FieldPath.newBuilder() + .addElements( + FieldPathUtils.fieldPathElement( + invalidMsg.getDescriptorForType().findFieldByName("val")) + .toBuilder() + .build())) + .setRule( + FieldPath.newBuilder() + .addElements( + FieldPathUtils.fieldPathElement( + FieldConstraints.getDescriptor() + .findFieldByNumber(FieldConstraints.CEL_FIELD_NUMBER)) + .toBuilder() + .setIndex(0) + .build())) + .setConstraintId("field_expression.repeated.message") + .setMessage("test message field_expression.repeated.message") + .build(); + + Validator validator = new Validator(); + + // Valid message checks + ValidationResult validResult = validator.validate(validMsg); + assertThat(validResult.isSuccess()).isTrue(); + + // Invalid message checks + ValidationResult invalidResult = validator.validate(invalidMsg); + assertThat(invalidResult.isSuccess()).isFalse(); + assertThat(invalidResult.toProto().getViolationsList()).containsExactly(expectedViolation); + } + + @Test + public void testFieldExpressionRepeatedMessageItems() throws Exception { + // Nested message wrapping the int 1 + com.example.imports.validationtest.FieldExpressionRepeatedMessageItems.Msg one = + com.example.imports.validationtest.FieldExpressionRepeatedMessageItems.Msg.newBuilder() + .setA(1) + .build(); + + // Nested message wrapping the int 2 + com.example.imports.validationtest.FieldExpressionRepeatedMessageItems.Msg two = + com.example.imports.validationtest.FieldExpressionRepeatedMessageItems.Msg.newBuilder() + .setA(2) + .build(); + + // Create a valid message (1, 1) + com.example.imports.validationtest.FieldExpressionRepeatedMessageItems validMsg = + com.example.imports.validationtest.FieldExpressionRepeatedMessageItems.newBuilder() + .addAllVal(Arrays.asList(one, one)) + .build(); + + // Create an invalid message (1, 2, 1) + com.example.imports.validationtest.FieldExpressionRepeatedMessageItems invalidMsg = + com.example.imports.validationtest.FieldExpressionRepeatedMessageItems.newBuilder() + .addAllVal(Arrays.asList(one, two, one)) + .build(); + + // Build a model of the expected violation + Violation expectedViolation = + Violation.newBuilder() + .setField( + FieldPath.newBuilder() + .addElements( + FieldPathUtils.fieldPathElement( + invalidMsg.getDescriptorForType().findFieldByName("val")) + .toBuilder() + .setIndex(1) + .build())) + .setRule( + FieldPath.newBuilder() + .addElements( + FieldPathUtils.fieldPathElement( + FieldConstraints.getDescriptor() + .findFieldByNumber(FieldConstraints.REPEATED_FIELD_NUMBER))) + .addElements( + FieldPathUtils.fieldPathElement( + RepeatedRules.getDescriptor().findFieldByName("items"))) + .addElements( + FieldPathUtils.fieldPathElement( + FieldConstraints.getDescriptor() + .findFieldByNumber(FieldConstraints.CEL_FIELD_NUMBER)) + .toBuilder() + .setIndex(0) + .build())) + .setConstraintId("field_expression.repeated.message.items") + .setMessage("test message field_expression.repeated.message.items") + .build(); + + Validator validator = new Validator(); + + // Valid message checks + ValidationResult validResult = validator.validate(validMsg); + assertThat(validResult.isSuccess()).isTrue(); + + // Invalid message checks + ValidationResult invalidResult = validator.validate(invalidMsg); + assertThat(invalidResult.isSuccess()).isFalse(); + assertThat(invalidResult.toProto().getViolationsList()).containsExactly(expectedViolation); + } +} diff --git a/src/test/resources/proto/validationtest/custom_constraints.proto b/src/test/resources/proto/validationtest/custom_constraints.proto new file mode 100644 index 00000000..98d7b8b6 --- /dev/null +++ b/src/test/resources/proto/validationtest/custom_constraints.proto @@ -0,0 +1,41 @@ +// Copyright 2023-2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package validationtest; + +import "buf/validate/validate.proto"; + +message FieldExpressionRepeatedMessage { + repeated Msg val = 1 [(buf.validate.field).cel = { + id: "field_expression.repeated.message" + message: "test message field_expression.repeated.message" + expression: "this.all(e, e.a == 1)" + }]; + message Msg { + int32 a = 1; + } +} + +message FieldExpressionRepeatedMessageItems { + repeated Msg val = 1 [(buf.validate.field).repeated.items.cel = { + id: "field_expression.repeated.message.items" + message: "test message field_expression.repeated.message.items" + expression: "this.a == 1" + }]; + message Msg { + int32 a = 1; + } +}