Skip to content

Port autograd parts of lfilter to python #3954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/scripts/unittest-linux/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@ fi

(
cd test
pytest torchaudio_unittest -k "not backend and not /io/ and not prototype and not sox and not ffmpeg and not fairseq and not hdemucs"
pytest torchaudio_unittest -k "not backend and not /io/ and not prototype and not sox and not ffmpeg and not fairseq and not hdemucs and not torchscript_consistency"
)
195 changes: 12 additions & 183 deletions src/libtorchaudio/lfilter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,194 +100,23 @@ void lfilter_core_generic_loop(
}
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've got lfilter_core_generic_loop left-over above in this file. Is it still needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. That path was used before if you had tensors on a cuda device but you hadn't compiled torchaudio with support for CUDA. We should probably still keep it around. I think we can register it with the dispatcher as a CompositeExplicitAutograd path, which, as I understand it, will work as a catch-all if other dispatcher keys don't kick in.

class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
public:
static torch::Tensor forward(
torch::autograd::AutogradContext* ctx,
const torch::Tensor& waveform,
const torch::Tensor& a_coeffs_normalized) {
auto device = waveform.device();
auto dtype = waveform.dtype();
int64_t n_batch = waveform.size(0);
int64_t n_channel = waveform.size(1);
int64_t n_sample = waveform.size(2);
int64_t n_order = a_coeffs_normalized.size(1);
int64_t n_sample_padded = n_sample + n_order - 1;

auto a_coeff_flipped = a_coeffs_normalized.flip(1).contiguous();

auto options = torch::TensorOptions().dtype(dtype).device(device);
auto padded_output_waveform =
torch::zeros({n_batch, n_channel, n_sample_padded}, options);

if (device.is_cpu()) {
cpu_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform);
} else if (device.is_cuda()) {
#ifdef USE_CUDA
cuda_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform);
#else
lfilter_core_generic_loop(
waveform, a_coeff_flipped, padded_output_waveform);
#endif
} else {
lfilter_core_generic_loop(
waveform, a_coeff_flipped, padded_output_waveform);
}

auto output = padded_output_waveform.index(
{torch::indexing::Slice(),
torch::indexing::Slice(),
torch::indexing::Slice(n_order - 1, torch::indexing::None)});

ctx->save_for_backward({waveform, a_coeffs_normalized, output});
return output;
}

static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto x = saved[0];
auto a_coeffs_normalized = saved[1];
auto y = saved[2];

int64_t n_channel = x.size(1);
int64_t n_order = a_coeffs_normalized.size(1);

auto dx = torch::Tensor();
auto da = torch::Tensor();
auto dy = grad_outputs[0];

namespace F = torch::nn::functional;

auto tmp =
DifferentiableIIR::apply(dy.flip(2).contiguous(), a_coeffs_normalized)
.flip(2);

if (x.requires_grad()) {
dx = tmp;
}

if (a_coeffs_normalized.requires_grad()) {
da = -torch::matmul(
tmp.transpose(0, 1).reshape({n_channel, 1, -1}),
F::pad(y, F::PadFuncOptions({n_order - 1, 0}))
.unfold(2, n_order, 1)
.transpose(0, 1)
.reshape({n_channel, -1, n_order}))
.squeeze(1)
.flip(1);
}
return {dx, da};
}
};

class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
public:
static torch::Tensor forward(
torch::autograd::AutogradContext* ctx,
const torch::Tensor& waveform,
const torch::Tensor& b_coeffs) {
int64_t n_order = b_coeffs.size(1);
int64_t n_channel = b_coeffs.size(0);

namespace F = torch::nn::functional;
auto b_coeff_flipped = b_coeffs.flip(1).contiguous();
auto padded_waveform =
F::pad(waveform, F::PadFuncOptions({n_order - 1, 0}));

auto output = F::conv1d(
padded_waveform,
b_coeff_flipped.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel));

ctx->save_for_backward({waveform, b_coeffs, output});
return output;
}

static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto x = saved[0];
auto b_coeffs = saved[1];
auto y = saved[2];

int64_t n_batch = x.size(0);
int64_t n_channel = x.size(1);
int64_t n_order = b_coeffs.size(1);

auto dx = torch::Tensor();
auto db = torch::Tensor();
auto dy = grad_outputs[0];

namespace F = torch::nn::functional;

if (b_coeffs.requires_grad()) {
db = F::conv1d(
F::pad(x, F::PadFuncOptions({n_order - 1, 0}))
.view({1, n_batch * n_channel, -1}),
dy.view({n_batch * n_channel, 1, -1}),
F::Conv1dFuncOptions().groups(n_batch * n_channel))
.view({n_batch, n_channel, -1})
.sum(0)
.flip(1);
}

if (x.requires_grad()) {
dx = F::conv1d(
F::pad(dy, F::PadFuncOptions({0, n_order - 1})),
b_coeffs.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel));
}

return {dx, db};
}
};

torch::Tensor lfilter_core(
const torch::Tensor& waveform,
const torch::Tensor& a_coeffs,
const torch::Tensor& b_coeffs) {
TORCH_CHECK(waveform.device() == a_coeffs.device());
TORCH_CHECK(b_coeffs.device() == a_coeffs.device());
TORCH_CHECK(a_coeffs.sizes() == b_coeffs.sizes());

TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 3);
TORCH_INTERNAL_ASSERT(a_coeffs.sizes().size() == 2);
TORCH_INTERNAL_ASSERT(a_coeffs.size(0) == waveform.size(1));

int64_t n_order = b_coeffs.size(1);

TORCH_INTERNAL_ASSERT(n_order > 0);

auto filtered_waveform = DifferentiableFIR::apply(
waveform,
b_coeffs /
a_coeffs.index(
{torch::indexing::Slice(), torch::indexing::Slice(0, 1)}));
} // namespace

auto output = DifferentiableIIR::apply(
filtered_waveform,
a_coeffs /
a_coeffs.index(
{torch::indexing::Slice(), torch::indexing::Slice(0, 1)}));
return output;
TORCH_LIBRARY(torchaudio, m) {
m.def(
"torchaudio::_lfilter_core_loop(Tensor input_signal_windows, Tensor a_coeff_flipped, Tensor(a!) padded_output_waveform) -> ()");
}

} // namespace

// Note: We want to avoid using "catch-all" kernel.
// The following registration should be replaced with CPU specific registration.
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::_lfilter_core_loop", &cpu_lfilter_core_loop);
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("torchaudio::_lfilter_core_loop", &cpu_lfilter_core_loop);
}

TORCH_LIBRARY(torchaudio, m) {
m.def(
"torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor");
#ifdef USE_CUDA
TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
m.impl("torchaudio::_lfilter_core_loop", &cuda_lfilter_core_loop);
}
#endif

TORCH_LIBRARY_IMPL(torchaudio, CompositeImplicitAutograd, m) {
m.impl("torchaudio::_lfilter", lfilter_core);
TORCH_LIBRARY_IMPL(torchaudio, CompositeExplicitAutograd, m) {
m.impl("torchaudio::_lfilter_core_loop", &lfilter_core_generic_loop);
}
132 changes: 68 additions & 64 deletions src/torchaudio/functional/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
from torch import Tensor
import torch.nn.functional as F

from torchaudio._extension import _IS_TORCHAUDIO_EXT_AVAILABLE

Expand Down Expand Up @@ -932,70 +933,74 @@ def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: T


if _IS_TORCHAUDIO_EXT_AVAILABLE:
_lfilter_core_cpu_loop = torch.ops.torchaudio._lfilter_core_loop
_lfilter_core_loop = torch.ops.torchaudio._lfilter_core_loop
else:
_lfilter_core_cpu_loop = _lfilter_core_generic_loop


def _lfilter_core(
waveform: Tensor,
a_coeffs: Tensor,
b_coeffs: Tensor,
) -> Tensor:

if a_coeffs.size() != b_coeffs.size():
raise ValueError(
"Expected coeffs to be the same size."
f"Found a_coeffs size: {a_coeffs.size()}, b_coeffs size: {b_coeffs.size()}"
)
if waveform.ndim != 3:
raise ValueError(f"Expected waveform to be 3 dimensional. Found: {waveform.ndim}")
if not (waveform.device == a_coeffs.device == b_coeffs.device):
raise ValueError(
"Expected waveform and coeffs to be on the same device."
f"Found: waveform device:{waveform.device}, a_coeffs device: {a_coeffs.device}, "
f"b_coeffs device: {b_coeffs.device}"
)

n_batch, n_channel, n_sample = waveform.size()
n_order = a_coeffs.size(1)
if n_order <= 0:
raise ValueError(f"Expected n_order to be positive. Found: {n_order}")

# Pad the input and create output

padded_waveform = torch.nn.functional.pad(waveform, [n_order - 1, 0])
padded_output_waveform = torch.zeros_like(padded_waveform)

# Set up the coefficients matrix
# Flip coefficients' order
a_coeffs_flipped = a_coeffs.flip(1)
b_coeffs_flipped = b_coeffs.flip(1)

# calculate windowed_input_signal in parallel using convolution
input_signal_windows = torch.nn.functional.conv1d(padded_waveform, b_coeffs_flipped.unsqueeze(1), groups=n_channel)

input_signal_windows.div_(a_coeffs[:, :1])
a_coeffs_flipped.div_(a_coeffs[:, :1])

if (
input_signal_windows.device == torch.device("cpu")
and a_coeffs_flipped.device == torch.device("cpu")
and padded_output_waveform.device == torch.device("cpu")
):
_lfilter_core_cpu_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform)
else:
_lfilter_core_generic_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform)

output = padded_output_waveform[:, :, n_order - 1 :]
return output


if _IS_TORCHAUDIO_EXT_AVAILABLE:
_lfilter = torch.ops.torchaudio._lfilter
else:
_lfilter = _lfilter_core

_lfilter_core_loop = _lfilter_core_generic_loop


class DifferentiableFIR(torch.autograd.Function):
@staticmethod
def forward(ctx, waveform, b_coeffs):
n_order = b_coeffs.size(1)
n_channel = b_coeffs.size(0)
b_coeff_flipped = b_coeffs.flip(1).contiguous()
padded_waveform = F.pad(waveform, (n_order - 1, 0))
output = F.conv1d(padded_waveform, b_coeff_flipped.unsqueeze(1), groups=n_channel)
ctx.save_for_backward(waveform, b_coeffs, output)
return output

@staticmethod
def backward(ctx, dy):
x, b_coeffs, y = ctx.saved_tensors
n_batch = x.size(0)
n_channel = x.size(1)
n_order = b_coeffs.size(1)
db = F.conv1d(
F.pad(x, (n_order - 1, 0)).view(1, n_batch * n_channel, -1),
dy.view(n_batch * n_channel, 1, -1),
groups=n_batch * n_channel
).view(
n_batch, n_channel, -1
).sum(0).flip(1) if b_coeffs.requires_grad else None
dx = F.conv1d(
F.pad(dy, (0, n_order - 1)),
b_coeffs.unsqueeze(1),
groups=n_channel
) if x.requires_grad else None
return (dx, db)

class DifferentiableIIR(torch.autograd.Function):
@staticmethod
def forward(ctx, waveform, a_coeffs_normalized):
n_batch, n_channel, n_sample = waveform.shape
n_order = a_coeffs_normalized.size(1)
n_sample_padded = n_sample + n_order - 1

a_coeff_flipped = a_coeffs_normalized.flip(1).contiguous();
padded_output_waveform = torch.zeros(n_batch, n_channel, n_sample_padded,
device=waveform.device, dtype=waveform.dtype)
_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform)
output = padded_output_waveform[:,:,n_order - 1:]
ctx.save_for_backward(waveform, a_coeffs_normalized, output)
return output

@staticmethod
def backward(ctx, dy):
x, a_coeffs_normalized, y = ctx.saved_tensors
n_channel = x.size(1)
n_order = a_coeffs_normalized.size(1)
tmp = DifferentiableIIR.apply(dy.flip(2).contiguous(), a_coeffs_normalized).flip(2)
dx = tmp if x.requires_grad else None
da = -(tmp.transpose(0, 1).reshape(n_channel, 1, -1) @
F.pad(y, (n_order - 1, 0)).unfold(2, n_order, 1).transpose(0,1)
.reshape(n_channel, -1, n_order)
).squeeze(1).flip(1) if a_coeffs_normalized.requires_grad else None
return (dx, da)

def _lfilter(waveform, a_coeffs, b_coeffs):
n_order = b_coeffs.size(1)
filtered_waveform = DifferentiableFIR.apply(waveform, b_coeffs / a_coeffs[:, 0:1])
return DifferentiableIIR.apply(filtered_waveform, a_coeffs / a_coeffs[:, 0:1])

def lfilter(waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool = True, batching: bool = True) -> Tensor:
r"""Perform an IIR filter by evaluating difference equation, using differentiable implementation
Expand Down Expand Up @@ -1066,7 +1071,6 @@ def lfilter(waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool =

return output


def lowpass_biquad(waveform: Tensor, sample_rate: int, cutoff_freq: float, Q: float = 0.707) -> Tensor:
r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation.

Expand Down
Loading