Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformers v4.45 support #2023

Merged
merged 22 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ REAL_CLONE_URL = $(if $(CLONE_URL),$(CLONE_URL),$(DEFAULT_CLONE_URL))
# Run code quality checks
style_check:
black --check .
ruff .
ruff check .

style:
black .
ruff . --fix
ruff check . --fix

# Run tests for the library
test:
Expand Down
84 changes: 77 additions & 7 deletions optimum/bettertransformer/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,71 @@ def gpt2_wrapped_scaled_dot_product(
return sdpa_result, None


# Adapted from transformers.models.gptj.modeling_gptj.GPTJAttention._attn
def gptj_wrapped_scaled_dot_product(
Comment on lines +95 to +96
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe add docstring on how this custom implementation solves the problem.

self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
):
raise_on_head_mask(head_mask)
batch_size = query.shape[0]

mask_value = torch.finfo(value.dtype).min
mask_value = torch.full([], mask_value, dtype=value.dtype)

# in gpt-neo-x and gpt-j the query and keys are always in fp32
# thus we need to cast them to the value dtype
if self.downcast_qk:
query = query.to(value.dtype)
key = key.to(value.dtype)

if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, -1, -1] < -1:
raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.")

dropout_p = self.dropout_prob_attn if self.training else 0.0
if batch_size == 1 or self.training:
if query.shape[2] > 1:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
)
else:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False
)
else:
query_length, key_length = query.size(-2), key.size(-2)

# causal_mask is always [True, ..., True] otherwise, so executing this
# is unnecessary
if query_length > 1:
if not check_if_transformers_greater("4.44.99"):
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)

causal_mask = torch.where(causal_mask, 0, mask_value)

# torch.Tensor.expand does no memory copy
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)
if attention_mask is not None:
attention_mask = causal_mask + attention_mask

else:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
)

# in gpt-neo-x and gpt-j the query and keys are always in fp32
# thus we need to cast them to the value dtype
if self.downcast_qk:
sdpa_result = sdpa_result.to(value.dtype)

return sdpa_result, None


# Adapted from transformers.models.bark.modeling_bark.BarkSelfAttention._attn
def bark_wrapped_scaled_dot_product(
self,
Expand Down Expand Up @@ -195,7 +260,7 @@ def codegen_wrapped_scaled_dot_product(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
)
else:
# in this case, which is the later decoding steps, the `causal_mask`` in
# in this case, which is the later decoding steps, the `causal_mask` in
# https://github.com/huggingface/transformers/blob/ae54e3c3b18bac0832ad62ea9b896dfd52a09850/src/transformers/models/gpt2/modeling_gpt2.py#L195
# is [True, ..., True] so actually not causal
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
Expand All @@ -207,15 +272,20 @@ def codegen_wrapped_scaled_dot_product(
# causal_mask is always [True, ..., True] otherwise, so executing this
# is unnecessary
if query_length > 1:
causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
if not check_if_transformers_greater("4.44.99"):
causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].to(
torch.bool
)

causal_mask = torch.where(causal_mask, 0, mask_value)
causal_mask = torch.where(causal_mask, 0, mask_value)

# torch.Tensor.expand does no memory copy
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)
# torch.Tensor.expand does no memory copy
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)

# we use torch.min to avoid having tensor(-inf)
attention_mask = torch.min(causal_mask, attention_mask)
# we use torch.min to avoid having tensor(-inf)
attention_mask = torch.min(causal_mask, attention_mask)
else:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
Expand Down
35 changes: 31 additions & 4 deletions optimum/bettertransformer/models/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
codegen_wrapped_scaled_dot_product,
gpt2_wrapped_scaled_dot_product,
gpt_neo_wrapped_scaled_dot_product,
gptj_wrapped_scaled_dot_product,
opt_forward,
t5_forward,
)
Expand Down Expand Up @@ -82,7 +83,7 @@ def forward(self, *args, **kwargs):


class GPTJAttentionLayerBetterTransformer(BetterTransformerBaseLayer, GPTJAttention, nn.Module):
_attn = gpt2_wrapped_scaled_dot_product
_attn = gptj_wrapped_scaled_dot_product

def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):
super().__init__(config)
Expand All @@ -96,14 +97,22 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):
"out_proj",
"attn_dropout",
"resid_dropout",
"bias",
"scale_attn",
"masked_bias",
]
# Attribute only for transformers>=4.28
if hasattr(layer, "embed_positions"):
submodules.append("embed_positions")

# Attribute only for transformers<4.45
if hasattr(layer, "bias"):
submodules.append("bias")
if hasattr(layer, "masked_bias"):
submodules.append("masked_bias")

# Attribute only for transformers>=4.45
if hasattr(layer, "layer_idx"):
submodules.append("layer_idx")

for attr in submodules:
setattr(self, attr, getattr(layer, attr))

Expand All @@ -127,6 +136,11 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):

self.module_mapping = None
submodules = ["rotary_emb", "query_key_value", "dense", "bias", "masked_bias", "norm_factor"]

# Attribute only for transformers>=4.45
if hasattr(layer, "layer_idx"):
submodules.append("layer_idx")

for attr in submodules:
setattr(self, attr, getattr(layer, attr))

Expand Down Expand Up @@ -155,6 +169,11 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):

self.module_mapping = None
submodules = ["attn_dropout", "resid_dropout", "k_proj", "v_proj", "q_proj", "out_proj", "bias", "masked_bias"]

# Attribute only for transformers>=4.45
if hasattr(layer, "layer_id"):
submodules.append("layer_id")

for attr in submodules:
setattr(self, attr, getattr(layer, attr))

Expand Down Expand Up @@ -238,12 +257,20 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):
super(BetterTransformerBaseLayer, self).__init__(config)

self.module_mapping = None
submodules = ["attn_dropout", "resid_dropout", "qkv_proj", "out_proj", "causal_mask", "scale_attn"]
submodules = ["attn_dropout", "resid_dropout", "qkv_proj", "out_proj", "scale_attn"]

# Attribute only for transformers>=4.28
if hasattr(layer, "embed_positions"):
submodules.append("embed_positions")

# Attribute only for transformers<4.45
if hasattr(layer, "causal_mask"):
submodules.append("causal_mask")

# Attribute only for transformers>=4.45
if hasattr(layer, "layer_idx"):
submodules.append("layer_idx")

for attr in submodules:
setattr(self, attr, getattr(layer, attr))

Expand Down
4 changes: 4 additions & 0 deletions optimum/bettertransformer/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,10 @@ def transform(
The converted model if the conversion has been successful.
"""

logger.warning(
"The class `optimum.bettertransformers.transformation.BetterTransformer` is deprecated and will be removed in a future release."
)

hf_config = model.config
if hf_config.model_type in ["falcon", "gpt_bigcode", "llama", "whisper"]:
raise ValueError(
Expand Down
18 changes: 18 additions & 0 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import numpy as np
import onnx
import transformers
from transformers.modeling_utils import get_parameter_dtype
from transformers.utils import is_tf_available, is_torch_available

Expand All @@ -34,6 +35,7 @@
DEFAULT_DUMMY_SHAPES,
ONNX_WEIGHTS_NAME,
TORCH_MINIMUM_VERSION,
check_if_transformers_greater,
is_diffusers_available,
is_torch_onnx_support_available,
logging,
Expand Down Expand Up @@ -999,6 +1001,10 @@ def onnx_export_from_model(
>>> onnx_export_from_model(model, output="gpt2_onnx/")
```
"""
if check_if_transformers_greater("4.44.99"):
raise ImportError(
f"ONNX conversion disabled for now for transformers version greater than v4.45, found {transformers.__version__}"
)
Comment on lines +1004 to +1007
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

disabling onnx export as additional fixes are needed (but don't want to block the latest transformers release for other subpackages) @michaelbenayoun @xenova

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok to me, maybe specify that it's temporary.


TasksManager.standardize_model_attributes(model)

Expand Down Expand Up @@ -1120,6 +1126,18 @@ def onnx_export_from_model(
if isinstance(atol, dict):
atol = atol[task.replace("-with-past", "")]

if check_if_transformers_greater("4.44.99"):
misplaced_generation_parameters = model.config._get_non_default_generation_parameters()
if model.can_generate() and len(misplaced_generation_parameters) > 0:
logger.warning(
"Moving the following attributes in the config to the generation config: "
f"{misplaced_generation_parameters}. You are seeing this warning because you've set "
"generation parameters in the model config, as opposed to in the generation config.",
)
for param_name, param_value in misplaced_generation_parameters.items():
setattr(model.generation_config, param_name, param_value)
setattr(model.config, param_name, None)

# Saving the model config and preprocessor as this is needed sometimes.
model.config.save_pretrained(output)
generation_config = getattr(model, "generation_config", None)
Expand Down
11 changes: 6 additions & 5 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]:


class AlbertOnnxConfig(BertOnnxConfig):
DEFAULT_ONNX_OPSET = 11
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.


class ConvBertOnnxConfig(BertOnnxConfig):
Expand Down Expand Up @@ -171,11 +171,11 @@ class MPNetOnnxConfig(DistilBertOnnxConfig):


class RobertaOnnxConfig(DistilBertOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.


class CamembertOnnxConfig(DistilBertOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.


class FlaubertOnnxConfig(BertOnnxConfig):
Expand All @@ -187,7 +187,7 @@ class IBertOnnxConfig(DistilBertOnnxConfig):


class XLMRobertaOnnxConfig(DistilBertOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.


class DebertaOnnxConfig(BertOnnxConfig):
Expand Down Expand Up @@ -257,7 +257,7 @@ class ImageGPTOnnxConfig(GPT2OnnxConfig):


class GPTNeoOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 13
DEFAULT_ONNX_OPSET = 14
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_attention_heads="num_heads")


Expand Down Expand Up @@ -564,6 +564,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int


class M2M100OnnxConfig(TextSeq2SeqOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args(
encoder_num_layers="encoder_layers",
decoder_num_layers="decoder_layers",
Expand Down
3 changes: 3 additions & 0 deletions optimum/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,9 @@ def from_pretrained(
export = from_transformers

if len(model_id.split("@")) == 2:
logger.warning(
f"Specifying the `revision` as @{model_id.split('@')[1]} is deprecated and will be removed in v1.23, please use the `revision` argument instead."
)
if revision is not None:
logger.warning(
f"The argument `revision` was set to {revision} but will be ignored for {model_id.split('@')[1]}"
Expand Down
Loading
Loading