Skip to content

Add _rank_not_in_group to idist #3339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ignite/distributed/comp_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,10 @@ def barrier(self) -> None:
def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
pass

@abstractmethod
def _rank_not_in_group(self, group: Any) -> bool:
pass


class _SerialModel(ComputationModel):
"""Private class defines non-distributed computation model for code compatibility with other distributed models."""
Expand Down Expand Up @@ -396,3 +400,6 @@ def new_group(self, ranks: List[int], **kwargs: Any) -> Any:
return self._do_new_group(ranks, **kwargs)
else:
raise ValueError("Argument ranks should be list of int")

def _rank_not_in_group(self, group: Any) -> bool:
return False
31 changes: 27 additions & 4 deletions ignite/distributed/comp_models/horovod.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import warnings
from typing import Any, Callable, cast, List, Mapping, Optional, Tuple

Expand All @@ -23,6 +24,9 @@
if has_hvd_support:
HOROVOD = "horovod"

# Enables dynamic process sets: new_group methods and passing group into collective ops
os.environ["HOROVOD_DYNAMIC_PROCESS_SETS"] = "1"

class _HorovodDistModel(ComputationModel):
"""Private class for `Horovod <https://horovod.readthedocs.io/en/stable/>`_ distributed computation model."""

Expand Down Expand Up @@ -155,6 +159,15 @@ def spawn(
**kwargs,
)

def _setup_group(self, group: Any) -> hvd.ProcessSet:
if isinstance(group, list) and all(isinstance(item, int) for item in group):
group = self._do_new_group(group)
if not isinstance(group, hvd.ProcessSet):
raise ValueError(
f"Argument group should be list of int or hvd.ProcessSet, got {type(group)}, group={group}"
)
return group

_reduce_op_map = {
"SUM": hvd.mpi_ops.Sum,
"AVERAGE": hvd.mpi_ops.Average,
Expand Down Expand Up @@ -187,19 +200,24 @@ def _do_manual_all_reduce(self, tensor: torch.Tensor, op: Any) -> torch.Tensor:

def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
if group is not None:
raise NotImplementedError("all_gather with group for horovod is not implemented")
group = self._setup_group(group)
if self._rank_not_in_group(group):
return tensor
if tensor.ndimension() == 0:
tensor = tensor.unsqueeze(0)
return hvd.allgather(tensor)
if group is not None:
return hvd.allgather(tensor, process_set=group)
else:
return hvd.allgather(tensor)

def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]:
if group is not None:
raise NotImplementedError("all_gather with group for horovod is not implemented")

return hvd.allgather_object(tensor)

def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
return hvd.ProcessSet(ranks)
def _do_new_group(self, ranks: List[int], **kwargs: Any) -> hvd.ProcessSet:
return hvd.add_process_set(ranks)

def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
return hvd.broadcast(tensor, root_rank=src)
Expand All @@ -208,3 +226,8 @@ def barrier(self) -> None:
# https://github.com/horovod/horovod/issues/159#issuecomment-424834603
# hvd.allreduce(torch.tensor(0, device=self.device()), name="barrier")
hvd.allreduce(torch.tensor(0, device="cpu"), name="barrier")

def _rank_not_in_group(self, group: Optional[Any]) -> bool:
if group is None:
return False
return not group.included()
9 changes: 6 additions & 3 deletions ignite/distributed/comp_models/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def spawn(
**spawn_kwargs,
)

def _setup_group(self, group: Optional[Any]) -> dist.ProcessGroup:
def _setup_group(self, group: Any) -> dist.ProcessGroup:
if isinstance(group, list) and all(isinstance(item, int) for item in group):
group = self._do_new_group(group)
if not (isinstance(group, dist.ProcessGroup) or group == dist.GroupMember.NON_GROUP_MEMBER):
Expand Down Expand Up @@ -442,7 +442,7 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
if group is not None:
group = self._setup_group(group)
if group == dist.GroupMember.NON_GROUP_MEMBER:
if self._rank_not_in_group(group):
return tensor
if group is None:
group_size = self.get_world_size()
Expand All @@ -466,7 +466,7 @@ def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> Lis
)
if group is not None:
group = self._setup_group(group)
if group == dist.GroupMember.NON_GROUP_MEMBER:
if self._rank_not_in_group(group):
return tensor
if group is None:
group_size = self.get_world_size()
Expand All @@ -491,6 +491,9 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
def barrier(self) -> None:
dist.barrier()

def _rank_not_in_group(self, group: Optional[Any]) -> bool:
return dist._rank_not_in_group(group)

def _expand_hostlist(nodelist: str) -> List[str]:
"""Expand a compressed hostlist string and returns all hosts listed.

Expand Down
3 changes: 3 additions & 0 deletions ignite/distributed/comp_models/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,6 @@ def _check_group_type(self, group: Optional[Any]) -> bool:
if isinstance(group, list) and all(isinstance(item, int) for item in group):
return True
return False

def _rank_not_in_group(self, group: Any) -> bool:
return self.get_rank() not in group
16 changes: 12 additions & 4 deletions ignite/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
import socket
from contextlib import contextmanager
from functools import wraps
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Callable, cast, List, Mapping, Optional, Sequence, Tuple, Union

import torch
from torch import distributed as dist

from ignite.distributed.comp_models import (
_SerialModel,
Expand Down Expand Up @@ -384,15 +383,15 @@ def all_gather_tensors_with_shapes(
if isinstance(group, list) and all(isinstance(item, int) for item in group):
group = _model.new_group(group)

if isinstance(_model, _SerialModel) or group == dist.GroupMember.NON_GROUP_MEMBER:
if _rank_not_in_group(group):
return [tensor]

max_shape = torch.tensor(shapes).amax(dim=0)
padding_sizes = (max_shape - torch.tensor(tensor.shape)).tolist()
padded_tensor = torch.nn.functional.pad(
tensor, tuple(itertools.chain.from_iterable(map(lambda dim_size: (0, dim_size), reversed(padding_sizes))))
)
all_padded_tensors: torch.Tensor = _model.all_gather(padded_tensor, group=group)
all_padded_tensors: torch.Tensor = cast(torch.Tensor, _model.all_gather(padded_tensor, group=group))
return [
all_padded_tensors[
[
Expand Down Expand Up @@ -731,3 +730,12 @@ def download_dataset():

if current_rank == rank:
barrier()


def _rank_not_in_group(group: Optional[Union[Any, List[int]]]) -> bool:
"""Check if the current process's rank is not in a given group."""
if group is None:
return False
if isinstance(group, list) and all(isinstance(item, int) for item in group):
group = new_group(group)
return _model._rank_not_in_group(group)
108 changes: 48 additions & 60 deletions tests/ignite/distributed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.distributed as dist

import ignite.distributed as idist
from ignite.distributed.utils import all_gather_tensors_with_shapes, sync
from ignite.distributed.utils import _rank_not_in_group, all_gather_tensors_with_shapes, sync
from ignite.engine import Engine, Events


Expand Down Expand Up @@ -122,7 +122,7 @@ def _test_distrib_all_reduce_group(device):
assert idist.get_world_size() > 1, idist.get_world_size()
assert idist.backend() is not None, idist.backend()

ranks = [0, 1]
ranks = sorted(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [3, 2, 1]
rank = idist.get_rank()
t = torch.tensor([rank], device=device)
bnd = idist.backend()
Expand Down Expand Up @@ -225,32 +225,27 @@ def _test_distrib_all_gather(device):
def _test_distrib_all_gather_group(device):
assert idist.get_world_size() > 1, idist.get_world_size()

ranks = list(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [3, 2, 1]
ranks = sorted(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [3, 2, 1]
rank = idist.get_rank()
bnd = idist.backend()

t = torch.tensor([rank], device=device)
group = idist.new_group(ranks)
if bnd in ("horovod"):
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
res = idist.all_gather(t, group=group)
res = idist.all_gather(t, group=group)
if rank in ranks:
assert torch.equal(res, torch.tensor(ranks, device=device))
else:
res = idist.all_gather(t, group=group)
if rank in ranks:
assert torch.equal(res, torch.tensor(sorted(ranks), device=device)), res
else:
assert res == t
assert res == t

t = torch.tensor([rank], device=device)
if bnd in ("horovod"):
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
res = idist.all_gather(t, group=ranks)
if bnd == "horovod":
res = idist.all_gather(t, group=group)
else:
res = idist.all_gather(t, group=ranks)
if rank in ranks:
assert torch.equal(res, torch.tensor(sorted(ranks), device=device))
else:
assert res == t
if rank in ranks:
assert torch.equal(res, torch.tensor(ranks, device=device))
else:
assert res == t

t = {
"a": [rank + 1, rank + 2, torch.tensor(rank + 3, device=device)],
Expand All @@ -262,12 +257,12 @@ def _test_distrib_all_gather_group(device):
res = idist.all_gather(t, group=ranks)
elif bnd in ("horovod"):
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
res = idist.all_gather(t, group=ranks)
res = idist.all_gather(t, group=group)
else:
res = idist.all_gather(t, group=ranks)
if rank in ranks:
assert isinstance(res, list) and len(res) == len(ranks)
for i, obj in zip(sorted(ranks), res):
for i, obj in zip(ranks, res):
assert isinstance(obj, dict)
assert list(obj.keys()) == ["a", "b", "c"], obj
expected_device = (
Expand All @@ -284,22 +279,20 @@ def _test_distrib_all_gather_group(device):
else:
assert res == t

if bnd in ("nccl", "gloo", "mpi"):
with pytest.raises(ValueError, match=r"Argument group should be list of int or ProcessGroup"):
t = torch.tensor([rank], device=device)
if bnd in ("nccl", "gloo", "mpi", "horovod"):
with pytest.raises(ValueError, match=r"Argument group should be list of int"):
res = idist.all_gather(t, group="abc")
elif bnd in ("xla-tpu"):
with pytest.raises(ValueError, match=r"Argument group should be list of int"):
res = idist.all_gather(t, group="abc")
elif bnd in ("horovod"):
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
res = idist.all_gather(t, group="abc")


def _test_idist_all_gather_tensors_with_shapes(device):
torch.manual_seed(41)
rank = idist.get_rank()
ws = idist.get_world_size()
reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device)
reference = torch.randn(ws * 5, ws * 5, ws * 5, device=device)
rank_tensor = reference[
rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 1,
rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 2,
Expand All @@ -312,41 +305,37 @@ def _test_idist_all_gather_tensors_with_shapes(device):
r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2,
r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3,
]
assert (r_tensor == tensors[r]).all()
assert r_tensor.allclose(tensors[r])


def _test_idist_all_gather_tensors_with_shapes_group(device):
assert idist.get_world_size(), idist.get_world_size()
torch.manual_seed(41)

rank = idist.get_rank()
ranks = list(range(1, idist.get_world_size()))
ranks = sorted(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [1, 2, 3]
ws = idist.get_world_size()
bnd = idist.backend()
if rank in ranks:
reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device)
reference = torch.randn(ws * 5, ws * 5, ws * 5, device=device)
rank_tensor = reference[
rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 1,
rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 2,
rank * (rank + 5) // 2 : rank * (rank + 5) // 2 + rank + 3,
]
else:
rank_tensor = torch.tensor([rank], device=device)
if bnd in ("horovod"):
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks)

tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks)
if rank in ranks:
for i, r in enumerate(ranks):
r_tensor = reference[
r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1,
r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2,
r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3,
]
assert r_tensor.allclose(tensors[i])
else:
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks)
if rank in ranks:
for r in ranks:
r_tensor = reference[
r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1,
r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2,
r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3,
]
assert (r_tensor == tensors[r - 1]).all()
else:
assert [rank_tensor] == tensors
assert [rank_tensor] == tensors


def _test_distrib_broadcast(device):
Expand Down Expand Up @@ -413,31 +402,30 @@ def _test_distrib_barrier(device):
assert tt.item() == true_res + 10.0


def _test_distrib_new_group(device):
def _test_distrib_group(device):
ranks = sorted(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [1, 2, 3]
if idist.get_world_size() > 1 and idist.backend() is not None:
bnd = idist.backend()
ranks = [0, 1]
rank = idist.get_rank()
g = idist.new_group(ranks)
if idist.has_native_dist_support and bnd in ("nccl", "gloo", "mpi"):
g1 = idist.new_group(ranks)
g2 = dist.new_group(ranks)

rank = idist.get_rank()
if rank in ranks:
assert g1.rank() == g2.rank()
# mapping between group ranks and global ranks
global_to_group = {r: i for i, r in enumerate(ranks)}
assert g.rank() == global_to_group[rank], (g.rank(), global_to_group, rank)

elif idist.has_xla_support and bnd in ("xla-tpu"):
assert idist.new_group(ranks) == [ranks]
assert g == [ranks]
elif idist.has_hvd_support and bnd in ("horovod"):
from horovod.common.process_sets import ProcessSet

g1 = idist.new_group(ranks)
g2 = ProcessSet(ranks)

rank = idist.get_rank()
if rank in ranks:
assert g1.ranks == g2.ranks
assert g.ranks == ranks

if rank in ranks:
assert not _rank_not_in_group(g)
else:
assert _rank_not_in_group(g)

elif idist.backend() is None:
ranks = [0, 1]
assert idist.new_group(ranks) == ranks

with pytest.raises(ValueError, match="Argument ranks should be list of int"):
Expand Down
Loading
Loading