460 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			Rust
		
	
	
	
	
	
			
		
		
	
	
			460 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			Rust
		
	
	
	
	
	
| use crate::render::{
 | |
|     pipeline::{
 | |
|         BindGroupDescriptor, BindType, BindingDescriptor, InputStepMode, UniformProperty,
 | |
|         UniformPropertyType, VertexAttributeDescriptor, VertexBufferDescriptor, VertexFormat,
 | |
|     },
 | |
|     texture::TextureViewDimension,
 | |
| };
 | |
| use spirv_reflect::{
 | |
|     types::{
 | |
|         ReflectDescriptorBinding, ReflectDescriptorSet, ReflectDescriptorType, ReflectDimension,
 | |
|         ReflectInterfaceVariable, ReflectTypeDescription, ReflectTypeFlags,
 | |
|     },
 | |
|     ShaderModule,
 | |
| };
 | |
| use std::collections::HashSet;
 | |
| use zerocopy::AsBytes;
 | |
| // use rspirv::{binary::Parser, dr::Loader, lift::LiftContext};
 | |
| 
 | |
| // TODO: use rspirv when structured representation is ready. this way we can remove spirv_reflect, which is a non-rust dependency
 | |
| // pub fn get_shader_layout(spirv_data: &[u32]) {
 | |
| //     let mut loader = Loader::new();  // You can use your own consumer here.
 | |
| //     {
 | |
| //         let p = Parser::new(spirv_data.as_bytes(), &mut loader);
 | |
| //         p.parse().unwrap();
 | |
| //     }
 | |
| //     let module = loader.module();
 | |
| //     let structured = LiftContext::convert(&module).unwrap();
 | |
| //     println!("{:?}", structured.types);
 | |
| // }
 | |
| 
 | |
| #[derive(Debug, Clone, PartialEq, Eq)]
 | |
| pub struct ShaderLayout {
 | |
|     pub bind_groups: Vec<BindGroupDescriptor>,
 | |
|     pub vertex_buffer_descriptors: Vec<VertexBufferDescriptor>,
 | |
|     pub entry_point: String,
 | |
| }
 | |
| 
 | |
| impl ShaderLayout {
 | |
|     pub fn from_spirv(spirv_data: &[u32]) -> ShaderLayout {
 | |
|         match ShaderModule::load_u8_data(spirv_data.as_bytes()) {
 | |
|             Ok(ref mut module) => {
 | |
|                 let entry_point_name = module.get_entry_point_name();
 | |
|                 let mut bind_groups = Vec::new();
 | |
|                 for descriptor_set in module.enumerate_descriptor_sets(None).unwrap() {
 | |
|                     let bind_group = reflect_bind_group(&descriptor_set);
 | |
|                     bind_groups.push(bind_group);
 | |
|                 }
 | |
| 
 | |
|                 let mut vertex_attribute_descriptors = Vec::new();
 | |
|                 for input_variable in module.enumerate_input_variables(None).unwrap() {
 | |
|                     let vertex_attribute_descriptor =
 | |
|                         reflect_vertex_attribute_descriptor(&input_variable);
 | |
|                     vertex_attribute_descriptors.push(vertex_attribute_descriptor);
 | |
|                 }
 | |
| 
 | |
|                 vertex_attribute_descriptors
 | |
|                     .sort_by(|a, b| a.shader_location.cmp(&b.shader_location));
 | |
| 
 | |
|                 let mut visited_buffer_descriptors = HashSet::new();
 | |
|                 let mut vertex_buffer_descriptors = Vec::new();
 | |
|                 let mut current_descriptor: Option<VertexBufferDescriptor> = None;
 | |
|                 for vertex_attribute_descriptor in vertex_attribute_descriptors.drain(..) {
 | |
|                     let mut instance = false;
 | |
|                     let current_buffer_name = {
 | |
|                         let parts = vertex_attribute_descriptor
 | |
|                             .name
 | |
|                             .splitn(3, "_")
 | |
|                             .collect::<Vec<&str>>();
 | |
|                         if parts.len() == 3 {
 | |
|                             if parts[0] == "I" {
 | |
|                                 instance = true;
 | |
|                                 parts[1].to_string()
 | |
|                             } else {
 | |
|                                 parts[0].to_string()
 | |
|                             }
 | |
|                         } else if parts.len() == 2 {
 | |
|                             parts[0].to_string()
 | |
|                         } else {
 | |
|                             panic!("Vertex attributes must follow the form BUFFERNAME_PROPERTYNAME. For example: Vertex_Position");
 | |
|                         }
 | |
|                     };
 | |
| 
 | |
|                     if let Some(current) = current_descriptor.as_mut() {
 | |
|                         if ¤t.name == ¤t_buffer_name {
 | |
|                             current.attributes.push(vertex_attribute_descriptor);
 | |
|                             continue;
 | |
|                         } else {
 | |
|                             if visited_buffer_descriptors.contains(¤t_buffer_name) {
 | |
|                                 panic!("Vertex attribute buffer names must be consecutive.")
 | |
|                             }
 | |
|                         }
 | |
|                     }
 | |
| 
 | |
|                     if let Some(current) = current_descriptor.take() {
 | |
|                         visited_buffer_descriptors.insert(current.name.to_string());
 | |
|                         vertex_buffer_descriptors.push(current);
 | |
|                     }
 | |
| 
 | |
|                     current_descriptor = Some(VertexBufferDescriptor {
 | |
|                         attributes: vec![vertex_attribute_descriptor],
 | |
|                         name: current_buffer_name,
 | |
|                         step_mode: if instance {
 | |
|                             InputStepMode::Instance
 | |
|                         } else {
 | |
|                             InputStepMode::Vertex
 | |
|                         },
 | |
|                         stride: 0,
 | |
|                     })
 | |
|                 }
 | |
| 
 | |
|                 if let Some(current) = current_descriptor.take() {
 | |
|                     visited_buffer_descriptors.insert(current.name.to_string());
 | |
|                     vertex_buffer_descriptors.push(current);
 | |
|                 }
 | |
| 
 | |
|                 for vertex_buffer_descriptor in vertex_buffer_descriptors.iter_mut() {
 | |
|                     calculate_offsets(vertex_buffer_descriptor);
 | |
|                 }
 | |
| 
 | |
|                 ShaderLayout {
 | |
|                     bind_groups,
 | |
|                     vertex_buffer_descriptors,
 | |
|                     entry_point: entry_point_name,
 | |
|                 }
 | |
|             }
 | |
|             Err(err) => panic!("Failed to reflect shader layout: {:?}", err),
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| fn calculate_offsets(vertex_buffer_descriptor: &mut VertexBufferDescriptor) {
 | |
|     let mut offset = 0;
 | |
|     for attribute in vertex_buffer_descriptor.attributes.iter_mut() {
 | |
|         attribute.offset = offset;
 | |
|         offset += attribute.format.get_size();
 | |
|     }
 | |
| 
 | |
|     vertex_buffer_descriptor.stride = offset;
 | |
| }
 | |
| 
 | |
| fn reflect_vertex_attribute_descriptor(
 | |
|     input_variable: &ReflectInterfaceVariable,
 | |
| ) -> VertexAttributeDescriptor {
 | |
|     VertexAttributeDescriptor {
 | |
|         name: input_variable.name.clone(),
 | |
|         format: reflect_vertex_format(input_variable.type_description.as_ref().unwrap()),
 | |
|         offset: 0,
 | |
|         shader_location: input_variable.location,
 | |
|     }
 | |
| }
 | |
| 
 | |
| fn reflect_bind_group(descriptor_set: &ReflectDescriptorSet) -> BindGroupDescriptor {
 | |
|     let mut bindings = Vec::new();
 | |
|     for descriptor_binding in descriptor_set.bindings.iter() {
 | |
|         let binding = reflect_binding(descriptor_binding);
 | |
|         bindings.push(binding);
 | |
|     }
 | |
| 
 | |
|     BindGroupDescriptor::new(descriptor_set.set, bindings)
 | |
| }
 | |
| 
 | |
| fn reflect_dimension(type_description: &ReflectTypeDescription) -> TextureViewDimension {
 | |
|     match type_description.traits.image.dim {
 | |
|         ReflectDimension::Type1d => TextureViewDimension::D1,
 | |
|         ReflectDimension::Type2d => TextureViewDimension::D2,
 | |
|         ReflectDimension::Type3d => TextureViewDimension::D3,
 | |
|         ReflectDimension::Cube => TextureViewDimension::Cube,
 | |
|         dimension => panic!("unsupported image dimension: {:?}", dimension),
 | |
|     }
 | |
| }
 | |
| 
 | |
| fn reflect_binding(binding: &ReflectDescriptorBinding) -> BindingDescriptor {
 | |
|     let type_description = binding.type_description.as_ref().unwrap();
 | |
|     let (name, bind_type) = match binding.descriptor_type {
 | |
|         ReflectDescriptorType::UniformBuffer => (
 | |
|             &type_description.type_name,
 | |
|             BindType::Uniform {
 | |
|                 dynamic: false,
 | |
|                 properties: vec![reflect_uniform(type_description)],
 | |
|             },
 | |
|         ),
 | |
|         ReflectDescriptorType::SampledImage => (
 | |
|             &binding.name,
 | |
|             BindType::SampledTexture {
 | |
|                 dimension: reflect_dimension(type_description),
 | |
|                 multisampled: false,
 | |
|             },
 | |
|         ),
 | |
|         ReflectDescriptorType::Sampler => (&binding.name, BindType::Sampler),
 | |
|         _ => panic!("unsupported bind type {:?}", binding.descriptor_type),
 | |
|     };
 | |
| 
 | |
|     BindingDescriptor {
 | |
|         index: binding.binding,
 | |
|         bind_type,
 | |
|         name: name.to_string(),
 | |
|     }
 | |
| }
 | |
| 
 | |
| #[derive(Debug)]
 | |
| enum NumberType {
 | |
|     Int,
 | |
|     UInt,
 | |
|     Float,
 | |
| }
 | |
| 
 | |
| fn reflect_uniform(type_description: &ReflectTypeDescription) -> UniformProperty {
 | |
|     let uniform_property_type = if type_description
 | |
|         .type_flags
 | |
|         .contains(ReflectTypeFlags::STRUCT)
 | |
|     {
 | |
|         reflect_uniform_struct(type_description)
 | |
|     } else {
 | |
|         reflect_uniform_numeric(type_description)
 | |
|     };
 | |
| 
 | |
|     UniformProperty {
 | |
|         name: type_description.type_name.to_string(),
 | |
|         property_type: uniform_property_type,
 | |
|     }
 | |
| }
 | |
| 
 | |
| fn reflect_uniform_struct(type_description: &ReflectTypeDescription) -> UniformPropertyType {
 | |
|     let mut properties = Vec::new();
 | |
|     for member in type_description.members.iter() {
 | |
|         properties.push(reflect_uniform(member));
 | |
|     }
 | |
| 
 | |
|     UniformPropertyType::Struct(properties)
 | |
| }
 | |
| 
 | |
| fn reflect_uniform_numeric(type_description: &ReflectTypeDescription) -> UniformPropertyType {
 | |
|     let traits = &type_description.traits;
 | |
|     let number_type = if type_description.type_flags.contains(ReflectTypeFlags::INT) {
 | |
|         match traits.numeric.scalar.signedness {
 | |
|             0 => NumberType::UInt,
 | |
|             1 => NumberType::Int,
 | |
|             signedness => panic!("unexpected signedness {}", signedness),
 | |
|         }
 | |
|     } else if type_description
 | |
|         .type_flags
 | |
|         .contains(ReflectTypeFlags::FLOAT)
 | |
|     {
 | |
|         NumberType::Float
 | |
|     } else {
 | |
|         panic!("unexpected type flag {:?}", type_description.type_flags);
 | |
|     };
 | |
| 
 | |
|     // TODO: handle scalar width here
 | |
| 
 | |
|     if type_description
 | |
|         .type_flags
 | |
|         .contains(ReflectTypeFlags::MATRIX)
 | |
|     {
 | |
|         match (
 | |
|             number_type,
 | |
|             traits.numeric.matrix.column_count,
 | |
|             traits.numeric.matrix.row_count,
 | |
|         ) {
 | |
|             (NumberType::Float, 3, 3) => UniformPropertyType::Mat3,
 | |
|             (NumberType::Float, 4, 4) => UniformPropertyType::Mat4,
 | |
|             (number_type, column_count, row_count) => panic!(
 | |
|                 "unexpected uniform property matrix format {:?} {}x{}",
 | |
|                 number_type, column_count, row_count
 | |
|             ),
 | |
|         }
 | |
|     } else {
 | |
|         match (number_type, traits.numeric.vector.component_count) {
 | |
|             (NumberType::Int, 1) => UniformPropertyType::Int,
 | |
|             (NumberType::Float, 3) => UniformPropertyType::Vec3,
 | |
|             (NumberType::Float, 4) => UniformPropertyType::Vec4,
 | |
|             (NumberType::UInt, 4) => UniformPropertyType::UVec4,
 | |
|             (number_type, component_count) => panic!(
 | |
|                 "unexpected uniform property format {:?} {}",
 | |
|                 number_type, component_count
 | |
|             ),
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| fn reflect_vertex_format(type_description: &ReflectTypeDescription) -> VertexFormat {
 | |
|     let traits = &type_description.traits;
 | |
|     let number_type = if type_description.type_flags.contains(ReflectTypeFlags::INT) {
 | |
|         match traits.numeric.scalar.signedness {
 | |
|             0 => NumberType::UInt,
 | |
|             1 => NumberType::Int,
 | |
|             signedness => panic!("unexpected signedness {}", signedness),
 | |
|         }
 | |
|     } else if type_description
 | |
|         .type_flags
 | |
|         .contains(ReflectTypeFlags::FLOAT)
 | |
|     {
 | |
|         NumberType::Float
 | |
|     } else {
 | |
|         panic!("unexpected type flag {:?}", type_description.type_flags);
 | |
|     };
 | |
| 
 | |
|     let width = traits.numeric.scalar.width;
 | |
| 
 | |
|     match (number_type, traits.numeric.vector.component_count, width) {
 | |
|         (NumberType::UInt, 2, 8) => VertexFormat::Uchar2,
 | |
|         (NumberType::UInt, 4, 8) => VertexFormat::Uchar4,
 | |
|         (NumberType::Int, 2, 8) => VertexFormat::Char2,
 | |
|         (NumberType::Int, 4, 8) => VertexFormat::Char4,
 | |
|         (NumberType::UInt, 2, 16) => VertexFormat::Ushort2,
 | |
|         (NumberType::UInt, 4, 16) => VertexFormat::Ushort4,
 | |
|         (NumberType::Int, 2, 16) => VertexFormat::Short2,
 | |
|         (NumberType::Int, 8, 16) => VertexFormat::Short4,
 | |
|         (NumberType::Float, 2, 16) => VertexFormat::Half2,
 | |
|         (NumberType::Float, 4, 16) => VertexFormat::Half4,
 | |
|         (NumberType::Float, 0, 32) => VertexFormat::Float,
 | |
|         (NumberType::Float, 2, 32) => VertexFormat::Float2,
 | |
|         (NumberType::Float, 3, 32) => VertexFormat::Float3,
 | |
|         (NumberType::Float, 4, 32) => VertexFormat::Float4,
 | |
|         (NumberType::UInt, 0, 32) => VertexFormat::Uint,
 | |
|         (NumberType::UInt, 2, 32) => VertexFormat::Uint2,
 | |
|         (NumberType::UInt, 3, 32) => VertexFormat::Uint3,
 | |
|         (NumberType::UInt, 4, 32) => VertexFormat::Uint4,
 | |
|         (NumberType::Int, 0, 32) => VertexFormat::Int,
 | |
|         (NumberType::Int, 2, 32) => VertexFormat::Int2,
 | |
|         (NumberType::Int, 3, 32) => VertexFormat::Int3,
 | |
|         (NumberType::Int, 4, 32) => VertexFormat::Int4,
 | |
|         (number_type, component_count, width) => panic!(
 | |
|             "unexpected uniform property format {:?} {} {}",
 | |
|             number_type, component_count, width
 | |
|         ),
 | |
|     }
 | |
| }
 | |
| 
 | |
| #[cfg(test)]
 | |
| mod tests {
 | |
|     use super::*;
 | |
|     use crate::render::shader::{Shader, ShaderStage};
 | |
| 
 | |
|     #[test]
 | |
|     fn test_reflection() {
 | |
|         let vertex_shader = Shader::from_glsl(
 | |
|             ShaderStage::Vertex,
 | |
|             r#"
 | |
|             #version 450
 | |
|             layout(location = 0) in vec4 Vertex_Position;
 | |
|             layout(location = 1) in uvec4 Vertex_Normal;
 | |
|             layout(location = 2) in uvec4 I_TestInstancing_Property;
 | |
| 
 | |
|             layout(location = 0) out vec4 v_Position;
 | |
|             layout(set = 0, binding = 0) uniform Camera {
 | |
|                 mat4 ViewProj;
 | |
|             };
 | |
|             layout(set = 1, binding = 0) uniform texture2D Texture;
 | |
| 
 | |
|             void main() {
 | |
|                 v_Position = Vertex_Position;
 | |
|                 gl_Position = ViewProj * v_Position;
 | |
|             }
 | |
|         "#,
 | |
|         )
 | |
|         .get_spirv_shader(None);
 | |
| 
 | |
|         let layout = vertex_shader.reflect_layout().unwrap();
 | |
|         assert_eq!(
 | |
|             layout,
 | |
|             ShaderLayout {
 | |
|                 entry_point: "main".to_string(),
 | |
|                 vertex_buffer_descriptors: vec![
 | |
|                     VertexBufferDescriptor {
 | |
|                         name: "Vertex".to_string(),
 | |
|                         attributes: vec![
 | |
|                             VertexAttributeDescriptor {
 | |
|                                 name: "Vertex_Position".to_string(),
 | |
|                                 format: VertexFormat::Float4,
 | |
|                                 offset: 0,
 | |
|                                 shader_location: 0,
 | |
|                             },
 | |
|                             VertexAttributeDescriptor {
 | |
|                                 name: "Vertex_Normal".to_string(),
 | |
|                                 format: VertexFormat::Uint4,
 | |
|                                 offset: 16,
 | |
|                                 shader_location: 1,
 | |
|                             }
 | |
|                         ],
 | |
|                         step_mode: InputStepMode::Vertex,
 | |
|                         stride: 32,
 | |
|                     },
 | |
|                     VertexBufferDescriptor {
 | |
|                         name: "TestInstancing".to_string(),
 | |
|                         attributes: vec![VertexAttributeDescriptor {
 | |
|                             name: "I_TestInstancing_Property".to_string(),
 | |
|                             format: VertexFormat::Uint4,
 | |
|                             offset: 0,
 | |
|                             shader_location: 2,
 | |
|                         },],
 | |
|                         step_mode: InputStepMode::Instance,
 | |
|                         stride: 16,
 | |
|                     }
 | |
|                 ],
 | |
|                 bind_groups: vec![
 | |
|                     BindGroupDescriptor::new(
 | |
|                         0,
 | |
|                         vec![BindingDescriptor {
 | |
|                             index: 0,
 | |
|                             name: "Camera".to_string(),
 | |
|                             bind_type: BindType::Uniform {
 | |
|                                 dynamic: false,
 | |
|                                 properties: vec![UniformProperty {
 | |
|                                     name: "Camera".to_string(),
 | |
|                                     property_type: UniformPropertyType::Struct(vec![
 | |
|                                         UniformProperty {
 | |
|                                             name: "".to_string(),
 | |
|                                             property_type: UniformPropertyType::Mat4,
 | |
|                                         }
 | |
|                                     ]),
 | |
|                                 }],
 | |
|                             },
 | |
|                         }]
 | |
|                     ),
 | |
|                     BindGroupDescriptor::new(
 | |
|                         1,
 | |
|                         vec![BindingDescriptor {
 | |
|                             index: 0,
 | |
|                             name: "Texture".to_string(),
 | |
|                             bind_type: BindType::SampledTexture {
 | |
|                                 multisampled: false,
 | |
|                                 dimension: TextureViewDimension::D2,
 | |
|                             },
 | |
|                         }]
 | |
|                     ),
 | |
|                 ]
 | |
|             }
 | |
|         );
 | |
|     }
 | |
| 
 | |
|     #[test]
 | |
|     #[should_panic(expected = "Vertex attribute buffer names must be consecutive.")]
 | |
|     fn test_reflection_consecutive_buffer_validation() {
 | |
|         let vertex_shader = Shader::from_glsl(
 | |
|             ShaderStage::Vertex,
 | |
|             r#"
 | |
|             #version 450
 | |
|             layout(location = 0) in vec4 Vertex_Position;
 | |
|             layout(location = 1) in uvec4 Other_Property;
 | |
|             layout(location = 2) in uvec4 Vertex_Normal;
 | |
| 
 | |
|             layout(location = 0) out vec4 v_Position;
 | |
|             layout(set = 0, binding = 0) uniform Camera {
 | |
|                 mat4 ViewProj;
 | |
|             };
 | |
|             layout(set = 1, binding = 0) uniform texture2D Texture;
 | |
| 
 | |
|             void main() {
 | |
|                 v_Position = Vertex_Position;
 | |
|                 gl_Position = ViewProj * v_Position;
 | |
|             }
 | |
|         "#,
 | |
|         )
 | |
|         .get_spirv_shader(None);
 | |
| 
 | |
|         let _layout = vertex_shader.reflect_layout().unwrap();
 | |
|     }
 | |
| }
 | 
