diff --git a/.github/workflows/ascend_accuracy_ci.yml b/.github/workflows/ascend_accuracy_ci.yml index 5cb966c78..9cbe51a26 100644 --- a/.github/workflows/ascend_accuracy_ci.yml +++ b/.github/workflows/ascend_accuracy_ci.yml @@ -31,7 +31,7 @@ jobs: name: Run npu tests runs-on: NPU container: - image: hub.byted.org/aicompiler/npu.ci.debian12:cann8.2.rc1_py3.11_th27 + image: hub.byted.org/aicompiler/npu.ci.debian12:1.0.0.26 volumes: - /etc/localtime:/etc/localtime - /usr/local/Ascend/driver:/usr/local/Ascend/driver @@ -56,6 +56,9 @@ jobs: pip install -e .\[npu\] --index-url https://bytedpypi.byted.org/simple python3 -m build . pip install dist/*.whl --force-reinstall --no-deps + pip uninstall -y byted-triton-x triton + rm -rf /usr/local/lib/python3.11/dist-packages/triton + pip install byted-triton-x>=3.2.1 --index-url https://bytedpypi.byted.org/simple - name: test-Base shell: bash run: | @@ -113,4 +116,4 @@ jobs: shell: bash run: | echo "Running ttx graph tests..." - MOJO_RUN_MODE="COMPILE" python3 -m pytest -n 8 mojo_opset/tests/test_ttx_graph/ + MOJO_RUN_MODE="COMPILE" python3 -m pytest -n 8 mojo_opset/tests/test_ttx_graph/ \ No newline at end of file diff --git a/mojo_opset/backends/ttx/functions/normalization.py b/mojo_opset/backends/ttx/functions/normalization.py index a0103b6a1..c2852937e 100644 --- a/mojo_opset/backends/ttx/functions/normalization.py +++ b/mojo_opset/backends/ttx/functions/normalization.py @@ -23,7 +23,7 @@ def forward( ) -> torch.Tensor: # FIXME: Currently, MojoNormFunction base class does not define fields like 'offset', so they are hardcoded here temporarily. offset = 0.0 - casting_mode = "llama" + casting_mode = "gemma" str_to_casting_mode = {"llama": 0, "gemma": 1, "none": -1} casting_mode_int = str_to_casting_mode[casting_mode] diff --git a/mojo_opset/backends/ttx/kernels/npu/convolution.py b/mojo_opset/backends/ttx/kernels/npu/convolution.py index c7b99e801..fee0fa7f6 100644 --- a/mojo_opset/backends/ttx/kernels/npu/convolution.py +++ b/mojo_opset/backends/ttx/kernels/npu/convolution.py @@ -96,7 +96,7 @@ def causal_conv1d_fwd_kernel( # We keep intra loop load because preloading will cause ub overflow under certain tiling. b_yi = tl.load(x + bos * D + yi_offset_0 * D + yi_offset_1, mask=mask, other=0.0).to(tl.float32) if HAS_WEIGHT: - b_yi *= tl.extract_slice(b_w, [i_w + W - 1, 0], [1, BD], [1, 1]) + b_yi *= tl.extra.cann.extension.extract_slice(b_w, [i_w + W - 1, 0], [1, BD], [1, 1]) b_y += b_yi elif i_t * BT >= W: @@ -105,7 +105,7 @@ def causal_conv1d_fwd_kernel( mask = (yi_offset_0 < T_len) & (yi_offset_1 < D) & (yi_offset_0 >= 0) b_yi = tl.load(x + bos * D + yi_offset_0 * D + yi_offset_1, mask=mask, other=0.0).to(tl.float32) if HAS_WEIGHT: - b_yi *= tl.extract_slice(b_w, [i_w + W - 1, 0], [1, BD], [1, 1]) + b_yi *= tl.extra.cann.extension.extract_slice(b_w, [i_w + W - 1, 0], [1, BD], [1, 1]) b_y += b_yi else: o_t = i_t * BT + tl.arange(0, BT) @@ -123,7 +123,7 @@ def causal_conv1d_fwd_kernel( ) if HAS_WEIGHT: - b_yi *= tl.extract_slice(b_w, [i_w + W - 1, 0], [1, BD], [1, 1]) + b_yi *= tl.extra.cann.extension.extract_slice(b_w, [i_w + W - 1, 0], [1, BD], [1, 1]) b_y += b_yi if HAS_BIAS: @@ -246,19 +246,19 @@ def causal_conv1d_bwd_kernel( b_y = tl.load(p_y, boundary_check=(0, 1)).to(tl.float32) for i_w in tl.static_range(0, W): - b_dy_sub = tl.extract_slice(b_dy, [i_w, 0], [BT, BD], [1, 1]) + b_dy_sub = tl.extra.cann.extension.extract_slice(b_dy, [i_w, 0], [BT, BD], [1, 1]) if ACTIVATION == "swish" or ACTIVATION == "silu": - b_y_sub = tl.extract_slice(b_y, [i_w, 0], [BT, BD], [1, 1]) + b_y_sub = tl.extra.cann.extension.extract_slice(b_y, [i_w, 0], [BT, BD], [1, 1]) b_ys = tl.sigmoid(b_y_sub) b_dy_sub = b_dy_sub * b_ys * (1 + b_y_sub * (1 - b_ys)) b_wdy = b_dy_sub if HAS_WEIGHT: - b_wdy = b_wdy * tl.extract_slice(b_w, [W - i_w - 1, 0], [1, BD], [1, 1]) + b_wdy = b_wdy * tl.extra.cann.extension.extract_slice(b_w, [W - i_w - 1, 0], [1, BD], [1, 1]) b_dw_sub = tl.sum(b_dy_sub * b_x, 0) # [BT, BD] * [BT, BD] --> sum(0) = [BD] - b_dw = tl.insert_slice(b_dw, b_dw_sub[None, :], [W - i_w - 1, 0], [1, BD], [1, 1]) + b_dw = tl.extra.cann.extension.insert_slice(b_dw, b_dw_sub[None, :], [W - i_w - 1, 0], [1, BD], [1, 1]) if HAS_BIAS and i_w == 0: b_db += tl.sum(b_dy_sub, 0) @@ -280,7 +280,7 @@ def causal_conv1d_bwd_kernel( b_dy = b_dy * b_ys * (1 + b_y * (1 - b_ys)) b_wdy = b_dy if HAS_WEIGHT: - b_wdy = b_wdy * tl.extract_slice(b_w, [W - i_w - 1, 0], [1, BD], [1, 1]) + b_wdy = b_wdy * tl.extra.cann.extension.extract_slice(b_w, [W - i_w - 1, 0], [1, BD], [1, 1]) b_dw = tl.sum(b_dy * b_x, 0) tl.store(dw + i_tg * D * W + o_d * W + W - i_w - 1, b_dw.to(dw.dtype.element_ty), mask=m_d) @@ -335,7 +335,7 @@ def causal_conv1d_bwd_kernel( b_wdy = ( b_dy_shift if not HAS_WEIGHT - else (b_dy_shift * tl.extract_slice(b_w, [W - i_w - 1, 0], [1, BD], [1, 1])) + else (b_dy_shift * tl.extra.cann.extension.extract_slice(b_w, [W - i_w - 1, 0], [1, BD], [1, 1])) ) b_dx += b_wdy @@ -583,7 +583,7 @@ def causal_conv1d_states_fwd_kernel( o_t = eos - BW + tl.arange(0, BW) o_d = i_d * BD + tl.arange(0, BD) o_w = W - BW + tl.arange(0, BW) - m_t = o_t >= tl.maximum(bos, eos - W) + m_t = o_t >= tl.maximum(bos, eos - W,propagate_nan=tl.PropagateNan.ALL) m_d = o_d < D m_w = (o_w >= 0) & (o_w < W) @@ -694,7 +694,7 @@ def causal_conv1d_update_kernel_bdt_fwd( mask_x = (offset0_x < dim)[:, None] & ((offset1_x >= 0) & (offset1_x < seq_len))[None, :] block_off_x = bi * dim * seq_len + offset0_x[:, None] * seq_len + offset1_x[None, :] x_b_tmp = tl.load(x_ptr + block_off_x, mask=mask_x, other=0) - x_b = tl.insert_slice(st_b, x_b_tmp, (0, width - 1), (D_CHK_SIZE, T_CHK_SIZE), (1, 1)) + x_b = tl.extra.cann.extension.insert_slice(st_b, x_b_tmp, (0, width - 1), (D_CHK_SIZE, T_CHK_SIZE), (1, 1)) else: offset0 = di * D_CHK_SIZE + tl.arange(0, D_CHK_SIZE) offset1 = ti * T_CHK_SIZE - (width - 1) + tl.arange(0, T_CHK_SIZE + width - 1) @@ -715,7 +715,7 @@ def causal_conv1d_update_kernel_bdt_fwd( # NOTE: In order to avoid use tl.maximum for negative offset, # we pre-compute a fix head tile size (ST_STORE_HEAD_TILE_SIZE) # to store the scene of negative address - x_new_h = tl.extract_slice(x_b, (-t_off, 0), (ST_STORE_HEAD_TILE_SIZE, D_CHK_SIZE), (1, 1)) + x_new_h = tl.extra.cann.extension.extract_slice(x_b, (-t_off, 0), (ST_STORE_HEAD_TILE_SIZE, D_CHK_SIZE), (1, 1)) x_new_h = tl.trans(x_new_h, (1, 0)) nst_off_y0 = di * D_CHK_SIZE + tl.arange(0, D_CHK_SIZE)[:, None] nst_off_y1_h = tl.arange(0, ST_STORE_HEAD_TILE_SIZE)[None, :] @@ -723,7 +723,7 @@ def causal_conv1d_update_kernel_bdt_fwd( block_ptr_h = bi * dim * state_len + nst_off_y0 * state_len + nst_off_y1_h tl.store(conv_state_update_ptr + block_ptr_h, x_new_h, mask=nst_mask_h) else: - x_new_s = tl.extract_slice(x_b, (width - 1, 0), (T_CHK_SIZE, D_CHK_SIZE), (1, 1)) + x_new_s = tl.extra.cann.extension.extract_slice(x_b, (width - 1, 0), (T_CHK_SIZE, D_CHK_SIZE), (1, 1)) x_new_s = tl.trans(x_new_s, (1, 0)) nst_off_y0 = di * D_CHK_SIZE + tl.arange(0, D_CHK_SIZE)[:, None] nst_off_y1 = width - 1 + t_off + tl.arange(0, T_CHK_SIZE)[None, :] @@ -732,14 +732,14 @@ def causal_conv1d_update_kernel_bdt_fwd( tl.store(conv_state_update_ptr + block_ptr, x_new_s, mask=nst_mask) for owi in tl.range(0, width): - new_x = tl.extract_slice(x_b, (owi, 0), (T_CHK_SIZE, D_CHK_SIZE), (1, 1)) - w_chl_wi = tl.extract_slice(w, (owi, 0), (1, D_CHK_SIZE), (1, 1)) + new_x = tl.extra.cann.extension.extract_slice(x_b, (owi, 0), (T_CHK_SIZE, D_CHK_SIZE), (1, 1)) + w_chl_wi = tl.extra.cann.extension.extract_slice(w, (owi, 0), (1, D_CHK_SIZE), (1, 1)) x_mul_chl_wi = new_x * w_chl_wi out_block += x_mul_chl_wi out_block = tl.trans(out_block, (1, 0)) if SILU_ACTIVATION: - out_block = out_block * tl.sigmoid(out_block) + out_block = (out_block * tl.sigmoid(out_block)).to(x_ptr.dtype.element_ty) tl.store( tl.make_block_ptr( out_ptr, diff --git a/mojo_opset/backends/ttx/kernels/npu/diffution_attention.py b/mojo_opset/backends/ttx/kernels/npu/diffution_attention.py index a4493381a..545adce0e 100644 --- a/mojo_opset/backends/ttx/kernels/npu/diffution_attention.py +++ b/mojo_opset/backends/ttx/kernels/npu/diffution_attention.py @@ -68,7 +68,7 @@ def micro_kernel_fwd( block_s = tl.dot(block_q, block_k.T) * scale if block_mask is not None: block_s += block_mask - block_m_1 = tl.maximum(block_m, tl.max(block_s, axis=1)) + block_m_1 = tl.maximum(block_m, tl.max(block_s, axis=1,propagate_nan=tl.PropagateNan.ALL),propagate_nan=tl.PropagateNan.ALL) block_s = tl.exp(block_s - block_m_1[:, None]) block_l_1 = tl.exp(block_m - block_m_1) * block_l + tl.sum(block_s, axis=1) diff --git a/mojo_opset/backends/ttx/kernels/npu/flash_attention.py b/mojo_opset/backends/ttx/kernels/npu/flash_attention.py index 1292e2963..3f41eafa5 100644 --- a/mojo_opset/backends/ttx/kernels/npu/flash_attention.py +++ b/mojo_opset/backends/ttx/kernels/npu/flash_attention.py @@ -50,7 +50,7 @@ def _sdpa_infer_single_block( if mask is not None: qk = tl.where(mask, qk, float("-inf")) # 32B # bool - m_ij = tl.maximum(m_i, tl.max(qk, 1)) # Scaled max + m_ij = tl.maximum(m_i, tl.max(qk, 1, propagate_nan=tl.PropagateNan.ALL), propagate_nan=tl.PropagateNan.ALL) # Scaled max qk = qk - m_ij[:, None] # Stabilize # Softmax weights p = exp(qk) @@ -423,8 +423,8 @@ def paged_decode_kernel( qk *= softmax_scale qk = tl.where(mask, qk, float("-inf")) - m_j = tl.max(qk, axis=0) - m_ij = tl.maximum(m_i, m_j) + m_j = tl.max(qk, axis=0,propagate_nan=tl.PropagateNan.ALL) + m_ij = tl.maximum(m_i, m_j,propagate_nan=tl.PropagateNan.ALL) qk = qk - m_ij p = tl.math.exp(qk) diff --git a/mojo_opset/backends/ttx/kernels/npu/fused_linear_cross_entropy.py b/mojo_opset/backends/ttx/kernels/npu/fused_linear_cross_entropy.py index 90b9839a0..dc284456a 100644 --- a/mojo_opset/backends/ttx/kernels/npu/fused_linear_cross_entropy.py +++ b/mojo_opset/backends/ttx/kernels/npu/fused_linear_cross_entropy.py @@ -94,7 +94,7 @@ def _cross_entropy_kernel( if HAS_SOFTCAPPING: X_block = softcap * tl.math.tanh(X_block / softcap) - block_max = tl.max(X_block) + block_max = tl.max(X_block,propagate_nan=tl.PropagateNan.ALL) if label_smoothing > 0: X_block2 = tl.load( X_ptr + X_offsets, @@ -107,7 +107,7 @@ def _cross_entropy_kernel( scaled_x_sum += tl.sum(-eps * X_block2 * weight_block).to(tl.float32) else: scaled_x_sum += tl.sum(-eps * X_block2).to(tl.float32) - m_new = tl.maximum(m, block_max) + m_new = tl.maximum(m, block_max,propagate_nan=tl.PropagateNan.ALL) d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) m = m_new @@ -312,7 +312,7 @@ def _cross_entropy_prime_kernel( other=float("-inf"), ).cast(tl.float32) - block_max = tl.max(X_block, axis=0) # Use axis=0 for clarity, it's a 1D reduction + block_max = tl.max(X_block, axis=0,propagate_nan=tl.PropagateNan.ALL) # Use axis=0 for clarity, it's a 1D reduction if label_smoothing > 0: X_block2 = tl.load( current_X_ptr + X_offsets, @@ -320,7 +320,7 @@ def _cross_entropy_prime_kernel( other=0.0, ).cast(tl.float32) scaled_x_sum += tl.sum(-eps * X_block2, axis=0).to(tl.float32) - m_new = tl.maximum(m, block_max) + m_new = tl.maximum(m, block_max,propagate_nan=tl.PropagateNan.ALL) d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new), axis=0) m = m_new diff --git a/mojo_opset/backends/ttx/kernels/npu/gelu.py b/mojo_opset/backends/ttx/kernels/npu/gelu.py index b272f0ff7..851f20c62 100644 --- a/mojo_opset/backends/ttx/kernels/npu/gelu.py +++ b/mojo_opset/backends/ttx/kernels/npu/gelu.py @@ -28,7 +28,7 @@ def gelu_tanh_approx(x): sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / π) x_cubed = x * x * x tanh_arg = sqrt_2_over_pi * (x + 0.044715 * x_cubed) - return 0.5 * x * (1 + tl.tanh(tanh_arg)) + return 0.5 * x * (1 + tl.extra.cann.math.tanh(tanh_arg)) @triton.autotune( @@ -137,7 +137,7 @@ def _gelu_bwd_kernel( sqrt_2_over_pi = 0.7978845608028654 x_cubed = x_f32 * x_f32 * x_f32 tanh_arg = sqrt_2_over_pi * (x_f32 + 0.044715 * x_cubed) - tanh_result = tl.tanh(tanh_arg) + tanh_result = tl.extra.cann.math.tanh(tanh_arg) term1 = 0.5 * (1 + tanh_result) tanh_sq = tanh_result * tanh_result diff --git a/mojo_opset/backends/ttx/kernels/npu/group_gemm.py b/mojo_opset/backends/ttx/kernels/npu/group_gemm.py index 67ffc0b08..fd87d2af3 100644 --- a/mojo_opset/backends/ttx/kernels/npu/group_gemm.py +++ b/mojo_opset/backends/ttx/kernels/npu/group_gemm.py @@ -103,14 +103,14 @@ def _m_grouped_matmul_bKmajor_kernel( mask=msk_m[:, None] and (offs_k[None, :] < K - k * BLOCK_K), other=0.0, ) - tl.compile_hint(a, "dot_pad_only_k") + tl.extra.cann.extension.compile_hint(a, "dot_pad_only_k") b = tl.load( b_ptrs, mask=msk_n[:, None] and (offs_k[None, :] < (K - k * BLOCK_K)), other=0.0, ) b = tl.trans(b) - tl.compile_hint(b, "dot_pad_only_k") + tl.extra.cann.extension.compile_hint(b, "dot_pad_only_k") accumulator = tl.dot(a, b, acc=accumulator) c = accumulator.to(C.dtype.element_ty) @@ -178,13 +178,13 @@ def _m_grouped_matmul_bNmajor_kernel( mask=msk_m[:, None] and (offs_ak[None, :] < K - k * BLOCK_K), other=0.0, ) - tl.compile_hint(a, "dot_pad_only_k") + tl.extra.cann.extension.compile_hint(a, "dot_pad_only_k") b = tl.load( b_ptrs, mask=(offs_bk[:, None] < (group_idx * K + K - k * BLOCK_K)) and msk_n[None, :], other=0.0, ) - tl.compile_hint(b, "dot_pad_only_k") + tl.extra.cann.extension.compile_hint(b, "dot_pad_only_k") accumulator = tl.dot(a, b, acc=accumulator) c = accumulator.to(C.dtype.element_ty) @@ -224,6 +224,7 @@ def m_grouped_matmul_impl( strideBN, strideBK, multibuffer=True, + sync_solver=False, ) return C @@ -291,9 +292,9 @@ def _k_grouped_matmul_kernel( b_ptrs = b_ptrs_base + kk * BLOCK_K * N a = tl.load(a_ptrs, mask=(offs_k[:, None] < group_end - kk * BLOCK_K) and msk_m[None, :], other=0.0) aa = tl.trans(a) - tl.compile_hint(aa, "dot_pad_only_k") + tl.extra.cann.extension.compile_hint(aa, "dot_pad_only_k") b = tl.load(b_ptrs, mask=(offs_k[:, None] < group_end - kk * BLOCK_K) and msk_n[None, :], other=0.0) - tl.compile_hint(b, "dot_pad_only_k") + tl.extra.cann.extension.compile_hint(b, "dot_pad_only_k") accumulator = tl.dot(aa, b, acc=accumulator) c = accumulator.to(C.dtype.element_ty) @@ -323,5 +324,5 @@ def grid(META): assert M % META["BLOCK_M"] == 0, "Only support when M is a multiple of BLOCK_M" return (num_cores,) - _k_grouped_matmul_kernel[grid](A, B, C, size_per_group, num_groups, M, N, multibuffer=True) + _k_grouped_matmul_kernel[grid](A, B, C, size_per_group, num_groups, M, N, multibuffer=True,sync_solver=False) return C diff --git a/mojo_opset/backends/ttx/kernels/npu/lightning_indexer.py b/mojo_opset/backends/ttx/kernels/npu/lightning_indexer.py index bc9a9acda..64901ce99 100644 --- a/mojo_opset/backends/ttx/kernels/npu/lightning_indexer.py +++ b/mojo_opset/backends/ttx/kernels/npu/lightning_indexer.py @@ -154,7 +154,7 @@ def lightning_indexer_kernel( ) q = tl.load(query_ptrs) - relu_qk = tl.maximum(tl.dot(q.to(k.dtype), tl.trans(k)), 0.0) + relu_qk = tl.maximum(tl.dot(q.to(k.dtype), tl.trans(k)), 0.0,propagate_nan=tl.PropagateNan.ALL) query_scale_ptrs = ( query_scale_ptr diff --git a/mojo_opset/backends/ttx/kernels/npu/over_encoding/fused_over_encoding.py b/mojo_opset/backends/ttx/kernels/npu/over_encoding/fused_over_encoding.py index 048bf46df..8b0f75df8 100644 --- a/mojo_opset/backends/ttx/kernels/npu/over_encoding/fused_over_encoding.py +++ b/mojo_opset/backends/ttx/kernels/npu/over_encoding/fused_over_encoding.py @@ -88,7 +88,7 @@ def over_encoding_decode_kernel( oe_carry = tl.full((MTP_STEP, BLOCK_SIZE_N,), ori_vocab_size, dtype=tl.int64) history_ptr = oe_history + oe_history_stride_0 * bid - n_gram_offsets = tl.flip(tl.arange(0, MAX_N_GRAM)) + n_gram_offsets = tl.extra.cann.extension.flip(tl.arange(0, MAX_N_GRAM)) history_id = tl.load( history_ptr + (oe_history_dim_1 - n_gram_offsets - 1) * oe_history_stride_1 @@ -97,13 +97,13 @@ def over_encoding_decode_kernel( # WARNING(liuyuan): tl.cat required the same shapes of lhs and rhs in triton-npu. WTF? # history_id = tl.cat(history_id, __input_ids, can_reorder=True) __tmp = tl.zeros((MTP_STEP + MAX_N_GRAM,), dtype=tl.int64) - __tmp = tl.insert_slice(__tmp, history_id, (0,), (MAX_N_GRAM,), (1,)) - __tmp = tl.insert_slice(__tmp, __input_ids, (MAX_N_GRAM,), (MTP_STEP,), (1,)) + __tmp = tl.extra.cann.extension.insert_slice(__tmp, history_id, (0,), (MAX_N_GRAM,), (1,)) + __tmp = tl.extra.cann.extension.insert_slice(__tmp, __input_ids, (MAX_N_GRAM,), (MTP_STEP,), (1,)) history_id = __tmp for i in tl.static_range(1, MAX_N_GRAM): __cal_mask = n_grams >= (i + 1) - __history_ids = tl.extract_slice( + __history_ids = tl.extra.cann.extension.extract_slice( history_id, (MAX_N_GRAM - i,), (MTP_STEP,), (1,) ) __history_ids = ( @@ -120,7 +120,7 @@ def over_encoding_decode_kernel( for ele_idx in (tl.static_range if BLOCK_BATCH_SIZE < 4 else tl.range)( 0, MTP_STEP * n_grams_size ): - __id = tl.get_element(n_gram_ids, (ele_idx,)) + __id = tl.extra.cann.extension.get_element(n_gram_ids, (ele_idx,)) __embedding_nf4_dequant__( ele_idx + bid * MTP_STEP * MAX_N_GRAM, __id, @@ -169,7 +169,7 @@ def over_encoding_decode_kernel( for ele_idx in (tl.static_range if BLOCK_BATCH_SIZE < 4 else tl.range)( 0, MTP_STEP * n_grams_size ): - __id = tl.get_element(n_gram_ids, (ele_idx,)) + __id = tl.extra.cann.extension.get_element(n_gram_ids, (ele_idx,)) __embedding_nf4_dequant__( ele_idx + bid * MTP_STEP * n_grams_size, __id, diff --git a/mojo_opset/backends/ttx/kernels/npu/over_encoding/n_gram.py b/mojo_opset/backends/ttx/kernels/npu/over_encoding/n_gram.py index 04f6f3fb3..f1acda39b 100644 --- a/mojo_opset/backends/ttx/kernels/npu/over_encoding/n_gram.py +++ b/mojo_opset/backends/ttx/kernels/npu/over_encoding/n_gram.py @@ -164,6 +164,7 @@ def n_gram_decode_kernel( block_offsets = tl.arange(0, BLOCK_SIZE_N) block_mask = block_offsets < n_grams_size + block_mask_1 = block_offsets[None,:] < n_grams_size oe_vocab_sizes = tl.load( oe_vocab_sizes + block_offsets, mask=block_mask, other=1 @@ -185,7 +186,6 @@ def n_gram_decode_kernel( if MTP_STEP > 1: input_indices = tl.arange(0, MTP_STEP) __input_ids = tl.load(input_ids + bid * MTP_STEP + input_indices).to(tl.int64) - tl.device_print("__input_ids:", __input_ids) n_gram_ids = tl.view(__input_ids, (MTP_STEP, 1)) n_gram_ids = tl.broadcast_to(n_gram_ids, (MTP_STEP, BLOCK_SIZE_N)) @@ -193,7 +193,7 @@ def n_gram_decode_kernel( oe_carry = tl.full((MTP_STEP, BLOCK_SIZE_N,), vocab_size, dtype=tl.int64) history_ptr = oe_history + oe_history_stride_0 * bid - n_gram_offsets = tl.flip(tl.arange(0, MAX_N_GRAM)) + n_gram_offsets = tl.extra.cann.extension.flip(tl.arange(0, MAX_N_GRAM)) history_id = tl.load( history_ptr + (oe_history_dim_1 - n_gram_offsets - 1) * oe_history_stride_1 @@ -202,13 +202,13 @@ def n_gram_decode_kernel( # WARNING(liuyuan): tl.cat required the same shapes of lhs and rhs in triton-npu. WTF? # history_id = tl.cat(history_id, __input_ids, can_reorder=True) __tmp = tl.zeros((MTP_STEP + MAX_N_GRAM,), dtype=tl.int64) - __tmp = tl.insert_slice(__tmp, history_id, (0,), (MAX_N_GRAM,), (1,)) - __tmp = tl.insert_slice(__tmp, __input_ids, (MAX_N_GRAM,), (MTP_STEP,), (1,)) + __tmp = tl.extra.cann.extension.insert_slice(__tmp, history_id, (0,), (MAX_N_GRAM,), (1,)) + __tmp = tl.extra.cann.extension.insert_slice(__tmp, __input_ids, (MAX_N_GRAM,), (MTP_STEP,), (1,)) history_id = __tmp for i in tl.static_range(1, MAX_N_GRAM): __cal_mask = n_grams >= (i + 1) - __history_ids = tl.extract_slice( + __history_ids = tl.extra.cann.extension.extract_slice( history_id, (MAX_N_GRAM - i,), (MTP_STEP,), (1,) ) __history_ids = ( @@ -226,7 +226,7 @@ def n_gram_decode_kernel( output_ids + bid * output_stride_0 + tl.arange(0, MTP_STEP)[:, None] * output_stride_1 + block_offsets[None, :] * output_stride_2, # output_ids + bid * output_stride_0 + output_offsets, n_gram_ids.to(output_ids.dtype.element_ty), - mask=tl.view(block_mask, (1, BLOCK_SIZE_N)), + mask=block_mask_1 ) else: diff --git a/mojo_opset/backends/ttx/kernels/npu/quant.py b/mojo_opset/backends/ttx/kernels/npu/quant.py index 56a14cd1c..b346c33df 100644 --- a/mojo_opset/backends/ttx/kernels/npu/quant.py +++ b/mojo_opset/backends/ttx/kernels/npu/quant.py @@ -119,7 +119,7 @@ def scale_dynamic_quant_kernel( current_max = tl.max(tl.abs(scaled_vals), axis=1) - max_abs_accumulator = tl.maximum(max_abs_accumulator, current_max) + max_abs_accumulator = tl.maximum(max_abs_accumulator, current_max,propagate_nan=tl.PropagateNan.ALL) final_max_abs = max_abs_accumulator current_quant_scale = final_max_abs / 127.0 @@ -151,6 +151,6 @@ def scale_dynamic_quant_kernel( scaled_vals = input_vals quant_vals = scaled_vals / current_quant_scale[:, None] quant_vals = tl.where(quant_vals < 0, quant_vals - 0.5, quant_vals + 0.5) - quant_vals_int8 = tl.cast(quant_vals, dtype=tl.int8, overflow_mode="saturate") + quant_vals_int8 = tl.cast(quant_vals, dtype=tl.int8) tl.store(output_ptr, quant_vals_int8, mask=block_mask) diff --git a/mojo_opset/backends/ttx/kernels/npu/rope.py b/mojo_opset/backends/ttx/kernels/npu/rope.py index 1cf06520d..0d7446c2c 100644 --- a/mojo_opset/backends/ttx/kernels/npu/rope.py +++ b/mojo_opset/backends/ttx/kernels/npu/rope.py @@ -77,8 +77,8 @@ def _compute_rope( TOKEN_BLOCK_SIZE: tl.constexpr, inverse: tl.constexpr, ): - x1 = tl.extract_slice(x, [0, 0, 0], [TOKEN_BLOCK_SIZE, head_num, half_rope_dim], [1, 1, 1]) - x2 = tl.extract_slice(x, [0, 0, half_rope_dim], [TOKEN_BLOCK_SIZE, head_num, half_rope_dim], [1, 1, 1]) + x1 = tl.extra.cann.extension.extract_slice(x, [0, 0, 0], [TOKEN_BLOCK_SIZE, head_num, half_rope_dim], [1, 1, 1]) + x2 = tl.extra.cann.extension.extract_slice(x, [0, 0, half_rope_dim], [TOKEN_BLOCK_SIZE, head_num, half_rope_dim], [1, 1, 1]) if inverse: roped_x1 = x1 * cos_tile + x2 * sin_tile @@ -87,8 +87,8 @@ def _compute_rope( roped_x1 = x1 * cos_tile - x2 * sin_tile roped_x2 = x2 * cos_tile + x1 * sin_tile - x = tl.insert_slice(x, roped_x1, [0, 0, 0], [TOKEN_BLOCK_SIZE, head_num, half_rope_dim], [1, 1, 1]) - x = tl.insert_slice( + x = tl.extra.cann.extension.insert_slice(x, roped_x1, [0, 0, 0], [TOKEN_BLOCK_SIZE, head_num, half_rope_dim], [1, 1, 1]) + x = tl.extra.cann.extension.insert_slice( x, roped_x2, [0, 0, half_rope_dim], diff --git a/mojo_opset/backends/ttx/kernels/npu/sample.py b/mojo_opset/backends/ttx/kernels/npu/sample.py index 698e78304..aba19e2b7 100644 --- a/mojo_opset/backends/ttx/kernels/npu/sample.py +++ b/mojo_opset/backends/ttx/kernels/npu/sample.py @@ -282,10 +282,10 @@ def _compare_and_swap(x, ids, flip, i: core.constexpr, n_dims: core.constexpr): left = core.reshape(left, x.shape) right = core.reshape(right, x.shape) - left_idx = core.broadcast_to(tl.max(y_idx * (1 - mask), 1)[:, None, :], shape).to( + left_idx = core.broadcast_to(tl.max(y_idx * (1 - mask), 1,propagate_nan=tl.PropagateNan.ALL)[:, None, :], shape).to( ids.dtype ) - right_idx = core.broadcast_to(tl.max(y_idx * mask, 1)[:, None, :], shape).to( + right_idx = core.broadcast_to(tl.max(y_idx * mask, 1,propagate_nan=tl.PropagateNan.ALL)[:, None, :], shape).to( ids.dtype ) left_idx = core.reshape(left_idx, ids.shape) @@ -633,14 +633,14 @@ def _top_p_sample_kernel( logits = tl.load(row_logits_ptr + offsets * stride_logits_k) - logits_max = tl.max(logits, 0) + logits_max = tl.max(logits, 0,propagate_nan=tl.PropagateNan.ALL) numerator = tl.exp(logits - logits_max) probs = numerator / tl.sum(numerator, 0) cum_probs = tl.cumsum(probs, 0) to_remove = (cum_probs - probs) > top_p to_remove = tl.where(offsets < min_tokens_to_keep, False, to_remove) filtered_logits = tl.where(to_remove, filter_value, logits) - f_logits_max = tl.max(filtered_logits, 0) + f_logits_max = tl.max(filtered_logits, 0,propagate_nan=tl.PropagateNan.ALL) f_numerator = tl.exp(filtered_logits - f_logits_max) f_probs = f_numerator / tl.sum(f_numerator, 0) @@ -787,14 +787,14 @@ def _top_p_filter_kernel( logits = tl.load(row_logits_ptr + offsets * stride_logits_k) - logits_max = tl.max(logits, 0) + logits_max = tl.max(logits, 0, propagate_nan=tl.PropagateNan.ALL) numerator = tl.exp(logits - logits_max) probs = numerator / tl.sum(numerator, 0) cum_probs = tl.cumsum(probs, 0) to_remove = (cum_probs - probs) > top_p to_remove = tl.where(offsets < min_tokens_to_keep, False, to_remove) filtered_logits = tl.where(to_remove, filter_value, logits) - f_logits_max = tl.max(filtered_logits, 0) + f_logits_max = tl.max(filtered_logits, 0, propagate_nan=tl.PropagateNan.ALL) f_numerator = tl.exp(filtered_logits - f_logits_max) f_probs = f_numerator / tl.sum(f_numerator, 0) diff --git a/mojo_opset/backends/ttx/kernels/npu/sdpa.py b/mojo_opset/backends/ttx/kernels/npu/sdpa.py index 9abf2d80b..0666cc621 100644 --- a/mojo_opset/backends/ttx/kernels/npu/sdpa.py +++ b/mojo_opset/backends/ttx/kernels/npu/sdpa.py @@ -65,8 +65,8 @@ def _sdpa_infer_inner( # qk += (1 - mask.to(tl.float32)) * (-1e6) # qk = tl.where(mask, qk, float("-inf")) qk = tl.where(mask, qk, -1e6) - - m_ij = tl.maximum(m_i, tl.max(qk, 1)) # Scaled max + m_ij = tl.maximum(m_i, tl.max(qk, 1, propagate_nan=tl.PropagateNan.ALL), propagate_nan=tl.PropagateNan.ALL) + #m_ij = tl.maximum(m_i, tl.max(qk, 1)) # Scaled max qk = qk - m_ij[:, None] # Stabilize # Softmax weights p = exp(qk) @@ -82,7 +82,7 @@ def _sdpa_infer_inner( # -- Update output accumulator -- acc_ptr = acc_ptr * alpha[:, None] acc_ptr = tl.dot(p_cast, v, acc_ptr) - tl.compile_hint(acc_ptr, "tile_cube_loop", 2) + tl.extra.cann.extension.compile_hint(acc_ptr, "hivm.tile_mix_cube_num",2) m_i = m_ij # Update current block max # Advance V and K block pointers to next BLOCK_N range @@ -408,7 +408,8 @@ def kernel_sdpa_fwd( block_s = tl.dot(block_q, tl.trans(block_k)) * scale block_s -= (1.0 - block_mask.to(HIGH_TYPE)) * 1e6 - block_m_1 = tl.maximum(block_m, tl.max(block_s, axis=1)) + block_m_1 = tl.maximum(block_m, tl.max(block_s, axis=1, + propagate_nan=tl.PropagateNan.ALL),propagate_nan=tl.PropagateNan.ALL) block_s = tl.exp(block_s - block_m_1[:, None]) block_l_1 = tl.exp(block_m - block_m_1) * block_l + tl.sum(block_s, axis=1) block_o = tl.exp(block_m - block_m_1)[:, None] * block_o + tl.dot(block_s.to(LOW_TYPE), block_v).to( @@ -802,7 +803,6 @@ def sdpa_infer_impl( if scale is None: scale = 1.0 - o = torch.empty_like(q) extra_kern_args = {} diff --git a/mojo_opset/backends/ttx/kernels/npu/swa.py b/mojo_opset/backends/ttx/kernels/npu/swa.py index 1c9cfe928..817f36015 100644 --- a/mojo_opset/backends/ttx/kernels/npu/swa.py +++ b/mojo_opset/backends/ttx/kernels/npu/swa.py @@ -216,7 +216,7 @@ def _sdpa_acc_fwd_MxN( if mask is not None and mask is not True: qk = tl.where(mask, qk, -1e6) # 32B # bool - m_ij = tl.maximum(m_i, tl.max(qk, 1)) # Scaled max + m_ij = tl.maximum(m_i, tl.max(qk, 1,propagate_nan=tl.PropagateNan.ALL),propagate_nan=tl.PropagateNan.ALL) # Scaled max qk = qk - m_ij[:, None] # Stabilize # Softmax weights p = exp(qk) @@ -630,6 +630,7 @@ def swa_infer_impl( BLOCK_M, BLOCK_N, BLOCK_D, + enable_ubuf_saving=True, ) return o @@ -1023,6 +1024,7 @@ def swa_paged_prefill_impl( BLOCK_N, BLOCK_D, page_size, + enable_ubuf_saving=True ) return o @@ -1055,7 +1057,7 @@ def _sdpa_acc_fwd_1xN( if mask is not None and mask is not True: qk = tl.where(mask, qk, float("-inf")) # 32B # bool - m_ij = tl.maximum(m_i, tl.max(qk, 0)) # Scaled max + m_ij = tl.maximum(m_i, tl.max(qk, 0,propagate_nan=tl.PropagateNan.ALL),propagate_nan=tl.PropagateNan.ALL) # Scaled max qk = qk - m_ij # Stabilize # Softmax weights p = exp(qk) diff --git a/mojo_opset/backends/ttx/kernels/npu/vision_rope.py b/mojo_opset/backends/ttx/kernels/npu/vision_rope.py index 9b8493d02..1a8e66e94 100644 --- a/mojo_opset/backends/ttx/kernels/npu/vision_rope.py +++ b/mojo_opset/backends/ttx/kernels/npu/vision_rope.py @@ -96,16 +96,16 @@ def _compute_vision_rope( so both cos/sin halves are numerically identical and we only need one half. """ # Split x into halves: [TOKEN_BLOCK_SIZE, head_num, half_rope_dim] - x1 = tl.extract_slice(x, [0, 0, 0], [TOKEN_BLOCK_SIZE, head_num, half_rope_dim], [1, 1, 1]) - x2 = tl.extract_slice(x, [0, 0, half_rope_dim], [TOKEN_BLOCK_SIZE, head_num, half_rope_dim], [1, 1, 1]) + x1 = tl.extra.cann.extension.extract_slice(x, [0, 0, 0], [TOKEN_BLOCK_SIZE, head_num, half_rope_dim], [1, 1, 1]) + x2 = tl.extra.cann.extension.extract_slice(x, [0, 0, half_rope_dim], [TOKEN_BLOCK_SIZE, head_num, half_rope_dim], [1, 1, 1]) # out_half1 = x1*c - x2*s (broadcasts c/s across heads) # out_half2 = x2*c + x1*s roped_x1 = x1 * cos_half_tile - x2 * sin_half_tile roped_x2 = x2 * cos_half_tile + x1 * sin_half_tile - x = tl.insert_slice(x, roped_x1, [0, 0, 0], [TOKEN_BLOCK_SIZE, head_num, half_rope_dim], [1, 1, 1]) - x = tl.insert_slice(x, roped_x2, [0, 0, half_rope_dim], [TOKEN_BLOCK_SIZE, head_num, half_rope_dim], [1, 1, 1]) + x = tl.extra.cann.extension.insert_slice(x, roped_x1, [0, 0, 0], [TOKEN_BLOCK_SIZE, head_num, half_rope_dim], [1, 1, 1]) + x = tl.extra.cann.extension.insert_slice(x, roped_x2, [0, 0, half_rope_dim], [TOKEN_BLOCK_SIZE, head_num, half_rope_dim], [1, 1, 1]) return x diff --git a/mojo_opset/core/functions/normalization.py b/mojo_opset/core/functions/normalization.py index 14a810456..70a21c06f 100644 --- a/mojo_opset/core/functions/normalization.py +++ b/mojo_opset/core/functions/normalization.py @@ -36,7 +36,6 @@ def forward( ctx.save_for_backward(input, weight) ctx.normalized_shape = normalized_shape ctx.eps = eps - return y @staticmethod diff --git a/mojo_opset/experimental/operators/indexer.py b/mojo_opset/experimental/operators/indexer.py index d3edf1598..35451f7a1 100644 --- a/mojo_opset/experimental/operators/indexer.py +++ b/mojo_opset/experimental/operators/indexer.py @@ -89,8 +89,9 @@ def forward( q = self.activation(q) k = self.activation(k) - q_quant, q_scale = self.quant(q, None) - k_quant, k_scale = self.quant(k, None) + q_quant, q_scale = self.quant(q) + k_quant, k_scale = self.quant(k) + q_scale = q_scale.squeeze(-1) if k_scale.dim() == 3: k_scale = k_scale.amax(dim=-1) diff --git a/mojo_opset/tests/accuracy/functions/test_normalization.py b/mojo_opset/tests/accuracy/functions/test_normalization.py index 67b458d1b..4b2a33e94 100644 --- a/mojo_opset/tests/accuracy/functions/test_normalization.py +++ b/mojo_opset/tests/accuracy/functions/test_normalization.py @@ -39,7 +39,6 @@ def test_rmsnorm_forward_backward_diff(x, weight): ctx = MockFunctionCtx() y = MojoRMSNormFunction.forward(ctx, x, weight, 1e-6) - ctx_ref = MockFunctionCtx() y_ref = MojoRMSNormFunction._registry.get("torch").forward(ctx_ref, x, weight, 1e-6) assert_close(y, y_ref) diff --git a/mojo_opset/tests/perf/test_indexer.py b/mojo_opset/tests/perf/test_indexer.py index c26485a7c..e23050e74 100644 --- a/mojo_opset/tests/perf/test_indexer.py +++ b/mojo_opset/tests/perf/test_indexer.py @@ -29,7 +29,4 @@ @bypass_not_implemented def test_lightning_index(query, query_scale, key, key_scale): indexer = MojoLightningIndexer() - indexer_ref = indexer._registry.get("torch")() - - perf(lambda: indexer_ref(query, query_scale, key, key_scale)) perf(lambda: indexer(query, query_scale, key, key_scale)) diff --git a/mojo_opset/tests/perf/test_linear.py b/mojo_opset/tests/perf/test_linear.py index 6478d6b65..78de1ab0a 100644 --- a/mojo_opset/tests/perf/test_linear.py +++ b/mojo_opset/tests/perf/test_linear.py @@ -19,7 +19,7 @@ def generate_random_list(length, total_sum): diff = total_sum - sum(lst) lst[-1] += diff - return torch.Tensor(lst).to(torch.int64) + return torch.Tensor(lst).to(torch.int32) @pytest.mark.parametrize( diff --git a/mojo_opset/tests/perf/test_store_lowrank.py b/mojo_opset/tests/perf/test_store_lowrank.py index 63e14e603..39371645f 100644 --- a/mojo_opset/tests/perf/test_store_lowrank.py +++ b/mojo_opset/tests/perf/test_store_lowrank.py @@ -1,9 +1,11 @@ import pytest import torch -from mojo_opset.experimental import MojoStoreLowrank -from mojo_opset.tests.utils import auto_switch_platform from mojo_opset.tests.utils import bypass_not_implemented +from mojo_opset.tests.utils import auto_switch_platform +from mojo_opset.utils.platform import get_torch_device + +from mojo_opset.experimental import MojoStoreLowrank kv_lens = [1, 24, 1024, 2048, 4096, 8192, 13312] slot_mappings = [torch.randperm(kv_len) for kv_len in kv_lens] @@ -20,22 +22,26 @@ @pytest.mark.parametrize( - "label_cache, key_lr, block_idxs, token_idxs, token_num", + "shape_label_cache, shape_key_lr", [ - ( - torch.zeros(size=shape0, dtype=torch.bfloat16), - torch.randn(size=(slot_mapping.shape[0], *shape1), dtype=torch.bfloat16), - slot_mapping // 512, - slot_mapping % 512, - slot_mapping.shape[0], - ) - for shape0, shape1 in zip(shapes_label_cache, shapes_key_lr) - for slot_mapping in slot_mappings + ((256, 1, 512, 128), (1, 128)), + ((256, 8, 512, 128), (8, 128)), ], ) +@pytest.mark.parametrize( + "kv_len", + [1024, 2048, 4096, 8192, 13312], +) @auto_switch_platform(set_perf=True) @bypass_not_implemented -def test_store_lowrank(label_cache, key_lr, block_idxs, token_idxs, token_num): - store_lowrank = MojoStoreLowrank() +def test_store_lowrank(shape_label_cache, shape_key_lr, kv_len): + device = get_torch_device() + slot_mapping = torch.randperm(kv_len,device=device) + label_cache = torch.zeros(size=shape_label_cache, dtype=torch.bfloat16,device=device) + key_lr = torch.randn(size=(slot_mapping.shape[0], *shape_key_lr), dtype=torch.bfloat16,device=device) + block_idxs = (slot_mapping // 512).to(torch.int32) + token_idxs = (slot_mapping % 512).to(torch.int32) + token_num = slot_mapping.shape[0] - perf(lambda: store_lowrank(label_cache, key_lr, block_idxs, token_idxs, token_num)) + store_lowrank = MojoStoreLowrank() + perf(lambda: store_lowrank(label_cache, key_lr, block_idxs, token_idxs, token_num)) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 8e2a7990f..bda7c280c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,13 +13,13 @@ authors = [{name="SeedXPU", email='yuminghui.exp@bytedance.com'}] dependencies = [ "einops", "pydantic", - "pytest >= 8.4.0", + "pytest >= 8.3.2", "byted-xpu-graph @ https://luban-source.byted.org/repository/scm/Seed.Foundation.xpu_graph_1.0.0.20.tar.gz" ] [project.optional-dependencies] npu = [ - "byted-triton-x >= 3.2.0.post24" + "byted-triton-x>=3.2.1" ] mlu = [ "triton >= 3.2.0"