Skip to content

Commit

Permalink
fixed a CI test for wenetspeech (#1476)
Browse files Browse the repository at this point in the history
* Comply to issue #1149

#1149
  • Loading branch information
JinZr authored Jan 26, 2024
1 parent 1c30847 commit 37b975c
Show file tree
Hide file tree
Showing 17 changed files with 150 additions and 169 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,22 @@ log "Test exporting to ONNX format"

./pruned_transducer_stateless2/export-onnx.py \
--exp-dir $repo/exp \
--lang-dir $repo/data/lang_char \
--tokens $repo/data/lang_char/tokens.txt \
--epoch 99 \
--avg 1

log "Export to torchscript model"

./pruned_transducer_stateless2/export.py \
--exp-dir $repo/exp \
--lang-dir $repo/data/lang_char \
--tokens $repo/data/lang_char/tokens.txt \
--epoch 99 \
--avg 1 \
--jit 1

./pruned_transducer_stateless2/export.py \
--exp-dir $repo/exp \
--lang-dir $repo/data/lang_char \
--tokens $repo/data/lang_char/tokens.txt \
--epoch 99 \
--avg 1 \
--jit-trace 1
Expand Down
19 changes: 9 additions & 10 deletions egs/aishell/ASR/pruned_transducer_stateless2/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@
import logging
from pathlib import Path

import k2
import torch
from train import add_model_arguments, get_params, get_transducer_model

from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool


def get_parser():
Expand Down Expand Up @@ -106,10 +106,10 @@ def get_parser():
)

parser.add_argument(
"--lang-dir",
type=Path,
default=Path("data/lang_char"),
help="The lang dir",
"--tokens",
type=str,
default="data/lang_char/tokens.txt",
help="Path to the tokens.txt",
)

parser.add_argument(
Expand All @@ -136,10 +136,9 @@ def main():

logging.info(f"device: {device}")

lexicon = Lexicon(params.lang_dir)

params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1

logging.info(params)

Expand Down
19 changes: 9 additions & 10 deletions egs/aishell/ASR/pruned_transducer_stateless3/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import logging
from pathlib import Path

import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
Expand All @@ -57,8 +58,7 @@
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool


def get_parser():
Expand Down Expand Up @@ -123,10 +123,10 @@ def get_parser():
)

parser.add_argument(
"--lang-dir",
type=Path,
default=Path("data/lang_char"),
help="The lang dir",
"--tokens",
type=str,
default="data/lang_char/tokens.txt",
help="Path to the tokens.txt",
)

parser.add_argument(
Expand All @@ -153,10 +153,9 @@ def main():

logging.info(f"device: {device}")

lexicon = Lexicon(params.lang_dir)

params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1
params.datatang_prob = 0

logging.info(params)
Expand Down
21 changes: 9 additions & 12 deletions egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@
from pathlib import Path
from typing import Dict, Tuple

import k2
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from decoder2 import Decoder
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from zipformer import Zipformer

from icefall.checkpoint import (
Expand All @@ -65,8 +65,7 @@
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import setup_logger, str2bool
from icefall.utils import num_tokens, setup_logger, str2bool


def get_parser():
Expand Down Expand Up @@ -123,12 +122,10 @@ def get_parser():
)

parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
default="data/lang_char/tokens.txt",
help="Path to the tokens.txt",
)

parser.add_argument(
Expand Down Expand Up @@ -404,9 +401,9 @@ def main():

logging.info(f"device: {device}")

lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1

logging.info(params)

Expand Down
19 changes: 9 additions & 10 deletions egs/aishell/ASR/transducer_stateless/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
Usage:
./transducer_stateless/export.py \
--exp-dir ./transducer_stateless/exp \
--lang-dir data/lang_char \
--tokens data/lang_char/tokens.txt \
--epoch 20 \
--avg 10
Expand All @@ -47,6 +47,7 @@
import logging
from pathlib import Path

import k2
import torch
import torch.nn as nn
from conformer import Conformer
Expand All @@ -56,8 +57,7 @@

from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
from icefall.utils import AttributeDict, num_tokens, str2bool


def get_parser():
Expand Down Expand Up @@ -92,10 +92,10 @@ def get_parser():
)

parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
default="data/lang_char",
help="The lang dir",
default="data/lang_char/tokens.txt",
help="Path to the tokens.txt",
)

parser.add_argument(
Expand Down Expand Up @@ -192,10 +192,9 @@ def main():

logging.info(f"device: {device}")

lexicon = Lexicon(params.lang_dir)

params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1

logging.info(params)

Expand Down
18 changes: 9 additions & 9 deletions egs/aishell/ASR/transducer_stateless_modified-2/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import logging
from pathlib import Path

import k2
import torch
import torch.nn as nn
from conformer import Conformer
Expand All @@ -56,7 +57,7 @@
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
from icefall.utils import AttributeDict, num_tokens, str2bool


def get_parser():
Expand Down Expand Up @@ -99,10 +100,10 @@ def get_parser():
)

parser.add_argument(
"--lang-dir",
type=Path,
default=Path("data/lang_char"),
help="The lang dir",
"--tokens",
type=str,
default="data/lang_char/tokens.txt",
help="Path to the tokens.txt",
)

parser.add_argument(
Expand Down Expand Up @@ -190,10 +191,9 @@ def main():

logging.info(f"device: {device}")

lexicon = Lexicon(params.lang_dir)

params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1

logging.info(params)

Expand Down
19 changes: 9 additions & 10 deletions egs/aishell/ASR/transducer_stateless_modified/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import logging
from pathlib import Path

import k2
import torch
import torch.nn as nn
from conformer import Conformer
Expand All @@ -55,8 +56,7 @@

from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
from icefall.utils import AttributeDict, num_tokens, str2bool


def get_parser():
Expand Down Expand Up @@ -99,10 +99,10 @@ def get_parser():
)

parser.add_argument(
"--lang-dir",
type=Path,
default=Path("data/lang_char"),
help="The lang dir",
"--tokens",
type=str,
default="data/lang_char/tokens.txt",
help="Path to the tokens.txt",
)

parser.add_argument(
Expand Down Expand Up @@ -190,10 +190,9 @@ def main():

logging.info(f"device: {device}")

lexicon = Lexicon(params.lang_dir)

params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1

logging.info(params)

Expand Down
20 changes: 10 additions & 10 deletions egs/aishell2/ASR/pruned_transducer_stateless5/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
Usage:
./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir data/lang_char
--tokens ./data/lang_char/tokens.txt \
--epoch 25 \
--avg 5
Expand All @@ -48,6 +48,7 @@
import logging
from pathlib import Path

import k2
import torch
from train import add_model_arguments, get_params, get_transducer_model

Expand All @@ -57,8 +58,7 @@
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool


def get_parser():
Expand Down Expand Up @@ -115,10 +115,10 @@ def get_parser():
)

parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
default="data/lang_char",
help="The lang dir",
default="data/lang_char/tokens.txt",
help="Path to the tokens.txt",
)

parser.add_argument(
Expand Down Expand Up @@ -154,10 +154,10 @@ def main():

logging.info(f"device: {device}")

lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.unk_id = lexicon.token_table["<unk>"]
params.vocab_size = max(lexicon.tokens) + 1
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1

logging.info(params)

Expand Down
Loading

0 comments on commit 37b975c

Please sign in to comment.