From c63cf1f52fad1b5754edd5e12a0884d4ca6edc3e Mon Sep 17 00:00:00 2001 From: Joel Dice Date: Thu, 4 Jan 2024 17:33:53 -0700 Subject: [PATCH] fix `fd_read` and `fd_write` for `wasm32-wasi-threads` (#7750) Previously, `first_non_empty_{c}io_vec` always returned `Ok(None)` for buffers residing in shared memories since they cannot, in general, safely be represented as slices. That caused e.g. `wasi-libc` to spin forever when trying to write to stdout using `fd_write` since it always got `Ok(0)` and never made progress. This commit changes the return type of both functions to use `GuestPtr` instead of `GuestSlice{Mut}`, allowing safe access to shared guest memory. Big thanks to Alex Crichton for narrowing this down and suggesting the fix. Fixes #7745 Signed-off-by: Joel Dice --- crates/wasi/src/preview2/preview1.rs | 53 ++++++++++++++++------------ 1 file changed, 30 insertions(+), 23 deletions(-) diff --git a/crates/wasi/src/preview2/preview1.rs b/crates/wasi/src/preview2/preview1.rs index 8a0c64bd713c..83da227715bd 100644 --- a/crates/wasi/src/preview2/preview1.rs +++ b/crates/wasi/src/preview2/preview1.rs @@ -19,7 +19,7 @@ use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use wasmtime::component::Resource; use wiggle::tracing::instrument; -use wiggle::{GuestError, GuestPtr, GuestSlice, GuestSliceMut, GuestStrCow, GuestType}; +use wiggle::{GuestError, GuestPtr, GuestStrCow, GuestType}; #[derive(Debug)] struct File { @@ -81,10 +81,13 @@ impl BlockingMode { &self, host: &mut (impl streams::Host + poll::Host), output_stream: Resource, - mut bytes: &[u8], + bytes: GuestPtr<'_, [u8]>, ) -> StreamResult { use streams::HostOutputStream as Streams; + let bytes = bytes.as_cow().map_err(|e| StreamError::Trap(e.into()))?; + let mut bytes = &bytes[..]; + match self { BlockingMode::Blocking => { let total = bytes.len(); @@ -814,29 +817,27 @@ fn read_string<'a>(ptr: impl Borrow>) -> Result { // Find first non-empty buffer. fn first_non_empty_ciovec<'a, 'b>( ciovs: &'a types::CiovecArray<'b>, -) -> Result>> { +) -> Result>> { for iov in ciovs.iter() { let iov = iov?.read()?; if iov.buf_len == 0 { continue; } - return Ok(iov.buf.as_array(iov.buf_len).as_slice()?); + return Ok(Some(iov.buf.as_array(iov.buf_len))); } Ok(None) } // Find first non-empty buffer. -fn first_non_empty_iovec<'a>( - iovs: &types::IovecArray<'a>, -) -> Result>> { +fn first_non_empty_iovec<'a>(iovs: &types::IovecArray<'a>) -> Result>> { iovs.iter() .map(|iov| { let iov = iov?.read()?; if iov.buf_len == 0 { return Ok(None); } - let slice = iov.buf.as_array(iov.buf_len).as_slice_mut()?; - Ok(slice) + let slice = iov.buf.as_array(iov.buf_len); + Ok(Some(slice)) }) .find_map(Result::transpose) .transpose() @@ -1317,7 +1318,7 @@ impl< ) -> Result { let t = self.transact()?; let desc = t.get_descriptor(fd)?; - let (mut buf, read) = match desc { + let (buf, read) = match desc { Descriptor::File(File { fd, blocking_mode, @@ -1338,7 +1339,9 @@ impl< .context("failed to call `read-via-stream`") .unwrap_or_else(types::Error::trap) })?; - let read = blocking_mode.read(self, stream, buf.len()).await?; + let read = blocking_mode + .read(self, stream, buf.len().try_into()?) + .await?; let n = read.len().try_into()?; let pos = pos.checked_add(n).ok_or(types::Errno::Overflow)?; position.store(pos, Ordering::Relaxed); @@ -1351,16 +1354,18 @@ impl< let Some(buf) = first_non_empty_iovec(iovs)? else { return Ok(0); }; - let read = BlockingMode::Blocking.read(self, stream, buf.len()).await?; + let read = BlockingMode::Blocking + .read(self, stream, buf.len().try_into()?) + .await?; (buf, read) } _ => return Err(types::Errno::Badf.into()), }; - if read.len() > buf.len() { + if read.len() > buf.len().try_into()? { return Err(types::Errno::Range.into()); } - let (buf, _) = buf.split_at_mut(read.len()); - buf.copy_from_slice(&read); + let buf = buf.get_range(0..u32::try_from(read.len())?).unwrap(); + buf.copy_from_slice(&read)?; let n = read.len().try_into()?; Ok(n) } @@ -1376,7 +1381,7 @@ impl< ) -> Result { let t = self.transact()?; let desc = t.get_descriptor(fd)?; - let (mut buf, read) = match desc { + let (buf, read) = match desc { Descriptor::File(File { fd, blocking_mode, .. }) if t.view.table().get(fd)?.is_file() => { @@ -1392,7 +1397,9 @@ impl< .context("failed to call `read-via-stream`") .unwrap_or_else(types::Error::trap) })?; - let read = blocking_mode.read(self, stream, buf.len()).await?; + let read = blocking_mode + .read(self, stream, buf.len().try_into()?) + .await?; (buf, read) } Descriptor::Stdin { .. } => { @@ -1401,11 +1408,11 @@ impl< } _ => return Err(types::Errno::Badf.into()), }; - if read.len() > buf.len() { + if read.len() > buf.len().try_into()? { return Err(types::Errno::Range.into()); } - let (buf, _) = buf.split_at_mut(read.len()); - buf.copy_from_slice(&read); + let buf = buf.get_range(0..u32::try_from(read.len())?).unwrap(); + buf.copy_from_slice(&read)?; let n = read.len().try_into()?; Ok(n) } @@ -1452,7 +1459,7 @@ impl< })?; (stream, pos) }; - let n = blocking_mode.write(self, stream, &buf).await?; + let n = blocking_mode.write(self, stream, buf).await?; if append { let len = self.stat(fd2).await?; position.store(len.size, Ordering::Relaxed); @@ -1470,7 +1477,7 @@ impl< return Ok(0); }; let n = BlockingMode::Blocking - .write(self, stream, &buf) + .write(self, stream, buf) .await? .try_into()?; Ok(n) @@ -1505,7 +1512,7 @@ impl< .context("failed to call `write-via-stream`") .unwrap_or_else(types::Error::trap) })?; - blocking_mode.write(self, stream, &buf).await? + blocking_mode.write(self, stream, buf).await? } Descriptor::Stdout { .. } | Descriptor::Stderr { .. } => { // NOTE: legacy implementation returns SPIPE here