Skip to content

[Performance] Long-context OOM on newer architectures #25965

@IkerJansa44

Description

@IkerJansa44

Describe the issue

With ONNX Runtime GenAI (INT4, CUDA), newer architectures using grouped-query attention (GQA) (e.g., Qwen3-1.7B-Instruct and Phi-4-mini-instruct) show a larger VRAM footprint and OOM on a fixed long context when Phi-3.5-mini-instruct completes the inference and still leaves >6 GB VRAM free on the same hardware (see screenshots). This contradicts the expected KV-cache savings from GQA.

Image

Phi3.5 peak GPU utilization:
Image

To reproduce

  1. Create the venv
certifi==2025.8.3
charset-normalizer==3.4.3
coloredlogs==15.0.1
filelock==3.19.1
flatbuffers==25.2.10
fsspec==2025.9.0
hf-xet==1.1.9
huggingface-hub==0.34.4
humanfriendly==10.0
idna==3.10
jinja2==3.1.6
markupsafe==3.0.2
ml-dtypes==0.5.3
mpmath==1.3.0
networkx==3.4.2
numpy==2.2.6
nvidia-cublas-cu12==12.8.4.1
nvidia-cuda-cupti-cu12==12.8.90
nvidia-cuda-nvrtc-cu12==12.8.93
nvidia-cuda-runtime-cu12==12.8.90
nvidia-cudnn-cu12==9.10.2.21
nvidia-cufft-cu12==11.3.3.83
nvidia-cufile-cu12==1.13.1.3
nvidia-curand-cu12==10.3.9.90
nvidia-cusolver-cu12==11.7.3.90
nvidia-cusparse-cu12==12.5.8.93
nvidia-cusparselt-cu12==0.7.1
nvidia-nccl-cu12==2.27.3
nvidia-nvjitlink-cu12==12.8.93
nvidia-nvtx-cu12==12.8.90
onnx==1.19.0
onnx-ir==0.1.7
onnxruntime==1.22.1
onnxruntime-genai-cuda==0.9.0
onnxruntime-gpu==1.22.0
packaging==25.0
protobuf==6.32.0
pyyaml==6.0.2
regex==2025.9.1
requests==2.32.5
safetensors==0.6.2
setuptools==80.9.0
sympy==1.14.0
tokenizers==0.22.0
torch==2.8.0
tqdm==4.67.1
transformers==4.56.0
triton==3.4.0
typing-extensions==4.15.0
urllib3==2.5.0
  1. Build INT4 CUDA models (same command for all)
python -m onnxruntime_genai.models.builder -m Qwen/Qwen3-1.7B-Instruct       -o ./qwen3-1_7b_int4 -p int4 -e cuda
python -m onnxruntime_genai.models.builder -m microsoft/Phi-4-mini-instruct  -o ./phi4-mini_int4 -p int4 -e cuda
python -m onnxruntime_genai.models.builder -m microsoft/Phi-3.5-mini-instruct -o ./phi3.5-mini_int4 -p int4 -e cuda
  1. Run this script (including a long text) with all 3 quantized models.
import time
import argparse
import os
import onnxruntime_genai as og
from transformers import AutoTokenizer

class ONNXInferencer:

    def __init__(self, model_path: str, **kwargs):
        self.model_path = model_path

        self.model = og.Model(self.model_path)
        
        # Tokenizers
        self.safetensors_tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        self.onnx_tokenizer = og.Tokenizer(self.model)

        # Collect generation/search options
        allowed = [
            "do_sample",
            "min_length",
            "top_p",
            "top_k",
            "temperature",
            "repetition_penalty",
        ]
        search_options = {name: kwargs[name] for name in allowed if name in kwargs}
        search_options["max_length"] = kwargs.get("max_length", 1024)

        self._onnx_params = og.GeneratorParams(self.model)
        self._onnx_params.set_search_options(**search_options)

    def inference(
        self, prompt: list[dict], return_latency: bool = False
    ) -> str | tuple:
        """
        Perform inference using the ONNX model.
        Args:
            prompt (str | list[dict]): The input prompt for the model. It can be a string or a list of dictionaries.
            return_latency (bool, optional): If True, return the latency and tokens per second (TPS) along with the response. Defaults to False.

        Returns:
            str | tuple: The generated response as a string. If return_latency is True, returns a tuple containing the response,
                 the time taken to generate the first token, and the tokens per second (TPS).
        """

        query = self.safetensors_tokenizer.apply_chat_template(
            prompt, tokenize=False, add_generation_prompt=True
        )
        generator = og.Generator(self.model, self._onnx_params)
        input_tokens = self.onnx_tokenizer.encode(query)
        generator.append_tokens(input_tokens)
        tokenizer_stream = self.onnx_tokenizer.create_stream()
        response = ""
        new_tokens = []
        initial_time = time.time()

        while not generator.is_done():
            generator.generate_next_token()
            new_token = generator.get_next_tokens()[0]
            new_tokens.append(new_token)
            new_text = tokenizer_stream.decode(new_token)
            if response == "":
                first_token_timestamp = time.time()
                time_first_token = first_token_timestamp - initial_time
            response += new_text

        run_time = time.time() - first_token_timestamp
        tps = len(new_tokens) / run_time
        if return_latency:
            return (response, time_first_token, tps)

        return response


text = "<insert-text>"

def main():
    parser = argparse.ArgumentParser(description="ONNX OOM test inferencer")
    parser.add_argument(
        "--model-name",
        type=str,
        default="qwen3_1-7b_onnx_cuda_int4",
        help="Model directory name (last part of the path under --models-base-dir)",
    )
    parser.add_argument(
        "--models-base-dir",
        type=str,
        default="./",
        help="Base directory containing ONNX model folders",
    )
    parser.add_argument(
        "--max-length",
        type=int,
        default=40960,
        help="Maximum generation length passed to ONNXInferencer",
    )
    args = parser.parse_args()

    model_path = os.path.join(args.models_base_dir, args.model_name)
    if not os.path.isdir(model_path):
        raise FileNotFoundError(f"Model path does not exist: {model_path}")

    llm = ONNXInferencer(model_path, max_length=args.max_length)
    safetensors_tokenizer = llm.safetensors_tokenizer

    system_prompt = "You are a helpful, respectful and honest assistant"
    prompt = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": f"Summarize the following text:\n\nText: {text}"},
    ]
    print(f"\033[94mModel name: {args.model_name}\033[0m")
    print(
        f"\033[94mNumber of tokens: {len(safetensors_tokenizer.encode(prompt[1]['content']))}\033[0m"
    )
    output = llm.inference(prompt=prompt)
    print(f"\n\nResponse:\n{output}")


if __name__ == "__main__":
    main()

Urgency

No response

Platform

Linux

OS Version

Ubuntu 24.04.2 LTS (kernel 6.8.0-79-generic)

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.22.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

No response

Model File

No response

Is this a quantized model?

Yes

Metadata

Metadata

Assignees

No one assigned

    Labels

    model:transformerissues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.performanceissues related to performance regressionsstaleissues that have not been addressed in a while; categorized by a bot

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions