Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix lost Waker instances with async stdio streams #8782

Merged
merged 2 commits into from
Jun 12, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 104 additions & 41 deletions crates/wasi/src/stdio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@ use crate::{
HostInputStream, HostOutputStream, StreamError, StreamResult, Subscribe, WasiImpl, WasiView,
};
use bytes::Bytes;
use std::future::Future;
use std::io::IsTerminal;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::sync::Arc;
use tokio::sync::Mutex;
use wasmtime::component::Resource;

/// A trait used to represent the standard input to a guest program.
Expand Down Expand Up @@ -61,6 +59,33 @@ impl StdinStream for pipe::ClosedInputStream {
}

/// An impl of [`StdinStream`] built on top of [`crate::pipe::AsyncReadStream`].
//
// Note the usage of `tokio::sync::Mutex` here as opposed to a
// `std::sync::Mutex`. This is intentionally done to implement the `Subscribe`
// variant of this trait. Note that in doing so we're left with the quandry of
// how to implement methods of `HostInputStream` since those methods are not
// `async`. They're currently implemented with `try_lock`, which then raises the
// question of what to do on contention. Currently traps are returned.
//
// Why should it be ok to return a trap? In general concurrency/contention
// shouldn't return a trap since it should be able to happen normally. The
// current assumption, though, is that WASI stdin/stdout streams are special
// enough that the contention case should never come up in practice. Currently
// in WASI there is no actually concurrency, there's just the items in a single
// `Store` and that store owns all of its I/O in a single Tokio task. There's no
// means to actually spawn multiple Tokio tasks that use the same store. This
// means at the very least that there's zero parallelism. Due to the lack of
// multiple tasks that also means that there's no concurrency either.
//
// This `AsyncStdinStream` wrapper is only intended to be used by the WASI
// bindings themselves. It's possible for the host to take this and work with it
// on its own task, but that's niche enough it's not designed for.
//
// Overall that means that the guest is either calling `Subscribe` or it's
// calling `HostInputStream` methods. This means that there should never be
// contention between the two at this time. This may all change in the future
// with WASI 0.3, but perhaps we'll have a better story for stdio at that time
// (see the doc block on the `HostOutputStream` impl below)
pub struct AsyncStdinStream(Arc<Mutex<crate::pipe::AsyncReadStream>>);

impl AsyncStdinStream {
Expand All @@ -79,30 +104,24 @@ impl StdinStream for AsyncStdinStream {
}

impl HostInputStream for AsyncStdinStream {
fn read(&mut self, size: usize) -> Result<bytes::Bytes, crate::StreamError> {
self.0.lock().unwrap().read(size)
fn read(&mut self, size: usize) -> Result<bytes::Bytes, StreamError> {
match self.0.try_lock() {
Ok(mut stream) => stream.read(size),
Err(_) => Err(StreamError::trap("concurrent reads are not supported")),
}
}
fn skip(&mut self, size: usize) -> Result<usize, crate::StreamError> {
self.0.lock().unwrap().skip(size)
fn skip(&mut self, size: usize) -> Result<usize, StreamError> {
match self.0.try_lock() {
Ok(mut stream) => stream.skip(size),
Err(_) => Err(StreamError::trap("concurrent skips are not supported")),
}
}
}

#[async_trait::async_trait]
impl Subscribe for AsyncStdinStream {
fn ready<'a, 'b>(&'a mut self) -> Pin<Box<dyn Future<Output = ()> + Send + 'b>>
where
Self: 'b,
'a: 'b,
{
struct F(AsyncStdinStream);
impl Future for F {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let mut inner = self.0 .0.lock().unwrap();
let mut fut = inner.ready();
fut.as_mut().poll(cx)
}
}
Box::pin(F(Self(self.0.clone())))
async fn ready(&mut self) {
self.0.lock().await.ready().await
}
}

Expand Down Expand Up @@ -300,6 +319,10 @@ impl Subscribe for OutputStream {
/// A wrapper of [`crate::pipe::AsyncWriteStream`] that implements
/// [`StdoutStream`]. Note that the [`HostOutputStream`] impl for this is not
/// correct when used for interleaved async IO.
//
// Note that the use of `tokio::sync::Mutex` here is intentional, in addition to
// the `try_lock()` calls below in the implementation of `HostOutputStream`. For
// more information see the documentation on `AsyncStdinStream`.
pub struct AsyncStdoutStream(Arc<Mutex<crate::pipe::AsyncWriteStream>>);

impl AsyncStdoutStream {
Expand Down Expand Up @@ -334,32 +357,29 @@ impl StdoutStream for AsyncStdoutStream {
// this comment to correct it: sorry about that.
impl HostOutputStream for AsyncStdoutStream {
fn check_write(&mut self) -> Result<usize, StreamError> {
self.0.lock().unwrap().check_write()
match self.0.try_lock() {
Ok(mut stream) => stream.check_write(),
Err(_) => Err(StreamError::trap("concurrent writes are not supported")),
}
}
fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> {
self.0.lock().unwrap().write(bytes)
match self.0.try_lock() {
Ok(mut stream) => stream.write(bytes),
Err(_) => Err(StreamError::trap("concurrent writes not supported yet")),
}
}
fn flush(&mut self) -> Result<(), StreamError> {
self.0.lock().unwrap().flush()
match self.0.try_lock() {
Ok(mut stream) => stream.flush(),
Err(_) => Err(StreamError::trap("concurrent flushes not supported yet")),
}
}
}

#[async_trait::async_trait]
impl Subscribe for AsyncStdoutStream {
fn ready<'a, 'b>(&'a mut self) -> Pin<Box<dyn Future<Output = ()> + Send + 'b>>
where
Self: 'b,
'a: 'b,
{
struct F(AsyncStdoutStream);
impl Future for F {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let mut inner = self.0 .0.lock().unwrap();
let mut fut = inner.ready();
fut.as_mut().poll(cx)
}
}
Box::pin(F(Self(self.0.clone())))
async fn ready(&mut self) {
self.0.lock().await.ready().await
}
}

Expand Down Expand Up @@ -464,6 +484,13 @@ where

#[cfg(test)]
mod test {
use crate::stdio::StdoutStream;
use crate::write_stream::AsyncWriteStream;
use crate::{AsyncStdoutStream, HostOutputStream};
use anyhow::Result;
use bytes::Bytes;
use tokio::io::AsyncReadExt;

#[test]
fn memory_stdin_stream() {
// A StdinStream has the property that there are multiple
Expand Down Expand Up @@ -492,6 +519,7 @@ mod test {
let read4 = view2.read(10).expect("read fourth 10 bytes");
assert_eq!(read4, "r the thre".as_bytes(), "fourth 10 bytes");
}

#[tokio::test]
async fn async_stdin_stream() {
// A StdinStream has the property that there are multiple
Expand Down Expand Up @@ -530,4 +558,39 @@ mod test {
let read4 = view2.read(10).expect("read fourth 10 bytes");
assert_eq!(read4, "r the thre".as_bytes(), "fourth 10 bytes");
}

#[tokio::test]
async fn async_stdout_stream_unblocks() {
let (mut read, write) = tokio::io::duplex(1024);
let stdout = AsyncStdoutStream::new(AsyncWriteStream::new(1024, write));

let task = tokio::task::spawn(async move {
let mut stream = stdout.stream();
blocking_write_and_flush(&mut *stream, "x".into())
.await
.unwrap();
});

let mut buf = [0; 100];
let n = read.read(&mut buf).await.unwrap();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be more comfortable with this test if the amount read was in excess of the buffer sizes in the io::duplex and the AsyncWriteStream. You could just lower those to 32 each.

assert_eq!(&buf[..n], b"x");

task.await.unwrap();
}

async fn blocking_write_and_flush(
s: &mut dyn HostOutputStream,
mut bytes: Bytes,
) -> Result<()> {
while !bytes.is_empty() {
let permit = s.write_ready().await?;
let len = bytes.len().min(permit);
let chunk = bytes.split_to(len);
s.write(chunk)?;
}

s.flush()?;
s.write_ready().await?;
Ok(())
}
}