diff --git a/crates/bevy_render/macros/src/as_bind_group.rs b/crates/bevy_render/macros/src/as_bind_group.rs index 5602efd8939dc..6cfae011e3d01 100644 --- a/crates/bevy_render/macros/src/as_bind_group.rs +++ b/crates/bevy_render/macros/src/as_bind_group.rs @@ -11,6 +11,7 @@ use syn::{ const UNIFORM_ATTRIBUTE_NAME: Symbol = Symbol("uniform"); const TEXTURE_ATTRIBUTE_NAME: Symbol = Symbol("texture"); +const STORAGE_TEXTURE_ATTRIBUTE_NAME: Symbol = Symbol("storage_texture"); const SAMPLER_ATTRIBUTE_NAME: Symbol = Symbol("sampler"); const STORAGE_ATTRIBUTE_NAME: Symbol = Symbol("storage"); const BIND_GROUP_DATA_ATTRIBUTE_NAME: Symbol = Symbol("bind_group_data"); @@ -19,6 +20,7 @@ const BIND_GROUP_DATA_ATTRIBUTE_NAME: Symbol = Symbol("bind_group_data"); enum BindingType { Uniform, Texture, + StorageTexture, Sampler, Storage, } @@ -139,6 +141,8 @@ pub fn derive_as_bind_group(ast: syn::DeriveInput) -> Result { BindingType::Uniform } else if attr_ident == TEXTURE_ATTRIBUTE_NAME { BindingType::Texture + } else if attr_ident == STORAGE_TEXTURE_ATTRIBUTE_NAME { + BindingType::StorageTexture } else if attr_ident == SAMPLER_ATTRIBUTE_NAME { BindingType::Sampler } else if attr_ident == STORAGE_ATTRIBUTE_NAME { @@ -262,6 +266,43 @@ pub fn derive_as_bind_group(ast: syn::DeriveInput) -> Result { } }); } + BindingType::StorageTexture => { + let StorageTextureAttrs { + dimension, + image_format, + access, + visibility, + } = get_storage_texture_binding_attr(nested_meta_items)?; + + let visibility = + visibility.hygienic_quote("e! { #render_path::render_resource }); + + let fallback_image = get_fallback_image(&render_path, dimension); + + binding_impls.push(quote! { + #render_path::render_resource::OwnedBindingResource::TextureView({ + let handle: Option<&#asset_path::Handle<#render_path::texture::Image>> = (&self.#field_name).into(); + if let Some(handle) = handle { + images.get(handle).ok_or_else(|| #render_path::render_resource::AsBindGroupError::RetryNextUpdate)?.texture_view.clone() + } else { + #fallback_image.texture_view.clone() + } + }) + }); + + binding_layouts.push(quote! { + #render_path::render_resource::BindGroupLayoutEntry { + binding: #binding_index, + visibility: #visibility, + ty: #render_path::render_resource::BindingType::StorageTexture { + access: #render_path::render_resource::StorageTextureAccess::#access, + format: #render_path::render_resource::TextureFormat::#image_format, + view_dimension: #render_path::render_resource::#dimension, + }, + count: None, + } + }); + } BindingType::Texture => { let TextureAttrs { dimension, @@ -593,6 +634,10 @@ impl ShaderStageVisibility { fn vertex_fragment() -> Self { Self::Flags(VisibilityFlags::vertex_fragment()) } + + fn compute() -> Self { + Self::Flags(VisibilityFlags::compute()) + } } impl VisibilityFlags { @@ -603,6 +648,13 @@ impl VisibilityFlags { ..Default::default() } } + + fn compute() -> Self { + Self { + compute: true, + ..Default::default() + } + } } impl ShaderStageVisibility { @@ -749,7 +801,72 @@ impl Default for TextureAttrs { } } +struct StorageTextureAttrs { + dimension: BindingTextureDimension, + // Parsing of the image_format parameter is deferred to the type checker, + // which will error if the format is not member of the TextureFormat enum. + image_format: proc_macro2::TokenStream, + // Parsing of the access parameter is deferred to the type checker, + // which will error if the access is not member of the StorageTextureAccess enum. + access: proc_macro2::TokenStream, + visibility: ShaderStageVisibility, +} + +impl Default for StorageTextureAttrs { + fn default() -> Self { + Self { + dimension: Default::default(), + image_format: quote! { Rgba8Unorm }, + access: quote! { ReadWrite }, + visibility: ShaderStageVisibility::compute(), + } + } +} + +fn get_storage_texture_binding_attr(metas: Vec) -> Result { + let mut storage_texture_attrs = StorageTextureAttrs::default(); + + for meta in metas { + use syn::Meta::{List, NameValue}; + match meta { + // Parse #[storage_texture(0, dimension = "...")]. + NameValue(m) if m.path == DIMENSION => { + let value = get_lit_str(DIMENSION, &m.value)?; + storage_texture_attrs.dimension = get_texture_dimension_value(value)?; + } + // Parse #[storage_texture(0, format = ...))]. + NameValue(m) if m.path == IMAGE_FORMAT => { + storage_texture_attrs.image_format = m.value.into_token_stream(); + } + // Parse #[storage_texture(0, access = ...))]. + NameValue(m) if m.path == ACCESS => { + storage_texture_attrs.access = m.value.into_token_stream(); + } + // Parse #[storage_texture(0, visibility(...))]. + List(m) if m.path == VISIBILITY => { + storage_texture_attrs.visibility = get_visibility_flag_value(&m)?; + } + NameValue(m) => { + return Err(Error::new_spanned( + m.path, + "Not a valid name. Available attributes: `dimension`, `image_format`, `access`.", + )); + } + _ => { + return Err(Error::new_spanned( + meta, + "Not a name value pair: `foo = \"...\"`", + )); + } + } + } + + Ok(storage_texture_attrs) +} + const DIMENSION: Symbol = Symbol("dimension"); +const IMAGE_FORMAT: Symbol = Symbol("image_format"); +const ACCESS: Symbol = Symbol("access"); const SAMPLE_TYPE: Symbol = Symbol("sample_type"); const FILTERABLE: Symbol = Symbol("filterable"); const MULTISAMPLED: Symbol = Symbol("multisampled"); diff --git a/crates/bevy_render/macros/src/lib.rs b/crates/bevy_render/macros/src/lib.rs index 89eec6b220c9a..97126ba830bf4 100644 --- a/crates/bevy_render/macros/src/lib.rs +++ b/crates/bevy_render/macros/src/lib.rs @@ -51,7 +51,7 @@ pub fn derive_extract_component(input: TokenStream) -> TokenStream { #[proc_macro_derive( AsBindGroup, - attributes(uniform, texture, sampler, bind_group_data, storage) + attributes(uniform, storage_texture, texture, sampler, bind_group_data, storage) )] pub fn derive_as_bind_group(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); diff --git a/crates/bevy_render/src/render_resource/bind_group.rs b/crates/bevy_render/src/render_resource/bind_group.rs index 213c39782679c..c90d23a01043a 100644 --- a/crates/bevy_render/src/render_resource/bind_group.rs +++ b/crates/bevy_render/src/render_resource/bind_group.rs @@ -86,6 +86,8 @@ impl Deref for BindGroup { /// values: Vec, /// #[storage(4, read_only, buffer)] /// buffer: Buffer, +/// #[storage_texture(5)] +/// storage_texture: Handle, /// } /// ``` /// @@ -96,6 +98,7 @@ impl Deref for BindGroup { /// @group(1) @binding(1) var color_texture: texture_2d; /// @group(1) @binding(2) var color_sampler: sampler; /// @group(1) @binding(3) var values: array; +/// @group(1) @binding(5) var storage_texture: texture_storage_2d; /// ``` /// Note that the "group" index is determined by the usage context. It is not defined in [`AsBindGroup`]. For example, in Bevy material bind groups /// are generally bound to group 1. @@ -122,6 +125,19 @@ impl Deref for BindGroup { /// | `multisampled` = ... | `true`, `false` | `false` | /// | `visibility(...)` | `all`, `none`, or a list-combination of `vertex`, `fragment`, `compute` | `vertex`, `fragment` | /// +/// * `storage_texture(BINDING_INDEX, arguments)` +/// * This field's [`Handle`](bevy_asset::Handle) will be used to look up the matching [`Texture`](crate::render_resource::Texture) +/// GPU resource, which will be bound as a storage texture in shaders. The field will be assumed to implement [`Into>>`]. In practice, +/// most fields should be a [`Handle`](bevy_asset::Handle) or [`Option>`]. If the value of an [`Option>`] is +/// [`None`], the [`FallbackImage`] resource will be used instead. +/// +/// | Arguments | Values | Default | +/// |------------------------|--------------------------------------------------------------------------------------------|---------------| +/// | `dimension` = "..." | `"1d"`, `"2d"`, `"2d_array"`, `"3d"`, `"cube"`, `"cube_array"` | `"2d"` | +/// | `image_format` = ... | any member of [`TextureFormat`](crate::render_resource::TextureFormat) | `Rgba8Unorm` | +/// | `access` = ... | any member of [`StorageTextureAccess`](crate::render_resource::StorageTextureAccess) | `ReadWrite` | +/// | `visibility(...)` | `all`, `none`, or a list-combination of `vertex`, `fragment`, `compute` | `compute` | +/// /// * `sampler(BINDING_INDEX, arguments)` /// * This field's [`Handle`](bevy_asset::Handle) will be used to look up the matching [`Sampler`](crate::render_resource::Sampler) GPU /// resource, which will be bound as a sampler in shaders. The field will be assumed to implement [`Into>>`]. In practice, diff --git a/examples/shader/compute_shader_game_of_life.rs b/examples/shader/compute_shader_game_of_life.rs index 10c368771474b..7f47fb59e3a3b 100644 --- a/examples/shader/compute_shader_game_of_life.rs +++ b/examples/shader/compute_shader_game_of_life.rs @@ -11,6 +11,7 @@ use bevy::{ render_graph::{self, RenderGraph}, render_resource::*, renderer::{RenderContext, RenderDevice}, + texture::FallbackImage, Render, RenderApp, RenderSet, }, window::WindowPlugin, @@ -63,7 +64,7 @@ fn setup(mut commands: Commands, mut images: ResMut>) { }); commands.spawn(Camera2dBundle::default()); - commands.insert_resource(GameOfLifeImage(image)); + commands.insert_resource(GameOfLifeImage { texture: image }); } pub struct GameOfLifeComputePlugin; @@ -93,11 +94,14 @@ impl Plugin for GameOfLifeComputePlugin { } } -#[derive(Resource, Clone, Deref, ExtractResource)] -struct GameOfLifeImage(Handle); +#[derive(Resource, Clone, Deref, ExtractResource, AsBindGroup)] +struct GameOfLifeImage { + #[storage_texture(0, image_format = Rgba8Unorm, access = ReadWrite)] + texture: Handle, +} #[derive(Resource)] -struct GameOfLifeImageBindGroup(BindGroup); +struct GameOfLifeImageBindGroup(PreparedBindGroup<()>); fn prepare_bind_group( mut commands: Commands, @@ -105,16 +109,19 @@ fn prepare_bind_group( gpu_images: Res>, game_of_life_image: Res, render_device: Res, + fallback_image: Res, ) { - let view = gpu_images.get(&game_of_life_image.0).unwrap(); - let bind_group = render_device.create_bind_group(&BindGroupDescriptor { - label: None, - layout: &pipeline.texture_bind_group_layout, - entries: &[BindGroupEntry { - binding: 0, - resource: BindingResource::TextureView(&view.texture_view), - }], - }); + // When `AsBindGroup` is derived, `as_bind_group` will never return + // an error, so we can safely unwrap. + let bind_group = game_of_life_image + .as_bind_group( + &pipeline.texture_bind_group_layout, + &render_device, + &gpu_images, + &fallback_image, + ) + .ok() + .unwrap(); commands.insert_resource(GameOfLifeImageBindGroup(bind_group)); } @@ -127,22 +134,8 @@ pub struct GameOfLifePipeline { impl FromWorld for GameOfLifePipeline { fn from_world(world: &mut World) -> Self { - let texture_bind_group_layout = - world - .resource::() - .create_bind_group_layout(&BindGroupLayoutDescriptor { - label: None, - entries: &[BindGroupLayoutEntry { - binding: 0, - visibility: ShaderStages::COMPUTE, - ty: BindingType::StorageTexture { - access: StorageTextureAccess::ReadWrite, - format: TextureFormat::Rgba8Unorm, - view_dimension: TextureViewDimension::D2, - }, - count: None, - }], - }); + let render_device = world.resource::(); + let texture_bind_group_layout = GameOfLifeImage::bind_group_layout(render_device); let shader = world .resource::() .load("shaders/game_of_life.wgsl"); @@ -229,7 +222,7 @@ impl render_graph::Node for GameOfLifeNode { .command_encoder() .begin_compute_pass(&ComputePassDescriptor::default()); - pass.set_bind_group(0, texture_bind_group, &[]); + pass.set_bind_group(0, &texture_bind_group.bind_group, &[]); // select the pipeline based on the current state match self.state {