Skip to content

Commit

Permalink
Handle zero-length reads and writes, and other cleanups.
Browse files Browse the repository at this point in the history
  • Loading branch information
sunfishcode committed Aug 17, 2023
1 parent 42403be commit 37b0521
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 16 deletions.
21 changes: 17 additions & 4 deletions crates/test-programs/wasi-sockets-tests/src/bin/tcp_v4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,13 @@ fn main() {
wait(client_sub);
let (client_input, client_output) = tcp::finish_connect(client).unwrap();

let (n, _status) = streams::write(client_output, first_message).unwrap();
let (n, status) = streams::write(client_output, &[]).unwrap();
assert_eq!(n, 0);
assert_eq!(status, streams::StreamStatus::Open);

let (n, status) = streams::write(client_output, first_message).unwrap();
assert_eq!(n, first_message.len() as u64); // Not guaranteed to work but should work in practice.
assert_eq!(status, streams::StreamStatus::Open);

streams::drop_input_stream(client_input);
streams::drop_output_stream(client_output);
Expand All @@ -57,7 +62,13 @@ fn main() {

wait(sub);
let (accepted, input, output) = tcp::accept(sock).unwrap();
let (data, _status) = streams::read(input, first_message.len() as u64).unwrap();

let (empty_data, status) = streams::read(input, 0).unwrap();
assert!(empty_data.is_empty());
assert_eq!(status, streams::StreamStatus::Open);

let (data, status) = streams::read(input, first_message.len() as u64).unwrap();
assert_eq!(status, streams::StreamStatus::Open);

tcp::drop_tcp_socket(accepted);
streams::drop_input_stream(input);
Expand All @@ -74,8 +85,9 @@ fn main() {
wait(client_sub);
let (client_input, client_output) = tcp::finish_connect(client).unwrap();

let (n, _status) = streams::write(client_output, second_message).unwrap();
let (n, status) = streams::write(client_output, second_message).unwrap();
assert_eq!(n, second_message.len() as u64); // Not guaranteed to work but should work in practice.
assert_eq!(status, streams::StreamStatus::Open);

streams::drop_input_stream(client_input);
streams::drop_output_stream(client_output);
Expand All @@ -84,7 +96,8 @@ fn main() {

wait(sub);
let (accepted, input, output) = tcp::accept(sock).unwrap();
let (data, _status) = streams::read(input, second_message.len() as u64).unwrap();
let (data, status) = streams::read(input, second_message.len() as u64).unwrap();
assert_eq!(status, streams::StreamStatus::Open);

streams::drop_input_stream(input);
streams::drop_output_stream(output);
Expand Down
21 changes: 17 additions & 4 deletions crates/test-programs/wasi-sockets-tests/src/bin/tcp_v6.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,13 @@ fn main() {
wait(client_sub);
let (client_input, client_output) = tcp::finish_connect(client).unwrap();

let (n, _status) = streams::write(client_output, first_message).unwrap();
let (n, status) = streams::write(client_output, &[]).unwrap();
assert_eq!(n, 0);
assert_eq!(status, streams::StreamStatus::Open);

let (n, status) = streams::write(client_output, first_message).unwrap();
assert_eq!(n, first_message.len() as u64); // Not guaranteed to work but should work in practice.
assert_eq!(status, streams::StreamStatus::Open);

streams::drop_input_stream(client_input);
streams::drop_output_stream(client_output);
Expand All @@ -59,7 +64,13 @@ fn main() {

wait(sub);
let (accepted, input, output) = tcp::accept(sock).unwrap();
let (data, _status) = streams::read(input, first_message.len() as u64).unwrap();

let (empty_data, status) = streams::read(input, 0).unwrap();
assert!(empty_data.is_empty());
assert_eq!(status, streams::StreamStatus::Open);

let (data, status) = streams::read(input, first_message.len() as u64).unwrap();
assert_eq!(status, streams::StreamStatus::Open);

tcp::drop_tcp_socket(accepted);
streams::drop_input_stream(input);
Expand All @@ -76,8 +87,9 @@ fn main() {
wait(client_sub);
let (client_input, client_output) = tcp::finish_connect(client).unwrap();

let (n, _status) = streams::write(client_output, second_message).unwrap();
let (n, status) = streams::write(client_output, second_message).unwrap();
assert_eq!(n, second_message.len() as u64); // Not guaranteed to work but should work in practice.
assert_eq!(status, streams::StreamStatus::Open);

streams::drop_input_stream(client_input);
streams::drop_output_stream(client_output);
Expand All @@ -86,7 +98,8 @@ fn main() {

wait(sub);
let (accepted, input, output) = tcp::accept(sock).unwrap();
let (data, _status) = streams::read(input, second_message.len() as u64).unwrap();
let (data, status) = streams::read(input, second_message.len() as u64).unwrap();
assert_eq!(status, streams::StreamStatus::Open);

streams::drop_input_stream(input);
streams::drop_output_stream(output);
Expand Down
14 changes: 6 additions & 8 deletions crates/wasi/src/preview2/host/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ impl<T: WasiView> tcp::Host for T {
let network = table.get_network(network)?;
let binder = network.0.tcp_binder(local_address)?;

// Perform the OS bind call.
binder.bind_existing_tcp_listener(socket.tcp_socket())?;

set_state(tcp_state, HostTcpState::BindStarted);
Expand Down Expand Up @@ -84,7 +85,7 @@ impl<T: WasiView> tcp::Host for T {
let network = table.get_network(network)?;
let connecter = network.0.tcp_connecter(remote_address)?;

// Do a host `connect`. Our socket is non-blocking, so it'll either...
// Do an OS `connect`. Our socket is non-blocking, so it'll either...
match connecter.connect_existing_tcp_listener(socket.tcp_socket()) {
// succeed immediately,
Ok(()) => {
Expand Down Expand Up @@ -262,7 +263,7 @@ impl<T: WasiView> tcp::Host for T {
HostTcpState::Listening(Pin::from(Box::new(new_join))),
);

// Do the host system call.
// Do the OS accept call.
let (connection, _addr) = socket.tcp_socket().accept_with(Blocking::No)?;
let tcp_socket = HostTcpSocket::from_tcp_stream(connection)?;

Expand All @@ -282,8 +283,7 @@ impl<T: WasiView> tcp::Host for T {
let table = self.table();
let socket = table.get_tcp_socket(this)?;
let addr = socket
.inner
.tcp_socket
.tcp_socket()
.as_socketlike_view::<std::net::TcpStream>()
.local_addr()?;
Ok(addr.into())
Expand All @@ -293,8 +293,7 @@ impl<T: WasiView> tcp::Host for T {
let table = self.table();
let socket = table.get_tcp_socket(this)?;
let addr = socket
.inner
.tcp_socket
.tcp_socket()
.as_socketlike_view::<std::net::TcpStream>()
.peer_addr()?;
Ok(addr.into())
Expand Down Expand Up @@ -498,8 +497,7 @@ impl<T: WasiView> tcp::Host for T {
};

socket
.inner
.tcp_socket
.tcp_socket()
.as_socketlike_view::<std::net::TcpStream>()
.shutdown(how)?;
Ok(())
Expand Down
13 changes: 13 additions & 0 deletions crates/wasi/src/preview2/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ impl HostTcpSocketInner {
#[async_trait::async_trait]
impl HostInputStream for Arc<HostTcpSocketInner> {
fn read(&mut self, size: usize) -> anyhow::Result<(Bytes, StreamState)> {
if size == 0 {
return Ok((Bytes::new(), StreamState::Open));
}
let mut buf = BytesMut::zeroed(size);
let r = self
.tcp_socket()
Expand Down Expand Up @@ -228,6 +231,9 @@ impl HostInputStream for Arc<HostTcpSocketInner> {
#[async_trait::async_trait]
impl HostOutputStream for Arc<HostTcpSocketInner> {
fn write(&mut self, buf: Bytes) -> anyhow::Result<(usize, StreamState)> {
if buf.is_empty() {
return Ok((0, StreamState::Open));
}
let r = self
.tcp_socket
.as_socketlike_view::<TcpStream>()
Expand Down Expand Up @@ -313,8 +319,15 @@ pub(crate) fn read_result(r: io::Result<usize>) -> io::Result<(usize, StreamStat

pub(crate) fn write_result(r: io::Result<usize>) -> io::Result<(usize, StreamState)> {
match r {
// We special-case zero-write stores ourselves, so if we get a zero
// back from a `write`, it means the stream is closed on some
// platforms.
Ok(0) => Ok((0, StreamState::Closed)),
Ok(n) => Ok((n, StreamState::Open)),
#[cfg(not(windows))]
Err(e) if e.raw_os_error() == Some(rustix::io::Errno::PIPE.raw_os_error()) => {
Ok((0, StreamState::Closed))
}
Err(e) => Err(e),
}
}

0 comments on commit 37b0521

Please sign in to comment.