diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 1dabf5d7979b70..e656ce8f62313f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -342,6 +342,7 @@ def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1] //===---------------------------------------------------------------------===// def ROCDLBufferRsrc : LLVM_PointerInAddressSpace<8>; +def ROCDLGlobalPtr: LLVM_PointerInAddressSpace<1>; def ROCDL_MakeBufferRsrcOp : ROCDL_IntrOp<"make.buffer.rsrc", [], [0], [Pure], 1>, @@ -516,7 +517,7 @@ def ROCDL_RawBufferAtomicCmpSwap : } //===---------------------------------------------------------------------===// -// MI-100 and MI-200 buffer atomic floating point add intrinsic +// gfx9x global/buffer atomic floating point add intrinsics def ROCDL_RawBufferAtomicFAddOp : ROCDL_Op<"raw.buffer.atomic.fadd">, @@ -534,6 +535,14 @@ def ROCDL_RawBufferAtomicFAddOp : let hasCustomAssemblyFormat = 1; } +def ROCDL_GlobalAtomicFAddOp: + ROCDL_IntrOp<"global.atomic.fadd", + [0], [0, 1], [AllTypesMatch<["res", "vdata"]>], 1>, + Arguments<(ins ROCDLGlobalPtr:$ptr, + LLVM_Type:$vdata)>{ + let assemblyFormat = "operands attr-dict `:` type($res)"; +} + //===---------------------------------------------------------------------===// // Buffer atomic floating point max intrinsic. GFX9 does not support fp32. diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index ce6b56d48437a0..c940d01a0a6145 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -494,6 +494,15 @@ llvm.func @rocdl.raw.buffer.atomic.f32(%rsrc : vector<4xi32>, llvm.return } +// CHECK-LABEL: rocdl.global.atomic +llvm.func @rocdl.global.atomic(%vdata0 : f32, %vdata1 : vector<2xf16>, %ptr : !llvm.ptr<1>) { + // CHECK: call float @llvm.amdgcn.global.atomic.fadd.f32.p1.f32(ptr addrspace(1) %{{.*}}, float %{{.*}} + rocdl.global.atomic.fadd %ptr, %vdata0: f32 + // CHECK: call <2 x half> @llvm.amdgcn.global.atomic.fadd.v2f16.p1.v2f16(ptr addrspace(1) %{{.*}}, <2 x half> %{{.*}}) + rocdl.global.atomic.fadd %ptr, %vdata1: vector<2xf16> + llvm.return +} + llvm.func @rocdl.raw.buffer.atomic.i32(%rsrc : vector<4xi32>, %offset : i32, %soffset : i32, %vdata1 : i32) {