-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BACKEND] Propagate mma layout to following elementwise operations. #3973
Conversation
Mitigating performance issues seen with fp8_fast_accum=False in #2513 |
This seems very ad hoc. The proper fix would be to remove the workaround altogether. We had planned for it but never got to it |
You mean to remove the workaround for reductionOp? |
yes basically remove |
And reductionOp should be handled in backwards rematerialization? |
yes if they still cause perf problems we should handle them separately |
Sounds good. I'll work on that. BTW, does |
I don't think so |
It'd be handy if you can share a case so that I can make sure not regressing it. |
it was done a while ago and I'm not sure this still exists. I think if we can remove this workaround in term of functionality we should be good to go and we can fix up performance problems |
I'm making a test case to drive the implementation and I'd like to check with you if this is what we want to shoot for. Here is what I made based on my understanding of the existing code:
Previously
and I'm working on reverting this, by reverse-rematerializing the Does this sound right to you? |
I don't think you want to concentrate on that part at the moment. We should measure and figure out what to do about the layout of reductions but it should be once we have better data. I think at the moment the workaround is still there mostly because of functional problems when removing it. |
I thought the tt.reduce issue is the functional problem. Actually I'm seeing the only python test failure is related to tt.reduce because of a codegen difference ( triton/python/test/unit/language/test_core.py Line 3184 in 74ad278
|
Sorry I didn't express it correctly. What I meant is that this works around some functional issues that are also related to reduce. For those we should fix the root cause. There may be performance problems but I don't have data about those and we should revisit based what we see on important workloads |
So I took a look at the failing python tests. The codegen of reduceOp on mma layout is indeed inefficient as it triggers On the second thought, I was wondering how beneficial it is to always propagate mma layouts. Reduction on mma layout isn't efficient. Pointwise operations on mma isn't better either. It seems to me that propagation makes sense only if the mma layout is used again later (transitive as defined currently). |
I don't think that's true. Reduction is less efficient only if the warps per CTA along the dimension reduced is > 1. Like any other layout, we want to propagate to avoid having convert layout. If we don't do propagation we will have more convert.
I don't think I agree but just so that I understand, are you suggesting we leave the code as is then? I though the point of the PR is that you were seeing improvements by propagating? |
If there is no other use of mma later, there will be an eventual convert from mma to other layout that cannot be got rid of, such as for tt.store? If there is other use of the mma, then it is transitive, then the current logic can handle that.
Yes, I was seeing a transitive use of mma that's not captured by existing logic. It is about an
|
This would still block the forward propagation as we wouldn't pick a layout for the transitive users of the dot op. The algorithm kind of relies on pushing the convert downs to combine them.
how is this a transitive use of mma? |
It is blocked only by a reduceOp that has a small mma shape? Without such reduceOp, the transitive users will be detected and the propagation will go there.
If in a
and with this patch it becomes
|
ah I see, I think this is a bit arbitrary to be honest :)
I see, well I'm also not convinced about the rational. I think a better way would be to always propagate and for reduction we want to optimize the layout to not cross the warp boundary along a given dimension. This is actually quite orthogonal to mma vs non mma layout. |
Once a reduction gets a mma layout, how should we optimize it? Should we stop the propagation when such reduction is seen? I'm still having the question. When there is no use of the mma layout further after the reduction, is it still worth propagating the mma over? |
as I mentioned having reduction with mma layout is fine, what matters is really the warps per CTA. I guess I don't understand what you mean by use of mma layout in this case. In the example you add the use of a mma layout is just an elementwise using mma layout. After a reduction the reduced tensor should have uses otherwise the reduction would be dead. |
I see. If the mma layout has more warps along the reduction axis, how we want to optimize it? Do we want to generate a new mma in the first place?
The result of the reduction is stored back to memory, in blocked layout. So the convert is not going away, even if we propagate mma into the reduction. If the reduction output was ever broadcast and fed into another dot, I could see the benefit of leaving it in mma layout. |
If we remove the work around for hasConvertToMMATransisitiveUse, i.e let layout of an anchor op propagate to the result, it will fix the case that you are looking at, right? @htyu The question is more around the performance related to reduction if we remove hasConvertToMMATransisitiveUse. Can we make sure we get good performance for general cases with reduction? From Thomas' reply, sounds like it is orthogonal to mma layout or non-mma layout, and we can fix the performance issue by making sure the layout to not cross the warp boundary. It is not clear to me which place is the right place to fix this though. There are related PRs: But we can remove the work around for hasConvertToMMATransisitiveUse, fix functionality issues, then fix performance issues? |
Yes, it fixes my case, but causes regression to other cases
Yeah, I'm not quite sure where to fix for avoiding reducing across warps. And my other question is whether it is always beneficial to propagate mma layouts if it's going to be converted back to other layout if there is no other use of the mma in between. |
I think with our current code this is the best heuristic
Can you share examples of cases that regress performance. I need to think a bit more about the right phase ordering for this. |
Which other cases? I assume performance regression?
Will the same question apply to non-mma layouts?
Is this about the code snippet above? |
Sure. It is one python test
After the reduce is done, the result will be converted to another blocked layout for storing. TTGIR:
Since the last convert layout cannot be got rid of, I was wondering how beneficial the propagation was. |
but this is not really a meaningful real life scenario. Have you seen actual performance regressions on benchmarks? |
Ah I haven't done an extensive perf testing. Will kick off a pytorch nightly perf job. So you think we can just remove |
yes |
Ok, so I did hit a correctness issue when running against a variant of a flash attention kernel. The symptom is like
when lowering a convert layout from #mma to #dot_op whose parent #mma layout has a different instrShape.
The error basically says the producer (the output of last dot) and the consumer (first operand of the next dot) are expecting a different number of values per-thread, though the tile shape is the same. The number of values per-thread is computed based on It looks to me that this layout Such a layout originates from propagating mma over a reduce op.
The If we allow triton/lib/Dialect/TritonGPU/IR/Dialect.cpp Line 870 in 74ad278
|
Closing this for now, I think the best solution is to remove the workaround |
Looks like the linear layout issue has been fixed. I'm updating the PR to remove the workaround |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks! Give me a chance to run few things internally, I'll approve later today
looks like there are still some bugs, I need to debug those. Sorry about that. I'll try to look at it at the end of the week |
Thanks for the update. Are they correctness issues or perf issues? |
correctness |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, the functionality problem was actually just a precision difference due to different reduce ordering.
For matmul with following arithmetic operations such as
acc += tl.dot(a, b)
, currently the mma layout of thedot
result isn't propagated into the subsequentadd
. As a result when the dot is inside a loop, there will be repeated layout conversion from mma to blocked. I'm fixing this by allowing mma layout propagated so that it can be reused.