-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BACKEND] Replace isMmaToDotShortcut
with linear layout based logic
#4951
base: main
Are you sure you want to change the base?
Conversation
@zhanglx13 and @antiagainst you may want to take a look as well. |
@ThomasRaoux feel free to run a regression test on the PR. I don't think there should be any issues since I only changed the register access order, but I just wanted to catch potential problems early. |
@@ -80,19 +80,6 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> { | |||
ret.push_back(v); | |||
} | |||
} | |||
// FIXME [Dot LL] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great!
@@ -75,9 +75,39 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct( | |||
|
|||
// For kWidth = 8, split the mma into 4 mmas with "stride 4" along K | |||
if (dot.getOpIdx() == 0) { | |||
si = llvm::SmallVector<unsigned>{0, 8, 4, 12, 1, 9, 5, 13, | |||
2, 10, 6, 14, 3, 11, 7, 15}; | |||
// Original register layout: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for making the comments more explicit!
ret.push_back(values[i]); | ||
ret.push_back(values[i + 1]); | ||
ret.push_back(values[i + 3]); | ||
ret.push_back(values[i + 2]); | ||
ret.push_back(values[i + 4]); | ||
ret.push_back(values[i + 5]); | ||
ret.push_back(values[i + 2]); | ||
ret.push_back(values[i + 3]); | ||
ret.push_back(values[i + 6]); | ||
ret.push_back(values[i + 7]); | ||
ret.push_back(values[i + 6]); | ||
ret.push_back(values[i + 8]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Off-by-one error: you are accessing i + 8
and not accessing i + 1
. Can you write a test that exercises this path?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah. Wonder why no test case captures this problem
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fp16->fp32 should have been covered by 9357902 now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After thinking a bit about it, I think I understand why padding fixes the issues we were seeing when the inputs and outputs have a different number of registers.
The issue stems from the function
triton/lib/Tools/LinearLayout.cpp
Line 119 in 1064b59
getInjectiveMat(const LinearLayout &layout) { |
This function makes both the matrices injective by extending their codomain. This is an issue if the inDims have different dimensions, as the codomains will now differ, which is a precondition for the Gaussian elimination to make sense!
The padding patch mitigates this in the cases we found in practice, as it so happens that padding matches perfectly all the free variables from the two matrices, so getInjectiveMat
turns this transformation into the identity, which is perfect.
I don't think that this is the correct approach in general, but it's clearly an improvement over the previous state, so approving. I think I have a solution for the general problem, but I'll implement that at a later point.
Also, thank you @Jokeren for finding the adversarial examples and adding tests for them!
This PR fixes the
cvtReordersRegisters
method, which previously could not return true for two layouts with different numbers of registers. With this update, we can remove the legacyisMmaToDotShortcut
and its associated shortcut conversion.Additionally, we store the dot operand results in the access order to improve code clarity.
Going forward, we intend to eliminate unnecessary shortcut conversions and replace them with the use of
transferWithinThread
.