Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Order of inputs of forward function problematic for jit with Classification models #1010

Closed
dhpollack opened this issue Aug 12, 2019 · 3 comments

Comments

@dhpollack
Copy link
Contributor

TL;DR

Due to order of args of forward in classification models, device gets hardcoded during jit tracing or causes unwanted overhead. Easy solution (but possibly breaking):

# change this
# classification BERT
class BertForSequenceClassification(BertPreTrainedModel):
    ...
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
                position_ids=None, head_mask=None):
        ...
# to this
# classification BERT
class BertForSequenceClassification(BertPreTrainedModel):
    ...
    def forward(self, input_ids, token_type_ids=None, attention_mask=None,
                position_ids=None, head_mask=None, labels=None):
        ...

Long Version

The order of the inputs of the models is problematic for jit tracing, because you separate the inputs of the base BERT models in the classifications models. Confusing in words, but easy to see in code:

# base BERT
class BertModel(BertPreTrainedModel):
    ...
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None):
        ...

# classification BERT
# notice the order where labels comes in
class BertForSequenceClassification(BertPreTrainedModel):
    ...
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
                position_ids=None, head_mask=None):
        ...

The problem arises because torch.jit.trace does not use the python logic when creating the embedding layer. This line, position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device), becomes position_ids = torch.arange(seq_length, dtype=torch.long, device=torch.device("[device at time of jit]")). Importantly, model.to(device) will not change this hardcoded device in the embeddings. Thus the torch device gets hardcoded into the whole network and one can't use model.to(device) as expected. One could circumvent this problem by explicitly passing position_ids at the time of tracing, but the torch.jit.trace function only takes a tuple of inputs. Because labels comes before position_ids, you cannot jit trace the function without putting in dummy labels and doing the extra overhead of calculating the loss, which you don't want for a graph used solely for inference.

The simple solution is to change the order of your arguments to make the labels argument come after the arguments in the base bert model. Of course, this could break existing scripts that rely on this order, although the current examples use kwargs so it should be a problem.

# classification BERT
class BertForSequenceClassification(BertPreTrainedModel):
    ...
    def forward(self, input_ids, token_type_ids=None, attention_mask=None,
                position_ids=None, head_mask=None, labels=None):
        ...

If this were done then one could do:

#  model = any of the classification models
msl = 15  # max sequence length, which gets hardcoded into the network
inputs = [
    torch.ones(1, msl, dtype=torch.long()),  # input_ids
    torch.ones(1, msl, dtype=torch.long()),  # segment_ids
    torch.ones(1, msl, dtype=torch.long()),  # attention_masks
    torch.ones(1, msl, dtype=torch.long()),  # position_ids
]
traced_model = torch.jit.trace(model, input)

Finally, and this is a judgement call, it's illogical to stick the labels parameter into the middle of the list of parameters, it probably should be at the end. But that is a minor, minor gripe in an otherwise fantastically built library.

@LysandreJik
Copy link
Member

Thanks for giving such an in-depth review of the issue, it is very helpful. I indeed see this can be problematic, I'll have a look into it.

@thomwolf
Copy link
Member

thomwolf commented Aug 19, 2019

Thanks a lot for the details @dhpollack!

As you probably guessed, the strange order of the arguments is the results of trying to minimize the number of breaking changes (for people who rely on the positions to feed keyword arguments) while adding additional functionalities to the library.

The resulting situation is not very satisfactory indeed.
Personally, I think it's probably time to reorder the keyword arguments.

@dhpollack
Copy link
Contributor Author

#1195 seems to have solved this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants