Skip to content

Commit 0368421

Browse files
authored
Fix test_gpu_distributed_initialize() to get a more proper GPUs number (#503)
1 parent 58a5c56 commit 0368421

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/multiprocess_gpu_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_gpu_distributed_initialize(self):
5050

5151
port = portpicker.pick_unused_port()
5252
# ROCm will fail if HIP_VISIBLE_DEVICES is set to a single non existent device
53-
num_gpus = 4 if not jtu.is_device_rocm() else min(4, len(jax.devices()))
53+
num_gpus = 4 if not jtu.is_device_rocm() else min(4, jax.local_device_count())
5454
num_gpus_per_task = 1
5555
num_tasks = num_gpus // num_gpus_per_task
5656

0 commit comments

Comments
 (0)