This commit is contained in:
Emerson Coskey 2025-07-18 09:50:21 +08:00 committed by GitHub
commit 2541897196
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 316 additions and 387 deletions

View File

@ -1,6 +1,8 @@
use core::{ops::Deref, result::Result};
use crate::FullscreenShader; use crate::FullscreenShader;
use bevy_app::{App, Plugin}; use bevy_app::{App, Plugin};
use bevy_asset::{embedded_asset, load_embedded_asset, AssetServer, Handle}; use bevy_asset::{embedded_asset, load_embedded_asset, AssetServer};
use bevy_ecs::prelude::*; use bevy_ecs::prelude::*;
use bevy_render::{ use bevy_render::{
render_resource::{ render_resource::{
@ -24,8 +26,7 @@ impl Plugin for BlitPlugin {
}; };
render_app render_app
.allow_ambiguous_resource::<SpecializedRenderPipelines<BlitPipeline>>() .allow_ambiguous_resource::<BlitPipeline>()
.init_resource::<SpecializedRenderPipelines<BlitPipeline>>()
.add_systems(RenderStartup, init_blit_pipeline); .add_systems(RenderStartup, init_blit_pipeline);
} }
} }
@ -34,8 +35,7 @@ impl Plugin for BlitPlugin {
pub struct BlitPipeline { pub struct BlitPipeline {
pub layout: BindGroupLayout, pub layout: BindGroupLayout,
pub sampler: Sampler, pub sampler: Sampler,
pub fullscreen_shader: FullscreenShader, pub specialized_cache: SpecializedCache<RenderPipeline, BlitSpecializer>,
pub fragment_shader: Handle<Shader>,
} }
pub fn init_blit_pipeline( pub fn init_blit_pipeline(
@ -57,11 +57,23 @@ pub fn init_blit_pipeline(
let sampler = render_device.create_sampler(&SamplerDescriptor::default()); let sampler = render_device.create_sampler(&SamplerDescriptor::default());
let base_descriptor = RenderPipelineDescriptor {
label: Some("blit pipeline".into()),
layout: vec![layout.clone()],
vertex: fullscreen_shader.to_vertex_state(),
fragment: Some(FragmentState {
shader: load_embedded_asset!(asset_server.deref(), "blit.wgsl"),
..default()
}),
..default()
};
let specialized_cache = SpecializedCache::new(BlitSpecializer, base_descriptor);
commands.insert_resource(BlitPipeline { commands.insert_resource(BlitPipeline {
layout, layout,
sampler, sampler,
fullscreen_shader: fullscreen_shader.clone(), specialized_cache,
fragment_shader: load_embedded_asset!(asset_server.as_ref(), "blit.wgsl"),
}); });
} }
@ -79,35 +91,34 @@ impl BlitPipeline {
} }
} }
#[derive(PartialEq, Eq, Hash, Clone, Copy)] pub struct BlitSpecializer;
pub struct BlitPipelineKey {
impl Specializer<RenderPipeline> for BlitSpecializer {
type Key = BlitKey;
fn specialize(
&self,
key: Self::Key,
descriptor: &mut <RenderPipeline as Specializable>::Descriptor,
) -> Result<Canonical<Self::Key>, BevyError> {
descriptor.multisample.count = key.samples;
descriptor.fragment_mut()?.set_target(
0,
ColorTargetState {
format: key.texture_format,
blend: key.blend_state,
write_mask: ColorWrites::ALL,
},
);
Ok(key)
}
}
#[derive(PartialEq, Eq, Hash, Clone, Copy, SpecializerKey)]
pub struct BlitKey {
pub texture_format: TextureFormat, pub texture_format: TextureFormat,
pub blend_state: Option<BlendState>, pub blend_state: Option<BlendState>,
pub samples: u32, pub samples: u32,
} }
impl SpecializedRenderPipeline for BlitPipeline {
type Key = BlitPipelineKey;
fn specialize(&self, key: Self::Key) -> RenderPipelineDescriptor {
RenderPipelineDescriptor {
label: Some("blit pipeline".into()),
layout: vec![self.layout.clone()],
vertex: self.fullscreen_shader.to_vertex_state(),
fragment: Some(FragmentState {
shader: self.fragment_shader.clone(),
targets: vec![Some(ColorTargetState {
format: key.texture_format,
blend: key.blend_state,
write_mask: ColorWrites::ALL,
})],
..default()
}),
multisample: MultisampleState {
count: key.samples,
..default()
},
..default()
}
}
}

View File

@ -1,5 +1,5 @@
use crate::{ use crate::{
blit::{BlitPipeline, BlitPipelineKey}, blit::{BlitKey, BlitPipeline},
core_2d::graph::{Core2d, Node2d}, core_2d::graph::{Core2d, Node2d},
core_3d::graph::{Core3d, Node3d}, core_3d::graph::{Core3d, Node3d},
}; };
@ -119,22 +119,23 @@ pub struct MsaaWritebackBlitPipeline(CachedRenderPipelineId);
fn prepare_msaa_writeback_pipelines( fn prepare_msaa_writeback_pipelines(
mut commands: Commands, mut commands: Commands,
pipeline_cache: Res<PipelineCache>, pipeline_cache: Res<PipelineCache>,
mut pipelines: ResMut<SpecializedRenderPipelines<BlitPipeline>>, mut blit_pipeline: ResMut<BlitPipeline>,
blit_pipeline: Res<BlitPipeline>,
view_targets: Query<(Entity, &ViewTarget, &ExtractedCamera, &Msaa)>, view_targets: Query<(Entity, &ViewTarget, &ExtractedCamera, &Msaa)>,
) { ) -> Result<(), BevyError> {
for (entity, view_target, camera, msaa) in view_targets.iter() { for (entity, view_target, camera, msaa) in view_targets.iter() {
// only do writeback if writeback is enabled for the camera and this isn't the first camera in the target, // only do writeback if writeback is enabled for the camera and this isn't the first camera in the target,
// as there is nothing to write back for the first camera. // as there is nothing to write back for the first camera.
if msaa.samples() > 1 && camera.msaa_writeback && camera.sorted_camera_index_for_target > 0 if msaa.samples() > 1 && camera.msaa_writeback && camera.sorted_camera_index_for_target > 0
{ {
let key = BlitPipelineKey { let key = BlitKey {
texture_format: view_target.main_texture_format(), texture_format: view_target.main_texture_format(),
samples: msaa.samples(), samples: msaa.samples(),
blend_state: None, blend_state: None,
}; };
let pipeline = pipelines.specialize(&pipeline_cache, &blit_pipeline, key); let pipeline = blit_pipeline
.specialized_cache
.specialize(&pipeline_cache, key)?;
commands commands
.entity(entity) .entity(entity)
.insert(MsaaWritebackBlitPipeline(pipeline)); .insert(MsaaWritebackBlitPipeline(pipeline));
@ -146,4 +147,5 @@ fn prepare_msaa_writeback_pipelines(
.remove::<MsaaWritebackBlitPipeline>(); .remove::<MsaaWritebackBlitPipeline>();
} }
} }
Ok(())
} }

View File

@ -1,4 +1,4 @@
use crate::blit::{BlitPipeline, BlitPipelineKey}; use crate::blit::{BlitKey, BlitPipeline};
use bevy_app::prelude::*; use bevy_app::prelude::*;
use bevy_ecs::prelude::*; use bevy_ecs::prelude::*;
use bevy_platform::collections::HashSet; use bevy_platform::collections::HashSet;
@ -39,10 +39,9 @@ pub struct ViewUpscalingPipeline(CachedRenderPipelineId);
fn prepare_view_upscaling_pipelines( fn prepare_view_upscaling_pipelines(
mut commands: Commands, mut commands: Commands,
mut pipeline_cache: ResMut<PipelineCache>, mut pipeline_cache: ResMut<PipelineCache>,
mut pipelines: ResMut<SpecializedRenderPipelines<BlitPipeline>>, mut blit_pipeline: ResMut<BlitPipeline>,
blit_pipeline: Res<BlitPipeline>,
view_targets: Query<(Entity, &ViewTarget, Option<&ExtractedCamera>)>, view_targets: Query<(Entity, &ViewTarget, Option<&ExtractedCamera>)>,
) { ) -> Result<(), BevyError> {
let mut output_textures = <HashSet<_>>::default(); let mut output_textures = <HashSet<_>>::default();
for (entity, view_target, camera) in view_targets.iter() { for (entity, view_target, camera) in view_targets.iter() {
let out_texture_id = view_target.out_texture().id(); let out_texture_id = view_target.out_texture().id();
@ -73,12 +72,14 @@ fn prepare_view_upscaling_pipelines(
None None
}; };
let key = BlitPipelineKey { let key = BlitKey {
texture_format: view_target.out_texture_format(), texture_format: view_target.out_texture_format(),
blend_state, blend_state,
samples: 1, samples: 1,
}; };
let pipeline = pipelines.specialize(&pipeline_cache, &blit_pipeline, key); let pipeline = blit_pipeline
.specialized_cache
.specialize(&pipeline_cache, key)?;
// Ensure the pipeline is loaded before continuing the frame to prevent frames without any GPU work submitted // Ensure the pipeline is loaded before continuing the frame to prevent frames without any GPU work submitted
pipeline_cache.block_on_render_pipeline(pipeline); pipeline_cache.block_on_render_pipeline(pipeline);
@ -87,4 +88,6 @@ fn prepare_view_upscaling_pipelines(
.entity(entity) .entity(entity)
.insert(ViewUpscalingPipeline(pipeline)); .insert(ViewUpscalingPipeline(pipeline));
} }
Ok(())
} }

View File

@ -20,8 +20,6 @@ const SPECIALIZE_ALL_IDENT: &str = "all";
const KEY_ATTR_IDENT: &str = "key"; const KEY_ATTR_IDENT: &str = "key";
const KEY_DEFAULT_IDENT: &str = "default"; const KEY_DEFAULT_IDENT: &str = "default";
const BASE_DESCRIPTOR_ATTR_IDENT: &str = "base_descriptor";
enum SpecializeImplTargets { enum SpecializeImplTargets {
All, All,
Specific(Vec<Path>), Specific(Vec<Path>),
@ -87,7 +85,6 @@ struct FieldInfo {
ty: Type, ty: Type,
member: Member, member: Member,
key: Key, key: Key,
use_base_descriptor: bool,
} }
impl FieldInfo { impl FieldInfo {
@ -117,15 +114,6 @@ impl FieldInfo {
parse_quote!(#ty: #specialize_path::Specializer<#target_path>) parse_quote!(#ty: #specialize_path::Specializer<#target_path>)
} }
} }
fn get_base_descriptor_predicate(
&self,
specialize_path: &Path,
target_path: &Path,
) -> WherePredicate {
let ty = &self.ty;
parse_quote!(#ty: #specialize_path::GetBaseDescriptor<#target_path>)
}
} }
fn get_field_info( fn get_field_info(
@ -151,12 +139,8 @@ fn get_field_info(
let mut use_key_field = true; let mut use_key_field = true;
let mut key = Key::Index(key_index); let mut key = Key::Index(key_index);
let mut use_base_descriptor = false;
for attr in &field.attrs { for attr in &field.attrs {
match &attr.meta { match &attr.meta {
Meta::Path(path) if path.is_ident(&BASE_DESCRIPTOR_ATTR_IDENT) => {
use_base_descriptor = true;
}
Meta::List(MetaList { path, tokens, .. }) if path.is_ident(&KEY_ATTR_IDENT) => { Meta::List(MetaList { path, tokens, .. }) if path.is_ident(&KEY_ATTR_IDENT) => {
let owned_tokens = tokens.clone().into(); let owned_tokens = tokens.clone().into();
let Ok(parsed_key) = syn::parse::<Key>(owned_tokens) else { let Ok(parsed_key) = syn::parse::<Key>(owned_tokens) else {
@ -190,7 +174,6 @@ fn get_field_info(
ty: field_ty, ty: field_ty,
member: field_member, member: field_member,
key, key,
use_base_descriptor,
}); });
} }
@ -261,41 +244,18 @@ pub fn impl_specializer(input: TokenStream) -> TokenStream {
}) })
.collect(); .collect();
let base_descriptor_fields = field_info
.iter()
.filter(|field| field.use_base_descriptor)
.collect::<Vec<_>>();
if base_descriptor_fields.len() > 1 {
return syn::Error::new(
Span::call_site(),
"Too many #[base_descriptor] attributes found. It must be present on exactly one field",
)
.into_compile_error()
.into();
}
let base_descriptor_field = base_descriptor_fields.first().copied();
match targets { match targets {
SpecializeImplTargets::All => { SpecializeImplTargets::All => impl_specialize_all(
let specialize_impl = impl_specialize_all( &specialize_path,
&specialize_path, &ecs_path,
&ecs_path, &ast,
&ast, &field_info,
&field_info, &key_patterns,
&key_patterns, &key_tuple_idents,
&key_tuple_idents, ),
); SpecializeImplTargets::Specific(targets) => targets
let get_base_descriptor_impl = base_descriptor_field .iter()
.map(|field_info| impl_get_base_descriptor_all(&specialize_path, &ast, field_info)) .map(|target| {
.unwrap_or_default();
[specialize_impl, get_base_descriptor_impl]
.into_iter()
.collect()
}
SpecializeImplTargets::Specific(targets) => {
let specialize_impls = targets.iter().map(|target| {
impl_specialize_specific( impl_specialize_specific(
&specialize_path, &specialize_path,
&ecs_path, &ecs_path,
@ -305,14 +265,8 @@ pub fn impl_specializer(input: TokenStream) -> TokenStream {
&key_patterns, &key_patterns,
&key_tuple_idents, &key_tuple_idents,
) )
}); })
let get_base_descriptor_impls = targets.iter().filter_map(|target| { .collect(),
base_descriptor_field.map(|field_info| {
impl_get_base_descriptor_specific(&specialize_path, &ast, field_info, target)
})
});
specialize_impls.chain(get_base_descriptor_impls).collect()
}
} }
} }
@ -406,56 +360,6 @@ fn impl_specialize_specific(
}) })
} }
fn impl_get_base_descriptor_specific(
specialize_path: &Path,
ast: &DeriveInput,
base_descriptor_field_info: &FieldInfo,
target_path: &Path,
) -> TokenStream {
let struct_name = &ast.ident;
let (impl_generics, type_generics, where_clause) = &ast.generics.split_for_impl();
let field_ty = &base_descriptor_field_info.ty;
let field_member = &base_descriptor_field_info.member;
TokenStream::from(quote!(
impl #impl_generics #specialize_path::GetBaseDescriptor<#target_path> for #struct_name #type_generics #where_clause {
fn get_base_descriptor(&self) -> <#target_path as #specialize_path::Specializable>::Descriptor {
<#field_ty as #specialize_path::GetBaseDescriptor<#target_path>>::get_base_descriptor(&self.#field_member)
}
}
))
}
fn impl_get_base_descriptor_all(
specialize_path: &Path,
ast: &DeriveInput,
base_descriptor_field_info: &FieldInfo,
) -> TokenStream {
let target_path = Path::from(format_ident!("T"));
let struct_name = &ast.ident;
let mut generics = ast.generics.clone();
generics.params.insert(
0,
parse_quote!(#target_path: #specialize_path::Specializable),
);
let where_clause = generics.make_where_clause();
where_clause.predicates.push(
base_descriptor_field_info.get_base_descriptor_predicate(specialize_path, &target_path),
);
let (_, type_generics, _) = ast.generics.split_for_impl();
let (impl_generics, _, where_clause) = &generics.split_for_impl();
let field_ty = &base_descriptor_field_info.ty;
let field_member = &base_descriptor_field_info.member;
TokenStream::from(quote! {
impl #impl_generics #specialize_path::GetBaseDescriptor<#target_path> for #struct_name #type_generics #where_clause {
fn get_base_descriptor(&self) -> <#target_path as #specialize_path::Specializable>::Descriptor {
<#field_ty as #specialize_path::GetBaseDescriptor<#target_path>>::get_base_descriptor(&self.#field_member)
}
}
})
}
pub fn impl_specializer_key(input: TokenStream) -> TokenStream { pub fn impl_specializer_key(input: TokenStream) -> TokenStream {
let bevy_render_path: Path = crate::bevy_render_path(); let bevy_render_path: Path = crate::bevy_render_path();
let specialize_path = { let specialize_path = {

View File

@ -96,6 +96,7 @@ use render_asset::{
extract_render_asset_bytes_per_frame, reset_render_asset_bytes_per_frame, extract_render_asset_bytes_per_frame, reset_render_asset_bytes_per_frame,
RenderAssetBytesPerFrame, RenderAssetBytesPerFrameLimiter, RenderAssetBytesPerFrame, RenderAssetBytesPerFrameLimiter,
}; };
use render_resource::init_empty_bind_group_layout;
use renderer::{RenderAdapter, RenderDevice, RenderQueue}; use renderer::{RenderAdapter, RenderDevice, RenderQueue};
use settings::RenderResources; use settings::RenderResources;
use sync_world::{ use sync_world::{
@ -465,6 +466,8 @@ impl Plugin for RenderPlugin {
Render, Render,
reset_render_asset_bytes_per_frame.in_set(RenderSystems::Cleanup), reset_render_asset_bytes_per_frame.in_set(RenderSystems::Cleanup),
); );
render_app.add_systems(RenderStartup, init_empty_bind_group_layout);
} }
app.register_type::<alpha::AlphaMode>() app.register_type::<alpha::AlphaMode>()

View File

@ -1,4 +1,6 @@
use crate::define_atomic_id; use crate::{define_atomic_id, renderer::RenderDevice};
use bevy_ecs::system::Res;
use bevy_platform::sync::OnceLock;
use bevy_utils::WgpuWrapper; use bevy_utils::WgpuWrapper;
use core::ops::Deref; use core::ops::Deref;
@ -62,3 +64,19 @@ impl Deref for BindGroupLayout {
&self.value &self.value
} }
} }
static EMPTY_BIND_GROUP_LAYOUT: OnceLock<BindGroupLayout> = OnceLock::new();
pub(crate) fn init_empty_bind_group_layout(render_device: Res<RenderDevice>) {
let layout = render_device.create_bind_group_layout(Some("empty_bind_group_layout"), &[]);
EMPTY_BIND_GROUP_LAYOUT
.set(layout)
.expect("init_empty_bind_group_layout was called more than once");
}
pub fn empty_bind_group_layout() -> BindGroupLayout {
EMPTY_BIND_GROUP_LAYOUT
.get()
.expect("init_empty_bind_group_layout was not called")
.clone()
}

View File

@ -1,4 +1,4 @@
use super::ShaderDefVal; use super::{empty_bind_group_layout, ShaderDefVal};
use crate::mesh::VertexBufferLayout; use crate::mesh::VertexBufferLayout;
use crate::{ use crate::{
define_atomic_id, define_atomic_id,
@ -7,7 +7,9 @@ use crate::{
use alloc::borrow::Cow; use alloc::borrow::Cow;
use bevy_asset::Handle; use bevy_asset::Handle;
use bevy_utils::WgpuWrapper; use bevy_utils::WgpuWrapper;
use core::iter;
use core::ops::Deref; use core::ops::Deref;
use thiserror::Error;
use wgpu::{ use wgpu::{
ColorTargetState, DepthStencilState, MultisampleState, PrimitiveState, PushConstantRange, ColorTargetState, DepthStencilState, MultisampleState, PrimitiveState, PushConstantRange,
}; };
@ -112,6 +114,20 @@ pub struct RenderPipelineDescriptor {
pub zero_initialize_workgroup_memory: bool, pub zero_initialize_workgroup_memory: bool,
} }
#[derive(Copy, Clone, Debug, Error)]
#[error("RenderPipelineDescriptor has no FragmentState configured")]
pub struct NoFragmentStateError;
impl RenderPipelineDescriptor {
pub fn fragment_mut(&mut self) -> Result<&mut FragmentState, NoFragmentStateError> {
self.fragment.as_mut().ok_or(NoFragmentStateError)
}
pub fn set_layout(&mut self, index: usize, layout: BindGroupLayout) {
filling_set_at(&mut self.layout, index, empty_bind_group_layout(), layout);
}
}
#[derive(Clone, Debug, Eq, PartialEq, Default)] #[derive(Clone, Debug, Eq, PartialEq, Default)]
pub struct VertexState { pub struct VertexState {
/// The compiled shader module for this stage. /// The compiled shader module for this stage.
@ -137,6 +153,12 @@ pub struct FragmentState {
pub targets: Vec<Option<ColorTargetState>>, pub targets: Vec<Option<ColorTargetState>>,
} }
impl FragmentState {
pub fn set_target(&mut self, index: usize, target: ColorTargetState) {
filling_set_at(&mut self.targets, index, None, Some(target));
}
}
/// Describes a compute pipeline. /// Describes a compute pipeline.
#[derive(Clone, Debug, PartialEq, Eq, Default)] #[derive(Clone, Debug, PartialEq, Eq, Default)]
pub struct ComputePipelineDescriptor { pub struct ComputePipelineDescriptor {
@ -153,3 +175,11 @@ pub struct ComputePipelineDescriptor {
/// If this is false, reading from workgroup variables before writing to them will result in garbage values. /// If this is false, reading from workgroup variables before writing to them will result in garbage values.
pub zero_initialize_workgroup_memory: bool, pub zero_initialize_workgroup_memory: bool,
} }
// utility function to set a value at the specified index, extending with
// a filler value if the index is out of bounds.
fn filling_set_at<T: Clone>(vec: &mut Vec<T>, index: usize, filler: T, value: T) {
let num_to_fill = (index + 1).saturating_sub(vec.len());
vec.extend(iter::repeat_n(filler, num_to_fill));
vec[index] = value;
}

View File

@ -2,11 +2,7 @@ use super::{
CachedComputePipelineId, CachedRenderPipelineId, ComputePipeline, ComputePipelineDescriptor, CachedComputePipelineId, CachedRenderPipelineId, ComputePipeline, ComputePipelineDescriptor,
PipelineCache, RenderPipeline, RenderPipelineDescriptor, PipelineCache, RenderPipeline, RenderPipelineDescriptor,
}; };
use bevy_ecs::{ use bevy_ecs::error::BevyError;
error::BevyError,
resource::Resource,
world::{FromWorld, World},
};
use bevy_platform::{ use bevy_platform::{
collections::{ collections::{
hash_map::{Entry, VacantEntry}, hash_map::{Entry, VacantEntry},
@ -260,113 +256,22 @@ macro_rules! impl_specialization_key_tuple {
// TODO: How to we fake_variadics this? // TODO: How to we fake_variadics this?
all_tuples!(impl_specialization_key_tuple, 0, 12, T); all_tuples!(impl_specialization_key_tuple, 0, 12, T);
/// Defines a specializer that can also provide a "base descriptor".
///
/// In order to be composable, [`Specializer`] implementers don't create full
/// descriptors, only transform them. However, [`SpecializedCache`]s need a
/// "base descriptor" at creation time in order to have something for the
/// [`Specializer`] to work off of. This trait allows [`SpecializedCache`]
/// to impl [`FromWorld`] for [`Specializer`]s that also satisfy [`FromWorld`]
/// and [`GetBaseDescriptor`].
///
/// This trait can be also derived with `#[derive(Specializer)]`, by marking
/// a field with `#[base_descriptor]` to use its [`GetBaseDescriptor`] implementation.
///
/// Example:
/// ```rust
/// # use bevy_ecs::error::BevyError;
/// # use bevy_render::render_resource::Specializer;
/// # use bevy_render::render_resource::GetBaseDescriptor;
/// # use bevy_render::render_resource::SpecializerKey;
/// # use bevy_render::render_resource::RenderPipeline;
/// # use bevy_render::render_resource::RenderPipelineDescriptor;
/// struct A;
/// struct B;
///
/// impl Specializer<RenderPipeline> for A {
/// # type Key = ();
/// #
/// # fn specialize(
/// # &self,
/// # key: (),
/// # _descriptor: &mut RenderPipelineDescriptor
/// # ) -> Result<(), BevyError> {
/// # Ok(key)
/// # }
/// // ...
/// }
///
/// impl Specializer<RenderPipeline> for B {
/// # type Key = ();
/// #
/// # fn specialize(
/// # &self,
/// # key: (),
/// # _descriptor: &mut RenderPipelineDescriptor
/// # ) -> Result<(), BevyError> {
/// # Ok(key)
/// # }
/// // ...
/// }
///
/// impl GetBaseDescriptor<RenderPipeline> for B {
/// fn get_base_descriptor(&self) -> RenderPipelineDescriptor {
/// # todo!()
/// // ...
/// }
/// }
///
///
/// #[derive(Specializer)]
/// #[specialize(RenderPipeline)]
/// struct C {
/// a: A,
/// #[base_descriptor]
/// b: B,
/// }
///
/// /*
/// The generated implementation:
/// impl GetBaseDescriptor for C {
/// fn get_base_descriptor(&self) -> RenderPipelineDescriptor {
/// self.b.base_descriptor()
/// }
/// }
/// */
/// ```
pub trait GetBaseDescriptor<T: Specializable>: Specializer<T> {
fn get_base_descriptor(&self) -> T::Descriptor;
}
pub type SpecializerFn<T, S> =
fn(<S as Specializer<T>>::Key, &mut <T as Specializable>::Descriptor) -> Result<(), BevyError>;
/// A cache for specializable resources. For a given key type the resulting /// A cache for specializable resources. For a given key type the resulting
/// resource will only be created if it is missing, retrieving it from the /// resource will only be created if it is missing, retrieving it from the
/// cache otherwise. /// cache otherwise.
#[derive(Resource)]
pub struct SpecializedCache<T: Specializable, S: Specializer<T>> { pub struct SpecializedCache<T: Specializable, S: Specializer<T>> {
specializer: S, specializer: S,
user_specializer: Option<SpecializerFn<T, S>>,
base_descriptor: T::Descriptor, base_descriptor: T::Descriptor,
primary_cache: HashMap<S::Key, T::CachedId>, primary_cache: HashMap<S::Key, T::CachedId>,
secondary_cache: HashMap<Canonical<S::Key>, T::CachedId>, secondary_cache: HashMap<Canonical<S::Key>, T::CachedId>,
} }
impl<T: Specializable, S: Specializer<T>> SpecializedCache<T, S> { impl<T: Specializable, S: Specializer<T>> SpecializedCache<T, S> {
/// Creates a new [`SpecializedCache`] from a [`Specializer`], /// Creates a new [`SpecializedCache`] from a [`Specializer`] and a base descriptor.
/// an optional "user specializer", and a base descriptor. The
/// user specializer is applied after the [`Specializer`], with
/// the same key.
#[inline] #[inline]
pub fn new( pub fn new(specializer: S, base_descriptor: T::Descriptor) -> Self {
specializer: S,
user_specializer: Option<SpecializerFn<T, S>>,
base_descriptor: T::Descriptor,
) -> Self {
Self { Self {
specializer, specializer,
user_specializer,
base_descriptor, base_descriptor,
primary_cache: Default::default(), primary_cache: Default::default(),
secondary_cache: Default::default(), secondary_cache: Default::default(),
@ -385,7 +290,6 @@ impl<T: Specializable, S: Specializer<T>> SpecializedCache<T, S> {
Entry::Occupied(entry) => Ok(entry.get().clone()), Entry::Occupied(entry) => Ok(entry.get().clone()),
Entry::Vacant(entry) => Self::specialize_slow( Entry::Vacant(entry) => Self::specialize_slow(
&self.specializer, &self.specializer,
self.user_specializer,
self.base_descriptor.clone(), self.base_descriptor.clone(),
pipeline_cache, pipeline_cache,
key, key,
@ -398,7 +302,6 @@ impl<T: Specializable, S: Specializer<T>> SpecializedCache<T, S> {
#[cold] #[cold]
fn specialize_slow( fn specialize_slow(
specializer: &S, specializer: &S,
user_specializer: Option<SpecializerFn<T, S>>,
base_descriptor: T::Descriptor, base_descriptor: T::Descriptor,
pipeline_cache: &PipelineCache, pipeline_cache: &PipelineCache,
key: S::Key, key: S::Key,
@ -408,10 +311,6 @@ impl<T: Specializable, S: Specializer<T>> SpecializedCache<T, S> {
let mut descriptor = base_descriptor.clone(); let mut descriptor = base_descriptor.clone();
let canonical_key = specializer.specialize(key.clone(), &mut descriptor)?; let canonical_key = specializer.specialize(key.clone(), &mut descriptor)?;
if let Some(user_specializer) = user_specializer {
(user_specializer)(key, &mut descriptor)?;
}
// if the whole key is canonical, the secondary cache isn't needed. // if the whole key is canonical, the secondary cache isn't needed.
if <S::Key as SpecializerKey>::IS_CANONICAL { if <S::Key as SpecializerKey>::IS_CANONICAL {
return Ok(primary_entry return Ok(primary_entry
@ -447,19 +346,3 @@ impl<T: Specializable, S: Specializer<T>> SpecializedCache<T, S> {
Ok(id) Ok(id)
} }
} }
/// [`SpecializedCache`] implements [`FromWorld`] for [`Specializer`]s
/// that also satisfy [`FromWorld`] and [`GetBaseDescriptor`]. This will
/// create a [`SpecializedCache`] with no user specializer, and the base
/// descriptor take from the specializer's [`GetBaseDescriptor`] implementation.
impl<T, S> FromWorld for SpecializedCache<T, S>
where
T: Specializable,
S: FromWorld + Specializer<T> + GetBaseDescriptor<T>,
{
fn from_world(world: &mut World) -> Self {
let specializer = S::from_world(world);
let base_descriptor = specializer.get_base_descriptor();
Self::new(specializer, None, base_descriptor)
}
}

View File

@ -25,8 +25,8 @@ use bevy::{
}, },
render_resource::{ render_resource::{
BufferUsages, Canonical, ColorTargetState, ColorWrites, CompareFunction, BufferUsages, Canonical, ColorTargetState, ColorWrites, CompareFunction,
DepthStencilState, FragmentState, GetBaseDescriptor, IndexFormat, PipelineCache, DepthStencilState, FragmentState, IndexFormat, PipelineCache, RawBufferVec,
RawBufferVec, RenderPipeline, RenderPipelineDescriptor, SpecializedCache, Specializer, RenderPipeline, RenderPipelineDescriptor, SpecializedCache, Specializer,
SpecializerKey, TextureFormat, VertexAttribute, VertexBufferLayout, VertexFormat, SpecializerKey, TextureFormat, VertexAttribute, VertexBufferLayout, VertexFormat,
VertexState, VertexStepMode, VertexState, VertexStepMode,
}, },
@ -165,9 +165,8 @@ fn main() {
.add_systems(Startup, setup); .add_systems(Startup, setup);
// We make sure to add these to the render app, not the main app. // We make sure to add these to the render app, not the main app.
app.get_sub_app_mut(RenderApp) app.sub_app_mut(RenderApp)
.unwrap() .init_resource::<CustomPhasePipeline>()
.init_resource::<SpecializedCache<RenderPipeline, CustomPhaseSpecializer>>()
.add_render_command::<Opaque3d, DrawCustomPhaseItemCommands>() .add_render_command::<Opaque3d, DrawCustomPhaseItemCommands>()
.add_systems( .add_systems(
Render, Render,
@ -212,9 +211,9 @@ fn prepare_custom_phase_item_buffers(mut commands: Commands) {
/// the opaque render phases of each view. /// the opaque render phases of each view.
fn queue_custom_phase_item( fn queue_custom_phase_item(
pipeline_cache: Res<PipelineCache>, pipeline_cache: Res<PipelineCache>,
mut pipeline: ResMut<CustomPhasePipeline>,
mut opaque_render_phases: ResMut<ViewBinnedRenderPhases<Opaque3d>>, mut opaque_render_phases: ResMut<ViewBinnedRenderPhases<Opaque3d>>,
opaque_draw_functions: Res<DrawFunctions<Opaque3d>>, opaque_draw_functions: Res<DrawFunctions<Opaque3d>>,
mut specializer: ResMut<SpecializedCache<RenderPipeline, CustomPhaseSpecializer>>,
views: Query<(&ExtractedView, &RenderVisibleEntities, &Msaa)>, views: Query<(&ExtractedView, &RenderVisibleEntities, &Msaa)>,
mut next_tick: Local<Tick>, mut next_tick: Local<Tick>,
) { ) {
@ -237,7 +236,9 @@ fn queue_custom_phase_item(
// some per-view settings, such as whether the view is HDR, but for // some per-view settings, such as whether the view is HDR, but for
// simplicity's sake we simply hard-code the view's characteristics, // simplicity's sake we simply hard-code the view's characteristics,
// with the exception of number of MSAA samples. // with the exception of number of MSAA samples.
let Ok(pipeline_id) = specializer.specialize(&pipeline_cache, CustomPhaseKey(*msaa)) let Ok(pipeline_id) = pipeline
.specialized_cache
.specialize(&pipeline_cache, CustomPhaseKey(*msaa))
else { else {
continue; continue;
}; };
@ -275,44 +276,23 @@ fn queue_custom_phase_item(
} }
} }
/// Holds a reference to our shader. struct CustomPhaseSpecializer;
///
/// This is loaded at app creation time. #[derive(Resource)]
struct CustomPhaseSpecializer { struct CustomPhasePipeline {
shader: Handle<Shader>, /// the `specialized_cache` holds onto the shader handle through the base descriptor
specialized_cache: SpecializedCache<RenderPipeline, CustomPhaseSpecializer>,
} }
impl FromWorld for CustomPhaseSpecializer { impl FromWorld for CustomPhasePipeline {
fn from_world(world: &mut World) -> Self { fn from_world(world: &mut World) -> Self {
let asset_server = world.resource::<AssetServer>(); let asset_server = world.resource::<AssetServer>();
Self { let shader = asset_server.load("shaders/custom_phase_item.wgsl");
shader: asset_server.load("shaders/custom_phase_item.wgsl"),
}
}
}
#[derive(Copy, Clone, PartialEq, Eq, Hash, SpecializerKey)] let base_descriptor = RenderPipelineDescriptor {
struct CustomPhaseKey(Msaa);
impl Specializer<RenderPipeline> for CustomPhaseSpecializer {
type Key = CustomPhaseKey;
fn specialize(
&self,
key: Self::Key,
descriptor: &mut RenderPipelineDescriptor,
) -> Result<Canonical<Self::Key>, BevyError> {
descriptor.multisample.count = key.0.samples();
Ok(key)
}
}
impl GetBaseDescriptor<RenderPipeline> for CustomPhaseSpecializer {
fn get_base_descriptor(&self) -> RenderPipelineDescriptor {
RenderPipelineDescriptor {
label: Some("custom render pipeline".into()), label: Some("custom render pipeline".into()),
vertex: VertexState { vertex: VertexState {
shader: self.shader.clone(), shader: shader.clone(),
buffers: vec![VertexBufferLayout { buffers: vec![VertexBufferLayout {
array_stride: size_of::<Vertex>() as u64, array_stride: size_of::<Vertex>() as u64,
step_mode: VertexStepMode::Vertex, step_mode: VertexStepMode::Vertex,
@ -333,7 +313,7 @@ impl GetBaseDescriptor<RenderPipeline> for CustomPhaseSpecializer {
..default() ..default()
}, },
fragment: Some(FragmentState { fragment: Some(FragmentState {
shader: self.shader.clone(), shader: shader.clone(),
targets: vec![Some(ColorTargetState { targets: vec![Some(ColorTargetState {
// Ordinarily, you'd want to check whether the view has the // Ordinarily, you'd want to check whether the view has the
// HDR format and substitute the appropriate texture format // HDR format and substitute the appropriate texture format
@ -354,7 +334,27 @@ impl GetBaseDescriptor<RenderPipeline> for CustomPhaseSpecializer {
bias: default(), bias: default(),
}), }),
..default() ..default()
} };
let specialized_cache = SpecializedCache::new(CustomPhaseSpecializer, base_descriptor);
Self { specialized_cache }
}
}
#[derive(Copy, Clone, PartialEq, Eq, Hash, SpecializerKey)]
struct CustomPhaseKey(Msaa);
impl Specializer<RenderPipeline> for CustomPhaseSpecializer {
type Key = CustomPhaseKey;
fn specialize(
&self,
key: Self::Key,
descriptor: &mut RenderPipelineDescriptor,
) -> Result<Canonical<Self::Key>, BevyError> {
descriptor.multisample.count = key.0.samples();
Ok(key)
} }
} }

View File

@ -6,7 +6,7 @@ pull_requests: [17373]
The existing pipeline specialization APIs (`SpecializedRenderPipeline` etc.) have The existing pipeline specialization APIs (`SpecializedRenderPipeline` etc.) have
been replaced with a single `Specializer` trait and `SpecializedCache` collection: been replaced with a single `Specializer` trait and `SpecializedCache` collection:
```rs ```rust
pub trait Specializer<T: Specializable>: Send + Sync + 'static { pub trait Specializer<T: Specializable>: Send + Sync + 'static {
type Key: SpecializerKey; type Key: SpecializerKey;
fn specialize( fn specialize(
@ -19,20 +19,55 @@ pub trait Specializer<T: Specializable>: Send + Sync + 'static {
pub struct SpecializedCache<T: Specializable, S: Specializer<T>>{ ... }; pub struct SpecializedCache<T: Specializable, S: Specializer<T>>{ ... };
``` ```
The main difference is the change from *producing* a pipeline descriptor to For more info on specialization, see the docs for `bevy_render::render_resources::Specializer`
*mutating* one based on a key. The "base descriptor" that the `SpecializedCache`
passes to the `Specializer` can either be specified manually with `Specializer::new`
or by implementing `GetBaseDescriptor`. There's also a new trait for specialization
keys, `SpecializeKey`, that can be derived with the included macro in most cases.
Composing multiple different specializers together with the `derive(Specializer)` ## Mutation and Base Descriptors
macro can be a lot more powerful (see the `Specialize` docs), but migrating
individual specializers is fairly simple. All static parts of the pipeline
should be specified in the base descriptor, while the `Specializer` impl
should mutate the key as little as necessary to match the key.
```rs The main difference between the old and new trait is that instead of
*producing* a pipeline descriptor, `Specializer`s *mutate* existing descriptors
based on a key. As such, `SpecializedCache::new` takes in a "base descriptor"
to act as the template from which the specializer creates pipeline variants.
When migrating, the "static" parts of the pipeline (that don't depend
on the key) should become part of the base descriptor, while the specializer
itself should only change the parts demanded by the key. In the full example
below, instead of creating the entire pipeline descriptor the specializer
only changes the msaa sample count and the bind group layout.
## Composing Specializers
`Specializer`s can also be *composed* with the included derive macro to combine
their effects! This is a great way to encapsulate and reuse specialization logic,
though the rest of this guide will focus on migrating "standalone" specializers.
```rust
pub struct MsaaSpecializer {...}
impl Specialize<RenderPipeline> for MsaaSpecializer {...}
pub struct MeshLayoutSpecializer {...}
impl Specialize<RenderPipeline> for MeshLayoutSpecializer {...}
#[derive(Specializer)]
#[specialize(RenderPipeline)]
pub struct MySpecializer { pub struct MySpecializer {
msaa: MsaaSpecializer,
mesh_layout: MeshLayoutSpecializer,
}
```
## Misc Changes
The analogue of `SpecializedRenderPipelines`, `SpecializedCache`, is no longer a
Bevy `Resource`. Instead, the cache should be stored in a user-created `Resource`
(shown below) or even in a `Component` depending on the use case.
## Full Migration Example
Before:
```rust
#[derive(Resource)]
pub struct MyPipeline {
layout: BindGroupLayout, layout: BindGroupLayout,
layout_msaa: BindGroupLayout, layout_msaa: BindGroupLayout,
vertex: Handle<Shader>, vertex: Handle<Shader>,
@ -41,65 +76,131 @@ pub struct MySpecializer {
// before // before
#[derive(Clone, Copy, PartialEq, Eq, Hash)] #[derive(Clone, Copy, PartialEq, Eq, Hash)]
// after pub struct MyPipelineKey {
#[derive(Clone, Copy, PartialEq, Eq, Hash, SpecializerKey)]
pub struct MyKey {
blend_state: BlendState,
msaa: Msaa, msaa: Msaa,
} }
impl FromWorld for MySpecializer { impl FromWorld for MyPipeline {
fn from_world(&mut World) -> Self { fn from_world(world: &mut World) -> Self {
... let render_device = world.resource::<RenderDevice>();
let asset_server = world.resource::<AssetServer>();
let layout = render_device.create_bind_group_layout(...);
let layout_msaa = render_device.create_bind_group_layout(...);
let vertex = asset_server.load("vertex.wgsl");
let fragment = asset_server.load("fragment.wgsl");
Self {
layout,
layout_msaa,
vertex,
fragment,
}
} }
} }
// before impl SpecializedRenderPipeline for MyPipeline {
impl SpecializedRenderPipeline for MySpecializer { type Key = MyPipelineKey;
type Key = MyKey;
fn specialize(&self, key: Self::Key) -> RenderPipelineDescriptor { fn specialize(&self, key: Self::Key) -> RenderPipelineDescriptor {
RenderPipelineDescriptor { RenderPipelineDescriptor {
label: Some("my_pipeline".into()), label: Some("my_pipeline".into()),
layout: vec![ layout: vec![
if key.msaa.samples() > 0 { if key.msaa.samples() > 1 {
self.layout_msaa.clone() self.layout_msaa.clone()
} else { } else {
self.layout.clone() self.layout.clone()
} }
], ],
push_constant_ranges: vec![],
vertex: VertexState { vertex: VertexState {
shader: self.vertex.clone(), shader: self.vertex.clone(),
shader_defs: vec![], ..default()
entry_point: "vertex".into(),
buffers: vec![],
}, },
primitive: Default::default(),
depth_stencil: None,
multisample: MultisampleState { multisample: MultisampleState {
count: key.msaa.samples(), count: key.msaa.samples(),
..Default::default() ..default()
}, },
fragment: Some(FragmentState { fragment: Some(FragmentState {
shader: self.fragment.clone(), shader: self.fragment.clone(),
shader_defs: vec![],
entry_point: "fragment".into(),
targets: vec![Some(ColorTargetState { targets: vec![Some(ColorTargetState {
format: TextureFormat::Rgba8Unorm, format: TextureFormat::Rgba8Unorm,
blend: Some(key.blend_state), blend: None,
write_mask: ColorWrites::all(), write_mask: ColorWrites::all(),
})], })],
..default()
}), }),
zero_initialize_workgroup_memory: false, ..default()
}, },
} }
} }
app.init_resource::<SpecializedRenderPipelines<MySpecializer>>(); render_app
.init_resource::<MyPipeline>();
.init_resource::<SpecializedRenderPipelines<MySpecializer>>();
```
After:
```rust
#[derive(Resource)]
pub struct MyPipeline {
// the base_descriptor and specializer each hold onto the static
// wgpu resources (layout, shader handles), so we don't need
// explicit fields for them here. However, real-world cases
// may still need to expose them as fields to create bind groups
// from, for example.
variants: SpecializedCache<RenderPipeline, MySpecializer>,
}
pub struct MySpecializer {
layout: BindGroupLayout,
layout_msaa: BindGroupLayout,
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, SpecializerKey)]
pub struct MyPipelineKey {
msaa: Msaa,
}
impl FromWorld for MyPipeline {
fn from_world(world: &mut World) -> Self {
let render_device = world.resource::<RenderDevice>();
let asset_server = world.resource::<AssetServer>();
let layout = render_device.create_bind_group_layout(...);
let layout_msaa = render_device.create_bind_group_layout(...);
let vertex = asset_server.load("vertex.wgsl");
let fragment = asset_server.load("fragment.wgsl");
let base_descriptor = RenderPipelineDescriptor {
label: Some("my_pipeline".into()),
vertex: VertexState {
shader: vertex.clone(),
..default()
},
fragment: Some(FragmentState {
shader: fragment.clone(),
..default()
}),
..default()
},
let variants = SpecializedCache::new(
MySpecializer {
layout: layout.clone(),
layout_msaa: layout_msaa.clone(),
},
base_descriptor,
);
Self {
variants
}
}
}
// after
impl Specializer<RenderPipeline> for MySpecializer { impl Specializer<RenderPipeline> for MySpecializer {
type Key = MyKey; type Key = MyKey;
@ -109,45 +210,19 @@ impl Specializer<RenderPipeline> for MySpecializer {
descriptor: &mut RenderPipeline, descriptor: &mut RenderPipeline,
) -> Result<Canonical<Self::Key>, BevyError> { ) -> Result<Canonical<Self::Key>, BevyError> {
descriptor.multisample.count = key.msaa.samples(); descriptor.multisample.count = key.msaa.samples();
descriptor.layout[0] = if key.msaa.samples() > 0 {
let layout = if key.msaa.samples() > 1 {
self.layout_msaa.clone() self.layout_msaa.clone()
} else { } else {
self.layout.clone() self.layout.clone()
}; };
descriptor.fragment.targets[0].as_mut().unwrap().blend_mode = key.blend_state;
descriptor.set_layout(0, layout);
Ok(key) Ok(key)
} }
} }
impl GetBaseDescriptor for MySpecializer { render_app.init_resource::<MyPipeline>();
fn get_base_descriptor(&self) -> RenderPipelineDescriptor {
RenderPipelineDescriptor {
label: Some("my_pipeline".into()),
layout: vec![self.layout.clone()],
push_constant_ranges: vec![],
vertex: VertexState {
shader: self.vertex.clone(),
shader_defs: vec![],
entry_point: "vertex".into(),
buffers: vec![],
},
primitive: Default::default(),
depth_stencil: None,
multisample: MultiSampleState::default(),
fragment: Some(FragmentState {
shader: self.fragment.clone(),
shader_defs: vec![],
entry_point: "fragment".into(),
targets: vec![Some(ColorTargetState {
format: TextureFormat::Rgba8Unorm,
blend: None,
write_mask: ColorWrites::all(),
})],
}),
zero_initialize_workgroup_memory: false,
},
}
}
app.init_resource::<SpecializedCache<RenderPipeline, MySpecializer>>();
``` ```