Skip to content

Commit 122b673

Browse files
committed
Scatter elision cleanup
Dr.Jit can elide scatter operations when their result can no longer be referenced by any other operations. The logic to do so, and when reference count decreases are needed, was dispersed throughout ``eval.cpp``. This commit simplifies the underlying code.
1 parent 38c4a3c commit 122b673

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

src/eval.cpp

+13-12
Original file line numberDiff line numberDiff line change
@@ -104,19 +104,15 @@ static std::vector<VisitedKey> visit_later;
104104
// ====================================================================
105105

106106
// Don't perform scatters, whose output buffer is found to be unreferenced
107-
bool jitc_var_maybe_suppress_scatter(uint32_t index, Variable *v, uint32_t depth) {
107+
bool jitc_elide_scatter(uint32_t index, const Variable *v) {
108+
if ((VarKind) v->kind != VarKind::Scatter)
109+
return false;
108110
Variable *target = jitc_var(v->dep[0]);
109111
Variable *target_ptr = jitc_var(target->dep[3]);
110-
if (target_ptr->ref_count != 0 || depth != 0)
111-
return false;
112-
113112
jitc_log(Debug, "jit_eval(): eliding scatter r%u, whose output is unreferenced.", index);
114-
if (callable_depth == 0)
115-
jitc_var_dec_ref(index, v);
116-
return true;
113+
return target_ptr->ref_count == 0;
117114
}
118115

119-
120116
/// Recursively traverse the computation graph to find variables needed by a computation
121117
static void jitc_var_traverse(uint32_t size, uint32_t index, uint32_t depth = 0) {
122118
if (!visited.emplace(size, index, depth).second)
@@ -125,7 +121,7 @@ static void jitc_var_traverse(uint32_t size, uint32_t index, uint32_t depth = 0)
125121
Variable *v = jitc_var(index);
126122
switch ((VarKind) v->kind) {
127123
case VarKind::Scatter:
128-
if (jitc_var_maybe_suppress_scatter(index, v, depth))
124+
if (jitc_elide_scatter(index, v))
129125
return;
130126
break;
131127

@@ -690,8 +686,14 @@ void jitc_eval_impl(ThreadState *ts) {
690686

691687
ts->scheduled.clear();
692688

693-
for (uint32_t index: ts->side_effects)
694-
jitc_var_traverse(jitc_var(index)->size, index);
689+
for (uint32_t index: ts->side_effects) {
690+
Variable *v = jitc_var(index);
691+
692+
if (jitc_elide_scatter(index, v))
693+
jitc_var_dec_ref(index);
694+
else
695+
jitc_var_traverse(v->size, index);
696+
}
695697

696698
ts->side_effects.clear();
697699

@@ -747,7 +749,6 @@ void jitc_eval_impl(ThreadState *ts) {
747749

748750
for (ScheduledGroup &group : schedule_groups) {
749751
jitc_assemble(ts, group);
750-
751752
jitc_run(ts, group);
752753
}
753754

src/llvm_eval.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,7 @@ static void jitc_llvm_render(Variable *v) {
957957
}
958958
}
959959
break;
960+
960961
case VarKind::Scatter:
961962
if (v->literal)
962963
jitc_llvm_render_scatter_reduce(v, a0, a1, a2, a3);

0 commit comments

Comments
 (0)