Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: introduce AudioFeatureTransformer and audio_service #30

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from tests.fakes.fake_audios_repository import FakeAudiosRepository
from tests.fakes.fake_path_validator import FakePathValidator
from tests.fakes.fake_file_reader import FakeFileReader
from tests.fakes.fake_pytorch_audio_factory import FakePytorchTensorFactory


async def main():
Expand All @@ -22,7 +21,6 @@ async def main():
audios_repository = FakeAudiosRepository()
file_reader = FakeFileReader()
path_validator = FakePathValidator()
pytorch_audio_factory = FakePytorchTensorFactory()

service = FeatureEngineeringService(
config=config,
Expand All @@ -31,7 +29,6 @@ async def main():
audios_repository=audios_repository,
file_reader=file_reader,
path_validator=path_validator,
pytorch_audio_factory=pytorch_audio_factory,
)

# TODO - Get the bucket name to save the dataset
Expand Down
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from src.spira_training.shared.core.models.audio_collection import AudioCollection
from src.spira_training.shared.ports.audio_feature_transformer import AudioFeatureTransformer

class AudioFeatureTransformerPipeline(AudioFeatureTransformer):
def __init__(self, transformers: list[AudioFeatureTransformer]):
self.transformers = transformers

def transform_into_features(self, audios: AudioCollection) -> AudioCollection:
for transformer in self.transformers:
audios = transformer.transform_into_features(audios)
return audios
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from src.spira_training.shared.core.audio_service import create_slices_from_audio, concatenate_audios
from src.spira_training.shared.core.models.audio import Audio
from src.spira_training.shared.core.models.audio_collection import AudioCollection
from src.spira_training.shared.ports.audio_feature_transformer import AudioFeatureTransformer
from src.spira_training.shared.core.audio_processor import AudioProcessor


class OverlappedAudioFeatureTransformer(AudioFeatureTransformer):
def __init__(self, audio_processor: AudioProcessor, window_length: int, step_size: int):
self.audio_processor = audio_processor
self.window_length = window_length
self.step_size = step_size

def transform_into_features(self, audios: AudioCollection) -> AudioCollection:

return self._overlap_audio_collection(audios)

def _overlap_audio_collection(self, audios: AudioCollection) -> AudioCollection:
return AudioCollection(
[self._overlap_audio(audio) for audio in audios]
)

def _overlap_audio(self, audio: Audio) -> Audio:
audio_slices = create_slices_from_audio(audio, self.window_length, self.step_size)
processed_audios = self.audio_processor.process_audios(audio_slices)

return concatenate_audios(processed_audios)
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from src.spira_training.shared.core.audio_processor import AudioProcessor
from src.spira_training.shared.core.audio_service import add_padding_to_audio_collection
from src.spira_training.shared.core.models.audio_collection import AudioCollection
from src.spira_training.shared.ports.audio_feature_transformer import AudioFeatureTransformer


class PaddedAudioFeatureTransformer(AudioFeatureTransformer):
def __init__(self, audio_processor: AudioProcessor):
self.audio_processor = audio_processor

def transform_into_features(self, audios: AudioCollection) -> AudioCollection:
processed_audios = self.audio_processor.process_audios(audios)

return add_padding_to_audio_collection(processed_audios)
14 changes: 4 additions & 10 deletions src/spira_training/shared/core/audio_processor.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,24 @@
from spira_training.shared.core.models.wav import Wav
from src.spira_training.shared.adapters.pytorch.model_trainer.interfaces.pytorch_audio_factory import (
PytorchTensorFactory,
)
from src.spira_training.shared.core.models.audio_collection import AudioCollection
from src.spira_training.shared.core.models.audio import Audio
from src.spira_training.shared.core.models.generated_audio_collection import GeneratedAudioCollection
from src.spira_training.shared.ports.feature_transformer import FeatureTransformer


class AudioProcessor:
def __init__(
self,
feature_transformer: FeatureTransformer,
pytorch_tensor_factory: PytorchTensorFactory,
):
self.feature_transformer = feature_transformer
self.pytorch_tensor_factory = pytorch_tensor_factory

def process_audio(self, audio: Audio) -> Audio:
pytorch_tensor = self.pytorch_tensor_factory.create_tensor_from_audio(audio)
feature_wav = self.feature_transformer.transform(Wav(pytorch_tensor))
feature_wav = self.feature_transformer.transform(audio.wav)
transposed_feature_wav = feature_wav.transpose(1, 2)
reshaped_feature_wav = transposed_feature_wav.reshape(
transposed_feature_wav.shape[1:]
)
return Audio(wav=reshaped_feature_wav, sample_rate=audio.sample_rate)

def process_audios(self, audios: AudioCollection) -> AudioCollection:
def process_audios(self, audios: AudioCollection | GeneratedAudioCollection) -> AudioCollection | GeneratedAudioCollection:
audio_list = [self.process_audio(audio) for audio in audios]
return AudioCollection(audios=audio_list, hop_length=audios.hop_length)
return audios.copy_using(audios=audio_list)
6 changes: 1 addition & 5 deletions src/spira_training/shared/core/audio_processor_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from src.spira_training.shared.adapters.pytorch.model_trainer.interfaces.pytorch_audio_factory import (
PytorchTensorFactory,
)
from src.spira_training.apps.feature_engineering.configs.audio_processor_config import (
AudioProcessorType,
AudioProcessorConfig,
Expand All @@ -19,13 +16,12 @@


def create_audio_processor(
config: AudioProcessorConfig, pytorch_tensor_factory: PytorchTensorFactory
config: AudioProcessorConfig
) -> AudioProcessor:
feature_transformer = create_feature_transformer(config)

return AudioProcessor(
feature_transformer=feature_transformer,
pytorch_tensor_factory=pytorch_tensor_factory,
)


Expand Down
57 changes: 57 additions & 0 deletions src/spira_training/shared/core/audio_service.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its not a class right? why not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i wanted to make it simple and reduce the overhead to manage instance variables and dependencies that making it a class require

Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from src.spira_training.shared.core.models.audio import Audio
from src.spira_training.shared.core.models.audio_collection import AudioCollection
from src.spira_training.shared.core.models.generated_audio_collection import GeneratedAudioCollection
from src.spira_training.shared.core.models.wav import create_empty_wav, concatenate_wavs

def create_slices_from_audio(audio: Audio, window_length: int, step_size: int) -> GeneratedAudioCollection:
slices = []
slice_index = 0

while slice_index < len(audio):
start = slice_index
end = slice_index + window_length

if end > len(audio):
end = len(audio)

slices.append(_create_slice(start, end))
slice_index += step_size

return GeneratedAudioCollection(generated_audios=slices)

def _create_slice(self, start_index: int, end_index: int) -> 'Audio':
if start_index < 0 or end_index < 0 or start_index >= end_index:
raise ValueError(f"Invalid range [{start_index}:{end_index}]")

return Audio(
wav=self.wav.slice(
# Audios are indexed in sample_rate chunks
start_index=start_index * self.sample_rate,
end_index=end_index * self.sample_rate,
)
)

def concatenate_audios(audios: AudioCollection | GeneratedAudioCollection) -> Audio:
if len(audios) == 0:
return Audio(wav=create_empty_wav(), sample_rate=0)

if _check_audios_have_different_sample_rate(audios):
raise ValueError("Sample rates are not equal")

wav_list = [audio.wav for audio in audios]
concatenated_wav = concatenate_wavs(wav_list)

return Audio(concatenated_wav, sample_rate=audios[0].sample_rate)

def _check_audios_have_different_sample_rate(audios: GeneratedAudioCollection) -> bool:
sample_rates = {audio.sample_rate for audio in audios}
return len(sample_rates) > 1


def add_padding_to_audio_collection(audios: AudioCollection) -> AudioCollection:
max_audio_length = audios.get_max_audio_length()

return AudioCollection(
[audio.add_padding(max_audio_length) for audio in audios],
audios.hop_length
)
8 changes: 8 additions & 0 deletions src/spira_training/shared/core/models/audio.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import math

from src.spira_training.shared.core.models.wav import Wav


class Audio:
def __init__(self, wav: Wav = None, sample_rate: int = 0):
self.wav = wav
self.sample_rate = sample_rate

def __len__(self):
return math.ceil(len(self.wav) / self.sample_rate)

def add_padding(self, max_audio_length):
return Audio(wav=self.wav.add_padding(max_audio_length), sample_rate=self.sample_rate)
31 changes: 29 additions & 2 deletions src/spira_training/shared/core/models/audio_collection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterator
from typing import Iterator, Optional

from src.spira_training.shared.core.models.audio import Audio

Expand All @@ -7,6 +7,33 @@ class AudioCollection:
def __init__(self, audios: list[Audio], hop_length: int):
self.audios = audios
self.hop_length = hop_length
self._min_audio_length: Optional[int] = None
self._max_audio_length: Optional[int] = None

def __iter__(self) -> Iterator[Audio]:
return iter(self.audios)
return iter(self.audios)

def __len__(self) -> int:
return len(self.audios)


def copy_using(self, audios: list[Audio]) -> 'AudioCollection':
return AudioCollection(audios, self.hop_length)

def get_max_audio_length(self) -> int:
if self._max_audio_length is None:
self._calculate_min_max_audio_length()
return self._max_audio_length

def get_min_audio_length(self) -> int:
if self._min_audio_length is None:
self._calculate_min_max_audio_length()
return self._min_audio_length

def _calculate_min_max_audio_length(self):
audio_lengths = [self._calculate_audio_length(audio) for audio in self.audios]
self._min_audio_length = min(audio_lengths)
self._max_audio_length = max(audio_lengths)

def _calculate_audio_length(self, audio: Audio) -> int:
return int((audio.wav.shape[1] / self.hop_length) + 1)
6 changes: 6 additions & 0 deletions src/spira_training/shared/core/models/generated_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from src.spira_training.shared.core.models.wav import Wav


class GeneratedAudio:
def __init__(self, wav: Wav):
self.wav = wav
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from src.spira_training.shared.core.models.audio import Audio


class GeneratedAudioCollection:
def __init__(self, generated_audios: list['Audio']):
self.generated_audios = generated_audios

def __len__(self) -> int:
return len(self.generated_audios)

def copy_using(self, audios: list['Audio']) -> 'GeneratedAudioCollection':
return GeneratedAudioCollection(audios)
7 changes: 6 additions & 1 deletion src/spira_training/shared/core/models/wav.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,13 @@ def concatenate(self, wav: 'Wav') -> 'Wav':
def __getattr__(self, name):
return getattr(self.tensor, name)


def concatenate_wavs(wavs: List[Wav]) -> Wav:
if not wavs:
return None

return reduce(lambda acc, wav: acc.concatenate(wav), wavs)
return reduce(lambda acc, wav: acc.concatenate(wav), wavs)


def create_empty_wav() -> Wav:
return Wav(torch.empty(0))
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from pathlib import Path

from src.spira_training.shared.adapters.pytorch.model_trainer.interfaces.pytorch_audio_factory import (
PytorchTensorFactory,
)

from src.spira_training.apps.feature_engineering.configs.feature_engineering_config import (
FeatureEngineeringConfig,
Expand All @@ -27,9 +24,7 @@ def __init__(
audios_repository: AudiosRepository,
file_reader: FileReader,
path_validator: PathValidator,
pytorch_audio_factory: PytorchTensorFactory,
):
self.pytorch_audio_factory = pytorch_audio_factory
self.config = config
self.randomizer = randomizer
self.dataset_repository = dataset_repository
Expand All @@ -40,9 +35,7 @@ def __init__(
async def execute(self, save_dataset_path: Path) -> None:
patients_inputs, controls_inputs, noises = self._load_data()

audio_processor = create_audio_processor(
self.config.audio_processor, self.pytorch_audio_factory
)
audio_processor = create_audio_processor(self.config.audio_processor)

dataset = self._generate_dataset()

Expand Down
9 changes: 9 additions & 0 deletions src/spira_training/shared/ports/audio_feature_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from abc import ABC, abstractmethod

from src.spira_training.shared.core.models.audio_collection import AudioCollection


class AudioFeatureTransformer(ABC):
@abstractmethod
def transform_into_features(self, audios: AudioCollection) -> AudioCollection:
pass
5 changes: 1 addition & 4 deletions tests/core/test_audio_processor_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
create_audio_processor,
)
from tests.fakes.fake_feature_engineering_config import make_audio_processor_config
from tests.fakes.fake_pytorch_audio_factory import FakePytorchTensorFactory


@pytest.mark.asyncio
Expand All @@ -19,9 +18,7 @@ async def test_create_audio_processor_mfcc():
audio_processor_config = make_audio_processor_config(AudioProcessorType.MFCC)

# Act
audio_processor = create_audio_processor(
audio_processor_config, FakePytorchTensorFactory()
)
audio_processor = create_audio_processor(audio_processor_config)

# Assert
assert isinstance(audio_processor.feature_transformer, MFCCFeatureTransformer)
Loading