Skip to content

Commit

Permalink
[Model] Update Mixtral to have well-formed TIR
Browse files Browse the repository at this point in the history
Inside a `T.block`, loop variables may not be used, and access to them
must be done through the corresponding `T.axis.remap` output.
  • Loading branch information
Lunderberg committed Mar 1, 2024
1 parent 4b59cfa commit e653630
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions mlc_llm/relax_model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,14 +350,14 @@ def top2_softmax_func(
for j in T.unroll(2):
with T.block("cast"):
vj = T.axis.remap("S", [j])
local_top_k_f32[vj] = T.cast(local_top_k[j], "float32")
local_top_k_f32[vj] = T.cast(local_top_k[vj], "float32")
with T.block("max"):
local_top_k_max[0] = T.max(local_top_k_f32[0], local_top_k_f32[1])
for j in T.unroll(2):
with T.block("output"):
vj = T.axis.remap("S", [j])
out[vi, vj] = T.cast(
T.exp(local_top_k_f32[j] - local_top_k_max[0])
T.exp(local_top_k_f32[vj] - local_top_k_max[0])
/ (
T.exp(local_top_k_f32[0] - local_top_k_max[0])
+ T.exp(local_top_k_f32[1] - local_top_k_max[0])
Expand Down

0 comments on commit e653630

Please sign in to comment.