Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 51 additions & 3 deletions python/monarch/_src/job/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@
import os
import subprocess
import sys
from typing import Any, cast, Dict, FrozenSet, List, Optional, Sequence
from typing import Any, Dict, FrozenSet, List, Optional, Sequence

from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport
from monarch._rust_bindings.monarch_hyperactor.config import configure

from monarch._src.actor.bootstrap import attach_to_workers
from monarch._src.actor.host_mesh import HostMesh
from monarch._src.job.job import JobState, JobTrait


Expand Down Expand Up @@ -55,6 +54,8 @@ def __init__(
log_dir: Optional[str] = None,
exclusive: bool = True,
gpus_per_node: Optional[int] = None,
cpus_per_task: Optional[int] = None,
mem: Optional[str] = None,
) -> None:
"""
Args:
Expand Down Expand Up @@ -84,6 +85,8 @@ def __init__(
self._log_dir: str = log_dir if log_dir is not None else os.getcwd()
self._exclusive = exclusive
self._gpus_per_node = gpus_per_node
self._cpus_per_task = cpus_per_task
self._mem = mem
# Track the single SLURM job ID and all allocated hostnames
self._slurm_job_id: Optional[str] = None
self._all_hostnames: List[str] = []
Expand Down Expand Up @@ -128,12 +131,33 @@ def _submit_slurm_job(self, num_nodes: int) -> str:
if self._gpus_per_node is not None:
sbatch_directives.append(f"#SBATCH --gpus-per-node={self._gpus_per_node}")

if self._cpus_per_task is not None:
sbatch_directives.append(f"#SBATCH --cpus-per-task={self._cpus_per_task}")

if self._mem is not None:
sbatch_directives.append(f"#SBATCH --mem={self._mem}")

if self._exclusive:
sbatch_directives.append("#SBATCH --exclusive")

if self._partition:
if self._partition is not None:
sbatch_directives.append(f"#SBATCH --partition={self._partition}")

if (
not self._exclusive
and self._partition is not None
and self._gpus_per_node is not None
):
gpus_per_task = self._gpus_per_node // self._ntasks_per_node
assert (
self._partition
), "Slurm partition must be set for jobs that share nodes with other jobs"
self.share_node(
tasks_per_node=self._ntasks_per_node,
gpus_per_task=gpus_per_task,
partition=self._partition,
)

# Add any additional slurm args as directives
for arg in self._slurm_args:
if arg.startswith("-"):
Expand Down Expand Up @@ -297,6 +321,8 @@ def can_run(self, spec: "JobTrait") -> bool:
and spec._time_limit == self._time_limit
and spec._partition == self._partition
and spec._gpus_per_node == self._gpus_per_node
and spec._cpus_per_task == self._cpus_per_task
and spec._mem == self._mem
and self._jobs_active()
)

Expand All @@ -318,6 +344,28 @@ def _jobs_active(self) -> bool:

return True

def share_node(
self, tasks_per_node: int, gpus_per_task: int, partition: str
) -> None:
"""
Share a node with other jobs.
"""
try:
import clusterscope
except ImportError:
raise RuntimeError(
"please install clusterscope to use share_node. `pip install clusterscope`"
)
self._exclusive = False

slurm_args = clusterscope.job_gen_task_slurm(
partition=partition,
gpus_per_task=gpus_per_task,
tasks_per_node=tasks_per_node,
)
self._cpus_per_task = slurm_args["cpus_per_task"]
self._mem = slurm_args["memory"]

def _kill(self) -> None:
"""Cancel the SLURM job."""
if self._slurm_job_id is not None:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ torchx-nightly
lark
tabulate
opentelemetry-api
clusterscope