1
1
#include < drjit/python.h>
2
2
#include < drjit/autodiff.h>
3
3
#include < drjit/packet.h>
4
+ #include < drjit/traversable_base.h>
5
+ #include < nanobind/nanobind.h>
6
+ #include < nanobind/trampoline.h>
4
7
5
8
namespace nb = nanobind;
6
9
namespace dr = drjit;
@@ -42,6 +45,49 @@ struct CustomHolder {
42
45
Value m_value;
43
46
};
44
47
48
+ class Object : public drjit ::TraversableBase {
49
+ DR_TRAVERSE_CB (drjit::TraversableBase);
50
+ };
51
+
52
+ template <typename Value>
53
+ class CustomBase : public Object {
54
+ public:
55
+ CustomBase () : Object() {}
56
+
57
+ virtual Value &value () = 0;
58
+
59
+ DR_TRAVERSE_CB (Object);
60
+ };
61
+
62
+ template <typename Value>
63
+ class PyCustomBase : public CustomBase <Value>{
64
+ public:
65
+ using Base = CustomBase<Value>;
66
+ NB_TRAMPOLINE (Base, 1 );
67
+
68
+ PyCustomBase () : Base() {}
69
+
70
+ Value &value () override { NB_OVERRIDE (value); }
71
+
72
+ DR_TRAMPOLINE_TRAVERSE_CB (Base);
73
+ };
74
+
75
+ template <typename Value>
76
+ class CustomA : public CustomBase <Value>{
77
+ public:
78
+ using Base = CustomBase<Value>;
79
+
80
+ CustomA () {}
81
+ CustomA (const Value &v) : m_value(v) {}
82
+
83
+ Value &value () override { return m_value; }
84
+
85
+ private:
86
+ Value m_value;
87
+
88
+ DR_TRAVERSE_CB (Base, m_value);
89
+ };
90
+
45
91
46
92
template <JitBackend Backend> void bind (nb::module_ &m) {
47
93
dr::ArrayBinding b;
@@ -64,12 +110,42 @@ template <JitBackend Backend> void bind(nb::module_ &m) {
64
110
.def (nb::init<Float>())
65
111
.def (" value" , &CustomFloatHolder::value, nanobind::rv_policy::reference);
66
112
113
+ using CustomBase = CustomBase<Float>;
114
+ using PyCustomBase = PyCustomBase<Float>;
115
+ using CustomA = CustomA<Float>;
116
+
117
+ auto object = nb::class_<Object>(
118
+ m, " Object" ,
119
+ nb::intrusive_ptr<Object>(
120
+ [](Object *o, PyObject *po) noexcept { o->set_self_py (po); }));
121
+
122
+ auto base = nb::class_<CustomBase, Object, PyCustomBase>(m, " CustomBase" )
123
+ .def (nb::init ())
124
+ .def (" value" , &CustomBase::value);
125
+ jit_log (LogLevel::Debug, " binding base" );
126
+
127
+ drjit::bind_traverse (base);
128
+
129
+ auto a = nb::class_<CustomA>(m, " CustomA" ).def (nb::init<Float>());
130
+
131
+ drjit::bind_traverse (a);
132
+
67
133
m.def (" cpp_make_opaque" ,
68
134
[](CustomFloatHolder &holder) { dr::make_opaque (holder); }
69
135
);
70
136
}
71
137
72
138
NB_MODULE (custom_type_ext, m) {
139
+ nb::intrusive_init (
140
+ [](PyObject *o) noexcept {
141
+ nb::gil_scoped_acquire guard;
142
+ Py_INCREF (o);
143
+ },
144
+ [](PyObject *o) noexcept {
145
+ nb::gil_scoped_acquire guard;
146
+ Py_DECREF (o);
147
+ });
148
+
73
149
#if defined(DRJIT_ENABLE_LLVM)
74
150
nb::module_ llvm = m.def_submodule (" llvm" );
75
151
bind<JitBackend::LLVM>(llvm);
0 commit comments