Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCDL] Add the global.atomic.fadd intrinsic in ROCDL #94486

Closed
wants to merge 4 commits into from

Conversation

giuseros
Copy link
Contributor

@giuseros giuseros commented Jun 5, 2024

This PR adds the global.atomic.fadd intrinsic in ROCDL (which supports f32 and vector<2xf16>)

@llvmbot
Copy link
Collaborator

llvmbot commented Jun 5, 2024

@llvm/pr-subscribers-mlir-llvm

Author: Giuseppe Rossini (giuseros)

Changes

This PR adds the global.atomic.fadd intrinsic in ROCDL (which supports f32 and vector&lt;2xf16&gt;)


Full diff: https://github.com/llvm/llvm-project/pull/94486.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+15-2)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp (+20)
  • (modified) mlir/test/Target/LLVMIR/rocdl.mlir (+9)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 1dabf5d7979b7..c8d4e4c03486e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -165,7 +165,7 @@ def ROCDL_BallotOp :
   let summary = "Vote across thread group";
 
   let description = [{
-      Ballot provides a bit mask containing the 1-bit predicate value from each lane. 
+      Ballot provides a bit mask containing the 1-bit predicate value from each lane.
       The nth bit of the result contains the 1 bit contributed by the nth warp lane.
   }];
 
@@ -516,7 +516,7 @@ def ROCDL_RawBufferAtomicCmpSwap :
 }
 
 //===---------------------------------------------------------------------===//
-// MI-100 and MI-200 buffer atomic floating point add intrinsic
+// MI-100, MI-200 and MI-300 global/buffer atomic floating point add intrinsic
 
 def ROCDL_RawBufferAtomicFAddOp :
   ROCDL_Op<"raw.buffer.atomic.fadd">,
@@ -534,6 +534,19 @@ def ROCDL_RawBufferAtomicFAddOp :
   let hasCustomAssemblyFormat = 1;
 }
 
+def ROCDL_GlobalAtomicFAddOp :
+  ROCDL_Op<"global.atomic.fadd">,
+  Arguments<(ins LLVM_Type:$ptr,
+                 LLVM_Type:$vdata)>{
+  string llvmBuilder = [{
+      auto vdataType = moduleTranslation.convertType(op.getVdata().getType());
+      auto ptrType = moduleTranslation.convertType(op.getPtr().getType());
+      createIntrinsicCall(builder,
+          llvm::Intrinsic::amdgcn_global_atomic_fadd, {$ptr, $vdata}, {vdataType, ptrType, vdataType});
+  }];
+  let hasCustomAssemblyFormat = 1;
+}
+
 //===---------------------------------------------------------------------===//
 // Buffer atomic floating point max intrinsic. GFX9 does not support fp32.
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index 65b770ae32610..34ebdb2ffd3d0 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -157,6 +157,26 @@ void RawBufferAtomicFAddOp::print(mlir::OpAsmPrinter &p) {
   p << " " << getOperands() << " : " << getVdata().getType();
 }
 
+// <operation> ::=
+//     `llvm.amdgcn.global.atomic.fadd.* %vdata, %ptr
+ParseResult GlobalAtomicFAddOp::parse(OpAsmParser &parser,
+                                      OperationState &result) {
+  SmallVector<OpAsmParser::UnresolvedOperand, 5> ops;
+  Type type;
+  if (parser.parseOperandList(ops, 2) || parser.parseColonType(type))
+    return failure();
+
+  auto ptrType = LLVM::LLVMPointerType::get(parser.getContext());
+  if (parser.resolveOperands(ops, {ptrType, type}, parser.getNameLoc(),
+                             result.operands))
+    return failure();
+  return success();
+}
+
+void GlobalAtomicFAddOp::print(mlir::OpAsmPrinter &p) {
+  p << " " << getOperands() << " : " << getVdata().getType();
+}
+
 // <operation> ::=
 //     `llvm.amdgcn.raw.buffer.atomic.fmax.* %vdata, %rsrc,  %offset,
 //     %soffset, %aux : result_type`
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index ce6b56d48437a..9d22b80748e14 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -494,6 +494,15 @@ llvm.func @rocdl.raw.buffer.atomic.f32(%rsrc : vector<4xi32>,
   llvm.return
 }
 
+// CHECK-LABEL: rocdl.global.atomic
+llvm.func @rocdl.global.atomic(%vdata0 : f32, %vdata1 : vector<2xf16>, %ptr : !llvm.ptr) {
+  // CHECK: call float @llvm.amdgcn.global.atomic.fadd.f32.p0.f32(ptr %{{.*}}, float %{{.*}}
+  rocdl.global.atomic.fadd %ptr, %vdata0: f32
+  // CHECK: call <2 x half> @llvm.amdgcn.global.atomic.fadd.v2f16.p0.v2f16(ptr %{{.*}}, <2 x half> %{{.*}})
+  rocdl.global.atomic.fadd %ptr, %vdata1: vector<2xf16>
+  llvm.return
+}
+
 llvm.func @rocdl.raw.buffer.atomic.i32(%rsrc : vector<4xi32>,
                         %offset : i32, %soffset : i32,
                         %vdata1 : i32) {

@llvmbot
Copy link
Collaborator

llvmbot commented Jun 5, 2024

@llvm/pr-subscribers-mlir

Author: Giuseppe Rossini (giuseros)

Changes

This PR adds the global.atomic.fadd intrinsic in ROCDL (which supports f32 and vector&lt;2xf16&gt;)


Full diff: https://github.com/llvm/llvm-project/pull/94486.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+15-2)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp (+20)
  • (modified) mlir/test/Target/LLVMIR/rocdl.mlir (+9)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 1dabf5d7979b7..c8d4e4c03486e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -165,7 +165,7 @@ def ROCDL_BallotOp :
   let summary = "Vote across thread group";
 
   let description = [{
-      Ballot provides a bit mask containing the 1-bit predicate value from each lane. 
+      Ballot provides a bit mask containing the 1-bit predicate value from each lane.
       The nth bit of the result contains the 1 bit contributed by the nth warp lane.
   }];
 
@@ -516,7 +516,7 @@ def ROCDL_RawBufferAtomicCmpSwap :
 }
 
 //===---------------------------------------------------------------------===//
-// MI-100 and MI-200 buffer atomic floating point add intrinsic
+// MI-100, MI-200 and MI-300 global/buffer atomic floating point add intrinsic
 
 def ROCDL_RawBufferAtomicFAddOp :
   ROCDL_Op<"raw.buffer.atomic.fadd">,
@@ -534,6 +534,19 @@ def ROCDL_RawBufferAtomicFAddOp :
   let hasCustomAssemblyFormat = 1;
 }
 
+def ROCDL_GlobalAtomicFAddOp :
+  ROCDL_Op<"global.atomic.fadd">,
+  Arguments<(ins LLVM_Type:$ptr,
+                 LLVM_Type:$vdata)>{
+  string llvmBuilder = [{
+      auto vdataType = moduleTranslation.convertType(op.getVdata().getType());
+      auto ptrType = moduleTranslation.convertType(op.getPtr().getType());
+      createIntrinsicCall(builder,
+          llvm::Intrinsic::amdgcn_global_atomic_fadd, {$ptr, $vdata}, {vdataType, ptrType, vdataType});
+  }];
+  let hasCustomAssemblyFormat = 1;
+}
+
 //===---------------------------------------------------------------------===//
 // Buffer atomic floating point max intrinsic. GFX9 does not support fp32.
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index 65b770ae32610..34ebdb2ffd3d0 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -157,6 +157,26 @@ void RawBufferAtomicFAddOp::print(mlir::OpAsmPrinter &p) {
   p << " " << getOperands() << " : " << getVdata().getType();
 }
 
+// <operation> ::=
+//     `llvm.amdgcn.global.atomic.fadd.* %vdata, %ptr
+ParseResult GlobalAtomicFAddOp::parse(OpAsmParser &parser,
+                                      OperationState &result) {
+  SmallVector<OpAsmParser::UnresolvedOperand, 5> ops;
+  Type type;
+  if (parser.parseOperandList(ops, 2) || parser.parseColonType(type))
+    return failure();
+
+  auto ptrType = LLVM::LLVMPointerType::get(parser.getContext());
+  if (parser.resolveOperands(ops, {ptrType, type}, parser.getNameLoc(),
+                             result.operands))
+    return failure();
+  return success();
+}
+
+void GlobalAtomicFAddOp::print(mlir::OpAsmPrinter &p) {
+  p << " " << getOperands() << " : " << getVdata().getType();
+}
+
 // <operation> ::=
 //     `llvm.amdgcn.raw.buffer.atomic.fmax.* %vdata, %rsrc,  %offset,
 //     %soffset, %aux : result_type`
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index ce6b56d48437a..9d22b80748e14 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -494,6 +494,15 @@ llvm.func @rocdl.raw.buffer.atomic.f32(%rsrc : vector<4xi32>,
   llvm.return
 }
 
+// CHECK-LABEL: rocdl.global.atomic
+llvm.func @rocdl.global.atomic(%vdata0 : f32, %vdata1 : vector<2xf16>, %ptr : !llvm.ptr) {
+  // CHECK: call float @llvm.amdgcn.global.atomic.fadd.f32.p0.f32(ptr %{{.*}}, float %{{.*}}
+  rocdl.global.atomic.fadd %ptr, %vdata0: f32
+  // CHECK: call <2 x half> @llvm.amdgcn.global.atomic.fadd.v2f16.p0.v2f16(ptr %{{.*}}, <2 x half> %{{.*}})
+  rocdl.global.atomic.fadd %ptr, %vdata1: vector<2xf16>
+  llvm.return
+}
+
 llvm.func @rocdl.raw.buffer.atomic.i32(%rsrc : vector<4xi32>,
                         %offset : i32, %soffset : i32,
                         %vdata1 : i32) {

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td Outdated Show resolved Hide resolved
@krzysz00 krzysz00 self-requested a review June 5, 2024 18:38
@giuseros giuseros marked this pull request as draft June 6, 2024 10:56
@giuseros
Copy link
Contributor Author

giuseros commented Jun 6, 2024

I noticed that the use case I am dealing with needs also the output from the global.fadd (i.e., the original value in global memory). I am converting this to draft to avoid being accidentally merged, and I will ping you guys back once I sort this out. Thanks!

@giuseros giuseros marked this pull request as ready for review June 6, 2024 11:32
@giuseros
Copy link
Contributor Author

giuseros commented Jun 6, 2024

Done

Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please just use atomicrmw fadd. I will shortly be pushing to remove the intrinsic

@giuseros
Copy link
Contributor Author

giuseros commented Jun 6, 2024

Please just use atomicrmw fadd. I will shortly be pushing to remove the intrinsic

Hi @arsenm , the problem is that atomicrmw fadd does not support vectors. So, in the case of fp16, this gets translated into a cas loop which is very slow

@giuseros
Copy link
Contributor Author

giuseros commented Jun 6, 2024

Please just use atomicrmw fadd. I will shortly be pushing to remove the intrinsic

Hi @arsenm , the problem is that atomicrmw fadd does not support vectors. So, in the case of fp16, this gets translated into a cas loop which is very slow

Or maybe it does?

@krzysz00
Copy link
Contributor

krzysz00 commented Jun 6, 2024

@giuseros I wonder if it's that MLIR's wrappers around atomicrmw don't support vectors ... which seems like an extension we could do

@arsenm
Copy link
Contributor

arsenm commented Jun 6, 2024

Please just use atomicrmw fadd. I will shortly be pushing to remove the intrinsic

Hi @arsenm , the problem is that atomicrmw fadd does not support vectors. So, in the case of fp16, this gets translated into a cas loop which is very slow

Or maybe it does?

atomicrmw FP operations do since 4cb110a. I still need to implement the AMDGPU codegen changes to start using the vector instructions though (plus eventually the new metadata from #85052 will be needed

@giuseros
Copy link
Contributor Author

giuseros commented Jun 6, 2024

Ok, given I am exactly after that vector instruction, how about we merge this PR and then we enable vector support for atomicrmw in MLIR (like @krzysz00 was suggesting) once it emits the vector instruction?

@giuseros
Copy link
Contributor Author

giuseros commented Jun 7, 2024

Hi @arsenm , is it ok for this to merge?

@arsenm
Copy link
Contributor

arsenm commented Jun 7, 2024

Hi @arsenm , is it ok for this to merge?

I guess, though I always prefer to just do whatever is needed move towards the end goal instead of adding new throwaway code

@giuseros
Copy link
Contributor Author

giuseros commented Jun 7, 2024

Ok, after a chat with Matthew, we agree on closing this for now and trying to emit the vectorized atocmirmw intrinsic. If we urgently need the feature, we will get back to this.

@giuseros giuseros closed this Jun 7, 2024
@arsenm
Copy link
Contributor

arsenm commented Jun 8, 2024

Part 1 to start supporting the vector selection is in #94845

@arsenm
Copy link
Contributor

arsenm commented Jun 13, 2024

#95393 for the LDS case, #95394 for global and flat

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants