Skip to content

Commit 4cc7be1

Browse files
Added traversal tests
1 parent 4efa918 commit 4cc7be1

File tree

2 files changed

+139
-0
lines changed

2 files changed

+139
-0
lines changed

tests/custom_type_ext.cpp

+76
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#include <drjit/python.h>
22
#include <drjit/autodiff.h>
33
#include <drjit/packet.h>
4+
#include <drjit/traversable_base.h>
5+
#include <nanobind/nanobind.h>
6+
#include <nanobind/trampoline.h>
47

58
namespace nb = nanobind;
69
namespace dr = drjit;
@@ -42,6 +45,49 @@ struct CustomHolder {
4245
Value m_value;
4346
};
4447

48+
class Object : public drjit::TraversableBase {
49+
DR_TRAVERSE_CB(drjit::TraversableBase);
50+
};
51+
52+
template <typename Value>
53+
class CustomBase : public Object{
54+
public:
55+
CustomBase() : Object() {}
56+
57+
virtual Value &value() = 0;
58+
59+
DR_TRAVERSE_CB(Object);
60+
};
61+
62+
template <typename Value>
63+
class PyCustomBase : public CustomBase<Value>{
64+
public:
65+
using Base = CustomBase<Value>;
66+
NB_TRAMPOLINE(Base, 1);
67+
68+
PyCustomBase() : Base() {}
69+
70+
Value &value() override { NB_OVERRIDE(value); }
71+
72+
DR_TRAMPOLINE_TRAVERSE_CB(Base);
73+
};
74+
75+
template <typename Value>
76+
class CustomA: public CustomBase<Value>{
77+
public:
78+
using Base = CustomBase<Value>;
79+
80+
CustomA() {}
81+
CustomA(const Value &v) : m_value(v) {}
82+
83+
Value &value() override { return m_value; }
84+
85+
private:
86+
Value m_value;
87+
88+
DR_TRAVERSE_CB(Base, m_value);
89+
};
90+
4591

4692
template <JitBackend Backend> void bind(nb::module_ &m) {
4793
dr::ArrayBinding b;
@@ -64,12 +110,42 @@ template <JitBackend Backend> void bind(nb::module_ &m) {
64110
.def(nb::init<Float>())
65111
.def("value", &CustomFloatHolder::value, nanobind::rv_policy::reference);
66112

113+
using CustomBase = CustomBase<Float>;
114+
using PyCustomBase = PyCustomBase<Float>;
115+
using CustomA = CustomA<Float>;
116+
117+
auto object = nb::class_<Object>(
118+
m, "Object",
119+
nb::intrusive_ptr<Object>(
120+
[](Object *o, PyObject *po) noexcept { o->set_self_py(po); }));
121+
122+
auto base = nb::class_<CustomBase, Object, PyCustomBase>(m, "CustomBase")
123+
.def(nb::init())
124+
.def("value", &CustomBase::value);
125+
jit_log(LogLevel::Debug, "binding base");
126+
127+
drjit::bind_traverse(base);
128+
129+
auto a = nb::class_<CustomA>(m, "CustomA").def(nb::init<Float>());
130+
131+
drjit::bind_traverse(a);
132+
67133
m.def("cpp_make_opaque",
68134
[](CustomFloatHolder &holder) { dr::make_opaque(holder); }
69135
);
70136
}
71137

72138
NB_MODULE(custom_type_ext, m) {
139+
nb::intrusive_init(
140+
[](PyObject *o) noexcept {
141+
nb::gil_scoped_acquire guard;
142+
Py_INCREF(o);
143+
},
144+
[](PyObject *o) noexcept {
145+
nb::gil_scoped_acquire guard;
146+
Py_DECREF(o);
147+
});
148+
73149
#if defined(DRJIT_ENABLE_LLVM)
74150
nb::module_ llvm = m.def_submodule("llvm");
75151
bind<JitBackend::LLVM>(llvm);

tests/test_custom_type_ext.py

+63
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import drjit as dr
22
import pytest
33

4+
dr.set_log_level(dr.LogLevel.Info)
45

56
def get_pkg(t):
67
with dr.detail.scoped_rtld_deepbind():
@@ -69,3 +70,65 @@ def test03_cpp_make_opaque(t):
6970

7071
pkg.cpp_make_opaque(holder)
7172
assert holder.value().state == dr.VarState.Evaluated
73+
74+
75+
@pytest.test_arrays("float32,-diff,shape=(*),jit")
76+
def test04_traverse_opaque(t):
77+
# Tests that it is possible to traverse an opaque C++ object
78+
pkg = get_pkg(t)
79+
print(f"{dir(pkg)=}")
80+
Float = t
81+
82+
v = dr.arange(Float, 10)
83+
84+
a = pkg.CustomA(v)
85+
assert dr.detail.collect_indices(a) == [v.index]
86+
87+
88+
@pytest.test_arrays("float32,-diff,shape=(*),jit")
89+
def test05_traverse_py(t):
90+
# Tests the implementation of `%raverse_py_cb_ro`,
91+
# used for traversal of PyTrees inside of C++ objects
92+
Float = t
93+
94+
v = dr.arange(Float, 10)
95+
96+
class PyClass:
97+
def __init__(self, v) -> None:
98+
self.v = v
99+
100+
c = PyClass(v)
101+
102+
result = []
103+
104+
def callback(index):
105+
result.append(index)
106+
107+
dr.detail.traverse_py_cb_ro(c, callback)
108+
109+
assert result == [v.index]
110+
111+
112+
@pytest.test_arrays("float32,-diff,shape=(*),jit")
113+
def test06_trampoline_traversal(t):
114+
# Tests that classes inhereting from trampoline classes are traversed
115+
# automatically
116+
pkg = get_pkg(t)
117+
print(f"{dir(pkg)=}")
118+
Float = t
119+
120+
v = dr.opaque(Float, 0, 3)
121+
122+
class B(pkg.CustomBase):
123+
def __init__(self, v) -> None:
124+
super().__init__()
125+
self.v = v
126+
127+
def value(self):
128+
return self.v
129+
130+
b = B(v)
131+
132+
b.value()
133+
134+
assert dr.detail.collect_indices(b) == [v.index]

0 commit comments

Comments
 (0)