Skip to content

Commit c0e26d9

Browse files
committed
feat: set ChainInvoke to default to json serdes
1 parent 0761a6a commit c0e26d9

File tree

5 files changed

+73
-13
lines changed

5 files changed

+73
-13
lines changed

ops/__tests__/test_parse_sdk_branch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ def test():
7373

7474
for input_text, expected in test_cases:
7575
result = parse_sdk_branch(input_text)
76-
if result != expected:
77-
return False
78-
79-
return True
76+
# Assert is expected in test functions
77+
assert result == expected, ( # noqa: S101
78+
f"Expected '{expected}' but got '{result}' for input: {input_text[:50]}..."
79+
)
8080

8181

8282
if __name__ == "__main__":

src/aws_durable_execution_sdk_python/config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,10 +392,12 @@ class InvokeConfig(Generic[P, R]):
392392
from blocking execution indefinitely.
393393
394394
serdes_payload: Custom serialization/deserialization for the payload
395-
sent to the invoked function. If None, uses default JSON serialization.
395+
sent to the invoked function. Defaults to DEFAULT_JSON_SERDES when
396+
not set.
396397
397398
serdes_result: Custom serialization/deserialization for the result
398-
returned from the invoked function. If None, uses default JSON serialization.
399+
returned from the invoked function. Defaults to DEFAULT_JSON_SERDES when
400+
not set.
399401
400402
tenant_id: Optional tenant identifier for multi-tenant isolation.
401403
If provided, the invocation will be scoped to this tenant.

src/aws_durable_execution_sdk_python/operation/invoke.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
ChainedInvokeOptions,
1212
OperationUpdate,
1313
)
14-
from aws_durable_execution_sdk_python.serdes import deserialize, serialize
14+
from aws_durable_execution_sdk_python.serdes import (
15+
DEFAULT_JSON_SERDES,
16+
deserialize,
17+
serialize,
18+
)
1519
from aws_durable_execution_sdk_python.suspend import suspend_with_optional_resume_delay
1620

1721
if TYPE_CHECKING:
@@ -53,7 +57,7 @@ def invoke_handler(
5357
and checkpointed_result.operation.chained_invoke_details.result
5458
):
5559
return deserialize(
56-
serdes=config.serdes_result,
60+
serdes=config.serdes_result or DEFAULT_JSON_SERDES,
5761
data=checkpointed_result.operation.chained_invoke_details.result,
5862
operation_id=operation_identifier.operation_id,
5963
durable_execution_arn=state.durable_execution_arn,
@@ -78,7 +82,7 @@ def invoke_handler(
7882
suspend_with_optional_resume_delay(msg, config.timeout_seconds)
7983

8084
serialized_payload: str = serialize(
81-
serdes=config.serdes_payload,
85+
serdes=config.serdes_payload or DEFAULT_JSON_SERDES,
8286
value=payload,
8387
operation_id=operation_identifier.operation_id,
8488
durable_execution_arn=state.durable_execution_arn,

src/aws_durable_execution_sdk_python/serdes.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -441,8 +441,8 @@ def _to_json_serializable(self, obj: Any) -> Any:
441441
return obj
442442

443443

444-
_DEFAULT_JSON_SERDES: SerDes[Any] = JsonSerDes()
445-
_EXTENDED_TYPES_SERDES: SerDes[Any] = ExtendedTypeSerDes()
444+
DEFAULT_JSON_SERDES: SerDes[Any] = JsonSerDes()
445+
EXTENDED_TYPES_SERDES: SerDes[Any] = ExtendedTypeSerDes()
446446

447447

448448
def serialize(
@@ -463,7 +463,7 @@ def serialize(
463463
FatalError: If serialization fails
464464
"""
465465
serdes_context: SerDesContext = SerDesContext(operation_id, durable_execution_arn)
466-
active_serdes: SerDes[T] = serdes or _EXTENDED_TYPES_SERDES
466+
active_serdes: SerDes[T] = serdes or EXTENDED_TYPES_SERDES
467467
try:
468468
return active_serdes.serialize(value, serdes_context)
469469
except Exception as e:
@@ -493,7 +493,7 @@ def deserialize(
493493
FatalError: If deserialization fails
494494
"""
495495
serdes_context: SerDesContext = SerDesContext(operation_id, durable_execution_arn)
496-
active_serdes: SerDes[T] = serdes or _EXTENDED_TYPES_SERDES
496+
active_serdes: SerDes[T] = serdes or EXTENDED_TYPES_SERDES
497497
try:
498498
return active_serdes.deserialize(data, serdes_context)
499499
except Exception as e:

tests/operation/invoke_test.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,3 +612,57 @@ def test_invoke_handler_default_config_no_tenant_id():
612612
chained_invoke_options = operation_update.to_dict()["ChainedInvokeOptions"]
613613
assert chained_invoke_options["FunctionName"] == "test_function"
614614
assert "TenantId" not in chained_invoke_options
615+
616+
617+
def test_invoke_handler_defaults_to_json_serdes():
618+
"""Test invoke_handler uses DEFAULT_JSON_SERDES when config has no serdes."""
619+
mock_state = Mock(spec=ExecutionState)
620+
mock_state.durable_execution_arn = "test_arn"
621+
mock_state.get_checkpoint_result.return_value = (
622+
CheckpointedResult.create_not_found()
623+
)
624+
625+
config = InvokeConfig[dict, dict](serdes_payload=None, serdes_result=None)
626+
payload = {"key": "value", "number": 42}
627+
628+
with pytest.raises(SuspendExecution):
629+
invoke_handler(
630+
function_name="test_function",
631+
payload=payload,
632+
state=mock_state,
633+
operation_identifier=OperationIdentifier("invoke_json", None, None),
634+
config=config,
635+
)
636+
637+
# Verify JSON serialization was used (not extended types)
638+
operation_update = mock_state.create_checkpoint.call_args[1]["operation_update"]
639+
assert operation_update.payload == json.dumps(payload)
640+
641+
642+
def test_invoke_handler_result_defaults_to_json_serdes():
643+
"""Test invoke_handler uses DEFAULT_JSON_SERDES for result deserialization."""
644+
mock_state = Mock(spec=ExecutionState)
645+
mock_state.durable_execution_arn = "test_arn"
646+
647+
result_data = {"key": "value", "number": 42}
648+
operation = Operation(
649+
operation_id="invoke_result_json",
650+
operation_type=OperationType.CHAINED_INVOKE,
651+
status=OperationStatus.SUCCEEDED,
652+
chained_invoke_details=ChainedInvokeDetails(result=json.dumps(result_data)),
653+
)
654+
mock_result = CheckpointedResult.create_from_operation(operation)
655+
mock_state.get_checkpoint_result.return_value = mock_result
656+
657+
config = InvokeConfig[dict, dict](serdes_payload=None, serdes_result=None)
658+
659+
result = invoke_handler(
660+
function_name="test_function",
661+
payload={"input": "data"},
662+
state=mock_state,
663+
operation_identifier=OperationIdentifier("invoke_result_json", None, None),
664+
config=config,
665+
)
666+
667+
# Verify JSON deserialization was used (not extended types)
668+
assert result == result_data

0 commit comments

Comments
 (0)