diff --git a/src/libtorchaudio/lfilter.cpp b/src/libtorchaudio/lfilter.cpp index 454b2cbcda..4e40c9b622 100644 --- a/src/libtorchaudio/lfilter.cpp +++ b/src/libtorchaudio/lfilter.cpp @@ -182,97 +182,66 @@ class DifferentiableIIR : public torch::autograd::Function { } }; -class DifferentiableFIR : public torch::autograd::Function { - public: - static torch::Tensor forward( - torch::autograd::AutogradContext* ctx, - const torch::Tensor& waveform, - const torch::Tensor& b_coeffs) { - int64_t n_order = b_coeffs.size(1); - int64_t n_channel = b_coeffs.size(0); - - namespace F = torch::nn::functional; - auto b_coeff_flipped = b_coeffs.flip(1).contiguous(); - auto padded_waveform = - F::pad(waveform, F::PadFuncOptions({n_order - 1, 0})); - - auto output = F::conv1d( - padded_waveform, - b_coeff_flipped.unsqueeze(1), - F::Conv1dFuncOptions().groups(n_channel)); - - ctx->save_for_backward({waveform, b_coeffs, output}); - return output; - } - - static torch::autograd::tensor_list backward( - torch::autograd::AutogradContext* ctx, - torch::autograd::tensor_list grad_outputs) { - auto saved = ctx->get_saved_variables(); - auto x = saved[0]; - auto b_coeffs = saved[1]; - auto y = saved[2]; - - int64_t n_batch = x.size(0); - int64_t n_channel = x.size(1); - int64_t n_order = b_coeffs.size(1); - - auto dx = torch::Tensor(); - auto db = torch::Tensor(); - auto dy = grad_outputs[0]; - - namespace F = torch::nn::functional; +// FIR filter forward and backward functions (no autograd inheritance) +torch::Tensor fir_forward( + const torch::Tensor& waveform, + const torch::Tensor& b_coeffs) { + int64_t n_order = b_coeffs.size(1); + int64_t n_channel = b_coeffs.size(0); - if (b_coeffs.requires_grad()) { - db = F::conv1d( - F::pad(x, F::PadFuncOptions({n_order - 1, 0})) - .view({1, n_batch * n_channel, -1}), - dy.view({n_batch * n_channel, 1, -1}), - F::Conv1dFuncOptions().groups(n_batch * n_channel)) - .view({n_batch, n_channel, -1}) - .sum(0) - .flip(1); - } + namespace F = torch::nn::functional; + auto b_coeff_flipped = b_coeffs.flip(1).contiguous(); + auto padded_waveform = + F::pad(waveform, F::PadFuncOptions({n_order - 1, 0})); - if (x.requires_grad()) { - dx = F::conv1d( - F::pad(dy, F::PadFuncOptions({0, n_order - 1})), - b_coeffs.unsqueeze(1), - F::Conv1dFuncOptions().groups(n_channel)); - } + auto output = F::conv1d( + padded_waveform, + b_coeff_flipped.unsqueeze(1), + F::Conv1dFuncOptions().groups(n_channel)); - return {dx, db}; - } -}; + return output; +} -torch::Tensor lfilter_core( +std::tuple fir_backward( + const torch::Tensor& grad_output, const torch::Tensor& waveform, - const torch::Tensor& a_coeffs, const torch::Tensor& b_coeffs) { - TORCH_CHECK(waveform.device() == a_coeffs.device()); - TORCH_CHECK(b_coeffs.device() == a_coeffs.device()); - TORCH_CHECK(a_coeffs.sizes() == b_coeffs.sizes()); - - TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 3); - TORCH_INTERNAL_ASSERT(a_coeffs.sizes().size() == 2); - TORCH_INTERNAL_ASSERT(a_coeffs.size(0) == waveform.size(1)); - + int64_t n_batch = waveform.size(0); + int64_t n_channel = waveform.size(1); int64_t n_order = b_coeffs.size(1); - TORCH_INTERNAL_ASSERT(n_order > 0); + auto dx = torch::Tensor(); + auto db = torch::Tensor(); + + namespace F = torch::nn::functional; + + // Compute gradient w.r.t. b_coeffs + if (b_coeffs.requires_grad()) { + db = F::conv1d( + F::pad(waveform, F::PadFuncOptions({n_order - 1, 0})) + .view({1, n_batch * n_channel, -1}), + grad_output.view({n_batch * n_channel, 1, -1}), + F::Conv1dFuncOptions().groups(n_batch * n_channel)) + .view({n_batch, n_channel, -1}) + .sum(0) + .flip(1); + } - auto filtered_waveform = DifferentiableFIR::apply( - waveform, - b_coeffs / - a_coeffs.index( - {torch::indexing::Slice(), torch::indexing::Slice(0, 1)})); + // Compute gradient w.r.t. waveform + if (waveform.requires_grad()) { + dx = F::conv1d( + F::pad(grad_output, F::PadFuncOptions({0, n_order - 1})), + b_coeffs.unsqueeze(1), + F::Conv1dFuncOptions().groups(n_channel)); + } - auto output = DifferentiableIIR::apply( - filtered_waveform, - a_coeffs / - a_coeffs.index( - {torch::indexing::Slice(), torch::indexing::Slice(0, 1)})); - return output; + return std::make_tuple(dx, db); +} + +torch::Tensor differentiable_iir_apply( + const torch::Tensor& waveform, + const torch::Tensor& a_coeffs_normalized) { + return DifferentiableIIR::apply(waveform, a_coeffs_normalized); } } // namespace @@ -285,9 +254,15 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { TORCH_LIBRARY(torchaudio, m) { m.def( - "torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor"); + "torchaudio::_differentiable_iir_apply(Tensor waveform, Tensor a_coeffs_normalized) -> Tensor"); + m.def( + "torchaudio::_fir_forward(Tensor waveform, Tensor b_coeffs) -> Tensor"); + m.def( + "torchaudio::_fir_backward(Tensor grad_output, Tensor waveform, Tensor b_coeffs) -> (Tensor, Tensor)"); } TORCH_LIBRARY_IMPL(torchaudio, CompositeImplicitAutograd, m) { - m.impl("torchaudio::_lfilter", lfilter_core); + m.impl("torchaudio::_differentiable_iir_apply", differentiable_iir_apply); + m.impl("torchaudio::_fir_forward", fir_forward); + m.impl("torchaudio::_fir_backward", fir_backward); } diff --git a/src/torchaudio/functional/filtering.py b/src/torchaudio/functional/filtering.py index 541c56c475..fbfe2f8a79 100644 --- a/src/torchaudio/functional/filtering.py +++ b/src/torchaudio/functional/filtering.py @@ -933,6 +933,9 @@ def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: T if _IS_TORCHAUDIO_EXT_AVAILABLE: _lfilter_core_cpu_loop = torch.ops.torchaudio._lfilter_core_loop + _differentiable_iir_apply = torch.ops.torchaudio._differentiable_iir_apply + _fir_forward = torch.ops.torchaudio._fir_forward + _fir_backward = torch.ops.torchaudio._fir_backward else: _lfilter_core_cpu_loop = _lfilter_core_generic_loop @@ -991,8 +994,49 @@ def _lfilter_core( return output +class _DifferentiableFIRFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, waveform, b_coeffs): + ctx.save_for_backward(waveform, b_coeffs) + + output = _fir_forward(waveform, b_coeffs) + return output + + @staticmethod + def backward(ctx, grad_output): + # Retrieve saved inputs + waveform, b_coeffs = ctx.saved_tensors + + # Call C++ backward function + grad_waveform, grad_b_coeffs = _fir_backward(grad_output, waveform, b_coeffs) + + return grad_waveform, grad_b_coeffs + + +def _lfilter_core_python( + waveform: Tensor, + a_coeffs: Tensor, + b_coeffs: Tensor, +) -> Tensor: + """Python implementation of lfilter_core using C++ DifferentiableFIR and DifferentiableIIR.""" + + # TODO here: input validation checks + + a0 = a_coeffs[:, 0:1] # Keep dimension for broadcasting + b_coeffs_normalized = b_coeffs / a0 + a_coeffs_normalized = a_coeffs / a0 + + # Apply FIR filter using Python autograd function + filtered_waveform = _DifferentiableFIRFunction.apply(waveform, b_coeffs_normalized) + + # Apply IIR filter (still using C++ autograd) + output = _differentiable_iir_apply(filtered_waveform, a_coeffs_normalized) + + return output + + if _IS_TORCHAUDIO_EXT_AVAILABLE: - _lfilter = torch.ops.torchaudio._lfilter + _lfilter = _lfilter_core_python else: _lfilter = _lfilter_core