Skip to content

Commit 3d80462

Browse files
Added base value test for custom_type_ext traversal
1 parent d0b9871 commit 3d80462

File tree

2 files changed

+25
-19
lines changed

2 files changed

+25
-19
lines changed

tests/custom_type_ext.cpp

+13-9
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,15 @@ class Object : public drjit::TraversableBase {
5151

5252
template <typename Value>
5353
class CustomBase : public Object{
54+
Value m_base_value;
55+
5456
public:
55-
CustomBase() : Object() {}
57+
CustomBase(const Value &base_value) : Object(), m_base_value(base_value) {}
5658

59+
Value &base_value() { return m_base_value; }
5760
virtual Value &value() = 0;
5861

59-
DR_TRAVERSE_CB(Object);
62+
DR_TRAVERSE_CB(Object, m_base_value);
6063
};
6164

6265
template <typename Value>
@@ -65,7 +68,7 @@ class PyCustomBase : public CustomBase<Value>{
6568
using Base = CustomBase<Value>;
6669
NB_TRAMPOLINE(Base, 1);
6770

68-
PyCustomBase() : Base() {}
71+
PyCustomBase(const Value &base_value) : Base(base_value) {}
6972

7073
Value &value() override { NB_OVERRIDE_PURE(value); }
7174

@@ -77,8 +80,7 @@ class CustomA: public CustomBase<Value>{
7780
public:
7881
using Base = CustomBase<Value>;
7982

80-
CustomA() {}
81-
CustomA(const Value &v) : m_value(v) {}
83+
CustomA(const Value &value, const Value &base_value) : Base(base_value), m_value(value) {}
8284

8385
Value &value() override { return m_value; }
8486

@@ -119,13 +121,15 @@ template <JitBackend Backend> void bind(nb::module_ &m) {
119121
nb::intrusive_ptr<Object>(
120122
[](Object *o, PyObject *po) noexcept { o->set_self_py(po); }));
121123

122-
auto base = nb::class_<CustomBase, Object, PyCustomBase>(m, "CustomBase")
123-
.def(nb::init())
124-
.def("value", nb::overload_cast<>(&CustomBase::value));
124+
auto base =
125+
nb::class_<CustomBase, Object, PyCustomBase>(m, "CustomBase")
126+
.def(nb::init<Float>())
127+
.def("value", nb::overload_cast<>(&CustomBase::value))
128+
.def("base_value", nb::overload_cast<>(&CustomBase::base_value));
125129

126130
drjit::bind_traverse(base);
127131

128-
auto a = nb::class_<CustomA>(m, "CustomA").def(nb::init<Float>());
132+
auto a = nb::class_<CustomA>(m, "CustomA").def(nb::init<Float, Float>());
129133

130134
drjit::bind_traverse(a);
131135

tests/test_custom_type_ext.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,11 @@ def test04_traverse_opaque(t):
7878
pkg = get_pkg(t)
7979
Float = t
8080

81-
v = dr.arange(Float, 10)
81+
value = dr.arange(Float, 10)
82+
base_value = dr.arange(Float, 10)
8283

83-
a = pkg.CustomA(v)
84-
assert dr.detail.collect_indices(a) == [v.index]
84+
a = pkg.CustomA(value, base_value)
85+
assert dr.detail.collect_indices(a) == [base_value.index, value.index]
8586

8687

8788
@pytest.test_arrays("float32,-diff,shape=(*),jit")
@@ -119,16 +120,17 @@ def test06_trampoline_traversal(t):
119120
pkg = get_pkg(t)
120121
Float = t
121122

122-
v = dr.opaque(Float, 0, 3)
123+
value = dr.opaque(Float, 0, 3)
124+
base_value = dr.opaque(Float, 1, 3)
123125

124126
class B(pkg.CustomBase):
125-
def __init__(self, v) -> None:
126-
super().__init__()
127-
self.v = v
127+
def __init__(self, value, base_value) -> None:
128+
super().__init__(base_value)
129+
self._value = value
128130

129131
def value(self):
130-
return self.v
132+
return self._value
131133

132-
b = B(v)
134+
b = B(value, base_value)
133135

134-
assert dr.detail.collect_indices(b) == [v.index]
136+
assert dr.detail.collect_indices(b) == [base_value.index, value.index]

0 commit comments

Comments
 (0)