Skip to content

Commit

Permalink
add storage_texture option to as_bind_group macro
Browse files Browse the repository at this point in the history
Changes:

- Add storage_texture option to as_bind_group macro
- Use it to generate the bind group layout for the compute shader example
  • Loading branch information
HugoPeters1024 committed Oct 4, 2023
1 parent b6ead2b commit 040cb3e
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 31 deletions.
117 changes: 117 additions & 0 deletions crates/bevy_render/macros/src/as_bind_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use syn::{

const UNIFORM_ATTRIBUTE_NAME: Symbol = Symbol("uniform");
const TEXTURE_ATTRIBUTE_NAME: Symbol = Symbol("texture");
const STORAGE_TEXTURE_ATTRIBUTE_NAME: Symbol = Symbol("storage_texture");
const SAMPLER_ATTRIBUTE_NAME: Symbol = Symbol("sampler");
const STORAGE_ATTRIBUTE_NAME: Symbol = Symbol("storage");
const BIND_GROUP_DATA_ATTRIBUTE_NAME: Symbol = Symbol("bind_group_data");
Expand All @@ -19,6 +20,7 @@ const BIND_GROUP_DATA_ATTRIBUTE_NAME: Symbol = Symbol("bind_group_data");
enum BindingType {
Uniform,
Texture,
StorageTexture,
Sampler,
Storage,
}
Expand Down Expand Up @@ -139,6 +141,8 @@ pub fn derive_as_bind_group(ast: syn::DeriveInput) -> Result<TokenStream> {
BindingType::Uniform
} else if attr_ident == TEXTURE_ATTRIBUTE_NAME {
BindingType::Texture
} else if attr_ident == STORAGE_TEXTURE_ATTRIBUTE_NAME {
BindingType::StorageTexture
} else if attr_ident == SAMPLER_ATTRIBUTE_NAME {
BindingType::Sampler
} else if attr_ident == STORAGE_ATTRIBUTE_NAME {
Expand Down Expand Up @@ -262,6 +266,43 @@ pub fn derive_as_bind_group(ast: syn::DeriveInput) -> Result<TokenStream> {
}
});
}
BindingType::StorageTexture => {
let StorageTextureAttrs {
dimension,
image_format,
access,
visibility,
} = get_storage_texture_binding_attr(nested_meta_items)?;

let visibility =
visibility.hygienic_quote(&quote! { #render_path::render_resource });

let fallback_image = get_fallback_image(&render_path, dimension);

binding_impls.push(quote! {
#render_path::render_resource::OwnedBindingResource::TextureView({
let handle: Option<&#asset_path::Handle<#render_path::texture::Image>> = (&self.#field_name).into();
if let Some(handle) = handle {
images.get(handle).ok_or_else(|| #render_path::render_resource::AsBindGroupError::RetryNextUpdate)?.texture_view.clone()
} else {
#fallback_image.texture_view.clone()
}
})
});

binding_layouts.push(quote! {
#render_path::render_resource::BindGroupLayoutEntry {
binding: #binding_index,
visibility: #visibility,
ty: #render_path::render_resource::BindingType::StorageTexture {
access: #render_path::render_resource::StorageTextureAccess::#access,
format: #render_path::render_resource::TextureFormat::#image_format,
view_dimension: #render_path::render_resource::#dimension,
},
count: None,
}
});
}
BindingType::Texture => {
let TextureAttrs {
dimension,
Expand Down Expand Up @@ -593,6 +634,10 @@ impl ShaderStageVisibility {
fn vertex_fragment() -> Self {
Self::Flags(VisibilityFlags::vertex_fragment())
}

fn compute() -> Self {
Self::Flags(VisibilityFlags::compute())
}
}

impl VisibilityFlags {
Expand All @@ -603,6 +648,13 @@ impl VisibilityFlags {
..Default::default()
}
}

fn compute() -> Self {
Self {
compute: true,
..Default::default()
}
}
}

impl ShaderStageVisibility {
Expand Down Expand Up @@ -749,7 +801,72 @@ impl Default for TextureAttrs {
}
}

struct StorageTextureAttrs {
dimension: BindingTextureDimension,
// Parsing of the image_format parameter is deferred to the type checker,
// which will error if the format is not member of the TextureFormat enum.
image_format: proc_macro2::TokenStream,
// Parsing of the access parameter is deferred to the type checker,
// which will error if the access is not member of the StorageTextureAccess enum.
access: proc_macro2::TokenStream,
visibility: ShaderStageVisibility,
}

impl Default for StorageTextureAttrs {
fn default() -> Self {
Self {
dimension: Default::default(),
image_format: quote! { Rgba8Unorm },
access: quote! { ReadWrite },
visibility: ShaderStageVisibility::compute(),
}
}
}

fn get_storage_texture_binding_attr(metas: Vec<Meta>) -> Result<StorageTextureAttrs> {
let mut storage_texture_attrs = StorageTextureAttrs::default();

for meta in metas {
use syn::Meta::{List, NameValue};
match meta {
// Parse #[storage_texture(0, dimension = "...")].
NameValue(m) if m.path == DIMENSION => {
let value = get_lit_str(DIMENSION, &m.value)?;
storage_texture_attrs.dimension = get_texture_dimension_value(value)?;
}
// Parse #[storage_texture(0, format = ...))].
NameValue(m) if m.path == IMAGE_FORMAT => {
storage_texture_attrs.image_format = m.value.into_token_stream();
}
// Parse #[storage_texture(0, access = ...))].
NameValue(m) if m.path == ACCESS => {
storage_texture_attrs.access = m.value.into_token_stream();
}
// Parse #[storage_texture(0, visibility(...))].
List(m) if m.path == VISIBILITY => {
storage_texture_attrs.visibility = get_visibility_flag_value(&m)?;
}
NameValue(m) => {
return Err(Error::new_spanned(
m.path,
"Not a valid name. Available attributes: `dimension`, `image_format`, `access`.",
));
}
_ => {
return Err(Error::new_spanned(
meta,
"Not a name value pair: `foo = \"...\"`",
));
}
}
}

Ok(storage_texture_attrs)
}

const DIMENSION: Symbol = Symbol("dimension");
const IMAGE_FORMAT: Symbol = Symbol("image_format");
const ACCESS: Symbol = Symbol("access");
const SAMPLE_TYPE: Symbol = Symbol("sample_type");
const FILTERABLE: Symbol = Symbol("filterable");
const MULTISAMPLED: Symbol = Symbol("multisampled");
Expand Down
2 changes: 1 addition & 1 deletion crates/bevy_render/macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pub fn derive_extract_component(input: TokenStream) -> TokenStream {

#[proc_macro_derive(
AsBindGroup,
attributes(uniform, texture, sampler, bind_group_data, storage)
attributes(uniform, storage_texture, texture, sampler, bind_group_data, storage)
)]
pub fn derive_as_bind_group(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
Expand Down
16 changes: 16 additions & 0 deletions crates/bevy_render/src/render_resource/bind_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ impl Deref for BindGroup {
/// values: Vec<f32>,
/// #[storage(4, read_only, buffer)]
/// buffer: Buffer,
/// #[storage_texture(5)]
/// storage_texture: Handle<Image>,
/// }
/// ```
///
Expand All @@ -96,6 +98,7 @@ impl Deref for BindGroup {
/// @group(1) @binding(1) var color_texture: texture_2d<f32>;
/// @group(1) @binding(2) var color_sampler: sampler;
/// @group(1) @binding(3) var<storage> values: array<f32>;
/// @group(1) @binding(5) var storage_texture: texture_storage_2d<rgba8unorm, read_write>;
/// ```
/// Note that the "group" index is determined by the usage context. It is not defined in [`AsBindGroup`]. For example, in Bevy material bind groups
/// are generally bound to group 1.
Expand All @@ -122,6 +125,19 @@ impl Deref for BindGroup {
/// | `multisampled` = ... | `true`, `false` | `false` |
/// | `visibility(...)` | `all`, `none`, or a list-combination of `vertex`, `fragment`, `compute` | `vertex`, `fragment` |
///
/// * `storage_texture(BINDING_INDEX, arguments)`
/// * This field's [`Handle<Image>`](bevy_asset::Handle) will be used to look up the matching [`Texture`](crate::render_resource::Texture)
/// GPU resource, which will be bound as a storage texture in shaders. The field will be assumed to implement [`Into<Option<Handle<Image>>>`]. In practice,
/// most fields should be a [`Handle<Image>`](bevy_asset::Handle) or [`Option<Handle<Image>>`]. If the value of an [`Option<Handle<Image>>`] is
/// [`None`], the [`FallbackImage`] resource will be used instead.
///
/// | Arguments | Values | Default |
/// |------------------------|--------------------------------------------------------------------------------------------|---------------|
/// | `dimension` = "..." | `"1d"`, `"2d"`, `"2d_array"`, `"3d"`, `"cube"`, `"cube_array"` | `"2d"` |
/// | `image_format` = ... | any member of [`TextureFormat`](crate::render_resource::TextureFormat) | `Rgba8Unorm` |
/// | `access` = ... | any member of [`StorageTextureAccess`](crate::render_resource::StorageTextureAccess) | `ReadWrite` |
/// | `visibility(...)` | `all`, `none`, or a list-combination of `vertex`, `fragment`, `compute` | `compute` |
///
/// * `sampler(BINDING_INDEX, arguments)`
/// * This field's [`Handle<Image>`](bevy_asset::Handle) will be used to look up the matching [`Sampler`](crate::render_resource::Sampler) GPU
/// resource, which will be bound as a sampler in shaders. The field will be assumed to implement [`Into<Option<Handle<Image>>>`]. In practice,
Expand Down
53 changes: 23 additions & 30 deletions examples/shader/compute_shader_game_of_life.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use bevy::{
render_graph::{self, RenderGraph},
render_resource::*,
renderer::{RenderContext, RenderDevice},
texture::FallbackImage,
Render, RenderApp, RenderSet,
},
window::WindowPlugin,
Expand Down Expand Up @@ -63,7 +64,7 @@ fn setup(mut commands: Commands, mut images: ResMut<Assets<Image>>) {
});
commands.spawn(Camera2dBundle::default());

commands.insert_resource(GameOfLifeImage(image));
commands.insert_resource(GameOfLifeImage { texture: image });
}

pub struct GameOfLifeComputePlugin;
Expand Down Expand Up @@ -93,28 +94,34 @@ impl Plugin for GameOfLifeComputePlugin {
}
}

#[derive(Resource, Clone, Deref, ExtractResource)]
struct GameOfLifeImage(Handle<Image>);
#[derive(Resource, Clone, Deref, ExtractResource, AsBindGroup)]
struct GameOfLifeImage {
#[storage_texture(0, image_format = Rgba8Unorm, access = ReadWrite)]
texture: Handle<Image>,
}

#[derive(Resource)]
struct GameOfLifeImageBindGroup(BindGroup);
struct GameOfLifeImageBindGroup(PreparedBindGroup<()>);

fn prepare_bind_group(
mut commands: Commands,
pipeline: Res<GameOfLifePipeline>,
gpu_images: Res<RenderAssets<Image>>,
game_of_life_image: Res<GameOfLifeImage>,
render_device: Res<RenderDevice>,
fallback_image: Res<FallbackImage>,
) {
let view = gpu_images.get(&game_of_life_image.0).unwrap();
let bind_group = render_device.create_bind_group(&BindGroupDescriptor {
label: None,
layout: &pipeline.texture_bind_group_layout,
entries: &[BindGroupEntry {
binding: 0,
resource: BindingResource::TextureView(&view.texture_view),
}],
});
// When `AsBindGroup` is derived, `as_bind_group` will never return
// an error, so we can safely unwrap.
let bind_group = game_of_life_image
.as_bind_group(
&pipeline.texture_bind_group_layout,
&render_device,
&gpu_images,
&fallback_image,
)
.ok()
.unwrap();
commands.insert_resource(GameOfLifeImageBindGroup(bind_group));
}

Expand All @@ -127,22 +134,8 @@ pub struct GameOfLifePipeline {

impl FromWorld for GameOfLifePipeline {
fn from_world(world: &mut World) -> Self {
let texture_bind_group_layout =
world
.resource::<RenderDevice>()
.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: None,
entries: &[BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::COMPUTE,
ty: BindingType::StorageTexture {
access: StorageTextureAccess::ReadWrite,
format: TextureFormat::Rgba8Unorm,
view_dimension: TextureViewDimension::D2,
},
count: None,
}],
});
let render_device = world.resource::<RenderDevice>();
let texture_bind_group_layout = GameOfLifeImage::bind_group_layout(render_device);
let shader = world
.resource::<AssetServer>()
.load("shaders/game_of_life.wgsl");
Expand Down Expand Up @@ -229,7 +222,7 @@ impl render_graph::Node for GameOfLifeNode {
.command_encoder()
.begin_compute_pass(&ComputePassDescriptor::default());

pass.set_bind_group(0, texture_bind_group, &[]);
pass.set_bind_group(0, &texture_bind_group.bind_group, &[]);

// select the pipeline based on the current state
match self.state {
Expand Down

0 comments on commit 040cb3e

Please sign in to comment.