File tree Expand file tree Collapse file tree 2 files changed +39
-0
lines changed
src/sagemaker/modules/train/sm_recipes
tests/unit/sagemaker/modules/train/sm_recipes Expand file tree Collapse file tree 2 files changed +39
-0
lines changed Original file line number Diff line number Diff line change @@ -305,6 +305,13 @@ def _get_args_from_nova_recipe(
305305 )
306306 args ["hyperparameters" ]["kms_key" ] = kms_key
307307
308+ # Handle eval custom lambda configuration
309+ if recipe .get ("evaluation" , {}):
310+ processor = recipe .get ("processor" , {})
311+ lambda_arn = processor .get ("lambda_arn" , "" )
312+ if lambda_arn :
313+ args ["hyperparameters" ]["lambda_arn" ] = lambda_arn
314+
308315 _register_custom_resolvers ()
309316
310317 # Resolve Final Recipe
Original file line number Diff line number Diff line change @@ -446,3 +446,35 @@ def test_get_args_from_nova_recipe_with_distillation_errors(test_case):
446446 _get_args_from_nova_recipe (
447447 recipe = recipe , compute = test_case ["compute" ], role = test_case .get ("role" )
448448 )
449+
450+
451+ @pytest .mark .parametrize (
452+ "test_case" ,
453+ [
454+ {
455+ "recipe" : {
456+ "evaluation" : {"task:" : "gen_qa" , "strategy" : "gen_qa" , "metric" : "all" },
457+ "processor" : {
458+ "lambda_arn" : "arn:aws:lambda:us-east-1:123456789012:function:MyLambdaFunction"
459+ },
460+ },
461+ "compute" : Compute (instance_type = "ml.m5.xlarge" , instance_count = 2 ),
462+ "role" : "arn:aws:iam::123456789012:role/SageMakerRole" ,
463+ "expected_args" : {
464+ "compute" : Compute (instance_type = "ml.m5.xlarge" , instance_count = 2 ),
465+ "hyperparameters" : {
466+ "lambda_arn" : "arn:aws:lambda:us-east-1:123456789012:function:MyLambdaFunction" ,
467+ },
468+ "training_image" : None ,
469+ "source_code" : None ,
470+ "distributed" : None ,
471+ },
472+ },
473+ ],
474+ )
475+ def test_get_args_from_nova_recipe_with_evaluation (test_case ):
476+ recipe = OmegaConf .create (test_case ["recipe" ])
477+ args , _ = _get_args_from_nova_recipe (
478+ recipe = recipe , compute = test_case ["compute" ], role = test_case ["role" ]
479+ )
480+ assert args == test_case ["expected_args" ]
You can’t perform that action at this time.
0 commit comments