From 1adf1e441d7ad49d5d4a96246a28aa9e12d6f967 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Tue, 21 May 2024 18:22:19 +0800 Subject: [PATCH] Removed unused ``k2`` dependencies from the AT recipe (#1633) --- egs/audioset/AT/zipformer/at_datamodule.py | 8 ++++--- egs/audioset/AT/zipformer/evaluate.py | 25 ++++---------------- egs/audioset/AT/zipformer/export-onnx.py | 3 +-- egs/audioset/AT/zipformer/jit_pretrained.py | 1 - egs/audioset/AT/zipformer/model.py | 8 ++----- egs/audioset/AT/zipformer/onnx_pretrained.py | 3 +-- egs/audioset/AT/zipformer/train.py | 5 ++-- 7 files changed, 15 insertions(+), 38 deletions(-) diff --git a/egs/audioset/AT/zipformer/at_datamodule.py b/egs/audioset/AT/zipformer/at_datamodule.py index 66497c1ca6..ac8671fa61 100644 --- a/egs/audioset/AT/zipformer/at_datamodule.py +++ b/egs/audioset/AT/zipformer/at_datamodule.py @@ -373,9 +373,11 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") test = AudioTaggingDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)(), + input_strategy=( + OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)() + ), return_cuts=self.args.return_cuts, ) sampler = DynamicBucketingSampler( diff --git a/egs/audioset/AT/zipformer/evaluate.py b/egs/audioset/AT/zipformer/evaluate.py index b52a284d04..0a1b8ea5fd 100644 --- a/egs/audioset/AT/zipformer/evaluate.py +++ b/egs/audioset/AT/zipformer/evaluate.py @@ -29,27 +29,18 @@ """ import argparse -import csv import logging -import math -import os -from collections import defaultdict from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Dict -import k2 -import numpy as np -import sentencepiece as spm import torch import torch.nn as nn -import torch.nn.functional as F from at_datamodule import AudioSetATDatamodule -from lhotse import load_manifest try: from sklearn.metrics import average_precision_score -except Exception as ex: - raise RuntimeError(f"{ex}\nPlease run\n" "pip3 install -U scikit-learn") +except: + raise ImportError(f"Please run\n" "pip3 install -U scikit-learn") from train import add_model_arguments, get_model, get_params, str2multihot from icefall.checkpoint import ( @@ -58,15 +49,7 @@ find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) +from icefall.utils import AttributeDict, setup_logger, str2bool def get_parser(): diff --git a/egs/audioset/AT/zipformer/export-onnx.py b/egs/audioset/AT/zipformer/export-onnx.py index 24b7717b45..2b0ec8b4b7 100755 --- a/egs/audioset/AT/zipformer/export-onnx.py +++ b/egs/audioset/AT/zipformer/export-onnx.py @@ -36,7 +36,6 @@ from pathlib import Path from typing import Dict -import k2 import onnx import onnxoptimizer import torch @@ -53,7 +52,7 @@ find_checkpoints, load_checkpoint, ) -from icefall.utils import make_pad_mask, num_tokens, str2bool +from icefall.utils import make_pad_mask, str2bool def get_parser(): diff --git a/egs/audioset/AT/zipformer/jit_pretrained.py b/egs/audioset/AT/zipformer/jit_pretrained.py index 403308fcfb..d376aa1486 100755 --- a/egs/audioset/AT/zipformer/jit_pretrained.py +++ b/egs/audioset/AT/zipformer/jit_pretrained.py @@ -50,7 +50,6 @@ import math from typing import List -import k2 import kaldifeat import torch import torchaudio diff --git a/egs/audioset/AT/zipformer/model.py b/egs/audioset/AT/zipformer/model.py index f189eac622..fb8e2dd855 100644 --- a/egs/audioset/AT/zipformer/model.py +++ b/egs/audioset/AT/zipformer/model.py @@ -14,17 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging -import random -from typing import List, Optional, Tuple +from typing import Tuple -import k2 import torch import torch.nn as nn -import torch.nn.functional as F from encoder_interface import EncoderInterface -from icefall.utils import AttributeDict, make_pad_mask +from icefall.utils import make_pad_mask class AudioTaggingModel(nn.Module): diff --git a/egs/audioset/AT/zipformer/onnx_pretrained.py b/egs/audioset/AT/zipformer/onnx_pretrained.py index 82fa3d45b6..8de60bbb5d 100755 --- a/egs/audioset/AT/zipformer/onnx_pretrained.py +++ b/egs/audioset/AT/zipformer/onnx_pretrained.py @@ -42,9 +42,8 @@ import csv import logging import math -from typing import List, Tuple +from typing import List -import k2 import kaldifeat import onnxruntime as ort import torch diff --git a/egs/audioset/AT/zipformer/train.py b/egs/audioset/AT/zipformer/train.py index 0e234c59f5..2d193030a8 100644 --- a/egs/audioset/AT/zipformer/train.py +++ b/egs/audioset/AT/zipformer/train.py @@ -41,7 +41,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union import optim -import sentencepiece as spm import torch import torch.multiprocessing as mp import torch.nn as nn @@ -632,7 +631,7 @@ def compute_loss( model: The model for training. It is an instance of Zipformer in our case. batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + A batch of data. See `lhotse.dataset.AudioTaggingDataset()` for the content in it. is_training: True for training. False for validation. When it is True, this @@ -1108,7 +1107,7 @@ def display_and_save_batch( Args: batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + A batch of data. See `lhotse.dataset.AudioTaggingDataset()` for the content in it. params: Parameters for training. See :func:`get_params`.