Skip to content

Commit

Permalink
Fix lost Waker instances with async stdio streams (#8782)
Browse files Browse the repository at this point in the history
* Fix lost `Waker` instances with async stdio streams

This commit fixes a bug in the `Subscribe` trait implementation for
`AsyncStd{in,out}Stream` structures in the `wasmtime-wasi` crate.
Previously these implementations would create a future for the duration
of a single `poll` but then the future was dropped which could lead to
lost wakeups as the waker is gone after the future is dropped. The fix
was to use a `tokio::sync::Mutex` here instead of a `std::sync::Mutex`
and leave some comments about why contention isn't expected.

Closes #8781

* Reduce sizes used in tests
  • Loading branch information
alexcrichton authored Jun 12, 2024
1 parent 2835a34 commit 34d2a08
Showing 1 changed file with 104 additions and 41 deletions.
145 changes: 104 additions & 41 deletions crates/wasi/src/stdio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@ use crate::{
HostInputStream, HostOutputStream, StreamError, StreamResult, Subscribe, WasiImpl, WasiView,
};
use bytes::Bytes;
use std::future::Future;
use std::io::IsTerminal;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::sync::Arc;
use tokio::sync::Mutex;
use wasmtime::component::Resource;

/// A trait used to represent the standard input to a guest program.
Expand Down Expand Up @@ -61,6 +59,33 @@ impl StdinStream for pipe::ClosedInputStream {
}

/// An impl of [`StdinStream`] built on top of [`crate::pipe::AsyncReadStream`].
//
// Note the usage of `tokio::sync::Mutex` here as opposed to a
// `std::sync::Mutex`. This is intentionally done to implement the `Subscribe`
// variant of this trait. Note that in doing so we're left with the quandry of
// how to implement methods of `HostInputStream` since those methods are not
// `async`. They're currently implemented with `try_lock`, which then raises the
// question of what to do on contention. Currently traps are returned.
//
// Why should it be ok to return a trap? In general concurrency/contention
// shouldn't return a trap since it should be able to happen normally. The
// current assumption, though, is that WASI stdin/stdout streams are special
// enough that the contention case should never come up in practice. Currently
// in WASI there is no actually concurrency, there's just the items in a single
// `Store` and that store owns all of its I/O in a single Tokio task. There's no
// means to actually spawn multiple Tokio tasks that use the same store. This
// means at the very least that there's zero parallelism. Due to the lack of
// multiple tasks that also means that there's no concurrency either.
//
// This `AsyncStdinStream` wrapper is only intended to be used by the WASI
// bindings themselves. It's possible for the host to take this and work with it
// on its own task, but that's niche enough it's not designed for.
//
// Overall that means that the guest is either calling `Subscribe` or it's
// calling `HostInputStream` methods. This means that there should never be
// contention between the two at this time. This may all change in the future
// with WASI 0.3, but perhaps we'll have a better story for stdio at that time
// (see the doc block on the `HostOutputStream` impl below)
pub struct AsyncStdinStream(Arc<Mutex<crate::pipe::AsyncReadStream>>);

impl AsyncStdinStream {
Expand All @@ -79,30 +104,24 @@ impl StdinStream for AsyncStdinStream {
}

impl HostInputStream for AsyncStdinStream {
fn read(&mut self, size: usize) -> Result<bytes::Bytes, crate::StreamError> {
self.0.lock().unwrap().read(size)
fn read(&mut self, size: usize) -> Result<bytes::Bytes, StreamError> {
match self.0.try_lock() {
Ok(mut stream) => stream.read(size),
Err(_) => Err(StreamError::trap("concurrent reads are not supported")),
}
}
fn skip(&mut self, size: usize) -> Result<usize, crate::StreamError> {
self.0.lock().unwrap().skip(size)
fn skip(&mut self, size: usize) -> Result<usize, StreamError> {
match self.0.try_lock() {
Ok(mut stream) => stream.skip(size),
Err(_) => Err(StreamError::trap("concurrent skips are not supported")),
}
}
}

#[async_trait::async_trait]
impl Subscribe for AsyncStdinStream {
fn ready<'a, 'b>(&'a mut self) -> Pin<Box<dyn Future<Output = ()> + Send + 'b>>
where
Self: 'b,
'a: 'b,
{
struct F(AsyncStdinStream);
impl Future for F {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let mut inner = self.0 .0.lock().unwrap();
let mut fut = inner.ready();
fut.as_mut().poll(cx)
}
}
Box::pin(F(Self(self.0.clone())))
async fn ready(&mut self) {
self.0.lock().await.ready().await
}
}

Expand Down Expand Up @@ -300,6 +319,10 @@ impl Subscribe for OutputStream {
/// A wrapper of [`crate::pipe::AsyncWriteStream`] that implements
/// [`StdoutStream`]. Note that the [`HostOutputStream`] impl for this is not
/// correct when used for interleaved async IO.
//
// Note that the use of `tokio::sync::Mutex` here is intentional, in addition to
// the `try_lock()` calls below in the implementation of `HostOutputStream`. For
// more information see the documentation on `AsyncStdinStream`.
pub struct AsyncStdoutStream(Arc<Mutex<crate::pipe::AsyncWriteStream>>);

impl AsyncStdoutStream {
Expand Down Expand Up @@ -334,32 +357,29 @@ impl StdoutStream for AsyncStdoutStream {
// this comment to correct it: sorry about that.
impl HostOutputStream for AsyncStdoutStream {
fn check_write(&mut self) -> Result<usize, StreamError> {
self.0.lock().unwrap().check_write()
match self.0.try_lock() {
Ok(mut stream) => stream.check_write(),
Err(_) => Err(StreamError::trap("concurrent writes are not supported")),
}
}
fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> {
self.0.lock().unwrap().write(bytes)
match self.0.try_lock() {
Ok(mut stream) => stream.write(bytes),
Err(_) => Err(StreamError::trap("concurrent writes not supported yet")),
}
}
fn flush(&mut self) -> Result<(), StreamError> {
self.0.lock().unwrap().flush()
match self.0.try_lock() {
Ok(mut stream) => stream.flush(),
Err(_) => Err(StreamError::trap("concurrent flushes not supported yet")),
}
}
}

#[async_trait::async_trait]
impl Subscribe for AsyncStdoutStream {
fn ready<'a, 'b>(&'a mut self) -> Pin<Box<dyn Future<Output = ()> + Send + 'b>>
where
Self: 'b,
'a: 'b,
{
struct F(AsyncStdoutStream);
impl Future for F {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let mut inner = self.0 .0.lock().unwrap();
let mut fut = inner.ready();
fut.as_mut().poll(cx)
}
}
Box::pin(F(Self(self.0.clone())))
async fn ready(&mut self) {
self.0.lock().await.ready().await
}
}

Expand Down Expand Up @@ -464,6 +484,13 @@ where

#[cfg(test)]
mod test {
use crate::stdio::StdoutStream;
use crate::write_stream::AsyncWriteStream;
use crate::{AsyncStdoutStream, HostOutputStream};
use anyhow::Result;
use bytes::Bytes;
use tokio::io::AsyncReadExt;

#[test]
fn memory_stdin_stream() {
// A StdinStream has the property that there are multiple
Expand Down Expand Up @@ -492,6 +519,7 @@ mod test {
let read4 = view2.read(10).expect("read fourth 10 bytes");
assert_eq!(read4, "r the thre".as_bytes(), "fourth 10 bytes");
}

#[tokio::test]
async fn async_stdin_stream() {
// A StdinStream has the property that there are multiple
Expand Down Expand Up @@ -530,4 +558,39 @@ mod test {
let read4 = view2.read(10).expect("read fourth 10 bytes");
assert_eq!(read4, "r the thre".as_bytes(), "fourth 10 bytes");
}

#[tokio::test]
async fn async_stdout_stream_unblocks() {
let (mut read, write) = tokio::io::duplex(32);
let stdout = AsyncStdoutStream::new(AsyncWriteStream::new(32, write));

let task = tokio::task::spawn(async move {
let mut stream = stdout.stream();
blocking_write_and_flush(&mut *stream, "x".into())
.await
.unwrap();
});

let mut buf = [0; 100];
let n = read.read(&mut buf).await.unwrap();
assert_eq!(&buf[..n], b"x");

task.await.unwrap();
}

async fn blocking_write_and_flush(
s: &mut dyn HostOutputStream,
mut bytes: Bytes,
) -> Result<()> {
while !bytes.is_empty() {
let permit = s.write_ready().await?;
let len = bytes.len().min(permit);
let chunk = bytes.split_to(len);
s.write(chunk)?;
}

s.flush()?;
s.write_ready().await?;
Ok(())
}
}

0 comments on commit 34d2a08

Please sign in to comment.