Skip to content

Commit 0620b7f

Browse files
authored
support automatic_speech_recognition pipeline (#934)
1 parent c746607 commit 0620b7f

19 files changed

+2283
-8
lines changed

.github/pylint.conf

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ disable=raw-checker-failed,
161161
too-few-public-methods,
162162
no-member,
163163
protected-access,
164-
abstract-method
164+
abstract-method,
165+
cyclic-import
165166

166167
# Enable the message, report, category or checker with the given id(s). You can
167168
# either give multiple identifier separated by comma (,) or put this option

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ The table below represents the current support in the library for each of those
154154
| Phi2 |||
155155
| Pop2piano |||
156156
| Qwen2 |||
157+
| Reformer |||
157158
| RegNet | Todo ||
158159
| RoBERTa |||
159160
| RWKV |||
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import mindspore
2+
from mindnlp.transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
3+
from datasets import load_dataset
4+
5+
6+
model_id = "openai/whisper-large-v3"
7+
8+
model = AutoModelForSpeechSeq2Seq.from_pretrained(
9+
model_id, ms_dtype=mindspore.float16, low_cpu_mem_usage=True, use_safetensors=True
10+
)
11+
12+
processor = AutoProcessor.from_pretrained(model_id)
13+
14+
pipe = pipeline(
15+
"automatic-speech-recognition",
16+
model=model,
17+
tokenizer=processor.tokenizer,
18+
feature_extractor=processor.feature_extractor,
19+
max_new_tokens=128,
20+
chunk_length_s=30,
21+
batch_size=16,
22+
return_timestamps=True,
23+
ms_dtype=mindspore.float16,
24+
)
25+
26+
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
27+
sample = dataset[0]["audio"]
28+
29+
result = pipe(sample)
30+
print(result["text"])

mindnlp/configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
3535
IMAGE_PROCESSOR_NAME = FEATURE_EXTRACTOR_NAME
36+
PROCESSOR_NAME = "processor_config.json"
3637

3738
DEFAULT_ROOT = os.path.join(os.getcwd(), ".mindnlp")
3839
# for modelscope models

mindnlp/transformers/generation/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ def _maybe_initialize_input_ids_for_generation(
581581
encoder_outputs = model_kwargs.get("encoder_outputs")
582582
if self.config.is_encoder_decoder and encoder_outputs is not None:
583583
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
584-
shape = encoder_outputs.last_hidden_state.size()[:-1]
584+
shape = encoder_outputs.last_hidden_state.shape[:-1]
585585
return ops.ones(shape, dtype=mindspore.int64) * -100
586586

587587
if bos_token_id is None:
@@ -609,7 +609,7 @@ def _maybe_initialize_input_ids_for_generation(
609609
encoder_outputs = model_kwargs.get("encoder_outputs")
610610
if self.config.is_encoder_decoder and encoder_outputs is not None:
611611
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
612-
shape = encoder_outputs.last_hidden_state.size()[:-1]
612+
shape = encoder_outputs.last_hidden_state.shape[:-1]
613613
return ops.ones(shape, dtype=mindspore.int64) * -100
614614

615615
if bos_token_id is None:
@@ -651,7 +651,7 @@ def _prepare_input_ids_for_generation(
651651
) -> mindspore.Tensor:
652652
if self.config.is_encoder_decoder and encoder_outputs is not None:
653653
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
654-
shape = encoder_outputs.last_hidden_state.size()[:-1]
654+
shape = encoder_outputs.last_hidden_state.shape[:-1]
655655
return ops.ones(shape, dtype=mindspore.float32) * -100
656656

657657
if bos_token_id is None:

mindnlp/transformers/models/auto/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
# ============================================================================
1516
""" Auto class."""
1617

1718
from mindnlp.utils import (
@@ -28,6 +29,9 @@
2829
)
2930

3031
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
32+
from .feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
33+
from .image_processing_auto import IMAGE_PROCESSOR_MAPPING, AutoImageProcessor
34+
from .processing_auto import PROCESSOR_MAPPING, AutoProcessor
3135

3236

3337
from .modeling_auto import (
@@ -106,6 +110,9 @@
106110
'AutoConfig',
107111
'TOKENIZER_MAPPING',
108112
'AutoTokenizer',
113+
"FEATURE_EXTRACTOR_MAPPING", "AutoFeatureExtractor",
114+
"IMAGE_PROCESSOR_MAPPING", "AutoImageProcessor",
115+
"PROCESSOR_MAPPING", "AutoProcessor",
109116
'MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING',
110117
'MODEL_FOR_AUDIO_XVECTOR_MAPPING',
111118
'MODEL_FOR_BACKBONE_MAPPING',

mindnlp/transformers/models/auto/auto_factory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
# ============================================================================
1516
# pylint: disable=C0116
1617
# pylint: disable=C2801
1718
"""Factory function to build auto-model classes."""

mindnlp/transformers/models/auto/configuration_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
# ============================================================================
1516
# pylint: disable=C0116
1617
# pylint: disable=C0103
1718
""" Auto Config class."""
@@ -67,6 +68,7 @@
6768
("starcoder2", "Starcoder2Config"),
6869
('t5', 'T5Config'),
6970
('wav2vec2', 'Wav2Vec2Config'),
71+
("whisper", "WhisperConfig"),
7072
('xlm-roberta', 'XLMRobertaConfig'),
7173
]
7274
)

0 commit comments

Comments
 (0)