|
2 | 2 |
|
3 | 3 | from unittest.mock import Mock, patch |
4 | 4 |
|
| 5 | +import pytest |
| 6 | + |
5 | 7 | # Mock the executor.execute method |
6 | 8 | from aws_durable_execution_sdk_python.concurrency import ( |
7 | 9 | BatchItem, |
|
15 | 17 | ItemBatcher, |
16 | 18 | MapConfig, |
17 | 19 | ) |
| 20 | +from aws_durable_execution_sdk_python.context import DurableContext |
18 | 21 | from aws_durable_execution_sdk_python.identifier import OperationIdentifier |
19 | 22 | from aws_durable_execution_sdk_python.lambda_service import OperationSubType |
20 | 23 | from aws_durable_execution_sdk_python.operation.map import MapExecutor, map_handler |
@@ -750,3 +753,128 @@ def get_checkpoint_result(self, operation_id): |
750 | 753 | # Verify replay was called, execute was not |
751 | 754 | mock_replay.assert_called_once() |
752 | 755 | mock_execute.assert_not_called() |
| 756 | + |
| 757 | + |
| 758 | +@pytest.mark.parametrize( |
| 759 | + ("item_serdes", "batch_serdes"), |
| 760 | + [ |
| 761 | + (Mock(), Mock()), |
| 762 | + (None, Mock()), |
| 763 | + (Mock(), None), |
| 764 | + ], |
| 765 | +) |
| 766 | +@patch("aws_durable_execution_sdk_python.operation.child.serialize") |
| 767 | +def test_map_item_serialize(mock_serialize, item_serdes, batch_serdes): |
| 768 | + """Test map serializes items with item_serdes or fallback.""" |
| 769 | + mock_serialize.return_value = '"serialized"' |
| 770 | + |
| 771 | + parent_checkpoint = Mock() |
| 772 | + parent_checkpoint.is_succeeded.return_value = False |
| 773 | + parent_checkpoint.is_failed.return_value = False |
| 774 | + parent_checkpoint.is_started.return_value = False |
| 775 | + parent_checkpoint.is_existent.return_value = True |
| 776 | + parent_checkpoint.is_replay_children.return_value = False |
| 777 | + |
| 778 | + child_checkpoint = Mock() |
| 779 | + child_checkpoint.is_succeeded.return_value = False |
| 780 | + child_checkpoint.is_failed.return_value = False |
| 781 | + child_checkpoint.is_started.return_value = False |
| 782 | + child_checkpoint.is_existent.return_value = True |
| 783 | + child_checkpoint.is_replay_children.return_value = False |
| 784 | + |
| 785 | + def get_checkpoint(op_id): |
| 786 | + return child_checkpoint if op_id.startswith("child-") else parent_checkpoint |
| 787 | + |
| 788 | + mock_state = Mock() |
| 789 | + mock_state.durable_execution_arn = "arn:test" |
| 790 | + mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) |
| 791 | + mock_state.create_checkpoint = Mock() |
| 792 | + |
| 793 | + context_map = {} |
| 794 | + |
| 795 | + def create_id(self, i): |
| 796 | + ctx_id = id(self) |
| 797 | + if ctx_id not in context_map: |
| 798 | + context_map[ctx_id] = [] |
| 799 | + context_map[ctx_id].append(i) |
| 800 | + return ( |
| 801 | + "parent" |
| 802 | + if len(context_map) == 1 and len(context_map[ctx_id]) == 1 |
| 803 | + else f"child-{i}" |
| 804 | + ) |
| 805 | + |
| 806 | + with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id): |
| 807 | + context = DurableContext(state=mock_state) |
| 808 | + context.map( |
| 809 | + ["a", "b"], |
| 810 | + lambda ctx, item, idx, items: item, |
| 811 | + config=MapConfig(serdes=batch_serdes, item_serdes=item_serdes), |
| 812 | + ) |
| 813 | + |
| 814 | + expected = item_serdes or batch_serdes |
| 815 | + assert mock_serialize.call_args_list[0][1]["serdes"] is expected |
| 816 | + assert mock_serialize.call_args_list[0][1]["operation_id"] == "child-0" |
| 817 | + assert mock_serialize.call_args_list[1][1]["serdes"] is expected |
| 818 | + assert mock_serialize.call_args_list[1][1]["operation_id"] == "child-1" |
| 819 | + assert mock_serialize.call_args_list[2][1]["serdes"] is batch_serdes |
| 820 | + assert mock_serialize.call_args_list[2][1]["operation_id"] == "parent" |
| 821 | + |
| 822 | + |
| 823 | +@pytest.mark.parametrize( |
| 824 | + ("item_serdes", "batch_serdes"), |
| 825 | + [ |
| 826 | + (Mock(), Mock()), |
| 827 | + (None, Mock()), |
| 828 | + (Mock(), None), |
| 829 | + ], |
| 830 | +) |
| 831 | +@patch("aws_durable_execution_sdk_python.operation.child.deserialize") |
| 832 | +def test_map_item_deserialize(mock_deserialize, item_serdes, batch_serdes): |
| 833 | + """Test map deserializes items with item_serdes or fallback.""" |
| 834 | + mock_deserialize.return_value = "deserialized" |
| 835 | + |
| 836 | + parent_checkpoint = Mock() |
| 837 | + parent_checkpoint.is_succeeded.return_value = False |
| 838 | + parent_checkpoint.is_failed.return_value = False |
| 839 | + parent_checkpoint.is_existent.return_value = False |
| 840 | + |
| 841 | + child_checkpoint = Mock() |
| 842 | + child_checkpoint.is_succeeded.return_value = True |
| 843 | + child_checkpoint.is_failed.return_value = False |
| 844 | + child_checkpoint.is_replay_children.return_value = False |
| 845 | + child_checkpoint.result = '"cached"' |
| 846 | + |
| 847 | + def get_checkpoint(op_id): |
| 848 | + return child_checkpoint if op_id.startswith("child-") else parent_checkpoint |
| 849 | + |
| 850 | + mock_state = Mock() |
| 851 | + mock_state.durable_execution_arn = "arn:test" |
| 852 | + mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) |
| 853 | + mock_state.create_checkpoint = Mock() |
| 854 | + |
| 855 | + context_map = {} |
| 856 | + |
| 857 | + def create_id(self, i): |
| 858 | + ctx_id = id(self) |
| 859 | + if ctx_id not in context_map: |
| 860 | + context_map[ctx_id] = [] |
| 861 | + context_map[ctx_id].append(i) |
| 862 | + return ( |
| 863 | + "parent" |
| 864 | + if len(context_map) == 1 and len(context_map[ctx_id]) == 1 |
| 865 | + else f"child-{i}" |
| 866 | + ) |
| 867 | + |
| 868 | + with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id): |
| 869 | + context = DurableContext(state=mock_state) |
| 870 | + context.map( |
| 871 | + ["a", "b"], |
| 872 | + lambda ctx, item, idx, items: item, |
| 873 | + config=MapConfig(serdes=batch_serdes, item_serdes=item_serdes), |
| 874 | + ) |
| 875 | + |
| 876 | + expected = item_serdes or batch_serdes |
| 877 | + assert mock_deserialize.call_args_list[0][1]["serdes"] is expected |
| 878 | + assert mock_deserialize.call_args_list[0][1]["operation_id"] == "child-0" |
| 879 | + assert mock_deserialize.call_args_list[1][1]["serdes"] is expected |
| 880 | + assert mock_deserialize.call_args_list[1][1]["operation_id"] == "child-1" |
0 commit comments