Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Record: 11L XSA-all + LeakyReLU(0.5)^2 + VR + GA + 7-gram cache (val_bpb=1.0337)

**3-seed mean val_bpb = 1.0337** (std=0.0010) | **~15.99 MB** | No TTT

## Summary

Non-TTT submission combining XSA on all 11 layers with LeakyReLU(0.5)^2 activation, Value Residual, Gated Attention, and a 7-gram backward-looking eval cache (alpha=0.40, fixed mixing). Achieves 1.0337 mean BPB across 3 seeds on 8xH100 SXM within 600s wallclock.

## 3-Seed Results (8xH100 SXM, 600s wallclock)

| Seed | Steps | Quant | Size (bytes) | Sliding BPB (s=64) |
|------|-------|-------|-------------|---------------------|
| 1337 | 5589 | int6 zstd-16 | 15,990,221 | 1.0329 |
| 42 | ~5589 | int6 zstd-17 | 15,982,903 | 1.0334 |
| 7 | ~5589 | int6 zstd-16 | 15,992,378 | 1.0349 |
| **Mean** | | | | **1.0337** |
| **Std** | | | | **0.0010** |

## Architecture

- 11 transformer layers, 512d, 8H/4KV (GQA), MLP 3x
- **LeakyReLU(0.5)^2**: `leaky_relu(x, 0.5).square()` replaces ReLU^2. Preserves negative gradient flow at zero overhead.
- **XSA on all 11 layers**: Exclusive Self-Attention removes self-position bias in all layers.
- **Value Residual (VR)**: Layer 0 V output mixed into subsequent layers via learned sigmoid gates.
- **Gated Attention (GA)**: Per-head sigmoid gates on attention output.
- SmearGate + OrthoInit, BigramHash(4096), U-Net skip connections
- Partial RoPE (16/64 dims), LN Scale, EMA(0.997)
- Int6 per-row quantization + zstd compression

## 7-gram Backward-Looking Eval Cache

During sliding-window evaluation, a token-level n-gram cache adjusts the model's next-token predictions using observed n-gram statistics from previously scored tokens.

### How it works

1. As evaluation proceeds left-to-right through the validation set, completed (already-scored) tokens are added to an n-gram frequency table.
2. For each new position, the cache looks up all n-gram contexts (orders 1 through 7) ending at the current position using only backward (already-scored) context.
3. The n-gram distribution is mixed with the model's softmax output: `p_final = (1 - alpha) * p_model + alpha * p_ngram`, with a fixed alpha=0.40.
4. The mixed distribution is used to compute the loss for that position.

### Compliance notes

- **Score-first**: Each token is scored by the model *before* it enters the n-gram table. The cache only uses tokens that have already been scored — it never looks ahead.
- **Fixed alpha**: The mixing weight alpha=0.40 is a fixed hyperparameter baked into the submission code, not tuned per-sample or per-position at eval time.
- **No oracle selection**: There is no selection among multiple cache configurations at eval time. The same alpha and order are used for every token.
- **Deterministic**: Given the same model weights and validation data, the cache produces identical results regardless of hardware or random seeds.
- **No additional parameters**: The n-gram cache adds zero learned parameters. It is a purely statistical post-processing step built from the evaluation data stream.

## Training Config

```bash
ITERATIONS=20000 (wallclock-capped at ~5589 steps)
WARMDOWN_ITERS=3000 MAX_WALLCLOCK_SECONDS=600
MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035
MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 MUON_MOMENTUM_WARMUP_STEPS=1500
XSA_LAST_N=11 LEAKY_RELU=1 TTT_ENABLED=0 CANON_LAST_N=0 SWA_ENABLED=0
# N-gram cache (eval-time only):
NGRAM_CACHE=1 NGRAM_ALPHA=0.40 NGRAM_ORDER=7
```

## Reproduction

```bash
# Download data
python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards 80

# Train (seed 1337, 8xH100)
SEED=1337 XSA_LAST_N=11 LEAKY_RELU=1 WARMDOWN_ITERS=3000 \
MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035 \
MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 MUON_MOMENTUM_WARMUP_STEPS=1500 \
TTT_ENABLED=0 CANON_LAST_N=0 SWA_ENABLED=0 MAX_WALLCLOCK_SECONDS=600 \
torchrun --standalone --nproc_per_node=8 train_gpt.py
```

## Credits

- Base architecture: modded-nanogpt
- XSA-all: PR #609
- LeakyReLU^2: PR #493, #518
- Value Residual: PR #413 (arXiv:2410.17897)
- Gated Attention: NeurIPS 2025, arXiv:2505.06708
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
W0325 12:18:43.307000 123010 torch/distributed/run.py:803]
W0325 12:18:43.307000 123010 torch/distributed/run.py:803] *****************************************
W0325 12:18:43.307000 123010 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0325 12:18:43.307000 123010 torch/distributed/run.py:803] *****************************************
logs/3a9335ff-3719-4f94-9065-60e14e09cd93.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
model_params:27137223
mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0
world_size:8 grad_accum_steps:1
sdp_backends:fa3=True cudnn=False flash=True mem_efficient=False math=False
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025
train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
seed:1337
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9286 val_bpb:4.1035 train_time:0ms step_avg:0.02ms
step:1/20000 train_loss:6.9308 train_time:150ms step_avg:149.93ms
step:2/20000 train_loss:8.6115 train_time:249ms step_avg:124.30ms
step:3/20000 train_loss:7.7112 train_time:347ms step_avg:115.73ms
step:4/20000 train_loss:7.2918 train_time:445ms step_avg:111.32ms
step:5/20000 train_loss:7.0145 train_time:543ms step_avg:108.67ms
step:6/20000 train_loss:6.8948 train_time:642ms step_avg:107.08ms
step:7/20000 train_loss:6.7785 train_time:743ms step_avg:106.13ms
step:8/20000 train_loss:6.6045 train_time:844ms step_avg:105.49ms
step:9/20000 train_loss:6.2824 train_time:946ms step_avg:105.12ms
step:10/20000 train_loss:5.9656 train_time:1049ms step_avg:104.88ms
step:200/20000 train_loss:2.3438 train_time:21425ms step_avg:107.12ms
step:400/20000 train_loss:2.4001 train_time:42952ms step_avg:107.38ms
step:600/20000 train_loss:2.3155 train_time:64425ms step_avg:107.38ms
step:800/20000 train_loss:2.2162 train_time:85980ms step_avg:107.48ms
step:1000/20000 train_loss:2.2571 train_time:107469ms step_avg:107.47ms
step:1000/20000 val_loss:2.2043 val_bpb:1.3055 train_time:107474ms step_avg:107.47ms
step:1200/20000 train_loss:2.3272 train_time:128927ms step_avg:107.44ms
step:1400/20000 train_loss:2.1643 train_time:150525ms step_avg:107.52ms
step:1600/20000 train_loss:2.0548 train_time:171979ms step_avg:107.49ms
step:1800/20000 train_loss:2.1294 train_time:193545ms step_avg:107.52ms
step:2000/20000 train_loss:2.0444 train_time:215054ms step_avg:107.53ms
step:2000/20000 val_loss:2.1097 val_bpb:1.2495 train_time:215059ms step_avg:107.53ms
step:2200/20000 train_loss:2.1160 train_time:236523ms step_avg:107.51ms
step:2400/20000 train_loss:2.0474 train_time:258010ms step_avg:107.50ms
step:2600/20000 train_loss:2.0909 train_time:279580ms step_avg:107.53ms
step:2800/20000 train_loss:2.1317 train_time:301196ms step_avg:107.57ms
step:3000/20000 train_loss:2.1334 train_time:322657ms step_avg:107.55ms
step:3000/20000 val_loss:2.0613 val_bpb:1.2208 train_time:322662ms step_avg:107.55ms
step:3200/20000 train_loss:2.1364 train_time:344170ms step_avg:107.55ms
step:3400/20000 train_loss:1.9821 train_time:365669ms step_avg:107.55ms
step:3600/20000 train_loss:2.0512 train_time:387227ms step_avg:107.56ms
step:3800/20000 train_loss:2.0236 train_time:408716ms step_avg:107.56ms
step:4000/20000 train_loss:1.9222 train_time:430243ms step_avg:107.56ms
step:4000/20000 val_loss:2.0147 val_bpb:1.1932 train_time:430248ms step_avg:107.56ms
step:4200/20000 train_loss:2.0924 train_time:451730ms step_avg:107.55ms
step:4400/20000 train_loss:1.9749 train_time:473151ms step_avg:107.53ms
step:4600/20000 train_loss:1.7853 train_time:494689ms step_avg:107.54ms
step:4800/20000 train_loss:2.3668 train_time:516188ms step_avg:107.54ms
step:5000/20000 train_loss:2.0407 train_time:537726ms step_avg:107.55ms
step:5000/20000 val_loss:1.9603 val_bpb:1.1610 train_time:537731ms step_avg:107.55ms
step:5200/20000 train_loss:1.9764 train_time:559127ms step_avg:107.52ms
step:5400/20000 train_loss:1.9832 train_time:580637ms step_avg:107.53ms
step:5581/20000 val_loss:1.9307 val_bpb:1.1435 train_time:600065ms step_avg:107.52ms
stopping_early: wallclock_cap train_time:600065ms step:5581/20000
peak memory allocated: 22472 MiB reserved: 22518 MiB
ema:applying EMA weights
Serialized model: 106498817 bytes
Code size: 85725 bytes
quant_try int6 zstd-16: 15904496 bytes (limit 15914275)
Serialized model quant+zstd-16: 15904496 bytes
Total submission size: 15990221 bytes
final_int6_roundtrip val_loss:1.9407 val_bpb:1.1494 eval_time:7599ms
final_int6_roundtrip_exact val_loss:1.94066127 val_bpb:1.14936892
ngram_cache:enabled alpha=0.4 min_count=2 order=7 buckets=4194304
ngram_cache:enabled alpha=0.4 min_count=2 order=7 buckets=4194304
ngram_cache:enabled alpha=0.4 min_count=2 order=7 buckets=4194304
ngram_cache:enabled alpha=0.4 min_count=2 order=7 buckets=4194304ngram_cache:enabled alpha=0.4 min_count=2 order=7 buckets=4194304

ngram_cache:enabled alpha=0.4 min_count=2 order=7 buckets=4194304
ngram_cache:enabled alpha=0.4 min_count=2 order=7 buckets=4194304
ngram_cache:enabled alpha=0.4 min_count=2 order=7 buckets=4194304
final_int6_sliding_window val_loss:1.7439 val_bpb:1.0329 stride:64 eval_time:103164ms
final_int6_sliding_window_exact val_loss:1.74394155 val_bpb:1.03286315
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
W0325 12:32:14.396000 125116 torch/distributed/run.py:803]
W0325 12:32:14.396000 125116 torch/distributed/run.py:803] *****************************************
W0325 12:32:14.396000 125116 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0325 12:32:14.396000 125116 torch/distributed/run.py:803] *****************************************
logs/42973fb6-f9ef-468c-98ee-18c461676e70.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
model_params:27137223
mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0
world_size:8 grad_accum_steps:1
sdp_backends:fa3=True cudnn=False flash=True mem_efficient=False math=False
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025
train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
seed:42
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9310 val_bpb:4.1049 train_time:0ms step_avg:0.02ms
step:1/20000 train_loss:6.9318 train_time:145ms step_avg:145.17ms
step:2/20000 train_loss:8.7601 train_time:245ms step_avg:122.63ms
step:3/20000 train_loss:7.7913 train_time:343ms step_avg:114.44ms
step:4/20000 train_loss:7.2740 train_time:441ms step_avg:110.35ms
step:5/20000 train_loss:7.0823 train_time:539ms step_avg:107.81ms
step:6/20000 train_loss:6.9496 train_time:638ms step_avg:106.30ms
step:7/20000 train_loss:6.8480 train_time:738ms step_avg:105.36ms
step:8/20000 train_loss:6.6873 train_time:838ms step_avg:104.77ms
step:9/20000 train_loss:6.3447 train_time:939ms step_avg:104.38ms
step:10/20000 train_loss:6.0018 train_time:1042ms step_avg:104.18ms
step:200/20000 train_loss:2.3493 train_time:21326ms step_avg:106.63ms
step:400/20000 train_loss:2.3822 train_time:42800ms step_avg:107.00ms
step:600/20000 train_loss:2.3075 train_time:64225ms step_avg:107.04ms
step:800/20000 train_loss:2.2140 train_time:85701ms step_avg:107.13ms
step:1000/20000 train_loss:2.2526 train_time:107132ms step_avg:107.13ms
step:1000/20000 val_loss:2.2004 val_bpb:1.3032 train_time:107137ms step_avg:107.14ms
step:1200/20000 train_loss:2.3254 train_time:128542ms step_avg:107.12ms
step:1400/20000 train_loss:2.1636 train_time:150093ms step_avg:107.21ms
step:1600/20000 train_loss:2.0535 train_time:171500ms step_avg:107.19ms
step:1800/20000 train_loss:2.1268 train_time:193015ms step_avg:107.23ms
step:2000/20000 train_loss:2.0457 train_time:214446ms step_avg:107.22ms
step:2000/20000 val_loss:2.1088 val_bpb:1.2490 train_time:214451ms step_avg:107.23ms
step:2200/20000 train_loss:2.1183 train_time:235901ms step_avg:107.23ms
step:2400/20000 train_loss:2.0457 train_time:257374ms step_avg:107.24ms
step:2600/20000 train_loss:2.0899 train_time:278896ms step_avg:107.27ms
step:2800/20000 train_loss:2.1323 train_time:300470ms step_avg:107.31ms
step:3000/20000 train_loss:2.1337 train_time:321874ms step_avg:107.29ms
step:3000/20000 val_loss:2.0626 val_bpb:1.2216 train_time:321879ms step_avg:107.29ms
step:3200/20000 train_loss:2.1432 train_time:343310ms step_avg:107.28ms
step:3400/20000 train_loss:1.9860 train_time:364794ms step_avg:107.29ms
step:3600/20000 train_loss:2.0540 train_time:386315ms step_avg:107.31ms
step:3800/20000 train_loss:2.0255 train_time:407826ms step_avg:107.32ms
step:4000/20000 train_loss:1.9260 train_time:429334ms step_avg:107.33ms
step:4000/20000 val_loss:2.0163 val_bpb:1.1941 train_time:429338ms step_avg:107.33ms
step:4200/20000 train_loss:2.0982 train_time:450837ms step_avg:107.34ms
step:4400/20000 train_loss:1.9800 train_time:472247ms step_avg:107.33ms
step:4600/20000 train_loss:1.7885 train_time:493748ms step_avg:107.34ms
step:4800/20000 train_loss:2.3701 train_time:515194ms step_avg:107.33ms
step:5000/20000 train_loss:2.0430 train_time:536685ms step_avg:107.34ms
step:5000/20000 val_loss:1.9619 val_bpb:1.1619 train_time:536689ms step_avg:107.34ms
step:5200/20000 train_loss:1.9777 train_time:558050ms step_avg:107.32ms
step:5400/20000 train_loss:1.9827 train_time:579542ms step_avg:107.32ms
step:5591/20000 val_loss:1.9316 val_bpb:1.1440 train_time:600036ms step_avg:107.32ms
stopping_early: wallclock_cap train_time:600036ms step:5591/20000
peak memory allocated: 22472 MiB reserved: 22518 MiB
ema:applying EMA weights
Serialized model: 106498817 bytes
Code size: 85725 bytes
quant_try int6 zstd-16: 15917753 bytes (limit 15914275)
quant_try int6 zstd-1: 15956755 bytes (limit 15914275)
quant_try int6 zstd-17: 15897178 bytes (limit 15914275)
Serialized model quant+zstd-17: 15897178 bytes
Total submission size: 15982903 bytes
final_int6_roundtrip val_loss:1.9420 val_bpb:1.1502 eval_time:7656ms
final_int6_roundtrip_exact val_loss:1.94200534 val_bpb:1.15016495
ngram_cache:enabled alpha=0.4 min_count=2 order=7 buckets=4194304
ngram_cache:enabled alpha=0.4 min_count=2 order=7 buckets=4194304
ngram_cache:enabled alpha=0.4 min_count=2 order=7 buckets=4194304
ngram_cache:enabled alpha=0.4 min_count=2 order=7 buckets=4194304
ngram_cache:enabled alpha=0.4 min_count=2 order=7 buckets=4194304
ngram_cache:enabled alpha=0.4 min_count=2 order=7 buckets=4194304
ngram_cache:enabled alpha=0.4 min_count=2 order=7 buckets=4194304
ngram_cache:enabled alpha=0.4 min_count=2 order=7 buckets=4194304
final_int6_sliding_window val_loss:1.7449 val_bpb:1.0334 stride:64 eval_time:102235ms
final_int6_sliding_window_exact val_loss:1.74491299 val_bpb:1.03343849
Loading