From df4965c20a43162fbe6b0079a0387309dca39025 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 4 Sep 2025 12:57:52 +0000 Subject: [PATCH 1/5] fix do_Sample=True in continous batching --- examples/pytorch/continuous_batching.py | 4 +++- .../generation/continuous_batching/continuous_api.py | 10 +++++++--- src/transformers/integrations/flash_paged.py | 2 +- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index b5ad94ed3f11..2c6c6212d11f 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -229,7 +229,9 @@ def batch_generate( use_cuda_graph=args.use_cuda_graph, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, - do_sample=False, + do_sample=True, + temperature=0.8, + top_p=1.0, num_blocks=args.num_blocks, max_batch_tokens=args.max_batch_tokens, ) diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index b46880fa4ce1..2cbceeff75bf 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -624,10 +624,14 @@ def _process_logit(self, batch_data, logits): def _sample(self, batch_processor: ContinuousBatchProcessor, probs): if self.do_sample: # sample probs = nn.functional.softmax(probs, dim=-1) - next_tokens = torch.multinomial(probs[0], num_samples=1).squeeze(1) + # probs[0] has shape [seq_len, vocab_size], multinomial returns [seq_len, 1] + next_tokens = torch.multinomial(probs[0], num_samples=1).squeeze(-1) # Now [seq_len] + # Add batch dimension back to match argmax output + next_tokens = next_tokens.unsqueeze(0) # Now [1, seq_len] else: - next_tokens = torch.argmax(probs, dim=-1) - tokens = next_tokens.size(1) + next_tokens = torch.argmax(probs, dim=-1) # Already [1, seq_len] + + tokens = next_tokens.size(1) # Get seq_len dimension batch_processor.output_ids[:, :tokens].copy_(next_tokens) def _run_generation_loop(self): diff --git a/src/transformers/integrations/flash_paged.py b/src/transformers/integrations/flash_paged.py index 352bc82a1e40..00836beabe13 100644 --- a/src/transformers/integrations/flash_paged.py +++ b/src/transformers/integrations/flash_paged.py @@ -51,7 +51,7 @@ def paged_attention_forward( k, v = cache.update(k, v, module.layer_idx, **kwargs) sliding_window = (-1, -1) if not getattr(module, "sliding_window", False) else (module.sliding_window, 0) - if implementation is not None: + if implementation is not None and hasattr(implementation, "flash_attn_varlen_func"): flash_attn_varlen_func = implementation.flash_attn_varlen_func custom_kwargs = {"s_aux": kwargs.get("s_aux")} if "s_aux" in kwargs else {} attn_output = flash_attn_varlen_func( From 40a32ece5b955ef9e8bd8b1b93e75fb6c6ea37e0 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 5 Sep 2025 14:57:53 +0200 Subject: [PATCH 2/5] added test --- tests/generation/test_paged_attention.py | 59 +++++++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_paged_attention.py b/tests/generation/test_paged_attention.py index d3241e237466..a4e1f1869b98 100644 --- a/tests/generation/test_paged_attention.py +++ b/tests/generation/test_paged_attention.py @@ -78,9 +78,66 @@ def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max ) for i, req_id in enumerate(batch_outputs): - generated = self.tokenizer.decode(batch_outputs[req_id].static_outputs, skip_special_tokens=False).strip() + generated = self.tokenizer.decode(batch_outputs[req_id].generated_tokens, skip_special_tokens=False).strip() expected = _EXPECTED_OUTPUTS[i].strip() self.assertTrue( generated.startswith(expected), msg=f"[{attn_impl}] Mismatch in request {i}:\nExpected start: {expected}\nGot: {generated}", ) + + @parameterized.expand( + [ + ("eager_paged", 64, 128, 64), + ("sdpa_paged", 32, 256, 128), + ("paged_attention", 16, 512, 256), + ("flex_paged", 64, 128, 64), + ] + ) + def test_generate_batch_with_sampling(self, attn_impl, num_blocks, block_size, max_batch_tokens): + """Test batch generation with do_sampling=True to verify sampling works correctly.""" + self.model.config.attn_implementation = attn_impl + + generation_config = GenerationConfig( + max_new_tokens=30, + do_sample=True, + top_k=50, + temperature=0.8, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, + use_cache=False, + num_blocks=num_blocks, + block_size=block_size, + max_batch_tokens=max_batch_tokens, + ) + + tokenized = self.tokenizer(_TEST_PROMPTS, truncation=True, max_length=512) # Use fewer prompts for faster test + batch_inputs = list(tokenized["input_ids"]) + + start = time.time() + batch_outputs = self.model.generate_batch( + inputs=batch_inputs, + generation_config=generation_config, + ) + end = time.time() + print( + f"\n[{attn_impl}] Sampling batch took {end - start:.2f}s with config: blocks={num_blocks}, block_size={block_size}, max_batch_tokens={max_batch_tokens}" + ) + + # With sampling enabled, we can't check exact outputs, but we should verify: + # 1. All requests completed successfully + # 2. Generated text is non-empty + # 3. Generated text is different from greedy (demonstrating sampling is working) + self.assertEqual(len(batch_outputs), len(batch_inputs), f"[{attn_impl}] Not all requests completed") + + for i, req_id in enumerate(batch_outputs): + generated = self.tokenizer.decode(batch_outputs[req_id].generated_tokens, skip_special_tokens=False).strip() + self.assertTrue( + len(generated) > 0, + msg=f"[{attn_impl}] Empty output for request {i}", + ) + # Check that we got at least some tokens generated + generated_tokens = batch_outputs[req_id].generated_tokens + self.assertGreater( + len(generated_tokens), 0, + msg=f"[{attn_impl}] No tokens generated for request {i}", + ) From 0b6518ff08ff9cf62a77a565f7d75c1a238bc68e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 5 Sep 2025 15:10:43 +0200 Subject: [PATCH 3/5] fix top_p --- .../continuous_batching/continuous_api.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 25c3e185ff83..1c63507abe93 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -631,7 +631,18 @@ def _process_logit(self, batch_data, logits): self.logit_processor.set_continuous_batching_context( batch_data["logits_indices"], batch_data["cu_seq_lens_q"] ) - return self.logit_processor(batch_data["input_ids"], logits) + + # Handle shape compatibility: logit processors expect 2D tensors [batch_size, vocab_size] + # but continuous batching always produces 3D tensors [batch_size, seq_len, vocab_size] + batch_size, seq_len, vocab_size = logits.shape + logits_2d = logits.view(batch_size * seq_len, vocab_size) + input_ids_2d = batch_data["input_ids"].view(batch_size * seq_len) + + # Process with 2D tensors + processed_logits_2d = self.logit_processor(input_ids_2d, logits_2d) + + # Reshape back to 3D + return processed_logits_2d.view(batch_size, seq_len, vocab_size) @traced(span_name="sampling") def _sample(self, batch_processor: ContinuousBatchProcessor, probs): From c86227a56818ff696eff5eac74aa803d4ce7a9de Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 5 Sep 2025 15:11:07 +0200 Subject: [PATCH 4/5] test --- tests/generation/test_paged_attention.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/generation/test_paged_attention.py b/tests/generation/test_paged_attention.py index a4e1f1869b98..e7673f5f08cd 100644 --- a/tests/generation/test_paged_attention.py +++ b/tests/generation/test_paged_attention.py @@ -78,7 +78,9 @@ def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max ) for i, req_id in enumerate(batch_outputs): - generated = self.tokenizer.decode(batch_outputs[req_id].generated_tokens, skip_special_tokens=False).strip() + generated = self.tokenizer.decode( + batch_outputs[req_id].generated_tokens, skip_special_tokens=False + ).strip() expected = _EXPECTED_OUTPUTS[i].strip() self.assertTrue( generated.startswith(expected), @@ -101,6 +103,7 @@ def test_generate_batch_with_sampling(self, attn_impl, num_blocks, block_size, m max_new_tokens=30, do_sample=True, top_k=50, + top_p=0.9, temperature=0.8, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id, @@ -128,9 +131,11 @@ def test_generate_batch_with_sampling(self, attn_impl, num_blocks, block_size, m # 2. Generated text is non-empty # 3. Generated text is different from greedy (demonstrating sampling is working) self.assertEqual(len(batch_outputs), len(batch_inputs), f"[{attn_impl}] Not all requests completed") - + for i, req_id in enumerate(batch_outputs): - generated = self.tokenizer.decode(batch_outputs[req_id].generated_tokens, skip_special_tokens=False).strip() + generated = self.tokenizer.decode( + batch_outputs[req_id].generated_tokens, skip_special_tokens=False + ).strip() self.assertTrue( len(generated) > 0, msg=f"[{attn_impl}] Empty output for request {i}", @@ -138,6 +143,7 @@ def test_generate_batch_with_sampling(self, attn_impl, num_blocks, block_size, m # Check that we got at least some tokens generated generated_tokens = batch_outputs[req_id].generated_tokens self.assertGreater( - len(generated_tokens), 0, + len(generated_tokens), + 0, msg=f"[{attn_impl}] No tokens generated for request {i}", ) From 94eb97a328faca8914845ca6bc2e0c5d60a4c51e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 5 Sep 2025 16:51:02 +0200 Subject: [PATCH 5/5] Update examples/pytorch/continuous_batching.py --- examples/pytorch/continuous_batching.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index 2c6c6212d11f..7196dc994204 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -231,7 +231,7 @@ def batch_generate( pad_token_id=tokenizer.pad_token_id, do_sample=True, temperature=0.8, - top_p=1.0, + top_p=0.9, num_blocks=args.num_blocks, max_batch_tokens=args.max_batch_tokens, )