Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions spd/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@


class DatasetConfig(BaseConfig):
name: str = "lennart-finke/SimpleStories"
is_tokenized: bool = True
hf_tokenizer_path: str | None = None
streaming: bool = False
split: str = "train"
n_ctx: int = 1024
name: str
is_tokenized: bool
hf_tokenizer_path: str | None
streaming: bool
split: str
n_ctx: int
"""Must be model n_ctx + 1 to provide room for next-token label indexing."""
seed: int | None = None
column_name: str = "input_ids"
column_name: str
"""The name of the column in the dataset that contains the data (tokenized or non-tokenized).
Typically 'input_ids' for datasets stored with e2e_sae/scripts/upload_hf_dataset.py, or "tokens"
for datasets tokenized in TransformerLens (e.g. NeelNanda/pile-10k)."""
Expand Down Expand Up @@ -223,9 +224,14 @@ def create_data_loader(
assert isinstance(sample, Tensor) and sample.ndim == 1, (
f"Expected the dataset to be tokenized. Got type {type(sample)}"
)
assert len(sample) == dataset_config.n_ctx, (
f"n_ctx ({dataset_config.n_ctx}) does not match the tokenized length ({len(sample)})."
tokenized_len = len(sample)
assert dataset_config.n_ctx <= tokenized_len, (
f"n_ctx ({dataset_config.n_ctx}) is larger than the tokenized length ({tokenized_len})."
)
if dataset_config.n_ctx < tokenized_len:
col = dataset_config.column_name
n_ctx = dataset_config.n_ctx
torch_dataset = dataset.map(lambda x: {col: x[col][:n_ctx]}).with_format("torch")
Comment on lines +231 to +234
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're just throwing out a bunch of data in this case, right? seems like instead maybe we should just assert dataset_config.n_ctx == tokenized_len. Do we imagine ever not doing that? and if we do then cropping examples doesn't seem the best way, or at least a potential footgun

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah in SPD we use max_seq_len=512, so we just throw out one token. We could show that it's within 1 of the tokenized_len. But it also might be useful for random sampling of the tokenized dataset.

else:
to_lower = "SimpleStories" in dataset_config.name
torch_dataset = tokenize_and_concatenate(
Expand Down
16 changes: 8 additions & 8 deletions spd/experiments/lm/pile_llama_simple_mlp-2L.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,15 @@ pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP
pretrained_model_path: null
pretrained_model_name: goodfire/spd/runs/ivdw6l06
pretrained_model_output_attr: idx_0
tokenizer_name: gpt2
tokenizer_name: EleutherAI/gpt-neox-20b
task_config:
task_name: lm
max_seq_len: 512
max_seq_len: 512 # Temporary. Later we will do n_ctx=513 for the dataset and streaming=false
buffer_size: 1000
dataset_name: monology/pile-uncopyrighted
column_name: text
train_data_split: train[:10000000]
eval_data_split: train[-100000:]
dataset_name: danbraunai/pile-uncopyrighted-tok
column_name: input_ids
train_data_split: train
eval_data_split: val
shuffle_each_epoch: true
is_tokenized: false
streaming: false
is_tokenized: true
streaming: true # Temporary. Later we will do n_ctx=513 for the dataset and streaming=false
7 changes: 7 additions & 0 deletions spd/pretrain/CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ spd-pretrain --config_path ... --n_gpus 4
- **SimpleStories**: `SimpleStories/test-SimpleStories-gpt2-1.25M` (vocab size: 4019)
- **Pile/OpenWebText**: `gpt2` (vocab size: 50257)

## Dataset n_ctx vs Model n_ctx

The dataset `n_ctx` must be **model n_ctx + 1**. During training, sequences are split into
input `[:, :-1]` and target `[:, 1:]` for next-token prediction, so the extra token provides
room for label indexing. For example, if the model has `n_ctx: 512`, the dataset should have
`n_ctx: 513`.

## Key Files

- `train.py` - Main training loop with DDP support
Expand Down
4 changes: 2 additions & 2 deletions spd/pretrain/configs/gpt2_simple-2L.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ train_dataset_config:
hf_tokenizer_path: SimpleStories/test-SimpleStories-gpt2-1.25M
split: train
streaming: false
n_ctx: 512
n_ctx: 513 # model block_size + 1 for next-token label indexing
seed: 0
column_name: story

Expand All @@ -37,6 +37,6 @@ val_dataset_config:
hf_tokenizer_path: SimpleStories/test-SimpleStories-gpt2-1.25M
split: test
streaming: false
n_ctx: 512
n_ctx: 513 # model block_size + 1 for next-token label indexing
seed: 0
column_name: story
4 changes: 2 additions & 2 deletions spd/pretrain/configs/owt_llama_simple_mlp-12L-768.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ train_dataset_config:
hf_tokenizer_path: gpt2
split: train[:10000000]
streaming: false
n_ctx: 512
n_ctx: 513 # model n_ctx + 1 for next-token label indexing
seed: 0
column_name: text

Expand All @@ -41,6 +41,6 @@ val_dataset_config:
hf_tokenizer_path: gpt2
split: train[-100000:]
streaming: false
n_ctx: 512
n_ctx: 513 # model n_ctx + 1 for next-token label indexing
seed: 0
column_name: text
21 changes: 11 additions & 10 deletions spd/pretrain/configs/pile_llama_simple_mlp-12L-768.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,22 @@ model:
flash_attention: false

train_dataset_config:
name: monology/pile-uncopyrighted
is_tokenized: false
name: danbraunai/pile-uncopyrighted-tok
is_tokenized: true
hf_tokenizer_path: EleutherAI/gpt-neox-20b
split: train[:100000] # Dataset has 177M examples
split: train
streaming: false
n_ctx: 512
n_ctx: 513 # model n_ctx + 1 for next-token label indexing
seed: 0
column_name: text
column_name: input_ids

val_dataset_config:
name: monology/pile-uncopyrighted
is_tokenized: false
name: danbraunai/pile-uncopyrighted-tok
is_tokenized: true
hf_tokenizer_path: EleutherAI/gpt-neox-20b
split: train[-1000000:]
split: val
streaming: false
n_ctx: 512
n_ctx: 513 # model n_ctx + 1 for next-token label indexing
seed: 0
column_name: text
column_name: input_ids

46 changes: 46 additions & 0 deletions spd/pretrain/configs/pile_llama_simple_mlp-1L-128.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
wandb_project: spd
dtype: bfloat16
batch_size: 1024
num_iterations: 100_000
warmup_iters: 600
learning_rate: 3e-4
learning_rate_decay_frac: 0.1
weight_decay: 0.1
grad_clip: 1.0
val_loss_every: 1000
val_max_steps: 20
sample_every: 1000
intermediate_checkpoints: false

model:
model_type: LlamaSimpleMLP
block_size: 512
vocab_size: 4019
n_layer: 2
n_head: 4
n_embd: 128
n_intermediate: 512 # 128 * 4
rotary_dim: 32 # 128 // 4
n_ctx: 512
n_key_value_heads: 2
flash_attention: false

train_dataset_config:
name: danbraunai/pile-uncopyrighted-tok
is_tokenized: true
hf_tokenizer_path: EleutherAI/gpt-neox-20b
split: train
streaming: false
n_ctx: 513 # model n_ctx + 1 for next-token label indexing
seed: 0
column_name: input_ids

val_dataset_config:
name: danbraunai/pile-uncopyrighted-tok
is_tokenized: true
hf_tokenizer_path: EleutherAI/gpt-neox-20b
split: val
streaming: false
n_ctx: 513 # model n_ctx + 1 for next-token label indexing
seed: 0
column_name: input_ids
20 changes: 10 additions & 10 deletions spd/pretrain/configs/pile_llama_simple_mlp-2L-2048.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,21 @@ model:
vocab_size: 50277

train_dataset_config:
name: monology/pile-uncopyrighted
is_tokenized: false
name: danbraunai/pile-uncopyrighted-tok
is_tokenized: true
hf_tokenizer_path: EleutherAI/gpt-neox-20b
split: train[:100000] # Dataset has 177M examples
split: train
streaming: false
n_ctx: 512
n_ctx: 513 # model n_ctx + 1 for next-token label indexing
seed: 0
column_name: text
column_name: input_ids

val_dataset_config:
name: monology/pile-uncopyrighted
is_tokenized: false
name: danbraunai/pile-uncopyrighted-tok
is_tokenized: true
hf_tokenizer_path: EleutherAI/gpt-neox-20b
split: train[-100000:]
split: val
streaming: false
n_ctx: 512
n_ctx: 513 # model n_ctx + 1 for next-token label indexing
seed: 0
column_name: text
column_name: input_ids
20 changes: 10 additions & 10 deletions spd/pretrain/configs/pile_llama_simple_mlp-2L-768.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,21 @@ model:
vocab_size: 50277

train_dataset_config:
name: monology/pile-uncopyrighted
is_tokenized: false
name: danbraunai/pile-uncopyrighted-tok
is_tokenized: true
hf_tokenizer_path: EleutherAI/gpt-neox-20b
split: train[:-100000] # Dataset has 177M examples
split: train
streaming: false
n_ctx: 512
n_ctx: 513 # model n_ctx + 1 for next-token label indexing
seed: 0
column_name: text
column_name: input_ids

val_dataset_config:
name: monology/pile-uncopyrighted
is_tokenized: false
name: danbraunai/pile-uncopyrighted-tok
is_tokenized: true
hf_tokenizer_path: EleutherAI/gpt-neox-20b
split: train[-100000:]
split: val
streaming: false
n_ctx: 512
n_ctx: 513 # model n_ctx + 1 for next-token label indexing
seed: 0
column_name: text
column_name: input_ids
21 changes: 11 additions & 10 deletions spd/pretrain/configs/pile_llama_simple_mlp-4L-768.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,22 @@ model:
vocab_size: 50277

train_dataset_config:
name: monology/pile-uncopyrighted
is_tokenized: false
name: danbraunai/pile-uncopyrighted-tok
is_tokenized: true
hf_tokenizer_path: EleutherAI/gpt-neox-20b
split: train[:100000] # Dataset has 177M examples
split: train
streaming: false
n_ctx: 512
n_ctx: 513 # model n_ctx + 1 for next-token label indexing
seed: 0
column_name: text
column_name: input_ids

val_dataset_config:
name: monology/pile-uncopyrighted
is_tokenized: false
name: danbraunai/pile-uncopyrighted-tok
is_tokenized: true
hf_tokenizer_path: EleutherAI/gpt-neox-20b
split: train[-100000:]
split: val
streaming: false
n_ctx: 512
n_ctx: 513 # model n_ctx + 1 for next-token label indexing
seed: 0
column_name: text
column_name: input_ids

4 changes: 2 additions & 2 deletions spd/pretrain/configs/ss_llama_simple_mlp-2L-128.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ train_dataset_config:
hf_tokenizer_path: SimpleStories/test-SimpleStories-gpt2-1.25M
split: train
streaming: false
n_ctx: 512
n_ctx: 513 # model n_ctx + 1 for next-token label indexing
seed: 0
column_name: story

Expand All @@ -41,6 +41,6 @@ val_dataset_config:
hf_tokenizer_path: SimpleStories/test-SimpleStories-gpt2-1.25M
split: test
streaming: false
n_ctx: 512
n_ctx: 513 # model n_ctx + 1 for next-token label indexing
seed: 0
column_name: story
4 changes: 2 additions & 2 deletions spd/pretrain/configs/ss_llama_simple_mlp-4L-192.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ train_dataset_config:
hf_tokenizer_path: SimpleStories/test-SimpleStories-gpt2-1.25M
split: train
streaming: false
n_ctx: 512
n_ctx: 513 # model n_ctx + 1 for next-token label indexing
seed: 0
column_name: story

Expand All @@ -41,6 +41,6 @@ val_dataset_config:
hf_tokenizer_path: SimpleStories/test-SimpleStories-gpt2-1.25M
split: test
streaming: false
n_ctx: 512
n_ctx: 513 # model n_ctx + 1 for next-token label indexing
seed: 0
column_name: story
14 changes: 3 additions & 11 deletions spd/pretrain/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,15 +206,7 @@ def main(config_path_or_obj: Path | str | Config | None = None) -> None:
load_dotenv(override=True)
config = load_config(config_path_or_obj, config_model=Config)

T = config.train_dataset_config.n_ctx # Training sequence length (positions to train on)

# Load n_ctx+1 tokens so we can train on n_ctx positions (need extra token for labels)
train_dataset_config = config.train_dataset_config.model_copy(
update={"n_ctx": config.train_dataset_config.n_ctx + 1}
)
val_dataset_config = config.val_dataset_config.model_copy(
update={"n_ctx": config.val_dataset_config.n_ctx + 1}
)
T = config.train_dataset_config.n_ctx - 1 # Training sequence length (positions to train on)

# set up DDP (distributed data parallel). torchrun sets this env variable
ddp = int(os.environ.get("RANK", -1)) != -1
Expand Down Expand Up @@ -304,7 +296,7 @@ def main(config_path_or_obj: Path | str | Config | None = None) -> None:
model = cast(nn.Module, torch.compile(model)) # type: ignore[reportArgumentType]

train_loader, train_tokenizer = create_data_loader(
dataset_config=train_dataset_config,
dataset_config=config.train_dataset_config,
batch_size=B,
buffer_size=1000,
global_seed=0,
Expand All @@ -313,7 +305,7 @@ def main(config_path_or_obj: Path | str | Config | None = None) -> None:
train_iter = iter(train_loader)

val_loader, _ = create_data_loader(
dataset_config=val_dataset_config,
dataset_config=config.val_dataset_config,
batch_size=B,
buffer_size=1000,
global_seed=0,
Expand Down