Skip to content

Commit

Permalink
add ort to debertav2 model config (#12)
Browse files Browse the repository at this point in the history
* add ort config for debertav2 model

* remove prints

* remove old commented code

* fix run style error

* add flake ignore comment

* trial to fix blackify format error
  • Loading branch information
harshithapv committed May 18, 2021
1 parent 8e5b0db commit 0b2532a
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 12 deletions.
1 change: 1 addition & 0 deletions examples/pytorch/text-classification/run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def main():
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
ort=True if training_args.ort else None,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
Expand Down
51 changes: 51 additions & 0 deletions src/transformers/models/deberta_v2/jit_tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# flake8: noqa
# coding=utf-8
# Copyright 2020, Microsoft and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Logging util @Author: penhe@microsoft.com
"""

""" Utils for torch jit tracing customer operators/functions"""
import os

import torch


def traceable(cls):
class _Function(object):
@staticmethod
def apply(*args):
if torch.onnx.is_in_onnx_export():
return cls.forward(_Function, *args)
else:
return cls.apply(*args)

@staticmethod
def save_for_backward(*args):
pass

return _Function


class TraceMode:
"""Trace context used when tracing modules contains customer operators/Functions"""

def __enter__(self):
os.environ["JIT_TRACE"] = "True"
return self

def __exit__(self, exp_value, exp_type, trace):
del os.environ["JIT_TRACE"]
65 changes: 53 additions & 12 deletions src/transformers/models/deberta_v2/modeling_deberta_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,21 @@

import numpy as np
import torch
from torch import _softmax_backward_data, nn
from torch.nn import CrossEntropyLoss, LayerNorm
from torch import (
_softmax_backward_data,
nn,
)
from torch.nn import (
CrossEntropyLoss,
LayerNorm,
)

from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from ...modeling_outputs import (
BaseModelOutput,
MaskedLMOutput,
Expand All @@ -34,6 +44,7 @@
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_deberta_v2 import DebertaV2Config
from .jit_tracing import traceable


logger = logging.get_logger(__name__)
Expand All @@ -55,7 +66,10 @@ class ContextPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
self.dropout = StableDropout(config.pooler_dropout)
if config.ort:
self.dropout = TorchNNDropout(config.pooler_dropout)
else:
self.dropout = StableDropout(config.pooler_dropout)
self.config = config

def forward(self, hidden_states):
Expand All @@ -73,6 +87,7 @@ def output_dim(self):
return self.config.hidden_size


@traceable
# Copied from transformers.models.deberta.modeling_deberta.XSoftmax with deberta->deberta_v2
class XSoftmax(torch.autograd.Function):
"""
Expand Down Expand Up @@ -144,6 +159,7 @@ def get_mask(input, local_context):
return mask, dropout


@traceable
# Copied from transformers.models.deberta.modeling_deberta.XDropout
class XDropout(torch.autograd.Function):
"""Optimized dropout function to save computation and memory by using mask operation instead of multiplication."""
Expand All @@ -167,6 +183,11 @@ def backward(ctx, grad_output):
return grad_output, None


class TorchNNDropout(torch.nn.Dropout):
def __init__(self, drop_prob):
super().__init__(drop_prob)


# Copied from transformers.models.deberta.modeling_deberta.StableDropout
class StableDropout(torch.nn.Module):
"""
Expand Down Expand Up @@ -223,7 +244,10 @@ def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
self.dropout = StableDropout(config.hidden_dropout_prob)
if config.ort:
self.dropout = TorchNNDropout(config.hidden_dropout_prob)
else:
self.dropout = StableDropout(config.hidden_dropout_prob)

def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
Expand Down Expand Up @@ -291,7 +315,10 @@ def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
self.dropout = StableDropout(config.hidden_dropout_prob)
if config.ort:
self.dropout = TorchNNDropout(config.hidden_dropout_prob)
else:
self.dropout = StableDropout(config.hidden_dropout_prob)
self.config = config

def forward(self, hidden_states, input_tensor):
Expand Down Expand Up @@ -346,7 +373,10 @@ def __init__(self, config):
config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups
)
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
self.dropout = StableDropout(config.hidden_dropout_prob)
if config.ort:
self.dropout = TorchNNDropout(config.hidden_dropout_prob)
else:
self.dropout = StableDropout(config.hidden_dropout_prob)
self.config = config

def forward(self, hidden_states, residual_states, input_mask):
Expand Down Expand Up @@ -584,16 +614,21 @@ def __init__(self, config):
self.pos_ebd_size = self.max_relative_positions
if self.position_buckets > 0:
self.pos_ebd_size = self.position_buckets

self.pos_dropout = StableDropout(config.hidden_dropout_prob)
if config.ort:
self.pos_dropout = TorchNNDropout(config.hidden_dropout_prob)
else:
self.pos_dropout = StableDropout(config.hidden_dropout_prob)

if not self.share_att_key:
if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type:
self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size)

self.dropout = StableDropout(config.attention_probs_dropout_prob)
if config.ort:
self.dropout = TorchNNDropout(config.attention_probs_dropout_prob)
else:
self.dropout = StableDropout(config.attention_probs_dropout_prob)

def transpose_for_scores(self, x, attention_heads):
new_x_shape = x.size()[:-1] + (attention_heads, -1)
Expand Down Expand Up @@ -816,7 +851,10 @@ def __init__(self, config):
if self.embedding_size != config.hidden_size:
self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False)
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
self.dropout = StableDropout(config.hidden_dropout_prob)
if config.ort:
self.dropout = TorchNNDropout(config.hidden_dropout_prob)
else:
self.dropout = StableDropout(config.hidden_dropout_prob)
self.config = config

# position_ids (1, len position emb) is contiguous in memory and exported when serialized
Expand Down Expand Up @@ -1247,7 +1285,10 @@ def __init__(self, config):
self.classifier = torch.nn.Linear(output_dim, num_labels)
drop_out = getattr(config, "cls_dropout", None)
drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
self.dropout = StableDropout(drop_out)
if config.ort:
self.dropout = TorchNNDropout(drop_out)
else:
self.dropout = StableDropout(drop_out)

self.init_weights()

Expand Down

0 comments on commit 0b2532a

Please sign in to comment.