@@ -308,10 +308,10 @@ def test_invoke_handler_no_config():
308308
309309 # Verify default config was used
310310 operation_update = mock_state .create_checkpoint .call_args [1 ]["operation_update" ]
311- assert (
312- operation_update . to_dict ()[ "ChainedInvokeOptions" ][ "FunctionName" ]
313- == "test_function"
314- )
311+ chained_invoke_options = operation_update . to_dict ()[ "ChainedInvokeOptions" ]
312+ assert chained_invoke_options [ "FunctionName" ] == "test_function"
313+ # tenant_id should be None when not specified
314+ assert "TenantId" not in chained_invoke_options
315315
316316
317317def test_invoke_handler_custom_serdes ():
@@ -533,3 +533,82 @@ def test_invoke_handler_suspend_does_not_raise(mock_suspend):
533533 )
534534
535535 mock_suspend .assert_called_once ()
536+
537+
538+ def test_invoke_handler_with_tenant_id ():
539+ """Test invoke_handler passes tenant_id to checkpoint."""
540+ mock_state = Mock (spec = ExecutionState )
541+ mock_state .durable_execution_arn = "test_arn"
542+ mock_state .get_checkpoint_result .return_value = (
543+ CheckpointedResult .create_not_found ()
544+ )
545+
546+ config = InvokeConfig (tenant_id = "test-tenant-123" )
547+
548+ with pytest .raises (SuspendExecution ):
549+ invoke_handler (
550+ function_name = "test_function" ,
551+ payload = "test_input" ,
552+ state = mock_state ,
553+ operation_identifier = OperationIdentifier ("invoke1" , None , None ),
554+ config = config ,
555+ )
556+
557+ # Verify checkpoint was called with tenant_id
558+ mock_state .create_checkpoint .assert_called_once ()
559+ operation_update = mock_state .create_checkpoint .call_args [1 ]["operation_update" ]
560+ chained_invoke_options = operation_update .to_dict ()["ChainedInvokeOptions" ]
561+ assert chained_invoke_options ["FunctionName" ] == "test_function"
562+ assert chained_invoke_options ["TenantId" ] == "test-tenant-123"
563+
564+
565+ def test_invoke_handler_without_tenant_id ():
566+ """Test invoke_handler without tenant_id doesn't include it in checkpoint."""
567+ mock_state = Mock (spec = ExecutionState )
568+ mock_state .durable_execution_arn = "test_arn"
569+ mock_state .get_checkpoint_result .return_value = (
570+ CheckpointedResult .create_not_found ()
571+ )
572+
573+ config = InvokeConfig (tenant_id = None )
574+
575+ with pytest .raises (SuspendExecution ):
576+ invoke_handler (
577+ function_name = "test_function" ,
578+ payload = "test_input" ,
579+ state = mock_state ,
580+ operation_identifier = OperationIdentifier ("invoke1" , None , None ),
581+ config = config ,
582+ )
583+
584+ # Verify checkpoint was called without tenant_id
585+ mock_state .create_checkpoint .assert_called_once ()
586+ operation_update = mock_state .create_checkpoint .call_args [1 ]["operation_update" ]
587+ chained_invoke_options = operation_update .to_dict ()["ChainedInvokeOptions" ]
588+ assert chained_invoke_options ["FunctionName" ] == "test_function"
589+ assert "TenantId" not in chained_invoke_options
590+
591+
592+ def test_invoke_handler_default_config_no_tenant_id ():
593+ """Test invoke_handler with default config has no tenant_id."""
594+ mock_state = Mock (spec = ExecutionState )
595+ mock_state .durable_execution_arn = "test_arn"
596+ mock_state .get_checkpoint_result .return_value = (
597+ CheckpointedResult .create_not_found ()
598+ )
599+
600+ with pytest .raises (SuspendExecution ):
601+ invoke_handler (
602+ function_name = "test_function" ,
603+ payload = "test_input" ,
604+ state = mock_state ,
605+ operation_identifier = OperationIdentifier ("invoke1" , None , None ),
606+ config = None ,
607+ )
608+
609+ # Verify checkpoint was called without tenant_id
610+ mock_state .create_checkpoint .assert_called_once ()
611+ operation_update = mock_state .create_checkpoint .call_args [1 ]["operation_update" ]
612+ chained_invoke_options = operation_update .to_dict ()["ChainedInvokeOptions" ]
613+ assert chained_invoke_options ["FunctionName" ] == "test_function"
614+ assert "TenantId" not in chained_invoke_options
0 commit comments