Skip to content

Commit dbf7b72

Browse files
authored
Allow for sequences into exp.start(), and unpack iterables (#712)
The Experiment.start() method is now able to iterate over sequences of jobs and unpack any nested sequence to be deployed. [ committed by @juliaputko ] [ reviewed by @amandarichardsonn @MattToast @mellis13 ]
1 parent 4d9ab27 commit dbf7b72

File tree

4 files changed

+163
-6
lines changed

4 files changed

+163
-6
lines changed

smartsim/_core/utils/helpers.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,39 @@
4343
from datetime import datetime
4444
from shutil import which
4545

46+
from typing_extensions import TypeAlias
47+
4648
if t.TYPE_CHECKING:
4749
from types import FrameType
4850

4951
from typing_extensions import TypeVarTuple, Unpack
5052

53+
from smartsim.launchable.job import Job
54+
5155
_Ts = TypeVarTuple("_Ts")
5256

5357

5458
_T = t.TypeVar("_T")
5559
_HashableT = t.TypeVar("_HashableT", bound=t.Hashable)
5660
_TSignalHandlerFn = t.Callable[[int, t.Optional["FrameType"]], object]
5761

62+
_NestedJobSequenceType: TypeAlias = "t.Sequence[Job | _NestedJobSequenceType]"
63+
64+
65+
def unpack(value: _NestedJobSequenceType) -> t.Generator[Job, None, None]:
66+
"""Unpack any iterable input in order to obtain a
67+
single sequence of values
68+
69+
:param value: Sequence containing elements of type Job or other
70+
sequences that are also of type _NestedJobSequenceType
71+
:return: flattened list of Jobs"""
72+
73+
for item in value:
74+
if isinstance(item, t.Iterable):
75+
yield from unpack(item)
76+
else:
77+
yield item
78+
5879

5980
def check_name(name: str) -> None:
6081
"""

smartsim/experiment.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,19 +153,18 @@ def __init__(self, name: str, exp_path: str | None = None):
153153
experiment
154154
"""
155155

156-
def start(self, *jobs: Job) -> tuple[LaunchedJobID, ...]:
156+
def start(self, *jobs: Job | t.Sequence[Job]) -> tuple[LaunchedJobID, ...]:
157157
"""Execute a collection of `Job` instances.
158158
159159
:param jobs: A collection of other job instances to start
160160
:returns: A sequence of ids with order corresponding to the sequence of
161161
jobs that can be used to query or alter the status of that
162162
particular execution of the job.
163163
"""
164-
# Create the run id
164+
jobs_ = list(_helpers.unpack(jobs))
165165
run_id = datetime.datetime.now().replace(microsecond=0).isoformat()
166-
# Generate the root path
167166
root = pathlib.Path(self.exp_path, run_id)
168-
return self._dispatch(Generator(root), dispatch.DEFAULT_DISPATCHER, *jobs)
167+
return self._dispatch(Generator(root), dispatch.DEFAULT_DISPATCHER, *jobs_)
169168

170169
def _dispatch(
171170
self,

tests/_legacy/test_helpers.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,32 @@
3030
import pytest
3131

3232
from smartsim._core.utils import helpers
33-
from smartsim._core.utils.helpers import cat_arg_and_value
33+
from smartsim._core.utils.helpers import cat_arg_and_value, unpack
34+
from smartsim.entity.application import Application
35+
from smartsim.launchable.job import Job
36+
from smartsim.settings.launch_settings import LaunchSettings
3437

3538
# The tests in this file belong to the group_a group
3639
pytestmark = pytest.mark.group_a
3740

3841

42+
def test_unpack_iterates_over_nested_jobs_in_expected_order(wlmutils):
43+
launch_settings = LaunchSettings(wlmutils.get_test_launcher())
44+
app = Application("app_name", exe="python")
45+
job_1 = Job(app, launch_settings)
46+
job_2 = Job(app, launch_settings)
47+
job_3 = Job(app, launch_settings)
48+
job_4 = Job(app, launch_settings)
49+
job_5 = Job(app, launch_settings)
50+
51+
assert (
52+
[job_1, job_2, job_3, job_4, job_5]
53+
== list(unpack([job_1, [job_2, job_3], job_4, [job_5]]))
54+
== list(unpack([job_1, job_2, [job_3, job_4], job_5]))
55+
== list(unpack([job_1, [job_2, [job_3, job_4], job_5]]))
56+
)
57+
58+
3959
def test_double_dash_concat():
4060
result = cat_arg_and_value("--foo", "FOO")
4161
assert result == "--foo=FOO"

tests/test_experiment.py

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,24 +34,33 @@
3434
import time
3535
import typing as t
3636
import uuid
37+
from os import path as osp
3738

3839
import pytest
3940

4041
from smartsim._core import dispatch
41-
from smartsim._core.control.interval import SynchronousTimeInterval
4242
from smartsim._core.control.launch_history import LaunchHistory
4343
from smartsim._core.generation.generator import Job_Path
4444
from smartsim._core.utils.launcher import LauncherProtocol, create_job_id
45+
from smartsim.builders.ensemble import Ensemble
4546
from smartsim.entity import entity
47+
from smartsim.entity.application import Application
4648
from smartsim.error import errors
4749
from smartsim.experiment import Experiment
4850
from smartsim.launchable import job
4951
from smartsim.settings import launch_settings
5052
from smartsim.settings.arguments import launch_arguments
5153
from smartsim.status import InvalidJobStatus, JobStatus
54+
from smartsim.types import LaunchedJobID
5255

5356
pytestmark = pytest.mark.group_a
5457

58+
_ID_GENERATOR = (str(i) for i in itertools.count())
59+
60+
61+
def random_id():
62+
return next(_ID_GENERATOR)
63+
5564

5665
@pytest.fixture
5766
def experiment(monkeypatch, test_dir, dispatcher):
@@ -614,3 +623,111 @@ def test_experiment_stop_does_not_raise_on_unknown_job_id(
614623
assert stat == InvalidJobStatus.NEVER_STARTED
615624
after_cancel = exp.get_status(*all_known_ids)
616625
assert before_cancel == after_cancel
626+
627+
628+
@pytest.mark.parametrize(
629+
"job_list",
630+
(
631+
pytest.param(
632+
[
633+
(
634+
job.Job(
635+
Application(
636+
"test_name",
637+
exe="echo",
638+
exe_args=["spam", "eggs"],
639+
),
640+
launch_settings.LaunchSettings("local"),
641+
),
642+
Ensemble("ensemble-name", "echo", replicas=2).build_jobs(
643+
launch_settings.LaunchSettings("local")
644+
),
645+
)
646+
],
647+
id="(job1, (job2, job_3))",
648+
),
649+
pytest.param(
650+
[
651+
(
652+
Ensemble("ensemble-name", "echo", replicas=2).build_jobs(
653+
launch_settings.LaunchSettings("local")
654+
),
655+
(
656+
job.Job(
657+
Application(
658+
"test_name",
659+
exe="echo",
660+
exe_args=["spam", "eggs"],
661+
),
662+
launch_settings.LaunchSettings("local"),
663+
),
664+
job.Job(
665+
Application(
666+
"test_name_2",
667+
exe="echo",
668+
exe_args=["spam", "eggs"],
669+
),
670+
launch_settings.LaunchSettings("local"),
671+
),
672+
),
673+
)
674+
],
675+
id="((job1, job2), (job3, job4))",
676+
),
677+
pytest.param(
678+
[
679+
(
680+
job.Job(
681+
Application(
682+
"test_name",
683+
exe="echo",
684+
exe_args=["spam", "eggs"],
685+
),
686+
launch_settings.LaunchSettings("local"),
687+
),
688+
)
689+
],
690+
id="(job,)",
691+
),
692+
pytest.param(
693+
[
694+
[
695+
job.Job(
696+
Application(
697+
"test_name",
698+
exe="echo",
699+
exe_args=["spam", "eggs"],
700+
),
701+
launch_settings.LaunchSettings("local"),
702+
),
703+
(
704+
Ensemble("ensemble-name", "echo", replicas=2).build_jobs(
705+
launch_settings.LaunchSettings("local")
706+
),
707+
job.Job(
708+
Application(
709+
"test_name_2",
710+
exe="echo",
711+
exe_args=["spam", "eggs"],
712+
),
713+
launch_settings.LaunchSettings("local"),
714+
),
715+
),
716+
]
717+
],
718+
id="[job_1, ((job_2, job_3), job_4)]",
719+
),
720+
),
721+
)
722+
def test_start_unpack(
723+
test_dir: str, wlmutils, monkeypatch: pytest.MonkeyPatch, job_list: job.Job
724+
):
725+
"""Test unpacking a sequences of jobs"""
726+
727+
monkeypatch.setattr(
728+
"smartsim._core.dispatch._LauncherAdapter.start",
729+
lambda launch, exe, job_execution_path, env, out, err: random_id(),
730+
)
731+
732+
exp = Experiment(name="exp_name", exp_path=test_dir)
733+
exp.start(*job_list)

0 commit comments

Comments
 (0)