A bunch. of CUDA errors appearing in the wild

#5
by floschne - opened

Hi, and thanks for sharing the model!

For some reason, I get spammed with CUDA device-side asserts when I want to run your sample code. Do you have any idea what could be the issue?

This is the assert message: (I get about 1000 lines of this)
/opt/conda/conda-bld/pytorch_1724789560443/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [8201,0,0], thread: [31,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed.

This is the code I'm using:

from transformers import (
    PaliGemmaProcessor,
    PaliGemmaForConditionalGeneration,
)
from transformers.image_utils import load_image
import torch
model_id = "google/paligemma2-3b-pt-896"

url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
image = load_image(url)

model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    device_map="auto",
).eval()
processor = PaliGemmaProcessor.from_pretrained(model_id)


prompt = "<image> What is this?"
model_inputs = (
    processor(text=prompt, images=image, return_tensors="pt")
    .to(torch.bfloat16)
    .to(model.device)
)

input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
    generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
    generation = generation[0][input_len:]
    decoded = processor.decode(generation, skip_special_tokens=True)
    print(decoded)

These are my env specs:

  • Python 3.10
  • transformers 4.48.0.dev0
  • torch 2.4.1
  • CUDA 12.1
  • A100 80GB GPU

This is the full Python error:

Show
RuntimeError                              Traceback (most recent call last)
Cell In[20], line 11
      8 input_len = model_inputs["input_ids"].shape[-1]
     10 with torch.inference_mode():
---> 11     generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
     12     generation = generation[0][input_len:]
     13     decoded = processor.decode(generation, skip_special_tokens=True)

File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/transformers/generation/utils.py:2254, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   2246     input_ids, model_kwargs = self._expand_inputs_for_generation(
   2247         input_ids=input_ids,
   2248         expand_size=generation_config.num_return_sequences,
   2249         is_encoder_decoder=self.config.is_encoder_decoder,
   2250         **model_kwargs,
   2251     )
   2253     # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 2254     result = self._sample(
   2255         input_ids,
   2256         logits_processor=prepared_logits_processor,
   2257         stopping_criteria=prepared_stopping_criteria,
   2258         generation_config=generation_config,
   2259         synced_gpus=synced_gpus,
   2260         streamer=streamer,
   2261         **model_kwargs,
   2262     )
   2264 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
   2265     # 11. prepare beam search scorer
   2266     beam_scorer = BeamSearchScorer(
   2267         batch_size=batch_size,
   2268         num_beams=generation_config.num_beams,
   (...)
   2273         max_length=generation_config.max_length,
   2274     )

File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/transformers/generation/utils.py:3253, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)
   3250 model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
   3252 if is_prefill:
-> 3253     outputs = self(**model_inputs, return_dict=True)
   3254     is_prefill = False
   3255 else:

File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/transformers/models/paligemma/modeling_paligemma.py:530, in PaliGemmaForConditionalGeneration.forward(self, input_ids, pixel_values, attention_mask, position_ids, past_key_values, token_type_ids, cache_position, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, num_logits_to_keep)
    525     labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
    527 causal_mask = self._update_causal_mask(
    528     attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training
    529 )
--> 530 outputs = self.language_model(
    531     attention_mask=causal_mask,
    532     position_ids=position_ids,
    533     past_key_values=past_key_values,
    534     inputs_embeds=inputs_embeds,
    535     use_cache=use_cache,
    536     output_attentions=output_attentions,
    537     output_hidden_states=output_hidden_states,
    538     return_dict=return_dict,
    539     cache_position=cache_position,
    540     num_logits_to_keep=num_logits_to_keep,
    541 )
    543 logits = outputs.logits
    544 loss = None

File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:846, in Gemma2ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep, **loss_kwargs)
    844 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    845 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 846 outputs = self.model(
    847     input_ids=input_ids,
    848     attention_mask=attention_mask,
    849     position_ids=position_ids,
    850     past_key_values=past_key_values,
    851     inputs_embeds=inputs_embeds,
    852     use_cache=use_cache,
    853     output_attentions=output_attentions,
    854     output_hidden_states=output_hidden_states,
    855     return_dict=return_dict,
    856     cache_position=cache_position,
    857 )
    859 hidden_states = outputs[0]
    860 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss

File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:634, in Gemma2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    622     layer_outputs = self._gradient_checkpointing_func(
    623         decoder_layer.__call__,
    624         hidden_states,
   (...)
    631         cache_position,
    632     )
    633 else:
--> 634     layer_outputs = decoder_layer(
    635         hidden_states,
    636         position_embeddings=position_embeddings,
    637         attention_mask=causal_mask,
    638         position_ids=position_ids,
    639         past_key_value=past_key_values,
    640         output_attentions=output_attentions,
    641         use_cache=use_cache,
    642         cache_position=cache_position,
    643     )
    645 hidden_states = layer_outputs[0]
    647 if output_attentions:

File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:299, in Gemma2DecoderLayer.forward(self, hidden_states, position_embeddings, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
    296 hidden_states = self.input_layernorm(hidden_states)
    298 # Self Attention
--> 299 hidden_states, self_attn_weights = self.self_attn(
    300     hidden_states=hidden_states,
    301     position_embeddings=position_embeddings,
    302     attention_mask=attention_mask,
    303     position_ids=position_ids,
    304     past_key_value=past_key_value,
    305     output_attentions=output_attentions,
    306     use_cache=use_cache,
    307     cache_position=cache_position,
    308 )
    309 hidden_states = self.post_attention_layernorm(hidden_states)
    310 hidden_states = residual + hidden_states

File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:236, in Gemma2Attention.forward(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs)
    233     else:
    234         attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
--> 236 attn_output, attn_weights = attention_interface(
    237     self,
    238     query_states,
    239     key_states,
    240     value_states,
    241     attention_mask,
    242     dropout=self.attention_dropout if self.training else 0.0,
    243     scaling=self.scaling,
    244     sliding_window=self.sliding_window,
    245     softcap=self.attn_logit_softcapping,
    246     **kwargs,
    247 )
    249 attn_output = attn_output.reshape(*input_shape, -1).contiguous()
    250 attn_output = self.o_proj(attn_output)

File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/transformers/integrations/sdpa_attention.py:48, in sdpa_attention_forward(module, query, key, value, attention_mask, dropout, scaling, is_causal, **kwargs)
     45 if is_causal is None:
     46     is_causal = causal_mask is None and query.shape[2] > 1
---> 48 attn_output = torch.nn.functional.scaled_dot_product_attention(
     49     query,
     50     key,
     51     value,
     52     attn_mask=causal_mask,
     53     dropout_p=dropout,
     54     scale=scaling,
     55     is_causal=is_causal,
     56 )
     57 attn_output = attn_output.transpose(1, 2).contiguous()
     59 return attn_output, None

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Google org

Hi @floschne ,

I have tested the code on T4 GPU and A100 40GB GPU without encountering any errors, the code executed successfully.

Possible reasons for the issue could be related to as follows:

1. Precision Issues which means A100 80GB GPU may be using mixed precision (e.g., FP16 or TF32) more aggressively, leading to unstable values in scaled dot product attention.

2. Large Batch Size and Tensor Overflow which means the larger available memory of 80GB may allow the script to allocate more memory or process larger tensor sizes, leading to tensor overflow or invalid operations.

 3. CUDA Kernel Version Mismatch which means The CUDA kernel for scaled dot product attention may differ in versions, causing an assert failure due to mismatched tensor shapes or missing kernel support.

If the issue persists, please try using T4 GPU as the runtime type in Google Colab, this should resolve the problem. Could you please refer to this gist file.

Thank you.

Sign up or log in to comment