Skip to content

Commit 3dd0dba

Browse files
authored
Add subnet id for pipeline step if specified by user. (#1255)
2 parents 136d2fd + b5395b8 commit 3dd0dba

File tree

4 files changed

+45
-26
lines changed

4 files changed

+45
-26
lines changed

ads/jobs/builders/infrastructure/dsc_job.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1751,6 +1751,7 @@ def is_multi_node_job(runtime):
17511751
return (
17521752
MULTI_NODE_JOB_SUPPORT
17531753
and isinstance(runtime, MultiNodeRuntime)
1754+
and runtime.replica
17541755
and runtime.replica > 1
17551756
)
17561757

ads/jobs/builders/infrastructure/dsc_job_runtime.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,11 @@ def _get_node_group(self, dsc_job):
365365
dsc_job,
366366
"job_node_configuration_details.job_node_group_configuration_details_list",
367367
)
368+
if node_groups is None:
369+
node_groups = get_value(
370+
dsc_job,
371+
"job_node_configuration_details.jobNodeGroupConfigurationDetailsList",
372+
)
368373
if node_groups and len(node_groups) == 1:
369374
return node_groups[0]
370375
return None
@@ -373,6 +378,7 @@ def _get_replica(self, dsc_job, envs):
373378
node_group = self._get_node_group(dsc_job)
374379
if node_group:
375380
replica = get_value(node_group, "replicas")
381+
envs.pop(self.CONST_NODE_COUNT, None)
376382
elif not envs:
377383
replica = None
378384
elif self.CONST_WORKER_COUNT in envs:
@@ -399,7 +405,9 @@ def _extract_envs(self, dsc_job):
399405
env_attr = "job_configuration_details.environment_variables"
400406
node_group = self._get_node_group(dsc_job)
401407
if node_group:
402-
envs = get_value(node_group, env_attr)
408+
envs = get_value(node_group, env_attr) or get_value(
409+
node_group, "jobConfigurationDetails.environment_variables"
410+
)
403411
else:
404412
envs = get_value(dsc_job, env_attr)
405413
if envs:

ads/pipeline/ads_pipeline.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1728,15 +1728,19 @@ def __step_details(self, pipeline_details: Dict) -> list:
17281728

17291729
def __step_infrastructure_configuration_details(self, step) -> dict:
17301730
step_infrastructure_configuration_details = {}
1731-
step_infrastructure_configuration_details[
1732-
"blockStorageSizeInGBs"
1733-
] = step.infrastructure.block_storage_size
1734-
step_infrastructure_configuration_details[
1735-
"shapeName"
1736-
] = step.infrastructure.shape_name
1737-
step_infrastructure_configuration_details[
1738-
"shapeConfigDetails"
1739-
] = step.infrastructure.shape_config_details
1731+
step_infrastructure_configuration_details["blockStorageSizeInGBs"] = (
1732+
step.infrastructure.block_storage_size
1733+
)
1734+
step_infrastructure_configuration_details["shapeName"] = (
1735+
step.infrastructure.shape_name
1736+
)
1737+
step_infrastructure_configuration_details["shapeConfigDetails"] = (
1738+
step.infrastructure.shape_config_details
1739+
)
1740+
if getattr(step.infrastructure, "subnet_id", ""):
1741+
step_infrastructure_configuration_details["subnetId"] = (
1742+
step.infrastructure.subnet_id
1743+
)
17401744
return step_infrastructure_configuration_details
17411745

17421746
def __step_configuration_details(self, pipeline_details: Dict, step) -> dict:

tests/unitary/default_setup/jobs/test_jobs_pytorch_ddp.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from unittest import mock
1111

1212
from ads.jobs import DataScienceJob, DataScienceJobRun, PyTorchDistributedRuntime
13+
from ads.jobs.builders.infrastructure.dsc_job_runtime import (
14+
MULTI_NODE_JOB_SUPPORT,
15+
)
1316
from ads.jobs.builders.infrastructure.dsc_job_runtime import (
1417
PyTorchDistributedRuntimeHandler as Handler,
1518
)
@@ -133,23 +136,26 @@ def test_create_job_runs(self, patched_run, *args):
133136
self.assertIsInstance(main_run, DataScienceJobRun)
134137
self.assertEqual(main_run.id, test_ocid)
135138
kwarg_list = [call_args.kwargs for call_args in patched_run.call_args_list]
136-
self.assertEqual(
137-
kwarg_list,
138-
[
139-
{
140-
"display_name": "None-0",
141-
"environment_variables": {"NODE_RANK": "0", "NODE_COUNT": "2"},
142-
},
143-
{
144-
"display_name": "None-1",
145-
"environment_variables": {
146-
"NODE_RANK": "1",
147-
"NODE_COUNT": "2",
148-
"MAIN_JOB_RUN_OCID": test_ocid,
139+
if MULTI_NODE_JOB_SUPPORT:
140+
self.assertEqual(kwarg_list, [{}])
141+
else:
142+
self.assertEqual(
143+
kwarg_list,
144+
[
145+
{
146+
"display_name": "None-0",
147+
"environment_variables": {"NODE_RANK": "0", "NODE_COUNT": "2"},
149148
},
150-
},
151-
],
152-
)
149+
{
150+
"display_name": "None-1",
151+
"environment_variables": {
152+
"NODE_RANK": "1",
153+
"NODE_COUNT": "2",
154+
"MAIN_JOB_RUN_OCID": test_ocid,
155+
},
156+
},
157+
],
158+
)
153159

154160
@mock.patch.dict(
155161
os.environ, {utils.CONST_ENV_INPUT_MAPPINGS: json.dumps({INPUT_SRC: INPUT_DST})}

0 commit comments

Comments
 (0)