-
Notifications
You must be signed in to change notification settings - Fork 25.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speed up GELU computation with torch.jit #2988
Conversation
Codecov Report
@@ Coverage Diff @@
## master #2988 +/- ##
==========================================
- Coverage 77.16% 76.11% -1.06%
==========================================
Files 98 98
Lines 15997 15997
==========================================
- Hits 12344 12176 -168
- Misses 3653 3821 +168
Continue to review full report at Codecov.
|
Any reason why the other activation functions (swish, _gelu_python) do not need jit? (I have no experience with JIT, so this is a genuine question. When should jit.script be used, and when shouldn't it?) |
Indeed, it's possible to wrap both activations you mentioned with torch.jit; in case of Answering your question on use of TL;DR: Talking of |
Here's the results for both activations (done on 1080Ti, I've updated the gist with two scripts): _gelu_python
swish
|
Same benchmarks on RTX 2080Ti: _python_gelu
swish
Seems like it makes sense to compile _python_gelu, and for swish the benefits are negligible |
We only use _gelu_python for torch < 1.4. |
I've tested the current implementation with pytorch==1.0.0 on CPU, and it indeed breaks because torch.jit did not support python floats at that time. I have two possible solutions for this, @sshleifer what will be the best one? First: slightly modify gelu_python and gelu_new to be backwards-compatible
Second: use torch.jit.script only with pytorch>1.4.0. We won't need to wrap |
I've changed the PR so that gelu_new gets JIT-compiled only on pytorch>=1.4. Benchmarking resuts are the same with 3-4x faster forward and 3x faster backward for this activation (although no speedup on CPU float32). @sshleifer is it ready to be merged now? |
In my opinion, yes. LGTM. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thanks for iterating on that and getting so deep into it @mryab !
Hello! Unfortunately, we'll have to revert this PR as jitting an activation function prevents the model from being pickled. This has already been an issue in several cases:
Nonetheless, thank you for your contribution and for such a detailed study of what was to be gained from it. |
Currently, the implementation of the GELU activation uses several unfused pointwise operations. In my experiments, computing this activation takes about 10% of forward time for GPT2-like networks for inputs of size similar to (32,128). This PR speeds up the execution of gelu_new during both forward (~3-5x) and backward (~2-3x) passes with the help of torch.jit, which might be helpful for both training and inference.
Below are the benchmarking results, done on pytorch v1.4.0 and transformers v2.5.0 with RTX 2080Ti and GTX 1080Ti. The benchmarking code is available here.
1080Ti:
RTX 2080Ti: