-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: introduce AudioFeatureTransformer and audio_service #30
base: main
Are you sure you want to change the base?
Changes from 9 commits
a726e95
f81bf1a
c736c7b
6ab2e2b
5d16808
dd1f417
6682541
3ca5e55
32e4081
a676fad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from src.spira_training.shared.core.models.audio_collection import AudioCollection | ||
from src.spira_training.shared.ports.audio_feature_transformer import AudioFeatureTransformer | ||
|
||
class AudioFeatureTransformerPipeline(AudioFeatureTransformer): | ||
def __init__(self, transformers: list[AudioFeatureTransformer]): | ||
self.transformers = transformers | ||
|
||
def transform_into_features(self, audios: AudioCollection) -> AudioCollection: | ||
for transformer in self.transformers: | ||
audios = transformer.transform_into_features(audios) | ||
return audios |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from src.spira_training.shared.core.audio_service import create_slices_from_audio, concatenate_audios | ||
from src.spira_training.shared.core.models.audio import Audio | ||
from src.spira_training.shared.core.models.audio_collection import AudioCollection | ||
from src.spira_training.shared.ports.audio_feature_transformer import AudioFeatureTransformer | ||
from src.spira_training.shared.core.audio_processor import AudioProcessor | ||
|
||
|
||
class OverlappedAudioFeatureTransformer(AudioFeatureTransformer): | ||
def __init__(self, audio_processor: AudioProcessor, window_length: int, step_size: int): | ||
self.audio_processor = audio_processor | ||
self.window_length = window_length | ||
self.step_size = step_size | ||
|
||
def transform_into_features(self, audios: AudioCollection) -> AudioCollection: | ||
|
||
return self._overlap_audio_collection(audios) | ||
|
||
def _overlap_audio_collection(self, audios: AudioCollection) -> AudioCollection: | ||
return AudioCollection( | ||
[self._overlap_audio(audio) for audio in audios] | ||
) | ||
|
||
def _overlap_audio(self, audio: Audio) -> Audio: | ||
audio_slices = create_slices_from_audio(audio, self.window_length, self.step_size) | ||
processed_audios = self.audio_processor.process_audios(audio_slices) | ||
|
||
return concatenate_audios(processed_audios) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from src.spira_training.shared.core.audio_processor import AudioProcessor | ||
from src.spira_training.shared.core.audio_service import add_padding_to_audio_collection | ||
from src.spira_training.shared.core.models.audio_collection import AudioCollection | ||
from src.spira_training.shared.ports.audio_feature_transformer import AudioFeatureTransformer | ||
|
||
|
||
class PaddedAudioFeatureTransformer(AudioFeatureTransformer): | ||
def __init__(self, audio_processor: AudioProcessor): | ||
self.audio_processor = audio_processor | ||
|
||
def transform_into_features(self, audios: AudioCollection) -> AudioCollection: | ||
processed_audios = self.audio_processor.process_audios(audios) | ||
|
||
return add_padding_to_audio_collection(processed_audios) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,24 @@ | ||
from spira_training.shared.core.models.wav import Wav | ||
from src.spira_training.shared.adapters.pytorch.model_trainer.interfaces.pytorch_audio_factory import ( | ||
PytorchTensorFactory, | ||
) | ||
from src.spira_training.shared.core.models.audio_collection import AudioCollection | ||
from src.spira_training.shared.core.models.audio import Audio | ||
from src.spira_training.shared.core.models.generated_audio_collection import GeneratedAudioCollection | ||
from src.spira_training.shared.ports.feature_transformer import FeatureTransformer | ||
|
||
|
||
class AudioProcessor: | ||
def __init__( | ||
self, | ||
feature_transformer: FeatureTransformer, | ||
pytorch_tensor_factory: PytorchTensorFactory, | ||
): | ||
self.feature_transformer = feature_transformer | ||
self.pytorch_tensor_factory = pytorch_tensor_factory | ||
|
||
def process_audio(self, audio: Audio) -> Audio: | ||
pytorch_tensor = self.pytorch_tensor_factory.create_tensor_from_audio(audio) | ||
feature_wav = self.feature_transformer.transform(Wav(pytorch_tensor)) | ||
feature_wav = self.feature_transformer.transform(audio.wav) | ||
transposed_feature_wav = feature_wav.transpose(1, 2) | ||
reshaped_feature_wav = transposed_feature_wav.reshape( | ||
transposed_feature_wav.shape[1:] | ||
) | ||
return Audio(wav=reshaped_feature_wav, sample_rate=audio.sample_rate) | ||
|
||
def process_audios(self, audios: AudioCollection) -> AudioCollection: | ||
def process_audios(self, audios: AudioCollection | GeneratedAudioCollection) -> AudioCollection | GeneratedAudioCollection: | ||
audio_list = [self.process_audio(audio) for audio in audios] | ||
return AudioCollection(audios=audio_list, hop_length=audios.hop_length) | ||
return audios.copy_using(audios=audio_list) |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. its not a class right? why not? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i wanted to make it simple and reduce the overhead to manage instance variables and dependencies that making it a class require |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from src.spira_training.shared.core.models.audio import Audio | ||
from src.spira_training.shared.core.models.audio_collection import AudioCollection | ||
from src.spira_training.shared.core.models.generated_audio_collection import GeneratedAudioCollection | ||
from src.spira_training.shared.core.models.wav import create_empty_wav, concatenate_wavs | ||
|
||
def create_slices_from_audio(audio: Audio, window_length: int, step_size: int) -> GeneratedAudioCollection: | ||
slices = [] | ||
slice_index = 0 | ||
|
||
while slice_index < len(audio): | ||
start = slice_index | ||
end = slice_index + window_length | ||
|
||
if end > len(audio): | ||
end = len(audio) | ||
|
||
slices.append(_create_slice(start, end)) | ||
slice_index += step_size | ||
|
||
return GeneratedAudioCollection(generated_audios=slices) | ||
|
||
def _create_slice(self, start_index: int, end_index: int) -> 'Audio': | ||
if start_index < 0 or end_index < 0 or start_index >= end_index: | ||
raise ValueError(f"Invalid range [{start_index}:{end_index}]") | ||
|
||
return Audio( | ||
wav=self.wav.slice( | ||
# Audios are indexed in sample_rate chunks | ||
start_index=start_index * self.sample_rate, | ||
end_index=end_index * self.sample_rate, | ||
) | ||
) | ||
|
||
def concatenate_audios(audios: AudioCollection | GeneratedAudioCollection) -> Audio: | ||
if len(audios) == 0: | ||
return Audio(wav=create_empty_wav(), sample_rate=0) | ||
|
||
if _check_audios_have_different_sample_rate(audios): | ||
raise ValueError("Sample rates are not equal") | ||
|
||
wav_list = [audio.wav for audio in audios] | ||
concatenated_wav = concatenate_wavs(wav_list) | ||
|
||
return Audio(concatenated_wav, sample_rate=audios[0].sample_rate) | ||
|
||
def _check_audios_have_different_sample_rate(audios: GeneratedAudioCollection) -> bool: | ||
sample_rates = {audio.sample_rate for audio in audios} | ||
return len(sample_rates) > 1 | ||
|
||
|
||
def add_padding_to_audio_collection(audios: AudioCollection) -> AudioCollection: | ||
max_audio_length = audios.get_max_audio_length() | ||
|
||
return AudioCollection( | ||
[audio.add_padding(max_audio_length) for audio in audios], | ||
audios.hop_length | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,15 @@ | ||
import math | ||
|
||
from src.spira_training.shared.core.models.wav import Wav | ||
|
||
|
||
class Audio: | ||
def __init__(self, wav: Wav = None, sample_rate: int = 0): | ||
self.wav = wav | ||
self.sample_rate = sample_rate | ||
|
||
def __len__(self): | ||
return math.ceil(len(self.wav) / self.sample_rate) | ||
|
||
def add_padding(self, max_audio_length): | ||
return Audio(wav=self.wav.add_padding(max_audio_length), sample_rate=self.sample_rate) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from src.spira_training.shared.core.models.wav import Wav | ||
|
||
|
||
class GeneratedAudio: | ||
def __init__(self, wav: Wav): | ||
self.wav = wav |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from src.spira_training.shared.core.models.audio import Audio | ||
|
||
|
||
class GeneratedAudioCollection: | ||
def __init__(self, generated_audios: list['Audio']): | ||
self.generated_audios = generated_audios | ||
|
||
def __len__(self) -> int: | ||
return len(self.generated_audios) | ||
|
||
def copy_using(self, audios: list['Audio']) -> 'GeneratedAudioCollection': | ||
return GeneratedAudioCollection(audios) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
from src.spira_training.shared.core.models.audio_collection import AudioCollection | ||
|
||
|
||
class AudioFeatureTransformer(ABC): | ||
@abstractmethod | ||
def transform_into_features(self, audios: AudioCollection) -> AudioCollection: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice