Skip to content

Commit

Permalink
Removed unused k2 dependencies from the AT recipe (#1633)
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr authored May 21, 2024
1 parent 0df406c commit 1adf1e4
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 38 deletions.
8 changes: 5 additions & 3 deletions egs/audioset/AT/zipformer/at_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,11 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = AudioTaggingDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
input_strategy=(
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)()
),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
Expand Down
25 changes: 4 additions & 21 deletions egs/audioset/AT/zipformer/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,18 @@
"""

import argparse
import csv
import logging
import math
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import Dict

import k2
import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
import torch.nn.functional as F
from at_datamodule import AudioSetATDatamodule
from lhotse import load_manifest

try:
from sklearn.metrics import average_precision_score
except Exception as ex:
raise RuntimeError(f"{ex}\nPlease run\n" "pip3 install -U scikit-learn")
except:
raise ImportError(f"Please run\n" "pip3 install -U scikit-learn")
from train import add_model_arguments, get_model, get_params, str2multihot

from icefall.checkpoint import (
Expand All @@ -58,15 +49,7 @@
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
from icefall.utils import AttributeDict, setup_logger, str2bool


def get_parser():
Expand Down
3 changes: 1 addition & 2 deletions egs/audioset/AT/zipformer/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from pathlib import Path
from typing import Dict

import k2
import onnx
import onnxoptimizer
import torch
Expand All @@ -53,7 +52,7 @@
find_checkpoints,
load_checkpoint,
)
from icefall.utils import make_pad_mask, num_tokens, str2bool
from icefall.utils import make_pad_mask, str2bool


def get_parser():
Expand Down
1 change: 0 additions & 1 deletion egs/audioset/AT/zipformer/jit_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
import math
from typing import List

import k2
import kaldifeat
import torch
import torchaudio
Expand Down
8 changes: 2 additions & 6 deletions egs/audioset/AT/zipformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import random
from typing import List, Optional, Tuple
from typing import Tuple

import k2
import torch
import torch.nn as nn
import torch.nn.functional as F
from encoder_interface import EncoderInterface

from icefall.utils import AttributeDict, make_pad_mask
from icefall.utils import make_pad_mask


class AudioTaggingModel(nn.Module):
Expand Down
3 changes: 1 addition & 2 deletions egs/audioset/AT/zipformer/onnx_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,8 @@
import csv
import logging
import math
from typing import List, Tuple
from typing import List

import k2
import kaldifeat
import onnxruntime as ort
import torch
Expand Down
5 changes: 2 additions & 3 deletions egs/audioset/AT/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import optim
import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
Expand Down Expand Up @@ -632,7 +631,7 @@ def compute_loss(
model:
The model for training. It is an instance of Zipformer in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
A batch of data. See `lhotse.dataset.AudioTaggingDataset()`
for the content in it.
is_training:
True for training. False for validation. When it is True, this
Expand Down Expand Up @@ -1108,7 +1107,7 @@ def display_and_save_batch(
Args:
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
A batch of data. See `lhotse.dataset.AudioTaggingDataset()`
for the content in it.
params:
Parameters for training. See :func:`get_params`.
Expand Down

0 comments on commit 1adf1e4

Please sign in to comment.