Skip to content

Remove autograd from FIR #3968

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 58 additions & 83 deletions src/libtorchaudio/lfilter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,97 +182,66 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
}
};

class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
public:
static torch::Tensor forward(
torch::autograd::AutogradContext* ctx,
const torch::Tensor& waveform,
const torch::Tensor& b_coeffs) {
int64_t n_order = b_coeffs.size(1);
int64_t n_channel = b_coeffs.size(0);

namespace F = torch::nn::functional;
auto b_coeff_flipped = b_coeffs.flip(1).contiguous();
auto padded_waveform =
F::pad(waveform, F::PadFuncOptions({n_order - 1, 0}));

auto output = F::conv1d(
padded_waveform,
b_coeff_flipped.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel));

ctx->save_for_backward({waveform, b_coeffs, output});
return output;
}

static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto x = saved[0];
auto b_coeffs = saved[1];
auto y = saved[2];

int64_t n_batch = x.size(0);
int64_t n_channel = x.size(1);
int64_t n_order = b_coeffs.size(1);

auto dx = torch::Tensor();
auto db = torch::Tensor();
auto dy = grad_outputs[0];

namespace F = torch::nn::functional;
// FIR filter forward and backward functions (no autograd inheritance)
torch::Tensor fir_forward(
const torch::Tensor& waveform,
const torch::Tensor& b_coeffs) {
int64_t n_order = b_coeffs.size(1);
int64_t n_channel = b_coeffs.size(0);

if (b_coeffs.requires_grad()) {
db = F::conv1d(
F::pad(x, F::PadFuncOptions({n_order - 1, 0}))
.view({1, n_batch * n_channel, -1}),
dy.view({n_batch * n_channel, 1, -1}),
F::Conv1dFuncOptions().groups(n_batch * n_channel))
.view({n_batch, n_channel, -1})
.sum(0)
.flip(1);
}
namespace F = torch::nn::functional;
auto b_coeff_flipped = b_coeffs.flip(1).contiguous();
auto padded_waveform =
F::pad(waveform, F::PadFuncOptions({n_order - 1, 0}));

if (x.requires_grad()) {
dx = F::conv1d(
F::pad(dy, F::PadFuncOptions({0, n_order - 1})),
b_coeffs.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel));
}
auto output = F::conv1d(
padded_waveform,
b_coeff_flipped.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel));

return {dx, db};
}
};
return output;
}

torch::Tensor lfilter_core(
std::tuple<torch::Tensor, torch::Tensor> fir_backward(
const torch::Tensor& grad_output,
const torch::Tensor& waveform,
const torch::Tensor& a_coeffs,
const torch::Tensor& b_coeffs) {
TORCH_CHECK(waveform.device() == a_coeffs.device());
TORCH_CHECK(b_coeffs.device() == a_coeffs.device());
TORCH_CHECK(a_coeffs.sizes() == b_coeffs.sizes());

TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 3);
TORCH_INTERNAL_ASSERT(a_coeffs.sizes().size() == 2);
TORCH_INTERNAL_ASSERT(a_coeffs.size(0) == waveform.size(1));

int64_t n_batch = waveform.size(0);
int64_t n_channel = waveform.size(1);
int64_t n_order = b_coeffs.size(1);

TORCH_INTERNAL_ASSERT(n_order > 0);
auto dx = torch::Tensor();
auto db = torch::Tensor();

namespace F = torch::nn::functional;

// Compute gradient w.r.t. b_coeffs
if (b_coeffs.requires_grad()) {
db = F::conv1d(
F::pad(waveform, F::PadFuncOptions({n_order - 1, 0}))
.view({1, n_batch * n_channel, -1}),
grad_output.view({n_batch * n_channel, 1, -1}),
F::Conv1dFuncOptions().groups(n_batch * n_channel))
.view({n_batch, n_channel, -1})
.sum(0)
.flip(1);
}

auto filtered_waveform = DifferentiableFIR::apply(
waveform,
b_coeffs /
a_coeffs.index(
{torch::indexing::Slice(), torch::indexing::Slice(0, 1)}));
// Compute gradient w.r.t. waveform
if (waveform.requires_grad()) {
dx = F::conv1d(
F::pad(grad_output, F::PadFuncOptions({0, n_order - 1})),
b_coeffs.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel));
}

auto output = DifferentiableIIR::apply(
filtered_waveform,
a_coeffs /
a_coeffs.index(
{torch::indexing::Slice(), torch::indexing::Slice(0, 1)}));
return output;
return std::make_tuple(dx, db);
}

torch::Tensor differentiable_iir_apply(
const torch::Tensor& waveform,
const torch::Tensor& a_coeffs_normalized) {
return DifferentiableIIR::apply(waveform, a_coeffs_normalized);
}

} // namespace
Expand All @@ -285,9 +254,15 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {

TORCH_LIBRARY(torchaudio, m) {
m.def(
"torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor");
"torchaudio::_differentiable_iir_apply(Tensor waveform, Tensor a_coeffs_normalized) -> Tensor");
m.def(
"torchaudio::_fir_forward(Tensor waveform, Tensor b_coeffs) -> Tensor");
m.def(
"torchaudio::_fir_backward(Tensor grad_output, Tensor waveform, Tensor b_coeffs) -> (Tensor, Tensor)");
}

TORCH_LIBRARY_IMPL(torchaudio, CompositeImplicitAutograd, m) {
m.impl("torchaudio::_lfilter", lfilter_core);
m.impl("torchaudio::_differentiable_iir_apply", differentiable_iir_apply);
m.impl("torchaudio::_fir_forward", fir_forward);
m.impl("torchaudio::_fir_backward", fir_backward);
}
46 changes: 45 additions & 1 deletion src/torchaudio/functional/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,9 @@ def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: T

if _IS_TORCHAUDIO_EXT_AVAILABLE:
_lfilter_core_cpu_loop = torch.ops.torchaudio._lfilter_core_loop
_differentiable_iir_apply = torch.ops.torchaudio._differentiable_iir_apply
_fir_forward = torch.ops.torchaudio._fir_forward
_fir_backward = torch.ops.torchaudio._fir_backward
else:
_lfilter_core_cpu_loop = _lfilter_core_generic_loop

Expand Down Expand Up @@ -991,8 +994,49 @@ def _lfilter_core(
return output


class _DifferentiableFIRFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, waveform, b_coeffs):
ctx.save_for_backward(waveform, b_coeffs)

output = _fir_forward(waveform, b_coeffs)
return output

@staticmethod
def backward(ctx, grad_output):
# Retrieve saved inputs
waveform, b_coeffs = ctx.saved_tensors

# Call C++ backward function
grad_waveform, grad_b_coeffs = _fir_backward(grad_output, waveform, b_coeffs)

return grad_waveform, grad_b_coeffs


def _lfilter_core_python(
waveform: Tensor,
a_coeffs: Tensor,
b_coeffs: Tensor,
) -> Tensor:
"""Python implementation of lfilter_core using C++ DifferentiableFIR and DifferentiableIIR."""

# TODO here: input validation checks

a0 = a_coeffs[:, 0:1] # Keep dimension for broadcasting
b_coeffs_normalized = b_coeffs / a0
a_coeffs_normalized = a_coeffs / a0

# Apply FIR filter using Python autograd function
filtered_waveform = _DifferentiableFIRFunction.apply(waveform, b_coeffs_normalized)

# Apply IIR filter (still using C++ autograd)
output = _differentiable_iir_apply(filtered_waveform, a_coeffs_normalized)

return output


if _IS_TORCHAUDIO_EXT_AVAILABLE:
_lfilter = torch.ops.torchaudio._lfilter
_lfilter = _lfilter_core_python
else:
_lfilter = _lfilter_core

Expand Down
Loading