Skip to content

Commit

Permalink
relax the conditions when the complex function dispatch loop is used
Browse files Browse the repository at this point in the history
  • Loading branch information
wjakob committed Sep 20, 2024
1 parent f8f19eb commit f17c414
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
27 changes: 18 additions & 9 deletions src/nb_func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ PyObject *nb_func_new(const void *in_) noexcept {
bool has_scope = f->flags & (uint32_t) func_flags::has_scope,
has_name = f->flags & (uint32_t) func_flags::has_name,
has_args = f->flags & (uint32_t) func_flags::has_args,
has_var_args = f->flags & (uint32_t) func_flags::has_var_kwargs,
has_var_kwargs = f->flags & (uint32_t) func_flags::has_var_args,
has_keep_alive = f->flags & (uint32_t) func_flags::has_keep_alive,
has_doc = f->flags & (uint32_t) func_flags::has_doc,
has_signature = f->flags & (uint32_t) func_flags::has_signature,
Expand Down Expand Up @@ -289,13 +291,20 @@ PyObject *nb_func_new(const void *in_) noexcept {

maybe_make_immortal((PyObject *) func);

func->max_nargs = f->nargs;
func->complex_call = f->nargs_pos < f->nargs || has_args || has_keep_alive;
// Check if the complex dispatch loop is needed
bool complex_call = has_keep_alive || has_var_kwargs || has_var_args || f->nargs >= NB_MAXARGS_SIMPLE;
if (has_args) {
for (size_t i = 0; i < f->nargs; ++i) {
arg_data &a = f->args[i];
complex_call |= a.name != nullptr || a.value != nullptr ||
a.flag != cast_flags::convert;
}
}

uint32_t max_nargs = f->nargs;
if (func_prev) {
func->complex_call |= ((nb_func *) func_prev)->complex_call;
func->max_nargs = std::max(func->max_nargs,
((nb_func *) func_prev)->max_nargs);
complex_call |= ((nb_func *) func_prev)->complex_call;
max_nargs = std::max(max_nargs, ((nb_func *) func_prev)->max_nargs);

func_data *cur = nb_func_data(func),
*prev = nb_func_data(func_prev);
Expand All @@ -314,10 +323,10 @@ PyObject *nb_func_new(const void *in_) noexcept {
Py_CLEAR(func_prev);
}

func->complex_call |= func->max_nargs >= NB_MAXARGS_SIMPLE;

func->vectorcall = func->complex_call ? nb_func_vectorcall_complex
: nb_func_vectorcall_simple;
func->max_nargs = max_nargs;
func->complex_call = complex_call;
func->vectorcall = complex_call ? nb_func_vectorcall_complex
: nb_func_vectorcall_simple;

#if !defined(NB_FREE_THREADED)
// Register the function
Expand Down
2 changes: 1 addition & 1 deletion tests/test_thread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ NB_MODULE(test_thread_ext, m) {

m.def("inc_safe",
[](Counter &c) { c.inc(); },
"counter"_a.lock());
nb::arg().lock());

m.def("inc_global",
[](Counter &c) {
Expand Down

0 comments on commit f17c414

Please sign in to comment.