Skip to content

Conversation

@jackiehimel
Copy link

@jackiehimel jackiehimel commented Oct 22, 2025

What does this PR do?

Adds SDPA and FlashAttention-2 support to LayoutLMv3 following the same pattern as other models. Fully backward compatible.

SDPA converts masks to boolean format. FA2 uses _upad_input for variable-length sequences and avoids redundant unpads. Both fall back gracefully when needed. FA2 is O(N) memory vs O(N²).

Fixes #35467

Changes

  • Added LayoutLMv3SdpaAttention using torch.nn.functional.scaled_dot_product_attention
  • Added LayoutLMv3FlashAttention2 with flash_attn_func / flash_attn_varlen_func
  • Both inherit from LayoutLMv3Attention
  • Fallback to standard attention when backends unavailable or output_attentions=True / relative position bias is used

Testing

  • 121 tests passed in test_modeling_layoutlmv3.py
  • Manually verified forward passes with/without attention masks

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@vasqu @ArthurZucker @Cyrilvallez

jackiehimel added 9 commits October 22, 2025 16:39
- Removed duplicate imports
- Fixed attention component access patterns
- Added missing _upad_input method for Flash Attention 2
- Corrected relative position bias handling
- Added proper fallbacks for unsupported features
- Remove unsupported head_mask parameter from super().forward() calls
- Fix attention mask shape handling for SDPA (convert 4D mask to boolean format)
- Maintain proper mask application in relative position bias fallback path
…ad_mask from super calls, and fixed attention mask format for SDPA.
- Fix attention mask conversion for SDPA (use >= 0 instead of > 0)
- Fix _upad_input to use actual unpadded tensors from unpad_input
- Optimize _upad_input to call unpad_input only once for self-attention
- Remove unused head_mask parameter from forward methods
Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach is sadly outdated, I made a comment on what to do instead. I suspect that the tests also passed because the model didn't update their _supports_xxx flag. Without them these tests for FA, SDPA won't be run.

return outputs


class LayoutLMv3SdpaAttention(LayoutLMv3Attention):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry but this approach is outdated and we changed it to a unified class instead. An older Bert version should give a good idea over here

def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: Optional[float] = None,
dropout: float = 0.0,
use_cache: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
):
if scaling is None:
scaling = query.size(-1) ** -0.5
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(2, 3))
# Relative positional embeddings
if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query":
query_length, key_length = query.shape[2], key.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility
if module.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
attn_weights = attn_weights + relative_position_scores
elif module.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding)
attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key
# Scaling is shifted in case of embeddings being relative
attn_weights = attn_weights * scaling
if attention_mask is not None and attention_mask.ndim == 4:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class BertSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.scaling = self.attention_head_size**-0.5
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder
self.is_causal = is_causal
self.layer_idx = layer_idx
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.attention_head_size)
# get all proj
query_layer = self.query(hidden_states).view(*hidden_shape).transpose(1, 2)
key_layer = self.key(hidden_states).view(*hidden_shape).transpose(1, 2)
value_layer = self.value(hidden_states).view(*hidden_shape).transpose(1, 2)
if past_key_value is not None:
# decoder-only bert can have a simple dynamic cache for example
current_past_key_value = past_key_value
if isinstance(past_key_value, EncoderDecoderCache):
current_past_key_value = past_key_value.self_attention_cache
# save all key/value_layer to cache to be re-used for fast auto-regressive generation
key_layer, value_layer = current_past_key_value.update(
key_layer,
value_layer,
self.layer_idx,
{"cache_position": cache_position},
)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.position_embedding_type != "absolute":
raise ValueError(
f"You are using {self.config._attn_implementation} as attention type. However, non-absolute "
'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.'
)
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_layer,
key_layer,
value_layer,
attention_mask,
dropout=0.0 if not self.training else self.dropout.p,
scaling=self.scaling,
# only for relevant for non-absolute positional embeddings
use_cache=past_key_value is not None,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
return attn_output, attn_weights

the relative positions are no longer in main (deprecated) but since they are used here, it's a better pointer

We will also need to update the mask creations to something along

attention_mask = create_bidirectional_mask(
config=self.config,
input_embeds=embedding_output,
attention_mask=attention_mask,
)

And you will need _supports_xxx flags in order to indicate that the model really supports these attention flavors, e.g.

_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback @vasqu! You're right - I've refactored to the unified class approach following the BERT pattern. I added the _supports_sdpa and _supports_flash_attn_2 flags to the config and updated the mask creation to use create_bidirectional_mask.

Just pushed the changes - let me know if there's anything else that needs adjusting. Thanks again!

jackiehimel and others added 8 commits October 24, 2025 20:20
- Replace separate SDPA and FlashAttention classes with unified LayoutLMv3Attention
- Add _supports_sdpa and _supports_flash_attn_2 config flags
- Add _attn_implementation parameter to config
- Implement runtime dispatch based on attention implementation
- Add proper fallback logic for unsupported features
- Fix linting and formatting issues
- All tests passing
- Add proper None checks for relative position bias
- Fix attention mask shape handling for SDPA
- Add fallback imports for flash_attn functions
- Ensure proper mask broadcasting for scaled_dot_product_attention
- All attention implementations now work correctly
@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: layoutlmv3

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check that you follow what other model implementations do for integrating other attentions variations. Bert is a good model to follow in this case, I also referenced an older version of Bert (with relative position support) which should help here. We like to avoid using custom paths and use our interface to handle specifics.

Currently, we likely still dont support the other attention variations as the flags are set incorrectly.

Comment on lines +109 to +111
# Support flags for attention implementations
_supports_sdpa = True
_supports_flash_attn_2 = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These flags go under the pretrained model class, e.g. in llama

_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_can_compile_fullgraph = True
_supports_attention_backend = True

self.num_channels = num_channels
self.patch_size = patch_size
self.classifier_dropout = classifier_dropout
self._attn_implementation = _attn_implementation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We handle this already, no need to modify anything in the config

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be removed?

Comment on lines +347 to +348
# Adapted from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3
# Enhanced with unified attention implementation supporting eager, SDPA, and FlashAttention-2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better if we could change the upstream model instead. Or only copy relevant parts.

BUT, we also shouldn't modify this class by much. You should follow the structure of Bert for example:

  • Wrapper before the actual attention classes (without the additional cross attn logic for this model here)
    class BertAttention(nn.Module):
    def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False):
    super().__init__()
    self.is_cross_attention = is_cross_attention
    attention_class = BertCrossAttention if is_cross_attention else BertSelfAttention
    self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx)
    self.output = BertSelfOutput(config)
    def forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.FloatTensor] = None,
    encoder_hidden_states: Optional[torch.FloatTensor] = None,
    encoder_attention_mask: Optional[torch.FloatTensor] = None,
    past_key_values: Optional[Cache] = None,
    cache_position: Optional[torch.Tensor] = None,
    **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[torch.Tensor]:
    attention_mask = attention_mask if not self.is_cross_attention else encoder_attention_mask
    attention_output, attn_weights = self.self(
    hidden_states,
    encoder_hidden_states=encoder_hidden_states,
    attention_mask=attention_mask,
    past_key_values=past_key_values,
    cache_position=cache_position,
    **kwargs,
    )
    attention_output = self.output(attention_output, hidden_states)
    return attention_output, attn_weights
  • The actual attention class
    def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: Optional[float] = None,
    dropout: float = 0.0,
    use_cache: Optional[bool] = None,
    **kwargs: Unpack[TransformersKwargs],
    ):
    if scaling is None:
    scaling = query.size(-1) ** -0.5
    # Take the dot product between "query" and "key" to get the raw attention scores.
    attn_weights = torch.matmul(query, key.transpose(2, 3))
    # Relative positional embeddings
    if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query":
    query_length, key_length = query.shape[2], key.shape[2]
    if use_cache:
    position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1)
    else:
    position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1)
    position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1)
    distance = position_ids_l - position_ids_r
    positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1)
    positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility
    if module.position_embedding_type == "relative_key":
    relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
    attn_weights = attn_weights + relative_position_scores
    elif module.position_embedding_type == "relative_key_query":
    relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
    relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding)
    attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key
    # Scaling is shifted in case of embeddings being relative
    attn_weights = attn_weights * scaling
    if attention_mask is not None and attention_mask.ndim == 4:
    attention_mask = attention_mask[:, :, :, : key.shape[-2]]
    attn_weights = attn_weights + attention_mask
    attn_weights = nn.functional.softmax(attn_weights, dim=-1)
    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
    attn_output = torch.matmul(attn_weights, value)
    attn_output = attn_output.transpose(1, 2).contiguous()
    return attn_output, attn_weights
    class BertSelfAttention(nn.Module):
    def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None):
    super().__init__()
    if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
    raise ValueError(
    f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
    f"heads ({config.num_attention_heads})"
    )
    self.config = config
    self.num_attention_heads = config.num_attention_heads
    self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
    self.all_head_size = self.num_attention_heads * self.attention_head_size
    self.scaling = self.attention_head_size**-0.5
    self.query = nn.Linear(config.hidden_size, self.all_head_size)
    self.key = nn.Linear(config.hidden_size, self.all_head_size)
    self.value = nn.Linear(config.hidden_size, self.all_head_size)
    self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
    self.position_embedding_type = position_embedding_type or getattr(
    config, "position_embedding_type", "absolute"
    )
    if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
    self.max_position_embeddings = config.max_position_embeddings
    self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
    self.is_decoder = config.is_decoder
    self.is_causal = is_causal
    self.layer_idx = layer_idx
    def forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.FloatTensor] = None,
    past_key_value: Optional[Cache] = None,
    cache_position: Optional[torch.Tensor] = None,
    **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[torch.Tensor]:
    input_shape = hidden_states.shape[:-1]
    hidden_shape = (*input_shape, -1, self.attention_head_size)
    # get all proj
    query_layer = self.query(hidden_states).view(*hidden_shape).transpose(1, 2)
    key_layer = self.key(hidden_states).view(*hidden_shape).transpose(1, 2)
    value_layer = self.value(hidden_states).view(*hidden_shape).transpose(1, 2)
    if past_key_value is not None:
    # decoder-only bert can have a simple dynamic cache for example
    current_past_key_value = past_key_value
    if isinstance(past_key_value, EncoderDecoderCache):
    current_past_key_value = past_key_value.self_attention_cache
    # save all key/value_layer to cache to be re-used for fast auto-regressive generation
    key_layer, value_layer = current_past_key_value.update(
    key_layer,
    value_layer,
    self.layer_idx,
    {"cache_position": cache_position},
    )
    attention_interface: Callable = eager_attention_forward
    if self.config._attn_implementation != "eager":
    if self.position_embedding_type != "absolute":
    raise ValueError(
    f"You are using {self.config._attn_implementation} as attention type. However, non-absolute "
    'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.'
    )
    attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
    attn_output, attn_weights = attention_interface(
    self,
    query_layer,
    key_layer,
    value_layer,
    attention_mask,
    dropout=0.0 if not self.training else self.dropout.p,
    scaling=self.scaling,
    # only for relevant for non-absolute positional embeddings
    use_cache=past_key_value is not None,
    **kwargs,
    )
    attn_output = attn_output.reshape(*input_shape, -1).contiguous()
    return attn_output, attn_weights

Currently, the wrapper class is modified and a lot of custom paths are introduced. We use our own integrations with the interface here (linking the files as examples below as well)

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.position_embedding_type != "absolute":
raise ValueError(
f"You are using {self.config._attn_implementation} as attention type. However, non-absolute "
'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.'
)
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

What needs to be handled:

  • the mask creation
  • adapting the new attention class (as per Bert for example)

self,
hidden_states,
attention_mask=None,
head_mask=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Head mask is deprecated

outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs

def _sdpa_attention_forward(self, hidden_states, attention_mask=None, head_mask=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


return (attn_output,)

def _flash_attention_forward(self, hidden_states, attention_mask=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support SDPA & Flash Attention 2 for LayoutLMv3

2 participants