Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
is_torch_npu_available,
logging,
)
from .utils.import_utils import is_tracing


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -174,6 +175,8 @@ def _unpad_input(hidden_states, attention_mask, unused_mask=None):
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
# `.item()` is necessary to work with torch compile as the FA API requires base ints, not tensors.
# You might need to set `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1`.
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))

Expand Down Expand Up @@ -223,8 +226,8 @@ def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.T
"""
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
# NOTE: Similar to the `.item()` in prepare_fa2_from_position_ids, with torch compile,
# this might cause a graph break
# `.item()` is necessary to work with torch compile as the FA API requires base ints, not tensors.
# You might need to set `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1`.
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
Expand Down Expand Up @@ -346,11 +349,8 @@ def prepare_fa_kwargs_from_position_ids(position_ids):
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
# for some models (e.g. qwen2-vl).
max_length_q = cu_seq_lens_q.diff().max()
# NOTE: With torch compile, this will cause a graph break if you don't set
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
# This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
# requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
# `.item()` is necessary to work with torch compile as the FA API requires base ints, not tensors.
# You might need to set `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1`.
max_length_q = max_length_q.item()
max_length_k = max_length_q

Expand Down Expand Up @@ -401,8 +401,11 @@ def _is_packed_sequence(position_ids, batch_size):
1. Position ids exist
2. Flattened sequences only are supported
3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. we have multiple increasing sequences

NOTE: We disable this feature if torch compile or similar features are used due to dynamic control flows
we cannot avoid without losing control over the gradients, e.g. via `torch.cond`.
"""
if position_ids is None:
if is_tracing(position_ids) or position_ids is None:
return False

increasing_position_sequences = (
Expand Down Expand Up @@ -592,8 +595,10 @@ def _flash_attention_forward(

# We will use `flash_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases:
# Case 1. If position ids is provided and the position ids indicate packed sequences, see `_is_packed_sequence`.
# --> not compile friendly, will be ignored if torch compile is used
# Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to
# use `flash_varlen_fn` knowing we already have all necessary the kwargs.
# use `flash_varlen_fn` knowing we already have all necessary the kwargs.
# --> compile friendly, preferred option to use
#
# NOTE: it is user's responsibility to take care of flattening `position_ids` if that's needed by the model.
# See #39121 for more information.
Expand Down