Shader Processor: process imported shader (#3290)
# Objective - I want to be able to use `#ifdef` and other processor directives in an imported shader ## Solution - Process imported shader strings Co-authored-by: François <8672791+mockersf@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									b5d7ff2d75
								
							
						
					
					
						commit
						a3c53e689d
					
				@ -360,16 +360,16 @@ impl ShaderProcessor {
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        let shader_defs = HashSet::<String>::from_iter(shader_defs.iter().cloned());
 | 
			
		||||
        let shader_defs_unique = HashSet::<String>::from_iter(shader_defs.iter().cloned());
 | 
			
		||||
        let mut scopes = vec![true];
 | 
			
		||||
        let mut final_string = String::new();
 | 
			
		||||
        for line in shader_str.split('\n') {
 | 
			
		||||
            if let Some(cap) = self.ifdef_regex.captures(line) {
 | 
			
		||||
                let def = cap.get(1).unwrap();
 | 
			
		||||
                scopes.push(*scopes.last().unwrap() && shader_defs.contains(def.as_str()));
 | 
			
		||||
                scopes.push(*scopes.last().unwrap() && shader_defs_unique.contains(def.as_str()));
 | 
			
		||||
            } else if let Some(cap) = self.ifndef_regex.captures(line) {
 | 
			
		||||
                let def = cap.get(1).unwrap();
 | 
			
		||||
                scopes.push(*scopes.last().unwrap() && !shader_defs.contains(def.as_str()));
 | 
			
		||||
                scopes.push(*scopes.last().unwrap() && !shader_defs_unique.contains(def.as_str()));
 | 
			
		||||
            } else if self.else_regex.is_match(line) {
 | 
			
		||||
                let mut is_parent_scope_truthy = true;
 | 
			
		||||
                if scopes.len() > 1 {
 | 
			
		||||
@ -388,19 +388,32 @@ impl ShaderProcessor {
 | 
			
		||||
                .captures(line)
 | 
			
		||||
            {
 | 
			
		||||
                let import = ShaderImport::AssetPath(cap.get(1).unwrap().as_str().to_string());
 | 
			
		||||
                apply_import(import_handles, shaders, &import, shader, &mut final_string)?;
 | 
			
		||||
                self.apply_import(
 | 
			
		||||
                    import_handles,
 | 
			
		||||
                    shaders,
 | 
			
		||||
                    &import,
 | 
			
		||||
                    shader,
 | 
			
		||||
                    shader_defs,
 | 
			
		||||
                    &mut final_string,
 | 
			
		||||
                )?;
 | 
			
		||||
            } else if let Some(cap) = SHADER_IMPORT_PROCESSOR
 | 
			
		||||
                .import_custom_path_regex
 | 
			
		||||
                .captures(line)
 | 
			
		||||
            {
 | 
			
		||||
                let import = ShaderImport::Custom(cap.get(1).unwrap().as_str().to_string());
 | 
			
		||||
                apply_import(import_handles, shaders, &import, shader, &mut final_string)?;
 | 
			
		||||
                self.apply_import(
 | 
			
		||||
                    import_handles,
 | 
			
		||||
                    shaders,
 | 
			
		||||
                    &import,
 | 
			
		||||
                    shader,
 | 
			
		||||
                    shader_defs,
 | 
			
		||||
                    &mut final_string,
 | 
			
		||||
                )?;
 | 
			
		||||
            } else if *scopes.last().unwrap() {
 | 
			
		||||
                final_string.push_str(line);
 | 
			
		||||
                final_string.push('\n');
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        final_string.pop();
 | 
			
		||||
 | 
			
		||||
        if scopes.len() != 1 {
 | 
			
		||||
@ -417,45 +430,51 @@ impl ShaderProcessor {
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn apply_import(
 | 
			
		||||
    import_handles: &HashMap<ShaderImport, Handle<Shader>>,
 | 
			
		||||
    shaders: &HashMap<Handle<Shader>, Shader>,
 | 
			
		||||
    import: &ShaderImport,
 | 
			
		||||
    shader: &Shader,
 | 
			
		||||
    final_string: &mut String,
 | 
			
		||||
) -> Result<(), ProcessShaderError> {
 | 
			
		||||
    let imported_shader = import_handles
 | 
			
		||||
        .get(import)
 | 
			
		||||
        .and_then(|handle| shaders.get(handle))
 | 
			
		||||
        .ok_or_else(|| ProcessShaderError::UnresolvedImport(import.clone()))?;
 | 
			
		||||
    match &shader.source {
 | 
			
		||||
        Source::Wgsl(_) => {
 | 
			
		||||
            if let Source::Wgsl(import_source) = &imported_shader.source {
 | 
			
		||||
                final_string.push_str(import_source);
 | 
			
		||||
            } else {
 | 
			
		||||
                return Err(ProcessShaderError::MismatchedImportFormat(import.clone()));
 | 
			
		||||
    fn apply_import(
 | 
			
		||||
        &self,
 | 
			
		||||
        import_handles: &HashMap<ShaderImport, Handle<Shader>>,
 | 
			
		||||
        shaders: &HashMap<Handle<Shader>, Shader>,
 | 
			
		||||
        import: &ShaderImport,
 | 
			
		||||
        shader: &Shader,
 | 
			
		||||
        shader_defs: &[String],
 | 
			
		||||
        final_string: &mut String,
 | 
			
		||||
    ) -> Result<(), ProcessShaderError> {
 | 
			
		||||
        let imported_shader = import_handles
 | 
			
		||||
            .get(import)
 | 
			
		||||
            .and_then(|handle| shaders.get(handle))
 | 
			
		||||
            .ok_or_else(|| ProcessShaderError::UnresolvedImport(import.clone()))?;
 | 
			
		||||
        let imported_processed =
 | 
			
		||||
            self.process(imported_shader, shader_defs, shaders, import_handles)?;
 | 
			
		||||
 | 
			
		||||
        match &shader.source {
 | 
			
		||||
            Source::Wgsl(_) => {
 | 
			
		||||
                if let ProcessedShader::Wgsl(import_source) = &imported_processed {
 | 
			
		||||
                    final_string.push_str(import_source);
 | 
			
		||||
                } else {
 | 
			
		||||
                    return Err(ProcessShaderError::MismatchedImportFormat(import.clone()));
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            Source::Glsl(_, _) => {
 | 
			
		||||
                if let ProcessedShader::Glsl(import_source, _) = &imported_processed {
 | 
			
		||||
                    final_string.push_str(import_source);
 | 
			
		||||
                } else {
 | 
			
		||||
                    return Err(ProcessShaderError::MismatchedImportFormat(import.clone()));
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            Source::SpirV(_) => {
 | 
			
		||||
                return Err(ProcessShaderError::ShaderFormatDoesNotSupportImports);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        Source::Glsl(_, _) => {
 | 
			
		||||
            if let Source::Glsl(import_source, _) = &imported_shader.source {
 | 
			
		||||
                final_string.push_str(import_source);
 | 
			
		||||
            } else {
 | 
			
		||||
                return Err(ProcessShaderError::MismatchedImportFormat(import.clone()));
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        Source::SpirV(_) => {
 | 
			
		||||
            return Err(ProcessShaderError::ShaderFormatDoesNotSupportImports);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Ok(())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[cfg(test)]
 | 
			
		||||
mod tests {
 | 
			
		||||
    use bevy_asset::Handle;
 | 
			
		||||
    use bevy_asset::{Handle, HandleUntyped};
 | 
			
		||||
    use bevy_reflect::TypeUuid;
 | 
			
		||||
    use bevy_utils::HashMap;
 | 
			
		||||
    use naga::ShaderStage;
 | 
			
		||||
 | 
			
		||||
@ -1081,4 +1100,106 @@ fn vertex(
 | 
			
		||||
            .unwrap();
 | 
			
		||||
        assert_eq!(result.get_wgsl_source().unwrap(), EXPECTED);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[test]
 | 
			
		||||
    fn process_import_ifdef() {
 | 
			
		||||
        #[rustfmt::skip]
 | 
			
		||||
        const FOO: &str = r"
 | 
			
		||||
#ifdef IMPORT_MISSING
 | 
			
		||||
fn in_import_missing() { }
 | 
			
		||||
#endif
 | 
			
		||||
#ifdef IMPORT_PRESENT
 | 
			
		||||
fn in_import_present() { }
 | 
			
		||||
#endif
 | 
			
		||||
";
 | 
			
		||||
        #[rustfmt::skip]
 | 
			
		||||
        const INPUT: &str = r"
 | 
			
		||||
#import FOO
 | 
			
		||||
#ifdef MAIN_MISSING
 | 
			
		||||
fn in_main_missing() { }
 | 
			
		||||
#endif
 | 
			
		||||
#ifdef MAIN_PRESENT
 | 
			
		||||
fn in_main_present() { }
 | 
			
		||||
#endif
 | 
			
		||||
";
 | 
			
		||||
        #[rustfmt::skip]
 | 
			
		||||
        const EXPECTED: &str = r"
 | 
			
		||||
 | 
			
		||||
fn in_import_present() { }
 | 
			
		||||
fn in_main_present() { }
 | 
			
		||||
";
 | 
			
		||||
        let processor = ShaderProcessor::default();
 | 
			
		||||
        let mut shaders = HashMap::default();
 | 
			
		||||
        let mut import_handles = HashMap::default();
 | 
			
		||||
        let foo_handle = Handle::<Shader>::default();
 | 
			
		||||
        shaders.insert(foo_handle.clone_weak(), Shader::from_wgsl(FOO));
 | 
			
		||||
        import_handles.insert(
 | 
			
		||||
            ShaderImport::Custom("FOO".to_string()),
 | 
			
		||||
            foo_handle.clone_weak(),
 | 
			
		||||
        );
 | 
			
		||||
        let result = processor
 | 
			
		||||
            .process(
 | 
			
		||||
                &Shader::from_wgsl(INPUT),
 | 
			
		||||
                &["MAIN_PRESENT".to_string(), "IMPORT_PRESENT".to_string()],
 | 
			
		||||
                &shaders,
 | 
			
		||||
                &import_handles,
 | 
			
		||||
            )
 | 
			
		||||
            .unwrap();
 | 
			
		||||
        assert_eq!(result.get_wgsl_source().unwrap(), EXPECTED);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[test]
 | 
			
		||||
    fn process_import_in_import() {
 | 
			
		||||
        #[rustfmt::skip]
 | 
			
		||||
        const BAR: &str = r"
 | 
			
		||||
#ifdef DEEP
 | 
			
		||||
fn inner_import() { }
 | 
			
		||||
#endif
 | 
			
		||||
";
 | 
			
		||||
        const FOO: &str = r"
 | 
			
		||||
#import BAR
 | 
			
		||||
fn import() { }
 | 
			
		||||
";
 | 
			
		||||
        #[rustfmt::skip]
 | 
			
		||||
        const INPUT: &str = r"
 | 
			
		||||
#import FOO
 | 
			
		||||
fn in_main() { }
 | 
			
		||||
";
 | 
			
		||||
        #[rustfmt::skip]
 | 
			
		||||
        const EXPECTED: &str = r"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
fn inner_import() { }
 | 
			
		||||
fn import() { }
 | 
			
		||||
fn in_main() { }
 | 
			
		||||
";
 | 
			
		||||
        let processor = ShaderProcessor::default();
 | 
			
		||||
        let mut shaders = HashMap::default();
 | 
			
		||||
        let mut import_handles = HashMap::default();
 | 
			
		||||
        {
 | 
			
		||||
            let bar_handle = Handle::<Shader>::default();
 | 
			
		||||
            shaders.insert(bar_handle.clone_weak(), Shader::from_wgsl(BAR));
 | 
			
		||||
            import_handles.insert(
 | 
			
		||||
                ShaderImport::Custom("BAR".to_string()),
 | 
			
		||||
                bar_handle.clone_weak(),
 | 
			
		||||
            );
 | 
			
		||||
        }
 | 
			
		||||
        {
 | 
			
		||||
            let foo_handle = HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 1).typed();
 | 
			
		||||
            shaders.insert(foo_handle.clone_weak(), Shader::from_wgsl(FOO));
 | 
			
		||||
            import_handles.insert(
 | 
			
		||||
                ShaderImport::Custom("FOO".to_string()),
 | 
			
		||||
                foo_handle.clone_weak(),
 | 
			
		||||
            );
 | 
			
		||||
        }
 | 
			
		||||
        let result = processor
 | 
			
		||||
            .process(
 | 
			
		||||
                &Shader::from_wgsl(INPUT),
 | 
			
		||||
                &["DEEP".to_string()],
 | 
			
		||||
                &shaders,
 | 
			
		||||
                &import_handles,
 | 
			
		||||
            )
 | 
			
		||||
            .unwrap();
 | 
			
		||||
        assert_eq!(result.get_wgsl_source().unwrap(), EXPECTED);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user