-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Could torch.einsum gain speed boost ? #394
Comments
Hi @fyubang, could you post a link to the repo you are using so that we can have a look? |
Sorry for forgetting about the link, I used the code here: |
Hi @ptrblck, |
Thanks for the link, @fyubang. |
We tried to compare the performance between a FP32 run and an
Using 8 V100 GPUs (each with 32GB), we could achieve a mean speed of ~2.65 iterations/second. However, supporting the After changing the order of initialization, we could successfully run the script on the same machine achieving ~3.70 iterations/second, which seems reasonable. By "it did not work either", are you referring to the raised error or to a slower run using CC @huggingface |
I rerun the test using the xlnet:
and got the following numbers: The performance benefit is indeed smaller and worth having a closer look at. |
@ptrblck I tried to replace it by:
but it became even slower than |
@ptrblck |
@fyubang Here is a small benchmark using 1) shapes of factors of 8 and 2) missing this condition slightly: # 1)
I, J, K = 64, 1024, 1024
A = torch.randn(I, J, device='cuda', dtype=torch.half)
B = torch.randn(J, K, device='cuda', dtype=torch.half)
nb_iters = 1000
torch.cuda.synchronize()
t0 = time.time()
for _ in range(nb_iters):
C = torch.matmul(A, B)
torch.cuda.synchronize()
t1 = time.time()
print('{:.3f}us per iteration'.format((t1 - t0) / nb_iters * 1e6))
> 16.043us per iteration
# 2)
I, J, K = 63, 1023, 1023
A = torch.randn(I, J, device='cuda', dtype=torch.half)
B = torch.randn(J, K, device='cuda', dtype=torch.half)
nb_iters = 1000
torch.cuda.synchronize()
t0 = time.time()
for _ in range(nb_iters):
C = torch.matmul(A, B)
torch.cuda.synchronize()
t1 = time.time()
print('{:.3f}us per iteration'.format((t1 - t0) / nb_iters * 1e6))
> 39.476us per iteration Could this also be the reason for the minor speedup in the XLNET? |
@ptrblck {
"attn_type": "bi",
"bi_data": false,
"clamp_len": -1,
"d_head": 64,
"d_inner": 4096,
"d_model": 1024,
"dropatt": 0.1,
"dropout": 0.1,
"ff_activation": "gelu",
"init": "normal",
"init_range": 0.1,
"init_std": 0.02,
"initializer_range": 0.02,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"mem_len": null,
"n_head": 16,
"n_layer": 24,
"n_token": 32000,
"reuse_len": null,
"same_length": false,
"untie_r": true
} Since they are all multiples of 8, I think it is not the problem of "multiples of 8". |
Do you mean by "In fact, when I tryied (i, j) matmul (j,k), it can always have a speed boost", that each FP16 matmul in this form will be faster than the corresponding FP32 matmul regardless of the input shapes? Could you try to add some warmup iterations before the actual timings? Thanks for the information about xlnet. We'll look into it. |
@ptrblck For the second quesiton: import os
os.environ['CUDA_VISIBLE_DEVICES']='0'
import torch
from time import time
# 1) fp32
a = torch.empty(24,32,40,48, dtype=torch.float32).to('cuda')
b = torch.empty(64,32,40,48, dtype=torch.float32).to('cuda')
c = torch.empty(40,80,24, dtype=torch.float32).to('cuda')
d = torch.empty(40,24,16, dtype=torch.float32).to('cuda')
torch.cuda.synchronize()
st = time()
for _ in range(1000):
c.matmul(d)
torch.cuda.synchronize()
print(time()-st)
torch.cuda.synchronize()
st = time()
for _ in range(1000):
torch.einsum('ibnd,jbnd->ijbn', a, b)
torch.cuda.synchronize()
print(time()-st)
# 2) fp16
a = torch.empty(24,32,40,48, dtype=torch.float16).to('cuda')
b = torch.empty(64,32,40,48, dtype=torch.float16).to('cuda')
c = torch.empty(40,80,24, dtype=torch.float16).to('cuda')
d = torch.empty(40,24,16, dtype=torch.float16).to('cuda')
torch.cuda.synchronize()
st = time()
for _ in range(1000):
torch.matmul(c,d)
torch.cuda.synchronize()
print(time()-st)
torch.cuda.synchronize()
st = time()
for _ in range(1000):
torch.einsum('ibnd,jbnd->ijbn', a, b)
torch.cuda.synchronize()
print(time()-st) my result is:
|
There are my results for your calculations on a TITAN V:
|
@ptrblck Thanks for your reply. |
Closed in favor of pytorch/pytorch#23061, this does not seem to be amp-specific. |
I am trying to fine tune xlnet and found that the memory was half, but it was slower than fp32(even when I double the batch size).
Environment: v100, cuda 10.0, torch 1.1
The environment is ok, because I tried bert + fp16 and it was much faster than fp32.
I thought it is the problem of torch.einsum, but I am not that sure.
The text was updated successfully, but these errors were encountered: