(Adoped) Remove panics and optimise mesh picking (#18232)

_Note from BD103: this PR was adopted from #16148. The majority of this
PR's description is copied from the original._

# Objective

Adds tests to cover various mesh picking cases and removes sources of
panics.

It should prevent users being able to trigger panics in `bevy_picking`
code via bad mesh data such as #15891, and is a follow up to my comments
in [#15800
(review)](https://github.com/bevyengine/bevy/pull/15800#pullrequestreview-2361694213).

This is motivated by #15979

## Testing

Adds 8 new tests to cover `ray_mesh_intersection` code.

## Changes from original PR

I reverted the changes to the benchmarks, since that was the largest
factor blocking it merging. I'll open a follow-up issue so that those
benchmark changes can be implemented.

---------

Co-authored-by: Trent <2771466+tbillington@users.noreply.github.com>
This commit is contained in:
BD103 2025-03-10 17:55:40 -04:00 committed by GitHub
parent 690858166c
commit e24191dd89
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -66,160 +66,134 @@ pub fn ray_mesh_intersection<I: TryInto<usize> + Clone + Copy>(
indices: Option<&[I]>,
backface_culling: Backfaces,
) -> Option<RayMeshHit> {
// The ray cast can hit the same mesh many times, so we need to track which hit is
// closest to the camera, and record that.
let mut closest_hit_distance = f32::MAX;
let mut closest_hit = None;
let world_to_mesh = mesh_transform.inverse();
let mesh_space_ray = Ray3d::new(
let ray = Ray3d::new(
world_to_mesh.transform_point3(ray.origin),
Dir3::new(world_to_mesh.transform_vector3(*ray.direction)).ok()?,
);
if let Some(indices) = indices {
let closest_hit = if let Some(indices) = indices {
// The index list must be a multiple of three. If not, the mesh is malformed and the raycast
// result might be nonsensical.
if indices.len() % 3 != 0 {
return None;
}
for triangle in indices.chunks_exact(3) {
let [a, b, c] = [
triangle[0].try_into().ok()?,
triangle[1].try_into().ok()?,
triangle[2].try_into().ok()?,
];
indices
.chunks_exact(3)
.fold(
(f32::MAX, None),
|(closest_distance, closest_hit), triangle| {
let [Ok(a), Ok(b), Ok(c)] = [
triangle[0].try_into(),
triangle[1].try_into(),
triangle[2].try_into(),
] else {
return (closest_distance, closest_hit);
};
let triangle_index = Some(a);
let tri_vertex_positions = &[
Vec3::from(positions[a]),
Vec3::from(positions[b]),
Vec3::from(positions[c]),
];
let tri_normals = vertex_normals.map(|normals| {
[
Vec3::from(normals[a]),
Vec3::from(normals[b]),
Vec3::from(normals[c]),
]
});
let tri_vertices = match [positions.get(a), positions.get(b), positions.get(c)]
{
[Some(a), Some(b), Some(c)] => {
[Vec3::from(*a), Vec3::from(*b), Vec3::from(*c)]
}
_ => return (closest_distance, closest_hit),
};
let Some(hit) = triangle_intersection(
tri_vertex_positions,
tri_normals.as_ref(),
closest_hit_distance,
&mesh_space_ray,
backface_culling,
) else {
continue;
};
closest_hit = Some(RayMeshHit {
point: mesh_transform.transform_point3(hit.point),
normal: mesh_transform.transform_vector3(hit.normal),
barycentric_coords: hit.barycentric_coords,
distance: mesh_transform
.transform_vector3(mesh_space_ray.direction * hit.distance)
.length(),
triangle: hit.triangle.map(|tri| {
[
mesh_transform.transform_point3(tri[0]),
mesh_transform.transform_point3(tri[1]),
mesh_transform.transform_point3(tri[2]),
]
}),
triangle_index,
});
closest_hit_distance = hit.distance;
}
match ray_triangle_intersection(&ray, &tri_vertices, backface_culling) {
Some(hit) if hit.distance >= 0. && hit.distance < closest_distance => {
(hit.distance, Some((a, hit)))
}
_ => (closest_distance, closest_hit),
}
},
)
.1
} else {
for (i, triangle) in positions.chunks_exact(3).enumerate() {
let &[a, b, c] = triangle else {
continue;
};
let triangle_index = Some(i);
let tri_vertex_positions = &[Vec3::from(a), Vec3::from(b), Vec3::from(c)];
let tri_normals = vertex_normals.map(|normals| {
[
Vec3::from(normals[i]),
Vec3::from(normals[i + 1]),
Vec3::from(normals[i + 2]),
]
});
positions
.chunks_exact(3)
.enumerate()
.fold(
(f32::MAX, None),
|(closest_distance, closest_hit), (tri_idx, triangle)| {
let tri_vertices = [
Vec3::from(triangle[0]),
Vec3::from(triangle[1]),
Vec3::from(triangle[2]),
];
let Some(hit) = triangle_intersection(
tri_vertex_positions,
tri_normals.as_ref(),
closest_hit_distance,
&mesh_space_ray,
backface_culling,
) else {
continue;
};
closest_hit = Some(RayMeshHit {
point: mesh_transform.transform_point3(hit.point),
normal: mesh_transform.transform_vector3(hit.normal),
barycentric_coords: hit.barycentric_coords,
distance: mesh_transform
.transform_vector3(mesh_space_ray.direction * hit.distance)
.length(),
triangle: hit.triangle.map(|tri| {
[
mesh_transform.transform_point3(tri[0]),
mesh_transform.transform_point3(tri[1]),
mesh_transform.transform_point3(tri[2]),
]
}),
triangle_index,
});
closest_hit_distance = hit.distance;
}
}
closest_hit
}
fn triangle_intersection(
tri_vertices: &[Vec3; 3],
tri_normals: Option<&[Vec3; 3]>,
max_distance: f32,
ray: &Ray3d,
backface_culling: Backfaces,
) -> Option<RayMeshHit> {
let hit = ray_triangle_intersection(ray, tri_vertices, backface_culling)?;
if hit.distance < 0.0 || hit.distance > max_distance {
return None;
match ray_triangle_intersection(&ray, &tri_vertices, backface_culling) {
Some(hit) if hit.distance >= 0. && hit.distance < closest_distance => {
(hit.distance, Some((tri_idx, hit)))
}
_ => (closest_distance, closest_hit),
}
},
)
.1
};
let point = ray.get_point(hit.distance);
let u = hit.barycentric_coords.0;
let v = hit.barycentric_coords.1;
let w = 1.0 - u - v;
let barycentric = Vec3::new(u, v, w);
closest_hit.and_then(|(tri_idx, hit)| {
let [a, b, c] = match indices {
Some(indices) => {
let triangle = indices.get((tri_idx * 3)..(tri_idx * 3 + 3))?;
let normal = if let Some(normals) = tri_normals {
normals[1] * u + normals[2] * v + normals[0] * w
} else {
(tri_vertices[1] - tri_vertices[0])
.cross(tri_vertices[2] - tri_vertices[0])
.normalize()
};
let [Ok(a), Ok(b), Ok(c)] = [
triangle[0].try_into(),
triangle[1].try_into(),
triangle[2].try_into(),
] else {
return None;
};
Some(RayMeshHit {
point,
normal,
barycentric_coords: barycentric,
distance: hit.distance,
triangle: Some(*tri_vertices),
triangle_index: None,
[a, b, c]
}
None => [tri_idx * 3, tri_idx * 3 + 1, tri_idx * 3 + 2],
};
let tri_vertices = match [positions.get(a), positions.get(b), positions.get(c)] {
[Some(a), Some(b), Some(c)] => [Vec3::from(*a), Vec3::from(*b), Vec3::from(*c)],
_ => return None,
};
let tri_normals = vertex_normals.and_then(|normals| {
let [Some(a), Some(b), Some(c)] = [normals.get(a), normals.get(b), normals.get(c)]
else {
return None;
};
Some([Vec3::from(*a), Vec3::from(*b), Vec3::from(*c)])
});
let point = ray.get_point(hit.distance);
let u = hit.barycentric_coords.0;
let v = hit.barycentric_coords.1;
let w = 1.0 - u - v;
let barycentric = Vec3::new(u, v, w);
let normal = if let Some(normals) = tri_normals {
normals[1] * u + normals[2] * v + normals[0] * w
} else {
(tri_vertices[1] - tri_vertices[0])
.cross(tri_vertices[2] - tri_vertices[0])
.normalize()
};
Some(RayMeshHit {
point: mesh_transform.transform_point3(point),
normal: mesh_transform.transform_vector3(normal),
barycentric_coords: barycentric,
distance: mesh_transform
.transform_vector3(ray.direction * hit.distance)
.length(),
triangle: Some(tri_vertices.map(|v| mesh_transform.transform_point3(v))),
triangle_index: Some(a),
})
})
}
/// Takes a ray and triangle and computes the intersection.
#[inline]
fn ray_triangle_intersection(
ray: &Ray3d,
triangle: &[Vec3; 3],
@ -313,6 +287,7 @@ pub fn ray_aabb_intersection_3d(ray: Ray3d, aabb: &Aabb3d, model_to_world: &Mat4
#[cfg(test)]
mod tests {
use bevy_math::Vec3;
use bevy_transform::components::GlobalTransform;
use super::*;
@ -336,4 +311,174 @@ mod tests {
let result = ray_triangle_intersection(&ray, &triangle, Backfaces::Cull);
assert!(result.is_none());
}
#[test]
fn ray_mesh_intersection_simple() {
let ray = Ray3d::new(Vec3::ZERO, Dir3::X);
let mesh_transform = GlobalTransform::IDENTITY.compute_matrix();
let positions = &[V0, V1, V2];
let vertex_normals = None;
let indices: Option<&[u16]> = None;
let backface_culling = Backfaces::Cull;
let result = ray_mesh_intersection(
ray,
&mesh_transform,
positions,
vertex_normals,
indices,
backface_culling,
);
assert!(result.is_some());
}
#[test]
fn ray_mesh_intersection_indices() {
let ray = Ray3d::new(Vec3::ZERO, Dir3::X);
let mesh_transform = GlobalTransform::IDENTITY.compute_matrix();
let positions = &[V0, V1, V2];
let vertex_normals = None;
let indices: Option<&[u16]> = Some(&[0, 1, 2]);
let backface_culling = Backfaces::Cull;
let result = ray_mesh_intersection(
ray,
&mesh_transform,
positions,
vertex_normals,
indices,
backface_culling,
);
assert!(result.is_some());
}
#[test]
fn ray_mesh_intersection_indices_vertex_normals() {
let ray = Ray3d::new(Vec3::ZERO, Dir3::X);
let mesh_transform = GlobalTransform::IDENTITY.compute_matrix();
let positions = &[V0, V1, V2];
let vertex_normals: Option<&[[f32; 3]]> =
Some(&[[-1., 0., 0.], [-1., 0., 0.], [-1., 0., 0.]]);
let indices: Option<&[u16]> = Some(&[0, 1, 2]);
let backface_culling = Backfaces::Cull;
let result = ray_mesh_intersection(
ray,
&mesh_transform,
positions,
vertex_normals,
indices,
backface_culling,
);
assert!(result.is_some());
}
#[test]
fn ray_mesh_intersection_vertex_normals() {
let ray = Ray3d::new(Vec3::ZERO, Dir3::X);
let mesh_transform = GlobalTransform::IDENTITY.compute_matrix();
let positions = &[V0, V1, V2];
let vertex_normals: Option<&[[f32; 3]]> =
Some(&[[-1., 0., 0.], [-1., 0., 0.], [-1., 0., 0.]]);
let indices: Option<&[u16]> = None;
let backface_culling = Backfaces::Cull;
let result = ray_mesh_intersection(
ray,
&mesh_transform,
positions,
vertex_normals,
indices,
backface_culling,
);
assert!(result.is_some());
}
#[test]
fn ray_mesh_intersection_missing_vertex_normals() {
let ray = Ray3d::new(Vec3::ZERO, Dir3::X);
let mesh_transform = GlobalTransform::IDENTITY.compute_matrix();
let positions = &[V0, V1, V2];
let vertex_normals: Option<&[[f32; 3]]> = Some(&[]);
let indices: Option<&[u16]> = None;
let backface_culling = Backfaces::Cull;
let result = ray_mesh_intersection(
ray,
&mesh_transform,
positions,
vertex_normals,
indices,
backface_culling,
);
assert!(result.is_some());
}
#[test]
fn ray_mesh_intersection_indices_missing_vertex_normals() {
let ray = Ray3d::new(Vec3::ZERO, Dir3::X);
let mesh_transform = GlobalTransform::IDENTITY.compute_matrix();
let positions = &[V0, V1, V2];
let vertex_normals: Option<&[[f32; 3]]> = Some(&[]);
let indices: Option<&[u16]> = Some(&[0, 1, 2]);
let backface_culling = Backfaces::Cull;
let result = ray_mesh_intersection(
ray,
&mesh_transform,
positions,
vertex_normals,
indices,
backface_culling,
);
assert!(result.is_some());
}
#[test]
fn ray_mesh_intersection_not_enough_indices() {
let ray = Ray3d::new(Vec3::ZERO, Dir3::X);
let mesh_transform = GlobalTransform::IDENTITY.compute_matrix();
let positions = &[V0, V1, V2];
let vertex_normals = None;
let indices: Option<&[u16]> = Some(&[0]);
let backface_culling = Backfaces::Cull;
let result = ray_mesh_intersection(
ray,
&mesh_transform,
positions,
vertex_normals,
indices,
backface_culling,
);
assert!(result.is_none());
}
#[test]
fn ray_mesh_intersection_bad_indices() {
let ray = Ray3d::new(Vec3::ZERO, Dir3::X);
let mesh_transform = GlobalTransform::IDENTITY.compute_matrix();
let positions = &[V0, V1, V2];
let vertex_normals = None;
let indices: Option<&[u16]> = Some(&[0, 1, 3]);
let backface_culling = Backfaces::Cull;
let result = ray_mesh_intersection(
ray,
&mesh_transform,
positions,
vertex_normals,
indices,
backface_culling,
);
assert!(result.is_none());
}
}