diff --git a/crates/test-programs/tests/wasi-sockets.rs b/crates/test-programs/tests/wasi-sockets.rs index 8484a4f98f77..3ff9367d610c 100644 --- a/crates/test-programs/tests/wasi-sockets.rs +++ b/crates/test-programs/tests/wasi-sockets.rs @@ -76,6 +76,16 @@ async fn tcp_v6() { run("tcp_v6").await.unwrap(); } +#[test_log::test(tokio::test(flavor = "multi_thread"))] +async fn udp_v4() { + run("udp_v4").await.unwrap(); +} + +#[test_log::test(tokio::test(flavor = "multi_thread"))] +async fn udp_v6() { + run("udp_v6").await.unwrap(); +} + #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn ip_name_lookup() { run("ip_name_lookup").await.unwrap(); diff --git a/crates/test-programs/wasi-sockets-tests/src/bin/tcp_v4.rs b/crates/test-programs/wasi-sockets-tests/src/bin/tcp_v4.rs index cf02fdd79663..21daef42691b 100644 --- a/crates/test-programs/wasi-sockets-tests/src/bin/tcp_v4.rs +++ b/crates/test-programs/wasi-sockets-tests/src/bin/tcp_v4.rs @@ -24,5 +24,5 @@ fn main() { sock.finish_bind().unwrap(); - example_body(net, sock, IpAddressFamily::Ipv4) + example_body_tcp(net, sock, IpAddressFamily::Ipv4) } diff --git a/crates/test-programs/wasi-sockets-tests/src/bin/tcp_v6.rs b/crates/test-programs/wasi-sockets-tests/src/bin/tcp_v6.rs index 807db9825f1e..696e6cc0f5b8 100644 --- a/crates/test-programs/wasi-sockets-tests/src/bin/tcp_v6.rs +++ b/crates/test-programs/wasi-sockets-tests/src/bin/tcp_v6.rs @@ -1,4 +1,4 @@ -//! Like v4.rs, but with IPv6. +//! Like tcp_v4.rs, but with IPv6. use wasi::io::poll; use wasi::sockets::network::{IpAddressFamily, IpSocketAddress, Ipv6SocketAddress}; @@ -26,5 +26,5 @@ fn main() { sock.finish_bind().unwrap(); - example_body(net, sock, IpAddressFamily::Ipv6) + example_body_tcp(net, sock, IpAddressFamily::Ipv6) } diff --git a/crates/test-programs/wasi-sockets-tests/src/bin/udp_v4.rs b/crates/test-programs/wasi-sockets-tests/src/bin/udp_v4.rs new file mode 100644 index 000000000000..86d606cad96f --- /dev/null +++ b/crates/test-programs/wasi-sockets-tests/src/bin/udp_v4.rs @@ -0,0 +1,28 @@ +//! A simple UDP testcase, using IPv4. + +use wasi::io::poll; +use wasi::sockets::network::{IpAddressFamily, IpSocketAddress, Ipv4SocketAddress}; +use wasi::sockets::{instance_network, udp_create_socket}; +use wasi_sockets_tests::*; + +fn main() { + let net = instance_network::instance_network(); + + let sock = udp_create_socket::create_udp_socket(IpAddressFamily::Ipv4).unwrap(); + + let addr = IpSocketAddress::Ipv4(Ipv4SocketAddress { + port: 0, // use any free port + address: (127, 0, 0, 1), // localhost + }); + + let sub = sock.subscribe(); + + sock.start_bind(&net, addr).unwrap(); + + poll::poll_one(&sub); + drop(sub); + + sock.finish_bind().unwrap(); + + example_body_udp(net, sock, IpAddressFamily::Ipv4) +} diff --git a/crates/test-programs/wasi-sockets-tests/src/bin/udp_v6.rs b/crates/test-programs/wasi-sockets-tests/src/bin/udp_v6.rs new file mode 100644 index 000000000000..58d455997283 --- /dev/null +++ b/crates/test-programs/wasi-sockets-tests/src/bin/udp_v6.rs @@ -0,0 +1,30 @@ +//! Like udp_v4.rs, but with IPv6. + +use wasi::io::poll; +use wasi::sockets::network::{IpAddressFamily, IpSocketAddress, Ipv6SocketAddress}; +use wasi::sockets::{instance_network, udp_create_socket}; +use wasi_sockets_tests::*; + +fn main() { + let net = instance_network::instance_network(); + + let sock = udp_create_socket::create_udp_socket(IpAddressFamily::Ipv6).unwrap(); + + let addr = IpSocketAddress::Ipv6(Ipv6SocketAddress { + port: 0, // use any free port + address: (0, 0, 0, 0, 0, 0, 0, 1), // localhost + flow_info: 0, + scope_id: 0, + }); + + let sub = sock.subscribe(); + + sock.start_bind(&net, addr).unwrap(); + + poll::poll_one(&sub); + drop(sub); + + sock.finish_bind().unwrap(); + + example_body_udp(net, sock, IpAddressFamily::Ipv6) +} diff --git a/crates/test-programs/wasi-sockets-tests/src/lib.rs b/crates/test-programs/wasi-sockets-tests/src/lib.rs index a46cff830c0f..c3ac2669917d 100644 --- a/crates/test-programs/wasi-sockets-tests/src/lib.rs +++ b/crates/test-programs/wasi-sockets-tests/src/lib.rs @@ -2,7 +2,7 @@ wit_bindgen::generate!("test-command-with-sockets" in "../../wasi/wit"); use wasi::io::poll; use wasi::io::streams; -use wasi::sockets::{network, tcp, tcp_create_socket}; +use wasi::sockets::{network, tcp, tcp_create_socket, udp, udp_create_socket}; pub fn write(output: &streams::OutputStream, mut bytes: &[u8]) -> Result<(), streams::StreamError> { let pollable = output.subscribe(); @@ -24,7 +24,7 @@ pub fn write(output: &streams::OutputStream, mut bytes: &[u8]) -> Result<(), str Ok(()) } -pub fn example_body(net: tcp::Network, sock: tcp::TcpSocket, family: network::IpAddressFamily) { +pub fn example_body_tcp(net: tcp::Network, sock: tcp::TcpSocket, family: network::IpAddressFamily) { let first_message = b"Hello, world!"; let second_message = b"Greetings, planet!"; @@ -95,3 +95,90 @@ pub fn example_body(net: tcp::Network, sock: tcp::TcpSocket, family: network::Ip // Check that we sent and recieved our message! assert_eq!(data, second_message); // Not guaranteed to work but should work in practice. } + +pub fn example_body_udp(net: udp::Network, sock: udp::UdpSocket, family: network::IpAddressFamily) { + let first_message = b"Hello, world!"; + let second_message = b"Greetings, planet!"; + + let sub = sock.subscribe(); + + let addr = sock.local_address().unwrap(); + + let client = udp_create_socket::create_udp_socket(family).unwrap(); + let client_sub = client.subscribe(); + + client.start_connect(&net, addr).unwrap(); + poll::poll_one(&client_sub); + client.finish_connect().unwrap(); + + let _client_addr = client.local_address().unwrap(); + + let n = client + .send(&[ + udp::Datagram { + data: vec![], + remote_address: addr, + }, + udp::Datagram { + data: first_message.to_vec(), + remote_address: addr, + }, + ]) + .unwrap(); + assert_eq!(n, 2); + + drop(client_sub); + drop(client); + + poll::poll_one(&sub); + let datagrams = sock.receive(2).unwrap(); + let mut datagrams = datagrams.into_iter(); + let (first, second) = match (datagrams.next(), datagrams.next(), datagrams.next()) { + (Some(first), Some(second), None) => (first, second), + (Some(_first), None, None) => panic!("only one datagram received"), + (None, None, None) => panic!("no datagrams received"), + _ => panic!("invalid datagram sequence received"), + }; + + assert!(first.data.is_empty()); + + // TODO: Verify the `remote_address` + //assert_eq!(first.remote_address, client_addr); + + // Check that we sent and recieved our message! + assert_eq!(second.data, first_message); // Not guaranteed to work but should work in practice. + + // TODO: Verify the `remote_address` + //assert_eq!(second.remote_address, client_addr); + + // Another client + let client = udp_create_socket::create_udp_socket(family).unwrap(); + let client_sub = client.subscribe(); + + client.start_connect(&net, addr).unwrap(); + poll::poll_one(&client_sub); + client.finish_connect().unwrap(); + + let n = client + .send(&[udp::Datagram { + data: second_message.to_vec(), + remote_address: addr, + }]) + .unwrap(); + assert_eq!(n, 1); + + drop(client_sub); + drop(client); + + poll::poll_one(&sub); + let datagrams = sock.receive(2).unwrap(); + let mut datagrams = datagrams.into_iter(); + let first = match (datagrams.next(), datagrams.next()) { + (Some(first), None) => first, + (None, None) => panic!("no datagrams received"), + _ => panic!("invalid datagram sequence received"), + }; + + // Check that we sent and recieved our message! + assert_eq!(first.data, second_message); // Not guaranteed to work but should work in practice. +} diff --git a/crates/wasi/src/preview2/command.rs b/crates/wasi/src/preview2/command.rs index 811e3cf18e2c..898311157354 100644 --- a/crates/wasi/src/preview2/command.rs +++ b/crates/wasi/src/preview2/command.rs @@ -48,6 +48,8 @@ pub fn add_to_linker(l: &mut wasmtime::component::Linker) -> any crate::preview2::bindings::cli::terminal_stderr::add_to_linker(l, |t| t)?; crate::preview2::bindings::sockets::tcp::add_to_linker(l, |t| t)?; crate::preview2::bindings::sockets::tcp_create_socket::add_to_linker(l, |t| t)?; + crate::preview2::bindings::sockets::udp::add_to_linker(l, |t| t)?; + crate::preview2::bindings::sockets::udp_create_socket::add_to_linker(l, |t| t)?; crate::preview2::bindings::sockets::instance_network::add_to_linker(l, |t| t)?; crate::preview2::bindings::sockets::network::add_to_linker(l, |t| t)?; crate::preview2::bindings::sockets::ip_name_lookup::add_to_linker(l, |t| t)?; @@ -65,6 +67,7 @@ pub mod sync { "wasi:filesystem/types": crate::preview2::bindings::sync_io::filesystem::types, "wasi:filesystem/preopens": crate::preview2::bindings::filesystem::preopens, "wasi:sockets/tcp": crate::preview2::bindings::sockets::tcp, + "wasi:sockets/udp": crate::preview2::bindings::sockets::udp, "wasi:clocks/monotonic_clock": crate::preview2::bindings::clocks::monotonic_clock, "wasi:io/poll": crate::preview2::bindings::sync_io::io::poll, "wasi:io/streams": crate::preview2::bindings::sync_io::io::streams, @@ -107,6 +110,8 @@ pub mod sync { crate::preview2::bindings::cli::terminal_stderr::add_to_linker(l, |t| t)?; crate::preview2::bindings::sockets::tcp::add_to_linker(l, |t| t)?; crate::preview2::bindings::sockets::tcp_create_socket::add_to_linker(l, |t| t)?; + crate::preview2::bindings::sockets::udp::add_to_linker(l, |t| t)?; + crate::preview2::bindings::sockets::udp_create_socket::add_to_linker(l, |t| t)?; crate::preview2::bindings::sockets::instance_network::add_to_linker(l, |t| t)?; crate::preview2::bindings::sockets::network::add_to_linker(l, |t| t)?; crate::preview2::bindings::sockets::ip_name_lookup::add_to_linker(l, |t| t)?; diff --git a/crates/wasi/src/preview2/host/mod.rs b/crates/wasi/src/preview2/host/mod.rs index 138166731565..651d2cd38e0c 100644 --- a/crates/wasi/src/preview2/host/mod.rs +++ b/crates/wasi/src/preview2/host/mod.rs @@ -8,3 +8,5 @@ mod network; mod random; mod tcp; mod tcp_create_socket; +mod udp; +mod udp_create_socket; diff --git a/crates/wasi/src/preview2/host/udp.rs b/crates/wasi/src/preview2/host/udp.rs new file mode 100644 index 000000000000..38dadc938808 --- /dev/null +++ b/crates/wasi/src/preview2/host/udp.rs @@ -0,0 +1,488 @@ +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; + +use crate::preview2::{ + bindings::{ + sockets::network::{ErrorCode, IpAddressFamily, IpSocketAddress, Network}, + sockets::udp, + }, + udp::UdpState, + Table, +}; +use crate::preview2::{Pollable, SocketResult, WasiView}; +use cap_net_ext::PoolExt; +use io_lifetimes::AsSocketlike; +use rustix::io::Errno; +use rustix::net::sockopt; +use wasmtime::component::Resource; + +/// Theoretical maximum byte size of a UDP datagram, the real limit is lower, +/// but we do not account for e.g. the transport layer here for simplicity. +/// In practice, datagrams are typically less than 1500 bytes. +const MAX_UDP_DATAGRAM_SIZE: usize = 65535; + +fn start_bind( + table: &mut Table, + this: Resource, + network: Resource, + local_address: IpSocketAddress, +) -> SocketResult<()> { + let socket = table.get_resource(&this)?; + match socket.udp_state { + UdpState::Default => {} + _ => return Err(ErrorCode::NotInProgress.into()), + } + + let network = table.get_resource(&network)?; + let binder = network.pool.udp_binder(local_address)?; + + // Perform the OS bind call. + binder.bind_existing_udp_socket( + &*socket + .udp_socket() + .as_socketlike_view::(), + )?; + + let socket = table.get_resource_mut(&this)?; + socket.udp_state = UdpState::BindStarted; + + Ok(()) +} + +fn finish_bind(table: &mut Table, this: Resource) -> SocketResult<()> { + let socket = table.get_resource_mut(&this)?; + match socket.udp_state { + UdpState::BindStarted => {} + _ => return Err(ErrorCode::NotInProgress.into()), + } + + socket.udp_state = UdpState::Bound; + + Ok(()) +} + +fn address_family(table: &Table, this: Resource) -> SocketResult { + let socket = table.get_resource(&this)?; + + // If `SO_DOMAIN` is available, use it. + // + // TODO: OpenBSD also supports this; upstream PRs are posted. + #[cfg(not(any( + windows, + target_os = "ios", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd" + )))] + { + use rustix::net::AddressFamily; + + let family = sockopt::get_socket_domain(socket.udp_socket())?; + let family = match family { + AddressFamily::INET => IpAddressFamily::Ipv4, + AddressFamily::INET6 => IpAddressFamily::Ipv6, + _ => return Err(ErrorCode::NotSupported.into()), + }; + Ok(family) + } + + // When `SO_DOMAIN` is not available, emulate it. + #[cfg(any( + windows, + target_os = "ios", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd" + ))] + { + if let Ok(_) = sockopt::get_ipv6_unicast_hops(socket.udp_socket()) { + return Ok(IpAddressFamily::Ipv6); + } + if let Ok(_) = sockopt::get_ip_ttl(socket.udp_socket()) { + return Ok(IpAddressFamily::Ipv4); + } + Err(ErrorCode::NotSupported.into()) + } +} + +impl udp::Host for T {} + +impl crate::preview2::host::udp::udp::HostUdpSocket for T { + fn start_bind( + &mut self, + this: Resource, + network: Resource, + local_address: IpSocketAddress, + ) -> SocketResult<()> { + start_bind(self.table_mut(), this, network, local_address) + } + + fn finish_bind(&mut self, this: Resource) -> SocketResult<()> { + finish_bind(self.table_mut(), this) + } + + fn start_connect( + &mut self, + this: Resource, + network: Resource, + remote_address: IpSocketAddress, + ) -> SocketResult<()> { + let table = self.table_mut(); + let r = { + let socket = table.get_resource(&this)?; + match socket.udp_state { + UdpState::Default => { + let family = address_family(table, Resource::new_borrow(this.rep()))?; + let addr = match family { + IpAddressFamily::Ipv4 => Ipv4Addr::UNSPECIFIED.into(), + IpAddressFamily::Ipv6 => Ipv6Addr::UNSPECIFIED.into(), + }; + start_bind( + table, + Resource::new_borrow(this.rep()), + Resource::new_borrow(network.rep()), + SocketAddr::new(addr, 0).into(), + )?; + finish_bind(table, Resource::new_borrow(this.rep()))?; + } + UdpState::BindStarted => { + finish_bind(table, Resource::new_borrow(this.rep()))?; + } + UdpState::Bound => {} + UdpState::Connected => return Err(ErrorCode::AlreadyConnected.into()), + _ => return Err(ErrorCode::NotInProgress.into()), + } + + let socket = table.get_resource(&this)?; + let network = table.get_resource(&network)?; + let connecter = network.pool.udp_connecter(remote_address)?; + + // Do an OS `connect`. Our socket is non-blocking, so it'll either... + { + let view = &*socket + .udp_socket() + .as_socketlike_view::(); + let r = connecter.connect_existing_udp_socket(view); + r + } + }; + + match r { + // succeed immediately, + Ok(()) => { + let socket = table.get_resource_mut(&this)?; + socket.udp_state = UdpState::ConnectReady; + return Ok(()); + } + // continue in progress, + Err(err) if err.raw_os_error() == Some(INPROGRESS.raw_os_error()) => {} + // or fail immediately. + Err(err) => return Err(err.into()), + } + + let socket = table.get_resource_mut(&this)?; + socket.udp_state = UdpState::Connecting; + + Ok(()) + } + + fn finish_connect(&mut self, this: Resource) -> SocketResult<()> { + let table = self.table_mut(); + let socket = table.get_resource_mut(&this)?; + + match socket.udp_state { + UdpState::ConnectReady => {} + UdpState::Connecting => { + // Do a `poll` to test for completion, using a timeout of zero + // to avoid blocking. + match rustix::event::poll( + &mut [rustix::event::PollFd::new( + socket.udp_socket(), + rustix::event::PollFlags::OUT, + )], + 0, + ) { + Ok(0) => return Err(ErrorCode::WouldBlock.into()), + Ok(_) => (), + Err(err) => Err(err).unwrap(), + } + + // Check whether the connect succeeded. + match sockopt::get_socket_error(socket.udp_socket()) { + Ok(Ok(())) => {} + Err(err) | Ok(Err(err)) => return Err(err.into()), + } + } + _ => return Err(ErrorCode::NotInProgress.into()), + }; + + socket.udp_state = UdpState::Connected; + Ok(()) + } + + fn receive( + &mut self, + this: Resource, + max_results: u64, + ) -> SocketResult> { + if max_results == 0 { + return Ok(vec![]); + } + + let table = self.table(); + let socket = table.get_resource(&this)?; + + let udp_socket = socket.udp_socket(); + let mut datagrams = Vec::with_capacity(max_results.try_into().unwrap_or(usize::MAX)); + let mut buf = [0; MAX_UDP_DATAGRAM_SIZE]; + match socket.udp_state { + UdpState::Default | UdpState::BindStarted => return Err(ErrorCode::NotBound.into()), + UdpState::Bound | UdpState::Connecting | UdpState::ConnectReady => { + for i in 0..max_results { + match udp_socket.try_recv_from(&mut buf) { + Ok((size, remote_address)) => datagrams.push(udp::Datagram { + data: buf[..size].into(), + remote_address: remote_address.into(), + }), + Err(_e) if i > 0 => { + return Ok(datagrams); + } + Err(e) => return Err(e.into()), + } + } + } + UdpState::Connected => { + let remote_address = udp_socket.peer_addr().map(Into::into)?; + for i in 0..max_results { + match udp_socket.try_recv(&mut buf) { + Ok(size) => datagrams.push(udp::Datagram { + data: buf[..size].into(), + remote_address, + }), + Err(_e) if i > 0 => { + return Ok(datagrams); + } + Err(e) => return Err(e.into()), + } + } + } + } + Ok(datagrams) + } + + fn send( + &mut self, + this: Resource, + datagrams: Vec, + ) -> SocketResult { + if datagrams.is_empty() { + return Ok(0); + }; + let table = self.table(); + let socket = table.get_resource(&this)?; + + let udp_socket = socket.udp_socket(); + let mut count = 0; + match socket.udp_state { + UdpState::Default | UdpState::BindStarted => return Err(ErrorCode::NotBound.into()), + UdpState::Bound | UdpState::Connecting | UdpState::ConnectReady => { + for udp::Datagram { + data, + remote_address, + } in datagrams + { + match udp_socket.try_send_to(&data, remote_address.into()) { + Ok(_size) => count += 1, + Err(_e) if count > 0 => { + return Ok(count); + } + Err(e) => return Err(e.into()), + } + } + } + UdpState::Connected => { + let peer_addr = udp_socket.peer_addr()?; + for udp::Datagram { + data, + remote_address, + } in datagrams + { + if SocketAddr::from(remote_address) != peer_addr { + // From WIT documentation: + // If at least one datagram has been sent successfully, this function never returns an error. + if count == 0 { + return Err(ErrorCode::AlreadyConnected.into()); + } else { + return Ok(count); + } + } + match udp_socket.try_send(&data) { + Ok(_size) => count += 1, + Err(_e) if count > 0 => { + return Ok(count); + } + Err(e) => return Err(e.into()), + } + } + } + } + Ok(count) + } + + fn local_address(&mut self, this: Resource) -> SocketResult { + let table = self.table(); + let socket = table.get_resource(&this)?; + let addr = socket + .udp_socket() + .as_socketlike_view::() + .local_addr()?; + Ok(addr.into()) + } + + fn remote_address(&mut self, this: Resource) -> SocketResult { + let table = self.table(); + let socket = table.get_resource(&this)?; + let addr = socket + .udp_socket() + .as_socketlike_view::() + .peer_addr()?; + Ok(addr.into()) + } + + fn address_family( + &mut self, + this: Resource, + ) -> Result { + let family = address_family(self.table(), this)?; + Ok(family) + } + + fn ipv6_only(&mut self, this: Resource) -> SocketResult { + let table = self.table(); + let socket = table.get_resource(&this)?; + Ok(sockopt::get_ipv6_v6only(socket.udp_socket())?) + } + + fn set_ipv6_only(&mut self, this: Resource, value: bool) -> SocketResult<()> { + let table = self.table(); + let socket = table.get_resource(&this)?; + Ok(sockopt::set_ipv6_v6only(socket.udp_socket(), value)?) + } + + fn unicast_hop_limit(&mut self, this: Resource) -> SocketResult { + let table = self.table(); + let socket = table.get_resource(&this)?; + + // We don't track whether the socket is IPv4 or IPv6 so try one and + // fall back to the other. + match sockopt::get_ipv6_unicast_hops(socket.udp_socket()) { + Ok(value) => Ok(value), + Err(Errno::NOPROTOOPT) => { + let value = sockopt::get_ip_ttl(socket.udp_socket())?; + let value = value.try_into().unwrap(); + Ok(value) + } + Err(err) => Err(err.into()), + } + } + + fn set_unicast_hop_limit( + &mut self, + this: Resource, + value: u8, + ) -> SocketResult<()> { + let table = self.table(); + let socket = table.get_resource(&this)?; + + // We don't track whether the socket is IPv4 or IPv6 so try one and + // fall back to the other. + match sockopt::set_ipv6_unicast_hops(socket.udp_socket(), Some(value)) { + Ok(()) => Ok(()), + Err(Errno::NOPROTOOPT) => Ok(sockopt::set_ip_ttl(socket.udp_socket(), value.into())?), + Err(err) => Err(err.into()), + } + } + + fn receive_buffer_size(&mut self, this: Resource) -> SocketResult { + let table = self.table(); + let socket = table.get_resource(&this)?; + Ok(sockopt::get_socket_recv_buffer_size(socket.udp_socket())? as u64) + } + + fn set_receive_buffer_size( + &mut self, + this: Resource, + value: u64, + ) -> SocketResult<()> { + let table = self.table(); + let socket = table.get_resource(&this)?; + let value = value.try_into().map_err(|_| ErrorCode::OutOfMemory)?; + Ok(sockopt::set_socket_recv_buffer_size( + socket.udp_socket(), + value, + )?) + } + + fn send_buffer_size(&mut self, this: Resource) -> SocketResult { + let table = self.table(); + let socket = table.get_resource(&this)?; + Ok(sockopt::get_socket_send_buffer_size(socket.udp_socket())? as u64) + } + + fn set_send_buffer_size( + &mut self, + this: Resource, + value: u64, + ) -> SocketResult<()> { + let table = self.table(); + let socket = table.get_resource(&this)?; + let value = value.try_into().map_err(|_| ErrorCode::OutOfMemory)?; + Ok(sockopt::set_socket_send_buffer_size( + socket.udp_socket(), + value, + )?) + } + + fn subscribe(&mut self, this: Resource) -> anyhow::Result> { + crate::preview2::poll::subscribe(self.table_mut(), this) + } + + fn drop(&mut self, this: Resource) -> Result<(), anyhow::Error> { + let table = self.table_mut(); + + // As in the filesystem implementation, we assume closing a socket + // doesn't block. + let dropped = table.delete_resource(this)?; + + // If we might have an `event::poll` waiting on the socket, wake it up. + #[cfg(not(unix))] + { + match dropped.udp_state { + UdpState::Default + | UdpState::BindStarted + | UdpState::Bound + | UdpState::ConnectReady => {} + + UdpState::Connecting | UdpState::Connected => { + match rustix::net::shutdown(&*dropped.inner, rustix::net::Shutdown::ReadWrite) { + Ok(()) | Err(Errno::NOTCONN) => {} + Err(err) => Err(err).unwrap(), + } + } + } + } + + drop(dropped); + + Ok(()) + } +} + +// On POSIX, non-blocking UDP socket `connect` uses `EINPROGRESS`. +// +#[cfg(not(windows))] +const INPROGRESS: Errno = Errno::INPROGRESS; + +// On Windows, non-blocking UDP socket `connect` uses `WSAEWOULDBLOCK`. +// +#[cfg(windows)] +const INPROGRESS: Errno = Errno::WOULDBLOCK; diff --git a/crates/wasi/src/preview2/host/udp_create_socket.rs b/crates/wasi/src/preview2/host/udp_create_socket.rs new file mode 100644 index 000000000000..7e57e19d5297 --- /dev/null +++ b/crates/wasi/src/preview2/host/udp_create_socket.rs @@ -0,0 +1,15 @@ +use crate::preview2::bindings::{sockets::network::IpAddressFamily, sockets::udp_create_socket}; +use crate::preview2::udp::UdpSocket; +use crate::preview2::{SocketResult, WasiView}; +use wasmtime::component::Resource; + +impl udp_create_socket::Host for T { + fn create_udp_socket( + &mut self, + address_family: IpAddressFamily, + ) -> SocketResult> { + let socket = UdpSocket::new(address_family.into())?; + let socket = self.table_mut().push_resource(socket)?; + Ok(socket) + } +} diff --git a/crates/wasi/src/preview2/mod.rs b/crates/wasi/src/preview2/mod.rs index aecc414ea4b2..0f3059597dc0 100644 --- a/crates/wasi/src/preview2/mod.rs +++ b/crates/wasi/src/preview2/mod.rs @@ -36,6 +36,7 @@ mod stdio; mod stream; mod table; mod tcp; +mod udp; mod write_stream; pub use self::clocks::{HostMonotonicClock, HostWallClock}; @@ -160,6 +161,7 @@ pub mod bindings { with: { "wasi:sockets/network/network": super::network::Network, "wasi:sockets/tcp/tcp-socket": super::tcp::TcpSocket, + "wasi:sockets/udp/udp-socket": super::udp::UdpSocket, "wasi:filesystem/types/directory-entry-stream": super::filesystem::ReaddirIterator, "wasi:filesystem/types/descriptor": super::filesystem::Descriptor, "wasi:io/streams/input-stream": super::stream::InputStream, diff --git a/crates/wasi/src/preview2/udp.rs b/crates/wasi/src/preview2/udp.rs new file mode 100644 index 000000000000..b146a47ccd14 --- /dev/null +++ b/crates/wasi/src/preview2/udp.rs @@ -0,0 +1,87 @@ +use crate::preview2::poll::Subscribe; +use crate::preview2::with_ambient_tokio_runtime; +use async_trait::async_trait; +use cap_net_ext::{AddressFamily, Blocking, UdpSocketExt}; +use io_lifetimes::raw::{FromRawSocketlike, IntoRawSocketlike}; +use std::io; +use std::sync::Arc; +use tokio::io::Interest; + +/// The state of a UDP socket. +/// +/// This represents the various states a socket can be in during the +/// activities of binding, and connecting. +pub(crate) enum UdpState { + /// The initial state for a newly-created socket. + Default, + + /// Binding started via `start_bind`. + BindStarted, + + /// Binding finished via `finish_bind`. The socket has an address but + /// is not yet listening for connections. + Bound, + + /// An outgoing connection is started via `start_connect`. + Connecting, + + /// An outgoing connection is ready to be established. + ConnectReady, + + /// An outgoing connection has been established. + Connected, +} + +/// A host UDP socket, plus associated bookkeeping. +/// +/// The inner state is wrapped in an Arc because the same underlying socket is +/// used for implementing the stream types. +pub struct UdpSocket { + /// The part of a `UdpSocket` which is reference-counted so that we + /// can pass it to async tasks. + pub(crate) inner: Arc, + + /// The current state in the bind/connect progression. + pub(crate) udp_state: UdpState, +} + +#[async_trait] +impl Subscribe for UdpSocket { + async fn ready(&mut self) { + // Some states are ready immediately. + match self.udp_state { + UdpState::BindStarted | UdpState::ConnectReady => return, + _ => {} + } + + // FIXME: Add `Interest::ERROR` when we update to tokio 1.32. + self.inner + .ready(Interest::READABLE | Interest::WRITABLE) + .await + .expect("failed to await UDP socket readiness"); + } +} + +impl UdpSocket { + /// Create a new socket in the given family. + pub fn new(family: AddressFamily) -> io::Result { + // Create a new host socket and set it to non-blocking, which is needed + // by our async implementation. + let udp_socket = cap_std::net::UdpSocket::new(family, Blocking::No)?; + Self::from_udp_socket(udp_socket) + } + + pub fn from_udp_socket(udp_socket: cap_std::net::UdpSocket) -> io::Result { + let fd = udp_socket.into_raw_socketlike(); + let std_socket = unsafe { std::net::UdpSocket::from_raw_socketlike(fd) }; + let socket = with_ambient_tokio_runtime(|| tokio::net::UdpSocket::try_from(std_socket))?; + Ok(Self { + inner: Arc::new(socket), + udp_state: UdpState::Default, + }) + } + + pub fn udp_socket(&self) -> &tokio::net::UdpSocket { + &self.inner + } +} diff --git a/crates/wasi/wit/test.wit b/crates/wasi/wit/test.wit index 03073513f8e6..500fd92ae84d 100644 --- a/crates/wasi/wit/test.wit +++ b/crates/wasi/wit/test.wit @@ -37,6 +37,8 @@ world test-command-with-sockets { import wasi:cli/stderr import wasi:sockets/tcp import wasi:sockets/tcp-create-socket + import wasi:sockets/udp + import wasi:sockets/udp-create-socket import wasi:sockets/network import wasi:sockets/instance-network import wasi:sockets/ip-name-lookup