Skip to content

Commit f31ec87

Browse files
author
Alex Wang
committed
fix: fix callback serdes
- Add a new passthrough serdes - If the customer does not provide customized serdes for callback handler, use passthrough serdes for callback result because they are not created by sdk, instead, they are created by backend with customer data.
1 parent a04015e commit f31ec87

File tree

4 files changed

+50
-7
lines changed

4 files changed

+50
-7
lines changed

src/aws_durable_execution_sdk_python/context.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@
3636
from aws_durable_execution_sdk_python.operation.wait_for_condition import (
3737
wait_for_condition_handler,
3838
)
39-
from aws_durable_execution_sdk_python.serdes import SerDes, deserialize
39+
from aws_durable_execution_sdk_python.serdes import (
40+
PassThroughSerDes,
41+
SerDes,
42+
deserialize,
43+
)
4044
from aws_durable_execution_sdk_python.state import ExecutionState # noqa: TCH001
4145
from aws_durable_execution_sdk_python.threading import OrderedCounter
4246
from aws_durable_execution_sdk_python.types import (
@@ -144,7 +148,7 @@ def result(self) -> T | None:
144148
return None # type: ignore
145149

146150
return deserialize(
147-
serdes=self.serdes,
151+
serdes=self.serdes if self.serdes is not None else PassThroughSerDes(),
148152
data=checkpointed_result.result,
149153
operation_id=self.operation_id,
150154
durable_execution_arn=self.state.durable_execution_arn,

src/aws_durable_execution_sdk_python/serdes.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -316,11 +316,7 @@ def encode(self, obj: Any) -> EncodedValue:
316316
def decode(self, tag: TypeTag, value: Any) -> Any:
317317
match tag:
318318
case (
319-
TypeTag.NONE
320-
| TypeTag.STR
321-
| TypeTag.BOOL
322-
| TypeTag.INT
323-
| TypeTag.FLOAT
319+
TypeTag.NONE | TypeTag.STR | TypeTag.BOOL | TypeTag.INT | TypeTag.FLOAT
324320
):
325321
return self.primitive_codec.decode(tag, value)
326322
case TypeTag.BYTES:
@@ -372,6 +368,14 @@ def is_primitive(obj: Any) -> bool:
372368
return False
373369

374370

371+
class PassThroughSerDes(SerDes[T]):
372+
def serialize(self, value: T, _: SerDesContext) -> str: # noqa: PLR6301
373+
return value # type: ignore
374+
375+
def deserialize(self, data: str, _: SerDesContext) -> T: # noqa: PLR6301
376+
return data # type: ignore
377+
378+
375379
class JsonSerDes(SerDes[T]):
376380
def serialize(self, value: T, _: SerDesContext) -> str: # noqa: PLR6301
377381
return json.dumps(value)

tests/context_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,28 @@ def test_callback_result_succeeded():
7575
callback = Callback("callback1", "op1", mock_state)
7676
result = callback.result()
7777

78+
assert result == '"success_result"'
79+
mock_state.get_checkpoint_result.assert_called_once_with("op1")
80+
81+
82+
def test_callback_result_succeeded_with_plain_str():
83+
"""Test Callback.result() when operation succeeded."""
84+
mock_state = Mock(spec=ExecutionState)
85+
mock_state.durable_execution_arn = "test_arn"
86+
operation = Operation(
87+
operation_id="op1",
88+
operation_type=OperationType.CALLBACK,
89+
status=OperationStatus.SUCCEEDED,
90+
callback_details=CallbackDetails(
91+
callback_id="callback1", result="success_result"
92+
),
93+
)
94+
mock_result = CheckpointedResult.create_from_operation(operation)
95+
mock_state.get_checkpoint_result.return_value = mock_result
96+
97+
callback = Callback("callback1", "op1", mock_state)
98+
result = callback.result()
99+
78100
assert result == "success_result"
79101
mock_state.get_checkpoint_result.assert_called_once_with("op1")
80102

tests/serdes_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
EncodedValue,
2929
ExtendedTypeSerDes,
3030
JsonSerDes,
31+
PassThroughSerDes,
3132
PrimitiveCodec,
3233
SerDes,
3334
SerDesContext,
@@ -737,6 +738,18 @@ def test_extended_serdes_errors():
737738
# endregion
738739

739740

741+
def test_pass_through_serdes():
742+
serdes = PassThroughSerDes()
743+
744+
data = '"name": "test", "value": 123'
745+
serialized = serialize(serdes, data, "test-op", "test-arn")
746+
assert isinstance(serialized, str)
747+
assert serialized == '"name": "test", "value": 123'
748+
# Dict uses envelope format, so roundtrip through deserialize
749+
deserialized = deserialize(serdes, serialized, "test-op", "test-arn")
750+
assert deserialized == data
751+
752+
740753
# region EnvelopeSerDes Performance and Edge Cases
741754
def test_envelope_large_data_structure():
742755
"""Test with reasonably large data."""

0 commit comments

Comments
 (0)