Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up GELU computation with torch.jit #2988

Merged
merged 3 commits into from
Apr 3, 2020
Merged

Speed up GELU computation with torch.jit #2988

merged 3 commits into from
Apr 3, 2020

Conversation

mryab
Copy link
Contributor

@mryab mryab commented Feb 23, 2020

Currently, the implementation of the GELU activation uses several unfused pointwise operations. In my experiments, computing this activation takes about 10% of forward time for GPT2-like networks for inputs of size similar to (32,128). This PR speeds up the execution of gelu_new during both forward (~3-5x) and backward (~2-3x) passes with the help of torch.jit, which might be helpful for both training and inference.

Below are the benchmarking results, done on pytorch v1.4.0 and transformers v2.5.0 with RTX 2080Ti and GTX 1080Ti. The benchmarking code is available here.

1080Ti:

torch.float32   (32, 128)       gelu 2.6e-04    4.1e-04 jit 1.1e-04     1.8e-04 speedup forward 2.50    backward 2.27
torch.float32   (32, 512)       gelu 2.6e-04    4.1e-04 jit 6.5e-05     1.5e-04 speedup forward 4.06    backward 2.67
torch.float32   (32, 1024)      gelu 2.6e-04    4.0e-04 jit 6.7e-05     1.6e-04 speedup forward 3.94    backward 2.59
torch.float32   (32, 4096)      gelu 2.5e-04    3.9e-04 jit 6.6e-05     1.6e-04 speedup forward 3.75    backward 2.51
torch.float32   (256, 128)      gelu 2.7e-04    4.1e-04 jit 6.7e-05     1.6e-04 speedup forward 3.96    backward 2.61
torch.float32   (256, 512)      gelu 2.5e-04    4.0e-04 jit 6.5e-05     1.5e-04 speedup forward 3.88    backward 2.57
torch.float32   (256, 1024)     gelu 2.5e-04    4.0e-04 jit 6.2e-05     1.5e-04 speedup forward 4.05    backward 2.62
torch.float32   (256, 4096)     gelu 2.6e-04    4.2e-04 jit 1.0e-04     1.7e-04 speedup forward 2.52    backward 2.45
torch.float32   (1024, 128)     gelu 2.5e-04    3.9e-04 jit 6.5e-05     1.5e-04 speedup forward 3.82    backward 2.57
torch.float32   (1024, 512)     gelu 2.5e-04    3.8e-04 jit 7.2e-05     1.5e-04 speedup forward 3.43    backward 2.52
torch.float32   (1024, 1024)    gelu 2.6e-04    4.2e-04 jit 1.0e-04     1.7e-04 speedup forward 2.52    backward 2.44
torch.float32   (1024, 4096)    gelu 8.8e-04    1.3e-03 jit 3.2e-04     3.5e-04 speedup forward 2.71    backward 3.79
torch.float32   (8192, 128)     gelu 2.6e-04    4.2e-04 jit 1.0e-04     1.7e-04 speedup forward 2.51    backward 2.43
torch.float32   (8192, 512)     gelu 8.8e-04    1.3e-03 jit 3.2e-04     3.5e-04 speedup forward 2.72    backward 3.80
torch.float32   (8192, 1024)    gelu 1.7e-03    2.5e-03 jit 6.4e-04     5.9e-04 speedup forward 2.69    backward 4.30
torch.float32   (8192, 4096)    gelu 6.7e-03    1.0e-02 jit 2.7e-03     2.5e-03 speedup forward 2.53    backward 4.05
torch.float16   (32, 128)       gelu 2.6e-04    4.0e-04 jit 9.4e-05     1.8e-04 speedup forward 2.79    backward 2.24
torch.float16   (32, 512)       gelu 2.5e-04    3.9e-04 jit 6.2e-05     1.4e-04 speedup forward 4.09    backward 2.74
torch.float16   (32, 1024)      gelu 2.6e-04    4.0e-04 jit 6.2e-05     1.5e-04 speedup forward 4.22    backward 2.68
torch.float16   (32, 4096)      gelu 2.4e-04    3.8e-04 jit 6.3e-05     1.5e-04 speedup forward 3.84    backward 2.56
torch.float16   (256, 128)      gelu 2.6e-04    4.0e-04 jit 6.1e-05     1.4e-04 speedup forward 4.34    backward 2.81
torch.float16   (256, 512)      gelu 2.5e-04    3.9e-04 jit 6.4e-05     1.5e-04 speedup forward 3.98    backward 2.59
torch.float16   (256, 1024)     gelu 2.4e-04    3.7e-04 jit 6.3e-05     1.4e-04 speedup forward 3.82    backward 2.65
torch.float16   (256, 4096)     gelu 2.3e-04    3.2e-04 jit 7.6e-05     1.4e-04 speedup forward 3.00    backward 2.32
torch.float16   (1024, 128)     gelu 2.2e-04    3.2e-04 jit 6.3e-05     1.4e-04 speedup forward 3.47    backward 2.32
torch.float16   (1024, 512)     gelu 2.2e-04    3.2e-04 jit 6.3e-05     1.4e-04 speedup forward 3.47    backward 2.31
torch.float16   (1024, 1024)    gelu 2.3e-04    3.2e-04 jit 7.6e-05     1.4e-04 speedup forward 3.01    backward 2.31
torch.float16   (1024, 4096)    gelu 5.4e-04    8.9e-04 jit 2.2e-04     2.6e-04 speedup forward 2.44    backward 3.40
torch.float16   (8192, 128)     gelu 2.5e-04    3.8e-04 jit 7.6e-05     1.5e-04 speedup forward 3.29    backward 2.61
torch.float16   (8192, 512)     gelu 5.4e-04    8.9e-04 jit 2.2e-04     2.5e-04 speedup forward 2.43    backward 3.49
torch.float16   (8192, 1024)    gelu 1.0e-03    1.7e-03 jit 4.8e-04     4.6e-04 speedup forward 2.18    backward 3.60
torch.float16   (8192, 4096)    gelu 4.2e-03    6.5e-03 jit 2.3e-03     2.0e-03 speedup forward 1.83    backward 3.30

RTX 2080Ti:

torch.float32   (32, 128)       gelu 3.0e-04    6.2e-04 jit 1.2e-04     2.2e-04 speedup forward 2.50    backward 2.80
torch.float32   (32, 512)       gelu 3.2e-04    6.8e-04 jit 6.8e-05     2.1e-04 speedup forward 4.66    backward 3.20
torch.float32   (32, 1024)      gelu 3.4e-04    7.2e-04 jit 6.8e-05     2.1e-04 speedup forward 4.96    backward 3.38
torch.float32   (32, 4096)      gelu 3.3e-04    7.0e-04 jit 6.4e-05     1.8e-04 speedup forward 5.07    backward 3.83
torch.float32   (256, 128)      gelu 3.3e-04    6.9e-04 jit 6.5e-05     1.9e-04 speedup forward 5.07    backward 3.57
torch.float32   (256, 512)      gelu 3.0e-04    6.2e-04 jit 6.4e-05     1.9e-04 speedup forward 4.73    backward 3.21
torch.float32   (256, 1024)     gelu 3.3e-04    6.9e-04 jit 6.6e-05     2.1e-04 speedup forward 4.95    backward 3.35
torch.float32   (256, 4096)     gelu 3.3e-04    6.8e-04 jit 9.3e-05     2.2e-04 speedup forward 3.53    backward 3.09
torch.float32   (1024, 128)     gelu 3.1e-04    6.2e-04 jit 6.5e-05     1.9e-04 speedup forward 4.70    backward 3.32
torch.float32   (1024, 512)     gelu 3.4e-04    6.4e-04 jit 7.7e-05     1.9e-04 speedup forward 4.41    backward 3.30
torch.float32   (1024, 1024)    gelu 3.1e-04    6.1e-04 jit 9.5e-05     2.2e-04 speedup forward 3.26    backward 2.73
torch.float32   (1024, 4096)    gelu 6.2e-04    9.9e-04 jit 2.7e-04     3.1e-04 speedup forward 2.26    backward 3.15
torch.float32   (8192, 128)     gelu 3.1e-04    4.9e-04 jit 9.7e-05     1.9e-04 speedup forward 3.13    backward 2.55
torch.float32   (8192, 512)     gelu 6.1e-04    1.0e-03 jit 2.7e-04     3.4e-04 speedup forward 2.27    backward 2.99
torch.float32   (8192, 1024)    gelu 1.2e-03    1.9e-03 jit 5.3e-04     5.5e-04 speedup forward 2.21    backward 3.38
torch.float32   (8192, 4096)    gelu 4.5e-03    6.7e-03 jit 2.2e-03     1.6e-03 speedup forward 2.04    backward 4.24
torch.float16   (32, 128)       gelu 3.2e-04    6.3e-04 jit 1.1e-04     2.2e-04 speedup forward 2.84    backward 2.92
torch.float16   (32, 512)       gelu 3.3e-04    6.9e-04 jit 6.2e-05     1.6e-04 speedup forward 5.23    backward 4.29
torch.float16   (32, 1024)      gelu 3.0e-04    5.9e-04 jit 6.5e-05     1.7e-04 speedup forward 4.58    backward 3.46
torch.float16   (32, 4096)      gelu 3.0e-04    6.1e-04 jit 6.4e-05     1.8e-04 speedup forward 4.63    backward 3.34
torch.float16   (256, 128)      gelu 3.0e-04    5.9e-04 jit 6.4e-05     1.7e-04 speedup forward 4.61    backward 3.49
torch.float16   (256, 512)      gelu 3.0e-04    5.9e-04 jit 6.3e-05     1.7e-04 speedup forward 4.68    backward 3.41
torch.float16   (256, 1024)     gelu 2.9e-04    5.7e-04 jit 6.5e-05     1.6e-04 speedup forward 4.40    backward 3.54
torch.float16   (256, 4096)     gelu 2.9e-04    5.5e-04 jit 7.5e-05     2.0e-04 speedup forward 3.87    backward 2.74
torch.float16   (1024, 128)     gelu 3.7e-04    6.3e-04 jit 8.0e-05     2.3e-04 speedup forward 4.59    backward 2.75
torch.float16   (1024, 512)     gelu 3.4e-04    6.0e-04 jit 6.6e-05     1.6e-04 speedup forward 5.13    backward 3.81
torch.float16   (1024, 1024)    gelu 3.0e-04    5.9e-04 jit 7.2e-05     1.9e-04 speedup forward 4.12    backward 3.08
torch.float16   (1024, 4096)    gelu 4.1e-04    6.9e-04 jit 1.6e-04     2.6e-04 speedup forward 2.49    backward 2.68
torch.float16   (8192, 128)     gelu 3.6e-04    6.6e-04 jit 7.0e-05     1.8e-04 speedup forward 5.08    backward 3.73
torch.float16   (8192, 512)     gelu 4.1e-04    7.0e-04 jit 1.6e-04     2.5e-04 speedup forward 2.57    backward 2.76
torch.float16   (8192, 1024)    gelu 7.4e-04    1.2e-03 jit 3.2e-04     4.1e-04 speedup forward 2.30    backward 2.81
torch.float16   (8192, 4096)    gelu 2.8e-03    3.9e-03 jit 1.5e-03     1.2e-03 speedup forward 1.86    backward 3.34

@codecov-io
Copy link

codecov-io commented Feb 24, 2020

Codecov Report

Merging #2988 into master will decrease coverage by 1.05%.
The diff coverage is 100%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #2988      +/-   ##
==========================================
- Coverage   77.16%   76.11%   -1.06%     
==========================================
  Files          98       98              
  Lines       15997    15997              
==========================================
- Hits        12344    12176     -168     
- Misses       3653     3821     +168
Impacted Files Coverage Δ
src/transformers/activations.py 75% <100%> (-12.5%) ⬇️
src/transformers/modeling_tf_pytorch_utils.py 8.72% <0%> (-81.21%) ⬇️
src/transformers/modeling_roberta.py 85.71% <0%> (-10%) ⬇️
src/transformers/modeling_xlnet.py 73.48% <0%> (-2.3%) ⬇️
src/transformers/modeling_ctrl.py 96.03% <0%> (-2.21%) ⬇️
src/transformers/modeling_openai.py 80.2% <0%> (-1.35%) ⬇️
src/transformers/modeling_utils.py 92.2% <0%> (-0.17%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 38f5fe9...7e91273. Read the comment docs.

@BramVanroy
Copy link
Collaborator

Any reason why the other activation functions (swish, _gelu_python) do not need jit? (I have no experience with JIT, so this is a genuine question. When should jit.script be used, and when shouldn't it?)

@mryab
Copy link
Contributor Author

mryab commented Feb 24, 2020

Indeed, it's possible to wrap both activations you mentioned with torch.jit; in case of _gelu_python it's likely to yield similar reduction in execution time. I will come back with benchmarking results and, if you think it's a good idea, will add JIT compilation to this PR.

Answering your question on use of jit.script: it usually makes sense to optimize functions with many elementwise ops, as they tend to get fused into a single kernel, which eliminates unnecessary memory accesses. There are other advantages, e.g. removing Python overhead and lifting GIL as a result; if you're interested, this tutorial and this blogpost give a good overview of underlying optimizations.

TL;DR: jit.script useful when you have TorchScript-friendly functions/modules with lots of custom PyTorch code; if your code uses unsupported Python features, you either leave it be or use torch.jit.trace.

Talking of swish, there is something I'd like to mention: its current implementation can be made more memory-efficient (see this and this) at the cost of losing torch.jit/torch.onnx support. Not sure if swish will benefit much from JIT compilation — would memory savings be useful then?

@mryab
Copy link
Contributor Author

mryab commented Feb 24, 2020

Here's the results for both activations (done on 1080Ti, I've updated the gist with two scripts):

_gelu_python

torch.float32   (32, 512)       gelu 1.9e-04    3.7e-04 jit 7.0e-05     1.6e-04 speedup forward 2.69    backward 2.36
torch.float32   (32, 1024)      gelu 1.9e-04    3.7e-04 jit 6.9e-05     1.6e-04 speedup forward 2.71    backward 2.33
torch.float32   (32, 4096)      gelu 1.8e-04    3.6e-04 jit 6.9e-05     1.6e-04 speedup forward 2.66    backward 2.29
torch.float32   (256, 128)      gelu 1.9e-04    3.6e-04 jit 7.0e-05     1.6e-04 speedup forward 2.66    backward 2.30
torch.float32   (256, 512)      gelu 1.8e-04    3.6e-04 jit 6.9e-05     1.6e-04 speedup forward 2.65    backward 2.31
torch.float32   (256, 1024)     gelu 1.8e-04    3.6e-04 jit 6.9e-05     1.6e-04 speedup forward 2.67    backward 2.30
torch.float32   (256, 4096)     gelu 1.7e-04    3.6e-04 jit 9.8e-05     1.5e-04 speedup forward 1.74    backward 2.33
torch.float32   (1024, 128)     gelu 1.8e-04    3.6e-04 jit 6.9e-05     1.6e-04 speedup forward 2.67    backward 2.30
torch.float32   (1024, 512)     gelu 1.9e-04    3.6e-04 jit 7.3e-05     1.6e-04 speedup forward 2.55    backward 2.34
torch.float32   (1024, 1024)    gelu 1.7e-04    3.5e-04 jit 9.9e-05     1.6e-04 speedup forward 1.74    backward 2.29
torch.float32   (1024, 4096)    gelu 5.1e-04    1.1e-03 jit 3.1e-04     2.9e-04 speedup forward 1.65    backward 3.78
torch.float32   (8192, 128)     gelu 1.7e-04    3.6e-04 jit 1.0e-04     1.5e-04 speedup forward 1.74    backward 2.30
torch.float32   (8192, 512)     gelu 5.1e-04    1.1e-03 jit 3.1e-04     2.9e-04 speedup forward 1.65    backward 3.78
torch.float32   (8192, 1024)    gelu 9.8e-04    2.1e-03 jit 6.1e-04     4.6e-04 speedup forward 1.61    backward 4.43
torch.float32   (8192, 4096)    gelu 3.8e-03    8.1e-03 jit 2.6e-03     1.9e-03 speedup forward 1.46    backward 4.15
torch.float16   (32, 128)       gelu 1.9e-04    3.6e-04 jit 9.6e-05     1.8e-04 speedup forward 1.94    backward 1.98
torch.float16   (32, 512)       gelu 1.8e-04    3.6e-04 jit 6.8e-05     1.5e-04 speedup forward 2.73    backward 2.38
torch.float16   (32, 1024)      gelu 1.9e-04    3.6e-04 jit 7.0e-05     1.6e-04 speedup forward 2.66    backward 2.28
torch.float16   (32, 4096)      gelu 1.9e-04    3.6e-04 jit 6.9e-05     1.6e-04 speedup forward 2.68    backward 2.33
torch.float16   (256, 128)      gelu 1.9e-04    3.6e-04 jit 7.0e-05     1.6e-04 speedup forward 2.66    backward 2.29
torch.float16   (256, 512)      gelu 1.9e-04    3.6e-04 jit 6.9e-05     1.6e-04 speedup forward 2.67    backward 2.30
torch.float16   (256, 1024)     gelu 1.9e-04    3.7e-04 jit 7.0e-05     1.6e-04 speedup forward 2.68    backward 2.31
torch.float16   (256, 4096)     gelu 1.9e-04    3.7e-04 jit 7.4e-05     1.5e-04 speedup forward 2.56    backward 2.43
torch.float16   (1024, 128)     gelu 1.9e-04    3.6e-04 jit 6.9e-05     1.6e-04 speedup forward 2.67    backward 2.28
torch.float16   (1024, 512)     gelu 1.9e-04    3.6e-04 jit 6.9e-05     1.6e-04 speedup forward 2.69    backward 2.30
torch.float16   (1024, 1024)    gelu 1.9e-04    3.7e-04 jit 7.4e-05     1.5e-04 speedup forward 2.56    backward 2.40
torch.float16   (1024, 4096)    gelu 3.3e-04    8.1e-04 jit 2.1e-04     2.3e-04 speedup forward 1.62    backward 3.50
torch.float16   (8192, 128)     gelu 1.9e-04    3.7e-04 jit 7.4e-05     1.6e-04 speedup forward 2.56    backward 2.34
torch.float16   (8192, 512)     gelu 3.4e-04    8.1e-04 jit 2.1e-04     2.3e-04 speedup forward 1.62    backward 3.51
torch.float16   (8192, 1024)    gelu 6.3e-04    1.5e-03 jit 4.5e-04     3.7e-04 speedup forward 1.39    backward 4.06
torch.float16   (8192, 4096)    gelu 2.5e-03    5.9e-03 jit 2.2e-03     1.5e-03 speedup forward 1.11    backward 3.93

swish

torch.float32   (32, 128)       swish 5.9e-05   1.8e-04 jit 1.0e-04     1.8e-04 speedup forward 0.59    backward 1.01
torch.float32   (32, 512)       swish 5.8e-05   1.8e-04 jit 5.4e-05     1.4e-04 speedup forward 1.08    backward 1.30
torch.float32   (32, 1024)      swish 5.8e-05   1.8e-04 jit 5.4e-05     1.4e-04 speedup forward 1.08    backward 1.31
torch.float32   (32, 4096)      swish 5.9e-05   1.8e-04 jit 5.4e-05     1.4e-04 speedup forward 1.08    backward 1.33
torch.float32   (256, 128)      swish 5.8e-05   1.8e-04 jit 5.4e-05     1.4e-04 speedup forward 1.08    backward 1.33
torch.float32   (256, 512)      swish 5.9e-05   1.8e-04 jit 5.4e-05     1.3e-04 speedup forward 1.09    backward 1.36
torch.float32   (256, 1024)     swish 5.9e-05   1.8e-04 jit 5.4e-05     1.4e-04 speedup forward 1.09    backward 1.32
torch.float32   (256, 4096)     swish 8.6e-05   2.2e-04 jit 7.4e-05     1.4e-04 speedup forward 1.17    backward 1.57
torch.float32   (1024, 128)     swish 5.8e-05   1.8e-04 jit 5.5e-05     1.4e-04 speedup forward 1.07    backward 1.28
torch.float32   (1024, 512)     swish 6.7e-05   1.9e-04 jit 5.6e-05     1.4e-04 speedup forward 1.20    backward 1.31
torch.float32   (1024, 1024)    swish 8.6e-05   2.2e-04 jit 7.4e-05     1.4e-04 speedup forward 1.17    backward 1.56
torch.float32   (1024, 4096)    swish 2.6e-04   5.8e-04 jit 2.0e-04     2.4e-04 speedup forward 1.33    backward 2.39
torch.float32   (8192, 128)     swish 8.8e-05   2.2e-04 jit 7.4e-05     1.4e-04 speedup forward 1.18    backward 1.63
torch.float32   (8192, 512)     swish 2.6e-04   5.7e-04 jit 2.0e-04     2.4e-04 speedup forward 1.34    backward 2.36
torch.float32   (8192, 1024)    swish 4.9e-04   1.0e-03 jit 3.7e-04     3.9e-04 speedup forward 1.32    backward 2.69
torch.float32   (8192, 4096)    swish 1.9e-03   4.1e-03 jit 1.5e-03     1.6e-03 speedup forward 1.25    backward 2.56
torch.float16   (32, 128)       swish 5.8e-05   1.8e-04 jit 9.5e-05     1.7e-04 speedup forward 0.62    backward 1.06
torch.float16   (32, 512)       swish 5.8e-05   1.8e-04 jit 5.4e-05     1.3e-04 speedup forward 1.09    backward 1.35
torch.float16   (32, 1024)      swish 5.9e-05   1.8e-04 jit 5.4e-05     1.3e-04 speedup forward 1.10    backward 1.32
torch.float16   (32, 4096)      swish 5.9e-05   1.8e-04 jit 5.4e-05     1.4e-04 speedup forward 1.10    backward 1.30
torch.float16   (256, 128)      swish 5.8e-05   1.8e-04 jit 5.3e-05     1.3e-04 speedup forward 1.09    backward 1.33
torch.float16   (256, 512)      swish 5.9e-05   1.8e-04 jit 5.4e-05     1.4e-04 speedup forward 1.10    backward 1.29
torch.float16   (256, 1024)     swish 5.9e-05   1.8e-04 jit 5.4e-05     1.4e-04 speedup forward 1.09    backward 1.30
torch.float16   (256, 4096)     swish 7.5e-05   1.8e-04 jit 8.1e-05     1.4e-04 speedup forward 0.93    backward 1.31
torch.float16   (1024, 128)     swish 5.9e-05   1.8e-04 jit 5.4e-05     1.4e-04 speedup forward 1.10    backward 1.29
torch.float16   (1024, 512)     swish 5.9e-05   1.8e-04 jit 5.9e-05     1.4e-04 speedup forward 1.00    backward 1.30
torch.float16   (1024, 1024)    swish 7.3e-05   1.8e-04 jit 8.1e-05     1.4e-04 speedup forward 0.91    backward 1.30
torch.float16   (1024, 4096)    swish 2.1e-04   4.3e-04 jit 2.1e-04     2.1e-04 speedup forward 0.99    backward 2.08
torch.float16   (8192, 128)     swish 7.4e-05   1.8e-04 jit 8.1e-05     1.3e-04 speedup forward 0.91    backward 1.37
torch.float16   (8192, 512)     swish 2.1e-04   4.2e-04 jit 2.1e-04     2.1e-04 speedup forward 0.99    backward 2.03
torch.float16   (8192, 1024)    swish 3.8e-04   7.5e-04 jit 3.7e-04     3.0e-04 speedup forward 1.02    backward 2.47
torch.float16   (8192, 4096)    swish 1.4e-03   2.8e-03 jit 1.4e-03     1.1e-03 speedup forward 1.06    backward 2.60

@mryab
Copy link
Contributor Author

mryab commented Feb 25, 2020

Same benchmarks on RTX 2080Ti:

_python_gelu

torch.float32   (32, 128)       gelu 2.1e-04    5.9e-04 jit 1.2e-04     2.2e-04 speedup forward 1.79    backward 2.63
torch.float32   (32, 512)       gelu 2.3e-04    6.0e-04 jit 6.5e-05     1.6e-04 speedup forward 3.59    backward 3.76
torch.float32   (32, 1024)      gelu 2.3e-04    5.8e-04 jit 6.4e-05     1.6e-04 speedup forward 3.54    backward 3.73
torch.float32   (32, 4096)      gelu 1.7e-04    3.3e-04 jit 6.2e-05     1.4e-04 speedup forward 2.65    backward 2.38
torch.float32   (256, 128)      gelu 1.7e-04    3.6e-04 jit 6.6e-05     1.9e-04 speedup forward 2.59    backward 1.94
torch.float32   (256, 512)      gelu 2.5e-04    6.7e-04 jit 6.7e-05     2.0e-04 speedup forward 3.71    backward 3.34
torch.float32   (256, 1024)     gelu 2.3e-04    6.1e-04 jit 6.6e-05     1.9e-04 speedup forward 3.41    backward 3.25
torch.float32   (256, 4096)     gelu 2.1e-04    5.3e-04 jit 9.2e-05     2.0e-04 speedup forward 2.33    backward 2.64
torch.float32   (1024, 128)     gelu 2.1e-04    5.0e-04 jit 6.5e-05     1.9e-04 speedup forward 3.25    backward 2.70
torch.float32   (1024, 512)     gelu 2.2e-04    5.2e-04 jit 6.7e-05     1.8e-04 speedup forward 3.21    backward 2.91
torch.float32   (1024, 1024)    gelu 2.4e-04    6.1e-04 jit 9.2e-05     2.0e-04 speedup forward 2.56    backward 3.06
torch.float32   (1024, 4096)    gelu 4.0e-04    9.3e-04 jit 2.7e-04     3.6e-04 speedup forward 1.44    backward 2.58
torch.float32   (8192, 128)     gelu 2.3e-04    5.7e-04 jit 9.4e-05     2.2e-04 speedup forward 2.44    backward 2.63
torch.float32   (8192, 512)     gelu 4.0e-04    9.3e-04 jit 2.7e-04     3.4e-04 speedup forward 1.47    backward 2.76
torch.float32   (8192, 1024)    gelu 7.4e-04    1.6e-03 jit 5.5e-04     4.8e-04 speedup forward 1.36    backward 3.42
torch.float32   (8192, 4096)    gelu 2.8e-03    5.8e-03 jit 2.2e-03     1.3e-03 speedup forward 1.26    backward 4.55
torch.float16   (32, 128)       gelu 2.4e-04    6.7e-04 jit 1.1e-04     2.0e-04 speedup forward 2.16    backward 3.29
torch.float16   (32, 512)       gelu 2.4e-04    5.0e-04 jit 7.6e-05     1.8e-04 speedup forward 3.11    backward 2.80
torch.float16   (32, 1024)      gelu 2.1e-04    5.4e-04 jit 6.4e-05     1.8e-04 speedup forward 3.31    backward 3.03
torch.float16   (32, 4096)      gelu 2.2e-04    5.7e-04 jit 6.5e-05     1.9e-04 speedup forward 3.40    backward 3.04
torch.float16   (256, 128)      gelu 2.1e-04    5.3e-04 jit 7.1e-05     2.0e-04 speedup forward 2.93    backward 2.61
torch.float16   (256, 512)      gelu 2.2e-04    4.8e-04 jit 7.9e-05     2.1e-04 speedup forward 2.83    backward 2.27
torch.float16   (256, 1024)     gelu 2.2e-04    5.8e-04 jit 6.4e-05     1.8e-04 speedup forward 3.35    backward 3.28
torch.float16   (256, 4096)     gelu 1.9e-04    4.5e-04 jit 6.5e-05     1.6e-04 speedup forward 2.93    backward 2.85
torch.float16   (1024, 128)     gelu 1.9e-04    4.5e-04 jit 6.4e-05     1.7e-04 speedup forward 2.99    backward 2.73
torch.float16   (1024, 512)     gelu 1.9e-04    4.4e-04 jit 5.9e-05     1.5e-04 speedup forward 3.18    backward 2.97
torch.float16   (1024, 1024)    gelu 2.1e-04    5.2e-04 jit 6.5e-05     1.6e-04 speedup forward 3.16    backward 3.23
torch.float16   (1024, 4096)    gelu 2.8e-04    6.4e-04 jit 1.5e-04     2.4e-04 speedup forward 1.83    backward 2.60
torch.float16   (8192, 128)     gelu 2.1e-04    5.4e-04 jit 6.4e-05     1.8e-04 speedup forward 3.27    backward 2.96
torch.float16   (8192, 512)     gelu 2.8e-04    6.7e-04 jit 1.5e-04     2.4e-04 speedup forward 1.83    backward 2.79
torch.float16   (8192, 1024)    gelu 4.8e-04    1.1e-03 jit 3.0e-04     3.5e-04 speedup forward 1.57    backward 3.03
torch.float16   (8192, 4096)    gelu 1.8e-03    3.4e-03 jit 1.5e-03     8.8e-04 speedup forward 1.14    backward 3.91

swish

torch.float32   (32, 128)       swish 7.5e-05   2.6e-04 jit 1.1e-04     2.0e-04 speedup forward 0.71   backward 1.32
torch.float32   (32, 512)       swish 7.5e-05   2.6e-04 jit 5.8e-05     1.7e-04 speedup forward 1.31   backward 1.57
torch.float32   (32, 1024)      swish 7.2e-05   2.5e-04 jit 5.8e-05     1.6e-04 speedup forward 1.24   backward 1.50
torch.float32   (32, 4096)      swish 7.1e-05   2.6e-04 jit 6.1e-05     1.9e-04 speedup forward 1.17   backward 1.38
torch.float32   (256, 128)      swish 7.2e-05   2.5e-04 jit 5.7e-05     1.7e-04 speedup forward 1.26   backward 1.50
torch.float32   (256, 512)      swish 7.4e-05   2.7e-04 jit 5.9e-05     1.8e-04 speedup forward 1.25   backward 1.55
torch.float32   (256, 1024)     swish 7.3e-05   2.6e-04 jit 6.2e-05     2.0e-04 speedup forward 1.18   backward 1.35
torch.float32   (256, 4096)     swish 8.5e-05   2.7e-04 jit 6.5e-05     1.6e-04 speedup forward 1.31   backward 1.75
torch.float32   (1024, 128)     swish 7.4e-05   2.7e-04 jit 5.8e-05     1.8e-04 speedup forward 1.27   backward 1.47
torch.float32   (1024, 512)     swish 7.5e-05   2.8e-04 jit 6.4e-05     2.2e-04 speedup forward 1.16   backward 1.29
torch.float32   (1024, 1024)    swish 9.2e-05   3.3e-04 jit 7.0e-05     2.1e-04 speedup forward 1.32   backward 1.59
torch.float32   (1024, 4096)    swish 1.9e-04   5.7e-04 jit 1.6e-04     2.7e-04 speedup forward 1.24   backward 2.10
torch.float32   (8192, 128)     swish 9.1e-05   3.2e-04 jit 7.2e-05     2.0e-04 speedup forward 1.26   backward 1.54
torch.float32   (8192, 512)     swish 1.9e-04   5.5e-04 jit 1.6e-04     2.7e-04 speedup forward 1.20   backward 2.03
torch.float32   (8192, 1024)    swish 3.5e-04   8.8e-04 jit 3.2e-04     3.9e-04 speedup forward 1.09   backward 2.24
torch.float32   (8192, 4096)    swish 1.3e-03   2.7e-03 jit 1.3e-03     1.0e-03 speedup forward 0.99   backward 2.62
torch.float16   (32, 128)       swish 7.0e-05   2.5e-04 jit 1.0e-04     2.1e-04 speedup forward 0.69   backward 1.18
torch.float16   (32, 512)       swish 6.9e-05   2.4e-04 jit 6.6e-05     1.8e-04 speedup forward 1.05   backward 1.38
torch.float16   (32, 1024)      swish 7.0e-05   2.4e-04 jit 6.0e-05     1.7e-04 speedup forward 1.18   backward 1.43
torch.float16   (32, 4096)      swish 6.9e-05   2.5e-04 jit 6.0e-05     1.8e-04 speedup forward 1.14   backward 1.37
torch.float16   (256, 128)      swish 6.5e-05   2.4e-04 jit 5.8e-05     1.6e-04 speedup forward 1.12   backward 1.48
torch.float16   (256, 512)      swish 7.1e-05   2.6e-04 jit 6.0e-05     1.8e-04 speedup forward 1.20   backward 1.41
torch.float16   (256, 1024)     swish 6.8e-05   2.5e-04 jit 6.0e-05     1.8e-04 speedup forward 1.14   backward 1.37
torch.float16   (256, 4096)     swish 7.1e-05   2.5e-04 jit 9.7e-05     2.1e-04 speedup forward 0.73   backward 1.20
torch.float16   (1024, 128)     swish 7.0e-05   2.5e-04 jit 6.0e-05     1.8e-04 speedup forward 1.17   backward 1.42
torch.float16   (1024, 512)     swish 7.2e-05   2.6e-04 jit 6.8e-05     1.7e-04 speedup forward 1.06   backward 1.49
torch.float16   (1024, 1024)    swish 6.7e-05   2.4e-04 jit 9.7e-05     2.1e-04 speedup forward 0.69   backward 1.14
torch.float16   (1024, 4096)    swish 1.3e-04   3.6e-04 jit 1.9e-04     1.8e-04 speedup forward 0.69   backward 1.98
torch.float16   (8192, 128)     swish 7.0e-05   2.4e-04 jit 9.7e-05     1.9e-04 speedup forward 0.73   backward 1.26
torch.float16   (8192, 512)     swish 1.3e-04   3.5e-04 jit 1.9e-04     2.2e-04 speedup forward 0.66   backward 1.62
torch.float16   (8192, 1024)    swish 2.1e-04   5.7e-04 jit 3.5e-04     3.1e-04 speedup forward 0.62   backward 1.82
torch.float16   (8192, 4096)    swish 7.6e-04   1.6e-03 jit 1.3e-03     7.2e-04 speedup forward 0.60   backward 2.17

Seems like it makes sense to compile _python_gelu, and for swish the benefits are negligible

@sshleifer
Copy link
Contributor

sshleifer commented Feb 25, 2020

We only use _gelu_python for torch < 1.4.
My only concern with this PR is that it will break in early pytorch versions or on CPU or something, can you test it under those circumstances?

@mryab
Copy link
Contributor Author

mryab commented Mar 11, 2020

I've tested the current implementation with pytorch==1.0.0 on CPU, and it indeed breaks because torch.jit did not support python floats at that time. I have two possible solutions for this, @sshleifer what will be the best one?

First: slightly modify gelu_python and gelu_new to be backwards-compatible

@torch.jit.script
def jit_gelu_python(x):
    """ Original Implementation of the gelu activation function in Google Bert repo when initially created.
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
        This is now written in C in torch.nn.functional
        Also see https://arxiv.org/abs/1606.08415
    """
    gelu_const = torch.sqrt(torch.full((), 2.0, dtype=x.dtype, device=x.device))
    return x * 0.5 * (1.0 + torch.erf(x / gelu_const))

@torch.jit.script
def jit_gelu(x):
    """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
        Also see https://arxiv.org/abs/1606.08415
    """
    gelu_const = torch.sqrt(torch.full((), 2.0/math.pi, dtype=x.dtype, device=x.device))
    return 0.5 * x * (1 + torch.tanh(gelu_const * (x + 0.044715 * torch.pow(x, 3))))

Second: use torch.jit.script only with pytorch>1.4.0. We won't need to wrap gelu, as it already has a native implementation, and for gelu_new we'll add a single check.

@mryab
Copy link
Contributor Author

mryab commented Apr 3, 2020

I've changed the PR so that gelu_new gets JIT-compiled only on pytorch>=1.4. Benchmarking resuts are the same with 3-4x faster forward and 3x faster backward for this activation (although no speedup on CPU float32). @sshleifer is it ready to be merged now?

@sshleifer
Copy link
Contributor

In my opinion, yes. LGTM.
@LysandreJik @julien-c this is a backwards compatible speedup.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks for iterating on that and getting so deep into it @mryab !

@LysandreJik LysandreJik merged commit c6acd24 into huggingface:master Apr 3, 2020
@mryab mryab deleted the jit-gelu branch April 3, 2020 20:48
@LysandreJik
Copy link
Member

Hello! Unfortunately, we'll have to revert this PR as jitting an activation function prevents the model from being pickled.

This has already been an issue in several cases:

Nonetheless, thank you for your contribution and for such a detailed study of what was to be gained from it.

lopuhin added a commit to lopuhin/transformer-lm that referenced this pull request May 30, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants