Skip to content

Commit

Permalink
Merge japanese-to-english multilingual branch (#1860)
Browse files Browse the repository at this point in the history
* add streaming support to reazonresearch

* update README for streaming

* Update RESULTS.md

* add onnx decode

---------

Co-authored-by: root <[email protected]>
Co-authored-by: Fangjun Kuang <[email protected]>
Co-authored-by: root <[email protected]>
Co-authored-by: zr_jin <[email protected]>
  • Loading branch information
5 people authored Feb 3, 2025
1 parent dd5d7e3 commit 0855b03
Show file tree
Hide file tree
Showing 54 changed files with 7,048 additions and 104 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/style_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ jobs:
working-directory: ${{github.workspace}}
run: |
black --check --diff .
- name: Run isort
shell: bash
working-directory: ${{github.workspace}}
Expand Down
17 changes: 17 additions & 0 deletions egs/multi_ja_en/ASR/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Introduction

A bilingual Japanese-English ASR model that utilizes ReazonSpeech, developed by the developers of ReazonSpeech.

**ReazonSpeech** is an open-source dataset that contains a diverse set of natural Japanese speech, collected from terrestrial television streams. It contains more than 35,000 hours of audio.


# Included Training Sets

1. LibriSpeech (English)
2. ReazonSpeech (Japanese)

|Datset| Number of hours| URL|
|---|---:|---|
|**TOTAL**|35,960|---|
|LibriSpeech|960|https://www.openslr.org/12/|
|ReazonSpeech (all) |35,000|https://huggingface.co/datasets/reazon-research/reazonspeech|
53 changes: 53 additions & 0 deletions egs/multi_ja_en/ASR/RESULTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
## Results

### Zipformer

#### Non-streaming

The training command is:

```shell
./zipformer/train.py \
--bilingual 1 \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--max-duration 600
```

The decoding command is:

```shell
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method greedy_search
```

To export the model with onnx:

```shell
./zipformer/export-onnx.py --tokens data/lang_bbpe_2000/tokens.txt --use-averaged-model 0 --epoch 35 --avg 1 --exp-dir zipformer/exp --num-encoder-layers "2,2,3,4,3,2" --downsampling-factor "1,2,4,8,4,2" --feedforward-dim "512,768,1024,1536,1024,768" --num-heads "4,4,4,8,4,4" --encoder-dim "192,256,384,512,384,256" --query-head-dim 32 --value-head-dim 12 --pos-head-dim 4 --pos-dim 48 --encoder-unmasked-dim "192,192,256,256,256,192" --cnn-module-kernel "31,31,15,15,15,31" --decoder-dim 512 --joiner-dim 512 --causal False --chunk-size "16,32,64,-1" --left-context-frames "64,128,256,-1" --fp16 True
```
Word Error Rates (WERs) listed below:

| Datasets | ReazonSpeech | ReazonSpeech | LibriSpeech | LibriSpeech |
|----------------------|--------------|---------------|--------------------|-------------------|
| Zipformer WER (%) | dev | test | test-clean | test-other |
| greedy_search | 5.9 | 4.07 | 3.46 | 8.35 |
| modified_beam_search | 4.87 | 3.61 | 3.28 | 8.07 |
| fast_beam_search | 41.04 | 36.59 | 16.14 | 22.0 |


Character Error Rates (CERs) for Japanese listed below:
| Decoding Method | In-Distribution CER | JSUT | CommonVoice | TEDx |
| :------------------: | :-----------------: | :--: | :---------: | :---: |
| greedy search | 12.56 | 6.93 | 9.75 | 9.67 |
| modified beam search | 11.59 | 6.97 | 9.55 | 9.51 |

Pre-trained model can be found here: https://huggingface.co/reazon-research/reazonspeech-k2-v2-ja-en/tree/main

146 changes: 146 additions & 0 deletions egs/multi_ja_en/ASR/local/compute_fbank_reazonspeech.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#!/usr/bin/env python3
# Copyright 2023 The University of Electro-Communications (Author: Teo Wen Shen) # noqa
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse
import logging
import os
from pathlib import Path
from typing import List, Tuple

import torch

# fmt: off
from lhotse import ( # See the following for why LilcomChunkyWriter is preferred; https://github.com/k2-fsa/icefall/pull/404; https://github.com/lhotse-speech/lhotse/pull/527
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
RecordingSet,
SupervisionSet,
)

# fmt: on

# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)

RNG_SEED = 42
concat_params = {"gap": 1.0, "maxlen": 10.0}


def make_cutset_blueprints(
manifest_dir: Path,
) -> List[Tuple[str, CutSet]]:
cut_sets = []

# Create test dataset
logging.info("Creating test cuts.")
cut_sets.append(
(
"test",
CutSet.from_manifests(
recordings=RecordingSet.from_file(
manifest_dir / "reazonspeech_recordings_test.jsonl.gz"
),
supervisions=SupervisionSet.from_file(
manifest_dir / "reazonspeech_supervisions_test.jsonl.gz"
),
),
)
)

# Create dev dataset
logging.info("Creating dev cuts.")
cut_sets.append(
(
"dev",
CutSet.from_manifests(
recordings=RecordingSet.from_file(
manifest_dir / "reazonspeech_recordings_dev.jsonl.gz"
),
supervisions=SupervisionSet.from_file(
manifest_dir / "reazonspeech_supervisions_dev.jsonl.gz"
),
),
)
)

# Create train dataset
logging.info("Creating train cuts.")
cut_sets.append(
(
"train",
CutSet.from_manifests(
recordings=RecordingSet.from_file(
manifest_dir / "reazonspeech_recordings_train.jsonl.gz"
),
supervisions=SupervisionSet.from_file(
manifest_dir / "reazonspeech_supervisions_train.jsonl.gz"
),
),
)
)
return cut_sets


def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("-m", "--manifest-dir", type=Path)
return parser.parse_args()


def main():
args = get_args()

extractor = Fbank(FbankConfig(num_mel_bins=80))
num_jobs = min(16, os.cpu_count())

formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

logging.basicConfig(format=formatter, level=logging.INFO)

if (args.manifest_dir / ".reazonspeech-fbank.done").exists():
logging.info(
"Previous fbank computed for ReazonSpeech found. "
f"Delete {args.manifest_dir / '.reazonspeech-fbank.done'} to allow recomputing fbank."
)
return
else:
cut_sets = make_cutset_blueprints(args.manifest_dir)
for part, cut_set in cut_sets:
logging.info(f"Processing {part}")
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
num_jobs=num_jobs,
storage_path=(args.manifest_dir / f"feats_{part}").as_posix(),
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(args.manifest_dir / f"reazonspeech_cuts_{part}.jsonl.gz")

logging.info("All fbank computed for ReazonSpeech.")
(args.manifest_dir / ".reazonspeech-fbank.done").touch()


if __name__ == "__main__":
main()
58 changes: 58 additions & 0 deletions egs/multi_ja_en/ASR/local/display_manifest_statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# 2022 The University of Electro-Communications (author: Teo Wen Shen) # noqa
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
from pathlib import Path

from lhotse import CutSet, load_manifest

ARGPARSE_DESCRIPTION = """
This file displays duration statistics of utterances in a manifest.
You can use the displayed value to choose minimum/maximum duration
to remove short and long utterances during the training.
See the function `remove_short_and_long_utt()` in
pruned_transducer_stateless5/train.py for usage.
"""


def get_parser():
parser = argparse.ArgumentParser(
description=ARGPARSE_DESCRIPTION,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

parser.add_argument("--manifest-dir", type=Path, help="Path to cutset manifests")

return parser.parse_args()


def main():
args = get_parser()

for part in ["train", "dev"]:
path = args.manifest_dir / f"reazonspeech_cuts_{part}.jsonl.gz"
cuts: CutSet = load_manifest(path)

print("\n---------------------------------\n")
print(path.name + ":")
cuts.describe()


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions egs/multi_ja_en/ASR/local/prepare_char.py
66 changes: 66 additions & 0 deletions egs/multi_ja_en/ASR/local/prepare_for_bpe_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This script tokenizes the training transcript by CJK characters
# and saves the result to transcript_chars.txt, which is used
# to train the BPE model later.

import argparse
import re
from pathlib import Path

from tqdm.auto import tqdm

from icefall.utils import tokenize_by_ja_char


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Output directory.
The generated transcript_chars.txt is saved to this directory.
""",
)

parser.add_argument(
"--text",
type=str,
help="Training transcript.",
)

return parser.parse_args()


def main():
args = get_args()
lang_dir = Path(args.lang_dir)
text = Path(args.text)

assert lang_dir.exists() and text.exists(), f"{lang_dir} or {text} does not exist!"

transcript_path = lang_dir / "transcript_chars.txt"

with open(text, "r", encoding="utf-8") as fin:
with open(transcript_path, "w+", encoding="utf-8") as fout:
for line in tqdm(fin):
fout.write(tokenize_by_ja_char(line) + "\n")


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions egs/multi_ja_en/ASR/local/prepare_lang.py
Loading

0 comments on commit 0855b03

Please sign in to comment.