diff --git a/examples/shader/gpu_readback.rs b/examples/shader/gpu_readback.rs index 964776291e..07473e199e 100644 --- a/examples/shader/gpu_readback.rs +++ b/examples/shader/gpu_readback.rs @@ -15,7 +15,7 @@ use bevy::{ renderer::{RenderContext, RenderDevice}, storage::{GpuShaderStorageBuffer, ShaderStorageBuffer}, texture::GpuImage, - Render, RenderApp, RenderSystems, + Render, RenderApp, RenderStartup, RenderSystems, }, }; @@ -41,24 +41,22 @@ fn main() { // We need a plugin to organize all the systems and render node required for this example struct GpuReadbackPlugin; impl Plugin for GpuReadbackPlugin { - fn build(&self, _app: &mut App) {} - - fn finish(&self, app: &mut App) { - let render_app = app.sub_app_mut(RenderApp); - render_app.init_resource::().add_systems( - Render, - prepare_bind_group - .in_set(RenderSystems::PrepareBindGroups) - // We don't need to recreate the bind group every frame - .run_if(not(resource_exists::)), - ); - - // Add the compute node as a top level node to the render graph - // This means it will only execute once per frame + fn build(&self, app: &mut App) { + let Some(render_app) = app.get_sub_app_mut(RenderApp) else { + return; + }; render_app - .world_mut() - .resource_mut::() - .add_node(ComputeNodeLabel, ComputeNode::default()); + .add_systems( + RenderStartup, + (init_compute_pipeline, add_compute_render_graph_node), + ) + .add_systems( + Render, + prepare_bind_group + .in_set(RenderSystems::PrepareBindGroups) + // We don't need to recreate the bind group every frame + .run_if(not(resource_exists::)), + ); } } @@ -127,6 +125,13 @@ fn setup( commands.insert_resource(ReadbackImage(image)); } +fn add_compute_render_graph_node(mut render_graph: ResMut) { + // Add the compute node as a top-level node to the render graph. This means it will only execute + // once per frame. Normally, adding a node would use the `RenderGraphApp::add_render_graph_node` + // method, but it does not allow adding as a top-level node. + render_graph.add_node(ComputeNodeLabel, ComputeNode::default()); +} + #[derive(Resource)] struct GpuBufferBindGroup(BindGroup); @@ -158,29 +163,30 @@ struct ComputePipeline { pipeline: CachedComputePipelineId, } -impl FromWorld for ComputePipeline { - fn from_world(world: &mut World) -> Self { - let render_device = world.resource::(); - let layout = render_device.create_bind_group_layout( - None, - &BindGroupLayoutEntries::sequential( - ShaderStages::COMPUTE, - ( - storage_buffer::>(false), - texture_storage_2d(TextureFormat::R32Uint, StorageTextureAccess::WriteOnly), - ), +fn init_compute_pipeline( + mut commands: Commands, + render_device: Res, + asset_server: Res, + pipeline_cache: Res, +) { + let layout = render_device.create_bind_group_layout( + None, + &BindGroupLayoutEntries::sequential( + ShaderStages::COMPUTE, + ( + storage_buffer::>(false), + texture_storage_2d(TextureFormat::R32Uint, StorageTextureAccess::WriteOnly), ), - ); - let shader = world.load_asset(SHADER_ASSET_PATH); - let pipeline_cache = world.resource::(); - let pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { - label: Some("GPU readback compute shader".into()), - layout: vec![layout.clone()], - shader: shader.clone(), - ..default() - }); - ComputePipeline { layout, pipeline } - } + ), + ); + let shader = asset_server.load(SHADER_ASSET_PATH); + let pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { + label: Some("GPU readback compute shader".into()), + layout: vec![layout.clone()], + shader: shader.clone(), + ..default() + }); + commands.insert_resource(ComputePipeline { layout, pipeline }); } /// Label to identify the node in the render graph