@@ -206,6 +206,7 @@ def run(self, args: argparse.Namespace) -> None:
206
206
class CmdRun (SubCommand ):
207
207
def __init__ (self ) -> None :
208
208
self ._subparser : Optional [argparse .ArgumentParser ] = None
209
+ self ._stdin_data_json : Optional [Dict [str , Any ]] = None
209
210
210
211
def add_arguments (self , subparser : argparse .ArgumentParser ) -> None :
211
212
scheduler_names = get_scheduler_factories ().keys ()
@@ -369,6 +370,15 @@ def _run_from_stdin_args(self, runner: Runner, stdin_data: Dict[str, Any]) -> No
369
370
torchx_run_args .scheduler_cfg = cfg
370
371
self ._run_inner (runner , torchx_run_args )
371
372
373
+ def _get_torchx_stdin_args (
374
+ self , args : argparse .Namespace
375
+ ) -> Optional [Dict [str , Any ]]:
376
+ if not args .stdin :
377
+ return None
378
+ if self ._stdin_data_json is None :
379
+ self ._stdin_data_json = self .torchx_json_from_stdin ()
380
+ return self ._stdin_data_json
381
+
372
382
def torchx_json_from_stdin (self ) -> Dict [str , Any ]:
373
383
try :
374
384
stdin_data_json = json .load (sys .stdin )
@@ -419,11 +429,11 @@ def verify_no_extra_args(self, args: argparse.Namespace) -> None:
419
429
)
420
430
421
431
def _run (self , runner : Runner , args : argparse .Namespace ) -> None :
422
- # Verify no conflicting arguments when using to loop over the stdin
423
432
self .verify_no_extra_args (args )
424
433
if args .stdin :
425
- stdin_data_json = self .torchx_json_from_stdin ()
426
- self ._run_from_stdin_args (runner , stdin_data_json )
434
+ stdin_data_json = self ._get_torchx_stdin_args (args )
435
+ if stdin_data_json is not None :
436
+ self ._run_from_stdin_args (runner , stdin_data_json )
427
437
else :
428
438
self ._run_from_cli_args (runner , args )
429
439
0 commit comments