-
Notifications
You must be signed in to change notification settings - Fork 309
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement recipe for Fluent Speech Commands dataset
Signed-off-by: Xinyuan Li <[email protected]>
- Loading branch information
Xinyuan Li
committed
Jan 19, 2024
1 parent
bbb03f7
commit d305c7c
Showing
35 changed files
with
6,698 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
## Fluent Speech Commands recipe | ||
|
||
This is a recipe for the Fluent Speech Commands dataset, a speech dataset which transcribes short utterances (such as "turn the lights on in the kitchen") into action frames (such as {"action": "activate", "object": "lights", "location": "kitchen"}). The training set contains 23,132 utterances, whereas the test set contains 3793 utterances. | ||
|
||
Dataset Paper link: <https://paperswithcode.com/dataset/fluent-speech-commands> | ||
|
||
cd icefall/egs/fluent_speech_commands/ | ||
Training: python transducer/train.py | ||
Decoding: python transducer/decode.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
#!/usr/bin/env python3 | ||
|
||
""" | ||
This script takes as input lang_dir and generates HLG from | ||
- H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt | ||
- L, the lexicon, built from lang_dir/L_disambig.pt | ||
Caution: We use a lexicon that contains disambiguation symbols | ||
- G, the LM, built from data/lm/G.fst.txt | ||
The generated HLG is saved in $lang_dir/HLG.pt | ||
""" | ||
import argparse | ||
import logging | ||
from pathlib import Path | ||
|
||
import k2 | ||
import torch | ||
|
||
from icefall.lexicon import Lexicon | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--lang-dir", | ||
type=str, | ||
help="""Input and output directory. | ||
""", | ||
) | ||
|
||
return parser.parse_args() | ||
|
||
|
||
def compile_HLG(lang_dir: str) -> k2.Fsa: | ||
""" | ||
Args: | ||
lang_dir: | ||
The language directory, e.g., data/lang_phone or data/lang_bpe_5000. | ||
Return: | ||
An FSA representing HLG. | ||
""" | ||
lexicon = Lexicon(lang_dir) | ||
max_token_id = max(lexicon.tokens) | ||
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") | ||
H = k2.ctc_topo(max_token_id) | ||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) | ||
|
||
logging.info("Loading G.fst.txt") | ||
with open(lang_dir / "G.fst.txt") as f: | ||
G = k2.Fsa.from_openfst(f.read(), acceptor=False) | ||
|
||
first_token_disambig_id = lexicon.token_table["#0"] | ||
first_word_disambig_id = lexicon.word_table["#0"] | ||
|
||
L = k2.arc_sort(L) | ||
G = k2.arc_sort(G) | ||
|
||
logging.info("Intersecting L and G") | ||
LG = k2.compose(L, G) | ||
logging.info(f"LG shape: {LG.shape}") | ||
|
||
logging.info("Connecting LG") | ||
LG = k2.connect(LG) | ||
logging.info(f"LG shape after k2.connect: {LG.shape}") | ||
|
||
logging.info(type(LG.aux_labels)) | ||
logging.info("Determinizing LG") | ||
|
||
LG = k2.determinize(LG) | ||
logging.info(type(LG.aux_labels)) | ||
|
||
logging.info("Connecting LG after k2.determinize") | ||
LG = k2.connect(LG) | ||
|
||
logging.info("Removing disambiguation symbols on LG") | ||
|
||
# LG.labels[LG.labels >= first_token_disambig_id] = 0 | ||
# see https://github.com/k2-fsa/k2/pull/1140 | ||
labels = LG.labels | ||
labels[labels >= first_token_disambig_id] = 0 | ||
LG.labels = labels | ||
|
||
assert isinstance(LG.aux_labels, k2.RaggedTensor) | ||
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 | ||
|
||
LG = k2.remove_epsilon(LG) | ||
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") | ||
|
||
LG = k2.connect(LG) | ||
LG.aux_labels = LG.aux_labels.remove_values_eq(0) | ||
|
||
logging.info("Arc sorting LG") | ||
LG = k2.arc_sort(LG) | ||
|
||
logging.info("Composing H and LG") | ||
# CAUTION: The name of the inner_labels is fixed | ||
# to `tokens`. If you want to change it, please | ||
# also change other places in icefall that are using | ||
# it. | ||
HLG = k2.compose(H, LG, inner_labels="tokens") | ||
|
||
logging.info("Connecting LG") | ||
HLG = k2.connect(HLG) | ||
|
||
logging.info("Arc sorting LG") | ||
HLG = k2.arc_sort(HLG) | ||
logging.info(f"HLG.shape: {HLG.shape}") | ||
|
||
return HLG | ||
|
||
|
||
def main(): | ||
args = get_args() | ||
lang_dir = Path(args.lang_dir) | ||
|
||
if (lang_dir / "HLG.pt").is_file(): | ||
logging.info(f"{lang_dir}/HLG.pt already exists - skipping") | ||
return | ||
|
||
logging.info(f"Processing {lang_dir}") | ||
|
||
HLG = compile_HLG(lang_dir) | ||
logging.info(f"Saving HLG.pt to {lang_dir}") | ||
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") | ||
|
||
|
||
if __name__ == "__main__": | ||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" | ||
|
||
logging.basicConfig(format=formatter, level=logging.INFO) | ||
|
||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
#!/usr/bin/env python3 | ||
|
||
""" | ||
This file computes fbank features of the Fluent Speech Commands dataset. | ||
It looks for manifests in the directory data/manifests. | ||
The generated fbank features are saved in data/fbank. | ||
""" | ||
|
||
import logging | ||
import os, argparse | ||
from pathlib import Path | ||
|
||
import torch | ||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter | ||
from lhotse.recipes.utils import read_manifests_if_cached | ||
|
||
from icefall.utils import get_executor | ||
|
||
# Torch's multithreaded behavior needs to be disabled or it wastes a | ||
# lot of CPU and slow things down. | ||
# Do this outside of main() in case it needs to take effect | ||
# even when we are not invoking the main (e.g. when spawning subprocesses). | ||
torch.set_num_threads(1) | ||
torch.set_num_interop_threads(1) | ||
|
||
|
||
def compute_fbank_slu(manifest_dir, fbanks_dir): | ||
src_dir = Path(manifest_dir) | ||
output_dir = Path(fbanks_dir) | ||
|
||
# This dataset is rather small, so we use only one job | ||
num_jobs = min(1, os.cpu_count()) | ||
num_mel_bins = 23 | ||
|
||
dataset_parts = ( | ||
"train", | ||
"valid", | ||
"test", | ||
) | ||
prefix = "slu" | ||
suffix = "jsonl.gz" | ||
manifests = read_manifests_if_cached( | ||
dataset_parts=dataset_parts, | ||
output_dir=src_dir, | ||
prefix=prefix, | ||
suffix=suffix, | ||
) | ||
assert manifests is not None | ||
|
||
assert len(manifests) == len(dataset_parts), ( | ||
len(manifests), | ||
len(dataset_parts), | ||
list(manifests.keys()), | ||
dataset_parts, | ||
) | ||
|
||
extractor = Fbank(FbankConfig(sampling_rate=16000, num_mel_bins=num_mel_bins)) | ||
|
||
with get_executor() as ex: # Initialize the executor only once. | ||
for partition, m in manifests.items(): | ||
cuts_file = output_dir / f"{prefix}_cuts_{partition}.{suffix}" | ||
if cuts_file.is_file(): | ||
logging.info(f"{partition} already exists - skipping.") | ||
continue | ||
logging.info(f"Processing {partition}") | ||
cut_set = CutSet.from_manifests( | ||
recordings=m["recordings"], | ||
supervisions=m["supervisions"], | ||
) | ||
if "train" in partition: | ||
cut_set = ( | ||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) | ||
) | ||
cut_set = cut_set.compute_and_store_features( | ||
extractor=extractor, | ||
storage_path=f"{output_dir}/{prefix}_feats_{partition}", | ||
# when an executor is specified, make more partitions | ||
num_jobs=num_jobs if ex is None else 1, # use one job | ||
executor=ex, | ||
storage_type=LilcomChunkyWriter, | ||
) | ||
cut_set.to_file(cuts_file) | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('manifest_dir') | ||
parser.add_argument('fbanks_dir') | ||
|
||
if __name__ == "__main__": | ||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" | ||
args = parser.parse_args() | ||
|
||
logging.basicConfig(format=formatter, level=logging.INFO) | ||
|
||
compute_fbank_slu(args.manifest_dir, args.fbanks_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import pandas, argparse | ||
from tqdm import tqdm | ||
|
||
def generate_lexicon(corpus_dir, lm_dir): | ||
data = pandas.read_csv(str(corpus_dir) + '/data/train_data.csv', index_col = 0, header = 0) | ||
vocab_transcript = set() | ||
vocab_frames = set() | ||
transcripts = data['transcription'].tolist() | ||
frames = list(i for i in zip(data['action'].tolist(), data['object'].tolist(), data['location'].tolist())) | ||
|
||
for transcript in tqdm(transcripts): | ||
for word in transcript.split(): | ||
vocab_transcript.add(word) | ||
|
||
for frame in tqdm(frames): | ||
for word in frame: | ||
vocab_frames.add('_'.join(word.split())) | ||
|
||
with open(lm_dir + '/words_transcript.txt', 'w') as lexicon_transcript_file: | ||
lexicon_transcript_file.write("<UNK> 1" + '\n') | ||
lexicon_transcript_file.write("<s> 2" + '\n') | ||
lexicon_transcript_file.write("</s> 0" + '\n') | ||
id = 3 | ||
for vocab in vocab_transcript: | ||
lexicon_transcript_file.write(vocab + ' ' + str(id) + '\n') | ||
id += 1 | ||
|
||
with open(lm_dir + '/words_frames.txt', 'w') as lexicon_frames_file: | ||
lexicon_frames_file.write("<UNK> 1" + '\n') | ||
lexicon_frames_file.write("<s> 2" + '\n') | ||
lexicon_frames_file.write("</s> 0" + '\n') | ||
id = 3 | ||
for vocab in vocab_frames: | ||
lexicon_frames_file.write(vocab + ' ' + str(id) + '\n') | ||
id += 1 | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('corpus_dir') | ||
parser.add_argument('lm_dir') | ||
|
||
def main(): | ||
args = parser.parse_args() | ||
|
||
generate_lexicon(args.corpus_dir, args.lm_dir) | ||
|
||
main() |
Oops, something went wrong.