-
Notifications
You must be signed in to change notification settings - Fork 26.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Order of inputs of forward function problematic for jit with Classification models #1010
Comments
Thanks for giving such an in-depth review of the issue, it is very helpful. I indeed see this can be problematic, I'll have a look into it. |
Thanks a lot for the details @dhpollack! As you probably guessed, the strange order of the arguments is the results of trying to minimize the number of breaking changes (for people who rely on the positions to feed keyword arguments) while adding additional functionalities to the library. The resulting situation is not very satisfactory indeed. |
[2.0] Reodering arguments for torch jit #1010 and future TF2.0 compatibility
#1195 seems to have solved this. |
TL;DR
Due to order of args of
forward
in classification models,device
gets hardcoded during jit tracing or causes unwanted overhead. Easy solution (but possibly breaking):Long Version
The order of the inputs of the models is problematic for jit tracing, because you separate the inputs of the base BERT models in the classifications models. Confusing in words, but easy to see in code:
The problem arises because
torch.jit.trace
does not use the python logic when creating the embedding layer. This line,position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
, becomesposition_ids = torch.arange(seq_length, dtype=torch.long, device=torch.device("[device at time of jit]"))
. Importantly,model.to(device)
will not change this hardcoded device in the embeddings. Thus the torch device gets hardcoded into the whole network and one can't usemodel.to(device)
as expected. One could circumvent this problem by explicitly passingposition_ids
at the time of tracing, but thetorch.jit.trace
function only takes a tuple of inputs. Becauselabels
comes beforeposition_ids
, you cannot jit trace the function without putting in dummy labels and doing the extra overhead of calculating the loss, which you don't want for a graph used solely for inference.The simple solution is to change the order of your arguments to make the
labels
argument come after the arguments in the base bert model. Of course, this could break existing scripts that rely on this order, although the current examples use kwargs so it should be a problem.If this were done then one could do:
Finally, and this is a judgement call, it's illogical to stick the labels parameter into the middle of the list of parameters, it probably should be at the end. But that is a minor, minor gripe in an otherwise fantastically built library.
The text was updated successfully, but these errors were encountered: