diff --git a/mlc_llm/relax_model/mixtral.py b/mlc_llm/relax_model/mixtral.py index ba50af9c6d..c9caa1ffdd 100644 --- a/mlc_llm/relax_model/mixtral.py +++ b/mlc_llm/relax_model/mixtral.py @@ -350,14 +350,14 @@ def top2_softmax_func( for j in T.unroll(2): with T.block("cast"): vj = T.axis.remap("S", [j]) - local_top_k_f32[vj] = T.cast(local_top_k[j], "float32") + local_top_k_f32[vj] = T.cast(local_top_k[vj], "float32") with T.block("max"): local_top_k_max[0] = T.max(local_top_k_f32[0], local_top_k_f32[1]) for j in T.unroll(2): with T.block("output"): vj = T.axis.remap("S", [j]) out[vi, vj] = T.cast( - T.exp(local_top_k_f32[j] - local_top_k_max[0]) + T.exp(local_top_k_f32[vj] - local_top_k_max[0]) / ( T.exp(local_top_k_f32[0] - local_top_k_max[0]) + T.exp(local_top_k_f32[1] - local_top_k_max[0])