@@ -104,19 +104,15 @@ static std::vector<VisitedKey> visit_later;
104
104
// ====================================================================
105
105
106
106
// Don't perform scatters, whose output buffer is found to be unreferenced
107
- bool jitc_var_maybe_suppress_scatter (uint32_t index, Variable *v, uint32_t depth) {
107
+ bool jitc_elide_scatter (uint32_t index, const Variable *v) {
108
+ if ((VarKind) v->kind != VarKind::Scatter)
109
+ return false ;
108
110
Variable *target = jitc_var (v->dep [0 ]);
109
111
Variable *target_ptr = jitc_var (target->dep [3 ]);
110
- if (target_ptr->ref_count != 0 || depth != 0 )
111
- return false ;
112
-
113
112
jitc_log (Debug, " jit_eval(): eliding scatter r%u, whose output is unreferenced." , index );
114
- if (callable_depth == 0 )
115
- jitc_var_dec_ref (index , v);
116
- return true ;
113
+ return target_ptr->ref_count == 0 ;
117
114
}
118
115
119
-
120
116
// / Recursively traverse the computation graph to find variables needed by a computation
121
117
static void jitc_var_traverse (uint32_t size, uint32_t index, uint32_t depth = 0 ) {
122
118
if (!visited.emplace (size, index , depth).second )
@@ -125,7 +121,7 @@ static void jitc_var_traverse(uint32_t size, uint32_t index, uint32_t depth = 0)
125
121
Variable *v = jitc_var (index );
126
122
switch ((VarKind) v->kind ) {
127
123
case VarKind::Scatter:
128
- if (jitc_var_maybe_suppress_scatter (index , v, depth ))
124
+ if (jitc_elide_scatter (index , v))
129
125
return ;
130
126
break ;
131
127
@@ -690,8 +686,14 @@ void jitc_eval_impl(ThreadState *ts) {
690
686
691
687
ts->scheduled .clear ();
692
688
693
- for (uint32_t index : ts->side_effects )
694
- jitc_var_traverse (jitc_var (index )->size , index );
689
+ for (uint32_t index : ts->side_effects ) {
690
+ Variable *v = jitc_var (index );
691
+
692
+ if (jitc_elide_scatter (index , v))
693
+ jitc_var_dec_ref (index );
694
+ else
695
+ jitc_var_traverse (v->size , index );
696
+ }
695
697
696
698
ts->side_effects .clear ();
697
699
@@ -747,7 +749,6 @@ void jitc_eval_impl(ThreadState *ts) {
747
749
748
750
for (ScheduledGroup &group : schedule_groups) {
749
751
jitc_assemble (ts, group);
750
-
751
752
jitc_run (ts, group);
752
753
}
753
754
0 commit comments