Skip to content

Commit eb26a2e

Browse files
Added symbolic width for freezing means
1 parent 2a5431b commit eb26a2e

File tree

8 files changed

+91
-0
lines changed

8 files changed

+91
-0
lines changed

include/drjit-core/array.h

+5
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,11 @@ template <JitBackend Backend_, typename Value_> struct JitArray {
253253
return jit_var_size(m_index);
254254
}
255255

256+
auto symbolic_width() {
257+
using UInt32 = JitArray<Backend_, uint32_t>;
258+
return UInt32::steal(jit_var_symbolic_width(m_index));
259+
}
260+
256261
void resize(size_t size) {
257262
uint32_t index = jit_var_resize(m_index, size);
258263
jit_var_dec_ref(m_index);

include/drjit-core/jit.h

+2
Original file line numberDiff line numberDiff line change
@@ -1176,6 +1176,8 @@ extern JIT_EXPORT uint32_t jit_var_data(uint32_t index, void **ptr_out);
11761176
/// Query the size of a given variable
11771177
extern JIT_EXPORT size_t jit_var_size(uint32_t index);
11781178

1179+
extern JIT_EXPORT uint32_t jit_var_symbolic_width(uint32_t index);
1180+
11791181
/// Query the type of a given variable
11801182
extern JIT_EXPORT JIT_ENUM VarType jit_var_type(uint32_t index);
11811183

src/api.cpp

+18
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,24 @@ size_t jit_var_size(uint32_t index) {
610610
return (size_t) jitc_var(index)->size;
611611
}
612612

613+
uint32_t jit_var_symbolic_width(uint32_t index){
614+
if(index == 0)
615+
return 0;
616+
617+
lock_guard guard(state.lock);
618+
619+
Variable *var = jitc_var(index);
620+
uint32_t var_size = var->size;
621+
622+
uint32_t width_index = jitc_var_literal(
623+
(JitBackend) var->backend, VarType::UInt32, &var_size, 1, true);
624+
625+
ThreadState *ts = thread_state(var->backend);
626+
ts->notify_symbolic_width(index, width_index);
627+
628+
return width_index;
629+
}
630+
613631
VarState jit_var_state(uint32_t index) {
614632
if (index == 0)
615633
return VarState::Invalid;

src/init.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -768,3 +768,4 @@ void ThreadState::reset_state() {
768768
}
769769
void ThreadState::notify_free(const void *) { }
770770
void ThreadState::notify_expand(uint32_t) { }
771+
void ThreadState::notify_symbolic_width(uint32_t, uint32_t) { }

src/internal.h

+2
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,8 @@ struct ThreadState : public ThreadStateBase {
720720
virtual void reduce_expanded(VarType vt, ReduceOp op, void *data,
721721
uint32_t exp, uint32_t size) = 0;
722722

723+
virtual void notify_symbolic_width(uint32_t index, uint32_t width_index);
724+
723725
/// Notify the \c ThreadState that \c jitc_free has been called on a pointer.
724726
/// This is required for kernel freezing.
725727
virtual void notify_free(const void *ptr);

src/record_ts.cpp

+39
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,10 @@ int Recording::replay(const uint32_t *replay_inputs, uint32_t *replay_outputs) {
373373
if (!replay_aggregate(op))
374374
return false;
375375
break;
376+
case OpType::SymbolicWidth:
377+
if(!replay_symbolic_width(op))
378+
return false;
379+
break;
376380
case OpType::Free: {
377381
ProfilerPhase profiler2(pr_free);
378382

@@ -471,6 +475,41 @@ void RecordThreadState::barrier() {
471475
return m_internal->barrier();
472476
}
473477

478+
void RecordThreadState::notify_symbolic_width(uint32_t index,
479+
uint32_t width_index) {
480+
if (!paused()) {
481+
uint32_t start = m_recording.dependencies.size();
482+
Variable *v1 = jitc_var(index);
483+
Variable *v2 = jitc_var(width_index);
484+
add_in_param(v1->data, (VarType) v1->type);
485+
add_out_param(v2->data, VarType::UInt32);
486+
uint32_t end = m_recording.dependencies.size();
487+
488+
Operation op;
489+
op.type = OpType::SymbolicWidth;
490+
op.dependency_range = std::pair(start, end);
491+
m_recording.operations.push_back(op);
492+
}
493+
}
494+
495+
int Recording::replay_symbolic_width(Operation &op){
496+
497+
uint32_t dependency_index = op.dependency_range.first;
498+
AccessInfo in_info = dependencies[dependency_index];
499+
AccessInfo out_info = dependencies[dependency_index + 1];
500+
501+
ReplayVariable &in_var = replay_variables[in_info.slot];
502+
ReplayVariable &out_var = replay_variables[out_info.slot];
503+
504+
out_var.alloc(backend, 1, out_info.vtype);
505+
uint32_t size = in_var.size(in_info.vtype);
506+
507+
if (!dry_run)
508+
jitc_memcpy(backend, out_var.data, &size, sizeof(uint32_t));
509+
510+
return true;
511+
}
512+
474513
/**
475514
* This function is called every time a pointer is freed using \ref
476515
* jitc_free. It records the operation and removes the mapping from that

src/record_ts.h

+5
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ enum class OpType {
3838
BlockPrefixReduce,
3939
ReduceDot,
4040
Aggregate,
41+
SymbolicWidth,
4142
Free,
4243
Count,
4344
};
@@ -296,6 +297,8 @@ struct Recording {
296297

297298
int replay_aggregate(Operation &op);
298299

300+
int replay_symbolic_width(Operation &op);
301+
299302
/// This function is called after recording and checks that the recording is
300303
/// valid i.e. that no variables where left uninitialized.
301304
void validate();
@@ -428,6 +431,8 @@ struct RecordThreadState : ThreadState {
428431
void reduce_expanded(VarType vt, ReduceOp reduce_op, void *data,
429432
uint32_t exp, uint32_t size) override;
430433

434+
void notify_symbolic_width(uint32_t index, uint32_t width_index) override;
435+
431436
/**
432437
* This function is called every time a pointer is freed using \ref
433438
* jitc_free. It records the operation and removes the mapping from that

tests/record.cpp

+19
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,22 @@ TEST_LLVM(10_scatter) {
248248
jit_assert(all(eq(y, arange<UInt32>(10 + i) + 1)));
249249
}
250250
}
251+
252+
TEST_BOTH(11_symbolic_width) {
253+
auto func = [](UInt32 x) {
254+
auto y = block_prefix_sum(x, x.size());
255+
y = y / x.symbolic_width();
256+
return y;
257+
};
258+
259+
FrozenFunction frozen(Backend, func);
260+
261+
for (uint32_t i = 0; i < 4; i++) {
262+
auto x = arange<UInt32>(10 + i);
263+
264+
auto res = frozen(x);
265+
auto ref = func(x);
266+
267+
jit_assert(all(eq(res, ref)));
268+
}
269+
}

0 commit comments

Comments
 (0)