Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wasi-sockets: Factor out cap-std #7687

Merged
merged 2 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 58 additions & 27 deletions crates/wasi/src/preview2/host/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,9 @@ pub(crate) mod util {
use crate::preview2::bindings::sockets::network::ErrorCode;
use crate::preview2::network::SocketAddressFamily;
use crate::preview2::SocketResult;
use cap_net_ext::{Blocking, TcpListenerExt};
use cap_std::net::{TcpListener, TcpStream, UdpSocket};
use rustix::fd::AsFd;
use cap_net_ext::{AddressFamily, Blocking, TcpListenerExt, UdpSocketExt};
use io_lifetimes::AsSocketlike;
use rustix::fd::{AsFd, OwnedFd};
use rustix::io::Errno;
use rustix::net::sockopt;

Expand Down Expand Up @@ -302,42 +302,64 @@ pub(crate) mod util {
* Syscalls wrappers with (opinionated) portability fixes.
*/

pub fn tcp_bind(listener: &TcpListener, addr: &SocketAddr) -> std::io::Result<()> {
rustix::net::bind(listener, addr).map_err(|error| match error {
pub fn tcp_socket(family: AddressFamily, blocking: Blocking) -> std::io::Result<OwnedFd> {
// Delegate socket creation to cap_net_ext. They handle a couple of things for us:
// - On Windows: call WSAStartup if not done before.
// - Set the NONBLOCK and CLOEXEC flags. Either immediately during socket creation,
// or afterwards using ioctl or fcntl. Exact method depends on the platform.

let listener = cap_std::net::TcpListener::new(family, blocking)?;
Ok(OwnedFd::from(listener))
}

pub fn udp_socket(family: AddressFamily, blocking: Blocking) -> std::io::Result<OwnedFd> {
// Delegate socket creation to cap_net_ext. They handle a couple of things for us:
// - On Windows: call WSAStartup if not done before.
// - Set the NONBLOCK and CLOEXEC flags. Either immediately during socket creation,
// or afterwards using ioctl or fcntl. Exact method depends on the platform.

let socket = cap_std::net::UdpSocket::new(family, blocking)?;
Ok(OwnedFd::from(socket))
}

pub fn tcp_bind<Fd: AsFd>(sockfd: Fd, addr: &SocketAddr) -> rustix::io::Result<()> {
rustix::net::bind(sockfd, addr).map_err(|error| match error {
// See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-bind#:~:text=WSAENOBUFS
// Windows returns WSAENOBUFS when the ephemeral ports have been exhausted.
#[cfg(windows)]
Errno::NOBUFS => Errno::ADDRINUSE.into(),
_ => error.into(),
Errno::NOBUFS => Errno::ADDRINUSE,
_ => error,
})
}

pub fn udp_bind(socket: &UdpSocket, addr: &SocketAddr) -> std::io::Result<()> {
rustix::net::bind(socket, addr).map_err(|error| {
match error {
// See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-bind#:~:text=WSAENOBUFS
// Windows returns WSAENOBUFS when the ephemeral ports have been exhausted.
#[cfg(windows)]
Errno::NOBUFS => Errno::ADDRINUSE.into(),
_ => error.into(),
}
pub fn udp_bind<Fd: AsFd>(sockfd: Fd, addr: &SocketAddr) -> rustix::io::Result<()> {
rustix::net::bind(sockfd, addr).map_err(|error| match error {
// See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-bind#:~:text=WSAENOBUFS
// Windows returns WSAENOBUFS when the ephemeral ports have been exhausted.
#[cfg(windows)]
Errno::NOBUFS => Errno::ADDRINUSE,
_ => error,
})
}

pub fn tcp_connect(listener: &TcpListener, addr: &SocketAddr) -> std::io::Result<()> {
rustix::net::connect(listener, addr).map_err(|error| match error {
pub fn tcp_connect<Fd: AsFd>(sockfd: Fd, addr: &SocketAddr) -> rustix::io::Result<()> {
rustix::net::connect(sockfd, addr).map_err(|error| match error {
// On POSIX, non-blocking `connect` returns `EINPROGRESS`.
// Windows returns `WSAEWOULDBLOCK`.
//
// This normalized error code is depended upon by: tcp.rs
#[cfg(windows)]
Errno::WOULDBLOCK => Errno::INPROGRESS.into(),
_ => error.into(),
Errno::WOULDBLOCK => Errno::INPROGRESS,
_ => error,
})
}

pub fn tcp_listen(listener: &TcpListener, backlog: Option<i32>) -> std::io::Result<()> {
listener
pub fn tcp_listen<Fd: AsFd>(sockfd: Fd, backlog: Option<i32>) -> std::io::Result<()> {
// Delegate `listen` to cap_net_ext. That is a thin wrapper around rustix::net::listen,
// with a platform-dependent default value for the backlog size.
sockfd
.as_fd()
.as_socketlike_view::<cap_std::net::TcpListener>()
.listen(backlog)
.map_err(|error| match Errno::from_io_error(&error) {
// See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-listen#:~:text=WSAEMFILE
Expand All @@ -351,15 +373,22 @@ pub(crate) mod util {
// on Microsoft's documentation here.
#[cfg(windows)]
Some(Errno::MFILE) => Errno::NOBUFS.into(),

_ => error,
})
}

pub fn tcp_accept(
listener: &TcpListener,
pub fn tcp_accept<Fd: AsFd>(
sockfd: Fd,
blocking: Blocking,
) -> std::io::Result<(TcpStream, SocketAddr)> {
listener
) -> std::io::Result<(OwnedFd, SocketAddr)> {
// Delegate `accept` to cap_net_ext. They set the NONBLOCK and CLOEXEC flags
// for us. Either immediately as a flag to `accept`, or afterwards using
// ioctl or fcntl. Exact method depends on the platform.

let (client, addr) = sockfd
.as_fd()
.as_socketlike_view::<cap_std::net::TcpListener>()
.accept_with(blocking)
.map_err(|error| match Errno::from_io_error(&error) {
// From: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-accept#:~:text=WSAEINPROGRESS
Expand Down Expand Up @@ -393,7 +422,9 @@ pub(crate) mod util {
) => Errno::CONNABORTED.into(),

_ => error,
})
})?;

Ok((client.into(), addr))
}

pub fn udp_disconnect<Fd: AsFd>(sockfd: Fd) -> rustix::io::Result<()> {
Expand Down
53 changes: 22 additions & 31 deletions crates/wasi/src/preview2/host/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use crate::preview2::{
};
use crate::preview2::{Pollable, SocketResult, WasiView};
use cap_net_ext::Blocking;
use cap_std::net::TcpListener;
use io_lifetimes::AsSocketlike;
use rustix::io::Errno;
use rustix::net::sockopt;
Expand Down Expand Up @@ -47,21 +46,18 @@ impl<T: WasiView> crate::preview2::host::tcp::tcp::HostTcpSocket for T {
{
// Ensure that we're allowed to connect to this address.
network.check_socket_addr(&local_address, SocketAddrUse::TcpBind)?;
let listener = &*socket.tcp_socket().as_socketlike_view::<TcpListener>();

// Perform the OS bind call.
util::tcp_bind(listener, &local_address).map_err(|error| {
match Errno::from_io_error(&error) {
// From https://pubs.opengroup.org/onlinepubs/9699919799/functions/bind.html:
// > [EAFNOSUPPORT] The specified address is not a valid address for the address family of the specified socket
//
// The most common reasons for this error should have already
// been handled by our own validation slightly higher up in this
// function. This error mapping is here just in case there is
// an edge case we didn't catch.
Some(Errno::AFNOSUPPORT) => ErrorCode::InvalidArgument,
_ => ErrorCode::from(error),
}
util::tcp_bind(socket.tcp_socket(), &local_address).map_err(|error| match error {
// From https://pubs.opengroup.org/onlinepubs/9699919799/functions/bind.html:
// > [EAFNOSUPPORT] The specified address is not a valid address for the address family of the specified socket
//
// The most common reasons for this error should have already
// been handled by our own validation slightly higher up in this
// function. This error mapping is here just in case there is
// an edge case we didn't catch.
Errno::AFNOSUPPORT => ErrorCode::InvalidArgument,
_ => ErrorCode::from(error),
})?;
}

Expand Down Expand Up @@ -116,10 +112,9 @@ impl<T: WasiView> crate::preview2::host::tcp::tcp::HostTcpSocket for T {

// Ensure that we're allowed to connect to this address.
network.check_socket_addr(&remote_address, SocketAddrUse::TcpConnect)?;
let listener = &*socket.tcp_socket().as_socketlike_view::<TcpListener>();

// Do an OS `connect`. Our socket is non-blocking, so it'll either...
util::tcp_connect(listener, &remote_address)
util::tcp_connect(socket.tcp_socket(), &remote_address)
};

match r {
Expand All @@ -130,11 +125,11 @@ impl<T: WasiView> crate::preview2::host::tcp::tcp::HostTcpSocket for T {
return Ok(());
}
// continue in progress,
Err(err) if Errno::from_io_error(&err) == Some(Errno::INPROGRESS) => {}
Err(err) if err == Errno::INPROGRESS => {}
// or fail immediately.
Err(err) => {
return Err(match Errno::from_io_error(&err) {
Some(Errno::AFNOSUPPORT) => ErrorCode::InvalidArgument.into(), // See `bind` implementation.
return Err(match err {
Errno::AFNOSUPPORT => ErrorCode::InvalidArgument.into(), // See `bind` implementation.
_ => err.into(),
});
}
Expand Down Expand Up @@ -207,10 +202,7 @@ impl<T: WasiView> crate::preview2::host::tcp::tcp::HostTcpSocket for T {
| TcpState::BindStarted => return Err(ErrorCode::ConcurrencyConflict.into()),
}

{
let listener = &*socket.tcp_socket().as_socketlike_view::<TcpListener>();
util::tcp_listen(listener, socket.listen_backlog_size)?;
}
util::tcp_listen(socket.tcp_socket(), socket.listen_backlog_size)?;

socket.tcp_state = TcpState::ListenStarted;

Expand Down Expand Up @@ -250,9 +242,8 @@ impl<T: WasiView> crate::preview2::host::tcp::tcp::HostTcpSocket for T {

// Do the OS accept call.
let tcp_socket = socket.tcp_socket();
let (connection, _addr) = tcp_socket.try_io(Interest::READABLE, || {
let listener = &*tcp_socket.as_socketlike_view::<TcpListener>();
util::tcp_accept(listener, Blocking::No)
let (client_fd, _addr) = tcp_socket.try_io(Interest::READABLE, || {
util::tcp_accept(tcp_socket, Blocking::No)
})?;

#[cfg(target_os = "macos")]
Expand All @@ -262,25 +253,25 @@ impl<T: WasiView> crate::preview2::host::tcp::tcp::HostTcpSocket for T {
// and only if a specific value was explicitly set on the listener.

if let Some(size) = socket.receive_buffer_size {
_ = util::set_socket_recv_buffer_size(&connection, size); // Ignore potential error.
_ = util::set_socket_recv_buffer_size(&client_fd, size); // Ignore potential error.
}

if let Some(size) = socket.send_buffer_size {
_ = util::set_socket_send_buffer_size(&connection, size); // Ignore potential error.
_ = util::set_socket_send_buffer_size(&client_fd, size); // Ignore potential error.
}

// For some reason, IP_TTL is inherited, but IPV6_UNICAST_HOPS isn't.
if let (SocketAddressFamily::Ipv6 { .. }, Some(ttl)) = (socket.family, socket.hop_limit)
{
_ = util::set_ipv6_unicast_hops(&connection, ttl); // Ignore potential error.
_ = util::set_ipv6_unicast_hops(&client_fd, ttl); // Ignore potential error.
}

if let Some(value) = socket.keep_alive_idle_time {
_ = util::set_tcp_keepidle(&connection, value); // Ignore potential error.
_ = util::set_tcp_keepidle(&client_fd, value); // Ignore potential error.
}
}

let mut tcp_socket = TcpSocket::from_tcp_stream(connection, socket.family)?;
let mut tcp_socket = TcpSocket::from_fd(client_fd, socket.family)?;

// Mark the socket as connected so that we can exit early from methods like `start-bind`.
tcp_socket.tcp_state = TcpState::Connected;
Expand Down
25 changes: 10 additions & 15 deletions crates/wasi/src/preview2/host/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,23 +55,18 @@ impl<T: WasiView> udp::HostUdpSocket for T {

{
check.check(&local_address, SocketAddrUse::UdpBind)?;
let udp_socket = &*socket
.udp_socket()
.as_socketlike_view::<cap_std::net::UdpSocket>();

// Perform the OS bind call.
util::udp_bind(udp_socket, &local_address).map_err(|error| {
match Errno::from_io_error(&error) {
// From https://pubs.opengroup.org/onlinepubs/9699919799/functions/bind.html:
// > [EAFNOSUPPORT] The specified address is not a valid address for the address family of the specified socket
//
// The most common reasons for this error should have already
// been handled by our own validation slightly higher up in this
// function. This error mapping is here just in case there is
// an edge case we didn't catch.
Some(Errno::AFNOSUPPORT) => ErrorCode::InvalidArgument,
_ => ErrorCode::from(error),
}
util::udp_bind(socket.udp_socket(), &local_address).map_err(|error| match error {
// From https://pubs.opengroup.org/onlinepubs/9699919799/functions/bind.html:
// > [EAFNOSUPPORT] The specified address is not a valid address for the address family of the specified socket
//
// The most common reasons for this error should have already
// been handled by our own validation slightly higher up in this
// function. This error mapping is here just in case there is
// an edge case we didn't catch.
Errno::AFNOSUPPORT => ErrorCode::InvalidArgument,
_ => ErrorCode::from(error),
})?;
}

Expand Down
32 changes: 14 additions & 18 deletions crates/wasi/src/preview2/tcp.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use super::network::SocketAddressFamily;
use super::{HostInputStream, HostOutputStream, StreamError};
use crate::preview2::host::network::util;
use crate::preview2::{
with_ambient_tokio_runtime, AbortOnDropJoinHandle, InputStream, OutputStream, Subscribe,
};
use anyhow::{Error, Result};
use cap_net_ext::{AddressFamily, Blocking, TcpListenerExt};
use cap_std::net::TcpListener;
use cap_net_ext::{AddressFamily, Blocking};
use io_lifetimes::raw::{FromRawSocketlike, IntoRawSocketlike};
use rustix::net::sockopt;
use std::io;
Expand Down Expand Up @@ -261,36 +261,26 @@ impl TcpSocket {
pub fn new(family: AddressFamily) -> io::Result<Self> {
// Create a new host socket and set it to non-blocking, which is needed
// by our async implementation.
let tcp_listener = TcpListener::new(family, Blocking::No)?;
let fd = util::tcp_socket(family, Blocking::No)?;

let socket_address_family = match family {
AddressFamily::Ipv4 => SocketAddressFamily::Ipv4,
AddressFamily::Ipv6 => SocketAddressFamily::Ipv6 {
v6only: sockopt::get_ipv6_v6only(&tcp_listener)?,
v6only: sockopt::get_ipv6_v6only(&fd)?,
},
};

Self::from_tcp_listener(tcp_listener, socket_address_family)
Self::from_fd(fd, socket_address_family)
}

/// Create a `TcpSocket` from an existing socket.
///
/// The socket must be in non-blocking mode.
pub(crate) fn from_tcp_stream(
tcp_socket: cap_std::net::TcpStream,
pub(crate) fn from_fd(
fd: rustix::fd::OwnedFd,
family: SocketAddressFamily,
) -> io::Result<Self> {
let tcp_listener = TcpListener::from(rustix::fd::OwnedFd::from(tcp_socket));
Self::from_tcp_listener(tcp_listener, family)
}

pub(crate) fn from_tcp_listener(
tcp_listener: cap_std::net::TcpListener,
family: SocketAddressFamily,
) -> io::Result<Self> {
let fd = tcp_listener.into_raw_socketlike();
let std_stream = unsafe { std::net::TcpStream::from_raw_socketlike(fd) };
let stream = with_ambient_tokio_runtime(|| tokio::net::TcpStream::try_from(std_stream))?;
let stream = Self::setup_tokio_tcp_stream(fd)?;

Ok(Self {
inner: Arc::new(stream),
Expand All @@ -308,6 +298,12 @@ impl TcpSocket {
})
}

fn setup_tokio_tcp_stream(fd: rustix::fd::OwnedFd) -> io::Result<tokio::net::TcpStream> {
let std_stream =
unsafe { std::net::TcpStream::from_raw_socketlike(fd.into_raw_socketlike()) };
with_ambient_tokio_runtime(|| tokio::net::TcpStream::try_from(std_stream))
}

pub fn tcp_socket(&self) -> &tokio::net::TcpStream {
&self.inner
}
Expand Down
Loading