diff --git a/tokio-test/src/lib.rs b/tokio-test/src/lib.rs index de3f0864a94..87e63861210 100644 --- a/tokio-test/src/lib.rs +++ b/tokio-test/src/lib.rs @@ -12,6 +12,7 @@ //! Tokio and Futures based testing utilities pub mod io; +pub mod stream_mock; mod macros; pub mod task; diff --git a/tokio-test/src/stream_mock.rs b/tokio-test/src/stream_mock.rs new file mode 100644 index 00000000000..0426470af27 --- /dev/null +++ b/tokio-test/src/stream_mock.rs @@ -0,0 +1,168 @@ +#![cfg(not(loom))] + +//! A mock stream implementing [`Stream`]. +//! +//! # Overview +//! This crate provides a `StreamMock` that can be used to test code that interacts with streams. +//! It allows you to mock the behavior of a stream and control the items it yields and the waiting +//! intervals between items. +//! +//! # Usage +//! To use the `StreamMock`, you need to create a builder using[`StreamMockBuilder`]. The builder +//! allows you to enqueue actions such as returning items or waiting for a certain duration. +//! +//! # Example +//! ```rust +//! +//! use futures_util::StreamExt; +//! use std::time::Duration; +//! use tokio_test::stream_mock::StreamMockBuilder; +//! +//! async fn test_stream_mock_wait() { +//! let mut stream_mock = StreamMockBuilder::new() +//! .next(1) +//! .wait(Duration::from_millis(300)) +//! .next(2) +//! .build(); +//! +//! assert_eq!(stream_mock.next().await, Some(1)); +//! let start = std::time::Instant::now(); +//! assert_eq!(stream_mock.next().await, Some(2)); +//! let elapsed = start.elapsed(); +//! assert!(elapsed >= Duration::from_millis(300)); +//! assert_eq!(stream_mock.next().await, None); +//! } +//! ``` + +use std::collections::VecDeque; +use std::pin::Pin; +use std::task::Poll; +use std::time::Duration; + +use futures_core::{ready, Stream}; +use std::future::Future; +use tokio::time::{sleep_until, Instant, Sleep}; + +#[derive(Debug, Clone)] +enum Action { + Next(T), + Wait(Duration), +} + +/// A builder for [`StreamMock`] +#[derive(Debug, Clone)] +pub struct StreamMockBuilder { + actions: VecDeque>, +} + +impl StreamMockBuilder { + /// Create a new empty [`StreamMockBuilder`] + pub fn new() -> Self { + StreamMockBuilder::default() + } + + /// Queue an item to be returned by the stream + pub fn next(mut self, value: T) -> Self { + self.actions.push_back(Action::Next(value)); + self + } + + // Queue an item to be consumed by the sink, + // commented out until Sink is implemented. + // + // pub fn consume(mut self, value: T) -> Self { + // self.actions.push_back(Action::Consume(value)); + // self + // } + + /// Queue the stream to wait for a duration + pub fn wait(mut self, duration: Duration) -> Self { + self.actions.push_back(Action::Wait(duration)); + self + } + + /// Build the [`StreamMock`] + pub fn build(self) -> StreamMock { + StreamMock { + actions: self.actions, + sleep: None, + } + } +} + +impl Default for StreamMockBuilder { + fn default() -> Self { + StreamMockBuilder { + actions: VecDeque::new(), + } + } +} + +/// A mock stream implementing [`Stream`] +/// +/// See [`StreamMockBuilder`] for more information. +#[derive(Debug)] +pub struct StreamMock { + actions: VecDeque>, + sleep: Option>>, +} + +impl StreamMock { + fn next_action(&mut self) -> Option> { + self.actions.pop_front() + } +} + +impl Stream for StreamMock { + type Item = T; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + // Try polling the sleep future first + if let Some(ref mut sleep) = self.sleep { + ready!(Pin::new(sleep).poll(cx)); + // Since we're ready, discard the sleep future + self.sleep.take(); + } + + match self.next_action() { + Some(action) => match action { + Action::Next(item) => Poll::Ready(Some(item)), + Action::Wait(duration) => { + // Set up a sleep future and schedule this future to be polled again for it. + self.sleep = Some(Box::pin(sleep_until(Instant::now() + duration))); + cx.waker().wake_by_ref(); + + Poll::Pending + } + }, + None => Poll::Ready(None), + } + } +} + +impl Drop for StreamMock { + fn drop(&mut self) { + // Avoid double panicking to make debugging easier. + if std::thread::panicking() { + return; + } + + let undropped_count = self + .actions + .iter() + .filter(|action| match action { + Action::Next(_) => true, + Action::Wait(_) => false, + }) + .count(); + + assert!( + undropped_count == 0, + "StreamMock was dropped before all actions were consumed, {} actions were not consumed", + undropped_count + ); + } +} diff --git a/tokio-test/tests/stream_mock.rs b/tokio-test/tests/stream_mock.rs new file mode 100644 index 00000000000..a54ea838a5b --- /dev/null +++ b/tokio-test/tests/stream_mock.rs @@ -0,0 +1,50 @@ +use futures_util::StreamExt; +use std::time::Duration; +use tokio_test::stream_mock::StreamMockBuilder; + +#[tokio::test] +async fn test_stream_mock_empty() { + let mut stream_mock = StreamMockBuilder::::new().build(); + + assert_eq!(stream_mock.next().await, None); + assert_eq!(stream_mock.next().await, None); +} + +#[tokio::test] +async fn test_stream_mock_items() { + let mut stream_mock = StreamMockBuilder::new().next(1).next(2).build(); + + assert_eq!(stream_mock.next().await, Some(1)); + assert_eq!(stream_mock.next().await, Some(2)); + assert_eq!(stream_mock.next().await, None); +} + +#[tokio::test] +async fn test_stream_mock_wait() { + let mut stream_mock = StreamMockBuilder::new() + .next(1) + .wait(Duration::from_millis(300)) + .next(2) + .build(); + + assert_eq!(stream_mock.next().await, Some(1)); + let start = std::time::Instant::now(); + assert_eq!(stream_mock.next().await, Some(2)); + let elapsed = start.elapsed(); + assert!(elapsed >= Duration::from_millis(300)); + assert_eq!(stream_mock.next().await, None); +} + +#[tokio::test] +#[should_panic(expected = "StreamMock was dropped before all actions were consumed")] +async fn test_stream_mock_drop_without_consuming_all() { + let stream_mock = StreamMockBuilder::new().next(1).next(2).build(); + drop(stream_mock); +} + +#[tokio::test] +#[should_panic(expected = "test panic was not masked")] +async fn test_stream_mock_drop_during_panic_doesnt_mask_panic() { + let _stream_mock = StreamMockBuilder::new().next(1).next(2).build(); + panic!("test panic was not masked"); +}