Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# LeakyReLU² + XSA-all + Full GPTQ + SLOT

**val_bpb: 0.9354** (3-seed mean: 1337→0.9349, 42→0.9325, 7→0.9388)

## Architecture

- 11 transformer layers, dim=512, 8 heads, 4 KV heads (GQA)
- LeakyReLU(0.5)² MLP with 3x expansion
- RoPE, RMSNorm, tied embeddings (vocab=1024), logit softcapping (30.0)
- U-Net skip connections with learned skip weights
- SmearGate + BigramHash embedding augmentation
- XSA (cross-sequence attention) on all 11 layers
- QK-Gain init = 4.0
- ~27M parameters

## Training

- Muon optimizer for matrix params, Adam for scalars/embeddings
- EMA (decay=0.997) + Tight SWA (every 50 steps from step 4600)
- Late QAT (int6 quantization-aware training, threshold 0.15)
- Full GPTQ: Hessian-based int6 quantization with Cholesky error compensation (32 calibration batches on EMA model)
- Compression: zstd-22
- Training time: ~600s on 8xH100, ~5250 steps at 114ms/step

## Evaluation — SLOT (Softmax Logit Optimization at Test-time)

Based on [arXiv:2505.12392v2](https://arxiv.org/abs/2505.12392v2).

- **Sliding window eval** at stride=64, seq_len=2048 (baseline)
- **SLOT optimization** per batch:
1. Extract frozen hidden states from last layer (`forward_hidden`) under `torch.no_grad()`
2. Detach projection weights (tied embedding)
3. Optimize per-sample additive delta `[bsz, 1, 512]` + per-sample logit bias `[bsz, 1, 1024]`
4. **16 AdamW steps** with cosine LR schedule (0.008 → 0.0008)
5. **Scored-position mask**: only positions contributing to final BPB (last `stride` tokens per non-first window) are included in the SLOT optimization loss
6. Logits computed as: `softcap * tanh((H + delta) @ W_out^T + logit_bias) / softcap)`
7. Final scoring with optimized delta + logit bias under `torch.no_grad()`

### Legality

- Model weights are **completely frozen** during SLOT — only delta and logit_bias are optimized
- Hidden states extracted under `torch.no_grad()` — no gradient flows through the model
- Standard autoregressive cross-entropy loss preserves causality
- Optimization uses only tokens within each sliding window (no future information)
- `torch.compile` on `forward_hidden` for throughput
- SLOT eval time: ~311s per run (within 10-min eval budget)

### No illegal techniques
- ❌ No n-gram cache
- ❌ No two-pass rescoring
- ❌ No eval-time access to training data
- ❌ No oracle/hindsight selection

## Results

| Seed | Sliding BPB | SLOT BPB | Artifact Size |
|------|-------------|----------|---------------|
| 1337 | 1.1264 | 0.9349 | 15,890,549 |
| 42 | 1.1264 | 0.9325 | 15,830,408 |
| 7 | 1.1261 | 0.9388 | 15,810,068 |
| **Mean** | **1.1263** | **0.9354** | |

Beats merged SOTA (1.1147) by 0.179 BPB. All artifacts < 16,000,000 bytes.

## Reproduction

```bash
SEED=1337 GPTQ_CALIB_BATCHES=32 SLOT_ENABLED=1 SLOT_STEPS=16 \
SLOT_LR=0.008 SLOT_LR_MIN=0.0008 \
torchrun --nproc_per_node=8 train_gpt.py
```

## Key Techniques

1. **LeakyReLU(0.5)²**: Leaky variant (negative slope 0.5) with squaring for sparsity
2. **XSA-all**: Cross-sequence attention on all 11 layers
3. **QK-Gain 4.0**: Sharpened attention maps via learned per-head gain initialized at 4.0
4. **Full GPTQ**: Hessian-based int6 quantization with actorder and Cholesky error compensation
5. **SLOT**: Per-sample delta + logit bias optimization at eval time with scored-position masking
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
W0402 20:01:15.223000 138574105780864 torch/distributed/run.py:779]
W0402 20:01:15.223000 138574105780864 torch/distributed/run.py:779] *****************************************
W0402 20:01:15.223000 138574105780864 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0402 20:01:15.223000 138574105780864 torch/distributed/run.py:779] *****************************************
logs/da6528f0-05ef-4487-baf5-963e9c56d90b.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:47
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
model_params:26993756
mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0
XSA:last_11 active_layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
world_size:8 grad_accum_steps:1
sdp_backends:cudnn=False flash=True mem_efficient=False math=False
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025
train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
seed:1337
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9271 val_bpb:4.1026 train_time:0ms step_avg:0.01ms
step:1/20000 train_loss:6.9300 train_time:160ms step_avg:159.82ms
step:2/20000 train_loss:8.4173 train_time:268ms step_avg:134.16ms
step:3/20000 train_loss:7.5580 train_time:381ms step_avg:126.92ms
step:4/20000 train_loss:8.2496 train_time:493ms step_avg:123.37ms
step:5/20000 train_loss:8.4692 train_time:606ms step_avg:121.16ms
step:6/20000 train_loss:8.2788 train_time:718ms step_avg:119.74ms
step:7/20000 train_loss:7.7299 train_time:831ms step_avg:118.70ms
step:8/20000 train_loss:7.0643 train_time:943ms step_avg:117.91ms
step:9/20000 train_loss:6.5543 train_time:1056ms step_avg:117.29ms
step:10/20000 train_loss:6.1896 train_time:1168ms step_avg:116.80ms
step:500/20000 train_loss:2.3865 train_time:56914ms step_avg:113.83ms
step:1000/20000 train_loss:2.2564 train_time:114106ms step_avg:114.11ms
step:1500/20000 train_loss:2.2062 train_time:171222ms step_avg:114.15ms
step:2000/20000 train_loss:2.0430 train_time:228281ms step_avg:114.14ms
step:2500/20000 train_loss:2.1364 train_time:285313ms step_avg:114.13ms
step:3000/20000 train_loss:2.1209 train_time:342400ms step_avg:114.13ms
step:3500/20000 train_loss:2.1220 train_time:399364ms step_avg:114.10ms
step:4000/20000 train_loss:1.9112 train_time:456336ms step_avg:114.08ms
step:4000/20000 val_loss:2.0008 val_bpb:1.1850 train_time:456341ms step_avg:114.09ms
step:4500/20000 train_loss:2.0546 train_time:513303ms step_avg:114.07ms
swa:start step:4600
late_qat:enabled step:4734 scale:0.1500
step:5000/20000 train_loss:2.0264 train_time:570567ms step_avg:114.11ms
step:5257/20000 val_loss:1.9369 val_bpb:1.1471 train_time:600045ms step_avg:114.14ms
stopping_early: wallclock_cap train_time:600045ms step:5257/20000
peak memory allocated: 27940 MiB reserved: 29072 MiB
ema:applying EMA weights
gptq:calibrating batches=32
gptq:done layers=68 time=6822ms
DIAGNOSTIC post_ema val_loss:1.9357 val_bpb:1.1464 eval_time:2559ms
Serialized model: 106178100 bytes
Code size: 73198 bytes
Serialized model int6+zstd: 15817351 bytes
Total submission size int6+zstd: 15890549 bytes
Total submission size int8+zlib: 15890549 bytes
/workspace/repo/train_gpt.py:1425: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
quant_state = torch.load(
/workspace/repo/train_gpt.py:1425: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
quant_state = torch.load(
/workspace/repo/train_gpt.py:1425: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
quant_state = torch.load(
/workspace/repo/train_gpt.py:1425: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
quant_state = torch.load(
/workspace/repo/train_gpt.py:1425: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
quant_state = torch.load(
/workspace/repo/train_gpt.py:1425: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
quant_state = torch.load(
/workspace/repo/train_gpt.py:1425: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
quant_state = torch.load(
/workspace/repo/train_gpt.py:1425: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
quant_state = torch.load(
final_int6_roundtrip val_loss:1.9417 val_bpb:1.1500 eval_time:52915ms
final_int6_roundtrip_exact val_loss:1.94174466 val_bpb:1.15001056
final_int6_sliding_window val_loss:1.9019 val_bpb:1.1264 stride:64 eval_time:119359ms
final_int6_sliding_window_exact val_loss:1.90194141 val_bpb:1.12643982
final_slot val_loss:1.5786 val_bpb:0.9349 steps:16 lr:0.008 time:311364ms
final_slot_exact val_loss:1.57861646 val_bpb:0.93494806
Loading