bevy_pbr: Avoid copying structs and using registers in shaders (#7069)

# Objective

- The #7064 PR had poor performance on an M1 Max in MacOS due to significant overuse of registers resulting in 'register spilling' where data that would normally be stored in registers on the GPU is instead stored in VRAM. The latency to read from/write to VRAM instead of registers incurs a significant performance penalty.
- Use of registers is a limiting factor in shader performance. Assignment of a struct from memory to a local variable can incur copies. Passing a variable that has struct type as an argument to a function can also incur copies. As such, these two cases can incur increased register usage and decreased performance.

## Solution

- Remove/avoid a number of assignments of light struct type data to local variables.
- Remove/avoid a number of passing light struct type variables/data as value arguments to shader functions.
This commit is contained in:
Robert Swain 2023-01-02 22:07:33 +00:00
parent b833bdab17
commit b44b606d29
3 changed files with 51 additions and 49 deletions

View File

@ -196,38 +196,35 @@ fn pbr(
// point lights // point lights
for (var i: u32 = offset_and_counts[0]; i < offset_and_counts[0] + offset_and_counts[1]; i = i + 1u) { for (var i: u32 = offset_and_counts[0]; i < offset_and_counts[0] + offset_and_counts[1]; i = i + 1u) {
let light_id = get_light_id(i); let light_id = get_light_id(i);
let light = point_lights.data[light_id];
var shadow: f32 = 1.0; var shadow: f32 = 1.0;
if ((mesh.flags & MESH_FLAGS_SHADOW_RECEIVER_BIT) != 0u if ((mesh.flags & MESH_FLAGS_SHADOW_RECEIVER_BIT) != 0u
&& (light.flags & POINT_LIGHT_FLAGS_SHADOWS_ENABLED_BIT) != 0u) { && (point_lights.data[light_id].flags & POINT_LIGHT_FLAGS_SHADOWS_ENABLED_BIT) != 0u) {
shadow = fetch_point_shadow(light_id, in.world_position, in.world_normal); shadow = fetch_point_shadow(light_id, in.world_position, in.world_normal);
} }
let light_contrib = point_light(in.world_position.xyz, light, roughness, NdotV, in.N, in.V, R, F0, diffuse_color); let light_contrib = point_light(in.world_position.xyz, light_id, roughness, NdotV, in.N, in.V, R, F0, diffuse_color);
light_accum = light_accum + light_contrib * shadow; light_accum = light_accum + light_contrib * shadow;
} }
// spot lights // spot lights
for (var i: u32 = offset_and_counts[0] + offset_and_counts[1]; i < offset_and_counts[0] + offset_and_counts[1] + offset_and_counts[2]; i = i + 1u) { for (var i: u32 = offset_and_counts[0] + offset_and_counts[1]; i < offset_and_counts[0] + offset_and_counts[1] + offset_and_counts[2]; i = i + 1u) {
let light_id = get_light_id(i); let light_id = get_light_id(i);
let light = point_lights.data[light_id];
var shadow: f32 = 1.0; var shadow: f32 = 1.0;
if ((mesh.flags & MESH_FLAGS_SHADOW_RECEIVER_BIT) != 0u if ((mesh.flags & MESH_FLAGS_SHADOW_RECEIVER_BIT) != 0u
&& (light.flags & POINT_LIGHT_FLAGS_SHADOWS_ENABLED_BIT) != 0u) { && (point_lights.data[light_id].flags & POINT_LIGHT_FLAGS_SHADOWS_ENABLED_BIT) != 0u) {
shadow = fetch_spot_shadow(light_id, in.world_position, in.world_normal); shadow = fetch_spot_shadow(light_id, in.world_position, in.world_normal);
} }
let light_contrib = spot_light(in.world_position.xyz, light, roughness, NdotV, in.N, in.V, R, F0, diffuse_color); let light_contrib = spot_light(in.world_position.xyz, light_id, roughness, NdotV, in.N, in.V, R, F0, diffuse_color);
light_accum = light_accum + light_contrib * shadow; light_accum = light_accum + light_contrib * shadow;
} }
let n_directional_lights = lights.n_directional_lights; let n_directional_lights = lights.n_directional_lights;
for (var i: u32 = 0u; i < n_directional_lights; i = i + 1u) { for (var i: u32 = 0u; i < n_directional_lights; i = i + 1u) {
let light = lights.directional_lights[i];
var shadow: f32 = 1.0; var shadow: f32 = 1.0;
if ((mesh.flags & MESH_FLAGS_SHADOW_RECEIVER_BIT) != 0u if ((mesh.flags & MESH_FLAGS_SHADOW_RECEIVER_BIT) != 0u
&& (light.flags & DIRECTIONAL_LIGHT_FLAGS_SHADOWS_ENABLED_BIT) != 0u) { && (lights.directional_lights[i].flags & DIRECTIONAL_LIGHT_FLAGS_SHADOWS_ENABLED_BIT) != 0u) {
shadow = fetch_directional_shadow(i, in.world_position, in.world_normal); shadow = fetch_directional_shadow(i, in.world_position, in.world_normal);
} }
let light_contrib = directional_light(light, roughness, NdotV, in.N, in.V, R, F0, diffuse_color); let light_contrib = directional_light(i, roughness, NdotV, in.N, in.V, R, F0, diffuse_color);
light_accum = light_accum + light_contrib * shadow; light_accum = light_accum + light_contrib * shadow;
} }

View File

@ -150,22 +150,23 @@ fn perceptualRoughnessToRoughness(perceptualRoughness: f32) -> f32 {
} }
fn point_light( fn point_light(
world_position: vec3<f32>, light: PointLight, roughness: f32, NdotV: f32, N: vec3<f32>, V: vec3<f32>, world_position: vec3<f32>, light_id: u32, roughness: f32, NdotV: f32, N: vec3<f32>, V: vec3<f32>,
R: vec3<f32>, F0: vec3<f32>, diffuseColor: vec3<f32> R: vec3<f32>, F0: vec3<f32>, diffuseColor: vec3<f32>
) -> vec3<f32> { ) -> vec3<f32> {
let light_to_frag = light.position_radius.xyz - world_position.xyz; let light = &point_lights.data[light_id];
let light_to_frag = (*light).position_radius.xyz - world_position.xyz;
let distance_square = dot(light_to_frag, light_to_frag); let distance_square = dot(light_to_frag, light_to_frag);
let rangeAttenuation = let rangeAttenuation =
getDistanceAttenuation(distance_square, light.color_inverse_square_range.w); getDistanceAttenuation(distance_square, (*light).color_inverse_square_range.w);
// Specular. // Specular.
// Representative Point Area Lights. // Representative Point Area Lights.
// see http://blog.selfshadow.com/publications/s2013-shading-course/karis/s2013_pbs_epic_notes_v2.pdf p14-16 // see http://blog.selfshadow.com/publications/s2013-shading-course/karis/s2013_pbs_epic_notes_v2.pdf p14-16
let a = roughness; let a = roughness;
let centerToRay = dot(light_to_frag, R) * R - light_to_frag; let centerToRay = dot(light_to_frag, R) * R - light_to_frag;
let closestPoint = light_to_frag + centerToRay * saturate(light.position_radius.w * inverseSqrt(dot(centerToRay, centerToRay))); let closestPoint = light_to_frag + centerToRay * saturate((*light).position_radius.w * inverseSqrt(dot(centerToRay, centerToRay)));
let LspecLengthInverse = inverseSqrt(dot(closestPoint, closestPoint)); let LspecLengthInverse = inverseSqrt(dot(closestPoint, closestPoint));
let normalizationFactor = a / saturate(a + (light.position_radius.w * 0.5 * LspecLengthInverse)); let normalizationFactor = a / saturate(a + ((*light).position_radius.w * 0.5 * LspecLengthInverse));
let specularIntensity = normalizationFactor * normalizationFactor; let specularIntensity = normalizationFactor * normalizationFactor;
var L: vec3<f32> = closestPoint * LspecLengthInverse; // normalize() equivalent? var L: vec3<f32> = closestPoint * LspecLengthInverse; // normalize() equivalent?
@ -197,40 +198,44 @@ fn point_light(
// I = Φ / 4 π // I = Φ / 4 π
// The derivation of this can be seen here: https://google.github.io/filament/Filament.html#mjx-eqn-pointLightLuminousPower // The derivation of this can be seen here: https://google.github.io/filament/Filament.html#mjx-eqn-pointLightLuminousPower
// NOTE: light.color.rgb is premultiplied with light.intensity / 4 π (which would be the luminous intensity) on the CPU // NOTE: (*light).color.rgb is premultiplied with (*light).intensity / 4 π (which would be the luminous intensity) on the CPU
// TODO compensate for energy loss https://google.github.io/filament/Filament.html#materialsystem/improvingthebrdfs/energylossinspecularreflectance // TODO compensate for energy loss https://google.github.io/filament/Filament.html#materialsystem/improvingthebrdfs/energylossinspecularreflectance
return ((diffuse + specular_light) * light.color_inverse_square_range.rgb) * (rangeAttenuation * NoL); return ((diffuse + specular_light) * (*light).color_inverse_square_range.rgb) * (rangeAttenuation * NoL);
} }
fn spot_light( fn spot_light(
world_position: vec3<f32>, light: PointLight, roughness: f32, NdotV: f32, N: vec3<f32>, V: vec3<f32>, world_position: vec3<f32>, light_id: u32, roughness: f32, NdotV: f32, N: vec3<f32>, V: vec3<f32>,
R: vec3<f32>, F0: vec3<f32>, diffuseColor: vec3<f32> R: vec3<f32>, F0: vec3<f32>, diffuseColor: vec3<f32>
) -> vec3<f32> { ) -> vec3<f32> {
// reuse the point light calculations // reuse the point light calculations
let point_light = point_light(world_position, light, roughness, NdotV, N, V, R, F0, diffuseColor); let point_light = point_light(world_position, light_id, roughness, NdotV, N, V, R, F0, diffuseColor);
let light = &point_lights.data[light_id];
// reconstruct spot dir from x/z and y-direction flag // reconstruct spot dir from x/z and y-direction flag
var spot_dir = vec3<f32>(light.light_custom_data.x, 0.0, light.light_custom_data.y); var spot_dir = vec3<f32>((*light).light_custom_data.x, 0.0, (*light).light_custom_data.y);
spot_dir.y = sqrt(max(0.0, 1.0 - spot_dir.x * spot_dir.x - spot_dir.z * spot_dir.z)); spot_dir.y = sqrt(max(0.0, 1.0 - spot_dir.x * spot_dir.x - spot_dir.z * spot_dir.z));
if ((light.flags & POINT_LIGHT_FLAGS_SPOT_LIGHT_Y_NEGATIVE) != 0u) { if (((*light).flags & POINT_LIGHT_FLAGS_SPOT_LIGHT_Y_NEGATIVE) != 0u) {
spot_dir.y = -spot_dir.y; spot_dir.y = -spot_dir.y;
} }
let light_to_frag = light.position_radius.xyz - world_position.xyz; let light_to_frag = (*light).position_radius.xyz - world_position.xyz;
// calculate attenuation based on filament formula https://google.github.io/filament/Filament.html#listing_glslpunctuallight // calculate attenuation based on filament formula https://google.github.io/filament/Filament.html#listing_glslpunctuallight
// spot_scale and spot_offset have been precomputed // spot_scale and spot_offset have been precomputed
// note we normalize here to get "l" from the filament listing. spot_dir is already normalized // note we normalize here to get "l" from the filament listing. spot_dir is already normalized
let cd = dot(-spot_dir, normalize(light_to_frag)); let cd = dot(-spot_dir, normalize(light_to_frag));
let attenuation = saturate(cd * light.light_custom_data.z + light.light_custom_data.w); let attenuation = saturate(cd * (*light).light_custom_data.z + (*light).light_custom_data.w);
let spot_attenuation = attenuation * attenuation; let spot_attenuation = attenuation * attenuation;
return point_light * spot_attenuation; return point_light * spot_attenuation;
} }
fn directional_light(light: DirectionalLight, roughness: f32, NdotV: f32, normal: vec3<f32>, view: vec3<f32>, R: vec3<f32>, F0: vec3<f32>, diffuseColor: vec3<f32>) -> vec3<f32> { fn directional_light(light_id: u32, roughness: f32, NdotV: f32, normal: vec3<f32>, view: vec3<f32>, R: vec3<f32>, F0: vec3<f32>, diffuseColor: vec3<f32>) -> vec3<f32> {
let incident_light = light.direction_to_light.xyz; let light = &lights.directional_lights[light_id];
let incident_light = (*light).direction_to_light.xyz;
let half_vector = normalize(incident_light + view); let half_vector = normalize(incident_light + view);
let NoL = saturate(dot(normal, incident_light)); let NoL = saturate(dot(normal, incident_light));
@ -241,5 +246,5 @@ fn directional_light(light: DirectionalLight, roughness: f32, NdotV: f32, normal
let specularIntensity = 1.0; let specularIntensity = 1.0;
let specular_light = specular(F0, roughness, half_vector, NdotV, NoL, NoH, LoH, specularIntensity); let specular_light = specular(F0, roughness, half_vector, NdotV, NoL, NoH, LoH, specularIntensity);
return (specular_light + diffuse) * light.color.rgb * NoL; return (specular_light + diffuse) * (*light).color.rgb * NoL;
} }

View File

@ -1,23 +1,23 @@
#define_import_path bevy_pbr::shadows #define_import_path bevy_pbr::shadows
fn fetch_point_shadow(light_id: u32, frag_position: vec4<f32>, surface_normal: vec3<f32>) -> f32 { fn fetch_point_shadow(light_id: u32, frag_position: vec4<f32>, surface_normal: vec3<f32>) -> f32 {
let light = point_lights.data[light_id]; let light = &point_lights.data[light_id];
// because the shadow maps align with the axes and the frustum planes are at 45 degrees // because the shadow maps align with the axes and the frustum planes are at 45 degrees
// we can get the worldspace depth by taking the largest absolute axis // we can get the worldspace depth by taking the largest absolute axis
let surface_to_light = light.position_radius.xyz - frag_position.xyz; let surface_to_light = (*light).position_radius.xyz - frag_position.xyz;
let surface_to_light_abs = abs(surface_to_light); let surface_to_light_abs = abs(surface_to_light);
let distance_to_light = max(surface_to_light_abs.x, max(surface_to_light_abs.y, surface_to_light_abs.z)); let distance_to_light = max(surface_to_light_abs.x, max(surface_to_light_abs.y, surface_to_light_abs.z));
// The normal bias here is already scaled by the texel size at 1 world unit from the light. // The normal bias here is already scaled by the texel size at 1 world unit from the light.
// The texel size increases proportionally with distance from the light so multiplying by // The texel size increases proportionally with distance from the light so multiplying by
// distance to light scales the normal bias to the texel size at the fragment distance. // distance to light scales the normal bias to the texel size at the fragment distance.
let normal_offset = light.shadow_normal_bias * distance_to_light * surface_normal.xyz; let normal_offset = (*light).shadow_normal_bias * distance_to_light * surface_normal.xyz;
let depth_offset = light.shadow_depth_bias * normalize(surface_to_light.xyz); let depth_offset = (*light).shadow_depth_bias * normalize(surface_to_light.xyz);
let offset_position = frag_position.xyz + normal_offset + depth_offset; let offset_position = frag_position.xyz + normal_offset + depth_offset;
// similar largest-absolute-axis trick as above, but now with the offset fragment position // similar largest-absolute-axis trick as above, but now with the offset fragment position
let frag_ls = light.position_radius.xyz - offset_position.xyz; let frag_ls = (*light).position_radius.xyz - offset_position.xyz;
let abs_position_ls = abs(frag_ls); let abs_position_ls = abs(frag_ls);
let major_axis_magnitude = max(abs_position_ls.x, max(abs_position_ls.y, abs_position_ls.z)); let major_axis_magnitude = max(abs_position_ls.x, max(abs_position_ls.y, abs_position_ls.z));
@ -25,7 +25,7 @@ fn fetch_point_shadow(light_id: u32, frag_position: vec4<f32>, surface_normal: v
// projection * vec4(0, 0, -major_axis_magnitude, 1.0) // projection * vec4(0, 0, -major_axis_magnitude, 1.0)
// and keeping only the terms that have any impact on the depth. // and keeping only the terms that have any impact on the depth.
// Projection-agnostic approach: // Projection-agnostic approach:
let zw = -major_axis_magnitude * light.light_custom_data.xy + light.light_custom_data.zw; let zw = -major_axis_magnitude * (*light).light_custom_data.xy + (*light).light_custom_data.zw;
let depth = zw.x / zw.y; let depth = zw.x / zw.y;
// do the lookup, using HW PCF and comparison // do the lookup, using HW PCF and comparison
@ -42,27 +42,27 @@ fn fetch_point_shadow(light_id: u32, frag_position: vec4<f32>, surface_normal: v
} }
fn fetch_spot_shadow(light_id: u32, frag_position: vec4<f32>, surface_normal: vec3<f32>) -> f32 { fn fetch_spot_shadow(light_id: u32, frag_position: vec4<f32>, surface_normal: vec3<f32>) -> f32 {
let light = point_lights.data[light_id]; let light = &point_lights.data[light_id];
let surface_to_light = light.position_radius.xyz - frag_position.xyz; let surface_to_light = (*light).position_radius.xyz - frag_position.xyz;
// construct the light view matrix // construct the light view matrix
var spot_dir = vec3<f32>(light.light_custom_data.x, 0.0, light.light_custom_data.y); var spot_dir = vec3<f32>((*light).light_custom_data.x, 0.0, (*light).light_custom_data.y);
// reconstruct spot dir from x/z and y-direction flag // reconstruct spot dir from x/z and y-direction flag
spot_dir.y = sqrt(1.0 - spot_dir.x * spot_dir.x - spot_dir.z * spot_dir.z); spot_dir.y = sqrt(1.0 - spot_dir.x * spot_dir.x - spot_dir.z * spot_dir.z);
if ((light.flags & POINT_LIGHT_FLAGS_SPOT_LIGHT_Y_NEGATIVE) != 0u) { if (((*light).flags & POINT_LIGHT_FLAGS_SPOT_LIGHT_Y_NEGATIVE) != 0u) {
spot_dir.y = -spot_dir.y; spot_dir.y = -spot_dir.y;
} }
// view matrix z_axis is the reverse of transform.forward() // view matrix z_axis is the reverse of transform.forward()
let fwd = -spot_dir; let fwd = -spot_dir;
let distance_to_light = dot(fwd, surface_to_light); let distance_to_light = dot(fwd, surface_to_light);
let offset_position = let offset_position =
-surface_to_light -surface_to_light
+ (light.shadow_depth_bias * normalize(surface_to_light)) + ((*light).shadow_depth_bias * normalize(surface_to_light))
+ (surface_normal.xyz * light.shadow_normal_bias) * distance_to_light; + (surface_normal.xyz * (*light).shadow_normal_bias) * distance_to_light;
// the construction of the up and right vectors needs to precisely mirror the code // the construction of the up and right vectors needs to precisely mirror the code
// in render/light.rs:spot_light_view_matrix // in render/light.rs:spot_light_view_matrix
var sign = -1.0; var sign = -1.0;
if (fwd.z >= 0.0) { if (fwd.z >= 0.0) {
@ -74,14 +74,14 @@ fn fetch_spot_shadow(light_id: u32, frag_position: vec4<f32>, surface_normal: ve
let right_dir = vec3<f32>(-b, -sign - fwd.y * fwd.y * a, fwd.y); let right_dir = vec3<f32>(-b, -sign - fwd.y * fwd.y * a, fwd.y);
let light_inv_rot = mat3x3<f32>(right_dir, up_dir, fwd); let light_inv_rot = mat3x3<f32>(right_dir, up_dir, fwd);
// because the matrix is a pure rotation matrix, the inverse is just the transpose, and to calculate // because the matrix is a pure rotation matrix, the inverse is just the transpose, and to calculate
// the product of the transpose with a vector we can just post-multiply instead of pre-multplying. // the product of the transpose with a vector we can just post-multiply instead of pre-multplying.
// this allows us to keep the matrix construction code identical between CPU and GPU. // this allows us to keep the matrix construction code identical between CPU and GPU.
let projected_position = offset_position * light_inv_rot; let projected_position = offset_position * light_inv_rot;
// divide xy by perspective matrix "f" and by -projected.z (projected.z is -projection matrix's w) // divide xy by perspective matrix "f" and by -projected.z (projected.z is -projection matrix's w)
// to get ndc coordinates // to get ndc coordinates
let f_div_minus_z = 1.0 / (light.spot_light_tan_angle * -projected_position.z); let f_div_minus_z = 1.0 / ((*light).spot_light_tan_angle * -projected_position.z);
let shadow_xy_ndc = projected_position.xy * f_div_minus_z; let shadow_xy_ndc = projected_position.xy * f_div_minus_z;
// convert to uv coordinates // convert to uv coordinates
let shadow_uv = shadow_xy_ndc * vec2<f32>(0.5, -0.5) + vec2<f32>(0.5, 0.5); let shadow_uv = shadow_xy_ndc * vec2<f32>(0.5, -0.5) + vec2<f32>(0.5, 0.5);
@ -90,23 +90,23 @@ fn fetch_spot_shadow(light_id: u32, frag_position: vec4<f32>, surface_normal: ve
let depth = 0.1 / -projected_position.z; let depth = 0.1 / -projected_position.z;
#ifdef NO_ARRAY_TEXTURES_SUPPORT #ifdef NO_ARRAY_TEXTURES_SUPPORT
return textureSampleCompare(directional_shadow_textures, directional_shadow_textures_sampler, return textureSampleCompare(directional_shadow_textures, directional_shadow_textures_sampler,
shadow_uv, depth); shadow_uv, depth);
#else #else
return textureSampleCompareLevel(directional_shadow_textures, directional_shadow_textures_sampler, return textureSampleCompareLevel(directional_shadow_textures, directional_shadow_textures_sampler,
shadow_uv, i32(light_id) + lights.spot_light_shadowmap_offset, depth); shadow_uv, i32(light_id) + lights.spot_light_shadowmap_offset, depth);
#endif #endif
} }
fn fetch_directional_shadow(light_id: u32, frag_position: vec4<f32>, surface_normal: vec3<f32>) -> f32 { fn fetch_directional_shadow(light_id: u32, frag_position: vec4<f32>, surface_normal: vec3<f32>) -> f32 {
let light = lights.directional_lights[light_id]; let light = &lights.directional_lights[light_id];
// The normal bias is scaled to the texel size. // The normal bias is scaled to the texel size.
let normal_offset = light.shadow_normal_bias * surface_normal.xyz; let normal_offset = (*light).shadow_normal_bias * surface_normal.xyz;
let depth_offset = light.shadow_depth_bias * light.direction_to_light.xyz; let depth_offset = (*light).shadow_depth_bias * (*light).direction_to_light.xyz;
let offset_position = vec4<f32>(frag_position.xyz + normal_offset + depth_offset, frag_position.w); let offset_position = vec4<f32>(frag_position.xyz + normal_offset + depth_offset, frag_position.w);
let offset_position_clip = light.view_projection * offset_position; let offset_position_clip = (*light).view_projection * offset_position;
if (offset_position_clip.w <= 0.0) { if (offset_position_clip.w <= 0.0) {
return 1.0; return 1.0;
} }