Skip to content

Fix label alignment bug in finetuning#278

Open
matevosashot wants to merge 1 commit into
QwenLM:mainfrom
matevosashot:main
Open

Fix label alignment bug in finetuning#278
matevosashot wants to merge 1 commit into
QwenLM:mainfrom
matevosashot:main

Conversation

@matevosashot

Copy link
Copy Markdown

Fix label alignment bug in finetuning

Fixes label alignment issues in sft_12hz.py and modeling_qwen3_tts.py caused by incorrect interaction with ForCausalLMLoss from transformers.

Context

ForCausalLMLoss has two modes: passing labels automatically shifts them left by one (token n predicts n+1), while passing shift_labels uses them as-is.

sft_12hz.py

The old code manually shifted inputs and labels before passing to the talker (inputs_embeds[:, :-1], labels[:, 1:]). This is unnecessary — ForCausalLMLoss already handles the left-shift internally. The fix passes full unshifted tensors and adjusts hidden state slicing accordingly. Also adds the missing text_projection call on text embeddings.

modeling_qwen3_tts.py

The subtalker outputs 15 codes per position. Passing them via labels causes ForCausalLMLoss to shift and drop one, leaving only 14 — misaligning with the 15 logit outputs. The fix passes subtalker labels via shift_labels instead, bypassing the automatic shift.


ForCausalLMLoss implementation:

def ForCausalLMLoss(
    logits, labels, vocab_size,
    num_items_in_batch=None, ignore_index=-100,
    shift_labels=None, **kwargs,
) -> torch.Tensor:
    logits = logits.float()
    if shift_labels is None:
        # Shift so that tokens < n predict n
        labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
        shift_labels = labels[..., 1:].contiguous()
    logits = logits.view(-1, vocab_size)
    shift_labels = shift_labels.view(-1)
    shift_labels = shift_labels.to(logits.device)
    loss = fixed_cross_entropy(logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
    return loss

chopchop-jiahao added a commit to chopchop-jiahao/Qwen3-TTS that referenced this pull request Apr 24, 2026
…t_12hz.py

Bug 1 — Label double-shift (issue QwenLM#179, pr QwenLM#278):
  sft_12hz.py manually shifted inputs/labels before passing to talker,
  but ForCausalLMLoss shifts internally too. Double-shift trained on wrong
  targets, causing audio to speed up each epoch. Fixed by passing full tensors.

Bug 2 — TensorBoard crash on accelerate >=1.12.0 (issue QwenLM#286, pr QwenLM#295):
  Accelerator(log_with="tensorboard") requires project_dir. Added project_dir="./logs".

Bug 3 — flash_attention_2 ImportError on startup:
  flash-attn is not installed in the default Docker image and fails to build.
  Replaced with PyTorch built-in sdpa, which is equivalent on A100.

Bug 4 — 0.6B tensor dimension mismatch (RuntimeError):
  sft_12hz.py skipped model.talker.text_projection for both model sizes.
  For 1.7B this was hidden (both embeddings 2048-dim, addition worked by accident).
  For 0.6B, text_embedding=2048 and codec_embedding=1024 → RuntimeError on addition.
  Fix: call text_projection (2048→1024 for 0.6B, 2048→2048 for 1.7B) before masking.
  Mask is applied after projection to prevent linear_fc1 bias contaminating zero-vectors.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant