Skip to content

Commit

Permalink
FlattenUnordered: always replace inner wakers (#2726)
Browse files Browse the repository at this point in the history
  • Loading branch information
olegnn authored and taiki-e committed Mar 30, 2023
1 parent 890f893 commit a730a19
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 12 deletions.
23 changes: 11 additions & 12 deletions futures-util/src/stream/stream/flatten_unordered.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,8 @@ impl WrappedWaker {
///
/// This function will modify waker's `inner_waker` via `UnsafeCell`, so
/// it should be used only during `POLLING` phase by one thread at the time.
unsafe fn replace_waker(self_arc: &mut Arc<Self>, cx: &Context<'_>) -> Waker {
unsafe fn replace_waker(self_arc: &mut Arc<Self>, cx: &Context<'_>) {
*self_arc.inner_waker.get() = cx.waker().clone().into();
waker(self_arc.clone())
}

/// Attempts to start the waking process for the waker with the given value.
Expand Down Expand Up @@ -414,6 +413,12 @@ where
}
};

// Safety: now state is `POLLING`.
unsafe {
WrappedWaker::replace_waker(this.stream_waker, cx);
WrappedWaker::replace_waker(this.inner_streams_waker, cx)
};

if poll_state_value & NEED_TO_POLL_STREAM != NONE {
let mut stream_waker = None;

Expand All @@ -431,13 +436,9 @@ where

break;
} else {
// Initialize base stream waker if it's not yet initialized
if stream_waker.is_none() {
// Safety: now state is `POLLING`.
stream_waker
.replace(unsafe { WrappedWaker::replace_waker(this.stream_waker, cx) });
}
let mut cx = Context::from_waker(stream_waker.as_ref().unwrap());
let mut cx = Context::from_waker(
stream_waker.get_or_insert_with(|| waker(this.stream_waker.clone())),
);

match this.stream.as_mut().poll_next(&mut cx) {
Poll::Ready(Some(item)) => {
Expand Down Expand Up @@ -475,9 +476,7 @@ where
}

if poll_state_value & NEED_TO_POLL_INNER_STREAMS != NONE {
// Safety: now state is `POLLING`.
let inner_streams_waker =
unsafe { WrappedWaker::replace_waker(this.inner_streams_waker, cx) };
let inner_streams_waker = waker(this.inner_streams_waker.clone());
let mut cx = Context::from_waker(&inner_streams_waker);

match this.inner_streams.as_mut().poll_next(&mut cx) {
Expand Down
1 change: 1 addition & 0 deletions futures/tests/no-std/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![cfg(nightly)]
#![no_std]
#![allow(useless_anonymous_reexport)]

#[cfg(feature = "futures-core-alloc")]
#[cfg(target_has_atomic = "ptr")]
Expand Down
43 changes: 43 additions & 0 deletions futures/tests/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use futures::stream::{self, StreamExt};
use futures::task::Poll;
use futures::{ready, FutureExt};
use futures_core::Stream;
use futures_executor::ThreadPool;
use futures_test::task::noop_context;

#[test]
Expand Down Expand Up @@ -65,6 +66,7 @@ fn flatten_unordered() {
use futures::task::*;
use std::convert::identity;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread;
use std::time::Duration;

Expand Down Expand Up @@ -322,6 +324,47 @@ fn flatten_unordered() {
assert_eq!(values, (0..60).collect::<Vec<u8>>());
});
}

// nested `flatten_unordered`
let te = ThreadPool::new().unwrap();
let handle = te
.spawn_with_handle(async move {
let inner = stream::iter(0..10)
.then(|_| {
let task = Arc::new(AtomicBool::new(false));
let mut spawned = false;

future::poll_fn(move |cx| {
if !spawned {
let waker = cx.waker().clone();
let task = task.clone();

std::thread::spawn(move || {
std::thread::sleep(Duration::from_millis(500));
task.store(true, Ordering::Release);

waker.wake_by_ref()
});
spawned = true;
}

if task.load(Ordering::Acquire) {
Poll::Ready(Some(()))
} else {
Poll::Pending
}
})
})
.map(|_| stream::once(future::ready(())))
.flatten_unordered(None);

let stream = stream::once(future::ready(inner)).flatten_unordered(None);

assert_eq!(stream.count().await, 10);
})
.unwrap();

block_on(handle);
}

#[test]
Expand Down

0 comments on commit a730a19

Please sign in to comment.