Skip to content

Commit

Permalink
Address review feedback.
Browse files Browse the repository at this point in the history
  • Loading branch information
sunfishcode committed Aug 17, 2023
1 parent ce95b4e commit 42403be
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 56 deletions.
1 change: 1 addition & 0 deletions crates/wasi/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ libc = { workspace = true }
[target.'cfg(windows)'.dependencies]
io-extras = { workspace = true }
windows-sys = { workspace = true }
rustix = { workspace = true, features = ["net"], optional = true }

[features]
default = ["sync", "preview2", "preview1-on-preview2"]
Expand Down
11 changes: 6 additions & 5 deletions crates/wasi/src/preview2/host/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ impl From<io::Error> for network::Error {

// Errors we don't expect to see here.
io::ErrorKind::Interrupted | io::ErrorKind::ConnectionAborted => {
panic!("transient errors should be skipped")
// Transient errors should be skipped.
return Self::trap(error.into());
}

// Errors not expected from network APIs.
Expand All @@ -49,11 +50,11 @@ impl From<io::Error> for network::Error {
| io::ErrorKind::BrokenPipe
| io::ErrorKind::NotFound
| io::ErrorKind::UnexpectedEof
| io::ErrorKind::AlreadyExists => ErrorCode::Unknown,
| io::ErrorKind::AlreadyExists => return Self::trap(error.into()),

// Errors that don't correspond to a Rust `io::ErrorKind`.
io::ErrorKind::Other => match error.raw_os_error() {
None => ErrorCode::Unknown,
None => return Self::trap(error.into()),
Some(libc::ENOBUFS) | Some(libc::ENOMEM) => ErrorCode::OutOfMemory,
Some(libc::EOPNOTSUPP) => ErrorCode::NotSupported,
Some(libc::ENETUNREACH) | Some(libc::EHOSTUNREACH) | Some(libc::ENETDOWN) => {
Expand All @@ -62,9 +63,9 @@ impl From<io::Error> for network::Error {
Some(libc::ECONNRESET) => ErrorCode::ConnectionReset,
Some(libc::ECONNREFUSED) => ErrorCode::ConnectionRefused,
Some(libc::EADDRINUSE) => ErrorCode::AddressInUse,
Some(_) => panic!("unknown error {:?}", error),
Some(_) => return Self::trap(error.into()),
},
_ => panic!("unknown error {:?}", error),
_ => return Self::trap(error.into()),
}
.into()
}
Expand Down
123 changes: 73 additions & 50 deletions crates/wasi/src/preview2/host/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::any::Any;
use std::mem;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::RwLockWriteGuard;
#[cfg(unix)]
use tokio::task::spawn;
#[cfg(not(unix))]
Expand All @@ -32,7 +33,7 @@ impl<T: WasiView> tcp::Host for T {
let table = self.table();
let socket = table.get_tcp_socket(this)?;

let mut tcp_state = socket.inner.tcp_state.write().unwrap();
let tcp_state = socket.tcp_state_write_lock();
match &*tcp_state {
HostTcpState::Default => {}
_ => return Err(ErrorCode::NotInProgress.into()),
Expand All @@ -43,26 +44,25 @@ impl<T: WasiView> tcp::Host for T {

binder.bind_existing_tcp_listener(socket.tcp_socket())?;

*tcp_state = HostTcpState::BindStarted;
socket.inner.sender.send(()).unwrap();
set_state(tcp_state, HostTcpState::BindStarted);
socket.notify();

Ok(())
}

// TODO: Bind and listen aren't really blocking operations; figure this
// out at the spec level.
fn finish_bind(&mut self, this: tcp::TcpSocket) -> Result<(), network::Error> {
let table = self.table();
let socket = table.get_tcp_socket(this)?;

let mut tcp_state = socket.inner.tcp_state.write().unwrap();
match &mut *tcp_state {
HostTcpState::BindStarted => {
*tcp_state = HostTcpState::Bound;
Ok(())
}
_ => Err(ErrorCode::NotInProgress.into()),
let tcp_state = socket.tcp_state_write_lock();
match &*tcp_state {
HostTcpState::BindStarted => {}
_ => return Err(ErrorCode::NotInProgress.into()),
}

set_state(tcp_state, HostTcpState::Bound);

Ok(())
}

fn start_connect(
Expand All @@ -74,7 +74,7 @@ impl<T: WasiView> tcp::Host for T {
let table = self.table();
let socket = table.get_tcp_socket(this)?;

let mut tcp_state = socket.inner.tcp_state.write().unwrap();
let tcp_state = socket.tcp_state_write_lock();
match &*tcp_state {
HostTcpState::Default => {}
HostTcpState::Connected => return Err(ErrorCode::AlreadyConnected.into()),
Expand All @@ -88,7 +88,8 @@ impl<T: WasiView> tcp::Host for T {
match connecter.connect_existing_tcp_listener(socket.tcp_socket()) {
// succeed immediately,
Ok(()) => {
*tcp_state = HostTcpState::ConnectReady(Ok(()));
set_state(tcp_state, HostTcpState::ConnectReady(Ok(())));
socket.notify();
return Ok(());
}
// continue in progress,
Expand All @@ -98,7 +99,7 @@ impl<T: WasiView> tcp::Host for T {
Err(err) => return Err(err.into()),
}

// The connect is continuing in progres. Set up the join handle.
// The connect is continuing in progress. Set up the join handle.

let clone = socket.clone_inner();

Expand All @@ -117,8 +118,7 @@ impl<T: WasiView> tcp::Host for T {
Err(err) => Err(err),
};

*clone.tcp_state.write().unwrap() = HostTcpState::ConnectReady(result);
clone.sender.send(()).unwrap();
clone.set_state_and_notify(HostTcpState::ConnectReady(result));
});

#[cfg(not(unix))]
Expand All @@ -140,11 +140,13 @@ impl<T: WasiView> tcp::Host for T {
Err(err) => Err(err.into()),
};

*clone.tcp_state.write().unwrap() = HostTcpState::ConnectReady(result);
clone.sender.send(()).unwrap();
clone.set_state_and_notify(HostTcpState::ConnectReady(result));
});

*tcp_state = HostTcpState::Connecting(Pin::from(Box::new(join)));
set_state(
tcp_state,
HostTcpState::Connecting(Pin::from(Box::new(join))),
);

Ok(())
}
Expand All @@ -156,7 +158,7 @@ impl<T: WasiView> tcp::Host for T {
let table = self.table();
let socket = table.get_tcp_socket(this)?;

let mut tcp_state = socket.inner.tcp_state.write().unwrap();
let mut tcp_state = socket.tcp_state_write_lock();
match &mut *tcp_state {
HostTcpState::ConnectReady(_) => {}
HostTcpState::Connecting(join) => match maybe_unwrap_future(join) {
Expand All @@ -171,14 +173,14 @@ impl<T: WasiView> tcp::Host for T {
// Extract the connection result.
let result = match old_state {
HostTcpState::ConnectReady(result) => result,
_ => panic!(),
_ => unreachable!(),
};

// Report errors, resetting the state if needed.
match result {
Ok(()) => {}
Err(err) => {
*tcp_state = HostTcpState::Default;
set_state(tcp_state, HostTcpState::Default);
return Err(err.into());
}
}
Expand All @@ -200,7 +202,7 @@ impl<T: WasiView> tcp::Host for T {
let table = self.table();
let socket = table.get_tcp_socket(this)?;

let mut tcp_state = socket.inner.tcp_state.write().unwrap();
let tcp_state = socket.tcp_state_write_lock();
match &*tcp_state {
HostTcpState::Bound => {}
HostTcpState::ListenStarted => return Err(ErrorCode::AlreadyListening.into()),
Expand All @@ -210,8 +212,8 @@ impl<T: WasiView> tcp::Host for T {

socket.tcp_socket().listen(None)?;

*tcp_state = HostTcpState::ListenStarted;
socket.inner.sender.send(()).unwrap();
set_state(tcp_state, HostTcpState::ListenStarted);
socket.notify();

Ok(())
}
Expand All @@ -220,16 +222,18 @@ impl<T: WasiView> tcp::Host for T {
let table = self.table();
let socket = table.get_tcp_socket(this)?;

let mut tcp_state = socket.inner.tcp_state.write().unwrap();
let tcp_state = socket.tcp_state_write_lock();

match &mut *tcp_state {
match &*tcp_state {
HostTcpState::ListenStarted => {}
_ => return Err(ErrorCode::NotInProgress.into()),
}

let new_join = spawn_task_to_wait_for_connections(socket.clone_inner());
*tcp_state = HostTcpState::Listening(Pin::from(Box::new(new_join)));
drop(tcp_state);
set_state(
tcp_state,
HostTcpState::Listening(Pin::from(Box::new(new_join))),
);

Ok(())
}
Expand All @@ -241,7 +245,7 @@ impl<T: WasiView> tcp::Host for T {
let table = self.table();
let socket = table.get_tcp_socket(this)?;

let mut tcp_state = socket.inner.tcp_state.write().unwrap();
let mut tcp_state = socket.tcp_state_write_lock();
match &mut *tcp_state {
HostTcpState::ListenReady(_) => {}
HostTcpState::Listening(join) => match maybe_unwrap_future(join) {
Expand All @@ -253,8 +257,10 @@ impl<T: WasiView> tcp::Host for T {
}

let new_join = spawn_task_to_wait_for_connections(socket.clone_inner());
*tcp_state = HostTcpState::Listening(Pin::from(Box::new(new_join)));
drop(tcp_state);
set_state(
tcp_state,
HostTcpState::Listening(Pin::from(Box::new(new_join))),
);

// Do the host system call.
let (connection, _addr) = socket.tcp_socket().accept_with(Blocking::No)?;
Expand Down Expand Up @@ -347,7 +353,7 @@ impl<T: WasiView> tcp::Host for T {
let table = self.table();
let socket = table.get_tcp_socket(this)?;

let tcp_state = socket.inner.tcp_state.read().unwrap();
let tcp_state = socket.tcp_state_read_lock();
match &*tcp_state {
HostTcpState::Listening(_) => {}
_ => return Err(ErrorCode::NotInProgress.into()),
Expand Down Expand Up @@ -506,7 +512,8 @@ impl<T: WasiView> tcp::Host for T {
// doesn't block.
let dropped = table.delete_tcp_socket(this)?;

// On non-Unix platforms, do a `shutdown` to wake up `poll`.
// On non-Unix platforms, do a `shutdown` to wake up any `poll` calls
// that are waiting.
#[cfg(not(unix))]
rustix::net::shutdown(&dropped.inner.tcp_socket, rustix::net::Shutdown::ReadWrite).unwrap();

Expand All @@ -520,14 +527,13 @@ impl<T: WasiView> tcp::Host for T {
/// can be `accept`ed.
fn spawn_task_to_wait_for_connections(socket: Arc<HostTcpSocketInner>) -> JoinHandle<()> {
#[cfg(unix)]
let new_join = spawn(async move {
let join = spawn(async move {
socket.tcp_socket.readable().await.unwrap().retain_ready();
*socket.tcp_state.write().unwrap() = HostTcpState::ListenReady(Ok(()));
socket.sender.send(()).unwrap();
socket.set_state_and_notify(HostTcpState::ListenReady(Ok(())));
});

#[cfg(not(unix))]
let new_join = spawn_blocking(move || {
let join = spawn_blocking(move || {
let result = match rustix::event::poll(
&mut [rustix::event::PollFd::new(
&socket.tcp_socket,
Expand All @@ -538,11 +544,16 @@ fn spawn_task_to_wait_for_connections(socket: Arc<HostTcpSocketInner>) -> JoinHa
Ok(_) => Ok(()),
Err(err) => Err(err.into()),
};
*socket.tcp_state.write().unwrap() = HostTcpState::ListenReady(result);
socket.sender.send(()).unwrap();
socket.set_state_and_notify(HostTcpState::ListenReady(result));
});

new_join
join
}

/// Set `*tcp_state` to `new_state` and consume `tcp_state`.
fn set_state(tcp_state: RwLockWriteGuard<HostTcpState>, new_state: HostTcpState) {
let mut tcp_state = tcp_state;
*tcp_state = new_state;
}

/// Given a future, return the finished value if it's already ready, or
Expand All @@ -553,16 +564,28 @@ fn maybe_unwrap_future<F: std::future::Future + std::marker::Unpin>(
use std::ptr;
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};

unsafe fn clone(_ptr: *const ()) -> RawWaker {
const VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop);
RawWaker::new(std::ptr::null(), &VTABLE)
// Create a no-op Waker. This is derived from [code in std] and can
// be replaced with `std::task::Waker::noop()` when the "noop_waker"
// feature is stablized.
//
// [code in std]: https://github.com/rust-lang/rust/blob/27fb598d51d4566a725e4868eaf5d2e15775193e/library/core/src/task/wake.rs#L349
fn noop_waker() -> Waker {
const VTABLE: RawWakerVTable = RawWakerVTable::new(
// Cloning just returns a new no-op raw waker
|_| RAW,
// `wake` does nothing
|_| {},
// `wake_by_ref` does nothing
|_| {},
// Dropping does nothing as we don't allocate anything
|_| {},
);
const RAW: RawWaker = RawWaker::new(ptr::null(), &VTABLE);

unsafe { Waker::from_raw(RAW) }
}
unsafe fn wake(_ptr: *const ()) {}
unsafe fn wake_by_ref(_ptr: *const ()) {}
unsafe fn drop(_ptr: *const ()) {}

let waker = unsafe { Waker::from_raw(clone(ptr::null() as _)) };

let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
match future.as_mut().poll(&mut cx) {
Poll::Ready(val) => Some(val),
Expand Down
29 changes: 28 additions & 1 deletion crates/wasi/src/preview2/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use cap_std::net::{TcpListener, TcpStream};
use io_lifetimes::AsSocketlike;
use std::io;
use std::pin::Pin;
use std::sync::{Arc, RwLock};
use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
use system_interface::io::IoExt;
use tokio::sync::watch::{channel, Receiver, Sender};
use tokio::task::JoinHandle;
Expand Down Expand Up @@ -128,9 +128,23 @@ impl HostTcpSocket {
self.inner.tcp_socket()
}

pub fn notify(&self) {
self.inner.notify()
}

pub fn clone_inner(&self) -> Arc<HostTcpSocketInner> {
Arc::clone(&self.inner)
}

/// Acquire a reader lock for `self.tcp_state`.
pub fn tcp_state_read_lock(&self) -> RwLockReadGuard<HostTcpState> {
self.inner.tcp_state.read().unwrap()
}

/// Acquire a writer lock for `self.tcp_state`.
pub fn tcp_state_write_lock(&self) -> RwLockWriteGuard<HostTcpState> {
self.inner.tcp_state.write().unwrap()
}
}

impl HostTcpSocketInner {
Expand All @@ -144,6 +158,19 @@ impl HostTcpSocketInner {
tcp_socket
}

pub fn notify(&self) {
self.sender.send(()).unwrap()
}

pub fn set_state(&self, new_state: HostTcpState) {
*self.tcp_state.write().unwrap() = new_state;
}

pub fn set_state_and_notify(&self, new_state: HostTcpState) {
self.set_state(new_state);
self.notify()
}

/// Spawn a task on tokio's blocking thread for performing blocking
/// syscalls on the underlying [`cap_std::net::TcpListener`].
#[cfg(not(unix))]
Expand Down

0 comments on commit 42403be

Please sign in to comment.