Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKEND] Propagate mma layout to following elementwise operations. #3973

Merged
merged 3 commits into from
Oct 22, 2024

Conversation

htyu
Copy link
Collaborator

@htyu htyu commented May 22, 2024

For matmul with following arithmetic operations such as acc += tl.dot(a, b), currently the mma layout of the dot result isn't propagated into the subsequent add. As a result when the dot is inside a loop, there will be repeated layout conversion from mma to blocked. I'm fixing this by allowing mma layout propagated so that it can be reused.

@htyu htyu requested a review from ptillet as a code owner May 22, 2024 21:28
@htyu htyu requested review from manman-ren and removed request for ptillet May 22, 2024 21:28
@htyu
Copy link
Collaborator Author

htyu commented May 22, 2024

Mitigating performance issues seen with fp8_fast_accum=False in #2513

@ThomasRaoux
Copy link
Collaborator

This seems very ad hoc. The proper fix would be to remove the workaround altogether. We had planned for it but never got to it

@htyu
Copy link
Collaborator Author

htyu commented May 22, 2024

This seems very ad hoc. The proper fix would be to remove the workaround altogether. We had planned for it but never got to it

You mean to remove the workaround for reductionOp?

@ThomasRaoux
Copy link
Collaborator

This seems very ad hoc. The proper fix would be to remove the workaround altogether. We had planned for it but never got to it

You mean to remove the workaround for reductionOp?

yes basically remove hasConvertToMMATransisitiveUse ideally

@htyu
Copy link
Collaborator Author

htyu commented May 22, 2024

This seems very ad hoc. The proper fix would be to remove the workaround altogether. We had planned for it but never got to it

You mean to remove the workaround for reductionOp?

yes basically remove hasConvertToMMATransisitiveUse ideally

And reductionOp should be handled in backwards rematerialization?

@ThomasRaoux
Copy link
Collaborator

This seems very ad hoc. The proper fix would be to remove the workaround altogether. We had planned for it but never got to it

You mean to remove the workaround for reductionOp?

yes basically remove hasConvertToMMATransisitiveUse ideally

And reductionOp should be handled in backwards rematerialization?

yes if they still cause perf problems we should handle them separately

@htyu
Copy link
Collaborator Author

htyu commented May 22, 2024

This seems very ad hoc. The proper fix would be to remove the workaround altogether. We had planned for it but never got to it

You mean to remove the workaround for reductionOp?

yes basically remove hasConvertToMMATransisitiveUse ideally

And reductionOp should be handled in backwards rematerialization?

yes if they still cause perf problems we should handle them separately

Sounds good. I'll work on that.

BTW, does test/TritonGPU/combine.mlir have the cases where you saw perf issue?

@ThomasRaoux
Copy link
Collaborator

This seems very ad hoc. The proper fix would be to remove the workaround altogether. We had planned for it but never got to it

You mean to remove the workaround for reductionOp?

yes basically remove hasConvertToMMATransisitiveUse ideally

And reductionOp should be handled in backwards rematerialization?

yes if they still cause perf problems we should handle them separately

Sounds good. I'll work on that.

BTW, does test/TritonGPU/combine.mlir have the cases where you saw perf issue?

I don't think so

@htyu
Copy link
Collaborator Author

htyu commented May 22, 2024

This seems very ad hoc. The proper fix would be to remove the workaround altogether. We had planned for it but never got to it

You mean to remove the workaround for reductionOp?

yes basically remove hasConvertToMMATransisitiveUse ideally

And reductionOp should be handled in backwards rematerialization?

yes if they still cause perf problems we should handle them separately

Sounds good. I'll work on that.
BTW, does test/TritonGPU/combine.mlir have the cases where you saw perf issue?

I don't think so

It'd be handy if you can share a case so that I can make sure not regressing it.

@ThomasRaoux
Copy link
Collaborator

This seems very ad hoc. The proper fix would be to remove the workaround altogether. We had planned for it but never got to it

You mean to remove the workaround for reductionOp?

yes basically remove hasConvertToMMATransisitiveUse ideally

And reductionOp should be handled in backwards rematerialization?

yes if they still cause perf problems we should handle them separately

Sounds good. I'll work on that.
BTW, does test/TritonGPU/combine.mlir have the cases where you saw perf issue?

I don't think so

It'd be handy if you can share a case so that I can make sure not regressing it.

it was done a while ago and I'm not sure this still exists. I think if we can remove this workaround in term of functionality we should be good to go and we can fix up performance problems

@htyu
Copy link
Collaborator Author

htyu commented May 24, 2024

I'm making a test case to drive the implementation and I'd like to check with you if this is what we want to shoot for. Here is what I made based on my understanding of the existing code:

    #C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [32, 8]}>
    #C2 = #triton_gpu.nvidia_mma<{versionMajor = 1, warpsPerCTA = [4, 1], instrShape = [16, 8]}>
     %c = tt.dot %a, %b, %c_init : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
      %t = triton_gpu.convert_layout %c : tensor<128x128xf32, #C> -> tensor<128x128xf32, #C2>
      %2 = "tt.reduce" (%t) ({
      ^bb0(%arg1: f32, %arg2: f32):
        %add = arith.addf %arg1, %arg2 : f32
        tt.reduce.return %add : f32
      }) {axis = 0 : i32} : (tensor<128x128xf32, #C2>) -> tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #C2}>

Previously #C could not be propagated into tt.reduce. And I'm enabling that now and I'm getting

     %c = tt.dot %a, %b, %c_init : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
      %10 = "tt.reduce"(%c) <{axis = 0 : i32}> ({
      ^bb0(%arg12: f32, %arg13: f32):
        %15 = arith.addf %arg12, %arg13 : f32
        tt.reduce.return %15 : f32
      }) : (tensor<128x128xf32, #C>) -> tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #C}>>
      %11 = triton_gpu.convert_layout %10 : tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #C}>> -> tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #C2}>>

and I'm working on reverting this, by reverse-rematerializing the tt.reduce back to #C2 encoding.

Does this sound right to you?

@ThomasRaoux
Copy link
Collaborator

I'm making a test case to drive the implementation and I'd like to check with you if this is what we want to shoot for. Here is what I made based on my understanding of the existing code:

    #C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [32, 8]}>
    #C2 = #triton_gpu.nvidia_mma<{versionMajor = 1, warpsPerCTA = [4, 1], instrShape = [16, 8]}>
     %c = tt.dot %a, %b, %c_init : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
      %t = triton_gpu.convert_layout %c : tensor<128x128xf32, #C> -> tensor<128x128xf32, #C2>
      %2 = "tt.reduce" (%t) ({
      ^bb0(%arg1: f32, %arg2: f32):
        %add = arith.addf %arg1, %arg2 : f32
        tt.reduce.return %add : f32
      }) {axis = 0 : i32} : (tensor<128x128xf32, #C2>) -> tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #C2}>

Previously #C could not be propagated into tt.reduce. And I'm enabling that now and I'm getting

     %c = tt.dot %a, %b, %c_init : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
      %10 = "tt.reduce"(%c) <{axis = 0 : i32}> ({
      ^bb0(%arg12: f32, %arg13: f32):
        %15 = arith.addf %arg12, %arg13 : f32
        tt.reduce.return %15 : f32
      }) : (tensor<128x128xf32, #C>) -> tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #C}>>
      %11 = triton_gpu.convert_layout %10 : tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #C}>> -> tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #C2}>>

and I'm working on reverting this, by reverse-rematerializing the tt.reduce back to #C2 encoding.

Does this sound right to you?

I don't think you want to concentrate on that part at the moment. We should measure and figure out what to do about the layout of reductions but it should be once we have better data. I think at the moment the workaround is still there mostly because of functional problems when removing it.

@htyu
Copy link
Collaborator Author

htyu commented May 24, 2024

I'm making a test case to drive the implementation and I'd like to check with you if this is what we want to shoot for. Here is what I made based on my understanding of the existing code:

    #C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [32, 8]}>
    #C2 = #triton_gpu.nvidia_mma<{versionMajor = 1, warpsPerCTA = [4, 1], instrShape = [16, 8]}>
     %c = tt.dot %a, %b, %c_init : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
      %t = triton_gpu.convert_layout %c : tensor<128x128xf32, #C> -> tensor<128x128xf32, #C2>
      %2 = "tt.reduce" (%t) ({
      ^bb0(%arg1: f32, %arg2: f32):
        %add = arith.addf %arg1, %arg2 : f32
        tt.reduce.return %add : f32
      }) {axis = 0 : i32} : (tensor<128x128xf32, #C2>) -> tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #C2}>

Previously #C could not be propagated into tt.reduce. And I'm enabling that now and I'm getting

     %c = tt.dot %a, %b, %c_init : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
      %10 = "tt.reduce"(%c) <{axis = 0 : i32}> ({
      ^bb0(%arg12: f32, %arg13: f32):
        %15 = arith.addf %arg12, %arg13 : f32
        tt.reduce.return %15 : f32
      }) : (tensor<128x128xf32, #C>) -> tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #C}>>
      %11 = triton_gpu.convert_layout %10 : tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #C}>> -> tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #C2}>>

and I'm working on reverting this, by reverse-rematerializing the tt.reduce back to #C2 encoding.
Does this sound right to you?

I don't think you want to concentrate on that part at the moment. We should measure and figure out what to do about the layout of reductions but it should be once we have better data. I think at the moment the workaround is still there mostly because of functional problems when removing it.

I thought the tt.reduce issue is the functional problem. Actually I'm seeing the only python test failure is related to tt.reduce because of a codegen difference (

assert "bar.sync" not in red_code
) . I'm not sure how to expose other functional issues.

@ThomasRaoux
Copy link
Collaborator

I'm making a test case to drive the implementation and I'd like to check with you if this is what we want to shoot for. Here is what I made based on my understanding of the existing code:

    #C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [32, 8]}>
    #C2 = #triton_gpu.nvidia_mma<{versionMajor = 1, warpsPerCTA = [4, 1], instrShape = [16, 8]}>
     %c = tt.dot %a, %b, %c_init : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
      %t = triton_gpu.convert_layout %c : tensor<128x128xf32, #C> -> tensor<128x128xf32, #C2>
      %2 = "tt.reduce" (%t) ({
      ^bb0(%arg1: f32, %arg2: f32):
        %add = arith.addf %arg1, %arg2 : f32
        tt.reduce.return %add : f32
      }) {axis = 0 : i32} : (tensor<128x128xf32, #C2>) -> tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #C2}>

Previously #C could not be propagated into tt.reduce. And I'm enabling that now and I'm getting

     %c = tt.dot %a, %b, %c_init : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
      %10 = "tt.reduce"(%c) <{axis = 0 : i32}> ({
      ^bb0(%arg12: f32, %arg13: f32):
        %15 = arith.addf %arg12, %arg13 : f32
        tt.reduce.return %15 : f32
      }) : (tensor<128x128xf32, #C>) -> tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #C}>>
      %11 = triton_gpu.convert_layout %10 : tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #C}>> -> tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #C2}>>

and I'm working on reverting this, by reverse-rematerializing the tt.reduce back to #C2 encoding.
Does this sound right to you?

I don't think you want to concentrate on that part at the moment. We should measure and figure out what to do about the layout of reductions but it should be once we have better data. I think at the moment the workaround is still there mostly because of functional problems when removing it.

I thought the tt.reduce issue is the functional problem. Actually I'm seeing the only python test failure is related to tt.reduce because of a codegen difference (

assert "bar.sync" not in red_code

) . I'm not sure how to expose other functional issues.

Sorry I didn't express it correctly. What I meant is that this works around some functional issues that are also related to reduce. For those we should fix the root cause. There may be performance problems but I don't have data about those and we should revisit based what we see on important workloads

@htyu
Copy link
Collaborator Author

htyu commented May 24, 2024

So I took a look at the failing python tests. The codegen of reduceOp on mma layout is indeed inefficient as it triggers storeWarpReduceToSharedMemory . But otherwise there seems no correctness issue as the results are expected.

On the second thought, I was wondering how beneficial it is to always propagate mma layouts. Reduction on mma layout isn't efficient. Pointwise operations on mma isn't better either. It seems to me that propagation makes sense only if the mma layout is used again later (transitive as defined currently).

@ThomasRaoux
Copy link
Collaborator

On the second thought, I was wondering how beneficial it is to always propagate mma layouts. Reduction on mma layout isn't efficient.

I don't think that's true. Reduction is less efficient only if the warps per CTA along the dimension reduced is > 1. Like any other layout, we want to propagate to avoid having convert layout. If we don't do propagation we will have more convert.

It seems to me that propagation makes sense only if the mma layout is used again later (transitive as defined currently).

I don't think I agree but just so that I understand, are you suggesting we leave the code as is then? I though the point of the PR is that you were seeing improvements by propagating?

@htyu
Copy link
Collaborator Author

htyu commented May 24, 2024

Like any other layout, we want to propagate to avoid having convert layout. If we don't do propagation we will have more convert.

If there is no other use of mma later, there will be an eventual convert from mma to other layout that cannot be got rid of, such as for tt.store? If there is other use of the mma, then it is transitive, then the current logic can handle that.

I though the point of the PR is that you were seeing improvements by propagating?

Yes, I was seeing a transitive use of mma that's not captured by existing logic. It is about an add on the dot output, in each loop iteration. Without the loop, I'm not sure propagating #mma to the add below is very beneficial. The convert is still needed by the store. It's just postponed.

c = dot(a,b), #mma
cblocked = convert c to #blocked
d = d + cblocked 
store d

@ThomasRaoux
Copy link
Collaborator

Like any other layout, we want to propagate to avoid having convert layout. If we don't do propagation we will have more convert.

If there is no other use of mma later, there will be an eventual convert from mma to other layout that cannot be got rid of, such as for tt.store? If there is other use of the mma, then it is transitive, then the current logic can handle that.

This would still block the forward propagation as we wouldn't pick a layout for the transitive users of the dot op. The algorithm kind of relies on pushing the convert downs to combine them.

I though the point of the PR is that you were seeing improvements by propagating?

Yes, I was seeing a transitive use of mma that's not captured by existing logic. It is about an add on the dot output, in each iteration. Without the loop, I'm not sure propagating #mma to the add below is very beneficial. The convert is still needed by the store. It's just postponed.

c = dot(a,b), #mma
cblocked = convert c to #blocked
d = d + cblocked 
store d

how is this a transitive use of mma? c doesn't have any transitive use that are dot right?

@htyu
Copy link
Collaborator Author

htyu commented May 24, 2024

This would still block the forward propagation as we wouldn't pick a layout for the transitive users of the dot op. The algorithm kind of relies on pushing the convert downs to combine them.

It is blocked only by a reduceOp that has a small mma shape? Without such reduceOp, the transitive users will be detected and the propagation will go there.

how is this a transitive use of mma? c doesn't have any transitive use that are dot right?

If in a for loop, the add is kind of a transitive use of c. Note that c will be propagated to the add twice, once through cblocked and one through yield. The logic in this change detects that.

for (d_init = 0)  
   c = dot(a,b), #mma
   cblocked = convert c to #blocked
   d = d_init + cblocked 
   yield d
store d

and with this patch it becomes

for (d_init = 0)  
   c = dot(a,b), #mma
   d = d_init + c 
   yield d

dblocked = convert d to #blocked
store dblocked

@ThomasRaoux
Copy link
Collaborator

It is blocked only by a reduceOp that has a small mma shape? Without such reduceOp, the transitive users will be detected and the propagation will go there.

ah I see, I think this is a bit arbitrary to be honest :)

If in a for loop, the add is kind of a transitive use of c. Note that c will be propagated to the add twice, once through cblocked and one through yield. The logic in this change detects that.

I see, well I'm also not convinced about the rational.

I think a better way would be to always propagate and for reduction we want to optimize the layout to not cross the warp boundary along a given dimension. This is actually quite orthogonal to mma vs non mma layout.

@htyu
Copy link
Collaborator Author

htyu commented May 24, 2024

for reduction we want to optimize the layout to not cross the warp boundary along a given dimension.

Once a reduction gets a mma layout, how should we optimize it? Should we stop the propagation when such reduction is seen?

I'm still having the question. When there is no use of the mma layout further after the reduction, is it still worth propagating the mma over?

@ThomasRaoux
Copy link
Collaborator

ThomasRaoux commented May 24, 2024

as I mentioned having reduction with mma layout is fine, what matters is really the warps per CTA.

I guess I don't understand what you mean by use of mma layout in this case. In the example you add the use of a mma layout is just an elementwise using mma layout. After a reduction the reduced tensor should have uses otherwise the reduction would be dead.

@htyu
Copy link
Collaborator Author

htyu commented May 24, 2024

as I mentioned having reduction with mma layout is fine, what matters is really the warps per CTA.

I see. If the mma layout has more warps along the reduction axis, how we want to optimize it? Do we want to generate a new mma in the first place?

I guess I don't understand what you mean by use of mma layout in this case. In the example you add the use of a mma layout is just an elementwise using mma layout. After a reduction the reduced tensor should have uses otherwise the reduction would be dead.

The result of the reduction is stored back to memory, in blocked layout. So the convert is not going away, even if we propagate mma into the reduction. If the reduction output was ever broadcast and fed into another dot, I could see the benefit of leaving it in mma layout.

@manman-ren
Copy link
Collaborator

If we remove the work around for hasConvertToMMATransisitiveUse, i.e let layout of an anchor op propagate to the result, it will fix the case that you are looking at, right? @htyu

The question is more around the performance related to reduction if we remove hasConvertToMMATransisitiveUse. Can we make sure we get good performance for general cases with reduction? From Thomas' reply, sounds like it is orthogonal to mma layout or non-mma layout, and we can fix the performance issue by making sure the layout to not cross the warp boundary. It is not clear to me which place is the right place to fix this though. There are related PRs:
#3694 that looks through if to fix sub-optimal code in some cases
#3768 that is a more restricted form of PR3694

But we can remove the work around for hasConvertToMMATransisitiveUse, fix functionality issues, then fix performance issues?

@htyu
Copy link
Collaborator Author

htyu commented May 28, 2024

If we remove the work around for hasConvertToMMATransisitiveUse, i.e let layout of an anchor op propagate to the result, it will fix the case that you are looking at, right? @htyu

Yes, it fixes my case, but causes regression to other cases

The question is more around the performance related to reduction if we remove hasConvertToMMATransisitiveUse. Can we make sure we get good performance for general cases with reduction? From Thomas' reply, sounds like it is orthogonal to mma layout or non-mma layout, and we can fix the performance issue by making sure the layout to not cross the warp boundary. It is not clear to me which place is the right place to fix this though.

Yeah, I'm not quite sure where to fix for avoiding reducing across warps. And my other question is whether it is always beneficial to propagate mma layouts if it's going to be converted back to other layout if there is no other use of the mma in between.

@ThomasRaoux
Copy link
Collaborator

And my other question is whether it is always beneficial to propagate mma layouts if it's going to be converted back to other layout if there is no other use of the mma in between.

I think with our current code this is the best heuristic

Yeah, I'm not quite sure where to fix for avoiding reducing across warps

Can you share examples of cases that regress performance. I need to think a bit more about the right phase ordering for this.

@manman-ren
Copy link
Collaborator

Yes, it fixes my case, but causes regression to other cases

Which other cases? I assume performance regression?

And my other question is whether it is always beneficial to propagate mma layouts if it's going to be converted back to other layout if there is no other use of the mma in between.

Will the same question apply to non-mma layouts?

The result of the reduction is stored back to memory, in blocked layout. So the convert is not going away, even if we propagate mma into the reduction.

Is this about the code snippet above?
c = dot(a,b), #mma
cblocked = convert c to #blocked
d = d + cblocked
store d

@htyu
Copy link
Collaborator Author

htyu commented May 29, 2024

Can you share examples of cases that regress performance. I need to think a bit more about the right phase ordering for this.

Sure. It is one python test python/test/unit/language/test_core.py -k test_dot[1-64-64-64-4-False-False-softmax-ieee-float16-float32-1]. The softmax operation leads to a reduction op. The mma layout is like:

#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>
, and propagating it to the reduce will trigger storeWarpReduceToSharedMemory due to warpsPerCTA, as compared to the original blocked layout:
#blocked = #triton_gpu.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>

After the reduce is done, the result will be converted to another blocked layout for storing.

TTGIR:

#blocked = #triton_gpu.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}>
#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>

%30 = tt.dot %28, %29, %27 : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> loc(#loc15)
    %31 = triton_gpu.convert_layout %30 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> loc(#loc15)
    %32 = "tt.reduce"(%31) <{axis = 1 : i32}> ({
    ^bb0(%arg8: f32 loc(unknown), %arg9: f32 loc(unknown)):
      %43 = arith.maxnumf %arg8, %arg9 : f32 loc(#loc33)
      tt.reduce.return %43 : f32 loc(#loc29)
    }) : (tensor<32x32xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> loc(#loc29)
    %33 = tt.expand_dims %32 {axis = 1 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xf32, #blocked> loc(#loc19)
    %34 = tt.broadcast %33 : tensor<32x1xf32, #blocked> -> tensor<32x32xf32, #blocked> loc(#loc20)
    %35 = arith.subf %31, %34 : tensor<32x32xf32, #blocked> loc(#loc20)
    %36 = math.exp %35 : tensor<32x32xf32, #blocked> loc(#loc21)
    %37 = "tt.reduce"(%36) <{axis = 1 : i32}> ({
    ^bb0(%arg8: f32 loc(unknown), %arg9: f32 loc(unknown)):
      %43 = arith.addf %arg8, %arg9 : f32 loc(#loc34)
      tt.reduce.return %43 : f32 loc(#loc31)
    }) : (tensor<32x32xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> loc(#loc31)
    %38 = tt.expand_dims %37 {axis = 1 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xf32, #blocked> loc(#loc25)
    %39 = tt.broadcast %38 : tensor<32x1xf32, #blocked> -> tensor<32x32xf32, #blocked> loc(#loc26)
    %40 = arith.divf %36, %39 : tensor<32x32xf32, #blocked> loc(#loc26)
    %41 = arith.truncf %40 : tensor<32x32xf32, #blocked> to tensor<32x32xf16, #blocked> loc(#loc27)
    %42 = triton_gpu.convert_layout %41 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #blocked1> loc(#loc27)
    tt.store %22, %42 : tensor<32x32x!tt.ptr<f16>, #blocked1> loc(#loc27)

Since the last convert layout cannot be got rid of, I was wondering how beneficial the propagation was.

@ThomasRaoux
Copy link
Collaborator

Can you share examples of cases that regress performance. I need to think a bit more about the right phase ordering for this.

Sure. It is one python test python/test/unit/language/test_core.py -k test_dot[1-64-64-64-4-False-False-softmax-ieee-float16-float32-1]. The softmax operation leads to a reduction op. The mma layout is like:

#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> , and propagating it to the reduce will trigger storeWarpReduceToSharedMemory due to warpsPerCTA, as compared to the original blocked layout: #blocked = #triton_gpu.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>

After the reduce is done, the result will be converted to another blocked layout for storing.

TTGIR:

#blocked = #triton_gpu.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}>
#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>

%30 = tt.dot %28, %29, %27 : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> loc(#loc15)
    %31 = triton_gpu.convert_layout %30 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> loc(#loc15)
    %32 = "tt.reduce"(%31) <{axis = 1 : i32}> ({
    ^bb0(%arg8: f32 loc(unknown), %arg9: f32 loc(unknown)):
      %43 = arith.maxnumf %arg8, %arg9 : f32 loc(#loc33)
      tt.reduce.return %43 : f32 loc(#loc29)
    }) : (tensor<32x32xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> loc(#loc29)
    %33 = tt.expand_dims %32 {axis = 1 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xf32, #blocked> loc(#loc19)
    %34 = tt.broadcast %33 : tensor<32x1xf32, #blocked> -> tensor<32x32xf32, #blocked> loc(#loc20)
    %35 = arith.subf %31, %34 : tensor<32x32xf32, #blocked> loc(#loc20)
    %36 = math.exp %35 : tensor<32x32xf32, #blocked> loc(#loc21)
    %37 = "tt.reduce"(%36) <{axis = 1 : i32}> ({
    ^bb0(%arg8: f32 loc(unknown), %arg9: f32 loc(unknown)):
      %43 = arith.addf %arg8, %arg9 : f32 loc(#loc34)
      tt.reduce.return %43 : f32 loc(#loc31)
    }) : (tensor<32x32xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> loc(#loc31)
    %38 = tt.expand_dims %37 {axis = 1 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xf32, #blocked> loc(#loc25)
    %39 = tt.broadcast %38 : tensor<32x1xf32, #blocked> -> tensor<32x32xf32, #blocked> loc(#loc26)
    %40 = arith.divf %36, %39 : tensor<32x32xf32, #blocked> loc(#loc26)
    %41 = arith.truncf %40 : tensor<32x32xf32, #blocked> to tensor<32x32xf16, #blocked> loc(#loc27)
    %42 = triton_gpu.convert_layout %41 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #blocked1> loc(#loc27)
    tt.store %22, %42 : tensor<32x32x!tt.ptr<f16>, #blocked1> loc(#loc27)

Since the last convert layout cannot be got rid of, I was wondering how beneficial the propagation was.

but this is not really a meaningful real life scenario. Have you seen actual performance regressions on benchmarks?

@htyu
Copy link
Collaborator Author

htyu commented May 29, 2024

but this is not really a meaningful real life scenario. Have you seen actual performance regressions on benchmarks?

Ah I haven't done an extensive perf testing. Will kick off a pytorch nightly perf job.

So you think we can just remove hasConvertToMMATransisitiveUse if perf testing comes back good, assuming no other correctness issues?

@ThomasRaoux
Copy link
Collaborator

but this is not really a meaningful real life scenario. Have you seen actual performance regressions on benchmarks?

Ah I haven't done an extensive perf testing. Will kick off a pytorch nightly perf job.

So you think we can just remove hasConvertToMMATransisitiveUse if perf testing comes back good, assuming no other correctness issues?

yes

@htyu
Copy link
Collaborator Author

htyu commented May 30, 2024

Ok, so I did hit a correctness issue when running against a variant of a flash attention kernel. The symptom is like

error: size mismatch when packing elements for LLVM struct expected 32 but got 64

when lowering a convert layout from #mma to #dot_op whose parent #mma layout has a different instrShape.

     #mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 32]}>
    #mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 32, 32]}>
     %134 = triton_gpu.convert_layout %133 {allocation.offset = 28672 : i32} : tensor<256x32xf8E5M2, #mma> -> tensor<256x32xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> 

The error basically says the producer (the output of last dot) and the consumer (first operand of the next dot) are expecting a different number of values per-thread, though the tile shape is the same. The number of values per-thread is computed based on instrShape.

It looks to me that this layout tensor<256x32xf8E5M2, #mma> is invalid, with instrShape = [16, 64, 32]. Am I correct? What does it mean when mma instruction size is bigger than tensor size?

Such a layout originates from propagating mma over a reduce op.

^bb0(%arg26: f32, %arg27: f32):
        %121 = arith.addf %arg26, %arg27 : f32 
        tt.reduce.return %121 : f32 
      }) : (tensor<256x32xf32, #mma>) -> tensor<256xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> 
 %104 = tt.expand_dims %101 {axis = 1 : i32} : tensor<256xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<256x1xf32, #mma> 

The expand_dims results in a tensor layout tensor<256x1xf32, #mma> which seems invalid too. Do you think that's the source of error?

If we allow instrShape bigger than tensor size (not very intuitive though) , we probably need to adjust this logic to compute the correct element number per-thread:

elemsPerThread[1] = (instrMNK[1] / 4) * repN;
.

@ThomasRaoux
Copy link
Collaborator

Closing this for now, I think the best solution is to remove the workaround hasConvertToMMATransisitiveUse when the linear layout refactor is done

@htyu
Copy link
Collaborator Author

htyu commented Oct 7, 2024

Closing this for now, I think the best solution is to remove the workaround hasConvertToMMATransisitiveUse when the linear layout refactor is done

Looks like the linear layout issue has been fixed. I'm updating the PR to remove the workaround
hasConvertToMMATransisitiveUse .

@htyu htyu requested a review from ThomasRaoux October 7, 2024 15:48
Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks! Give me a chance to run few things internally, I'll approve later today

@ThomasRaoux
Copy link
Collaborator

looks like there are still some bugs, I need to debug those. Sorry about that. I'll try to look at it at the end of the week

@htyu
Copy link
Collaborator Author

htyu commented Oct 9, 2024

looks like there are still some bugs, I need to debug those. Sorry about that. I'll try to look at it at the end of the week

Thanks for the update. Are they correctness issues or perf issues?

@ThomasRaoux
Copy link
Collaborator

looks like there are still some bugs, I need to debug those. Sorry about that. I'll try to look at it at the end of the week

Thanks for the update. Are they correctness issues or perf issues?

correctness

Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, the functionality problem was actually just a precision difference due to different reduce ordering.

@ThomasRaoux ThomasRaoux merged commit 1064b59 into triton-lang:main Oct 22, 2024
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants