- 
                Notifications
    
You must be signed in to change notification settings  - Fork 31k
 
Open
Labels
Feature requestRequest for a new featureRequest for a new feature
Description
Feature request
Repro:
import torch
from transformers import AutoModelForCausalLM
device = "cuda"
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
).to(device).eval()
model = torch.compile(model, fullgraph=True)
# dummy inputs; we only want logits
bsz, seqlen = 1, 128
inp = torch.randint(0, model.config.vocab_size, (bsz, seqlen), device=device)
with torch.inference_mode():
    model(input_ids=inp)Output:
Traceback (most recent call last):
  File "/home/ryanguo99/repos/verl/run.py", line 20, in <module>
    model(input_ids=inp)
  File "/home/ryanguo99/.conda/envs/verl-nightly/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 418, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/.conda/envs/verl-nightly/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/.conda/envs/verl-nightly/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1788, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/.conda/envs/verl-nightly/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 895, in compile_wrapper
    raise e.with_traceback(None) from e.__cause__  # User compiler error
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.Unsupported: Data-dependent branching
  Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
  Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
  Hint: Use `torch.cond` to express dynamic control flow.
  Developer debug context: attempted to jump with TensorVariable()
 For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0170.html
from user code:
   File "/home/ryanguo99/.conda/envs/verl-nightly/lib/python3.12/site-packages/transformers/utils/generic.py", line 918, in wrapper
    output = func(self, *args, **kwargs)
  File "/home/ryanguo99/.conda/envs/verl-nightly/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 459, in forward
    outputs: BaseModelOutputWithPast = self.model(
  File "/home/ryanguo99/.conda/envs/verl-nightly/lib/python3.12/site-packages/transformers/utils/generic.py", line 1064, in wrapper
    outputs = func(self, *args, **kwargs)
  File "/home/ryanguo99/.conda/envs/verl-nightly/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 395, in forward
    hidden_states = decoder_layer(
  File "/home/ryanguo99/.conda/envs/verl-nightly/lib/python3.12/site-packages/transformers/modeling_layers.py", line 94, in __call__
    return super().__call__(*args, **kwargs)
  File "/home/ryanguo99/.conda/envs/verl-nightly/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ryanguo99/.conda/envs/verl-nightly/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1788, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ryanguo99/.conda/envs/verl-nightly/lib/python3.12/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
  File "/home/ryanguo99/.conda/envs/verl-nightly/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 294, in forward
    hidden_states, _ = self.self_attn(
  File "/home/ryanguo99/.conda/envs/verl-nightly/lib/python3.12/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
  File "/home/ryanguo99/.conda/envs/verl-nightly/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 252, in forward
    attn_output, attn_weights = attention_interface(
  File "/home/ryanguo99/.conda/envs/verl-nightly/lib/python3.12/site-packages/transformers/integrations/flash_attention.py", line 66, in flash_attention_forward
    attn_output = _flash_attention_forward(
  File "/home/ryanguo99/.conda/envs/verl-nightly/lib/python3.12/site-packages/transformers/modeling_flash_attention_utils.py", line 632, in _flash_attention_forward
    elif is_fa_with_varlen_kwargs or is_fa_with_position_ids:
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
Motivation
In LLM RL frameworks that uses transformers for training models, typically attn_implementation="flash_attention_2", is used, e.g., verl, because the default SDPA backend can't route to flash attention under variable sequence length.
So this graph break negatively affects the performance of compiled model.
Your contribution
.
Metadata
Metadata
Assignees
Labels
Feature requestRequest for a new featureRequest for a new feature