Skip to content

Add tutorial section: forward_backward_custom with a fixed target-token set at a single sequence position#725

Open
mvanhorn wants to merge 1 commit into
thinking-machines-lab:mainfrom
mvanhorn:fix/703-tinker-cookbook-custom-loss-logit-set-targets
Open

Add tutorial section: forward_backward_custom with a fixed target-token set at a single sequence position#725
mvanhorn wants to merge 1 commit into
thinking-machines-lab:mainfrom
mvanhorn:fix/703-tinker-cookbook-custom-loss-logit-set-targets

Conversation

@mvanhorn

Copy link
Copy Markdown

Summary

Add a new marimo cell pair (an explanation cell and a code cell) to tutorials/202_loss_functions.py, slotted after the DPO-style example (after line ~295) and before the "When to use each loss" table cell.

  1. New explanation cell content:

    ## Custom loss: scoring a fixed target-token set at one sequence position
    
    The built-in losses (`cross_entropy`, `importance_sampling`, `ppo`, `cispo`) require `target_tokens` to be the same length as the input sequence - one target per position. They are aimed at next-token prediction across an entire sequence.
    
    If you instead want to score only a *small fixed vocabulary* (e.g. 21 emotion-word tokens) at *one specific output position*, that is a different shape of problem and you should use `forward_backward_custom`. The logprobs tensor it gives you has shape `[T, vocab_size]` (well, full logprobs over the vocabulary at each input position), so you can:
    
    1. Index into the position you care about.
    2. Gather the logits at the target-token IDs you care about.
    3. Renormalize and compute whatever loss you want (KL, JSD, etc.) against your ground-truth distribution.
    
    The target-token IDs and target probabilities ride along in `Datum.loss_fn_inputs` for that datum. (Tinker's `forward_backward_custom` accepts arbitrary entries in `loss_fn_inputs` as long as the values are `TensorData` or JSON-serializable.)
  2. New code cell (illustrative, not API-key-required - it can run on the same training_client already created in the tutorial; pick a token position that exists in the existing sft_datum or build a tiny new datum):

    def kl_over_target_set(data, logprobs_list):
        """KL divergence between the model's distribution over a fixed target
        token set at one position and a ground-truth distribution.
    
        Each datum carries:
          - loss_fn_inputs["score_position"]: int, the sequence position to score at.
          - loss_fn_inputs["target_token_ids"]: 1D long tensor of K target token IDs.
          - loss_fn_inputs["target_probs"]: 1D float tensor of K ground-truth probs.
        """
        total_loss = torch.tensor(0.0)
        for datum, logprobs in zip(data, logprobs_list):
            pos = int(datum.loss_fn_inputs["score_position"].to_torch().item())
            target_ids = datum.loss_fn_inputs["target_token_ids"].to_torch().long()
            target_probs = datum.loss_fn_inputs["target_probs"].to_torch()
    
            # logprobs[pos] has shape [vocab_size]. Gather the K target slots,
            # then renormalize them so they form a distribution over the target set.
            target_logprobs = logprobs[pos][target_ids]               # [K]
            target_logprobs = target_logprobs - target_logprobs.logsumexp(0)
    
            # KL(target || model) over the K-slot distribution.
            kl = (target_probs * (target_probs.clamp_min(1e-12).log() - target_logprobs)).sum()
            total_loss = total_loss + kl
        return total_loss, {"kl_over_target_set": total_loss.item()}

    Plus a short comment line: "the shape mismatch you would see with cross_entropy (target_tokens.shape[0]=K, token_count=T) does not apply here - forward_backward_custom returns full per-position logprobs and lets you index where you want."

  3. Update the "When to use each loss" table cell to add one row for the new pattern, e.g.:

    | forward_backward_custom (logit-set at one position) | Few-class classification at a known output position; KL against a target distribution over a fixed token set | Sidesteps the "target_tokens must equal sequence length" constraint of built-in losses |

  4. Reply on issue Questions regrading a more flexible custom loss function #703 with a one-paragraph comment when the PR is open, linking the PR and confirming this is the documented pattern. (No comment before the PR is ready.)

  5. PR title: docs(tutorials): add custom-loss example for scoring a fixed target token set at one position

    PR body (concise, factual):

    Addresses Questions regrading a more flexible custom loss function #703.

    Adds a marimo cell to tutorials/202_loss_functions.py that shows how to use forward_backward_custom to score a fixed set of K target tokens at one specific output position - the use case in the issue, where a 21-token emotion vocabulary is scored against a ground-truth distribution with KL divergence.

    This is the pattern the built-in losses cannot express (they require target_tokens to match the sequence length), and the existing custom-loss examples in the tutorial only show full-sequence patterns (logprob-squared and DPO-style pairwise).

Why this matters

Sidong Zhang's issue describes a vision-language fine-tuning setup: at one specific output position, they want the model's distribution over a fixed 21-word emotion vocabulary, and want to train against a ground-truth distribution with KL divergence. They tried forward_backward with the built-in cross_entropy loss and a 21-element target_tokens, which Tinker correctly rejected because for the built-in losses target_tokens must be 1:1 with sequence positions:

input sequence and target_tokens at index 0 must agree on sequence length:
target_tokens.shape[0]=21 token_count=3103

The correct answer is forward_backward_custom: it returns full per-position log-probabilities (shape [T, vocab_size]), and the user can index into the logits at the position they care about, gather the 21 target-token logits, renormalize, and compute KL against the target distribution.

This is exactly the use case forward_backward_custom exists for, but the existing tutorial (tutorials/202_loss_functions.py) only shows two custom-loss examples (logprob-squared and a DPO-style pairwise), neither of which demonstrates the "pick specific logits at a specific position" pattern. New users see "custom loss" but cannot tell that you can do single-position logit-set scoring with it.

Testing

  • uv run pytest tutorials/ (if there are tutorial tests) passes - or if there are none, smoke-test the marimo notebook executes top to bottom with uv run marimo run tutorials/202_loss_functions.py --no-token and a dummy TINKER_API_KEY set (verify only the new cells render and parse, since the cell does not need the live model to type-check).
  • uv run ruff check tutorials/202_loss_functions.py and uv run ruff format --check tutorials/202_loss_functions.py pass.
  • uv run pyright tutorials/202_loss_functions.py passes - the new function annotations match the existing logprob_squared_loss example signature.
  • Manual: open tutorials/202_loss_functions.py in marimo and confirm the new cell appears between the DPO-style example and the "When to use each loss" cell, and that the table has the new row.

Fixes #703

AI was used for assistance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Questions regrading a more flexible custom loss function

1 participant