|
3 | 3 | from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
4 | 4 |
|
5 | 5 | import torch
|
| 6 | +from torch._C._distributed_c10d import Backend as C10dBackend |
6 | 7 | from torch.distributed import (
|
7 | 8 | DeviceMesh,
|
8 | 9 | ProcessGroup as BaseProcessGroup,
|
@@ -145,7 +146,13 @@ def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> BaseProcessGr
|
145 | 146 | assert self.mesh is not None
|
146 | 147 | return self.mesh.get_group(self._real_mesh_dim(dim))
|
147 | 148 |
|
148 |
| - def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh": |
| 149 | + def _flatten( |
| 150 | + self, |
| 151 | + mesh_dim_name: Optional[str] = None, |
| 152 | + backend_override: Union[ |
| 153 | + None, str, C10dBackend.Options, tuple[str, C10dBackend.Options] |
| 154 | + ] = None, |
| 155 | + ) -> "DeviceMesh": |
149 | 156 | flatten_mesh = _FlattenDeviceMesh(self)
|
150 | 157 | if mesh_dim_name is None:
|
151 | 158 | raise ValueError("ManagedDeviceMesh._flatten requires `mesh_dim_name`")
|
@@ -261,7 +268,13 @@ def __getitem__(self, mesh_dim_names: Union[str, tuple[str, ...]]) -> DeviceMesh
|
261 | 268 | def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> BaseProcessGroup:
|
262 | 269 | raise NotImplementedError
|
263 | 270 |
|
264 |
| - def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh": |
| 271 | + def _flatten( |
| 272 | + self, |
| 273 | + mesh_dim_name: Optional[str] = None, |
| 274 | + backend_override: Union[ |
| 275 | + None, str, C10dBackend.Options, tuple[str, C10dBackend.Options] |
| 276 | + ] = None, |
| 277 | + ) -> "DeviceMesh": |
265 | 278 | raise NotImplementedError
|
266 | 279 |
|
267 | 280 | def size(self, mesh_dim: Optional[int] = None) -> int:
|
|
0 commit comments