Skip to content

Device mismatch error when evaluating converted PT2E quantized ViT model on CUDA/MPS #16250

@ofirgo

Description

@ofirgo

🐛 Describe the bug

After converting a quantized model using convert_pt2e(), attempting to evaluate the model on CUDA or MPS devices results in a device mismatch error. The error indicates that some tensors remain on CPU while others are on the target device (CUDA/MPS), specifically during dequantization operations.

Error Message
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

The error occurs during torch.ops.aten.mul.Tensor operation between dequantized tensors.

Reproduce Steps

import torch
from executorch.backends.arm.ethosu import EthosUCompileSpec
from executorch.backends.arm.quantizer import EthosUQuantizer, get_symmetric_quantization_config
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
from torchvision.models import vit_b_16, ViT_B_16_Weights

device = 'cuda'  # or 'mps'
batch_size = 4
example_input = torch.randn(batch_size, 3, 224, 224)

weights = ViT_B_16_Weights.IMAGENET1K_V1
model = vit_b_16(weights=weights)
model.eval()

exported_program = torch.export.export(model, (example_input,))
graph_model = exported_program.module(check_guards=False)

# Configure quantizer and prepare model
compile_spec = EthosUCompileSpec(
    target="ethos-u55-128",
    system_config="Ethos_U55_High_End_Embedded",
    memory_mode="Shared_Sram",
)

quantizer = EthosUQuantizer(compile_spec)
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)

prepared_model = prepare_pt2e(graph_model, quantizer)

# Calibrate (simplified)
with torch.no_grad():
    prepared_model(example_input)

# Convert to quantized model
quantized_graph_model = convert_pt2e(prepared_model, fold_quantize=True)
quantized_graph_model._exported_training = False
quantized_graph_model.to(device)

# Error occurs here when trying to run on CUDA/MPS
test_input = torch.randn(batch_size, 3, 224, 224).to(device)
output = quantized_graph_model(test_input)  # RuntimeError!

Current Workaround
A helper function move_after_dequant_to_device() is required to manually move dequantization parameters to the target device:

def move_after_dequant_to_device(gm: fx.GraphModule, device="mps"):
    dev = torch.device(device)
    g = gm.graph
    for n in list(g.nodes):
        if n.op == "call_function" and n.target in DEQUANTS:
            with g.inserting_after(n):
                to_n = g.call_function(
                    torch.ops.aten._to_copy.default,
                    args=(n,),
                    kwargs=dict(dtype=None, layout=None, device=dev,
                                pin_memory=False, non_blocking=False, memory_format=None),
                )
            n.replace_all_uses_with(to_n)
            to_n.replace_input_with(to_n, n)
    gm.recompile()
    return gm

quantized_graph_model = move_after_dequant_to_device(quantized_graph_model, device)

Versions

PyTorch version: 2.9.1+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.3 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: version 3.31.6
Libc version: glibc-2.39

Python version: 3.12.3 (main, Aug 14 2025, 17:47:21) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.12.55-74.119.amzn2023.x86_64-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to:
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
GPU 4: NVIDIA A100-SXM4-80GB
GPU 5: NVIDIA A100-SXM4-80GB
GPU 6: NVIDIA A100-SXM4-80GB
GPU 7: NVIDIA A100-SXM4-80GB

Nvidia driver version: 580.105.08
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.14.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.14.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.14.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.14.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.14.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.14.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.14.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.14.0
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 96
On-line CPU(s) list: 0-95
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz
CPU family: 6
Model: 85
Thread(s) per core: 2
Core(s) per socket: 24
Socket(s): 2
Stepping: 7
BogoMIPS: 5999.99
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 1.5 MiB (48 instances)
L1i cache: 1.5 MiB (48 instances)
L2 cache: 48 MiB (48 instances)
L3 cache: 71.5 MiB (2 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-23,48-71
NUMA node1 CPU(s): 24-47,72-95
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Indirect target selection: Mitigation; Aligned branch/return thunks
Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Vulnerable
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Retpoline
Vulnerability Srbds: Not affected
Vulnerability Tsa: Not affected
Vulnerability Tsx async abort: Not affected
Vulnerability Vmscape: Not affected

Versions of relevant libraries:
[pip3] executorch==1.0.1
[pip3] intel-openmp==2021.4.0
[pip3] mkl==2021.1.1
[pip3] mkl-devel==2021.1.1
[pip3] mkl-include==2021.1.1
[pip3] mypy_extensions==1.1.0
[pip3] numpy==2.3.5
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cudnn-frontend==1.14.1
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.5
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] onnx==1.18.0
[pip3] onnx-ir==0.1.11
[pip3] onnxscript==0.5.4
[pip3] optree==0.17.0
[pip3] pytorch_tokenizers==1.0.1
[pip3] pytorch-triton==3.4.0+gitc817b9b6
[pip3] tbb==2021.13.1
[pip3] torch==2.9.1+cu128
[pip3] torch_tensorrt==2.9.0a0
[pip3] torchao==0.14.0+git
[pip3] torchprofile==0.0.4
[pip3] torchvision==0.24.1+cu128
[pip3] triton==3.5.1

cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai

Metadata

Metadata

Labels

module: mpsIssues related to Apple's MPS delegation and code under backends/apple/mps/partner: armFor backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm

Type

No type

Projects

Status

To triage

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions