Skip to content

Commit bcd5348

Browse files
authored
feature: add eval custom lambda arn to hyperparameters (#5272)
1 parent fd566bd commit bcd5348

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

src/sagemaker/modules/train/sm_recipes/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,13 @@ def _get_args_from_nova_recipe(
305305
)
306306
args["hyperparameters"]["kms_key"] = kms_key
307307

308+
# Handle eval custom lambda configuration
309+
if recipe.get("evaluation", {}):
310+
processor = recipe.get("processor", {})
311+
lambda_arn = processor.get("lambda_arn", "")
312+
if lambda_arn:
313+
args["hyperparameters"]["lambda_arn"] = lambda_arn
314+
308315
_register_custom_resolvers()
309316

310317
# Resolve Final Recipe

tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,3 +446,35 @@ def test_get_args_from_nova_recipe_with_distillation_errors(test_case):
446446
_get_args_from_nova_recipe(
447447
recipe=recipe, compute=test_case["compute"], role=test_case.get("role")
448448
)
449+
450+
451+
@pytest.mark.parametrize(
452+
"test_case",
453+
[
454+
{
455+
"recipe": {
456+
"evaluation": {"task:": "gen_qa", "strategy": "gen_qa", "metric": "all"},
457+
"processor": {
458+
"lambda_arn": "arn:aws:lambda:us-east-1:123456789012:function:MyLambdaFunction"
459+
},
460+
},
461+
"compute": Compute(instance_type="ml.m5.xlarge", instance_count=2),
462+
"role": "arn:aws:iam::123456789012:role/SageMakerRole",
463+
"expected_args": {
464+
"compute": Compute(instance_type="ml.m5.xlarge", instance_count=2),
465+
"hyperparameters": {
466+
"lambda_arn": "arn:aws:lambda:us-east-1:123456789012:function:MyLambdaFunction",
467+
},
468+
"training_image": None,
469+
"source_code": None,
470+
"distributed": None,
471+
},
472+
},
473+
],
474+
)
475+
def test_get_args_from_nova_recipe_with_evaluation(test_case):
476+
recipe = OmegaConf.create(test_case["recipe"])
477+
args, _ = _get_args_from_nova_recipe(
478+
recipe=recipe, compute=test_case["compute"], role=test_case["role"]
479+
)
480+
assert args == test_case["expected_args"]

0 commit comments

Comments
 (0)