diff --git a/apex/__init__.py b/apex/__init__.py index 74851f5b3..be739694b 100644 --- a/apex/__init__.py +++ b/apex/__init__.py @@ -24,7 +24,6 @@ # load time) the error message is timely and visible. from . import optimizers from . import normalization -from . import transformer # Logging utilities for apex.transformer module diff --git a/apex/transformer/utils.py b/apex/transformer/utils.py index 4434e3604..39d5d7668 100644 --- a/apex/transformer/utils.py +++ b/apex/transformer/utils.py @@ -8,6 +8,8 @@ # The following 4 lines are for backward comparability with # older PyTorch. if "all_gather_into_tensor" not in dir(torch.distributed): + if not torch.distributed.is_available(): + raise RuntimeError("PyTorch Distributed is Not available or Disabled.") torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base def ensure_divisibility(numerator, denominator):