@@ -119,39 +119,44 @@ def _test_distrib_all_reduce(device):
119
119
120
120
121
121
def _test_distrib_all_reduce_group (device ):
122
- if idist .get_world_size () > 1 and idist .backend () is not None :
123
- ranks = [0 , 1 ]
124
- rank = idist .get_rank ()
125
- t = torch .tensor ([rank ], device = device )
126
- bnd = idist .backend ()
122
+ assert idist .get_world_size () > 1 , idist .get_world_size ()
123
+ assert idist .backend () is not None , idist .backend ()
127
124
128
- group = idist .new_group (ranks )
129
- if bnd in ("horovod" ):
130
- with pytest .raises (NotImplementedError , match = r"all_reduce with group for horovod is not implemented" ):
131
- res = idist .all_reduce (t , group = group )
132
- else :
125
+ ranks = [0 , 1 ]
126
+ rank = idist .get_rank ()
127
+ t = torch .tensor ([rank ], device = device )
128
+ bnd = idist .backend ()
129
+
130
+ group = idist .new_group (ranks )
131
+ if bnd in ("horovod" ):
132
+ with pytest .raises (NotImplementedError , match = r"all_reduce with group for horovod is not implemented" ):
133
+ res = idist .all_reduce (t , group = group )
134
+ else :
135
+ if rank in ranks :
136
+ # we should call all_reduce with group on the participating ranks only
137
+ # otherwise a warning is raised:
138
+ # UserWarning: Running all_reduce on global rank 2 which does not belong to the given group.
133
139
res = idist .all_reduce (t , group = group )
134
140
assert res == torch .tensor ([sum (ranks )], device = device )
135
141
136
- t = torch .tensor ([rank ], device = device )
137
- if bnd in ("horovod" ):
138
- with pytest .raises (NotImplementedError , match = r"all_reduce with group for horovod is not implemented" ):
139
- res = idist .all_reduce (t , group = ranks )
140
- else :
142
+ t = torch .tensor ([rank ], device = device )
143
+ if bnd in ("horovod" ):
144
+ with pytest .raises (NotImplementedError , match = r"all_reduce with group for horovod is not implemented" ):
145
+ res = idist .all_reduce (t , group = ranks )
146
+ else :
147
+ if rank in ranks :
141
148
res = idist .all_reduce (t , group = ranks )
142
149
assert res == torch .tensor ([sum (ranks )], device = device )
143
150
144
- ranks = "abc"
145
-
146
- if bnd in ("nccl" , "gloo" , "mpi" ):
147
- with pytest .raises (ValueError , match = r"Argument group should be list of int or ProcessGroup" ):
148
- res = idist .all_reduce (t , group = "abc" )
149
- elif bnd in ("xla-tpu" ):
150
- with pytest .raises (ValueError , match = r"Argument group should be list of int" ):
151
- res = idist .all_reduce (t , group = "abc" )
152
- elif bnd in ("horovod" ):
153
- with pytest .raises (NotImplementedError , match = r"all_reduce with group for horovod is not implemented" ):
154
- res = idist .all_reduce (t , group = "abc" )
151
+ if bnd in ("nccl" , "gloo" , "mpi" ):
152
+ with pytest .raises (ValueError , match = r"Argument group should be list of int or ProcessGroup" ):
153
+ idist .all_reduce (t , group = "abc" )
154
+ elif bnd in ("xla-tpu" ):
155
+ with pytest .raises (ValueError , match = r"Argument group should be list of int" ):
156
+ idist .all_reduce (t , group = "abc" )
157
+ elif bnd in ("horovod" ):
158
+ with pytest .raises (NotImplementedError , match = r"all_reduce with group for horovod is not implemented" ):
159
+ idist .all_reduce (t , group = "abc" )
155
160
156
161
157
162
def _test_distrib_all_gather (device ):
@@ -218,77 +223,76 @@ def _test_distrib_all_gather(device):
218
223
219
224
220
225
def _test_distrib_all_gather_group (device ):
221
- if idist .get_world_size () > 1 :
222
- ranks = list (range (idist .get_world_size () - 1 , 0 , - 1 )) # [0, 1, 2, 3] -> [3, 2, 1]
223
- rank = idist .get_rank ()
224
- bnd = idist .backend ()
226
+ assert idist .get_world_size () > 1 , idist .get_world_size ()
225
227
226
- t = torch .tensor ([rank ], device = device )
227
- group = idist .new_group (ranks )
228
- if bnd in ("horovod" ):
229
- with pytest .raises (NotImplementedError , match = r"all_gather with group for horovod is not implemented" ):
230
- res = idist .all_gather (t , group = group )
231
- else :
232
- res = idist .all_gather (t , group = group )
233
- if rank in ranks :
234
- assert torch .equal (res , torch .tensor (ranks , device = device ))
235
- else :
236
- assert res == t
228
+ ranks = list (range (idist .get_world_size () - 1 , 0 , - 1 )) # [0, 1, 2, 3] -> [3, 2, 1]
229
+ rank = idist .get_rank ()
230
+ bnd = idist .backend ()
237
231
238
- t = torch .tensor ([rank ], device = device )
239
- if bnd in ("horovod" ):
240
- with pytest .raises (NotImplementedError , match = r"all_gather with group for horovod is not implemented" ):
241
- res = idist .all_gather (t , group = ranks )
232
+ t = torch .tensor ([rank ], device = device )
233
+ group = idist .new_group (ranks )
234
+ if bnd in ("horovod" ):
235
+ with pytest .raises (NotImplementedError , match = r"all_gather with group for horovod is not implemented" ):
236
+ res = idist .all_gather (t , group = group )
237
+ else :
238
+ res = idist .all_gather (t , group = group )
239
+ if rank in ranks :
240
+ assert torch .equal (res , torch .tensor (sorted (ranks ), device = device )), res
242
241
else :
243
- res = idist .all_gather (t , group = ranks )
244
- if rank in ranks :
245
- assert torch .equal (res , torch .tensor (ranks , device = device ))
246
- else :
247
- assert res == t
242
+ assert res == t
248
243
249
- t = {
250
- "a" : [rank + 1 , rank + 2 , torch .tensor (rank + 3 , device = device )],
251
- "b" : torch .tensor ([[rank + 1 , rank + 2 , rank + 3 ]], device = device ),
252
- "c" : {"abcd" : rank , "cdfg" : torch .tensor (rank , dtype = torch .uint8 , device = device )},
253
- }
254
- if bnd in ("xla-tpu" ):
255
- with pytest .raises (NotImplementedError , match = r"all_gather on object is not implemented for xla" ):
256
- res = idist .all_gather (t , group = ranks )
257
- elif bnd in ("horovod" ):
258
- with pytest .raises (NotImplementedError , match = r"all_gather with group for horovod is not implemented" ):
259
- res = idist .all_gather (t , group = ranks )
244
+ t = torch .tensor ([rank ], device = device )
245
+ if bnd in ("horovod" ):
246
+ with pytest .raises (NotImplementedError , match = r"all_gather with group for horovod is not implemented" ):
247
+ res = idist .all_gather (t , group = ranks )
248
+ else :
249
+ res = idist .all_gather (t , group = ranks )
250
+ if rank in ranks :
251
+ assert torch .equal (res , torch .tensor (sorted (ranks ), device = device ))
260
252
else :
253
+ assert res == t
254
+
255
+ t = {
256
+ "a" : [rank + 1 , rank + 2 , torch .tensor (rank + 3 , device = device )],
257
+ "b" : torch .tensor ([[rank + 1 , rank + 2 , rank + 3 ]], device = device ),
258
+ "c" : {"abcd" : rank , "cdfg" : torch .tensor (rank , dtype = torch .uint8 , device = device )},
259
+ }
260
+ if bnd in ("xla-tpu" ):
261
+ with pytest .raises (NotImplementedError , match = r"all_gather on object is not implemented for xla" ):
261
262
res = idist .all_gather (t , group = ranks )
262
- if rank in ranks :
263
- assert isinstance (res , list ) and len (res ) == len (ranks )
264
- for i , obj in zip (ranks , res ):
265
- assert isinstance (obj , dict )
266
- assert list (obj .keys ()) == ["a" , "b" , "c" ], obj
267
- expected_device = (
268
- device
269
- if torch .device (device ).type == "cpu"
270
- else torch .device (f"{ torch .device (device ).type } :{ i } " )
271
- )
272
- expected = {
273
- "a" : [i + 1 , i + 2 , torch .tensor (i + 3 , device = expected_device )],
274
- "b" : torch .tensor ([[i + 1 , i + 2 , i + 3 ]], device = expected_device ),
275
- "c" : {"abcd" : i , "cdfg" : torch .tensor (i , dtype = torch .uint8 , device = expected_device )},
276
- }
277
- assert obj ["a" ] == expected ["a" ], (obj , expected )
278
- assert (obj ["b" ] == expected ["b" ]).all (), (obj , expected )
279
- assert obj ["c" ] == expected ["c" ], (obj , expected )
280
- else :
281
- assert res == t
263
+ elif bnd in ("horovod" ):
264
+ with pytest .raises (NotImplementedError , match = r"all_gather with group for horovod is not implemented" ):
265
+ res = idist .all_gather (t , group = ranks )
266
+ else :
267
+ res = idist .all_gather (t , group = ranks )
268
+ if rank in ranks :
269
+ assert isinstance (res , list ) and len (res ) == len (ranks )
270
+ for i , obj in zip (sorted (ranks ), res ):
271
+ assert isinstance (obj , dict )
272
+ assert list (obj .keys ()) == ["a" , "b" , "c" ], obj
273
+ expected_device = (
274
+ device if torch .device (device ).type == "cpu" else torch .device (f"{ torch .device (device ).type } :{ i } " )
275
+ )
276
+ expected = {
277
+ "a" : [i + 1 , i + 2 , torch .tensor (i + 3 , device = expected_device )],
278
+ "b" : torch .tensor ([[i + 1 , i + 2 , i + 3 ]], device = expected_device ),
279
+ "c" : {"abcd" : i , "cdfg" : torch .tensor (i , dtype = torch .uint8 , device = expected_device )},
280
+ }
281
+ assert obj ["a" ] == expected ["a" ], (obj , expected )
282
+ assert (obj ["b" ] == expected ["b" ]).all (), (obj , expected )
283
+ assert obj ["c" ] == expected ["c" ], (obj , expected )
284
+ else :
285
+ assert res == t
282
286
283
- if bnd in ("nccl" , "gloo" , "mpi" ):
284
- with pytest .raises (ValueError , match = r"Argument group should be list of int or ProcessGroup" ):
285
- res = idist .all_gather (t , group = "abc" )
286
- elif bnd in ("xla-tpu" ):
287
- with pytest .raises (ValueError , match = r"Argument group should be list of int" ):
288
- res = idist .all_gather (t , group = "abc" )
289
- elif bnd in ("horovod" ):
290
- with pytest .raises (NotImplementedError , match = r"all_gather with group for horovod is not implemented" ):
291
- res = idist .all_gather (t , group = "abc" )
287
+ if bnd in ("nccl" , "gloo" , "mpi" ):
288
+ with pytest .raises (ValueError , match = r"Argument group should be list of int or ProcessGroup" ):
289
+ res = idist .all_gather (t , group = "abc" )
290
+ elif bnd in ("xla-tpu" ):
291
+ with pytest .raises (ValueError , match = r"Argument group should be list of int" ):
292
+ res = idist .all_gather (t , group = "abc" )
293
+ elif bnd in ("horovod" ):
294
+ with pytest .raises (NotImplementedError , match = r"all_gather with group for horovod is not implemented" ):
295
+ res = idist .all_gather (t , group = "abc" )
292
296
293
297
294
298
def _test_idist_all_gather_tensors_with_shapes (device ):
@@ -312,37 +316,37 @@ def _test_idist_all_gather_tensors_with_shapes(device):
312
316
313
317
314
318
def _test_idist_all_gather_tensors_with_shapes_group (device ):
315
- if idist .get_world_size () > 1 :
316
- torch .manual_seed (41 )
319
+ assert idist .get_world_size (), idist . get_world_size ()
320
+ torch .manual_seed (41 )
317
321
318
- rank = idist .get_rank ()
319
- ranks = list (range (1 , idist .get_world_size ()))
320
- ws = idist .get_world_size ()
321
- bnd = idist .backend ()
322
+ rank = idist .get_rank ()
323
+ ranks = list (range (1 , idist .get_world_size ()))
324
+ ws = idist .get_world_size ()
325
+ bnd = idist .backend ()
326
+ if rank in ranks :
327
+ reference = torch .randn (ws * (ws + 1 ) // 2 , ws * (ws + 3 ) // 2 , ws * (ws + 5 ) // 2 , device = device )
328
+ rank_tensor = reference [
329
+ rank * (rank + 1 ) // 2 : rank * (rank + 1 ) // 2 + rank + 1 ,
330
+ rank * (rank + 3 ) // 2 : rank * (rank + 3 ) // 2 + rank + 2 ,
331
+ rank * (rank + 5 ) // 2 : rank * (rank + 5 ) // 2 + rank + 3 ,
332
+ ]
333
+ else :
334
+ rank_tensor = torch .tensor ([rank ], device = device )
335
+ if bnd in ("horovod" ):
336
+ with pytest .raises (NotImplementedError , match = r"all_gather with group for horovod is not implemented" ):
337
+ tensors = all_gather_tensors_with_shapes (rank_tensor , [[r + 1 , r + 2 , r + 3 ] for r in ranks ], ranks )
338
+ else :
339
+ tensors = all_gather_tensors_with_shapes (rank_tensor , [[r + 1 , r + 2 , r + 3 ] for r in ranks ], ranks )
322
340
if rank in ranks :
323
- reference = torch .randn (ws * (ws + 1 ) // 2 , ws * (ws + 3 ) // 2 , ws * (ws + 5 ) // 2 , device = device )
324
- rank_tensor = reference [
325
- rank * (rank + 1 ) // 2 : rank * (rank + 1 ) // 2 + rank + 1 ,
326
- rank * (rank + 3 ) // 2 : rank * (rank + 3 ) // 2 + rank + 2 ,
327
- rank * (rank + 5 ) // 2 : rank * (rank + 5 ) // 2 + rank + 3 ,
328
- ]
329
- else :
330
- rank_tensor = torch .tensor ([rank ], device = device )
331
- if bnd in ("horovod" ):
332
- with pytest .raises (NotImplementedError , match = r"all_gather with group for horovod is not implemented" ):
333
- tensors = all_gather_tensors_with_shapes (rank_tensor , [[r + 1 , r + 2 , r + 3 ] for r in ranks ], ranks )
341
+ for r in ranks :
342
+ r_tensor = reference [
343
+ r * (r + 1 ) // 2 : r * (r + 1 ) // 2 + r + 1 ,
344
+ r * (r + 3 ) // 2 : r * (r + 3 ) // 2 + r + 2 ,
345
+ r * (r + 5 ) // 2 : r * (r + 5 ) // 2 + r + 3 ,
346
+ ]
347
+ assert (r_tensor == tensors [r - 1 ]).all ()
334
348
else :
335
- tensors = all_gather_tensors_with_shapes (rank_tensor , [[r + 1 , r + 2 , r + 3 ] for r in ranks ], ranks )
336
- if rank in ranks :
337
- for r in ranks :
338
- r_tensor = reference [
339
- r * (r + 1 ) // 2 : r * (r + 1 ) // 2 + r + 1 ,
340
- r * (r + 3 ) // 2 : r * (r + 3 ) // 2 + r + 2 ,
341
- r * (r + 5 ) // 2 : r * (r + 5 ) // 2 + r + 3 ,
342
- ]
343
- assert (r_tensor == tensors [r - 1 ]).all ()
344
- else :
345
- assert [rank_tensor ] == tensors
349
+ assert [rank_tensor ] == tensors
346
350
347
351
348
352
def _test_distrib_broadcast (device ):
0 commit comments