diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 5a63b3ff35679f..f47d4b7a5fb7df 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -642,12 +642,8 @@ def is_torch_mlu_available(check_device=False): def is_torchdynamo_available(): if not is_torch_available(): return False - try: - import torch._dynamo as dynamo # noqa: F401 - return True - except Exception: - return False + return version.parse(_torch_version) >= version.parse("2.0.0") def is_torch_compile_available(): @@ -665,9 +661,15 @@ def is_torchdynamo_compiling(): if not is_torch_available(): return False try: - import torch._dynamo as dynamo # noqa: F401 + # Importing torch._dynamo causes issues with PyTorch profiler (https://github.com/pytorch/pytorch/issues/130622) hence rather relying on `torch.compiler.is_compiling()` when possible. + if version.parse(_torch_version) >= version.parse("2.3.0"): + import torch + + return torch.compiler.is_compiling() + else: + import torch._dynamo as dynamo # noqa: F401 - return dynamo.is_compiling() + return dynamo.is_compiling() except Exception: return False