Skip to content

Commit b2e53b3

Browse files
committed
Impl of record, convert exp to use records
1 parent 2cbd3be commit b2e53b3

File tree

4 files changed

+217
-94
lines changed

4 files changed

+217
-94
lines changed

smartsim/experiment.py

Lines changed: 38 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from smartsim._core.control.launch_history import LaunchHistory as _LaunchHistory
4343
from smartsim._core.utils import helpers as _helpers
4444
from smartsim.error import errors
45-
from smartsim.launchable.job import Job
45+
from smartsim.launchable.job import Job, Record
4646
from smartsim.status import TERMINAL_STATUSES, InvalidJobStatus, JobStatus
4747

4848
from ._core import Generator, Manifest
@@ -52,7 +52,6 @@
5252
from .log import ctx_exp_path, get_logger, method_contextualizer
5353

5454
if t.TYPE_CHECKING:
55-
from smartsim.launchable.job import Job
5655
from smartsim.types import LaunchedJobID
5756

5857
logger = get_logger(__name__)
@@ -158,7 +157,7 @@ def __init__(self, name: str, exp_path: str | None = None):
158157
experiment
159158
"""
160159

161-
def start(self, *jobs: Job | t.Sequence[Job]) -> tuple[LaunchedJobID, ...]:
160+
def start(self, *jobs: Job | t.Sequence[Job]) -> tuple[Record, ...]:
162161
"""Execute a collection of `Job` instances.
163162
164163
:param jobs: A collection of other job instances to start
@@ -172,12 +171,11 @@ def start(self, *jobs: Job | t.Sequence[Job]) -> tuple[LaunchedJobID, ...]:
172171
if not jobs:
173172
raise ValueError("No jobs provided to start")
174173

175-
# Create the run id
176174
jobs_ = list(_helpers.unpack(jobs))
177-
178175
run_id = datetime.datetime.now().replace(microsecond=0).isoformat()
179176
root = pathlib.Path(self.exp_path, run_id)
180-
return self._dispatch(Generator(root), dispatch.DEFAULT_DISPATCHER, *jobs_)
177+
ids = self._dispatch(Generator(root), dispatch.DEFAULT_DISPATCHER, *jobs_)
178+
return tuple(Record(id_, job) for id_, job in zip(ids, jobs_))
181179

182180
def _dispatch(
183181
self,
@@ -233,9 +231,7 @@ def execute_dispatch(generator: Generator, job: Job, idx: int) -> LaunchedJobID:
233231
execute_dispatch(generator, job, idx) for idx, job in enumerate(jobs, 1)
234232
)
235233

236-
def get_status(
237-
self, *ids: LaunchedJobID
238-
) -> tuple[JobStatus | InvalidJobStatus, ...]:
234+
def get_status(self, *records: Record) -> tuple[JobStatus | InvalidJobStatus, ...]:
239235
"""Get the status of jobs launched through the `Experiment` from their
240236
launched job id returned when calling `Experiment.start`.
241237
@@ -252,16 +248,17 @@ def get_status(
252248
launched job ids issued by user defined launcher are not sufficiently
253249
unique.
254250
255-
:param ids: A sequence of launched job ids issued by the experiment.
251+
:param records: A sequence of records issued by the experiment.
256252
:raises TypeError: If ids provided are not the correct type
257253
:raises ValueError: No IDs were provided.
258254
:returns: A tuple of statuses with order respective of the order of the
259255
calling arguments.
260256
"""
261-
if not ids:
262-
raise ValueError("No job ids provided to get status")
263-
if not all(isinstance(id, str) for id in ids):
264-
raise TypeError("ids argument was not of type LaunchedJobID")
257+
if not records:
258+
raise ValueError("No records provided to get status")
259+
if not all(isinstance(record, Record) for record in records):
260+
raise TypeError("record argument was not of type Record")
261+
ids = tuple(record.launched_id for record in records)
265262

266263
to_query = self._launch_history.group_by_launcher(
267264
set(ids), unknown_ok=True
@@ -272,39 +269,38 @@ def get_status(
272269
return tuple(stats)
273270

274271
def wait(
275-
self, *ids: LaunchedJobID, timeout: float | None = None, verbose: bool = True
272+
self, *records: Record, timeout: float | None = None, verbose: bool = True
276273
) -> None:
277274
"""Block execution until all of the provided launched jobs, represented
278275
by an ID, have entered a terminal status.
279276
280-
:param ids: The ids of the launched jobs to wait for.
277+
:param records: The records of the launched jobs to wait for.
281278
:param timeout: The max time to wait for all of the launched jobs to end.
282279
:param verbose: Whether found statuses should be displayed in the console.
283280
:raises TypeError: If IDs provided are not the correct type
284281
:raises ValueError: No IDs were provided.
285282
"""
286-
if ids:
287-
if not all(isinstance(id, str) for id in ids):
288-
raise TypeError("ids argument was not of type LaunchedJobID")
289-
else:
290-
raise ValueError("No job ids to wait on provided")
283+
if not records:
284+
raise ValueError("No records to wait on provided")
285+
if not all(isinstance(record, Record) for record in records):
286+
raise TypeError("record argument was not of type Record")
291287
self._poll_for_statuses(
292-
ids, TERMINAL_STATUSES, timeout=timeout, verbose=verbose
288+
records, TERMINAL_STATUSES, timeout=timeout, verbose=verbose
293289
)
294290

295291
def _poll_for_statuses(
296292
self,
297-
ids: t.Sequence[LaunchedJobID],
293+
records: t.Sequence[Record],
298294
statuses: t.Collection[JobStatus],
299295
timeout: float | None = None,
300296
interval: float = 5.0,
301297
verbose: bool = True,
302-
) -> dict[LaunchedJobID, JobStatus | InvalidJobStatus]:
298+
) -> dict[Record, JobStatus | InvalidJobStatus]:
303299
"""Poll the experiment's launchers for the statuses of the launched
304300
jobs with the provided ids, until the status of the changes to one of
305301
the provided statuses.
306302
307-
:param ids: The ids of the launched jobs to wait for.
303+
:param records: The records of the launched jobs to wait for.
308304
:param statuses: A collection of statuses to poll for.
309305
:param timeout: The minimum amount of time to spend polling all jobs to
310306
reach one of the supplied statuses. If not supplied or `None`, the
@@ -320,12 +316,10 @@ def _poll_for_statuses(
320316
log = logger.info if verbose else lambda *_, **__: None
321317
method_timeout = _interval.SynchronousTimeInterval(timeout)
322318
iter_timeout = _interval.SynchronousTimeInterval(interval)
323-
final: dict[LaunchedJobID, JobStatus | InvalidJobStatus] = {}
319+
final: dict[Record, JobStatus | InvalidJobStatus] = {}
324320

325-
def is_finished(
326-
id_: LaunchedJobID, status: JobStatus | InvalidJobStatus
327-
) -> bool:
328-
job_title = f"Job({id_}): "
321+
def is_finished(record: Record, status: JobStatus | InvalidJobStatus) -> bool:
322+
job_title = f"Job({record.launched_id}, {record.job.name}): "
329323
if done := status in terminal:
330324
log(f"{job_title}Finished with status '{status.value}'")
331325
else:
@@ -334,22 +328,22 @@ def is_finished(
334328

335329
if iter_timeout.infinite:
336330
raise ValueError("Polling interval cannot be infinite")
337-
while ids and not method_timeout.expired:
331+
while records and not method_timeout.expired:
338332
iter_timeout = iter_timeout.new_interval()
339-
stats = zip(ids, self.get_status(*ids))
333+
stats = zip(records, self.get_status(*records))
340334
is_done = _helpers.group_by(_helpers.pack_params(is_finished), stats)
341335
final |= dict(is_done.get(True, ()))
342-
ids = tuple(id_ for id_, _ in is_done.get(False, ()))
343-
if ids:
336+
records = tuple(rec for rec, _ in is_done.get(False, ()))
337+
if records:
344338
(
345339
iter_timeout
346340
if iter_timeout.remaining < method_timeout.remaining
347341
else method_timeout
348342
).block()
349-
if ids:
343+
if records:
350344
raise TimeoutError(
351-
f"Job ID(s) {', '.join(map(str, ids))} failed to reach "
352-
"terminal status before timeout"
345+
f"Job ID(s) {', '.join(rec.launched_id for rec in records)} "
346+
"failed to reach terminal status before timeout"
353347
)
354348
return final
355349

@@ -445,20 +439,20 @@ def summary(self, style: str = "github") -> str:
445439
disable_numparse=True,
446440
)
447441

448-
def stop(self, *ids: LaunchedJobID) -> tuple[JobStatus | InvalidJobStatus, ...]:
442+
def stop(self, *records: Record) -> tuple[JobStatus | InvalidJobStatus, ...]:
449443
"""Cancel the execution of a previously launched job.
450444
451-
:param ids: The ids of the launched jobs to stop.
445+
:param records: The records of the launched jobs to stop.
452446
:raises TypeError: If ids provided are not the correct type
453447
:raises ValueError: No job ids were provided.
454448
:returns: A tuple of job statuses upon cancellation with order
455449
respective of the order of the calling arguments.
456450
"""
457-
if ids:
458-
if not all(isinstance(id, str) for id in ids):
459-
raise TypeError("ids argument was not of type LaunchedJobID")
460-
else:
461-
raise ValueError("No job ids provided")
451+
if not records:
452+
raise ValueError("No records provided")
453+
if not all(isinstance(reocrd, Record) for reocrd in records):
454+
raise TypeError("record argument was not of type Record")
455+
ids = tuple(record.launched_id for record in records)
462456
by_launcher = self._launch_history.group_by_launcher(set(ids), unknown_ok=True)
463457
id_to_stop_stat = (
464458
launcher.stop_jobs(*launched).items()

smartsim/launchable/job.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
if t.TYPE_CHECKING:
4141
from smartsim.entity.entity import SmartSimEntity
42+
from smartsim.types import LaunchedJobID
4243

4344

4445
@t.final
@@ -158,3 +159,35 @@ def __str__(self) -> str: # pragma: no cover
158159
string = f"SmartSim Entity: {self.entity}\n"
159160
string += f"Launch Settings: {self.launch_settings}"
160161
return string
162+
163+
164+
@t.final
165+
class Record:
166+
"""A record a job that was launched and launch ID assigned to the
167+
launching.
168+
"""
169+
170+
def __init__(self, launch_id: LaunchedJobID, job: Job) -> None:
171+
"""Initialize a new record of a launched job
172+
173+
:param launch_id: A unique identifier for the launch of the job.
174+
:param job: The job that was launched.
175+
"""
176+
self._id = launch_id
177+
self._job = deepcopy(job)
178+
179+
@property
180+
def launched_id(self) -> LaunchedJobID:
181+
"""The unique identifier for the launched job.
182+
183+
:returns: A unique identifier for the launched job.
184+
"""
185+
return self._id
186+
187+
@property
188+
def job(self) -> Job:
189+
"""A deep copy of the job that was launched.
190+
191+
:returns: A deep copy of the launched job.
192+
"""
193+
return deepcopy(self._job)

0 commit comments

Comments
 (0)