@@ -90,7 +90,7 @@ def forward(
90
90
query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin )
91
91
92
92
if past_key_value is not None :
93
- # sin and cos are specific to RoPE models; position_ids needed for the static cache
93
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
94
94
cache_kwargs = {"sin" : sin , "cos" : cos , "cache_position" : cache_position }
95
95
key_states , value_states = past_key_value .update (key_states , value_states , self .layer_idx , cache_kwargs )
96
96
@@ -107,8 +107,7 @@ def forward(
107
107
bsz = key_states .shape [0 ]
108
108
109
109
if attention_mask is not None : # no matter the length, we just slice it
110
- if cache_position is not None :
111
- causal_mask = attention_mask [:, :, cache_position , : key_states .shape [- 2 ]]
110
+ causal_mask = attention_mask [:, :, :, : key_states .shape [- 2 ]]
112
111
attn_weights = attn_weights + causal_mask
113
112
114
113
# upcast attention to fp32
@@ -184,7 +183,7 @@ def forward(
184
183
past_key_value = getattr (self , "past_key_value" , past_key_value )
185
184
186
185
if past_key_value is not None :
187
- # sin and cos are specific to RoPE models; position_ids needed for the static cache
186
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
188
187
cache_kwargs = {"sin" : sin , "cos" : cos , "cache_position" : cache_position }
189
188
key_states , value_states = past_key_value .update (key_states , value_states , self .layer_idx , cache_kwargs )
190
189
@@ -284,10 +283,11 @@ def forward(
284
283
cos , sin = self .rotary_emb (value_states , position_ids )
285
284
query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin )
286
285
286
+ # In case static cache is used, it is an instance attribute.
287
287
past_key_value = getattr (self , "past_key_value" , past_key_value )
288
288
289
289
if past_key_value is not None :
290
- # sin and cos are specific to RoPE models; position_ids needed for the static cache
290
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
291
291
cache_kwargs = {"sin" : sin , "cos" : cos , "cache_position" : cache_position }
292
292
key_states , value_states = past_key_value .update (key_states , value_states , self .layer_idx , cache_kwargs )
293
293
@@ -302,8 +302,9 @@ def forward(
302
302
bsz = key_states .shape [0 ]
303
303
304
304
causal_mask = attention_mask
305
- if attention_mask is not None and cache_position is not None :
306
- causal_mask = causal_mask [:, :, cache_position , : key_states .shape [- 2 ]]
305
+ # if attention_mask is not None and cache_position is not None:
306
+ if attention_mask is not None :
307
+ causal_mask = causal_mask [:, :, :, : key_states .shape [- 2 ]]
307
308
308
309
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
309
310
# Reference: https://github.com/pytorch/pytorch/issues/112577.
0 commit comments