Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][AMDGPU] Add support for AMD f16 math library calls #108809

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

dhernandez0
Copy link
Contributor

@dhernandez0 dhernandez0 commented Sep 16, 2024

In this PR we add support for AMD f16 math library calls (__ocml_*_f16)

CC: @krzysz00 @manupak

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 16, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Daniel Hernandez-Juarez (dhernandez0)

Changes

In this PR we add support for AMD f16 math library calls (_ocml*_f16)


Patch is 52.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/108809.diff

6 Files Affected:

  • (modified) mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h (+15-5)
  • (modified) mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp (+2-2)
  • (modified) mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp (+2-2)
  • (modified) mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp (+32-29)
  • (modified) mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir (+102-61)
  • (modified) mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir (+131-69)
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 6be5548fdb60ef..8a9414d32ec611 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -17,10 +17,10 @@
 namespace mlir {
 
 /// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` or
-/// `f32ApproxFunc` depending on the element type and the fastMathFlag of that
+/// `f32ApproxFunc` or `f16Func` depending on the element type and the fastMathFlag of that
 /// Op. The function declaration is added in case it was not added before.
 ///
-/// If the input values are of f16 type, the value is first casted to f32, the
+/// If the input values are of unsupported type, the value is first casted to f32, the
 /// function called and then the result casted back.
 ///
 /// Example with NVVM:
@@ -41,9 +41,9 @@ template <typename SourceOp>
 struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
 public:
   explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func,
-                                StringRef f64Func, StringRef f32ApproxFunc)
+                                StringRef f64Func, StringRef f32ApproxFunc, StringRef f16Func)
       : ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
-        f64Func(f64Func), f32ApproxFunc(f32ApproxFunc) {}
+        f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func) {}
 
   LogicalResult
   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -89,7 +89,14 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
 private:
   Value maybeCast(Value operand, PatternRewriter &rewriter) const {
     Type type = operand.getType();
-    if (!isa<Float16Type>(type))
+    if (!isa<FloatType>(type))
+      return operand;
+
+    // if there's a f16 function, no need to cast f16 values
+    if (!f16Func.empty() && isa<Float16Type>(type))
+      return operand;
+
+    if (isa<Float64Type>(type) || isa<Float32Type>(type))
       return operand;
 
     return rewriter.create<LLVM::FPExtOp>(
@@ -102,6 +109,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
   }
 
   StringRef getFunctionName(Type type, arith::FastMathFlags flag) const {
+    if (isa<Float16Type>(type))
+      return f16Func;
     if (isa<Float32Type>(type)) {
       if (((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
           !f32ApproxFunc.empty())
@@ -130,6 +139,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
   const std::string f32Func;
   const std::string f64Func;
   const std::string f32ApproxFunc;
+  const std::string f16Func;
 };
 
 } // namespace mlir
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 164622d77e6b62..f5650c35c3b3c4 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -336,10 +336,10 @@ template <typename OpTy>
 static void populateOpPatterns(LLVMTypeConverter &converter,
                                RewritePatternSet &patterns, StringRef f32Func,
                                StringRef f64Func,
-                               StringRef f32ApproxFunc = "") {
+                               StringRef f32ApproxFunc = "", StringRef f16Func = "") {
   patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
   patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
-                                           f32ApproxFunc);
+                                           f32ApproxFunc, f16Func);
 }
 
 void mlir::populateGpuSubgroupReduceOpLoweringPattern(
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index fc3e1fc4f9d0c9..6b9e6b1192e050 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -346,9 +346,9 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
 template <typename OpTy>
 static void populateOpPatterns(LLVMTypeConverter &converter,
                                RewritePatternSet &patterns, StringRef f32Func,
-                               StringRef f64Func) {
+                               StringRef f64Func, StringRef f32ApproxFunc, StringRef f16Func) {
   patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
-  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
+  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f32ApproxFunc, f16Func);
 }
 
 void mlir::populateGpuToROCDLConversionPatterns(
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index b3b4d81e7ffa5b..1611a8835c91ef 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -39,16 +39,17 @@ template <typename OpTy>
 static void populateOpPatterns(LLVMTypeConverter &converter,
                                RewritePatternSet &patterns, StringRef f32Func,
                                StringRef f64Func,
+                               StringRef f16Func,
                                StringRef f32ApproxFunc = "") {
   patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
   patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
-                                           f32ApproxFunc);
+                                           f32ApproxFunc, f16Func);
 }
 
 void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                                  RewritePatternSet &patterns) {
   // Handled by mathToLLVM: math::AbsIOp
-  // Handled by mathToLLVM: math::AbsFIOp
+  // Handled by mathToLLVM: math::AbsFOp
   // Handled by mathToLLVM: math::CopySignOp
   // Handled by mathToLLVM: math::CountLeadingZerosOp
   // Handled by mathToLLVM: math::CountTrailingZerosOp
@@ -63,59 +64,61 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
   // Handled by mathToLLVM: math::SqrtOp
   // Handled by mathToLLVM: math::TruncOp
   populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
-                                   "__ocml_acos_f64");
+                                   "__ocml_acos_f64", "__ocml_acos_f16");
   populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
-                                    "__ocml_acosh_f64");
+                                    "__ocml_acosh_f64", "__ocml_acosh_f16");
   populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
-                                   "__ocml_asin_f64");
+                                   "__ocml_asin_f64", "__ocml_asin_f16");
   populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
-                                    "__ocml_asinh_f64");
+                                    "__ocml_asinh_f64", "__ocml_asinh_f16");
   populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
-                                   "__ocml_atan_f64");
+                                   "__ocml_atan_f64", "__ocml_atan_f16");
   populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
-                                    "__ocml_atanh_f64");
+                                    "__ocml_atanh_f64", "__ocml_atanh_f16");
   populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
-                                    "__ocml_atan2_f64");
+                                    "__ocml_atan2_f64", "__ocml_atan2_f16");
   populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
-                                   "__ocml_cbrt_f64");
+                                   "__ocml_cbrt_f64", "__ocml_cbrt_f16");
   populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
-                                   "__ocml_ceil_f64");
+                                   "__ocml_ceil_f64", "__ocml_ceil_f16");
   populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
-                                  "__ocml_cos_f64");
+                                  "__ocml_cos_f64", "__ocml_cos_f16");
   populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
-                                   "__ocml_cosh_f64");
+                                   "__ocml_cosh_f64", "__ocml_cosh_f16");
   populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
-                                   "__ocml_sinh_f64");
-  populateOpPatterns<math::ExpOp>(converter, patterns, "", "__ocml_exp_f64");
+                                   "__ocml_sinh_f64", "__ocml_sinh_f16");
+  populateOpPatterns<math::ExpOp>(converter, patterns, "",
+                                  "__ocml_exp_f64", "__ocml_exp_f16");
   populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
-                                   "__ocml_exp2_f64");
+                                   "__ocml_exp2_f64", "__ocml_exp2_f16");
   populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
-                                    "__ocml_expm1_f64");
+                                    "__ocml_expm1_f64", "__ocml_expm1_f16");
   populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
-                                    "__ocml_floor_f64");
-  populateOpPatterns<math::LogOp>(converter, patterns, "", "__ocml_log_f64");
+                                    "__ocml_floor_f64", "__ocml_floor_f16");
+  populateOpPatterns<math::LogOp>(converter, patterns, "",
+                                  "__ocml_log_f64", "__ocml_log_f16");
   populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
-                                    "__ocml_log10_f64");
+                                    "__ocml_log10_f64", "__ocml_log10_f16");
   populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
-                                    "__ocml_log1p_f64");
+                                    "__ocml_log1p_f64", "__ocml_log1p_f16");
   populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
-                                   "__ocml_log2_f64");
+                                   "__ocml_log2_f64", "__ocml_log2_f16");
   populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
-                                   "__ocml_pow_f64");
+                                   "__ocml_pow_f64", "__ocml_pow_f16");
   populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
-                                    "__ocml_rsqrt_f64");
+                                    "__ocml_rsqrt_f64", "__ocml_rsqrt_f16");
   populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
-                                  "__ocml_sin_f64");
+                                  "__ocml_sin_f64", "__ocml_sin_f16");
   populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
-                                   "__ocml_tanh_f64");
+                                   "__ocml_tanh_f64", "__ocml_tanh_f16");
   populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
-                                  "__ocml_tan_f64");
+                                  "__ocml_tan_f64", "__ocml_tan_f16");
   populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
-                                  "__ocml_erf_f64");
+                                  "__ocml_erf_f64", "__ocml_erf_f16");
   // Single arith pattern that needs a ROCDL call, probably not
   // worth creating a separate pass for it.
   populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
-                                    "__ocml_fmod_f64");
+                                    "__ocml_fmod_f64", "__ocml_fmod_f16");
 }
 
 namespace {
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index eb065cbab86789..0d3e9f4ea2bf39 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -162,11 +162,12 @@ gpu.module @test_module {
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_exp_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_exp
   func.func @gpu_exp(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
     %result16 = math.exp %arg_f16 : f16
-    // CHECK: llvm.intr.exp(%{{.*}})  : (f16) -> f16
+    // CHECK: llvm.call @__ocml_exp_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.exp %arg_f32 : f32
     // CHECK: llvm.intr.exp(%{{.*}})  : (f32) -> f32
     %result64 = math.exp %arg_f64 : f64
@@ -178,11 +179,12 @@ gpu.module @test_module {
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_log_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_log_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_log
   func.func @gpu_log(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
     %result16 = math.log %arg_f16 : f16
-    // CHECK: llvm.intr.log(%{{.*}})  : (f16) -> f16
+    // CHECK: llvm.call @__ocml_log_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.log %arg_f32 : f32
     // CHECK: llvm.intr.log(%{{.*}})  : (f32) -> f32
     %result64 = math.log %arg_f64 : f64
@@ -194,108 +196,113 @@ gpu.module @test_module {
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_cbrt_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_cbrt
-  func.func @gpu_cbrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_cbrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.cbrt %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_cbrt_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.cbrt %arg_f32 : f32
     // CHECK: llvm.call @__ocml_cbrt_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.cbrt %arg_f64 : f64
     // CHECK: llvm.call @__ocml_cbrt_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_ceil_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_ceil_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_ceil_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_ceil
-  func.func @gpu_ceil(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_ceil(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.ceil %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_ceil_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.ceil %arg_f32 : f32
     // CHECK: llvm.call @__ocml_ceil_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.ceil %arg_f64 : f64
     // CHECK: llvm.call @__ocml_ceil_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_floor_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_floor_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_floor_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_floor
-  func.func @gpu_floor(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_floor(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.floor %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_floor_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.floor %arg_f32 : f32
     // CHECK: llvm.call @__ocml_floor_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.floor %arg_f64 : f64
     // CHECK: llvm.call @__ocml_floor_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_cos_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_cos_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_cos_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_cos
-  func.func @gpu_cos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_cos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.cos %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_cos_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.cos %arg_f32 : f32
     // CHECK: llvm.call @__ocml_cos_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.cos %arg_f64 : f64
     // CHECK: llvm.call @__ocml_cos_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
-  }
-}
-
-// -----
-
-gpu.module @test_module {
-  // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
-  // CHECK-LABEL: func @gpu_exp
-  func.func @gpu_exp(%arg_f64 : f64) -> (f64) {
-    %result64 = math.exp %arg_f64 : f64
-    // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
-    func.return %result64 : f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_exp2_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_exp2_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_exp2_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_exp2
-  func.func @gpu_exp2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_exp2(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.exp2 %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_exp2_f16(%{{.*}}) : (f16) -> f16
     %exp2_f32 = math.exp2 %arg_f32 : f32
     // CHECK: llvm.call @__ocml_exp2_f32(%{{.*}}) : (f32) -> f32
     %result32 = math.exp2 %exp2_f32 : f32
     // CHECK: llvm.call @__ocml_exp2_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.exp2 %arg_f64 : f64
     // CHECK: llvm.call @__ocml_exp2_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
+
 // Test that we handled properly operation with SymbolTable other than module op
 gpu.module @test_module {
   "test.symbol_scope"() ({
     // CHECK: test.symbol_scope
+    // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16
     // CHECK: llvm.func @__ocml_sin_f32(f32) -> f32
     // CHECK: llvm.func @__ocml_sin_f64(f64) -> f64
     // CHECK-LABEL: func @gpu_sin
-    func.func @gpu_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
-      %sin_f32 = math.sin %arg_f32 : f32
+    func.func @gpu_sin(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+      // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
+      %result16 = math.sin %arg_f16 : f16
       // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
-      %result32 = math.sin %sin_f32 : f32
-      // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
-      %result64 = math.sin %arg_f64 : f64
+      %result32 = math.sin %arg_f32 : f32
       // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64
-      func.return %result32, %result64 : f32, f64
+      %result64 = math.sin %arg_f64 : f64
+      func.return %result16, %result32, %result64 : f16, f32, f64
     }
     "test.finish" () : () -> ()
   }) : () -> ()
@@ -304,89 +311,102 @@ gpu.module @test_module {
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_expm1_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_expm1_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_expm1_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_expm1
-  func.func @gpu_expm1(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_expm1(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.expm1 %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_expm1_f16(%{{.*}}) : (f16) -> f16
     %expm1_f32 = math.expm1 %arg_f32 : f32
     // CHECK: llvm.call @__ocml_expm1_f32(%{{.*}}) : (f32) -> f32
     %result32 = math.expm1 %expm1_f32 : f32
     // CHECK: llvm.call @__ocml_expm1_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.expm1 %arg_f64 : f64
     // CHECK: llvm.call @__ocml_expm1_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 gpu.mod...
[truncated]

Copy link

github-actions bot commented Sep 16, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@dhernandez0 dhernandez0 force-pushed the amd_f16_math_lib_calls branch 2 times, most recently from 752dc7c to 4cce1f0 Compare September 17, 2024 08:13
Copy link
Contributor

@nirvedhmeshram nirvedhmeshram left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the two f16 calls that you are adding where f32 calls are not present, I am fine with whichever choice makes sense for the hardware / AMDGPU backend but the following needs to be updated accordingly,
https://github.com/llvm/llvm-project/blob/main/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp#L338
It controls if for a given type we will make a llvm op or the library call. Otherwise which one happens may vary based on which pattern gets applied first or in the worst case we end up with the op not getting lowered at all and causing a crash.

@dhernandez0
Copy link
Contributor Author

Regarding the two f16 calls that you are adding where f32 calls are not present, I am fine with whichever choice makes sense for the hardware / AMDGPU backend but the following needs to be updated accordingly, https://github.com/llvm/llvm-project/blob/main/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp#L338 It controls if for a given type we will make a llvm op or the library call. Otherwise which one happens may vary based on which pattern gets applied first or in the worst case we end up with the op not getting lowered at all and causing a crash.

Thanks for your review! I've updated LowerGpuOpsToROCDLOps.cpp as suggested.

Copy link
Contributor

@nirvedhmeshram nirvedhmeshram left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

}
}

}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: please add a newline here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants