@@ -1062,30 +1062,6 @@ def callback(
1062
1062
return work
1063
1063
1064
1064
1065
- class _ManagedWork (Work ):
1066
- def __init__ (self , manager : "Manager" , work : Work , default_result : object ) -> None :
1067
- super ().__init__ ()
1068
-
1069
- self ._manager = manager
1070
- self ._work = work
1071
- self ._default_result = default_result
1072
-
1073
- def wait (self , timeout : Optional [timedelta ] = None ) -> bool :
1074
- try :
1075
- if self ._work is not None :
1076
- if timeout is not None :
1077
- self ._work .wait (timeout )
1078
- else :
1079
- self ._work .wait ()
1080
- except Exception as e :
1081
- self ._manager .report_error (e )
1082
-
1083
- return True
1084
-
1085
- def get_future (self ) -> Future [object ]:
1086
- return self ._manager .wrap_future (self ._work .get_future (), self ._default_result )
1087
-
1088
-
1089
1065
class ManagedProcessGroup (ProcessGroupWrapper ):
1090
1066
"""
1091
1067
This is a wrapper around any ProcessGroup that is managed by a torchft
@@ -1105,23 +1081,13 @@ def __init__(self, manager: "Manager") -> None:
1105
1081
self ._manager = manager
1106
1082
1107
1083
def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
1108
- # Ensure we have a valid quorum and are configured before trying to do
1109
- # any work.
1110
- self ._manager .wait_quorum ()
1084
+ if isinstance (opts , ReduceOp ):
1085
+ return self ._manager .allreduce (tensors , reduce_op = opts )
1111
1086
1112
- if self ._manager .errored () is not None :
1113
- return _DummyWork (tensors )
1114
- try :
1115
- work = super ().allreduce (tensors , opts )
1116
- except Exception as e :
1117
- self ._manager .report_error (e )
1118
- return _DummyWork (tensors )
1087
+ if isinstance (opts , AllreduceOptions ):
1088
+ return self ._manager .allreduce (tensors , reduce_op = opts .reduceOp )
1119
1089
1120
- return _ManagedWork (
1121
- self ._manager ,
1122
- work ,
1123
- tensors ,
1124
- )
1090
+ assert False , "unreachable"
1125
1091
1126
1092
def size (self ) -> int :
1127
1093
return self ._manager .num_participants ()
0 commit comments