Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove data based dependencies in T5 for ORT #4

Merged
merged 4 commits into from
Apr 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/seq2seq/run_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ def main():
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
ort=True if training_args.ort else None,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ class PretrainedConfig(object):

- **use_bfloat16** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should use
BFloat16 scalars (only used by some TensorFlow models).

Onnxruntime specific parameters

- **ort** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should use ORT.
"""
model_type: str = ""
is_composition: bool = False
Expand All @@ -186,6 +190,7 @@ def __init__(self, **kwargs):
self.output_attentions = kwargs.pop("output_attentions", False)
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
self.ort = kwargs.pop("ort", False)
self.pruned_heads = kwargs.pop("pruned_heads", {})
self.tie_word_embeddings = kwargs.pop(
"tie_word_embeddings", True
Expand Down
40 changes: 31 additions & 9 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,7 @@ class T5Block(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
self.is_decoder = config.is_decoder
self.ort = config.ort
self.layer = nn.ModuleList()
self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
if self.is_decoder:
Expand Down Expand Up @@ -643,9 +644,16 @@ def forward(
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights

# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
if self.ort:
# Remove data-based control flow for static graph
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(torch.isinf(hidden_states).any(), torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
else:
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

do_cross_attention = self.is_decoder and encoder_hidden_states is not None
if do_cross_attention:
Expand All @@ -670,9 +678,16 @@ def forward(
hidden_states = cross_attention_outputs[0]

# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
if self.ort:
# Remove data-based control flow for static graph
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(torch.isinf(hidden_states).any(), torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
else:
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

# Combine self attn and cross attn key value states
if present_key_value_state is not None:
Expand All @@ -685,9 +700,16 @@ def forward(
hidden_states = self.layer[-1](hidden_states)

# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
if self.ort:
# Remove data-based control flow for static graph
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(torch.isinf(hidden_states).any(), torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
else:
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

outputs = (hidden_states,)

Expand Down