Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions records/track_10min_16mb/2026-03-22_TightSWA_VE128_TTT/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Record: 11L Tight SWA + VE128 + XSA4 + TTT (val_bpb: 1.1299)

**NEW SOTA** — beats previous record of 1.1428 by 0.0129 nats (3-seed mean)

## Results

| Seed | Steps | Post-SWA BPB | Quant BPB | Sliding Window BPB | Artifact Size |
|------|-------|-------------|-----------|--------------------|--------------:|
| 1337 | 5880 | 1.1462 | 1.1529 | **1.1291** | 15,787,610 |
| 7 | 5850 | 1.1478 | 1.1545 | **1.1309** | 15,659,426 |
| 99 | 6024 | 1.1465 | 1.1533 | **1.1296** | 15,688,657 |
| **Mean** | | 1.1468 | 1.1536 | **1.1299** | 15,711,898 |

All 3 seeds beat SOTA (1.1428) by ≥0.012 nats. All artifacts < 16MB.

## Architecture

- 11 transformer layers, 512-dim, 8 heads (4 KV heads, GQA)
- 3x MLP expansion with relu-squared activation
- Efficient Partial XSA on last 4 layers (GQA-aware, zero-alloc)
- Partial RoPE (16/64 dims) + NTK-aware scaling
- LN Scale Factor 1/sqrt(layer_idx+1)
- U-Net skip connections (5 encoder, 6 decoder)
- SmearGate + BigramHash (2048 buckets, dim=128)
- Shared Value Embedding (dim=128, layers 9,10) — 1 table, per-layer learned scales
- FlashAttention 3 (Hopper) with SDPA fallback
- Orthogonal init with proj scaling by 1/sqrt(2*num_layers)
- Logit softcap 30.0, tied embeddings
- 27M parameters

## Key Techniques

### Tight SWA
SWA checkpoint collection restricted to scale<0.2 (last ~600 steps), every 50 steps, averaging 12 checkpoints. Eliminates the SWA quality penalty while maintaining quantization-friendly weight averaging.

### Test-Time Training (TTT)
3 epochs of continued training on already-evaluated validation tokens (SGD with momentum 0.9, lr=0.002, batch=32 sequences). Freezes first 2 blocks. Runs after quantization, before sliding window eval. ~51s additional eval time.

### Late QAT
STE int6 fake-quantization enabled when LR scale < 0.1 (during warmdown), teaching the model to be robust to quantization noise before SWA collection begins.

### Sliding Window Evaluation
Overlapping windows at stride=64 (context=2048), significantly improving BPB vs single-pass evaluation. ~100s eval time.

## Training

- **Optimizer**: Muon (matrices, lr=0.025, momentum=0.99, warmup 0.92→0.99 over 1500 steps) + AdamW (embeddings lr=0.035, scalars lr=0.025)
- **Weight Decay**: 0.04 (both Muon and Adam)
- **Gradient Clip**: 0.3
- **Batch**: 786,432 tokens/step, seq_len=2048
- **Warmdown**: 3000 iters (wallclock-based, ~600s cap)
- **Tight SWA**: every 50 steps when scale<0.2 (12 checkpoints)
- **Late QAT**: STE int6 when LR scale<0.1
- ~5900 steps in 600s at ~101ms/step

## Quantization

- Int6 per-row for MLP + attention weights
- Int8 per-row for embeddings
- Control tensors in fp32
- zstd level 22 compression

## Evaluation Pipeline

1. Train for 600s (wallclock cap)
2. Apply Tight SWA (12 checkpoint average)
3. Serialize + int6/zstd compress (verify artifact < 16MB)
4. TTT: 3 epochs on already-evaluated val tokens (~51s)
5. Sliding window eval at stride=64 (~100s)
6. Total eval time: ~155s (well under 10min limit)

## Reproduction

```bash
# 8xH100 (default config, all hyperparameters are baked in)
DATA_PATH=./data/datasets/fineweb10B_sp1024 \
TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \
torchrun --standalone --nproc_per_node=8 \
records/track_10min_16mb/2026-03-22_TightSWA_VE128_TTT/train_gpt.py

# To reproduce specific seeds:
SEED=1337 DATA_PATH=... TOKENIZER_PATH=... torchrun --standalone --nproc_per_node=8 train_gpt.py
SEED=7 DATA_PATH=... TOKENIZER_PATH=... torchrun --standalone --nproc_per_node=8 train_gpt.py
SEED=99 DATA_PATH=... TOKENIZER_PATH=... torchrun --standalone --nproc_per_node=8 train_gpt.py
```

## Hardware

8xH100 SXM (RunPod), PyTorch 2.10, CUDA 12.x

## Acknowledgments

Built on [PR #374](https://github.com/openai/parameter-golf/pull/374) by [@unnir](https://github.com/unnir) (v38: Tight SWA + VE128 + XSA4, val_bpb=1.1246). Added test-time training (TTT) and SDPA fallback.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"author": "kasimte",
"github_id": "kasimte",
"model_name": "TightSWA_VE128_TTT",
"description": "Record: 11L + Tight SWA + Shared VE128 + XSA4 + Partial RoPE + LN Scale + Late QAT + TTT (3-seed mean val_bpb=1.1299)",
"val_loss": 1.90649894,
"val_bpb": 1.12913905,
"bytes_total": 15787610,
"track": "10min_16mb",
"seed": 1337
}
260 changes: 260 additions & 0 deletions records/track_10min_16mb/2026-03-22_TightSWA_VE128_TTT/train.log
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
=== Seed 1337 ===
root@6da42ce110bf:/workspace/parameter-golf# RUN_ID=pathb_s1337 SEED=1337 \
> DATA_PATH=./data/datasets/fineweb10B_sp1024 \
> TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \
> torchrun --standalone --nproc_per_node=8 \
> records/track_10min_16mb/2026-03-22_TightSWA_VE128_TTT/train_gpt.py
W0322 11:56:44.641000 979 torch/distributed/run.py:803]
W0322 11:56:44.641000 979 torch/distributed/run.py:803] *****************************************
W0322 11:56:44.641000 979 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0322 11:56:44.641000 979 torch/distributed/run.py:803] *****************************************
logs/pathb_s1337.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
model_params:26993756
mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0
XSA:last_4 active_layers:[7, 8, 9, 10]
world_size:8 grad_accum_steps:1
sdp_backends:cudnn=False flash=True mem_efficient=False math=False
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025
train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
seed:1337
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9279 val_bpb:4.1031 train_time:0ms step_avg:0.02ms
step:1/20000 train_loss:6.9299 train_time:143ms step_avg:142.79ms
step:2/20000 train_loss:8.5641 train_time:225ms step_avg:112.44ms
step:3/20000 train_loss:7.8180 train_time:321ms step_avg:106.84ms
step:4/20000 train_loss:7.2357 train_time:417ms step_avg:104.32ms
step:5/20000 train_loss:7.0661 train_time:514ms step_avg:102.72ms
step:6/20000 train_loss:6.8295 train_time:610ms step_avg:101.65ms
step:7/20000 train_loss:6.7202 train_time:706ms step_avg:100.88ms
step:8/20000 train_loss:6.7472 train_time:802ms step_avg:100.27ms
step:9/20000 train_loss:6.4048 train_time:898ms step_avg:99.81ms
step:10/20000 train_loss:6.0804 train_time:995ms step_avg:99.46ms
step:500/20000 train_loss:2.3986 train_time:50116ms step_avg:100.23ms
step:1000/20000 train_loss:2.2760 train_time:100751ms step_avg:100.75ms
step:1500/20000 train_loss:2.2138 train_time:151518ms step_avg:101.01ms
step:2000/20000 train_loss:2.0554 train_time:202444ms step_avg:101.22ms
step:2500/20000 train_loss:2.1594 train_time:253399ms step_avg:101.36ms
step:3000/20000 train_loss:2.1543 train_time:304518ms step_avg:101.51ms
step:3500/20000 train_loss:2.1575 train_time:355644ms step_avg:101.61ms
step:4000/20000 train_loss:1.9469 train_time:406845ms step_avg:101.71ms
step:4000/20000 val_loss:2.0369 val_bpb:1.2064 train_time:406864ms step_avg:101.72ms
step:4500/20000 train_loss:2.0917 train_time:458046ms step_avg:101.79ms
step:5000/20000 train_loss:2.0632 train_time:509680ms step_avg:101.94ms
swa:start step:5300
step:5500/20000 train_loss:1.9698 train_time:560996ms step_avg:102.00ms
late_qat:enabled step:5583 scale:0.1000
step:5880/20000 val_loss:1.9352 val_bpb:1.1461 train_time:600055ms step_avg:102.05ms
stopping_early: wallclock_cap train_time:600055ms step:5880/20000
peak memory allocated: 20570 MiB reserved: 20682 MiB
swa:applying averaged 12 checkpoints
DIAGNOSTIC post_swa val_loss:1.9353 val_bpb:1.1462 eval_time:2311ms
Serialized model: 106178569 bytes
Code size: 69485 bytes
Serialized model int6+zstd: 15718125 bytes
Total submission size int6+zstd: 15787610 bytes
ttt_epoch:1/3 loss:1.9487 time:17.2s
ttt_epoch:2/3 loss:1.9483 time:34.2s
ttt_epoch:3/3 loss:1.9480 time:51.3s
ttt:done elapsed=51.3s
final_int6_roundtrip val_loss:1.9466 val_bpb:1.1529 eval_time:2273ms
final_int6_roundtrip_exact val_loss:1.94664617 val_bpb:1.15291351
final_int6_sliding_window val_loss:1.9065 val_bpb:1.1291 stride:64 eval_time:100076ms
final_int6_sliding_window_exact val_loss:1.90649894 val_bpb:1.12913905

=== Seed 7 ===
root@6da42ce110bf:/workspace/parameter-golf# RUN_ID=pathb_s7 SEED=7 \
> DATA_PATH=./data/datasets/fineweb10B_sp1024 \
> TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \
> torchrun --standalone --nproc_per_node=8 \
> records/track_10min_16mb/2026-03-22_TightSWA_VE128_TTT/train_gpt.py
W0322 12:26:52.565000 63459 torch/distributed/run.py:803]
W0322 12:26:52.565000 63459 torch/distributed/run.py:803] *****************************************
W0322 12:26:52.565000 63459 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoidyour system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0322 12:26:52.565000 63459 torch/distributed/run.py:803] *****************************************
logs/pathb_s7.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
model_params:26993756
mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0
XSA:last_4 active_layers:[7, 8, 9, 10]
world_size:8 grad_accum_steps:1
sdp_backends:cudnn=False flash=True mem_efficient=False math=False
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025
train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
seed:7
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9297 val_bpb:4.1041 train_time:0ms step_avg:0.02ms
step:1/20000 train_loss:6.9305 train_time:141ms step_avg:140.91ms
step:2/20000 train_loss:8.8335 train_time:225ms step_avg:112.62ms
step:3/20000 train_loss:7.9593 train_time:321ms step_avg:107.03ms
step:4/20000 train_loss:7.1907 train_time:417ms step_avg:104.30ms
step:5/20000 train_loss:6.9718 train_time:514ms step_avg:102.72ms
step:6/20000 train_loss:6.9133 train_time:610ms step_avg:101.68ms
step:7/20000 train_loss:6.7581 train_time:707ms step_avg:100.94ms
step:8/20000 train_loss:6.6566 train_time:803ms step_avg:100.38ms
step:9/20000 train_loss:6.3992 train_time:899ms step_avg:99.92ms
step:10/20000 train_loss:6.0589 train_time:996ms step_avg:99.59ms
step:500/20000 train_loss:2.4041 train_time:50355ms step_avg:100.71ms
step:1000/20000 train_loss:2.2760 train_time:101473ms step_avg:101.47ms
step:1500/20000 train_loss:2.2188 train_time:152764ms step_avg:101.84ms
step:2000/20000 train_loss:2.0601 train_time:204054ms step_avg:102.03ms
step:2500/20000 train_loss:2.1623 train_time:255455ms step_avg:102.18ms
step:3000/20000 train_loss:2.1536 train_time:306846ms step_avg:102.28ms
step:3500/20000 train_loss:2.1592 train_time:358221ms step_avg:102.35ms
step:4000/20000 train_loss:1.9476 train_time:409656ms step_avg:102.41ms
step:4000/20000 val_loss:2.0373 val_bpb:1.2066 train_time:409676ms step_avg:102.42ms
step:4500/20000 train_loss:2.0920 train_time:461048ms step_avg:102.46ms
step:5000/20000 train_loss:2.0645 train_time:512503ms step_avg:102.50ms
swa:start step:5300
step:5500/20000 train_loss:1.9693 train_time:564070ms step_avg:102.56ms
late_qat:enabled step:5551 scale:0.0998
step:5850/20000 val_loss:1.9381 val_bpb:1.1479 train_time:600339ms step_avg:102.62ms
stopping_early: wallclock_cap train_time:600339ms step:5850/20000
peak memory allocated: 20566 MiB reserved: 20688 MiB
swa:applying averaged 12 checkpoints
DIAGNOSTIC post_swa val_loss:1.9380 val_bpb:1.1478 eval_time:2322ms
Serialized model: 106178569 bytes
Code size: 69485 bytes
Serialized model int6+zstd: 15589941 bytes
Total submission size int6+zstd: 15659426 bytes
ttt_epoch:1/3 loss:1.9516 time:17.2s
ttt_epoch:2/3 loss:1.9513 time:34.3s
ttt_epoch:3/3 loss:1.9510 time:51.3s
ttt:done elapsed=51.3s
final_int6_roundtrip val_loss:1.9493 val_bpb:1.1545 eval_time:2279ms
final_int6_roundtrip_exact val_loss:1.94929521 val_bpb:1.15448242
final_int6_sliding_window val_loss:1.9095 val_bpb:1.1309 stride:64 eval_time:85248ms
final_int6_sliding_window_exact val_loss:1.90949180 val_bpb:1.13091159
root@6da42ce110bf:/workspace/parameter-golf#

=== Seed 99 ===
root@151da2602b13:/workspace/parameter-golf# RUN_ID=pathb_s99b SEED=99 \
> DATA_PATH=./data/datasets/fineweb10B_sp1024 \
> TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \
> torchrun --standalone --nproc_per_node=8 \
> records/track_10min_16mb/2026-03-22_TightSWA_VE128_TTT/train_gpt.py
W0322 18:54:15.252000 54971 torch/distributed/run.py:803]
W0322 18:54:15.252000 54971 torch/distributed/run.py:803] *****************************************
W0322 18:54:15.252000 54971 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0322 18:54:15.252000 54971 torch/distributed/run.py:803] *****************************************
logs/pathb_s99b.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
model_params:26993756
mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0
XSA:last_4 active_layers:[7, 8, 9, 10]
world_size:8 grad_accum_steps:1
sdp_backends:cudnn=False flash=True mem_efficient=False math=False
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025
train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
seed:99
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9280 val_bpb:4.1031 train_time:0ms step_avg:0.02ms
step:1/20000 train_loss:6.9297 train_time:143ms step_avg:142.82ms
step:2/20000 train_loss:8.7051 train_time:223ms step_avg:111.74ms
step:3/20000 train_loss:7.9051 train_time:320ms step_avg:106.73ms
step:4/20000 train_loss:7.2488 train_time:417ms step_avg:104.19ms
step:5/20000 train_loss:7.0147 train_time:514ms step_avg:102.72ms
step:6/20000 train_loss:6.8572 train_time:610ms step_avg:101.72ms
step:7/20000 train_loss:6.7803 train_time:707ms step_avg:100.95ms
step:8/20000 train_loss:6.6588 train_time:803ms step_avg:100.38ms
step:9/20000 train_loss:6.3764 train_time:899ms step_avg:99.94ms
step:10/20000 train_loss:6.0845 train_time:996ms step_avg:99.57ms
step:500/20000 train_loss:2.4048 train_time:49533ms step_avg:99.07ms
step:1000/20000 train_loss:2.2688 train_time:99320ms step_avg:99.32ms
step:1500/20000 train_loss:2.2183 train_time:149107ms step_avg:99.40ms
step:2000/20000 train_loss:2.0586 train_time:198914ms step_avg:99.46ms
step:2500/20000 train_loss:2.1669 train_time:248721ms step_avg:99.49ms
step:3000/20000 train_loss:2.1537 train_time:298514ms step_avg:99.50ms
step:3500/20000 train_loss:2.1618 train_time:348368ms step_avg:99.53ms
step:4000/20000 train_loss:1.9507 train_time:398163ms step_avg:99.54ms
step:4000/20000 val_loss:2.0427 val_bpb:1.2098 train_time:398181ms step_avg:99.55ms
step:4500/20000 train_loss:2.0969 train_time:447954ms step_avg:99.55ms
step:5000/20000 train_loss:2.0717 train_time:497754ms step_avg:99.55ms
swa:start step:5450
step:5500/20000 train_loss:1.9791 train_time:547600ms step_avg:99.56ms
late_qat:enabled step:5725 scale:0.0998
step:6000/20000 train_loss:1.8989 train_time:597643ms step_avg:99.61ms
step:6024/20000 val_loss:1.9359 val_bpb:1.1466 train_time:600075ms step_avg:99.61ms
stopping_early: wallclock_cap train_time:600075ms step:6024/20000
peak memory allocated: 20566 MiB reserved: 20688 MiB
swa:applying averaged 12 checkpoints
DIAGNOSTIC post_swa val_loss:1.9358 val_bpb:1.1465 eval_time:2255ms
Serialized model: 106178569 bytes
Code size: 69485 bytes
Serialized model int6+zstd: 15619172 bytes
Total submission size int6+zstd: 15688657 bytes
ttt_epoch:1/3 loss:1.9493 time:17.2s
ttt_epoch:2/3 loss:1.9490 time:34.3s
ttt_epoch:3/3 loss:1.9487 time:51.3s
ttt:done elapsed=51.3s
final_int6_roundtrip val_loss:1.9473 val_bpb:1.1533 eval_time:2227ms
final_int6_roundtrip_exact val_loss:1.94729244 val_bpb:1.15329627
final_int6_sliding_window val_loss:1.9073 val_bpb:1.1296 stride:64 eval_time:98259ms
final_int6_sliding_window_exact val_loss:1.90729652 val_bpb:1.12961142
root@151da2602b13:/workspace/parameter-golf#
Loading