Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add tensor descriptor API backed by device-side TMA creation #4916

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

peterbell10
Copy link
Contributor

@peterbell10 peterbell10 commented Oct 15, 2024

Commits in this PR

  1. WIP: Add tensor descriptor API backed by device-side TMA creation
  2. Fix lit failure
  3. More fixes
  4. Fix aot compile
  5. Fix hip compile
  6. Support allocating memory in noinline functions
  7. Move GlobalScratchAllocOp to TritonGPUIR
  8. Add descriptor lowering lit test
  9. Add note explaining the calling convention

PR chain

  1. 👉 WIP: Add tensor descriptor API backed by device-side TMA creation #4916 👈 YOU ARE HERE

@peterbell10 peterbell10 force-pushed the pb/pr-chain/wip_add_tensor_descriptor_api_backed_by__a0ac branch from 6d6fbfa to bdcbf91 Compare October 15, 2024 17:28
Base automatically changed from pb/pr-chain/frontend_factor_out_block_shape_validati_9ccb to main October 15, 2024 21:38
@peterbell10 peterbell10 force-pushed the pb/pr-chain/wip_add_tensor_descriptor_api_backed_by__a0ac branch from bdcbf91 to d0cd5c2 Compare October 15, 2024 21:44
@peterbell10 peterbell10 force-pushed the pb/pr-chain/wip_add_tensor_descriptor_api_backed_by__a0ac branch from 53f5d50 to cae0fdf Compare October 17, 2024 14:49
Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, few minor comments

@@ -77,4 +77,18 @@ def TritonNvidiaGPUTMALoweringPass : Pass<"triton-nvidia-tma-lowering", "mlir::M
];
}

def TritonNvidiaGPUGlobalScratchAllocationPass : Pass<"triton-tensor-memory-allocation", "mlir::ModuleOp"> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this nvidia specific? It sounds like we could make it a generic pass?

callOp->getLoc(), rewriter, targetInfo, callOp));

auto opOffsetAttr = caller->getAttrOfType<mlir::IntegerAttr>(
"triton_nvidia_gpu.global_scratch_memory_offset");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a bit weird to have this in triton_nvidia_gpu namespace, especially since this is in shared code

@@ -101,6 +102,102 @@ class TMAStoreLowering
}
};

class TMACreateDescLowering : public OpRewritePattern<MakeTensorDescOp> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a lit test for this? It is a nice way to have an example of the IR

@@ -368,7 +369,7 @@ inline bool isKernel(FunctionOpInterface funcOp) {
inline Value getStackPointer(RewriterBase &rewriter,
FunctionOpInterface funcOp) {
if (!isKernel(funcOp)) {
return funcOp.getArgument(funcOp.getNumArguments() - 1);
return funcOp.getArgument(funcOp.getNumArguments() - 2);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may need to document it more explicitly: what is funcOp.getNumArguments() - 2 and what is funcOp.getNumArguments() - 1

idx = tl.arange(0, M_BLOCK)[:, None] * N_BLOCK + tl.arange(0, N_BLOCK)[None, :]
tl.store(out_ptr + idx, block)

def alloc_fn(size: int, align: int, stream: Optional[int]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the user must have the knowledge to figure out the size of the scratch buffer. Can you justify how it is compatible with our Autotuner?

Copy link
Contributor Author

@peterbell10 peterbell10 Oct 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand your question. The user provided allocation function is called with the correct size as computed by the launcher code. During autotuning the allocation function will be called multiple times with different sizes to allocate new memory.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My question is more about how we implement the alloc_fn when using the autotuner. The block sizes are changed based on the configurations passed by triton.Config.

Copy link
Contributor Author

@peterbell10 peterbell10 Oct 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The size passed to the allocation function is the total allocation size, so the allocation function doesn't need to know anything about the config.

See for example in the next test down where there is a larger grid and so I assert:

        assert size == 128 * (grid_m * grid_n)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure I understand it clearly.

If the maximum number of grids is 100x100, then we will allocate 100x100x128 anyway for all configurations

Copy link
Contributor Author

@peterbell10 peterbell10 Oct 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, say you have one config is 100x100 and another is 200x100 then yor allocation function will be called several times with size = 100*100*128 and several times with size = 200*100*128.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Makes sense

@peterbell10 peterbell10 force-pushed the pb/pr-chain/wip_add_tensor_descriptor_api_backed_by__a0ac branch from 57e7f2d to 214f259 Compare October 22, 2024 16:33
@peterbell10 peterbell10 force-pushed the pb/pr-chain/wip_add_tensor_descriptor_api_backed_by__a0ac branch from 214f259 to 837308f Compare October 22, 2024 18:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants