Skip to content

Commit

Permalink
Auto merge of rust-lang#104818 - scottmcm:refactor-extend-func, r=the…
Browse files Browse the repository at this point in the history
…8472

Stop peeling the last iteration of the loop in `Vec::resize_with`

`resize_with` uses the `ExtendWith` code that peels the last iteration:
https://github.com/rust-lang/rust/blob/341d8b8a2c290b4535e965867e876b095461ff6e/library/alloc/src/vec/mod.rs#L2525-L2529

But that's kinda weird for `ExtendFunc` because it does the same thing on the last iteration anyway:
https://github.com/rust-lang/rust/blob/341d8b8a2c290b4535e965867e876b095461ff6e/library/alloc/src/vec/mod.rs#L2494-L2502

So this just has it use the normal `extend`-from-`TrustedLen` code instead.

r? `@ghost`
  • Loading branch information
bors committed Nov 27, 2022
2 parents c0e9c86 + 9d68a1a commit faf1891
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 44 deletions.
46 changes: 35 additions & 11 deletions library/alloc/src/vec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2163,7 +2163,7 @@ impl<T, A: Allocator> Vec<T, A> {
{
let len = self.len();
if new_len > len {
self.extend_with(new_len - len, ExtendFunc(f));
self.extend_trusted(iter::repeat_with(f).take(new_len - len));
} else {
self.truncate(new_len);
}
Expand Down Expand Up @@ -2491,16 +2491,6 @@ impl<T: Clone> ExtendWith<T> for ExtendElement<T> {
}
}

struct ExtendFunc<F>(F);
impl<T, F: FnMut() -> T> ExtendWith<T> for ExtendFunc<F> {
fn next(&mut self) -> T {
(self.0)()
}
fn last(mut self) -> T {
(self.0)()
}
}

impl<T, A: Allocator> Vec<T, A> {
#[cfg(not(no_global_oom_handling))]
/// Extend the vector by `n` values, using the given generator.
Expand Down Expand Up @@ -2870,6 +2860,40 @@ impl<T, A: Allocator> Vec<T, A> {
}
}

// specific extend for `TrustedLen` iterators, called both by the specializations
// and internal places where resolving specialization makes compilation slower
#[cfg(not(no_global_oom_handling))]
fn extend_trusted(&mut self, iterator: impl iter::TrustedLen<Item = T>) {
let (low, high) = iterator.size_hint();
if let Some(additional) = high {
debug_assert_eq!(
low,
additional,
"TrustedLen iterator's size hint is not exact: {:?}",
(low, high)
);
self.reserve(additional);
unsafe {
let ptr = self.as_mut_ptr();
let mut local_len = SetLenOnDrop::new(&mut self.len);
iterator.for_each(move |element| {
ptr::write(ptr.add(local_len.current_len()), element);
// Since the loop executes user code which can panic we have to update
// the length every step to correctly drop what we've written.
// NB can't overflow since we would have had to alloc the address space
local_len.increment_len(1);
});
}
} else {
// Per TrustedLen contract a `None` upper bound means that the iterator length
// truly exceeds usize::MAX, which would eventually lead to a capacity overflow anyway.
// Since the other branch already panics eagerly (via `reserve()`) we do the same here.
// This avoids additional codegen for a fallback code path which would eventually
// panic anyway.
panic!("capacity overflow");
}
}

/// Creates a splicing iterator that replaces the specified range in the vector
/// with the given `replace_with` iterator and yields the removed items.
/// `replace_with` does not need to be the same length as `range`.
Expand Down
5 changes: 5 additions & 0 deletions library/alloc/src/vec/set_len_on_drop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ impl<'a> SetLenOnDrop<'a> {
pub(super) fn increment_len(&mut self, increment: usize) {
self.local_len += increment;
}

#[inline]
pub(super) fn current_len(&self) -> usize {
self.local_len
}
}

impl Drop for SetLenOnDrop<'_> {
Expand Down
34 changes: 2 additions & 32 deletions library/alloc/src/vec/spec_extend.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use crate::alloc::Allocator;
use core::iter::TrustedLen;
use core::ptr::{self};
use core::slice::{self};

use super::{IntoIter, SetLenOnDrop, Vec};
use super::{IntoIter, Vec};

// Specialization trait used for Vec::extend
pub(super) trait SpecExtend<T, I> {
Expand All @@ -24,36 +23,7 @@ where
I: TrustedLen<Item = T>,
{
default fn spec_extend(&mut self, iterator: I) {
// This is the case for a TrustedLen iterator.
let (low, high) = iterator.size_hint();
if let Some(additional) = high {
debug_assert_eq!(
low,
additional,
"TrustedLen iterator's size hint is not exact: {:?}",
(low, high)
);
self.reserve(additional);
unsafe {
let mut ptr = self.as_mut_ptr().add(self.len());
let mut local_len = SetLenOnDrop::new(&mut self.len);
iterator.for_each(move |element| {
ptr::write(ptr, element);
ptr = ptr.add(1);
// Since the loop executes user code which can panic we have to bump the pointer
// after each step.
// NB can't overflow since we would have had to alloc the address space
local_len.increment_len(1);
});
}
} else {
// Per TrustedLen contract a `None` upper bound means that the iterator length
// truly exceeds usize::MAX, which would eventually lead to a capacity overflow anyway.
// Since the other branch already panics eagerly (via `reserve()`) we do the same here.
// This avoids additional codegen for a fallback code path which would eventually
// panic anyway.
panic!("capacity overflow");
}
self.extend_trusted(iterator)
}
}

Expand Down
21 changes: 20 additions & 1 deletion library/core/src/iter/adapters/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ where
#[inline]
fn try_fold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
where
Self: Sized,
Fold: FnMut(Acc, Self::Item) -> R,
R: Try<Output = Acc>,
{
Expand All @@ -100,6 +99,26 @@ where

impl_fold_via_try_fold! { fold -> try_fold }

#[inline]
fn for_each<F: FnMut(Self::Item)>(mut self, f: F) {
// The default implementation would use a unit accumulator, so we can
// avoid a stateful closure by folding over the remaining number
// of items we wish to return instead.
fn check<'a, Item>(
mut action: impl FnMut(Item) + 'a,
) -> impl FnMut(usize, Item) -> Option<usize> + 'a {
move |more, x| {
action(x);
more.checked_sub(1)
}
}

let remaining = self.n;
if remaining > 0 {
self.iter.try_fold(remaining - 1, check(f));
}
}

#[inline]
#[rustc_inherit_overflow_checks]
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
Expand Down
17 changes: 17 additions & 0 deletions library/core/src/iter/sources/repeat_with.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::iter::{FusedIterator, TrustedLen};
use crate::ops::Try;

/// Creates a new iterator that repeats elements of type `A` endlessly by
/// applying the provided closure, the repeater, `F: FnMut() -> A`.
Expand Down Expand Up @@ -89,6 +90,22 @@ impl<A, F: FnMut() -> A> Iterator for RepeatWith<F> {
fn size_hint(&self) -> (usize, Option<usize>) {
(usize::MAX, None)
}

#[inline]
fn try_fold<Acc, Fold, R>(&mut self, mut init: Acc, mut fold: Fold) -> R
where
Fold: FnMut(Acc, Self::Item) -> R,
R: Try<Output = Acc>,
{
// This override isn't strictly needed, but avoids the need to optimize
// away the `next`-always-returns-`Some` and emphasizes that the `?`
// is the only way to exit the loop.

loop {
let item = (self.repeater)();
init = fold(init, item)?;
}
}
}

#[stable(feature = "iterator_repeat_with", since = "1.28.0")]
Expand Down
20 changes: 20 additions & 0 deletions library/core/tests/iter/adapters/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,23 @@ fn test_take_try_folds() {
assert_eq!(iter.try_for_each(Err), Err(2));
assert_eq!(iter.try_for_each(Err), Ok(()));
}

#[test]
fn test_byref_take_consumed_items() {
let mut inner = 10..90;

let mut count = 0;
inner.by_ref().take(0).for_each(|_| count += 1);
assert_eq!(count, 0);
assert_eq!(inner, 10..90);

let mut count = 0;
inner.by_ref().take(10).for_each(|_| count += 1);
assert_eq!(count, 10);
assert_eq!(inner, 20..90);

let mut count = 0;
inner.by_ref().take(100).for_each(|_| count += 1);
assert_eq!(count, 70);
assert_eq!(inner, 90..90);
}
7 changes: 7 additions & 0 deletions src/test/codegen/repeat-trusted-len.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,10 @@ pub fn repeat_take_collect() -> Vec<u8> {
// CHECK: call void @llvm.memset.{{.+}}({{i8\*|ptr}} {{.*}}align 1{{.*}} %{{[0-9]+}}, i8 42, i{{[0-9]+}} 100000, i1 false)
iter::repeat(42).take(100000).collect()
}

// CHECK-LABEL: @repeat_with_take_collect
#[no_mangle]
pub fn repeat_with_take_collect() -> Vec<u8> {
// CHECK: call void @llvm.memset.{{.+}}({{i8\*|ptr}} {{.*}}align 1{{.*}} %{{[0-9]+}}, i8 13, i{{[0-9]+}} 12345, i1 false)
iter::repeat_with(|| 13).take(12345).collect()
}

0 comments on commit faf1891

Please sign in to comment.