Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Merged by Bors] - Fix size_hint for partially consumed QueryIter and QueryCombinationIter #5214

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 37 additions & 34 deletions crates/bevy_ecs/src/query/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,7 @@ impl<'w, 's, Q: WorldQuery, F: ReadOnlyWorldQuery> Iterator for QueryIter<'w, 's
}

fn size_hint(&self) -> (usize, Option<usize>) {
let max_size = self
.query_state
.matched_archetype_ids
.iter()
.map(|id| self.archetypes[*id].len())
.sum();

let max_size = self.cursor.max_remaining(self.tables, self.archetypes);
let archetype_query = Q::IS_ARCHETYPAL && F::IS_ARCHETYPAL;
let min_size = if archetype_query { max_size } else { 0 };
(min_size, Some(max_size))
Expand Down Expand Up @@ -348,11 +342,16 @@ impl<'w, 's, Q: WorldQuery, F: ReadOnlyWorldQuery, const K: usize>
return None;
}

// first, iterate from last to first until next item is found
// PERF: can speed up the following code using `cursor.remaining()` instead of `next_item.is_none()`
// when Q::IS_ARCHETYPAL && F::IS_ARCHETYPAL
//
// let `i` be the index of `c`, the last cursor in `self.cursors` that
// returns `K-i` or more elements.
// Make cursor in index `j` for all `j` in `[i, K)` a copy of `c` advanced `j-i+1` times.
// If no such `c` exists, return `None`
'outer: for i in (0..K).rev() {
match self.cursors[i].next(self.tables, self.archetypes, self.query_state) {
Some(_) => {
// walk forward up to last element, propagating cursor state forward
for j in (i + 1)..K {
self.cursors[j] = self.cursors[j - 1].clone_cursor();
match self.cursors[j].next(self.tables, self.archetypes, self.query_state) {
Expand Down Expand Up @@ -406,36 +405,29 @@ impl<'w, 's, Q: ReadOnlyWorldQuery, F: ReadOnlyWorldQuery, const K: usize> Itera
}

fn size_hint(&self) -> (usize, Option<usize>) {
if K == 0 {
return (0, Some(0));
}

let max_size: usize = self
.query_state
.matched_archetype_ids
.iter()
.map(|id| self.archetypes[*id].len())
.sum();

if max_size < K {
return (0, Some(0));
}
if max_size == K {
return (1, Some(1));
}

// binomial coefficient: (n ; k) = n! / k!(n-k)! = (n*n-1*...*n-k+1) / k!
// See https://en.wikipedia.org/wiki/Binomial_coefficient
// See https://blog.plover.com/math/choose.html for implementation
// It was chosen to reduce overflow potential.
fn choose(n: usize, k: usize) -> Option<usize> {
if k > n || n == 0 {
return Some(0);
}
let k = k.min(n - k);
let ks = 1..=k;
let ns = (n - k + 1..=n).rev();
ks.zip(ns)
.try_fold(1_usize, |acc, (k, n)| Some(acc.checked_mul(n)? / k))
}
let smallest = K.min(max_size - K);
let max_combinations = choose(max_size, smallest);
// sum_i=0..k choose(cursors[i].remaining, k-i)
let max_combinations = self
.cursors
.iter()
.enumerate()
.try_fold(0, |acc, (i, cursor)| {
let n = cursor.max_remaining(self.tables, self.archetypes);
Some(acc + choose(n, K - i)?)
});

let archetype_query = F::IS_ARCHETYPAL && Q::IS_ARCHETYPAL;
let known_max = max_combinations.unwrap_or(usize::MAX);
Expand All @@ -449,11 +441,7 @@ where
F: ArchetypeFilter,
{
fn len(&self) -> usize {
self.query_state
.matched_archetype_ids
.iter()
.map(|id| self.archetypes[*id].len())
.sum()
self.size_hint().0
}
}

Expand Down Expand Up @@ -568,6 +556,21 @@ impl<'w, 's, Q: WorldQuery, F: ReadOnlyWorldQuery> QueryIterationCursor<'w, 's,
}
}

/// How many values will this cursor return at most?
///
/// Note that if `Q::IS_ARCHETYPAL && F::IS_ARCHETYPAL`, the return value
/// will be **the exact count of remaining values**.
fn max_remaining(&self, tables: &'w Tables, archetypes: &'w Archetypes) -> usize {
let remaining_matched: usize = if Self::IS_DENSE {
let ids = self.table_id_iter.clone();
ids.map(|id| tables[*id].entity_count()).sum()
} else {
let ids = self.archetype_id_iter.clone();
ids.map(|id| archetypes[*id].len()).sum()
};
nicopap marked this conversation as resolved.
Show resolved Hide resolved
remaining_matched + self.current_len - self.current_index
}

// NOTE: If you are changing query iteration code, remember to update the following places, where relevant:
// QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter, QueryState::for_each_unchecked_manual, QueryState::par_for_each_unchecked_manual
/// # Safety
Expand Down
143 changes: 35 additions & 108 deletions crates/bevy_ecs/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,100 +57,6 @@ mod tests {

#[test]
fn query_filtered_exactsizeiterator_len() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this test was redundant with query_filtered_combination_size because we merged the two, and is why it's deleted.

fn assert_all_sizes_iterator_equal(
iterator: impl ExactSizeIterator,
expected_size: usize,
query_type: &'static str,
) {
let len = iterator.len();
let size_hint_0 = iterator.size_hint().0;
let size_hint_1 = iterator.size_hint().1;
// `count` tests that not only it is the expected value, but also
// the value is accurate to what the query returns.
let count = iterator.count();
// This will show up when one of the asserts in this function fails
println!(
r#"query declared sizes:
for query: {query_type}
expected: {expected_size}
len: {len}
size_hint().0: {size_hint_0}
size_hint().1: {size_hint_1:?}
count(): {count}"#
);
assert_eq!(len, expected_size);
assert_eq!(size_hint_0, expected_size);
assert_eq!(size_hint_1, Some(expected_size));
assert_eq!(count, expected_size);
}
fn assert_all_sizes_equal<Q, F>(world: &mut World, expected_size: usize)
where
Q: ReadOnlyWorldQuery,
F: ReadOnlyWorldQuery,
F::ReadOnly: ArchetypeFilter,
{
let mut query = world.query_filtered::<Q, F>();
let iter = query.iter(world);
let query_type = type_name::<QueryState<Q, F>>();
assert_all_sizes_iterator_equal(iter, expected_size, query_type);
}

let mut world = World::new();
world.spawn((A(1), B(1)));
world.spawn(A(2));
world.spawn(A(3));

assert_all_sizes_equal::<&A, With<B>>(&mut world, 1);
assert_all_sizes_equal::<&A, Without<B>>(&mut world, 2);

let mut world = World::new();
world.spawn((A(1), B(1), C(1)));
world.spawn((A(2), B(2)));
world.spawn((A(3), B(3)));
world.spawn((A(4), C(4)));
world.spawn((A(5), C(5)));
world.spawn((A(6), C(6)));
world.spawn(A(7));
world.spawn(A(8));
world.spawn(A(9));
world.spawn(A(10));

// With/Without for B and C
assert_all_sizes_equal::<&A, With<B>>(&mut world, 3);
assert_all_sizes_equal::<&A, With<C>>(&mut world, 4);
assert_all_sizes_equal::<&A, Without<B>>(&mut world, 7);
assert_all_sizes_equal::<&A, Without<C>>(&mut world, 6);

// With/Without (And) combinations
assert_all_sizes_equal::<&A, (With<B>, With<C>)>(&mut world, 1);
assert_all_sizes_equal::<&A, (With<B>, Without<C>)>(&mut world, 2);
assert_all_sizes_equal::<&A, (Without<B>, With<C>)>(&mut world, 3);
assert_all_sizes_equal::<&A, (Without<B>, Without<C>)>(&mut world, 4);

// With/Without Or<()> combinations
assert_all_sizes_equal::<&A, Or<(With<B>, With<C>)>>(&mut world, 6);
assert_all_sizes_equal::<&A, Or<(With<B>, Without<C>)>>(&mut world, 7);
assert_all_sizes_equal::<&A, Or<(Without<B>, With<C>)>>(&mut world, 8);
assert_all_sizes_equal::<&A, Or<(Without<B>, Without<C>)>>(&mut world, 9);
assert_all_sizes_equal::<&A, (Or<(With<B>,)>, Or<(With<C>,)>)>(&mut world, 1);
assert_all_sizes_equal::<&A, Or<(Or<(With<B>, With<C>)>, With<D>)>>(&mut world, 6);

for i in 11..14 {
world.spawn((A(i), D(i)));
}

assert_all_sizes_equal::<&A, Or<(Or<(With<B>, With<C>)>, With<D>)>>(&mut world, 9);
assert_all_sizes_equal::<&A, Or<(Or<(With<B>, With<C>)>, Without<D>)>>(&mut world, 10);

// a fair amount of entities
for i in 14..20 {
world.spawn((C(i), D(i)));
}
assert_all_sizes_equal::<Entity, (With<C>, With<D>)>(&mut world, 6);
}

#[test]
fn query_filtered_combination_size() {
fn choose(n: usize, k: usize) -> usize {
if n == 0 || k == 0 || n < k {
return 0;
Expand All @@ -161,52 +67,73 @@ mod tests {
}
fn assert_combination<Q, F, const K: usize>(world: &mut World, expected_size: usize)
where
Q: WorldQuery,
Q: ReadOnlyWorldQuery,
F: ReadOnlyWorldQuery,
F::ReadOnly: ArchetypeFilter,
{
let mut query = world.query_filtered::<Q, F>();
let iter = query.iter_combinations::<K>(world);
let query_type = type_name::<QueryCombinationIter<Q, F, K>>();
assert_all_sizes_iterator_equal(iter, expected_size, query_type);
let iter = query.iter_combinations::<K>(world);
assert_all_sizes_iterator_equal(iter, expected_size, 0, query_type);
let iter = query.iter_combinations::<K>(world);
assert_all_sizes_iterator_equal(iter, expected_size, 1, query_type);
let iter = query.iter_combinations::<K>(world);
assert_all_sizes_iterator_equal(iter, expected_size, 5, query_type);
}
fn assert_all_sizes_equal<Q, F>(world: &mut World, expected_size: usize)
where
Q: WorldQuery,
Q: ReadOnlyWorldQuery,
F: ReadOnlyWorldQuery,
F::ReadOnly: ArchetypeFilter,
{
let mut query = world.query_filtered::<Q, F>();
let iter = query.iter(world);
let query_type = type_name::<QueryState<Q, F>>();
assert_all_sizes_iterator_equal(iter, expected_size, query_type);
assert_all_exact_sizes_iterator_equal(query.iter(world), expected_size, 0, query_type);
assert_all_exact_sizes_iterator_equal(query.iter(world), expected_size, 1, query_type);
assert_all_exact_sizes_iterator_equal(query.iter(world), expected_size, 5, query_type);

let expected = expected_size;
assert_combination::<Q, F, 0>(world, choose(expected, 0));
assert_combination::<Q, F, 1>(world, choose(expected, 1));
assert_combination::<Q, F, 2>(world, choose(expected, 2));
assert_combination::<Q, F, 5>(world, choose(expected, 5));
assert_combination::<Q, F, 43>(world, choose(expected, 43));
assert_combination::<Q, F, 128>(world, choose(expected, 128));
assert_combination::<Q, F, 64>(world, choose(expected, 64));
}
fn assert_all_exact_sizes_iterator_equal(
iterator: impl ExactSizeIterator,
expected_size: usize,
skip: usize,
query_type: &'static str,
) {
let len = iterator.len();
println!("len: {len}");
assert_all_sizes_iterator_equal(iterator, expected_size, skip, query_type);
assert_eq!(len, expected_size);
}
fn assert_all_sizes_iterator_equal(
iterator: impl Iterator,
mut iterator: impl Iterator,
expected_size: usize,
skip: usize,
query_type: &'static str,
) {
let expected_size = expected_size.saturating_sub(skip);
for _ in 0..skip {
iterator.next();
}
nicopap marked this conversation as resolved.
Show resolved Hide resolved
let size_hint_0 = iterator.size_hint().0;
let size_hint_1 = iterator.size_hint().1;
// `count` tests that not only it is the expected value, but also
// the value is accurate to what the query returns.
let count = iterator.count();
// This will show up when one of the asserts in this function fails
println!(
r#"query declared sizes:
for query: {query_type}
expected: {expected_size}
size_hint().0: {size_hint_0}
size_hint().1: {size_hint_1:?}
count(): {count}"#
"query declared sizes: \n\
for query: {query_type} \n\
expected: {expected_size} \n\
size_hint().0: {size_hint_0} \n\
size_hint().1: {size_hint_1:?} \n\
count(): {count}"
);
assert_eq!(size_hint_0, expected_size);
assert_eq!(size_hint_1, Some(expected_size));
Expand Down