Skip to content

Commit

Permalink
update prepare.sh
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Oct 23, 2024
1 parent c920735 commit 84f8adf
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 33 deletions.
35 changes: 5 additions & 30 deletions egs/librilight/SSL/local/extract_kmeans_from_hubert_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import torch
from lhotse import CutSet, SupervisionSegment
from lhotse.utils import fastcopy
from silero_vad import get_speech_timestamps, load_silero_vad
from tqdm import tqdm

# Torch's multithreaded behavior needs to be disabled or
Expand Down Expand Up @@ -82,7 +81,7 @@ def get_args():
parser.add_argument(
"--kmeans-model-path",
type=str,
default="download/hubert_base_ls960_L9_km500.model",
default="download/hubert_base_ls960_L9_km500.bin",
)

parser.add_argument(
Expand All @@ -103,28 +102,19 @@ def get_args():


def extract_and_save_one_cuts(
raw_cuts_path, cuts_path, model, vad_model, apply_kmeans, do_normalize, device
raw_cuts_path, cuts_path, model, apply_kmeans, do_normalize, device
):
logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)

logging.info("Extracting kmeans")
cuts = []
for cut in tqdm(cut_set):
assert cut.sampling_rate == 16000, f"{cut.sampling_rate}"
assert cut.sampling_rate == 16000, f"Sampling rate: {cut.sampling_rate}"
audio = cut.load_audio()

if audio.shape[-1] > 64 * 16000:
timestamps = get_speech_timestamps(audio, vad_model)
offsets = [i["start"] for i in timestamps]
audios = [audio[:, i["start"] : i["end"]] for i in timestamps]
logging.info(f"Trim audio {cut.id} into {len(audios)} segments")
else:
offsets = [0]
audios = [audio]

seq = 0
for audio, offset in zip(audios, offsets):
offsets = 0
if True:
x = torch.from_numpy(audio).float().to(device)

with torch.no_grad():
Expand All @@ -141,24 +131,12 @@ def extract_and_save_one_cuts(

kmeans = " ".join(map(str, apply_kmeans(feature).tolist()))

supervision_segment = fastcopy(
cut.supervisions[0],
id=f"{cut.id}-{seq}",
start=0.0,
duration=audio.shape[-1] / 16000,
)
cut_with_kmeans = fastcopy(
cut,
id=f"{cut.id}-{seq}",
start=cut.start + offset / 16000,
duration=audio.shape[-1] / 16000,
supervisions=[supervision_segment],
custom={"kmeans": kmeans},
)
cuts.append(cut_with_kmeans)

seq += 1

cuts = CutSet(cuts)

logging.info(f"Saving to {cuts_path}")
Expand All @@ -181,7 +159,6 @@ def extract_kmeans(args):

prefix = "librilight"

vad_model = load_silero_vad()
apply_kmeans = ApplyKmeans(args.kmeans_model_path)
model, _, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[args.model_path]
Expand All @@ -204,7 +181,6 @@ def extract_kmeans(args):
raw_cuts_path,
cuts_path,
model,
vad_model,
apply_kmeans,
do_normalize,
device,
Expand Down Expand Up @@ -235,7 +211,6 @@ def extract_kmeans(args):
raw_cuts_path,
cuts_path,
model,
vad_model,
apply_kmeans,
do_normalize,
device,
Expand Down
12 changes: 9 additions & 3 deletions egs/librilight/SSL/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python

set -eou pipefail

nj=15
nj=32
# run step 0 to step 4 by default
stage=0
stop_stage=4
Expand Down Expand Up @@ -58,13 +58,13 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
mkdir -p data/kmeans
if [ ! -f data/kmeans/.preprocess_complete ]; then
python3 ./local/preprocess_librilight.py
touch data/fbank/.preprocess_complete
touch data/kmeans/.preprocess_complete
fi
fi

if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Split medium and large subset into pieces"
num_per_split=200000
num_per_split=2500
split_dir=data/kmeans/medium_split
if [ ! -f $split_dir/.split_completed ]; then
lhotse split-lazy ./data/kmeans/librilight_cuts_medium_raw.jsonl.gz $split_dir $num_per_split
Expand All @@ -79,6 +79,12 @@ fi

if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Extract SSL target for librilight"
if [ ! -e download/hubert_base_ls960.pt ]; then
wget https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt -P download
fi
if [ ! -e download/hubert_base_ls960_L9_km500.bin ]; then
wget https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960_L9_km500.bin -P download
fi
if [ ! -e data/kmeans/.extract_small.done ]; then
./local/extract_kmeans_from_hubert_base.py --subset small
touch data/kmeans/.extract_small.done
Expand Down

0 comments on commit 84f8adf

Please sign in to comment.