Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Record: Split-LR + N-gram Agreement + Full Hessian GPTQ + Brotli

**val_bpb: 1.1078** (3-seed mean, std 0.0009) | **1.8752 nats** | **~15.86 MB** | 8xH100 SXM, 600s train + 449s eval

Built on [PR #1179](https://github.com/openai/parameter-golf/pull/1179) by @dexhunter (training) and [PR #1145](https://github.com/openai/parameter-golf/pull/1145) by @AnirudhRahul (n-gram agreement evaluation).

## Results (8xH100 80GB SXM, PyTorch 2.9.1+cu128)

| Seed | Steps | ms/step | Sliding BPB | **N-gram BPB** | Artifact |
|------|-------|---------|-------------|----------------|----------|
| 1337 | ~6780 | 88.0 | 1.1110 | **1.1083** | 15,853,466 |
| 42 | ~6780 | 88.0 | 1.1095 | **1.1068** | 15,857,705 |
| 2025 | ~6780 | 88.0 | 1.1112 | **1.1085** | 15,846,914 |
| **Mean** | | | **1.1106** | **1.1078** | |

SOTA (PR #1019, 3-seed mean): **1.8822 nats**. This run: **1.8752 nats**. Delta: **-0.00697 nats**. Clears the 0.005-nat threshold.

### Timing Budget

| Phase | Time |
|-------|------|
| Training (wallclock cap) | ~591s |
| GPTQ calibration (reserved) | ~7s |
| Post-EMA eval | ~2s |
| Int6 roundtrip eval | ~7s |
| Sliding window eval (stride=64) | ~78s |
| **N-gram agreement eval** | **~449s** |
| **Total eval** | **~536s** |

## What's New vs PR #1019

### Training improvements (from PR #1179)
1. **Split-LR** — different learning rates for early (0.025) vs late (0.030) layers
2. **BigramHash(2816x160)** — wider projection (160 vs 112), fewer buckets
3. **Sigmoid-gated U-Net** — learnable gates on encoder-decoder skip connections
4. **Soft-round QAT** — temperature-controlled rounding (alpha 1->16) replacing STE
5. **Brotli-11 + byte-shuffle** — saves ~400KB vs LZMA
6. **Coprime-stride data loader** — better data shuffling and coverage

### Evaluation improvement (from PR #1145)
7. **Online n-gram agreement** — 3 causal experts (token n-gram, within-word, word-start) with agreement boosting. Adjusts LLM probabilities via properly normalized exponential tilting. Contributes **-0.0028 BPB**.

## N-gram Agreement: How It Works

Three online n-gram experts predict the next token using only already-scored (past) tokens:
- **Token n-gram** (16-gram context, hash table): predicts based on raw token patterns
- **Within-word continuation**: predicts next subword within the current word
- **Word-start hints**: predicts first token of next word based on previous word context

For each position, the expert with highest expected gain is selected. When 2+ experts agree on the same token, their boost is increased. The LLM's probability is adjusted via exponential tilting:

```
p_adjusted = (scale * p_true) / (1 - p_hint + scale * p_hint)
```

This produces a properly normalized distribution (sums to 1.0). The adjustment is:
- **Causal**: each expert predicts BEFORE updating its state with the target token
- **Score-first**: runs under `torch.inference_mode()`, no model parameters modified
- **Properly normalized**: exponential tilting with correct partition function

## Legality

- Standard F.cross_entropy for training
- N-gram agreement: causal, score-first, properly normalized (exponential tilting)
- No training on validation data
- No SLOT, no multi-epoch TTT
- GPTQ calibration within training budget
- Artifact < 16,000,000 bytes (all seeds)
- Training <= 600s, eval <= 600s (all seeds)

## Architecture

| Component | Setting |
|-----------|---------|
| Layers | 11 (512d, 8 GQA heads, 4 KV heads) |
| MLP | 3x (1536) with LeakyReLU(0.5)^2 |
| Attention | XSA on all 11 layers |
| BigramHash | 2816 x dim=160 |
| Split-LR | early=0.025, late=0.030, bank_split=5 |
| Skip connections | Sigmoid-gated U-Net |
| QAT | Soft-round (alpha ramp 1->16) |
| RoPE | Partial (16/64 dims) |
| LN Scale | 1/sqrt(layer+1) |
| VE128 | Layers 9-10 |
| SmearGate | Position-mixing gate |
| Weight avg | EMA(0.997) + SWA(every 50) |
| Quantization | Full Hessian GPTQ int6 |
| Compression | Brotli quality=11 + byte-shuffle |
| Optimizer | Parallel Muon + Parameter Banking |
| Eval | Online n-gram agreement (token 16-gram + within-word + word-start) |

## Run Command

```bash
# Training (3 seeds)
pip install brotli
for SEED in 1337 42 2025; do
BIGRAM_DIM=160 SEED=$SEED \
torchrun --standalone --nproc_per_node=8 train_gpt.py 2>&1 | tee train_seed${SEED}.log
cp final_model.int6.ptz checkpoints/final_model_seed${SEED}.int6.ptz
done

# N-gram agreement evaluation (per seed)
gcc -O3 -march=native -shared -fPIC -o libonline_ngram_state.so online_ngram_state.c
for SEED in 1337 42 2025; do
BIGRAM_DIM=160 CHECKPOINT=checkpoints/final_model_seed${SEED}.int6.ptz \
torchrun --standalone --nproc_per_node=8 eval_ngram_on_checkpoint.py
done
```

## Credits

- **Training scaffold**: [PR #1179](https://github.com/openai/parameter-golf/pull/1179) by @dexhunter (built on PR #1019 by @abaybektursun)
- **N-gram agreement eval**: [PR #1145](https://github.com/openai/parameter-golf/pull/1145) by @AnirudhRahul
- **Full Hessian GPTQ**: [PR #535](https://github.com/openai/parameter-golf/pull/535) by @raahilshah
- **XSA-all**: [PR #478](https://github.com/openai/parameter-golf/pull/478) by @gowtham0992

## Included Files

- `train_gpt.py` — training + quantization + sliding window eval
- `online_best_agree_eval.py` — n-gram agreement evaluation
- `online_ngram_state.c` — native n-gram hash table (compiled at eval time)
- `eval_ngram_on_checkpoint.py` — helper to run n-gram eval on saved checkpoints
- `train_seed{1337,42,2025}.log` — training logs
- `submission_ngram_seed{1337,42,2025}.log` — n-gram eval logs
- `submission.json` — leaderboard metadata
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#!/usr/bin/env python3
"""Evaluate n-gram agreement on a saved int6 checkpoint."""
from __future__ import annotations
import io
import os
import sys
import time

import brotli
import numpy as np
import sentencepiece as spm
import torch
import torch.distributed as dist

# Add current dir to path for imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from train_gpt import (
GPT,
CastedLinear,
Hyperparameters,
_byte_unshuffle,
_unbank_state_dict,
_rebank_state_dict,
build_sentencepiece_luts,
dequantize_mixed_int6,
load_validation_tokens,
restore_low_dim_params_to_fp32,
)
from online_best_agree_eval import eval_val_sliding_online_best_agree


def main():
args = Hyperparameters()
args.bigram_dim = int(os.environ.get("BIGRAM_DIM", "160"))

distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ
rank = int(os.environ.get("RANK", "0"))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
device = torch.device("cuda", local_rank)
torch.cuda.set_device(device)
if distributed:
dist.init_process_group(backend="nccl", device_id=device)
dist.barrier()
master = rank == 0

def log0(msg, console=True):
if master and console:
print(msg, flush=True)

# Load tokenizer
sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path)
val_tokens = load_validation_tokens(args.val_files, args.train_seq_len)
base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts(
sp, args.vocab_size, device
)

# Load int6 checkpoint
ptz_path = os.environ.get("CHECKPOINT", "final_model.int6.ptz")
log0(f"Loading checkpoint: {ptz_path}")
with open(ptz_path, "rb") as f:
quant_blob = f.read()
quant_state = torch.load(
io.BytesIO(_byte_unshuffle(brotli.decompress(quant_blob))),
map_location="cpu",
)

# Build model
eval_model = GPT(
vocab_size=args.vocab_size,
num_layers=args.num_layers,
model_dim=args.model_dim,
num_heads=args.num_heads,
num_kv_heads=args.num_kv_heads,
mlp_mult=args.mlp_mult,
tie_embeddings=args.tie_embeddings,
tied_embed_init_std=args.tied_embed_init_std,
logit_softcap=args.logit_softcap,
rope_base=args.rope_base,
qk_gain_init=args.qk_gain_init,
bigram_vocab_size=args.bigram_vocab_size,
bigram_dim=args.bigram_dim,
xsa_last_n=args.xsa_last_n,
rope_dims=args.rope_dims,
ln_scale=args.ln_scale,
ve_enabled=args.ve_enabled,
ve_dim=args.ve_dim,
ve_layers=args.ve_layers,
neg_slope=args.negative_slope,
).to(device).bfloat16()
eval_model.qo_bank.data = eval_model.qo_bank.data.float()
eval_model.kv_bank.data = eval_model.kv_bank.data.float()
eval_model.mlp_up_bank.data = eval_model.mlp_up_bank.data.float()
eval_model.mlp_down_bank.data = eval_model.mlp_down_bank.data.float()
for m in eval_model.modules():
if isinstance(m, CastedLinear):
m.float()
restore_low_dim_params_to_fp32(eval_model)

# Dequantize and load weights
template_sd = {k: v.detach().cpu() for k, v in eval_model.state_dict().items()}
unbanked_template = _unbank_state_dict(template_sd, args.num_layers)
deq_unbanked = dequantize_mixed_int6(quant_state["w"], quant_state["m"], unbanked_template)
deq_state = _rebank_state_dict(deq_unbanked, args.num_layers, template_sd)
eval_model.load_state_dict(deq_state, strict=True)
eval_model.eval()

log0(f"Model loaded, running n-gram agreement eval...")
t0 = time.perf_counter()
_, best_bpb, timings = eval_val_sliding_online_best_agree(
args=args,
base_model=eval_model,
rank=rank,
world_size=world_size,
device=device,
val_tokens=val_tokens,
base_bytes_lut=base_bytes_lut,
has_leading_space_lut=has_leading_space_lut,
is_boundary_token_lut=is_boundary_token_lut,
stride=args.eval_stride,
batch_seqs=32,
eval_seq_len=args.train_seq_len,
log0=log0,
)
elapsed = time.perf_counter() - t0
log0(f"n-gram agreement BPB: {best_bpb:.8f} (elapsed: {elapsed:.1f}s)")
log0(f"LLM-only BPB: {timings['llm_bpb']:.8f}")
log0(f"Gain: {timings['gain_bpb']:.8f}")

if distributed:
dist.destroy_process_group()


if __name__ == "__main__":
main()
Loading