Skip to content

Commit 21d9f41

Browse files
fix(event-handler): preserve metadata constraints in parameter validation (#7609)
* Fixing regression with annotated types * Fixing regression with annotated types * Fixing regression with annotated types
1 parent 70e94c4 commit 21d9f41

File tree

3 files changed

+224
-2
lines changed

3 files changed

+224
-2
lines changed

aws_lambda_powertools/event_handler/openapi/compat.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# We use this for forward reference, as it allows us to handle forward references in type annotations.
1515
from pydantic._internal._typing_extra import eval_type_lenient
1616
from pydantic._internal._utils import lenient_issubclass
17+
from pydantic.fields import FieldInfo as PydanticFieldInfo
1718
from pydantic_core import PydanticUndefined, PydanticUndefinedType
1819
from typing_extensions import Annotated, Literal, get_args, get_origin
1920

@@ -186,8 +187,36 @@ def model_rebuild(model: type[BaseModel]) -> None:
186187
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
187188
# Create a shallow copy of the field_info to preserve its type and all attributes
188189
new_field = copy(field_info)
189-
# Update only the annotation to the new one
190-
new_field.annotation = annotation
190+
191+
# Recursively extract all metadata from nested Annotated types
192+
def extract_metadata(ann: Any) -> tuple[Any, list[Any]]:
193+
"""Extract base type and all non-FieldInfo metadata from potentially nested Annotated types."""
194+
if get_origin(ann) is not Annotated:
195+
return ann, []
196+
197+
args = get_args(ann)
198+
base_type = args[0]
199+
metadata = list(args[1:])
200+
201+
# If base type is also Annotated, recursively extract its metadata
202+
if get_origin(base_type) is Annotated:
203+
inner_base, inner_metadata = extract_metadata(base_type)
204+
all_metadata = [m for m in inner_metadata + metadata if not isinstance(m, PydanticFieldInfo)]
205+
return inner_base, all_metadata
206+
else:
207+
constraint_metadata = [m for m in metadata if not isinstance(m, PydanticFieldInfo)]
208+
return base_type, constraint_metadata
209+
210+
# Extract base type and constraints
211+
base_type, constraints = extract_metadata(annotation)
212+
213+
# Set the annotation with base type and all constraint metadata
214+
# Use tuple unpacking for Python 3.9+ compatibility
215+
if constraints:
216+
new_field.annotation = Annotated[(base_type, *constraints)]
217+
else:
218+
new_field.annotation = base_type
219+
191220
return new_field
192221

193222

aws_lambda_powertools/event_handler/openapi/params.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,10 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup
11101110
type_annotation = annotated_args[0]
11111111
powertools_annotations = [arg for arg in annotated_args[1:] if isinstance(arg, FieldInfo)]
11121112

1113+
# Preserve non-FieldInfo metadata (like annotated_types constraints)
1114+
# This is important for constraints like Interval, Gt, Lt, etc.
1115+
other_metadata = [arg for arg in annotated_args[1:] if not isinstance(arg, FieldInfo)]
1116+
11131117
# Determine which annotation to use
11141118
powertools_annotation: FieldInfo | None = None
11151119
has_discriminator_with_param = False
@@ -1124,6 +1128,11 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup
11241128
else:
11251129
powertools_annotation = next(iter(powertools_annotations), None)
11261130

1131+
# Reconstruct type_annotation with non-FieldInfo metadata if present
1132+
# This ensures constraints like Interval are preserved
1133+
if other_metadata and not has_discriminator_with_param:
1134+
type_annotation = Annotated[(type_annotation, *other_metadata)]
1135+
11271136
# Process the annotation if it exists
11281137
field_info: FieldInfo | None = None
11291138
if isinstance(powertools_annotation, FieldInfo): # pragma: no cover

tests/functional/event_handler/_pydantic/test_openapi_params.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import json
12
from dataclasses import dataclass
23
from datetime import datetime
34
from typing import List, Optional, Tuple
45

6+
import pytest
57
from pydantic import BaseModel, Field
68
from typing_extensions import Annotated
79

@@ -1044,3 +1046,185 @@ def complex_handler(params: Annotated[QueryParams, Query()]):
10441046
assert type_mapping["int_field"] == "integer"
10451047
assert type_mapping["float_field"] == "number"
10461048
assert type_mapping["bool_field"] == "boolean"
1049+
1050+
1051+
@pytest.mark.parametrize(
1052+
"body_value,expected_value",
1053+
[
1054+
("50", 50), # Valid: within range
1055+
("0", 0), # Valid: at lower bound
1056+
("100", 100), # Valid: at upper bound
1057+
],
1058+
)
1059+
def test_annotated_types_interval_constraints_in_body_params(body_value, expected_value):
1060+
"""
1061+
Test for issue #7600: Validate that annotated_types.Interval constraints
1062+
are properly enforced in Body parameters with valid values.
1063+
"""
1064+
from annotated_types import Interval
1065+
1066+
# GIVEN an APIGatewayRestResolver with validation enabled
1067+
app = APIGatewayRestResolver(enable_validation=True)
1068+
1069+
# AND a constrained type using annotated_types.Interval
1070+
ConstrainedInt = Annotated[int, Interval(ge=0, le=100)]
1071+
1072+
@app.post("/items")
1073+
def create_item(value: Annotated[ConstrainedInt, Body()]):
1074+
return {"value": value}
1075+
1076+
# WHEN sending a request with a valid value
1077+
event = {
1078+
"resource": "/items",
1079+
"path": "/items",
1080+
"httpMethod": "POST",
1081+
"body": body_value,
1082+
"isBase64Encoded": False,
1083+
}
1084+
1085+
# THEN the request should succeed
1086+
result = app(event, {})
1087+
assert result["statusCode"] == 200
1088+
body = json.loads(result["body"])
1089+
assert body["value"] == expected_value
1090+
1091+
1092+
@pytest.mark.parametrize(
1093+
"body_value",
1094+
[
1095+
"-1", # Invalid: below range
1096+
"101", # Invalid: above range
1097+
],
1098+
)
1099+
def test_annotated_types_interval_constraints_in_body_params_invalid(body_value):
1100+
"""
1101+
Test for issue #7600: Validate that annotated_types.Interval constraints
1102+
reject invalid values in Body parameters.
1103+
"""
1104+
from annotated_types import Interval
1105+
1106+
# GIVEN an APIGatewayRestResolver with validation enabled
1107+
app = APIGatewayRestResolver(enable_validation=True)
1108+
1109+
# AND a constrained type using annotated_types.Interval
1110+
constrained_int = Annotated[int, Interval(ge=0, le=100)]
1111+
1112+
@app.post("/items")
1113+
def create_item(value: Annotated[constrained_int, Body()]):
1114+
return {"value": value}
1115+
1116+
# WHEN sending a request with an invalid value
1117+
event = {
1118+
"resource": "/items",
1119+
"path": "/items",
1120+
"httpMethod": "POST",
1121+
"body": body_value,
1122+
"isBase64Encoded": False,
1123+
}
1124+
1125+
# THEN validation should fail
1126+
result = app(event, {})
1127+
assert result["statusCode"] == 422
1128+
1129+
1130+
@pytest.mark.parametrize(
1131+
"query_value,expected_value",
1132+
[
1133+
("50", 50), # Valid: within range
1134+
("0", 0), # Valid: at lower bound
1135+
("100", 100), # Valid: at upper bound
1136+
],
1137+
)
1138+
def test_annotated_types_interval_constraints_in_query_params(query_value, expected_value):
1139+
"""
1140+
Test for issue #7600: Validate that annotated_types.Interval constraints
1141+
are properly enforced in Query parameters with valid values.
1142+
"""
1143+
from annotated_types import Interval
1144+
1145+
# GIVEN an APIGatewayRestResolver with validation enabled
1146+
app = APIGatewayRestResolver(enable_validation=True)
1147+
1148+
# AND a constrained type using annotated_types.Interval
1149+
constrained_int = Annotated[int, Interval(ge=0, le=100)]
1150+
1151+
@app.get("/items")
1152+
def list_items(limit: Annotated[constrained_int, Query()]):
1153+
return {"limit": limit}
1154+
1155+
# WHEN sending a request with a valid value
1156+
event = {
1157+
"resource": "/items",
1158+
"path": "/items",
1159+
"httpMethod": "GET",
1160+
"queryStringParameters": {"limit": query_value},
1161+
"isBase64Encoded": False,
1162+
}
1163+
1164+
# THEN the request should succeed
1165+
result = app(event, {})
1166+
assert result["statusCode"] == 200
1167+
body = json.loads(result["body"])
1168+
assert body["limit"] == expected_value
1169+
1170+
1171+
@pytest.mark.parametrize(
1172+
"query_value",
1173+
[
1174+
"-1", # Invalid: below range
1175+
"101", # Invalid: above range
1176+
],
1177+
)
1178+
def test_annotated_types_interval_constraints_in_query_params_invalid(query_value):
1179+
"""
1180+
Test for issue #7600: Validate that annotated_types.Interval constraints
1181+
reject invalid values in Query parameters.
1182+
"""
1183+
from annotated_types import Interval
1184+
1185+
# GIVEN an APIGatewayRestResolver with validation enabled
1186+
app = APIGatewayRestResolver(enable_validation=True)
1187+
1188+
# AND a constrained type using annotated_types.Interval
1189+
constrained_int = Annotated[int, Interval(ge=0, le=100)]
1190+
1191+
@app.get("/items")
1192+
def list_items(limit: Annotated[constrained_int, Query()]):
1193+
return {"limit": limit}
1194+
1195+
# WHEN sending a request with an invalid value
1196+
event = {
1197+
"resource": "/items",
1198+
"path": "/items",
1199+
"httpMethod": "GET",
1200+
"queryStringParameters": {"limit": query_value},
1201+
"isBase64Encoded": False,
1202+
}
1203+
1204+
# THEN validation should fail
1205+
result = app(event, {})
1206+
assert result["statusCode"] == 422
1207+
1208+
1209+
def test_annotated_types_interval_in_openapi_schema():
1210+
"""
1211+
Test that annotated_types.Interval constraints are reflected in the OpenAPI schema.
1212+
"""
1213+
from annotated_types import Interval
1214+
1215+
app = APIGatewayRestResolver()
1216+
constrained_int = Annotated[int, Interval(ge=0, le=100)]
1217+
1218+
@app.get("/items")
1219+
def list_items(limit: Annotated[constrained_int, Query()] = 10):
1220+
return {"limit": limit}
1221+
1222+
schema = app.get_openapi_schema()
1223+
1224+
# Verify the Query parameter schema includes constraints
1225+
get_operation = schema.paths["/items"].get
1226+
limit_param = next(p for p in get_operation.parameters if p.name == "limit")
1227+
1228+
assert limit_param.schema_.type == "integer"
1229+
assert limit_param.schema_.default == 10
1230+
assert limit_param.required is False

0 commit comments

Comments
 (0)