Skip to content

Setup tokenized pile dataset#371

Merged
danbraunai-goodfire merged 5 commits intomainfrom
feature/custom-pile
Feb 9, 2026
Merged

Setup tokenized pile dataset#371
danbraunai-goodfire merged 5 commits intomainfrom
feature/custom-pile

Conversation

@danbraunai-goodfire
Copy link
Collaborator

@danbraunai-goodfire danbraunai-goodfire commented Feb 9, 2026

Description

Created tokenized pile. See here (and here). NOTE: I want to update these scripts at some point to not import spd.

Use these tokenized datasets in pretraining and in the spd pile config.

Process:

  • Load in n_ctx=513 for pretraining. With this we do [:-1] for the batch and [-1:] for the labels
  • FOR NOW: For spd, just continue using n_ctx=512, and don't change how the SPD code handles it. Currently the spd code will calculate ce losses using the fancy ignore_index=-100 thing. Once we sort out Makes ComponentModel generic over batch and output types #363, I think we should make our dataloaders return batch and labels, and then we can set n_ctx for the dataset to 513 just like we do for pretrain.

Related Issue

N/A

Motivation and Context

If using the regular monology/pile-uncopyrighted, it's difficult to split into train/val in a way that's reproducible and that allows for streaming=True (which non-cluster users are going to need).

How Has This Been Tested?

No new tests.

Does this PR introduce a breaking change?

No

@ocg-goodfire
Copy link
Collaborator

@claude no full review pls but what do you think of the 513-length pattern? my intuition is that it's more typical to just serve 512 batches and use ignore indices for the offset problem. pls reference the torch CE loss functionality too (how it handles offsets)

@claude
Copy link
Contributor

claude bot commented Feb 9, 2026

Claude finished @ocg-goodfire's task —— View job


Analyzing 513-length pattern vs typical approaches

  • Read changed files to understand the 513-length implementation
  • Analyze PyTorch CrossEntropyLoss ignore_index functionality
  • Compare 513-length approach vs typical 512 + ignore indices
  • Provide technical feedback on trade-offs

Analysis

You're absolutely right that the typical pattern is to use 512-length sequences with ignore indices. Here's my analysis of the trade-offs:

Current 513-length approach:

  • Pretraining: loads n_ctx=513, uses bat[:, :-1] for input (512 tokens), bat[:, 1:] for targets (512 tokens)
  • SPD: loads max_seq_len=512 directly since no autoregressive training needed
  • See spd/pretrain/train.py:452-453 and config shows n_ctx: 513

Typical 512 + ignore_index approach:

  • Load sequences of exactly 512 tokens
  • Use ignore_index=-100 in CrossEntropyLoss for the offset problem
  • PyTorch CE loss ignores targets with ignore_index value in loss computation but keeps gradients flowing

Technical considerations:

  1. Memory efficiency: 513-length approach uses slightly more memory (0.2% increase) but avoids ignore index masking overhead

  2. Implementation clarity: Current approach is cleaner - the slicing [:-1] and [1:] is explicit and doesn't require ignore index handling

  3. PyTorch CE loss with ignore_index: torch.nn.CrossEntropyLoss(ignore_index=-100) would handle offset by:

    • Computing loss only on non-ignored positions
    • Setting first target to -100, keeping rest as normal token IDs
    • Mathematically equivalent but adds masking complexity
  4. Data pipeline consistency: Since you're creating a custom tokenized dataset anyway, the 513 approach gives you exact control over sequence boundaries

Recommendation: The 513-length pattern is actually quite reasonable here because:

  • Memory overhead is negligible (1 extra token per sequence)
  • Avoids ignore index complexity
  • Makes the input/target relationship explicit
  • Matches your custom dataset creation workflow

The typical 512+ignore approach would make sense if you were using standard datasets, but since you're tokenizing specifically for this use case, the 513 approach is cleaner.

@ocg-goodfire
Copy link
Collaborator

@claude ok great. just are you sure this'll work with torch CE? I thought the torch CE loss had some weird gross thing where it implicitly handles inputs being non-correctly offset. Basically I thought it worked when x = [logits(a), logits(b), logits(c)] and y = [a, b, c], which would make this case fail, right?

@claude
Copy link
Contributor

claude bot commented Feb 9, 2026

Claude finished @ocg-goodfire's task —— View job


Analyzing PyTorch CrossEntropyLoss behavior with 513-length pattern

  • Read current training implementation to see how CE loss is used
  • Analyze PyTorch CE loss input/target alignment requirements
  • Determine if 513-length approach has alignment issues
  • Provide corrected analysis

You're absolutely correct to question this! Let me clarify the PyTorch CrossEntropyLoss behavior:

The key insight: PyTorch CE loss does NOT have any built-in autoregressive offset handling. It simply compares input[i] with target[i] position-wise.

Current 513-length implementation:

# Training (train.py:452-456)
x = bat[:, :-1]  # [batch, 512] - tokens 0-511 (input)
y = bat[:, 1:]   # [batch, 512] - tokens 1-512 (targets) 
logits, loss = model(x, y)

# Model forward (llama_simple_mlp.py)
logits = self.lm_head(x)  # [batch, 512, vocab] - predictions for each position
loss = F.cross_entropy(logits.view(-1, vocab), targets.view(-1))

The alignment IS correct:

  • logits[i] = prediction after processing input tokens 0..i → predicts token i+1
  • targets[i] = actual token i+1 from original sequence
  • CE compares: prediction for token i+1 vs actual token i+1 ✓

Your concern about x = [logits(a), logits(b), logits(c)], y = [a, b, c] failing:
That would indeed fail, but that's not what's happening here. The correct pattern is:

  • x = [logits_after_a, logits_after_ab, logits_after_abc] (predictions for next token)
  • y = [b, c, d] (actual next tokens)

So the 513-length approach works perfectly - it's the clean way to get proper autoregressive alignment without needing ignore_index masking.

@danbraunai-goodfire
Copy link
Collaborator Author

Note that even if we did the 512 length with ignore_index for ce losses, we have to manually handle this for regular SPD things when we load in a 512 context and e.g. calculate the L0. We don't want to calculate on the final token because that token wasn't trained on a label (or at least, we'd get slightly weirded results there). Same for whatever analysis we want to do on the full 512 context.

@ocg-goodfire
Copy link
Collaborator

@claude gotcha, I think I'm getting confused with hf transformers trainers, which I believe do this. Anyway, thanks.

@danbraunai-goodfire I'm sceptical that the models would generalise to 1-token-too-long sequences, but definitely not willing to bet on that intuition so fair point that this is safer

@claude
Copy link
Contributor

claude bot commented Feb 9, 2026

Claude finished @ocg-goodfire's task —— View job


Analysis Complete

You're exactly right about the HF transformers confusion! HuggingFace transformers do indeed handle the autoregressive offset automatically in their training loops, which can make the manual slicing approach here seem unusual at first glance.

Key clarifications from this discussion:

  1. HF vs Manual Implementation: HuggingFace trainers abstract away the [:-1] and [1:] slicing, but under the hood they're doing the same alignment. The manual approach here gives explicit control.

  2. Why 513-length is safer for SPD: As @danbraunai-goodfire noted, even if we used 512 + ignore_index for training, SPD analysis (L0 calculations, component importance, etc.) would still need to handle the final token specially since it wasn't trained with a target. The 513 → 512 approach makes this boundary explicit.

  3. Model generalization concern: @ocg-goodfire's point about models potentially not generalizing to "1-token-too-long sequences" is valid, but since the model only ever sees 512 tokens during training (after the [:-1] slice), there's no generalization issue.

The implementation looks solid - clean autoregressive alignment without the complexity of ignore masking, and explicit handling of sequence boundaries for both training and analysis phases.


Comment on lines +230 to +233
if dataset_config.n_ctx < tokenized_len:
col = dataset_config.column_name
n_ctx = dataset_config.n_ctx
torch_dataset = dataset.map(lambda x: {col: x[col][:n_ctx]}).with_format("torch")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're just throwing out a bunch of data in this case, right? seems like instead maybe we should just assert dataset_config.n_ctx == tokenized_len. Do we imagine ever not doing that? and if we do then cropping examples doesn't seem the best way, or at least a potential footgun

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah in SPD we use max_seq_len=512, so we just throw out one token. We could show that it's within 1 of the tokenized_len. But it also might be useful for random sampling of the tokenized dataset.

@danbraunai-goodfire
Copy link
Collaborator Author

@ocg-goodfire

@danbraunai-goodfire I'm sceptical that the models would generalise to 1-token-too-long sequences, but definitely not willing to bet on that intuition so fair point that this is safer

The changes here don't rely on the model generalising. Maybe you mean "that the models wouldn't generalise"? In which case, yeah I feel like they probably would, but I don't want to take a risk.

@danbraunai-goodfire danbraunai-goodfire merged commit b71c7b2 into main Feb 9, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants