Skip to content

How to Train Residual LFQ #103

@ZetangForward

Description

@ZetangForward

Hi, I want to train a Residual LFQ model for audio, and this is my core code:

def _loss_fn(loss_fn, x_target, x_pred, cfg, padding_mask=None):
    if padding_mask is not None:
        padding_mask = padding_mask.unsqueeze(-1).expand_as(x_target)
        x_target = torch.where(padding_mask, x_target, torch.zeros_like(x_target)).to(x_pred.device)
        x_pred = torch.where(padding_mask, x_pred, torch.zeros_like(x_pred)).to(x_pred.device)
        mask_sum = padding_mask.sum()

    if loss_fn == 'l1':
        loss = torch.sum(torch.abs(x_pred - x_target)) / mask_sum
    elif loss_fn == 'l2':
        loss = torch.sum((x_pred - x_target) ** 2) / mask_sum
    elif loss_fn == 'linf':
        residual = ((x_pred - x_target) ** 2).reshape(x_target.shape[0], -1)
        # only consider the residual of the padded part
        masked_residual = torch.where(padding_mask.reshape(x_target.shape[0], -1), residual, torch.zeros_like(residual))
        values, _ = torch.topk(masked_residual, cfg.linf_k, dim=1)
        loss = torch.mean(values)
    else:
        assert False, f"Unknown loss_fn {loss_fn}"

    return loss


def training_step(self, batch, batch_idx):
        quantized, indices, commit_loss = self.model(batch['audio'], batch['padding_mask'])
        quantized_out = self.model.get_output_from_indices(indices)
        reconstruction_loss = _loss_fn('l2', batch['svg_path'], quantized_out, self.cfg, batch['padding_mask'])
        return reconstruction_loss  + commit_loss

model = ResidualLFQ(
        dim = config.lfq.dim,
        codebook_size = config.lfq.codebook_size,
        num_quantizers = config.lfq.num_quantizers
    )

I use reconstruction_loss and commit_loss to jointly update the ResidualLFQ model.

I wonder two things:

  1. Is the reconstruction loss necessary?
  2. Sometimes commitment loss is negative, e.g., -0.02, is this normal? Since I added commit_loss and reconstruction_loss together, it is weird that one loss is positive and one is negative ...

I hope to get some suggestions @kashif @lucidrains Thank you

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions