Skip to content

Commit

Permalink
add decoding method of ctc-greedy-search in zipformer recipe (#1690)
Browse files Browse the repository at this point in the history
  • Loading branch information
yaozengwei authored Jul 14, 2024
1 parent 334beed commit d47c078
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 16 deletions.
57 changes: 41 additions & 16 deletions egs/librispeech/ASR/zipformer/ctc_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,16 @@
"""
Usage:
(1) ctc-decoding
(1) ctc-greedy-search
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--use-ctc 1 \
--max-duration 600 \
--decoding-method ctc-greedy-search
(2) ctc-decoding
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
Expand All @@ -30,7 +39,7 @@
--max-duration 600 \
--decoding-method ctc-decoding
(2) 1best
(3) 1best
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
Expand All @@ -40,7 +49,7 @@
--hlg-scale 0.6 \
--decoding-method 1best
(3) nbest
(4) nbest
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
Expand All @@ -50,7 +59,7 @@
--hlg-scale 0.6 \
--decoding-method nbest
(4) nbest-rescoring
(5) nbest-rescoring
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
Expand All @@ -62,7 +71,7 @@
--lm-dir data/lm \
--decoding-method nbest-rescoring
(5) whole-lattice-rescoring
(6) whole-lattice-rescoring
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
Expand All @@ -74,7 +83,7 @@
--lm-dir data/lm \
--decoding-method whole-lattice-rescoring
(6) attention-decoder-rescoring-no-ngram
(7) attention-decoder-rescoring-no-ngram
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
Expand All @@ -84,7 +93,7 @@
--max-duration 100 \
--decoding-method attention-decoder-rescoring-no-ngram
(7) attention-decoder-rescoring-with-ngram
(8) attention-decoder-rescoring-with-ngram
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
Expand Down Expand Up @@ -120,6 +129,7 @@
load_checkpoint,
)
from icefall.decode import (
ctc_greedy_search,
get_lattice,
nbest_decoding,
nbest_oracle,
Expand Down Expand Up @@ -220,26 +230,29 @@ def get_parser():
default="ctc-decoding",
help="""Decoding method.
Supported values are:
- (1) ctc-decoding. Use CTC decoding. It uses a sentence piece
- (1) ctc-greedy-search. Use CTC greedy search. It uses a sentence piece
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM.
- (2) 1best. Extract the best path from the decoding lattice as the
- (2) ctc-decoding. Use CTC decoding. It uses a sentence piece
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM.
- (3) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (3) nbest. Extract n paths from the decoding lattice; the path
- (4) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result.
- (4) nbest-rescoring. Extract n paths from the decoding lattice,
- (5) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
the highest score is the decoding result.
- (5) whole-lattice-rescoring. Rescore the decoding lattice with an
- (6) whole-lattice-rescoring. Rescore the decoding lattice with an
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
is the decoding result.
you have trained an RNN LM using ./rnn_lm/train.py
- (6) nbest-oracle. Its WER is the lower bound of any n-best
- (7) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
- (7) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding
- (8) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding
lattice, rescore them with the attention decoder.
- (8) attention-decoder-rescoring-with-ngram. Extract n paths from the LM
- (9) attention-decoder-rescoring-with-ngram. Extract n paths from the LM
rescored lattice, rescore them with the attention decoder.
""",
)
Expand Down Expand Up @@ -381,6 +394,15 @@ def decode_one_batch(
encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
ctc_output = model.ctc_output(encoder_out) # (N, T, C)

if params.decoding_method == "ctc-greedy-search":
hyps = ctc_greedy_search(ctc_output, encoder_out_lens)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(hyps)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "ctc-greedy-search"
return {key: hyps}

supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
Expand Down Expand Up @@ -684,6 +706,7 @@ def main():
params.update(vars(args))

assert params.decoding_method in (
"ctc-greedy-search",
"ctc-decoding",
"1best",
"nbest",
Expand Down Expand Up @@ -733,7 +756,9 @@ def main():
params.eos_id = 1
params.sos_id = 1

if params.decoding_method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]:
if params.decoding_method in [
"ctc-greedy-search", "ctc-decoding", "attention-decoder-rescoring-no-ngram"
]:
HLG = None
H = k2.ctc_topo(
max_token=max_token_id,
Expand Down
31 changes: 31 additions & 0 deletions icefall/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,3 +1473,34 @@ def rescore_with_rnn_lm(
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}_rnn_lm_scale_{r_scale}" # noqa
ans[key] = best_path
return ans


def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
# from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
if hyp[cur] != 0:
new_hyp.append(hyp[cur])
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp


def ctc_greedy_search(
ctc_output: torch.Tensor, encoder_out_lens: torch.Tensor
) -> List[List[int]]:
"""CTC greedy search.
Args:
ctc_output: (batch, seq_len, vocab_size)
encoder_out_lens: (batch,)
Returns:
List[List[int]]: greedy search result
"""
batch = ctc_output.shape[0]
index = ctc_output.argmax(dim=-1) # (batch, seq_len)
hyps = [index[i].tolist()[:encoder_out_lens[i]] for i in range(batch)]
hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
return hyps

0 comments on commit d47c078

Please sign in to comment.