bevy/examples/shader/shader_instancing.rs
Robert Swain 0a11af9375
Reduce the size of MeshUniform to improve performance (#9416)
# Objective

- Significantly reduce the size of MeshUniform by only including
necessary data.

## Solution

Local to world, model transforms are affine. This means they only need a
4x3 matrix to represent them.

`MeshUniform` stores the current, and previous model transforms, and the
inverse transpose of the current model transform, all as 4x4 matrices.
Instead we can store the current, and previous model transforms as 4x3
matrices, and we only need the upper-left 3x3 part of the inverse
transpose of the current model transform. This change allows us to
reduce the serialized MeshUniform size from 208 bytes to 144 bytes,
which is over a 30% saving in data to serialize, and VRAM bandwidth and
space.

## Benchmarks

On an M1 Max, running `many_cubes -- sphere`, main is in yellow, this PR
is in red:
<img width="1484" alt="Screenshot 2023-08-11 at 02 36 43"
src="https://github.com/bevyengine/bevy/assets/302146/7d99c7b3-f2bb-4004-a8d0-4c00f755cb0d">
A reduction in frame time of ~14%.

---

## Changelog

- Changed: Redefined `MeshUniform` to improve performance by using 4x3
affine transforms and reconstructing 4x4 matrices in the shader. Helper
functions were added to `bevy_pbr::mesh_functions` to unpack the data.
`affine_to_square` converts the packed 4x3 in 3x4 matrix data to a 4x4
matrix. `mat2x4_f32_to_mat3x3` converts the 3x3 in mat2x4 + f32 matrix
data back into a 3x3.

## Migration Guide

Shader code before:
```
var model = mesh[instance_index].model;
```

Shader code after:
```
#import bevy_pbr::mesh_functions affine_to_square

var model = affine_to_square(mesh[instance_index].model);
```
2023-08-15 06:00:23 +00:00

275 lines
9.2 KiB
Rust

//! A shader that renders a mesh multiple times in one draw call.
use bevy::{
core_pipeline::core_3d::Transparent3d,
ecs::{
query::QueryItem,
system::{lifetimeless::*, SystemParamItem},
},
pbr::{MeshPipeline, MeshPipelineKey, MeshTransforms, SetMeshBindGroup, SetMeshViewBindGroup},
prelude::*,
render::{
extract_component::{ExtractComponent, ExtractComponentPlugin},
mesh::{GpuBufferInfo, MeshVertexBufferLayout},
render_asset::RenderAssets,
render_phase::{
AddRenderCommand, DrawFunctions, PhaseItem, RenderCommand, RenderCommandResult,
RenderPhase, SetItemPipeline, TrackedRenderPass,
},
render_resource::*,
renderer::RenderDevice,
view::{ExtractedView, NoFrustumCulling},
Render, RenderApp, RenderSet,
},
};
use bytemuck::{Pod, Zeroable};
fn main() {
App::new()
.add_plugins((DefaultPlugins, CustomMaterialPlugin))
.add_systems(Startup, setup)
.run();
}
fn setup(mut commands: Commands, mut meshes: ResMut<Assets<Mesh>>) {
commands.spawn((
meshes.add(Mesh::from(shape::Cube { size: 0.5 })),
SpatialBundle::INHERITED_IDENTITY,
InstanceMaterialData(
(1..=10)
.flat_map(|x| (1..=10).map(move |y| (x as f32 / 10.0, y as f32 / 10.0)))
.map(|(x, y)| InstanceData {
position: Vec3::new(x * 10.0 - 5.0, y * 10.0 - 5.0, 0.0),
scale: 1.0,
color: Color::hsla(x * 360., y, 0.5, 1.0).as_rgba_f32(),
})
.collect(),
),
// NOTE: Frustum culling is done based on the Aabb of the Mesh and the GlobalTransform.
// As the cube is at the origin, if its Aabb moves outside the view frustum, all the
// instanced cubes will be culled.
// The InstanceMaterialData contains the 'GlobalTransform' information for this custom
// instancing, and that is not taken into account with the built-in frustum culling.
// We must disable the built-in frustum culling by adding the `NoFrustumCulling` marker
// component to avoid incorrect culling.
NoFrustumCulling,
));
// camera
commands.spawn(Camera3dBundle {
transform: Transform::from_xyz(0.0, 0.0, 15.0).looking_at(Vec3::ZERO, Vec3::Y),
..default()
});
}
#[derive(Component, Deref)]
struct InstanceMaterialData(Vec<InstanceData>);
impl ExtractComponent for InstanceMaterialData {
type Query = &'static InstanceMaterialData;
type Filter = ();
type Out = Self;
fn extract_component(item: QueryItem<'_, Self::Query>) -> Option<Self> {
Some(InstanceMaterialData(item.0.clone()))
}
}
pub struct CustomMaterialPlugin;
impl Plugin for CustomMaterialPlugin {
fn build(&self, app: &mut App) {
app.add_plugins(ExtractComponentPlugin::<InstanceMaterialData>::default());
app.sub_app_mut(RenderApp)
.add_render_command::<Transparent3d, DrawCustom>()
.init_resource::<SpecializedMeshPipelines<CustomPipeline>>()
.add_systems(
Render,
(
queue_custom.in_set(RenderSet::Queue),
prepare_instance_buffers.in_set(RenderSet::Prepare),
),
);
}
fn finish(&self, app: &mut App) {
app.sub_app_mut(RenderApp).init_resource::<CustomPipeline>();
}
}
#[derive(Clone, Copy, Pod, Zeroable)]
#[repr(C)]
struct InstanceData {
position: Vec3,
scale: f32,
color: [f32; 4],
}
#[allow(clippy::too_many_arguments)]
fn queue_custom(
transparent_3d_draw_functions: Res<DrawFunctions<Transparent3d>>,
custom_pipeline: Res<CustomPipeline>,
msaa: Res<Msaa>,
mut pipelines: ResMut<SpecializedMeshPipelines<CustomPipeline>>,
pipeline_cache: Res<PipelineCache>,
meshes: Res<RenderAssets<Mesh>>,
material_meshes: Query<(Entity, &MeshTransforms, &Handle<Mesh>), With<InstanceMaterialData>>,
mut views: Query<(&ExtractedView, &mut RenderPhase<Transparent3d>)>,
) {
let draw_custom = transparent_3d_draw_functions.read().id::<DrawCustom>();
let msaa_key = MeshPipelineKey::from_msaa_samples(msaa.samples());
for (view, mut transparent_phase) in &mut views {
let view_key = msaa_key | MeshPipelineKey::from_hdr(view.hdr);
let rangefinder = view.rangefinder3d();
for (entity, mesh_transforms, mesh_handle) in &material_meshes {
if let Some(mesh) = meshes.get(mesh_handle) {
let key =
view_key | MeshPipelineKey::from_primitive_topology(mesh.primitive_topology);
let pipeline = pipelines
.specialize(&pipeline_cache, &custom_pipeline, key, &mesh.layout)
.unwrap();
transparent_phase.add(Transparent3d {
entity,
pipeline,
draw_function: draw_custom,
distance: rangefinder
.distance_translation(&mesh_transforms.transform.translation),
});
}
}
}
}
#[derive(Component)]
pub struct InstanceBuffer {
buffer: Buffer,
length: usize,
}
fn prepare_instance_buffers(
mut commands: Commands,
query: Query<(Entity, &InstanceMaterialData)>,
render_device: Res<RenderDevice>,
) {
for (entity, instance_data) in &query {
let buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
label: Some("instance data buffer"),
contents: bytemuck::cast_slice(instance_data.as_slice()),
usage: BufferUsages::VERTEX | BufferUsages::COPY_DST,
});
commands.entity(entity).insert(InstanceBuffer {
buffer,
length: instance_data.len(),
});
}
}
#[derive(Resource)]
pub struct CustomPipeline {
shader: Handle<Shader>,
mesh_pipeline: MeshPipeline,
}
impl FromWorld for CustomPipeline {
fn from_world(world: &mut World) -> Self {
let asset_server = world.resource::<AssetServer>();
let shader = asset_server.load("shaders/instancing.wgsl");
let mesh_pipeline = world.resource::<MeshPipeline>();
CustomPipeline {
shader,
mesh_pipeline: mesh_pipeline.clone(),
}
}
}
impl SpecializedMeshPipeline for CustomPipeline {
type Key = MeshPipelineKey;
fn specialize(
&self,
key: Self::Key,
layout: &MeshVertexBufferLayout,
) -> Result<RenderPipelineDescriptor, SpecializedMeshPipelineError> {
let mut descriptor = self.mesh_pipeline.specialize(key, layout)?;
// meshes typically live in bind group 2. because we are using bindgroup 1
// we need to add MESH_BINDGROUP_1 shader def so that the bindings are correctly
// linked in the shader
descriptor
.vertex
.shader_defs
.push("MESH_BINDGROUP_1".into());
descriptor.vertex.shader = self.shader.clone();
descriptor.vertex.buffers.push(VertexBufferLayout {
array_stride: std::mem::size_of::<InstanceData>() as u64,
step_mode: VertexStepMode::Instance,
attributes: vec![
VertexAttribute {
format: VertexFormat::Float32x4,
offset: 0,
shader_location: 3, // shader locations 0-2 are taken up by Position, Normal and UV attributes
},
VertexAttribute {
format: VertexFormat::Float32x4,
offset: VertexFormat::Float32x4.size(),
shader_location: 4,
},
],
});
descriptor.fragment.as_mut().unwrap().shader = self.shader.clone();
Ok(descriptor)
}
}
type DrawCustom = (
SetItemPipeline,
SetMeshViewBindGroup<0>,
SetMeshBindGroup<1>,
DrawMeshInstanced,
);
pub struct DrawMeshInstanced;
impl<P: PhaseItem> RenderCommand<P> for DrawMeshInstanced {
type Param = SRes<RenderAssets<Mesh>>;
type ViewWorldQuery = ();
type ItemWorldQuery = (Read<Handle<Mesh>>, Read<InstanceBuffer>);
#[inline]
fn render<'w>(
_item: &P,
_view: (),
(mesh_handle, instance_buffer): (&'w Handle<Mesh>, &'w InstanceBuffer),
meshes: SystemParamItem<'w, '_, Self::Param>,
pass: &mut TrackedRenderPass<'w>,
) -> RenderCommandResult {
let gpu_mesh = match meshes.into_inner().get(mesh_handle) {
Some(gpu_mesh) => gpu_mesh,
None => return RenderCommandResult::Failure,
};
pass.set_vertex_buffer(0, gpu_mesh.vertex_buffer.slice(..));
pass.set_vertex_buffer(1, instance_buffer.buffer.slice(..));
match &gpu_mesh.buffer_info {
GpuBufferInfo::Indexed {
buffer,
index_format,
count,
} => {
pass.set_index_buffer(buffer.slice(..), 0, *index_format);
pass.draw_indexed(0..*count, 0, 0..instance_buffer.length as u32);
}
GpuBufferInfo::NonIndexed => {
pass.draw(0..gpu_mesh.vertex_count, 0..instance_buffer.length as u32);
}
}
RenderCommandResult::Success
}
}