Skip to content

Commit b115778

Browse files
committed
Fix group as list of ints in torch dist collective ops
1 parent 6401a59 commit b115778

File tree

4 files changed

+113
-109
lines changed

4 files changed

+113
-109
lines changed

ignite/distributed/comp_models/native.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,15 @@ def spawn(
408408
**spawn_kwargs,
409409
)
410410

411+
def _setup_group(self, group: Optional[Any]) -> dist.ProcessGroup:
412+
if isinstance(group, list) and all(isinstance(item, int) for item in group):
413+
group = self._do_new_group(group)
414+
if not (isinstance(group, dist.ProcessGroup) or group == dist.GroupMember.NON_GROUP_MEMBER):
415+
raise ValueError(
416+
f"Argument group should be list of int or ProcessGroup, got {type(group)}, group={group}"
417+
)
418+
return group
419+
411420
_reduce_op_map = {
412421
"SUM": dist.ReduceOp.SUM,
413422
"PRODUCT": dist.ReduceOp.PRODUCT,
@@ -420,8 +429,8 @@ def spawn(
420429
def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[Any] = None) -> torch.Tensor:
421430
if op not in self._reduce_op_map:
422431
raise ValueError(f"Unsupported reduction operation: '{op}'")
423-
if group is not None and not isinstance(group, dist.ProcessGroup):
424-
raise ValueError("Argument group should be list of int or ProcessGroup")
432+
if group is not None:
433+
group = self._setup_group(group)
425434
reduce_op = self._reduce_op_map[op]
426435
# We do if/else here for compatibility with older pytorch versions
427436
if group is not None:
@@ -431,15 +440,14 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[
431440
return tensor
432441

433442
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
443+
if group is not None:
444+
group = self._setup_group(group)
434445
if group == dist.GroupMember.NON_GROUP_MEMBER:
435446
return tensor
436-
437447
if group is None:
438448
group_size = self.get_world_size()
439-
elif isinstance(group, dist.ProcessGroup):
440-
group_size = group.size()
441449
else:
442-
raise ValueError("Argument group should be list of int or ProcessGroup")
450+
group_size = group.size()
443451
if tensor.ndimension() == 0:
444452
tensor = tensor.unsqueeze(0)
445453
output = [torch.zeros_like(tensor) for _ in range(group_size)]
@@ -456,16 +464,14 @@ def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> Lis
456464
"Current torch version does not implement dist.all_gather_object. "
457465
"Required version should be >=1.7.0"
458466
)
459-
467+
if group is not None:
468+
group = self._setup_group(group)
460469
if group == dist.GroupMember.NON_GROUP_MEMBER:
461470
return tensor
462-
463471
if group is None:
464472
group_size = self.get_world_size()
465-
elif isinstance(group, dist.ProcessGroup):
466-
group_size = group.size()
467473
else:
468-
raise ValueError("Argument group should be list of int or ProcessGroup")
474+
group_size = group.size()
469475
output = [None for _ in range(group_size)]
470476
# We do if/else here for compatibility with older pytorch versions
471477
if group is not None:

ignite/distributed/utils.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -347,9 +347,6 @@ def all_reduce(
347347
if _need_to_sync and isinstance(_model, _SerialModel):
348348
sync(temporary=True)
349349

350-
if isinstance(group, list) and all(isinstance(item, int) for item in group):
351-
group = _model.new_group(group)
352-
353350
return _model.all_reduce(tensor, op, group=group)
354351

355352

@@ -429,9 +426,6 @@ def all_gather(
429426
if _need_to_sync and isinstance(_model, _SerialModel):
430427
sync(temporary=True)
431428

432-
if isinstance(group, list) and all(isinstance(item, int) for item in group):
433-
group = _model.new_group(group)
434-
435429
return _model.all_gather(tensor, group=group)
436430

437431

tests/ignite/distributed/utils/__init__.py

Lines changed: 95 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -119,39 +119,44 @@ def _test_distrib_all_reduce(device):
119119

120120

121121
def _test_distrib_all_reduce_group(device):
122-
if idist.get_world_size() > 1 and idist.backend() is not None:
123-
ranks = [0, 1]
124-
rank = idist.get_rank()
125-
t = torch.tensor([rank], device=device)
126-
bnd = idist.backend()
122+
assert idist.get_world_size() > 1, idist.get_world_size()
123+
assert idist.backend() is not None, idist.backend()
127124

128-
group = idist.new_group(ranks)
129-
if bnd in ("horovod"):
130-
with pytest.raises(NotImplementedError, match=r"all_reduce with group for horovod is not implemented"):
131-
res = idist.all_reduce(t, group=group)
132-
else:
125+
ranks = [0, 1]
126+
rank = idist.get_rank()
127+
t = torch.tensor([rank], device=device)
128+
bnd = idist.backend()
129+
130+
group = idist.new_group(ranks)
131+
if bnd in ("horovod"):
132+
with pytest.raises(NotImplementedError, match=r"all_reduce with group for horovod is not implemented"):
133+
res = idist.all_reduce(t, group=group)
134+
else:
135+
if rank in ranks:
136+
# we should call all_reduce with group on the participating ranks only
137+
# otherwise a warning is raised:
138+
# UserWarning: Running all_reduce on global rank 2 which does not belong to the given group.
133139
res = idist.all_reduce(t, group=group)
134140
assert res == torch.tensor([sum(ranks)], device=device)
135141

136-
t = torch.tensor([rank], device=device)
137-
if bnd in ("horovod"):
138-
with pytest.raises(NotImplementedError, match=r"all_reduce with group for horovod is not implemented"):
139-
res = idist.all_reduce(t, group=ranks)
140-
else:
142+
t = torch.tensor([rank], device=device)
143+
if bnd in ("horovod"):
144+
with pytest.raises(NotImplementedError, match=r"all_reduce with group for horovod is not implemented"):
145+
res = idist.all_reduce(t, group=ranks)
146+
else:
147+
if rank in ranks:
141148
res = idist.all_reduce(t, group=ranks)
142149
assert res == torch.tensor([sum(ranks)], device=device)
143150

144-
ranks = "abc"
145-
146-
if bnd in ("nccl", "gloo", "mpi"):
147-
with pytest.raises(ValueError, match=r"Argument group should be list of int or ProcessGroup"):
148-
res = idist.all_reduce(t, group="abc")
149-
elif bnd in ("xla-tpu"):
150-
with pytest.raises(ValueError, match=r"Argument group should be list of int"):
151-
res = idist.all_reduce(t, group="abc")
152-
elif bnd in ("horovod"):
153-
with pytest.raises(NotImplementedError, match=r"all_reduce with group for horovod is not implemented"):
154-
res = idist.all_reduce(t, group="abc")
151+
if bnd in ("nccl", "gloo", "mpi"):
152+
with pytest.raises(ValueError, match=r"Argument group should be list of int or ProcessGroup"):
153+
idist.all_reduce(t, group="abc")
154+
elif bnd in ("xla-tpu"):
155+
with pytest.raises(ValueError, match=r"Argument group should be list of int"):
156+
idist.all_reduce(t, group="abc")
157+
elif bnd in ("horovod"):
158+
with pytest.raises(NotImplementedError, match=r"all_reduce with group for horovod is not implemented"):
159+
idist.all_reduce(t, group="abc")
155160

156161

157162
def _test_distrib_all_gather(device):
@@ -218,77 +223,76 @@ def _test_distrib_all_gather(device):
218223

219224

220225
def _test_distrib_all_gather_group(device):
221-
if idist.get_world_size() > 1:
222-
ranks = list(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [3, 2, 1]
223-
rank = idist.get_rank()
224-
bnd = idist.backend()
226+
assert idist.get_world_size() > 1, idist.get_world_size()
225227

226-
t = torch.tensor([rank], device=device)
227-
group = idist.new_group(ranks)
228-
if bnd in ("horovod"):
229-
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
230-
res = idist.all_gather(t, group=group)
231-
else:
232-
res = idist.all_gather(t, group=group)
233-
if rank in ranks:
234-
assert torch.equal(res, torch.tensor(ranks, device=device))
235-
else:
236-
assert res == t
228+
ranks = list(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [3, 2, 1]
229+
rank = idist.get_rank()
230+
bnd = idist.backend()
237231

238-
t = torch.tensor([rank], device=device)
239-
if bnd in ("horovod"):
240-
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
241-
res = idist.all_gather(t, group=ranks)
232+
t = torch.tensor([rank], device=device)
233+
group = idist.new_group(ranks)
234+
if bnd in ("horovod"):
235+
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
236+
res = idist.all_gather(t, group=group)
237+
else:
238+
res = idist.all_gather(t, group=group)
239+
if rank in ranks:
240+
assert torch.equal(res, torch.tensor(sorted(ranks), device=device)), res
242241
else:
243-
res = idist.all_gather(t, group=ranks)
244-
if rank in ranks:
245-
assert torch.equal(res, torch.tensor(ranks, device=device))
246-
else:
247-
assert res == t
242+
assert res == t
248243

249-
t = {
250-
"a": [rank + 1, rank + 2, torch.tensor(rank + 3, device=device)],
251-
"b": torch.tensor([[rank + 1, rank + 2, rank + 3]], device=device),
252-
"c": {"abcd": rank, "cdfg": torch.tensor(rank, dtype=torch.uint8, device=device)},
253-
}
254-
if bnd in ("xla-tpu"):
255-
with pytest.raises(NotImplementedError, match=r"all_gather on object is not implemented for xla"):
256-
res = idist.all_gather(t, group=ranks)
257-
elif bnd in ("horovod"):
258-
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
259-
res = idist.all_gather(t, group=ranks)
244+
t = torch.tensor([rank], device=device)
245+
if bnd in ("horovod"):
246+
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
247+
res = idist.all_gather(t, group=ranks)
248+
else:
249+
res = idist.all_gather(t, group=ranks)
250+
if rank in ranks:
251+
assert torch.equal(res, torch.tensor(sorted(ranks), device=device))
260252
else:
253+
assert res == t
254+
255+
t = {
256+
"a": [rank + 1, rank + 2, torch.tensor(rank + 3, device=device)],
257+
"b": torch.tensor([[rank + 1, rank + 2, rank + 3]], device=device),
258+
"c": {"abcd": rank, "cdfg": torch.tensor(rank, dtype=torch.uint8, device=device)},
259+
}
260+
if bnd in ("xla-tpu"):
261+
with pytest.raises(NotImplementedError, match=r"all_gather on object is not implemented for xla"):
261262
res = idist.all_gather(t, group=ranks)
262-
if rank in ranks:
263-
assert isinstance(res, list) and len(res) == len(ranks)
264-
for i, obj in zip(ranks, res):
265-
assert isinstance(obj, dict)
266-
assert list(obj.keys()) == ["a", "b", "c"], obj
267-
expected_device = (
268-
device
269-
if torch.device(device).type == "cpu"
270-
else torch.device(f"{torch.device(device).type}:{i}")
271-
)
272-
expected = {
273-
"a": [i + 1, i + 2, torch.tensor(i + 3, device=expected_device)],
274-
"b": torch.tensor([[i + 1, i + 2, i + 3]], device=expected_device),
275-
"c": {"abcd": i, "cdfg": torch.tensor(i, dtype=torch.uint8, device=expected_device)},
276-
}
277-
assert obj["a"] == expected["a"], (obj, expected)
278-
assert (obj["b"] == expected["b"]).all(), (obj, expected)
279-
assert obj["c"] == expected["c"], (obj, expected)
280-
else:
281-
assert res == t
282-
283-
if bnd in ("nccl", "gloo", "mpi"):
284-
with pytest.raises(ValueError, match=r"Argument group should be list of int or ProcessGroup"):
285-
res = idist.all_gather(t, group="abc")
286-
elif bnd in ("xla-tpu"):
287-
with pytest.raises(ValueError, match=r"Argument group should be list of int"):
288-
res = idist.all_gather(t, group="abc")
289-
elif bnd in ("horovod"):
290-
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
291-
res = idist.all_gather(t, group="abc")
263+
elif bnd in ("horovod"):
264+
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
265+
res = idist.all_gather(t, group=ranks)
266+
else:
267+
res = idist.all_gather(t, group=ranks)
268+
if rank in ranks:
269+
assert isinstance(res, list) and len(res) == len(ranks)
270+
for i, obj in zip(sorted(ranks), res):
271+
assert isinstance(obj, dict)
272+
assert list(obj.keys()) == ["a", "b", "c"], obj
273+
expected_device = (
274+
device if torch.device(device).type == "cpu" else torch.device(f"{torch.device(device).type}:{i}")
275+
)
276+
expected = {
277+
"a": [i + 1, i + 2, torch.tensor(i + 3, device=expected_device)],
278+
"b": torch.tensor([[i + 1, i + 2, i + 3]], device=expected_device),
279+
"c": {"abcd": i, "cdfg": torch.tensor(i, dtype=torch.uint8, device=expected_device)},
280+
}
281+
assert obj["a"] == expected["a"], (obj, expected)
282+
assert (obj["b"] == expected["b"]).all(), (obj, expected)
283+
assert obj["c"] == expected["c"], (obj, expected)
284+
else:
285+
assert res == t
286+
287+
if bnd in ("nccl", "gloo", "mpi"):
288+
with pytest.raises(ValueError, match=r"Argument group should be list of int or ProcessGroup"):
289+
res = idist.all_gather(t, group="abc")
290+
elif bnd in ("xla-tpu"):
291+
with pytest.raises(ValueError, match=r"Argument group should be list of int"):
292+
res = idist.all_gather(t, group="abc")
293+
elif bnd in ("horovod"):
294+
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
295+
res = idist.all_gather(t, group="abc")
292296

293297

294298
def _test_idist_all_gather_tensors_with_shapes(device):

tests/run_cpu_tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ fi
2222
# Run 2 processes with --dist=each
2323
CUDA_VISIBLE_DEVICES="" run_tests \
2424
--core_args "-m distributed -vvv tests/ignite" \
25-
--world_size 2 \
25+
--world_size 4 \
2626
--cache_dir ".cpu-distrib" \
2727
--skip_distrib_tests 0 \
2828
--use_coverage 1 \

0 commit comments

Comments
 (0)