|
| 1 | +import json |
1 | 2 | from dataclasses import dataclass |
2 | 3 | from datetime import datetime |
3 | 4 | from typing import List, Optional, Tuple |
4 | 5 |
|
| 6 | +import pytest |
5 | 7 | from pydantic import BaseModel, Field |
6 | 8 | from typing_extensions import Annotated |
7 | 9 |
|
@@ -1044,3 +1046,185 @@ def complex_handler(params: Annotated[QueryParams, Query()]): |
1044 | 1046 | assert type_mapping["int_field"] == "integer" |
1045 | 1047 | assert type_mapping["float_field"] == "number" |
1046 | 1048 | assert type_mapping["bool_field"] == "boolean" |
| 1049 | + |
| 1050 | + |
| 1051 | +@pytest.mark.parametrize( |
| 1052 | + "body_value,expected_value", |
| 1053 | + [ |
| 1054 | + ("50", 50), # Valid: within range |
| 1055 | + ("0", 0), # Valid: at lower bound |
| 1056 | + ("100", 100), # Valid: at upper bound |
| 1057 | + ], |
| 1058 | +) |
| 1059 | +def test_annotated_types_interval_constraints_in_body_params(body_value, expected_value): |
| 1060 | + """ |
| 1061 | + Test for issue #7600: Validate that annotated_types.Interval constraints |
| 1062 | + are properly enforced in Body parameters with valid values. |
| 1063 | + """ |
| 1064 | + from annotated_types import Interval |
| 1065 | + |
| 1066 | + # GIVEN an APIGatewayRestResolver with validation enabled |
| 1067 | + app = APIGatewayRestResolver(enable_validation=True) |
| 1068 | + |
| 1069 | + # AND a constrained type using annotated_types.Interval |
| 1070 | + ConstrainedInt = Annotated[int, Interval(ge=0, le=100)] |
| 1071 | + |
| 1072 | + @app.post("/items") |
| 1073 | + def create_item(value: Annotated[ConstrainedInt, Body()]): |
| 1074 | + return {"value": value} |
| 1075 | + |
| 1076 | + # WHEN sending a request with a valid value |
| 1077 | + event = { |
| 1078 | + "resource": "/items", |
| 1079 | + "path": "/items", |
| 1080 | + "httpMethod": "POST", |
| 1081 | + "body": body_value, |
| 1082 | + "isBase64Encoded": False, |
| 1083 | + } |
| 1084 | + |
| 1085 | + # THEN the request should succeed |
| 1086 | + result = app(event, {}) |
| 1087 | + assert result["statusCode"] == 200 |
| 1088 | + body = json.loads(result["body"]) |
| 1089 | + assert body["value"] == expected_value |
| 1090 | + |
| 1091 | + |
| 1092 | +@pytest.mark.parametrize( |
| 1093 | + "body_value", |
| 1094 | + [ |
| 1095 | + "-1", # Invalid: below range |
| 1096 | + "101", # Invalid: above range |
| 1097 | + ], |
| 1098 | +) |
| 1099 | +def test_annotated_types_interval_constraints_in_body_params_invalid(body_value): |
| 1100 | + """ |
| 1101 | + Test for issue #7600: Validate that annotated_types.Interval constraints |
| 1102 | + reject invalid values in Body parameters. |
| 1103 | + """ |
| 1104 | + from annotated_types import Interval |
| 1105 | + |
| 1106 | + # GIVEN an APIGatewayRestResolver with validation enabled |
| 1107 | + app = APIGatewayRestResolver(enable_validation=True) |
| 1108 | + |
| 1109 | + # AND a constrained type using annotated_types.Interval |
| 1110 | + constrained_int = Annotated[int, Interval(ge=0, le=100)] |
| 1111 | + |
| 1112 | + @app.post("/items") |
| 1113 | + def create_item(value: Annotated[constrained_int, Body()]): |
| 1114 | + return {"value": value} |
| 1115 | + |
| 1116 | + # WHEN sending a request with an invalid value |
| 1117 | + event = { |
| 1118 | + "resource": "/items", |
| 1119 | + "path": "/items", |
| 1120 | + "httpMethod": "POST", |
| 1121 | + "body": body_value, |
| 1122 | + "isBase64Encoded": False, |
| 1123 | + } |
| 1124 | + |
| 1125 | + # THEN validation should fail |
| 1126 | + result = app(event, {}) |
| 1127 | + assert result["statusCode"] == 422 |
| 1128 | + |
| 1129 | + |
| 1130 | +@pytest.mark.parametrize( |
| 1131 | + "query_value,expected_value", |
| 1132 | + [ |
| 1133 | + ("50", 50), # Valid: within range |
| 1134 | + ("0", 0), # Valid: at lower bound |
| 1135 | + ("100", 100), # Valid: at upper bound |
| 1136 | + ], |
| 1137 | +) |
| 1138 | +def test_annotated_types_interval_constraints_in_query_params(query_value, expected_value): |
| 1139 | + """ |
| 1140 | + Test for issue #7600: Validate that annotated_types.Interval constraints |
| 1141 | + are properly enforced in Query parameters with valid values. |
| 1142 | + """ |
| 1143 | + from annotated_types import Interval |
| 1144 | + |
| 1145 | + # GIVEN an APIGatewayRestResolver with validation enabled |
| 1146 | + app = APIGatewayRestResolver(enable_validation=True) |
| 1147 | + |
| 1148 | + # AND a constrained type using annotated_types.Interval |
| 1149 | + constrained_int = Annotated[int, Interval(ge=0, le=100)] |
| 1150 | + |
| 1151 | + @app.get("/items") |
| 1152 | + def list_items(limit: Annotated[constrained_int, Query()]): |
| 1153 | + return {"limit": limit} |
| 1154 | + |
| 1155 | + # WHEN sending a request with a valid value |
| 1156 | + event = { |
| 1157 | + "resource": "/items", |
| 1158 | + "path": "/items", |
| 1159 | + "httpMethod": "GET", |
| 1160 | + "queryStringParameters": {"limit": query_value}, |
| 1161 | + "isBase64Encoded": False, |
| 1162 | + } |
| 1163 | + |
| 1164 | + # THEN the request should succeed |
| 1165 | + result = app(event, {}) |
| 1166 | + assert result["statusCode"] == 200 |
| 1167 | + body = json.loads(result["body"]) |
| 1168 | + assert body["limit"] == expected_value |
| 1169 | + |
| 1170 | + |
| 1171 | +@pytest.mark.parametrize( |
| 1172 | + "query_value", |
| 1173 | + [ |
| 1174 | + "-1", # Invalid: below range |
| 1175 | + "101", # Invalid: above range |
| 1176 | + ], |
| 1177 | +) |
| 1178 | +def test_annotated_types_interval_constraints_in_query_params_invalid(query_value): |
| 1179 | + """ |
| 1180 | + Test for issue #7600: Validate that annotated_types.Interval constraints |
| 1181 | + reject invalid values in Query parameters. |
| 1182 | + """ |
| 1183 | + from annotated_types import Interval |
| 1184 | + |
| 1185 | + # GIVEN an APIGatewayRestResolver with validation enabled |
| 1186 | + app = APIGatewayRestResolver(enable_validation=True) |
| 1187 | + |
| 1188 | + # AND a constrained type using annotated_types.Interval |
| 1189 | + constrained_int = Annotated[int, Interval(ge=0, le=100)] |
| 1190 | + |
| 1191 | + @app.get("/items") |
| 1192 | + def list_items(limit: Annotated[constrained_int, Query()]): |
| 1193 | + return {"limit": limit} |
| 1194 | + |
| 1195 | + # WHEN sending a request with an invalid value |
| 1196 | + event = { |
| 1197 | + "resource": "/items", |
| 1198 | + "path": "/items", |
| 1199 | + "httpMethod": "GET", |
| 1200 | + "queryStringParameters": {"limit": query_value}, |
| 1201 | + "isBase64Encoded": False, |
| 1202 | + } |
| 1203 | + |
| 1204 | + # THEN validation should fail |
| 1205 | + result = app(event, {}) |
| 1206 | + assert result["statusCode"] == 422 |
| 1207 | + |
| 1208 | + |
| 1209 | +def test_annotated_types_interval_in_openapi_schema(): |
| 1210 | + """ |
| 1211 | + Test that annotated_types.Interval constraints are reflected in the OpenAPI schema. |
| 1212 | + """ |
| 1213 | + from annotated_types import Interval |
| 1214 | + |
| 1215 | + app = APIGatewayRestResolver() |
| 1216 | + constrained_int = Annotated[int, Interval(ge=0, le=100)] |
| 1217 | + |
| 1218 | + @app.get("/items") |
| 1219 | + def list_items(limit: Annotated[constrained_int, Query()] = 10): |
| 1220 | + return {"limit": limit} |
| 1221 | + |
| 1222 | + schema = app.get_openapi_schema() |
| 1223 | + |
| 1224 | + # Verify the Query parameter schema includes constraints |
| 1225 | + get_operation = schema.paths["/items"].get |
| 1226 | + limit_param = next(p for p in get_operation.parameters if p.name == "limit") |
| 1227 | + |
| 1228 | + assert limit_param.schema_.type == "integer" |
| 1229 | + assert limit_param.schema_.default == 10 |
| 1230 | + assert limit_param.required is False |
0 commit comments