diff --git a/inplace_abn/abn.py b/inplace_abn/abn.py index caab17e..e778dea 100644 --- a/inplace_abn/abn.py +++ b/inplace_abn/abn.py @@ -4,6 +4,7 @@ import torch.nn.functional as functional from .functions import * +_default_group = distributed.group.WORLD if hasattr(distributed, "group") else None class ABN(nn.Module): @@ -138,7 +139,7 @@ class InPlaceABNSync(ABN): """ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", - activation_param=0.01, group=distributed.group.WORLD): + activation_param=0.01, group=_default_group): super(InPlaceABNSync, self).__init__(num_features, eps, momentum, affine, activation, activation_param) self.group = group diff --git a/inplace_abn/functions.py b/inplace_abn/functions.py index d2cb4c5..51ccf81 100644 --- a/inplace_abn/functions.py +++ b/inplace_abn/functions.py @@ -3,6 +3,7 @@ from torch.autograd.function import once_differentiable from . import _backend +_default_group = distributed.group.WORLD if hasattr(distributed, "group") else None def _activation_from_name(activation): @@ -152,7 +153,7 @@ def inplace_abn(x, weight, bias, running_mean, running_var, def inplace_abn_sync(x, weight, bias, running_mean, running_var, training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01, - group=distributed.group.WORLD): + group=_default_group): return InPlaceABN.apply(x, weight, bias, running_mean, running_var, training, momentum, eps, activation, activation_param, group)