Skip to content

Commit

Permalink
Fix size_hint for partially consumed QueryIter and QueryCombinationIt…
Browse files Browse the repository at this point in the history
…er (bevyengine#5214)

# Objective

Fix bevyengine#5149

## Solution

Instead of returning the **total count** of elements in the `QueryIter` in
`size_hint`, we return the **count of remaining elements**. This
Fixes bevyengine#5149 even when bevyengine#5148 gets merged.

- bevyengine#5149
- bevyengine#5148

---

## Changelog

- Fix partially consumed `QueryIter` and `QueryCombinationIter` having invalid `size_hint`


Co-authored-by: Nicola Papale <nicopap@users.noreply.github.com>
  • Loading branch information
2 people authored and ItsDoot committed Feb 1, 2023
1 parent 407d428 commit 1321cba
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 142 deletions.
71 changes: 37 additions & 34 deletions crates/bevy_ecs/src/query/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,7 @@ impl<'w, 's, Q: WorldQuery, F: ReadOnlyWorldQuery> Iterator for QueryIter<'w, 's
}

fn size_hint(&self) -> (usize, Option<usize>) {
let max_size = self
.query_state
.matched_archetype_ids
.iter()
.map(|id| self.archetypes[*id].len())
.sum();

let max_size = self.cursor.max_remaining(self.tables, self.archetypes);
let archetype_query = Q::IS_ARCHETYPAL && F::IS_ARCHETYPAL;
let min_size = if archetype_query { max_size } else { 0 };
(min_size, Some(max_size))
Expand Down Expand Up @@ -351,11 +345,16 @@ impl<'w, 's, Q: WorldQuery, F: ReadOnlyWorldQuery, const K: usize>
return None;
}

// first, iterate from last to first until next item is found
// PERF: can speed up the following code using `cursor.remaining()` instead of `next_item.is_none()`
// when Q::IS_ARCHETYPAL && F::IS_ARCHETYPAL
//
// let `i` be the index of `c`, the last cursor in `self.cursors` that
// returns `K-i` or more elements.
// Make cursor in index `j` for all `j` in `[i, K)` a copy of `c` advanced `j-i+1` times.
// If no such `c` exists, return `None`
'outer: for i in (0..K).rev() {
match self.cursors[i].next(self.tables, self.archetypes, self.query_state) {
Some(_) => {
// walk forward up to last element, propagating cursor state forward
for j in (i + 1)..K {
self.cursors[j] = self.cursors[j - 1].clone_cursor();
match self.cursors[j].next(self.tables, self.archetypes, self.query_state) {
Expand Down Expand Up @@ -409,36 +408,29 @@ impl<'w, 's, Q: ReadOnlyWorldQuery, F: ReadOnlyWorldQuery, const K: usize> Itera
}

fn size_hint(&self) -> (usize, Option<usize>) {
if K == 0 {
return (0, Some(0));
}

let max_size: usize = self
.query_state
.matched_archetype_ids
.iter()
.map(|id| self.archetypes[*id].len())
.sum();

if max_size < K {
return (0, Some(0));
}
if max_size == K {
return (1, Some(1));
}

// binomial coefficient: (n ; k) = n! / k!(n-k)! = (n*n-1*...*n-k+1) / k!
// See https://en.wikipedia.org/wiki/Binomial_coefficient
// See https://blog.plover.com/math/choose.html for implementation
// It was chosen to reduce overflow potential.
fn choose(n: usize, k: usize) -> Option<usize> {
if k > n || n == 0 {
return Some(0);
}
let k = k.min(n - k);
let ks = 1..=k;
let ns = (n - k + 1..=n).rev();
ks.zip(ns)
.try_fold(1_usize, |acc, (k, n)| Some(acc.checked_mul(n)? / k))
}
let smallest = K.min(max_size - K);
let max_combinations = choose(max_size, smallest);
// sum_i=0..k choose(cursors[i].remaining, k-i)
let max_combinations = self
.cursors
.iter()
.enumerate()
.try_fold(0, |acc, (i, cursor)| {
let n = cursor.max_remaining(self.tables, self.archetypes);
Some(acc + choose(n, K - i)?)
});

let archetype_query = F::IS_ARCHETYPAL && Q::IS_ARCHETYPAL;
let known_max = max_combinations.unwrap_or(usize::MAX);
Expand All @@ -452,11 +444,7 @@ where
F: ArchetypeFilter,
{
fn len(&self) -> usize {
self.query_state
.matched_archetype_ids
.iter()
.map(|id| self.archetypes[*id].len())
.sum()
self.size_hint().0
}
}

Expand Down Expand Up @@ -571,6 +559,21 @@ impl<'w, 's, Q: WorldQuery, F: ReadOnlyWorldQuery> QueryIterationCursor<'w, 's,
}
}

/// How many values will this cursor return at most?
///
/// Note that if `Q::IS_ARCHETYPAL && F::IS_ARCHETYPAL`, the return value
/// will be **the exact count of remaining values**.
fn max_remaining(&self, tables: &'w Tables, archetypes: &'w Archetypes) -> usize {
let remaining_matched: usize = if Self::IS_DENSE {
let ids = self.table_id_iter.clone();
ids.map(|id| tables[*id].entity_count()).sum()
} else {
let ids = self.archetype_id_iter.clone();
ids.map(|id| archetypes[*id].len()).sum()
};
remaining_matched + self.current_len - self.current_index
}

// NOTE: If you are changing query iteration code, remember to update the following places, where relevant:
// QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter, QueryState::for_each_unchecked_manual, QueryState::par_for_each_unchecked_manual
/// # Safety
Expand Down
143 changes: 35 additions & 108 deletions crates/bevy_ecs/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,100 +96,6 @@ mod tests {

#[test]
fn query_filtered_exactsizeiterator_len() {
fn assert_all_sizes_iterator_equal(
iterator: impl ExactSizeIterator,
expected_size: usize,
query_type: &'static str,
) {
let len = iterator.len();
let size_hint_0 = iterator.size_hint().0;
let size_hint_1 = iterator.size_hint().1;
// `count` tests that not only it is the expected value, but also
// the value is accurate to what the query returns.
let count = iterator.count();
// This will show up when one of the asserts in this function fails
println!(
r#"query declared sizes:
for query: {query_type}
expected: {expected_size}
len: {len}
size_hint().0: {size_hint_0}
size_hint().1: {size_hint_1:?}
count(): {count}"#
);
assert_eq!(len, expected_size);
assert_eq!(size_hint_0, expected_size);
assert_eq!(size_hint_1, Some(expected_size));
assert_eq!(count, expected_size);
}
fn assert_all_sizes_equal<Q, F>(world: &mut World, expected_size: usize)
where
Q: ReadOnlyWorldQuery,
F: ReadOnlyWorldQuery,
F::ReadOnly: ArchetypeFilter,
{
let mut query = world.query_filtered::<Q, F>();
let iter = query.iter(world);
let query_type = type_name::<QueryState<Q, F>>();
assert_all_sizes_iterator_equal(iter, expected_size, query_type);
}

let mut world = World::new();
world.spawn((A(1), B(1)));
world.spawn(A(2));
world.spawn(A(3));

assert_all_sizes_equal::<&A, With<B>>(&mut world, 1);
assert_all_sizes_equal::<&A, Without<B>>(&mut world, 2);

let mut world = World::new();
world.spawn((A(1), B(1), C(1)));
world.spawn((A(2), B(2)));
world.spawn((A(3), B(3)));
world.spawn((A(4), C(4)));
world.spawn((A(5), C(5)));
world.spawn((A(6), C(6)));
world.spawn(A(7));
world.spawn(A(8));
world.spawn(A(9));
world.spawn(A(10));

// With/Without for B and C
assert_all_sizes_equal::<&A, With<B>>(&mut world, 3);
assert_all_sizes_equal::<&A, With<C>>(&mut world, 4);
assert_all_sizes_equal::<&A, Without<B>>(&mut world, 7);
assert_all_sizes_equal::<&A, Without<C>>(&mut world, 6);

// With/Without (And) combinations
assert_all_sizes_equal::<&A, (With<B>, With<C>)>(&mut world, 1);
assert_all_sizes_equal::<&A, (With<B>, Without<C>)>(&mut world, 2);
assert_all_sizes_equal::<&A, (Without<B>, With<C>)>(&mut world, 3);
assert_all_sizes_equal::<&A, (Without<B>, Without<C>)>(&mut world, 4);

// With/Without Or<()> combinations
assert_all_sizes_equal::<&A, Or<(With<B>, With<C>)>>(&mut world, 6);
assert_all_sizes_equal::<&A, Or<(With<B>, Without<C>)>>(&mut world, 7);
assert_all_sizes_equal::<&A, Or<(Without<B>, With<C>)>>(&mut world, 8);
assert_all_sizes_equal::<&A, Or<(Without<B>, Without<C>)>>(&mut world, 9);
assert_all_sizes_equal::<&A, (Or<(With<B>,)>, Or<(With<C>,)>)>(&mut world, 1);
assert_all_sizes_equal::<&A, Or<(Or<(With<B>, With<C>)>, With<D>)>>(&mut world, 6);

for i in 11..14 {
world.spawn((A(i), D(i)));
}

assert_all_sizes_equal::<&A, Or<(Or<(With<B>, With<C>)>, With<D>)>>(&mut world, 9);
assert_all_sizes_equal::<&A, Or<(Or<(With<B>, With<C>)>, Without<D>)>>(&mut world, 10);

// a fair amount of entities
for i in 14..20 {
world.spawn((C(i), D(i)));
}
assert_all_sizes_equal::<Entity, (With<C>, With<D>)>(&mut world, 6);
}

#[test]
fn query_filtered_combination_size() {
fn choose(n: usize, k: usize) -> usize {
if n == 0 || k == 0 || n < k {
return 0;
Expand All @@ -200,52 +106,73 @@ mod tests {
}
fn assert_combination<Q, F, const K: usize>(world: &mut World, expected_size: usize)
where
Q: WorldQuery,
Q: ReadOnlyWorldQuery,
F: ReadOnlyWorldQuery,
F::ReadOnly: ArchetypeFilter,
{
let mut query = world.query_filtered::<Q, F>();
let iter = query.iter_combinations::<K>(world);
let query_type = type_name::<QueryCombinationIter<Q, F, K>>();
assert_all_sizes_iterator_equal(iter, expected_size, query_type);
let iter = query.iter_combinations::<K>(world);
assert_all_sizes_iterator_equal(iter, expected_size, 0, query_type);
let iter = query.iter_combinations::<K>(world);
assert_all_sizes_iterator_equal(iter, expected_size, 1, query_type);
let iter = query.iter_combinations::<K>(world);
assert_all_sizes_iterator_equal(iter, expected_size, 5, query_type);
}
fn assert_all_sizes_equal<Q, F>(world: &mut World, expected_size: usize)
where
Q: WorldQuery,
Q: ReadOnlyWorldQuery,
F: ReadOnlyWorldQuery,
F::ReadOnly: ArchetypeFilter,
{
let mut query = world.query_filtered::<Q, F>();
let iter = query.iter(world);
let query_type = type_name::<QueryState<Q, F>>();
assert_all_sizes_iterator_equal(iter, expected_size, query_type);
assert_all_exact_sizes_iterator_equal(query.iter(world), expected_size, 0, query_type);
assert_all_exact_sizes_iterator_equal(query.iter(world), expected_size, 1, query_type);
assert_all_exact_sizes_iterator_equal(query.iter(world), expected_size, 5, query_type);

let expected = expected_size;
assert_combination::<Q, F, 0>(world, choose(expected, 0));
assert_combination::<Q, F, 1>(world, choose(expected, 1));
assert_combination::<Q, F, 2>(world, choose(expected, 2));
assert_combination::<Q, F, 5>(world, choose(expected, 5));
assert_combination::<Q, F, 43>(world, choose(expected, 43));
assert_combination::<Q, F, 128>(world, choose(expected, 128));
assert_combination::<Q, F, 64>(world, choose(expected, 64));
}
fn assert_all_exact_sizes_iterator_equal(
iterator: impl ExactSizeIterator,
expected_size: usize,
skip: usize,
query_type: &'static str,
) {
let len = iterator.len();
println!("len: {len}");
assert_all_sizes_iterator_equal(iterator, expected_size, skip, query_type);
assert_eq!(len, expected_size);
}
fn assert_all_sizes_iterator_equal(
iterator: impl Iterator,
mut iterator: impl Iterator,
expected_size: usize,
skip: usize,
query_type: &'static str,
) {
let expected_size = expected_size.saturating_sub(skip);
for _ in 0..skip {
iterator.next();
}
let size_hint_0 = iterator.size_hint().0;
let size_hint_1 = iterator.size_hint().1;
// `count` tests that not only it is the expected value, but also
// the value is accurate to what the query returns.
let count = iterator.count();
// This will show up when one of the asserts in this function fails
println!(
r#"query declared sizes:
for query: {query_type}
expected: {expected_size}
size_hint().0: {size_hint_0}
size_hint().1: {size_hint_1:?}
count(): {count}"#
"query declared sizes: \n\
for query: {query_type} \n\
expected: {expected_size} \n\
size_hint().0: {size_hint_0} \n\
size_hint().1: {size_hint_1:?} \n\
count(): {count}"
);
assert_eq!(size_hint_0, expected_size);
assert_eq!(size_hint_1, Some(expected_size));
Expand Down

0 comments on commit 1321cba

Please sign in to comment.