diff --git a/crates/test-programs/wasi-sockets-tests/src/bin/tcp_v4.rs b/crates/test-programs/wasi-sockets-tests/src/bin/tcp_v4.rs index f0e971ac1252..907b8986b7de 100644 --- a/crates/test-programs/wasi-sockets-tests/src/bin/tcp_v4.rs +++ b/crates/test-programs/wasi-sockets-tests/src/bin/tcp_v4.rs @@ -47,8 +47,13 @@ fn main() { wait(client_sub); let (client_input, client_output) = tcp::finish_connect(client).unwrap(); - let (n, _status) = streams::write(client_output, first_message).unwrap(); + let (n, status) = streams::write(client_output, &[]).unwrap(); + assert_eq!(n, 0); + assert_eq!(status, streams::StreamStatus::Open); + + let (n, status) = streams::write(client_output, first_message).unwrap(); assert_eq!(n, first_message.len() as u64); // Not guaranteed to work but should work in practice. + assert_eq!(status, streams::StreamStatus::Open); streams::drop_input_stream(client_input); streams::drop_output_stream(client_output); @@ -57,7 +62,13 @@ fn main() { wait(sub); let (accepted, input, output) = tcp::accept(sock).unwrap(); - let (data, _status) = streams::read(input, first_message.len() as u64).unwrap(); + + let (empty_data, status) = streams::read(input, 0).unwrap(); + assert!(empty_data.is_empty()); + assert_eq!(status, streams::StreamStatus::Open); + + let (data, status) = streams::read(input, first_message.len() as u64).unwrap(); + assert_eq!(status, streams::StreamStatus::Open); tcp::drop_tcp_socket(accepted); streams::drop_input_stream(input); @@ -74,8 +85,9 @@ fn main() { wait(client_sub); let (client_input, client_output) = tcp::finish_connect(client).unwrap(); - let (n, _status) = streams::write(client_output, second_message).unwrap(); + let (n, status) = streams::write(client_output, second_message).unwrap(); assert_eq!(n, second_message.len() as u64); // Not guaranteed to work but should work in practice. + assert_eq!(status, streams::StreamStatus::Open); streams::drop_input_stream(client_input); streams::drop_output_stream(client_output); @@ -84,7 +96,8 @@ fn main() { wait(sub); let (accepted, input, output) = tcp::accept(sock).unwrap(); - let (data, _status) = streams::read(input, second_message.len() as u64).unwrap(); + let (data, status) = streams::read(input, second_message.len() as u64).unwrap(); + assert_eq!(status, streams::StreamStatus::Open); streams::drop_input_stream(input); streams::drop_output_stream(output); diff --git a/crates/test-programs/wasi-sockets-tests/src/bin/tcp_v6.rs b/crates/test-programs/wasi-sockets-tests/src/bin/tcp_v6.rs index d3db94ddcef2..47a569aed30c 100644 --- a/crates/test-programs/wasi-sockets-tests/src/bin/tcp_v6.rs +++ b/crates/test-programs/wasi-sockets-tests/src/bin/tcp_v6.rs @@ -49,8 +49,13 @@ fn main() { wait(client_sub); let (client_input, client_output) = tcp::finish_connect(client).unwrap(); - let (n, _status) = streams::write(client_output, first_message).unwrap(); + let (n, status) = streams::write(client_output, &[]).unwrap(); + assert_eq!(n, 0); + assert_eq!(status, streams::StreamStatus::Open); + + let (n, status) = streams::write(client_output, first_message).unwrap(); assert_eq!(n, first_message.len() as u64); // Not guaranteed to work but should work in practice. + assert_eq!(status, streams::StreamStatus::Open); streams::drop_input_stream(client_input); streams::drop_output_stream(client_output); @@ -59,7 +64,13 @@ fn main() { wait(sub); let (accepted, input, output) = tcp::accept(sock).unwrap(); - let (data, _status) = streams::read(input, first_message.len() as u64).unwrap(); + + let (empty_data, status) = streams::read(input, 0).unwrap(); + assert!(empty_data.is_empty()); + assert_eq!(status, streams::StreamStatus::Open); + + let (data, status) = streams::read(input, first_message.len() as u64).unwrap(); + assert_eq!(status, streams::StreamStatus::Open); tcp::drop_tcp_socket(accepted); streams::drop_input_stream(input); @@ -76,8 +87,9 @@ fn main() { wait(client_sub); let (client_input, client_output) = tcp::finish_connect(client).unwrap(); - let (n, _status) = streams::write(client_output, second_message).unwrap(); + let (n, status) = streams::write(client_output, second_message).unwrap(); assert_eq!(n, second_message.len() as u64); // Not guaranteed to work but should work in practice. + assert_eq!(status, streams::StreamStatus::Open); streams::drop_input_stream(client_input); streams::drop_output_stream(client_output); @@ -86,7 +98,8 @@ fn main() { wait(sub); let (accepted, input, output) = tcp::accept(sock).unwrap(); - let (data, _status) = streams::read(input, second_message.len() as u64).unwrap(); + let (data, status) = streams::read(input, second_message.len() as u64).unwrap(); + assert_eq!(status, streams::StreamStatus::Open); streams::drop_input_stream(input); streams::drop_output_stream(output); diff --git a/crates/wasi/src/preview2/host/tcp.rs b/crates/wasi/src/preview2/host/tcp.rs index a757a1c143ae..8d65663eabf3 100644 --- a/crates/wasi/src/preview2/host/tcp.rs +++ b/crates/wasi/src/preview2/host/tcp.rs @@ -42,6 +42,7 @@ impl tcp::Host for T { let network = table.get_network(network)?; let binder = network.0.tcp_binder(local_address)?; + // Perform the OS bind call. binder.bind_existing_tcp_listener(socket.tcp_socket())?; set_state(tcp_state, HostTcpState::BindStarted); @@ -84,7 +85,7 @@ impl tcp::Host for T { let network = table.get_network(network)?; let connecter = network.0.tcp_connecter(remote_address)?; - // Do a host `connect`. Our socket is non-blocking, so it'll either... + // Do an OS `connect`. Our socket is non-blocking, so it'll either... match connecter.connect_existing_tcp_listener(socket.tcp_socket()) { // succeed immediately, Ok(()) => { @@ -262,7 +263,7 @@ impl tcp::Host for T { HostTcpState::Listening(Pin::from(Box::new(new_join))), ); - // Do the host system call. + // Do the OS accept call. let (connection, _addr) = socket.tcp_socket().accept_with(Blocking::No)?; let tcp_socket = HostTcpSocket::from_tcp_stream(connection)?; @@ -282,8 +283,7 @@ impl tcp::Host for T { let table = self.table(); let socket = table.get_tcp_socket(this)?; let addr = socket - .inner - .tcp_socket + .tcp_socket() .as_socketlike_view::() .local_addr()?; Ok(addr.into()) @@ -293,8 +293,7 @@ impl tcp::Host for T { let table = self.table(); let socket = table.get_tcp_socket(this)?; let addr = socket - .inner - .tcp_socket + .tcp_socket() .as_socketlike_view::() .peer_addr()?; Ok(addr.into()) @@ -498,8 +497,7 @@ impl tcp::Host for T { }; socket - .inner - .tcp_socket + .tcp_socket() .as_socketlike_view::() .shutdown(how)?; Ok(()) diff --git a/crates/wasi/src/preview2/tcp.rs b/crates/wasi/src/preview2/tcp.rs index 92c812bbea28..f4cfc80ce479 100644 --- a/crates/wasi/src/preview2/tcp.rs +++ b/crates/wasi/src/preview2/tcp.rs @@ -189,6 +189,9 @@ impl HostTcpSocketInner { #[async_trait::async_trait] impl HostInputStream for Arc { fn read(&mut self, size: usize) -> anyhow::Result<(Bytes, StreamState)> { + if size == 0 { + return Ok((Bytes::new(), StreamState::Open)); + } let mut buf = BytesMut::zeroed(size); let r = self .tcp_socket() @@ -228,6 +231,9 @@ impl HostInputStream for Arc { #[async_trait::async_trait] impl HostOutputStream for Arc { fn write(&mut self, buf: Bytes) -> anyhow::Result<(usize, StreamState)> { + if buf.is_empty() { + return Ok((0, StreamState::Open)); + } let r = self .tcp_socket .as_socketlike_view::() @@ -313,8 +319,15 @@ pub(crate) fn read_result(r: io::Result) -> io::Result<(usize, StreamStat pub(crate) fn write_result(r: io::Result) -> io::Result<(usize, StreamState)> { match r { + // We special-case zero-write stores ourselves, so if we get a zero + // back from a `write`, it means the stream is closed on some + // platforms. Ok(0) => Ok((0, StreamState::Closed)), Ok(n) => Ok((n, StreamState::Open)), + #[cfg(not(windows))] + Err(e) if e.raw_os_error() == Some(rustix::io::Errno::PIPE.raw_os_error()) => { + Ok((0, StreamState::Closed)) + } Err(e) => Err(e), } }