Skip to content

Commit ac4c5fd

Browse files
Copilotmanujosephv
andcommitted
Add tests and update documentation for devices field
Co-authored-by: manujosephv <[email protected]>
1 parent 96cd400 commit ac4c5fd

File tree

2 files changed

+79
-2
lines changed

2 files changed

+79
-2
lines changed

src/pytorch_tabular/config/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,8 @@ class TrainerConfig:
287287
'cpu','gpu','tpu','ipu', 'mps', 'auto'. Defaults to 'auto'.
288288
Choices are: [`cpu`,`gpu`,`tpu`,`ipu`,'mps',`auto`].
289289
290-
devices (Optional[int]): Number of devices to train on (int). -1 uses all available devices. By
291-
default, uses all available devices (-1)
290+
devices (Union[int, List[int]]): Number of devices to train on (int), or list of device indices.
291+
-1 uses all available devices. By default, uses all available devices (-1)
292292
293293
devices_list (Optional[List[int]]): List of devices to train on (list). If specified, takes
294294
precedence over `devices` argument. Defaults to None

tests/test_config.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#!/usr/bin/env python
2+
"""Tests for config classes."""
3+
4+
import pytest
5+
from omegaconf import OmegaConf
6+
7+
from pytorch_tabular.config import TrainerConfig
8+
9+
10+
class TestTrainerConfig:
11+
"""Tests for TrainerConfig class."""
12+
13+
def test_devices_list_to_devices_conversion(self):
14+
"""Test that devices_list is properly converted to devices."""
15+
# Test with a list of devices
16+
trainer_config = TrainerConfig(devices_list=[0, 1])
17+
assert trainer_config.devices == [0, 1]
18+
19+
# Wrap with OmegaConf as done in TabularModel
20+
config = OmegaConf.structured(trainer_config)
21+
assert config.devices == [0, 1]
22+
23+
def test_devices_list_multiple_gpus(self):
24+
"""Test devices_list with multiple GPU IDs as documented."""
25+
trainer_config = TrainerConfig(devices_list=[1, 2, 3, 4])
26+
assert trainer_config.devices == [1, 2, 3, 4]
27+
28+
config = OmegaConf.structured(trainer_config)
29+
assert config.devices == [1, 2, 3, 4]
30+
31+
def test_devices_int_value(self):
32+
"""Test that devices accepts integer values."""
33+
trainer_config = TrainerConfig(devices=2)
34+
assert trainer_config.devices == 2
35+
36+
config = OmegaConf.structured(trainer_config)
37+
assert config.devices == 2
38+
39+
def test_devices_default_value(self):
40+
"""Test that devices has default value of -1."""
41+
trainer_config = TrainerConfig()
42+
assert trainer_config.devices == -1
43+
44+
config = OmegaConf.structured(trainer_config)
45+
assert config.devices == -1
46+
47+
def test_devices_list_single_device(self):
48+
"""Test devices_list with a single device."""
49+
trainer_config = TrainerConfig(devices_list=[0])
50+
assert trainer_config.devices == [0]
51+
52+
config = OmegaConf.structured(trainer_config)
53+
assert config.devices == [0]
54+
55+
def test_devices_list_precedence(self):
56+
"""Test that devices_list takes precedence over devices."""
57+
# When both are provided, devices_list should take precedence
58+
trainer_config = TrainerConfig(devices=2, devices_list=[0, 1])
59+
assert trainer_config.devices == [0, 1]
60+
61+
config = OmegaConf.structured(trainer_config)
62+
assert config.devices == [0, 1]
63+
64+
def test_omegaconf_merge_compatibility(self):
65+
"""Test that config works correctly with OmegaConf.merge."""
66+
trainer_config = TrainerConfig(devices_list=[0, 1], max_epochs=10)
67+
config = OmegaConf.structured(trainer_config)
68+
69+
# Simulate merging as done in TabularModel
70+
merged = OmegaConf.merge(
71+
OmegaConf.to_container(config),
72+
{"accelerator": "gpu"}
73+
)
74+
75+
assert merged.devices == [0, 1]
76+
assert merged.max_epochs == 10
77+
assert merged.accelerator == "gpu"

0 commit comments

Comments
 (0)