A bunch. of CUDA errors appearing in the wild
Hi, and thanks for sharing the model!
For some reason, I get spammed with CUDA device-side asserts when I want to run your sample code. Do you have any idea what could be the issue?
This is the assert message: (I get about 1000 lines of this)/opt/conda/conda-bld/pytorch_1724789560443/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [8201,0,0], thread: [31,0,0] Assertion
-sizes[i] <= index && index < sizes[i] && "index out of bounds" failed.
This is the code I'm using:
from transformers import (
PaliGemmaProcessor,
PaliGemmaForConditionalGeneration,
)
from transformers.image_utils import load_image
import torch
model_id = "google/paligemma2-3b-pt-896"
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
image = load_image(url)
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
device_map="auto",
).eval()
processor = PaliGemmaProcessor.from_pretrained(model_id)
prompt = "<image> What is this?"
model_inputs = (
processor(text=prompt, images=image, return_tensors="pt")
.to(torch.bfloat16)
.to(model.device)
)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)
These are my env specs:
- Python 3.10
- transformers 4.48.0.dev0
- torch 2.4.1
- CUDA 12.1
- A100 80GB GPU
This is the full Python error:
Show
RuntimeError Traceback (most recent call last)
Cell In[20], line 11
8 input_len = model_inputs["input_ids"].shape[-1]
10 with torch.inference_mode():
---> 11 generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
12 generation = generation[0][input_len:]
13 decoded = processor.decode(generation, skip_special_tokens=True)
File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/transformers/generation/utils.py:2254, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
2246 input_ids, model_kwargs = self._expand_inputs_for_generation(
2247 input_ids=input_ids,
2248 expand_size=generation_config.num_return_sequences,
2249 is_encoder_decoder=self.config.is_encoder_decoder,
2250 **model_kwargs,
2251 )
2253 # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 2254 result = self._sample(
2255 input_ids,
2256 logits_processor=prepared_logits_processor,
2257 stopping_criteria=prepared_stopping_criteria,
2258 generation_config=generation_config,
2259 synced_gpus=synced_gpus,
2260 streamer=streamer,
2261 **model_kwargs,
2262 )
2264 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
2265 # 11. prepare beam search scorer
2266 beam_scorer = BeamSearchScorer(
2267 batch_size=batch_size,
2268 num_beams=generation_config.num_beams,
(...)
2273 max_length=generation_config.max_length,
2274 )
File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/transformers/generation/utils.py:3253, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)
3250 model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
3252 if is_prefill:
-> 3253 outputs = self(**model_inputs, return_dict=True)
3254 is_prefill = False
3255 else:
File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None
File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/transformers/models/paligemma/modeling_paligemma.py:530, in PaliGemmaForConditionalGeneration.forward(self, input_ids, pixel_values, attention_mask, position_ids, past_key_values, token_type_ids, cache_position, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, num_logits_to_keep)
525 labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
527 causal_mask = self._update_causal_mask(
528 attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training
529 )
--> 530 outputs = self.language_model(
531 attention_mask=causal_mask,
532 position_ids=position_ids,
533 past_key_values=past_key_values,
534 inputs_embeds=inputs_embeds,
535 use_cache=use_cache,
536 output_attentions=output_attentions,
537 output_hidden_states=output_hidden_states,
538 return_dict=return_dict,
539 cache_position=cache_position,
540 num_logits_to_keep=num_logits_to_keep,
541 )
543 logits = outputs.logits
544 loss = None
File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None
File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:846, in Gemma2ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep, **loss_kwargs)
844 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
845 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 846 outputs = self.model(
847 input_ids=input_ids,
848 attention_mask=attention_mask,
849 position_ids=position_ids,
850 past_key_values=past_key_values,
851 inputs_embeds=inputs_embeds,
852 use_cache=use_cache,
853 output_attentions=output_attentions,
854 output_hidden_states=output_hidden_states,
855 return_dict=return_dict,
856 cache_position=cache_position,
857 )
859 hidden_states = outputs[0]
860 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None
File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:634, in Gemma2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
622 layer_outputs = self._gradient_checkpointing_func(
623 decoder_layer.__call__,
624 hidden_states,
(...)
631 cache_position,
632 )
633 else:
--> 634 layer_outputs = decoder_layer(
635 hidden_states,
636 position_embeddings=position_embeddings,
637 attention_mask=causal_mask,
638 position_ids=position_ids,
639 past_key_value=past_key_values,
640 output_attentions=output_attentions,
641 use_cache=use_cache,
642 cache_position=cache_position,
643 )
645 hidden_states = layer_outputs[0]
647 if output_attentions:
File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None
File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:299, in Gemma2DecoderLayer.forward(self, hidden_states, position_embeddings, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
296 hidden_states = self.input_layernorm(hidden_states)
298 # Self Attention
--> 299 hidden_states, self_attn_weights = self.self_attn(
300 hidden_states=hidden_states,
301 position_embeddings=position_embeddings,
302 attention_mask=attention_mask,
303 position_ids=position_ids,
304 past_key_value=past_key_value,
305 output_attentions=output_attentions,
306 use_cache=use_cache,
307 cache_position=cache_position,
308 )
309 hidden_states = self.post_attention_layernorm(hidden_states)
310 hidden_states = residual + hidden_states
File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None
File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:236, in Gemma2Attention.forward(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs)
233 else:
234 attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
--> 236 attn_output, attn_weights = attention_interface(
237 self,
238 query_states,
239 key_states,
240 value_states,
241 attention_mask,
242 dropout=self.attention_dropout if self.training else 0.0,
243 scaling=self.scaling,
244 sliding_window=self.sliding_window,
245 softcap=self.attn_logit_softcapping,
246 **kwargs,
247 )
249 attn_output = attn_output.reshape(*input_shape, -1).contiguous()
250 attn_output = self.o_proj(attn_output)
File ~/miniforge3/envs/ckt/lib/python3.12/site-packages/transformers/integrations/sdpa_attention.py:48, in sdpa_attention_forward(module, query, key, value, attention_mask, dropout, scaling, is_causal, **kwargs)
45 if is_causal is None:
46 is_causal = causal_mask is None and query.shape[2] > 1
---> 48 attn_output = torch.nn.functional.scaled_dot_product_attention(
49 query,
50 key,
51 value,
52 attn_mask=causal_mask,
53 dropout_p=dropout,
54 scale=scaling,
55 is_causal=is_causal,
56 )
57 attn_output = attn_output.transpose(1, 2).contiguous()
59 return attn_output, None
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Hi @floschne ,
I have tested the code on T4 GPU and A100 40GB GPU without encountering any errors, the code executed successfully.
Possible reasons for the issue could be related to as follows:
1. Precision Issues which means A100 80GB GPU may be using mixed precision (e.g., FP16 or TF32) more aggressively, leading to unstable values in scaled dot product attention.
2. Large Batch Size and Tensor Overflow which means the larger available memory of 80GB may allow the script to allocate more memory or process larger tensor sizes, leading to tensor overflow or invalid operations.
3. CUDA Kernel Version Mismatch which means The CUDA kernel for scaled dot product attention may differ in versions, causing an assert failure due to mismatched tensor shapes or missing kernel support.
If the issue persists, please try using T4 GPU as the runtime type in Google Colab, this should resolve the problem. Could you please refer to this gist file.
Thank you.