Skip to content

Commit

Permalink
More thorough typechecking of the struct returned by `atomicCompareEx…
Browse files Browse the repository at this point in the history
…changeWeak`.
  • Loading branch information
aweinstock314 committed Dec 12, 2022
1 parent 8ce2ed7 commit 8e06c82
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 45 deletions.
16 changes: 10 additions & 6 deletions src/front/wgsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1626,9 +1626,14 @@ impl Parser {

let expression = match *ctx.resolve_type(value)? {
crate::TypeInner::Scalar { kind, width } => crate::Expression::AtomicResult {
kind,
width,
comparison: None,
ty: ctx.types.insert(
crate::Type {
name: None,
inner: crate::TypeInner::Scalar { kind, width },
},
NagaSpan::UNDEFINED,
),
comparison: false,
},
_ => return Err(Error::InvalidAtomicOperandType(value_span)),
};
Expand Down Expand Up @@ -1898,9 +1903,8 @@ impl Parser {
NagaSpan::UNDEFINED,
);
crate::Expression::AtomicResult {
kind,
width,
comparison: Some(struct_ty),
ty: struct_ty,
comparison: true,
}
}
_ => return Err(Error::InvalidAtomicOperandType(value_span)),
Expand Down
6 changes: 1 addition & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1401,11 +1401,7 @@ pub enum Expression {
/// Result of calling another function.
CallResult(Handle<Function>),
/// Result of an atomic operation.
AtomicResult {
kind: ScalarKind,
width: Bytes,
comparison: Option<Handle<Type>>,
},
AtomicResult { ty: Handle<Type>, comparison: bool },
/// Get the length of an array.
/// The expression must resolve to a pointer to an array with a dynamic size.
///
Expand Down
12 changes: 1 addition & 11 deletions src/proc/typifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -644,17 +644,7 @@ impl<'a> ResolveContext<'a> {
| crate::BinaryOperator::ShiftLeft
| crate::BinaryOperator::ShiftRight => past(left)?.clone(),
},
crate::Expression::AtomicResult {
kind,
width,
comparison,
} => {
if let Some(struct_ty) = comparison {
TypeResolution::Handle(struct_ty)
} else {
TypeResolution::Value(Ti::Scalar { kind, width })
}
}
crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty),
crate::Expression::Select { accept, .. } => past(accept)?.clone(),
crate::Expression::Derivative { axis: _, expr } => past(expr)?.clone(),
crate::Expression::Relational { fun, argument } => match fun {
Expand Down
31 changes: 20 additions & 11 deletions src/valid/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use super::{compose::validate_compose, FunctionInfo, ShaderStages, TypeFlags};
#[cfg(feature = "validate")]
use crate::arena::UniqueArena;
use crate::valid::validate_atomic_compare_exchange_struct;

use crate::{
arena::{BadHandle, Handle},
Expand Down Expand Up @@ -115,8 +116,8 @@ pub enum ExpressionError {
WrongArgumentCount(crate::MathFunction),
#[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")]
InvalidArgumentType(crate::MathFunction, u32, Handle<crate::Expression>),
#[error("Atomic result type can't be {0:?} of {1} bytes")]
InvalidAtomicResultType(crate::ScalarKind, crate::Bytes),
#[error("Atomic result type can't be {0:?}")]
InvalidAtomicResultType(Handle<crate::Type>),
#[error("Shader requires capability {0:?}")]
MissingCapabilities(super::Capabilities),
}
Expand Down Expand Up @@ -1389,19 +1390,27 @@ impl super::Validator {
ShaderStages::all()
}
E::CallResult(function) => other_infos[function.index()].available_stages,
E::AtomicResult {
kind,
width,
comparison: _,
} => {
let good = match kind {
crate::ScalarKind::Uint | crate::ScalarKind::Sint => {
self.check_width(kind, width)
E::AtomicResult { ty, comparison } => {
let scalar_predicate = |ty: &crate::TypeInner| match ty {
&crate::TypeInner::Scalar {
kind: kind @ (crate::ScalarKind::Uint | crate::ScalarKind::Sint),
width,
} => self.check_width(kind, width),
_ => false,
};
let good = match &module.types[ty].inner {
ty if !comparison => scalar_predicate(ty),
&crate::TypeInner::Struct { ref members, .. } if comparison => {
validate_atomic_compare_exchange_struct(
&module.types,
members,
scalar_predicate,
)
}
_ => false,
};
if !good {
return Err(ExpressionError::InvalidAtomicResultType(kind, width));
return Err(ExpressionError::InvalidAtomicResultType(ty));
}
ShaderStages::all()
}
Expand Down
33 changes: 21 additions & 12 deletions src/valid/function.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#[cfg(feature = "validate")]
use crate::arena::{Arena, UniqueArena};
use crate::arena::{BadHandle, Handle};
use crate::valid::validate_atomic_compare_exchange_struct;

use super::{
analyzer::{UniformityDisruptor, UniformityRequirements},
Expand Down Expand Up @@ -363,18 +364,26 @@ impl super::Validator {
.into_other());
}
match context.expressions[result] {
//TODO: does the result of an atomicCompareExchange need additional validation, or does the existing validation for
// the struct type it returns suffice?
crate::Expression::AtomicResult {
kind,
width,
comparison: Some(_),
} if kind == ptr_kind && width == ptr_width => {}
crate::Expression::AtomicResult {
kind,
width,
comparison: None,
} if kind == ptr_kind && width == ptr_width => {}
crate::Expression::AtomicResult { ty, comparison }
if {
let scalar_predicate = |ty: &crate::TypeInner| {
*ty == crate::TypeInner::Scalar {
kind: ptr_kind,
width: ptr_width,
}
};
match &context.types[ty].inner {
ty if !comparison => scalar_predicate(ty),
&crate::TypeInner::Struct { ref members, .. } if comparison => {
validate_atomic_compare_exchange_struct(
context.types,
members,
scalar_predicate,
)
}
_ => false,
}
} => {}
_ => {
return Err(AtomicError::ResultTypeMismatch(result)
.with_span_handle(result, context.expressions)
Expand Down
16 changes: 16 additions & 0 deletions src/valid/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,3 +412,19 @@ impl Validator {
Ok(mod_info)
}
}

fn validate_atomic_compare_exchange_struct(
types: &UniqueArena<crate::Type>,
members: &[crate::StructMember],
scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool,
) -> bool {
members.len() == 2
&& members[0].name.as_deref() == Some("old_value")
&& scalar_predicate(&types[members[0].ty].inner)
&& members[1].name.as_deref() == Some("exchanged")
&& types[members[1].ty].inner
== crate::TypeInner::Scalar {
kind: crate::ScalarKind::Bool,
width: crate::BOOL_WIDTH,
}
}

0 comments on commit 8e06c82

Please sign in to comment.