render: add SpecializedPipeline and SpecializedShader types

This commit is contained in:
Carter Anderson 2020-06-17 13:27:10 -07:00
parent e57fdca1bc
commit e89c693c4d
2 changed files with 49 additions and 37 deletions

View File

@ -22,15 +22,21 @@ pub struct ShaderSpecialization {
pub shader_defs: HashSet<String>,
}
struct SpecializedShader {
shader: Handle<Shader>,
specialization: ShaderSpecialization,
}
struct SpecializedPipeline {
pipeline: Handle<PipelineDescriptor>,
specialization: PipelineSpecialization,
}
// TODO: consider using (Typeid, fieldinfo.index) in place of string for hashes
#[derive(Default)]
pub struct PipelineCompiler {
pub shader_source_to_compiled:
HashMap<Handle<Shader>, Vec<(ShaderSpecialization, Handle<Shader>)>>,
pub specialized_pipelines: HashMap<
Handle<PipelineDescriptor>,
Vec<(PipelineSpecialization, Handle<PipelineDescriptor>)>,
>,
specialized_shaders: HashMap<Handle<Shader>, Vec<SpecializedShader>>,
specialized_pipelines: HashMap<Handle<PipelineDescriptor>, Vec<SpecializedPipeline>>,
}
impl PipelineCompiler {
@ -40,8 +46,8 @@ impl PipelineCompiler {
shader_handle: &Handle<Shader>,
shader_specialization: &ShaderSpecialization,
) -> Handle<Shader> {
let compiled_shaders = self
.shader_source_to_compiled
let specialized_shaders = self
.specialized_shaders
.entry(*shader_handle)
.or_insert_with(|| Vec::new());
@ -52,15 +58,15 @@ impl PipelineCompiler {
return *shader_handle;
}
if let Some((_shader_specialization, compiled_shader)) =
compiled_shaders
if let Some(specialized_shader) =
specialized_shaders
.iter()
.find(|(current_shader_specialization, _compiled_shader)| {
*current_shader_specialization == *shader_specialization
.find(|current_specialized_shader| {
current_specialized_shader.specialization == *shader_specialization
})
{
// if shader has already been compiled with current configuration, use existing shader
*compiled_shader
specialized_shader.shader
} else {
// if no shader exists with the current configuration, create new shader and compile
let shader_def_vec = shader_specialization
@ -69,9 +75,12 @@ impl PipelineCompiler {
.cloned()
.collect::<Vec<String>>();
let compiled_shader = shader.get_spirv_shader(Some(&shader_def_vec));
let compiled_handle = shaders.add(compiled_shader);
compiled_shaders.push((shader_specialization.clone(), compiled_handle));
compiled_handle
let specialized_handle = shaders.add(compiled_shader);
specialized_shaders.push(SpecializedShader {
shader: specialized_handle,
specialization: shader_specialization.clone(),
});
specialized_handle
}
}
@ -86,13 +95,13 @@ impl PipelineCompiler {
render_resource_bindings: &RenderResourceBindings,
) -> Handle<PipelineDescriptor> {
let source_descriptor = pipelines.get(&source_pipeline).unwrap();
let mut compiled_descriptor = source_descriptor.clone();
compiled_descriptor.shader_stages.vertex = self.compile_shader(
let mut specialized_descriptor = source_descriptor.clone();
specialized_descriptor.shader_stages.vertex = self.compile_shader(
shaders,
&compiled_descriptor.shader_stages.vertex,
&specialized_descriptor.shader_stages.vertex,
&pipeline_specialization.shader_specialization,
);
compiled_descriptor.shader_stages.fragment = compiled_descriptor
specialized_descriptor.shader_stages.fragment = specialized_descriptor
.shader_stages
.fragment
.as_ref()
@ -104,35 +113,38 @@ impl PipelineCompiler {
)
});
compiled_descriptor.reflect_layout(
specialized_descriptor.reflect_layout(
shaders,
true,
Some(vertex_buffer_descriptors),
Some(render_resource_bindings),
);
compiled_descriptor.primitive_topology = pipeline_specialization.primitive_topology;
let compiled_pipeline_handle =
specialized_descriptor.primitive_topology = pipeline_specialization.primitive_topology;
let specialized_pipeline_handle =
if *pipeline_specialization == PipelineSpecialization::default() {
pipelines.set(source_pipeline, compiled_descriptor);
pipelines.set(source_pipeline, specialized_descriptor);
source_pipeline
} else {
pipelines.add(compiled_descriptor)
pipelines.add(specialized_descriptor)
};
render_resource_context.create_render_pipeline(
compiled_pipeline_handle,
pipelines.get(&compiled_pipeline_handle).unwrap(),
specialized_pipeline_handle,
pipelines.get(&specialized_pipeline_handle).unwrap(),
&shaders,
);
let compiled_pipelines = self
let specialized_pipelines = self
.specialized_pipelines
.entry(source_pipeline)
.or_insert_with(|| Vec::new());
compiled_pipelines.push((pipeline_specialization.clone(), compiled_pipeline_handle));
specialized_pipelines.push(SpecializedPipeline {
pipeline: specialized_pipeline_handle,
specialization: pipeline_specialization.clone(),
});
compiled_pipeline_handle
specialized_pipeline_handle
}
fn compile_render_pipelines(
@ -145,17 +157,17 @@ impl PipelineCompiler {
) {
for render_pipeline in render_pipelines.pipelines.iter_mut() {
let source_pipeline = render_pipeline.pipeline;
let compiled_pipeline_handle = if let Some((_shader_defs, compiled_pipeline_handle)) =
let compiled_pipeline_handle = if let Some(specialized_pipeline) =
self.specialized_pipelines
.get_mut(&source_pipeline)
.and_then(|specialized_pipelines| {
specialized_pipelines.iter().find(
|(pipeline_specialization, _compiled_pipeline_handle)| {
*pipeline_specialization == render_pipeline.specialization
|current_specialized_pipeline| {
current_specialized_pipeline.specialization == render_pipeline.specialization
},
)
}) {
*compiled_pipeline_handle
specialized_pipeline.pipeline
} else {
self.compile_pipeline(
render_resource_context,
@ -177,7 +189,7 @@ impl PipelineCompiler {
pipeline_handle: Handle<PipelineDescriptor>,
) -> Option<impl Iterator<Item = &Handle<PipelineDescriptor>>> {
if let Some(compiled_pipelines) = self.specialized_pipelines.get(&pipeline_handle) {
Some(compiled_pipelines.iter().map(|(_, handle)| handle))
Some(compiled_pipelines.iter().map(|specialized_pipeline| &specialized_pipeline.pipeline))
} else {
None
}
@ -189,7 +201,7 @@ impl PipelineCompiler {
.map(|compiled_pipelines| {
compiled_pipelines
.iter()
.map(|(_, pipeline_handle)| pipeline_handle)
.map(|specialized_pipeline| &specialized_pipeline.pipeline)
})
.flatten()
}

View File

@ -8,7 +8,7 @@ fn main() {
.add_default_plugins()
.add_plugin(FrameTimeDiagnosticsPlugin::default())
.add_startup_system(setup.system())
// .add_system(text_update_system.system())
.add_system(text_update_system.system())
.run();
}