@@ -1059,30 +1059,6 @@ def callback(
1059
1059
return work
1060
1060
1061
1061
1062
- class _ManagedWork (Work ):
1063
- def __init__ (self , manager : "Manager" , work : Work , default_result : object ) -> None :
1064
- super ().__init__ ()
1065
-
1066
- self ._manager = manager
1067
- self ._work = work
1068
- self ._default_result = default_result
1069
-
1070
- def wait (self , timeout : Optional [timedelta ] = None ) -> bool :
1071
- try :
1072
- if self ._work is not None :
1073
- if timeout is not None :
1074
- self ._work .wait (timeout )
1075
- else :
1076
- self ._work .wait ()
1077
- except Exception as e :
1078
- self ._manager .report_error (e )
1079
-
1080
- return True
1081
-
1082
- def get_future (self ) -> Future [object ]:
1083
- return self ._manager .wrap_future (self ._work .get_future (), self ._default_result )
1084
-
1085
-
1086
1062
class ManagedProcessGroup (ProcessGroupWrapper ):
1087
1063
"""
1088
1064
This is a wrapper around any ProcessGroup that is managed by a torchft
@@ -1102,23 +1078,7 @@ def __init__(self, manager: "Manager") -> None:
1102
1078
self ._manager = manager
1103
1079
1104
1080
def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
1105
- # Ensure we have a valid quorum and are configured before trying to do
1106
- # any work.
1107
- self ._manager .wait_quorum ()
1108
-
1109
- if self ._manager .errored () is not None :
1110
- return _DummyWork (tensors )
1111
- try :
1112
- work = super ().allreduce (tensors , opts )
1113
- except Exception as e :
1114
- self ._manager .report_error (e )
1115
- return _DummyWork (tensors )
1116
-
1117
- return _ManagedWork (
1118
- self ._manager ,
1119
- work ,
1120
- tensors ,
1121
- )
1081
+ return self ._manager .allreduce (tensors )
1122
1082
1123
1083
def size (self ) -> int :
1124
1084
return self ._manager .num_participants ()
0 commit comments