Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement recipe for Fluent Speech Commands dataset #1469

Merged
merged 9 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions egs/fluent_speech_commands/SLU/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
## Fluent Speech Commands recipe

This is a recipe for the Fluent Speech Commands dataset, a speech dataset which transcribes short utterances (such as "turn the lights on in the kitchen") into action frames (such as {"action": "activate", "object": "lights", "location": "kitchen"}). The training set contains 23,132 utterances, whereas the test set contains 3793 utterances.

Dataset Paper link: <https://paperswithcode.com/dataset/fluent-speech-commands>

cd icefall/egs/fluent_speech_commands/
Training: python transducer/train.py
Decoding: python transducer/decode.py
136 changes: 136 additions & 0 deletions egs/fluent_speech_commands/SLU/local/compile_hlg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#!/usr/bin/env python3

"""
This script takes as input lang_dir and generates HLG from

- H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
- L, the lexicon, built from lang_dir/L_disambig.pt

Caution: We use a lexicon that contains disambiguation symbols

- G, the LM, built from data/lm/G.fst.txt

The generated HLG is saved in $lang_dir/HLG.pt
"""
import argparse
import logging
from pathlib import Path

import k2
import torch

from icefall.lexicon import Lexicon


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
""",
)

return parser.parse_args()


def compile_HLG(lang_dir: str) -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.

Return:
An FSA representing HLG.
"""
lexicon = Lexicon(lang_dir)
max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))

logging.info("Loading G.fst.txt")
with open(lang_dir / "G.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)

first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"]

L = k2.arc_sort(L)
G = k2.arc_sort(G)

logging.info("Intersecting L and G")
LG = k2.compose(L, G)
logging.info(f"LG shape: {LG.shape}")

logging.info("Connecting LG")
LG = k2.connect(LG)
logging.info(f"LG shape after k2.connect: {LG.shape}")

logging.info(type(LG.aux_labels))
logging.info("Determinizing LG")

LG = k2.determinize(LG)
logging.info(type(LG.aux_labels))

logging.info("Connecting LG after k2.determinize")
LG = k2.connect(LG)

logging.info("Removing disambiguation symbols on LG")

# LG.labels[LG.labels >= first_token_disambig_id] = 0
# see https://github.com/k2-fsa/k2/pull/1140
labels = LG.labels
labels[labels >= first_token_disambig_id] = 0
LG.labels = labels

assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0

LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")

LG = k2.connect(LG)
LG.aux_labels = LG.aux_labels.remove_values_eq(0)

logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)

logging.info("Composing H and LG")
# CAUTION: The name of the inner_labels is fixed
# to `tokens`. If you want to change it, please
# also change other places in icefall that are using
# it.
HLG = k2.compose(H, LG, inner_labels="tokens")

logging.info("Connecting LG")
HLG = k2.connect(HLG)

logging.info("Arc sorting LG")
HLG = k2.arc_sort(HLG)
logging.info(f"HLG.shape: {HLG.shape}")

return HLG


def main():
args = get_args()
lang_dir = Path(args.lang_dir)

if (lang_dir / "HLG.pt").is_file():
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
return

logging.info(f"Processing {lang_dir}")

HLG = compile_HLG(lang_dir)
logging.info(f"Saving HLG.pt to {lang_dir}")
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")


if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

logging.basicConfig(format=formatter, level=logging.INFO)

main()
97 changes: 97 additions & 0 deletions egs/fluent_speech_commands/SLU/local/compute_fbank_slu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#!/usr/bin/env python3

"""
This file computes fbank features of the Fluent Speech Commands dataset.
It looks for manifests in the directory data/manifests.

The generated fbank features are saved in data/fbank.
"""

import argparse
import logging
import os
from pathlib import Path

import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached

from icefall.utils import get_executor

# Torch's multithreaded behavior needs to be disabled or it wastes a
# lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)


def compute_fbank_slu(manifest_dir, fbanks_dir):
src_dir = Path(manifest_dir)
output_dir = Path(fbanks_dir)

# This dataset is rather small, so we use only one job
num_jobs = min(1, os.cpu_count())
num_mel_bins = 23

dataset_parts = (
"train",
"valid",
"test",
)
prefix = "slu"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None

assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)

extractor = Fbank(FbankConfig(sampling_rate=16000, num_mel_bins=num_mel_bins))

with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
cuts_file = output_dir / f"{prefix}_cuts_{partition}.{suffix}"
if cuts_file.is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if "train" in partition:
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 1, # use one job
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(cuts_file)


parser = argparse.ArgumentParser()
parser.add_argument("manifest_dir")
parser.add_argument("fbanks_dir")

if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
args = parser.parse_args()

logging.basicConfig(format=formatter, level=logging.INFO)

compute_fbank_slu(args.manifest_dir, args.fbanks_dir)
59 changes: 59 additions & 0 deletions egs/fluent_speech_commands/SLU/local/generate_lexicon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import argparse

import pandas
from tqdm import tqdm


def generate_lexicon(corpus_dir, lm_dir):
data = pandas.read_csv(
str(corpus_dir) + "/data/train_data.csv", index_col=0, header=0
)
vocab_transcript = set()
vocab_frames = set()
transcripts = data["transcription"].tolist()
frames = list(
i
for i in zip(
data["action"].tolist(), data["object"].tolist(), data["location"].tolist()
)
)

for transcript in tqdm(transcripts):
for word in transcript.split():
vocab_transcript.add(word)

for frame in tqdm(frames):
for word in frame:
vocab_frames.add("_".join(word.split()))

with open(lm_dir + "/words_transcript.txt", "w") as lexicon_transcript_file:
lexicon_transcript_file.write("<UNK> 1" + "\n")
lexicon_transcript_file.write("<s> 2" + "\n")
lexicon_transcript_file.write("</s> 0" + "\n")
id = 3
for vocab in vocab_transcript:
lexicon_transcript_file.write(vocab + " " + str(id) + "\n")
id += 1

with open(lm_dir + "/words_frames.txt", "w") as lexicon_frames_file:
lexicon_frames_file.write("<UNK> 1" + "\n")
lexicon_frames_file.write("<s> 2" + "\n")
lexicon_frames_file.write("</s> 0" + "\n")
id = 3
for vocab in vocab_frames:
lexicon_frames_file.write(vocab + " " + str(id) + "\n")
id += 1


parser = argparse.ArgumentParser()
parser.add_argument("corpus_dir")
parser.add_argument("lm_dir")


def main():
args = parser.parse_args()

generate_lexicon(args.corpus_dir, args.lm_dir)


main()
Loading
Loading