diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 201ea3eff305..356006675a5b 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -26,6 +26,7 @@ is_torch_npu_available, logging, ) +from .utils.import_utils import is_tracing logger = logging.get_logger(__name__) @@ -174,6 +175,8 @@ def _unpad_input(hidden_states, attention_mask, unused_mask=None): seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + # `.item()` is necessary to work with torch compile as the FA API requires base ints, not tensors. + # You might need to set `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1`. max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) @@ -223,8 +226,8 @@ def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.T """ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - # NOTE: Similar to the `.item()` in prepare_fa2_from_position_ids, with torch compile, - # this might cause a graph break + # `.item()` is necessary to work with torch compile as the FA API requires base ints, not tensors. + # You might need to set `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1`. max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return ( @@ -346,11 +349,8 @@ def prepare_fa_kwargs_from_position_ids(position_ids): # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing # for some models (e.g. qwen2-vl). max_length_q = cu_seq_lens_q.diff().max() - # NOTE: With torch compile, this will cause a graph break if you don't set - # `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call - # `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass. - # This is a limitation of flash attention API, as the function `flash_attn_varlen_func` - # requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`. + # `.item()` is necessary to work with torch compile as the FA API requires base ints, not tensors. + # You might need to set `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1`. max_length_q = max_length_q.item() max_length_k = max_length_q @@ -401,8 +401,11 @@ def _is_packed_sequence(position_ids, batch_size): 1. Position ids exist 2. Flattened sequences only are supported 3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. we have multiple increasing sequences + + NOTE: We disable this feature if torch compile or similar features are used due to dynamic control flows + we cannot avoid without losing control over the gradients, e.g. via `torch.cond`. """ - if position_ids is None: + if is_tracing(position_ids) or position_ids is None: return False increasing_position_sequences = ( @@ -592,8 +595,10 @@ def _flash_attention_forward( # We will use `flash_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases: # Case 1. If position ids is provided and the position ids indicate packed sequences, see `_is_packed_sequence`. + # --> not compile friendly, will be ignored if torch compile is used # Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to - # use `flash_varlen_fn` knowing we already have all necessary the kwargs. + # use `flash_varlen_fn` knowing we already have all necessary the kwargs. + # --> compile friendly, preferred option to use # # NOTE: it is user's responsibility to take care of flattening `position_ids` if that's needed by the model. # See #39121 for more information.