Skip to content

Commit

Permalink
Comply to issue #1149
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr committed Jan 26, 2024
1 parent c606ef5 commit b9bbdfa
Show file tree
Hide file tree
Showing 14 changed files with 126 additions and 140 deletions.
19 changes: 9 additions & 10 deletions egs/aishell/ASR/pruned_transducer_stateless2/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@
import logging
from pathlib import Path

import k2
import torch
from train import add_model_arguments, get_params, get_transducer_model

from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool


def get_parser():
Expand Down Expand Up @@ -106,10 +106,10 @@ def get_parser():
)

parser.add_argument(
"--lang-dir",
type=Path,
default=Path("data/lang_char"),
help="The lang dir",
"--tokens",
type=str,
default="data/lang_char/tokens.txt",
help="Path to the tokens.txt",
)

parser.add_argument(
Expand All @@ -136,10 +136,9 @@ def main():

logging.info(f"device: {device}")

lexicon = Lexicon(params.lang_dir)

params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1

logging.info(params)

Expand Down
19 changes: 9 additions & 10 deletions egs/aishell/ASR/pruned_transducer_stateless3/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import logging
from pathlib import Path

import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
Expand All @@ -57,8 +58,7 @@
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool


def get_parser():
Expand Down Expand Up @@ -123,10 +123,10 @@ def get_parser():
)

parser.add_argument(
"--lang-dir",
type=Path,
default=Path("data/lang_char"),
help="The lang dir",
"--tokens",
type=str,
default="data/lang_char/tokens.txt",
help="Path to the tokens.txt",
)

parser.add_argument(
Expand All @@ -153,10 +153,9 @@ def main():

logging.info(f"device: {device}")

lexicon = Lexicon(params.lang_dir)

params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1
params.datatang_prob = 0

logging.info(params)
Expand Down
21 changes: 9 additions & 12 deletions egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@
from pathlib import Path
from typing import Dict, Tuple

import k2
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from decoder2 import Decoder
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from zipformer import Zipformer

from icefall.checkpoint import (
Expand All @@ -65,8 +65,7 @@
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import setup_logger, str2bool
from icefall.utils import num_tokens, setup_logger, str2bool


def get_parser():
Expand Down Expand Up @@ -123,12 +122,10 @@ def get_parser():
)

parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
default="data/lang_char/tokens.txt",
help="Path to the tokens.txt",
)

parser.add_argument(
Expand Down Expand Up @@ -404,9 +401,9 @@ def main():

logging.info(f"device: {device}")

lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1

logging.info(params)

Expand Down
19 changes: 9 additions & 10 deletions egs/aishell/ASR/transducer_stateless/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
Usage:
./transducer_stateless/export.py \
--exp-dir ./transducer_stateless/exp \
--lang-dir data/lang_char \
--tokens data/lang_char/tokens.txt \
--epoch 20 \
--avg 10
Expand All @@ -47,6 +47,7 @@
import logging
from pathlib import Path

import k2
import torch
import torch.nn as nn
from conformer import Conformer
Expand All @@ -56,8 +57,7 @@

from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
from icefall.utils import AttributeDict, num_tokens, str2bool


def get_parser():
Expand Down Expand Up @@ -92,10 +92,10 @@ def get_parser():
)

parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
default="data/lang_char",
help="The lang dir",
default="data/lang_char/tokens.txt",
help="Path to the tokens.txt",
)

parser.add_argument(
Expand Down Expand Up @@ -192,10 +192,9 @@ def main():

logging.info(f"device: {device}")

lexicon = Lexicon(params.lang_dir)

params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1

logging.info(params)

Expand Down
18 changes: 9 additions & 9 deletions egs/aishell/ASR/transducer_stateless_modified-2/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import logging
from pathlib import Path

import k2
import torch
import torch.nn as nn
from conformer import Conformer
Expand All @@ -56,7 +57,7 @@
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
from icefall.utils import AttributeDict, num_tokens, str2bool


def get_parser():
Expand Down Expand Up @@ -99,10 +100,10 @@ def get_parser():
)

parser.add_argument(
"--lang-dir",
type=Path,
default=Path("data/lang_char"),
help="The lang dir",
"--tokens",
type=str,
default="data/lang_char/tokens.txt",
help="Path to the tokens.txt",
)

parser.add_argument(
Expand Down Expand Up @@ -190,10 +191,9 @@ def main():

logging.info(f"device: {device}")

lexicon = Lexicon(params.lang_dir)

params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1

logging.info(params)

Expand Down
19 changes: 9 additions & 10 deletions egs/aishell/ASR/transducer_stateless_modified/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import logging
from pathlib import Path

import k2
import torch
import torch.nn as nn
from conformer import Conformer
Expand All @@ -55,8 +56,7 @@

from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
from icefall.utils import AttributeDict, num_tokens, str2bool


def get_parser():
Expand Down Expand Up @@ -99,10 +99,10 @@ def get_parser():
)

parser.add_argument(
"--lang-dir",
type=Path,
default=Path("data/lang_char"),
help="The lang dir",
"--tokens",
type=str,
default="data/lang_char/tokens.txt",
help="Path to the tokens.txt",
)

parser.add_argument(
Expand Down Expand Up @@ -190,10 +190,9 @@ def main():

logging.info(f"device: {device}")

lexicon = Lexicon(params.lang_dir)

params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1

logging.info(params)

Expand Down
20 changes: 10 additions & 10 deletions egs/aishell2/ASR/pruned_transducer_stateless5/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
Usage:
./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir data/lang_char
--tokens ./data/lang_char/tokens.txt \
--epoch 25 \
--avg 5
Expand All @@ -48,6 +48,7 @@
import logging
from pathlib import Path

import k2
import torch
from train import add_model_arguments, get_params, get_transducer_model

Expand All @@ -57,8 +58,7 @@
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool


def get_parser():
Expand Down Expand Up @@ -115,10 +115,10 @@ def get_parser():
)

parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
default="data/lang_char",
help="The lang dir",
default="data/lang_char/tokens.txt",
help="Path to the tokens.txt",
)

parser.add_argument(
Expand Down Expand Up @@ -154,10 +154,10 @@ def main():

logging.info(f"device: {device}")

lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.unk_id = lexicon.token_table["<unk>"]
params.vocab_size = max(lexicon.tokens) + 1
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1

logging.info(params)

Expand Down
19 changes: 8 additions & 11 deletions egs/aishell4/ASR/pruned_transducer_stateless5/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import logging
from pathlib import Path

import k2
import torch
from train import add_model_arguments, get_params, get_transducer_model

Expand All @@ -57,8 +58,7 @@
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool


def get_parser():
Expand Down Expand Up @@ -115,13 +115,10 @@ def get_parser():
)

parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
default="data/lang_char/tokens.txt",
help="Path to the tokens.txt",
)

parser.add_argument(
Expand Down Expand Up @@ -157,9 +154,9 @@ def main():

logging.info(f"device: {device}")

lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1

logging.info(params)

Expand Down
Loading

0 comments on commit b9bbdfa

Please sign in to comment.