Skip to content

Commit 43ccd9f

Browse files
authored
Upgrade to Transformers v4.39.x (#686)
1 parent 42c1753 commit 43ccd9f

File tree

3 files changed

+10
-9
lines changed

3 files changed

+10
-9
lines changed

hf_transformers

Submodule hf_transformers updated 623 files

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
"sphinx-multiversion==0.2.4",
6161
"timeout-decorator",
6262
"torch>=1.10,!=1.12.0",
63-
"transformers~=4.38.1",
63+
"transformers~=4.39.3",
6464
]
6565

6666

src/adapters/models/llama/modeling_llama.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def forward(
9090
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
9191

9292
if past_key_value is not None:
93-
# sin and cos are specific to RoPE models; position_ids needed for the static cache
93+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
9494
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
9595
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
9696

@@ -107,8 +107,7 @@ def forward(
107107
bsz = key_states.shape[0]
108108

109109
if attention_mask is not None: # no matter the length, we just slice it
110-
if cache_position is not None:
111-
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
110+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
112111
attn_weights = attn_weights + causal_mask
113112

114113
# upcast attention to fp32
@@ -184,7 +183,7 @@ def forward(
184183
past_key_value = getattr(self, "past_key_value", past_key_value)
185184

186185
if past_key_value is not None:
187-
# sin and cos are specific to RoPE models; position_ids needed for the static cache
186+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
188187
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
189188
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
190189

@@ -284,10 +283,11 @@ def forward(
284283
cos, sin = self.rotary_emb(value_states, position_ids)
285284
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
286285

286+
# In case static cache is used, it is an instance attribute.
287287
past_key_value = getattr(self, "past_key_value", past_key_value)
288288

289289
if past_key_value is not None:
290-
# sin and cos are specific to RoPE models; position_ids needed for the static cache
290+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
291291
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
292292
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
293293

@@ -302,8 +302,9 @@ def forward(
302302
bsz = key_states.shape[0]
303303

304304
causal_mask = attention_mask
305-
if attention_mask is not None and cache_position is not None:
306-
causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
305+
# if attention_mask is not None and cache_position is not None:
306+
if attention_mask is not None:
307+
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
307308

308309
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
309310
# Reference: https://github.com/pytorch/pytorch/issues/112577.

0 commit comments

Comments
 (0)