Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace socket2 calls with rustix #146

Merged
merged 7 commits into from
Oct 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@ futures-io = { version = "0.3.28", default-features = false, features = ["std"]
futures-lite = { version = "1.11.0", default-features = false }
parking = "2.0.0"
polling = "3.0.0"
rustix = { version = "0.38.2", default-features = false, features = ["std", "fs"] }
rustix = { version = "0.38.2", default-features = false, features = ["fs", "net", "std"] }
slab = "0.4.2"
socket2 = { version = "0.5.3", features = ["all"] }
tracing = { version = "0.1.37", default-features = false }
waker-fn = "1.1.0"

[target.'cfg(windows)'.dependencies]
windows-sys = { version = "0.48.0", features = ["Win32_Foundation"] }

[dev-dependencies]
async-channel = "1"
async-net = "1"
Expand Down
158 changes: 124 additions & 34 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ use std::os::windows::io::{AsRawSocket, AsSocket, BorrowedSocket, OwnedSocket, R
use futures_io::{AsyncRead, AsyncWrite};
use futures_lite::stream::{self, Stream};
use futures_lite::{future, pin, ready};
use socket2::{Domain, Protocol, SockAddr, Socket, Type};

use rustix::io as rio;
use rustix::net as rn;

use crate::reactor::{Reactor, Registration, Source};

Expand Down Expand Up @@ -656,23 +658,7 @@ impl<T: AsFd> Async<T> {
pub fn new(io: T) -> io::Result<Async<T>> {
// Put the file descriptor in non-blocking mode.
let fd = io.as_fd();
cfg_if::cfg_if! {
// ioctl(FIONBIO) sets the flag atomically, but we use this only on Linux
// for now, as with the standard library, because it seems to behave
// differently depending on the platform.
// https://github.com/rust-lang/rust/commit/efeb42be2837842d1beb47b51bb693c7474aba3d
// https://github.com/libuv/libuv/blob/e9d91fccfc3e5ff772d5da90e1c4a24061198ca0/src/unix/poll.c#L78-L80
// https://github.com/tokio-rs/mio/commit/0db49f6d5caf54b12176821363d154384357e70a
if #[cfg(target_os = "linux")] {
rustix::io::ioctl_fionbio(fd, true)?;
} else {
let previous = rustix::fs::fcntl_getfl(fd)?;
let new = previous | rustix::fs::OFlags::NONBLOCK;
if new != previous {
rustix::fs::fcntl_setfl(fd, new)?;
}
}
}
set_nonblocking(fd)?;

// SAFETY: It is impossible to drop the I/O source while it is registered through
// this type.
Expand Down Expand Up @@ -1487,10 +1473,15 @@ impl Async<TcpStream> {
/// # std::io::Result::Ok(()) });
/// ```
pub async fn connect<A: Into<SocketAddr>>(addr: A) -> io::Result<Async<TcpStream>> {
// Begin async connect.
// Figure out how to handle this address.
let addr = addr.into();
let domain = Domain::for_address(addr);
let socket = connect(addr.into(), domain, Some(Protocol::TCP))?;
let (domain, sock_addr) = match addr {
SocketAddr::V4(v4) => (rn::AddressFamily::INET, rn::SocketAddrAny::V4(v4)),
SocketAddr::V6(v6) => (rn::AddressFamily::INET6, rn::SocketAddrAny::V6(v6)),
};

// Begin async connect.
let socket = connect(sock_addr, domain, Some(rn::ipproto::TCP))?;
let stream = Async::new(TcpStream::from(socket))?;

// The stream becomes writable when connected.
Expand Down Expand Up @@ -1819,7 +1810,11 @@ impl Async<UnixStream> {
/// ```
pub async fn connect<P: AsRef<Path>>(path: P) -> io::Result<Async<UnixStream>> {
// Begin async connect.
let socket = connect(SockAddr::unix(path)?, Domain::UNIX, None)?;
let socket = connect(
rn::SocketAddrUnix::new(path.as_ref())?.into(),
rn::AddressFamily::UNIX,
None,
)?;
let stream = Async::new(UnixStream::from(socket))?;

// The stream becomes writable when connected.
Expand Down Expand Up @@ -2029,8 +2024,14 @@ async fn optimistic(fut: impl Future<Output = io::Result<()>>) -> io::Result<()>
.await
}

fn connect(addr: SockAddr, domain: Domain, protocol: Option<Protocol>) -> io::Result<Socket> {
let sock_type = Type::STREAM;
fn connect(
addr: rn::SocketAddrAny,
domain: rn::AddressFamily,
protocol: Option<rn::Protocol>,
) -> io::Result<rustix::fd::OwnedFd> {
#[cfg(windows)]
use rustix::fd::AsFd;

#[cfg(any(
target_os = "android",
target_os = "dragonfly",
Expand All @@ -2041,10 +2042,13 @@ fn connect(addr: SockAddr, domain: Domain, protocol: Option<Protocol>) -> io::Re
target_os = "netbsd",
target_os = "openbsd"
notgull marked this conversation as resolved.
Show resolved Hide resolved
))]
// If we can, set nonblocking at socket creation for unix
let sock_type = sock_type.nonblocking();
// This automatically handles cloexec on unix, no_inherit on windows and nosigpipe on macos
let socket = Socket::new(domain, sock_type, protocol)?;
let socket = rn::socket_with(
domain,
rn::SocketType::STREAM,
rn::SocketFlags::CLOEXEC | rn::SocketFlags::NONBLOCK,
protocol,
)?;

#[cfg(not(any(
target_os = "android",
target_os = "dragonfly",
Expand All @@ -2055,14 +2059,100 @@ fn connect(addr: SockAddr, domain: Domain, protocol: Option<Protocol>) -> io::Re
target_os = "netbsd",
target_os = "openbsd"
)))]
// If the current platform doesn't support nonblocking at creation, enable it after creation
socket.set_nonblocking(true)?;
match socket.connect(&addr) {
let socket = {
#[cfg(not(any(
target_os = "macos",
target_os = "ios",
target_os = "tvos",
target_os = "watchos",
target_os = "espidf",
windows,
)))]
let flags = rn::SocketFlags::CLOEXEC;
#[cfg(any(
target_os = "macos",
target_os = "ios",
target_os = "tvos",
target_os = "watchos",
target_os = "espidf",
windows,
))]
notgull marked this conversation as resolved.
Show resolved Hide resolved
let flags = rn::SocketFlags::empty();

// Create the socket.
let socket = rn::socket_with(domain, rn::SocketType::STREAM, flags, protocol)?;

// Set cloexec if necessary.
#[cfg(any(
target_os = "macos",
target_os = "ios",
target_os = "tvos",
target_os = "watchos",
))]
rio::fcntl_setfd(&socket, rio::fcntl_getfd(&socket)? | rio::FdFlags::CLOEXEC)?;

// Set non-blocking mode.
set_nonblocking(socket.as_fd())?;

socket
};

// Set nosigpipe if necessary.
#[cfg(any(
target_os = "macos",
target_os = "ios",
target_os = "tvos",
target_os = "watchos",
target_os = "freebsd"
))]
Comment on lines +2100 to +2107
Copy link
Collaborator

@taiki-e taiki-e Sep 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK, NetBSD also supports nosigpipe (OpenBSD and DragonflyBSD do not).

Also, do you know why socket2 sets nosigpipe only on Apple targets, not also on FreeBSD/NetBSD?
https://github.com/rust-lang/socket2/blob/66ed6b055a4352a26ed25e8d01981fb5b555f3d8/src/socket.rs#L798
Maybe because MSG_NOSIGNAL has already been passed?
https://github.com/rust-lang/rust/blob/fece303aeb1d2b845745bc3c067475619670a868/library/std/src/sys_common/net.rs#L38

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know, I would bet that it's because FreeBSD/NetBSD did not have it at the time.

Copy link
Collaborator

@taiki-e taiki-e Oct 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened bytecodealliance/rustix#870 for NetBSD and DragonFly BSD.

rn::sockopt::set_socket_nosigpipe(&socket, true)?;

// Set the handle information to HANDLE_FLAG_INHERIT.
#[cfg(windows)]
unsafe {
if windows_sys::Win32::Foundation::SetHandleInformation(
socket.as_raw_socket() as _,
windows_sys::Win32::Foundation::HANDLE_FLAG_INHERIT,
windows_sys::Win32::Foundation::HANDLE_FLAG_INHERIT,
) == 0
{
return Err(io::Error::last_os_error());
}
}

#[allow(unreachable_patterns)]
match rn::connect_any(&socket, &addr) {
Ok(_) => {}
#[cfg(unix)]
Err(err) if err.raw_os_error() == Some(rustix::io::Errno::INPROGRESS.raw_os_error()) => {}
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
Err(err) => return Err(err),
Err(rio::Errno::INPROGRESS) => {}
Err(rio::Errno::AGAIN) | Err(rio::Errno::WOULDBLOCK) => {}
Err(err) => return Err(err.into()),
}
Ok(socket)
}

#[inline]
fn set_nonblocking(
#[cfg(unix)] fd: BorrowedFd<'_>,
#[cfg(windows)] fd: BorrowedSocket<'_>,
) -> io::Result<()> {
cfg_if::cfg_if! {
// ioctl(FIONBIO) sets the flag atomically, but we use this only on Linux
// for now, as with the standard library, because it seems to behave
// differently depending on the platform.
// https://github.com/rust-lang/rust/commit/efeb42be2837842d1beb47b51bb693c7474aba3d
// https://github.com/libuv/libuv/blob/e9d91fccfc3e5ff772d5da90e1c4a24061198ca0/src/unix/poll.c#L78-L80
// https://github.com/tokio-rs/mio/commit/0db49f6d5caf54b12176821363d154384357e70a
if #[cfg(any(windows, target_os = "linux"))] {
rustix::io::ioctl_fionbio(fd, true)?;
} else {
let previous = rustix::fs::fcntl_getfl(fd)?;
let new = previous | rustix::fs::OFlags::NONBLOCK;
if new != previous {
rustix::fs::fcntl_setfl(fd, new)?;
}
}
}

Ok(())
}
3 changes: 2 additions & 1 deletion tests/async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ fn udp_connect() -> io::Result<()> {
})
}

#[cfg(unix)]
// This test is broken for now on OpenBSD: https://github.com/rust-lang/rust/issues/116523
#[cfg(all(unix, not(target_os = "openbsd")))]
#[test]
fn uds_connect() -> io::Result<()> {
future::block_on(async {
Expand Down
Loading