Skip to content

Commit 0d95e45

Browse files
authored
Upgrade transformers to v4.51.3 (#3685)
1 parent 8901f20 commit 0d95e45

File tree

20 files changed

+1412
-1522
lines changed

20 files changed

+1412
-1522
lines changed

docs/tutorials/features/fast_bert.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Currently `ipex.fast_bert` API is only well optimized for training. For inferenc
99

1010
### Prerequisite
1111

12-
- Transformers 4.6.0 ~ 4.48.0
12+
- Transformers 4.6.0 ~ 4.51.3
1313

1414
### Usage Example
1515

examples/cpu/features/fast_bert/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Currently `ipex.fast_bert` API is only well optimized for training. For inference, it ensures functionality, while to get peak perf, please use `ipex.optimize` API + torchscript.
66

77
# Prerequisite:
8-
Transformers 4.6.0 ~ 4.48.0
8+
Transformers 4.6.0 ~ 4.51.3
99

1010
# Usage Example:
1111
Training:

examples/cpu/llm/fine-tuning/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ black[jupyter]
66
datasets
77
fire
88
peft
9-
transformers==4.48.0
9+
transformers==4.51.3
1010
gradio
1111
sentencepiece
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
transformers==4.48.0
1+
transformers==4.51.3

intel_extension_for_pytorch/cpu/tpp/fused_bert.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,7 +1250,7 @@ def fast_bert(model, dtype=torch.float, optimizer=None, unpad=False):
12501250
# tpp bert optimization depends on the transformers repo to implementate the related module
12511251
installed_pkg = {dist.metadata["Name"].lower() for dist in distributions()}
12521252
min_version = "4.6.0"
1253-
max_version = "4.48.0"
1253+
max_version = "4.51.3"
12541254
if "transformers" not in installed_pkg:
12551255
raise RuntimeError(
12561256
"Please installed the transformers with version: between {} and {}".format(
@@ -1275,9 +1275,10 @@ def fast_bert(model, dtype=torch.float, optimizer=None, unpad=False):
12751275
position_ids_persistent = True
12761276
PT_OPTIMIZER_TO_TPP_OPTIMIZER = {
12771277
torch.optim.AdamW: AdamW,
1278-
transformers.optimization.AdamW: AdamW,
12791278
torch.optim.SGD: SGD,
12801279
}
1280+
if hasattr(transformers.optimization, "AdamW"):
1281+
PT_OPTIMIZER_TO_TPP_OPTIMIZER[transformers.optimization.AdamW] = AdamW
12811282
if dtype not in (
12821283
torch.float,
12831284
torch.bfloat16,
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from .beam_search import _beam_search
1+
from .beam_search import _beam_search, _beam_search_legacy
22
from .greedy_search import _greedy_search
33
from .sample import _sample
4-
from .beam_sample import _beam_sample
4+
from .beam_sample import _beam_sample, _beam_sample_legacy
55
from .utils import whisper_generate

0 commit comments

Comments
 (0)