Skip to content

Conversation

@weijinqian0
Copy link
Collaborator

[Refactor] Remove redundant attention operator branches.

Reason:

  1. We replace other attention ops with fused_infer_attention_score expect decode_only state.
  2. clean code and remove 310P support.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the attention operator branches by unifying prefill logic into a single _forward_prefill method and removing support for 310P devices. The changes simplify the control flow in the main forward method, making the code cleaner and easier to maintain. My review focuses on the correctness and clarity of this refactoring. I've identified one high-severity issue related to an unused parameter in the new _forward_prefill method, which should be addressed to improve code quality.

Comment on lines 319 to 326
def _forward_prefill(self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Tuple[torch.Tensor],
attn_metadata: AscendMetadata,
output: torch.Tensor,
num_tokens=0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The output parameter in the _forward_prefill method is unused. The value passed to it is immediately shadowed by the assignment output, _ = torch_npu.npu_fused_infer_attention_score(...) on line 357. This is misleading as it suggests an in-place operation which is not happening, and the pre-allocated tensor is wasted.

To avoid confusion and make the code cleaner, the output parameter should be removed from the function signature. The call site at line 592 should also be updated to no longer pass this parameter.

    def _forward_prefill(self,
                             query: torch.Tensor,
                             key: torch.Tensor,
                             value: torch.Tensor,
                             kv_cache: Tuple[torch.Tensor],
                             attn_metadata: AscendMetadata,
                             num_tokens=0):

@weijinqian0
Copy link
Collaborator Author

Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");

you may not use this file except in compliance with the License.

You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software

distributed under the License is distributed on an "AS IS" BASIS,

WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

See the License for the specific language governing permissions and

limitations under the License.

This file is a part of the vllm-ascend project.

from dataclasses import dataclass
from enum import Enum
from typing import ClassVar, List, Optional, Tuple, Type

import torch
import torch.nn as nn
import torch_npu
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils import cdiv, direct_register_custom_op
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec

from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
maybe_save_kv_layer_to_connector,
wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import (get_graph_params,
update_graph_params_workspaces)

from ..utils import weak_ref_tensors

class AscendAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True

@staticmethod
def get_name() -> str:
    return "ASCEND"

@staticmethod
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
    return AscendAttentionBackendImpl

@staticmethod
def get_metadata_cls() -> Type["AscendMetadata"]:
    return AscendMetadata

@staticmethod
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
    return AscendAttentionMetadataBuilder

@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> Tuple[int, ...]:
    return (2, num_blocks, block_size, num_kv_heads, head_size)

@staticmethod
def get_bsh_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> Tuple[int, ...]:
    return (2, num_blocks, block_size, num_kv_heads * head_size)

@staticmethod
def swap_blocks(
    src_kv_cache: List[torch.Tensor],
    dst_kv_cache: List[torch.Tensor],
    src_to_dst: torch.Tensor,
) -> None:
    src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
    dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
    src_indices = src_to_dst[:, 0]
    dst_indices = src_to_dst[:, 1]

    dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
        dst_key_cache.device)
    dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
        dst_key_cache.device)

@staticmethod
def copy_blocks(
    kv_caches: List[torch.Tensor],
    src_to_dists: torch.Tensor,
) -> None:
    src_indices = src_to_dists[:, 0]
    dst_indices = src_to_dists[:, 1]

    for kv_cache in kv_caches:
        key_caches = kv_cache[0]
        value_caches = kv_cache[1]
        key_caches[dst_indices] = key_caches[src_indices]
        value_caches[dst_indices] = value_caches[src_indices]

@staticmethod
def get_supported_block_size() -> list[int]:
    return [128]

class AscendAttentionState(Enum):
PrefillNoCache = 0
PrefillCacheHit = 1
DecodeOnly = 2
ChunkedPrefill = 3
SpecDecoding = 4

@DataClass
class AscendMetadata:
# **************************** Basic Properties ************************** #
attn_mask: Optional[torch.Tensor] = None
# Current state of this attention run.
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill

# Number of tokens excluding padding.
num_actual_tokens: int = 0

# The sequence length per sequence. Sequence length means the computed
# tokens + new tokens (is None if it is a decoding).
# (batch_size,)
# TODO(Angazenn): The following parameters are quite redundant and
# contains similar information (such as seq_lens seq_lens_list). We
# should simplified these parameters once attention schema in vLLM-Ascend
# is unified.
seq_lens: torch.Tensor = None
seq_lens_list: List[int] = None  # type: ignore
actual_seq_lengths_q: List[int] = None  # type: ignore

query_start_loc: torch.Tensor = None
query_lens: torch.Tensor = None
# Maximum query length in the batch (None for decoding).
max_query_len: Optional[int] = None

# ********************** KV Cache Related Properties ********************* #
# Block addresses per sequence (Seq id -> list of physical block).
# (batch_size, max_blocks_per_seq)
block_tables: torch.Tensor = None

# The indices of the token slots that input tokens will be stored into.
# E.g., if `slot_mapping` is [35, 2, 17] and the block size is 16, the
# three tokens are stored in the 3rd slot in block 2, 2nd slot in block 0,
# and 1st slot in block 1, respectively.
# (num_tokens,)
slot_mapping: torch.Tensor = None

# *************************** Other Properties *************************** #
enable_dbo_across_dp: bool = False

class AscendAttentionMetadataBuilder:
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] =
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
reorder_batch_threshold: ClassVar[int] = 1

def __init__(
    self,
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: torch.device,
):
    self.vllm_config = vllm_config
    self.model_config = vllm_config.model_config
    self.device = device
    self.max_num_blocks_per_req = cdiv(
        self.model_config.max_model_len,
        AscendAttentionBackend.get_supported_block_size()[0])
    self.speculative_config = vllm_config.speculative_config
    self.decode_threshold = 1
    if self.speculative_config:
        spec_token_num = self.speculative_config.num_speculative_tokens
        self.decode_threshold += spec_token_num
        assert self.decode_threshold <= 16, f"decode_threshold exceeded \
            npu_fused_infer_attention_score TND layout's limit of 16, \
            got {self.decode_threshold}"

def reorder_batch(self, input_batch,
                  scheduler_output: "SchedulerOutput") -> bool:
    return False

def build(
    self,
    common_prefix_len: int,
    common_attn_metadata: AscendCommonAttentionMetadata,
    model: Optional[nn.Module] = None,
):
    num_reqs = common_attn_metadata.num_reqs
    num_actual_tokens = common_attn_metadata.num_actual_tokens
    query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
                                                                   num_reqs
                                                                   + 1]
    block_table = common_attn_metadata.block_table_tensor
    query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
    seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
    slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
    attn_mask = common_attn_metadata.attn_mask
    attn_state = common_attn_metadata.attn_state
    query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
                                                                   num_reqs
                                                                   + 1]

    if attn_state == AscendAttentionState.DecodeOnly and \
        common_attn_metadata.num_input_tokens > num_actual_tokens:
        padded_num_tokens = common_attn_metadata.num_input_tokens - num_actual_tokens
        seq_lens = torch.cat([
            seq_lens,
            torch.ones(padded_num_tokens,
                       dtype=seq_lens.dtype,
                       device=seq_lens.device)
        ])
        block_table_padding = torch.zeros(
            (padded_num_tokens, ) + block_table.shape[1:],
            dtype=block_table.dtype,
            device=block_table.device)
        block_table = torch.cat([block_table, block_table_padding], dim=0)
        query_start_loc_cpu = torch.cat([
            query_start_loc_cpu,
            torch.arange(query_start_loc_cpu[-1] + 1,
                         query_start_loc_cpu[-1] + padded_num_tokens,
                         dtype=query_start_loc_cpu.dtype,
                         device=query_start_loc_cpu.device)
        ])

    query_start_loc = query_start_loc_cpu.to(self.device,
                                             non_blocking=True)

    attn_metadata = AscendMetadata(
        num_actual_tokens=num_actual_tokens,
        block_tables=block_table,
        query_start_loc=query_start_loc,
        query_lens=query_lens,
        seq_lens=seq_lens,
        seq_lens_list=seq_lens.tolist(),
        max_query_len=common_attn_metadata.max_query_len,
        actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
        slot_mapping=slot_mapping,
        attn_mask=attn_mask,
        attn_state=attn_state,
        enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp)
    return attn_metadata

def build_for_graph_capture(
    self,
    common_attn_metadata: AscendCommonAttentionMetadata,
    attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
    model: Optional[nn.Module] = None,
):
    if attn_state == AscendAttentionState.DecodeOnly:
        attn_metadata = self.build(
            common_prefix_len=0,
            common_attn_metadata=common_attn_metadata,
        )
    else:
        raise NotImplementedError(
            "Currently we only support building dummy metadata for DecodeOnly state"
        )

    attn_metadata.attn_state = attn_state
    return attn_metadata

class AscendAttentionBackendImpl(AttentionImpl):

def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[List[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    logits_soft_cap: Optional[float],
    attn_type: str,
    kv_sharing_target_layer_name: Optional[str],
    **kwargs,
) -> None:
    self.num_heads = num_heads
    self.head_size = head_size
    self.scale = float(scale)
    self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
    self.hidden_size = self.num_heads * self.head_size
    self.kv_cache_dtype = kv_cache_dtype
    self.sliding_window = sliding_window
    if alibi_slopes is not None:
        alibi_slopes = torch.tensor(alibi_slopes,
                                    dtype=torch.float32,
                                    device="npu")
    self.alibi_slopes = alibi_slopes
    self.attn_type = attn_type

    assert self.num_heads % self.num_kv_heads == 0
    self.num_queries_per_kv = self.num_heads // self.num_kv_heads
    self.key_cache = None
    self.value_cache = None

def _forward_prefill(self,
                     query: torch.Tensor,
                     key: torch.Tensor,
                     value: torch.Tensor,
                     attn_metadata: AscendMetadata, output):
    if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
        block_size = 128
        block_table = None
        actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q
    elif attn_metadata.attn_state == \
            AscendAttentionState.PrefillCacheHit:
        batch_size = attn_metadata.query_lens.shape[0]
        block_table = attn_metadata.block_tables[:batch_size, :]
        num_block, block_size, _, _ = self.key_cache.shape  # type: ignore
        key = self.key_cache.view(  # type: ignore
            num_block, block_size, -1)
        value = self.value_cache.view(  # type: ignore
            num_block, block_size, -1)
        actual_seq_lengths_kv = attn_metadata.seq_lens_list
    # chunked_prefill.
    else:
        num_block, block_size, _, _ = self.key_cache.shape  # type: ignore
        key = self.key_cache.view(  # type: ignore
            num_block, block_size, -1)
        value = self.value_cache.view(  # type: ignore
            num_block, block_size, -1)
        block_table = attn_metadata.block_tables
        actual_seq_lengths_kv = attn_metadata.seq_lens_list

    num_tokens = attn_metadata.actual_seq_lengths_q[-1]
    query = query[:num_tokens]
    # Prepare tensors for attention output
    # TODO: Refactor this to step-level instead of layer-level

    # Get workspace from cache or calculate it if not present.
    attn_output, _ = torch_npu.npu_fused_infer_attention_score(
        query=query,
        key=key,
        value=value,
        atten_mask=attn_metadata.attn_mask,
        block_table=block_table,
        input_layout="TND",
        block_size=block_size,
        actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
        actual_seq_lengths_kv=actual_seq_lengths_kv,
        num_key_value_heads=self.num_kv_heads,
        num_heads=self.num_heads,
        scale=self.scale,
        sparse_mode=3,
    )

    attn_output = attn_output.view(num_tokens, self.num_heads, self.head_size)
    output[:num_tokens] = attn_output[:num_tokens]
    return output

def _forward_decode_only(
    self,
    query: torch.Tensor,
    attn_metadata: AscendMetadata,
    output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    if self.sliding_window is not None and attn_metadata.seq_lens.shape[
            0] == query.size(0):
        batch_size = attn_metadata.seq_lens.shape[0]
        block_size = 128
        query = query.view(batch_size, 1, self.num_heads * self.head_size)
        key = self.key_cache
        value = self.value_cache
        if self.key_cache is not None and self.value_cache is not None:
            block_size = self.key_cache.shape[1]
            key = self.key_cache.flatten(2, 3).contiguous()
            value = self.value_cache.flatten(2, 3).contiguous()

        output, _ = torch_npu.npu_fused_infer_attention_score(
            query,
            key,
            value,
            num_heads=self.num_heads,
            num_key_value_heads=self.num_kv_heads,
            input_layout="BSH",
            block_size=block_size,
            pre_tokens=self.sliding_window,
            scale=self.scale,
            block_table=attn_metadata.block_tables,
            actual_seq_lengths=[1] * len(attn_metadata.seq_lens),
            actual_seq_lengths_kv=attn_metadata.seq_lens)

        output = output.view(batch_size, self.num_heads, self.head_size)
    else:
        graph_params = get_graph_params()
        forward_context: ForwardContext = get_forward_context()
        num_tokens = query.shape[0]
        if forward_context.capturing:
            # Get workspace from cache or calculate it if not present.
            workspace = graph_params.workspaces.get(num_tokens)
            if workspace is None:
                workspace = torch_npu._npu_paged_attention_get_workspace(
                    query=query,
                    key_cache=self.key_cache,
                    value_cache=self.value_cache,
                    num_kv_heads=self.num_kv_heads,
                    num_heads=self.num_heads,
                    scale_value=self.scale,
                    block_table=attn_metadata.block_tables,
                    context_lens=attn_metadata.seq_lens,
                    out=output)
                update_graph_params_workspaces(num_tokens,
                                               weak_ref_tensors(workspace))

            # Handle graph capturing mode
            stream = torch_npu.npu.current_stream()

            event = torch.npu.ExternalEvent()
            event.wait(stream)
            event.reset(stream)
            graph_params.events[num_tokens].append(event)
            graph_params.attn_params[num_tokens].append((
                weak_ref_tensors(query),
                weak_ref_tensors(self.key_cache),
                weak_ref_tensors(self.value_cache),
                self.num_kv_heads,
                self.num_heads,
                self.scale,
                attn_metadata.block_tables,
                attn_metadata.seq_lens,
                weak_ref_tensors(output),
            ))

            torch.npu.graph_task_group_begin(stream)
            torch_npu._npu_paged_attention(
                query=query,
                key_cache=self.key_cache,
                value_cache=self.value_cache,
                num_kv_heads=self.num_kv_heads,
                num_heads=self.num_heads,
                scale_value=self.scale,
                block_table=attn_metadata.block_tables,
                context_lens=attn_metadata.seq_lens,
                out=output,
                workspace=workspace)
            handle = torch.npu.graph_task_group_end(stream)
            graph_params.handles[num_tokens].append(handle)
        else:
            torch_npu._npu_paged_attention(
                query=query,
                key_cache=self.key_cache,
                value_cache=self.value_cache,
                num_kv_heads=self.num_kv_heads,
                num_heads=self.num_heads,
                scale_value=self.scale,
                block_table=attn_metadata.block_tables,
                context_lens=attn_metadata.seq_lens,
                out=output)
    return output

def _forward_encode(
    self,
    query,
    key,
    value,
    attn_metadata,
    output,
) -> torch.Tensor:
    cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
    output = torch_npu.npu_fusion_attention(
        query,
        key,
        value,
        head_num=self.num_heads,
        input_layout="TND",
        scale=self.scale,
        sparse_mode=4,
        atten_mask=attn_metadata.attn_mask,
        pre_tockens=attn_metadata.max_query_len,
        next_tockens=attn_metadata.max_query_len,
        actual_seq_qlen=cum_seq_len,
        actual_seq_kvlen=cum_seq_len,
    )[0]
    return output


def forward(
    self,
    layer: AttentionLayer,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: Tuple[torch.Tensor],
    attn_metadata: AscendMetadata,
    output: Optional[torch.Tensor] = None,
    trace_flag: bool = True,
) -> torch.Tensor:
    """Forward pass with Ascend attention.
    Args:
        query: shape = [batch_size, seq_len, num_heads * head_size]
        key: shape = [batch_size, seq_len, num_kv_heads * head_size]
        value: shape = [batch_size, seq_len, num_kv_heads * head_size]
        kv_cache: shape = [key_cache, value_cache]
                  key_cache = [num_blocks, block_size,
                               num_kv_heads, head_size]
                  value_cache = [num_blocks, block_size,
                                 num_kv_heads, head_size]
        attn_metadata: Metadata for attention.
    Returns:
        shape = [batch_size * seq_len, num_heads, head_size]
    """
    if self.attn_type != AttentionType.DECODER and self.attn_type != AttentionType.ENCODER_ONLY:
        raise NotImplementedError("Encoder/decoder cross-attention "
                                  "are not implemented for "
                                  "AscendAttentionBackendImpl")
    assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0

    num_tokens = query.shape[0]
    use_kv_cache_int8 = len(
        kv_cache) > 0 and kv_cache[0].dtype == torch.int8

    if output is None:
        output = torch.empty(num_tokens,
                             self.num_heads,
                             self.head_size,
                             dtype=query.dtype,
                             device=query.device)
    if trace_flag:
        torch.ops.vllm.unified_ascend_attention_with_output(
            query=query,
            key=key,
            value=value,
            output=output,
            layer_name=layer.layer_name)
        return output.view(num_tokens, self.hidden_size)
    
    if attn_metadata is None:
        return output.view(num_tokens, self.hidden_size).fill_(0)
    # ori_output = output
    if hasattr(layer, 'quant_method') and use_kv_cache_int8:
        output = layer.quant_method.apply(layer, query, key, value,
                                          kv_cache, attn_metadata,
                                          self.attn_type, self.scale,
                                          output)
        return output.view(num_tokens, self.hidden_size)


    # View q k v to BSH.
    query = query.view(-1, self.num_heads, self.head_size)
    key = key.view(-1, self.num_kv_heads, self.head_size)
    value = value.view(-1, self.num_kv_heads, self.head_size)
    # TODO: Remove this contiguous in the future.
    value = value.contiguous()

    if self.attn_type == AttentionType.ENCODER_ONLY:
        ori_output = output
        output = self._forward_encode(query, key, value, attn_metadata)
        ori_output[:num_tokens, :, :] = output[:num_tokens, :, :]
        return ori_output.view(num_tokens, self.hidden_size)

    if len(kv_cache) > 1:
        if self.key_cache is None:
            self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
        slots = attn_metadata.slot_mapping
        num_actual_tokens = attn_metadata.num_actual_tokens
        torch_npu._npu_reshape_and_cache(key=key[:num_actual_tokens],
                                         value=value[:num_actual_tokens],
                                         key_cache=self.key_cache,
                                         value_cache=self.value_cache,
                                         slot_indices=slots)

    # V0-Style scheduler situation.
    if attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
        output = self._forward_decode_only(query, attn_metadata, output)
    else:
        output = self._forward_prefill(query, key, value, attn_metadata, output)

    return output.view(num_tokens, self.hidden_size)

def unified_ascend_attention_with_output(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self,
query,
key,
value,
kv_cache,
attn_metadata,
output,
trace_flag=False)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return

def unified_attention_with_output_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
return

direct_register_custom_op(
op_name="unified_ascend_attention_with_output",
op_func=unified_ascend_attention_with_output,
mutates_args=["output"],
fake_impl=unified_attention_with_output_fake,
dispatch_key="PrivateUse1",
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant