Skip to content

Commit

Permalink
Fix GenerationMixin.generate compatibility with pytorch profiler (h…
Browse files Browse the repository at this point in the history
…uggingface#31935)

use torch.compiler.is_compiling() when possible
  • Loading branch information
fxmarty authored and amyeroberts committed Jul 19, 2024
1 parent def7dff commit f392473
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,12 +642,8 @@ def is_torch_mlu_available(check_device=False):
def is_torchdynamo_available():
if not is_torch_available():
return False
try:
import torch._dynamo as dynamo # noqa: F401

return True
except Exception:
return False
return version.parse(_torch_version) >= version.parse("2.0.0")


def is_torch_compile_available():
Expand All @@ -665,9 +661,15 @@ def is_torchdynamo_compiling():
if not is_torch_available():
return False
try:
import torch._dynamo as dynamo # noqa: F401
# Importing torch._dynamo causes issues with PyTorch profiler (https://github.com/pytorch/pytorch/issues/130622) hence rather relying on `torch.compiler.is_compiling()` when possible.
if version.parse(_torch_version) >= version.parse("2.3.0"):
import torch

return torch.compiler.is_compiling()
else:
import torch._dynamo as dynamo # noqa: F401

return dynamo.is_compiling()
return dynamo.is_compiling()
except Exception:
return False

Expand Down

0 comments on commit f392473

Please sign in to comment.