diff --git a/crates/wasi/src/preview2/host/network.rs b/crates/wasi/src/preview2/host/network.rs index 55dedb1b145d..e7e025e8eb90 100644 --- a/crates/wasi/src/preview2/host/network.rs +++ b/crates/wasi/src/preview2/host/network.rs @@ -218,9 +218,9 @@ pub(crate) mod util { use crate::preview2::bindings::sockets::network::ErrorCode; use crate::preview2::network::SocketAddressFamily; use crate::preview2::SocketResult; - use cap_net_ext::{Blocking, TcpListenerExt}; - use cap_std::net::{TcpListener, TcpStream, UdpSocket}; - use rustix::fd::AsFd; + use cap_net_ext::{AddressFamily, Blocking, TcpListenerExt, UdpSocketExt}; + use io_lifetimes::AsSocketlike; + use rustix::fd::{AsFd, OwnedFd}; use rustix::io::Errno; use rustix::net::sockopt; @@ -302,42 +302,64 @@ pub(crate) mod util { * Syscalls wrappers with (opinionated) portability fixes. */ - pub fn tcp_bind(listener: &TcpListener, addr: &SocketAddr) -> std::io::Result<()> { - rustix::net::bind(listener, addr).map_err(|error| match error { + pub fn tcp_socket(family: AddressFamily, blocking: Blocking) -> std::io::Result { + // Delegate socket creation to cap_net_ext. They handle a couple of things for us: + // - On Windows: call WSAStartup if not done before. + // - Set the NONBLOCK and CLOEXEC flags. Either immediately during socket creation, + // or afterwards using ioctl or fcntl. Exact method depends on the platform. + + let listener = cap_std::net::TcpListener::new(family, blocking)?; + Ok(OwnedFd::from(listener)) + } + + pub fn udp_socket(family: AddressFamily, blocking: Blocking) -> std::io::Result { + // Delegate socket creation to cap_net_ext. They handle a couple of things for us: + // - On Windows: call WSAStartup if not done before. + // - Set the NONBLOCK and CLOEXEC flags. Either immediately during socket creation, + // or afterwards using ioctl or fcntl. Exact method depends on the platform. + + let socket = cap_std::net::UdpSocket::new(family, blocking)?; + Ok(OwnedFd::from(socket)) + } + + pub fn tcp_bind(sockfd: Fd, addr: &SocketAddr) -> rustix::io::Result<()> { + rustix::net::bind(sockfd, addr).map_err(|error| match error { // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-bind#:~:text=WSAENOBUFS // Windows returns WSAENOBUFS when the ephemeral ports have been exhausted. #[cfg(windows)] - Errno::NOBUFS => Errno::ADDRINUSE.into(), - _ => error.into(), + Errno::NOBUFS => Errno::ADDRINUSE, + _ => error, }) } - pub fn udp_bind(socket: &UdpSocket, addr: &SocketAddr) -> std::io::Result<()> { - rustix::net::bind(socket, addr).map_err(|error| { - match error { - // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-bind#:~:text=WSAENOBUFS - // Windows returns WSAENOBUFS when the ephemeral ports have been exhausted. - #[cfg(windows)] - Errno::NOBUFS => Errno::ADDRINUSE.into(), - _ => error.into(), - } + pub fn udp_bind(sockfd: Fd, addr: &SocketAddr) -> rustix::io::Result<()> { + rustix::net::bind(sockfd, addr).map_err(|error| match error { + // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-bind#:~:text=WSAENOBUFS + // Windows returns WSAENOBUFS when the ephemeral ports have been exhausted. + #[cfg(windows)] + Errno::NOBUFS => Errno::ADDRINUSE, + _ => error, }) } - pub fn tcp_connect(listener: &TcpListener, addr: &SocketAddr) -> std::io::Result<()> { - rustix::net::connect(listener, addr).map_err(|error| match error { + pub fn tcp_connect(sockfd: Fd, addr: &SocketAddr) -> rustix::io::Result<()> { + rustix::net::connect(sockfd, addr).map_err(|error| match error { // On POSIX, non-blocking `connect` returns `EINPROGRESS`. // Windows returns `WSAEWOULDBLOCK`. // // This normalized error code is depended upon by: tcp.rs #[cfg(windows)] - Errno::WOULDBLOCK => Errno::INPROGRESS.into(), - _ => error.into(), + Errno::WOULDBLOCK => Errno::INPROGRESS, + _ => error, }) } - pub fn tcp_listen(listener: &TcpListener, backlog: Option) -> std::io::Result<()> { - listener + pub fn tcp_listen(sockfd: Fd, backlog: Option) -> std::io::Result<()> { + // Delegate `listen` to cap_net_ext. That is a thin wrapper around rustix::net::listen, + // with a platform-dependent default value for the backlog size. + sockfd + .as_fd() + .as_socketlike_view::() .listen(backlog) .map_err(|error| match Errno::from_io_error(&error) { // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-listen#:~:text=WSAEMFILE @@ -351,15 +373,22 @@ pub(crate) mod util { // on Microsoft's documentation here. #[cfg(windows)] Some(Errno::MFILE) => Errno::NOBUFS.into(), + _ => error, }) } - pub fn tcp_accept( - listener: &TcpListener, + pub fn tcp_accept( + sockfd: Fd, blocking: Blocking, - ) -> std::io::Result<(TcpStream, SocketAddr)> { - listener + ) -> std::io::Result<(OwnedFd, SocketAddr)> { + // Delegate `accept` to cap_net_ext. They set the NONBLOCK and CLOEXEC flags + // for us. Either immediately as a flag to `accept`, or afterwards using + // ioctl or fcntl. Exact method depends on the platform. + + let (client, addr) = sockfd + .as_fd() + .as_socketlike_view::() .accept_with(blocking) .map_err(|error| match Errno::from_io_error(&error) { // From: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-accept#:~:text=WSAEINPROGRESS @@ -393,7 +422,9 @@ pub(crate) mod util { ) => Errno::CONNABORTED.into(), _ => error, - }) + })?; + + Ok((client.into(), addr)) } pub fn udp_disconnect(sockfd: Fd) -> rustix::io::Result<()> { diff --git a/crates/wasi/src/preview2/host/tcp.rs b/crates/wasi/src/preview2/host/tcp.rs index b1d4ba31f28c..1fe8bc5a6771 100644 --- a/crates/wasi/src/preview2/host/tcp.rs +++ b/crates/wasi/src/preview2/host/tcp.rs @@ -11,7 +11,6 @@ use crate::preview2::{ }; use crate::preview2::{Pollable, SocketResult, WasiView}; use cap_net_ext::Blocking; -use cap_std::net::TcpListener; use io_lifetimes::AsSocketlike; use rustix::io::Errno; use rustix::net::sockopt; @@ -47,21 +46,18 @@ impl crate::preview2::host::tcp::tcp::HostTcpSocket for T { { // Ensure that we're allowed to connect to this address. network.check_socket_addr(&local_address, SocketAddrUse::TcpBind)?; - let listener = &*socket.tcp_socket().as_socketlike_view::(); // Perform the OS bind call. - util::tcp_bind(listener, &local_address).map_err(|error| { - match Errno::from_io_error(&error) { - // From https://pubs.opengroup.org/onlinepubs/9699919799/functions/bind.html: - // > [EAFNOSUPPORT] The specified address is not a valid address for the address family of the specified socket - // - // The most common reasons for this error should have already - // been handled by our own validation slightly higher up in this - // function. This error mapping is here just in case there is - // an edge case we didn't catch. - Some(Errno::AFNOSUPPORT) => ErrorCode::InvalidArgument, - _ => ErrorCode::from(error), - } + util::tcp_bind(socket.tcp_socket(), &local_address).map_err(|error| match error { + // From https://pubs.opengroup.org/onlinepubs/9699919799/functions/bind.html: + // > [EAFNOSUPPORT] The specified address is not a valid address for the address family of the specified socket + // + // The most common reasons for this error should have already + // been handled by our own validation slightly higher up in this + // function. This error mapping is here just in case there is + // an edge case we didn't catch. + Errno::AFNOSUPPORT => ErrorCode::InvalidArgument, + _ => ErrorCode::from(error), })?; } @@ -116,10 +112,9 @@ impl crate::preview2::host::tcp::tcp::HostTcpSocket for T { // Ensure that we're allowed to connect to this address. network.check_socket_addr(&remote_address, SocketAddrUse::TcpConnect)?; - let listener = &*socket.tcp_socket().as_socketlike_view::(); // Do an OS `connect`. Our socket is non-blocking, so it'll either... - util::tcp_connect(listener, &remote_address) + util::tcp_connect(socket.tcp_socket(), &remote_address) }; match r { @@ -130,11 +125,11 @@ impl crate::preview2::host::tcp::tcp::HostTcpSocket for T { return Ok(()); } // continue in progress, - Err(err) if Errno::from_io_error(&err) == Some(Errno::INPROGRESS) => {} + Err(err) if err == Errno::INPROGRESS => {} // or fail immediately. Err(err) => { - return Err(match Errno::from_io_error(&err) { - Some(Errno::AFNOSUPPORT) => ErrorCode::InvalidArgument.into(), // See `bind` implementation. + return Err(match err { + Errno::AFNOSUPPORT => ErrorCode::InvalidArgument.into(), // See `bind` implementation. _ => err.into(), }); } @@ -207,10 +202,7 @@ impl crate::preview2::host::tcp::tcp::HostTcpSocket for T { | TcpState::BindStarted => return Err(ErrorCode::ConcurrencyConflict.into()), } - { - let listener = &*socket.tcp_socket().as_socketlike_view::(); - util::tcp_listen(listener, socket.listen_backlog_size)?; - } + util::tcp_listen(socket.tcp_socket(), socket.listen_backlog_size)?; socket.tcp_state = TcpState::ListenStarted; @@ -250,9 +242,8 @@ impl crate::preview2::host::tcp::tcp::HostTcpSocket for T { // Do the OS accept call. let tcp_socket = socket.tcp_socket(); - let (connection, _addr) = tcp_socket.try_io(Interest::READABLE, || { - let listener = &*tcp_socket.as_socketlike_view::(); - util::tcp_accept(listener, Blocking::No) + let (client_fd, _addr) = tcp_socket.try_io(Interest::READABLE, || { + util::tcp_accept(tcp_socket, Blocking::No) })?; #[cfg(target_os = "macos")] @@ -262,25 +253,25 @@ impl crate::preview2::host::tcp::tcp::HostTcpSocket for T { // and only if a specific value was explicitly set on the listener. if let Some(size) = socket.receive_buffer_size { - _ = util::set_socket_recv_buffer_size(&connection, size); // Ignore potential error. + _ = util::set_socket_recv_buffer_size(&client_fd, size); // Ignore potential error. } if let Some(size) = socket.send_buffer_size { - _ = util::set_socket_send_buffer_size(&connection, size); // Ignore potential error. + _ = util::set_socket_send_buffer_size(&client_fd, size); // Ignore potential error. } // For some reason, IP_TTL is inherited, but IPV6_UNICAST_HOPS isn't. if let (SocketAddressFamily::Ipv6 { .. }, Some(ttl)) = (socket.family, socket.hop_limit) { - _ = util::set_ipv6_unicast_hops(&connection, ttl); // Ignore potential error. + _ = util::set_ipv6_unicast_hops(&client_fd, ttl); // Ignore potential error. } if let Some(value) = socket.keep_alive_idle_time { - _ = util::set_tcp_keepidle(&connection, value); // Ignore potential error. + _ = util::set_tcp_keepidle(&client_fd, value); // Ignore potential error. } } - let mut tcp_socket = TcpSocket::from_tcp_stream(connection, socket.family)?; + let mut tcp_socket = TcpSocket::from_fd(client_fd, socket.family)?; // Mark the socket as connected so that we can exit early from methods like `start-bind`. tcp_socket.tcp_state = TcpState::Connected; diff --git a/crates/wasi/src/preview2/host/udp.rs b/crates/wasi/src/preview2/host/udp.rs index 09ca8b5b9875..b256fd861760 100644 --- a/crates/wasi/src/preview2/host/udp.rs +++ b/crates/wasi/src/preview2/host/udp.rs @@ -55,23 +55,18 @@ impl udp::HostUdpSocket for T { { check.check(&local_address, SocketAddrUse::UdpBind)?; - let udp_socket = &*socket - .udp_socket() - .as_socketlike_view::(); // Perform the OS bind call. - util::udp_bind(udp_socket, &local_address).map_err(|error| { - match Errno::from_io_error(&error) { - // From https://pubs.opengroup.org/onlinepubs/9699919799/functions/bind.html: - // > [EAFNOSUPPORT] The specified address is not a valid address for the address family of the specified socket - // - // The most common reasons for this error should have already - // been handled by our own validation slightly higher up in this - // function. This error mapping is here just in case there is - // an edge case we didn't catch. - Some(Errno::AFNOSUPPORT) => ErrorCode::InvalidArgument, - _ => ErrorCode::from(error), - } + util::udp_bind(socket.udp_socket(), &local_address).map_err(|error| match error { + // From https://pubs.opengroup.org/onlinepubs/9699919799/functions/bind.html: + // > [EAFNOSUPPORT] The specified address is not a valid address for the address family of the specified socket + // + // The most common reasons for this error should have already + // been handled by our own validation slightly higher up in this + // function. This error mapping is here just in case there is + // an edge case we didn't catch. + Errno::AFNOSUPPORT => ErrorCode::InvalidArgument, + _ => ErrorCode::from(error), })?; } diff --git a/crates/wasi/src/preview2/tcp.rs b/crates/wasi/src/preview2/tcp.rs index 23061760c27b..fe803306fc3a 100644 --- a/crates/wasi/src/preview2/tcp.rs +++ b/crates/wasi/src/preview2/tcp.rs @@ -1,11 +1,11 @@ use super::network::SocketAddressFamily; use super::{HostInputStream, HostOutputStream, StreamError}; +use crate::preview2::host::network::util; use crate::preview2::{ with_ambient_tokio_runtime, AbortOnDropJoinHandle, InputStream, OutputStream, Subscribe, }; use anyhow::{Error, Result}; -use cap_net_ext::{AddressFamily, Blocking, TcpListenerExt}; -use cap_std::net::TcpListener; +use cap_net_ext::{AddressFamily, Blocking}; use io_lifetimes::raw::{FromRawSocketlike, IntoRawSocketlike}; use rustix::net::sockopt; use std::io; @@ -261,36 +261,26 @@ impl TcpSocket { pub fn new(family: AddressFamily) -> io::Result { // Create a new host socket and set it to non-blocking, which is needed // by our async implementation. - let tcp_listener = TcpListener::new(family, Blocking::No)?; + let fd = util::tcp_socket(family, Blocking::No)?; let socket_address_family = match family { AddressFamily::Ipv4 => SocketAddressFamily::Ipv4, AddressFamily::Ipv6 => SocketAddressFamily::Ipv6 { - v6only: sockopt::get_ipv6_v6only(&tcp_listener)?, + v6only: sockopt::get_ipv6_v6only(&fd)?, }, }; - Self::from_tcp_listener(tcp_listener, socket_address_family) + Self::from_fd(fd, socket_address_family) } /// Create a `TcpSocket` from an existing socket. /// /// The socket must be in non-blocking mode. - pub(crate) fn from_tcp_stream( - tcp_socket: cap_std::net::TcpStream, + pub(crate) fn from_fd( + fd: rustix::fd::OwnedFd, family: SocketAddressFamily, ) -> io::Result { - let tcp_listener = TcpListener::from(rustix::fd::OwnedFd::from(tcp_socket)); - Self::from_tcp_listener(tcp_listener, family) - } - - pub(crate) fn from_tcp_listener( - tcp_listener: cap_std::net::TcpListener, - family: SocketAddressFamily, - ) -> io::Result { - let fd = tcp_listener.into_raw_socketlike(); - let std_stream = unsafe { std::net::TcpStream::from_raw_socketlike(fd) }; - let stream = with_ambient_tokio_runtime(|| tokio::net::TcpStream::try_from(std_stream))?; + let stream = Self::setup_tokio_tcp_stream(fd)?; Ok(Self { inner: Arc::new(stream), @@ -308,6 +298,12 @@ impl TcpSocket { }) } + fn setup_tokio_tcp_stream(fd: rustix::fd::OwnedFd) -> io::Result { + let std_stream = + unsafe { std::net::TcpStream::from_raw_socketlike(fd.into_raw_socketlike()) }; + with_ambient_tokio_runtime(|| tokio::net::TcpStream::try_from(std_stream)) + } + pub fn tcp_socket(&self) -> &tokio::net::TcpStream { &self.inner } diff --git a/crates/wasi/src/preview2/udp.rs b/crates/wasi/src/preview2/udp.rs index 70e35ba44e73..34d46402475b 100644 --- a/crates/wasi/src/preview2/udp.rs +++ b/crates/wasi/src/preview2/udp.rs @@ -1,7 +1,8 @@ +use crate::preview2::host::network::util; use crate::preview2::poll::Subscribe; use crate::preview2::with_ambient_tokio_runtime; use async_trait::async_trait; -use cap_net_ext::{AddressFamily, Blocking, UdpSocketExt}; +use cap_net_ext::{AddressFamily, Blocking}; use io_lifetimes::raw::{FromRawSocketlike, IntoRawSocketlike}; use std::io; use std::net::SocketAddr; @@ -57,15 +58,19 @@ impl Subscribe for UdpSocket { impl UdpSocket { /// Create a new socket in the given family. pub fn new(family: AddressFamily) -> io::Result { - let socket = Self::new_tokio_socket(family)?; + // Create a new host socket and set it to non-blocking, which is needed + // by our async implementation. + let fd = util::udp_socket(family, Blocking::No)?; let socket_address_family = match family { AddressFamily::Ipv4 => SocketAddressFamily::Ipv4, AddressFamily::Ipv6 => SocketAddressFamily::Ipv6 { - v6only: rustix::net::sockopt::get_ipv6_v6only(&socket)?, + v6only: rustix::net::sockopt::get_ipv6_v6only(&fd)?, }, }; + let socket = Self::setup_tokio_udp_socket(fd)?; + Ok(UdpSocket { inner: Arc::new(socket), udp_state: UdpState::Default, @@ -74,16 +79,10 @@ impl UdpSocket { }) } - fn new_tokio_socket(family: AddressFamily) -> io::Result { - // Create a new host socket and set it to non-blocking, which is needed - // by our async implementation. - let cap_std_socket = cap_std::net::UdpSocket::new(family, Blocking::No)?; - let fd = cap_std_socket.into_raw_socketlike(); - let std_socket = unsafe { std::net::UdpSocket::from_raw_socketlike(fd) }; - let tokio_socket = - with_ambient_tokio_runtime(|| tokio::net::UdpSocket::try_from(std_socket))?; - - Ok(tokio_socket) + fn setup_tokio_udp_socket(fd: rustix::fd::OwnedFd) -> io::Result { + let std_socket = + unsafe { std::net::UdpSocket::from_raw_socketlike(fd.into_raw_socketlike()) }; + with_ambient_tokio_runtime(|| tokio::net::UdpSocket::try_from(std_socket)) } pub fn udp_socket(&self) -> &tokio::net::UdpSocket {