4242from smartsim ._core .control .launch_history import LaunchHistory as _LaunchHistory
4343from smartsim ._core .utils import helpers as _helpers
4444from smartsim .error import errors
45- from smartsim .launchable .job import Job
45+ from smartsim .launchable .job import Job , Record
4646from smartsim .status import TERMINAL_STATUSES , InvalidJobStatus , JobStatus
4747
4848from ._core import Generator , Manifest
5252from .log import ctx_exp_path , get_logger , method_contextualizer
5353
5454if t .TYPE_CHECKING :
55- from smartsim .launchable .job import Job
5655 from smartsim .types import LaunchedJobID
5756
5857logger = get_logger (__name__ )
@@ -158,7 +157,7 @@ def __init__(self, name: str, exp_path: str | None = None):
158157 experiment
159158 """
160159
161- def start (self , * jobs : Job | t .Sequence [Job ]) -> tuple [LaunchedJobID , ...]:
160+ def start (self , * jobs : Job | t .Sequence [Job ]) -> tuple [Record , ...]:
162161 """Execute a collection of `Job` instances.
163162
164163 :param jobs: A collection of other job instances to start
@@ -172,12 +171,11 @@ def start(self, *jobs: Job | t.Sequence[Job]) -> tuple[LaunchedJobID, ...]:
172171 if not jobs :
173172 raise ValueError ("No jobs provided to start" )
174173
175- # Create the run id
176174 jobs_ = list (_helpers .unpack (jobs ))
177-
178175 run_id = datetime .datetime .now ().replace (microsecond = 0 ).isoformat ()
179176 root = pathlib .Path (self .exp_path , run_id )
180- return self ._dispatch (Generator (root ), dispatch .DEFAULT_DISPATCHER , * jobs_ )
177+ ids = self ._dispatch (Generator (root ), dispatch .DEFAULT_DISPATCHER , * jobs_ )
178+ return tuple (Record (id_ , job ) for id_ , job in zip (ids , jobs_ ))
181179
182180 def _dispatch (
183181 self ,
@@ -233,9 +231,7 @@ def execute_dispatch(generator: Generator, job: Job, idx: int) -> LaunchedJobID:
233231 execute_dispatch (generator , job , idx ) for idx , job in enumerate (jobs , 1 )
234232 )
235233
236- def get_status (
237- self , * ids : LaunchedJobID
238- ) -> tuple [JobStatus | InvalidJobStatus , ...]:
234+ def get_status (self , * records : Record ) -> tuple [JobStatus | InvalidJobStatus , ...]:
239235 """Get the status of jobs launched through the `Experiment` from their
240236 launched job id returned when calling `Experiment.start`.
241237
@@ -252,16 +248,17 @@ def get_status(
252248 launched job ids issued by user defined launcher are not sufficiently
253249 unique.
254250
255- :param ids : A sequence of launched job ids issued by the experiment.
251+ :param records : A sequence of records issued by the experiment.
256252 :raises TypeError: If ids provided are not the correct type
257253 :raises ValueError: No IDs were provided.
258254 :returns: A tuple of statuses with order respective of the order of the
259255 calling arguments.
260256 """
261- if not ids :
262- raise ValueError ("No job ids provided to get status" )
263- if not all (isinstance (id , str ) for id in ids ):
264- raise TypeError ("ids argument was not of type LaunchedJobID" )
257+ if not records :
258+ raise ValueError ("No records provided to get status" )
259+ if not all (isinstance (record , Record ) for record in records ):
260+ raise TypeError ("record argument was not of type Record" )
261+ ids = tuple (record .launched_id for record in records )
265262
266263 to_query = self ._launch_history .group_by_launcher (
267264 set (ids ), unknown_ok = True
@@ -272,39 +269,38 @@ def get_status(
272269 return tuple (stats )
273270
274271 def wait (
275- self , * ids : LaunchedJobID , timeout : float | None = None , verbose : bool = True
272+ self , * records : Record , timeout : float | None = None , verbose : bool = True
276273 ) -> None :
277274 """Block execution until all of the provided launched jobs, represented
278275 by an ID, have entered a terminal status.
279276
280- :param ids : The ids of the launched jobs to wait for.
277+ :param records : The records of the launched jobs to wait for.
281278 :param timeout: The max time to wait for all of the launched jobs to end.
282279 :param verbose: Whether found statuses should be displayed in the console.
283280 :raises TypeError: If IDs provided are not the correct type
284281 :raises ValueError: No IDs were provided.
285282 """
286- if ids :
287- if not all (isinstance (id , str ) for id in ids ):
288- raise TypeError ("ids argument was not of type LaunchedJobID" )
289- else :
290- raise ValueError ("No job ids to wait on provided" )
283+ if not records :
284+ raise ValueError ("No records to wait on provided" )
285+ if not all (isinstance (record , Record ) for record in records ):
286+ raise TypeError ("record argument was not of type Record" )
291287 self ._poll_for_statuses (
292- ids , TERMINAL_STATUSES , timeout = timeout , verbose = verbose
288+ records , TERMINAL_STATUSES , timeout = timeout , verbose = verbose
293289 )
294290
295291 def _poll_for_statuses (
296292 self ,
297- ids : t .Sequence [LaunchedJobID ],
293+ records : t .Sequence [Record ],
298294 statuses : t .Collection [JobStatus ],
299295 timeout : float | None = None ,
300296 interval : float = 5.0 ,
301297 verbose : bool = True ,
302- ) -> dict [LaunchedJobID , JobStatus | InvalidJobStatus ]:
298+ ) -> dict [Record , JobStatus | InvalidJobStatus ]:
303299 """Poll the experiment's launchers for the statuses of the launched
304300 jobs with the provided ids, until the status of the changes to one of
305301 the provided statuses.
306302
307- :param ids : The ids of the launched jobs to wait for.
303+ :param records : The records of the launched jobs to wait for.
308304 :param statuses: A collection of statuses to poll for.
309305 :param timeout: The minimum amount of time to spend polling all jobs to
310306 reach one of the supplied statuses. If not supplied or `None`, the
@@ -320,12 +316,10 @@ def _poll_for_statuses(
320316 log = logger .info if verbose else lambda * _ , ** __ : None
321317 method_timeout = _interval .SynchronousTimeInterval (timeout )
322318 iter_timeout = _interval .SynchronousTimeInterval (interval )
323- final : dict [LaunchedJobID , JobStatus | InvalidJobStatus ] = {}
319+ final : dict [Record , JobStatus | InvalidJobStatus ] = {}
324320
325- def is_finished (
326- id_ : LaunchedJobID , status : JobStatus | InvalidJobStatus
327- ) -> bool :
328- job_title = f"Job({ id_ } ): "
321+ def is_finished (record : Record , status : JobStatus | InvalidJobStatus ) -> bool :
322+ job_title = f"Job({ record .launched_id } , { record .job .name } ): "
329323 if done := status in terminal :
330324 log (f"{ job_title } Finished with status '{ status .value } '" )
331325 else :
@@ -334,22 +328,22 @@ def is_finished(
334328
335329 if iter_timeout .infinite :
336330 raise ValueError ("Polling interval cannot be infinite" )
337- while ids and not method_timeout .expired :
331+ while records and not method_timeout .expired :
338332 iter_timeout = iter_timeout .new_interval ()
339- stats = zip (ids , self .get_status (* ids ))
333+ stats = zip (records , self .get_status (* records ))
340334 is_done = _helpers .group_by (_helpers .pack_params (is_finished ), stats )
341335 final |= dict (is_done .get (True , ()))
342- ids = tuple (id_ for id_ , _ in is_done .get (False , ()))
343- if ids :
336+ records = tuple (rec for rec , _ in is_done .get (False , ()))
337+ if records :
344338 (
345339 iter_timeout
346340 if iter_timeout .remaining < method_timeout .remaining
347341 else method_timeout
348342 ).block ()
349- if ids :
343+ if records :
350344 raise TimeoutError (
351- f"Job ID(s) { ', ' .join (map ( str , ids )) } failed to reach "
352- "terminal status before timeout"
345+ f"Job ID(s) { ', ' .join (rec . launched_id for rec in records ) } "
346+ "failed to reach terminal status before timeout"
353347 )
354348 return final
355349
@@ -445,20 +439,20 @@ def summary(self, style: str = "github") -> str:
445439 disable_numparse = True ,
446440 )
447441
448- def stop (self , * ids : LaunchedJobID ) -> tuple [JobStatus | InvalidJobStatus , ...]:
442+ def stop (self , * records : Record ) -> tuple [JobStatus | InvalidJobStatus , ...]:
449443 """Cancel the execution of a previously launched job.
450444
451- :param ids : The ids of the launched jobs to stop.
445+ :param records : The records of the launched jobs to stop.
452446 :raises TypeError: If ids provided are not the correct type
453447 :raises ValueError: No job ids were provided.
454448 :returns: A tuple of job statuses upon cancellation with order
455449 respective of the order of the calling arguments.
456450 """
457- if ids :
458- if not all ( isinstance ( id , str ) for id in ids ):
459- raise TypeError ( "ids argument was not of type LaunchedJobID" )
460- else :
461- raise ValueError ( "No job ids provided" )
451+ if not records :
452+ raise ValueError ( "No records provided" )
453+ if not all ( isinstance ( reocrd , Record ) for reocrd in records ):
454+ raise TypeError ( "record argument was not of type Record" )
455+ ids = tuple ( record . launched_id for record in records )
462456 by_launcher = self ._launch_history .group_by_launcher (set (ids ), unknown_ok = True )
463457 id_to_stop_stat = (
464458 launcher .stop_jobs (* launched ).items ()
0 commit comments