From e85239618c4eb24d9ddb2b6c1aeab8b6204f8b41 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Wed, 4 Sep 2024 12:14:12 +0200 Subject: [PATCH 1/7] torch distributed: add support for user-specified parameter synchronization --- returnn/torch/distributed.py | 134 +++++++++++++++++++++++++++++++---- 1 file changed, 121 insertions(+), 13 deletions(-) diff --git a/returnn/torch/distributed.py b/returnn/torch/distributed.py index 3bb2cf825..ed7946408 100644 --- a/returnn/torch/distributed.py +++ b/returnn/torch/distributed.py @@ -3,20 +3,78 @@ """ from __future__ import annotations -from typing import Optional, Any, Dict +from abc import abstractmethod, ABC +import logging +import numpy import os import socket -import logging +from typing import Callable, Optional, Any, Dict, Type, Union import torch from torch.nn.parallel import DistributedDataParallel -from returnn.config import Config -from returnn.util.basic import CollectionReadCheckCovered +from returnn.util.basic import CollectionReadCheckCovered, OptionalNotImplementedError _logger = logging.getLogger("returnn.torch.distributed") +class ParamSynchronizer(ABC): + """ + Custom parameter synchronization primitive. + + Contains a callback that is called after every train step to synchronize model parameters + across processes/nodes. + """ + + @abstractmethod + def __init__(self, *, rank: int, size: int, local_rank: int, local_size: int, **kwargs): + """ + `__init__` called after the default global process group is created. + Can be used to initialize any additional custom process (sub)groups. + + Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatibility. + + :param rank: global rank of the current process across all nodes + :param size: global world size across all nodes + :param local_rank: local rank of the current process on the current node + :param local_rank: local world size on the current node + :param kwargs: any additional kwargs + """ + super().__init__() + + self.rank = rank + self.size = size + self.local_rank = local_rank + self.local_size = local_size + + def make_distributed_model(self, *, module: torch.nn.Module, **kwargs) -> DistributedDataParallel: + """ + Creates an associated `DistributedDataParallel` for the given module for gradient synchronization. + + This function can be left unimplemented if no gradient synchronization is done. + + Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatibility. + """ + raise OptionalNotImplementedError + + @abstractmethod + def step(self, *, module: torch.nn.Module, train_step_idx: int, **kwargs): + """ + Parameter synchronization callback called after every train step with updated model parameters. + + Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatibility. + + :param module: the NN being trained + :param train_step_idx: the current train step + :param kwargs: any additional kwargs + """ + raise NotImplementedError + + def __call__(self, *args, **kwargs): + """forwards to :func:``step``""" + return self.step(*args, **kwargs) + + class DistributedContext: """ This class setups some helper functions for torch distributed training @@ -26,6 +84,9 @@ def __init__(self, options: Dict[str, Any]): import torch.distributed as dist self._opts = CollectionReadCheckCovered(options) + # Only used to generate forwards compatibility ensuring random kwargs, therefore + # the seed is not important + self._rng = numpy.random.default_rng() # when no backend is specified, both gloo and nccl backends will be created # the gloo backend will be used for collectives with CPU tensors and @@ -42,8 +103,13 @@ def __init__(self, options: Dict[str, Any]): % (socket.gethostname(), os.getpid(), self._rank, self._size, self._local_rank, self._local_size) ) + self._custom_sync_class: Optional[Union[Callable, Type[ParamSynchronizer]]] = self._opts.get( + "synchronizer", None + ) + self._custom_sync: Optional[Callable] = None self._reduce_type = self._opts.get("reduce_type", "grad") self._param_sync_step: Optional[int] = self._opts.get("param_sync_step", None) + if self._reduce_type == "param": assert isinstance(self._param_sync_step, int) and self._param_sync_step > 0, ( f"reduce_type param: param_sync_step must be a positive int," @@ -52,6 +118,23 @@ def __init__(self, options: Dict[str, Any]): _logger.info(f"reduce_type param: param_sync_step {self._param_sync_step}") elif self._reduce_type == "grad": _logger.info("reduce_type grad") + elif self._reduce_type == "custom": + if issubclass(self._custom_sync_class, ParamSynchronizer): + self._custom_sync = self._custom_sync_class( + rank=self._rank, + size=self._size, + local_rank=self._local_rank, + local_size=self._local_size, + **{f"fwd_compatible_random_kwarg_{self._rng.integers(0, 100)}": None}, + ) + elif isinstance(self._custom_sync_class, Callable): + self._custom_sync = self._custom_sync_class + else: + raise ValueError( + f"synchronizer must either be a callable or a class inheriting from {ParamSynchronizer.__name__}" + ) + + _logger.info(f"reduce_type custom: {type(self._custom_sync)}") else: raise ValueError(f"invalid reduce_type {self._reduce_type!r}") @@ -70,6 +153,8 @@ def _check_no_unknown_opts(self): self._opts.get("options") if self._reduce_type == "param": self._opts.get("sync_on_cpu") + if self._reduce_type == "custom": + self._opts.get("synchronizer") self._opts.assert_all_read() @@ -102,7 +187,24 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis """ if self._reduce_type == "param": return None - assert self._reduce_type == "grad" + assert self._reduce_type in ["custom", "grad"] + + if self._reduce_type == "custom": + assert isinstance(self._custom_sync, (ParamSynchronizer, Callable)) + + if isinstance(self._custom_sync, ParamSynchronizer): + try: + return self._custom_sync.make_distributed_model( + module=module, **{f"fwd_compatible_random_kwarg_{self._rng.integers(0, 100)}": None} + ) + except OptionalNotImplementedError: + pass + else: + # callable short form does not have support for DistributedDataParallel + pass + + return None + cls = self._opts.get("class", DistributedDataParallel) if cls is not DistributedDataParallel: _logger.warning(f"Using custom class {cls} instead of DistributedDataParallel, might be unsupported.") @@ -115,7 +217,14 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis def step_after_param_update(self, *, module: torch.nn.Module, epoch_step_idx: int): """one train step""" - if self._reduce_type == "param" and ((epoch_step_idx % self._param_sync_step) == (self._param_sync_step - 1)): + if self._reduce_type == "custom": + with torch.no_grad(): # TODO: do we want this for all syncers? + self._custom_sync( + module=module, + train_step_idx=epoch_step_idx, + **{f"fwd_compatible_random_kwarg_{self._rng.integers(0, 100)}": None}, + ) + elif self._reduce_type == "param" and ((epoch_step_idx % self._param_sync_step) == (self._param_sync_step - 1)): _sync_params_avg(module=module, sync_on_cpu=self._opts.get("sync_on_cpu", False)) @@ -127,7 +236,7 @@ def get_ctx(config=None) -> Optional[DistributedContext]: """ :param Config|None config: :returns: the global context if Torch distributed is enabled, or None otherwise. - If we did not setup the context yet, it will automatically create it. + If we did not set up the context yet, it will automatically create it. """ global _is_set_up, _ctx if _is_set_up: @@ -155,7 +264,7 @@ def _sync_params_avg(*, module: torch.nn.Module, sync_on_cpu: bool = False): if sync_on_cpu: for param in module.parameters(): - # Separately move each param to CPU (instead of the whole module), to safe CPU memory. + # Separately move each param to CPU (instead of the whole module), to save CPU memory. param_cpu = param.to(torch.device("cpu")) # On CPU, we are likely using Gloo, and Gloo does not support AVG dist.all_reduce(param_cpu.data, op=dist.ReduceOp.SUM) @@ -166,12 +275,11 @@ def _sync_params_avg(*, module: torch.nn.Module, sync_on_cpu: bool = False): if dist.get_backend() == "gloo": # Gloo does not support AVG reduce_op = dist.ReduceOp.SUM + elif hasattr(dist.ReduceOp, "AVG"): + reduce_op = dist.ReduceOp.AVG else: - if hasattr(dist.ReduceOp, "AVG"): - reduce_op = dist.ReduceOp.AVG - else: - # Older PyTorch versions do not have ReduceOp.AVG. - reduce_op = dist.ReduceOp.SUM + # Older PyTorch versions do not have ReduceOp.AVG. + reduce_op = dist.ReduceOp.SUM for param in module.parameters(): dist.all_reduce(param.data, op=reduce_op) From 17b5a30a60b567ebb6a250858ade60f14128a1df Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Wed, 4 Sep 2024 13:39:58 +0200 Subject: [PATCH 2/7] introduce helper function for fwd compat kwargs --- returnn/torch/distributed.py | 14 ++++---------- returnn/util/basic.py | 12 ++++++++++++ 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/returnn/torch/distributed.py b/returnn/torch/distributed.py index ed7946408..6ea383803 100644 --- a/returnn/torch/distributed.py +++ b/returnn/torch/distributed.py @@ -5,7 +5,6 @@ from __future__ import annotations from abc import abstractmethod, ABC import logging -import numpy import os import socket from typing import Callable, Optional, Any, Dict, Type, Union @@ -13,7 +12,7 @@ import torch from torch.nn.parallel import DistributedDataParallel -from returnn.util.basic import CollectionReadCheckCovered, OptionalNotImplementedError +from returnn.util.basic import CollectionReadCheckCovered, OptionalNotImplementedError, get_fwd_compat_kwargs _logger = logging.getLogger("returnn.torch.distributed") @@ -84,9 +83,6 @@ def __init__(self, options: Dict[str, Any]): import torch.distributed as dist self._opts = CollectionReadCheckCovered(options) - # Only used to generate forwards compatibility ensuring random kwargs, therefore - # the seed is not important - self._rng = numpy.random.default_rng() # when no backend is specified, both gloo and nccl backends will be created # the gloo backend will be used for collectives with CPU tensors and @@ -125,7 +121,7 @@ def __init__(self, options: Dict[str, Any]): size=self._size, local_rank=self._local_rank, local_size=self._local_size, - **{f"fwd_compatible_random_kwarg_{self._rng.integers(0, 100)}": None}, + **get_fwd_compat_kwargs(), ) elif isinstance(self._custom_sync_class, Callable): self._custom_sync = self._custom_sync_class @@ -194,9 +190,7 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis if isinstance(self._custom_sync, ParamSynchronizer): try: - return self._custom_sync.make_distributed_model( - module=module, **{f"fwd_compatible_random_kwarg_{self._rng.integers(0, 100)}": None} - ) + return self._custom_sync.make_distributed_model(module=module, **get_fwd_compat_kwargs()) except OptionalNotImplementedError: pass else: @@ -222,7 +216,7 @@ def step_after_param_update(self, *, module: torch.nn.Module, epoch_step_idx: in self._custom_sync( module=module, train_step_idx=epoch_step_idx, - **{f"fwd_compatible_random_kwarg_{self._rng.integers(0, 100)}": None}, + **get_fwd_compat_kwargs(), ) elif self._reduce_type == "param" and ((epoch_step_idx % self._param_sync_step) == (self._param_sync_step - 1)): _sync_params_avg(module=module, sync_on_cpu=self._opts.get("sync_on_cpu", False)) diff --git a/returnn/util/basic.py b/returnn/util/basic.py index 26fb24484..a696d7242 100644 --- a/returnn/util/basic.py +++ b/returnn/util/basic.py @@ -4586,3 +4586,15 @@ def override_env_var(var_name: str, value: str): os.environ[var_name] = cur_val else: os.environ.pop(var_name) + + +_fwd_compat_rng = np.random.default_rng() + + +def get_fwd_compat_kwargs() -> Dict[str, Any]: + """ + Returns a dictionary suitable for passing as kwargs for any RETURNN userland + function where forwards compatibility wrt. additional arguments must be + ensured. + """ + return {f"fwd_compatible_random_kwarg_{_fwd_compat_rng.integers(0, 100)}": None} From b5f28f48d05294e5b8d706ac03fa3acf4491f4cb Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Wed, 4 Sep 2024 14:29:07 +0200 Subject: [PATCH 3/7] remove parameter, must be defined on subclass instead --- returnn/torch/distributed.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/returnn/torch/distributed.py b/returnn/torch/distributed.py index 6ea383803..032e52d38 100644 --- a/returnn/torch/distributed.py +++ b/returnn/torch/distributed.py @@ -26,12 +26,12 @@ class ParamSynchronizer(ABC): """ @abstractmethod - def __init__(self, *, rank: int, size: int, local_rank: int, local_size: int, **kwargs): + def __init__(self, *, rank: int, size: int, local_rank: int, local_size: int): """ `__init__` called after the default global process group is created. Can be used to initialize any additional custom process (sub)groups. - Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatibility. + Note the `__init__` is passed a randomly named kwarg on every invocation to ensure forwards compatibility. :param rank: global rank of the current process across all nodes :param size: global world size across all nodes From ccf6fa812fd4fe396d3c1b18cae69d0bc4dcdc82 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Wed, 4 Sep 2024 14:49:59 +0200 Subject: [PATCH 4/7] work around faulty lint --- returnn/torch/distributed.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/returnn/torch/distributed.py b/returnn/torch/distributed.py index 032e52d38..83633dfac 100644 --- a/returnn/torch/distributed.py +++ b/returnn/torch/distributed.py @@ -26,7 +26,7 @@ class ParamSynchronizer(ABC): """ @abstractmethod - def __init__(self, *, rank: int, size: int, local_rank: int, local_size: int): + def __init__(self, *, rank: int, size: int, local_rank: int, local_size: int, **_kwargs): """ `__init__` called after the default global process group is created. Can be used to initialize any additional custom process (sub)groups. @@ -37,7 +37,7 @@ def __init__(self, *, rank: int, size: int, local_rank: int, local_size: int): :param size: global world size across all nodes :param local_rank: local rank of the current process on the current node :param local_rank: local world size on the current node - :param kwargs: any additional kwargs + :param _kwargs: any additional kwargs """ super().__init__() From e106103af17c8482d0a7b8955f1f0d535a378d7d Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 6 Sep 2024 11:54:45 +0200 Subject: [PATCH 5/7] chore: Simplify implementation to support only callback --- returnn/torch/distributed.py | 114 ++++------------------------------- 1 file changed, 13 insertions(+), 101 deletions(-) diff --git a/returnn/torch/distributed.py b/returnn/torch/distributed.py index 83633dfac..d1c587595 100644 --- a/returnn/torch/distributed.py +++ b/returnn/torch/distributed.py @@ -3,77 +3,19 @@ """ from __future__ import annotations -from abc import abstractmethod, ABC import logging import os import socket -from typing import Callable, Optional, Any, Dict, Type, Union +from typing import Callable, Optional, Any, Dict import torch from torch.nn.parallel import DistributedDataParallel -from returnn.util.basic import CollectionReadCheckCovered, OptionalNotImplementedError, get_fwd_compat_kwargs +from returnn.util.basic import CollectionReadCheckCovered, get_fwd_compat_kwargs _logger = logging.getLogger("returnn.torch.distributed") -class ParamSynchronizer(ABC): - """ - Custom parameter synchronization primitive. - - Contains a callback that is called after every train step to synchronize model parameters - across processes/nodes. - """ - - @abstractmethod - def __init__(self, *, rank: int, size: int, local_rank: int, local_size: int, **_kwargs): - """ - `__init__` called after the default global process group is created. - Can be used to initialize any additional custom process (sub)groups. - - Note the `__init__` is passed a randomly named kwarg on every invocation to ensure forwards compatibility. - - :param rank: global rank of the current process across all nodes - :param size: global world size across all nodes - :param local_rank: local rank of the current process on the current node - :param local_rank: local world size on the current node - :param _kwargs: any additional kwargs - """ - super().__init__() - - self.rank = rank - self.size = size - self.local_rank = local_rank - self.local_size = local_size - - def make_distributed_model(self, *, module: torch.nn.Module, **kwargs) -> DistributedDataParallel: - """ - Creates an associated `DistributedDataParallel` for the given module for gradient synchronization. - - This function can be left unimplemented if no gradient synchronization is done. - - Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatibility. - """ - raise OptionalNotImplementedError - - @abstractmethod - def step(self, *, module: torch.nn.Module, train_step_idx: int, **kwargs): - """ - Parameter synchronization callback called after every train step with updated model parameters. - - Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatibility. - - :param module: the NN being trained - :param train_step_idx: the current train step - :param kwargs: any additional kwargs - """ - raise NotImplementedError - - def __call__(self, *args, **kwargs): - """forwards to :func:``step``""" - return self.step(*args, **kwargs) - - class DistributedContext: """ This class setups some helper functions for torch distributed training @@ -99,10 +41,9 @@ def __init__(self, options: Dict[str, Any]): % (socket.gethostname(), os.getpid(), self._rank, self._size, self._local_rank, self._local_size) ) - self._custom_sync_class: Optional[Union[Callable, Type[ParamSynchronizer]]] = self._opts.get( - "synchronizer", None + self._custom_step_after_param_update: Optional[Callable] = self._opts.get( + "custom_step_after_param_update", None ) - self._custom_sync: Optional[Callable] = None self._reduce_type = self._opts.get("reduce_type", "grad") self._param_sync_step: Optional[int] = self._opts.get("param_sync_step", None) @@ -114,23 +55,10 @@ def __init__(self, options: Dict[str, Any]): _logger.info(f"reduce_type param: param_sync_step {self._param_sync_step}") elif self._reduce_type == "grad": _logger.info("reduce_type grad") - elif self._reduce_type == "custom": - if issubclass(self._custom_sync_class, ParamSynchronizer): - self._custom_sync = self._custom_sync_class( - rank=self._rank, - size=self._size, - local_rank=self._local_rank, - local_size=self._local_size, - **get_fwd_compat_kwargs(), - ) - elif isinstance(self._custom_sync_class, Callable): - self._custom_sync = self._custom_sync_class - else: - raise ValueError( - f"synchronizer must either be a callable or a class inheriting from {ParamSynchronizer.__name__}" - ) - - _logger.info(f"reduce_type custom: {type(self._custom_sync)}") + elif self._reduce_type == "custom_step_after_param_update": + if not isinstance(self._custom_step_after_param_update, Callable): + raise ValueError(f"synchronizer must either be a callable") + _logger.info("reduce_type custom_step_after_param_update") else: raise ValueError(f"invalid reduce_type {self._reduce_type!r}") @@ -181,23 +109,9 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis :param module: original module :return: potentially wrapped module """ - if self._reduce_type == "param": - return None - assert self._reduce_type in ["custom", "grad"] - - if self._reduce_type == "custom": - assert isinstance(self._custom_sync, (ParamSynchronizer, Callable)) - - if isinstance(self._custom_sync, ParamSynchronizer): - try: - return self._custom_sync.make_distributed_model(module=module, **get_fwd_compat_kwargs()) - except OptionalNotImplementedError: - pass - else: - # callable short form does not have support for DistributedDataParallel - pass - + if self._reduce_type in ["param", "custom_step_after_param_update"]: return None + assert self._reduce_type == "grad" cls = self._opts.get("class", DistributedDataParallel) if cls is not DistributedDataParallel: @@ -211,12 +125,10 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis def step_after_param_update(self, *, module: torch.nn.Module, epoch_step_idx: int): """one train step""" - if self._reduce_type == "custom": + if self._reduce_type == "custom_step_after_param_update": with torch.no_grad(): # TODO: do we want this for all syncers? - self._custom_sync( - module=module, - train_step_idx=epoch_step_idx, - **get_fwd_compat_kwargs(), + self._custom_step_after_param_update( + module=module, train_step_idx=epoch_step_idx, **get_fwd_compat_kwargs() ) elif self._reduce_type == "param" and ((epoch_step_idx % self._param_sync_step) == (self._param_sync_step - 1)): _sync_params_avg(module=module, sync_on_cpu=self._opts.get("sync_on_cpu", False)) From 4540bdd29ff787025dc2b56328f38bc520461df5 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 6 Sep 2024 16:29:33 +0200 Subject: [PATCH 6/7] fix text --- returnn/torch/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/returnn/torch/distributed.py b/returnn/torch/distributed.py index d1c587595..67c4486d6 100644 --- a/returnn/torch/distributed.py +++ b/returnn/torch/distributed.py @@ -57,7 +57,7 @@ def __init__(self, options: Dict[str, Any]): _logger.info("reduce_type grad") elif self._reduce_type == "custom_step_after_param_update": if not isinstance(self._custom_step_after_param_update, Callable): - raise ValueError(f"synchronizer must either be a callable") + raise ValueError(f"custom step callback must be a callable, not {self._custom_step_after_param_update}") _logger.info("reduce_type custom_step_after_param_update") else: raise ValueError(f"invalid reduce_type {self._reduce_type!r}") From d666d337e24208fcb7e754f2a3cf1bb53305bdee Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 6 Sep 2024 16:30:09 +0200 Subject: [PATCH 7/7] remove torch.no_grad() for more flexibility --- returnn/torch/distributed.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/returnn/torch/distributed.py b/returnn/torch/distributed.py index 67c4486d6..58ac0dbe3 100644 --- a/returnn/torch/distributed.py +++ b/returnn/torch/distributed.py @@ -126,10 +126,9 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis def step_after_param_update(self, *, module: torch.nn.Module, epoch_step_idx: int): """one train step""" if self._reduce_type == "custom_step_after_param_update": - with torch.no_grad(): # TODO: do we want this for all syncers? - self._custom_step_after_param_update( - module=module, train_step_idx=epoch_step_idx, **get_fwd_compat_kwargs() - ) + self._custom_step_after_param_update( + module=module, train_step_idx=epoch_step_idx, **get_fwd_compat_kwargs() + ) elif self._reduce_type == "param" and ((epoch_step_idx % self._param_sync_step) == (self._param_sync_step - 1)): _sync_params_avg(module=module, sync_on_cpu=self._opts.get("sync_on_cpu", False))