-
Notifications
You must be signed in to change notification settings - Fork 11.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][AMDGPU] Add support for AMD f16 math library calls #108809
base: main
Are you sure you want to change the base?
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Daniel Hernandez-Juarez (dhernandez0) ChangesIn this PR we add support for AMD f16 math library calls (_ocml*_f16) Patch is 52.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/108809.diff 6 Files Affected:
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 6be5548fdb60ef..8a9414d32ec611 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -17,10 +17,10 @@
namespace mlir {
/// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` or
-/// `f32ApproxFunc` depending on the element type and the fastMathFlag of that
+/// `f32ApproxFunc` or `f16Func` depending on the element type and the fastMathFlag of that
/// Op. The function declaration is added in case it was not added before.
///
-/// If the input values are of f16 type, the value is first casted to f32, the
+/// If the input values are of unsupported type, the value is first casted to f32, the
/// function called and then the result casted back.
///
/// Example with NVVM:
@@ -41,9 +41,9 @@ template <typename SourceOp>
struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
public:
explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func,
- StringRef f64Func, StringRef f32ApproxFunc)
+ StringRef f64Func, StringRef f32ApproxFunc, StringRef f16Func)
: ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
- f64Func(f64Func), f32ApproxFunc(f32ApproxFunc) {}
+ f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func) {}
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -89,7 +89,14 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
private:
Value maybeCast(Value operand, PatternRewriter &rewriter) const {
Type type = operand.getType();
- if (!isa<Float16Type>(type))
+ if (!isa<FloatType>(type))
+ return operand;
+
+ // if there's a f16 function, no need to cast f16 values
+ if (!f16Func.empty() && isa<Float16Type>(type))
+ return operand;
+
+ if (isa<Float64Type>(type) || isa<Float32Type>(type))
return operand;
return rewriter.create<LLVM::FPExtOp>(
@@ -102,6 +109,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
}
StringRef getFunctionName(Type type, arith::FastMathFlags flag) const {
+ if (isa<Float16Type>(type))
+ return f16Func;
if (isa<Float32Type>(type)) {
if (((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
!f32ApproxFunc.empty())
@@ -130,6 +139,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
const std::string f32Func;
const std::string f64Func;
const std::string f32ApproxFunc;
+ const std::string f16Func;
};
} // namespace mlir
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 164622d77e6b62..f5650c35c3b3c4 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -336,10 +336,10 @@ template <typename OpTy>
static void populateOpPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns, StringRef f32Func,
StringRef f64Func,
- StringRef f32ApproxFunc = "") {
+ StringRef f32ApproxFunc = "", StringRef f16Func = "") {
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
- f32ApproxFunc);
+ f32ApproxFunc, f16Func);
}
void mlir::populateGpuSubgroupReduceOpLoweringPattern(
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index fc3e1fc4f9d0c9..6b9e6b1192e050 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -346,9 +346,9 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
template <typename OpTy>
static void populateOpPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns, StringRef f32Func,
- StringRef f64Func) {
+ StringRef f64Func, StringRef f32ApproxFunc, StringRef f16Func) {
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
- patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
+ patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f32ApproxFunc, f16Func);
}
void mlir::populateGpuToROCDLConversionPatterns(
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index b3b4d81e7ffa5b..1611a8835c91ef 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -39,16 +39,17 @@ template <typename OpTy>
static void populateOpPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns, StringRef f32Func,
StringRef f64Func,
+ StringRef f16Func,
StringRef f32ApproxFunc = "") {
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
- f32ApproxFunc);
+ f32ApproxFunc, f16Func);
}
void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
// Handled by mathToLLVM: math::AbsIOp
- // Handled by mathToLLVM: math::AbsFIOp
+ // Handled by mathToLLVM: math::AbsFOp
// Handled by mathToLLVM: math::CopySignOp
// Handled by mathToLLVM: math::CountLeadingZerosOp
// Handled by mathToLLVM: math::CountTrailingZerosOp
@@ -63,59 +64,61 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
// Handled by mathToLLVM: math::SqrtOp
// Handled by mathToLLVM: math::TruncOp
populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
- "__ocml_acos_f64");
+ "__ocml_acos_f64", "__ocml_acos_f16");
populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
- "__ocml_acosh_f64");
+ "__ocml_acosh_f64", "__ocml_acosh_f16");
populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
- "__ocml_asin_f64");
+ "__ocml_asin_f64", "__ocml_asin_f16");
populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
- "__ocml_asinh_f64");
+ "__ocml_asinh_f64", "__ocml_asinh_f16");
populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
- "__ocml_atan_f64");
+ "__ocml_atan_f64", "__ocml_atan_f16");
populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
- "__ocml_atanh_f64");
+ "__ocml_atanh_f64", "__ocml_atanh_f16");
populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
- "__ocml_atan2_f64");
+ "__ocml_atan2_f64", "__ocml_atan2_f16");
populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
- "__ocml_cbrt_f64");
+ "__ocml_cbrt_f64", "__ocml_cbrt_f16");
populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
- "__ocml_ceil_f64");
+ "__ocml_ceil_f64", "__ocml_ceil_f16");
populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
- "__ocml_cos_f64");
+ "__ocml_cos_f64", "__ocml_cos_f16");
populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
- "__ocml_cosh_f64");
+ "__ocml_cosh_f64", "__ocml_cosh_f16");
populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
- "__ocml_sinh_f64");
- populateOpPatterns<math::ExpOp>(converter, patterns, "", "__ocml_exp_f64");
+ "__ocml_sinh_f64", "__ocml_sinh_f16");
+ populateOpPatterns<math::ExpOp>(converter, patterns, "",
+ "__ocml_exp_f64", "__ocml_exp_f16");
populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
- "__ocml_exp2_f64");
+ "__ocml_exp2_f64", "__ocml_exp2_f16");
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
- "__ocml_expm1_f64");
+ "__ocml_expm1_f64", "__ocml_expm1_f16");
populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
- "__ocml_floor_f64");
- populateOpPatterns<math::LogOp>(converter, patterns, "", "__ocml_log_f64");
+ "__ocml_floor_f64", "__ocml_floor_f16");
+ populateOpPatterns<math::LogOp>(converter, patterns, "",
+ "__ocml_log_f64", "__ocml_log_f16");
populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
- "__ocml_log10_f64");
+ "__ocml_log10_f64", "__ocml_log10_f16");
populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
- "__ocml_log1p_f64");
+ "__ocml_log1p_f64", "__ocml_log1p_f16");
populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
- "__ocml_log2_f64");
+ "__ocml_log2_f64", "__ocml_log2_f16");
populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
- "__ocml_pow_f64");
+ "__ocml_pow_f64", "__ocml_pow_f16");
populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
- "__ocml_rsqrt_f64");
+ "__ocml_rsqrt_f64", "__ocml_rsqrt_f16");
populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
- "__ocml_sin_f64");
+ "__ocml_sin_f64", "__ocml_sin_f16");
populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
- "__ocml_tanh_f64");
+ "__ocml_tanh_f64", "__ocml_tanh_f16");
populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
- "__ocml_tan_f64");
+ "__ocml_tan_f64", "__ocml_tan_f16");
populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
- "__ocml_erf_f64");
+ "__ocml_erf_f64", "__ocml_erf_f16");
// Single arith pattern that needs a ROCDL call, probably not
// worth creating a separate pass for it.
populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
- "__ocml_fmod_f64");
+ "__ocml_fmod_f64", "__ocml_fmod_f16");
}
namespace {
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index eb065cbab86789..0d3e9f4ea2bf39 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -162,11 +162,12 @@ gpu.module @test_module {
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_exp_f16(f16) -> f16
// CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
// CHECK-LABEL: func @gpu_exp
func.func @gpu_exp(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
%result16 = math.exp %arg_f16 : f16
- // CHECK: llvm.intr.exp(%{{.*}}) : (f16) -> f16
+ // CHECK: llvm.call @__ocml_exp_f16(%{{.*}}) : (f16) -> f16
%result32 = math.exp %arg_f32 : f32
// CHECK: llvm.intr.exp(%{{.*}}) : (f32) -> f32
%result64 = math.exp %arg_f64 : f64
@@ -178,11 +179,12 @@ gpu.module @test_module {
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_log_f16(f16) -> f16
// CHECK: llvm.func @__ocml_log_f64(f64) -> f64
// CHECK-LABEL: func @gpu_log
func.func @gpu_log(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
%result16 = math.log %arg_f16 : f16
- // CHECK: llvm.intr.log(%{{.*}}) : (f16) -> f16
+ // CHECK: llvm.call @__ocml_log_f16(%{{.*}}) : (f16) -> f16
%result32 = math.log %arg_f32 : f32
// CHECK: llvm.intr.log(%{{.*}}) : (f32) -> f32
%result64 = math.log %arg_f64 : f64
@@ -194,108 +196,113 @@ gpu.module @test_module {
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_cbrt_f16(f16) -> f16
// CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32
// CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64
// CHECK-LABEL: func @gpu_cbrt
- func.func @gpu_cbrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_cbrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.cbrt %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_cbrt_f16(%{{.*}}) : (f16) -> f16
%result32 = math.cbrt %arg_f32 : f32
// CHECK: llvm.call @__ocml_cbrt_f32(%{{.*}}) : (f32) -> f32
%result64 = math.cbrt %arg_f64 : f64
// CHECK: llvm.call @__ocml_cbrt_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_ceil_f16(f16) -> f16
// CHECK: llvm.func @__ocml_ceil_f32(f32) -> f32
// CHECK: llvm.func @__ocml_ceil_f64(f64) -> f64
// CHECK-LABEL: func @gpu_ceil
- func.func @gpu_ceil(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_ceil(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.ceil %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_ceil_f16(%{{.*}}) : (f16) -> f16
%result32 = math.ceil %arg_f32 : f32
// CHECK: llvm.call @__ocml_ceil_f32(%{{.*}}) : (f32) -> f32
%result64 = math.ceil %arg_f64 : f64
// CHECK: llvm.call @__ocml_ceil_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_floor_f16(f16) -> f16
// CHECK: llvm.func @__ocml_floor_f32(f32) -> f32
// CHECK: llvm.func @__ocml_floor_f64(f64) -> f64
// CHECK-LABEL: func @gpu_floor
- func.func @gpu_floor(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_floor(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.floor %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_floor_f16(%{{.*}}) : (f16) -> f16
%result32 = math.floor %arg_f32 : f32
// CHECK: llvm.call @__ocml_floor_f32(%{{.*}}) : (f32) -> f32
%result64 = math.floor %arg_f64 : f64
// CHECK: llvm.call @__ocml_floor_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_cos_f16(f16) -> f16
// CHECK: llvm.func @__ocml_cos_f32(f32) -> f32
// CHECK: llvm.func @__ocml_cos_f64(f64) -> f64
// CHECK-LABEL: func @gpu_cos
- func.func @gpu_cos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_cos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.cos %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_cos_f16(%{{.*}}) : (f16) -> f16
%result32 = math.cos %arg_f32 : f32
// CHECK: llvm.call @__ocml_cos_f32(%{{.*}}) : (f32) -> f32
%result64 = math.cos %arg_f64 : f64
// CHECK: llvm.call @__ocml_cos_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
- }
-}
-
-// -----
-
-gpu.module @test_module {
- // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
- // CHECK-LABEL: func @gpu_exp
- func.func @gpu_exp(%arg_f64 : f64) -> (f64) {
- %result64 = math.exp %arg_f64 : f64
- // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
- func.return %result64 : f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_exp2_f16(f16) -> f16
// CHECK: llvm.func @__ocml_exp2_f32(f32) -> f32
// CHECK: llvm.func @__ocml_exp2_f64(f64) -> f64
// CHECK-LABEL: func @gpu_exp2
- func.func @gpu_exp2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_exp2(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.exp2 %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_exp2_f16(%{{.*}}) : (f16) -> f16
%exp2_f32 = math.exp2 %arg_f32 : f32
// CHECK: llvm.call @__ocml_exp2_f32(%{{.*}}) : (f32) -> f32
%result32 = math.exp2 %exp2_f32 : f32
// CHECK: llvm.call @__ocml_exp2_f32(%{{.*}}) : (f32) -> f32
%result64 = math.exp2 %arg_f64 : f64
// CHECK: llvm.call @__ocml_exp2_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
+
// Test that we handled properly operation with SymbolTable other than module op
gpu.module @test_module {
"test.symbol_scope"() ({
// CHECK: test.symbol_scope
+ // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16
// CHECK: llvm.func @__ocml_sin_f32(f32) -> f32
// CHECK: llvm.func @__ocml_sin_f64(f64) -> f64
// CHECK-LABEL: func @gpu_sin
- func.func @gpu_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
- %sin_f32 = math.sin %arg_f32 : f32
+ func.func @gpu_sin(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
+ %result16 = math.sin %arg_f16 : f16
// CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
- %result32 = math.sin %sin_f32 : f32
- // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.sin %arg_f64 : f64
+ %result32 = math.sin %arg_f32 : f32
// CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ %result64 = math.sin %arg_f64 : f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
"test.finish" () : () -> ()
}) : () -> ()
@@ -304,89 +311,102 @@ gpu.module @test_module {
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_expm1_f16(f16) -> f16
// CHECK: llvm.func @__ocml_expm1_f32(f32) -> f32
// CHECK: llvm.func @__ocml_expm1_f64(f64) -> f64
// CHECK-LABEL: func @gpu_expm1
- func.func @gpu_expm1(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_expm1(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.expm1 %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_expm1_f16(%{{.*}}) : (f16) -> f16
%expm1_f32 = math.expm1 %arg_f32 : f32
// CHECK: llvm.call @__ocml_expm1_f32(%{{.*}}) : (f32) -> f32
%result32 = math.expm1 %expm1_f32 : f32
// CHECK: llvm.call @__ocml_expm1_f32(%{{.*}}) : (f32) -> f32
%result64 = math.expm1 %arg_f64 : f64
// CHECK: llvm.call @__ocml_expm1_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
gpu.mod...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
752dc7c
to
4cce1f0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding the two f16 calls that you are adding where f32 calls are not present, I am fine with whichever choice makes sense for the hardware / AMDGPU backend but the following needs to be updated accordingly,
https://github.com/llvm/llvm-project/blob/main/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp#L338
It controls if for a given type we will make a llvm op or the library call. Otherwise which one happens may vary based on which pattern gets applied first or in the worst case we end up with the op not getting lowered at all and causing a crash.
ef36d24
to
7d0fc51
Compare
Thanks for your review! I've updated LowerGpuOpsToROCDLOps.cpp as suggested. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
} | ||
} | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: please add a newline here.
In this PR we add support for AMD f16 math library calls (
__ocml_*_f16
)CC: @krzysz00 @manupak