Composable Pipeline Specialization (#17373)

Currently, our specialization API works through a series of wrapper
structs and traits, which make things confusing to follow and difficult
to generalize.

This pr takes a different approach, where "specializers" (types that
implement `Specialize`) are composable, but "flat" rather than composed
of a series of wrappers. The key is that specializers don't *produce*
pipeline descriptors, but instead *modify* existing ones:

```rs
pub trait Specialize<T: Specializable> {
    type Key: SpecializeKey;
    
    fn specialize(
        &self, 
        key: Self::Key, 
        descriptor: &mut T::Descriptor
    ) -> Result<Canonical<Self::Key>, BevyError>;
}
```

This lets us use some derive magic to stick multiple specializers
together:

```rs
pub struct A;
pub struct B;

impl Specialize<RenderPipeline> for A { ... }
impl Specialize<RenderPipeline> for A { ... }

#[derive(Specialize)]
#[specialize(RenderPipeline)]
struct C {
    // specialization is applied in struct field order
    applied_first: A,
    applied_second: B,
}

type C::Key = (A::Key, B::Key);

```

This approach is much easier to understand, IMO, and also lets us
separate concerns better. Specializers can be placed in fully separate
crates/modules, and key computation can be shared as well.

The only real breaking change here is that since specializers only
modify descriptors, we need a "base" descriptor to work off of. This can
either be manually supplied when constructing a `Specializer` (the new
collection replacing `Specialized[Render/Compute]Pipelines`), or
supplied by implementing `HasBaseDescriptor` on a specializer. See
`examples/shader/custom_phase_item.rs` for an example implementation.

## Testing

- Did some simple manual testing of the derive macro, it seems robust.

---

## Showcase

```rs
#[derive(Specialize, HasBaseDescriptor)]
#[specialize(RenderPipeline)]
pub struct SpecializeMeshMaterial<M: Material> {
    // set mesh bind group layout and shader defs
    mesh: SpecializeMesh,
    // set view bind group layout and shader defs
    view: SpecializeView,
    // since type SpecializeMaterial::Key = (), 
    // we can hide it from the wrapper's external API
    #[key(default)]
    // defer to the GetBaseDescriptor impl of SpecializeMaterial, 
    // since it carries the vertex and fragment handles
    #[base_descriptor]
    // set material bind group layout, etc
    material: SpecializeMaterial<M>,
}

// implementation generated by the derive macro
impl <M: Material> Specialize<RenderPipeline> for SpecializeMeshMaterial<M> {
    type Key = (MeshKey, ViewKey);

    fn specialize(
        &self, 
        key: Self::Key, 
        descriptor: &mut RenderPipelineDescriptor
    ) -> Result<Canonical<Self::Key>, BevyError>  {
        let mesh_key = self.mesh.specialize(key.0, descriptor)?;
        let view_key = self.view.specialize(key.1, descriptor)?;
        let _ = self.material.specialize((), descriptor)?;
        Ok((mesh_key, view_key));
    }
}

impl <M: Material> HasBaseDescriptor<RenderPipeline> for SpecializeMeshMaterial<M> {
    fn base_descriptor(&self) -> RenderPipelineDescriptor {
        self.material.base_descriptor()
    }
}
```

---------

Co-authored-by: Tim Overbeek <158390905+Bleachfuel@users.noreply.github.com>
This commit is contained in:
Emerson Coskey 2025-06-30 18:32:44 -07:00 committed by GitHub
parent f98727c1b1
commit bdd3ef71b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 1103 additions and 38 deletions

View File

@ -4,6 +4,7 @@
mod as_bind_group;
mod extract_component;
mod extract_resource;
mod specialize;
use bevy_macro_utils::{derive_label, BevyManifest};
use proc_macro::TokenStream;
@ -14,6 +15,10 @@ pub(crate) fn bevy_render_path() -> syn::Path {
BevyManifest::shared().get_path("bevy_render")
}
pub(crate) fn bevy_ecs_path() -> syn::Path {
BevyManifest::shared().get_path("bevy_ecs")
}
#[proc_macro_derive(ExtractResource)]
pub fn derive_extract_resource(input: TokenStream) -> TokenStream {
extract_resource::derive_extract_resource(input)
@ -102,6 +107,20 @@ pub fn derive_render_sub_graph(input: TokenStream) -> TokenStream {
derive_label(input, "RenderSubGraph", &trait_path)
}
/// Derive macro generating an impl of the trait `Specialize`
///
/// This only works for structs whose members all implement `Specialize`
#[proc_macro_derive(Specialize, attributes(specialize, key, base_descriptor))]
pub fn derive_specialize(input: TokenStream) -> TokenStream {
specialize::impl_specialize(input)
}
/// Derive macro generating the most common impl of the trait `SpecializerKey`
#[proc_macro_derive(SpecializerKey)]
pub fn derive_specializer_key(input: TokenStream) -> TokenStream {
specialize::impl_specializer_key(input)
}
#[proc_macro_derive(ShaderLabel)]
pub fn derive_shader_label(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);

View File

@ -0,0 +1,483 @@
use bevy_macro_utils::fq_std::{FQDefault, FQResult};
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::{format_ident, quote};
use syn::{
parse,
parse::{Parse, ParseStream},
parse_macro_input, parse_quote,
spanned::Spanned,
Data, DataStruct, DeriveInput, Expr, Fields, Ident, Index, Member, Meta, MetaList, Pat, Path,
Token, Type, WherePredicate,
};
const SPECIALIZE_ATTR_IDENT: &str = "specialize";
const SPECIALIZE_ALL_IDENT: &str = "all";
const KEY_ATTR_IDENT: &str = "key";
const KEY_DEFAULT_IDENT: &str = "default";
const BASE_DESCRIPTOR_ATTR_IDENT: &str = "base_descriptor";
enum SpecializeImplTargets {
All,
Specific(Vec<Path>),
}
impl Parse for SpecializeImplTargets {
fn parse(input: ParseStream) -> syn::Result<Self> {
let paths = input.parse_terminated(Path::parse, Token![,])?;
if paths
.first()
.is_some_and(|p| p.is_ident(SPECIALIZE_ALL_IDENT))
{
Ok(SpecializeImplTargets::All)
} else {
Ok(SpecializeImplTargets::Specific(paths.into_iter().collect()))
}
}
}
#[derive(Clone)]
enum Key {
Whole,
Default,
Index(Index),
Custom(Expr),
}
impl Key {
fn expr(&self) -> Expr {
match self {
Key::Whole => parse_quote!(key),
Key::Default => parse_quote!(#FQDefault::default()),
Key::Index(index) => {
let member = Member::Unnamed(index.clone());
parse_quote!(key.#member)
}
Key::Custom(expr) => expr.clone(),
}
}
}
const KEY_ERROR_MSG: &str = "Invalid key override. Must be either `default` or a valid Rust expression of the correct key type";
impl Parse for Key {
fn parse(input: ParseStream) -> syn::Result<Self> {
if let Ok(ident) = input.parse::<Ident>() {
if ident == KEY_DEFAULT_IDENT {
Ok(Key::Default)
} else {
Err(syn::Error::new_spanned(ident, KEY_ERROR_MSG))
}
} else {
input.parse::<Expr>().map(Key::Custom).map_err(|mut err| {
err.extend(syn::Error::new(err.span(), KEY_ERROR_MSG));
err
})
}
}
}
#[derive(Clone)]
struct FieldInfo {
ty: Type,
member: Member,
key: Key,
use_base_descriptor: bool,
}
impl FieldInfo {
fn key_ty(&self, specialize_path: &Path, target_path: &Path) -> Option<Type> {
let ty = &self.ty;
matches!(self.key, Key::Whole | Key::Index(_))
.then_some(parse_quote!(<#ty as #specialize_path::Specialize<#target_path>>::Key))
}
fn key_ident(&self, ident: Ident) -> Option<Ident> {
matches!(self.key, Key::Whole | Key::Index(_)).then_some(ident)
}
fn specialize_expr(&self, specialize_path: &Path, target_path: &Path) -> Expr {
let FieldInfo {
ty, member, key, ..
} = &self;
let key_expr = key.expr();
parse_quote!(<#ty as #specialize_path::Specialize<#target_path>>::specialize(&self.#member, #key_expr, descriptor))
}
fn specialize_predicate(&self, specialize_path: &Path, target_path: &Path) -> WherePredicate {
let ty = &self.ty;
if matches!(&self.key, Key::Default) {
parse_quote!(#ty: #specialize_path::Specialize<#target_path, Key: #FQDefault>)
} else {
parse_quote!(#ty: #specialize_path::Specialize<#target_path>)
}
}
fn get_base_descriptor_predicate(
&self,
specialize_path: &Path,
target_path: &Path,
) -> WherePredicate {
let ty = &self.ty;
parse_quote!(#ty: #specialize_path::GetBaseDescriptor<#target_path>)
}
}
fn get_field_info(fields: &Fields, targets: &SpecializeImplTargets) -> syn::Result<Vec<FieldInfo>> {
let mut field_info: Vec<FieldInfo> = Vec::new();
let mut used_count = 0;
let mut single_index = 0;
for (index, field) in fields.iter().enumerate() {
let field_ty = field.ty.clone();
let field_member = field.ident.clone().map_or(
Member::Unnamed(Index {
index: index as u32,
span: field.span(),
}),
Member::Named,
);
let key_index = Index {
index: used_count,
span: field.span(),
};
let mut use_key_field = true;
let mut key = Key::Index(key_index);
let mut use_base_descriptor = false;
for attr in &field.attrs {
match &attr.meta {
Meta::Path(path) if path.is_ident(&BASE_DESCRIPTOR_ATTR_IDENT) => {
use_base_descriptor = true;
}
Meta::List(MetaList { path, tokens, .. }) if path.is_ident(&KEY_ATTR_IDENT) => {
let owned_tokens = tokens.clone().into();
let Ok(parsed_key) = parse::<Key>(owned_tokens) else {
return Err(syn::Error::new(
attr.span(),
"Invalid key override attribute",
));
};
key = parsed_key;
if matches!(
(&key, &targets),
(Key::Custom(_), SpecializeImplTargets::All)
) {
return Err(syn::Error::new(
attr.span(),
"#[key(default)] is the only key override type allowed with #[specialize(all)]",
));
}
use_key_field = false;
}
_ => {}
}
}
if use_key_field {
used_count += 1;
single_index = index;
}
field_info.push(FieldInfo {
ty: field_ty,
member: field_member,
key,
use_base_descriptor,
});
}
if used_count == 1 {
field_info[single_index].key = Key::Whole;
}
Ok(field_info)
}
fn get_struct_fields<'a>(ast: &'a DeriveInput, derive_name: &str) -> syn::Result<&'a Fields> {
match &ast.data {
Data::Struct(DataStruct { fields, .. }) => Ok(fields),
Data::Enum(data_enum) => Err(syn::Error::new(
data_enum.enum_token.span(),
format!("#[derive({derive_name})] only supports structs."),
)),
Data::Union(data_union) => Err(syn::Error::new(
data_union.union_token.span(),
format!("#[derive({derive_name})] only supports structs."),
)),
}
}
fn get_specialize_targets(
ast: &DeriveInput,
derive_name: &str,
) -> syn::Result<SpecializeImplTargets> {
let specialize_attr = ast.attrs.iter().find_map(|attr| {
if attr.path().is_ident(SPECIALIZE_ATTR_IDENT) {
if let Meta::List(meta_list) = &attr.meta {
return Some(meta_list);
}
}
None
});
let Some(specialize_meta_list) = specialize_attr else {
return Err(syn::Error::new(
Span::call_site(),
format!("#[derive({derive_name})] must be accompanied by #[specialize(..targets)].\n Example usages: #[specialize(RenderPipeline)], #[specialize(all)]")
));
};
parse::<SpecializeImplTargets>(specialize_meta_list.tokens.clone().into())
}
macro_rules! guard {
($expr: expr) => {
match $expr {
Ok(__val) => __val,
Err(err) => return err.to_compile_error().into(),
}
};
}
pub fn impl_specialize(input: TokenStream) -> TokenStream {
let bevy_render_path: Path = crate::bevy_render_path();
let specialize_path = {
let mut path = bevy_render_path.clone();
path.segments.push(format_ident!("render_resource").into());
path
};
let ecs_path = crate::bevy_ecs_path();
let ast = parse_macro_input!(input as DeriveInput);
let targets = guard!(get_specialize_targets(&ast, "Specialize"));
let fields = guard!(get_struct_fields(&ast, "Specialize"));
let field_info = guard!(get_field_info(fields, &targets));
let key_idents: Vec<Option<Ident>> = field_info
.iter()
.enumerate()
.map(|(i, field_info)| field_info.key_ident(format_ident!("key{i}")))
.collect();
let key_tuple_idents: Vec<Ident> = key_idents.iter().flatten().cloned().collect();
let ignore_pat: Pat = parse_quote!(_);
let key_patterns: Vec<Pat> = key_idents
.iter()
.map(|key_ident| match key_ident {
Some(key_ident) => parse_quote!(#key_ident),
None => ignore_pat.clone(),
})
.collect();
let base_descriptor_fields = field_info
.iter()
.filter(|field| field.use_base_descriptor)
.collect::<Vec<_>>();
if base_descriptor_fields.len() > 1 {
return syn::Error::new(
Span::call_site(),
"Too many #[base_descriptor] attributes found. It must be present on exactly one field",
)
.into_compile_error()
.into();
}
let base_descriptor_field = base_descriptor_fields.first().copied();
match targets {
SpecializeImplTargets::All => {
let specialize_impl = impl_specialize_all(
&specialize_path,
&ecs_path,
&ast,
&field_info,
&key_patterns,
&key_tuple_idents,
);
let get_base_descriptor_impl = base_descriptor_field
.map(|field_info| impl_get_base_descriptor_all(&specialize_path, &ast, field_info))
.unwrap_or_default();
[specialize_impl, get_base_descriptor_impl]
.into_iter()
.collect()
}
SpecializeImplTargets::Specific(targets) => {
let specialize_impls = targets.iter().map(|target| {
impl_specialize_specific(
&specialize_path,
&ecs_path,
&ast,
&field_info,
target,
&key_patterns,
&key_tuple_idents,
)
});
let get_base_descriptor_impls = targets.iter().filter_map(|target| {
base_descriptor_field.map(|field_info| {
impl_get_base_descriptor_specific(&specialize_path, &ast, field_info, target)
})
});
specialize_impls.chain(get_base_descriptor_impls).collect()
}
}
}
fn impl_specialize_all(
specialize_path: &Path,
ecs_path: &Path,
ast: &DeriveInput,
field_info: &[FieldInfo],
key_patterns: &[Pat],
key_tuple_idents: &[Ident],
) -> TokenStream {
let target_path = Path::from(format_ident!("T"));
let key_elems: Vec<Type> = field_info
.iter()
.filter_map(|field_info| field_info.key_ty(specialize_path, &target_path))
.collect();
let specialize_exprs: Vec<Expr> = field_info
.iter()
.map(|field_info| field_info.specialize_expr(specialize_path, &target_path))
.collect();
let struct_name = &ast.ident;
let mut generics = ast.generics.clone();
generics.params.insert(
0,
parse_quote!(#target_path: #specialize_path::Specializable),
);
if !field_info.is_empty() {
let where_clause = generics.make_where_clause();
for field in field_info {
where_clause
.predicates
.push(field.specialize_predicate(specialize_path, &target_path));
}
}
let (_, type_generics, _) = ast.generics.split_for_impl();
let (impl_generics, _, where_clause) = &generics.split_for_impl();
TokenStream::from(quote! {
impl #impl_generics #specialize_path::Specialize<#target_path> for #struct_name #type_generics #where_clause {
type Key = (#(#key_elems),*);
fn specialize(
&self,
key: Self::Key,
descriptor: &mut <#target_path as #specialize_path::Specializable>::Descriptor
) -> #FQResult<#specialize_path::Canonical<Self::Key>, #ecs_path::error::BevyError> {
#(let #key_patterns = #specialize_exprs?;)*
#FQResult::Ok((#(#key_tuple_idents),*))
}
}
})
}
fn impl_specialize_specific(
specialize_path: &Path,
ecs_path: &Path,
ast: &DeriveInput,
field_info: &[FieldInfo],
target_path: &Path,
key_patterns: &[Pat],
key_tuple_idents: &[Ident],
) -> TokenStream {
let key_elems: Vec<Type> = field_info
.iter()
.filter_map(|field_info| field_info.key_ty(specialize_path, target_path))
.collect();
let specialize_exprs: Vec<Expr> = field_info
.iter()
.map(|field_info| field_info.specialize_expr(specialize_path, target_path))
.collect();
let struct_name = &ast.ident;
let (impl_generics, type_generics, where_clause) = &ast.generics.split_for_impl();
TokenStream::from(quote! {
impl #impl_generics #specialize_path::Specialize<#target_path> for #struct_name #type_generics #where_clause {
type Key = (#(#key_elems),*);
fn specialize(
&self,
key: Self::Key,
descriptor: &mut <#target_path as #specialize_path::Specializable>::Descriptor
) -> #FQResult<#specialize_path::Canonical<Self::Key>, #ecs_path::error::BevyError> {
#(let #key_patterns = #specialize_exprs?;)*
#FQResult::Ok((#(#key_tuple_idents),*))
}
}
})
}
fn impl_get_base_descriptor_specific(
specialize_path: &Path,
ast: &DeriveInput,
base_descriptor_field_info: &FieldInfo,
target_path: &Path,
) -> TokenStream {
let struct_name = &ast.ident;
let (impl_generics, type_generics, where_clause) = &ast.generics.split_for_impl();
let field_ty = &base_descriptor_field_info.ty;
let field_member = &base_descriptor_field_info.member;
TokenStream::from(quote!(
impl #impl_generics #specialize_path::GetBaseDescriptor<#target_path> for #struct_name #type_generics #where_clause {
fn get_base_descriptor(&self) -> <#target_path as #specialize_path::Specializable>::Descriptor {
<#field_ty as #specialize_path::GetBaseDescriptor<#target_path>>::base_descriptor(&self.#field_member)
}
}
))
}
fn impl_get_base_descriptor_all(
specialize_path: &Path,
ast: &DeriveInput,
base_descriptor_field_info: &FieldInfo,
) -> TokenStream {
let target_path = Path::from(format_ident!("T"));
let struct_name = &ast.ident;
let mut generics = ast.generics.clone();
generics.params.insert(
0,
parse_quote!(#target_path: #specialize_path::Specializable),
);
let where_clause = generics.make_where_clause();
where_clause.predicates.push(
base_descriptor_field_info.get_base_descriptor_predicate(specialize_path, &target_path),
);
let (_, type_generics, _) = ast.generics.split_for_impl();
let (impl_generics, _, where_clause) = &generics.split_for_impl();
let field_ty = &base_descriptor_field_info.ty;
let field_member = &base_descriptor_field_info.member;
TokenStream::from(quote! {
impl #impl_generics #specialize_path::GetBaseDescriptor<#target_path> for #struct_name #type_generics #where_clause {
fn get_base_descriptor(&self) -> <#target_path as #specialize_path::Specializable>::Descriptor {
<#field_ty as #specialize_path::GetBaseDescriptor<#target_path>>::base_descriptor(&self.#field_member)
}
}
})
}
pub fn impl_specializer_key(input: TokenStream) -> TokenStream {
let bevy_render_path: Path = crate::bevy_render_path();
let specialize_path = {
let mut path = bevy_render_path.clone();
path.segments.push(format_ident!("render_resource").into());
path
};
let ast = parse_macro_input!(input as DeriveInput);
let ident = ast.ident;
TokenStream::from(quote!(
impl #specialize_path::SpecializerKey for #ident {
const IS_CANONICAL: bool = true;
type Canonical = Self;
}
))
}

View File

@ -12,6 +12,7 @@ mod pipeline_cache;
mod pipeline_specializer;
pub mod resource_macros;
mod shader;
mod specialize;
mod storage_buffer;
mod texture;
mod uniform_buffer;
@ -28,6 +29,7 @@ pub use pipeline::*;
pub use pipeline_cache::*;
pub use pipeline_specializer::*;
pub use shader::*;
pub use specialize::*;
pub use storage_buffer::*;
pub use texture::*;
pub use uniform_buffer::*;

View File

@ -138,7 +138,7 @@ pub struct FragmentState {
}
/// Describes a compute pipeline.
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ComputePipelineDescriptor {
pub label: Option<Cow<'static, str>>,
pub layout: Vec<BindGroupLayout>,

View File

@ -0,0 +1,400 @@
use super::{
CachedComputePipelineId, CachedRenderPipelineId, ComputePipeline, ComputePipelineDescriptor,
PipelineCache, RenderPipeline, RenderPipelineDescriptor,
};
use bevy_ecs::{
error::BevyError,
resource::Resource,
world::{FromWorld, World},
};
use bevy_platform::{
collections::{
hash_map::{Entry, VacantEntry},
HashMap,
},
hash::FixedHasher,
};
use core::{hash::Hash, marker::PhantomData};
use tracing::error;
use variadics_please::all_tuples;
pub use bevy_render_macros::{Specialize, SpecializerKey};
/// Defines a type that is able to be "specialized" and cached by creating and transforming
/// its descriptor type. This is implemented for [`RenderPipeline`] and [`ComputePipeline`], and
/// likely will not have much utility for other types.
pub trait Specializable {
type Descriptor: PartialEq + Clone + Send + Sync;
type CachedId: Clone + Send + Sync;
fn queue(pipeline_cache: &PipelineCache, descriptor: Self::Descriptor) -> Self::CachedId;
fn get_descriptor(pipeline_cache: &PipelineCache, id: Self::CachedId) -> &Self::Descriptor;
}
impl Specializable for RenderPipeline {
type Descriptor = RenderPipelineDescriptor;
type CachedId = CachedRenderPipelineId;
fn queue(pipeline_cache: &PipelineCache, descriptor: Self::Descriptor) -> Self::CachedId {
pipeline_cache.queue_render_pipeline(descriptor)
}
fn get_descriptor(
pipeline_cache: &PipelineCache,
id: CachedRenderPipelineId,
) -> &Self::Descriptor {
pipeline_cache.get_render_pipeline_descriptor(id)
}
}
impl Specializable for ComputePipeline {
type Descriptor = ComputePipelineDescriptor;
type CachedId = CachedComputePipelineId;
fn queue(pipeline_cache: &PipelineCache, descriptor: Self::Descriptor) -> Self::CachedId {
pipeline_cache.queue_compute_pipeline(descriptor)
}
fn get_descriptor(
pipeline_cache: &PipelineCache,
id: CachedComputePipelineId,
) -> &Self::Descriptor {
pipeline_cache.get_compute_pipeline_descriptor(id)
}
}
/// Defines a type that is able to transform descriptors for a specializable
/// type T, based on a hashable key type.
///
/// This is mainly used when "specializing" render
/// pipelines, i.e. specifying shader defs and binding layout based on the key,
/// the result of which can then be cached and accessed quickly later.
///
/// This trait can be derived with `#[derive(Specializer)]` for structs whose
/// fields all implement [`Specializer`]. The key type will be tuple of the keys
/// of each field, and their specialization logic will be applied in field
/// order. Since derive macros can't have generic parameters, the derive macro
/// requires an additional `#[specialize(..targets)]` attribute to specify a
/// list of types to target for the implementation. `#[specialize(all)]` is
/// also allowed, and will generate a fully generic implementation at the cost
/// of slightly worse error messages.
///
/// Additionally, each field can optionally take a `#[key]` attribute to
/// specify a "key override". This will "hide" that field's key from being
/// exposed by the wrapper, and always use the value given by the attribute.
/// Values for this attribute may either be `default` which will use the key's
/// [`Default`] implementation, or a valid rust
/// expression of the key type.
///
/// Example:
/// ```rs
/// # use super::RenderPipeline;
/// # use super::RenderPipelineDescriptor;
/// # use bevy_ecs::error::BevyError;
///
/// struct A;
/// struct B;
/// #[derive(Copy, Clone, PartialEq, Eq, Hash, SpecializerKey)]
/// struct BKey;
///
/// impl Specializer<RenderPipeline> for A {
/// type Key = ();
///
/// fn specializer(&self, key: (), descriptor: &mut RenderPipelineDescriptor) -> Result<(), BevyError> {
/// # let _ = (key, descriptor);
/// //...
/// Ok(())
/// }
/// }
///
/// impl Specializer<RenderPipeline> for B {
/// type Key = BKey;
///
/// fn specialize(&self, _key: Bkey, _descriptor: &mut RenderPipelineDescriptor) -> Result<BKey, BevyError> {
/// # let _ = (key, descriptor);
/// //...
/// Ok(BKey)
/// }
/// }
///
/// #[derive(Specializer)]
/// #[specialize(RenderPipeline)]
/// struct C {
/// #[key(default)]
/// a: A,
/// b: B,
/// }
///
/// /*
/// The generated implementation:
/// impl Specializer<RenderPipeline> for C {
/// type Key = BKey;
/// fn specialize(
/// &self,
/// key: Self::Key,
/// descriptor: &mut RenderPipelineDescriptor
/// ) -> Result<Canonical<Self::Key>, BevyError> {
/// let _ = self.a.specialize((), descriptor);
/// let key = self.b.specialize(key, descriptor);
/// Ok(key)
/// }
/// }
/// */
/// ```
pub trait Specializer<T: Specializable>: Send + Sync + 'static {
type Key: SpecializerKey;
fn specialize(
&self,
key: Self::Key,
descriptor: &mut T::Descriptor,
) -> Result<Canonical<Self::Key>, BevyError>;
}
/// Defines a type that is able to be used as a key for types that `impl Specialize`
///
/// **Most types should implement this trait with `IS_CANONICAL = true` and `Canonical = Self`**.
/// This is the implementation generated by `#[derive(SpecializerKey)]`
///
/// In this case, "canonical" means that each unique value of this type will produce
/// a unique specialized result, which isn't true in general. `MeshVertexBufferLayout`
/// is a good example of a type that's `Eq + Hash`, but that isn't canonical: vertex
/// attributes could be specified in any order, or there could be more attributes
/// provided than the specialized pipeline requires. Its `Canonical` key type would
/// be `VertexBufferLayout`, the final layout required by the pipeline.
///
/// Processing keys into canonical keys this way allows the `SpecializedCache` to reuse
/// resources more eagerly where possible.
pub trait SpecializerKey: Clone + Hash + Eq {
/// Denotes whether this key is canonical or not. This should only be `true`
/// if and only if `Canonical = Self`.
const IS_CANONICAL: bool;
/// The canonical key type to convert this into during specialization.
type Canonical: Hash + Eq;
}
pub type Canonical<T> = <T as SpecializerKey>::Canonical;
impl<T: Specializable> Specializer<T> for () {
type Key = ();
fn specialize(
&self,
_key: Self::Key,
_descriptor: &mut T::Descriptor,
) -> Result<(), BevyError> {
Ok(())
}
}
impl<T: Specializable, V: Send + Sync + 'static> Specializer<T> for PhantomData<V> {
type Key = ();
fn specialize(
&self,
_key: Self::Key,
_descriptor: &mut T::Descriptor,
) -> Result<(), BevyError> {
Ok(())
}
}
macro_rules! impl_specialization_key_tuple {
($($T:ident),*) => {
impl <$($T: SpecializerKey),*> SpecializerKey for ($($T,)*) {
const IS_CANONICAL: bool = true $(&& <$T as SpecializerKey>::IS_CANONICAL)*;
type Canonical = ($(Canonical<$T>,)*);
}
};
}
all_tuples!(impl_specialization_key_tuple, 0, 12, T);
/// Defines a specializer that can also provide a "base descriptor".
///
/// In order to be composable, [`Specializer`] implementers don't create full
/// descriptors, only transform them. However, [`SpecializedCache`]s need a
/// "base descriptor" at creation time in order to have something for the
/// [`Specializer`] implementation to work off of. This trait allows
/// [`SpecializedCache`] to impl [`FromWorld`] for [`Specializer`]
/// implementations that also satisfy [`FromWorld`] and [`GetBaseDescriptor`].
///
/// This trait can be also derived with `#[derive(Specializer)]`, by marking
/// a field with `#[base_descriptor]` to use its [`GetBaseDescriptor`] implementation.
///
/// Example:
/// ```rs
/// struct A;
/// struct B;
///
/// impl Specializer<RenderPipeline> for A {
/// type Key = ();
///
/// fn specialize(&self, _key: (), _descriptor: &mut RenderPipelineDescriptor) {
/// //...
/// }
/// }
///
/// impl Specializer<RenderPipeline> for B {
/// type Key = u32;
///
/// fn specialize(&self, _key: u32, _descriptor: &mut RenderPipelineDescriptor) {
/// //...
/// }
/// }
///
/// impl GetBaseDescriptor<RenderPipeline> for B {
/// fn get_base_descriptor(&self) -> RenderPipelineDescriptor {
/// # todo!()
/// //...
/// }
/// }
///
///
/// #[derive(Specializer)]
/// #[specialize(RenderPipeline)]
/// struct C {
/// #[key(default)]
/// a: A,
/// #[base_descriptor]
/// b: B,
/// }
///
/// /*
/// The generated implementation:
/// impl GetBaseDescriptor for C {
/// fn get_base_descriptor(&self) -> RenderPipelineDescriptor {
/// self.b.base_descriptor()
/// }
/// }
/// */
/// ```
pub trait GetBaseDescriptor<T: Specializable>: Specializer<T> {
fn get_base_descriptor(&self) -> T::Descriptor;
}
pub type SpecializerFn<T, S> =
fn(<S as Specializer<T>>::Key, &mut <T as Specializable>::Descriptor) -> Result<(), BevyError>;
/// A cache for specializable resources. For a given key type the resulting
/// resource will only be created if it is missing, retrieving it from the
/// cache otherwise.
#[derive(Resource)]
pub struct SpecializedCache<T: Specializable, S: Specializer<T>> {
specializer: S,
user_specializer: Option<SpecializerFn<T, S>>,
base_descriptor: T::Descriptor,
primary_cache: HashMap<S::Key, T::CachedId>,
secondary_cache: HashMap<Canonical<S::Key>, T::CachedId>,
}
impl<T: Specializable, S: Specializer<T>> SpecializedCache<T, S> {
/// Creates a new [`SpecializedCache`] from a [`Specializer`],
/// an optional "user specializer", and a base descriptor. The
/// user specializer is applied after the [`Specializer`], with
/// the same key.
#[inline]
pub fn new(
specializer: S,
user_specializer: Option<SpecializerFn<T, S>>,
base_descriptor: T::Descriptor,
) -> Self {
Self {
specializer,
user_specializer,
base_descriptor,
primary_cache: Default::default(),
secondary_cache: Default::default(),
}
}
/// Specializes a resource given the [`Specializer`]'s key type.
#[inline]
pub fn specialize(
&mut self,
pipeline_cache: &PipelineCache,
key: S::Key,
) -> Result<T::CachedId, BevyError> {
let entry = self.primary_cache.entry(key.clone());
match entry {
Entry::Occupied(entry) => Ok(entry.get().clone()),
Entry::Vacant(entry) => Self::specialize_slow(
&self.specializer,
self.user_specializer,
self.base_descriptor.clone(),
pipeline_cache,
key,
entry,
&mut self.secondary_cache,
),
}
}
#[cold]
fn specialize_slow(
specializer: &S,
user_specializer: Option<SpecializerFn<T, S>>,
base_descriptor: T::Descriptor,
pipeline_cache: &PipelineCache,
key: S::Key,
primary_entry: VacantEntry<S::Key, T::CachedId, FixedHasher>,
secondary_cache: &mut HashMap<Canonical<S::Key>, T::CachedId>,
) -> Result<T::CachedId, BevyError> {
let mut descriptor = base_descriptor.clone();
let canonical_key = specializer.specialize(key.clone(), &mut descriptor)?;
if let Some(user_specializer) = user_specializer {
(user_specializer)(key, &mut descriptor)?;
}
// if the whole key is canonical, the secondary cache isn't needed.
if <S::Key as SpecializerKey>::IS_CANONICAL {
return Ok(primary_entry
.insert(<T as Specializable>::queue(pipeline_cache, descriptor))
.clone());
}
let id = match secondary_cache.entry(canonical_key) {
Entry::Occupied(entry) => {
if cfg!(debug_assertions) {
let stored_descriptor =
<T as Specializable>::get_descriptor(pipeline_cache, entry.get().clone());
if &descriptor != stored_descriptor {
error!(
"Invalid Specializer<{}> impl for {}: the cached descriptor \
is not equal to the generated descriptor for the given key. \
This means the Specializer implementation uses unused information \
from the key to specialize the pipeline. This is not allowed \
because it would invalidate the cache.",
core::any::type_name::<T>(),
core::any::type_name::<S>()
);
}
}
entry.into_mut().clone()
}
Entry::Vacant(entry) => entry
.insert(<T as Specializable>::queue(pipeline_cache, descriptor))
.clone(),
};
primary_entry.insert(id.clone());
Ok(id)
}
}
/// [`SpecializedCache`] implements [`FromWorld`] for [`Specializer`]s
/// that also satisfy [`FromWorld`] and [`GetBaseDescriptor`]. This will
/// create a [`SpecializedCache`] with no user specializer, and the base
/// descriptor take from the specializer's [`GetBaseDescriptor`] implementation.
impl<T, S> FromWorld for SpecializedCache<T, S>
where
T: Specializable,
S: FromWorld + Specializer<T> + GetBaseDescriptor<T>,
{
fn from_world(world: &mut World) -> Self {
let specializer = S::from_world(world);
let base_descriptor = specializer.get_base_descriptor();
Self::new(specializer, None, base_descriptor)
}
}

View File

@ -24,11 +24,11 @@ use bevy::{
ViewBinnedRenderPhases,
},
render_resource::{
BufferUsages, ColorTargetState, ColorWrites, CompareFunction, DepthStencilState,
FragmentState, IndexFormat, MultisampleState, PipelineCache, PrimitiveState,
RawBufferVec, RenderPipelineDescriptor, SpecializedRenderPipeline,
SpecializedRenderPipelines, TextureFormat, VertexAttribute, VertexBufferLayout,
VertexFormat, VertexState, VertexStepMode,
BufferUsages, Canonical, ColorTargetState, ColorWrites, CompareFunction,
DepthStencilState, FragmentState, GetBaseDescriptor, IndexFormat, MultisampleState,
PipelineCache, PrimitiveState, RawBufferVec, RenderPipeline, RenderPipelineDescriptor,
SpecializedCache, Specializer, SpecializerKey, TextureFormat, VertexAttribute,
VertexBufferLayout, VertexFormat, VertexState, VertexStepMode,
},
renderer::{RenderDevice, RenderQueue},
view::{self, ExtractedView, RenderVisibleEntities, VisibilityClass},
@ -49,14 +49,6 @@ use bytemuck::{Pod, Zeroable};
#[component(on_add = view::add_visibility_class::<CustomRenderedEntity>)]
struct CustomRenderedEntity;
/// Holds a reference to our shader.
///
/// This is loaded at app creation time.
#[derive(Resource)]
struct CustomPhasePipeline {
shader: Handle<Shader>,
}
/// A [`RenderCommand`] that binds the vertex and index buffers and issues the
/// draw command for our custom phase item.
struct DrawCustomPhaseItem;
@ -175,8 +167,7 @@ fn main() {
// We make sure to add these to the render app, not the main app.
app.get_sub_app_mut(RenderApp)
.unwrap()
.init_resource::<CustomPhasePipeline>()
.init_resource::<SpecializedRenderPipelines<CustomPhasePipeline>>()
.init_resource::<SpecializedCache<RenderPipeline, CustomPhaseSpecializer>>()
.add_render_command::<Opaque3d, DrawCustomPhaseItemCommands>()
.add_systems(
Render,
@ -221,10 +212,9 @@ fn prepare_custom_phase_item_buffers(mut commands: Commands) {
/// the opaque render phases of each view.
fn queue_custom_phase_item(
pipeline_cache: Res<PipelineCache>,
custom_phase_pipeline: Res<CustomPhasePipeline>,
mut opaque_render_phases: ResMut<ViewBinnedRenderPhases<Opaque3d>>,
opaque_draw_functions: Res<DrawFunctions<Opaque3d>>,
mut specialized_render_pipelines: ResMut<SpecializedRenderPipelines<CustomPhasePipeline>>,
mut specializer: ResMut<SpecializedCache<RenderPipeline, CustomPhaseSpecializer>>,
views: Query<(&ExtractedView, &RenderVisibleEntities, &Msaa)>,
mut next_tick: Local<Tick>,
) {
@ -247,11 +237,10 @@ fn queue_custom_phase_item(
// some per-view settings, such as whether the view is HDR, but for
// simplicity's sake we simply hard-code the view's characteristics,
// with the exception of number of MSAA samples.
let pipeline_id = specialized_render_pipelines.specialize(
&pipeline_cache,
&custom_phase_pipeline,
*msaa,
);
let Ok(pipeline_id) = specializer.specialize(&pipeline_cache, CustomPhaseKey(*msaa))
else {
continue;
};
// Bump the change tick in order to force Bevy to rebuild the bin.
let this_tick = next_tick.get() + 1;
@ -286,10 +275,40 @@ fn queue_custom_phase_item(
}
}
impl SpecializedRenderPipeline for CustomPhasePipeline {
type Key = Msaa;
/// Holds a reference to our shader.
///
/// This is loaded at app creation time.
struct CustomPhaseSpecializer {
shader: Handle<Shader>,
}
fn specialize(&self, msaa: Self::Key) -> RenderPipelineDescriptor {
impl FromWorld for CustomPhaseSpecializer {
fn from_world(world: &mut World) -> Self {
let asset_server = world.resource::<AssetServer>();
Self {
shader: asset_server.load("shaders/custom_phase_item.wgsl"),
}
}
}
#[derive(Copy, Clone, PartialEq, Eq, Hash, SpecializerKey)]
struct CustomPhaseKey(Msaa);
impl Specializer<RenderPipeline> for CustomPhaseSpecializer {
type Key = CustomPhaseKey;
fn specialize(
&self,
key: Self::Key,
descriptor: &mut RenderPipelineDescriptor,
) -> Result<Canonical<Self::Key>, BevyError> {
descriptor.multisample.count = key.0.samples();
Ok(key)
}
}
impl GetBaseDescriptor<RenderPipeline> for CustomPhaseSpecializer {
fn get_base_descriptor(&self) -> RenderPipelineDescriptor {
RenderPipelineDescriptor {
label: Some("custom render pipeline".into()),
layout: vec![],
@ -340,7 +359,7 @@ impl SpecializedRenderPipeline for CustomPhasePipeline {
bias: default(),
}),
multisample: MultisampleState {
count: msaa.samples(),
count: 0,
mask: !0,
alpha_to_coverage_enabled: false,
},
@ -375,14 +394,3 @@ impl FromWorld for CustomPhaseItemBuffers {
}
}
}
impl FromWorld for CustomPhasePipeline {
fn from_world(world: &mut World) -> Self {
// Load and compile the shader in the background.
let asset_server = world.resource::<AssetServer>();
CustomPhasePipeline {
shader: asset_server.load("shaders/custom_phase_item.wgsl"),
}
}
}

View File

@ -0,0 +1,153 @@
---
title: Composable Specialization
pull_requests: [17373]
---
The existing pipeline specialization APIs (`SpecializedRenderPipeline` etc.) have
been replaced with a single `Specializer` trait and `SpecializedCache` collection:
```rs
pub trait Specializer<T: Specializable>: Send + Sync + 'static {
type Key: SpecializerKey;
fn specialize(
&self,
key: Self::Key,
descriptor: &mut T::Descriptor,
) -> Result<Canonical<Self::Key>, BevyError>;
}
pub struct SpecializedCache<T: Specializable, S: Specializer<T>>{ ... };
```
The main difference is the change from *producing* a pipeline descriptor to
*mutating* one based on a key. The "base descriptor" that the `SpecializedCache`
passes to the `Specializer` can either be specified manually with `Specializer::new`
or by implementing `GetBaseDescriptor`. There's also a new trait for specialization
keys, `SpecializeKey`, that can be derived with the included macro in most cases.
Composing multiple different specializers together with the `derive(Specializer)`
macro can be a lot more powerful (see the `Specialize` docs), but migrating
individual specializers is fairly simple. All static parts of the pipeline
should be specified in the base descriptor, while the `Specializer` impl
should mutate the key as little as necessary to match the key.
```rs
pub struct MySpecializer {
layout: BindGroupLayout,
layout_msaa: BindGroupLayout,
vertex: Handle<Shader>,
fragment: Handle<Shader>,
}
// before
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
// after
#[derive(Clone, Copy, PartialEq, Eq, Hash, SpecializerKey)]
pub struct MyKey {
blend_state: BlendState,
msaa: Msaa,
}
impl FromWorld for MySpecializer {
fn from_world(&mut World) -> Self {
...
}
}
// before
impl SpecializedRenderPipeline for MySpecializer {
type Key = MyKey;
fn specialize(&self, key: Self::Key) -> RenderPipelineDescriptor {
RenderPipelineDescriptor {
label: Some("my_pipeline".into()),
layout: vec![
if key.msaa.samples() > 0 {
self.layout_msaa.clone()
} else {
self.layout.clone()
}
],
push_constant_ranges: vec![],
vertex: VertexState {
shader: self.vertex.clone(),
shader_defs: vec![],
entry_point: "vertex".into(),
buffers: vec![],
},
primitive: Default::default(),
depth_stencil: None,
multisample: MultisampleState {
count: key.msaa.samples(),
..Default::default()
},
fragment: Some(FragmentState {
shader: self.fragment.clone(),
shader_defs: vec![],
entry_point: "fragment".into(),
targets: vec![Some(ColorTargetState {
format: TextureFormat::Rgba8Unorm,
blend: Some(key.blend_state),
write_mask: ColorWrites::all(),
})],
}),
zero_initialize_workgroup_memory: false,
},
}
}
app.init_resource::<SpecializedRenderPipelines<MySpecializer>>();
// after
impl Specializer<RenderPipeline> for MySpecializer {
type Key = MyKey;
fn specialize(
&self,
key: Self::Key,
descriptor: &mut RenderPipeline,
) -> Result<Canonical<Self::Key>, BevyError> {
descriptor.multisample.count = key.msaa.samples();
descriptor.layout[0] = if key.msaa.samples() > 0 {
self.layout_msaa.clone()
} else {
self.layout.clone()
};
descriptor.fragment.targets[0].as_mut().unwrap().blend_mode = key.blend_state;
Ok(key)
}
}
impl GetBaseDescriptor for MySpecializer {
fn get_base_descriptor(&self) -> RenderPipelineDescriptor {
RenderPipelineDescriptor {
label: Some("my_pipeline".into()),
layout: vec![self.layout.clone()],
push_constant_ranges: vec![],
vertex: VertexState {
shader: self.vertex.clone(),
shader_defs: vec![],
entry_point: "vertex".into(),
buffers: vec![],
},
primitive: Default::default(),
depth_stencil: None,
multisample: MultiSampleState::default(),
fragment: Some(FragmentState {
shader: self.fragment.clone(),
shader_defs: vec![],
entry_point: "fragment".into(),
targets: vec![Some(ColorTargetState {
format: TextureFormat::Rgba8Unorm,
blend: None,
write_mask: ColorWrites::all(),
})],
}),
zero_initialize_workgroup_memory: false,
},
}
}
app.init_resource::<SpecializedCache<RenderPipeline, MySpecializer>>();
```