-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: Add tensor descriptor API backed by device-side TMA creation #4916
base: main
Are you sure you want to change the base?
WIP: Add tensor descriptor API backed by device-side TMA creation #4916
Conversation
6d6fbfa
to
bdcbf91
Compare
bdcbf91
to
d0cd5c2
Compare
53f5d50
to
cae0fdf
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, few minor comments
@@ -77,4 +77,18 @@ def TritonNvidiaGPUTMALoweringPass : Pass<"triton-nvidia-tma-lowering", "mlir::M | |||
]; | |||
} | |||
|
|||
def TritonNvidiaGPUGlobalScratchAllocationPass : Pass<"triton-tensor-memory-allocation", "mlir::ModuleOp"> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this nvidia specific? It sounds like we could make it a generic pass?
callOp->getLoc(), rewriter, targetInfo, callOp)); | ||
|
||
auto opOffsetAttr = caller->getAttrOfType<mlir::IntegerAttr>( | ||
"triton_nvidia_gpu.global_scratch_memory_offset"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's a bit weird to have this in triton_nvidia_gpu
namespace, especially since this is in shared code
@@ -101,6 +102,102 @@ class TMAStoreLowering | |||
} | |||
}; | |||
|
|||
class TMACreateDescLowering : public OpRewritePattern<MakeTensorDescOp> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a lit test for this? It is a nice way to have an example of the IR
@@ -368,7 +369,7 @@ inline bool isKernel(FunctionOpInterface funcOp) { | |||
inline Value getStackPointer(RewriterBase &rewriter, | |||
FunctionOpInterface funcOp) { | |||
if (!isKernel(funcOp)) { | |||
return funcOp.getArgument(funcOp.getNumArguments() - 1); | |||
return funcOp.getArgument(funcOp.getNumArguments() - 2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may need to document it more explicitly: what is funcOp.getNumArguments() - 2
and what is funcOp.getNumArguments() - 1
idx = tl.arange(0, M_BLOCK)[:, None] * N_BLOCK + tl.arange(0, N_BLOCK)[None, :] | ||
tl.store(out_ptr + idx, block) | ||
|
||
def alloc_fn(size: int, align: int, stream: Optional[int]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the user must have the knowledge to figure out the size of the scratch buffer. Can you justify how it is compatible with our Autotuner?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I understand your question. The user provided allocation function is called with the correct size as computed by the launcher code. During autotuning the allocation function will be called multiple times with different sizes to allocate new memory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My question is more about how we implement the alloc_fn
when using the autotuner. The block sizes are changed based on the configurations passed by triton.Config
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The size passed to the allocation function is the total allocation size, so the allocation function doesn't need to know anything about the config.
See for example in the next test down where there is a larger grid and so I assert:
assert size == 128 * (grid_m * grid_n)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to make sure I understand it clearly.
If the maximum number of grids is 100x100, then we will allocate 100x100x128 anyway for all configurations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, say you have one config is 100x100 and another is 200x100 then yor allocation function will be called several times with size = 100*100*128
and several times with size = 200*100*128
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Makes sense
git-pr-chain: wip_add_tensor_descriptor_api_backed_by__a0ac
57e7f2d
to
214f259
Compare
214f259
to
837308f
Compare
Commits in this PR
PR chain