@@ -430,22 +430,18 @@ def _test_distrib_integration_multiclass(device):
430430
431431 from ignite .engine import Engine
432432
433- rank = idist .get_rank ()
434- torch .manual_seed (12 )
435-
436433 def _test (average , n_epochs , metric_device ):
437434 n_iters = 60
438- s = 16
435+ batch_size = 16
439436 n_classes = 7
440437
441- offset = n_iters * s
442- y_true = torch .randint (0 , n_classes , size = (offset * idist .get_world_size (),)).to (device )
443- y_preds = torch .rand (offset * idist .get_world_size (), n_classes ).to (device )
438+ y_true = torch .randint (0 , n_classes , size = (n_iters * batch_size ,)).to (device )
439+ y_preds = torch .rand (n_iters * batch_size , n_classes ).to (device )
444440
445441 def update (engine , i ):
446442 return (
447- y_preds [i * s + rank * offset : (i + 1 ) * s + rank * offset , :],
448- y_true [i * s + rank * offset : (i + 1 ) * s + rank * offset ],
443+ y_preds [i * batch_size : (i + 1 ) * batch_size , :],
444+ y_true [i * batch_size : (i + 1 ) * batch_size ],
449445 )
450446
451447 engine = Engine (update )
@@ -457,6 +453,9 @@ def update(engine, i):
457453 data = list (range (n_iters ))
458454 engine .run (data = data , max_epochs = n_epochs )
459455
456+ y_preds = idist .all_gather (y_preds )
457+ y_true = idist .all_gather (y_true )
458+
460459 assert "re" in engine .state .metrics
461460 assert re ._updated is True
462461 res = engine .state .metrics ["re" ]
@@ -475,7 +474,9 @@ def update(engine, i):
475474 metric_devices = [torch .device ("cpu" )]
476475 if device .type != "xla" :
477476 metric_devices .append (idist .device ())
478- for _ in range (2 ):
477+ rank = idist .get_rank ()
478+ for i in range (2 ):
479+ torch .manual_seed (12 + rank + i )
479480 for metric_device in metric_devices :
480481 _test (average = False , n_epochs = 1 , metric_device = metric_device )
481482 _test (average = False , n_epochs = 2 , metric_device = metric_device )
@@ -491,22 +492,20 @@ def _test_distrib_integration_multilabel(device):
491492
492493 from ignite .engine import Engine
493494
494- rank = idist .get_rank ()
495495 torch .manual_seed (12 )
496496
497497 def _test (average , n_epochs , metric_device ):
498498 n_iters = 60
499- s = 16
499+ batch_size = 16
500500 n_classes = 7
501501
502- offset = n_iters * s
503- y_true = torch .randint (0 , 2 , size = (offset * idist .get_world_size (), n_classes , 6 , 8 )).to (device )
504- y_preds = torch .randint (0 , 2 , size = (offset * idist .get_world_size (), n_classes , 6 , 8 )).to (device )
502+ y_true = torch .randint (0 , 2 , size = (n_iters * batch_size , n_classes , 6 , 8 )).to (device )
503+ y_preds = torch .randint (0 , 2 , size = (n_iters * batch_size , n_classes , 6 , 8 )).to (device )
505504
506505 def update (engine , i ):
507506 return (
508- y_preds [i * s + rank * offset : (i + 1 ) * s + rank * offset , ...],
509- y_true [i * s + rank * offset : (i + 1 ) * s + rank * offset , ...],
507+ y_preds [i * batch_size : (i + 1 ) * batch_size , ...],
508+ y_true [i * batch_size : (i + 1 ) * batch_size , ...],
510509 )
511510
512511 engine = Engine (update )
@@ -518,6 +517,9 @@ def update(engine, i):
518517 data = list (range (n_iters ))
519518 engine .run (data = data , max_epochs = n_epochs )
520519
520+ y_preds = idist .all_gather (y_preds )
521+ y_true = idist .all_gather (y_true )
522+
521523 assert "re" in engine .state .metrics
522524 assert re ._updated is True
523525 res = engine .state .metrics ["re" ]
@@ -540,7 +542,9 @@ def update(engine, i):
540542 metric_devices = ["cpu" ]
541543 if device .type != "xla" :
542544 metric_devices .append (idist .device ())
543- for _ in range (2 ):
545+ rank = idist .get_rank ()
546+ for i in range (2 ):
547+ torch .manual_seed (12 + rank + i )
544548 for metric_device in metric_devices :
545549 _test (average = False , n_epochs = 1 , metric_device = metric_device )
546550 _test (average = False , n_epochs = 2 , metric_device = metric_device )
0 commit comments