diff --git a/examples/cifar100/main_horovod.py b/examples/cifar100/main_horovod.py index ef9dda7..92ff704 100644 --- a/examples/cifar100/main_horovod.py +++ b/examples/cifar100/main_horovod.py @@ -124,7 +124,7 @@ def forward(self, x): print ("use cuda!!") print ("local rank {}, rank {}".format(hvd.local_rank(),hvd.rank())) -torch.cuda.set_device(hvd.rank()) +torch.cuda.set_device(hvd.local_rank()) torch.cuda.manual_seed(1111) best_acc = 0