Skip to content

Commit

Permalink
Fixes for fma function (#1580)
Browse files Browse the repository at this point in the history
* [hlsl-out] Write `mad` intrinsic for `fma` function

- This should be enough because we only support f32 for now.
- Adds a new test for WGSL functions, in the spirit of operators.wgsl.
- Closes #1579

* Add FMA feature to glsl backend

- I think this is right. Just iterate all known expressions in all
  functions and entry points to locate any `fma` function call.
  Should not need to walk the statement DAG.

* Transform GLSL fma function into an airthmetic expression when necessary

* Add tests for GLSL fma function tranformation

* Remove the hazard comment from the webgl test input

* Add helper method for fma function support checks

* Address review comment
  • Loading branch information
parasyte authored and kvark committed Dec 29, 2021
1 parent 132e247 commit 789aa16
Show file tree
Hide file tree
Showing 15 changed files with 222 additions and 4 deletions.
33 changes: 31 additions & 2 deletions src/back/glsl/features.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{BackendResult, Error, Version, Writer};
use crate::{
Binding, Bytes, Handle, ImageClass, ImageDimension, Interpolation, Sampling, ScalarKind,
ShaderStage, StorageClass, StorageFormat, Type, TypeInner,
Binding, Bytes, Expression, Handle, ImageClass, ImageDimension, Interpolation, MathFunction,
Sampling, ScalarKind, ShaderStage, StorageClass, StorageFormat, Type, TypeInner,
};
use std::fmt::Write;

Expand Down Expand Up @@ -33,6 +33,8 @@ bitflags::bitflags! {
/// Arrays with a dynamic length
const DYNAMIC_ARRAY_SIZE = 1 << 16;
const MULTI_VIEW = 1 << 17;
/// Adds support for fused multiply-add
const FMA = 1 << 18;
}
}

Expand Down Expand Up @@ -98,6 +100,7 @@ impl FeaturesManager {
check_feature!(SAMPLE_VARIABLES, 400, 300);
check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310);
check_feature!(MULTI_VIEW, 140, 310);
check_feature!(FMA, 400, 310);

// Return an error if there are missing features
if missing.is_empty() {
Expand Down Expand Up @@ -201,6 +204,11 @@ impl FeaturesManager {
writeln!(out, "#extension GL_EXT_multiview : require")?;
}

if self.0.contains(Features::FMA) && version.is_es() {
// https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_gpu_shader5.txt
writeln!(out, "#extension GL_EXT_gpu_shader5 : require")?;
}

Ok(())
}
}
Expand Down Expand Up @@ -347,6 +355,27 @@ impl<'a, W> Writer<'a, W> {
}
}

if self.options.version.supports_fma_function() {
let has_fma = self
.module
.functions
.iter()
.flat_map(|(_, f)| f.expressions.iter())
.chain(
self.module
.entry_points
.iter()
.flat_map(|e| e.function.expressions.iter()),
)
.any(|(_, e)| match *e {
Expression::Math { fun, .. } if fun == MathFunction::Fma => true,
_ => false,
});
if has_fma {
self.features.request(Features::FMA);
}
}

self.features.check_availability(self.options.version)
}

Expand Down
29 changes: 28 additions & 1 deletion src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ impl Version {
fn supports_std430_layout(&self) -> bool {
*self >= Version::Desktop(430) || *self >= Version::Embedded(310)
}

fn supports_fma_function(&self) -> bool {
*self >= Version::Desktop(400) || *self >= Version::Embedded(310)
}
}

impl PartialOrd for Version {
Expand Down Expand Up @@ -2471,7 +2475,30 @@ impl<'a, W: Write> Writer<'a, W> {
Mf::Refract => "refract",
// computational
Mf::Sign => "sign",
Mf::Fma => "fma",
Mf::Fma => {
if self.options.version.supports_fma_function() {
// Use the fma function when available
"fma"
} else {
// No fma support. Transform the function call into an arithmetic expression
write!(self.out, "(")?;

self.write_expr(arg, ctx)?;
write!(self.out, " * ")?;

let arg1 =
arg1.ok_or_else(|| Error::Custom("Missing fma arg1".to_owned()))?;
self.write_expr(arg1, ctx)?;
write!(self.out, " + ")?;

let arg2 =
arg2.ok_or_else(|| Error::Custom("Missing fma arg2".to_owned()))?;
self.write_expr(arg2, ctx)?;
write!(self.out, ")")?;

return Ok(());
}
}
Mf::Mix => "mix",
Mf::Step => "step",
Mf::SmoothStep => "smoothstep",
Expand Down
2 changes: 1 addition & 1 deletion src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1862,7 +1862,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Mf::Refract => Function::Regular("refract"),
// computational
Mf::Sign => Function::Regular("sign"),
Mf::Fma => Function::Regular("fma"),
Mf::Fma => Function::Regular("mad"),
Mf::Mix => Function::Regular("lerp"),
Mf::Step => Function::Regular("step"),
Mf::SmoothStep => Function::Regular("smoothstep"),
Expand Down
7 changes: 7 additions & 0 deletions tests/in/functions-webgl.param.ron
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
(
glsl: (
version: Embedded(300),
writer_flags: (bits: 0),
binding_map: {},
),
)
13 changes: 13 additions & 0 deletions tests/in/functions-webgl.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
fn test_fma() -> vec2<f32> {
let a = vec2<f32>(2.0, 2.0);
let b = vec2<f32>(0.5, 0.5);
let c = vec2<f32>(0.5, 0.5);

return fma(a, b, c);
}


[[stage(vertex)]]
fn main() {
let a = test_fma();
}
2 changes: 2 additions & 0 deletions tests/in/functions.param.ron
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
(
)
15 changes: 15 additions & 0 deletions tests/in/functions.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
fn test_fma() -> vec2<f32> {
let a = vec2<f32>(2.0, 2.0);
let b = vec2<f32>(0.5, 0.5);
let c = vec2<f32>(0.5, 0.5);

// Hazard: HLSL needs a different intrinsic function for f32 and f64
// See: https://github.com/gfx-rs/naga/issues/1579
return fma(a, b, c);
}


[[stage(compute), workgroup_size(1)]]
fn main() {
let a = test_fma();
}
18 changes: 18 additions & 0 deletions tests/out/glsl/functions-webgl.main.Vertex.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#version 300 es

precision highp float;
precision highp int;


vec2 test_fma() {
vec2 a = vec2(2.0, 2.0);
vec2 b = vec2(0.5, 0.5);
vec2 c = vec2(0.5, 0.5);
return (a * b + c);
}

void main() {
vec2 _e0 = test_fma();
return;
}

21 changes: 21 additions & 0 deletions tests/out/glsl/functions.main.Compute.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#version 310 es
#extension GL_EXT_gpu_shader5 : require

precision highp float;
precision highp int;

layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;


vec2 test_fma() {
vec2 a = vec2(2.0, 2.0);
vec2 b = vec2(0.5, 0.5);
vec2 c = vec2(0.5, 0.5);
return fma(a, b, c);
}

void main() {
vec2 _e0 = test_fma();
return;
}

15 changes: 15 additions & 0 deletions tests/out/hlsl/functions.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

float2 test_fma()
{
float2 a = float2(2.0, 2.0);
float2 b = float2(0.5, 0.5);
float2 c = float2(0.5, 0.5);
return mad(a, b, c);
}

[numthreads(1, 1, 1)]
void main()
{
const float2 _e0 = test_fma();
return;
}
3 changes: 3 additions & 0 deletions tests/out/hlsl/functions.hlsl.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
vertex=()
fragment=()
compute=(main:cs_5_1 )
18 changes: 18 additions & 0 deletions tests/out/msl/functions.msl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// language: metal1.1
#include <metal_stdlib>
#include <simd/simd.h>


metal::float2 test_fma(
) {
metal::float2 a = metal::float2(2.0, 2.0);
metal::float2 b = metal::float2(0.5, 0.5);
metal::float2 c = metal::float2(0.5, 0.5);
return metal::fma(a, b, c);
}

kernel void main_(
) {
metal::float2 _e0 = test_fma();
return;
}
33 changes: 33 additions & 0 deletions tests/out/spv/functions.spvasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 20
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %16 "main"
OpExecutionMode %16 LocalSize 1 1 1
%2 = OpTypeVoid
%4 = OpTypeFloat 32
%3 = OpConstant %4 2.0
%5 = OpConstant %4 0.5
%6 = OpTypeVector %4 2
%9 = OpTypeFunction %6
%17 = OpTypeFunction %2
%8 = OpFunction %6 None %9
%7 = OpLabel
OpBranch %10
%10 = OpLabel
%11 = OpCompositeConstruct %6 %3 %3
%12 = OpCompositeConstruct %6 %5 %5
%13 = OpCompositeConstruct %6 %5 %5
%14 = OpExtInst %6 %1 Fma %11 %12 %13
OpReturnValue %14
OpFunctionEnd
%16 = OpFunction %2 None %17
%15 = OpLabel
OpBranch %18
%18 = OpLabel
%19 = OpFunctionCall %6 %8
OpReturn
OpFunctionEnd
12 changes: 12 additions & 0 deletions tests/out/wgsl/functions.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
fn test_fma() -> vec2<f32> {
let a = vec2<f32>(2.0, 2.0);
let b = vec2<f32>(0.5, 0.5);
let c = vec2<f32>(0.5, 0.5);
return fma(a, b, c);
}

[[stage(compute), workgroup_size(1, 1, 1)]]
fn main() {
let _e0 = test_fma();
return;
}
5 changes: 5 additions & 0 deletions tests/snapshots.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,11 @@ fn convert_wgsl() {
"operators",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
),
(
"functions",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
),
("functions-webgl", Targets::GLSL),
(
"interpolate",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
Expand Down

0 comments on commit 789aa16

Please sign in to comment.