Skip to content

Commit

Permalink
Fix invalid use of dataflow var in sampler output (#2003)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 authored Mar 22, 2024
1 parent 8405cb1 commit 64badb5
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion python/mlc_llm/compiler_pass/attach_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def _attach_argsort_func(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr):
sorted_indices,
primfunc_name_hint="take_sorted_probs",
)
gv = bb.emit_func_output([sorted_values, sorted_indices])
output = (sorted_values, sorted_indices)
bb.emit_output(output)
gv = bb.emit_func_output(output)
return gv


Expand Down Expand Up @@ -201,6 +203,7 @@ def full(var_result: T.handle, value: T.int32):
sinfo_args=sample_indices.struct_info, # pylint: disable=no-member
)
)
bb.emit_output(result)
gv = bb.emit_func_output(result)
return gv

Expand Down Expand Up @@ -270,5 +273,6 @@ def sampler_take_probs_tir( # pylint: disable=too-many-locals,too-many-argument
],
)
)
bb.emit_output(taken_probs_indices)
gv = bb.emit_func_output(taken_probs_indices)
return gv

0 comments on commit 64badb5

Please sign in to comment.