Merge pull request #189 from StarArawn/bind-group-reflect-fix

Reflect shader stage for bind groups.
This commit is contained in:
Carter Anderson 2020-08-20 12:57:38 -07:00 committed by GitHub
commit 1ebb7e44ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 36 additions and 11 deletions

View File

@ -25,16 +25,21 @@ impl PipelineLayout {
for shader_binding in shader_bind_group.bindings.iter() { for shader_binding in shader_bind_group.bindings.iter() {
if let Some(binding) = bind_group if let Some(binding) = bind_group
.bindings .bindings
.iter() .iter_mut()
.find(|binding| binding.index == shader_binding.index) .find(|binding| binding.index == shader_binding.index)
{ {
if binding != shader_binding { binding.shader_stage |= shader_binding.shader_stage;
if binding.bind_type != shader_binding.bind_type
|| binding.name != shader_binding.name
|| binding.index != shader_binding.index
{
panic!("Binding {} in BindGroup {} does not match across all shader types: {:?} {:?}", binding.index, bind_group.index, binding, shader_binding); panic!("Binding {} in BindGroup {} does not match across all shader types: {:?} {:?}", binding.index, bind_group.index, binding, shader_binding);
} }
} else { } else {
bind_group.bindings.push(shader_binding.clone()); bind_group.bindings.push(shader_binding.clone());
} }
} }
bind_group.update_id();
} }
None => { None => {
bind_groups.insert(shader_bind_group.index, shader_bind_group.clone()); bind_groups.insert(shader_bind_group.index, shader_bind_group.clone());

View File

@ -9,7 +9,8 @@ use bevy_core::AsBytes;
use spirv_reflect::{ use spirv_reflect::{
types::{ types::{
ReflectDescriptorBinding, ReflectDescriptorSet, ReflectDescriptorType, ReflectDimension, ReflectDescriptorBinding, ReflectDescriptorSet, ReflectDescriptorType, ReflectDimension,
ReflectInterfaceVariable, ReflectTypeDescription, ReflectTypeFlags, ReflectInterfaceVariable, ReflectShaderStageFlags, ReflectTypeDescription,
ReflectTypeFlags,
}, },
ShaderModule, ShaderModule,
}; };
@ -30,9 +31,10 @@ impl ShaderLayout {
match ShaderModule::load_u8_data(spirv_data.as_bytes()) { match ShaderModule::load_u8_data(spirv_data.as_bytes()) {
Ok(ref mut module) => { Ok(ref mut module) => {
let entry_point_name = module.get_entry_point_name(); let entry_point_name = module.get_entry_point_name();
let shader_stage = module.get_shader_stage();
let mut bind_groups = Vec::new(); let mut bind_groups = Vec::new();
for descriptor_set in module.enumerate_descriptor_sets(None).unwrap() { for descriptor_set in module.enumerate_descriptor_sets(None).unwrap() {
let bind_group = reflect_bind_group(&descriptor_set); let bind_group = reflect_bind_group(&descriptor_set, shader_stage);
bind_groups.push(bind_group); bind_groups.push(bind_group);
} }
@ -148,10 +150,13 @@ fn reflect_vertex_attribute_descriptor(
} }
} }
fn reflect_bind_group(descriptor_set: &ReflectDescriptorSet) -> BindGroupDescriptor { fn reflect_bind_group(
descriptor_set: &ReflectDescriptorSet,
shader_stage: ReflectShaderStageFlags,
) -> BindGroupDescriptor {
let mut bindings = Vec::new(); let mut bindings = Vec::new();
for descriptor_binding in descriptor_set.bindings.iter() { for descriptor_binding in descriptor_set.bindings.iter() {
let binding = reflect_binding(descriptor_binding); let binding = reflect_binding(descriptor_binding, shader_stage);
bindings.push(binding); bindings.push(binding);
} }
@ -168,7 +173,10 @@ fn reflect_dimension(type_description: &ReflectTypeDescription) -> TextureViewDi
} }
} }
fn reflect_binding(binding: &ReflectDescriptorBinding) -> BindingDescriptor { fn reflect_binding(
binding: &ReflectDescriptorBinding,
shader_stage: ReflectShaderStageFlags,
) -> BindingDescriptor {
let type_description = binding.type_description.as_ref().unwrap(); let type_description = binding.type_description.as_ref().unwrap();
let (name, bind_type) = match binding.descriptor_type { let (name, bind_type) = match binding.descriptor_type {
ReflectDescriptorType::UniformBuffer => ( ReflectDescriptorType::UniformBuffer => (
@ -198,12 +206,24 @@ fn reflect_binding(binding: &ReflectDescriptorBinding) -> BindingDescriptor {
_ => panic!("unsupported bind type {:?}", binding.descriptor_type), _ => panic!("unsupported bind type {:?}", binding.descriptor_type),
}; };
let mut shader_stage = match shader_stage {
ReflectShaderStageFlags::COMPUTE => BindingShaderStage::COMPUTE,
ReflectShaderStageFlags::VERTEX => BindingShaderStage::VERTEX,
ReflectShaderStageFlags::FRAGMENT => BindingShaderStage::FRAGMENT,
_ => panic!("Only one specified shader stage is supported."),
};
let name = name.to_string();
if name == "Camera" {
shader_stage = BindingShaderStage::VERTEX | BindingShaderStage::FRAGMENT;
}
BindingDescriptor { BindingDescriptor {
index: binding.binding, index: binding.binding,
bind_type, bind_type,
name: name.to_string(), name,
// TODO: We should be able to detect which shader program the binding is being used in.. shader_stage,
shader_stage: BindingShaderStage::VERTEX | BindingShaderStage::FRAGMENT,
} }
} }
@ -425,7 +445,7 @@ mod tests {
dimension: TextureViewDimension::D2, dimension: TextureViewDimension::D2,
component_type: TextureComponentType::Float, component_type: TextureComponentType::Float,
}, },
shader_stage: BindingShaderStage::VERTEX | BindingShaderStage::FRAGMENT, shader_stage: BindingShaderStage::VERTEX,
}] }]
), ),
] ]