From f3a13f0f2e5821e0ac5d0b336a3911f505684619 Mon Sep 17 00:00:00 2001 From: jackiehimel Date: Sat, 18 Oct 2025 02:13:15 -0400 Subject: [PATCH 01/15] Added SDPA and Flash Attention 2 support for LayoutLMv3 (#35467) --- .../models/layoutlmv3/modeling_layoutlmv3.py | 176 ++++++++++++++++++ 1 file changed, 176 insertions(+) diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index 3aa97051f855..88a797b1b504 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -40,6 +40,19 @@ ) from .configuration_layoutlmv3 import LayoutLMv3Config +# SDPA and Flash Attention support +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input + + + logger = logging.get_logger(__name__) @@ -355,6 +368,169 @@ def forward( return outputs + +# MY NEW CLASS +# Additional imports for SDPA and Flash Attention support +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input + + +class LayoutLMv3SdpaAttention(LayoutLMv3Attention): + """ + Implements LayoutLMv3 attention using PyTorch's SDPA backend. + Provides improved speed and memory efficiency while maintaining original model weights. + """ + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + rel_pos=None, + rel_2d_pos=None, + ): + # Use standard attention when attention weights are requested + if output_attentions: + logger.warning_once( + "Manual attention is used since output_attentions=True; this is slower than SDPA." + ) + return super().forward( + hidden_states, + attention_mask, + head_mask, + output_attentions, + rel_pos, + rel_2d_pos, + ) + + batch_size, seq_len, _ = hidden_states.size() + query_layer = self.transpose_for_scores(self.query(hidden_states)) + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + # Add relative position bias if available + if self.has_relative_attention_bias: + query_layer += rel_pos + key_layer += rel_2d_pos + + # Reshape for SDPA: [batch, heads, seq_len, head_dim] + query_layer = query_layer.view(batch_size, self.num_attention_heads, seq_len, self.attention_head_size) + key_layer = key_layer.view(batch_size, self.num_attention_heads, seq_len, self.attention_head_size) + value_layer = value_layer.view(batch_size, self.num_attention_heads, seq_len, self.attention_head_size) + + # Convert mask to expected 4D format and dtype for SDPA + if attention_mask is not None: + attention_mask = attention_mask.view(batch_size, 1, 1, seq_len).expand( + batch_size, self.num_attention_heads, seq_len, seq_len + ) + attention_mask = (1.0 - attention_mask.to(dtype=query_layer.dtype)) * torch.finfo(query_layer.dtype).min + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.dropout.p if self.training else 0.0, + scale=1.0 / math.sqrt(self.attention_head_size), + ) + + attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.all_head_size) + outputs = self.dropout(self.dense(attn_output)) + return (outputs,) + + +class LayoutLMv3FlashAttention2(LayoutLMv3Attention): + """ + Implements LayoutLMv3 attention using the Flash Attention 2 library. + Offers optimal memory usage and speed, with graceful fallback when features are unavailable. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def _reshape(self, tensor, seq_len, batch_size): + return tensor.view(batch_size, seq_len, self.num_attention_heads, self.attention_head_size) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + rel_pos=None, + rel_2d_pos=None, + ): + # Fallback when output_attentions are explicitly requested + if output_attentions: + logger.warning_once( + "Flash Attention 2 does not provide output_attentions; reverting to standard logic." + ) + return super().forward( + hidden_states, attention_mask, head_mask, output_attentions, rel_pos, rel_2d_pos + ) + + batch_size, seq_length, _ = hidden_states.size() + query_states = self._reshape(self.query(hidden_states), seq_length, batch_size) + key_states = self._reshape(self.key(hidden_states), seq_length, batch_size) + value_states = self._reshape(self.value(hidden_states), seq_length, batch_size) + + # Fall back if model requires relative position bias + if self.has_relative_attention_bias and rel_pos is not None: + logger.warning_once( + "Standard attention used as Flash Attention 2 cannot process relative position bias." + ) + return super().forward( + hidden_states, attention_mask, head_mask, output_attentions, rel_pos, rel_2d_pos + ) + + if attention_mask is not None: + # Prepare variable length input for Flash Attention 2 + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, seq_length + ) + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_q, max_seqlen_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=self.dropout.p if self.training else 0.0, + softmax_scale=1.0 / math.sqrt(self.attention_head_size), + causal=False, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, seq_length) + else: + # Standard path for full-sequence, no padding required + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout_p=self.dropout.p if self.training else 0.0, + softmax_scale=1.0 / math.sqrt(self.attention_head_size), + causal=False, + ) + + attn_output = attn_output.reshape(batch_size, seq_length, -1) + attn_output = self.dropout(self.dense(attn_output)) + return (attn_output,) + + + # Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Layer with LayoutLMv2->LayoutLMv3 class LayoutLMv3Layer(GradientCheckpointingLayer): def __init__(self, config): From a27cab1fdda73800fe5349a0c3f7128295974fcc Mon Sep 17 00:00:00 2001 From: jackiehimel Date: Mon, 20 Oct 2025 22:27:45 -0400 Subject: [PATCH 02/15] Fix LayoutLMv3 SDPA and Flash Attention 2 implementation - Removed duplicate imports - Fixed attention component access patterns - Added missing _upad_input method for Flash Attention 2 - Corrected relative position bias handling - Added proper fallbacks for unsupported features --- .../models/layoutlmv3/modeling_layoutlmv3.py | 130 ++++++++++-------- 1 file changed, 70 insertions(+), 60 deletions(-) diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index 88a797b1b504..a0e5b2f0e909 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -45,7 +45,7 @@ _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa, ) -from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging +from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10 if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func @@ -369,23 +369,11 @@ def forward( -# MY NEW CLASS -# Additional imports for SDPA and Flash Attention support -from ...modeling_attn_mask_utils import ( - _prepare_4d_attention_mask_for_sdpa, - _prepare_4d_causal_attention_mask_for_sdpa, -) -from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging - -if is_flash_attn_2_available(): - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input class LayoutLMv3SdpaAttention(LayoutLMv3Attention): """ - Implements LayoutLMv3 attention using PyTorch's SDPA backend. - Provides improved speed and memory efficiency while maintaining original model weights. + LayoutLMv3 attention using PyTorch's scaled_dot_product_attention. """ def forward( @@ -397,10 +385,9 @@ def forward( rel_pos=None, rel_2d_pos=None, ): - # Use standard attention when attention weights are requested if output_attentions: logger.warning_once( - "Manual attention is used since output_attentions=True; this is slower than SDPA." + "SDPA doesn't support output_attentions, falling back to standard attention." ) return super().forward( hidden_states, @@ -412,45 +399,51 @@ def forward( ) batch_size, seq_len, _ = hidden_states.size() - query_layer = self.transpose_for_scores(self.query(hidden_states)) - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - # Add relative position bias if available - if self.has_relative_attention_bias: - query_layer += rel_pos - key_layer += rel_2d_pos - - # Reshape for SDPA: [batch, heads, seq_len, head_dim] - query_layer = query_layer.view(batch_size, self.num_attention_heads, seq_len, self.attention_head_size) - key_layer = key_layer.view(batch_size, self.num_attention_heads, seq_len, self.attention_head_size) - value_layer = value_layer.view(batch_size, self.num_attention_heads, seq_len, self.attention_head_size) - - # Convert mask to expected 4D format and dtype for SDPA + + query_layer = self.self.query(hidden_states) + key_layer = self.self.key(hidden_states) + value_layer = self.self.value(hidden_states) + query_layer = query_layer.view(batch_size, seq_len, self.self.num_attention_heads, self.self.attention_head_size).transpose(1, 2) + key_layer = key_layer.view(batch_size, seq_len, self.self.num_attention_heads, self.self.attention_head_size).transpose(1, 2) + value_layer = value_layer.view(batch_size, seq_len, self.self.num_attention_heads, self.self.attention_head_size).transpose(1, 2) if attention_mask is not None: attention_mask = attention_mask.view(batch_size, 1, 1, seq_len).expand( - batch_size, self.num_attention_heads, seq_len, seq_len + batch_size, self.self.num_attention_heads, seq_len, seq_len ) attention_mask = (1.0 - attention_mask.to(dtype=query_layer.dtype)) * torch.finfo(query_layer.dtype).min - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_layer, - key_layer, - value_layer, - attn_mask=attention_mask, - dropout_p=self.dropout.p if self.training else 0.0, - scale=1.0 / math.sqrt(self.attention_head_size), - ) + if self.self.has_relative_attention_bias and rel_pos is not None: + attention_scores = torch.matmul(query_layer / math.sqrt(self.self.attention_head_size), key_layer.transpose(-1, -2)) + + if self.self.has_spatial_attention_bias and rel_2d_pos is not None: + attention_scores += (rel_pos + rel_2d_pos) / math.sqrt(self.self.attention_head_size) + else: + attention_scores += rel_pos / math.sqrt(self.self.attention_head_size) + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = self.self.dropout(attention_probs) + attn_output = torch.matmul(attention_probs, value_layer) + else: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.self.dropout.p if self.training else 0.0, + scale=1.0 / math.sqrt(self.self.attention_head_size), + ) - attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.all_head_size) - outputs = self.dropout(self.dense(attn_output)) + attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.self.all_head_size) + outputs = self.output(attn_output, hidden_states) return (outputs,) class LayoutLMv3FlashAttention2(LayoutLMv3Attention): """ - Implements LayoutLMv3 attention using the Flash Attention 2 library. - Offers optimal memory usage and speed, with graceful fallback when features are unavailable. + LayoutLMv3 attention using Flash Attention 2. """ def __init__(self, *args, **kwargs): @@ -458,7 +451,27 @@ def __init__(self, *args, **kwargs): self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() def _reshape(self, tensor, seq_len, batch_size): - return tensor.view(batch_size, seq_len, self.num_attention_heads, self.attention_head_size) + return tensor.view(batch_size, seq_len, self.self.num_attention_heads, self.self.attention_head_size) + + def _upad_input(self, query_states, key_states, value_states, attention_mask, query_length): + batch_size, seq_len, num_heads, head_dim = query_states.shape + + indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask) + indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(key_states, attention_mask) + indices_v, cu_seqlens_v, max_seqlen_v = unpad_input(value_states, attention_mask) + + query_states = query_states.view(-1, num_heads, head_dim) + key_states = key_states.view(-1, num_heads, head_dim) + value_states = value_states.view(-1, num_heads, head_dim) + + return ( + query_states, + key_states, + value_states, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_q, max_seqlen_k) + ) def forward( self, @@ -469,31 +482,29 @@ def forward( rel_pos=None, rel_2d_pos=None, ): - # Fallback when output_attentions are explicitly requested if output_attentions: logger.warning_once( - "Flash Attention 2 does not provide output_attentions; reverting to standard logic." + "Flash Attention 2 doesn't support output_attentions, falling back to standard attention." ) return super().forward( hidden_states, attention_mask, head_mask, output_attentions, rel_pos, rel_2d_pos ) batch_size, seq_length, _ = hidden_states.size() - query_states = self._reshape(self.query(hidden_states), seq_length, batch_size) - key_states = self._reshape(self.key(hidden_states), seq_length, batch_size) - value_states = self._reshape(self.value(hidden_states), seq_length, batch_size) + + query_states = self._reshape(self.self.query(hidden_states), seq_length, batch_size) + key_states = self._reshape(self.self.key(hidden_states), seq_length, batch_size) + value_states = self._reshape(self.self.value(hidden_states), seq_length, batch_size) - # Fall back if model requires relative position bias - if self.has_relative_attention_bias and rel_pos is not None: + if self.self.has_relative_attention_bias and rel_pos is not None: logger.warning_once( - "Standard attention used as Flash Attention 2 cannot process relative position bias." + "Flash Attention 2 doesn't support relative position bias, falling back to standard attention." ) return super().forward( hidden_states, attention_mask, head_mask, output_attentions, rel_pos, rel_2d_pos ) if attention_mask is not None: - # Prepare variable length input for Flash Attention 2 query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( query_states, key_states, value_states, attention_mask, seq_length ) @@ -508,25 +519,24 @@ def forward( cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, - dropout_p=self.dropout.p if self.training else 0.0, - softmax_scale=1.0 / math.sqrt(self.attention_head_size), + dropout_p=self.self.dropout.p if self.training else 0.0, + softmax_scale=1.0 / math.sqrt(self.self.attention_head_size), causal=False, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, seq_length) else: - # Standard path for full-sequence, no padding required attn_output = flash_attn_func( query_states, key_states, value_states, - dropout_p=self.dropout.p if self.training else 0.0, - softmax_scale=1.0 / math.sqrt(self.attention_head_size), + dropout_p=self.self.dropout.p if self.training else 0.0, + softmax_scale=1.0 / math.sqrt(self.self.attention_head_size), causal=False, ) attn_output = attn_output.reshape(batch_size, seq_length, -1) - attn_output = self.dropout(self.dense(attn_output)) + attn_output = self.output(attn_output, hidden_states) return (attn_output,) From 0b7926e795e32f6c48d42ff17467e35c5d309494 Mon Sep 17 00:00:00 2001 From: jackiehimel Date: Mon, 20 Oct 2025 22:34:16 -0400 Subject: [PATCH 03/15] Fix head_mask parameter and attention mask handling - Remove unsupported head_mask parameter from super().forward() calls - Fix attention mask shape handling for SDPA (convert 4D mask to boolean format) - Maintain proper mask application in relative position bias fallback path --- .../models/layoutlmv3/modeling_layoutlmv3.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index a0e5b2f0e909..d0e87d1560a2 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -392,7 +392,6 @@ def forward( return super().forward( hidden_states, attention_mask, - head_mask, output_attentions, rel_pos, rel_2d_pos, @@ -406,11 +405,6 @@ def forward( query_layer = query_layer.view(batch_size, seq_len, self.self.num_attention_heads, self.self.attention_head_size).transpose(1, 2) key_layer = key_layer.view(batch_size, seq_len, self.self.num_attention_heads, self.self.attention_head_size).transpose(1, 2) value_layer = value_layer.view(batch_size, seq_len, self.self.num_attention_heads, self.self.attention_head_size).transpose(1, 2) - if attention_mask is not None: - attention_mask = attention_mask.view(batch_size, 1, 1, seq_len).expand( - batch_size, self.self.num_attention_heads, seq_len, seq_len - ) - attention_mask = (1.0 - attention_mask.to(dtype=query_layer.dtype)) * torch.finfo(query_layer.dtype).min if self.self.has_relative_attention_bias and rel_pos is not None: attention_scores = torch.matmul(query_layer / math.sqrt(self.self.attention_head_size), key_layer.transpose(-1, -2)) @@ -427,11 +421,17 @@ def forward( attention_probs = self.self.dropout(attention_probs) attn_output = torch.matmul(attention_probs, value_layer) else: + # Convert 4D mask to format expected by SDPA + attn_mask = None + if attention_mask is not None: + # SDPA expects mask to be [batch, heads, seq_len, seq_len] with True for positions to attend to + attn_mask = attention_mask > 0 + attn_output = torch.nn.functional.scaled_dot_product_attention( query_layer, key_layer, value_layer, - attn_mask=attention_mask, + attn_mask=attn_mask, dropout_p=self.self.dropout.p if self.training else 0.0, scale=1.0 / math.sqrt(self.self.attention_head_size), ) @@ -487,7 +487,7 @@ def forward( "Flash Attention 2 doesn't support output_attentions, falling back to standard attention." ) return super().forward( - hidden_states, attention_mask, head_mask, output_attentions, rel_pos, rel_2d_pos + hidden_states, attention_mask, output_attentions, rel_pos, rel_2d_pos ) batch_size, seq_length, _ = hidden_states.size() @@ -501,7 +501,7 @@ def forward( "Flash Attention 2 doesn't support relative position bias, falling back to standard attention." ) return super().forward( - hidden_states, attention_mask, head_mask, output_attentions, rel_pos, rel_2d_pos + hidden_states, attention_mask, output_attentions, rel_pos, rel_2d_pos ) if attention_mask is not None: From 7603ac453fc1a720c1b4a85206e5e70a2c624630 Mon Sep 17 00:00:00 2001 From: jackiehimel Date: Mon, 20 Oct 2025 22:36:50 -0400 Subject: [PATCH 04/15] Fixed parameter compatibility issues in attention classes, removed head_mask from super calls, and fixed attention mask format for SDPA. --- src/transformers/models/layoutlmv3/modeling_layoutlmv3.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index d0e87d1560a2..c578b0ef3128 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -421,10 +421,8 @@ def forward( attention_probs = self.self.dropout(attention_probs) attn_output = torch.matmul(attention_probs, value_layer) else: - # Convert 4D mask to format expected by SDPA attn_mask = None if attention_mask is not None: - # SDPA expects mask to be [batch, heads, seq_len, seq_len] with True for positions to attend to attn_mask = attention_mask > 0 attn_output = torch.nn.functional.scaled_dot_product_attention( From 56b2b9db3e6fc1c2f5635c960112431541d93030 Mon Sep 17 00:00:00 2001 From: jackiehimel Date: Mon, 20 Oct 2025 22:38:52 -0400 Subject: [PATCH 05/15] Fix attention mask logic and Flash Attention unpadding - Fix attention mask conversion for SDPA (use >= 0 instead of > 0) - Fix _upad_input to use actual unpadded tensors from unpad_input --- .../models/layoutlmv3/modeling_layoutlmv3.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index c578b0ef3128..1326b85d778b 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -423,7 +423,7 @@ def forward( else: attn_mask = None if attention_mask is not None: - attn_mask = attention_mask > 0 + attn_mask = attention_mask >= 0 attn_output = torch.nn.functional.scaled_dot_product_attention( query_layer, @@ -454,14 +454,14 @@ def _reshape(self, tensor, seq_len, batch_size): def _upad_input(self, query_states, key_states, value_states, attention_mask, query_length): batch_size, seq_len, num_heads, head_dim = query_states.shape - indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask) - indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(key_states, attention_mask) - indices_v, cu_seqlens_v, max_seqlen_v = unpad_input(value_states, attention_mask) - query_states = query_states.view(-1, num_heads, head_dim) key_states = key_states.view(-1, num_heads, head_dim) value_states = value_states.view(-1, num_heads, head_dim) + query_states, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask) + key_states, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(key_states, attention_mask) + value_states, indices_v, cu_seqlens_v, max_seqlen_v = unpad_input(value_states, attention_mask) + return ( query_states, key_states, From 572343e72ca50d839d98cbefd2506f9f2ae4ea54 Mon Sep 17 00:00:00 2001 From: jackiehimel Date: Tue, 21 Oct 2025 00:16:18 -0400 Subject: [PATCH 06/15] Optimize Flash Attention and fix API consistency - Optimize _upad_input to call unpad_input only once for self-attention - Remove unused head_mask parameter from forward methods --- .../models/layoutlmv3/modeling_layoutlmv3.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index 1326b85d778b..ef795a4f51c0 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -380,7 +380,6 @@ def forward( self, hidden_states, attention_mask=None, - head_mask=None, output_attentions=False, rel_pos=None, rel_2d_pos=None, @@ -458,24 +457,24 @@ def _upad_input(self, query_states, key_states, value_states, attention_mask, qu key_states = key_states.view(-1, num_heads, head_dim) value_states = value_states.view(-1, num_heads, head_dim) + # For self-attention, all tensors have same structure - only need to unpad once query_states, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask) - key_states, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(key_states, attention_mask) - value_states, indices_v, cu_seqlens_v, max_seqlen_v = unpad_input(value_states, attention_mask) + key_states = key_states[indices_q] + value_states = value_states[indices_q] return ( query_states, key_states, value_states, indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_q, max_seqlen_k) + (cu_seqlens_q, cu_seqlens_q), + (max_seqlen_q, max_seqlen_q) ) def forward( self, hidden_states, attention_mask=None, - head_mask=None, output_attentions=False, rel_pos=None, rel_2d_pos=None, From c8095af8b1a881b3bf71160bfed6b52000b2ad3b Mon Sep 17 00:00:00 2001 From: jackiehimel Date: Tue, 21 Oct 2025 00:17:41 -0400 Subject: [PATCH 07/15] Fixed attention mask and optimize Flash Attention --- src/transformers/models/layoutlmv3/modeling_layoutlmv3.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index ef795a4f51c0..ee65fe41148a 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -457,7 +457,6 @@ def _upad_input(self, query_states, key_states, value_states, attention_mask, qu key_states = key_states.view(-1, num_heads, head_dim) value_states = value_states.view(-1, num_heads, head_dim) - # For self-attention, all tensors have same structure - only need to unpad once query_states, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask) key_states = key_states[indices_q] value_states = value_states[indices_q] From 598a1946af8a9eeea22606cc54b8ba9706521679 Mon Sep 17 00:00:00 2001 From: jackiehimel Date: Wed, 22 Oct 2025 17:47:27 -0400 Subject: [PATCH 08/15] Fix linting issues: remove unused imports, fix whitespace, sort imports --- .../models/layoutlmv3/modeling_layoutlmv3.py | 32 +++++++++---------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index ee65fe41148a..c81edc4a78f1 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -24,6 +24,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN + +# SDPA and Flash Attention support from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, @@ -35,21 +37,17 @@ from ...pytorch_utils import apply_chunking_to_forward from ...utils import ( auto_docstring, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, logging, torch_int, ) from .configuration_layoutlmv3 import LayoutLMv3Config -# SDPA and Flash Attention support -from ...modeling_attn_mask_utils import ( - _prepare_4d_attention_mask_for_sdpa, - _prepare_4d_causal_attention_mask_for_sdpa, -) -from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10 if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input + from flash_attn.bert_padding import pad_input, unpad_input @@ -397,7 +395,7 @@ def forward( ) batch_size, seq_len, _ = hidden_states.size() - + query_layer = self.self.query(hidden_states) key_layer = self.self.key(hidden_states) value_layer = self.self.value(hidden_states) @@ -407,15 +405,15 @@ def forward( if self.self.has_relative_attention_bias and rel_pos is not None: attention_scores = torch.matmul(query_layer / math.sqrt(self.self.attention_head_size), key_layer.transpose(-1, -2)) - + if self.self.has_spatial_attention_bias and rel_2d_pos is not None: attention_scores += (rel_pos + rel_2d_pos) / math.sqrt(self.self.attention_head_size) else: attention_scores += rel_pos / math.sqrt(self.self.attention_head_size) - + if attention_mask is not None: attention_scores = attention_scores + attention_mask - + attention_probs = F.softmax(attention_scores, dim=-1) attention_probs = self.self.dropout(attention_probs) attn_output = torch.matmul(attention_probs, value_layer) @@ -423,7 +421,7 @@ def forward( attn_mask = None if attention_mask is not None: attn_mask = attention_mask >= 0 - + attn_output = torch.nn.functional.scaled_dot_product_attention( query_layer, key_layer, @@ -452,18 +450,18 @@ def _reshape(self, tensor, seq_len, batch_size): def _upad_input(self, query_states, key_states, value_states, attention_mask, query_length): batch_size, seq_len, num_heads, head_dim = query_states.shape - + query_states = query_states.view(-1, num_heads, head_dim) key_states = key_states.view(-1, num_heads, head_dim) value_states = value_states.view(-1, num_heads, head_dim) - + query_states, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask) key_states = key_states[indices_q] value_states = value_states[indices_q] - + return ( query_states, - key_states, + key_states, value_states, indices_q, (cu_seqlens_q, cu_seqlens_q), @@ -487,7 +485,7 @@ def forward( ) batch_size, seq_length, _ = hidden_states.size() - + query_states = self._reshape(self.self.query(hidden_states), seq_length, batch_size) key_states = self._reshape(self.self.key(hidden_states), seq_length, batch_size) value_states = self._reshape(self.self.value(hidden_states), seq_length, batch_size) From e023f914daad87e83f8ac4b893b4c26a80fd6799 Mon Sep 17 00:00:00 2001 From: jackiehimel Date: Wed, 22 Oct 2025 17:51:17 -0400 Subject: [PATCH 09/15] Fix code formatting with ruff --- .../models/layoutlmv3/modeling_layoutlmv3.py | 60 +++++++++---------- 1 file changed, 28 insertions(+), 32 deletions(-) diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index c81edc4a78f1..e5bd8bf6109c 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -50,8 +50,6 @@ from flash_attn.bert_padding import pad_input, unpad_input - - logger = logging.get_logger(__name__) @@ -366,26 +364,21 @@ def forward( return outputs - - - class LayoutLMv3SdpaAttention(LayoutLMv3Attention): """ LayoutLMv3 attention using PyTorch's scaled_dot_product_attention. """ def forward( - self, - hidden_states, - attention_mask=None, - output_attentions=False, - rel_pos=None, - rel_2d_pos=None, + self, + hidden_states, + attention_mask=None, + output_attentions=False, + rel_pos=None, + rel_2d_pos=None, ): if output_attentions: - logger.warning_once( - "SDPA doesn't support output_attentions, falling back to standard attention." - ) + logger.warning_once("SDPA doesn't support output_attentions, falling back to standard attention.") return super().forward( hidden_states, attention_mask, @@ -399,12 +392,20 @@ def forward( query_layer = self.self.query(hidden_states) key_layer = self.self.key(hidden_states) value_layer = self.self.value(hidden_states) - query_layer = query_layer.view(batch_size, seq_len, self.self.num_attention_heads, self.self.attention_head_size).transpose(1, 2) - key_layer = key_layer.view(batch_size, seq_len, self.self.num_attention_heads, self.self.attention_head_size).transpose(1, 2) - value_layer = value_layer.view(batch_size, seq_len, self.self.num_attention_heads, self.self.attention_head_size).transpose(1, 2) + query_layer = query_layer.view( + batch_size, seq_len, self.self.num_attention_heads, self.self.attention_head_size + ).transpose(1, 2) + key_layer = key_layer.view( + batch_size, seq_len, self.self.num_attention_heads, self.self.attention_head_size + ).transpose(1, 2) + value_layer = value_layer.view( + batch_size, seq_len, self.self.num_attention_heads, self.self.attention_head_size + ).transpose(1, 2) if self.self.has_relative_attention_bias and rel_pos is not None: - attention_scores = torch.matmul(query_layer / math.sqrt(self.self.attention_head_size), key_layer.transpose(-1, -2)) + attention_scores = torch.matmul( + query_layer / math.sqrt(self.self.attention_head_size), key_layer.transpose(-1, -2) + ) if self.self.has_spatial_attention_bias and rel_2d_pos is not None: attention_scores += (rel_pos + rel_2d_pos) / math.sqrt(self.self.attention_head_size) @@ -465,24 +466,22 @@ def _upad_input(self, query_states, key_states, value_states, attention_mask, qu value_states, indices_q, (cu_seqlens_q, cu_seqlens_q), - (max_seqlen_q, max_seqlen_q) + (max_seqlen_q, max_seqlen_q), ) def forward( - self, - hidden_states, - attention_mask=None, - output_attentions=False, - rel_pos=None, - rel_2d_pos=None, + self, + hidden_states, + attention_mask=None, + output_attentions=False, + rel_pos=None, + rel_2d_pos=None, ): if output_attentions: logger.warning_once( "Flash Attention 2 doesn't support output_attentions, falling back to standard attention." ) - return super().forward( - hidden_states, attention_mask, output_attentions, rel_pos, rel_2d_pos - ) + return super().forward(hidden_states, attention_mask, output_attentions, rel_pos, rel_2d_pos) batch_size, seq_length, _ = hidden_states.size() @@ -494,9 +493,7 @@ def forward( logger.warning_once( "Flash Attention 2 doesn't support relative position bias, falling back to standard attention." ) - return super().forward( - hidden_states, attention_mask, output_attentions, rel_pos, rel_2d_pos - ) + return super().forward(hidden_states, attention_mask, output_attentions, rel_pos, rel_2d_pos) if attention_mask is not None: query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( @@ -534,7 +531,6 @@ def forward( return (attn_output,) - # Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Layer with LayoutLMv2->LayoutLMv3 class LayoutLMv3Layer(GradientCheckpointingLayer): def __init__(self, config): From d593e88de0eb73441c54061e2365f34cf635fef2 Mon Sep 17 00:00:00 2001 From: Jackie Himel Date: Fri, 24 Oct 2025 20:20:51 -0400 Subject: [PATCH 10/15] Refactor to unified attention class per reviewer feedback - Replace separate SDPA and FlashAttention classes with unified LayoutLMv3Attention - Add _supports_sdpa and _supports_flash_attn_2 config flags - Add _attn_implementation parameter to config - Implement runtime dispatch based on attention implementation - Add proper fallback logic for unsupported features - Fix linting and formatting issues - All tests passing --- my_changes.txt | 0 .../layoutlmv3/configuration_layoutlmv3.py | 6 + .../models/layoutlmv3/modeling_layoutlmv3.py | 219 +++++++----------- .../models/layoutlmv3/my_changes.txt | 0 4 files changed, 96 insertions(+), 129 deletions(-) create mode 100644 my_changes.txt create mode 100644 src/transformers/models/layoutlmv3/my_changes.txt diff --git a/my_changes.txt b/my_changes.txt new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py b/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py index d67d4a446422..1976081b1ab6 100644 --- a/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py @@ -105,6 +105,10 @@ class LayoutLMv3Config(PreTrainedConfig): ```""" model_type = "layoutlmv3" + + # Support flags for attention implementations + _supports_sdpa = True + _supports_flash_attn_2 = True def __init__( self, @@ -138,6 +142,7 @@ def __init__( num_channels=3, patch_size=16, classifier_dropout=None, + _attn_implementation="eager", **kwargs, ): super().__init__( @@ -173,6 +178,7 @@ def __init__( self.num_channels = num_channels self.patch_size = patch_size self.classifier_dropout = classifier_dropout + self._attn_implementation = _attn_implementation __all__ = ["LayoutLMv3Config"] diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index e5bd8bf6109c..e1b075f4709d 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -38,7 +38,6 @@ from ...utils import ( auto_docstring, is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, torch_int, ) @@ -339,59 +338,70 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3 class LayoutLMv3Attention(nn.Module): + """ + Unified LayoutLMv3 attention module with support for eager, SDPA, and FlashAttention-2. + """ + def __init__(self, config): super().__init__() self.self = LayoutLMv3SelfAttention(config) self.output = LayoutLMv3SelfOutput(config) + # Store attention implementation config + self.is_causal = False + self._attn_implementation = config._attn_implementation if hasattr(config, "_attn_implementation") else "eager" + def forward( self, hidden_states, attention_mask=None, + head_mask=None, output_attentions=False, rel_pos=None, rel_2d_pos=None, ): - self_outputs = self.self( - hidden_states, - attention_mask, - output_attentions, - rel_pos=rel_pos, - rel_2d_pos=rel_2d_pos, - ) - attention_output = self.output(self_outputs[0], hidden_states) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them - return outputs + # Dispatch to appropriate attention implementation + if self._attn_implementation == "flash_attention_2": + # Check for unsupported features + if output_attentions: + logger.warning_once( + "FlashAttention-2 does not support output_attentions, falling back to eager attention." + ) + self_outputs = self.self(hidden_states, attention_mask, output_attentions, rel_pos, rel_2d_pos) + elif self.self.has_relative_attention_bias and rel_pos is not None: + logger.warning_once( + "FlashAttention-2 does not support relative position bias, falling back to eager attention." + ) + self_outputs = self.self(hidden_states, attention_mask, output_attentions, rel_pos, rel_2d_pos) + else: + self_outputs = self._flash_attention_forward(hidden_states, attention_mask) + elif self._attn_implementation == "sdpa": + # Check for unsupported features + if output_attentions: + logger.warning_once("SDPA does not support output_attentions, falling back to eager attention.") + self_outputs = self.self(hidden_states, attention_mask, output_attentions, rel_pos, rel_2d_pos) + elif self.self.has_relative_attention_bias and rel_pos is not None: + # SDPA doesn't support relative position bias + self_outputs = self.self(hidden_states, attention_mask, output_attentions, rel_pos, rel_2d_pos) + else: + self_outputs = self._sdpa_attention_forward(hidden_states, attention_mask, head_mask) -class LayoutLMv3SdpaAttention(LayoutLMv3Attention): - """ - LayoutLMv3 attention using PyTorch's scaled_dot_product_attention. - """ + else: # eager + self_outputs = self.self(hidden_states, attention_mask, output_attentions, rel_pos, rel_2d_pos) - def forward( - self, - hidden_states, - attention_mask=None, - output_attentions=False, - rel_pos=None, - rel_2d_pos=None, - ): - if output_attentions: - logger.warning_once("SDPA doesn't support output_attentions, falling back to standard attention.") - return super().forward( - hidden_states, - attention_mask, - output_attentions, - rel_pos, - rel_2d_pos, - ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + def _sdpa_attention_forward(self, hidden_states, attention_mask=None, head_mask=None): batch_size, seq_len, _ = hidden_states.size() + # Get Q, K, V query_layer = self.self.query(hidden_states) key_layer = self.self.key(hidden_states) value_layer = self.self.value(hidden_states) + query_layer = query_layer.view( batch_size, seq_len, self.self.num_attention_heads, self.self.attention_head_size ).transpose(1, 2) @@ -402,117 +412,68 @@ def forward( batch_size, seq_len, self.self.num_attention_heads, self.self.attention_head_size ).transpose(1, 2) - if self.self.has_relative_attention_bias and rel_pos is not None: - attention_scores = torch.matmul( - query_layer / math.sqrt(self.self.attention_head_size), key_layer.transpose(-1, -2) - ) - - if self.self.has_spatial_attention_bias and rel_2d_pos is not None: - attention_scores += (rel_pos + rel_2d_pos) / math.sqrt(self.self.attention_head_size) - else: - attention_scores += rel_pos / math.sqrt(self.self.attention_head_size) - - if attention_mask is not None: - attention_scores = attention_scores + attention_mask - - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = self.self.dropout(attention_probs) - attn_output = torch.matmul(attention_probs, value_layer) - else: - attn_mask = None - if attention_mask is not None: - attn_mask = attention_mask >= 0 - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_layer, - key_layer, - value_layer, - attn_mask=attn_mask, - dropout_p=self.self.dropout.p if self.training else 0.0, - scale=1.0 / math.sqrt(self.self.attention_head_size), - ) - - attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.self.all_head_size) - outputs = self.output(attn_output, hidden_states) - return (outputs,) - - -class LayoutLMv3FlashAttention2(LayoutLMv3Attention): - """ - LayoutLMv3 attention using Flash Attention 2. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def _reshape(self, tensor, seq_len, batch_size): - return tensor.view(batch_size, seq_len, self.self.num_attention_heads, self.self.attention_head_size) - - def _upad_input(self, query_states, key_states, value_states, attention_mask, query_length): - batch_size, seq_len, num_heads, head_dim = query_states.shape - - query_states = query_states.view(-1, num_heads, head_dim) - key_states = key_states.view(-1, num_heads, head_dim) - value_states = value_states.view(-1, num_heads, head_dim) - - query_states, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask) - key_states = key_states[indices_q] - value_states = value_states[indices_q] - - return ( - query_states, - key_states, - value_states, - indices_q, - (cu_seqlens_q, cu_seqlens_q), - (max_seqlen_q, max_seqlen_q), + # Convert attention mask to boolean format for SDPA + attn_mask = None + if attention_mask is not None: + attn_mask = attention_mask >= 0 + + # SDPA doesn't support head_mask, fallback if needed + if head_mask is not None: + return self.self(hidden_states, attention_mask, output_attentions=False, rel_pos=None, rel_2d_pos=None) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attn_mask, + dropout_p=self.self.dropout.p if self.training else 0.0, + scale=1.0 / math.sqrt(self.self.attention_head_size), ) - def forward( - self, - hidden_states, - attention_mask=None, - output_attentions=False, - rel_pos=None, - rel_2d_pos=None, - ): - if output_attentions: - logger.warning_once( - "Flash Attention 2 doesn't support output_attentions, falling back to standard attention." - ) - return super().forward(hidden_states, attention_mask, output_attentions, rel_pos, rel_2d_pos) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, seq_len, self.self.all_head_size) + return (attn_output,) + + def _flash_attention_forward(self, hidden_states, attention_mask=None): batch_size, seq_length, _ = hidden_states.size() - query_states = self._reshape(self.self.query(hidden_states), seq_length, batch_size) - key_states = self._reshape(self.self.key(hidden_states), seq_length, batch_size) - value_states = self._reshape(self.self.value(hidden_states), seq_length, batch_size) + # Get Q, K, V + query_states = self.self.query(hidden_states) + key_states = self.self.key(hidden_states) + value_states = self.self.value(hidden_states) - if self.self.has_relative_attention_bias and rel_pos is not None: - logger.warning_once( - "Flash Attention 2 doesn't support relative position bias, falling back to standard attention." - ) - return super().forward(hidden_states, attention_mask, output_attentions, rel_pos, rel_2d_pos) + query_states = query_states.view( + batch_size, seq_length, self.self.num_attention_heads, self.self.attention_head_size + ) + key_states = key_states.view( + batch_size, seq_length, self.self.num_attention_heads, self.self.attention_head_size + ) + value_states = value_states.view( + batch_size, seq_length, self.self.num_attention_heads, self.self.attention_head_size + ) if attention_mask is not None: - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, seq_length - ) - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_q, max_seqlen_k = max_seq_lens + # Unpad for variable length sequences + query_states = query_states.view(-1, self.self.num_attention_heads, self.self.attention_head_size) + key_states = key_states.view(-1, self.self.num_attention_heads, self.self.attention_head_size) + value_states = value_states.view(-1, self.self.num_attention_heads, self.self.attention_head_size) + + query_states, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask) + key_states = key_states[indices_q] + value_states = value_states[indices_q] attn_output_unpad = flash_attn_varlen_func( query_states, key_states, value_states, cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, + cu_seqlens_k=cu_seqlens_q, max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, + max_seqlen_k=max_seqlen_q, dropout_p=self.self.dropout.p if self.training else 0.0, softmax_scale=1.0 / math.sqrt(self.self.attention_head_size), - causal=False, + causal=self.is_causal, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, seq_length) @@ -523,11 +484,11 @@ def forward( value_states, dropout_p=self.self.dropout.p if self.training else 0.0, softmax_scale=1.0 / math.sqrt(self.self.attention_head_size), - causal=False, + causal=self.is_causal, ) - attn_output = attn_output.reshape(batch_size, seq_length, -1) - attn_output = self.output(attn_output, hidden_states) + attn_output = attn_output.reshape(batch_size, seq_length, self.self.all_head_size) + return (attn_output,) diff --git a/src/transformers/models/layoutlmv3/my_changes.txt b/src/transformers/models/layoutlmv3/my_changes.txt new file mode 100644 index 000000000000..e69de29bb2d1 From 446443e8636ac64fab05e78ac6fafb52748207c1 Mon Sep 17 00:00:00 2001 From: Jackie Himel Date: Fri, 24 Oct 2025 20:26:35 -0400 Subject: [PATCH 11/15] Fix whitespace issue in configuration file --- src/transformers/models/layoutlmv3/configuration_layoutlmv3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py b/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py index 1976081b1ab6..0eae2af49829 100644 --- a/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py @@ -105,7 +105,7 @@ class LayoutLMv3Config(PreTrainedConfig): ```""" model_type = "layoutlmv3" - + # Support flags for attention implementations _supports_sdpa = True _supports_flash_attn_2 = True From 5bdd158c3107dbd1a77f64703f880ca4a725ef75 Mon Sep 17 00:00:00 2001 From: Jackie Himel Date: Fri, 24 Oct 2025 21:23:33 -0400 Subject: [PATCH 12/15] Fix check_copies validation --- src/transformers/models/layoutlmv3/modeling_layoutlmv3.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index e1b075f4709d..274eb01fb8fc 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -336,7 +336,8 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -# Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3 +# Adapted from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3 +# Enhanced with unified attention implementation supporting eager, SDPA, and FlashAttention-2 class LayoutLMv3Attention(nn.Module): """ Unified LayoutLMv3 attention module with support for eager, SDPA, and FlashAttention-2. From 221d4e44d51aca32760eed3ee77145c7f49a77f0 Mon Sep 17 00:00:00 2001 From: Jackie Himel Date: Fri, 24 Oct 2025 21:27:35 -0400 Subject: [PATCH 13/15] Fix attention implementation bugs - Add proper None checks for relative position bias - Fix attention mask shape handling for SDPA - Add fallback imports for flash_attn functions - Ensure proper mask broadcasting for scaled_dot_product_attention - All attention implementations now work correctly --- .../models/layoutlmv3/modeling_layoutlmv3.py | 29 +++++++++++++++---- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index 274eb01fb8fc..fcd586ecd621 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -47,6 +47,12 @@ if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import pad_input, unpad_input +else: + # Define dummy functions for when flash_attn is not available + def unpad_input(*args, **kwargs): + raise ImportError("flash_attn is not available") + def pad_input(*args, **kwargs): + raise ImportError("flash_attn is not available") logger = logging.get_logger(__name__) @@ -293,10 +299,11 @@ def forward( # Changing the computational order into QT(K/√d) alleviates the problem. (https://huggingface.co/papers/2105.13290) attention_scores = torch.matmul(query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2)) - if self.has_relative_attention_bias and self.has_spatial_attention_bias: - attention_scores += (rel_pos + rel_2d_pos) / math.sqrt(self.attention_head_size) - elif self.has_relative_attention_bias: - attention_scores += rel_pos / math.sqrt(self.attention_head_size) + if self.has_relative_attention_bias and rel_pos is not None: + if self.has_spatial_attention_bias and rel_2d_pos is not None: + attention_scores += (rel_pos + rel_2d_pos) / math.sqrt(self.attention_head_size) + else: + attention_scores += rel_pos / math.sqrt(self.attention_head_size) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function) @@ -416,7 +423,19 @@ def _sdpa_attention_forward(self, hidden_states, attention_mask=None, head_mask= # Convert attention mask to boolean format for SDPA attn_mask = None if attention_mask is not None: - attn_mask = attention_mask >= 0 + # SDPA expects 2D mask, but we might have 4D extended mask + if attention_mask.dim() == 4: + # Convert 4D extended mask to 2D: (batch_size, 1, 1, seq_len) -> (batch_size, seq_len) + attn_mask = attention_mask.squeeze(1).squeeze(1) >= 0 + elif attention_mask.dim() == 2: + attn_mask = attention_mask >= 0 + else: + # For other dimensions, try to squeeze to 2D + attn_mask = attention_mask.squeeze() >= 0 + + # Expand mask to be broadcastable with attention heads: (batch_size, seq_len) -> (batch_size, 1, seq_len, seq_len) + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # SDPA doesn't support head_mask, fallback if needed if head_mask is not None: From 3c1b884c94bf93fc96f6cc2ea61d2eec649c345f Mon Sep 17 00:00:00 2001 From: Jackie Himel Date: Fri, 24 Oct 2025 21:32:59 -0400 Subject: [PATCH 14/15] Fix whitespace in blank line --- src/transformers/models/layoutlmv3/modeling_layoutlmv3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index fcd586ecd621..bb45edd29e29 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -432,7 +432,7 @@ def _sdpa_attention_forward(self, hidden_states, attention_mask=None, head_mask= else: # For other dimensions, try to squeeze to 2D attn_mask = attention_mask.squeeze() >= 0 - + # Expand mask to be broadcastable with attention heads: (batch_size, seq_len) -> (batch_size, 1, seq_len, seq_len) if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) From 74380601d901286c3876f3a881d5fec29c97cc1e Mon Sep 17 00:00:00 2001 From: Jackie Himel Date: Fri, 24 Oct 2025 21:59:08 -0400 Subject: [PATCH 15/15] Apply ruff formatting to fix CI issues --- src/transformers/models/layoutlmv3/modeling_layoutlmv3.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index bb45edd29e29..6fc88a7de0ce 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -51,6 +51,7 @@ # Define dummy functions for when flash_attn is not available def unpad_input(*args, **kwargs): raise ImportError("flash_attn is not available") + def pad_input(*args, **kwargs): raise ImportError("flash_attn is not available")