Skip to content

Commit

Permalink
Adds subgroup_operations tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
Lichtso authored and exrook committed Oct 22, 2023
1 parent 5d1e4ee commit da80ce8
Show file tree
Hide file tree
Showing 3 changed files with 218 additions and 0 deletions.
1 change: 1 addition & 0 deletions tests/tests/gpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ mod scissor_tests;
mod shader;
mod shader_primitive_index;
mod shader_view_format;
mod subgroup_operations;
mod texture_bounds;
mod transfer;
mod vertex_indices;
Expand Down
108 changes: 108 additions & 0 deletions tests/tests/subgroup_operations/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
use std::{borrow::Cow, num::NonZeroU64};

use wasm_bindgen_test::*;
use wgpu_test::{initialize_test, TestParameters};

const THREAD_COUNT: u64 = 128;

#[test]
#[wasm_bindgen_test]
fn subgroup_operations() {
initialize_test(
TestParameters::default()
.features(wgpu::Features::SUBGROUP_OPERATIONS)
.limits(wgpu::Limits::downlevel_defaults()),
|ctx| {
let device = &ctx.device;

let storage_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: THREAD_COUNT * std::mem::size_of::<u32>() as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});

let bind_group_layout =
device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("bind group layout"),
entries: &[wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: NonZeroU64::new(
THREAD_COUNT * std::mem::size_of::<u32>() as u64,
),
},
count: None,
}],
});

let cs_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("shader.wgsl"))),
});

let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("main"),
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});

let compute_pipeline =
device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: Some(&pipeline_layout),
module: &cs_module,
entry_point: "main",
});

let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: storage_buffer.as_entire_binding(),
}],
layout: &bind_group_layout,
label: Some("bind group"),
});

let mut encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: None,
timestamp_writes: None,
});
cpass.set_pipeline(&compute_pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
cpass.dispatch_workgroups(THREAD_COUNT as u32, 1, 1);
}

let mapping_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Mapping buffer"),
size: THREAD_COUNT * std::mem::size_of::<u32>() as u64,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
encoder.copy_buffer_to_buffer(
&storage_buffer,
0,
&mapping_buffer,
0,
THREAD_COUNT * std::mem::size_of::<u32>() as u64,
);
ctx.queue.submit(Some(encoder.finish()));

mapping_buffer
.slice(..)
.map_async(wgpu::MapMode::Read, |_| ());
ctx.device.poll(wgpu::Maintain::Wait);
let mapping_buffer_view = mapping_buffer.slice(..).get_mapped_range();
let result: &[u32; THREAD_COUNT as usize] = bytemuck::from_bytes(&mapping_buffer_view);
assert_eq!(result, &[27; THREAD_COUNT as usize]);
},
)
}
109 changes: 109 additions & 0 deletions tests/tests/subgroup_operations/shader.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
@group(0)
@binding(0)
var<storage, read_write> storage_buffer: array<u32>;

@compute
@workgroup_size(128)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_subgroups) num_subgroups: u32,
@builtin(subgroup_id) subgroup_id: u32,
@builtin(subgroup_size) subgroup_size: u32,
@builtin(subgroup_invocation_id) subgroup_invocation_id: u32,
) {
var passed = 0u;
var expected: u32;

passed += u32(num_subgroups == 128u / subgroup_size);
passed += u32(subgroup_id == global_id.x / subgroup_size);
passed += u32(subgroup_invocation_id == global_id.x % subgroup_size);

var expected_ballot = vec4<u32>(0u);
for(var i = 0u; i < subgroup_size; i += 1u) {
expected_ballot[i / 32u] |= ((global_id.x - subgroup_invocation_id + i) & 1u) << (i % 32u);
}
passed += u32(dot(vec4<u32>(1u), vec4<u32>(subgroupBallot((subgroup_invocation_id & 1u) == 1u) == expected_ballot)) == 4u);

passed += u32(subgroupAll(true));
passed += u32(!subgroupAll(subgroup_invocation_id != 0u));

passed += u32(subgroupAny(subgroup_invocation_id == 0u));
passed += u32(!subgroupAny(false));

expected = 0u;
for(var i = 0u; i < subgroup_size; i += 1u) {
expected += global_id.x - subgroup_invocation_id + i + 1u;
}
passed += u32(subgroupAdd(global_id.x + 1u) == expected);

expected = 1u;
for(var i = 0u; i < subgroup_size; i += 1u) {
expected *= global_id.x - subgroup_invocation_id + i + 1u;
}
passed += u32(subgroupMul(global_id.x + 1u) == expected);

expected = 0u;
for(var i = 0u; i < subgroup_size; i += 1u) {
expected = max(expected, global_id.x - subgroup_invocation_id + i + 1u);
}
passed += u32(subgroupMax(global_id.x + 1u) == expected);

expected = 0xFFFFFFFFu;
for(var i = 0u; i < subgroup_size; i += 1u) {
expected = min(expected, global_id.x - subgroup_invocation_id + i + 1u);
}
passed += u32(subgroupMin(global_id.x + 1u) == expected);

expected = 0xFFFFFFFFu;
for(var i = 0u; i < subgroup_size; i += 1u) {
expected &= global_id.x - subgroup_invocation_id + i + 1u;
}
passed += u32(subgroupAnd(global_id.x + 1u) == expected);

expected = 0u;
for(var i = 0u; i < subgroup_size; i += 1u) {
expected |= global_id.x - subgroup_invocation_id + i + 1u;
}
passed += u32(subgroupOr(global_id.x + 1u) == expected);

expected = 0u;
for(var i = 0u; i < subgroup_size; i += 1u) {
expected ^= global_id.x - subgroup_invocation_id + i + 1u;
}
passed += u32(subgroupXor(global_id.x + 1u) == expected);

expected = 0u;
for(var i = 0u; i < subgroup_invocation_id; i += 1u) {
expected += global_id.x - subgroup_invocation_id + i + 1u;
}
passed += u32(subgroupPrefixExclusiveAdd(global_id.x + 1u) == expected);

expected = 1u;
for(var i = 0u; i < subgroup_invocation_id; i += 1u) {
expected *= global_id.x - subgroup_invocation_id + i + 1u;
}
passed += u32(subgroupPrefixExclusiveMul(global_id.x + 1u) == expected);

expected = 0u;
for(var i = 0u; i <= subgroup_invocation_id; i += 1u) {
expected += global_id.x - subgroup_invocation_id + i + 1u;
}
passed += u32(subgroupPrefixInclusiveAdd(global_id.x + 1u) == expected);

expected = 1u;
for(var i = 0u; i <= subgroup_invocation_id; i += 1u) {
expected *= global_id.x - subgroup_invocation_id + i + 1u;
}
passed += u32(subgroupPrefixInclusiveMul(global_id.x + 1u) == expected);

passed += u32(subgroupBroadcastFirst(u32(subgroup_invocation_id != 0u)) == 0u);
passed += u32(subgroupBroadcastFirst(u32(subgroup_invocation_id == 0u)) == 1u);
passed += u32(subgroupBroadcast(subgroup_invocation_id, 4u) == 4u);
passed += u32(subgroupShuffle(subgroup_invocation_id, subgroup_invocation_id) == subgroup_invocation_id);
passed += u32(subgroupShuffle(subgroup_invocation_id, subgroup_size - 1u - subgroup_invocation_id) == subgroup_size - 1u - subgroup_invocation_id);
passed += u32(subgroup_invocation_id == subgroup_size - 1u || subgroupShuffleDown(subgroup_invocation_id, 1u) == subgroup_invocation_id + 1u);
passed += u32(subgroup_invocation_id == 0u || subgroupShuffleUp(subgroup_invocation_id, 1u) == subgroup_invocation_id - 1u);
passed += u32(subgroupShuffleXor(subgroup_invocation_id, subgroup_size - 1u) == (subgroup_invocation_id ^ (subgroup_size - 1u)));

storage_buffer[global_id.x] = passed;
}

0 comments on commit da80ce8

Please sign in to comment.