diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 10fd5d72aa..262548eab9 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -2079,8 +2079,50 @@ impl<'w> BlockContext<'w> { value_id, ) } - crate::AtomicFunction::Exchange { compare: Some(_) } => { - return Err(Error::FeatureNotImplemented("atomic CompareExchange")); + crate::AtomicFunction::Exchange { compare: Some(cmp) } => { + let scalar_type_id = match *value_inner { + crate::TypeInner::Scalar { kind, width } => { + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind, + width, + pointer_space: None, + })) + } + _ => unimplemented!(), + }; + let bool_type_id = + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Bool, + width: crate::BOOL_WIDTH, + pointer_space: None, + })); + + let cas_result_id = self.gen_id(); + let equality_result_id = self.gen_id(); + let mut cas_instr = Instruction::new(spirv::Op::AtomicCompareExchange); + cas_instr.set_type(scalar_type_id); + cas_instr.set_result(cas_result_id); + cas_instr.add_operand(pointer_id); + cas_instr.add_operand(scope_constant_id); + cas_instr.add_operand(semantics_id); // semantics if equal + cas_instr.add_operand(semantics_id); // semantics if not equal + cas_instr.add_operand(value_id); + cas_instr.add_operand(self.cached[cmp]); + block.body.push(cas_instr); + block.body.push(Instruction::binary( + spirv::Op::IEqual, + bool_type_id, + equality_result_id, + cas_result_id, + self.cached[cmp], + )); + Instruction::composite_construct( + result_type_id, + id, + &[cas_result_id, equality_result_id], + ) } }; diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index 2873e6c73c..118dfeb17c 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -1626,8 +1626,13 @@ impl Parser { let expression = match *ctx.resolve_type(value)? { crate::TypeInner::Scalar { kind, width } => crate::Expression::AtomicResult { - kind, - width, + ty: ctx.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar { kind, width }, + }, + NagaSpan::UNDEFINED, + ), comparison: false, }, _ => return Err(Error::InvalidAtomicOperandType(value_span)), @@ -1857,9 +1862,48 @@ impl Parser { let expression = match *ctx.resolve_type(value)? { crate::TypeInner::Scalar { kind, width } => { + let bool_ty = ctx.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar { + kind: crate::ScalarKind::Bool, + width: crate::BOOL_WIDTH, + }, + }, + NagaSpan::UNDEFINED, + ); + let scalar_ty = ctx.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar { kind, width }, + }, + NagaSpan::UNDEFINED, + ); + let struct_ty = ctx.types.insert( + crate::Type { + name: Some("__atomic_compare_exchange_result".to_string()), + inner: crate::TypeInner::Struct { + members: vec![ + crate::StructMember { + name: Some("old_value".to_string()), + ty: scalar_ty, + binding: None, + offset: 0, + }, + crate::StructMember { + name: Some("exchanged".to_string()), + ty: bool_ty, + binding: None, + offset: 4, + }, + ], + span: 8, + }, + }, + NagaSpan::UNDEFINED, + ); crate::Expression::AtomicResult { - kind, - width, + ty: struct_ty, comparison: true, } } diff --git a/src/lib.rs b/src/lib.rs index e122d1224c..ee02546835 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1401,11 +1401,7 @@ pub enum Expression { /// Result of calling another function. CallResult(Handle), /// Result of an atomic operation. - AtomicResult { - kind: ScalarKind, - width: Bytes, - comparison: bool, - }, + AtomicResult { ty: Handle, comparison: bool }, /// Get the length of an array. /// The expression must resolve to a pointer to an array with a dynamic size. /// diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index 9a5922ea76..9df538cc2b 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -644,21 +644,7 @@ impl<'a> ResolveContext<'a> { | crate::BinaryOperator::ShiftLeft | crate::BinaryOperator::ShiftRight => past(left)?.clone(), }, - crate::Expression::AtomicResult { - kind, - width, - comparison, - } => { - if comparison { - TypeResolution::Value(Ti::Vector { - size: crate::VectorSize::Bi, - kind, - width, - }) - } else { - TypeResolution::Value(Ti::Scalar { kind, width }) - } - } + crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty), crate::Expression::Select { accept, .. } => past(accept)?.clone(), crate::Expression::Derivative { axis: _, expr } => past(expr)?.clone(), crate::Expression::Relational { fun, argument } => match fun { diff --git a/src/valid/expression.rs b/src/valid/expression.rs index 20a6237d97..a3afc0535a 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -1,5 +1,8 @@ #[cfg(feature = "validate")] -use super::{compose::validate_compose, FunctionInfo, ShaderStages, TypeFlags}; +use super::{ + compose::validate_compose, validate_atomic_compare_exchange_struct, FunctionInfo, ShaderStages, + TypeFlags, +}; #[cfg(feature = "validate")] use crate::arena::UniqueArena; @@ -115,8 +118,8 @@ pub enum ExpressionError { WrongArgumentCount(crate::MathFunction), #[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")] InvalidArgumentType(crate::MathFunction, u32, Handle), - #[error("Atomic result type can't be {0:?} of {1} bytes")] - InvalidAtomicResultType(crate::ScalarKind, crate::Bytes), + #[error("Atomic result type can't be {0:?}")] + InvalidAtomicResultType(Handle), #[error("Shader requires capability {0:?}")] MissingCapabilities(super::Capabilities), } @@ -1389,19 +1392,27 @@ impl super::Validator { ShaderStages::all() } E::CallResult(function) => other_infos[function.index()].available_stages, - E::AtomicResult { - kind, - width, - comparison: _, - } => { - let good = match kind { - crate::ScalarKind::Uint | crate::ScalarKind::Sint => { - self.check_width(kind, width) + E::AtomicResult { ty, comparison } => { + let scalar_predicate = |ty: &crate::TypeInner| match ty { + &crate::TypeInner::Scalar { + kind: kind @ (crate::ScalarKind::Uint | crate::ScalarKind::Sint), + width, + } => self.check_width(kind, width), + _ => false, + }; + let good = match &module.types[ty].inner { + ty if !comparison => scalar_predicate(ty), + &crate::TypeInner::Struct { ref members, .. } if comparison => { + validate_atomic_compare_exchange_struct( + &module.types, + members, + scalar_predicate, + ) } _ => false, }; if !good { - return Err(ExpressionError::InvalidAtomicResultType(kind, width)); + return Err(ExpressionError::InvalidAtomicResultType(ty)); } ShaderStages::all() } diff --git a/src/valid/function.rs b/src/valid/function.rs index 5684f670fe..0f0a7b89f5 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -2,6 +2,9 @@ use crate::arena::{Arena, UniqueArena}; use crate::arena::{BadHandle, Handle}; +#[cfg(feature = "validate")] +use super::validate_atomic_compare_exchange_struct; + use super::{ analyzer::{UniformityDisruptor, UniformityRequirements}, ExpressionError, FunctionInfo, ModuleInfo, @@ -363,12 +366,26 @@ impl super::Validator { .into_other()); } match context.expressions[result] { - //TODO: support atomic result with comparison - crate::Expression::AtomicResult { - kind, - width, - comparison: false, - } if kind == ptr_kind && width == ptr_width => {} + crate::Expression::AtomicResult { ty, comparison } + if { + let scalar_predicate = |ty: &crate::TypeInner| { + *ty == crate::TypeInner::Scalar { + kind: ptr_kind, + width: ptr_width, + } + }; + match &context.types[ty].inner { + ty if !comparison => scalar_predicate(ty), + &crate::TypeInner::Struct { ref members, .. } if comparison => { + validate_atomic_compare_exchange_struct( + context.types, + members, + scalar_predicate, + ) + } + _ => false, + } + } => {} _ => { return Err(AtomicError::ResultTypeMismatch(result) .with_span_handle(result, context.expressions) diff --git a/src/valid/mod.rs b/src/valid/mod.rs index be27316299..17f3d8da0a 100644 --- a/src/valid/mod.rs +++ b/src/valid/mod.rs @@ -412,3 +412,20 @@ impl Validator { Ok(mod_info) } } + +#[cfg(feature = "validate")] +fn validate_atomic_compare_exchange_struct( + types: &UniqueArena, + members: &[crate::StructMember], + scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool, +) -> bool { + members.len() == 2 + && members[0].name.as_deref() == Some("old_value") + && scalar_predicate(&types[members[0].ty].inner) + && members[1].name.as_deref() == Some("exchanged") + && types[members[1].ty].inner + == crate::TypeInner::Scalar { + kind: crate::ScalarKind::Bool, + width: crate::BOOL_WIDTH, + } +} diff --git a/tests/in/atomicCompareExchange.wgsl b/tests/in/atomicCompareExchange.wgsl new file mode 100644 index 0000000000..35b91aaf7f --- /dev/null +++ b/tests/in/atomicCompareExchange.wgsl @@ -0,0 +1,34 @@ +let SIZE: u32 = 128u; + +@group(0) @binding(0) +var arr_i32: array, SIZE>; +@group(0) @binding(1) +var arr_u32: array, SIZE>; + +@compute @workgroup_size(1) +fn test_atomic_compare_exchange_i32() { + for(var i = 0u; i < SIZE; i++) { + var old = atomicLoad(&arr_i32[i]); + var exchanged = false; + while(!exchanged) { + let new_ = bitcast(bitcast(old) + 1.0); + let result = atomicCompareExchangeWeak(&arr_i32[i], old, new_); + old = result.old_value; + exchanged = result.exchanged; + } + } +} + +@compute @workgroup_size(1) +fn test_atomic_compare_exchange_u32() { + for(var i = 0u; i < SIZE; i++) { + var old = atomicLoad(&arr_u32[i]); + var exchanged = false; + while(!exchanged) { + let new_ = bitcast(bitcast(old) + 1.0); + let result = atomicCompareExchangeWeak(&arr_u32[i], old, new_); + old = result.old_value; + exchanged = result.exchanged; + } + } +} diff --git a/tests/out/spv/atomicCompareExchange.spvasm b/tests/out/spv/atomicCompareExchange.spvasm new file mode 100644 index 0000000000..9587f9aaa6 --- /dev/null +++ b/tests/out/spv/atomicCompareExchange.spvasm @@ -0,0 +1,188 @@ +; SPIR-V +; Version: 1.1 +; Generator: rspirv +; Bound: 116 +OpCapability Shader +OpExtension "SPV_KHR_storage_buffer_storage_class" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %31 "test_atomic_compare_exchange_i32" +OpEntryPoint GLCompute %79 "test_atomic_compare_exchange_u32" +OpExecutionMode %31 LocalSize 1 1 1 +OpExecutionMode %79 LocalSize 1 1 1 +OpDecorate %12 ArrayStride 4 +OpDecorate %13 ArrayStride 4 +OpMemberDecorate %14 0 Offset 0 +OpMemberDecorate %14 1 Offset 4 +OpMemberDecorate %15 0 Offset 0 +OpMemberDecorate %15 1 Offset 4 +OpDecorate %16 DescriptorSet 0 +OpDecorate %16 Binding 0 +OpDecorate %17 Block +OpMemberDecorate %17 0 Offset 0 +OpDecorate %19 DescriptorSet 0 +OpDecorate %19 Binding 1 +OpDecorate %20 Block +OpMemberDecorate %20 0 Offset 0 +%2 = OpTypeVoid +%4 = OpTypeInt 32 0 +%3 = OpConstant %4 128 +%5 = OpConstant %4 0 +%6 = OpConstant %4 1 +%8 = OpTypeBool +%7 = OpConstantFalse %8 +%10 = OpTypeFloat 32 +%9 = OpConstant %10 1.0 +%11 = OpTypeInt 32 1 +%12 = OpTypeArray %11 %3 +%13 = OpTypeArray %4 %3 +%14 = OpTypeStruct %11 %8 +%15 = OpTypeStruct %4 %8 +%17 = OpTypeStruct %12 +%18 = OpTypePointer StorageBuffer %17 +%16 = OpVariable %18 StorageBuffer +%20 = OpTypeStruct %13 +%21 = OpTypePointer StorageBuffer %20 +%19 = OpVariable %21 StorageBuffer +%23 = OpTypePointer Function %4 +%25 = OpTypePointer Function %11 +%26 = OpConstantNull %11 +%28 = OpTypePointer Function %8 +%29 = OpConstantNull %8 +%32 = OpTypeFunction %2 +%33 = OpTypePointer StorageBuffer %12 +%35 = OpTypePointer StorageBuffer %13 +%46 = OpTypePointer StorageBuffer %11 +%49 = OpConstant %11 1 +%50 = OpConstant %4 64 +%75 = OpConstantNull %4 +%77 = OpConstantNull %8 +%91 = OpTypePointer StorageBuffer %4 +%31 = OpFunction %2 None %32 +%30 = OpLabel +%22 = OpVariable %23 Function %5 +%24 = OpVariable %25 Function %26 +%27 = OpVariable %28 Function %29 +%34 = OpAccessChain %33 %16 %5 +OpBranch %36 +%36 = OpLabel +OpBranch %37 +%37 = OpLabel +OpLoopMerge %38 %40 None +OpBranch %39 +%39 = OpLabel +%41 = OpLoad %4 %22 +%42 = OpULessThan %8 %41 %3 +OpSelectionMerge %43 None +OpBranchConditional %42 %43 %44 +%44 = OpLabel +OpBranch %38 +%43 = OpLabel +%45 = OpLoad %4 %22 +%47 = OpAccessChain %46 %34 %45 +%48 = OpAtomicLoad %11 %47 %49 %50 +OpStore %24 %48 +OpStore %27 %7 +OpBranch %51 +%51 = OpLabel +OpLoopMerge %52 %54 None +OpBranch %53 +%53 = OpLabel +%55 = OpLoad %8 %27 +%56 = OpLogicalNot %8 %55 +OpSelectionMerge %57 None +OpBranchConditional %56 %57 %58 +%58 = OpLabel +OpBranch %52 +%57 = OpLabel +%59 = OpLoad %11 %24 +%60 = OpBitcast %10 %59 +%61 = OpFAdd %10 %60 %9 +%62 = OpBitcast %11 %61 +%63 = OpLoad %4 %22 +%64 = OpLoad %11 %24 +%66 = OpAccessChain %46 %34 %63 +%67 = OpAtomicCompareExchange %11 %66 %49 %50 %50 %62 %64 +%68 = OpIEqual %8 %67 %64 +%65 = OpCompositeConstruct %14 %67 %68 +%69 = OpCompositeExtract %11 %65 0 +OpStore %24 %69 +%70 = OpCompositeExtract %8 %65 1 +OpStore %27 %70 +OpBranch %54 +%54 = OpLabel +OpBranch %51 +%52 = OpLabel +OpBranch %40 +%40 = OpLabel +%71 = OpLoad %4 %22 +%72 = OpIAdd %4 %71 %6 +OpStore %22 %72 +OpBranch %37 +%38 = OpLabel +OpReturn +OpFunctionEnd +%79 = OpFunction %2 None %32 +%78 = OpLabel +%73 = OpVariable %23 Function %5 +%74 = OpVariable %23 Function %75 +%76 = OpVariable %28 Function %77 +%80 = OpAccessChain %35 %19 %5 +OpBranch %81 +%81 = OpLabel +OpBranch %82 +%82 = OpLabel +OpLoopMerge %83 %85 None +OpBranch %84 +%84 = OpLabel +%86 = OpLoad %4 %73 +%87 = OpULessThan %8 %86 %3 +OpSelectionMerge %88 None +OpBranchConditional %87 %88 %89 +%89 = OpLabel +OpBranch %83 +%88 = OpLabel +%90 = OpLoad %4 %73 +%92 = OpAccessChain %91 %80 %90 +%93 = OpAtomicLoad %4 %92 %49 %50 +OpStore %74 %93 +OpStore %76 %7 +OpBranch %94 +%94 = OpLabel +OpLoopMerge %95 %97 None +OpBranch %96 +%96 = OpLabel +%98 = OpLoad %8 %76 +%99 = OpLogicalNot %8 %98 +OpSelectionMerge %100 None +OpBranchConditional %99 %100 %101 +%101 = OpLabel +OpBranch %95 +%100 = OpLabel +%102 = OpLoad %4 %74 +%103 = OpBitcast %10 %102 +%104 = OpFAdd %10 %103 %9 +%105 = OpBitcast %4 %104 +%106 = OpLoad %4 %73 +%107 = OpLoad %4 %74 +%109 = OpAccessChain %91 %80 %106 +%110 = OpAtomicCompareExchange %4 %109 %49 %50 %50 %105 %107 +%111 = OpIEqual %8 %110 %107 +%108 = OpCompositeConstruct %15 %110 %111 +%112 = OpCompositeExtract %4 %108 0 +OpStore %74 %112 +%113 = OpCompositeExtract %8 %108 1 +OpStore %76 %113 +OpBranch %97 +%97 = OpLabel +OpBranch %94 +%95 = OpLabel +OpBranch %85 +%85 = OpLabel +%114 = OpLoad %4 %73 +%115 = OpIAdd %4 %114 %6 +OpStore %73 %115 +OpBranch %82 +%83 = OpLabel +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/atomicCompareExchange.wgsl b/tests/out/wgsl/atomicCompareExchange.wgsl new file mode 100644 index 0000000000..b1e2d06a4a --- /dev/null +++ b/tests/out/wgsl/atomicCompareExchange.wgsl @@ -0,0 +1,92 @@ +struct gen___atomic_compare_exchange_result { + old_value: i32, + exchanged: bool, +} + +struct gen___atomic_compare_exchange_result_1 { + old_value: u32, + exchanged: bool, +} + +let SIZE: u32 = 128u; + +@group(0) @binding(0) +var arr_i32_: array,SIZE>; +@group(0) @binding(1) +var arr_u32_: array,SIZE>; + +@compute @workgroup_size(1, 1, 1) +fn test_atomic_compare_exchange_i32_() { + var i: u32 = 0u; + var old: i32; + var exchanged: bool; + + loop { + let _e5 = i; + if (_e5 < SIZE) { + } else { + break; + } + let _e10 = i; + let _e12 = atomicLoad((&arr_i32_[_e10])); + old = _e12; + exchanged = false; + loop { + let _e16 = exchanged; + if !(_e16) { + } else { + break; + } + let _e18 = old; + let new_ = bitcast((bitcast(_e18) + 1.0)); + let _e23 = i; + let _e25 = old; + let _e26 = atomicCompareExchangeWeak((&arr_i32_[_e23]), _e25, new_); + old = _e26.old_value; + exchanged = _e26.exchanged; + } + continuing { + let _e7 = i; + i = (_e7 + 1u); + } + } + return; +} + +@compute @workgroup_size(1, 1, 1) +fn test_atomic_compare_exchange_u32_() { + var i_1: u32 = 0u; + var old_1: u32; + var exchanged_1: bool; + + loop { + let _e5 = i_1; + if (_e5 < SIZE) { + } else { + break; + } + let _e10 = i_1; + let _e12 = atomicLoad((&arr_u32_[_e10])); + old_1 = _e12; + exchanged_1 = false; + loop { + let _e16 = exchanged_1; + if !(_e16) { + } else { + break; + } + let _e18 = old_1; + let new_1 = bitcast((bitcast(_e18) + 1.0)); + let _e23 = i_1; + let _e25 = old_1; + let _e26 = atomicCompareExchangeWeak((&arr_u32_[_e23]), _e25, new_1); + old_1 = _e26.old_value; + exchanged_1 = _e26.exchanged; + } + continuing { + let _e7 = i_1; + i_1 = (_e7 + 1u); + } + } + return; +} diff --git a/tests/snapshots.rs b/tests/snapshots.rs index 197a29692f..b437a5589d 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -492,6 +492,7 @@ fn convert_wgsl() { "access", Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, ), + ("atomicCompareExchange", Targets::SPIRV | Targets::WGSL), ( "padding", Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,