Skip to content

Commit 6959d81

Browse files
committed
Specialize complex function dispatcher
This PR adds a variant of the complex dispatcher that applies when no overloads have kwargs and all overloads have 8 or fewer arguments.
1 parent dbe8a3c commit 6959d81

File tree

5 files changed

+202
-12
lines changed

5 files changed

+202
-12
lines changed

src/nb_func.cpp

Lines changed: 184 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ static PyObject *nb_func_vectorcall_simple_1(PyObject *, PyObject *const *,
3535
size_t, PyObject *) noexcept;
3636
static PyObject *nb_func_vectorcall_simple(PyObject *, PyObject *const *,
3737
size_t, PyObject *) noexcept;
38+
static PyObject *nb_func_vectorcall_modest(PyObject *, PyObject *const *,
39+
size_t, PyObject *) noexcept;
3840
static PyObject *nb_func_vectorcall_complex(PyObject *, PyObject *const *,
3941
size_t, PyObject *) noexcept;
4042
static uint32_t nb_func_render_signature(const func_data *f,
@@ -298,25 +300,27 @@ PyObject *nb_func_new(const func_data_prelim_base *f) noexcept {
298300
make_immortal((PyObject *) func);
299301

300302
// Check if the complex dispatch loop is needed
301-
bool complex_call = can_mutate_args || has_var_kwargs || has_var_args ||
303+
bool has_kwargs = has_var_kwargs;
304+
bool complex_call = can_mutate_args || has_var_args ||
302305
f->nargs > NB_MAXARGS_SIMPLE;
303-
304306
if (has_args) {
305307
for (size_t i = is_method; i < f->nargs; ++i) {
306308
arg_data &a = args_in[i - is_method];
307-
complex_call |= a.name != nullptr || a.value != nullptr ||
308-
a.flag != cast_flags::convert;
309+
has_kwargs |= a.name != nullptr;
310+
complex_call |= a.value != nullptr || a.flag != cast_flags::convert;
309311
}
310312
}
313+
complex_call |= has_kwargs;
311314

312315
uint32_t max_nargs = f->nargs;
313316

314317
const char *prev_doc = nullptr;
315318

316319
if (func_prev) {
317320
nb_func *nb_func_prev = (nb_func *) func_prev;
318-
complex_call |= nb_func_prev->complex_call;
319321
max_nargs = std::max(max_nargs, nb_func_prev->max_nargs);
322+
has_kwargs |= nb_func_prev->has_kwargs;
323+
complex_call |= nb_func_prev->complex_call;
320324

321325
func_data *cur = nb_func_data(func),
322326
*prev = nb_func_data(func_prev);
@@ -339,12 +343,15 @@ PyObject *nb_func_new(const func_data_prelim_base *f) noexcept {
339343
}
340344

341345
func->max_nargs = max_nargs;
346+
func->has_kwargs = has_kwargs;
342347
func->complex_call = complex_call;
343348

344-
345349
PyObject* (*vectorcall)(PyObject *, PyObject * const*, size_t, PyObject *);
346350
if (complex_call) {
347-
vectorcall = nb_func_vectorcall_complex;
351+
if (max_nargs <= NB_MAXARGS_SIMPLE && !has_kwargs)
352+
vectorcall = nb_func_vectorcall_modest;
353+
else
354+
vectorcall = nb_func_vectorcall_complex;
348355
} else {
349356
if (f->nargs == 0 && !prev_overloads)
350357
vectorcall = nb_func_vectorcall_simple_0;
@@ -636,7 +643,7 @@ static PyObject *nb_func_vectorcall_complex(PyObject *self,
636643
cleanup_list cleanup(self_arg);
637644

638645
// Preallocate stack memory for function dispatch
639-
size_t max_nargs = ((nb_func *) self)->max_nargs;
646+
const size_t max_nargs = ((nb_func *) self)->max_nargs;
640647
PyObject **args = (PyObject **) alloca(max_nargs * sizeof(PyObject *));
641648
uint8_t *args_flags = (uint8_t *) alloca(max_nargs * sizeof(uint8_t));
642649
bool *kwarg_used = (bool *) alloca(nkwargs_in * sizeof(bool));
@@ -715,15 +722,15 @@ static PyObject *nb_func_vectorcall_complex(PyObject *self,
715722

716723
// Number of C++ parameters eligible to be filled from individual
717724
// Python positional arguments
718-
size_t nargs_pos = f->nargs_pos;
725+
const size_t nargs_pos = f->nargs_pos;
719726

720727
// Number of C++ parameters in total, except for a possible trailing
721728
// nb::kwargs. All of these are eligible to be filled from individual
722729
// Python arguments (keyword always, positional until index nargs_pos)
723730
// except for a potential nb::args, which exists at index nargs_pos
724731
// if has_var_args is true. We'll skip that one in the individual-args
725732
// loop, and go back and fill it later with the unused positionals.
726-
size_t nargs_step1 = f->nargs - has_var_kwargs;
733+
const size_t nargs_step1 = f->nargs - has_var_kwargs;
727734

728735
if (nargs_in > nargs_pos && !has_var_args)
729736
continue; // Too many positional arguments given for this overload
@@ -827,6 +834,171 @@ static PyObject *nb_func_vectorcall_complex(PyObject *self,
827834
continue;
828835
}
829836

837+
if (is_constructor)
838+
args_flags[0] |= (uint8_t) cast_flags::construct;
839+
840+
rv_policy policy = (rv_policy) (f->flags & 0b111);
841+
842+
try {
843+
result = nullptr;
844+
845+
// Found a suitable overload, let's try calling it
846+
result = f->impl((void *) f->capture, args, args_flags,
847+
policy, &cleanup);
848+
849+
if (NB_UNLIKELY(!result))
850+
error_handler = nb_func_error_noconvert;
851+
} catch (builtin_exception &e) {
852+
if (!set_builtin_exception_status(e))
853+
result = NB_NEXT_OVERLOAD;
854+
} catch (python_error &e) {
855+
e.restore();
856+
} catch (...) {
857+
nb_func_convert_cpp_exception();
858+
}
859+
860+
if (result != NB_NEXT_OVERLOAD) {
861+
if (is_constructor && result != nullptr) {
862+
nb_inst *self_arg_nb = (nb_inst *) self_arg;
863+
self_arg_nb->destruct = true;
864+
self_arg_nb->state = nb_inst::state_ready;
865+
if (NB_UNLIKELY(self_arg_nb->intrusive))
866+
nb_type_data(Py_TYPE(self_arg))
867+
->set_self_py(inst_ptr(self_arg_nb), self_arg);
868+
}
869+
870+
goto done;
871+
}
872+
}
873+
}
874+
875+
error_handler = nb_func_error_overload;
876+
877+
done:
878+
if (NB_UNLIKELY(cleanup.used()))
879+
cleanup.release();
880+
881+
if (NB_UNLIKELY(error_handler))
882+
result = error_handler(self, args_in, nargs_in, kwargs_in);
883+
884+
return result;
885+
}
886+
887+
/// Simplified nb_func_vectorcall variant for functions w/o keyword arguments
888+
/// and with no more than NB_MAXARGS_SIMPLE arguments
889+
static PyObject *nb_func_vectorcall_modest(PyObject *self,
890+
PyObject *const *args_in,
891+
size_t nargsf,
892+
PyObject *kwargs_in) noexcept {
893+
const size_t count = (size_t) Py_SIZE(self),
894+
nargs_in = (size_t) NB_VECTORCALL_NARGS(nargsf);
895+
896+
func_data *fr = nb_func_data(self);
897+
898+
const bool is_method = fr->flags & (uint32_t) func_flags::is_method,
899+
is_constructor = fr->flags & (uint32_t) func_flags::is_constructor;
900+
901+
PyObject *result = nullptr,
902+
*self_arg = (is_method && nargs_in > 0) ? args_in[0] : nullptr;
903+
904+
// Handler routine that will be invoked in case of an error condition
905+
PyObject *(*error_handler)(PyObject *, PyObject *const *, size_t,
906+
PyObject *) noexcept = nullptr;
907+
908+
// Small array holding temporaries (implicit conversion/*args)
909+
cleanup_list cleanup(self_arg);
910+
911+
if (kwargs_in != nullptr) { // keyword arguments are unsupported
912+
error_handler = nb_func_error_overload;
913+
goto done;
914+
}
915+
916+
// Stack memory for function dispatch
917+
PyObject* args[NB_MAXARGS_SIMPLE];
918+
uint8_t args_flags[NB_MAXARGS_SIMPLE];
919+
920+
/* The logic below tries to find a suitable overload using two passes
921+
of the overload chain (or 1, if there are no overloads). The first pass
922+
is strict and permits no implicit conversions, while the second pass
923+
allows them.
924+
925+
The following is done per overload during a pass
926+
927+
1. Copy individual arguments, substituting missing entries using
928+
default argument values provided in the bindings, if available.
929+
930+
2. Any positional arguments still left get put into a tuple.
931+
932+
3. Pack everything into a vector; if we have nb::args, it becomes
933+
a tuple at the end of the positional arguments.
934+
935+
4. Call the function call dispatcher (func_data::impl)
936+
937+
If one of these fail, move on to the next overload and keep trying
938+
until we get a result other than NB_NEXT_OVERLOAD.
939+
*/
940+
941+
for (size_t pass = (count > 1) ? 0 : 1; pass < 2; ++pass) {
942+
for (size_t k = 0; k < count; ++k) {
943+
const func_data *f = fr + k;
944+
945+
const bool has_args = f->flags & (uint32_t) func_flags::has_args,
946+
has_var_args = f->flags & (uint32_t) func_flags::has_var_args;
947+
948+
// Number of C++ parameters eligible to be filled from individual
949+
// Python positional arguments
950+
const size_t nargs_pos = f->nargs_pos;
951+
952+
if (nargs_in > nargs_pos && !has_var_args)
953+
continue; // Too many positional arguments for this overload
954+
955+
if (nargs_in < nargs_pos && !has_args)
956+
continue; // Not enough positional arguments
957+
958+
// 1. Copy individual arguments, potentially substitute defaults
959+
size_t i = 0;
960+
for (; i < nargs_pos; ++i) {
961+
PyObject *arg = nullptr;
962+
uint8_t arg_flag = 1;
963+
964+
if (i < nargs_in)
965+
arg = args_in[i];
966+
967+
if (has_args) {
968+
const arg_data &ad = f->args[i];
969+
970+
if (!arg)
971+
arg = ad.value;
972+
arg_flag = ad.flag;
973+
}
974+
975+
if (!arg || (arg == Py_None && (arg_flag & cast_flags::accepts_none) == 0))
976+
break;
977+
978+
// Implicit conversion only active in the 2nd pass
979+
args_flags[i] = arg_flag & ~uint8_t(pass == 0);
980+
args[i] = arg;
981+
}
982+
983+
// Skip this overload if any arguments were unavailable
984+
if (i != nargs_pos)
985+
continue;
986+
987+
// Deal with remaining positional arguments
988+
if (has_var_args) {
989+
PyObject *tuple = PyTuple_New(
990+
nargs_in > nargs_pos ? (Py_ssize_t) (nargs_in - nargs_pos) : 0);
991+
992+
for (size_t j = nargs_pos; j < nargs_in; ++j) {
993+
PyObject *o = args_in[j];
994+
Py_INCREF(o);
995+
NB_TUPLE_SET_ITEM(tuple, j - nargs_pos, o);
996+
}
997+
998+
args[nargs_pos] = tuple;
999+
args_flags[nargs_pos] = 0;
1000+
cleanup.append(tuple);
1001+
}
8301002

8311003
if (is_constructor)
8321004
args_flags[0] |= (uint8_t) cast_flags::construct;
@@ -887,8 +1059,8 @@ static PyObject *nb_func_vectorcall_simple(PyObject *self,
8871059
uint8_t args_flags[NB_MAXARGS_SIMPLE];
8881060
func_data *fr = nb_func_data(self);
8891061

890-
const size_t count = (size_t) Py_SIZE(self),
891-
nargs_in = (size_t) NB_VECTORCALL_NARGS(nargsf);
1062+
const size_t count = (size_t) Py_SIZE(self),
1063+
nargs_in = (size_t) NB_VECTORCALL_NARGS(nargsf);
8921064

8931065
const bool is_method = fr->flags & (uint32_t) func_flags::is_method,
8941066
is_constructor = fr->flags & (uint32_t) func_flags::is_constructor;

src/nb_internals.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ struct nb_func {
100100
PyObject_VAR_HEAD
101101
PyObject* (*vectorcall)(PyObject *, PyObject * const*, size_t, PyObject *);
102102
uint32_t max_nargs; // maximum value of func_data::nargs for any overload
103+
bool has_kwargs; // whether any overload has keyword arguments
103104
bool complex_call;
104105
bool doc_uniform;
105106
};

tests/test_functions.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ NB_MODULE(test_functions_ext, m) {
9595
// Simple binary function (via function pointer)
9696
auto test_02 = [](int up, int down) -> int { return up - down; };
9797
m.def("test_02", (int (*)(int, int)) test_02, "up"_a = 8, "down"_a = 1);
98+
m.def("test_02p", (int (*)(int, int)) test_02, nb::arg()=8, nb::arg()=1);
99+
m.def("test_02nc", (int (*)(int, int)) test_02, nb::arg().noconvert()=8,
100+
nb::arg().noconvert()=1);
98101

99102
// Simple binary function with capture object
100103
int i = 42;

tests/test_functions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ def test01_capture():
2020
# Functions with and without capture object of different sizes
2121
assert t.test_01() is None
2222
assert t.test_02(5, 3) == 2
23+
assert t.test_02p(5, 3) == 2
24+
assert t.test_02nc(5, 3) == 2
2325
assert t.test_03(5, 3) == 44
2426
assert t.test_04() == 60
2527
assert t.test_simple(0, 1, 2, 3, 4, 5, 6, 7) == 14
@@ -29,6 +31,14 @@ def test02_default_args():
2931
# Default arguments
3032
assert t.test_02() == 7
3133
assert t.test_02(7) == 6
34+
assert t.test_02('17') == 16
35+
assert t.test_02p() == 7
36+
assert t.test_02p(7) == 6
37+
assert t.test_02p('17') == 16
38+
assert t.test_02nc() == 7
39+
assert t.test_02nc(7) == 6
40+
with pytest.raises(TypeError):
41+
t.test_02nc('17')
3242

3343

3444
def test03_kwargs():

tests/test_functions_ext.pyi.ref

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ def test_01() -> None: ...
77

88
def test_02(up: int = 8, down: int = 1) -> int: ...
99

10+
def test_02p(arg0: int = 8, arg1: int = 1) -> int: ...
11+
12+
def test_02nc(arg0: int = 8, arg1: int = 1) -> int: ...
13+
1014
def test_03(arg0: int, arg1: int, /) -> int: ...
1115

1216
def test_04() -> int: ...

0 commit comments

Comments
 (0)