42
42
from common import mpi_env_rank_and_size , skip_or_fail_gpu_test , temppath
43
43
44
44
_1_5_api = LooseVersion (torch .__version__ ) >= LooseVersion ('1.5.0' )
45
+ _1_10_api = LooseVersion (torch .__version__ ) >= LooseVersion ('1.10.0' )
45
46
46
47
ccl_supported_types = set ([torch .ByteTensor , torch .CharTensor , torch .ShortTensor ,
47
48
torch .IntTensor , torch .LongTensor , torch .FloatTensor ,
@@ -62,6 +63,11 @@ def __init__(self, *args, **kwargs):
62
63
super (TorchTests , self ).__init__ (* args , ** kwargs )
63
64
warnings .simplefilter ('module' )
64
65
66
+ def tearDown (self ):
67
+ if _1_10_api and hvd .is_initialized ():
68
+ # To fix https://github.com/horovod/horovod/issues/3149
69
+ hvd .join ()
70
+
65
71
def convert_cpu_fp16_to_fp32 (self , * values ):
66
72
# PyTorch doesn't support any CPU ops on FP16 tensors.
67
73
# In case we need to do ops, we will convert tensor to FP32 here.
@@ -612,9 +618,6 @@ def test_horovod_allreduce_duplicate_name_error(self):
612
618
assert False , 'hvd.allreduce_async did not throw error'
613
619
except (torch .FatalError , ValueError ):
614
620
pass
615
- if LooseVersion (torch .__version__ ) >= LooseVersion ('1.10.0' ):
616
- # To fix https://github.com/horovod/horovod/issues/3149
617
- hvd .join ()
618
621
619
622
def test_horovod_allreduce_grad (self ):
620
623
"""Test the correctness of the allreduce gradient."""
@@ -1221,9 +1224,6 @@ def test_horovod_allgather_duplicate_name_error(self):
1221
1224
assert False , 'hvd.allgather_async did not throw error'
1222
1225
except (torch .FatalError , ValueError ):
1223
1226
pass
1224
- if LooseVersion (torch .__version__ ) >= LooseVersion ('1.10.0' ):
1225
- # To fix https://github.com/horovod/horovod/issues/3149
1226
- hvd .join ()
1227
1227
1228
1228
def test_horovod_allgather_grad (self ):
1229
1229
"""Test the correctness of the allgather gradient."""
@@ -1534,9 +1534,6 @@ def test_horovod_broadcast_duplicate_name_error(self):
1534
1534
assert False , 'hvd.broadcast_async did not throw error'
1535
1535
except (torch .FatalError , ValueError ):
1536
1536
pass
1537
- if LooseVersion (torch .__version__ ) >= LooseVersion ('1.10.0' ):
1538
- # To fix https://github.com/horovod/horovod/issues/3149
1539
- hvd .join ()
1540
1537
1541
1538
def test_horovod_broadcast_grad (self ):
1542
1539
"""Test the correctness of the broadcast gradient."""
0 commit comments