diff --git a/crates/bevy_render/macros/src/lib.rs b/crates/bevy_render/macros/src/lib.rs index c58a8fd146..35990f465f 100644 --- a/crates/bevy_render/macros/src/lib.rs +++ b/crates/bevy_render/macros/src/lib.rs @@ -4,6 +4,7 @@ mod as_bind_group; mod extract_component; mod extract_resource; +mod specialize; use bevy_macro_utils::{derive_label, BevyManifest}; use proc_macro::TokenStream; @@ -14,6 +15,10 @@ pub(crate) fn bevy_render_path() -> syn::Path { BevyManifest::shared().get_path("bevy_render") } +pub(crate) fn bevy_ecs_path() -> syn::Path { + BevyManifest::shared().get_path("bevy_ecs") +} + #[proc_macro_derive(ExtractResource)] pub fn derive_extract_resource(input: TokenStream) -> TokenStream { extract_resource::derive_extract_resource(input) @@ -102,6 +107,20 @@ pub fn derive_render_sub_graph(input: TokenStream) -> TokenStream { derive_label(input, "RenderSubGraph", &trait_path) } +/// Derive macro generating an impl of the trait `Specialize` +/// +/// This only works for structs whose members all implement `Specialize` +#[proc_macro_derive(Specialize, attributes(specialize, key, base_descriptor))] +pub fn derive_specialize(input: TokenStream) -> TokenStream { + specialize::impl_specialize(input) +} + +/// Derive macro generating the most common impl of the trait `SpecializerKey` +#[proc_macro_derive(SpecializerKey)] +pub fn derive_specializer_key(input: TokenStream) -> TokenStream { + specialize::impl_specializer_key(input) +} + #[proc_macro_derive(ShaderLabel)] pub fn derive_shader_label(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); diff --git a/crates/bevy_render/macros/src/specialize.rs b/crates/bevy_render/macros/src/specialize.rs new file mode 100644 index 0000000000..092de6e8d7 --- /dev/null +++ b/crates/bevy_render/macros/src/specialize.rs @@ -0,0 +1,483 @@ +use bevy_macro_utils::fq_std::{FQDefault, FQResult}; +use proc_macro::TokenStream; +use proc_macro2::Span; +use quote::{format_ident, quote}; +use syn::{ + parse, + parse::{Parse, ParseStream}, + parse_macro_input, parse_quote, + spanned::Spanned, + Data, DataStruct, DeriveInput, Expr, Fields, Ident, Index, Member, Meta, MetaList, Pat, Path, + Token, Type, WherePredicate, +}; + +const SPECIALIZE_ATTR_IDENT: &str = "specialize"; +const SPECIALIZE_ALL_IDENT: &str = "all"; + +const KEY_ATTR_IDENT: &str = "key"; +const KEY_DEFAULT_IDENT: &str = "default"; + +const BASE_DESCRIPTOR_ATTR_IDENT: &str = "base_descriptor"; + +enum SpecializeImplTargets { + All, + Specific(Vec), +} + +impl Parse for SpecializeImplTargets { + fn parse(input: ParseStream) -> syn::Result { + let paths = input.parse_terminated(Path::parse, Token![,])?; + if paths + .first() + .is_some_and(|p| p.is_ident(SPECIALIZE_ALL_IDENT)) + { + Ok(SpecializeImplTargets::All) + } else { + Ok(SpecializeImplTargets::Specific(paths.into_iter().collect())) + } + } +} + +#[derive(Clone)] +enum Key { + Whole, + Default, + Index(Index), + Custom(Expr), +} + +impl Key { + fn expr(&self) -> Expr { + match self { + Key::Whole => parse_quote!(key), + Key::Default => parse_quote!(#FQDefault::default()), + Key::Index(index) => { + let member = Member::Unnamed(index.clone()); + parse_quote!(key.#member) + } + Key::Custom(expr) => expr.clone(), + } + } +} + +const KEY_ERROR_MSG: &str = "Invalid key override. Must be either `default` or a valid Rust expression of the correct key type"; + +impl Parse for Key { + fn parse(input: ParseStream) -> syn::Result { + if let Ok(ident) = input.parse::() { + if ident == KEY_DEFAULT_IDENT { + Ok(Key::Default) + } else { + Err(syn::Error::new_spanned(ident, KEY_ERROR_MSG)) + } + } else { + input.parse::().map(Key::Custom).map_err(|mut err| { + err.extend(syn::Error::new(err.span(), KEY_ERROR_MSG)); + err + }) + } + } +} + +#[derive(Clone)] +struct FieldInfo { + ty: Type, + member: Member, + key: Key, + use_base_descriptor: bool, +} + +impl FieldInfo { + fn key_ty(&self, specialize_path: &Path, target_path: &Path) -> Option { + let ty = &self.ty; + matches!(self.key, Key::Whole | Key::Index(_)) + .then_some(parse_quote!(<#ty as #specialize_path::Specialize<#target_path>>::Key)) + } + + fn key_ident(&self, ident: Ident) -> Option { + matches!(self.key, Key::Whole | Key::Index(_)).then_some(ident) + } + + fn specialize_expr(&self, specialize_path: &Path, target_path: &Path) -> Expr { + let FieldInfo { + ty, member, key, .. + } = &self; + let key_expr = key.expr(); + parse_quote!(<#ty as #specialize_path::Specialize<#target_path>>::specialize(&self.#member, #key_expr, descriptor)) + } + + fn specialize_predicate(&self, specialize_path: &Path, target_path: &Path) -> WherePredicate { + let ty = &self.ty; + if matches!(&self.key, Key::Default) { + parse_quote!(#ty: #specialize_path::Specialize<#target_path, Key: #FQDefault>) + } else { + parse_quote!(#ty: #specialize_path::Specialize<#target_path>) + } + } + + fn get_base_descriptor_predicate( + &self, + specialize_path: &Path, + target_path: &Path, + ) -> WherePredicate { + let ty = &self.ty; + parse_quote!(#ty: #specialize_path::GetBaseDescriptor<#target_path>) + } +} + +fn get_field_info(fields: &Fields, targets: &SpecializeImplTargets) -> syn::Result> { + let mut field_info: Vec = Vec::new(); + let mut used_count = 0; + let mut single_index = 0; + for (index, field) in fields.iter().enumerate() { + let field_ty = field.ty.clone(); + let field_member = field.ident.clone().map_or( + Member::Unnamed(Index { + index: index as u32, + span: field.span(), + }), + Member::Named, + ); + let key_index = Index { + index: used_count, + span: field.span(), + }; + + let mut use_key_field = true; + let mut key = Key::Index(key_index); + let mut use_base_descriptor = false; + for attr in &field.attrs { + match &attr.meta { + Meta::Path(path) if path.is_ident(&BASE_DESCRIPTOR_ATTR_IDENT) => { + use_base_descriptor = true; + } + Meta::List(MetaList { path, tokens, .. }) if path.is_ident(&KEY_ATTR_IDENT) => { + let owned_tokens = tokens.clone().into(); + let Ok(parsed_key) = parse::(owned_tokens) else { + return Err(syn::Error::new( + attr.span(), + "Invalid key override attribute", + )); + }; + key = parsed_key; + if matches!( + (&key, &targets), + (Key::Custom(_), SpecializeImplTargets::All) + ) { + return Err(syn::Error::new( + attr.span(), + "#[key(default)] is the only key override type allowed with #[specialize(all)]", + )); + } + use_key_field = false; + } + _ => {} + } + } + + if use_key_field { + used_count += 1; + single_index = index; + } + + field_info.push(FieldInfo { + ty: field_ty, + member: field_member, + key, + use_base_descriptor, + }); + } + + if used_count == 1 { + field_info[single_index].key = Key::Whole; + } + + Ok(field_info) +} + +fn get_struct_fields<'a>(ast: &'a DeriveInput, derive_name: &str) -> syn::Result<&'a Fields> { + match &ast.data { + Data::Struct(DataStruct { fields, .. }) => Ok(fields), + Data::Enum(data_enum) => Err(syn::Error::new( + data_enum.enum_token.span(), + format!("#[derive({derive_name})] only supports structs."), + )), + Data::Union(data_union) => Err(syn::Error::new( + data_union.union_token.span(), + format!("#[derive({derive_name})] only supports structs."), + )), + } +} + +fn get_specialize_targets( + ast: &DeriveInput, + derive_name: &str, +) -> syn::Result { + let specialize_attr = ast.attrs.iter().find_map(|attr| { + if attr.path().is_ident(SPECIALIZE_ATTR_IDENT) { + if let Meta::List(meta_list) = &attr.meta { + return Some(meta_list); + } + } + None + }); + let Some(specialize_meta_list) = specialize_attr else { + return Err(syn::Error::new( + Span::call_site(), + format!("#[derive({derive_name})] must be accompanied by #[specialize(..targets)].\n Example usages: #[specialize(RenderPipeline)], #[specialize(all)]") + )); + }; + parse::(specialize_meta_list.tokens.clone().into()) +} + +macro_rules! guard { + ($expr: expr) => { + match $expr { + Ok(__val) => __val, + Err(err) => return err.to_compile_error().into(), + } + }; +} + +pub fn impl_specialize(input: TokenStream) -> TokenStream { + let bevy_render_path: Path = crate::bevy_render_path(); + let specialize_path = { + let mut path = bevy_render_path.clone(); + path.segments.push(format_ident!("render_resource").into()); + path + }; + + let ecs_path = crate::bevy_ecs_path(); + + let ast = parse_macro_input!(input as DeriveInput); + let targets = guard!(get_specialize_targets(&ast, "Specialize")); + let fields = guard!(get_struct_fields(&ast, "Specialize")); + let field_info = guard!(get_field_info(fields, &targets)); + + let key_idents: Vec> = field_info + .iter() + .enumerate() + .map(|(i, field_info)| field_info.key_ident(format_ident!("key{i}"))) + .collect(); + let key_tuple_idents: Vec = key_idents.iter().flatten().cloned().collect(); + let ignore_pat: Pat = parse_quote!(_); + let key_patterns: Vec = key_idents + .iter() + .map(|key_ident| match key_ident { + Some(key_ident) => parse_quote!(#key_ident), + None => ignore_pat.clone(), + }) + .collect(); + + let base_descriptor_fields = field_info + .iter() + .filter(|field| field.use_base_descriptor) + .collect::>(); + + if base_descriptor_fields.len() > 1 { + return syn::Error::new( + Span::call_site(), + "Too many #[base_descriptor] attributes found. It must be present on exactly one field", + ) + .into_compile_error() + .into(); + } + + let base_descriptor_field = base_descriptor_fields.first().copied(); + + match targets { + SpecializeImplTargets::All => { + let specialize_impl = impl_specialize_all( + &specialize_path, + &ecs_path, + &ast, + &field_info, + &key_patterns, + &key_tuple_idents, + ); + let get_base_descriptor_impl = base_descriptor_field + .map(|field_info| impl_get_base_descriptor_all(&specialize_path, &ast, field_info)) + .unwrap_or_default(); + [specialize_impl, get_base_descriptor_impl] + .into_iter() + .collect() + } + SpecializeImplTargets::Specific(targets) => { + let specialize_impls = targets.iter().map(|target| { + impl_specialize_specific( + &specialize_path, + &ecs_path, + &ast, + &field_info, + target, + &key_patterns, + &key_tuple_idents, + ) + }); + let get_base_descriptor_impls = targets.iter().filter_map(|target| { + base_descriptor_field.map(|field_info| { + impl_get_base_descriptor_specific(&specialize_path, &ast, field_info, target) + }) + }); + specialize_impls.chain(get_base_descriptor_impls).collect() + } + } +} + +fn impl_specialize_all( + specialize_path: &Path, + ecs_path: &Path, + ast: &DeriveInput, + field_info: &[FieldInfo], + key_patterns: &[Pat], + key_tuple_idents: &[Ident], +) -> TokenStream { + let target_path = Path::from(format_ident!("T")); + let key_elems: Vec = field_info + .iter() + .filter_map(|field_info| field_info.key_ty(specialize_path, &target_path)) + .collect(); + let specialize_exprs: Vec = field_info + .iter() + .map(|field_info| field_info.specialize_expr(specialize_path, &target_path)) + .collect(); + + let struct_name = &ast.ident; + let mut generics = ast.generics.clone(); + generics.params.insert( + 0, + parse_quote!(#target_path: #specialize_path::Specializable), + ); + + if !field_info.is_empty() { + let where_clause = generics.make_where_clause(); + for field in field_info { + where_clause + .predicates + .push(field.specialize_predicate(specialize_path, &target_path)); + } + } + + let (_, type_generics, _) = ast.generics.split_for_impl(); + let (impl_generics, _, where_clause) = &generics.split_for_impl(); + + TokenStream::from(quote! { + impl #impl_generics #specialize_path::Specialize<#target_path> for #struct_name #type_generics #where_clause { + type Key = (#(#key_elems),*); + + fn specialize( + &self, + key: Self::Key, + descriptor: &mut <#target_path as #specialize_path::Specializable>::Descriptor + ) -> #FQResult<#specialize_path::Canonical, #ecs_path::error::BevyError> { + #(let #key_patterns = #specialize_exprs?;)* + #FQResult::Ok((#(#key_tuple_idents),*)) + } + } + }) +} + +fn impl_specialize_specific( + specialize_path: &Path, + ecs_path: &Path, + ast: &DeriveInput, + field_info: &[FieldInfo], + target_path: &Path, + key_patterns: &[Pat], + key_tuple_idents: &[Ident], +) -> TokenStream { + let key_elems: Vec = field_info + .iter() + .filter_map(|field_info| field_info.key_ty(specialize_path, target_path)) + .collect(); + let specialize_exprs: Vec = field_info + .iter() + .map(|field_info| field_info.specialize_expr(specialize_path, target_path)) + .collect(); + + let struct_name = &ast.ident; + let (impl_generics, type_generics, where_clause) = &ast.generics.split_for_impl(); + + TokenStream::from(quote! { + impl #impl_generics #specialize_path::Specialize<#target_path> for #struct_name #type_generics #where_clause { + type Key = (#(#key_elems),*); + + fn specialize( + &self, + key: Self::Key, + descriptor: &mut <#target_path as #specialize_path::Specializable>::Descriptor + ) -> #FQResult<#specialize_path::Canonical, #ecs_path::error::BevyError> { + #(let #key_patterns = #specialize_exprs?;)* + #FQResult::Ok((#(#key_tuple_idents),*)) + } + } + }) +} + +fn impl_get_base_descriptor_specific( + specialize_path: &Path, + ast: &DeriveInput, + base_descriptor_field_info: &FieldInfo, + target_path: &Path, +) -> TokenStream { + let struct_name = &ast.ident; + let (impl_generics, type_generics, where_clause) = &ast.generics.split_for_impl(); + let field_ty = &base_descriptor_field_info.ty; + let field_member = &base_descriptor_field_info.member; + TokenStream::from(quote!( + impl #impl_generics #specialize_path::GetBaseDescriptor<#target_path> for #struct_name #type_generics #where_clause { + fn get_base_descriptor(&self) -> <#target_path as #specialize_path::Specializable>::Descriptor { + <#field_ty as #specialize_path::GetBaseDescriptor<#target_path>>::base_descriptor(&self.#field_member) + } + } + )) +} + +fn impl_get_base_descriptor_all( + specialize_path: &Path, + ast: &DeriveInput, + base_descriptor_field_info: &FieldInfo, +) -> TokenStream { + let target_path = Path::from(format_ident!("T")); + let struct_name = &ast.ident; + let mut generics = ast.generics.clone(); + generics.params.insert( + 0, + parse_quote!(#target_path: #specialize_path::Specializable), + ); + + let where_clause = generics.make_where_clause(); + where_clause.predicates.push( + base_descriptor_field_info.get_base_descriptor_predicate(specialize_path, &target_path), + ); + + let (_, type_generics, _) = ast.generics.split_for_impl(); + let (impl_generics, _, where_clause) = &generics.split_for_impl(); + let field_ty = &base_descriptor_field_info.ty; + let field_member = &base_descriptor_field_info.member; + TokenStream::from(quote! { + impl #impl_generics #specialize_path::GetBaseDescriptor<#target_path> for #struct_name #type_generics #where_clause { + fn get_base_descriptor(&self) -> <#target_path as #specialize_path::Specializable>::Descriptor { + <#field_ty as #specialize_path::GetBaseDescriptor<#target_path>>::base_descriptor(&self.#field_member) + } + } + }) +} + +pub fn impl_specializer_key(input: TokenStream) -> TokenStream { + let bevy_render_path: Path = crate::bevy_render_path(); + let specialize_path = { + let mut path = bevy_render_path.clone(); + path.segments.push(format_ident!("render_resource").into()); + path + }; + + let ast = parse_macro_input!(input as DeriveInput); + let ident = ast.ident; + TokenStream::from(quote!( + impl #specialize_path::SpecializerKey for #ident { + const IS_CANONICAL: bool = true; + type Canonical = Self; + } + )) +} diff --git a/crates/bevy_render/src/render_resource/mod.rs b/crates/bevy_render/src/render_resource/mod.rs index 09be66e840..9233d9e4c4 100644 --- a/crates/bevy_render/src/render_resource/mod.rs +++ b/crates/bevy_render/src/render_resource/mod.rs @@ -12,6 +12,7 @@ mod pipeline_cache; mod pipeline_specializer; pub mod resource_macros; mod shader; +mod specialize; mod storage_buffer; mod texture; mod uniform_buffer; @@ -28,6 +29,7 @@ pub use pipeline::*; pub use pipeline_cache::*; pub use pipeline_specializer::*; pub use shader::*; +pub use specialize::*; pub use storage_buffer::*; pub use texture::*; pub use uniform_buffer::*; diff --git a/crates/bevy_render/src/render_resource/pipeline.rs b/crates/bevy_render/src/render_resource/pipeline.rs index b76174cac3..35020a43c3 100644 --- a/crates/bevy_render/src/render_resource/pipeline.rs +++ b/crates/bevy_render/src/render_resource/pipeline.rs @@ -138,7 +138,7 @@ pub struct FragmentState { } /// Describes a compute pipeline. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct ComputePipelineDescriptor { pub label: Option>, pub layout: Vec, diff --git a/crates/bevy_render/src/render_resource/specialize.rs b/crates/bevy_render/src/render_resource/specialize.rs new file mode 100644 index 0000000000..c2269b1a78 --- /dev/null +++ b/crates/bevy_render/src/render_resource/specialize.rs @@ -0,0 +1,400 @@ +use super::{ + CachedComputePipelineId, CachedRenderPipelineId, ComputePipeline, ComputePipelineDescriptor, + PipelineCache, RenderPipeline, RenderPipelineDescriptor, +}; +use bevy_ecs::{ + error::BevyError, + resource::Resource, + world::{FromWorld, World}, +}; +use bevy_platform::{ + collections::{ + hash_map::{Entry, VacantEntry}, + HashMap, + }, + hash::FixedHasher, +}; +use core::{hash::Hash, marker::PhantomData}; +use tracing::error; +use variadics_please::all_tuples; + +pub use bevy_render_macros::{Specialize, SpecializerKey}; + +/// Defines a type that is able to be "specialized" and cached by creating and transforming +/// its descriptor type. This is implemented for [`RenderPipeline`] and [`ComputePipeline`], and +/// likely will not have much utility for other types. +pub trait Specializable { + type Descriptor: PartialEq + Clone + Send + Sync; + type CachedId: Clone + Send + Sync; + fn queue(pipeline_cache: &PipelineCache, descriptor: Self::Descriptor) -> Self::CachedId; + fn get_descriptor(pipeline_cache: &PipelineCache, id: Self::CachedId) -> &Self::Descriptor; +} + +impl Specializable for RenderPipeline { + type Descriptor = RenderPipelineDescriptor; + type CachedId = CachedRenderPipelineId; + + fn queue(pipeline_cache: &PipelineCache, descriptor: Self::Descriptor) -> Self::CachedId { + pipeline_cache.queue_render_pipeline(descriptor) + } + + fn get_descriptor( + pipeline_cache: &PipelineCache, + id: CachedRenderPipelineId, + ) -> &Self::Descriptor { + pipeline_cache.get_render_pipeline_descriptor(id) + } +} + +impl Specializable for ComputePipeline { + type Descriptor = ComputePipelineDescriptor; + + type CachedId = CachedComputePipelineId; + + fn queue(pipeline_cache: &PipelineCache, descriptor: Self::Descriptor) -> Self::CachedId { + pipeline_cache.queue_compute_pipeline(descriptor) + } + + fn get_descriptor( + pipeline_cache: &PipelineCache, + id: CachedComputePipelineId, + ) -> &Self::Descriptor { + pipeline_cache.get_compute_pipeline_descriptor(id) + } +} + +/// Defines a type that is able to transform descriptors for a specializable +/// type T, based on a hashable key type. +/// +/// This is mainly used when "specializing" render +/// pipelines, i.e. specifying shader defs and binding layout based on the key, +/// the result of which can then be cached and accessed quickly later. +/// +/// This trait can be derived with `#[derive(Specializer)]` for structs whose +/// fields all implement [`Specializer`]. The key type will be tuple of the keys +/// of each field, and their specialization logic will be applied in field +/// order. Since derive macros can't have generic parameters, the derive macro +/// requires an additional `#[specialize(..targets)]` attribute to specify a +/// list of types to target for the implementation. `#[specialize(all)]` is +/// also allowed, and will generate a fully generic implementation at the cost +/// of slightly worse error messages. +/// +/// Additionally, each field can optionally take a `#[key]` attribute to +/// specify a "key override". This will "hide" that field's key from being +/// exposed by the wrapper, and always use the value given by the attribute. +/// Values for this attribute may either be `default` which will use the key's +/// [`Default`] implementation, or a valid rust +/// expression of the key type. +/// +/// Example: +/// ```rs +/// # use super::RenderPipeline; +/// # use super::RenderPipelineDescriptor; +/// # use bevy_ecs::error::BevyError; +/// +/// struct A; +/// struct B; +/// #[derive(Copy, Clone, PartialEq, Eq, Hash, SpecializerKey)] +/// struct BKey; +/// +/// impl Specializer for A { +/// type Key = (); +/// +/// fn specializer(&self, key: (), descriptor: &mut RenderPipelineDescriptor) -> Result<(), BevyError> { +/// # let _ = (key, descriptor); +/// //... +/// Ok(()) +/// } +/// } +/// +/// impl Specializer for B { +/// type Key = BKey; +/// +/// fn specialize(&self, _key: Bkey, _descriptor: &mut RenderPipelineDescriptor) -> Result { +/// # let _ = (key, descriptor); +/// //... +/// Ok(BKey) +/// } +/// } +/// +/// #[derive(Specializer)] +/// #[specialize(RenderPipeline)] +/// struct C { +/// #[key(default)] +/// a: A, +/// b: B, +/// } +/// +/// /* +/// The generated implementation: +/// impl Specializer for C { +/// type Key = BKey; +/// fn specialize( +/// &self, +/// key: Self::Key, +/// descriptor: &mut RenderPipelineDescriptor +/// ) -> Result, BevyError> { +/// let _ = self.a.specialize((), descriptor); +/// let key = self.b.specialize(key, descriptor); +/// Ok(key) +/// } +/// } +/// */ +/// ``` +pub trait Specializer: Send + Sync + 'static { + type Key: SpecializerKey; + fn specialize( + &self, + key: Self::Key, + descriptor: &mut T::Descriptor, + ) -> Result, BevyError>; +} + +/// Defines a type that is able to be used as a key for types that `impl Specialize` +/// +/// **Most types should implement this trait with `IS_CANONICAL = true` and `Canonical = Self`**. +/// This is the implementation generated by `#[derive(SpecializerKey)]` +/// +/// In this case, "canonical" means that each unique value of this type will produce +/// a unique specialized result, which isn't true in general. `MeshVertexBufferLayout` +/// is a good example of a type that's `Eq + Hash`, but that isn't canonical: vertex +/// attributes could be specified in any order, or there could be more attributes +/// provided than the specialized pipeline requires. Its `Canonical` key type would +/// be `VertexBufferLayout`, the final layout required by the pipeline. +/// +/// Processing keys into canonical keys this way allows the `SpecializedCache` to reuse +/// resources more eagerly where possible. +pub trait SpecializerKey: Clone + Hash + Eq { + /// Denotes whether this key is canonical or not. This should only be `true` + /// if and only if `Canonical = Self`. + const IS_CANONICAL: bool; + + /// The canonical key type to convert this into during specialization. + type Canonical: Hash + Eq; +} + +pub type Canonical = ::Canonical; + +impl Specializer for () { + type Key = (); + + fn specialize( + &self, + _key: Self::Key, + _descriptor: &mut T::Descriptor, + ) -> Result<(), BevyError> { + Ok(()) + } +} + +impl Specializer for PhantomData { + type Key = (); + + fn specialize( + &self, + _key: Self::Key, + _descriptor: &mut T::Descriptor, + ) -> Result<(), BevyError> { + Ok(()) + } +} + +macro_rules! impl_specialization_key_tuple { + ($($T:ident),*) => { + impl <$($T: SpecializerKey),*> SpecializerKey for ($($T,)*) { + const IS_CANONICAL: bool = true $(&& <$T as SpecializerKey>::IS_CANONICAL)*; + type Canonical = ($(Canonical<$T>,)*); + } + }; +} + +all_tuples!(impl_specialization_key_tuple, 0, 12, T); + +/// Defines a specializer that can also provide a "base descriptor". +/// +/// In order to be composable, [`Specializer`] implementers don't create full +/// descriptors, only transform them. However, [`SpecializedCache`]s need a +/// "base descriptor" at creation time in order to have something for the +/// [`Specializer`] implementation to work off of. This trait allows +/// [`SpecializedCache`] to impl [`FromWorld`] for [`Specializer`] +/// implementations that also satisfy [`FromWorld`] and [`GetBaseDescriptor`]. +/// +/// This trait can be also derived with `#[derive(Specializer)]`, by marking +/// a field with `#[base_descriptor]` to use its [`GetBaseDescriptor`] implementation. +/// +/// Example: +/// ```rs +/// struct A; +/// struct B; +/// +/// impl Specializer for A { +/// type Key = (); +/// +/// fn specialize(&self, _key: (), _descriptor: &mut RenderPipelineDescriptor) { +/// //... +/// } +/// } +/// +/// impl Specializer for B { +/// type Key = u32; +/// +/// fn specialize(&self, _key: u32, _descriptor: &mut RenderPipelineDescriptor) { +/// //... +/// } +/// } +/// +/// impl GetBaseDescriptor for B { +/// fn get_base_descriptor(&self) -> RenderPipelineDescriptor { +/// # todo!() +/// //... +/// } +/// } +/// +/// +/// #[derive(Specializer)] +/// #[specialize(RenderPipeline)] +/// struct C { +/// #[key(default)] +/// a: A, +/// #[base_descriptor] +/// b: B, +/// } +/// +/// /* +/// The generated implementation: +/// impl GetBaseDescriptor for C { +/// fn get_base_descriptor(&self) -> RenderPipelineDescriptor { +/// self.b.base_descriptor() +/// } +/// } +/// */ +/// ``` +pub trait GetBaseDescriptor: Specializer { + fn get_base_descriptor(&self) -> T::Descriptor; +} + +pub type SpecializerFn = + fn(>::Key, &mut ::Descriptor) -> Result<(), BevyError>; + +/// A cache for specializable resources. For a given key type the resulting +/// resource will only be created if it is missing, retrieving it from the +/// cache otherwise. +#[derive(Resource)] +pub struct SpecializedCache> { + specializer: S, + user_specializer: Option>, + base_descriptor: T::Descriptor, + primary_cache: HashMap, + secondary_cache: HashMap, T::CachedId>, +} + +impl> SpecializedCache { + /// Creates a new [`SpecializedCache`] from a [`Specializer`], + /// an optional "user specializer", and a base descriptor. The + /// user specializer is applied after the [`Specializer`], with + /// the same key. + #[inline] + pub fn new( + specializer: S, + user_specializer: Option>, + base_descriptor: T::Descriptor, + ) -> Self { + Self { + specializer, + user_specializer, + base_descriptor, + primary_cache: Default::default(), + secondary_cache: Default::default(), + } + } + + /// Specializes a resource given the [`Specializer`]'s key type. + #[inline] + pub fn specialize( + &mut self, + pipeline_cache: &PipelineCache, + key: S::Key, + ) -> Result { + let entry = self.primary_cache.entry(key.clone()); + match entry { + Entry::Occupied(entry) => Ok(entry.get().clone()), + Entry::Vacant(entry) => Self::specialize_slow( + &self.specializer, + self.user_specializer, + self.base_descriptor.clone(), + pipeline_cache, + key, + entry, + &mut self.secondary_cache, + ), + } + } + + #[cold] + fn specialize_slow( + specializer: &S, + user_specializer: Option>, + base_descriptor: T::Descriptor, + pipeline_cache: &PipelineCache, + key: S::Key, + primary_entry: VacantEntry, + secondary_cache: &mut HashMap, T::CachedId>, + ) -> Result { + let mut descriptor = base_descriptor.clone(); + let canonical_key = specializer.specialize(key.clone(), &mut descriptor)?; + + if let Some(user_specializer) = user_specializer { + (user_specializer)(key, &mut descriptor)?; + } + + // if the whole key is canonical, the secondary cache isn't needed. + if ::IS_CANONICAL { + return Ok(primary_entry + .insert(::queue(pipeline_cache, descriptor)) + .clone()); + } + + let id = match secondary_cache.entry(canonical_key) { + Entry::Occupied(entry) => { + if cfg!(debug_assertions) { + let stored_descriptor = + ::get_descriptor(pipeline_cache, entry.get().clone()); + if &descriptor != stored_descriptor { + error!( + "Invalid Specializer<{}> impl for {}: the cached descriptor \ + is not equal to the generated descriptor for the given key. \ + This means the Specializer implementation uses unused information \ + from the key to specialize the pipeline. This is not allowed \ + because it would invalidate the cache.", + core::any::type_name::(), + core::any::type_name::() + ); + } + } + entry.into_mut().clone() + } + Entry::Vacant(entry) => entry + .insert(::queue(pipeline_cache, descriptor)) + .clone(), + }; + + primary_entry.insert(id.clone()); + Ok(id) + } +} + +/// [`SpecializedCache`] implements [`FromWorld`] for [`Specializer`]s +/// that also satisfy [`FromWorld`] and [`GetBaseDescriptor`]. This will +/// create a [`SpecializedCache`] with no user specializer, and the base +/// descriptor take from the specializer's [`GetBaseDescriptor`] implementation. +impl FromWorld for SpecializedCache +where + T: Specializable, + S: FromWorld + Specializer + GetBaseDescriptor, +{ + fn from_world(world: &mut World) -> Self { + let specializer = S::from_world(world); + let base_descriptor = specializer.get_base_descriptor(); + Self::new(specializer, None, base_descriptor) + } +} diff --git a/examples/shader/custom_phase_item.rs b/examples/shader/custom_phase_item.rs index b363a7c27f..f06aba1403 100644 --- a/examples/shader/custom_phase_item.rs +++ b/examples/shader/custom_phase_item.rs @@ -24,11 +24,11 @@ use bevy::{ ViewBinnedRenderPhases, }, render_resource::{ - BufferUsages, ColorTargetState, ColorWrites, CompareFunction, DepthStencilState, - FragmentState, IndexFormat, MultisampleState, PipelineCache, PrimitiveState, - RawBufferVec, RenderPipelineDescriptor, SpecializedRenderPipeline, - SpecializedRenderPipelines, TextureFormat, VertexAttribute, VertexBufferLayout, - VertexFormat, VertexState, VertexStepMode, + BufferUsages, Canonical, ColorTargetState, ColorWrites, CompareFunction, + DepthStencilState, FragmentState, GetBaseDescriptor, IndexFormat, MultisampleState, + PipelineCache, PrimitiveState, RawBufferVec, RenderPipeline, RenderPipelineDescriptor, + SpecializedCache, Specializer, SpecializerKey, TextureFormat, VertexAttribute, + VertexBufferLayout, VertexFormat, VertexState, VertexStepMode, }, renderer::{RenderDevice, RenderQueue}, view::{self, ExtractedView, RenderVisibleEntities, VisibilityClass}, @@ -49,14 +49,6 @@ use bytemuck::{Pod, Zeroable}; #[component(on_add = view::add_visibility_class::)] struct CustomRenderedEntity; -/// Holds a reference to our shader. -/// -/// This is loaded at app creation time. -#[derive(Resource)] -struct CustomPhasePipeline { - shader: Handle, -} - /// A [`RenderCommand`] that binds the vertex and index buffers and issues the /// draw command for our custom phase item. struct DrawCustomPhaseItem; @@ -175,8 +167,7 @@ fn main() { // We make sure to add these to the render app, not the main app. app.get_sub_app_mut(RenderApp) .unwrap() - .init_resource::() - .init_resource::>() + .init_resource::>() .add_render_command::() .add_systems( Render, @@ -221,10 +212,9 @@ fn prepare_custom_phase_item_buffers(mut commands: Commands) { /// the opaque render phases of each view. fn queue_custom_phase_item( pipeline_cache: Res, - custom_phase_pipeline: Res, mut opaque_render_phases: ResMut>, opaque_draw_functions: Res>, - mut specialized_render_pipelines: ResMut>, + mut specializer: ResMut>, views: Query<(&ExtractedView, &RenderVisibleEntities, &Msaa)>, mut next_tick: Local, ) { @@ -247,11 +237,10 @@ fn queue_custom_phase_item( // some per-view settings, such as whether the view is HDR, but for // simplicity's sake we simply hard-code the view's characteristics, // with the exception of number of MSAA samples. - let pipeline_id = specialized_render_pipelines.specialize( - &pipeline_cache, - &custom_phase_pipeline, - *msaa, - ); + let Ok(pipeline_id) = specializer.specialize(&pipeline_cache, CustomPhaseKey(*msaa)) + else { + continue; + }; // Bump the change tick in order to force Bevy to rebuild the bin. let this_tick = next_tick.get() + 1; @@ -286,10 +275,40 @@ fn queue_custom_phase_item( } } -impl SpecializedRenderPipeline for CustomPhasePipeline { - type Key = Msaa; +/// Holds a reference to our shader. +/// +/// This is loaded at app creation time. +struct CustomPhaseSpecializer { + shader: Handle, +} - fn specialize(&self, msaa: Self::Key) -> RenderPipelineDescriptor { +impl FromWorld for CustomPhaseSpecializer { + fn from_world(world: &mut World) -> Self { + let asset_server = world.resource::(); + Self { + shader: asset_server.load("shaders/custom_phase_item.wgsl"), + } + } +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash, SpecializerKey)] +struct CustomPhaseKey(Msaa); + +impl Specializer for CustomPhaseSpecializer { + type Key = CustomPhaseKey; + + fn specialize( + &self, + key: Self::Key, + descriptor: &mut RenderPipelineDescriptor, + ) -> Result, BevyError> { + descriptor.multisample.count = key.0.samples(); + Ok(key) + } +} + +impl GetBaseDescriptor for CustomPhaseSpecializer { + fn get_base_descriptor(&self) -> RenderPipelineDescriptor { RenderPipelineDescriptor { label: Some("custom render pipeline".into()), layout: vec![], @@ -340,7 +359,7 @@ impl SpecializedRenderPipeline for CustomPhasePipeline { bias: default(), }), multisample: MultisampleState { - count: msaa.samples(), + count: 0, mask: !0, alpha_to_coverage_enabled: false, }, @@ -375,14 +394,3 @@ impl FromWorld for CustomPhaseItemBuffers { } } } - -impl FromWorld for CustomPhasePipeline { - fn from_world(world: &mut World) -> Self { - // Load and compile the shader in the background. - let asset_server = world.resource::(); - - CustomPhasePipeline { - shader: asset_server.load("shaders/custom_phase_item.wgsl"), - } - } -} diff --git a/release-content/migration-guides/composable_specialization.md b/release-content/migration-guides/composable_specialization.md new file mode 100644 index 0000000000..f87beef8cb --- /dev/null +++ b/release-content/migration-guides/composable_specialization.md @@ -0,0 +1,153 @@ +--- +title: Composable Specialization +pull_requests: [17373] +--- + +The existing pipeline specialization APIs (`SpecializedRenderPipeline` etc.) have +been replaced with a single `Specializer` trait and `SpecializedCache` collection: + +```rs +pub trait Specializer: Send + Sync + 'static { + type Key: SpecializerKey; + fn specialize( + &self, + key: Self::Key, + descriptor: &mut T::Descriptor, + ) -> Result, BevyError>; +} + +pub struct SpecializedCache>{ ... }; +``` + +The main difference is the change from *producing* a pipeline descriptor to +*mutating* one based on a key. The "base descriptor" that the `SpecializedCache` +passes to the `Specializer` can either be specified manually with `Specializer::new` +or by implementing `GetBaseDescriptor`. There's also a new trait for specialization +keys, `SpecializeKey`, that can be derived with the included macro in most cases. + +Composing multiple different specializers together with the `derive(Specializer)` +macro can be a lot more powerful (see the `Specialize` docs), but migrating +individual specializers is fairly simple. All static parts of the pipeline +should be specified in the base descriptor, while the `Specializer` impl +should mutate the key as little as necessary to match the key. + +```rs +pub struct MySpecializer { + layout: BindGroupLayout, + layout_msaa: BindGroupLayout, + vertex: Handle, + fragment: Handle, +} + +// before +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +// after +#[derive(Clone, Copy, PartialEq, Eq, Hash, SpecializerKey)] + +pub struct MyKey { + blend_state: BlendState, + msaa: Msaa, +} + +impl FromWorld for MySpecializer { + fn from_world(&mut World) -> Self { + ... + } +} + +// before +impl SpecializedRenderPipeline for MySpecializer { + type Key = MyKey; + + fn specialize(&self, key: Self::Key) -> RenderPipelineDescriptor { + RenderPipelineDescriptor { + label: Some("my_pipeline".into()), + layout: vec![ + if key.msaa.samples() > 0 { + self.layout_msaa.clone() + } else { + self.layout.clone() + } + ], + push_constant_ranges: vec![], + vertex: VertexState { + shader: self.vertex.clone(), + shader_defs: vec![], + entry_point: "vertex".into(), + buffers: vec![], + }, + primitive: Default::default(), + depth_stencil: None, + multisample: MultisampleState { + count: key.msaa.samples(), + ..Default::default() + }, + fragment: Some(FragmentState { + shader: self.fragment.clone(), + shader_defs: vec![], + entry_point: "fragment".into(), + targets: vec![Some(ColorTargetState { + format: TextureFormat::Rgba8Unorm, + blend: Some(key.blend_state), + write_mask: ColorWrites::all(), + })], + }), + zero_initialize_workgroup_memory: false, + }, + } +} + +app.init_resource::>(); + +// after +impl Specializer for MySpecializer { + type Key = MyKey; + + fn specialize( + &self, + key: Self::Key, + descriptor: &mut RenderPipeline, + ) -> Result, BevyError> { + descriptor.multisample.count = key.msaa.samples(); + descriptor.layout[0] = if key.msaa.samples() > 0 { + self.layout_msaa.clone() + } else { + self.layout.clone() + }; + descriptor.fragment.targets[0].as_mut().unwrap().blend_mode = key.blend_state; + Ok(key) + } +} + +impl GetBaseDescriptor for MySpecializer { + fn get_base_descriptor(&self) -> RenderPipelineDescriptor { + RenderPipelineDescriptor { + label: Some("my_pipeline".into()), + layout: vec![self.layout.clone()], + push_constant_ranges: vec![], + vertex: VertexState { + shader: self.vertex.clone(), + shader_defs: vec![], + entry_point: "vertex".into(), + buffers: vec![], + }, + primitive: Default::default(), + depth_stencil: None, + multisample: MultiSampleState::default(), + fragment: Some(FragmentState { + shader: self.fragment.clone(), + shader_defs: vec![], + entry_point: "fragment".into(), + targets: vec![Some(ColorTargetState { + format: TextureFormat::Rgba8Unorm, + blend: None, + write_mask: ColorWrites::all(), + })], + }), + zero_initialize_workgroup_memory: false, + }, + } +} + +app.init_resource::>(); +```