-
Notifications
You must be signed in to change notification settings - Fork 281
Open
Description
Hi, I want to train a Residual LFQ model for audio, and this is my core code:
def _loss_fn(loss_fn, x_target, x_pred, cfg, padding_mask=None):
if padding_mask is not None:
padding_mask = padding_mask.unsqueeze(-1).expand_as(x_target)
x_target = torch.where(padding_mask, x_target, torch.zeros_like(x_target)).to(x_pred.device)
x_pred = torch.where(padding_mask, x_pred, torch.zeros_like(x_pred)).to(x_pred.device)
mask_sum = padding_mask.sum()
if loss_fn == 'l1':
loss = torch.sum(torch.abs(x_pred - x_target)) / mask_sum
elif loss_fn == 'l2':
loss = torch.sum((x_pred - x_target) ** 2) / mask_sum
elif loss_fn == 'linf':
residual = ((x_pred - x_target) ** 2).reshape(x_target.shape[0], -1)
# only consider the residual of the padded part
masked_residual = torch.where(padding_mask.reshape(x_target.shape[0], -1), residual, torch.zeros_like(residual))
values, _ = torch.topk(masked_residual, cfg.linf_k, dim=1)
loss = torch.mean(values)
else:
assert False, f"Unknown loss_fn {loss_fn}"
return loss
def training_step(self, batch, batch_idx):
quantized, indices, commit_loss = self.model(batch['audio'], batch['padding_mask'])
quantized_out = self.model.get_output_from_indices(indices)
reconstruction_loss = _loss_fn('l2', batch['svg_path'], quantized_out, self.cfg, batch['padding_mask'])
return reconstruction_loss + commit_loss
model = ResidualLFQ(
dim = config.lfq.dim,
codebook_size = config.lfq.codebook_size,
num_quantizers = config.lfq.num_quantizers
)
I use reconstruction_loss
and commit_loss
to jointly update the ResidualLFQ model.
I wonder two things:
- Is the reconstruction loss necessary?
- Sometimes commitment loss is negative, e.g., -0.02, is this normal? Since I added
commit_loss
andreconstruction_loss
together, it is weird that one loss is positive and one is negative ...
I hope to get some suggestions @kashif @lucidrains Thank you
akashrajkn, ChuxiJ and andyl-flwls
Metadata
Metadata
Assignees
Labels
No labels