Skip to content

Commit 3ec77a3

Browse files
FullyTypedAstraea Quinn S
authored andcommitted
fix(sdk): pass item_serdes to executor
1. Pass item_serdes to executor factory methods. 2. Add tests to verify fallback and default behaviour.
1 parent 3e09a17 commit 3ec77a3

File tree

5 files changed

+264
-2
lines changed

5 files changed

+264
-2
lines changed

src/aws_durable_execution_sdk_python/context.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,11 @@ def map_in_child_context() -> BatchResult[R]:
343343
operation_identifier=operation_identifier,
344344
config=ChildConfig(
345345
sub_type=OperationSubType.MAP,
346-
serdes=config.serdes if config is not None else None,
346+
serdes=getattr(config, "serdes", None),
347+
# child_handler should only know the serdes of the parent serdes,
348+
# the item serdes will be passed when we are actually executing
349+
# the branch within its own child_handler.
350+
item_serdes=None,
347351
),
348352
)
349353

@@ -380,7 +384,11 @@ def parallel_in_child_context() -> BatchResult[T]:
380384
operation_identifier=operation_identifier,
381385
config=ChildConfig(
382386
sub_type=OperationSubType.PARALLEL,
383-
serdes=config.serdes if config is not None else None,
387+
serdes=getattr(config, "serdes", None),
388+
# child_handler should only know the serdes of the parent serdes,
389+
# the item serdes will be passed when we are actually executing
390+
# the branch within its own child_handler.
391+
item_serdes=None,
384392
),
385393
)
386394

src/aws_durable_execution_sdk_python/operation/map.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def from_items(
8282
name_prefix="map-item-",
8383
serdes=config.serdes,
8484
summary_generator=config.summary_generator,
85+
item_serdes=config.item_serdes,
8586
)
8687

8788
def execute_item(self, child_context, executable: Executable[Callable]) -> R:

src/aws_durable_execution_sdk_python/operation/parallel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def from_callables(
6969
name_prefix="parallel-branch-",
7070
serdes=config.serdes,
7171
summary_generator=config.summary_generator,
72+
item_serdes=config.item_serdes,
7273
)
7374

7475
def execute_item(self, child_context, executable: Executable[Callable]) -> R: # noqa: PLR6301

tests/operation/map_test.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from unittest.mock import Mock, patch
44

5+
import pytest
6+
57
# Mock the executor.execute method
68
from aws_durable_execution_sdk_python.concurrency import (
79
BatchItem,
@@ -15,6 +17,7 @@
1517
ItemBatcher,
1618
MapConfig,
1719
)
20+
from aws_durable_execution_sdk_python.context import DurableContext
1821
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
1922
from aws_durable_execution_sdk_python.lambda_service import OperationSubType
2023
from aws_durable_execution_sdk_python.operation.map import MapExecutor, map_handler
@@ -750,3 +753,128 @@ def get_checkpoint_result(self, operation_id):
750753
# Verify replay was called, execute was not
751754
mock_replay.assert_called_once()
752755
mock_execute.assert_not_called()
756+
757+
758+
@pytest.mark.parametrize(
759+
("item_serdes", "batch_serdes"),
760+
[
761+
(Mock(), Mock()),
762+
(None, Mock()),
763+
(Mock(), None),
764+
],
765+
)
766+
@patch("aws_durable_execution_sdk_python.operation.child.serialize")
767+
def test_map_item_serialize(mock_serialize, item_serdes, batch_serdes):
768+
"""Test map serializes items with item_serdes or fallback."""
769+
mock_serialize.return_value = '"serialized"'
770+
771+
parent_checkpoint = Mock()
772+
parent_checkpoint.is_succeeded.return_value = False
773+
parent_checkpoint.is_failed.return_value = False
774+
parent_checkpoint.is_started.return_value = False
775+
parent_checkpoint.is_existent.return_value = True
776+
parent_checkpoint.is_replay_children.return_value = False
777+
778+
child_checkpoint = Mock()
779+
child_checkpoint.is_succeeded.return_value = False
780+
child_checkpoint.is_failed.return_value = False
781+
child_checkpoint.is_started.return_value = False
782+
child_checkpoint.is_existent.return_value = True
783+
child_checkpoint.is_replay_children.return_value = False
784+
785+
def get_checkpoint(op_id):
786+
return child_checkpoint if op_id.startswith("child-") else parent_checkpoint
787+
788+
mock_state = Mock()
789+
mock_state.durable_execution_arn = "arn:test"
790+
mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint)
791+
mock_state.create_checkpoint = Mock()
792+
793+
context_map = {}
794+
795+
def create_id(self, i):
796+
ctx_id = id(self)
797+
if ctx_id not in context_map:
798+
context_map[ctx_id] = []
799+
context_map[ctx_id].append(i)
800+
return (
801+
"parent"
802+
if len(context_map) == 1 and len(context_map[ctx_id]) == 1
803+
else f"child-{i}"
804+
)
805+
806+
with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id):
807+
context = DurableContext(state=mock_state)
808+
context.map(
809+
["a", "b"],
810+
lambda ctx, item, idx, items: item,
811+
config=MapConfig(serdes=batch_serdes, item_serdes=item_serdes),
812+
)
813+
814+
expected = item_serdes or batch_serdes
815+
assert mock_serialize.call_args_list[0][1]["serdes"] is expected
816+
assert mock_serialize.call_args_list[0][1]["operation_id"] == "child-0"
817+
assert mock_serialize.call_args_list[1][1]["serdes"] is expected
818+
assert mock_serialize.call_args_list[1][1]["operation_id"] == "child-1"
819+
assert mock_serialize.call_args_list[2][1]["serdes"] is batch_serdes
820+
assert mock_serialize.call_args_list[2][1]["operation_id"] == "parent"
821+
822+
823+
@pytest.mark.parametrize(
824+
("item_serdes", "batch_serdes"),
825+
[
826+
(Mock(), Mock()),
827+
(None, Mock()),
828+
(Mock(), None),
829+
],
830+
)
831+
@patch("aws_durable_execution_sdk_python.operation.child.deserialize")
832+
def test_map_item_deserialize(mock_deserialize, item_serdes, batch_serdes):
833+
"""Test map deserializes items with item_serdes or fallback."""
834+
mock_deserialize.return_value = "deserialized"
835+
836+
parent_checkpoint = Mock()
837+
parent_checkpoint.is_succeeded.return_value = False
838+
parent_checkpoint.is_failed.return_value = False
839+
parent_checkpoint.is_existent.return_value = False
840+
841+
child_checkpoint = Mock()
842+
child_checkpoint.is_succeeded.return_value = True
843+
child_checkpoint.is_failed.return_value = False
844+
child_checkpoint.is_replay_children.return_value = False
845+
child_checkpoint.result = '"cached"'
846+
847+
def get_checkpoint(op_id):
848+
return child_checkpoint if op_id.startswith("child-") else parent_checkpoint
849+
850+
mock_state = Mock()
851+
mock_state.durable_execution_arn = "arn:test"
852+
mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint)
853+
mock_state.create_checkpoint = Mock()
854+
855+
context_map = {}
856+
857+
def create_id(self, i):
858+
ctx_id = id(self)
859+
if ctx_id not in context_map:
860+
context_map[ctx_id] = []
861+
context_map[ctx_id].append(i)
862+
return (
863+
"parent"
864+
if len(context_map) == 1 and len(context_map[ctx_id]) == 1
865+
else f"child-{i}"
866+
)
867+
868+
with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id):
869+
context = DurableContext(state=mock_state)
870+
context.map(
871+
["a", "b"],
872+
lambda ctx, item, idx, items: item,
873+
config=MapConfig(serdes=batch_serdes, item_serdes=item_serdes),
874+
)
875+
876+
expected = item_serdes or batch_serdes
877+
assert mock_deserialize.call_args_list[0][1]["serdes"] is expected
878+
assert mock_deserialize.call_args_list[0][1]["operation_id"] == "child-0"
879+
assert mock_deserialize.call_args_list[1][1]["serdes"] is expected
880+
assert mock_deserialize.call_args_list[1][1]["operation_id"] == "child-1"

tests/operation/parallel_test.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Executable,
1515
)
1616
from aws_durable_execution_sdk_python.config import CompletionConfig, ParallelConfig
17+
from aws_durable_execution_sdk_python.context import DurableContext
1718
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
1819
from aws_durable_execution_sdk_python.lambda_service import OperationSubType
1920
from aws_durable_execution_sdk_python.operation.parallel import (
@@ -734,3 +735,126 @@ def get_checkpoint_result(self, operation_id):
734735
# Verify replay was called, execute was not
735736
mock_replay.assert_called_once()
736737
mock_execute.assert_not_called()
738+
739+
740+
@pytest.mark.parametrize(
741+
("item_serdes", "batch_serdes"),
742+
[
743+
(Mock(), Mock()),
744+
(None, Mock()),
745+
(Mock(), None),
746+
],
747+
)
748+
@patch("aws_durable_execution_sdk_python.operation.child.serialize")
749+
def test_parallel_item_serialize(mock_serialize, item_serdes, batch_serdes):
750+
"""Test parallel serializes branches with item_serdes or fallback."""
751+
mock_serialize.return_value = '"serialized"'
752+
753+
parent_checkpoint = Mock()
754+
parent_checkpoint.is_succeeded.return_value = False
755+
parent_checkpoint.is_failed.return_value = False
756+
parent_checkpoint.is_started.return_value = False
757+
parent_checkpoint.is_existent.return_value = True
758+
parent_checkpoint.is_replay_children.return_value = False
759+
760+
child_checkpoint = Mock()
761+
child_checkpoint.is_succeeded.return_value = False
762+
child_checkpoint.is_failed.return_value = False
763+
child_checkpoint.is_started.return_value = False
764+
child_checkpoint.is_existent.return_value = True
765+
child_checkpoint.is_replay_children.return_value = False
766+
767+
def get_checkpoint(op_id):
768+
return child_checkpoint if op_id.startswith("child-") else parent_checkpoint
769+
770+
mock_state = Mock()
771+
mock_state.durable_execution_arn = "arn:test"
772+
mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint)
773+
mock_state.create_checkpoint = Mock()
774+
775+
context_map = {}
776+
777+
def create_id(self, i):
778+
ctx_id = id(self)
779+
if ctx_id not in context_map:
780+
context_map[ctx_id] = []
781+
context_map[ctx_id].append(i)
782+
return (
783+
"parent"
784+
if len(context_map) == 1 and len(context_map[ctx_id]) == 1
785+
else f"child-{i}"
786+
)
787+
788+
with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id):
789+
context = DurableContext(state=mock_state)
790+
context.parallel(
791+
[lambda ctx: "a", lambda ctx: "b"],
792+
config=ParallelConfig(serdes=batch_serdes, item_serdes=item_serdes),
793+
)
794+
795+
expected = item_serdes or batch_serdes
796+
assert mock_serialize.call_args_list[0][1]["serdes"] is expected
797+
assert mock_serialize.call_args_list[0][1]["operation_id"] == "child-0"
798+
assert mock_serialize.call_args_list[1][1]["serdes"] is expected
799+
assert mock_serialize.call_args_list[1][1]["operation_id"] == "child-1"
800+
assert mock_serialize.call_args_list[2][1]["serdes"] is batch_serdes
801+
assert mock_serialize.call_args_list[2][1]["operation_id"] == "parent"
802+
803+
804+
@pytest.mark.parametrize(
805+
("item_serdes", "batch_serdes"),
806+
[
807+
(Mock(), Mock()),
808+
(None, Mock()),
809+
(Mock(), None),
810+
],
811+
)
812+
@patch("aws_durable_execution_sdk_python.operation.child.deserialize")
813+
def test_parallel_item_deserialize(mock_deserialize, item_serdes, batch_serdes):
814+
"""Test parallel deserializes branches with item_serdes or fallback."""
815+
mock_deserialize.return_value = "deserialized"
816+
817+
parent_checkpoint = Mock()
818+
parent_checkpoint.is_succeeded.return_value = False
819+
parent_checkpoint.is_failed.return_value = False
820+
parent_checkpoint.is_existent.return_value = False
821+
822+
child_checkpoint = Mock()
823+
child_checkpoint.is_succeeded.return_value = True
824+
child_checkpoint.is_failed.return_value = False
825+
child_checkpoint.is_replay_children.return_value = False
826+
child_checkpoint.result = '"cached"'
827+
828+
def get_checkpoint(op_id):
829+
return child_checkpoint if op_id.startswith("child-") else parent_checkpoint
830+
831+
mock_state = Mock()
832+
mock_state.durable_execution_arn = "arn:test"
833+
mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint)
834+
mock_state.create_checkpoint = Mock()
835+
836+
context_map = {}
837+
838+
def create_id(self, i):
839+
ctx_id = id(self)
840+
if ctx_id not in context_map:
841+
context_map[ctx_id] = []
842+
context_map[ctx_id].append(i)
843+
return (
844+
"parent"
845+
if len(context_map) == 1 and len(context_map[ctx_id]) == 1
846+
else f"child-{i}"
847+
)
848+
849+
with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id):
850+
context = DurableContext(state=mock_state)
851+
context.parallel(
852+
[lambda ctx: "a", lambda ctx: "b"],
853+
config=ParallelConfig(serdes=batch_serdes, item_serdes=item_serdes),
854+
)
855+
856+
expected = item_serdes or batch_serdes
857+
assert mock_deserialize.call_args_list[0][1]["serdes"] is expected
858+
assert mock_deserialize.call_args_list[0][1]["operation_id"] == "child-0"
859+
assert mock_deserialize.call_args_list[1][1]["serdes"] is expected
860+
assert mock_deserialize.call_args_list[1][1]["operation_id"] == "child-1"

0 commit comments

Comments
 (0)