-
Notifications
You must be signed in to change notification settings - Fork 31.1k
fix continuous batching issues, extend ut cases to xpu #41830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
24c8781
e1c845b
8ad4e9c
22c6a87
30e4ca0
52840e3
fb98d11
3a5bba5
a3bddf8
3dab82a
d9ea091
32f149f
aa785a2
130387b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,7 +20,15 @@ | |
| from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList | ||
| from transformers.generation.continuous_batching.cache import group_layers_by_attn_type | ||
| from transformers.generation.continuous_batching.continuous_api import build_attention_mask | ||
| from transformers.testing_utils import Expectations, require_kernels, require_read_token, require_torch_gpu, slow | ||
| from transformers.testing_utils import ( | ||
| Expectations, | ||
| require_kernels, | ||
| require_read_token, | ||
| require_torch_accelerator, | ||
| require_torch_gpu, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is @require_torch_gpu still used after those changes?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @remi-or paged attention enabling on XPU in kernels is still ongoing, suppose be ready soon, then we will transfer the left cases from CUDA-only to XPU, at this time, we still have some paged attention cases which are still CUDA only. |
||
| slow, | ||
| torch_device, | ||
| ) | ||
|
|
||
|
|
||
| ALLOW_EXPECTED_OUTPUTS = True # this is a debug flag when you want to measure deviation between CB and non-CB gen | ||
|
|
@@ -148,7 +156,7 @@ def _continuous_batching_parity( | |
|
|
||
| # Generation with continuous batching | ||
| model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=attn_implementation, dtype="auto") | ||
| model = model.cuda().eval() | ||
| model = model.to(torch_device).eval() | ||
| model.generation_config.max_new_tokens = 40 | ||
| model.generation_config.do_sample = False | ||
| model.generation_config.use_cuda_graph = False | ||
|
|
@@ -169,14 +177,14 @@ def _continuous_batching_parity( | |
| model = AutoModelForCausalLM.from_pretrained( | ||
| model_id, attn_implementation=non_cb_attn_implementation, dtype="auto" | ||
| ) | ||
| model = model.cuda().eval() | ||
| model = model.to(torch_device).eval() | ||
| model.generation_config.max_new_tokens = 40 | ||
| model.generation_config.do_sample = False | ||
| model.generation_config.use_cuda_graph = False | ||
|
|
||
| for request_id, request in cb_outputs.items(): | ||
| # Generate without continuous batching | ||
| input_ids = torch.tensor([request.prompt_ids]).cuda() | ||
| input_ids = torch.tensor([request.prompt_ids]).to(torch_device) | ||
| attention_mask = torch.ones_like(input_ids) | ||
| outputs = model.generate( | ||
| input_ids, attention_mask=attention_mask, generation_config=model.generation_config | ||
|
|
@@ -208,8 +216,8 @@ def _continuous_batching_parity( | |
| ) | ||
|
|
||
| # Eager tests | ||
| @require_torch_accelerator | ||
| @require_read_token | ||
| @require_torch_gpu | ||
| @slow | ||
| def test_continuous_batching_parity_llama_eager(self) -> None: | ||
| expected_outputs = Expectations({ | ||
|
|
@@ -219,11 +227,15 @@ def test_continuous_batching_parity_llama_eager(self) -> None: | |
| ("cuda", (9, 0)): { | ||
| "req_1": " 3 bolts of blue fiber and 1.5 bolts of white fiber. The total number of bolts is 4.5. The total number of bolts is 4.5. The total", | ||
| "req_2": " $50,000. This is because the value of the house increased by 150%, which means that the value of the house increased by $50,000. This is because the value of the" | ||
| } | ||
| }, | ||
| ("xpu", None): { | ||
| "req_1": " 3 bolts of blue fiber and 1.5 bolts of white fiber. The answer is not 3.5 bolts of blue fiber and 1.5 bolts of white fiber. The answer'", | ||
| "req_2": " $50,000. This is because the value of the house increased by 150%, which means that the value of the house increased by $50,000. This is because the value of the" | ||
| }, | ||
| }).get_expectation() # fmt: skip | ||
| self._continuous_batching_parity("meta-llama/Llama-3.1-8B", "paged|eager", expected_outputs) | ||
|
|
||
| @require_torch_gpu | ||
| @require_torch_accelerator | ||
| @slow | ||
| def test_continuous_batching_parity_gemma_eager(self) -> None: | ||
| expected_outputs = Expectations({ | ||
|
|
@@ -233,53 +245,68 @@ def test_continuous_batching_parity_gemma_eager(self) -> None: | |
| ("cuda", (9, 0)): { | ||
| "req_0": "\n\n**$12**\n\n**Here's how to solve it:**\n\n* **Eggs eaten:** 3\n* **Eggs left:** 16 - 3 = 13", | ||
| "req_1": " \n \n 2 + 1 = 3 bolts \n \n \n \n \n \n \n \n \n \n \n \n \n " | ||
| } | ||
| }, | ||
| ("xpu", None): { | ||
| "req_0": "\n\n**$12**\n\n**Here's how to solve it:**\n\n* **Eggs eaten:** 3\n* **Eggs left:** 16 - 3 = 13", | ||
| "req_1": " \n \n 2 + 1 = 3 bolts \n \n \n \n \n \n \n \n \n \n \n \n \n ", | ||
| "req_2": "\n\n**$100,000**\n\n**Explanation:**\n\nHere's how to calculate the profit:\n\n1. **Calculate the total cost:** $80,00", | ||
| }, | ||
| }).get_expectation() # fmt: skip | ||
| self._continuous_batching_parity("google/gemma-2-2b-it", "paged|eager", expected_outputs) | ||
|
|
||
| # FIXME: set expected_outputs | ||
| # @require_torch_gpu | ||
| # @require_torch_accelerator | ||
| # @slow | ||
| # def test_continuous_batching_parity_qwen_eager(self) -> None: | ||
| # expected_outputs = {} | ||
| # self._continuous_batching_parity("Qwen/Qwen3-4B-Instruct-2507", "paged|eager", expected_outputs) | ||
|
|
||
| # FIXME: OOMs | ||
| # @require_torch_gpu | ||
| # @require_torch_accelerator | ||
| # @slow | ||
| # def test_continuous_batching_parity_gpt_oss_eager(self) -> None: | ||
| # expected_outputs = Expectations({ | ||
| # ("cuda", (9, 0)): { | ||
| # "req_1": " 2.5 bolts. The question: \"What is the name of the puzzle that involves a robe taking 2 bolts of blue fiber and half that much white fiber?\" The answer: \"The", | ||
| # "req_2": " 50%.\"\n\nWe need to parse: He buys a house for $80,000. He puts in $50,000 in repairs. This increased the value of the house by 150%." | ||
| # } | ||
| # }, | ||
| # ("xpu", None): { | ||
| # "req_1": " 2.5 bolts. The question: \"What is the name of the puzzle that involves a robe taking 2 bolts of blue fiber and half that much white fiber?\" The answer: \"The", | ||
| # "req_2": " 50%.\"\n\nWe need to parse: He buys a house for $80,000. He puts in $50,000 in repairs. This increased the value of the house by 150%." | ||
| # }, | ||
| # }).get_expectation() # fmt: skip | ||
| # self._continuous_batching_parity("openai/gpt-oss-20b", "paged|eager", expected_outputs) | ||
|
|
||
| # SDPA tests | ||
| @require_read_token | ||
| @require_torch_gpu | ||
| @require_torch_accelerator | ||
| @slow | ||
| def test_continuous_batching_parity_llama_sdpa(self) -> None: | ||
| expected_outputs = Expectations({ | ||
| ("rocm", (9, 4)): { | ||
| "req_2": " $50,000. This is because the value of the house increased by 150%, which means that the value of the house increased by $50,000. This is because the value of the" | ||
| } | ||
| }, | ||
| ("xpu", None): { | ||
| "req_2": " $50,000. This is because the value of the house increased by 150%, which means that the value of the house increased by $50,000. This is because the value of the" | ||
| }, | ||
| }).get_expectation() # fmt: skip | ||
| self._continuous_batching_parity("meta-llama/Llama-3.1-8B", "paged|sdpa", expected_outputs) | ||
|
|
||
| @require_torch_gpu | ||
| @require_torch_accelerator | ||
| @slow | ||
| def test_continuous_batching_parity_gemma_sdpa(self) -> None: | ||
| expected_outputs = Expectations({ | ||
| ("cuda", (9, 0)): { | ||
| "req_1": " \n\n**Answer:** 3 bolts\n\n**Solution:**\n\n* **White fiber:** The robe needs half as much white fiber as blue fiber, so it needs 2 bolts / 2 =", | ||
| } | ||
| }, | ||
| ("xpu", None): { | ||
| "req_1": " \n\n**Answer:** 3 bolts\n\n**Solution:**\n\n* **White fiber:** The robe needs half as much white fiber as blue fiber, so it needs 2 bolts / 2 =", | ||
| }, | ||
| }).get_expectation() # fmt: skip | ||
| self._continuous_batching_parity("google/gemma-2-2b-it", "paged|sdpa", expected_outputs) | ||
|
|
||
| # FIXME: set expected_outputs | ||
| # @require_torch_gpu | ||
| # @require_torch_accelerator | ||
| # @slow | ||
| # def test_continuous_batching_parity_qwen_sdpa(self) -> None: | ||
| # expected_outputs = {} | ||
|
|
@@ -333,7 +360,7 @@ def test_attn_implementation(self) -> None: | |
| manager = model.init_continuous_batching() | ||
| assert "paged|eager" == manager.model.config._attn_implementation | ||
|
|
||
| @require_torch_gpu | ||
| @require_torch_accelerator | ||
| def test_streaming_request(self) -> None: | ||
| model_id = "Qwen/Qwen2.5-0.5B-Instruct" | ||
| max_new_tokens = 3 | ||
|
|
@@ -365,7 +392,7 @@ def test_streaming_request(self) -> None: | |
|
|
||
| manager.stop(block=True) | ||
|
|
||
| @require_torch_gpu | ||
| @require_torch_accelerator | ||
| def test_non_streaming_request(self) -> None: | ||
| model_id = "Qwen/Qwen2.5-0.5B-Instruct" | ||
| max_new_tokens = 3 | ||
|
|
@@ -392,7 +419,7 @@ def test_non_streaming_request(self) -> None: | |
|
|
||
| manager.stop(block=True) | ||
|
|
||
| @require_torch_gpu | ||
| @require_torch_accelerator | ||
| def test_streaming_and_non_streaming_requests_can_alternate(self) -> None: | ||
| model_id = "Qwen/Qwen2.5-0.5B-Instruct" | ||
| max_new_tokens = 3 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if @SunMarc can OK this, not familiar w/ different accelerators