From 462a34d1532d01f0f33242ea7ee0b12875dc46fa Mon Sep 17 00:00:00 2001 From: Avi Weinstock Date: Mon, 12 Dec 2022 14:29:55 -0500 Subject: [PATCH] More thorough typechecking of the struct returned by `atomicCompareExchangeWeak`. --- src/front/wgsl/mod.rs | 16 ++++++++++------ src/lib.rs | 6 +----- src/proc/typifier.rs | 12 +----------- src/valid/expression.rs | 35 +++++++++++++++++++++++------------ src/valid/function.rs | 35 +++++++++++++++++++++++------------ src/valid/mod.rs | 17 +++++++++++++++++ 6 files changed, 75 insertions(+), 46 deletions(-) diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index 91feb9dfd6..118dfeb17c 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -1626,9 +1626,14 @@ impl Parser { let expression = match *ctx.resolve_type(value)? { crate::TypeInner::Scalar { kind, width } => crate::Expression::AtomicResult { - kind, - width, - comparison: None, + ty: ctx.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar { kind, width }, + }, + NagaSpan::UNDEFINED, + ), + comparison: false, }, _ => return Err(Error::InvalidAtomicOperandType(value_span)), }; @@ -1898,9 +1903,8 @@ impl Parser { NagaSpan::UNDEFINED, ); crate::Expression::AtomicResult { - kind, - width, - comparison: Some(struct_ty), + ty: struct_ty, + comparison: true, } } _ => return Err(Error::InvalidAtomicOperandType(value_span)), diff --git a/src/lib.rs b/src/lib.rs index 3c61fb2e69..ee02546835 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1401,11 +1401,7 @@ pub enum Expression { /// Result of calling another function. CallResult(Handle), /// Result of an atomic operation. - AtomicResult { - kind: ScalarKind, - width: Bytes, - comparison: Option>, - }, + AtomicResult { ty: Handle, comparison: bool }, /// Get the length of an array. /// The expression must resolve to a pointer to an array with a dynamic size. /// diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index 7669708e8a..9df538cc2b 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -644,17 +644,7 @@ impl<'a> ResolveContext<'a> { | crate::BinaryOperator::ShiftLeft | crate::BinaryOperator::ShiftRight => past(left)?.clone(), }, - crate::Expression::AtomicResult { - kind, - width, - comparison, - } => { - if let Some(struct_ty) = comparison { - TypeResolution::Handle(struct_ty) - } else { - TypeResolution::Value(Ti::Scalar { kind, width }) - } - } + crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty), crate::Expression::Select { accept, .. } => past(accept)?.clone(), crate::Expression::Derivative { axis: _, expr } => past(expr)?.clone(), crate::Expression::Relational { fun, argument } => match fun { diff --git a/src/valid/expression.rs b/src/valid/expression.rs index 20a6237d97..a3afc0535a 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -1,5 +1,8 @@ #[cfg(feature = "validate")] -use super::{compose::validate_compose, FunctionInfo, ShaderStages, TypeFlags}; +use super::{ + compose::validate_compose, validate_atomic_compare_exchange_struct, FunctionInfo, ShaderStages, + TypeFlags, +}; #[cfg(feature = "validate")] use crate::arena::UniqueArena; @@ -115,8 +118,8 @@ pub enum ExpressionError { WrongArgumentCount(crate::MathFunction), #[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")] InvalidArgumentType(crate::MathFunction, u32, Handle), - #[error("Atomic result type can't be {0:?} of {1} bytes")] - InvalidAtomicResultType(crate::ScalarKind, crate::Bytes), + #[error("Atomic result type can't be {0:?}")] + InvalidAtomicResultType(Handle), #[error("Shader requires capability {0:?}")] MissingCapabilities(super::Capabilities), } @@ -1389,19 +1392,27 @@ impl super::Validator { ShaderStages::all() } E::CallResult(function) => other_infos[function.index()].available_stages, - E::AtomicResult { - kind, - width, - comparison: _, - } => { - let good = match kind { - crate::ScalarKind::Uint | crate::ScalarKind::Sint => { - self.check_width(kind, width) + E::AtomicResult { ty, comparison } => { + let scalar_predicate = |ty: &crate::TypeInner| match ty { + &crate::TypeInner::Scalar { + kind: kind @ (crate::ScalarKind::Uint | crate::ScalarKind::Sint), + width, + } => self.check_width(kind, width), + _ => false, + }; + let good = match &module.types[ty].inner { + ty if !comparison => scalar_predicate(ty), + &crate::TypeInner::Struct { ref members, .. } if comparison => { + validate_atomic_compare_exchange_struct( + &module.types, + members, + scalar_predicate, + ) } _ => false, }; if !good { - return Err(ExpressionError::InvalidAtomicResultType(kind, width)); + return Err(ExpressionError::InvalidAtomicResultType(ty)); } ShaderStages::all() } diff --git a/src/valid/function.rs b/src/valid/function.rs index ee02315e7a..0f0a7b89f5 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -2,6 +2,9 @@ use crate::arena::{Arena, UniqueArena}; use crate::arena::{BadHandle, Handle}; +#[cfg(feature = "validate")] +use super::validate_atomic_compare_exchange_struct; + use super::{ analyzer::{UniformityDisruptor, UniformityRequirements}, ExpressionError, FunctionInfo, ModuleInfo, @@ -363,18 +366,26 @@ impl super::Validator { .into_other()); } match context.expressions[result] { - //TODO: does the result of an atomicCompareExchange need additional validation, or does the existing validation for - // the struct type it returns suffice? - crate::Expression::AtomicResult { - kind, - width, - comparison: Some(_), - } if kind == ptr_kind && width == ptr_width => {} - crate::Expression::AtomicResult { - kind, - width, - comparison: None, - } if kind == ptr_kind && width == ptr_width => {} + crate::Expression::AtomicResult { ty, comparison } + if { + let scalar_predicate = |ty: &crate::TypeInner| { + *ty == crate::TypeInner::Scalar { + kind: ptr_kind, + width: ptr_width, + } + }; + match &context.types[ty].inner { + ty if !comparison => scalar_predicate(ty), + &crate::TypeInner::Struct { ref members, .. } if comparison => { + validate_atomic_compare_exchange_struct( + context.types, + members, + scalar_predicate, + ) + } + _ => false, + } + } => {} _ => { return Err(AtomicError::ResultTypeMismatch(result) .with_span_handle(result, context.expressions) diff --git a/src/valid/mod.rs b/src/valid/mod.rs index be27316299..17f3d8da0a 100644 --- a/src/valid/mod.rs +++ b/src/valid/mod.rs @@ -412,3 +412,20 @@ impl Validator { Ok(mod_info) } } + +#[cfg(feature = "validate")] +fn validate_atomic_compare_exchange_struct( + types: &UniqueArena, + members: &[crate::StructMember], + scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool, +) -> bool { + members.len() == 2 + && members[0].name.as_deref() == Some("old_value") + && scalar_predicate(&types[members[0].ty].inner) + && members[1].name.as_deref() == Some("exchanged") + && types[members[1].ty].inner + == crate::TypeInner::Scalar { + kind: crate::ScalarKind::Bool, + width: crate::BOOL_WIDTH, + } +}