-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use FusedMatMul When Transpose is Between First Dim and Contiguous Batch Dims #9734
Conversation
This is a nice change!! Some site notes put here FYI. It is found APEX and other libs I investigated last week also do this trick, the trick is applied to the models having self-attention's input having shape [seq, batch, num_head, head_dim]. We would remove at least two transposes + a scaling multiple (sqrt(num_head)) for the BERT large case. @iK1D @SherlockNoMad |
static Node* GetTransposeNodeFromOutput(Graph& graph, NodeArg& node_arg) { | ||
// is_trans is whether to transpose the 2 dims used to MatMul. | ||
// is_trans_batch is whether to transpose 1st dim and batch dims (dim-1 to dim-rank-2). | ||
// For example: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be nice if we can give a more descriptive comments covering what exact cases we target to fuse.
An example FYI
/* Here we check input and mask dimensions are as expected: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and we need a definition for the 'batch' here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is better to use a different word than "batch" because it is used with respect to training batch. May be something like "range" may be okay.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
may be "circular permutation" is more clear.
1->0, 2->1, ..,r->r-1, 0->r.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CUDA's APIs (GemmBatched, GemmStridedBatched) use the same name. Our MatMul code also calls them batches. I think we still call batch here, but add more comments to explain.
static Node* GetTransposeNodeFromOutput(Graph& graph, NodeArg& node_arg) { | ||
// is_trans is whether to transpose the 2 dims used to MatMul. | ||
// is_trans_batch is whether to transpose 1st dim and batch dims (dim-1 to dim-rank-2). | ||
// For example: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and we need a definition for the 'batch' here
} | ||
|
||
if (!is_trans_on_last_two_dims) { | ||
return nullptr; | ||
// Transpose node can be fused to MatMul when the batch dimensions have same order before and after transpose. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: change to "the batch dims keep same relative orders before and after transpose"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Introducing the notion of "circular permutation" is really helpful to understand the code here.
// is_trans is whether to transpose the 2 dims used to MatMul. | ||
// is_trans_batch is whether to transpose 1st dim and batch dims (dim-1 to dim-rank-2). | ||
// For example: | ||
// is_trans=False, is_trans_batch=False: [0,1,2,3] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should not do the fusion for [0.1.2.3,..] case, right
left_ld_factor_ = right_ld_factor_ = 1; | ||
|
||
if (trans_batch_a || trans_batch_b) { | ||
ORT_ENFORCE(left_num_dims > 2 && left_num_dims == right_num_dims, "Two input should have same rank and rank >= 3 if transBatchA or transBatchB is true"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ORT_ENFORCE(left_num_dims > 2 && left_num_dims == right_num_dims, "Two input should have same rank and rank >= 3 if transBatchA or transBatchB is true"); | |
ORT_ENFORCE(left_num_dims > 2 && left_num_dims == right_num_dims, "Two inputs should have same rank and rank >= 3 if transBatchA or transBatchB is true"); |
The change looks great overall! There are few things, need your help for confirmation:
|
I didn't check the big graph carefully, but from the number yes it's 4 for each layer. From the code the fusion is added for both training and inference transformer list, so ideally it backward is also covered. But we build the gradient graph after the training transformers, and use FusedMatMul instead of MatMul in backward graph, so I think it's rare to have such case in backward we can fuse. |
This is good. |
I don't quite get the idea. Could you please give some example? i.e., what's the 'perm' attribute for the Transpose nodes. |
This change supports [1,2,0,3] or [1,2,3,0]. In future, not in this PR, could we consider permutations like, [2, 0, 1, 3] or [2,3,0,1] also? |
I have below comments in the code to explain which cases we can fuse. For [2,0,1,3] or [2,3,0,1], it's not possible to get the strideA, strideB, lda, ldb for the parameters of GemmStridedBatched, so we cannot fuse such cases. // Transpose node can be fused to MatMul when the batch dims keep same relative orders before and after transpose. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the late response! LGTM!! :)
Current FusedMatMul can support only Transpose on last 2 dims. When the 2-D arrays for MatMul is the 1st and last dims, and the batch dims are contiguous in the original tensor, we can also use GemmStridedBatched to calculate without doing the Transpose. The perm pattern in the Transpose is like [1,2,0,3] or [1,2,3,0]. This PR is to support these cases using FusedMatMul.
For perf comparison using a module with Add+EinSum("ks,ksm->sm")+MSELOss, K = 16, S = 7840, M = 2048, before the changes, it's ~7ms for each step, after the changes, it's ~4.5ms for each step, which has similar perf as PyTorch.
Using ULR-XL (16 layers) for perf test, before the changes, the execution graph has 195 Transpose nodes, 16 MatMul nodes and 306 FusedMatMul nodes. After the changes the numbers are: 131 Transpose nodes and 322 FusedMatMul nodes. From nvvp profiling, for each step, the execution time reduces from ~913ms to ~882ms, which have ~4%. The gain is from the reduce of Transpose compute, and the new fused FusedMatMul nodes use GemmStridedBatched, which has comparible perf as original MatMul node.s