diff --git a/src/transformers/generation/continuous_batching/requests.py b/src/transformers/generation/continuous_batching/requests.py index 4316ac85ad3a..ba9e198055c5 100644 --- a/src/transformers/generation/continuous_batching/requests.py +++ b/src/transformers/generation/continuous_batching/requests.py @@ -19,6 +19,7 @@ import torch +from ...utils import is_torch_xpu_available from ...utils.logging import logging from ...utils.metrics import traced @@ -35,6 +36,13 @@ def get_device_and_memory_breakdown() -> tuple[torch.device, int, int, int]: total_memory = torch.cuda.get_device_properties(device).total_memory reserved_memory = torch.cuda.memory_reserved(device) allocated_memory = torch.cuda.memory_allocated(device) + elif is_torch_xpu_available(): + device = torch.device("xpu") + torch.xpu.empty_cache() + torch.xpu.synchronize() + total_memory = torch.xpu.get_device_properties(device).total_memory + reserved_memory = torch.xpu.memory_reserved(device) + allocated_memory = torch.xpu.memory_allocated(device) elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): device = torch.device("mps") # MPS memory reporting (PyTorch 2.0+) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 6a6ce1db17e7..6552e068aaaf 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..utils import is_accelerate_available, is_torch_available, logging +from ..utils import is_accelerate_available, is_torch_available, is_torch_xpu_available, logging if is_torch_available(): @@ -114,6 +114,9 @@ def convert_moe_packed_tensors( if not blocks.is_cuda and torch.cuda.is_available(): blocks = blocks.cuda() scales = scales.cuda() + elif (blocks.device.type != "xpu") and is_torch_xpu_available(): + blocks = blocks.to("xpu") + scales = scales.to("xpu") scales = scales.to(torch.int32) - 127 # TODO that's because 128=2**7 @@ -351,6 +354,8 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, ** dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr)) if target_device == "cpu" and torch.cuda.is_available(): torch.cuda.empty_cache() + elif target_device == "cpu" and is_torch_xpu_available(): + torch.xpu.empty_cache() setattr(module, proj, torch.nn.Parameter(dequantized.to(target_device))) delattr(module, blocks_attr) delattr(module, scales_attr) @@ -395,7 +400,7 @@ def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, trito else: blocks = blocks.reshape(local_experts, -1, module.intermediate_size // 2) if getattr(target_device, "type", target_device) == "cpu": - target_device = "cuda" + target_device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda" blocks = blocks.to(target_device).contiguous() scales = scales.to(target_device).contiguous() with on_device(target_device): diff --git a/tests/generation/test_continuous_batching.py b/tests/generation/test_continuous_batching.py index 1dd0a21c1a6e..1393623793fc 100644 --- a/tests/generation/test_continuous_batching.py +++ b/tests/generation/test_continuous_batching.py @@ -20,7 +20,15 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList from transformers.generation.continuous_batching.cache import group_layers_by_attn_type from transformers.generation.continuous_batching.continuous_api import build_attention_mask -from transformers.testing_utils import Expectations, require_kernels, require_read_token, require_torch_gpu, slow +from transformers.testing_utils import ( + Expectations, + require_kernels, + require_read_token, + require_torch_accelerator, + require_torch_gpu, + slow, + torch_device, +) ALLOW_EXPECTED_OUTPUTS = True # this is a debug flag when you want to measure deviation between CB and non-CB gen @@ -148,7 +156,7 @@ def _continuous_batching_parity( # Generation with continuous batching model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=attn_implementation, dtype="auto") - model = model.cuda().eval() + model = model.to(torch_device).eval() model.generation_config.max_new_tokens = 40 model.generation_config.do_sample = False model.generation_config.use_cuda_graph = False @@ -169,14 +177,14 @@ def _continuous_batching_parity( model = AutoModelForCausalLM.from_pretrained( model_id, attn_implementation=non_cb_attn_implementation, dtype="auto" ) - model = model.cuda().eval() + model = model.to(torch_device).eval() model.generation_config.max_new_tokens = 40 model.generation_config.do_sample = False model.generation_config.use_cuda_graph = False for request_id, request in cb_outputs.items(): # Generate without continuous batching - input_ids = torch.tensor([request.prompt_ids]).cuda() + input_ids = torch.tensor([request.prompt_ids]).to(torch_device) attention_mask = torch.ones_like(input_ids) outputs = model.generate( input_ids, attention_mask=attention_mask, generation_config=model.generation_config @@ -208,8 +216,8 @@ def _continuous_batching_parity( ) # Eager tests + @require_torch_accelerator @require_read_token - @require_torch_gpu @slow def test_continuous_batching_parity_llama_eager(self) -> None: expected_outputs = Expectations({ @@ -219,11 +227,15 @@ def test_continuous_batching_parity_llama_eager(self) -> None: ("cuda", (9, 0)): { "req_1": " 3 bolts of blue fiber and 1.5 bolts of white fiber. The total number of bolts is 4.5. The total number of bolts is 4.5. The total", "req_2": " $50,000. This is because the value of the house increased by 150%, which means that the value of the house increased by $50,000. This is because the value of the" - } + }, + ("xpu", None): { + "req_1": " 3 bolts of blue fiber and 1.5 bolts of white fiber. The answer is not 3.5 bolts of blue fiber and 1.5 bolts of white fiber. The answer'", + "req_2": " $50,000. This is because the value of the house increased by 150%, which means that the value of the house increased by $50,000. This is because the value of the" + }, }).get_expectation() # fmt: skip self._continuous_batching_parity("meta-llama/Llama-3.1-8B", "paged|eager", expected_outputs) - @require_torch_gpu + @require_torch_accelerator @slow def test_continuous_batching_parity_gemma_eager(self) -> None: expected_outputs = Expectations({ @@ -233,53 +245,68 @@ def test_continuous_batching_parity_gemma_eager(self) -> None: ("cuda", (9, 0)): { "req_0": "\n\n**$12**\n\n**Here's how to solve it:**\n\n* **Eggs eaten:** 3\n* **Eggs left:** 16 - 3 = 13", "req_1": " \n \n 2 + 1 = 3 bolts \n \n \n \n \n \n \n \n \n \n \n \n \n " - } + }, + ("xpu", None): { + "req_0": "\n\n**$12**\n\n**Here's how to solve it:**\n\n* **Eggs eaten:** 3\n* **Eggs left:** 16 - 3 = 13", + "req_1": " \n \n 2 + 1 = 3 bolts \n \n \n \n \n \n \n \n \n \n \n \n \n ", + "req_2": "\n\n**$100,000**\n\n**Explanation:**\n\nHere's how to calculate the profit:\n\n1. **Calculate the total cost:** $80,00", + }, }).get_expectation() # fmt: skip self._continuous_batching_parity("google/gemma-2-2b-it", "paged|eager", expected_outputs) # FIXME: set expected_outputs - # @require_torch_gpu + # @require_torch_accelerator # @slow # def test_continuous_batching_parity_qwen_eager(self) -> None: # expected_outputs = {} # self._continuous_batching_parity("Qwen/Qwen3-4B-Instruct-2507", "paged|eager", expected_outputs) # FIXME: OOMs - # @require_torch_gpu + # @require_torch_accelerator # @slow # def test_continuous_batching_parity_gpt_oss_eager(self) -> None: # expected_outputs = Expectations({ # ("cuda", (9, 0)): { # "req_1": " 2.5 bolts. The question: \"What is the name of the puzzle that involves a robe taking 2 bolts of blue fiber and half that much white fiber?\" The answer: \"The", # "req_2": " 50%.\"\n\nWe need to parse: He buys a house for $80,000. He puts in $50,000 in repairs. This increased the value of the house by 150%." - # } + # }, + # ("xpu", None): { + # "req_1": " 2.5 bolts. The question: \"What is the name of the puzzle that involves a robe taking 2 bolts of blue fiber and half that much white fiber?\" The answer: \"The", + # "req_2": " 50%.\"\n\nWe need to parse: He buys a house for $80,000. He puts in $50,000 in repairs. This increased the value of the house by 150%." + # }, # }).get_expectation() # fmt: skip # self._continuous_batching_parity("openai/gpt-oss-20b", "paged|eager", expected_outputs) # SDPA tests @require_read_token - @require_torch_gpu + @require_torch_accelerator @slow def test_continuous_batching_parity_llama_sdpa(self) -> None: expected_outputs = Expectations({ ("rocm", (9, 4)): { "req_2": " $50,000. This is because the value of the house increased by 150%, which means that the value of the house increased by $50,000. This is because the value of the" - } + }, + ("xpu", None): { + "req_2": " $50,000. This is because the value of the house increased by 150%, which means that the value of the house increased by $50,000. This is because the value of the" + }, }).get_expectation() # fmt: skip self._continuous_batching_parity("meta-llama/Llama-3.1-8B", "paged|sdpa", expected_outputs) - @require_torch_gpu + @require_torch_accelerator @slow def test_continuous_batching_parity_gemma_sdpa(self) -> None: expected_outputs = Expectations({ ("cuda", (9, 0)): { "req_1": " \n\n**Answer:** 3 bolts\n\n**Solution:**\n\n* **White fiber:** The robe needs half as much white fiber as blue fiber, so it needs 2 bolts / 2 =", - } + }, + ("xpu", None): { + "req_1": " \n\n**Answer:** 3 bolts\n\n**Solution:**\n\n* **White fiber:** The robe needs half as much white fiber as blue fiber, so it needs 2 bolts / 2 =", + }, }).get_expectation() # fmt: skip self._continuous_batching_parity("google/gemma-2-2b-it", "paged|sdpa", expected_outputs) # FIXME: set expected_outputs - # @require_torch_gpu + # @require_torch_accelerator # @slow # def test_continuous_batching_parity_qwen_sdpa(self) -> None: # expected_outputs = {} @@ -333,7 +360,7 @@ def test_attn_implementation(self) -> None: manager = model.init_continuous_batching() assert "paged|eager" == manager.model.config._attn_implementation - @require_torch_gpu + @require_torch_accelerator def test_streaming_request(self) -> None: model_id = "Qwen/Qwen2.5-0.5B-Instruct" max_new_tokens = 3 @@ -365,7 +392,7 @@ def test_streaming_request(self) -> None: manager.stop(block=True) - @require_torch_gpu + @require_torch_accelerator def test_non_streaming_request(self) -> None: model_id = "Qwen/Qwen2.5-0.5B-Instruct" max_new_tokens = 3 @@ -392,7 +419,7 @@ def test_non_streaming_request(self) -> None: manager.stop(block=True) - @require_torch_gpu + @require_torch_accelerator def test_streaming_and_non_streaming_requests_can_alternate(self) -> None: model_id = "Qwen/Qwen2.5-0.5B-Instruct" max_new_tokens = 3