Skip to content

Conversation

@drslark
Copy link
Contributor

@drslark drslark commented Nov 26, 2025

What this PR does / why we need it?

Adapted Qwen3-Next to v0.11.2.

For the program:

prompts = [
    "Hello, my name is",
]

sampling_params = SamplingParams(temperature=0.0, top_p=0.95, top_k=40, max_tokens=128)
llm = LLM(model="/home/model/Qwen3-Next-80B-A3B-Instruct",
          additional_config={"ascend_scheduler_config": {"enabled":True, "enable_chunked_prefill":False}},
          tensor_parallel_size=4,
          enforce_eager=True,
          distributed_executor_backend="mp",
          gpu_memory_utilization=0.7,
          max_model_len=4096)

outputs = llm.generate(prompts, sampling_params)
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

The output was:

Prompt: 'Hello, my name is', Generated text: ' 1111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111'

The output now is:

Prompt: 'Hello, my name is', Generated text: ' <PRESIDIO_ANONYMIZED_PERSON>. I am a 20-year-old male from the United States. I am currently studying computer science at the University of California, Berkeley. I am interested in machine learning and artificial intelligence. I am also interested in the ethical implications of AI and how it can be used to improve society. I am currently working on a project that involves using machine learning to predict the spread of infectious diseases. I am also interested in the use of AI in healthcare and how it can be used to improve patient outcomes. I am passionate about using technology to make a positive impact on the world. I am'

Does this PR introduce any user-facing change?

N/A

How was this patch tested?

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adapts the Qwen3-Next model to vLLM v0.11.2. The changes include enabling a previously skipped test and fixing a variable shadowing bug in the model implementation. The fix for the shadowing bug is correct and prevents a runtime error. I have identified an additional potential critical issue: an IndexError could occur if batch_size is zero, which would crash the server. I've provided a comment with a code suggestion to make the code more robust against this edge case.

Comment on lines +708 to 719
tar_dtype = temp_core_attn_out[0].dtype
tar_device = temp_core_attn_out[0].device
tar_shape = list(temp_core_attn_out[0].shape)
tar_shape[1] = non_spec_query_start_loc[-1]
core_attn_out_non_spec = torch.empty(tar_shape,
dtype=tar_dtype,
device=tar_device)
for b_idx in range(batch_size):
cur_core_attn_out = core_attn_out[b_idx]
cur_core_attn_out = temp_core_attn_out[b_idx]
start, end = non_spec_query_start_loc[
b_idx], non_spec_query_start_loc[b_idx + 1]
core_attn_out_non_spec[:, start:end, ...] = cur_core_attn_out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a potential IndexError here. If batch_size is 0, temp_core_attn_out will be an empty list, and accessing temp_core_attn_out[0] at line 708 will raise an exception. While it seems unlikely for batch_size to be 0 when num_prefills > 0, it's safer to guard against this to prevent a server crash.

Additionally, torch.cat(last_recurrent_state, dim=0) at line 720 will also fail if last_recurrent_state is an empty list (when batch_size is 0).

I suggest wrapping this block and line 720 in a check for batch_size > 0 and handling the batch_size == 0 case separately by creating empty tensors for core_attn_out_non_spec and last_recurrent_state.

Here is a suggested implementation for lines 708-719. Please note that line 720 should also be moved inside the if batch_size > 0: block.

            if batch_size > 0:
                tar_dtype = temp_core_attn_out[0].dtype
                tar_device = temp_core_attn_out[0].device
                tar_shape = list(temp_core_attn_out[0].shape)
                tar_shape[1] = non_spec_query_start_loc[-1]
                core_attn_out_non_spec = torch.empty(tar_shape,
                                                     dtype=tar_dtype,
                                                     device=tar_device)
                for b_idx in range(batch_size):
                    cur_core_attn_out = temp_core_attn_out[b_idx]
                    start, end = non_spec_query_start_loc[
                        b_idx], non_spec_query_start_loc[b_idx + 1]
                    core_attn_out_non_spec[:, start:end, ...] = cur_core_attn_out
            else:
                num_v_heads = self.num_v_heads // self.tp_size
                core_attn_out_non_spec = torch.empty(
                    (1, 0, num_v_heads, self.head_v_dim),
                    dtype=ssm_state.dtype,
                    device=ssm_state.device
                )

@Yikun Yikun added ready read for review ready-for-test start test by label for PR labels Nov 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module:tests ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants