From 2c8f50185d76060de758b60cc4c9bdb3ded69074 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Sun, 14 Jul 2024 15:44:38 +0200 Subject: [PATCH] Fix `GenerationMixin.generate` compatibility with pytorch profiler (#31935) use torch.compiler.is_compiling() when possible --- src/transformers/utils/import_utils.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 5a63b3ff35679f..f47d4b7a5fb7df 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -642,12 +642,8 @@ def is_torch_mlu_available(check_device=False): def is_torchdynamo_available(): if not is_torch_available(): return False - try: - import torch._dynamo as dynamo # noqa: F401 - return True - except Exception: - return False + return version.parse(_torch_version) >= version.parse("2.0.0") def is_torch_compile_available(): @@ -665,9 +661,15 @@ def is_torchdynamo_compiling(): if not is_torch_available(): return False try: - import torch._dynamo as dynamo # noqa: F401 + # Importing torch._dynamo causes issues with PyTorch profiler (https://github.com/pytorch/pytorch/issues/130622) hence rather relying on `torch.compiler.is_compiling()` when possible. + if version.parse(_torch_version) >= version.parse("2.3.0"): + import torch + + return torch.compiler.is_compiling() + else: + import torch._dynamo as dynamo # noqa: F401 - return dynamo.is_compiling() + return dynamo.is_compiling() except Exception: return False