Skip to content
Open
8 changes: 8 additions & 0 deletions src/transformers/generation/continuous_batching/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import torch

from ...utils import is_torch_xpu_available
from ...utils.logging import logging
from ...utils.metrics import traced

Expand All @@ -35,6 +36,13 @@ def get_device_and_memory_breakdown() -> tuple[torch.device, int, int, int]:
total_memory = torch.cuda.get_device_properties(device).total_memory
reserved_memory = torch.cuda.memory_reserved(device)
allocated_memory = torch.cuda.memory_allocated(device)
elif is_torch_xpu_available():
device = torch.device("xpu")
torch.xpu.empty_cache()
torch.xpu.synchronize()
total_memory = torch.xpu.get_device_properties(device).total_memory
reserved_memory = torch.xpu.memory_reserved(device)
allocated_memory = torch.xpu.memory_allocated(device)
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = torch.device("mps")
# MPS memory reporting (PyTorch 2.0+)
Expand Down
9 changes: 7 additions & 2 deletions src/transformers/integrations/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ..utils import is_accelerate_available, is_torch_available, logging
from ..utils import is_accelerate_available, is_torch_available, is_torch_xpu_available, logging


if is_torch_available():
Expand Down Expand Up @@ -114,6 +114,9 @@ def convert_moe_packed_tensors(
if not blocks.is_cuda and torch.cuda.is_available():
blocks = blocks.cuda()
scales = scales.cuda()
elif (blocks.device.type != "xpu") and is_torch_xpu_available():
blocks = blocks.to("xpu")
scales = scales.to("xpu")

scales = scales.to(torch.int32) - 127 # TODO that's because 128=2**7

Expand Down Expand Up @@ -351,6 +354,8 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, **
dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr))
if target_device == "cpu" and torch.cuda.is_available():
torch.cuda.empty_cache()
elif target_device == "cpu" and is_torch_xpu_available():
torch.xpu.empty_cache()
setattr(module, proj, torch.nn.Parameter(dequantized.to(target_device)))
delattr(module, blocks_attr)
delattr(module, scales_attr)
Expand Down Expand Up @@ -395,7 +400,7 @@ def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, trito
else:
blocks = blocks.reshape(local_experts, -1, module.intermediate_size // 2)
if getattr(target_device, "type", target_device) == "cpu":
target_device = "cuda"
target_device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if @SunMarc can OK this, not familiar w/ different accelerators

blocks = blocks.to(target_device).contiguous()
scales = scales.to(target_device).contiguous()
with on_device(target_device):
Expand Down
65 changes: 46 additions & 19 deletions tests/generation/test_continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,15 @@
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList
from transformers.generation.continuous_batching.cache import group_layers_by_attn_type
from transformers.generation.continuous_batching.continuous_api import build_attention_mask
from transformers.testing_utils import Expectations, require_kernels, require_read_token, require_torch_gpu, slow
from transformers.testing_utils import (
Expectations,
require_kernels,
require_read_token,
require_torch_accelerator,
require_torch_gpu,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is @require_torch_gpu still used after those changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@remi-or paged attention enabling on XPU in kernels is still ongoing, suppose be ready soon, then we will transfer the left cases from CUDA-only to XPU, at this time, we still have some paged attention cases which are still CUDA only.

slow,
torch_device,
)


ALLOW_EXPECTED_OUTPUTS = True # this is a debug flag when you want to measure deviation between CB and non-CB gen
Expand Down Expand Up @@ -148,7 +156,7 @@ def _continuous_batching_parity(

# Generation with continuous batching
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=attn_implementation, dtype="auto")
model = model.cuda().eval()
model = model.to(torch_device).eval()
model.generation_config.max_new_tokens = 40
model.generation_config.do_sample = False
model.generation_config.use_cuda_graph = False
Expand All @@ -169,14 +177,14 @@ def _continuous_batching_parity(
model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation=non_cb_attn_implementation, dtype="auto"
)
model = model.cuda().eval()
model = model.to(torch_device).eval()
model.generation_config.max_new_tokens = 40
model.generation_config.do_sample = False
model.generation_config.use_cuda_graph = False

for request_id, request in cb_outputs.items():
# Generate without continuous batching
input_ids = torch.tensor([request.prompt_ids]).cuda()
input_ids = torch.tensor([request.prompt_ids]).to(torch_device)
attention_mask = torch.ones_like(input_ids)
outputs = model.generate(
input_ids, attention_mask=attention_mask, generation_config=model.generation_config
Expand Down Expand Up @@ -208,8 +216,8 @@ def _continuous_batching_parity(
)

# Eager tests
@require_torch_accelerator
@require_read_token
@require_torch_gpu
@slow
def test_continuous_batching_parity_llama_eager(self) -> None:
expected_outputs = Expectations({
Expand All @@ -219,11 +227,15 @@ def test_continuous_batching_parity_llama_eager(self) -> None:
("cuda", (9, 0)): {
"req_1": " 3 bolts of blue fiber and 1.5 bolts of white fiber. The total number of bolts is 4.5. The total number of bolts is 4.5. The total",
"req_2": " $50,000. This is because the value of the house increased by 150%, which means that the value of the house increased by $50,000. This is because the value of the"
}
},
("xpu", None): {
"req_1": " 3 bolts of blue fiber and 1.5 bolts of white fiber. The answer is not 3.5 bolts of blue fiber and 1.5 bolts of white fiber. The answer'",
"req_2": " $50,000. This is because the value of the house increased by 150%, which means that the value of the house increased by $50,000. This is because the value of the"
},
}).get_expectation() # fmt: skip
self._continuous_batching_parity("meta-llama/Llama-3.1-8B", "paged|eager", expected_outputs)

@require_torch_gpu
@require_torch_accelerator
@slow
def test_continuous_batching_parity_gemma_eager(self) -> None:
expected_outputs = Expectations({
Expand All @@ -233,53 +245,68 @@ def test_continuous_batching_parity_gemma_eager(self) -> None:
("cuda", (9, 0)): {
"req_0": "\n\n**$12**\n\n**Here's how to solve it:**\n\n* **Eggs eaten:** 3\n* **Eggs left:** 16 - 3 = 13",
"req_1": " \n \n 2 + 1 = 3 bolts \n \n \n \n \n \n \n \n \n \n \n \n \n "
}
},
("xpu", None): {
"req_0": "\n\n**$12**\n\n**Here's how to solve it:**\n\n* **Eggs eaten:** 3\n* **Eggs left:** 16 - 3 = 13",
"req_1": " \n \n 2 + 1 = 3 bolts \n \n \n \n \n \n \n \n \n \n \n \n \n ",
"req_2": "\n\n**$100,000**\n\n**Explanation:**\n\nHere's how to calculate the profit:\n\n1. **Calculate the total cost:** $80,00",
},
}).get_expectation() # fmt: skip
self._continuous_batching_parity("google/gemma-2-2b-it", "paged|eager", expected_outputs)

# FIXME: set expected_outputs
# @require_torch_gpu
# @require_torch_accelerator
# @slow
# def test_continuous_batching_parity_qwen_eager(self) -> None:
# expected_outputs = {}
# self._continuous_batching_parity("Qwen/Qwen3-4B-Instruct-2507", "paged|eager", expected_outputs)

# FIXME: OOMs
# @require_torch_gpu
# @require_torch_accelerator
# @slow
# def test_continuous_batching_parity_gpt_oss_eager(self) -> None:
# expected_outputs = Expectations({
# ("cuda", (9, 0)): {
# "req_1": " 2.5 bolts. The question: \"What is the name of the puzzle that involves a robe taking 2 bolts of blue fiber and half that much white fiber?\" The answer: \"The",
# "req_2": " 50%.\"\n\nWe need to parse: He buys a house for $80,000. He puts in $50,000 in repairs. This increased the value of the house by 150%."
# }
# },
# ("xpu", None): {
# "req_1": " 2.5 bolts. The question: \"What is the name of the puzzle that involves a robe taking 2 bolts of blue fiber and half that much white fiber?\" The answer: \"The",
# "req_2": " 50%.\"\n\nWe need to parse: He buys a house for $80,000. He puts in $50,000 in repairs. This increased the value of the house by 150%."
# },
# }).get_expectation() # fmt: skip
# self._continuous_batching_parity("openai/gpt-oss-20b", "paged|eager", expected_outputs)

# SDPA tests
@require_read_token
@require_torch_gpu
@require_torch_accelerator
@slow
def test_continuous_batching_parity_llama_sdpa(self) -> None:
expected_outputs = Expectations({
("rocm", (9, 4)): {
"req_2": " $50,000. This is because the value of the house increased by 150%, which means that the value of the house increased by $50,000. This is because the value of the"
}
},
("xpu", None): {
"req_2": " $50,000. This is because the value of the house increased by 150%, which means that the value of the house increased by $50,000. This is because the value of the"
},
}).get_expectation() # fmt: skip
self._continuous_batching_parity("meta-llama/Llama-3.1-8B", "paged|sdpa", expected_outputs)

@require_torch_gpu
@require_torch_accelerator
@slow
def test_continuous_batching_parity_gemma_sdpa(self) -> None:
expected_outputs = Expectations({
("cuda", (9, 0)): {
"req_1": " \n\n**Answer:** 3 bolts\n\n**Solution:**\n\n* **White fiber:** The robe needs half as much white fiber as blue fiber, so it needs 2 bolts / 2 =",
}
},
("xpu", None): {
"req_1": " \n\n**Answer:** 3 bolts\n\n**Solution:**\n\n* **White fiber:** The robe needs half as much white fiber as blue fiber, so it needs 2 bolts / 2 =",
},
}).get_expectation() # fmt: skip
self._continuous_batching_parity("google/gemma-2-2b-it", "paged|sdpa", expected_outputs)

# FIXME: set expected_outputs
# @require_torch_gpu
# @require_torch_accelerator
# @slow
# def test_continuous_batching_parity_qwen_sdpa(self) -> None:
# expected_outputs = {}
Expand Down Expand Up @@ -333,7 +360,7 @@ def test_attn_implementation(self) -> None:
manager = model.init_continuous_batching()
assert "paged|eager" == manager.model.config._attn_implementation

@require_torch_gpu
@require_torch_accelerator
def test_streaming_request(self) -> None:
model_id = "Qwen/Qwen2.5-0.5B-Instruct"
max_new_tokens = 3
Expand Down Expand Up @@ -365,7 +392,7 @@ def test_streaming_request(self) -> None:

manager.stop(block=True)

@require_torch_gpu
@require_torch_accelerator
def test_non_streaming_request(self) -> None:
model_id = "Qwen/Qwen2.5-0.5B-Instruct"
max_new_tokens = 3
Expand All @@ -392,7 +419,7 @@ def test_non_streaming_request(self) -> None:

manager.stop(block=True)

@require_torch_gpu
@require_torch_accelerator
def test_streaming_and_non_streaming_requests_can_alternate(self) -> None:
model_id = "Qwen/Qwen2.5-0.5B-Instruct"
max_new_tokens = 3
Expand Down
Loading