Skip to content

Commit 7234bab

Browse files
authored
fix device mesh overrides (#254)
1 parent ee2b322 commit 7234bab

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

torchft/device_mesh.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
44

55
import torch
6+
from torch._C._distributed_c10d import Backend as C10dBackend
67
from torch.distributed import (
78
DeviceMesh,
89
ProcessGroup as BaseProcessGroup,
@@ -145,7 +146,13 @@ def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> BaseProcessGr
145146
assert self.mesh is not None
146147
return self.mesh.get_group(self._real_mesh_dim(dim))
147148

148-
def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh":
149+
def _flatten(
150+
self,
151+
mesh_dim_name: Optional[str] = None,
152+
backend_override: Union[
153+
None, str, C10dBackend.Options, tuple[str, C10dBackend.Options]
154+
] = None,
155+
) -> "DeviceMesh":
149156
flatten_mesh = _FlattenDeviceMesh(self)
150157
if mesh_dim_name is None:
151158
raise ValueError("ManagedDeviceMesh._flatten requires `mesh_dim_name`")
@@ -261,7 +268,13 @@ def __getitem__(self, mesh_dim_names: Union[str, tuple[str, ...]]) -> DeviceMesh
261268
def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> BaseProcessGroup:
262269
raise NotImplementedError
263270

264-
def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh":
271+
def _flatten(
272+
self,
273+
mesh_dim_name: Optional[str] = None,
274+
backend_override: Union[
275+
None, str, C10dBackend.Options, tuple[str, C10dBackend.Options]
276+
] = None,
277+
) -> "DeviceMesh":
265278
raise NotImplementedError
266279

267280
def size(self, mesh_dim: Optional[int] = None) -> int:

0 commit comments

Comments
 (0)