Skip to content

Commit b06d91d

Browse files
authored
Add nova custom lambda in hyperparameter from estimator (#5282)
* Add nova custom lambda in hyperparameter from estimator * Add nova custom lambda in hyperparameter from estimator
1 parent 7b865f5 commit b06d91d

File tree

4 files changed

+53
-2
lines changed

4 files changed

+53
-2
lines changed

src/sagemaker/modules/train/sm_recipes/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def _get_args_from_nova_recipe(
310310
processor = recipe.get("processor", {})
311311
lambda_arn = processor.get("lambda_arn", "")
312312
if lambda_arn:
313-
args["hyperparameters"]["lambda_arn"] = lambda_arn
313+
args["hyperparameters"]["eval_lambda_arn"] = lambda_arn
314314

315315
_register_custom_resolvers()
316316

src/sagemaker/pytorch/estimator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,6 +1224,13 @@ def _setup_for_nova_recipe(
12241224
)
12251225
args["hyperparameters"]["kms_key"] = kms_key
12261226

1227+
# Handle eval custom lambda configuration
1228+
if recipe.get("evaluation", {}):
1229+
processor = recipe.get("processor", {})
1230+
lambda_arn = processor.get("lambda_arn", "")
1231+
if lambda_arn:
1232+
args["hyperparameters"]["eval_lambda_arn"] = lambda_arn
1233+
12271234
# Resolve and save the final recipe
12281235
self._recipe_resolve_and_save(recipe, recipe_name, args["source_dir"])
12291236

tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def test_get_args_from_nova_recipe_with_distillation_errors(test_case):
463463
"expected_args": {
464464
"compute": Compute(instance_type="ml.m5.xlarge", instance_count=2),
465465
"hyperparameters": {
466-
"lambda_arn": "arn:aws:lambda:us-east-1:123456789012:function:MyLambdaFunction",
466+
"eval_lambda_arn": "arn:aws:lambda:us-east-1:123456789012:function:MyLambdaFunction",
467467
},
468468
"training_image": None,
469469
"source_code": None,

tests/unit/test_pytorch_nova.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,50 @@ def test_framework_hyperparameters_nova():
684684
assert hyperparams["bool_param"] == "true"
685685

686686

687+
@patch("sagemaker.pytorch.estimator.PyTorch._recipe_resolve_and_save")
688+
def test_setup_for_nova_recipe_with_evaluation_lambda(mock_resolve_save, sagemaker_session):
689+
"""Test that _setup_for_nova_recipe correctly handles evaluation lambda configuration."""
690+
# Create a mock recipe with evaluation and processor config
691+
recipe = OmegaConf.create(
692+
{
693+
"run": {
694+
"model_type": "amazon.nova.foobar3",
695+
"model_name_or_path": "foobar/foobar-3-8b",
696+
"replicas": 1,
697+
},
698+
"evaluation": {"task:": "gen_qa", "strategy": "gen_qa", "metric": "all"},
699+
"processor": {
700+
"lambda_arn": "arn:aws:lambda:us-west-2:123456789012:function:eval-function"
701+
},
702+
}
703+
)
704+
705+
with patch(
706+
"sagemaker.pytorch.estimator.PyTorch._recipe_load", return_value=("nova_recipe", recipe)
707+
):
708+
mock_resolve_save.return_value = recipe
709+
710+
pytorch = PyTorch(
711+
training_recipe="nova_recipe",
712+
role=ROLE,
713+
sagemaker_session=sagemaker_session,
714+
instance_count=INSTANCE_COUNT,
715+
instance_type=INSTANCE_TYPE_GPU,
716+
image_uri=IMAGE_URI,
717+
framework_version="1.13.1",
718+
py_version="py3",
719+
)
720+
721+
# Check that the Nova recipe was correctly identified
722+
assert pytorch.is_nova_recipe is True
723+
724+
# Verify that eval_lambda_arn hyperparameter was set correctly
725+
assert (
726+
pytorch._hyperparameters.get("eval_lambda_arn")
727+
== "arn:aws:lambda:us-west-2:123456789012:function:eval-function"
728+
)
729+
730+
687731
@patch("sagemaker.pytorch.estimator.PyTorch._recipe_resolve_and_save")
688732
def test_setup_for_nova_recipe_with_distillation(mock_resolve_save, sagemaker_session):
689733
"""Test that _setup_for_nova_recipe correctly handles distillation configurations."""

0 commit comments

Comments
 (0)