Skip to content

Commit 831784f

Browse files
committed
Throw TypeError when subclasses forget to call __init__
- Based on pybind/pybind11#2152 - Fixes #1210
1 parent d145ec5 commit 831784f

File tree

2 files changed

+66
-7
lines changed

2 files changed

+66
-7
lines changed

src/nb_type.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,68 @@ static int nb_type_init(PyObject *self, PyObject *args, PyObject *kwds) {
498498
return 0;
499499
}
500500

501+
/// metaclass `__call__` function that is used to create all nanobind objects
502+
static PyObject *nb_type_call(PyObject *type, PyObject *args, PyObject *kwargs) {
503+
504+
// use the default metaclass call to create/initialize the object
505+
#if defined(Py_LIMITED_API)
506+
PyObject *self = ((ternaryfunc)PyType_GetSlot(&PyType_Type, Py_tp_call))(type, args, kwargs);
507+
#else
508+
PyObject *self = PyType_Type.tp_call(type, args, kwargs);
509+
#endif
510+
if (self == nullptr) {
511+
return nullptr;
512+
}
513+
514+
// This must be a nanobind instance
515+
nb_inst *inst = (nb_inst *) self;
516+
517+
// Walk through the MRO and check if all nanobind base classes are initialized
518+
PyTypeObject *tp = (PyTypeObject *) type;
519+
#if defined(Py_LIMITED_API)
520+
// In limited API, we must use __mro__ attribute instead of tp->tp_mro
521+
PyObject *mro = PyObject_GetAttrString((PyObject *) tp, "__mro__");
522+
if (!mro) {
523+
Py_DECREF(self);
524+
return nullptr;
525+
}
526+
#else
527+
PyObject *mro = tp->tp_mro;
528+
if (!mro) {
529+
Py_DECREF(self);
530+
PyErr_SetString(PyExc_TypeError, "MRO is not available");
531+
return nullptr;
532+
}
533+
#endif
534+
535+
Py_ssize_t n = NB_TUPLE_GET_SIZE(mro);
536+
for (Py_ssize_t i = 0; i < n; ++i) {
537+
PyTypeObject *base = (PyTypeObject *) NB_TUPLE_GET_ITEM(mro, i);
538+
539+
// Check if this base is a nanobind type
540+
if (!nb_type_check((PyObject *) base))
541+
continue;
542+
543+
// For each nanobind base, check if it's initialized
544+
if (inst->state == nb_inst::state_uninitialized) {
545+
const type_data *t = nb_type_data(Py_TYPE(self));
546+
PyErr_Format(PyExc_TypeError,
547+
"%.200s.__init__() must be called when overriding __init__",
548+
t->name);
549+
#if defined(Py_LIMITED_API)
550+
Py_DECREF(mro);
551+
#endif
552+
Py_DECREF(self);
553+
return nullptr;
554+
}
555+
}
556+
557+
#if defined(Py_LIMITED_API)
558+
Py_DECREF(mro);
559+
#endif
560+
return self;
561+
}
562+
501563
/// Special case to handle 'Class.property = value' assignments
502564
int nb_type_setattro(PyObject* obj, PyObject* name, PyObject* value) {
503565
nb_internals *int_p = internals;
@@ -861,6 +923,7 @@ static PyTypeObject *nb_type_tp(size_t supplement) noexcept {
861923
{ Py_tp_dealloc, (void *) nb_type_dealloc },
862924
{ Py_tp_setattro, (void *) nb_type_setattro },
863925
{ Py_tp_init, (void *) nb_type_init },
926+
{ Py_tp_call, (void *) nb_type_call },
864927
#if defined(Py_LIMITED_API)
865928
{ Py_tp_members, (void *) members },
866929
#endif

tests/test_classes.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -256,13 +256,9 @@ def name(self):
256256
def what(self):
257257
return "b"
258258

259-
with pytest.warns(
260-
RuntimeWarning,
261-
match="nanobind: attempted to access an uninitialized instance of type",
262-
):
263-
with pytest.raises(TypeError) as excinfo:
264-
t.go(Incomplete2())
265-
assert "incompatible function arguments" in str(excinfo.value)
259+
with pytest.raises(TypeError) as excinfo:
260+
Incomplete2()
261+
assert "__init__() must be called when overriding __init__" in str(excinfo.value)
266262

267263

268264
def test12_large_pointers():

0 commit comments

Comments
 (0)