Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
7bea30d
modify compiler_hint path to tl.extra.cann.extension; tl.max add prop…
kevin-hongkai Apr 9, 2026
3189c5c
swa add enable_ubuf_saving to solve ub overflow
kevin-hongkai Apr 9, 2026
f814f12
Merge remote-tracking branch 'origin/master' into hongkai/switch-to-t…
kevin-hongkai Apr 10, 2026
954c4b7
normalization use casting_mode gemma to avoid F.rms_norm fp32 cast pr…
kevin-hongkai Apr 11, 2026
64f3733
Merge remote-tracking branch 'origin/master' into hongkai/switch-to-t…
kevin-hongkai Apr 14, 2026
da275e9
remove print
kevin-hongkai Apr 15, 2026
d705131
flash attention use need_mask to do if else
kevin-hongkai Apr 16, 2026
2d8b7ae
Revert "flash attention use need_mask to do if else"
kevin-hongkai Apr 17, 2026
e64820e
Merge remote-tracking branch 'origin/master' into hongkai/switch-to-t…
kevin-hongkai Apr 27, 2026
cb1fa9d
modify test_over_encoding to switch to triton-ascend
kevin-hongkai Apr 27, 2026
b720f9d
modify n_gram mask to solve random precision problem
kevin-hongkai Apr 28, 2026
8d799a2
switch to triton-ascend 3.2.1, use wget temprorarily for checking CI …
kevin-hongkai May 7, 2026
bca288b
Merge branch 'master' into hongkai/switch-to-triton-ascend
kevin-hongkai May 7, 2026
c5d3cf5
switch to triton-ascend 3.2.1, use wget temprorarily for checking CI …
kevin-hongkai May 7, 2026
1714995
switch to triton-ascend 3.2.1, use wget temprorarily for checking CI …
kevin-hongkai May 8, 2026
4e22942
add print to debug CI
kevin-hongkai May 8, 2026
2945c96
modify ccec path
kevin-hongkai May 8, 2026
7c9b410
Merge remote-tracking branch 'origin/master' into hongkai/switch-to-t…
kevin-hongkai May 9, 2026
cb29734
add CI debug
kevin-hongkai May 9, 2026
14d36ec
add CI debug
kevin-hongkai May 9, 2026
f4a7c35
add CI debug
kevin-hongkai May 9, 2026
2578fdf
rollback ci and switch to cann8.5.0 image
kevin-hongkai May 15, 2026
79052b3
add triton-ascend on CI
kevin-hongkai May 15, 2026
704ca79
Merge remote-tracking branch 'origin/master' into hongkai/switch-to-t…
kevin-hongkai May 18, 2026
d4ea8aa
fix inder quant para error
kevin-hongkai May 18, 2026
d3dc6b1
fix perf test case
kevin-hongkai May 18, 2026
1569581
modify extract_slice _compute_vision_rope to adapter triton-ascend
kevin-hongkai May 18, 2026
4ab3469
add sync_solver=False to avoid groupgemm perf descend on triton-ascend
kevin-hongkai May 20, 2026
91d08a2
switch to byted-triton-x 3.2.1
kevin-hongkai May 25, 2026
3687e64
add --index-url to switch to byted-triton-x 3.2.1
kevin-hongkai May 25, 2026
be9d2fe
add --index-url to switch to byted-triton-x 3.2.1
kevin-hongkai May 25, 2026
3d00a17
Update mojo_opset/backends/ttx/kernels/npu/over_encoding/n_gram.py
kevin-hongkai May 25, 2026
42184ab
rollback to rmsnorm_fwd llama mode, triton-ascend 3.2.1 precision is OK
kevin-hongkai May 25, 2026
a2fd489
functions rms backward precision is not OK, change to gemma
kevin-hongkai May 25, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .github/workflows/ascend_accuracy_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
name: Run npu tests
runs-on: NPU
container:
image: hub.byted.org/aicompiler/npu.ci.debian12:cann8.2.rc1_py3.11_th27
image: hub.byted.org/aicompiler/npu.ci.debian12:1.0.0.26
volumes:
- /etc/localtime:/etc/localtime
- /usr/local/Ascend/driver:/usr/local/Ascend/driver
Expand All @@ -56,6 +56,9 @@ jobs:
pip install -e .\[npu\] --index-url https://bytedpypi.byted.org/simple
python3 -m build .
pip install dist/*.whl --force-reinstall --no-deps
pip uninstall -y byted-triton-x triton
rm -rf /usr/local/lib/python3.11/dist-packages/triton
pip install byted-triton-x>=3.2.1 --index-url https://bytedpypi.byted.org/simple
- name: test-Base
shell: bash
run: |
Expand Down Expand Up @@ -113,4 +116,4 @@ jobs:
shell: bash
run: |
echo "Running ttx graph tests..."
MOJO_RUN_MODE="COMPILE" python3 -m pytest -n 8 mojo_opset/tests/test_ttx_graph/
MOJO_RUN_MODE="COMPILE" python3 -m pytest -n 8 mojo_opset/tests/test_ttx_graph/
2 changes: 1 addition & 1 deletion mojo_opset/backends/ttx/functions/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def forward(
) -> torch.Tensor:
# FIXME: Currently, MojoNormFunction base class does not define fields like 'offset', so they are hardcoded here temporarily.
offset = 0.0
casting_mode = "llama"
casting_mode = "gemma"
str_to_casting_mode = {"llama": 0, "gemma": 1, "none": -1}
casting_mode_int = str_to_casting_mode[casting_mode]

Expand Down
32 changes: 16 additions & 16 deletions mojo_opset/backends/ttx/kernels/npu/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def causal_conv1d_fwd_kernel(
# We keep intra loop load because preloading will cause ub overflow under certain tiling.
b_yi = tl.load(x + bos * D + yi_offset_0 * D + yi_offset_1, mask=mask, other=0.0).to(tl.float32)
if HAS_WEIGHT:
b_yi *= tl.extract_slice(b_w, [i_w + W - 1, 0], [1, BD], [1, 1])
b_yi *= tl.extra.cann.extension.extract_slice(b_w, [i_w + W - 1, 0], [1, BD], [1, 1])

b_y += b_yi
elif i_t * BT >= W:
Expand All @@ -105,7 +105,7 @@ def causal_conv1d_fwd_kernel(
mask = (yi_offset_0 < T_len) & (yi_offset_1 < D) & (yi_offset_0 >= 0)
b_yi = tl.load(x + bos * D + yi_offset_0 * D + yi_offset_1, mask=mask, other=0.0).to(tl.float32)
if HAS_WEIGHT:
b_yi *= tl.extract_slice(b_w, [i_w + W - 1, 0], [1, BD], [1, 1])
b_yi *= tl.extra.cann.extension.extract_slice(b_w, [i_w + W - 1, 0], [1, BD], [1, 1])
b_y += b_yi
else:
o_t = i_t * BT + tl.arange(0, BT)
Expand All @@ -123,7 +123,7 @@ def causal_conv1d_fwd_kernel(
)

if HAS_WEIGHT:
b_yi *= tl.extract_slice(b_w, [i_w + W - 1, 0], [1, BD], [1, 1])
b_yi *= tl.extra.cann.extension.extract_slice(b_w, [i_w + W - 1, 0], [1, BD], [1, 1])
b_y += b_yi

if HAS_BIAS:
Expand Down Expand Up @@ -246,19 +246,19 @@ def causal_conv1d_bwd_kernel(
b_y = tl.load(p_y, boundary_check=(0, 1)).to(tl.float32)

for i_w in tl.static_range(0, W):
b_dy_sub = tl.extract_slice(b_dy, [i_w, 0], [BT, BD], [1, 1])
b_dy_sub = tl.extra.cann.extension.extract_slice(b_dy, [i_w, 0], [BT, BD], [1, 1])

if ACTIVATION == "swish" or ACTIVATION == "silu":
b_y_sub = tl.extract_slice(b_y, [i_w, 0], [BT, BD], [1, 1])
b_y_sub = tl.extra.cann.extension.extract_slice(b_y, [i_w, 0], [BT, BD], [1, 1])
b_ys = tl.sigmoid(b_y_sub)
b_dy_sub = b_dy_sub * b_ys * (1 + b_y_sub * (1 - b_ys))

b_wdy = b_dy_sub
if HAS_WEIGHT:
b_wdy = b_wdy * tl.extract_slice(b_w, [W - i_w - 1, 0], [1, BD], [1, 1])
b_wdy = b_wdy * tl.extra.cann.extension.extract_slice(b_w, [W - i_w - 1, 0], [1, BD], [1, 1])

b_dw_sub = tl.sum(b_dy_sub * b_x, 0) # [BT, BD] * [BT, BD] --> sum(0) = [BD]
b_dw = tl.insert_slice(b_dw, b_dw_sub[None, :], [W - i_w - 1, 0], [1, BD], [1, 1])
b_dw = tl.extra.cann.extension.insert_slice(b_dw, b_dw_sub[None, :], [W - i_w - 1, 0], [1, BD], [1, 1])

if HAS_BIAS and i_w == 0:
b_db += tl.sum(b_dy_sub, 0)
Expand All @@ -280,7 +280,7 @@ def causal_conv1d_bwd_kernel(
b_dy = b_dy * b_ys * (1 + b_y * (1 - b_ys))
b_wdy = b_dy
if HAS_WEIGHT:
b_wdy = b_wdy * tl.extract_slice(b_w, [W - i_w - 1, 0], [1, BD], [1, 1])
b_wdy = b_wdy * tl.extra.cann.extension.extract_slice(b_w, [W - i_w - 1, 0], [1, BD], [1, 1])

b_dw = tl.sum(b_dy * b_x, 0)
tl.store(dw + i_tg * D * W + o_d * W + W - i_w - 1, b_dw.to(dw.dtype.element_ty), mask=m_d)
Expand Down Expand Up @@ -335,7 +335,7 @@ def causal_conv1d_bwd_kernel(
b_wdy = (
b_dy_shift
if not HAS_WEIGHT
else (b_dy_shift * tl.extract_slice(b_w, [W - i_w - 1, 0], [1, BD], [1, 1]))
else (b_dy_shift * tl.extra.cann.extension.extract_slice(b_w, [W - i_w - 1, 0], [1, BD], [1, 1]))
)
b_dx += b_wdy

Expand Down Expand Up @@ -583,7 +583,7 @@ def causal_conv1d_states_fwd_kernel(
o_t = eos - BW + tl.arange(0, BW)
o_d = i_d * BD + tl.arange(0, BD)
o_w = W - BW + tl.arange(0, BW)
m_t = o_t >= tl.maximum(bos, eos - W)
m_t = o_t >= tl.maximum(bos, eos - W,propagate_nan=tl.PropagateNan.ALL)
m_d = o_d < D
m_w = (o_w >= 0) & (o_w < W)

Expand Down Expand Up @@ -694,7 +694,7 @@ def causal_conv1d_update_kernel_bdt_fwd(
mask_x = (offset0_x < dim)[:, None] & ((offset1_x >= 0) & (offset1_x < seq_len))[None, :]
block_off_x = bi * dim * seq_len + offset0_x[:, None] * seq_len + offset1_x[None, :]
x_b_tmp = tl.load(x_ptr + block_off_x, mask=mask_x, other=0)
x_b = tl.insert_slice(st_b, x_b_tmp, (0, width - 1), (D_CHK_SIZE, T_CHK_SIZE), (1, 1))
x_b = tl.extra.cann.extension.insert_slice(st_b, x_b_tmp, (0, width - 1), (D_CHK_SIZE, T_CHK_SIZE), (1, 1))
else:
offset0 = di * D_CHK_SIZE + tl.arange(0, D_CHK_SIZE)
offset1 = ti * T_CHK_SIZE - (width - 1) + tl.arange(0, T_CHK_SIZE + width - 1)
Expand All @@ -715,15 +715,15 @@ def causal_conv1d_update_kernel_bdt_fwd(
# NOTE: In order to avoid use tl.maximum for negative offset,
# we pre-compute a fix head tile size (ST_STORE_HEAD_TILE_SIZE)
# to store the scene of negative address
x_new_h = tl.extract_slice(x_b, (-t_off, 0), (ST_STORE_HEAD_TILE_SIZE, D_CHK_SIZE), (1, 1))
x_new_h = tl.extra.cann.extension.extract_slice(x_b, (-t_off, 0), (ST_STORE_HEAD_TILE_SIZE, D_CHK_SIZE), (1, 1))
x_new_h = tl.trans(x_new_h, (1, 0))
nst_off_y0 = di * D_CHK_SIZE + tl.arange(0, D_CHK_SIZE)[:, None]
nst_off_y1_h = tl.arange(0, ST_STORE_HEAD_TILE_SIZE)[None, :]
nst_mask_h = (nst_off_y0 < dim) & (nst_off_y1_h >= 0) & (nst_off_y1_h < state_len)
block_ptr_h = bi * dim * state_len + nst_off_y0 * state_len + nst_off_y1_h
tl.store(conv_state_update_ptr + block_ptr_h, x_new_h, mask=nst_mask_h)
else:
x_new_s = tl.extract_slice(x_b, (width - 1, 0), (T_CHK_SIZE, D_CHK_SIZE), (1, 1))
x_new_s = tl.extra.cann.extension.extract_slice(x_b, (width - 1, 0), (T_CHK_SIZE, D_CHK_SIZE), (1, 1))
x_new_s = tl.trans(x_new_s, (1, 0))
nst_off_y0 = di * D_CHK_SIZE + tl.arange(0, D_CHK_SIZE)[:, None]
nst_off_y1 = width - 1 + t_off + tl.arange(0, T_CHK_SIZE)[None, :]
Expand All @@ -732,14 +732,14 @@ def causal_conv1d_update_kernel_bdt_fwd(
tl.store(conv_state_update_ptr + block_ptr, x_new_s, mask=nst_mask)

for owi in tl.range(0, width):
new_x = tl.extract_slice(x_b, (owi, 0), (T_CHK_SIZE, D_CHK_SIZE), (1, 1))
w_chl_wi = tl.extract_slice(w, (owi, 0), (1, D_CHK_SIZE), (1, 1))
new_x = tl.extra.cann.extension.extract_slice(x_b, (owi, 0), (T_CHK_SIZE, D_CHK_SIZE), (1, 1))
w_chl_wi = tl.extra.cann.extension.extract_slice(w, (owi, 0), (1, D_CHK_SIZE), (1, 1))
x_mul_chl_wi = new_x * w_chl_wi
out_block += x_mul_chl_wi
out_block = tl.trans(out_block, (1, 0))

if SILU_ACTIVATION:
out_block = out_block * tl.sigmoid(out_block)
out_block = (out_block * tl.sigmoid(out_block)).to(x_ptr.dtype.element_ty)
tl.store(
tl.make_block_ptr(
out_ptr,
Expand Down
2 changes: 1 addition & 1 deletion mojo_opset/backends/ttx/kernels/npu/diffution_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def micro_kernel_fwd(
block_s = tl.dot(block_q, block_k.T) * scale
if block_mask is not None:
block_s += block_mask
block_m_1 = tl.maximum(block_m, tl.max(block_s, axis=1))
block_m_1 = tl.maximum(block_m, tl.max(block_s, axis=1,propagate_nan=tl.PropagateNan.ALL),propagate_nan=tl.PropagateNan.ALL)
block_s = tl.exp(block_s - block_m_1[:, None])
block_l_1 = tl.exp(block_m - block_m_1) * block_l + tl.sum(block_s, axis=1)

Expand Down
6 changes: 3 additions & 3 deletions mojo_opset/backends/ttx/kernels/npu/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _sdpa_infer_single_block(
if mask is not None:
qk = tl.where(mask, qk, float("-inf")) # 32B # bool

m_ij = tl.maximum(m_i, tl.max(qk, 1)) # Scaled max
m_ij = tl.maximum(m_i, tl.max(qk, 1, propagate_nan=tl.PropagateNan.ALL), propagate_nan=tl.PropagateNan.ALL) # Scaled max
qk = qk - m_ij[:, None] # Stabilize

# Softmax weights p = exp(qk)
Expand Down Expand Up @@ -423,8 +423,8 @@ def paged_decode_kernel(
qk *= softmax_scale
qk = tl.where(mask, qk, float("-inf"))

m_j = tl.max(qk, axis=0)
m_ij = tl.maximum(m_i, m_j)
m_j = tl.max(qk, axis=0,propagate_nan=tl.PropagateNan.ALL)
m_ij = tl.maximum(m_i, m_j,propagate_nan=tl.PropagateNan.ALL)
qk = qk - m_ij

p = tl.math.exp(qk)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _cross_entropy_kernel(

if HAS_SOFTCAPPING:
X_block = softcap * tl.math.tanh(X_block / softcap)
block_max = tl.max(X_block)
block_max = tl.max(X_block,propagate_nan=tl.PropagateNan.ALL)
if label_smoothing > 0:
X_block2 = tl.load(
X_ptr + X_offsets,
Expand All @@ -107,7 +107,7 @@ def _cross_entropy_kernel(
scaled_x_sum += tl.sum(-eps * X_block2 * weight_block).to(tl.float32)
else:
scaled_x_sum += tl.sum(-eps * X_block2).to(tl.float32)
m_new = tl.maximum(m, block_max)
m_new = tl.maximum(m, block_max,propagate_nan=tl.PropagateNan.ALL)
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
m = m_new

Expand Down Expand Up @@ -312,15 +312,15 @@ def _cross_entropy_prime_kernel(
other=float("-inf"),
).cast(tl.float32)

block_max = tl.max(X_block, axis=0) # Use axis=0 for clarity, it's a 1D reduction
block_max = tl.max(X_block, axis=0,propagate_nan=tl.PropagateNan.ALL) # Use axis=0 for clarity, it's a 1D reduction
if label_smoothing > 0:
X_block2 = tl.load(
current_X_ptr + X_offsets,
mask=X_mask,
other=0.0,
).cast(tl.float32)
scaled_x_sum += tl.sum(-eps * X_block2, axis=0).to(tl.float32)
m_new = tl.maximum(m, block_max)
m_new = tl.maximum(m, block_max,propagate_nan=tl.PropagateNan.ALL)
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new), axis=0)
m = m_new

Expand Down
4 changes: 2 additions & 2 deletions mojo_opset/backends/ttx/kernels/npu/gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def gelu_tanh_approx(x):
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / π)
x_cubed = x * x * x
tanh_arg = sqrt_2_over_pi * (x + 0.044715 * x_cubed)
return 0.5 * x * (1 + tl.tanh(tanh_arg))
return 0.5 * x * (1 + tl.extra.cann.math.tanh(tanh_arg))


@triton.autotune(
Expand Down Expand Up @@ -137,7 +137,7 @@ def _gelu_bwd_kernel(
sqrt_2_over_pi = 0.7978845608028654
x_cubed = x_f32 * x_f32 * x_f32
tanh_arg = sqrt_2_over_pi * (x_f32 + 0.044715 * x_cubed)
tanh_result = tl.tanh(tanh_arg)
tanh_result = tl.extra.cann.math.tanh(tanh_arg)

term1 = 0.5 * (1 + tanh_result)
tanh_sq = tanh_result * tanh_result
Expand Down
15 changes: 8 additions & 7 deletions mojo_opset/backends/ttx/kernels/npu/group_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,14 @@ def _m_grouped_matmul_bKmajor_kernel(
mask=msk_m[:, None] and (offs_k[None, :] < K - k * BLOCK_K),
other=0.0,
)
tl.compile_hint(a, "dot_pad_only_k")
tl.extra.cann.extension.compile_hint(a, "dot_pad_only_k")
b = tl.load(
b_ptrs,
mask=msk_n[:, None] and (offs_k[None, :] < (K - k * BLOCK_K)),
other=0.0,
)
b = tl.trans(b)
tl.compile_hint(b, "dot_pad_only_k")
tl.extra.cann.extension.compile_hint(b, "dot_pad_only_k")
accumulator = tl.dot(a, b, acc=accumulator)

c = accumulator.to(C.dtype.element_ty)
Expand Down Expand Up @@ -178,13 +178,13 @@ def _m_grouped_matmul_bNmajor_kernel(
mask=msk_m[:, None] and (offs_ak[None, :] < K - k * BLOCK_K),
other=0.0,
)
tl.compile_hint(a, "dot_pad_only_k")
tl.extra.cann.extension.compile_hint(a, "dot_pad_only_k")
b = tl.load(
b_ptrs,
mask=(offs_bk[:, None] < (group_idx * K + K - k * BLOCK_K)) and msk_n[None, :],
other=0.0,
)
tl.compile_hint(b, "dot_pad_only_k")
tl.extra.cann.extension.compile_hint(b, "dot_pad_only_k")
accumulator = tl.dot(a, b, acc=accumulator)

c = accumulator.to(C.dtype.element_ty)
Expand Down Expand Up @@ -224,6 +224,7 @@ def m_grouped_matmul_impl(
strideBN,
strideBK,
multibuffer=True,
sync_solver=False,
)
return C

Expand Down Expand Up @@ -291,9 +292,9 @@ def _k_grouped_matmul_kernel(
b_ptrs = b_ptrs_base + kk * BLOCK_K * N
a = tl.load(a_ptrs, mask=(offs_k[:, None] < group_end - kk * BLOCK_K) and msk_m[None, :], other=0.0)
aa = tl.trans(a)
tl.compile_hint(aa, "dot_pad_only_k")
tl.extra.cann.extension.compile_hint(aa, "dot_pad_only_k")
b = tl.load(b_ptrs, mask=(offs_k[:, None] < group_end - kk * BLOCK_K) and msk_n[None, :], other=0.0)
tl.compile_hint(b, "dot_pad_only_k")
tl.extra.cann.extension.compile_hint(b, "dot_pad_only_k")
accumulator = tl.dot(aa, b, acc=accumulator)

c = accumulator.to(C.dtype.element_ty)
Expand Down Expand Up @@ -323,5 +324,5 @@ def grid(META):
assert M % META["BLOCK_M"] == 0, "Only support when M is a multiple of BLOCK_M"
return (num_cores,)

_k_grouped_matmul_kernel[grid](A, B, C, size_per_group, num_groups, M, N, multibuffer=True)
_k_grouped_matmul_kernel[grid](A, B, C, size_per_group, num_groups, M, N, multibuffer=True,sync_solver=False)
return C
2 changes: 1 addition & 1 deletion mojo_opset/backends/ttx/kernels/npu/lightning_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def lightning_indexer_kernel(
)
q = tl.load(query_ptrs)

relu_qk = tl.maximum(tl.dot(q.to(k.dtype), tl.trans(k)), 0.0)
relu_qk = tl.maximum(tl.dot(q.to(k.dtype), tl.trans(k)), 0.0,propagate_nan=tl.PropagateNan.ALL)

query_scale_ptrs = (
query_scale_ptr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def over_encoding_decode_kernel(
oe_carry = tl.full((MTP_STEP, BLOCK_SIZE_N,), ori_vocab_size, dtype=tl.int64)

history_ptr = oe_history + oe_history_stride_0 * bid
n_gram_offsets = tl.flip(tl.arange(0, MAX_N_GRAM))
n_gram_offsets = tl.extra.cann.extension.flip(tl.arange(0, MAX_N_GRAM))

history_id = tl.load(
history_ptr + (oe_history_dim_1 - n_gram_offsets - 1) * oe_history_stride_1
Expand All @@ -97,13 +97,13 @@ def over_encoding_decode_kernel(
# WARNING(liuyuan): tl.cat required the same shapes of lhs and rhs in triton-npu. WTF?
# history_id = tl.cat(history_id, __input_ids, can_reorder=True)
__tmp = tl.zeros((MTP_STEP + MAX_N_GRAM,), dtype=tl.int64)
__tmp = tl.insert_slice(__tmp, history_id, (0,), (MAX_N_GRAM,), (1,))
__tmp = tl.insert_slice(__tmp, __input_ids, (MAX_N_GRAM,), (MTP_STEP,), (1,))
__tmp = tl.extra.cann.extension.insert_slice(__tmp, history_id, (0,), (MAX_N_GRAM,), (1,))
__tmp = tl.extra.cann.extension.insert_slice(__tmp, __input_ids, (MAX_N_GRAM,), (MTP_STEP,), (1,))
history_id = __tmp

for i in tl.static_range(1, MAX_N_GRAM):
__cal_mask = n_grams >= (i + 1)
__history_ids = tl.extract_slice(
__history_ids = tl.extra.cann.extension.extract_slice(
history_id, (MAX_N_GRAM - i,), (MTP_STEP,), (1,)
)
__history_ids = (
Expand All @@ -120,7 +120,7 @@ def over_encoding_decode_kernel(
for ele_idx in (tl.static_range if BLOCK_BATCH_SIZE < 4 else tl.range)(
0, MTP_STEP * n_grams_size
):
__id = tl.get_element(n_gram_ids, (ele_idx,))
__id = tl.extra.cann.extension.get_element(n_gram_ids, (ele_idx,))
__embedding_nf4_dequant__(
ele_idx + bid * MTP_STEP * MAX_N_GRAM,
__id,
Expand Down Expand Up @@ -169,7 +169,7 @@ def over_encoding_decode_kernel(
for ele_idx in (tl.static_range if BLOCK_BATCH_SIZE < 4 else tl.range)(
0, MTP_STEP * n_grams_size
):
__id = tl.get_element(n_gram_ids, (ele_idx,))
__id = tl.extra.cann.extension.get_element(n_gram_ids, (ele_idx,))
__embedding_nf4_dequant__(
ele_idx + bid * MTP_STEP * n_grams_size,
__id,
Expand Down
Loading
Loading