diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index 9db046b749..a9ccb35b79 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -67,6 +67,8 @@ def setup_distributed(): device_mesh = init_device_mesh("cuda", (world_size,)) # seed must be the same in all processes torch.manual_seed(1) + local_rank = torch.distributed.get_rank() + torch.cuda.set_device(local_rank) return device_mesh diff --git a/test/float8/test_fsdp2_tp.py b/test/float8/test_fsdp2_tp.py index fa3d30410b..f04b791273 100644 --- a/test/float8/test_fsdp2_tp.py +++ b/test/float8/test_fsdp2_tp.py @@ -46,6 +46,8 @@ def setup_distributed(): ) # seed must be the same in all processes torch.manual_seed(1) + local_rank = torch.distributed.get_rank() + torch.cuda.set_device(local_rank) return device_mesh