Add tutorial section: forward_backward_custom with a fixed target-token set at a single sequence position#725
Open
mvanhorn wants to merge 1 commit into
Conversation
…en set at a single sequence position Fixes thinking-machines-lab#703
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Add a new marimo cell pair (an explanation cell and a code cell) to
tutorials/202_loss_functions.py, slotted after the DPO-style example (after line ~295) and before the "When to use each loss" table cell.New explanation cell content:
New code cell (illustrative, not API-key-required - it can run on the same
training_clientalready created in the tutorial; pick a token position that exists in the existingsft_datumor build a tiny new datum):Plus a short comment line: "the shape mismatch you would see with
cross_entropy(target_tokens.shape[0]=K, token_count=T) does not apply here -forward_backward_customreturns full per-position logprobs and lets you index where you want."Update the "When to use each loss" table cell to add one row for the new pattern, e.g.:
|
forward_backward_custom(logit-set at one position) | Few-class classification at a known output position; KL against a target distribution over a fixed token set | Sidesteps the "target_tokens must equal sequence length" constraint of built-in losses |Reply on issue Questions regrading a more flexible custom loss function #703 with a one-paragraph comment when the PR is open, linking the PR and confirming this is the documented pattern. (No comment before the PR is ready.)
PR title:
docs(tutorials): add custom-loss example for scoring a fixed target token set at one positionPR body (concise, factual):
Why this matters
Sidong Zhang's issue describes a vision-language fine-tuning setup: at one specific output position, they want the model's distribution over a fixed 21-word emotion vocabulary, and want to train against a ground-truth distribution with KL divergence. They tried
forward_backwardwith the built-incross_entropyloss and a 21-elementtarget_tokens, which Tinker correctly rejected because for the built-in lossestarget_tokensmust be 1:1 with sequence positions:The correct answer is
forward_backward_custom: it returns full per-position log-probabilities (shape[T, vocab_size]), and the user can index into the logits at the position they care about, gather the 21 target-token logits, renormalize, and compute KL against the target distribution.This is exactly the use case
forward_backward_customexists for, but the existing tutorial (tutorials/202_loss_functions.py) only shows two custom-loss examples (logprob-squared and a DPO-style pairwise), neither of which demonstrates the "pick specific logits at a specific position" pattern. New users see "custom loss" but cannot tell that you can do single-position logit-set scoring with it.Testing
uv run pytest tutorials/(if there are tutorial tests) passes - or if there are none, smoke-test the marimo notebook executes top to bottom withuv run marimo run tutorials/202_loss_functions.py --no-tokenand a dummyTINKER_API_KEYset (verify only the new cells render and parse, since the cell does not need the live model to type-check).uv run ruff check tutorials/202_loss_functions.pyanduv run ruff format --check tutorials/202_loss_functions.pypass.uv run pyright tutorials/202_loss_functions.pypasses - the new function annotations match the existinglogprob_squared_lossexample signature.tutorials/202_loss_functions.pyin marimo and confirm the new cell appears between the DPO-style example and the "When to use each loss" cell, and that the table has the new row.Fixes #703
AI was used for assistance.