#define_import_path bevy_pbr::meshlet_bindings #import bevy_pbr::mesh_types::Mesh #import bevy_render::view::View #import bevy_pbr::prepass_bindings::PreviousViewUniforms #import bevy_pbr::utils::octahedral_decode_signed struct BvhNode { aabbs: array, lod_bounds: array, 8>, child_counts: array, _padding: vec2, } struct Meshlet { start_vertex_position_bit: u32, start_vertex_attribute_id: u32, start_index_id: u32, packed_a: u32, packed_b: u32, min_vertex_position_channel_x: f32, min_vertex_position_channel_y: f32, min_vertex_position_channel_z: f32, } fn get_meshlet_vertex_count(meshlet: ptr) -> u32 { return extractBits((*meshlet).packed_a, 0u, 8u); } fn get_meshlet_triangle_count(meshlet: ptr) -> u32 { return extractBits((*meshlet).packed_a, 8u, 8u); } struct MeshletCullData { aabb: MeshletAabbErrorOffset, lod_group_sphere: vec4, } struct MeshletAabb { center: vec3, half_extent: vec3, } struct MeshletAabbErrorOffset { center_and_error: vec4, half_extent_and_child_offset: vec4, } fn get_aabb(aabb: ptr) -> MeshletAabb { return MeshletAabb( (*aabb).center_and_error.xyz, (*aabb).half_extent_and_child_offset.xyz, ); } fn get_aabb_error(aabb: ptr) -> f32 { return (*aabb).center_and_error.w; } fn get_aabb_child_offset(aabb: ptr) -> u32 { return bitcast((*aabb).half_extent_and_child_offset.w); } struct DispatchIndirectArgs { x: atomic, y: u32, z: u32, } struct DrawIndirectArgs { vertex_count: u32, instance_count: atomic, first_vertex: u32, first_instance: u32, } // Either a BVH node or a meshlet, along with the instance it is associated with. // Refers to BVH nodes in `meshlet_bvh_cull_queue` and `meshlet_second_pass_bvh_queue`, where `offset` is the index into `meshlet_bvh_nodes`. // Refers to meshlets in `meshlet_meshlet_cull_queue` and `meshlet_raster_clusters`. // In `meshlet_meshlet_cull_queue`, `offset` is the index into `meshlet_cull_data`. // In `meshlet_raster_clusters`, `offset` is the index into `meshlets`. struct InstancedOffset { instance_id: u32, offset: u32, } const CENTIMETERS_PER_METER = 100.0; #ifdef MESHLET_INSTANCE_CULLING_PASS struct Constants { scene_instance_count: u32 } var constants: Constants; // Cull data @group(0) @binding(0) var depth_pyramid: texture_2d; @group(0) @binding(1) var view: View; @group(0) @binding(2) var previous_view: PreviousViewUniforms; // Per entity instance data @group(0) @binding(3) var meshlet_instance_uniforms: array; @group(0) @binding(4) var meshlet_view_instance_visibility: array; // 1 bit per entity instance, packed as a bitmask @group(0) @binding(5) var meshlet_instance_aabbs: array; @group(0) @binding(6) var meshlet_instance_bvh_root_nodes: array; // BVH cull queue data @group(0) @binding(7) var meshlet_bvh_cull_count_write: atomic; @group(0) @binding(8) var meshlet_bvh_cull_dispatch: DispatchIndirectArgs; @group(0) @binding(9) var meshlet_bvh_cull_queue: array; // Second pass queue data #ifdef MESHLET_FIRST_CULLING_PASS @group(0) @binding(10) var meshlet_second_pass_instance_count: atomic; @group(0) @binding(11) var meshlet_second_pass_instance_dispatch: DispatchIndirectArgs; @group(0) @binding(12) var meshlet_second_pass_instance_candidates: array; #else @group(0) @binding(10) var meshlet_second_pass_instance_count: u32; @group(0) @binding(11) var meshlet_second_pass_instance_candidates: array; #endif #endif #ifdef MESHLET_BVH_CULLING_PASS struct Constants { read_from_front: u32, rightmost_slot: u32 } var constants: Constants; // Cull data @group(0) @binding(0) var depth_pyramid: texture_2d; // From the end of the last frame for the first culling pass, and from the first raster pass for the second culling pass @group(0) @binding(1) var view: View; @group(0) @binding(2) var previous_view: PreviousViewUniforms; // Global mesh data @group(0) @binding(3) var meshlet_bvh_nodes: array; // Per entity instance data @group(0) @binding(4) var meshlet_instance_uniforms: array; // BVH cull queue data @group(0) @binding(5) var meshlet_bvh_cull_count_read: u32; @group(0) @binding(6) var meshlet_bvh_cull_count_write: atomic; @group(0) @binding(7) var meshlet_bvh_cull_dispatch: DispatchIndirectArgs; @group(0) @binding(8) var meshlet_bvh_cull_queue: array; // Meshlet cull queue data @group(0) @binding(9) var meshlet_meshlet_cull_count_early: atomic; @group(0) @binding(10) var meshlet_meshlet_cull_count_late: atomic; @group(0) @binding(11) var meshlet_meshlet_cull_dispatch_early: DispatchIndirectArgs; @group(0) @binding(12) var meshlet_meshlet_cull_dispatch_late: DispatchIndirectArgs; @group(0) @binding(13) var meshlet_meshlet_cull_queue: array; // Second pass queue data #ifdef MESHLET_FIRST_CULLING_PASS @group(0) @binding(14) var meshlet_second_pass_bvh_count: atomic; @group(0) @binding(15) var meshlet_second_pass_bvh_dispatch: DispatchIndirectArgs; @group(0) @binding(16) var meshlet_second_pass_bvh_queue: array; #endif #endif #ifdef MESHLET_CLUSTER_CULLING_PASS struct Constants { rightmost_slot: u32 } var constants: Constants; // Cull data @group(0) @binding(0) var depth_pyramid: texture_2d; // From the end of the last frame for the first culling pass, and from the first raster pass for the second culling pass @group(0) @binding(1) var view: View; @group(0) @binding(2) var previous_view: PreviousViewUniforms; // Global mesh data @group(0) @binding(3) var meshlet_cull_data: array; // Per entity instance data @group(0) @binding(4) var meshlet_instance_uniforms: array; // Raster queue data @group(0) @binding(5) var meshlet_software_raster_indirect_args: DispatchIndirectArgs; @group(0) @binding(6) var meshlet_hardware_raster_indirect_args: DrawIndirectArgs; @group(0) @binding(7) var meshlet_previous_raster_counts: array; @group(0) @binding(8) var meshlet_raster_clusters: array; // Meshlet cull queue data @group(0) @binding(9) var meshlet_meshlet_cull_count_read: u32; // Second pass queue data #ifdef MESHLET_FIRST_CULLING_PASS @group(0) @binding(10) var meshlet_meshlet_cull_count_write: atomic; @group(0) @binding(11) var meshlet_meshlet_cull_dispatch: DispatchIndirectArgs; @group(0) @binding(12) var meshlet_meshlet_cull_queue: array; #else @group(0) @binding(10) var meshlet_meshlet_cull_queue: array; #endif #endif #ifdef MESHLET_VISIBILITY_BUFFER_RASTER_PASS @group(0) @binding(0) var meshlet_raster_clusters: array; // Per cluster @group(0) @binding(1) var meshlets: array; // Per meshlet @group(0) @binding(2) var meshlet_indices: array; // Many per meshlet @group(0) @binding(3) var meshlet_vertex_positions: array; // Many per meshlet @group(0) @binding(4) var meshlet_instance_uniforms: array; // Per entity instance @group(0) @binding(5) var meshlet_previous_raster_counts: array; @group(0) @binding(6) var meshlet_software_raster_cluster_count: u32; #ifdef MESHLET_VISIBILITY_BUFFER_RASTER_PASS_OUTPUT @group(0) @binding(7) var meshlet_visibility_buffer: texture_storage_2d; #else @group(0) @binding(7) var meshlet_visibility_buffer: texture_storage_2d; #endif @group(0) @binding(8) var view: View; // TODO: Load only twice, instead of 3x in cases where you load 3 indices per thread? fn get_meshlet_vertex_id(index_id: u32) -> u32 { let packed_index = meshlet_indices[index_id >> 2u]; let bit_offset = (index_id & 3u) * 8u; return extractBits(packed_index, bit_offset, 8u); } fn get_meshlet_vertex_position(meshlet: ptr, vertex_id: u32) -> vec3 { // Get bitstream start for the vertex let unpacked = unpack4xU8((*meshlet).packed_b); let bits_per_channel = unpacked.xyz; let bits_per_vertex = bits_per_channel.x + bits_per_channel.y + bits_per_channel.z; var start_bit = (*meshlet).start_vertex_position_bit + (vertex_id * bits_per_vertex); // Read each vertex channel from the bitstream var vertex_position_packed = vec3(0u); for (var i = 0u; i < 3u; i++) { let lower_word_index = start_bit >> 5u; let lower_word_bit_offset = start_bit & 31u; var next_32_bits = meshlet_vertex_positions[lower_word_index] >> lower_word_bit_offset; if lower_word_bit_offset + bits_per_channel[i] > 32u { next_32_bits |= meshlet_vertex_positions[lower_word_index + 1u] << (32u - lower_word_bit_offset); } vertex_position_packed[i] = extractBits(next_32_bits, 0u, bits_per_channel[i]); start_bit += bits_per_channel[i]; } // Remap [0, range_max - range_min] vec3 to [range_min, range_max] vec3 var vertex_position = vec3(vertex_position_packed) + vec3( (*meshlet).min_vertex_position_channel_x, (*meshlet).min_vertex_position_channel_y, (*meshlet).min_vertex_position_channel_z, ); // Reverse vertex quantization let vertex_position_quantization_factor = unpacked.w; vertex_position /= f32(1u << vertex_position_quantization_factor) * CENTIMETERS_PER_METER; return vertex_position; } #endif #ifdef MESHLET_MESH_MATERIAL_PASS @group(2) @binding(0) var meshlet_visibility_buffer: texture_storage_2d; @group(2) @binding(1) var meshlet_raster_clusters: array; // Per cluster @group(2) @binding(2) var meshlets: array; // Per meshlet @group(2) @binding(3) var meshlet_indices: array; // Many per meshlet @group(2) @binding(4) var meshlet_vertex_positions: array; // Many per meshlet @group(2) @binding(5) var meshlet_vertex_normals: array; // Many per meshlet @group(2) @binding(6) var meshlet_vertex_uvs: array>; // Many per meshlet @group(2) @binding(7) var meshlet_instance_uniforms: array; // Per entity instance // TODO: Load only twice, instead of 3x in cases where you load 3 indices per thread? fn get_meshlet_vertex_id(index_id: u32) -> u32 { let packed_index = meshlet_indices[index_id >> 2u]; let bit_offset = (index_id & 3u) * 8u; return extractBits(packed_index, bit_offset, 8u); } fn get_meshlet_vertex_position(meshlet: ptr, vertex_id: u32) -> vec3 { // Get bitstream start for the vertex let unpacked = unpack4xU8((*meshlet).packed_b); let bits_per_channel = unpacked.xyz; let bits_per_vertex = bits_per_channel.x + bits_per_channel.y + bits_per_channel.z; var start_bit = (*meshlet).start_vertex_position_bit + (vertex_id * bits_per_vertex); // Read each vertex channel from the bitstream var vertex_position_packed = vec3(0u); for (var i = 0u; i < 3u; i++) { let lower_word_index = start_bit >> 5u; let lower_word_bit_offset = start_bit & 31u; var next_32_bits = meshlet_vertex_positions[lower_word_index] >> lower_word_bit_offset; if lower_word_bit_offset + bits_per_channel[i] > 32u { next_32_bits |= meshlet_vertex_positions[lower_word_index + 1u] << (32u - lower_word_bit_offset); } vertex_position_packed[i] = extractBits(next_32_bits, 0u, bits_per_channel[i]); start_bit += bits_per_channel[i]; } // Remap [0, range_max - range_min] vec3 to [range_min, range_max] vec3 var vertex_position = vec3(vertex_position_packed) + vec3( (*meshlet).min_vertex_position_channel_x, (*meshlet).min_vertex_position_channel_y, (*meshlet).min_vertex_position_channel_z, ); // Reverse vertex quantization let vertex_position_quantization_factor = unpacked.w; vertex_position /= f32(1u << vertex_position_quantization_factor) * CENTIMETERS_PER_METER; return vertex_position; } fn get_meshlet_vertex_normal(meshlet: ptr, vertex_id: u32) -> vec3 { let packed_normal = meshlet_vertex_normals[(*meshlet).start_vertex_attribute_id + vertex_id]; return octahedral_decode_signed(unpack2x16snorm(packed_normal)); } fn get_meshlet_vertex_uv(meshlet: ptr, vertex_id: u32) -> vec2 { return meshlet_vertex_uvs[(*meshlet).start_vertex_attribute_id + vertex_id]; } #endif