Skip to content

Commit 3f7bda4

Browse files
authored
[Continous Batching] fix do_Sample=True in continuous batching (#40692)
* fix do_Sample=True in continous batching * added test * fix top_p * test * Update examples/pytorch/continuous_batching.py
1 parent bb45d36 commit 3f7bda4

File tree

4 files changed

+87
-7
lines changed

4 files changed

+87
-7
lines changed

examples/pytorch/continuous_batching.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,9 @@ def batch_generate(
229229
use_cuda_graph=args.use_cuda_graph,
230230
eos_token_id=tokenizer.eos_token_id,
231231
pad_token_id=tokenizer.pad_token_id,
232-
do_sample=False,
232+
do_sample=True,
233+
temperature=0.8,
234+
top_p=0.9,
233235
num_blocks=args.num_blocks,
234236
max_batch_tokens=args.max_batch_tokens,
235237
)

src/transformers/generation/continuous_batching/continuous_api.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -631,16 +631,31 @@ def _process_logit(self, batch_data, logits):
631631
self.logit_processor.set_continuous_batching_context(
632632
batch_data["logits_indices"], batch_data["cu_seq_lens_q"]
633633
)
634-
return self.logit_processor(batch_data["input_ids"], logits)
634+
635+
# Handle shape compatibility: logit processors expect 2D tensors [batch_size, vocab_size]
636+
# but continuous batching always produces 3D tensors [batch_size, seq_len, vocab_size]
637+
batch_size, seq_len, vocab_size = logits.shape
638+
logits_2d = logits.view(batch_size * seq_len, vocab_size)
639+
input_ids_2d = batch_data["input_ids"].view(batch_size * seq_len)
640+
641+
# Process with 2D tensors
642+
processed_logits_2d = self.logit_processor(input_ids_2d, logits_2d)
643+
644+
# Reshape back to 3D
645+
return processed_logits_2d.view(batch_size, seq_len, vocab_size)
635646

636647
@traced(span_name="sampling")
637648
def _sample(self, batch_processor: ContinuousBatchProcessor, probs):
638649
if self.do_sample: # sample
639650
probs = nn.functional.softmax(probs, dim=-1)
640-
next_tokens = torch.multinomial(probs[0], num_samples=1).squeeze(1)
651+
# probs[0] has shape [seq_len, vocab_size], multinomial returns [seq_len, 1]
652+
next_tokens = torch.multinomial(probs[0], num_samples=1).squeeze(-1) # Now [seq_len]
653+
# Add batch dimension back to match argmax output
654+
next_tokens = next_tokens.unsqueeze(0) # Now [1, seq_len]
641655
else:
642-
next_tokens = torch.argmax(probs, dim=-1)
643-
tokens = next_tokens.size(1)
656+
next_tokens = torch.argmax(probs, dim=-1) # Already [1, seq_len]
657+
658+
tokens = next_tokens.size(1) # Get seq_len dimension
644659
batch_processor.output_ids[:, :tokens].copy_(next_tokens)
645660

646661
def _run_generation_loop(self):

src/transformers/integrations/flash_paged.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def paged_attention_forward(
5353
k, v = cache.update(k, v, module.layer_idx, **kwargs)
5454

5555
sliding_window = (-1, -1) if not getattr(module, "sliding_window", False) else (module.sliding_window, 0)
56-
if implementation is not None:
56+
if implementation is not None and hasattr(implementation, "flash_attn_varlen_func"):
5757
flash_attn_varlen_func = implementation.flash_attn_varlen_func
5858
custom_kwargs = {"s_aux": kwargs.get("s_aux")} if "s_aux" in kwargs else {}
5959
attn_output = flash_attn_varlen_func(

tests/generation/test_paged_attention.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,72 @@ def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max
7878
)
7979

8080
for i, req_id in enumerate(batch_outputs):
81-
generated = self.tokenizer.decode(batch_outputs[req_id].static_outputs, skip_special_tokens=False).strip()
81+
generated = self.tokenizer.decode(
82+
batch_outputs[req_id].generated_tokens, skip_special_tokens=False
83+
).strip()
8284
expected = _EXPECTED_OUTPUTS[i].strip()
8385
self.assertTrue(
8486
generated.startswith(expected),
8587
msg=f"[{attn_impl}] Mismatch in request {i}:\nExpected start: {expected}\nGot: {generated}",
8688
)
89+
90+
@parameterized.expand(
91+
[
92+
("eager_paged", 64, 128, 64),
93+
("sdpa_paged", 32, 256, 128),
94+
("paged_attention", 16, 512, 256),
95+
("flex_paged", 64, 128, 64),
96+
]
97+
)
98+
def test_generate_batch_with_sampling(self, attn_impl, num_blocks, block_size, max_batch_tokens):
99+
"""Test batch generation with do_sampling=True to verify sampling works correctly."""
100+
self.model.config.attn_implementation = attn_impl
101+
102+
generation_config = GenerationConfig(
103+
max_new_tokens=30,
104+
do_sample=True,
105+
top_k=50,
106+
top_p=0.9,
107+
temperature=0.8,
108+
eos_token_id=self.tokenizer.eos_token_id,
109+
pad_token_id=self.tokenizer.pad_token_id,
110+
use_cache=False,
111+
num_blocks=num_blocks,
112+
block_size=block_size,
113+
max_batch_tokens=max_batch_tokens,
114+
)
115+
116+
tokenized = self.tokenizer(_TEST_PROMPTS, truncation=True, max_length=512) # Use fewer prompts for faster test
117+
batch_inputs = list(tokenized["input_ids"])
118+
119+
start = time.time()
120+
batch_outputs = self.model.generate_batch(
121+
inputs=batch_inputs,
122+
generation_config=generation_config,
123+
)
124+
end = time.time()
125+
print(
126+
f"\n[{attn_impl}] Sampling batch took {end - start:.2f}s with config: blocks={num_blocks}, block_size={block_size}, max_batch_tokens={max_batch_tokens}"
127+
)
128+
129+
# With sampling enabled, we can't check exact outputs, but we should verify:
130+
# 1. All requests completed successfully
131+
# 2. Generated text is non-empty
132+
# 3. Generated text is different from greedy (demonstrating sampling is working)
133+
self.assertEqual(len(batch_outputs), len(batch_inputs), f"[{attn_impl}] Not all requests completed")
134+
135+
for i, req_id in enumerate(batch_outputs):
136+
generated = self.tokenizer.decode(
137+
batch_outputs[req_id].generated_tokens, skip_special_tokens=False
138+
).strip()
139+
self.assertTrue(
140+
len(generated) > 0,
141+
msg=f"[{attn_impl}] Empty output for request {i}",
142+
)
143+
# Check that we got at least some tokens generated
144+
generated_tokens = batch_outputs[req_id].generated_tokens
145+
self.assertGreater(
146+
len(generated_tokens),
147+
0,
148+
msg=f"[{attn_impl}] No tokens generated for request {i}",
149+
)

0 commit comments

Comments
 (0)