Skip to content

Write torch.ops.torchaudio._lfilter in Python #3967

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 14 additions & 27 deletions src/libtorchaudio/lfilter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,34 +245,16 @@ class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
}
};

torch::Tensor lfilter_core(
torch::Tensor differentiable_fir_apply(
const torch::Tensor& waveform,
const torch::Tensor& a_coeffs,
const torch::Tensor& b_coeffs) {
TORCH_CHECK(waveform.device() == a_coeffs.device());
TORCH_CHECK(b_coeffs.device() == a_coeffs.device());
TORCH_CHECK(a_coeffs.sizes() == b_coeffs.sizes());

TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 3);
TORCH_INTERNAL_ASSERT(a_coeffs.sizes().size() == 2);
TORCH_INTERNAL_ASSERT(a_coeffs.size(0) == waveform.size(1));

int64_t n_order = b_coeffs.size(1);

TORCH_INTERNAL_ASSERT(n_order > 0);

auto filtered_waveform = DifferentiableFIR::apply(
waveform,
b_coeffs /
a_coeffs.index(
{torch::indexing::Slice(), torch::indexing::Slice(0, 1)}));

auto output = DifferentiableIIR::apply(
filtered_waveform,
a_coeffs /
a_coeffs.index(
{torch::indexing::Slice(), torch::indexing::Slice(0, 1)}));
return output;
return DifferentiableFIR::apply(waveform, b_coeffs);
}

torch::Tensor differentiable_iir_apply(
const torch::Tensor& waveform,
const torch::Tensor& a_coeffs_normalized) {
return DifferentiableIIR::apply(waveform, a_coeffs_normalized);
}

} // namespace
Expand All @@ -286,8 +268,13 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
TORCH_LIBRARY(torchaudio, m) {
m.def(
"torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor");
m.def(
"torchaudio::_differentiable_fir_apply(Tensor waveform, Tensor b_coeffs) -> Tensor");
m.def(
"torchaudio::_differentiable_iir_apply(Tensor waveform, Tensor a_coeffs_normalized) -> Tensor");
}

TORCH_LIBRARY_IMPL(torchaudio, CompositeImplicitAutograd, m) {
m.impl("torchaudio::_lfilter", lfilter_core);
m.impl("torchaudio::_differentiable_fir_apply", differentiable_fir_apply);
m.impl("torchaudio::_differentiable_iir_apply", differentiable_iir_apply);
}
22 changes: 21 additions & 1 deletion src/torchaudio/functional/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,8 @@ def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: T

if _IS_TORCHAUDIO_EXT_AVAILABLE:
_lfilter_core_cpu_loop = torch.ops.torchaudio._lfilter_core_loop
_differentiable_fir_apply = torch.ops.torchaudio._differentiable_fir_apply
_differentiable_iir_apply = torch.ops.torchaudio._differentiable_iir_apply
else:
_lfilter_core_cpu_loop = _lfilter_core_generic_loop

Expand Down Expand Up @@ -991,8 +993,26 @@ def _lfilter_core(
return output


# TODO find a better name for this, possibly renaming the existing `_lfilter_core`
def _lfilter_core_in_python_calling_into_cpp_FIR_and_IIR(
waveform: Tensor,
a_coeffs: Tensor,
b_coeffs: Tensor,
) -> Tensor:
"""Python implementation of lfilter_core using C++ DifferentiableFIR and DifferentiableIIR."""

# TODO here: input validation checks

a0 = a_coeffs[:, 0:1] # Keep dimension for broadcasting
b_coeffs_normalized = b_coeffs / a0
a_coeffs_normalized = a_coeffs / a0

filtered_waveform = _differentiable_fir_apply(waveform, b_coeffs_normalized)
return _differentiable_iir_apply(filtered_waveform, a_coeffs_normalized)


if _IS_TORCHAUDIO_EXT_AVAILABLE:
_lfilter = torch.ops.torchaudio._lfilter
_lfilter = _lfilter_core_in_python_calling_into_cpp_FIR_and_IIR
else:
_lfilter = _lfilter_core

Expand Down
Loading