Skip to content

Commit

Permalink
Add more tests and resolutions
Browse files Browse the repository at this point in the history
  • Loading branch information
hagenw committed Mar 25, 2024
1 parent db0f065 commit 1e9f056
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 36 deletions.
11 changes: 6 additions & 5 deletions audinterface/core/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
from audinterface.core.typing import Timestamps


def identity(signal, sampling_rate):
r"""Default processing function."""
return signal


class Process:
r"""Processing interface.
Expand Down Expand Up @@ -156,11 +161,6 @@ def __init__(
if channels is not None:
channels = audeer.to_list(channels)

if process_func is None:

def process_func(signal, _):
return signal

if resample and sampling_rate is None:
raise ValueError("sampling_rate has to be provided for resample = True.")

Expand All @@ -169,6 +169,7 @@ def process_func(signal, _):
if win_dur is not None and hop_dur is None:
hop_dur = utils.to_timedelta(win_dur, sampling_rate) / 2

process_func = process_func or identity
signature = inspect.signature(process_func)
self._process_func_signature = dict(signature.parameters)
r"""Arguments present in processing function."""
Expand Down
11 changes: 6 additions & 5 deletions audinterface/core/process_with_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
from audinterface.core import utils


def identity(signal, sampling_rate, starts, ends):
r"""Default processing function."""
return [signal[:, start:end] for start, end in zip(starts, ends)]


class ProcessWithContext:
r"""Alternate processing interface that provides signal context.
Expand Down Expand Up @@ -105,14 +110,10 @@ def __init__(
if channels is not None:
channels = audeer.to_list(channels)

if process_func is None:

def process_func(signal, _, starts, ends):
return [signal[:, start:end] for start, end in zip(starts, ends)]

if resample and sampling_rate is None:
raise ValueError("sampling_rate has to be provided for resample = True.")

process_func = process_func or identity
signature = inspect.signature(process_func)
self._process_func_signature = signature.parameters
r"""Arguments present in processing function."""
Expand Down
42 changes: 19 additions & 23 deletions audinterface/core/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,19 @@
from audinterface.core.typing import Timestamps


def create_process_func(
process_func: typing.Optional[typing.Callable[..., pd.MultiIndex]],
invert: bool,
) -> typing.Callable[..., pd.MultiIndex]:
r"""Create processing function."""
if process_func is None:

def process_func(signal, sr, **kwargs):
return utils.signal_index()

if invert:

def process_func_invert(signal, sr, **kwargs):
index = process_func(signal, sr, **kwargs)
dur = pd.to_timedelta(signal.shape[-1] / sr, unit="s")
index = index.sortlevel("start")[0]
index = merge_index(index)
index = invert_index(index, dur)
return index
def signal_index(signal, sampling_rate, **kwargs):
r"""Default segment function."""
return utils.signal_index()


return process_func_invert
else:
return process_func
def signal_index_invert(signal, sampling_rate, **kwargs):
r"""Default inverted segment function."""
index = process_func(signal, sr, **kwargs)
dur = pd.to_timedelta(signal.shape[-1] / sr, unit="s")
index = index.sortlevel("start")[0]
index = merge_index(index)
index = invert_index(index, dur)
return index


def invert_index(
Expand Down Expand Up @@ -216,8 +206,14 @@ def __init__(
# avoid cycling imports
from audinterface.core.process import Process

if process_func is None:
process_func = signal_index

if invert:
process_func = signal_index_invert

process = Process(
process_func=create_process_func(process_func, invert),
process_func=process_func,
process_func_args=process_func_args,
sampling_rate=sampling_rate,
resample=resample,
Expand Down
39 changes: 37 additions & 2 deletions tests/test_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,42 @@ def test_process_file(tmpdir, start, end, segment):
np.testing.assert_array_equal(y, y_expected)


def test_process_folder(tmpdir):
@pytest.mark.parametrize(
"num_files, num_workers, multiprocessing",
[
(
3,
1,
False,
),
(
3,
2,
False,
),
(
3,
2,
True,
),
(
3,
None,
False,
),
(
3,
1,
False,
),
],
)
def test_process_folder(
tmpdir,
num_files,
num_workers,
multiprocessing,
):
index = audinterface.utils.signal_index(0, 1)
feature_names = ["o1", "o2", "o3"]
feature = audinterface.Feature(
Expand All @@ -343,7 +378,7 @@ def test_process_folder(tmpdir):
)

path = str(tmpdir.mkdir("wav"))
files = [f"file{n}.wav" for n in range(3)]
files = [f"file{n}.wav" for n in range(num_files)]
files_abs = [os.path.join(path, file) for file in files]
for file in files_abs:
af.write(file, SIGNAL_2D, SAMPLING_RATE)
Expand Down
48 changes: 47 additions & 1 deletion tests/test_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
INDEX = audinterface.utils.signal_index(STARTS, ENDS)


def predefined_index(signal, sampling_rate):
return INDEX


@pytest.mark.parametrize(
"signal, sampling_rate, segment_func, result",
[
Expand Down Expand Up @@ -89,6 +93,10 @@ def test_file(tmpdir):
2,
False,
),
(
2,
True,
),
(
None,
False,
Expand All @@ -97,7 +105,7 @@ def test_file(tmpdir):
)
def test_folder(tmpdir, num_workers, multiprocessing):
segment = audinterface.Segment(
process_func=lambda s, sr: INDEX,
process_func=predefined_index,
sampling_rate=None,
resample=False,
num_workers=num_workers,
Expand All @@ -123,6 +131,44 @@ def test_folder(tmpdir, num_workers, multiprocessing):
pd.testing.assert_index_equal(index, audformat.filewise_index())


@pytest.mark.parametrize(
"num_workers, multiprocessing",
[
(
1,
False,
),
(
2,
False,
),
(
2,
True,
),
(
None,
False,
),
],
)
def test_folder_default_process_func(tmpdir, num_workers, multiprocessing):
segment = audinterface.Segment(
process_func=None,
sampling_rate=None,
resample=False,
num_workers=num_workers,
multiprocessing=multiprocessing,
verbose=False,
)
path = str(tmpdir.mkdir("wav"))
files = [os.path.join(path, f"file{n}.wav") for n in range(3)]
for file in files:
af.write(file, SIGNAL, SAMPLING_RATE)
result = segment.process_folder(path)
assert len(result) == 0


@pytest.mark.parametrize(
"num_workers, multiprocessing",
[
Expand Down

0 comments on commit 1e9f056

Please sign in to comment.