Skip to content

Commit 39e8e1c

Browse files
committed
addressing review comments; added value checks for batch and launch arguments
1 parent 3e8395a commit 39e8e1c

File tree

7 files changed

+168
-34
lines changed

7 files changed

+168
-34
lines changed

smartsim/settings/arguments/batch/lsf.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,10 @@ def set_smts(self, smts: int) -> None:
7272
takes precedence.
7373
7474
:param smts: SMT (e.g on Summit: 1, 2, or 4)
75+
:raises TypeError: if not an int
7576
"""
77+
if not isinstance(smts, int):
78+
raise TypeError("smts argument was not of type int")
7679
self.set("alloc_flags", str(smts))
7780

7881
def set_project(self, project: str) -> None:
@@ -81,7 +84,10 @@ def set_project(self, project: str) -> None:
8184
This sets ``-P``.
8285
8386
:param time: project name
87+
:raises TypeError: if not a str
8488
"""
89+
if not isinstance(project, str):
90+
raise TypeError("project argument was not of type str")
8591
self.set("P", project)
8692

8793
def set_account(self, account: str) -> None:
@@ -90,7 +96,10 @@ def set_account(self, account: str) -> None:
9096
this function is an alias for `set_project`.
9197
9298
:param account: project name
99+
:raises TypeError: if not a str
93100
"""
101+
if not isinstance(account, str):
102+
raise TypeError("account argument was not of type str")
94103
return self.set_project(account)
95104

96105
def set_nodes(self, num_nodes: int) -> None:
@@ -99,7 +108,10 @@ def set_nodes(self, num_nodes: int) -> None:
99108
This sets ``-nnodes``.
100109
101110
:param nodes: number of nodes
111+
:raises TypeError: if not an int
102112
"""
113+
if not isinstance(num_nodes, int):
114+
raise TypeError("num_nodes argument was not of type int")
103115
self.set("nnodes", str(num_nodes))
104116

105117
def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None:
@@ -110,10 +122,11 @@ def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None:
110122
"""
111123
if isinstance(host_list, str):
112124
host_list = [host_list.strip()]
113-
if not isinstance(host_list, list):
125+
if not (
126+
isinstance(host_list, list)
127+
and all(isinstance(item, str) for item in host_list)
128+
):
114129
raise TypeError("host_list argument must be a list of strings")
115-
if not all(isinstance(host, str) for host in host_list):
116-
raise TypeError("host_list argument must be list of strings")
117130
self.set("m", '"' + " ".join(host_list) + '"')
118131

119132
def set_tasks(self, tasks: int) -> None:
@@ -122,7 +135,10 @@ def set_tasks(self, tasks: int) -> None:
122135
This sets ``-n``
123136
124137
:param tasks: number of tasks
138+
:raises TypeError: if not an int
125139
"""
140+
if not isinstance(tasks, int):
141+
raise TypeError("tasks argument was not of type int")
126142
self.set("n", str(tasks))
127143

128144
def set_queue(self, queue: str) -> None:
@@ -131,7 +147,10 @@ def set_queue(self, queue: str) -> None:
131147
This sets ``-q``
132148
133149
:param queue: The queue to submit the job on
150+
:raises TypeError: if not a str
134151
"""
152+
if not isinstance(queue, str):
153+
raise TypeError("queue argument was not of type str")
135154
self.set("q", queue)
136155

137156
def format_batch_args(self) -> t.List[str]:

smartsim/settings/arguments/batch/pbs.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from __future__ import annotations
2828

29+
import re
2930
import typing as t
3031
from copy import deepcopy
3132

@@ -61,7 +62,10 @@ def set_nodes(self, num_nodes: int) -> None:
6162
nodes here is sets the 'nodes' resource.
6263
6364
:param num_nodes: number of nodes
65+
:raises TypeError: if not an int
6466
"""
67+
if not isinstance(num_nodes, int):
68+
raise TypeError("num_nodes argument was not of type int")
6569

6670
self.set("nodes", str(num_nodes))
6771

@@ -73,9 +77,10 @@ def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None:
7377
"""
7478
if isinstance(host_list, str):
7579
host_list = [host_list.strip()]
76-
if not isinstance(host_list, list):
77-
raise TypeError("host_list argument must be a list of strings")
78-
if not all(isinstance(host, str) for host in host_list):
80+
if not (
81+
isinstance(host_list, list)
82+
and all(isinstance(item, str) for item in host_list)
83+
):
7984
raise TypeError("host_list argument must be a list of strings")
8085
self.set("hostname", ",".join(host_list))
8186

@@ -89,14 +94,22 @@ def set_walltime(self, walltime: str) -> None:
8994
this value will be overridden
9095
9196
:param walltime: wall time
97+
:raises ValueError: if walltime format is invalid
9298
"""
93-
self.set("walltime", walltime)
99+
pattern = r"^\d{2}:\d{2}:\d{2}$"
100+
if walltime and re.match(pattern, walltime):
101+
self.set("walltime", walltime)
102+
else:
103+
raise ValueError("Invalid walltime format. Please use 'HH:MM:SS' format.")
94104

95105
def set_queue(self, queue: str) -> None:
96106
"""Set the queue for the batch job
97107
98108
:param queue: queue name
109+
:raises TypeError: if not a str
99110
"""
111+
if not isinstance(queue, str):
112+
raise TypeError("queue argument was not of type str")
100113
self.set("q", str(queue))
101114

102115
def set_ncpus(self, num_cpus: int) -> None:
@@ -107,14 +120,20 @@ def set_ncpus(self, num_cpus: int) -> None:
107120
this value will be overridden
108121
109122
:param num_cpus: number of cpus per node in select
123+
:raises TypeError: if not an int
110124
"""
125+
if not isinstance(num_cpus, int):
126+
raise TypeError("num_cpus argument was not of type int")
111127
self.set("ppn", str(num_cpus))
112128

113129
def set_account(self, account: str) -> None:
114130
"""Set the account for this batch job
115131
116132
:param acct: account id
133+
:raises TypeError: if not a str
117134
"""
135+
if not isinstance(account, str):
136+
raise TypeError("account argument was not of type str")
118137
self.set("A", str(account))
119138

120139
def format_batch_args(self) -> t.List[str]:

smartsim/settings/arguments/batch/slurm.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def set_walltime(self, walltime: str) -> None:
5656
format = "HH:MM:SS"
5757
5858
:param walltime: wall time
59+
:raises ValueError: if walltime format is invalid
5960
"""
6061
pattern = r"^\d{2}:\d{2}:\d{2}$"
6162
if walltime and re.match(pattern, walltime):
@@ -69,7 +70,10 @@ def set_nodes(self, num_nodes: int) -> None:
6970
This sets ``--nodes``.
7071
7172
:param num_nodes: number of nodes
73+
:raises TypeError: if not an int
7274
"""
75+
if not isinstance(num_nodes, int):
76+
raise TypeError("num_nodes argument was not of type int")
7377
self.set("nodes", str(num_nodes))
7478

7579
def set_account(self, account: str) -> None:
@@ -78,7 +82,10 @@ def set_account(self, account: str) -> None:
7882
This sets ``--account``.
7983
8084
:param account: account id
85+
:raises TypeError: if not a str
8186
"""
87+
if not isinstance(account, str):
88+
raise TypeError("account argument was not of type str")
8289
self.set("account", account)
8390

8491
def set_partition(self, partition: str) -> None:
@@ -96,7 +103,10 @@ def set_queue(self, queue: str) -> None:
96103
Sets the partition for the slurm batch job
97104
98105
:param queue: the partition to run the batch job on
106+
:raises TypeError: if not a str
99107
"""
108+
if not isinstance(queue, str):
109+
raise TypeError("queue argument was not of type str")
100110
return self.set_partition(queue)
101111

102112
def set_cpus_per_task(self, cpus_per_task: int) -> None:
@@ -118,10 +128,13 @@ def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None:
118128
"""
119129
if isinstance(host_list, str):
120130
host_list = [host_list.strip()]
121-
if not isinstance(host_list, list):
131+
132+
if not (
133+
isinstance(host_list, list)
134+
and all(isinstance(item, str) for item in host_list)
135+
):
122136
raise TypeError("host_list argument must be a list of strings")
123-
if not all(isinstance(host, str) for host in host_list):
124-
raise TypeError("host_list argument must be list of strings")
137+
125138
self.set("nodelist", ",".join(host_list))
126139

127140
def format_batch_args(self) -> t.List[str]:

smartsim/settings/batch_settings.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ def __init__(
118118
except ValueError:
119119
raise ValueError(f"Invalid scheduler type: {batch_scheduler}") from None
120120

121-
if batch_args:
121+
if batch_args is not None:
122122
if not (
123-
isinstance(batch_args, t.Mapping)
123+
isinstance(batch_args, dict)
124124
and all(isinstance(key, str) for key, val in batch_args.items())
125125
):
126126
raise TypeError(
@@ -152,10 +152,7 @@ def env_vars(self, value: t.Dict[str, str | None]) -> None:
152152

153153
if not (
154154
isinstance(value, t.Mapping)
155-
and all(
156-
isinstance(key, str) and isinstance(val, str)
157-
for key, val in value.items()
158-
)
155+
and all(isinstance(key, str) for key, val in value.items())
159156
):
160157
raise TypeError("env_vars argument was not of type dic of str and str")
161158

smartsim/settings/launch_settings.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def __init__(
127127
except ValueError:
128128
raise ValueError(f"Invalid launcher type: {launcher}")
129129

130-
if launch_args:
130+
if launch_args is not None:
131131
if not (
132132
isinstance(launch_args, t.Mapping)
133133
and all(isinstance(key, str) for key, val in launch_args.items())
@@ -175,11 +175,8 @@ def env_vars(self, value: dict[str, str | None]) -> None:
175175
:param value: The new environment mapping
176176
"""
177177
if not (
178-
isinstance(value, t.Mapping)
179-
and all(
180-
isinstance(key, str) and isinstance(val, str)
181-
for key, val in value.items()
182-
)
178+
isinstance(value, dict)
179+
and all(isinstance(key, str) for key, val in value.items())
183180
):
184181
raise TypeError("env_vars argument was not of type dic of str and str")
185182

@@ -227,14 +224,6 @@ def update_env(self, env_vars: t.Dict[str, str | None]) -> None:
227224
:param env_vars: environment variables to update or add
228225
:raises TypeError: if env_vars values cannot be coerced to strings
229226
"""
230-
if not (
231-
isinstance(env_vars, t.Mapping)
232-
and all(
233-
isinstance(key, str) and isinstance(val, str)
234-
for key, val in env_vars.items()
235-
)
236-
):
237-
raise TypeError("env_vars argument was not of type dic of str and str")
238227

239228
# Coerce env_vars values to str as a convenience to user
240229
for env, val in env_vars.items():

tests/temp_tests/test_settings/test_batchSettings.py

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,16 @@ def test_type_batch_scheduler():
9090
)
9191

9292

93-
def test_type_batch_args():
94-
batch_args = "invalid"
93+
@pytest.mark.parametrize(
94+
"batch_args",
95+
[
96+
pytest.param("invalid", id="invalid"),
97+
pytest.param("", id="empty string"),
98+
pytest.param(0, id="0"),
99+
pytest.param([], id="empty list"),
100+
],
101+
)
102+
def test_type_batch_args(batch_args):
95103
with pytest.raises(
96104
TypeError, match="batch_args argument was not of type mapping of str and str"
97105
):
@@ -108,3 +116,84 @@ def test_type_env_vars():
108116
TypeError, match="env_vars argument was not of type dic of str and str"
109117
):
110118
BatchSettings(batch_scheduler="slurm", env_vars=env_vars)
119+
120+
121+
@pytest.mark.parametrize(
122+
"scheduler",
123+
[
124+
pytest.param("slurm", id="slurm scheduler"),
125+
pytest.param("lsf", id="bsub scheduler"),
126+
pytest.param("pbs", id="qsub scheduler"),
127+
],
128+
)
129+
def test_batch_arguments_type_set_nodes(scheduler):
130+
bs = BatchSettings(batch_scheduler=scheduler, env_vars={"ENV": "VAR"})
131+
with pytest.raises(TypeError, match="num_nodes argument was not of type int"):
132+
bs.batch_args.set_nodes("invalid")
133+
134+
135+
@pytest.mark.parametrize(
136+
"scheduler",
137+
[
138+
pytest.param("slurm", id="slurm scheduler"),
139+
pytest.param("lsf", id="bsub scheduler"),
140+
pytest.param("pbs", id="qsub scheduler"),
141+
],
142+
)
143+
def test_batch_arguments_type_set_account(scheduler):
144+
bs = BatchSettings(batch_scheduler=scheduler, env_vars={"ENV": "VAR"})
145+
146+
with pytest.raises(TypeError, match="account argument was not of type str"):
147+
bs.batch_args.set_account(27)
148+
149+
150+
@pytest.mark.parametrize(
151+
"scheduler",
152+
[
153+
pytest.param("slurm", id="slurm scheduler"),
154+
pytest.param("lsf", id="bsub scheduler"),
155+
pytest.param("pbs", id="qsub scheduler"),
156+
],
157+
)
158+
def test_batch_arguments_type_set_queue(scheduler):
159+
bs = BatchSettings(batch_scheduler=scheduler, env_vars={"ENV": "VAR"})
160+
with pytest.raises(TypeError, match="queue argument was not of type str"):
161+
bs.batch_args.set_queue(27)
162+
163+
164+
@pytest.mark.parametrize(
165+
"scheduler",
166+
[
167+
pytest.param("slurm", id="slurm scheduler"),
168+
pytest.param("lsf", id="bsub scheduler"),
169+
pytest.param("pbs", id="qsub scheduler"),
170+
],
171+
)
172+
def test_batch_arguments_type_set_hostlist(scheduler):
173+
bs = BatchSettings(batch_scheduler=scheduler, env_vars={"ENV": "VAR"})
174+
with pytest.raises(TypeError, match="host_list argument must be a list of strings"):
175+
bs.batch_args.set_hostlist([25, 37])
176+
177+
178+
def test_batch_arguments_type_set_ncpus():
179+
bs = BatchSettings(batch_scheduler="pbs", env_vars={"ENV": "VAR"})
180+
with pytest.raises(TypeError, match="num_cpus argument was not of type int"):
181+
bs.batch_args.set_ncpus("invalid")
182+
183+
184+
def test_batch_arguments_type_set_smts():
185+
bs = BatchSettings(batch_scheduler="lsf", env_vars={"ENV": "VAR"})
186+
with pytest.raises(TypeError, match="smts argument was not of type int"):
187+
bs.batch_args.set_smts("invalid")
188+
189+
190+
def test_batch_arguments_type_set_project():
191+
bs = BatchSettings(batch_scheduler="lsf", env_vars={"ENV": "VAR"})
192+
with pytest.raises(TypeError, match="project argument was not of type str"):
193+
bs.batch_args.set_project(27)
194+
195+
196+
def test_batch_arguments_type_set_tasks():
197+
bs = BatchSettings(batch_scheduler="lsf", env_vars={"ENV": "VAR"})
198+
with pytest.raises(TypeError, match="tasks argument was not of type int"):
199+
bs.batch_args.set_tasks("invalid")

0 commit comments

Comments
 (0)