diff --git a/src/libtorchaudio/lfilter.cpp b/src/libtorchaudio/lfilter.cpp index 454b2cbcda..ef497a5301 100644 --- a/src/libtorchaudio/lfilter.cpp +++ b/src/libtorchaudio/lfilter.cpp @@ -245,34 +245,16 @@ class DifferentiableFIR : public torch::autograd::Function { } }; -torch::Tensor lfilter_core( +torch::Tensor differentiable_fir_apply( const torch::Tensor& waveform, - const torch::Tensor& a_coeffs, const torch::Tensor& b_coeffs) { - TORCH_CHECK(waveform.device() == a_coeffs.device()); - TORCH_CHECK(b_coeffs.device() == a_coeffs.device()); - TORCH_CHECK(a_coeffs.sizes() == b_coeffs.sizes()); - - TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 3); - TORCH_INTERNAL_ASSERT(a_coeffs.sizes().size() == 2); - TORCH_INTERNAL_ASSERT(a_coeffs.size(0) == waveform.size(1)); - - int64_t n_order = b_coeffs.size(1); - - TORCH_INTERNAL_ASSERT(n_order > 0); - - auto filtered_waveform = DifferentiableFIR::apply( - waveform, - b_coeffs / - a_coeffs.index( - {torch::indexing::Slice(), torch::indexing::Slice(0, 1)})); - - auto output = DifferentiableIIR::apply( - filtered_waveform, - a_coeffs / - a_coeffs.index( - {torch::indexing::Slice(), torch::indexing::Slice(0, 1)})); - return output; + return DifferentiableFIR::apply(waveform, b_coeffs); +} + +torch::Tensor differentiable_iir_apply( + const torch::Tensor& waveform, + const torch::Tensor& a_coeffs_normalized) { + return DifferentiableIIR::apply(waveform, a_coeffs_normalized); } } // namespace @@ -286,8 +268,13 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { TORCH_LIBRARY(torchaudio, m) { m.def( "torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor"); + m.def( + "torchaudio::_differentiable_fir_apply(Tensor waveform, Tensor b_coeffs) -> Tensor"); + m.def( + "torchaudio::_differentiable_iir_apply(Tensor waveform, Tensor a_coeffs_normalized) -> Tensor"); } TORCH_LIBRARY_IMPL(torchaudio, CompositeImplicitAutograd, m) { - m.impl("torchaudio::_lfilter", lfilter_core); + m.impl("torchaudio::_differentiable_fir_apply", differentiable_fir_apply); + m.impl("torchaudio::_differentiable_iir_apply", differentiable_iir_apply); } diff --git a/src/torchaudio/functional/filtering.py b/src/torchaudio/functional/filtering.py index 541c56c475..6d39621e05 100644 --- a/src/torchaudio/functional/filtering.py +++ b/src/torchaudio/functional/filtering.py @@ -933,6 +933,8 @@ def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: T if _IS_TORCHAUDIO_EXT_AVAILABLE: _lfilter_core_cpu_loop = torch.ops.torchaudio._lfilter_core_loop + _differentiable_fir_apply = torch.ops.torchaudio._differentiable_fir_apply + _differentiable_iir_apply = torch.ops.torchaudio._differentiable_iir_apply else: _lfilter_core_cpu_loop = _lfilter_core_generic_loop @@ -991,8 +993,26 @@ def _lfilter_core( return output +# TODO find a better name for this, possibly renaming the existing `_lfilter_core` +def _lfilter_core_in_python_calling_into_cpp_FIR_and_IIR( + waveform: Tensor, + a_coeffs: Tensor, + b_coeffs: Tensor, +) -> Tensor: + """Python implementation of lfilter_core using C++ DifferentiableFIR and DifferentiableIIR.""" + + # TODO here: input validation checks + + a0 = a_coeffs[:, 0:1] # Keep dimension for broadcasting + b_coeffs_normalized = b_coeffs / a0 + a_coeffs_normalized = a_coeffs / a0 + + filtered_waveform = _differentiable_fir_apply(waveform, b_coeffs_normalized) + return _differentiable_iir_apply(filtered_waveform, a_coeffs_normalized) + + if _IS_TORCHAUDIO_EXT_AVAILABLE: - _lfilter = torch.ops.torchaudio._lfilter + _lfilter = _lfilter_core_in_python_calling_into_cpp_FIR_and_IIR else: _lfilter = _lfilter_core