Fix meshlet vertex attribute interpolation (#13775)

# Objective

- Mikktspace requires that we normalize world normals/tangents _before_
interpolation across vertices, and then do _not_ normalize after. I had
it backwards.
- We do not (am not supposed to?) need a second set of barycentrics for
motion vectors. If you think about the typical raster pipeline, in the
vertex shader we calculate previous_world_position, and then it gets
interpolated using the current triangle's barycentrics.

## Solution

- Fix normal/tangent processing 
- Reuse barycentrics for motion vector calculations
- Not implementing this for 0.14, but long term I aim to remove explicit
vertex tangents and calculate them in the shader on the fly.

## Testing

- I tested out some of the normal maps we have in repo. Didn't seem to
make a difference, but mikktspace is all about correctness across
various baking tools. I probably just didn't have any of the ones that
would cause it to break.
- Didn't test motion vectors as there's a known bug with the depth
buffer and meshlets that I'm waiting on the render graph rewrite to fix.
This commit is contained in:
JMS55 2024-06-10 13:18:43 -07:00 committed by François
parent 7cd90990f9
commit 93f48edbc3
No known key found for this signature in database
2 changed files with 50 additions and 35 deletions

View File

@ -13,8 +13,8 @@
unpack_meshlet_vertex, unpack_meshlet_vertex,
}, },
mesh_view_bindings::view, mesh_view_bindings::view,
mesh_functions::mesh_position_local_to_world, mesh_functions::{mesh_position_local_to_world, sign_determinant_model_3x3m},
mesh_types::MESH_FLAGS_SIGN_DETERMINANT_MODEL_3X3_BIT, mesh_types::{Mesh, MESH_FLAGS_SIGN_DETERMINANT_MODEL_3X3_BIT},
view_transformations::{position_world_to_clip, frag_coord_to_ndc}, view_transformations::{position_world_to_clip, frag_coord_to_ndc},
} }
#import bevy_render::maths::{affine3_to_square, mat2x4_f32_to_mat3x3_unpack} #import bevy_render::maths::{affine3_to_square, mat2x4_f32_to_mat3x3_unpack}
@ -99,6 +99,7 @@ fn resolve_vertex_output(frag_coord: vec4<f32>) -> VertexOutput {
let cluster_id = packed_ids >> 6u; let cluster_id = packed_ids >> 6u;
let meshlet_id = meshlet_cluster_meshlet_ids[cluster_id]; let meshlet_id = meshlet_cluster_meshlet_ids[cluster_id];
let meshlet = meshlets[meshlet_id]; let meshlet = meshlets[meshlet_id];
let triangle_id = extractBits(packed_ids, 0u, 6u); let triangle_id = extractBits(packed_ids, 0u, 6u);
let index_ids = meshlet.start_index_id + vec3(triangle_id * 3u) + vec3(0u, 1u, 2u); let index_ids = meshlet.start_index_id + vec3(triangle_id * 3u) + vec3(0u, 1u, 2u);
let indices = meshlet.start_vertex_id + vec3(get_meshlet_index(index_ids.x), get_meshlet_index(index_ids.y), get_meshlet_index(index_ids.z)); let indices = meshlet.start_vertex_id + vec3(get_meshlet_index(index_ids.x), get_meshlet_index(index_ids.y), get_meshlet_index(index_ids.z));
@ -108,9 +109,9 @@ fn resolve_vertex_output(frag_coord: vec4<f32>) -> VertexOutput {
let vertex_3 = unpack_meshlet_vertex(meshlet_vertex_data[vertex_ids.z]); let vertex_3 = unpack_meshlet_vertex(meshlet_vertex_data[vertex_ids.z]);
let instance_id = meshlet_cluster_instance_ids[cluster_id]; let instance_id = meshlet_cluster_instance_ids[cluster_id];
let instance_uniform = meshlet_instance_uniforms[instance_id]; var instance_uniform = meshlet_instance_uniforms[instance_id];
let world_from_local = affine3_to_square(instance_uniform.world_from_local);
let world_from_local = affine3_to_square(instance_uniform.world_from_local);
let world_position_1 = mesh_position_local_to_world(world_from_local, vec4(vertex_1.position, 1.0)); let world_position_1 = mesh_position_local_to_world(world_from_local, vec4(vertex_1.position, 1.0));
let world_position_2 = mesh_position_local_to_world(world_from_local, vec4(vertex_2.position, 1.0)); let world_position_2 = mesh_position_local_to_world(world_from_local, vec4(vertex_2.position, 1.0));
let world_position_3 = mesh_position_local_to_world(world_from_local, vec4(vertex_3.position, 1.0)); let world_position_3 = mesh_position_local_to_world(world_from_local, vec4(vertex_3.position, 1.0));
@ -126,27 +127,19 @@ fn resolve_vertex_output(frag_coord: vec4<f32>) -> VertexOutput {
); );
let world_position = mat3x4(world_position_1, world_position_2, world_position_3) * partial_derivatives.barycentrics; let world_position = mat3x4(world_position_1, world_position_2, world_position_3) * partial_derivatives.barycentrics;
let vertex_normal = mat3x3(vertex_1.normal, vertex_2.normal, vertex_3.normal) * partial_derivatives.barycentrics; let world_normal = mat3x3(
let world_normal = normalize( normal_local_to_world(vertex_1.normal, &instance_uniform),
mat2x4_f32_to_mat3x3_unpack( normal_local_to_world(vertex_2.normal, &instance_uniform),
instance_uniform.local_from_world_transpose_a, normal_local_to_world(vertex_3.normal, &instance_uniform),
instance_uniform.local_from_world_transpose_b, ) * partial_derivatives.barycentrics;
) * vertex_normal
);
let uv = mat3x2(vertex_1.uv, vertex_2.uv, vertex_3.uv) * partial_derivatives.barycentrics; let uv = mat3x2(vertex_1.uv, vertex_2.uv, vertex_3.uv) * partial_derivatives.barycentrics;
let ddx_uv = mat3x2(vertex_1.uv, vertex_2.uv, vertex_3.uv) * partial_derivatives.ddx; let ddx_uv = mat3x2(vertex_1.uv, vertex_2.uv, vertex_3.uv) * partial_derivatives.ddx;
let ddy_uv = mat3x2(vertex_1.uv, vertex_2.uv, vertex_3.uv) * partial_derivatives.ddy; let ddy_uv = mat3x2(vertex_1.uv, vertex_2.uv, vertex_3.uv) * partial_derivatives.ddy;
let vertex_tangent = mat3x4(vertex_1.tangent, vertex_2.tangent, vertex_3.tangent) * partial_derivatives.barycentrics; let world_tangent = mat3x4(
let world_tangent = vec4( tangent_local_to_world(vertex_1.tangent, world_from_local, instance_uniform.flags),
normalize( tangent_local_to_world(vertex_2.tangent, world_from_local, instance_uniform.flags),
mat3x3( tangent_local_to_world(vertex_3.tangent, world_from_local, instance_uniform.flags),
world_from_local[0].xyz, ) * partial_derivatives.barycentrics;
world_from_local[1].xyz,
world_from_local[2].xyz
) * vertex_tangent.xyz
),
vertex_tangent.w * (f32(bool(instance_uniform.flags & MESH_FLAGS_SIGN_DETERMINANT_MODEL_3X3_BIT)) * 2.0 - 1.0)
);
#ifdef PREPASS_FRAGMENT #ifdef PREPASS_FRAGMENT
#ifdef MOTION_VECTOR_PREPASS #ifdef MOTION_VECTOR_PREPASS
@ -154,15 +147,7 @@ fn resolve_vertex_output(frag_coord: vec4<f32>) -> VertexOutput {
let previous_world_position_1 = mesh_position_local_to_world(previous_world_from_local, vec4(vertex_1.position, 1.0)); let previous_world_position_1 = mesh_position_local_to_world(previous_world_from_local, vec4(vertex_1.position, 1.0));
let previous_world_position_2 = mesh_position_local_to_world(previous_world_from_local, vec4(vertex_2.position, 1.0)); let previous_world_position_2 = mesh_position_local_to_world(previous_world_from_local, vec4(vertex_2.position, 1.0));
let previous_world_position_3 = mesh_position_local_to_world(previous_world_from_local, vec4(vertex_3.position, 1.0)); let previous_world_position_3 = mesh_position_local_to_world(previous_world_from_local, vec4(vertex_3.position, 1.0));
let previous_clip_position_1 = previous_view_uniforms.clip_from_world * vec4(previous_world_position_1.xyz, 1.0); let previous_world_position = mat3x4(previous_world_position_1, previous_world_position_2, previous_world_position_3) * partial_derivatives.barycentrics;
let previous_clip_position_2 = previous_view_uniforms.clip_from_world * vec4(previous_world_position_2.xyz, 1.0);
let previous_clip_position_3 = previous_view_uniforms.clip_from_world * vec4(previous_world_position_3.xyz, 1.0);
let previous_partial_derivatives = compute_partial_derivatives(
array(previous_clip_position_1, previous_clip_position_2, previous_clip_position_3),
frag_coord_ndc,
view.viewport.zw,
);
let previous_world_position = mat3x4(previous_world_position_1, previous_world_position_2, previous_world_position_3) * previous_partial_derivatives.barycentrics;
let motion_vector = calculate_motion_vector(world_position, previous_world_position); let motion_vector = calculate_motion_vector(world_position, previous_world_position);
#endif #endif
#endif #endif
@ -184,4 +169,34 @@ fn resolve_vertex_output(frag_coord: vec4<f32>) -> VertexOutput {
#endif #endif
); );
} }
fn normal_local_to_world(vertex_normal: vec3<f32>, instance_uniform: ptr<function, Mesh>) -> vec3<f32> {
if any(vertex_normal != vec3<f32>(0.0)) {
return normalize(
mat2x4_f32_to_mat3x3_unpack(
(*instance_uniform).local_from_world_transpose_a,
(*instance_uniform).local_from_world_transpose_b,
) * vertex_normal
);
} else {
return vertex_normal;
}
}
fn tangent_local_to_world(vertex_tangent: vec4<f32>, world_from_local: mat4x4<f32>, mesh_flags: u32) -> vec4<f32> {
if any(vertex_tangent != vec4<f32>(0.0)) {
return vec4<f32>(
normalize(
mat3x3<f32>(
world_from_local[0].xyz,
world_from_local[1].xyz,
world_from_local[2].xyz,
) * vertex_tangent.xyz
),
vertex_tangent.w * sign_determinant_model_3x3m(mesh_flags)
);
} else {
return vertex_tangent;
}
}
#endif #endif

View File

@ -55,11 +55,11 @@ fn mesh_normal_local_to_world(vertex_normal: vec3<f32>, instance_index: u32) ->
// Calculates the sign of the determinant of the 3x3 model matrix based on a // Calculates the sign of the determinant of the 3x3 model matrix based on a
// mesh flag // mesh flag
fn sign_determinant_model_3x3m(instance_index: u32) -> f32 { fn sign_determinant_model_3x3m(mesh_flags: u32) -> f32 {
// bool(u32) is false if 0u else true // bool(u32) is false if 0u else true
// f32(bool) is 1.0 if true else 0.0 // f32(bool) is 1.0 if true else 0.0
// * 2.0 - 1.0 remaps 0.0 or 1.0 to -1.0 or 1.0 respectively // * 2.0 - 1.0 remaps 0.0 or 1.0 to -1.0 or 1.0 respectively
return f32(bool(mesh[instance_index].flags & MESH_FLAGS_SIGN_DETERMINANT_MODEL_3X3_BIT)) * 2.0 - 1.0; return f32(bool(mesh_flags & MESH_FLAGS_SIGN_DETERMINANT_MODEL_3X3_BIT)) * 2.0 - 1.0;
} }
fn mesh_tangent_local_to_world(world_from_local: mat4x4<f32>, vertex_tangent: vec4<f32>, instance_index: u32) -> vec4<f32> { fn mesh_tangent_local_to_world(world_from_local: mat4x4<f32>, vertex_tangent: vec4<f32>, instance_index: u32) -> vec4<f32> {
@ -76,12 +76,12 @@ fn mesh_tangent_local_to_world(world_from_local: mat4x4<f32>, vertex_tangent: ve
mat3x3<f32>( mat3x3<f32>(
world_from_local[0].xyz, world_from_local[0].xyz,
world_from_local[1].xyz, world_from_local[1].xyz,
world_from_local[2].xyz world_from_local[2].xyz,
) * vertex_tangent.xyz ) * vertex_tangent.xyz
), ),
// NOTE: Multiplying by the sign of the determinant of the 3x3 model matrix accounts for // NOTE: Multiplying by the sign of the determinant of the 3x3 model matrix accounts for
// situations such as negative scaling. // situations such as negative scaling.
vertex_tangent.w * sign_determinant_model_3x3m(instance_index) vertex_tangent.w * sign_determinant_model_3x3m(mesh[instance_index].flags)
); );
} else { } else {
return vertex_tangent; return vertex_tangent;