Skip to content

Commit e848c76

Browse files
author
Astraea Quinn S
authored
Add tenant_id support for invoke operations (#195)
* Add tenant_id support for invoke operations
1 parent 00b195d commit e848c76

File tree

7 files changed

+183
-9
lines changed

7 files changed

+183
-9
lines changed

src/aws_durable_execution_sdk_python/config.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,12 +378,34 @@ class MapConfig:
378378
summary_generator: SummaryGenerator | None = None
379379

380380

381-
@dataclass
381+
@dataclass(frozen=True)
382382
class InvokeConfig(Generic[P, R]):
383+
"""
384+
Configuration for invoke operations.
385+
386+
This class configures how function invocations are executed, including
387+
timeout behavior, serialization, and tenant isolation.
388+
389+
Args:
390+
timeout: Maximum duration to wait for the invoked function to complete.
391+
Default is no timeout. Use this to prevent long-running invocations
392+
from blocking execution indefinitely.
393+
394+
serdes_payload: Custom serialization/deserialization for the payload
395+
sent to the invoked function. If None, uses default JSON serialization.
396+
397+
serdes_result: Custom serialization/deserialization for the result
398+
returned from the invoked function. If None, uses default JSON serialization.
399+
400+
tenant_id: Optional tenant identifier for multi-tenant isolation.
401+
If provided, the invocation will be scoped to this tenant.
402+
"""
403+
383404
# retry_strategy: Callable[[Exception, int], RetryDecision] | None = None
384405
timeout: Duration = field(default_factory=Duration)
385406
serdes_payload: SerDes[P] | None = None
386407
serdes_result: SerDes[R] | None = None
408+
tenant_id: str | None = None
387409

388410
@property
389411
def timeout_seconds(self) -> int:

src/aws_durable_execution_sdk_python/lambda_service.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,17 +303,22 @@ class ChainedInvokeOptions:
303303
"""
304304

305305
function_name: str
306+
tenant_id: str | None = None
306307

307308
@classmethod
308309
def from_dict(cls, data: MutableMapping[str, Any]) -> ChainedInvokeOptions:
309310
return cls(
310311
function_name=data["FunctionName"],
312+
tenant_id=data.get("TenantId"),
311313
)
312314

313315
def to_dict(self) -> MutableMapping[str, Any]:
314316
result: MutableMapping[str, Any] = {
315317
"FunctionName": self.function_name,
316318
}
319+
if self.tenant_id is not None:
320+
result["TenantId"] = self.tenant_id
321+
317322
return result
318323

319324

src/aws_durable_execution_sdk_python/operation/invoke.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def invoke_handler(
4040

4141
if not config:
4242
config = InvokeConfig[P, R]()
43+
tenant_id = config.tenant_id
4344

4445
# Check if we have existing step data
4546
checkpointed_result = state.get_checkpoint_result(operation_identifier.operation_id)
@@ -87,7 +88,10 @@ def invoke_handler(
8788
start_operation: OperationUpdate = OperationUpdate.create_invoke_start(
8889
identifier=operation_identifier,
8990
payload=serialized_payload,
90-
chained_invoke_options=ChainedInvokeOptions(function_name=function_name),
91+
chained_invoke_options=ChainedInvokeOptions(
92+
function_name=function_name,
93+
tenant_id=tenant_id,
94+
),
9195
)
9296

9397
# Checkpoint invoke START with blocking (is_sync=True, default).

tests/config_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
ChildConfig,
1111
CompletionConfig,
1212
Duration,
13+
InvokeConfig,
1314
ItemBatcher,
1415
ItemsPerBatchUnit,
1516
MapConfig,
@@ -275,3 +276,16 @@ def test_step_future_without_name():
275276

276277
result = step_future.result()
277278
assert result == 42
279+
280+
281+
def test_invoke_config_defaults():
282+
"""Test InvokeConfig defaults."""
283+
config = InvokeConfig()
284+
assert config.tenant_id is None
285+
assert config.timeout_seconds == 0
286+
287+
288+
def test_invoke_config_with_tenant_id():
289+
"""Test InvokeConfig with explicit tenant_id."""
290+
config = InvokeConfig(tenant_id="test-tenant")
291+
assert config.tenant_id == "test-tenant"

tests/context_test.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -662,9 +662,11 @@ def test_invoke_with_custom_serdes(mock_handler):
662662
"arn:aws:durable:us-east-1:123456789012:execution/test"
663663
)
664664

665+
payload_serdes = CustomDictSerDes()
666+
result_serdes = CustomDictSerDes()
665667
config = InvokeConfig[dict, dict](
666-
serdes_payload=CustomDictSerDes(),
667-
serdes_result=CustomDictSerDes(),
668+
serdes_payload=payload_serdes,
669+
serdes_result=result_serdes,
668670
timeout=Duration.from_minutes(1),
669671
)
670672

@@ -1680,3 +1682,41 @@ def test_operation_id_generation_unique():
16801682

16811683
for i in range(len(ids) - 1):
16821684
assert ids[i] != ids[i + 1]
1685+
1686+
1687+
@patch("aws_durable_execution_sdk_python.context.invoke_handler")
1688+
def test_invoke_with_explicit_tenant_id(mock_handler):
1689+
"""Test invoke with explicit tenant_id in config."""
1690+
mock_handler.return_value = "result"
1691+
mock_state = Mock(spec=ExecutionState)
1692+
mock_state.durable_execution_arn = (
1693+
"arn:aws:durable:us-east-1:123456789012:execution/test"
1694+
)
1695+
1696+
config = InvokeConfig(tenant_id="explicit-tenant")
1697+
context = DurableContext(state=mock_state)
1698+
1699+
result = context.invoke("test_function", "payload", config=config)
1700+
1701+
assert result == "result"
1702+
call_args = mock_handler.call_args[1]
1703+
assert call_args["config"].tenant_id == "explicit-tenant"
1704+
1705+
1706+
@patch("aws_durable_execution_sdk_python.context.invoke_handler")
1707+
def test_invoke_without_tenant_id_defaults_to_none(mock_handler):
1708+
"""Test invoke without tenant_id defaults to None."""
1709+
mock_handler.return_value = "result"
1710+
mock_state = Mock(spec=ExecutionState)
1711+
mock_state.durable_execution_arn = (
1712+
"arn:aws:durable:us-east-1:123456789012:execution/test"
1713+
)
1714+
1715+
context = DurableContext(state=mock_state)
1716+
1717+
result = context.invoke("test_function", "payload")
1718+
1719+
assert result == "result"
1720+
# Config should be None when not provided
1721+
call_args = mock_handler.call_args[1]
1722+
assert call_args["config"] is None

tests/lambda_service_test.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,16 +400,26 @@ def test_callback_options_from_dict_partial():
400400

401401
def test_invoke_options_from_dict():
402402
"""Test ChainedInvokeOptions.from_dict method."""
403-
data = {"FunctionName": "test-function", "TimeoutSeconds": 120}
403+
data = {"FunctionName": "test-function", "TenantId": "test-tenant"}
404404
options = ChainedInvokeOptions.from_dict(data)
405405
assert options.function_name == "test-function"
406+
assert options.tenant_id == "test-tenant"
406407

407408

408409
def test_invoke_options_from_dict_required_only():
409410
"""Test ChainedInvokeOptions.from_dict with only required field."""
410411
data = {"FunctionName": "test-function"}
411412
options = ChainedInvokeOptions.from_dict(data)
412413
assert options.function_name == "test-function"
414+
assert options.tenant_id is None
415+
416+
417+
def test_invoke_options_from_dict_with_none_tenant():
418+
"""Test ChainedInvokeOptions.from_dict with explicit None tenant_id."""
419+
data = {"FunctionName": "test-function", "TenantId": None}
420+
options = ChainedInvokeOptions.from_dict(data)
421+
assert options.function_name == "test-function"
422+
assert options.tenant_id is None
413423

414424

415425
def test_context_options_from_dict():

tests/operation/invoke_test.py

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -308,10 +308,10 @@ def test_invoke_handler_no_config():
308308

309309
# Verify default config was used
310310
operation_update = mock_state.create_checkpoint.call_args[1]["operation_update"]
311-
assert (
312-
operation_update.to_dict()["ChainedInvokeOptions"]["FunctionName"]
313-
== "test_function"
314-
)
311+
chained_invoke_options = operation_update.to_dict()["ChainedInvokeOptions"]
312+
assert chained_invoke_options["FunctionName"] == "test_function"
313+
# tenant_id should be None when not specified
314+
assert "TenantId" not in chained_invoke_options
315315

316316

317317
def test_invoke_handler_custom_serdes():
@@ -533,3 +533,82 @@ def test_invoke_handler_suspend_does_not_raise(mock_suspend):
533533
)
534534

535535
mock_suspend.assert_called_once()
536+
537+
538+
def test_invoke_handler_with_tenant_id():
539+
"""Test invoke_handler passes tenant_id to checkpoint."""
540+
mock_state = Mock(spec=ExecutionState)
541+
mock_state.durable_execution_arn = "test_arn"
542+
mock_state.get_checkpoint_result.return_value = (
543+
CheckpointedResult.create_not_found()
544+
)
545+
546+
config = InvokeConfig(tenant_id="test-tenant-123")
547+
548+
with pytest.raises(SuspendExecution):
549+
invoke_handler(
550+
function_name="test_function",
551+
payload="test_input",
552+
state=mock_state,
553+
operation_identifier=OperationIdentifier("invoke1", None, None),
554+
config=config,
555+
)
556+
557+
# Verify checkpoint was called with tenant_id
558+
mock_state.create_checkpoint.assert_called_once()
559+
operation_update = mock_state.create_checkpoint.call_args[1]["operation_update"]
560+
chained_invoke_options = operation_update.to_dict()["ChainedInvokeOptions"]
561+
assert chained_invoke_options["FunctionName"] == "test_function"
562+
assert chained_invoke_options["TenantId"] == "test-tenant-123"
563+
564+
565+
def test_invoke_handler_without_tenant_id():
566+
"""Test invoke_handler without tenant_id doesn't include it in checkpoint."""
567+
mock_state = Mock(spec=ExecutionState)
568+
mock_state.durable_execution_arn = "test_arn"
569+
mock_state.get_checkpoint_result.return_value = (
570+
CheckpointedResult.create_not_found()
571+
)
572+
573+
config = InvokeConfig(tenant_id=None)
574+
575+
with pytest.raises(SuspendExecution):
576+
invoke_handler(
577+
function_name="test_function",
578+
payload="test_input",
579+
state=mock_state,
580+
operation_identifier=OperationIdentifier("invoke1", None, None),
581+
config=config,
582+
)
583+
584+
# Verify checkpoint was called without tenant_id
585+
mock_state.create_checkpoint.assert_called_once()
586+
operation_update = mock_state.create_checkpoint.call_args[1]["operation_update"]
587+
chained_invoke_options = operation_update.to_dict()["ChainedInvokeOptions"]
588+
assert chained_invoke_options["FunctionName"] == "test_function"
589+
assert "TenantId" not in chained_invoke_options
590+
591+
592+
def test_invoke_handler_default_config_no_tenant_id():
593+
"""Test invoke_handler with default config has no tenant_id."""
594+
mock_state = Mock(spec=ExecutionState)
595+
mock_state.durable_execution_arn = "test_arn"
596+
mock_state.get_checkpoint_result.return_value = (
597+
CheckpointedResult.create_not_found()
598+
)
599+
600+
with pytest.raises(SuspendExecution):
601+
invoke_handler(
602+
function_name="test_function",
603+
payload="test_input",
604+
state=mock_state,
605+
operation_identifier=OperationIdentifier("invoke1", None, None),
606+
config=None,
607+
)
608+
609+
# Verify checkpoint was called without tenant_id
610+
mock_state.create_checkpoint.assert_called_once()
611+
operation_update = mock_state.create_checkpoint.call_args[1]["operation_update"]
612+
chained_invoke_options = operation_update.to_dict()["ChainedInvokeOptions"]
613+
assert chained_invoke_options["FunctionName"] == "test_function"
614+
assert "TenantId" not in chained_invoke_options

0 commit comments

Comments
 (0)