Skip to content

Commit

Permalink
add --skip-scoring to zipformer/ctc_decode.py
Browse files Browse the repository at this point in the history
  • Loading branch information
KarelVesely84 committed Aug 12, 2024
1 parent 9de503f commit bc87117
Showing 1 changed file with 61 additions and 29 deletions.
90 changes: 61 additions & 29 deletions egs/librispeech/ASR/zipformer/ctc_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,13 @@ def get_parser():
""",
)

parser.add_argument(
"--skip-scoring",
type=str2bool,
default=False,
help="""Skip scoring, but still save the ASR output (for eval sets)."""
)

add_model_arguments(parser)

return parser
Expand Down Expand Up @@ -455,7 +462,7 @@ def decode_one_batch(
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "ctc-decoding"
return {key: hyps}
return {key: hyps} # note: returns words

if params.decoding_method == "attention-decoder-rescoring-no-ngram":
best_path_dict = rescore_with_attention_decoder_no_ngram(
Expand Down Expand Up @@ -492,27 +499,27 @@ def decode_one_batch(
)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
key = f"oracle_{params.num_paths}_nbest-scale-{params.nbest_scale}" # noqa
return {key: hyps}

if params.decoding_method in ["1best", "nbest"]:
if params.decoding_method == "1best":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
key = "no_rescore"
key = "no-rescore"
else:
best_path = nbest_decoding(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
nbest_scale=params.nbest_scale,
)
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
key = f"no-rescore_nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa

hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
return {key: hyps}
return {key: hyps} # note: returns BPE tokens

assert params.decoding_method in [
"nbest-rescoring",
Expand Down Expand Up @@ -646,7 +653,27 @@ def decode_dataset(
return results


def save_results(
def save_asr_output(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
"""
Save text produced by ASR.
"""
for key, results in results_dict.items():

recogs_filename = (
params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
)

results = sorted(results)
store_transcripts(filename=recogs_filename, texts=results)

logging.info(f"The transcripts are stored in {recogs_filename}")


def save_wer_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
Expand All @@ -661,32 +688,30 @@ def save_results(

test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")

# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
with open(errs_filename, "w", encoding="utf8") as fd:
wer = write_error_stats(
fd, f"{test_set_name}_{key}", results, enable_log=enable_log
)
test_set_wers[key] = wer
if enable_log:
logging.info("Wrote detailed error stats to {}".format(errs_filename))

logging.info(f"Wrote detailed error stats to {errs_filename}")

test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)

wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"

with open(wer_filename, "w", encoding="utf8") as fd:
print("settings\tWER", file=fd)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
print(f"{key}\t{val}", file=fd)

s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
s = f"\nFor {test_set_name}, WER of different settings are:\n"
note = f"\tbest for {test_set_name}"
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
s += f"{key}\t{val}{note}\n"
note = ""
logging.info(s)

Expand Down Expand Up @@ -719,9 +744,9 @@ def main():
params.res_dir = params.exp_dir / params.decoding_method

if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
params.suffix = f"iter-{params.iter}_avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
params.suffix = f"epoch-{params.epoch}_avg-{params.avg}"

if params.causal:
assert (
Expand All @@ -730,11 +755,11 @@ def main():
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
params.suffix += f"_chunk-{params.chunk_size}"
params.suffix += f"_left-context-{params.left_context_frames}"

if params.use_averaged_model:
params.suffix += "-use-averaged-model"
params.suffix += "_use-averaged-model"

setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
Expand Down Expand Up @@ -940,12 +965,19 @@ def main():
G=G,
)

save_results(
save_asr_output(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)

if not params.skip_scoring:
save_wer_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)

logging.info("Done!")


Expand Down

0 comments on commit bc87117

Please sign in to comment.