1111import  os 
1212import  subprocess 
1313import  sys 
14- from  typing  import  Any , cast ,  Dict , FrozenSet , List , Optional , Sequence 
14+ from  typing  import  Any , Dict , FrozenSet , List , Optional , Sequence 
1515
1616from  monarch ._rust_bindings .monarch_hyperactor .channel  import  ChannelTransport 
1717from  monarch ._rust_bindings .monarch_hyperactor .config  import  configure 
1818
1919from  monarch ._src .actor .bootstrap  import  attach_to_workers 
20- from  monarch ._src .actor .host_mesh  import  HostMesh 
2120from  monarch ._src .job .job  import  JobState , JobTrait 
2221
2322
@@ -55,6 +54,8 @@ def __init__(
5554        log_dir : Optional [str ] =  None ,
5655        exclusive : bool  =  True ,
5756        gpus_per_node : Optional [int ] =  None ,
57+         cpus_per_task : Optional [int ] =  None ,
58+         mem : Optional [str ] =  None ,
5859    ) ->  None :
5960        """ 
6061        Args: 
@@ -84,6 +85,8 @@ def __init__(
8485        self ._log_dir : str  =  log_dir  if  log_dir  is  not None  else  os .getcwd ()
8586        self ._exclusive  =  exclusive 
8687        self ._gpus_per_node  =  gpus_per_node 
88+         self ._cpus_per_task  =  cpus_per_task 
89+         self ._mem  =  mem 
8790        # Track the single SLURM job ID and all allocated hostnames 
8891        self ._slurm_job_id : Optional [str ] =  None 
8992        self ._all_hostnames : List [str ] =  []
@@ -128,6 +131,12 @@ def _submit_slurm_job(self, num_nodes: int) -> str:
128131        if  self ._gpus_per_node  is  not None :
129132            sbatch_directives .append (f"#SBATCH --gpus-per-node={ self ._gpus_per_node }  )
130133
134+         if  self ._cpus_per_task  is  not None :
135+             sbatch_directives .append (f"#SBATCH --cpus-per-task={ self ._cpus_per_task }  )
136+ 
137+         if  self ._mem  is  not None :
138+             sbatch_directives .append (f"#SBATCH --mem={ self ._mem }  )
139+ 
131140        if  self ._exclusive :
132141            sbatch_directives .append ("#SBATCH --exclusive" )
133142
@@ -297,6 +306,8 @@ def can_run(self, spec: "JobTrait") -> bool:
297306            and  spec ._time_limit  ==  self ._time_limit 
298307            and  spec ._partition  ==  self ._partition 
299308            and  spec ._gpus_per_node  ==  self ._gpus_per_node 
309+             and  spec ._cpus_per_task  ==  self ._cpus_per_task 
310+             and  spec ._mem  ==  self ._mem 
300311            and  self ._jobs_active ()
301312        )
302313
@@ -318,6 +329,28 @@ def _jobs_active(self) -> bool:
318329
319330        return  True 
320331
332+     def  share_node (
333+         self , tasks_per_node : int , gpus_per_task : int , partition : str 
334+     ) ->  None :
335+         """ 
336+         Share a node with other jobs. 
337+         """ 
338+         try :
339+             import  clusterscope 
340+         except  ImportError :
341+             raise  RuntimeError (
342+                 "please install clusterscope to use share_node. `pip install clusterscope`" 
343+             )
344+         self ._exclusive  =  False 
345+ 
346+         slurm_args  =  clusterscope .job_gen_task_slurm (
347+             partition = partition ,
348+             gpus_per_task = gpus_per_task ,
349+             tasks_per_node = tasks_per_node ,
350+         )
351+         self ._cpus_per_task  =  slurm_args ["cpus_per_task" ]
352+         self ._mem  =  slurm_args ["memory" ]
353+ 
321354    def  _kill (self ) ->  None :
322355        """Cancel the SLURM job.""" 
323356        if  self ._slurm_job_id  is  not None :
0 commit comments