Skip to content

Commit 3c75ce0

Browse files
authored
Fix group as list of ints in torch dist collective ops (#3340)
1 parent 6401a59 commit 3c75ce0

File tree

6 files changed

+178
-157
lines changed

6 files changed

+178
-157
lines changed

ignite/distributed/comp_models/native.py

+17-11
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,15 @@ def spawn(
408408
**spawn_kwargs,
409409
)
410410

411+
def _setup_group(self, group: Optional[Any]) -> dist.ProcessGroup:
412+
if isinstance(group, list) and all(isinstance(item, int) for item in group):
413+
group = self._do_new_group(group)
414+
if not (isinstance(group, dist.ProcessGroup) or group == dist.GroupMember.NON_GROUP_MEMBER):
415+
raise ValueError(
416+
f"Argument group should be list of int or ProcessGroup, got {type(group)}, group={group}"
417+
)
418+
return group
419+
411420
_reduce_op_map = {
412421
"SUM": dist.ReduceOp.SUM,
413422
"PRODUCT": dist.ReduceOp.PRODUCT,
@@ -420,8 +429,8 @@ def spawn(
420429
def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[Any] = None) -> torch.Tensor:
421430
if op not in self._reduce_op_map:
422431
raise ValueError(f"Unsupported reduction operation: '{op}'")
423-
if group is not None and not isinstance(group, dist.ProcessGroup):
424-
raise ValueError("Argument group should be list of int or ProcessGroup")
432+
if group is not None:
433+
group = self._setup_group(group)
425434
reduce_op = self._reduce_op_map[op]
426435
# We do if/else here for compatibility with older pytorch versions
427436
if group is not None:
@@ -431,15 +440,14 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[
431440
return tensor
432441

433442
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
443+
if group is not None:
444+
group = self._setup_group(group)
434445
if group == dist.GroupMember.NON_GROUP_MEMBER:
435446
return tensor
436-
437447
if group is None:
438448
group_size = self.get_world_size()
439-
elif isinstance(group, dist.ProcessGroup):
440-
group_size = group.size()
441449
else:
442-
raise ValueError("Argument group should be list of int or ProcessGroup")
450+
group_size = group.size()
443451
if tensor.ndimension() == 0:
444452
tensor = tensor.unsqueeze(0)
445453
output = [torch.zeros_like(tensor) for _ in range(group_size)]
@@ -456,16 +464,14 @@ def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> Lis
456464
"Current torch version does not implement dist.all_gather_object. "
457465
"Required version should be >=1.7.0"
458466
)
459-
467+
if group is not None:
468+
group = self._setup_group(group)
460469
if group == dist.GroupMember.NON_GROUP_MEMBER:
461470
return tensor
462-
463471
if group is None:
464472
group_size = self.get_world_size()
465-
elif isinstance(group, dist.ProcessGroup):
466-
group_size = group.size()
467473
else:
468-
raise ValueError("Argument group should be list of int or ProcessGroup")
474+
group_size = group.size()
469475
output = [None for _ in range(group_size)]
470476
# We do if/else here for compatibility with older pytorch versions
471477
if group is not None:

ignite/distributed/utils.py

-6
Original file line numberDiff line numberDiff line change
@@ -347,9 +347,6 @@ def all_reduce(
347347
if _need_to_sync and isinstance(_model, _SerialModel):
348348
sync(temporary=True)
349349

350-
if isinstance(group, list) and all(isinstance(item, int) for item in group):
351-
group = _model.new_group(group)
352-
353350
return _model.all_reduce(tensor, op, group=group)
354351

355352

@@ -429,9 +426,6 @@ def all_gather(
429426
if _need_to_sync and isinstance(_model, _SerialModel):
430427
sync(temporary=True)
431428

432-
if isinstance(group, list) and all(isinstance(item, int) for item in group):
433-
group = _model.new_group(group)
434-
435429
return _model.all_gather(tensor, group=group)
436430

437431

tests/ignite/distributed/utils/__init__.py

+122-118
Original file line numberDiff line numberDiff line change
@@ -119,39 +119,44 @@ def _test_distrib_all_reduce(device):
119119

120120

121121
def _test_distrib_all_reduce_group(device):
122-
if idist.get_world_size() > 1 and idist.backend() is not None:
123-
ranks = [0, 1]
124-
rank = idist.get_rank()
125-
t = torch.tensor([rank], device=device)
126-
bnd = idist.backend()
122+
assert idist.get_world_size() > 1, idist.get_world_size()
123+
assert idist.backend() is not None, idist.backend()
127124

128-
group = idist.new_group(ranks)
129-
if bnd in ("horovod"):
130-
with pytest.raises(NotImplementedError, match=r"all_reduce with group for horovod is not implemented"):
131-
res = idist.all_reduce(t, group=group)
132-
else:
125+
ranks = [0, 1]
126+
rank = idist.get_rank()
127+
t = torch.tensor([rank], device=device)
128+
bnd = idist.backend()
129+
130+
group = idist.new_group(ranks)
131+
if bnd in ("horovod"):
132+
with pytest.raises(NotImplementedError, match=r"all_reduce with group for horovod is not implemented"):
133+
res = idist.all_reduce(t, group=group)
134+
else:
135+
if rank in ranks:
136+
# we should call all_reduce with group on the participating ranks only
137+
# otherwise a warning is raised:
138+
# UserWarning: Running all_reduce on global rank 2 which does not belong to the given group.
133139
res = idist.all_reduce(t, group=group)
134140
assert res == torch.tensor([sum(ranks)], device=device)
135141

136-
t = torch.tensor([rank], device=device)
137-
if bnd in ("horovod"):
138-
with pytest.raises(NotImplementedError, match=r"all_reduce with group for horovod is not implemented"):
139-
res = idist.all_reduce(t, group=ranks)
140-
else:
142+
t = torch.tensor([rank], device=device)
143+
if bnd in ("horovod"):
144+
with pytest.raises(NotImplementedError, match=r"all_reduce with group for horovod is not implemented"):
145+
res = idist.all_reduce(t, group=ranks)
146+
else:
147+
if rank in ranks:
141148
res = idist.all_reduce(t, group=ranks)
142149
assert res == torch.tensor([sum(ranks)], device=device)
143150

144-
ranks = "abc"
145-
146-
if bnd in ("nccl", "gloo", "mpi"):
147-
with pytest.raises(ValueError, match=r"Argument group should be list of int or ProcessGroup"):
148-
res = idist.all_reduce(t, group="abc")
149-
elif bnd in ("xla-tpu"):
150-
with pytest.raises(ValueError, match=r"Argument group should be list of int"):
151-
res = idist.all_reduce(t, group="abc")
152-
elif bnd in ("horovod"):
153-
with pytest.raises(NotImplementedError, match=r"all_reduce with group for horovod is not implemented"):
154-
res = idist.all_reduce(t, group="abc")
151+
if bnd in ("nccl", "gloo", "mpi"):
152+
with pytest.raises(ValueError, match=r"Argument group should be list of int or ProcessGroup"):
153+
idist.all_reduce(t, group="abc")
154+
elif bnd in ("xla-tpu"):
155+
with pytest.raises(ValueError, match=r"Argument group should be list of int"):
156+
idist.all_reduce(t, group="abc")
157+
elif bnd in ("horovod"):
158+
with pytest.raises(NotImplementedError, match=r"all_reduce with group for horovod is not implemented"):
159+
idist.all_reduce(t, group="abc")
155160

156161

157162
def _test_distrib_all_gather(device):
@@ -218,77 +223,76 @@ def _test_distrib_all_gather(device):
218223

219224

220225
def _test_distrib_all_gather_group(device):
221-
if idist.get_world_size() > 1:
222-
ranks = list(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [3, 2, 1]
223-
rank = idist.get_rank()
224-
bnd = idist.backend()
226+
assert idist.get_world_size() > 1, idist.get_world_size()
225227

226-
t = torch.tensor([rank], device=device)
227-
group = idist.new_group(ranks)
228-
if bnd in ("horovod"):
229-
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
230-
res = idist.all_gather(t, group=group)
231-
else:
232-
res = idist.all_gather(t, group=group)
233-
if rank in ranks:
234-
assert torch.equal(res, torch.tensor(ranks, device=device))
235-
else:
236-
assert res == t
228+
ranks = list(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [3, 2, 1]
229+
rank = idist.get_rank()
230+
bnd = idist.backend()
237231

238-
t = torch.tensor([rank], device=device)
239-
if bnd in ("horovod"):
240-
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
241-
res = idist.all_gather(t, group=ranks)
232+
t = torch.tensor([rank], device=device)
233+
group = idist.new_group(ranks)
234+
if bnd in ("horovod"):
235+
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
236+
res = idist.all_gather(t, group=group)
237+
else:
238+
res = idist.all_gather(t, group=group)
239+
if rank in ranks:
240+
assert torch.equal(res, torch.tensor(sorted(ranks), device=device)), res
242241
else:
243-
res = idist.all_gather(t, group=ranks)
244-
if rank in ranks:
245-
assert torch.equal(res, torch.tensor(ranks, device=device))
246-
else:
247-
assert res == t
242+
assert res == t
248243

249-
t = {
250-
"a": [rank + 1, rank + 2, torch.tensor(rank + 3, device=device)],
251-
"b": torch.tensor([[rank + 1, rank + 2, rank + 3]], device=device),
252-
"c": {"abcd": rank, "cdfg": torch.tensor(rank, dtype=torch.uint8, device=device)},
253-
}
254-
if bnd in ("xla-tpu"):
255-
with pytest.raises(NotImplementedError, match=r"all_gather on object is not implemented for xla"):
256-
res = idist.all_gather(t, group=ranks)
257-
elif bnd in ("horovod"):
258-
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
259-
res = idist.all_gather(t, group=ranks)
244+
t = torch.tensor([rank], device=device)
245+
if bnd in ("horovod"):
246+
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
247+
res = idist.all_gather(t, group=ranks)
248+
else:
249+
res = idist.all_gather(t, group=ranks)
250+
if rank in ranks:
251+
assert torch.equal(res, torch.tensor(sorted(ranks), device=device))
260252
else:
253+
assert res == t
254+
255+
t = {
256+
"a": [rank + 1, rank + 2, torch.tensor(rank + 3, device=device)],
257+
"b": torch.tensor([[rank + 1, rank + 2, rank + 3]], device=device),
258+
"c": {"abcd": rank, "cdfg": torch.tensor(rank, dtype=torch.uint8, device=device)},
259+
}
260+
if bnd in ("xla-tpu"):
261+
with pytest.raises(NotImplementedError, match=r"all_gather on object is not implemented for xla"):
261262
res = idist.all_gather(t, group=ranks)
262-
if rank in ranks:
263-
assert isinstance(res, list) and len(res) == len(ranks)
264-
for i, obj in zip(ranks, res):
265-
assert isinstance(obj, dict)
266-
assert list(obj.keys()) == ["a", "b", "c"], obj
267-
expected_device = (
268-
device
269-
if torch.device(device).type == "cpu"
270-
else torch.device(f"{torch.device(device).type}:{i}")
271-
)
272-
expected = {
273-
"a": [i + 1, i + 2, torch.tensor(i + 3, device=expected_device)],
274-
"b": torch.tensor([[i + 1, i + 2, i + 3]], device=expected_device),
275-
"c": {"abcd": i, "cdfg": torch.tensor(i, dtype=torch.uint8, device=expected_device)},
276-
}
277-
assert obj["a"] == expected["a"], (obj, expected)
278-
assert (obj["b"] == expected["b"]).all(), (obj, expected)
279-
assert obj["c"] == expected["c"], (obj, expected)
280-
else:
281-
assert res == t
263+
elif bnd in ("horovod"):
264+
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
265+
res = idist.all_gather(t, group=ranks)
266+
else:
267+
res = idist.all_gather(t, group=ranks)
268+
if rank in ranks:
269+
assert isinstance(res, list) and len(res) == len(ranks)
270+
for i, obj in zip(sorted(ranks), res):
271+
assert isinstance(obj, dict)
272+
assert list(obj.keys()) == ["a", "b", "c"], obj
273+
expected_device = (
274+
device if torch.device(device).type == "cpu" else torch.device(f"{torch.device(device).type}:{i}")
275+
)
276+
expected = {
277+
"a": [i + 1, i + 2, torch.tensor(i + 3, device=expected_device)],
278+
"b": torch.tensor([[i + 1, i + 2, i + 3]], device=expected_device),
279+
"c": {"abcd": i, "cdfg": torch.tensor(i, dtype=torch.uint8, device=expected_device)},
280+
}
281+
assert obj["a"] == expected["a"], (obj, expected)
282+
assert (obj["b"] == expected["b"]).all(), (obj, expected)
283+
assert obj["c"] == expected["c"], (obj, expected)
284+
else:
285+
assert res == t
282286

283-
if bnd in ("nccl", "gloo", "mpi"):
284-
with pytest.raises(ValueError, match=r"Argument group should be list of int or ProcessGroup"):
285-
res = idist.all_gather(t, group="abc")
286-
elif bnd in ("xla-tpu"):
287-
with pytest.raises(ValueError, match=r"Argument group should be list of int"):
288-
res = idist.all_gather(t, group="abc")
289-
elif bnd in ("horovod"):
290-
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
291-
res = idist.all_gather(t, group="abc")
287+
if bnd in ("nccl", "gloo", "mpi"):
288+
with pytest.raises(ValueError, match=r"Argument group should be list of int or ProcessGroup"):
289+
res = idist.all_gather(t, group="abc")
290+
elif bnd in ("xla-tpu"):
291+
with pytest.raises(ValueError, match=r"Argument group should be list of int"):
292+
res = idist.all_gather(t, group="abc")
293+
elif bnd in ("horovod"):
294+
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
295+
res = idist.all_gather(t, group="abc")
292296

293297

294298
def _test_idist_all_gather_tensors_with_shapes(device):
@@ -312,37 +316,37 @@ def _test_idist_all_gather_tensors_with_shapes(device):
312316

313317

314318
def _test_idist_all_gather_tensors_with_shapes_group(device):
315-
if idist.get_world_size() > 1:
316-
torch.manual_seed(41)
319+
assert idist.get_world_size(), idist.get_world_size()
320+
torch.manual_seed(41)
317321

318-
rank = idist.get_rank()
319-
ranks = list(range(1, idist.get_world_size()))
320-
ws = idist.get_world_size()
321-
bnd = idist.backend()
322+
rank = idist.get_rank()
323+
ranks = list(range(1, idist.get_world_size()))
324+
ws = idist.get_world_size()
325+
bnd = idist.backend()
326+
if rank in ranks:
327+
reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device)
328+
rank_tensor = reference[
329+
rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 1,
330+
rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 2,
331+
rank * (rank + 5) // 2 : rank * (rank + 5) // 2 + rank + 3,
332+
]
333+
else:
334+
rank_tensor = torch.tensor([rank], device=device)
335+
if bnd in ("horovod"):
336+
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
337+
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks)
338+
else:
339+
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks)
322340
if rank in ranks:
323-
reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device)
324-
rank_tensor = reference[
325-
rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 1,
326-
rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 2,
327-
rank * (rank + 5) // 2 : rank * (rank + 5) // 2 + rank + 3,
328-
]
329-
else:
330-
rank_tensor = torch.tensor([rank], device=device)
331-
if bnd in ("horovod"):
332-
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
333-
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks)
341+
for r in ranks:
342+
r_tensor = reference[
343+
r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1,
344+
r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2,
345+
r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3,
346+
]
347+
assert (r_tensor == tensors[r - 1]).all()
334348
else:
335-
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks)
336-
if rank in ranks:
337-
for r in ranks:
338-
r_tensor = reference[
339-
r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1,
340-
r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2,
341-
r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3,
342-
]
343-
assert (r_tensor == tensors[r - 1]).all()
344-
else:
345-
assert [rank_tensor] == tensors
349+
assert [rank_tensor] == tensors
346350

347351

348352
def _test_distrib_broadcast(device):

0 commit comments

Comments
 (0)