File tree Expand file tree Collapse file tree 2 files changed +39
-0
lines changed
src/sagemaker/modules/train/sm_recipes
tests/unit/sagemaker/modules/train/sm_recipes Expand file tree Collapse file tree 2 files changed +39
-0
lines changed Original file line number Diff line number Diff line change @@ -305,6 +305,13 @@ def _get_args_from_nova_recipe(
305
305
)
306
306
args ["hyperparameters" ]["kms_key" ] = kms_key
307
307
308
+ # Handle eval custom lambda configuration
309
+ if recipe .get ("evaluation" , {}):
310
+ processor = recipe .get ("processor" , {})
311
+ lambda_arn = processor .get ("lambda_arn" , "" )
312
+ if lambda_arn :
313
+ args ["hyperparameters" ]["lambda_arn" ] = lambda_arn
314
+
308
315
_register_custom_resolvers ()
309
316
310
317
# Resolve Final Recipe
Original file line number Diff line number Diff line change @@ -446,3 +446,35 @@ def test_get_args_from_nova_recipe_with_distillation_errors(test_case):
446
446
_get_args_from_nova_recipe (
447
447
recipe = recipe , compute = test_case ["compute" ], role = test_case .get ("role" )
448
448
)
449
+
450
+
451
+ @pytest .mark .parametrize (
452
+ "test_case" ,
453
+ [
454
+ {
455
+ "recipe" : {
456
+ "evaluation" : {"task:" : "gen_qa" , "strategy" : "gen_qa" , "metric" : "all" },
457
+ "processor" : {
458
+ "lambda_arn" : "arn:aws:lambda:us-east-1:123456789012:function:MyLambdaFunction"
459
+ },
460
+ },
461
+ "compute" : Compute (instance_type = "ml.m5.xlarge" , instance_count = 2 ),
462
+ "role" : "arn:aws:iam::123456789012:role/SageMakerRole" ,
463
+ "expected_args" : {
464
+ "compute" : Compute (instance_type = "ml.m5.xlarge" , instance_count = 2 ),
465
+ "hyperparameters" : {
466
+ "lambda_arn" : "arn:aws:lambda:us-east-1:123456789012:function:MyLambdaFunction" ,
467
+ },
468
+ "training_image" : None ,
469
+ "source_code" : None ,
470
+ "distributed" : None ,
471
+ },
472
+ },
473
+ ],
474
+ )
475
+ def test_get_args_from_nova_recipe_with_evaluation (test_case ):
476
+ recipe = OmegaConf .create (test_case ["recipe" ])
477
+ args , _ = _get_args_from_nova_recipe (
478
+ recipe = recipe , compute = test_case ["compute" ], role = test_case ["role" ]
479
+ )
480
+ assert args == test_case ["expected_args" ]
You can’t perform that action at this time.
0 commit comments