Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add file descriptor support for poller #606

Merged
merged 3 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions tests/active_poller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

#include <array>
#include <memory>
#include <cstring>

#if !defined(_WIN32)
#include <unistd.h>
#endif // !_WIN32

TEST_CASE("create destroy", "[active_poller]")
{
Expand Down Expand Up @@ -86,6 +91,85 @@ TEST_CASE("add handler", "[active_poller]")
active_poller.add(socket, zmq::event_flags::pollin, no_op_handler));
}

TEST_CASE("add fd handler", "[active_poller]")
{
int fd = 1;
zmq::active_poller_t active_poller;
CHECK_NOTHROW(
active_poller.add(fd, zmq::event_flags::pollin, no_op_handler));
}

TEST_CASE("remove fd handler", "[active_poller]")
{
int fd = 1;
zmq::active_poller_t active_poller;
CHECK_NOTHROW(
active_poller.add(fd, zmq::event_flags::pollin, no_op_handler));
CHECK_NOTHROW(
active_poller.remove(fd));
CHECK_THROWS_ZMQ_ERROR(EINVAL, active_poller.remove(100));
}

#if !defined(_WIN32)
// On Windows, these functions can only be used with WinSock sockets.

TEST_CASE("mixed socket and fd handlers", "[active_poller]")
{
int pipefd[2];
::pipe(pipefd);

zmq::context_t context;
constexpr char inprocSocketAddress[] = "inproc://mixed-handlers";
zmq::socket_t socket_rcv{context, zmq::socket_type::pair};
zmq::socket_t socket_snd{context, zmq::socket_type::pair};
socket_rcv.bind(inprocSocketAddress);
socket_snd.connect(inprocSocketAddress);

unsigned eventsFd = 0;
unsigned eventsSocket = 0;

constexpr char messageText[] = "message";
constexpr size_t messageSize = sizeof(messageText);

zmq::active_poller_t active_poller;
CHECK_NOTHROW(
active_poller.add(pipefd[0], zmq::event_flags::pollin, [&](zmq::event_flags flags) {
if (flags == zmq::event_flags::pollin)
{
char buffer[256];
CHECK(messageSize == ::read(pipefd[0], buffer, messageSize));
CHECK(0 == std::strcmp(buffer, messageText));
++eventsFd;
}
}));
CHECK_NOTHROW(
active_poller.add(socket_rcv, zmq::event_flags::pollin, [&](zmq::event_flags flags) {
if (flags == zmq::event_flags::pollin)
{
zmq::message_t msg;
CHECK(socket_rcv.recv(msg, zmq::recv_flags::dontwait).has_value());
CHECK(messageSize == msg.size());
CHECK(0 == std::strcmp(messageText, msg.data<const char>()));
++eventsSocket;
}
}));

// send/rcv socket pair
zmq::message_t msg{messageText, messageSize};
socket_snd.send(msg, zmq::send_flags::dontwait);
CHECK(1 == active_poller.wait(std::chrono::milliseconds{100}));
CHECK(0 == eventsFd);
CHECK(1 == eventsSocket);

// send/rcv pipe
::write(pipefd[1], messageText, messageSize);
CHECK(1 == active_poller.wait(std::chrono::milliseconds{100}));
CHECK(1 == eventsFd);
CHECK(1 == eventsSocket);
}

#endif // !_WIN32

TEST_CASE("add null handler fails", "[active_poller]")
{
zmq::context_t context;
Expand Down
16 changes: 16 additions & 0 deletions zmq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2681,6 +2681,13 @@ template<typename T = no_user_data> class poller_t
}
}

void remove(fd_t fd)
{
if (0 != zmq_poller_remove_fd(poller_ptr.get(), fd)) {
throw error_t();
}
}

void modify(zmq::socket_ref socket, event_flags events)
{
if (0
Expand All @@ -2690,6 +2697,15 @@ template<typename T = no_user_data> class poller_t
}
}

void modify(fd_t fd, event_flags events)
{
if (0
!= zmq_poller_modify_fd(poller_ptr.get(), fd,
static_cast<short>(events))) {
throw error_t();
}
}

size_t wait_all(std::vector<event_type> &poller_events,
const std::chrono::milliseconds timeout)
{
Expand Down
105 changes: 100 additions & 5 deletions zmq_addon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,65 @@
#include <limits>
#include <functional>
#include <unordered_map>
#endif

namespace zmq
{
// socket ref or native file descriptor for poller
class poller_ref_t
{
public:
enum RefType
{
RT_SOCKET,
RT_FD
};

poller_ref_t() : poller_ref_t(socket_ref{})
{}

poller_ref_t(const zmq::socket_ref& socket) : data{RT_SOCKET, socket, {}}
{}

poller_ref_t(zmq::fd_t fd) : data{RT_FD, {}, fd}
{}

size_t hash() const ZMQ_NOTHROW
{
std::size_t h = 0;
hash_combine(h, std::get<0>(data));
hash_combine(h, std::get<1>(data));
hash_combine(h, std::get<2>(data));
return h;
}

bool operator == (const poller_ref_t& o) const ZMQ_NOTHROW
{
return data == o.data;
}

private:
template <class T>
static void hash_combine(std::size_t& seed, const T& v) ZMQ_NOTHROW
{
std::hash<T> hasher;
seed ^= hasher(v) + 0x9e3779b9 + (seed<<6) + (seed>>2);
}

std::tuple<int, zmq::socket_ref, zmq::fd_t> data;

}; // class poller_ref_t

} // namespace zmq

// std::hash<> specialization for std::unordered_map
template <> struct std::hash<zmq::poller_ref_t>
{
size_t operator()(const zmq::poller_ref_t& ref) const ZMQ_NOTHROW
{
return ref.hash();
}
};
#endif // ZMQ_CPP11

namespace zmq
{
Expand Down Expand Up @@ -683,10 +741,12 @@ class active_poller_t

void add(zmq::socket_ref socket, event_flags events, handler_type handler)
{
const poller_ref_t ref{socket};

if (!handler)
throw std::invalid_argument("null handler in active_poller_t::add");
throw std::invalid_argument("null handler in active_poller_t::add (socket)");
auto ret = handlers.emplace(
socket, std::make_shared<handler_type>(std::move(handler)));
ref, std::make_shared<handler_type>(std::move(handler)));
if (!ret.second)
throw error_t(EINVAL); // already added
try {
Expand All @@ -695,7 +755,28 @@ class active_poller_t
}
catch (...) {
// rollback
handlers.erase(socket);
handlers.erase(ref);
throw;
}
}

void add(fd_t fd, event_flags events, handler_type handler)
{
const poller_ref_t ref{fd};

if (!handler)
throw std::invalid_argument("null handler in active_poller_t::add (fd)");
auto ret = handlers.emplace(
ref, std::make_shared<handler_type>(std::move(handler)));
if (!ret.second)
throw error_t(EINVAL); // already added
try {
base_poller.add(fd, events, ret.first->second.get());
need_rebuild = true;
}
catch (...) {
// rollback
handlers.erase(ref);
throw;
}
}
Expand All @@ -707,11 +788,23 @@ class active_poller_t
need_rebuild = true;
}

void remove(fd_t fd)
{
base_poller.remove(fd);
handlers.erase(fd);
need_rebuild = true;
}

void modify(zmq::socket_ref socket, event_flags events)
{
base_poller.modify(socket, events);
}

void modify(fd_t fd, event_flags events)
{
base_poller.modify(fd, events);
}

size_t wait(std::chrono::milliseconds timeout)
{
if (need_rebuild) {
Expand Down Expand Up @@ -741,7 +834,9 @@ class active_poller_t
bool need_rebuild{false};

poller_t<handler_type> base_poller{};
std::unordered_map<socket_ref, std::shared_ptr<handler_type>> handlers{};

std::unordered_map<zmq::poller_ref_t, std::shared_ptr<handler_type>> handlers{};

std::vector<decltype(base_poller)::event_type> poller_events{};
std::vector<std::shared_ptr<handler_type>> poller_handlers{};
}; // class active_poller_t
Expand Down
Loading