bevy/examples/shader/shader_instancing.rs
robtfm 10f5c92068
improve shader import model (#5703)
# Objective

operate on naga IR directly to improve handling of shader modules.
- give codespan reporting into imported modules
- allow glsl to be used from wgsl and vice-versa

the ultimate objective is to make it possible to 
- provide user hooks for core shader functions (to modify light
behaviour within the standard pbr pipeline, for example)
- make automatic binding slot allocation possible

but ... since this is already big, adds some value and (i think) is at
feature parity with the existing code, i wanted to push this now.

## Solution

i made a crate called naga_oil (https://github.com/robtfm/naga_oil -
unpublished for now, could be part of bevy) which manages modules by
- building each module independantly to naga IR
- creating "header" files for each supported language, which are used to
build dependent modules/shaders
- make final shaders by combining the shader IR with the IR for imported
modules

then integrated this into bevy, replacing some of the existing shader
processing stuff. also reworked examples to reflect this.

## Migration Guide

shaders that don't use `#import` directives should work without changes.

the most notable user-facing difference is that imported
functions/variables/etc need to be qualified at point of use, and
there's no "leakage" of visible stuff into your shader scope from the
imports of your imports, so if you used things imported by your imports,
you now need to import them directly and qualify them.

the current strategy of including/'spreading' `mesh_vertex_output`
directly into a struct doesn't work any more, so these need to be
modified as per the examples (e.g. color_material.wgsl, or many others).
mesh data is assumed to be in bindgroup 2 by default, if mesh data is
bound into bindgroup 1 instead then the shader def `MESH_BINDGROUP_1`
needs to be added to the pipeline shader_defs.
2023-06-27 00:29:22 +00:00

274 lines
9.1 KiB
Rust

//! A shader that renders a mesh multiple times in one draw call.
use bevy::{
core_pipeline::core_3d::Transparent3d,
ecs::{
query::QueryItem,
system::{lifetimeless::*, SystemParamItem},
},
pbr::{MeshPipeline, MeshPipelineKey, MeshUniform, SetMeshBindGroup, SetMeshViewBindGroup},
prelude::*,
render::{
extract_component::{ExtractComponent, ExtractComponentPlugin},
mesh::{GpuBufferInfo, MeshVertexBufferLayout},
render_asset::RenderAssets,
render_phase::{
AddRenderCommand, DrawFunctions, PhaseItem, RenderCommand, RenderCommandResult,
RenderPhase, SetItemPipeline, TrackedRenderPass,
},
render_resource::*,
renderer::RenderDevice,
view::{ExtractedView, NoFrustumCulling},
Render, RenderApp, RenderSet,
},
};
use bytemuck::{Pod, Zeroable};
fn main() {
App::new()
.add_plugins((DefaultPlugins, CustomMaterialPlugin))
.add_systems(Startup, setup)
.run();
}
fn setup(mut commands: Commands, mut meshes: ResMut<Assets<Mesh>>) {
commands.spawn((
meshes.add(Mesh::from(shape::Cube { size: 0.5 })),
SpatialBundle::INHERITED_IDENTITY,
InstanceMaterialData(
(1..=10)
.flat_map(|x| (1..=10).map(move |y| (x as f32 / 10.0, y as f32 / 10.0)))
.map(|(x, y)| InstanceData {
position: Vec3::new(x * 10.0 - 5.0, y * 10.0 - 5.0, 0.0),
scale: 1.0,
color: Color::hsla(x * 360., y, 0.5, 1.0).as_rgba_f32(),
})
.collect(),
),
// NOTE: Frustum culling is done based on the Aabb of the Mesh and the GlobalTransform.
// As the cube is at the origin, if its Aabb moves outside the view frustum, all the
// instanced cubes will be culled.
// The InstanceMaterialData contains the 'GlobalTransform' information for this custom
// instancing, and that is not taken into account with the built-in frustum culling.
// We must disable the built-in frustum culling by adding the `NoFrustumCulling` marker
// component to avoid incorrect culling.
NoFrustumCulling,
));
// camera
commands.spawn(Camera3dBundle {
transform: Transform::from_xyz(0.0, 0.0, 15.0).looking_at(Vec3::ZERO, Vec3::Y),
..default()
});
}
#[derive(Component, Deref)]
struct InstanceMaterialData(Vec<InstanceData>);
impl ExtractComponent for InstanceMaterialData {
type Query = &'static InstanceMaterialData;
type Filter = ();
type Out = Self;
fn extract_component(item: QueryItem<'_, Self::Query>) -> Option<Self> {
Some(InstanceMaterialData(item.0.clone()))
}
}
pub struct CustomMaterialPlugin;
impl Plugin for CustomMaterialPlugin {
fn build(&self, app: &mut App) {
app.add_plugins(ExtractComponentPlugin::<InstanceMaterialData>::default());
app.sub_app_mut(RenderApp)
.add_render_command::<Transparent3d, DrawCustom>()
.init_resource::<SpecializedMeshPipelines<CustomPipeline>>()
.add_systems(
Render,
(
queue_custom.in_set(RenderSet::Queue),
prepare_instance_buffers.in_set(RenderSet::Prepare),
),
);
}
fn finish(&self, app: &mut App) {
app.sub_app_mut(RenderApp).init_resource::<CustomPipeline>();
}
}
#[derive(Clone, Copy, Pod, Zeroable)]
#[repr(C)]
struct InstanceData {
position: Vec3,
scale: f32,
color: [f32; 4],
}
#[allow(clippy::too_many_arguments)]
fn queue_custom(
transparent_3d_draw_functions: Res<DrawFunctions<Transparent3d>>,
custom_pipeline: Res<CustomPipeline>,
msaa: Res<Msaa>,
mut pipelines: ResMut<SpecializedMeshPipelines<CustomPipeline>>,
pipeline_cache: Res<PipelineCache>,
meshes: Res<RenderAssets<Mesh>>,
material_meshes: Query<(Entity, &MeshUniform, &Handle<Mesh>), With<InstanceMaterialData>>,
mut views: Query<(&ExtractedView, &mut RenderPhase<Transparent3d>)>,
) {
let draw_custom = transparent_3d_draw_functions.read().id::<DrawCustom>();
let msaa_key = MeshPipelineKey::from_msaa_samples(msaa.samples());
for (view, mut transparent_phase) in &mut views {
let view_key = msaa_key | MeshPipelineKey::from_hdr(view.hdr);
let rangefinder = view.rangefinder3d();
for (entity, mesh_uniform, mesh_handle) in &material_meshes {
if let Some(mesh) = meshes.get(mesh_handle) {
let key =
view_key | MeshPipelineKey::from_primitive_topology(mesh.primitive_topology);
let pipeline = pipelines
.specialize(&pipeline_cache, &custom_pipeline, key, &mesh.layout)
.unwrap();
transparent_phase.add(Transparent3d {
entity,
pipeline,
draw_function: draw_custom,
distance: rangefinder.distance(&mesh_uniform.transform),
});
}
}
}
}
#[derive(Component)]
pub struct InstanceBuffer {
buffer: Buffer,
length: usize,
}
fn prepare_instance_buffers(
mut commands: Commands,
query: Query<(Entity, &InstanceMaterialData)>,
render_device: Res<RenderDevice>,
) {
for (entity, instance_data) in &query {
let buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
label: Some("instance data buffer"),
contents: bytemuck::cast_slice(instance_data.as_slice()),
usage: BufferUsages::VERTEX | BufferUsages::COPY_DST,
});
commands.entity(entity).insert(InstanceBuffer {
buffer,
length: instance_data.len(),
});
}
}
#[derive(Resource)]
pub struct CustomPipeline {
shader: Handle<Shader>,
mesh_pipeline: MeshPipeline,
}
impl FromWorld for CustomPipeline {
fn from_world(world: &mut World) -> Self {
let asset_server = world.resource::<AssetServer>();
let shader = asset_server.load("shaders/instancing.wgsl");
let mesh_pipeline = world.resource::<MeshPipeline>();
CustomPipeline {
shader,
mesh_pipeline: mesh_pipeline.clone(),
}
}
}
impl SpecializedMeshPipeline for CustomPipeline {
type Key = MeshPipelineKey;
fn specialize(
&self,
key: Self::Key,
layout: &MeshVertexBufferLayout,
) -> Result<RenderPipelineDescriptor, SpecializedMeshPipelineError> {
let mut descriptor = self.mesh_pipeline.specialize(key, layout)?;
// meshes typically live in bind group 2. because we are using bindgroup 1
// we need to add MESH_BINDGROUP_1 shader def so that the bindings are correctly
// linked in the shader
descriptor
.vertex
.shader_defs
.push("MESH_BINDGROUP_1".into());
descriptor.vertex.shader = self.shader.clone();
descriptor.vertex.buffers.push(VertexBufferLayout {
array_stride: std::mem::size_of::<InstanceData>() as u64,
step_mode: VertexStepMode::Instance,
attributes: vec![
VertexAttribute {
format: VertexFormat::Float32x4,
offset: 0,
shader_location: 3, // shader locations 0-2 are taken up by Position, Normal and UV attributes
},
VertexAttribute {
format: VertexFormat::Float32x4,
offset: VertexFormat::Float32x4.size(),
shader_location: 4,
},
],
});
descriptor.fragment.as_mut().unwrap().shader = self.shader.clone();
Ok(descriptor)
}
}
type DrawCustom = (
SetItemPipeline,
SetMeshViewBindGroup<0>,
SetMeshBindGroup<1>,
DrawMeshInstanced,
);
pub struct DrawMeshInstanced;
impl<P: PhaseItem> RenderCommand<P> for DrawMeshInstanced {
type Param = SRes<RenderAssets<Mesh>>;
type ViewWorldQuery = ();
type ItemWorldQuery = (Read<Handle<Mesh>>, Read<InstanceBuffer>);
#[inline]
fn render<'w>(
_item: &P,
_view: (),
(mesh_handle, instance_buffer): (&'w Handle<Mesh>, &'w InstanceBuffer),
meshes: SystemParamItem<'w, '_, Self::Param>,
pass: &mut TrackedRenderPass<'w>,
) -> RenderCommandResult {
let gpu_mesh = match meshes.into_inner().get(mesh_handle) {
Some(gpu_mesh) => gpu_mesh,
None => return RenderCommandResult::Failure,
};
pass.set_vertex_buffer(0, gpu_mesh.vertex_buffer.slice(..));
pass.set_vertex_buffer(1, instance_buffer.buffer.slice(..));
match &gpu_mesh.buffer_info {
GpuBufferInfo::Indexed {
buffer,
index_format,
count,
} => {
pass.set_index_buffer(buffer.slice(..), 0, *index_format);
pass.draw_indexed(0..*count, 0, 0..instance_buffer.length as u32);
}
GpuBufferInfo::NonIndexed => {
pass.draw(0..gpu_mesh.vertex_count, 0..instance_buffer.length as u32);
}
}
RenderCommandResult::Success
}
}