Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for fma function #1580

Merged
merged 7 commits into from
Dec 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions src/back/glsl/features.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{BackendResult, Error, Version, Writer};
use crate::{
Binding, Bytes, Handle, ImageClass, ImageDimension, Interpolation, Sampling, ScalarKind,
ShaderStage, StorageClass, StorageFormat, Type, TypeInner,
Binding, Bytes, Expression, Handle, ImageClass, ImageDimension, Interpolation, MathFunction,
Sampling, ScalarKind, ShaderStage, StorageClass, StorageFormat, Type, TypeInner,
};
use std::fmt::Write;

Expand Down Expand Up @@ -33,6 +33,8 @@ bitflags::bitflags! {
/// Arrays with a dynamic length
const DYNAMIC_ARRAY_SIZE = 1 << 16;
const MULTI_VIEW = 1 << 17;
/// Adds support for fused multiply-add
const FMA = 1 << 18;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need thus flag? If FMA isn't natively supported, we are emulating it anyway. So it seems to me that this flag isn't getting us anything.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For one thing, I'm following precedent with existing features that need extensions, and fma is only supported on GLES 3.1+ with:

#version 310 es
#extension GL_EXT_gpu_shader5 : require

This PR does not emulate fma on GLSL in every case, but decides if it must emulate it or else it requests the extension when necessary. This fixes that unusual validation error you noted earlier: https://github.com/gfx-rs/naga/runs/4446174291?check_suite_focus=true

ERROR: 0:13: 'fma' : required extension not requested: Possible extensions include:
GL_EXT_gpu_shader5
GL_OES_gpu_shader5

I chose to use the existing feature flag infrastructure to write this extension, rather than coming up with something unique just for this case. Is there something better I could have done here?

The FMA feature flag name is probably too narrow, honestly. GL_EXT_gpu_shader5 enables a lot more than just the fma function, and the feature flag can be used to support all of it on GLES.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The capability flags in GLSL backend are meant to be requirements. I.e. shader requires A, B, C, and we want to check if we can work with this shader at all.
The case for FMA is different. The backend always supports FMA instruction. The only thing different is a code path taken. Therefore, there is no case where GLSL backend would check for this capability and report it missing. It's not a real capability.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I understand what you mean. Let me try to rephrase it; The fma function from the frontend is always supported by the backend (GLSL) even if it has to fallback to an arithmetic transformation (i.e., "emulated"). This PR uses a feature flag (capability) in another sense, that it can enable the use of the GLSL fma function on particular versions of the backend. Which are not how feature flags are used elsewhere.

Would that be an accurate way to describe the situation?

I'll have to think on it if I need to use some other mechanism to enable the extension for GLES. I do see an extension enabled that is not controlled by feature flags:

naga/src/back/glsl/mod.rs

Lines 442 to 450 in 7c8bedc

// Write the additional extensions
if self
.options
.writer_flags
.contains(WriterFlags::TEXTURE_SHADOW_LOD)
{
// https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_texture_shadow_lod.txt
writeln!(self.out, "#extension GL_EXT_texture_shadow_lod : require")?;
}
I suppose that would be an appropriate place to put this?

}
}

Expand Down Expand Up @@ -98,6 +100,7 @@ impl FeaturesManager {
check_feature!(SAMPLE_VARIABLES, 400, 300);
check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310);
check_feature!(MULTI_VIEW, 140, 310);
check_feature!(FMA, 400, 310);

// Return an error if there are missing features
if missing.is_empty() {
Expand Down Expand Up @@ -201,6 +204,11 @@ impl FeaturesManager {
writeln!(out, "#extension GL_EXT_multiview : require")?;
}

if self.0.contains(Features::FMA) && version.is_es() {
// https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_gpu_shader5.txt
writeln!(out, "#extension GL_EXT_gpu_shader5 : require")?;
}

Ok(())
}
}
Expand Down Expand Up @@ -347,6 +355,27 @@ impl<'a, W> Writer<'a, W> {
}
}

if self.options.version.supports_fma_function() {
let has_fma = self
.module
.functions
.iter()
.flat_map(|(_, f)| f.expressions.iter())
.chain(
self.module
.entry_points
.iter()
.flat_map(|e| e.function.expressions.iter()),
)
.any(|(_, e)| match *e {
Expression::Math { fun, .. } if fun == MathFunction::Fma => true,
_ => false,
});
if has_fma {
self.features.request(Features::FMA);
}
}

self.features.check_availability(self.options.version)
}

Expand Down
29 changes: 28 additions & 1 deletion src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ impl Version {
fn supports_std430_layout(&self) -> bool {
*self >= Version::Desktop(430) || *self >= Version::Embedded(310)
}

fn supports_fma_function(&self) -> bool {
*self >= Version::Desktop(400) || *self >= Version::Embedded(310)
}
}

impl PartialOrd for Version {
Expand Down Expand Up @@ -2433,7 +2437,30 @@ impl<'a, W: Write> Writer<'a, W> {
Mf::Refract => "refract",
// computational
Mf::Sign => "sign",
Mf::Fma => "fma",
Mf::Fma => {
if self.options.version.supports_fma_function() {
// Use the fma function when available
"fma"
} else {
// No fma support. Transform the function call into an arithmetic expression
write!(self.out, "(")?;

self.write_expr(arg, ctx)?;
write!(self.out, " * ")?;

let arg1 =
arg1.ok_or_else(|| Error::Custom("Missing fma arg1".to_owned()))?;
self.write_expr(arg1, ctx)?;
write!(self.out, " + ")?;

let arg2 =
arg2.ok_or_else(|| Error::Custom("Missing fma arg2".to_owned()))?;
self.write_expr(arg2, ctx)?;
write!(self.out, ")")?;

return Ok(());
}
}
Mf::Mix => "mix",
Mf::Step => "step",
Mf::SmoothStep => "smoothstep",
Expand Down
2 changes: 1 addition & 1 deletion src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1861,7 +1861,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Mf::Refract => Function::Regular("refract"),
// computational
Mf::Sign => Function::Regular("sign"),
Mf::Fma => Function::Regular("fma"),
Mf::Fma => Function::Regular("mad"),
Mf::Mix => Function::Regular("lerp"),
Mf::Step => Function::Regular("step"),
Mf::SmoothStep => Function::Regular("smoothstep"),
Expand Down
7 changes: 7 additions & 0 deletions tests/in/functions-webgl.param.ron
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
(
glsl: (
version: Embedded(300),
writer_flags: (bits: 0),
binding_map: {},
),
)
13 changes: 13 additions & 0 deletions tests/in/functions-webgl.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
fn test_fma() -> vec2<f32> {
let a = vec2<f32>(2.0, 2.0);
let b = vec2<f32>(0.5, 0.5);
let c = vec2<f32>(0.5, 0.5);

return fma(a, b, c);
}


[[stage(vertex)]]
fn main() {
let a = test_fma();
}
2 changes: 2 additions & 0 deletions tests/in/functions.param.ron
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
(
)
15 changes: 15 additions & 0 deletions tests/in/functions.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
fn test_fma() -> vec2<f32> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably move a few things from operators.wgsl into here

let a = vec2<f32>(2.0, 2.0);
let b = vec2<f32>(0.5, 0.5);
let c = vec2<f32>(0.5, 0.5);

// Hazard: HLSL needs a different intrinsic function for f32 and f64
// See: https://github.com/gfx-rs/naga/issues/1579
return fma(a, b, c);
}


[[stage(compute), workgroup_size(1)]]
fn main() {
let a = test_fma();
}
18 changes: 18 additions & 0 deletions tests/out/glsl/functions-webgl.main.Vertex.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#version 300 es

precision highp float;
precision highp int;


vec2 test_fma() {
vec2 a = vec2(2.0, 2.0);
vec2 b = vec2(0.5, 0.5);
vec2 c = vec2(0.5, 0.5);
return (a * b + c);
}

void main() {
vec2 _e0 = test_fma();
return;
}

21 changes: 21 additions & 0 deletions tests/out/glsl/functions.main.Compute.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#version 310 es
#extension GL_EXT_gpu_shader5 : require

precision highp float;
precision highp int;

layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;


vec2 test_fma() {
vec2 a = vec2(2.0, 2.0);
vec2 b = vec2(0.5, 0.5);
vec2 c = vec2(0.5, 0.5);
return fma(a, b, c);
}

void main() {
vec2 _e0 = test_fma();
return;
}

15 changes: 15 additions & 0 deletions tests/out/hlsl/functions.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

float2 test_fma()
{
float2 a = float2(2.0, 2.0);
float2 b = float2(0.5, 0.5);
float2 c = float2(0.5, 0.5);
return mad(a, b, c);
}

[numthreads(1, 1, 1)]
void main()
{
const float2 _e0 = test_fma();
return;
}
3 changes: 3 additions & 0 deletions tests/out/hlsl/functions.hlsl.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
vertex=()
fragment=()
compute=(main:cs_5_1 )
18 changes: 18 additions & 0 deletions tests/out/msl/functions.msl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// language: metal1.1
#include <metal_stdlib>
#include <simd/simd.h>


metal::float2 test_fma(
) {
metal::float2 a = metal::float2(2.0, 2.0);
metal::float2 b = metal::float2(0.5, 0.5);
metal::float2 c = metal::float2(0.5, 0.5);
return metal::fma(a, b, c);
}

kernel void main_(
) {
metal::float2 _e0 = test_fma();
return;
}
33 changes: 33 additions & 0 deletions tests/out/spv/functions.spvasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 20
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %16 "main"
OpExecutionMode %16 LocalSize 1 1 1
%2 = OpTypeVoid
%4 = OpTypeFloat 32
%3 = OpConstant %4 2.0
%5 = OpConstant %4 0.5
%6 = OpTypeVector %4 2
%9 = OpTypeFunction %6
%17 = OpTypeFunction %2
%8 = OpFunction %6 None %9
%7 = OpLabel
OpBranch %10
%10 = OpLabel
%11 = OpCompositeConstruct %6 %3 %3
%12 = OpCompositeConstruct %6 %5 %5
%13 = OpCompositeConstruct %6 %5 %5
%14 = OpExtInst %6 %1 Fma %11 %12 %13
OpReturnValue %14
OpFunctionEnd
%16 = OpFunction %2 None %17
%15 = OpLabel
OpBranch %18
%18 = OpLabel
%19 = OpFunctionCall %6 %8
OpReturn
OpFunctionEnd
12 changes: 12 additions & 0 deletions tests/out/wgsl/functions.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
fn test_fma() -> vec2<f32> {
let a = vec2<f32>(2.0, 2.0);
let b = vec2<f32>(0.5, 0.5);
let c = vec2<f32>(0.5, 0.5);
return fma(a, b, c);
}

[[stage(compute), workgroup_size(1, 1, 1)]]
fn main() {
let _e0 = test_fma();
return;
}
5 changes: 5 additions & 0 deletions tests/snapshots.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,11 @@ fn convert_wgsl() {
"operators",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
),
(
"functions",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
),
("functions-webgl", Targets::GLSL),
(
"interpolate",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
Expand Down