Fix size_hint for partially consumed QueryIter and QueryCombinationIter (#5214)
# Objective Fix #5149 ## Solution Instead of returning the **total count** of elements in the `QueryIter` in `size_hint`, we return the **count of remaining elements**. This Fixes #5149 even when #5148 gets merged. - https://github.com/bevyengine/bevy/issues/5149 - https://github.com/bevyengine/bevy/pull/5148 --- ## Changelog - Fix partially consumed `QueryIter` and `QueryCombinationIter` having invalid `size_hint` Co-authored-by: Nicola Papale <nicopap@users.noreply.github.com>
This commit is contained in:
parent
e0c3c6d166
commit
15ea93a348
@ -56,13 +56,7 @@ impl<'w, 's, Q: WorldQuery, F: ReadOnlyWorldQuery> Iterator for QueryIter<'w, 's
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
let max_size = self
|
||||
.query_state
|
||||
.matched_archetype_ids
|
||||
.iter()
|
||||
.map(|id| self.archetypes[*id].len())
|
||||
.sum();
|
||||
|
||||
let max_size = self.cursor.max_remaining(self.tables, self.archetypes);
|
||||
let archetype_query = Q::IS_ARCHETYPAL && F::IS_ARCHETYPAL;
|
||||
let min_size = if archetype_query { max_size } else { 0 };
|
||||
(min_size, Some(max_size))
|
||||
@ -351,11 +345,16 @@ impl<'w, 's, Q: WorldQuery, F: ReadOnlyWorldQuery, const K: usize>
|
||||
return None;
|
||||
}
|
||||
|
||||
// first, iterate from last to first until next item is found
|
||||
// PERF: can speed up the following code using `cursor.remaining()` instead of `next_item.is_none()`
|
||||
// when Q::IS_ARCHETYPAL && F::IS_ARCHETYPAL
|
||||
//
|
||||
// let `i` be the index of `c`, the last cursor in `self.cursors` that
|
||||
// returns `K-i` or more elements.
|
||||
// Make cursor in index `j` for all `j` in `[i, K)` a copy of `c` advanced `j-i+1` times.
|
||||
// If no such `c` exists, return `None`
|
||||
'outer: for i in (0..K).rev() {
|
||||
match self.cursors[i].next(self.tables, self.archetypes, self.query_state) {
|
||||
Some(_) => {
|
||||
// walk forward up to last element, propagating cursor state forward
|
||||
for j in (i + 1)..K {
|
||||
self.cursors[j] = self.cursors[j - 1].clone_cursor();
|
||||
match self.cursors[j].next(self.tables, self.archetypes, self.query_state) {
|
||||
@ -409,36 +408,29 @@ impl<'w, 's, Q: ReadOnlyWorldQuery, F: ReadOnlyWorldQuery, const K: usize> Itera
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
if K == 0 {
|
||||
return (0, Some(0));
|
||||
}
|
||||
|
||||
let max_size: usize = self
|
||||
.query_state
|
||||
.matched_archetype_ids
|
||||
.iter()
|
||||
.map(|id| self.archetypes[*id].len())
|
||||
.sum();
|
||||
|
||||
if max_size < K {
|
||||
return (0, Some(0));
|
||||
}
|
||||
if max_size == K {
|
||||
return (1, Some(1));
|
||||
}
|
||||
|
||||
// binomial coefficient: (n ; k) = n! / k!(n-k)! = (n*n-1*...*n-k+1) / k!
|
||||
// See https://en.wikipedia.org/wiki/Binomial_coefficient
|
||||
// See https://blog.plover.com/math/choose.html for implementation
|
||||
// It was chosen to reduce overflow potential.
|
||||
fn choose(n: usize, k: usize) -> Option<usize> {
|
||||
if k > n || n == 0 {
|
||||
return Some(0);
|
||||
}
|
||||
let k = k.min(n - k);
|
||||
let ks = 1..=k;
|
||||
let ns = (n - k + 1..=n).rev();
|
||||
ks.zip(ns)
|
||||
.try_fold(1_usize, |acc, (k, n)| Some(acc.checked_mul(n)? / k))
|
||||
}
|
||||
let smallest = K.min(max_size - K);
|
||||
let max_combinations = choose(max_size, smallest);
|
||||
// sum_i=0..k choose(cursors[i].remaining, k-i)
|
||||
let max_combinations = self
|
||||
.cursors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.try_fold(0, |acc, (i, cursor)| {
|
||||
let n = cursor.max_remaining(self.tables, self.archetypes);
|
||||
Some(acc + choose(n, K - i)?)
|
||||
});
|
||||
|
||||
let archetype_query = F::IS_ARCHETYPAL && Q::IS_ARCHETYPAL;
|
||||
let known_max = max_combinations.unwrap_or(usize::MAX);
|
||||
@ -452,11 +444,7 @@ where
|
||||
F: ArchetypeFilter,
|
||||
{
|
||||
fn len(&self) -> usize {
|
||||
self.query_state
|
||||
.matched_archetype_ids
|
||||
.iter()
|
||||
.map(|id| self.archetypes[*id].len())
|
||||
.sum()
|
||||
self.size_hint().0
|
||||
}
|
||||
}
|
||||
|
||||
@ -571,6 +559,21 @@ impl<'w, 's, Q: WorldQuery, F: ReadOnlyWorldQuery> QueryIterationCursor<'w, 's,
|
||||
}
|
||||
}
|
||||
|
||||
/// How many values will this cursor return at most?
|
||||
///
|
||||
/// Note that if `Q::IS_ARCHETYPAL && F::IS_ARCHETYPAL`, the return value
|
||||
/// will be **the exact count of remaining values**.
|
||||
fn max_remaining(&self, tables: &'w Tables, archetypes: &'w Archetypes) -> usize {
|
||||
let remaining_matched: usize = if Self::IS_DENSE {
|
||||
let ids = self.table_id_iter.clone();
|
||||
ids.map(|id| tables[*id].entity_count()).sum()
|
||||
} else {
|
||||
let ids = self.archetype_id_iter.clone();
|
||||
ids.map(|id| archetypes[*id].len()).sum()
|
||||
};
|
||||
remaining_matched + self.current_len - self.current_index
|
||||
}
|
||||
|
||||
// NOTE: If you are changing query iteration code, remember to update the following places, where relevant:
|
||||
// QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter, QueryState::for_each_unchecked_manual, QueryState::par_for_each_unchecked_manual
|
||||
/// # Safety
|
||||
|
||||
@ -96,100 +96,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn query_filtered_exactsizeiterator_len() {
|
||||
fn assert_all_sizes_iterator_equal(
|
||||
iterator: impl ExactSizeIterator,
|
||||
expected_size: usize,
|
||||
query_type: &'static str,
|
||||
) {
|
||||
let len = iterator.len();
|
||||
let size_hint_0 = iterator.size_hint().0;
|
||||
let size_hint_1 = iterator.size_hint().1;
|
||||
// `count` tests that not only it is the expected value, but also
|
||||
// the value is accurate to what the query returns.
|
||||
let count = iterator.count();
|
||||
// This will show up when one of the asserts in this function fails
|
||||
println!(
|
||||
r#"query declared sizes:
|
||||
for query: {query_type}
|
||||
expected: {expected_size}
|
||||
len: {len}
|
||||
size_hint().0: {size_hint_0}
|
||||
size_hint().1: {size_hint_1:?}
|
||||
count(): {count}"#
|
||||
);
|
||||
assert_eq!(len, expected_size);
|
||||
assert_eq!(size_hint_0, expected_size);
|
||||
assert_eq!(size_hint_1, Some(expected_size));
|
||||
assert_eq!(count, expected_size);
|
||||
}
|
||||
fn assert_all_sizes_equal<Q, F>(world: &mut World, expected_size: usize)
|
||||
where
|
||||
Q: ReadOnlyWorldQuery,
|
||||
F: ReadOnlyWorldQuery,
|
||||
F::ReadOnly: ArchetypeFilter,
|
||||
{
|
||||
let mut query = world.query_filtered::<Q, F>();
|
||||
let iter = query.iter(world);
|
||||
let query_type = type_name::<QueryState<Q, F>>();
|
||||
assert_all_sizes_iterator_equal(iter, expected_size, query_type);
|
||||
}
|
||||
|
||||
let mut world = World::new();
|
||||
world.spawn((A(1), B(1)));
|
||||
world.spawn(A(2));
|
||||
world.spawn(A(3));
|
||||
|
||||
assert_all_sizes_equal::<&A, With<B>>(&mut world, 1);
|
||||
assert_all_sizes_equal::<&A, Without<B>>(&mut world, 2);
|
||||
|
||||
let mut world = World::new();
|
||||
world.spawn((A(1), B(1), C(1)));
|
||||
world.spawn((A(2), B(2)));
|
||||
world.spawn((A(3), B(3)));
|
||||
world.spawn((A(4), C(4)));
|
||||
world.spawn((A(5), C(5)));
|
||||
world.spawn((A(6), C(6)));
|
||||
world.spawn(A(7));
|
||||
world.spawn(A(8));
|
||||
world.spawn(A(9));
|
||||
world.spawn(A(10));
|
||||
|
||||
// With/Without for B and C
|
||||
assert_all_sizes_equal::<&A, With<B>>(&mut world, 3);
|
||||
assert_all_sizes_equal::<&A, With<C>>(&mut world, 4);
|
||||
assert_all_sizes_equal::<&A, Without<B>>(&mut world, 7);
|
||||
assert_all_sizes_equal::<&A, Without<C>>(&mut world, 6);
|
||||
|
||||
// With/Without (And) combinations
|
||||
assert_all_sizes_equal::<&A, (With<B>, With<C>)>(&mut world, 1);
|
||||
assert_all_sizes_equal::<&A, (With<B>, Without<C>)>(&mut world, 2);
|
||||
assert_all_sizes_equal::<&A, (Without<B>, With<C>)>(&mut world, 3);
|
||||
assert_all_sizes_equal::<&A, (Without<B>, Without<C>)>(&mut world, 4);
|
||||
|
||||
// With/Without Or<()> combinations
|
||||
assert_all_sizes_equal::<&A, Or<(With<B>, With<C>)>>(&mut world, 6);
|
||||
assert_all_sizes_equal::<&A, Or<(With<B>, Without<C>)>>(&mut world, 7);
|
||||
assert_all_sizes_equal::<&A, Or<(Without<B>, With<C>)>>(&mut world, 8);
|
||||
assert_all_sizes_equal::<&A, Or<(Without<B>, Without<C>)>>(&mut world, 9);
|
||||
assert_all_sizes_equal::<&A, (Or<(With<B>,)>, Or<(With<C>,)>)>(&mut world, 1);
|
||||
assert_all_sizes_equal::<&A, Or<(Or<(With<B>, With<C>)>, With<D>)>>(&mut world, 6);
|
||||
|
||||
for i in 11..14 {
|
||||
world.spawn((A(i), D(i)));
|
||||
}
|
||||
|
||||
assert_all_sizes_equal::<&A, Or<(Or<(With<B>, With<C>)>, With<D>)>>(&mut world, 9);
|
||||
assert_all_sizes_equal::<&A, Or<(Or<(With<B>, With<C>)>, Without<D>)>>(&mut world, 10);
|
||||
|
||||
// a fair amount of entities
|
||||
for i in 14..20 {
|
||||
world.spawn((C(i), D(i)));
|
||||
}
|
||||
assert_all_sizes_equal::<Entity, (With<C>, With<D>)>(&mut world, 6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn query_filtered_combination_size() {
|
||||
fn choose(n: usize, k: usize) -> usize {
|
||||
if n == 0 || k == 0 || n < k {
|
||||
return 0;
|
||||
@ -200,25 +106,30 @@ mod tests {
|
||||
}
|
||||
fn assert_combination<Q, F, const K: usize>(world: &mut World, expected_size: usize)
|
||||
where
|
||||
Q: WorldQuery,
|
||||
Q: ReadOnlyWorldQuery,
|
||||
F: ReadOnlyWorldQuery,
|
||||
F::ReadOnly: ArchetypeFilter,
|
||||
{
|
||||
let mut query = world.query_filtered::<Q, F>();
|
||||
let iter = query.iter_combinations::<K>(world);
|
||||
let query_type = type_name::<QueryCombinationIter<Q, F, K>>();
|
||||
assert_all_sizes_iterator_equal(iter, expected_size, query_type);
|
||||
let iter = query.iter_combinations::<K>(world);
|
||||
assert_all_sizes_iterator_equal(iter, expected_size, 0, query_type);
|
||||
let iter = query.iter_combinations::<K>(world);
|
||||
assert_all_sizes_iterator_equal(iter, expected_size, 1, query_type);
|
||||
let iter = query.iter_combinations::<K>(world);
|
||||
assert_all_sizes_iterator_equal(iter, expected_size, 5, query_type);
|
||||
}
|
||||
fn assert_all_sizes_equal<Q, F>(world: &mut World, expected_size: usize)
|
||||
where
|
||||
Q: WorldQuery,
|
||||
Q: ReadOnlyWorldQuery,
|
||||
F: ReadOnlyWorldQuery,
|
||||
F::ReadOnly: ArchetypeFilter,
|
||||
{
|
||||
let mut query = world.query_filtered::<Q, F>();
|
||||
let iter = query.iter(world);
|
||||
let query_type = type_name::<QueryState<Q, F>>();
|
||||
assert_all_sizes_iterator_equal(iter, expected_size, query_type);
|
||||
assert_all_exact_sizes_iterator_equal(query.iter(world), expected_size, 0, query_type);
|
||||
assert_all_exact_sizes_iterator_equal(query.iter(world), expected_size, 1, query_type);
|
||||
assert_all_exact_sizes_iterator_equal(query.iter(world), expected_size, 5, query_type);
|
||||
|
||||
let expected = expected_size;
|
||||
assert_combination::<Q, F, 0>(world, choose(expected, 0));
|
||||
@ -226,13 +137,29 @@ mod tests {
|
||||
assert_combination::<Q, F, 2>(world, choose(expected, 2));
|
||||
assert_combination::<Q, F, 5>(world, choose(expected, 5));
|
||||
assert_combination::<Q, F, 43>(world, choose(expected, 43));
|
||||
assert_combination::<Q, F, 128>(world, choose(expected, 128));
|
||||
assert_combination::<Q, F, 64>(world, choose(expected, 64));
|
||||
}
|
||||
fn assert_all_sizes_iterator_equal(
|
||||
iterator: impl Iterator,
|
||||
fn assert_all_exact_sizes_iterator_equal(
|
||||
iterator: impl ExactSizeIterator,
|
||||
expected_size: usize,
|
||||
skip: usize,
|
||||
query_type: &'static str,
|
||||
) {
|
||||
let len = iterator.len();
|
||||
println!("len: {len}");
|
||||
assert_all_sizes_iterator_equal(iterator, expected_size, skip, query_type);
|
||||
assert_eq!(len, expected_size);
|
||||
}
|
||||
fn assert_all_sizes_iterator_equal(
|
||||
mut iterator: impl Iterator,
|
||||
expected_size: usize,
|
||||
skip: usize,
|
||||
query_type: &'static str,
|
||||
) {
|
||||
let expected_size = expected_size.saturating_sub(skip);
|
||||
for _ in 0..skip {
|
||||
iterator.next();
|
||||
}
|
||||
let size_hint_0 = iterator.size_hint().0;
|
||||
let size_hint_1 = iterator.size_hint().1;
|
||||
// `count` tests that not only it is the expected value, but also
|
||||
@ -240,12 +167,12 @@ mod tests {
|
||||
let count = iterator.count();
|
||||
// This will show up when one of the asserts in this function fails
|
||||
println!(
|
||||
r#"query declared sizes:
|
||||
for query: {query_type}
|
||||
expected: {expected_size}
|
||||
size_hint().0: {size_hint_0}
|
||||
size_hint().1: {size_hint_1:?}
|
||||
count(): {count}"#
|
||||
"query declared sizes: \n\
|
||||
for query: {query_type} \n\
|
||||
expected: {expected_size} \n\
|
||||
size_hint().0: {size_hint_0} \n\
|
||||
size_hint().1: {size_hint_1:?} \n\
|
||||
count(): {count}"
|
||||
);
|
||||
assert_eq!(size_hint_0, expected_size);
|
||||
assert_eq!(size_hint_1, Some(expected_size));
|
||||
|
||||
Loading…
Reference in New Issue
Block a user