Skip to content

Conversation

@taoyao1221
Copy link

@taoyao1221 taoyao1221 commented Nov 26, 2025

What this PR does / why we need it?

This PR adapts the gpt-oss model and implements the SWA and Sink features required by gpt-oss based on the torch_npu.npu_fused_infer_attention_score_v2 fusion operator.

Does this PR introduce any user-facing change?

  1. Modified the code in the Attention section to adapt to the SWA and Sink features required by gpt-oss.
  2. Modified the code in the MoE section to add support for bias and swigluoai.

How was this patch tested?

image The eager mode and Piecewise ACLGraph mode were tested and verified using evalscope. Through optimization with the torch_npu.npu_fused_infer_attention_score_v2 fusion operator and graph mode, simple performance tests were conducted. Compared to the small operator mode, under 10 concurrent requests, input size of 128, and output size of 128: TTFT: 810ms (small operator mode) --> 228ms (FIA eager mode) --> 217ms (Piecewise ACLGraph mode); TPOT: 512ms (small operator mode) --> 79ms (FIA eager mode) --> 20ms (Piecewise ACLGraph mode).

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the gpt-oss model by implementing SWA and Sink features in the attention mechanism and adding bias and swigluoai support in the MoE layers. The changes look good overall, but I have identified a few critical and high-severity issues.

My main concerns are:

  • A hardcoded mask size in the attention implementation, which limits model compatibility.
  • Magic numbers used in the new swiglu_oai activation function.
  • Duplicated code for determining sparse_mode in attention methods.

Addressing these points will improve the code's robustness, readability, and maintainability.

Comment on lines +564 to +570
if self.sliding_window is not None:
mask = torch.ones(2048, 2048, dtype=torch.bool)
triu_mask = torch.triu(mask, diagonal=1).to("npu")
tril_mask = torch.tril(mask, -self.sliding_window).to("npu")
self.share_mask_tril_spase = triu_mask + tril_mask
else:
self.share_mask_tril_spase = ~torch.tril(torch.ones((2048, 2048), device='npu', dtype=torch.bool))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The size of the attention mask share_mask_tril_spase is hardcoded to 2048x2048. This will cause runtime errors or incorrect behavior for models with a max_model_len greater than 2048. This value should be dynamic and based on the model's configuration.

I'd suggest passing max_model_len to the __init__ method and using it to create the mask. This will make the implementation more robust. I've used self.max_model_len in the suggestion, assuming it can be made available on the class instance.

Suggested change
if self.sliding_window is not None:
mask = torch.ones(2048, 2048, dtype=torch.bool)
triu_mask = torch.triu(mask, diagonal=1).to("npu")
tril_mask = torch.tril(mask, -self.sliding_window).to("npu")
self.share_mask_tril_spase = triu_mask + tril_mask
else:
self.share_mask_tril_spase = ~torch.tril(torch.ones((2048, 2048), device='npu', dtype=torch.bool))
if self.sliding_window is not None:
mask = torch.ones(self.max_model_len, self.max_model_len, dtype=torch.bool)
triu_mask = torch.triu(mask, diagonal=1).to("npu")
tril_mask = torch.tril(mask, -self.sliding_window).to("npu")
self.share_mask_tril_spase = triu_mask + tril_mask
else:
self.share_mask_tril_spase = ~torch.tril(torch.ones((self.max_model_len, self.max_model_len), device='npu', dtype=torch.bool))

Comment on lines +726 to +729
if self.sliding_window is not None:
sparse_mode=4
else:
sparse_mode=3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This logic for determining sparse_mode is duplicated in _forward_prefill_cache_hit (lines 776-779) and _forward_decode_only (lines 851-854). To improve maintainability, consider calculating this value once in the __init__ method and storing it as a member variable, for example self.sparse_mode_with_sinks. This would avoid code repetition and make future changes easier.

Comment on lines +219 to +227
def swiglu_oai(gate_up):
alpha = 1.702
limit = 7.0
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
gate = gate.clamp(min=None, max=limit)
up = up.clamp(min=-limit, max=limit)
glu = gate * torch.sigmoid(gate * alpha)
gated_output = (up + 1) * glu
return gated_output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The swiglu_oai function uses magic numbers 1.702 and 7.0. These should be defined as constants with descriptive names to improve readability and maintainability. I've replaced them with SWIGLU_OAI_ALPHA and SWIGLU_OAI_CLAMP_LIMIT in the suggestion. It would be even better to define these constants at the module level.

    def swiglu_oai(gate_up):
        # For readability and maintainability, it's better to define
        # these magic numbers as module-level constants.
        SWIGLU_OAI_ALPHA = 1.702
        SWIGLU_OAI_CLAMP_LIMIT = 7.0
        gate, up = gate_up[..., ::2], gate_up[..., 1::2]
        gate = gate.clamp(min=None, max=SWIGLU_OAI_CLAMP_LIMIT)
        up = up.clamp(min=-SWIGLU_OAI_CLAMP_LIMIT, max=SWIGLU_OAI_CLAMP_LIMIT)
        glu = gate * torch.sigmoid(gate * SWIGLU_OAI_ALPHA)
        gated_output = (up + 1) * glu
        return gated_output

@github-actions
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant