Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: support free-threaded CPython with GIL disabled #5148

Merged
merged 10 commits into from
Jun 18, 2024
116 changes: 63 additions & 53 deletions include/pybind11/detail/class.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,39 +205,40 @@ extern "C" inline PyObject *pybind11_meta_call(PyObject *type, PyObject *args, P

/// Cleanup the type-info for a pybind11-registered type.
extern "C" inline void pybind11_meta_dealloc(PyObject *obj) {
auto *type = (PyTypeObject *) obj;
auto &internals = get_internals();

// A pybind11-registered type will:
// 1) be found in internals.registered_types_py
// 2) have exactly one associated `detail::type_info`
auto found_type = internals.registered_types_py.find(type);
if (found_type != internals.registered_types_py.end() && found_type->second.size() == 1
&& found_type->second[0]->type == type) {

auto *tinfo = found_type->second[0];
auto tindex = std::type_index(*tinfo->cpptype);
internals.direct_conversions.erase(tindex);

if (tinfo->module_local) {
get_local_internals().registered_types_cpp.erase(tindex);
} else {
internals.registered_types_cpp.erase(tindex);
}
internals.registered_types_py.erase(tinfo->type);

// Actually just `std::erase_if`, but that's only available in C++20
auto &cache = internals.inactive_override_cache;
for (auto it = cache.begin(), last = cache.end(); it != last;) {
if (it->first == (PyObject *) tinfo->type) {
it = cache.erase(it);
with_internals([obj](internals &internals) {
auto *type = (PyTypeObject *) obj;

// A pybind11-registered type will:
// 1) be found in internals.registered_types_py
// 2) have exactly one associated `detail::type_info`
auto found_type = internals.registered_types_py.find(type);
if (found_type != internals.registered_types_py.end() && found_type->second.size() == 1
&& found_type->second[0]->type == type) {

auto *tinfo = found_type->second[0];
auto tindex = std::type_index(*tinfo->cpptype);
internals.direct_conversions.erase(tindex);

if (tinfo->module_local) {
get_local_internals().registered_types_cpp.erase(tindex);
} else {
++it;
internals.registered_types_cpp.erase(tindex);
}
internals.registered_types_py.erase(tinfo->type);

// Actually just `std::erase_if`, but that's only available in C++20
auto &cache = internals.inactive_override_cache;
for (auto it = cache.begin(), last = cache.end(); it != last;) {
if (it->first == (PyObject *) tinfo->type) {
it = cache.erase(it);
} else {
++it;
}
}
}

delete tinfo;
}
delete tinfo;
}
});

PyType_Type.tp_dealloc(obj);
}
Expand Down Expand Up @@ -310,19 +311,20 @@ inline void traverse_offset_bases(void *valueptr,
}

inline bool register_instance_impl(void *ptr, instance *self) {
get_internals().registered_instances.emplace(ptr, self);
with_instance_map(ptr, [&](instance_map &instances) { instances.emplace(ptr, self); });
return true; // unused, but gives the same signature as the deregister func
}
inline bool deregister_instance_impl(void *ptr, instance *self) {
auto &registered_instances = get_internals().registered_instances;
auto range = registered_instances.equal_range(ptr);
for (auto it = range.first; it != range.second; ++it) {
if (self == it->second) {
registered_instances.erase(it);
return true;
return with_instance_map(ptr, [&](instance_map &instances) {
auto range = instances.equal_range(ptr);
for (auto it = range.first; it != range.second; ++it) {
if (self == it->second) {
instances.erase(it);
return true;
}
}
}
return false;
return false;
});
}

inline void register_instance(instance *self, void *valptr, const type_info *tinfo) {
Expand Down Expand Up @@ -377,27 +379,32 @@ extern "C" inline int pybind11_object_init(PyObject *self, PyObject *, PyObject
}

inline void add_patient(PyObject *nurse, PyObject *patient) {
auto &internals = get_internals();
auto *instance = reinterpret_cast<detail::instance *>(nurse);
instance->has_patients = true;
Py_INCREF(patient);
internals.patients[nurse].push_back(patient);

with_internals([&](internals &internals) { internals.patients[nurse].push_back(patient); });
}

inline void clear_patients(PyObject *self) {
auto *instance = reinterpret_cast<detail::instance *>(self);
auto &internals = get_internals();
auto pos = internals.patients.find(self);
std::vector<PyObject *> patients;

if (pos == internals.patients.end()) {
pybind11_fail("FATAL: Internal consistency check failed: Invalid clear_patients() call.");
}
with_internals([&](internals &internals) {
auto pos = internals.patients.find(self);

if (pos == internals.patients.end()) {
pybind11_fail(
"FATAL: Internal consistency check failed: Invalid clear_patients() call.");
}

// Clearing the patients can cause more Python code to run, which
// can invalidate the iterator. Extract the vector of patients
// from the unordered_map first.
patients = std::move(pos->second);
internals.patients.erase(pos);
});

// Clearing the patients can cause more Python code to run, which
// can invalidate the iterator. Extract the vector of patients
// from the unordered_map first.
auto patients = std::move(pos->second);
internals.patients.erase(pos);
instance->has_patients = false;
for (PyObject *&patient : patients) {
Py_CLEAR(patient);
Expand Down Expand Up @@ -662,10 +669,13 @@ inline PyObject *make_new_python_type(const type_record &rec) {

char *tp_doc = nullptr;
if (rec.doc && options::show_user_defined_docstrings()) {
/* Allocate memory for docstring (using PyObject_MALLOC, since
Python will free this later on) */
/* Allocate memory for docstring (Python will free this later on) */
size_t size = std::strlen(rec.doc) + 1;
#if PY_VERSION_HEX >= 0x030D0000
tp_doc = (char *) PyMem_MALLOC(size);
#else
tp_doc = (char *) PyObject_MALLOC(size);
#endif
std::memcpy((void *) tp_doc, rec.doc, size);
}

Expand Down
7 changes: 5 additions & 2 deletions include/pybind11/detail/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ PYBIND11_WARNING_POP
});
}
\endrst */
#define PYBIND11_MODULE(name, variable) \
#define PYBIND11_MODULE(name, variable, ...) \
static ::pybind11::module_::module_def PYBIND11_CONCAT(pybind11_module_def_, name) \
PYBIND11_MAYBE_UNUSED; \
PYBIND11_MAYBE_UNUSED \
Expand All @@ -473,7 +473,10 @@ PYBIND11_WARNING_POP
PYBIND11_CHECK_PYTHON_VERSION \
PYBIND11_ENSURE_INTERNALS_READY \
auto m = ::pybind11::module_::create_extension_module( \
PYBIND11_TOSTRING(name), nullptr, &PYBIND11_CONCAT(pybind11_module_def_, name)); \
PYBIND11_TOSTRING(name), \
nullptr, \
&PYBIND11_CONCAT(pybind11_module_def_, name), \
##__VA_ARGS__); \
try { \
PYBIND11_CONCAT(pybind11_init_, name)(m); \
return m.ptr(); \
Expand Down
141 changes: 125 additions & 16 deletions include/pybind11/detail/internals.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "../pytypes.h"

#include <exception>
#include <mutex>
#include <thread>

/// Tracks the `internals` and `type_info` ABI version independent of the main library version.
///
Expand Down Expand Up @@ -168,15 +170,33 @@ struct override_hash {
}
};

using instance_map = std::unordered_multimap<const void *, instance *>;

// ignore: structure was padded due to alignment specifier
PYBIND11_WARNING_DISABLE_MSVC(4324)
henryiii marked this conversation as resolved.
Show resolved Hide resolved

struct alignas(64) instance_map_shard {
henryiii marked this conversation as resolved.
Show resolved Hide resolved
std::mutex mutex;
instance_map registered_instances;
};

/// Internal data structure used to track registered instances and types.
/// Whenever binary incompatible changes are made to this structure,
/// `PYBIND11_INTERNALS_VERSION` must be incremented.
struct internals {
#ifdef Py_GIL_DISABLED
std::mutex mutex;
henryiii marked this conversation as resolved.
Show resolved Hide resolved
#endif
// std::type_index -> pybind11's type information
type_map<type_info *> registered_types_cpp;
// PyTypeObject* -> base type_info(s)
std::unordered_map<PyTypeObject *, std::vector<type_info *>> registered_types_py;
std::unordered_multimap<const void *, instance *> registered_instances; // void * -> instance*
#ifdef Py_GIL_DISABLED
std::unique_ptr<instance_map_shard[]> instance_shards; // void * -> instance*
size_t instance_shards_mask;
#else
instance_map registered_instances; // void * -> instance*
henryiii marked this conversation as resolved.
Show resolved Hide resolved
#endif
std::unordered_set<std::pair<const PyObject *, const char *>, override_hash>
inactive_override_cache;
type_map<std::vector<bool (*)(PyObject *, void *&)>> direct_conversions;
Expand Down Expand Up @@ -462,7 +482,8 @@ inline object get_python_state_dict() {
}

inline object get_internals_obj_from_state_dict(handle state_dict) {
return reinterpret_borrow<object>(dict_getitemstring(state_dict.ptr(), PYBIND11_INTERNALS_ID));
return reinterpret_steal<object>(
dict_getitemstringref(state_dict.ptr(), PYBIND11_INTERNALS_ID));
}

inline internals **get_internals_pp_from_capsule(handle obj) {
Expand All @@ -474,6 +495,20 @@ inline internals **get_internals_pp_from_capsule(handle obj) {
return static_cast<internals **>(raw_ptr);
}

inline uint64_t next_pow2(uint64_t x) {
// Round-up to the next power of two.
henryiii marked this conversation as resolved.
Show resolved Hide resolved
// See https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
x--;
x |= (x >> 1);
x |= (x >> 2);
x |= (x >> 4);
x |= (x >> 8);
x |= (x >> 16);
x |= (x >> 32);
x++;
return x;
}

/// Return a reference to the current `internals` data
PYBIND11_NOINLINE internals &get_internals() {
auto **&internals_pp = get_internals_pp();
Expand Down Expand Up @@ -542,6 +577,14 @@ PYBIND11_NOINLINE internals &get_internals() {
internals_ptr->static_property_type = make_static_property_type();
internals_ptr->default_metaclass = make_default_metaclass();
internals_ptr->instance_base = make_object_base_type(internals_ptr->default_metaclass);
#ifdef Py_GIL_DISABLED
size_t num_shards = (size_t) next_pow2(2 * std::thread::hardware_concurrency());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe

auto num_shards = static_cast<size_t>(...);

?

(The main motivation is to minimize C-style casts. Similarly in a few other places changed in this PR.)

Could it be worth adding a comment to explain why 2 * ...?

if (num_shards == 0) {
num_shards = 1;
}
internals_ptr->instance_shards.reset(new instance_map_shard[num_shards]);
internals_ptr->instance_shards_mask = num_shards - 1;
#endif // Py_GIL_DISABLED
}
return **internals_pp;
}
Expand Down Expand Up @@ -602,13 +645,75 @@ inline local_internals &get_local_internals() {
return *locals;
}

#ifdef Py_GIL_DISABLED
# define PYBIND11_LOCK_INTERNALS(internals) std::unique_lock<std::mutex> lock((internals).mutex)
#else
# define PYBIND11_LOCK_INTERNALS(internals)
#endif

template <typename F>
inline auto with_internals(const F &cb) -> decltype(cb(get_internals())) {
auto &internals = get_internals();
PYBIND11_LOCK_INTERNALS(internals);
return cb(internals);
}

inline uint64_t splitmix64(uint64_t z) {
z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9;
henryiii marked this conversation as resolved.
Show resolved Hide resolved
z = (z ^ (z >> 27)) * 0x94d049bb133111eb;
return z ^ (z >> 31);
}

template <typename F>
inline auto with_instance_map(const void *ptr,
const F &cb) -> decltype(cb(std::declval<instance_map &>())) {
auto &internals = get_internals();

#ifdef Py_GIL_DISABLED
// Hash address to compute shard, but ignore low bits. We'd like allocations
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion to move this comment to the splitmix64 function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is specific to this section of code: it's mostly about why we are ignoring the low bits (addr >> 20) when hashing the address.

// from the same thread/core to map to the same shard and allocations from
// other threads/cores to map to other shards. Using the high bits is a good
// heuristic because memory allocators often have a per-thread
// arena/superblock/segment from which smaller allocations are served.
auto addr = reinterpret_cast<uintptr_t>(ptr);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

static_cast seems to work here (with Linux gcc at least). (I see we're already using reinterpret_cast in a bunch of similar situations, but maybe that isn't ideal?)

Would using std::uintptr_t and std::uint64_t be slightly better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

static_cast here causes a compiler error for me: https://gcc.godbolt.org/z/bTnosGf61

I've updated the other casts and used the std namespace.

uint64_t hash = splitmix64((uint64_t) (addr >> 20));
rwgk marked this conversation as resolved.
Show resolved Hide resolved
size_t idx = (size_t) hash & internals.instance_shards_mask;

auto &shard = internals.instance_shards[idx];
std::unique_lock<std::mutex> lock(shard.mutex);
return cb(shard.registered_instances);
#else
(void) ptr;
return cb(internals.registered_instances);
#endif
}

inline size_t num_registered_instances() {
henryiii marked this conversation as resolved.
Show resolved Hide resolved
auto &internals = get_internals();
#ifdef Py_GIL_DISABLED
size_t count = 0;
for (size_t i = 0; i <= internals.instance_shards_mask; ++i) {
auto &shard = internals.instance_shards[i];
std::unique_lock<std::mutex> lock(shard.mutex);
count += shard.registered_instances.size();
}
return count;
#else
return internals.registered_instances.size();
#endif
}

/// Constructs a std::string with the given arguments, stores it in `internals`, and returns its
/// `c_str()`. Such strings objects have a long storage duration -- the internal strings are only
/// cleared when the program exits or after interpreter shutdown (when embedding), and so are
/// suitable for c-style strings needed by Python internals (such as PyTypeObject's tp_name).
template <typename... Args>
const char *c_str(Args &&...args) {
auto &strings = get_internals().static_strings;
// GCC 4.8 doesn't like parameter unpack within lambda capture, so use
// PYBIND11_LOCK_INTERNALS.
auto &internals = get_internals();
PYBIND11_LOCK_INTERNALS(internals);
auto &strings = internals.static_strings;
strings.emplace_front(std::forward<Args>(args)...);
return strings.front().c_str();
}
Expand Down Expand Up @@ -638,30 +743,34 @@ PYBIND11_NAMESPACE_END(detail)
/// pybind11 version) running in the current interpreter. Names starting with underscores
/// are reserved for internal usage. Returns `nullptr` if no matching entry was found.
PYBIND11_NOINLINE void *get_shared_data(const std::string &name) {
auto &internals = detail::get_internals();
auto it = internals.shared_data.find(name);
return it != internals.shared_data.end() ? it->second : nullptr;
return detail::with_internals([&](detail::internals &internals) {
auto it = internals.shared_data.find(name);
return it != internals.shared_data.end() ? it->second : nullptr;
});
}

/// Set the shared data that can be later recovered by `get_shared_data()`.
PYBIND11_NOINLINE void *set_shared_data(const std::string &name, void *data) {
detail::get_internals().shared_data[name] = data;
return data;
return detail::with_internals([&](detail::internals &internals) {
internals.shared_data[name] = data;
return data;
});
}

/// Returns a typed reference to a shared data entry (by using `get_shared_data()`) if
/// such entry exists. Otherwise, a new object of default-constructible type `T` is
/// added to the shared data under the given name and a reference to it is returned.
template <typename T>
T &get_or_create_shared_data(const std::string &name) {
auto &internals = detail::get_internals();
auto it = internals.shared_data.find(name);
T *ptr = (T *) (it != internals.shared_data.end() ? it->second : nullptr);
if (!ptr) {
ptr = new T();
internals.shared_data[name] = ptr;
}
return *ptr;
return *detail::with_internals([&](detail::internals &internals) {
auto it = internals.shared_data.find(name);
T *ptr = (T *) (it != internals.shared_data.end() ? it->second : nullptr);
if (!ptr) {
ptr = new T();
internals.shared_data[name] = ptr;
}
return ptr;
});
}

PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
Loading
Loading