diff --git a/records/track_10min_16mb/2026-03-25_GravityTokenizer_AblationLeverage/README.md b/records/track_10min_16mb/2026-03-25_GravityTokenizer_AblationLeverage/README.md new file mode 100644 index 0000000000..b42cd0ad37 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_GravityTokenizer_AblationLeverage/README.md @@ -0,0 +1,213 @@ +# Gravity Tokenizer + +**val_bpb: 1.0321** (3-seed mean, std 0.0011) | **15.6 MB** | 8×H100 SXM + +## Core Idea + +Every submission to this challenge has optimized the model. Nobody has optimized the tokenizer. + +At 1024 vocabulary tokens, every merge slot matters. Standard BPE allocates those slots by frequency. But frequency and structural importance are not the same thing. Some tokens are load-bearing: shatter them back to bytes and downstream loss spikes. Others are convenient shortcuts the model barely notices losing. + +The Gravity Tokenizer replaces 659 of 765 merge tokens with tokens selected by **ablation leverage** — the downstream loss increase when a token is removed from the vocabulary and its occurrences are decomposed to bytes. The vocabulary size stays exactly 1024. Only which tokens occupy the merge slots changes. + +This single change — vocabulary composition — accounts for the entire improvement. The model architecture is a vanilla transformer with no novel components. + +## 3-Seed Results + +| Seed | val_bpb | artifact_bytes | training_time | ms/step | valid | +|------|---------|---------------|---------------|---------|-------| +| 42 | 1.0310 | 15,629,267 | 590,898 ms | 53.72 | yes | +| 137 | 1.0321 | 15,625,195 | 590,980 ms | 53.73 | yes | +| 3 | 1.0331 | 15,625,147 | 591,082 ms | 53.73 | yes | +| **Mean** | **1.0321** | | | | | +| **Std** | **0.0011** | | | | | + +## Architecture + +Deliberately simple. The goal is to isolate the vocabulary effect. + +| Component | Setting | +|-----------|---------| +| Layers | 12 | +| Dimension | 384 | +| Heads | 6 (2 KV heads, GQA) | +| MLP | 3× expansion (hidden=1152) | +| Activation | relu² | +| Sequence length | 2048 | +| Embeddings | Tied | +| Vocab size | 1024 (256 byte + 3 control + 765 merge) | +| Quantization | int8 + zlib | +| Parameters | ~16M | + +No SmearGate. No BigramHash. No XSA. No EMA. No TTT. No sliding window eval. No mixed-precision quantization. + +## The Gravity Scoring Pipeline + +### Step 1: Candidate Generation + +Extract the full BPE merge table (7,997 candidates). Filter to tokens with corpus frequency ≥ 1,000 (3,142 candidates). Remove 84 byte-level tokens. **3,058 scored candidates.** + +### Step 2: Ablation Leverage Scoring + +A frozen **GPT-2** reference model measures each candidate's structural importance. GPT-2 is used because it provides a tokenizer-independent measurement — its own BPE vocabulary is unrelated to the competition's 1024-token vocabulary, so leverage scores reflect language structure, not tokenizer artifacts. A GPT-2 vocabulary contamination check (Pearson correlation between leverage and GPT-2 vocab membership) confirmed no significant contamination. + +For each of the 3,058 candidates, across 100 contexts sampled from FineWeb training shards: + +1. **Find contexts:** Locate occurrences of the candidate's surface form in the decoded corpus text. Extract 512-character windows centered on each occurrence. + +2. **Build text pairs:** For each context, create an intact version (original text) and a shattered version where the target token's characters are space-separated (e.g., `"the"` becomes `"t h e"`). This forces the reference model to process the same content without the benefit of the atomic token. + +3. **Batched forward passes:** Tokenize both versions with GPT-2's tokenizer (with `return_offsets_mapping=True` for precise position tracking). Run batched inference (batch_size=32) on both versions. + +4. **Extract downstream loss:** Using the offset mapping, find the first token position *after* the target in both intact and shattered sequences. Compute mean per-token cross-entropy over a K=10 token downstream window in each version. The leverage is `mean_loss_shattered - mean_loss_intact`. + +5. **Early exit:** After scoring 30 contexts, if the candidate's mean leverage + 2 standard errors is below zero, skip the remaining 70 contexts (the token is clearly not load-bearing). + +6. **Aggregate:** Mean leverage across all scored contexts, with 95% confidence intervals and a breadth measure (entropy of the per-context leverage distribution). + +**Scoring formula:** +``` +score(t) = freq_norm(t)^(1-beta) * leverage_norm(t)^beta +``` + +At β=0.0 this recovers standard BPE. At β=1.0 (this submission), leverage dominates. Both frequency and leverage are log-scaled before normalization. Breadth was computed but dropped from the final score (std=0.15, no discriminative power). + +### Step 3: Vocabulary Construction + +Rank all 3,058 candidates by score. Select the top 765 as merge tokens. Build a SentencePiece Unigram model with byte fallback using the selected vocabulary. + +### Step 4: Retokenization + +Decode the original BPE-tokenized FineWeb shards to raw text, then re-encode with the gravity tokenizer. The validation set is the same FineWeb first-50k-document split used by all submissions, re-encoded. + +## What the Gravity Tokenizer Changes + +**Vocab diff vs BPE: 659 of 765 merge tokens replaced (86%).** + +Tokens removed by gravity scoring (examples): single characters with space prefixes, isolated uppercase letters, digits, punctuation fragments — tokens with high BPE frequency but zero structural importance. + +Tokens promoted by gravity scoring (examples): `every`(0.99), `under`(0.96), `first`(0.90), `take`(0.82), `help`(0.78), `may`(0.78) — common English words that BPE would not include at vocab=1024 but that the model structurally depends on. + +**Compression ratio:** The gravity tokenizer achieves 1.05 bytes/token vs BPE's 2.45 bytes/token. This means more tokens per byte of text — the model must predict more tokens to cover the same content. The BPB metric penalizes this directly: `val_bpb = bits_per_token * tokens_per_byte`. The improvement is entirely in per-token prediction quality overcoming the worse compression ratio. + +## The Tokenizer as Ontology + +The gravity vocabulary is legible in a way BPE is not. You can read the model's structural commitments directly from how it tokenizes a sentence. + +Consider: "The water because caused the damage." + +| Word | Tokens | Structure | +|------|--------|-----------| +| water | `▁water` | Single crystal — atomic unit, full leverage | +| because | `▁because` | Single crystal — causal anchor | +| caused | `▁cause` + `<0x64>` | Partial crystal — morpheme preserved, suffix is byte | +| damage | `<0x64>` `am` `<0x61>` `<0x67>` `<0x65>` | Byte-gas with one fragment — no structural handle | +| the | `▁the` | Single crystal | + +The crystallized tokens are the load-bearing walls. The byte-gas is the fill. BPE hides this distinction — it tokenizes by frequency, so common fragments get tokens regardless of structural importance. A BPE tokenization tells you "this substring appears often." A gravity tokenization tells you "this is where the model's structural commitments are." + +This has measurable consequences for how the model allocates its computational depth. The residual velocity probe (see "Depth Efficiency Law" below) shows that crystallized tokens engage all 12 layers of the network productively, while byte-gas tokens waste 8 of 12 layers — thrashing early, going idle in the middle, and panicking at the final layer. + +The gravity vocabulary doesn't just compress better. It gives the transformer a skeleton to build on. BPE gives it dust. + +For the full theoretical framework: [github.com/dcrow85/Avalanche](https://github.com/dcrow85/Avalanche) + +## Why It Works: The Depth Efficiency Law + +The vocabulary determines how much of the architecture each token can actually use. + +We measured the **residual velocity** — the L2 norm of the representation change at each layer — for every token across the full depth of the network. High velocity means the layer is doing useful work. Low velocity means the layer is idle. + +**High-leverage tokens use all 12 layers productively.** Their velocity rises smoothly from layer 0 through layer 11. Each layer builds on the last. The model spends its full depth processing these tokens. + +**Low-leverage byte-gas tokens waste most of the network.** They exhibit a U-shaped velocity profile: thrashing in early layers (the model attempts to assemble meaning from structureless bytes), going idle in middle layers (the model gives up), and panicking at the final layer (the model must produce a prediction from an unresolved representation). For these tokens, 8 of 12 layers do nothing useful. + +This finding survived a rigorous length-matched control (p = 0.00005, n = 168 single-token insertions). Two independent instruments agree: the external ablation measurement (how much does loss increase when this token is removed?) and the internal velocity measurement (how uniformly does the model process this token across layers?) are reading the same property. + +**The same physics holds at frontier scale.** We ran the identical probe on [Qwen 2.5-72B](https://github.com/dcrow85/Avalanche/blob/main/gravity-tokenizer/DEPTH_EFFICIENCY.md) — 80 layers, 151K vocabulary, 2x A100 GPUs, $3 of compute. The result: semantically rich tokens engage all 80 layers (panic ratio ~6). Single characters and fragments waste 60+ layers (panic ratio >10). The depth efficiency law scales with architecture depth. The waste just gets more expensive. + +A 12-layer model with a BPE vocabulary is effectively a 3-4 layer model for the byte-gas portions of its sequence. The gravity tokenizer doesn't add layers. It makes the existing layers usable. + +**What didn't survive the controls:** An initial hypothesis that high-leverage tokens would "lens" horizontal attention (bending information paths the way mass bends light) was tested and killed. A length-matched control showed that 56% of the observed deflection was a RoPE positional artifact — multi-token insertions increase subject-object distance, and RoPE attenuates attention over distance. A bidirectional control (RoBERTa) confirmed that directional asymmetries in causal models were architectural stress, not learned semantic geometry. The physics of the vocabulary is vertical (layer depth utilization), not horizontal (attention routing between positions). + +Full probe data, scripts, and the Qwen 72B results: [github.com/dcrow85/Avalanche](https://github.com/dcrow85/Avalanche/blob/main/gravity-tokenizer/DEPTH_EFFICIENCY.md) + +## Tokenizer Correctness + +The val_bpb calculation uses the competition's own `build_sentencepiece_luts()` and `eval_val()` functions with **zero modifications**. The byte-counting lookup tables are built from the SentencePiece model proto using the same code path as stock BPE. + +The gravity tokenizer's lower compression ratio (1.05 vs 2.45 bytes/token) results in a **higher** `tokens_per_byte` multiplier in the BPB formula. This penalizes the gravity tokenizer — any BPB improvement must come from genuinely better per-token prediction quality, not from gaming the metric. + +Detailed tokenizer correctness documentation: see `tokenizer_scrutiny_doc.md` in this submission. + +## Controlled Experiments (RTX 5080, matched conditions) + +The vocabulary effect was isolated through controlled A/B experiments before the competition run. All conditions use identical architecture (9L, 512d), identical training budget (matched on bytes seen, not steps), and differ only in vocabulary composition. + +| Condition | Steps | BPB | vs Step-Matched BPE | +|-----------|-------|-----|---------------------| +| BPE baseline | 2,000 | 1.4386 | -- | +| BPE control | 2,870 | 1.4011 | -- | +| Gravity β=0.3 (70 swaps) | 2,870 | 1.3845 | **-0.017** | +| Cold replication β=0.3 | 2,870 | 1.3821 | **-0.019** (confirms) | +| BPE control | 4,656 | 1.3649 | -- | +| Gravity β=1.0 (659 swaps) | 4,656 | 1.2262 | **-0.139** | + +The effect scales linearly: 9.4x more token swaps produced 8.2x more BPB improvement. + +## Negative Results + +Two experiments that did not work, both scientifically informative: + +**Bifurcation scoring (+0.021 worse):** Replacing absolute leverage with delta-leverage (the discrete derivative along the BPE merge tree) removed tokens that are individually redundant but collectively essential for compression. Language needs high-frequency connective tissue between semantic crystals. + +**Warm-start embeddings (+0.038 worse):** Initializing gravity token embeddings as spatial means of their BPE decomposition embeddings caused a catastrophic initial mismatch (val_bpb 27.5 vs 5.4 cold start). The crystallization plateau at ~2000 steps is a whole-model phase transition, not an embedding-local phenomenon. Trained embeddings are entangled with transformer weights. + +## Run Command + +```bash +# Setup (downloads stock FineWeb + retokenizes with gravity vocabulary) +bash setup.sh + +# Train (default seed=1337) +MODEL_DIM=384 NUM_LAYERS=12 NUM_HEADS=6 NUM_KV_HEADS=2 MLP_MULT=3 \ +TRAIN_SEQ_LEN=2048 VOCAB_SIZE=1024 \ +DATA_PATH=./data/datasets/fineweb_gravity_beta_1.0 \ +TOKENIZER_PATH=./data/tokenizers/gravity_beta_1.0.model \ +ITERATIONS=11000 WARMUP_STEPS=50 WARMDOWN_ITERS=2500 \ +MAX_WALLCLOCK_SECONDS=600 \ +SEED=1337 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py + +# With specific seed +SEED=42 MODEL_DIM=384 NUM_LAYERS=12 NUM_HEADS=6 NUM_KV_HEADS=2 MLP_MULT=3 \ +TRAIN_SEQ_LEN=2048 VOCAB_SIZE=1024 \ +DATA_PATH=./data/datasets/fineweb_gravity_beta_1.0 \ +TOKENIZER_PATH=./data/tokenizers/gravity_beta_1.0.model \ +ITERATIONS=11000 WARMUP_STEPS=50 WARMDOWN_ITERS=2500 \ +MAX_WALLCLOCK_SECONDS=600 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +All parameters are passed via environment variables. The `train_gpt.py` script is the unmodified competition baseline. + +## Reproducibility + +The gravity tokenizer can be rebuilt from scratch. The scoring pipeline (ablation leverage computation) requires ~4 hours on a single GPU. All other steps are deterministic and complete in minutes. Full pipeline: + +```bash +python scripts/generate_candidates_sp.py +python scripts/score_leverage.py +python scripts/build_vocabulary.py --beta 1.0 +python scripts/build_tokenizer.py \ + --vocabulary data/vocabularies/vocabulary_beta_1.0.json \ + --output data/tokenizers/gravity_beta_1.0.model \ + --corpus-sample data/corpus_sample.txt +python scripts/retokenize_corpus.py \ + --base-tokenizer data/tokenizers/fineweb_1024_bpe.model \ + --gravity-tokenizer data/tokenizers/gravity_beta_1.0.model \ + --data-dir data/datasets/fineweb10B_sp1024 \ + --output-dir data/datasets/fineweb_gravity_beta_1.0 +``` + +The training script is the competition's `train_gpt.py` with architecture parameters passed via environment variables. No code modifications to the training loop, evaluation function, or quantization pipeline. diff --git a/records/track_10min_16mb/2026-03-25_GravityTokenizer_AblationLeverage/gravity_beta_1.0.model b/records/track_10min_16mb/2026-03-25_GravityTokenizer_AblationLeverage/gravity_beta_1.0.model new file mode 100644 index 0000000000..4d7adb178d Binary files /dev/null and b/records/track_10min_16mb/2026-03-25_GravityTokenizer_AblationLeverage/gravity_beta_1.0.model differ diff --git a/records/track_10min_16mb/2026-03-25_GravityTokenizer_AblationLeverage/gravity_lensing_probe.py b/records/track_10min_16mb/2026-03-25_GravityTokenizer_AblationLeverage/gravity_lensing_probe.py new file mode 100644 index 0000000000..97ff8eb7e7 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_GravityTokenizer_AblationLeverage/gravity_lensing_probe.py @@ -0,0 +1,538 @@ +""" +Gravitational Lensing Probe — Does information flow bend around high-gravity tokens? + +Hypothesis: High-leverage tokens act as gravitational lenses in attention space. +Inserting a massive token between subject and object should collapse the direct +attention path and force information to route through the massive token. + +The falsifiable signature: A(object -> subject) drops when a high-gravity token +is inserted, while A(object -> massive) + A(massive -> subject) absorbs the mass. +Low-gravity control tokens should show significantly less deflection. +""" + +import io +import math +import os +import sys +import zlib +import json + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +import numpy as np + +import sentencepiece as spm + +SCRIPT_DIR = os.path.dirname(__file__) +GOLF_ROOT = os.path.join(SCRIPT_DIR, "..", "parameter-golf") +DATA_DIR = os.path.join(SCRIPT_DIR, "..", "data") + + +# ── Model (from generate.py, modified to capture attention weights) ── + +class CastedLinear(nn.Linear): + def __init__(self, in_f, out_f, bias=False): + super().__init__(in_f, out_f, bias=bias) + def forward(self, x): + return F.linear(x, self.weight.to(x.dtype), + self.bias.to(x.dtype) if self.bias is not None else None) + +class RMSNorm(nn.Module): + def forward(self, x): + return F.rms_norm(x, (x.size(-1),)) + +class Rotary(nn.Module): + def __init__(self, dim, base=10000.0): + super().__init__() + self.inv_freq = nn.Parameter( + 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)), + requires_grad=False) + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + + def forward(self, seq_len, device, dtype): + if seq_len > self._seq_len_cached: + self._seq_len_cached = seq_len + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + +def apply_rotary_emb(x, cos, sin): + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init): + super().__init__() + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + kv_dim = num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + # Storage for captured attention weights + self.last_attn_weights = None + + def forward(self, x, capture_attn=False): + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + + if self.num_kv_heads < self.num_heads: + reps = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(reps, dim=1) + v = v.repeat_interleave(reps, dim=1) + + scale = 1.0 / math.sqrt(self.head_dim) + attn = (q @ k.transpose(-2, -1)) * scale + mask = torch.triu(torch.ones(seqlen, seqlen, device=x.device, dtype=torch.bool), diagonal=1) + attn = attn.masked_fill(mask[None, None], float('-inf')) + attn_weights = F.softmax(attn, dim=-1, dtype=torch.float32) + + if capture_attn: + self.last_attn_weights = attn_weights.detach().cpu() + + y = (attn_weights.to(x.dtype) @ v).transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + +class MLP(nn.Module): + def __init__(self, dim, mlp_mult): + super().__init__() + self.fc = CastedLinear(dim, mlp_mult * dim, bias=False) + self.proj = CastedLinear(mlp_mult * dim, dim, bias=False) + def forward(self, x): + return self.proj(torch.relu(self.fc(x)).square()) + +class Block(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x, x0, capture_attn=False): + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * self.attn(self.attn_norm(x), capture_attn=capture_attn) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class GPT(nn.Module): + def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, + mlp_mult, tie_embeddings, logit_softcap, rope_base, qk_gain_init): + super().__init__() + self.tie_embeddings = tie_embeddings + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + for _ in range(num_layers) + ]) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + + def forward_with_attention(self, input_ids): + """Forward pass that captures attention weights from all layers.""" + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0, capture_attn=True) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0, capture_attn=True) + x = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + + # Collect attention from all layers + all_attn = [] + for block in self.blocks: + all_attn.append(block.attn.last_attn_weights) + return logits, all_attn + + +def dequantize_state_dict_int8(obj): + out = {} + qmeta = obj.get("qmeta", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = obj.get("passthrough_orig_dtypes", {}).get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def load_model(checkpoint_path, tokenizer_path, num_layers=12, device="cuda"): + sp = spm.SentencePieceProcessor(model_file=tokenizer_path) + + model = GPT( + vocab_size=1024, num_layers=num_layers, model_dim=384, + num_heads=6, num_kv_heads=2, mlp_mult=3, + tie_embeddings=True, logit_softcap=30.0, + rope_base=10000.0, qk_gain_init=1.5, + ) + + with open(checkpoint_path, "rb") as f: + quant_blob = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob)), + map_location="cpu", weights_only=True) + state_dict = dequantize_state_dict_int8(quant_state) + model.load_state_dict(state_dict, strict=False) + model = model.to(device).eval() + return model, sp + + +def tokenize_and_locate(sp, text, target_tokens): + """ + Tokenize text and find token SPANS covering each target word. + Returns positions as (start_idx, end_idx) tuples — the range of token + indices that reconstruct the target surface form. + """ + ids = sp.encode(text) + pieces = [sp.id_to_piece(i) for i in ids] + + # Reconstruct character offsets for each token + # SentencePiece uses \u2581 for space + reconstructed = "" + token_char_spans = [] + for idx, piece in enumerate(pieces): + display = piece.replace('\u2581', ' ') + # Byte tokens like <0x6D> represent single bytes + if piece.startswith('<0x') and piece.endswith('>'): + byte_val = int(piece[3:-1], 16) + try: + display = bytes([byte_val]).decode('utf-8', errors='replace') + except: + display = '?' + start = len(reconstructed) + reconstructed += display + end = len(reconstructed) + token_char_spans.append((start, end)) + + positions = {} + for name, surface in target_tokens.items(): + # Find the surface form in the reconstructed text + # Try with and without leading space + for search in [' ' + surface, surface]: + char_idx = reconstructed.find(search) + if char_idx >= 0: + char_start = char_idx + char_end = char_idx + len(search) + # Find which tokens span this character range + tok_start = None + tok_end = None + for tidx, (cs, ce) in enumerate(token_char_spans): + if cs < char_end and ce > char_start: + if tok_start is None: + tok_start = tidx + tok_end = tidx + 1 + if tok_start is not None: + # Use the LAST token in the span as the representative position + # (that's where the full word meaning has been assembled) + positions[name] = { + "span": (tok_start, tok_end), + "repr": tok_end - 1, # last token = most assembled + "tokens": pieces[tok_start:tok_end], + } + break + + return ids, pieces, positions + + +@torch.no_grad() +def measure_attention(model, sp, text, target_tokens, device="cuda"): + """ + Run text through model, capture attention, and extract attention weights + between specified token positions. + """ + ids, pieces, positions = tokenize_and_locate(sp, text, target_tokens) + + missing = [k for k in target_tokens if k not in positions] + if missing: + return None, pieces, positions, f"Missing positions for: {missing}" + + # Reset rotary caches to handle variable sequence lengths + for block in model.blocks: + block.attn.rotary._seq_len_cached = 0 + + input_ids = torch.tensor([ids], dtype=torch.long, device=device) + logits, all_attn = model.forward_with_attention(input_ids) + + # all_attn is list of [1, num_heads, seq_len, seq_len] per layer + return all_attn, pieces, positions, None + + +def run_lensing_probe(model, sp, device="cuda"): + """Run the full gravitational lensing experiment.""" + + # Load leverage scores to identify high/low gravity tokens + scored_path = os.path.join(DATA_DIR, "candidates_scored.jsonl") + leverage_map = {} + if os.path.exists(scored_path): + with open(scored_path, "r", encoding="utf-8") as f: + for line in f: + c = json.loads(line) + if c.get("ablation_leverage", 0) != 0: + leverage_map[c["readable"]] = c["ablation_leverage"] + + print(f"Loaded {len(leverage_map)} leverage scores") + if leverage_map: + top = sorted(leverage_map.items(), key=lambda x: -x[1])[:10] + bottom = sorted(leverage_map.items(), key=lambda x: x[1])[:10] + print(f"\nHighest gravity tokens:") + for tok, lev in top: + print(f" {lev:+.3f} {tok!r}") + print(f"\nLowest gravity tokens:") + for tok, lev in bottom: + print(f" {lev:+.3f} {tok!r}") + + # === EXPERIMENT SENTENCES === + # Base: clean subject-verb-object with measurable attention path + # Massive: insert high-gravity token between S and O + # Control: insert low-gravity token between S and O + + experiments = [ + { + "name": "Negation lensing", + "base": "The government announced the policy", + "massive": "The government not announced the policy", + "control": "The government or announced the policy", + "subject": "government", + "object": "policy", + "lens": "not", + "ctrl": "or", + }, + { + "name": "Conditional lensing", + "base": "The system produced the result", + "massive": "The system if produced the result", + "control": "The system an produced the result", + "subject": "system", + "object": "result", + "lens": "if", + "ctrl": "an", + }, + { + "name": "Causal lensing", + "base": "The water caused the damage", + "massive": "The water because caused the damage", + "control": "The water so caused the damage", + "subject": "water", + "object": "damage", + "lens": "because", + "ctrl": "so", + }, + ] + + results = [] + + for exp in experiments: + print(f"\n{'='*70}") + print(f"EXPERIMENT: {exp['name']}") + print(f"{'='*70}") + + for condition, label in [ + ("base", "BASE (no insertion)"), + ("massive", f"MASSIVE ({exp['lens']})"), + ("control", f"CONTROL ({exp['ctrl']})") + ]: + text = exp[condition] + print(f"\n--- {label} ---") + print(f"Text: {text!r}") + + # Identify target tokens to locate + targets = {"subject": exp["subject"], "object": exp["object"]} + if condition == "massive": + targets["lens"] = exp["lens"] + elif condition == "control": + targets["lens"] = exp["ctrl"] + + all_attn, pieces, positions, err = measure_attention( + model, sp, text, targets, device + ) + + print(f"Tokens: {pieces}") + print(f"Positions: {positions}") + + if err: + print(f"ERROR: {err}") + continue + + subj_pos = positions["subject"]["repr"] + obj_pos = positions["object"]["repr"] + lens_info = positions.get("lens", None) + lens_pos = lens_info["repr"] if lens_info else None + + print(f"Subject '{exp['subject']}': tokens {positions['subject']['tokens']} -> repr pos {subj_pos}") + print(f"Object '{exp['object']}': tokens {positions['object']['tokens']} -> repr pos {obj_pos}") + if lens_info: + lens_word = exp.get('lens', exp.get('ctrl', '?')) + print(f"Lens '{lens_word}': tokens {lens_info['tokens']} -> repr pos {lens_pos}") + + # Attention from object -> subject (the direct geodesic) + # Average across all layers and heads + direct_attn_per_layer = [] + lens_attn_per_layer = [] # object -> lens, lens -> subject + + for layer_idx, attn in enumerate(all_attn): + # attn shape: [1, num_heads, seq_len, seq_len] + # attn[0, head, query_pos, key_pos] = how much query_pos attends to key_pos + a = attn[0] # [num_heads, seq_len, seq_len] + + # Direct: object attends to subject + direct = a[:, obj_pos, subj_pos].mean().item() + direct_attn_per_layer.append(direct) + + if lens_pos is not None: + # Lensed path: object -> lens + lens -> subject + obj_to_lens = a[:, obj_pos, lens_pos].mean().item() + lens_to_subj = a[:, lens_pos, subj_pos].mean().item() + lens_attn_per_layer.append((obj_to_lens, lens_to_subj)) + + mean_direct = np.mean(direct_attn_per_layer) + print(f"\nA(object->subject) mean across layers: {mean_direct:.4f}") + print(f"A(object->subject) per layer: {['%.4f' % x for x in direct_attn_per_layer]}") + + if lens_pos is not None and lens_attn_per_layer: + obj_lens = [x[0] for x in lens_attn_per_layer] + lens_subj = [x[1] for x in lens_attn_per_layer] + print(f"A(object->lens) mean: {np.mean(obj_lens):.4f}") + print(f"A(lens->subject) mean: {np.mean(lens_subj):.4f}") + print(f"Indirect path product mean: {np.mean([a*b for a,b in lens_attn_per_layer]):.6f}") + + result = { + "experiment": exp["name"], + "condition": condition, + "label": label, + "text": text, + "pieces": pieces, + "positions": positions, + "direct_attn_mean": mean_direct, + "direct_attn_per_layer": direct_attn_per_layer, + } + if lens_pos is not None and lens_attn_per_layer: + result["obj_to_lens_mean"] = np.mean(obj_lens) + result["lens_to_subj_mean"] = np.mean(lens_subj) + result["obj_to_lens_per_layer"] = obj_lens + result["lens_to_subj_per_layer"] = lens_subj + + # Gravity scores if available + if condition == "massive" and exp["lens"] in leverage_map: + result["lens_leverage"] = leverage_map[exp["lens"]] + print(f"Lens token gravity: {leverage_map[exp['lens']]:.3f}") + if condition == "control" and exp["ctrl"] in leverage_map: + result["lens_leverage"] = leverage_map[exp["ctrl"]] + print(f"Control token gravity: {leverage_map[exp['ctrl']]:.3f}") + + results.append(result) + + # Summary comparison + base_r = [r for r in results if r["experiment"] == exp["name"] and r["condition"] == "base"] + mass_r = [r for r in results if r["experiment"] == exp["name"] and r["condition"] == "massive"] + ctrl_r = [r for r in results if r["experiment"] == exp["name"] and r["condition"] == "control"] + + if base_r and mass_r and ctrl_r: + base_direct = base_r[0]["direct_attn_mean"] + mass_direct = mass_r[0]["direct_attn_mean"] + ctrl_direct = ctrl_r[0]["direct_attn_mean"] + + print(f"\n{'─'*50}") + print(f"LENSING SUMMARY: {exp['name']}") + print(f"{'─'*50}") + print(f"Direct A(obj->subj) BASE: {base_direct:.4f}") + print(f"Direct A(obj->subj) MASSIVE: {mass_direct:.4f} (delta: {mass_direct - base_direct:+.4f})") + print(f"Direct A(obj->subj) CONTROL: {ctrl_direct:.4f} (delta: {ctrl_direct - base_direct:+.4f})") + + if mass_direct < base_direct and mass_direct < ctrl_direct: + deflection_ratio = (base_direct - mass_direct) / max(base_direct - ctrl_direct, 1e-8) + print(f"\n** LENSING DETECTED ** Massive token deflects {deflection_ratio:.1f}x more than control") + elif mass_direct < base_direct: + print(f"\nDeflection present but not stronger than control.") + else: + print(f"\nNo deflection detected.") + + # Save results + output_path = os.path.join(DATA_DIR, "lensing_probe_results.json") + # Convert numpy types for JSON serialization + def convert(obj): + if isinstance(obj, np.floating): + return float(obj) + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, np.ndarray): + return obj.tolist() + return obj + + clean_results = json.loads(json.dumps(results, default=convert)) + with open(output_path, "w", encoding="utf-8") as f: + json.dump(clean_results, f, indent=2, ensure_ascii=False) + print(f"\nResults saved to: {output_path}") + + return results + + +if __name__ == "__main__": + # Fix Windows console encoding + import codecs + sys.stdout = codecs.getwriter('utf-8')(sys.stdout.buffer, errors='replace') + + device = "cuda" if torch.cuda.is_available() else "cpu" + + # Default to the fully trained 12L seed 137 checkpoint + checkpoint = os.path.join(GOLF_ROOT, "logs", "gravity_12L_seed137.int8.ptz") + if not os.path.exists(checkpoint): + # Fall back to smoke test checkpoint + checkpoint = os.path.join(GOLF_ROOT, "final_model.int8.ptz") + num_layers = 13 + print("WARNING: Using 13L smoke test checkpoint (100 steps)") + else: + num_layers = 12 + + tokenizer = os.path.join(DATA_DIR, "tokenizers", "gravity_beta_1.0.model") + + print(f"Loading model ({num_layers}L)...") + model, sp = load_model(checkpoint, tokenizer, num_layers=num_layers, device=device) + print(f"Model loaded. Parameters: {sum(p.numel() for p in model.parameters()):,}") + print(f"Device: {device}\n") + + results = run_lensing_probe(model, sp, device) diff --git a/records/track_10min_16mb/2026-03-25_GravityTokenizer_AblationLeverage/lensing_probe_results.json b/records/track_10min_16mb/2026-03-25_GravityTokenizer_AblationLeverage/lensing_probe_results.json new file mode 100644 index 0000000000..99b61c78a4 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_GravityTokenizer_AblationLeverage/lensing_probe_results.json @@ -0,0 +1,984 @@ +[ + { + "experiment": "Negation lensing", + "condition": "base", + "label": "BASE (no insertion)", + "text": "The government announced the policy", + "pieces": [ + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x54>", + "<0x68>", + "<0x65>", + "▁govern", + "<0x6D>", + "<0x65>", + "<0x6E>", + "<0x74>", + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x61>", + "<0x6E>", + "<0x6E>", + "<0x6F>", + "<0x75>", + "<0x6E>", + "<0x63>", + "ed", + "▁the", + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x70>", + "<0x6F>", + "<0x6C>", + "<0x69>", + "<0x63>", + "<0x79>" + ], + "positions": { + "subject": { + "span": [ + 6, + 11 + ], + "repr": 10, + "tokens": [ + "▁govern", + "<0x6D>", + "<0x65>", + "<0x6E>", + "<0x74>" + ] + }, + "object": { + "span": [ + 26, + 32 + ], + "repr": 31, + "tokens": [ + "<0x70>", + "<0x6F>", + "<0x6C>", + "<0x69>", + "<0x63>", + "<0x79>" + ] + } + }, + "direct_attn_mean": 0.015122186294926602, + "direct_attn_per_layer": [ + 0.00043089943937957287, + 5.5245993280550465e-05, + 5.288392640068196e-05, + 0.037041742354631424, + 0.0012060640146955848, + 0.06661331653594971, + 0.0079692667350173, + 0.009655219502747059, + 0.019711574539542198, + 0.003436535596847534, + 0.011124531738460064, + 0.02416895516216755 + ] + }, + { + "experiment": "Negation lensing", + "condition": "massive", + "label": "MASSIVE (not)", + "text": "The government not announced the policy", + "pieces": [ + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x54>", + "<0x68>", + "<0x65>", + "▁govern", + "<0x6D>", + "<0x65>", + "<0x6E>", + "<0x74>", + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x6E>", + "<0x6F>", + "<0x74>", + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x61>", + "<0x6E>", + "<0x6E>", + "<0x6F>", + "<0x75>", + "<0x6E>", + "<0x63>", + "ed", + "▁the", + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x70>", + "<0x6F>", + "<0x6C>", + "<0x69>", + "<0x63>", + "<0x79>" + ], + "positions": { + "subject": { + "span": [ + 6, + 11 + ], + "repr": 10, + "tokens": [ + "▁govern", + "<0x6D>", + "<0x65>", + "<0x6E>", + "<0x74>" + ] + }, + "object": { + "span": [ + 32, + 38 + ], + "repr": 37, + "tokens": [ + "<0x70>", + "<0x6F>", + "<0x6C>", + "<0x69>", + "<0x63>", + "<0x79>" + ] + }, + "lens": { + "span": [ + 14, + 17 + ], + "repr": 16, + "tokens": [ + "<0x6E>", + "<0x6F>", + "<0x74>" + ] + } + }, + "direct_attn_mean": 0.015094329682445581, + "direct_attn_per_layer": [ + 0.00021376127551775426, + 8.595315739512444e-05, + 7.203236600616947e-05, + 0.036622606217861176, + 0.0009260917431674898, + 0.06514961272478104, + 0.01029993686825037, + 0.008856580592691898, + 0.020316286012530327, + 0.004500532057136297, + 0.011290494352579117, + 0.022798068821430206 + ], + "obj_to_lens_mean": 0.010393538135152388, + "lens_to_subj_mean": 0.06097297896243011, + "obj_to_lens_per_layer": [ + 0.000410925829783082, + 0.00014484314306173474, + 0.0001192315248772502, + 0.002107759937644005, + 0.0017183570889756083, + 0.006053080316632986, + 0.004514262080192566, + 0.008829547092318535, + 0.012348380871117115, + 0.0010772277601063251, + 0.03334233537316322, + 0.05405650660395622 + ], + "lens_to_subj_per_layer": [ + 0.0033661548513919115, + 0.021302374079823494, + 0.044286470860242844, + 0.06685604900121689, + 0.0015673968009650707, + 0.1540781557559967, + 0.13025078177452087, + 0.01759324222803116, + 0.050730545073747635, + 0.009366304613649845, + 0.07825236022472382, + 0.15402591228485107 + ], + "lens_leverage": 0.4994449830055237 + }, + { + "experiment": "Negation lensing", + "condition": "control", + "label": "CONTROL (or)", + "text": "The government or announced the policy", + "pieces": [ + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x54>", + "<0x68>", + "<0x65>", + "▁govern", + "<0x6D>", + "<0x65>", + "<0x6E>", + "<0x74>", + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x6F>", + "<0x72>", + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x61>", + "<0x6E>", + "<0x6E>", + "<0x6F>", + "<0x75>", + "<0x6E>", + "<0x63>", + "ed", + "▁the", + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x70>", + "<0x6F>", + "<0x6C>", + "<0x69>", + "<0x63>", + "<0x79>" + ], + "positions": { + "subject": { + "span": [ + 6, + 11 + ], + "repr": 10, + "tokens": [ + "▁govern", + "<0x6D>", + "<0x65>", + "<0x6E>", + "<0x74>" + ] + }, + "object": { + "span": [ + 31, + 37 + ], + "repr": 36, + "tokens": [ + "<0x70>", + "<0x6F>", + "<0x6C>", + "<0x69>", + "<0x63>", + "<0x79>" + ] + }, + "lens": { + "span": [ + 14, + 16 + ], + "repr": 15, + "tokens": [ + "<0x6F>", + "<0x72>" + ] + } + }, + "direct_attn_mean": 0.01587346368554184, + "direct_attn_per_layer": [ + 0.000273035402642563, + 0.00017297372687608004, + 6.43234743620269e-05, + 0.03787382319569588, + 0.0010990200098603964, + 0.06376150995492935, + 0.008282159455120564, + 0.008824133314192295, + 0.01660069078207016, + 0.008205784484744072, + 0.015001180581748486, + 0.030322929844260216 + ], + "obj_to_lens_mean": 0.02169098268010809, + "lens_to_subj_mean": 0.07981696611386724, + "obj_to_lens_per_layer": [ + 0.0008744990336708724, + 0.00011489063035696745, + 0.00037255522329360247, + 0.024879522621631622, + 0.02110409177839756, + 0.017770307138562202, + 0.020156443119049072, + 0.022705787792801857, + 0.027203276753425598, + 0.004043154884129763, + 0.06293142586946487, + 0.05813583731651306 + ], + "lens_to_subj_per_layer": [ + 0.007105917204171419, + 0.006914152298122644, + 0.12936672568321228, + 0.16284982860088348, + 0.0017436690395697951, + 0.12981651723384857, + 0.09091484546661377, + 0.08778300881385803, + 0.08032350242137909, + 0.06331507116556168, + 0.07532889395952225, + 0.12234146147966385 + ], + "lens_leverage": 0.47793734312057495 + }, + { + "experiment": "Conditional lensing", + "condition": "base", + "label": "BASE (no insertion)", + "text": "The system produced the result", + "pieces": [ + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x54>", + "<0x68>", + "<0x65>", + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x73>", + "<0x79>", + "<0x73>", + "<0x74>", + "<0x65>", + "<0x6D>", + "▁produ", + "<0x63>", + "ed", + "▁the", + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x72>", + "<0x65>", + "<0x73>", + "<0x75>", + "<0x6C>", + "<0x74>" + ], + "positions": { + "subject": { + "span": [ + 9, + 15 + ], + "repr": 14, + "tokens": [ + "<0x73>", + "<0x79>", + "<0x73>", + "<0x74>", + "<0x65>", + "<0x6D>" + ] + }, + "object": { + "span": [ + 22, + 28 + ], + "repr": 27, + "tokens": [ + "<0x72>", + "<0x65>", + "<0x73>", + "<0x75>", + "<0x6C>", + "<0x74>" + ] + } + }, + "direct_attn_mean": 0.017299913059105165, + "direct_attn_per_layer": [ + 0.0006068542716093361, + 0.0002747593680396676, + 0.0023292822297662497, + 0.023024151101708412, + 0.002907824469730258, + 0.054897408932447433, + 0.008313638158142567, + 0.017731403931975365, + 0.030789492651820183, + 0.00462228013202548, + 0.018410345539450645, + 0.04369151592254639 + ] + }, + { + "experiment": "Conditional lensing", + "condition": "massive", + "label": "MASSIVE (if)", + "text": "The system if produced the result", + "pieces": [ + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x54>", + "<0x68>", + "<0x65>", + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x73>", + "<0x79>", + "<0x73>", + "<0x74>", + "<0x65>", + "<0x6D>", + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x69>", + "<0x66>", + "▁produ", + "<0x63>", + "ed", + "▁the", + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x72>", + "<0x65>", + "<0x73>", + "<0x75>", + "<0x6C>", + "<0x74>" + ], + "positions": { + "subject": { + "span": [ + 9, + 15 + ], + "repr": 14, + "tokens": [ + "<0x73>", + "<0x79>", + "<0x73>", + "<0x74>", + "<0x65>", + "<0x6D>" + ] + }, + "object": { + "span": [ + 27, + 33 + ], + "repr": 32, + "tokens": [ + "<0x72>", + "<0x65>", + "<0x73>", + "<0x75>", + "<0x6C>", + "<0x74>" + ] + }, + "lens": { + "span": [ + 18, + 20 + ], + "repr": 19, + "tokens": [ + "<0x69>", + "<0x66>" + ] + } + }, + "direct_attn_mean": 0.013066406292637112, + "direct_attn_per_layer": [ + 0.0003607298422139138, + 0.00017732994456309825, + 0.0009517939179204404, + 0.016804426908493042, + 0.0017424466786906123, + 0.04083998128771782, + 0.0066705686040222645, + 0.00996649730950594, + 0.019006645306944847, + 0.004940125625580549, + 0.015650030225515366, + 0.03968629986047745 + ], + "obj_to_lens_mean": 0.036627730228550114, + "lens_to_subj_mean": 0.04541809319440896, + "obj_to_lens_per_layer": [ + 0.0006344399298541248, + 0.0006960152531974018, + 0.07109031081199646, + 0.030542252585291862, + 0.02923300303518772, + 0.015696316957473755, + 0.01812676526606083, + 0.039688657969236374, + 0.05722779035568237, + 0.012694957666099072, + 0.08814310282468796, + 0.0757591500878334 + ], + "lens_to_subj_per_layer": [ + 0.012249965220689774, + 0.015576481819152832, + 0.01270817220211029, + 0.07744071632623672, + 0.004510016646236181, + 0.09644815325737, + 0.07864350080490112, + 0.03455108031630516, + 0.06631912291049957, + 0.002945446642115712, + 0.08062685281038284, + 0.06299760937690735 + ], + "lens_leverage": 0.2885540497303009 + }, + { + "experiment": "Conditional lensing", + "condition": "control", + "label": "CONTROL (an)", + "text": "The system an produced the result", + "pieces": [ + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x54>", + "<0x68>", + "<0x65>", + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x73>", + "<0x79>", + "<0x73>", + "<0x74>", + "<0x65>", + "<0x6D>", + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x61>", + "<0x6E>", + "▁produ", + "<0x63>", + "ed", + "▁the", + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x72>", + "<0x65>", + "<0x73>", + "<0x75>", + "<0x6C>", + "<0x74>" + ], + "positions": { + "subject": { + "span": [ + 9, + 15 + ], + "repr": 14, + "tokens": [ + "<0x73>", + "<0x79>", + "<0x73>", + "<0x74>", + "<0x65>", + "<0x6D>" + ] + }, + "object": { + "span": [ + 27, + 33 + ], + "repr": 32, + "tokens": [ + "<0x72>", + "<0x65>", + "<0x73>", + "<0x75>", + "<0x6C>", + "<0x74>" + ] + }, + "lens": { + "span": [ + 18, + 20 + ], + "repr": 19, + "tokens": [ + "<0x61>", + "<0x6E>" + ] + } + }, + "direct_attn_mean": 0.016426302935239317, + "direct_attn_per_layer": [ + 0.000361365033313632, + 0.00017832913727033883, + 0.0010651357006281614, + 0.01672285795211792, + 0.0029113793279975653, + 0.04111924394965172, + 0.007443075533956289, + 0.011330702342092991, + 0.022801773622632027, + 0.0047871884889900684, + 0.02815561182796955, + 0.060238972306251526 + ], + "obj_to_lens_mean": 0.005372950511324841, + "lens_to_subj_mean": 0.06961187755223364, + "obj_to_lens_per_layer": [ + 0.0011022401740774512, + 3.130460390821099e-05, + 0.0025288211181759834, + 0.005047538783401251, + 0.008179730735719204, + 0.018197597935795784, + 0.003926458302885294, + 0.006636775564402342, + 0.010736343450844288, + 0.0012400108389556408, + 0.003870969405397773, + 0.0029776152223348618 + ], + "lens_to_subj_per_layer": [ + 0.006149050313979387, + 0.008231346495449543, + 0.04968501254916191, + 0.14766137301921844, + 0.02053576521575451, + 0.14519092440605164, + 0.1528148055076599, + 0.04174092039465904, + 0.09412223100662231, + 0.01120499987155199, + 0.07347569614648819, + 0.08453040570020676 + ], + "lens_leverage": 0.5156482374668121 + }, + { + "experiment": "Causal lensing", + "condition": "base", + "label": "BASE (no insertion)", + "text": "The water caused the damage", + "pieces": [ + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x54>", + "<0x68>", + "<0x65>", + "▁water", + "▁cause", + "<0x64>", + "▁the", + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x64>", + "am", + "<0x61>", + "<0x67>", + "<0x65>" + ], + "positions": { + "subject": { + "span": [ + 6, + 7 + ], + "repr": 6, + "tokens": [ + "▁water" + ] + }, + "object": { + "span": [ + 13, + 18 + ], + "repr": 17, + "tokens": [ + "<0x64>", + "am", + "<0x61>", + "<0x67>", + "<0x65>" + ] + } + }, + "direct_attn_mean": 0.01958591579386848, + "direct_attn_per_layer": [ + 0.001012675347737968, + 0.00021888363698963076, + 0.026914067566394806, + 0.04354044795036316, + 0.0029507491271942854, + 0.04756613448262215, + 0.012266930192708969, + 0.00934428721666336, + 0.017384560778737068, + 0.0028149804566055536, + 0.033293720334768295, + 0.03772355243563652 + ] + }, + { + "experiment": "Causal lensing", + "condition": "massive", + "label": "MASSIVE (because)", + "text": "The water because caused the damage", + "pieces": [ + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x54>", + "<0x68>", + "<0x65>", + "▁water", + "▁because", + "▁cause", + "<0x64>", + "▁the", + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x64>", + "am", + "<0x61>", + "<0x67>", + "<0x65>" + ], + "positions": { + "subject": { + "span": [ + 6, + 7 + ], + "repr": 6, + "tokens": [ + "▁water" + ] + }, + "object": { + "span": [ + 14, + 19 + ], + "repr": 18, + "tokens": [ + "<0x64>", + "am", + "<0x61>", + "<0x67>", + "<0x65>" + ] + }, + "lens": { + "span": [ + 7, + 8 + ], + "repr": 7, + "tokens": [ + "▁because" + ] + } + }, + "direct_attn_mean": 0.01643753959857956, + "direct_attn_per_layer": [ + 0.000748742779251188, + 0.0001208566318382509, + 0.01715696044266224, + 0.030608484521508217, + 0.0021519071888178587, + 0.044984120875597, + 0.010929272510111332, + 0.008835564367473125, + 0.014929008670151234, + 0.0035370290279388428, + 0.03301990032196045, + 0.03022862784564495 + ], + "obj_to_lens_mean": 0.03497233257166954, + "lens_to_subj_mean": 0.10236597650994857, + "obj_to_lens_per_layer": [ + 0.0035891039296984673, + 4.54233777418267e-05, + 0.00964908953756094, + 0.03006984107196331, + 0.02040799893438816, + 0.014549835585057735, + 0.03587587922811508, + 0.06929012387990952, + 0.0486983060836792, + 0.03496307507157326, + 0.0858139768242836, + 0.06671533733606339 + ], + "lens_to_subj_per_layer": [ + 0.020735178142786026, + 0.10182452201843262, + 0.04919968545436859, + 0.17654752731323242, + 0.00943372119218111, + 0.22931033372879028, + 0.12143711000680923, + 0.06141812726855278, + 0.17781932651996613, + 0.009793675504624844, + 0.1180754229426384, + 0.15279708802700043 + ], + "lens_leverage": 0.7346546828746796 + }, + { + "experiment": "Causal lensing", + "condition": "control", + "label": "CONTROL (so)", + "text": "The water so caused the damage", + "pieces": [ + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x54>", + "<0x68>", + "<0x65>", + "▁water", + "▁so", + "▁cause", + "<0x64>", + "▁the", + "<0xE2>", + "<0x96>", + "<0x81>", + "<0x64>", + "am", + "<0x61>", + "<0x67>", + "<0x65>" + ], + "positions": { + "subject": { + "span": [ + 6, + 7 + ], + "repr": 6, + "tokens": [ + "▁water" + ] + }, + "object": { + "span": [ + 14, + 19 + ], + "repr": 18, + "tokens": [ + "<0x64>", + "am", + "<0x61>", + "<0x67>", + "<0x65>" + ] + }, + "lens": { + "span": [ + 7, + 8 + ], + "repr": 7, + "tokens": [ + "▁so" + ] + } + }, + "direct_attn_mean": 0.016383881518777343, + "direct_attn_per_layer": [ + 0.0007508556009270251, + 0.00012135002907598391, + 0.01664518006145954, + 0.03187965974211693, + 0.0025103448424488306, + 0.04657818749547005, + 0.01147399004548788, + 0.00902616698294878, + 0.016614453867077827, + 0.002210734412074089, + 0.027384890243411064, + 0.031410764902830124 + ], + "obj_to_lens_mean": 0.010304755402103183, + "lens_to_subj_mean": 0.2287989921751432, + "obj_to_lens_per_layer": [ + 0.0024258799385279417, + 7.869590626796708e-05, + 0.0035336241126060486, + 0.013310817070305347, + 0.03829064592719078, + 0.01072451937943697, + 0.008060471154749393, + 0.017663750797510147, + 0.012170073576271534, + 0.005185696762055159, + 0.005919232498854399, + 0.006293657701462507 + ], + "lens_to_subj_per_layer": [ + 0.08700758963823318, + 0.10224344581365585, + 0.2489047795534134, + 0.3630940914154053, + 0.023152783513069153, + 0.36620381474494934, + 0.2586941421031952, + 0.24749697744846344, + 0.43611860275268555, + 0.003194293240085244, + 0.3856511116027832, + 0.22382627427577972 + ], + "lens_leverage": 0.786241979598999 + } +] \ No newline at end of file diff --git a/records/track_10min_16mb/2026-03-25_GravityTokenizer_AblationLeverage/retokenize_corpus.py b/records/track_10min_16mb/2026-03-25_GravityTokenizer_AblationLeverage/retokenize_corpus.py new file mode 100644 index 0000000000..ecf9fedfd7 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_GravityTokenizer_AblationLeverage/retokenize_corpus.py @@ -0,0 +1,150 @@ +""" +Re-tokenize FineWeb corpus with a gravity tokenizer. + +Takes the raw text (decoded from existing shards) and re-encodes it using +a gravity-weighted SentencePiece model. Outputs binary shard files in the +same format the Parameter Golf training script expects. + +Usage: + python scripts/retokenize_corpus.py \ + --base-tokenizer ./parameter-golf/data/tokenizers/fineweb_1024_bpe.model \ + --gravity-tokenizer data/tokenizers/gravity_beta_0.3.model \ + --data-dir ./parameter-golf/data/datasets/fineweb10B_sp1024 \ + --output-dir ./parameter-golf/data/datasets/fineweb_gravity_beta_0.3 \ + --max-shards 10 +""" + +import argparse +import struct +import sys +from pathlib import Path + +import numpy as np +import sentencepiece as spm +from tqdm import tqdm + + +SHARD_MAGIC = 20240520 +SHARD_VERSION = 1 +HEADER_INTS = 256 +SHARD_SIZE = 100_000_000 # tokens per shard (same as parameter-golf default) + + +def load_shard_tokens(path: Path) -> np.ndarray: + """Load tokens from a binary shard file.""" + header = np.fromfile(path, dtype=" str: + """Decode a tokenized shard back to raw text.""" + token_ids = load_shard_tokens(shard_path) + # Decode in chunks + chunk_size = 100_000 + text_parts = [] + for i in range(0, len(token_ids), chunk_size): + chunk = token_ids[i:i + chunk_size].tolist() + text_parts.append(sp.decode(chunk)) + return "".join(text_parts) + + +def encode_text_to_tokens(text: str, sp: spm.SentencePieceProcessor) -> np.ndarray: + """Encode text to token IDs using SentencePiece.""" + token_ids = sp.encode(text) + return np.array(token_ids, dtype=np.uint16) + + +def main(): + parser = argparse.ArgumentParser(description="Re-tokenize corpus") + parser.add_argument("--base-tokenizer", type=str, required=True, + help="Path to base SentencePiece model (for decoding)") + parser.add_argument("--gravity-tokenizer", type=str, required=True, + help="Path to gravity SentencePiece model (for encoding)") + parser.add_argument("--data-dir", type=str, required=True, + help="Directory with original tokenized shards") + parser.add_argument("--output-dir", type=str, required=True, + help="Output directory for re-tokenized shards") + parser.add_argument("--max-shards", type=int, default=0, + help="Max shards to process (0=all)") + args = parser.parse_args() + + sp_base = spm.SentencePieceProcessor(model_file=args.base_tokenizer) + sp_gravity = spm.SentencePieceProcessor(model_file=args.gravity_tokenizer) + + print(f"Base tokenizer: vocab={sp_base.vocab_size()}") + print(f"Gravity tokenizer: vocab={sp_gravity.vocab_size()}") + + data_dir = Path(args.data_dir) + output_dir = Path(args.output_dir) + + # Process training shards + train_shards = sorted(data_dir.glob("fineweb_train_*.bin")) + val_shards = sorted(data_dir.glob("fineweb_val_*.bin")) + + if args.max_shards > 0: + train_shards = train_shards[:args.max_shards] + + print(f"\nProcessing {len(train_shards)} training shards + {len(val_shards)} val shards") + + total_base_tokens = 0 + total_gravity_tokens = 0 + + for shard_path in tqdm(train_shards + val_shards, desc="Re-tokenizing"): + print(f"\n {shard_path.name}:") + + # Decode to text + text = decode_shard_to_text(shard_path, sp_base) + base_tokens_count = len(load_shard_tokens(shard_path)) + + # Re-encode with gravity tokenizer + gravity_tokens = encode_text_to_tokens(text, sp_gravity) + gravity_tokens_count = len(gravity_tokens) + + total_base_tokens += base_tokens_count + total_gravity_tokens += gravity_tokens_count + + ratio = gravity_tokens_count / base_tokens_count if base_tokens_count > 0 else 0 + print(f" Base tokens: {base_tokens_count:,}") + print(f" Gravity tokens: {gravity_tokens_count:,} ({ratio:.2f}x)") + + # Check that all token IDs fit in uint16 + if gravity_tokens.max() >= 65536: + print(f" WARNING: Token IDs exceed uint16 range!") + + # Write output shard + output_path = output_dir / shard_path.name + write_shard(output_path, gravity_tokens) + + # Summary + overall_ratio = total_gravity_tokens / total_base_tokens if total_base_tokens > 0 else 0 + print(f"\n{'='*60}") + print(f"Re-tokenization complete") + print(f" Total base tokens: {total_base_tokens:,}") + print(f" Total gravity tokens: {total_gravity_tokens:,}") + print(f" Sequence length ratio: {overall_ratio:.3f}x") + print(f" (>1.0 means gravity produces longer sequences)") + print(f" Output directory: {output_dir}") + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-25_GravityTokenizer_AblationLeverage/setup.sh b/records/track_10min_16mb/2026-03-25_GravityTokenizer_AblationLeverage/setup.sh new file mode 100644 index 0000000000..704e4f5a47 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_GravityTokenizer_AblationLeverage/setup.sh @@ -0,0 +1,120 @@ +#!/bin/bash +# ------------------------------------------------------------------------------- +# Gravity Tokenizer — Setup Script +# Downloads stock FineWeb, then retokenizes with the gravity vocabulary. +# ------------------------------------------------------------------------------- + +set -e + +echo "----------------------------------------------" +echo " Gravity Tokenizer — Setup" +echo "----------------------------------------------" + +# ------------------------------------------------------------------------------- +# 1. Dependencies +# ------------------------------------------------------------------------------- +echo "" +echo "[1/3] Checking dependencies..." + +pip install --quiet sentencepiece numpy huggingface-hub tqdm + +echo " Done." + +# ------------------------------------------------------------------------------- +# 2. Download stock BPE FineWeb (competition data) +# ------------------------------------------------------------------------------- +echo "" +echo "[2/3] Downloading stock FineWeb (sp1024)..." + +STOCK_DIR=./data/datasets/fineweb10B_sp1024 +STOCK_TOK=./data/tokenizers/fineweb_1024_bpe.model + +if [ -d "$STOCK_DIR" ] && ls "$STOCK_DIR"/fineweb_train_*.bin 1>/dev/null 2>&1; then + TRAIN_COUNT=$(ls "$STOCK_DIR"/fineweb_train_*.bin | wc -l) + echo " Found $TRAIN_COUNT existing train shards — skipping download." +else + python3 data/cached_challenge_fineweb.py --variant sp1024 + echo " Downloaded." +fi + +# ------------------------------------------------------------------------------- +# 3. Retokenize with gravity vocabulary +# ------------------------------------------------------------------------------- +echo "" +echo "[3/3] Retokenizing with gravity tokenizer..." + +GRAVITY_DIR=./data/datasets/fineweb_gravity_beta_1.0 +GRAVITY_TOK=./data/tokenizers/gravity_beta_1.0.model + +# Copy gravity tokenizer model into expected location +mkdir -p ./data/tokenizers +if [ ! -f "$GRAVITY_TOK" ]; then + # Download from HuggingFace if not bundled + if [ -f "$(dirname "$0")/gravity_beta_1.0.model" ]; then + cp "$(dirname "$0")/gravity_beta_1.0.model" "$GRAVITY_TOK" + else + huggingface-cli download dcrow85/gravity-tokenizer-fineweb \ + tokenizers/gravity_beta_1.0.model \ + --repo-type dataset \ + --local-dir ./data + fi +fi + +if [ -d "$GRAVITY_DIR" ] && ls "$GRAVITY_DIR"/fineweb_val_*.bin 1>/dev/null 2>&1; then + TRAIN_COUNT=$(ls "$GRAVITY_DIR"/fineweb_train_*.bin 2>/dev/null | wc -l) + echo " Found $TRAIN_COUNT existing gravity shards — skipping retokenization." +else + python3 retokenize_corpus.py \ + --base-tokenizer "$STOCK_TOK" \ + --gravity-tokenizer "$GRAVITY_TOK" \ + --data-dir "$STOCK_DIR" \ + --output-dir "$GRAVITY_DIR" + echo " Retokenization complete." +fi + +# ------------------------------------------------------------------------------- +# Verification +# ------------------------------------------------------------------------------- +echo "" +echo "----------------------------------------------" +echo " Verification" +echo "----------------------------------------------" + +python3 -c " +import os, glob, numpy as np, sentencepiece as spm + +train = sorted(glob.glob('$GRAVITY_DIR/fineweb_train_*.bin')) +val = sorted(glob.glob('$GRAVITY_DIR/fineweb_val_*.bin')) +print(f'Gravity train shards : {len(train)}') +print(f'Gravity val shards : {len(val)}') + +total_tokens = 0 +for f in train + val: + header = np.fromfile(f, dtype=' Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-25_GravityTokenizer_AblationLeverage/train_seed137.log b/records/track_10min_16mb/2026-03-25_GravityTokenizer_AblationLeverage/train_seed137.log new file mode 100644 index 0000000000..daf3f12a2e --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_GravityTokenizer_AblationLeverage/train_seed137.log @@ -0,0 +1,1252 @@ +W0325 16:42:23.668000 35526 torch/distributed/run.py:803] +W0325 16:42:23.668000 35526 torch/distributed/run.py:803] ***************************************** +W0325 16:42:23.668000 35526 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0325 16:42:23.668000 35526 torch/distributed/run.py:803] ***************************************** +logs/gravity_12L_seed137.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/gravity_beta_1.0.model +train_loader:dataset:fineweb_gravity_beta_1.0 train_shards:39 +val_loader:shards pattern=./data/datasets/fineweb_gravity_beta_1.0/fineweb_val_*.bin tokens:144967680 +model_params:15749448 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:6 num_kv_heads:2 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:2048 iterations:11000 warmup_steps:50 max_wallclock_seconds:600.000 +seed:137 +warmup_step:10/50 +warmup_step:20/50 +warmup_step:30/50 +warmup_step:40/50 +warmup_step:50/50 +step:0/11000 val_loss:6.9027 val_bpb:7.7519 train_time:0ms step_avg:0.02ms +step:1/11000 train_loss:6.9015 train_time:45ms step_avg:45.15ms +step:2/11000 train_loss:10.0140 train_time:93ms step_avg:46.33ms +step:3/11000 train_loss:6.0481 train_time:146ms step_avg:48.78ms +step:4/11000 train_loss:5.5802 train_time:199ms step_avg:49.84ms +step:5/11000 train_loss:4.7164 train_time:251ms step_avg:50.21ms +step:6/11000 train_loss:3.9876 train_time:305ms step_avg:50.75ms +step:7/11000 train_loss:3.6995 train_time:358ms step_avg:51.16ms +step:8/11000 train_loss:3.4643 train_time:412ms step_avg:51.46ms +step:9/11000 train_loss:3.2888 train_time:465ms step_avg:51.67ms +step:10/11000 train_loss:3.3499 train_time:519ms step_avg:51.87ms +step:100/11000 train_loss:1.6871 train_time:5334ms step_avg:53.34ms +step:200/11000 train_loss:1.4298 train_time:10679ms step_avg:53.40ms +step:300/11000 train_loss:1.3170 train_time:16024ms step_avg:53.41ms +step:400/11000 train_loss:1.1453 train_time:21365ms step_avg:53.41ms +step:500/11000 train_loss:1.1720 train_time:26847ms step_avg:53.69ms +step:600/11000 train_loss:1.1508 train_time:32190ms step_avg:53.65ms +step:700/11000 train_loss:1.0746 train_time:37539ms step_avg:53.63ms +step:800/11000 train_loss:1.0287 train_time:42891ms step_avg:53.61ms +step:900/11000 train_loss:1.1210 train_time:48420ms step_avg:53.80ms +step:1000/11000 train_loss:1.0061 train_time:53778ms step_avg:53.78ms +step:1000/11000 val_loss:1.0826 val_bpb:1.2157 train_time:53797ms step_avg:53.80ms +step:1100/11000 train_loss:0.9692 train_time:59133ms step_avg:53.76ms +step:1200/11000 train_loss:0.9277 train_time:64487ms step_avg:53.74ms +step:1300/11000 train_loss:1.0662 train_time:69835ms step_avg:53.72ms +step:1400/11000 train_loss:0.9869 train_time:75339ms step_avg:53.81ms +step:1500/11000 train_loss:1.0225 train_time:80692ms step_avg:53.79ms +step:1600/11000 train_loss:1.0480 train_time:86042ms step_avg:53.78ms +step:1700/11000 train_loss:1.0223 train_time:91398ms step_avg:53.76ms +step:1800/11000 train_loss:1.0955 train_time:96890ms step_avg:53.83ms +step:1900/11000 train_loss:1.0207 train_time:102240ms step_avg:53.81ms +step:2000/11000 train_loss:1.0712 train_time:107594ms step_avg:53.80ms +step:2000/11000 val_loss:1.0218 val_bpb:1.1475 train_time:107612ms step_avg:53.81ms +step:2100/11000 train_loss:1.0524 train_time:112943ms step_avg:53.78ms +step:2200/11000 train_loss:1.0018 train_time:118287ms step_avg:53.77ms +step:2300/11000 train_loss:0.9577 train_time:123770ms step_avg:53.81ms +step:2400/11000 train_loss:1.0102 train_time:129117ms step_avg:53.80ms +step:2500/11000 train_loss:1.0177 train_time:134459ms step_avg:53.78ms +step:2600/11000 train_loss:0.8945 train_time:139805ms step_avg:53.77ms +step:2700/11000 train_loss:1.0580 train_time:145283ms step_avg:53.81ms +step:2800/11000 train_loss:1.0015 train_time:150628ms step_avg:53.80ms +step:2900/11000 train_loss:1.0141 train_time:155974ms step_avg:53.78ms +step:3000/11000 train_loss:0.8697 train_time:161318ms step_avg:53.77ms +step:3000/11000 val_loss:0.9966 val_bpb:1.1192 train_time:161336ms step_avg:53.78ms +step:3100/11000 train_loss:1.0070 train_time:166659ms step_avg:53.76ms +step:3200/11000 train_loss:1.0017 train_time:172136ms step_avg:53.79ms +step:3300/11000 train_loss:0.9322 train_time:177474ms step_avg:53.78ms +step:3400/11000 train_loss:0.9904 train_time:182809ms step_avg:53.77ms +step:3500/11000 train_loss:0.9565 train_time:188146ms step_avg:53.76ms +step:3600/11000 train_loss:1.0013 train_time:193619ms step_avg:53.78ms +step:3700/11000 train_loss:0.8850 train_time:198954ms step_avg:53.77ms +step:3800/11000 train_loss:0.9380 train_time:204293ms step_avg:53.76ms +step:3900/11000 train_loss:1.0311 train_time:209630ms step_avg:53.75ms +step:4000/11000 train_loss:0.9945 train_time:214964ms step_avg:53.74ms +step:4000/11000 val_loss:0.9818 val_bpb:1.1026 train_time:214983ms step_avg:53.75ms +step:4100/11000 train_loss:0.9948 train_time:220450ms step_avg:53.77ms +step:4200/11000 train_loss:1.0060 train_time:225789ms step_avg:53.76ms +step:4300/11000 train_loss:0.9372 train_time:231124ms step_avg:53.75ms +step:4400/11000 train_loss:0.9707 train_time:236458ms step_avg:53.74ms +step:4500/11000 train_loss:0.9610 train_time:241946ms step_avg:53.77ms +step:4600/11000 train_loss:1.0711 train_time:247284ms step_avg:53.76ms +step:4700/11000 train_loss:0.9345 train_time:252619ms step_avg:53.75ms +step:4800/11000 train_loss:0.9821 train_time:257960ms step_avg:53.74ms +step:4900/11000 train_loss:0.9275 train_time:263293ms step_avg:53.73ms +step:5000/11000 train_loss:0.9554 train_time:268767ms step_avg:53.75ms +step:5000/11000 val_loss:0.9691 val_bpb:1.0884 train_time:268786ms step_avg:53.76ms +step:5100/11000 train_loss:0.9425 train_time:274102ms step_avg:53.75ms +step:5200/11000 train_loss:0.9844 train_time:279439ms step_avg:53.74ms +step:5300/11000 train_loss:0.9411 train_time:284778ms step_avg:53.73ms +step:5400/11000 train_loss:0.9712 train_time:290257ms step_avg:53.75ms +step:5500/11000 train_loss:0.9717 train_time:295594ms step_avg:53.74ms +step:5600/11000 train_loss:0.9696 train_time:300928ms step_avg:53.74ms +step:5700/11000 train_loss:0.9404 train_time:306269ms step_avg:53.73ms +step:5800/11000 train_loss:1.1172 train_time:311606ms step_avg:53.73ms +step:5900/11000 train_loss:0.9888 train_time:317095ms step_avg:53.74ms +step:6000/11000 train_loss:0.9612 train_time:322430ms step_avg:53.74ms +step:6000/11000 val_loss:0.9630 val_bpb:1.0815 train_time:322449ms step_avg:53.74ms +step:6100/11000 train_loss:0.9633 train_time:327772ms step_avg:53.73ms +step:6200/11000 train_loss:0.9357 train_time:333107ms step_avg:53.73ms +step:6300/11000 train_loss:1.0104 train_time:338574ms step_avg:53.74ms +step:6400/11000 train_loss:0.9340 train_time:343911ms step_avg:53.74ms +step:6500/11000 train_loss:0.8732 train_time:349246ms step_avg:53.73ms +step:6600/11000 train_loss:0.9453 train_time:354582ms step_avg:53.72ms +step:6700/11000 train_loss:0.9060 train_time:360063ms step_avg:53.74ms +step:6800/11000 train_loss:0.9707 train_time:365396ms step_avg:53.73ms +step:6900/11000 train_loss:0.9686 train_time:370732ms step_avg:53.73ms +step:7000/11000 train_loss:0.9265 train_time:376069ms step_avg:53.72ms +step:7000/11000 val_loss:0.9557 val_bpb:1.0733 train_time:376087ms step_avg:53.73ms +step:7100/11000 train_loss:0.9525 train_time:381515ms step_avg:53.73ms +step:7200/11000 train_loss:0.9774 train_time:386996ms step_avg:53.75ms +step:7300/11000 train_loss:0.9158 train_time:392335ms step_avg:53.74ms +step:7400/11000 train_loss:0.9200 train_time:397671ms step_avg:53.74ms +step:7500/11000 train_loss:0.9443 train_time:403009ms step_avg:53.73ms +step:7600/11000 train_loss:1.0086 train_time:408495ms step_avg:53.75ms +step:7700/11000 train_loss:0.9273 train_time:413830ms step_avg:53.74ms +step:7800/11000 train_loss:0.9577 train_time:419168ms step_avg:53.74ms +step:7900/11000 train_loss:0.9130 train_time:424517ms step_avg:53.74ms +step:8000/11000 train_loss:0.9957 train_time:429849ms step_avg:53.73ms +step:8000/11000 val_loss:0.9516 val_bpb:1.0687 train_time:429867ms step_avg:53.73ms +step:8100/11000 train_loss:0.9742 train_time:435333ms step_avg:53.74ms +step:8200/11000 train_loss:0.8971 train_time:440668ms step_avg:53.74ms +step:8300/11000 train_loss:0.8773 train_time:446007ms step_avg:53.74ms +step:8400/11000 train_loss:0.9756 train_time:451346ms step_avg:53.73ms +step:8500/11000 train_loss:0.9300 train_time:456828ms step_avg:53.74ms +step:8600/11000 train_loss:0.9641 train_time:462177ms step_avg:53.74ms +step:8700/11000 train_loss:0.8957 train_time:467515ms step_avg:53.74ms +step:8800/11000 train_loss:0.9576 train_time:472855ms step_avg:53.73ms +step:8900/11000 train_loss:0.9415 train_time:478191ms step_avg:53.73ms +step:9000/11000 train_loss:0.9448 train_time:483678ms step_avg:53.74ms +step:9000/11000 val_loss:0.9475 val_bpb:1.0641 train_time:483696ms step_avg:53.74ms +step:9100/11000 train_loss:0.8925 train_time:489016ms step_avg:53.74ms +step:9200/11000 train_loss:0.9727 train_time:494355ms step_avg:53.73ms +step:9300/11000 train_loss:0.9196 train_time:499692ms step_avg:53.73ms +step:9400/11000 train_loss:0.9820 train_time:505165ms step_avg:53.74ms +step:9500/11000 train_loss:0.8846 train_time:510499ms step_avg:53.74ms +step:9600/11000 train_loss:1.1223 train_time:515835ms step_avg:53.73ms +step:9700/11000 train_loss:0.9306 train_time:521170ms step_avg:53.73ms +step:9800/11000 train_loss:0.8593 train_time:526509ms step_avg:53.73ms +step:9900/11000 train_loss:0.8755 train_time:531987ms step_avg:53.74ms +step:10000/11000 train_loss:0.9427 train_time:537321ms step_avg:53.73ms +step:10000/11000 val_loss:0.9298 val_bpb:1.0442 train_time:537339ms step_avg:53.73ms +step:10100/11000 train_loss:0.9607 train_time:542662ms step_avg:53.73ms +step:10200/11000 train_loss:0.9080 train_time:547999ms step_avg:53.73ms +step:10300/11000 train_loss:0.9505 train_time:553479ms step_avg:53.74ms +step:10400/11000 train_loss:0.9323 train_time:558817ms step_avg:53.73ms +step:10500/11000 train_loss:0.8425 train_time:564152ms step_avg:53.73ms +step:10600/11000 train_loss:0.8268 train_time:569487ms step_avg:53.73ms +step:10700/11000 train_loss:1.0262 train_time:574822ms step_avg:53.72ms +step:10800/11000 train_loss:0.9165 train_time:580288ms step_avg:53.73ms +step:10900/11000 train_loss:0.8800 train_time:585630ms step_avg:53.73ms +step:11000/11000 train_loss:0.9361 train_time:590962ms step_avg:53.72ms +step:11000/11000 val_loss:0.9171 val_bpb:1.0299 train_time:590980ms step_avg:53.73ms +peak memory allocated: 11113 MiB reserved: 11382 MiB +Serialized model: 62256267 bytes +Code size: 47686 bytes +Total submission size: 62303953 bytes +Serialized model int8+zlib: 15577509 bytes (payload:17048864 raw_torch:17101673 payload_ratio:3.65x) +Total submission size int8+zlib: 15625195 bytes +final_int8_zlib_roundtrip val_loss:0.9190 val_bpb:1.0321 eval_time:4004ms +final_int8_zlib_roundtrip_exact val_loss:0.91900288 val_bpb:1.03206516 + + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Wed Mar 25 16:42:36 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 36C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 32C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 30C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 34C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 36C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 35C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 31C P0 113W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/gravity_beta_1.0.model +train_loader:dataset:fineweb_gravity_beta_1.0 train_shards:39 +val_loader:shards pattern=./data/datasets/fineweb_gravity_beta_1.0/fineweb_val_*.bin tokens:144967680 +model_params:15749448 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:6 num_kv_heads:2 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:2048 iterations:11000 warmup_steps:50 max_wallclock_seconds:600.000 +seed:137 +warmup_step:10/50 +warmup_step:20/50 +warmup_step:30/50 +warmup_step:40/50 +warmup_step:50/50 +step:0/11000 val_loss:6.9027 val_bpb:7.7519 train_time:0ms step_avg:0.02ms +step:1/11000 train_loss:6.9015 train_time:45ms step_avg:45.15ms +step:2/11000 train_loss:10.0140 train_time:93ms step_avg:46.33ms +step:3/11000 train_loss:6.0481 train_time:146ms step_avg:48.78ms +step:4/11000 train_loss:5.5802 train_time:199ms step_avg:49.84ms +step:5/11000 train_loss:4.7164 train_time:251ms step_avg:50.21ms +step:6/11000 train_loss:3.9876 train_time:305ms step_avg:50.75ms +step:7/11000 train_loss:3.6995 train_time:358ms step_avg:51.16ms +step:8/11000 train_loss:3.4643 train_time:412ms step_avg:51.46ms +step:9/11000 train_loss:3.2888 train_time:465ms step_avg:51.67ms +step:10/11000 train_loss:3.3499 train_time:519ms step_avg:51.87ms +step:100/11000 train_loss:1.6871 train_time:5334ms step_avg:53.34ms +step:200/11000 train_loss:1.4298 train_time:10679ms step_avg:53.40ms +step:300/11000 train_loss:1.3170 train_time:16024ms step_avg:53.41ms +step:400/11000 train_loss:1.1453 train_time:21365ms step_avg:53.41ms +step:500/11000 train_loss:1.1720 train_time:26847ms step_avg:53.69ms +step:600/11000 train_loss:1.1508 train_time:32190ms step_avg:53.65ms +step:700/11000 train_loss:1.0746 train_time:37539ms step_avg:53.63ms +step:800/11000 train_loss:1.0287 train_time:42891ms step_avg:53.61ms +step:900/11000 train_loss:1.1210 train_time:48420ms step_avg:53.80ms +step:1000/11000 train_loss:1.0061 train_time:53778ms step_avg:53.78ms +step:1000/11000 val_loss:1.0826 val_bpb:1.2157 train_time:53797ms step_avg:53.80ms +step:1100/11000 train_loss:0.9692 train_time:59133ms step_avg:53.76ms +step:1200/11000 train_loss:0.9277 train_time:64487ms step_avg:53.74ms +step:1300/11000 train_loss:1.0662 train_time:69835ms step_avg:53.72ms +step:1400/11000 train_loss:0.9869 train_time:75339ms step_avg:53.81ms +step:1500/11000 train_loss:1.0225 train_time:80692ms step_avg:53.79ms +step:1600/11000 train_loss:1.0480 train_time:86042ms step_avg:53.78ms +step:1700/11000 train_loss:1.0223 train_time:91398ms step_avg:53.76ms +step:1800/11000 train_loss:1.0955 train_time:96890ms step_avg:53.83ms +step:1900/11000 train_loss:1.0207 train_time:102240ms step_avg:53.81ms +step:2000/11000 train_loss:1.0712 train_time:107594ms step_avg:53.80ms +step:2000/11000 val_loss:1.0218 val_bpb:1.1475 train_time:107612ms step_avg:53.81ms +step:2100/11000 train_loss:1.0524 train_time:112943ms step_avg:53.78ms +step:2200/11000 train_loss:1.0018 train_time:118287ms step_avg:53.77ms +step:2300/11000 train_loss:0.9577 train_time:123770ms step_avg:53.81ms +step:2400/11000 train_loss:1.0102 train_time:129117ms step_avg:53.80ms +step:2500/11000 train_loss:1.0177 train_time:134459ms step_avg:53.78ms +step:2600/11000 train_loss:0.8945 train_time:139805ms step_avg:53.77ms +step:2700/11000 train_loss:1.0580 train_time:145283ms step_avg:53.81ms +step:2800/11000 train_loss:1.0015 train_time:150628ms step_avg:53.80ms +step:2900/11000 train_loss:1.0141 train_time:155974ms step_avg:53.78ms +step:3000/11000 train_loss:0.8697 train_time:161318ms step_avg:53.77ms +step:3000/11000 val_loss:0.9966 val_bpb:1.1192 train_time:161336ms step_avg:53.78ms +step:3100/11000 train_loss:1.0070 train_time:166659ms step_avg:53.76ms +step:3200/11000 train_loss:1.0017 train_time:172136ms step_avg:53.79ms +step:3300/11000 train_loss:0.9322 train_time:177474ms step_avg:53.78ms +step:3400/11000 train_loss:0.9904 train_time:182809ms step_avg:53.77ms +step:3500/11000 train_loss:0.9565 train_time:188146ms step_avg:53.76ms +step:3600/11000 train_loss:1.0013 train_time:193619ms step_avg:53.78ms +step:3700/11000 train_loss:0.8850 train_time:198954ms step_avg:53.77ms +step:3800/11000 train_loss:0.9380 train_time:204293ms step_avg:53.76ms +step:3900/11000 train_loss:1.0311 train_time:209630ms step_avg:53.75ms +step:4000/11000 train_loss:0.9945 train_time:214964ms step_avg:53.74ms +step:4000/11000 val_loss:0.9818 val_bpb:1.1026 train_time:214983ms step_avg:53.75ms +step:4100/11000 train_loss:0.9948 train_time:220450ms step_avg:53.77ms +step:4200/11000 train_loss:1.0060 train_time:225789ms step_avg:53.76ms +step:4300/11000 train_loss:0.9372 train_time:231124ms step_avg:53.75ms +step:4400/11000 train_loss:0.9707 train_time:236458ms step_avg:53.74ms +step:4500/11000 train_loss:0.9610 train_time:241946ms step_avg:53.77ms +step:4600/11000 train_loss:1.0711 train_time:247284ms step_avg:53.76ms +step:4700/11000 train_loss:0.9345 train_time:252619ms step_avg:53.75ms +step:4800/11000 train_loss:0.9821 train_time:257960ms step_avg:53.74ms +step:4900/11000 train_loss:0.9275 train_time:263293ms step_avg:53.73ms +step:5000/11000 train_loss:0.9554 train_time:268767ms step_avg:53.75ms +step:5000/11000 val_loss:0.9691 val_bpb:1.0884 train_time:268786ms step_avg:53.76ms +step:5100/11000 train_loss:0.9425 train_time:274102ms step_avg:53.75ms +step:5200/11000 train_loss:0.9844 train_time:279439ms step_avg:53.74ms +step:5300/11000 train_loss:0.9411 train_time:284778ms step_avg:53.73ms +step:5400/11000 train_loss:0.9712 train_time:290257ms step_avg:53.75ms +step:5500/11000 train_loss:0.9717 train_time:295594ms step_avg:53.74ms +step:5600/11000 train_loss:0.9696 train_time:300928ms step_avg:53.74ms +step:5700/11000 train_loss:0.9404 train_time:306269ms step_avg:53.73ms +step:5800/11000 train_loss:1.1172 train_time:311606ms step_avg:53.73ms +step:5900/11000 train_loss:0.9888 train_time:317095ms step_avg:53.74ms +step:6000/11000 train_loss:0.9612 train_time:322430ms step_avg:53.74ms +step:6000/11000 val_loss:0.9630 val_bpb:1.0815 train_time:322449ms step_avg:53.74ms +step:6100/11000 train_loss:0.9633 train_time:327772ms step_avg:53.73ms +step:6200/11000 train_loss:0.9357 train_time:333107ms step_avg:53.73ms +step:6300/11000 train_loss:1.0104 train_time:338574ms step_avg:53.74ms +step:6400/11000 train_loss:0.9340 train_time:343911ms step_avg:53.74ms +step:6500/11000 train_loss:0.8732 train_time:349246ms step_avg:53.73ms +step:6600/11000 train_loss:0.9453 train_time:354582ms step_avg:53.72ms +step:6700/11000 train_loss:0.9060 train_time:360063ms step_avg:53.74ms +step:6800/11000 train_loss:0.9707 train_time:365396ms step_avg:53.73ms +step:6900/11000 train_loss:0.9686 train_time:370732ms step_avg:53.73ms +step:7000/11000 train_loss:0.9265 train_time:376069ms step_avg:53.72ms +step:7000/11000 val_loss:0.9557 val_bpb:1.0733 train_time:376087ms step_avg:53.73ms +step:7100/11000 train_loss:0.9525 train_time:381515ms step_avg:53.73ms +step:7200/11000 train_loss:0.9774 train_time:386996ms step_avg:53.75ms +step:7300/11000 train_loss:0.9158 train_time:392335ms step_avg:53.74ms +step:7400/11000 train_loss:0.9200 train_time:397671ms step_avg:53.74ms +step:7500/11000 train_loss:0.9443 train_time:403009ms step_avg:53.73ms +step:7600/11000 train_loss:1.0086 train_time:408495ms step_avg:53.75ms +step:7700/11000 train_loss:0.9273 train_time:413830ms step_avg:53.74ms +step:7800/11000 train_loss:0.9577 train_time:419168ms step_avg:53.74ms +step:7900/11000 train_loss:0.9130 train_time:424517ms step_avg:53.74ms +step:8000/11000 train_loss:0.9957 train_time:429849ms step_avg:53.73ms +step:8000/11000 val_loss:0.9516 val_bpb:1.0687 train_time:429867ms step_avg:53.73ms +step:8100/11000 train_loss:0.9742 train_time:435333ms step_avg:53.74ms +step:8200/11000 train_loss:0.8971 train_time:440668ms step_avg:53.74ms +step:8300/11000 train_loss:0.8773 train_time:446007ms step_avg:53.74ms +step:8400/11000 train_loss:0.9756 train_time:451346ms step_avg:53.73ms +step:8500/11000 train_loss:0.9300 train_time:456828ms step_avg:53.74ms +step:8600/11000 train_loss:0.9641 train_time:462177ms step_avg:53.74ms +step:8700/11000 train_loss:0.8957 train_time:467515ms step_avg:53.74ms +step:8800/11000 train_loss:0.9576 train_time:472855ms step_avg:53.73ms +step:8900/11000 train_loss:0.9415 train_time:478191ms step_avg:53.73ms +step:9000/11000 train_loss:0.9448 train_time:483678ms step_avg:53.74ms +step:9000/11000 val_loss:0.9475 val_bpb:1.0641 train_time:483696ms step_avg:53.74ms +step:9100/11000 train_loss:0.8925 train_time:489016ms step_avg:53.74ms +step:9200/11000 train_loss:0.9727 train_time:494355ms step_avg:53.73ms +step:9300/11000 train_loss:0.9196 train_time:499692ms step_avg:53.73ms +step:9400/11000 train_loss:0.9820 train_time:505165ms step_avg:53.74ms +step:9500/11000 train_loss:0.8846 train_time:510499ms step_avg:53.74ms +step:9600/11000 train_loss:1.1223 train_time:515835ms step_avg:53.73ms +step:9700/11000 train_loss:0.9306 train_time:521170ms step_avg:53.73ms +step:9800/11000 train_loss:0.8593 train_time:526509ms step_avg:53.73ms +step:9900/11000 train_loss:0.8755 train_time:531987ms step_avg:53.74ms +step:10000/11000 train_loss:0.9427 train_time:537321ms step_avg:53.73ms +step:10000/11000 val_loss:0.9298 val_bpb:1.0442 train_time:537339ms step_avg:53.73ms +step:10100/11000 train_loss:0.9607 train_time:542662ms step_avg:53.73ms +step:10200/11000 train_loss:0.9080 train_time:547999ms step_avg:53.73ms +step:10300/11000 train_loss:0.9505 train_time:553479ms step_avg:53.74ms +step:10400/11000 train_loss:0.9323 train_time:558817ms step_avg:53.73ms +step:10500/11000 train_loss:0.8425 train_time:564152ms step_avg:53.73ms +step:10600/11000 train_loss:0.8268 train_time:569487ms step_avg:53.73ms +step:10700/11000 train_loss:1.0262 train_time:574822ms step_avg:53.72ms +step:10800/11000 train_loss:0.9165 train_time:580288ms step_avg:53.73ms +step:10900/11000 train_loss:0.8800 train_time:585630ms step_avg:53.73ms +step:11000/11000 train_loss:0.9361 train_time:590962ms step_avg:53.72ms +step:11000/11000 val_loss:0.9171 val_bpb:1.0299 train_time:590980ms step_avg:53.73ms +peak memory allocated: 11113 MiB reserved: 11382 MiB +Serialized model: 62256267 bytes +Code size: 47686 bytes +Total submission size: 62303953 bytes +Serialized model int8+zlib: 15577509 bytes (payload:17048864 raw_torch:17101673 payload_ratio:3.65x) +Total submission size int8+zlib: 15625195 bytes +final_int8_zlib_roundtrip val_loss:0.9190 val_bpb:1.0321 eval_time:4004ms +final_int8_zlib_roundtrip_exact val_loss:0.91900288 val_bpb:1.03206516 + diff --git a/records/track_10min_16mb/2026-03-25_GravityTokenizer_AblationLeverage/train_seed3.log b/records/track_10min_16mb/2026-03-25_GravityTokenizer_AblationLeverage/train_seed3.log new file mode 100644 index 0000000000..9bb61debe1 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_GravityTokenizer_AblationLeverage/train_seed3.log @@ -0,0 +1,1251 @@ +W0325 17:50:19.707000 38160 torch/distributed/run.py:803] +W0325 17:50:19.707000 38160 torch/distributed/run.py:803] ***************************************** +W0325 17:50:19.707000 38160 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0325 17:50:19.707000 38160 torch/distributed/run.py:803] ***************************************** +logs/gravity_12L_seed3.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/gravity_beta_1.0.model +train_loader:dataset:fineweb_gravity_beta_1.0 train_shards:39 +val_loader:shards pattern=./data/datasets/fineweb_gravity_beta_1.0/fineweb_val_*.bin tokens:144967680 +model_params:15749448 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:6 num_kv_heads:2 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:2048 iterations:11000 warmup_steps:50 max_wallclock_seconds:600.000 +seed:3 +warmup_step:10/50 +warmup_step:20/50 +warmup_step:30/50 +warmup_step:40/50 +warmup_step:50/50 +step:0/11000 val_loss:6.9138 val_bpb:7.7643 train_time:0ms step_avg:0.03ms +step:1/11000 train_loss:6.9128 train_time:47ms step_avg:46.60ms +step:2/11000 train_loss:10.0901 train_time:95ms step_avg:47.55ms +step:3/11000 train_loss:6.0162 train_time:148ms step_avg:49.43ms +step:4/11000 train_loss:5.3484 train_time:201ms step_avg:50.32ms +step:5/11000 train_loss:4.5925 train_time:255ms step_avg:51.01ms +step:6/11000 train_loss:3.9333 train_time:309ms step_avg:51.47ms +step:7/11000 train_loss:3.6710 train_time:365ms step_avg:52.09ms +step:8/11000 train_loss:3.4846 train_time:419ms step_avg:52.39ms +step:9/11000 train_loss:3.3972 train_time:474ms step_avg:52.62ms +step:10/11000 train_loss:3.3632 train_time:524ms step_avg:52.41ms +step:100/11000 train_loss:1.6581 train_time:5350ms step_avg:53.50ms +step:200/11000 train_loss:1.4248 train_time:10699ms step_avg:53.49ms +step:300/11000 train_loss:1.2938 train_time:16048ms step_avg:53.49ms +step:400/11000 train_loss:1.1488 train_time:21395ms step_avg:53.49ms +step:500/11000 train_loss:1.1716 train_time:26871ms step_avg:53.74ms +step:600/11000 train_loss:1.1536 train_time:32216ms step_avg:53.69ms +step:700/11000 train_loss:1.0812 train_time:37561ms step_avg:53.66ms +step:800/11000 train_loss:1.0326 train_time:42915ms step_avg:53.64ms +step:900/11000 train_loss:1.1281 train_time:48387ms step_avg:53.76ms +step:1000/11000 train_loss:1.0081 train_time:53737ms step_avg:53.74ms +step:1000/11000 val_loss:1.0846 val_bpb:1.2180 train_time:53755ms step_avg:53.75ms +step:1100/11000 train_loss:0.9719 train_time:59087ms step_avg:53.72ms +step:1200/11000 train_loss:0.9293 train_time:64430ms step_avg:53.69ms +step:1300/11000 train_loss:1.0667 train_time:69779ms step_avg:53.68ms +step:1400/11000 train_loss:0.9867 train_time:75266ms step_avg:53.76ms +step:1500/11000 train_loss:1.0260 train_time:80616ms step_avg:53.74ms +step:1600/11000 train_loss:1.0494 train_time:85965ms step_avg:53.73ms +step:1700/11000 train_loss:1.0281 train_time:91312ms step_avg:53.71ms +step:1800/11000 train_loss:1.0998 train_time:96789ms step_avg:53.77ms +step:1900/11000 train_loss:1.0214 train_time:102136ms step_avg:53.76ms +step:2000/11000 train_loss:1.0733 train_time:107481ms step_avg:53.74ms +step:2000/11000 val_loss:1.0240 val_bpb:1.1499 train_time:107498ms step_avg:53.75ms +step:2100/11000 train_loss:1.0557 train_time:112827ms step_avg:53.73ms +step:2200/11000 train_loss:1.0020 train_time:118176ms step_avg:53.72ms +step:2300/11000 train_loss:0.9574 train_time:123664ms step_avg:53.77ms +step:2400/11000 train_loss:1.0131 train_time:129007ms step_avg:53.75ms +step:2500/11000 train_loss:1.0222 train_time:134347ms step_avg:53.74ms +step:2600/11000 train_loss:0.8984 train_time:139689ms step_avg:53.73ms +step:2700/11000 train_loss:1.0605 train_time:145164ms step_avg:53.76ms +step:2800/11000 train_loss:1.0038 train_time:150506ms step_avg:53.75ms +step:2900/11000 train_loss:1.0156 train_time:155848ms step_avg:53.74ms +step:3000/11000 train_loss:0.8702 train_time:161187ms step_avg:53.73ms +step:3000/11000 val_loss:0.9986 val_bpb:1.1215 train_time:161205ms step_avg:53.74ms +step:3100/11000 train_loss:1.0093 train_time:166530ms step_avg:53.72ms +step:3200/11000 train_loss:1.0002 train_time:172021ms step_avg:53.76ms +step:3300/11000 train_loss:0.9371 train_time:177362ms step_avg:53.75ms +step:3400/11000 train_loss:0.9917 train_time:182702ms step_avg:53.74ms +step:3500/11000 train_loss:0.9571 train_time:188043ms step_avg:53.73ms +step:3600/11000 train_loss:1.0061 train_time:193525ms step_avg:53.76ms +step:3700/11000 train_loss:0.8877 train_time:198872ms step_avg:53.75ms +step:3800/11000 train_loss:0.9357 train_time:204213ms step_avg:53.74ms +step:3900/11000 train_loss:1.0379 train_time:209552ms step_avg:53.73ms +step:4000/11000 train_loss:0.9968 train_time:214893ms step_avg:53.72ms +step:4000/11000 val_loss:0.9836 val_bpb:1.1046 train_time:214911ms step_avg:53.73ms +step:4100/11000 train_loss:0.9965 train_time:220374ms step_avg:53.75ms +step:4200/11000 train_loss:1.0058 train_time:225713ms step_avg:53.74ms +step:4300/11000 train_loss:0.9384 train_time:231056ms step_avg:53.73ms +step:4400/11000 train_loss:0.9732 train_time:236400ms step_avg:53.73ms +step:4500/11000 train_loss:0.9623 train_time:241889ms step_avg:53.75ms +step:4600/11000 train_loss:1.0722 train_time:247229ms step_avg:53.75ms +step:4700/11000 train_loss:0.9358 train_time:252567ms step_avg:53.74ms +step:4800/11000 train_loss:0.9835 train_time:257905ms step_avg:53.73ms +step:4900/11000 train_loss:0.9285 train_time:263246ms step_avg:53.72ms +step:5000/11000 train_loss:0.9534 train_time:268729ms step_avg:53.75ms +step:5000/11000 val_loss:0.9706 val_bpb:1.0900 train_time:268749ms step_avg:53.75ms +step:5100/11000 train_loss:0.9435 train_time:274070ms step_avg:53.74ms +step:5200/11000 train_loss:0.9865 train_time:279415ms step_avg:53.73ms +step:5300/11000 train_loss:0.9410 train_time:284753ms step_avg:53.73ms +step:5400/11000 train_loss:0.9731 train_time:290233ms step_avg:53.75ms +step:5500/11000 train_loss:0.9721 train_time:295572ms step_avg:53.74ms +step:5600/11000 train_loss:0.9702 train_time:300911ms step_avg:53.73ms +step:5700/11000 train_loss:0.9445 train_time:306252ms step_avg:53.73ms +step:5800/11000 train_loss:1.1205 train_time:311593ms step_avg:53.72ms +step:5900/11000 train_loss:0.9899 train_time:317076ms step_avg:53.74ms +step:6000/11000 train_loss:0.9668 train_time:322414ms step_avg:53.74ms +step:6000/11000 val_loss:0.9645 val_bpb:1.0831 train_time:322432ms step_avg:53.74ms +step:6100/11000 train_loss:0.9652 train_time:327753ms step_avg:53.73ms +step:6200/11000 train_loss:0.9352 train_time:333091ms step_avg:53.72ms +step:6300/11000 train_loss:1.0119 train_time:338568ms step_avg:53.74ms +step:6400/11000 train_loss:0.9357 train_time:343908ms step_avg:53.74ms +step:6500/11000 train_loss:0.8723 train_time:349250ms step_avg:53.73ms +step:6600/11000 train_loss:0.9471 train_time:354587ms step_avg:53.73ms +step:6700/11000 train_loss:0.9093 train_time:360074ms step_avg:53.74ms +step:6800/11000 train_loss:0.9716 train_time:365410ms step_avg:53.74ms +step:6900/11000 train_loss:0.9681 train_time:370748ms step_avg:53.73ms +step:7000/11000 train_loss:0.9290 train_time:376089ms step_avg:53.73ms +step:7000/11000 val_loss:0.9569 val_bpb:1.0746 train_time:376106ms step_avg:53.73ms +step:7100/11000 train_loss:0.9515 train_time:381428ms step_avg:53.72ms +step:7200/11000 train_loss:0.9765 train_time:386904ms step_avg:53.74ms +step:7300/11000 train_loss:0.9165 train_time:392248ms step_avg:53.73ms +step:7400/11000 train_loss:0.9227 train_time:397590ms step_avg:53.73ms +step:7500/11000 train_loss:0.9453 train_time:402926ms step_avg:53.72ms +step:7600/11000 train_loss:1.0107 train_time:408408ms step_avg:53.74ms +step:7700/11000 train_loss:0.9311 train_time:413746ms step_avg:53.73ms +step:7800/11000 train_loss:0.9566 train_time:419081ms step_avg:53.73ms +step:7900/11000 train_loss:0.9141 train_time:424419ms step_avg:53.72ms +step:8000/11000 train_loss:0.9966 train_time:429758ms step_avg:53.72ms +step:8000/11000 val_loss:0.9527 val_bpb:1.0699 train_time:429775ms step_avg:53.72ms +step:8100/11000 train_loss:0.9746 train_time:435243ms step_avg:53.73ms +step:8200/11000 train_loss:0.8956 train_time:440579ms step_avg:53.73ms +step:8300/11000 train_loss:0.8775 train_time:445917ms step_avg:53.72ms +step:8400/11000 train_loss:0.9745 train_time:451260ms step_avg:53.72ms +step:8500/11000 train_loss:0.9331 train_time:456736ms step_avg:53.73ms +step:8600/11000 train_loss:0.9675 train_time:462076ms step_avg:53.73ms +step:8700/11000 train_loss:0.8975 train_time:467414ms step_avg:53.73ms +step:8800/11000 train_loss:0.9560 train_time:472753ms step_avg:53.72ms +step:8900/11000 train_loss:0.9402 train_time:478090ms step_avg:53.72ms +step:9000/11000 train_loss:0.9445 train_time:483574ms step_avg:53.73ms +step:9000/11000 val_loss:0.9491 val_bpb:1.0658 train_time:483593ms step_avg:53.73ms +step:9100/11000 train_loss:0.8941 train_time:488913ms step_avg:53.73ms +step:9200/11000 train_loss:0.9722 train_time:494276ms step_avg:53.73ms +step:9300/11000 train_loss:0.9234 train_time:499695ms step_avg:53.73ms +step:9400/11000 train_loss:0.9815 train_time:505180ms step_avg:53.74ms +step:9500/11000 train_loss:0.8862 train_time:510516ms step_avg:53.74ms +step:9600/11000 train_loss:1.1241 train_time:515858ms step_avg:53.74ms +step:9700/11000 train_loss:0.9322 train_time:521206ms step_avg:53.73ms +step:9800/11000 train_loss:0.8576 train_time:526546ms step_avg:53.73ms +step:9900/11000 train_loss:0.8772 train_time:532037ms step_avg:53.74ms +step:10000/11000 train_loss:0.9450 train_time:537380ms step_avg:53.74ms +step:10000/11000 val_loss:0.9307 val_bpb:1.0452 train_time:537394ms step_avg:53.74ms +step:10100/11000 train_loss:0.9620 train_time:542717ms step_avg:53.73ms +step:10200/11000 train_loss:0.9077 train_time:548058ms step_avg:53.73ms +step:10300/11000 train_loss:0.9518 train_time:553545ms step_avg:53.74ms +step:10400/11000 train_loss:0.9364 train_time:558886ms step_avg:53.74ms +step:10500/11000 train_loss:0.8425 train_time:564228ms step_avg:53.74ms +step:10600/11000 train_loss:0.8269 train_time:569567ms step_avg:53.73ms +step:10700/11000 train_loss:1.0273 train_time:574905ms step_avg:53.73ms +step:10800/11000 train_loss:0.9178 train_time:580383ms step_avg:53.74ms +step:10900/11000 train_loss:0.8812 train_time:585728ms step_avg:53.74ms +step:11000/11000 train_loss:0.9374 train_time:591064ms step_avg:53.73ms +step:11000/11000 val_loss:0.9181 val_bpb:1.0311 train_time:591082ms step_avg:53.73ms +peak memory allocated: 11113 MiB reserved: 11382 MiB +Serialized model: 62256267 bytes +Code size: 47686 bytes +Total submission size: 62303953 bytes +Serialized model int8+zlib: 15577461 bytes (payload:17048864 raw_torch:17101673 payload_ratio:3.65x) +Total submission size int8+zlib: 15625147 bytes +final_int8_zlib_roundtrip val_loss:0.9199 val_bpb:1.0331 eval_time:3999ms +final_int8_zlib_roundtrip_exact val_loss:0.91991714 val_bpb:1.03309190 +4) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Wed Mar 25 17:50:32 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 37C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 32C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 36C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 37C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 36C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 32C P0 114W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/gravity_beta_1.0.model +train_loader:dataset:fineweb_gravity_beta_1.0 train_shards:39 +val_loader:shards pattern=./data/datasets/fineweb_gravity_beta_1.0/fineweb_val_*.bin tokens:144967680 +model_params:15749448 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:6 num_kv_heads:2 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:2048 iterations:11000 warmup_steps:50 max_wallclock_seconds:600.000 +seed:3 +warmup_step:10/50 +warmup_step:20/50 +warmup_step:30/50 +warmup_step:40/50 +warmup_step:50/50 +step:0/11000 val_loss:6.9138 val_bpb:7.7643 train_time:0ms step_avg:0.03ms +step:1/11000 train_loss:6.9128 train_time:47ms step_avg:46.60ms +step:2/11000 train_loss:10.0901 train_time:95ms step_avg:47.55ms +step:3/11000 train_loss:6.0162 train_time:148ms step_avg:49.43ms +step:4/11000 train_loss:5.3484 train_time:201ms step_avg:50.32ms +step:5/11000 train_loss:4.5925 train_time:255ms step_avg:51.01ms +step:6/11000 train_loss:3.9333 train_time:309ms step_avg:51.47ms +step:7/11000 train_loss:3.6710 train_time:365ms step_avg:52.09ms +step:8/11000 train_loss:3.4846 train_time:419ms step_avg:52.39ms +step:9/11000 train_loss:3.3972 train_time:474ms step_avg:52.62ms +step:10/11000 train_loss:3.3632 train_time:524ms step_avg:52.41ms +step:100/11000 train_loss:1.6581 train_time:5350ms step_avg:53.50ms +step:200/11000 train_loss:1.4248 train_time:10699ms step_avg:53.49ms +step:300/11000 train_loss:1.2938 train_time:16048ms step_avg:53.49ms +step:400/11000 train_loss:1.1488 train_time:21395ms step_avg:53.49ms +step:500/11000 train_loss:1.1716 train_time:26871ms step_avg:53.74ms +step:600/11000 train_loss:1.1536 train_time:32216ms step_avg:53.69ms +step:700/11000 train_loss:1.0812 train_time:37561ms step_avg:53.66ms +step:800/11000 train_loss:1.0326 train_time:42915ms step_avg:53.64ms +step:900/11000 train_loss:1.1281 train_time:48387ms step_avg:53.76ms +step:1000/11000 train_loss:1.0081 train_time:53737ms step_avg:53.74ms +step:1000/11000 val_loss:1.0846 val_bpb:1.2180 train_time:53755ms step_avg:53.75ms +step:1100/11000 train_loss:0.9719 train_time:59087ms step_avg:53.72ms +step:1200/11000 train_loss:0.9293 train_time:64430ms step_avg:53.69ms +step:1300/11000 train_loss:1.0667 train_time:69779ms step_avg:53.68ms +step:1400/11000 train_loss:0.9867 train_time:75266ms step_avg:53.76ms +step:1500/11000 train_loss:1.0260 train_time:80616ms step_avg:53.74ms +step:1600/11000 train_loss:1.0494 train_time:85965ms step_avg:53.73ms +step:1700/11000 train_loss:1.0281 train_time:91312ms step_avg:53.71ms +step:1800/11000 train_loss:1.0998 train_time:96789ms step_avg:53.77ms +step:1900/11000 train_loss:1.0214 train_time:102136ms step_avg:53.76ms +step:2000/11000 train_loss:1.0733 train_time:107481ms step_avg:53.74ms +step:2000/11000 val_loss:1.0240 val_bpb:1.1499 train_time:107498ms step_avg:53.75ms +step:2100/11000 train_loss:1.0557 train_time:112827ms step_avg:53.73ms +step:2200/11000 train_loss:1.0020 train_time:118176ms step_avg:53.72ms +step:2300/11000 train_loss:0.9574 train_time:123664ms step_avg:53.77ms +step:2400/11000 train_loss:1.0131 train_time:129007ms step_avg:53.75ms +step:2500/11000 train_loss:1.0222 train_time:134347ms step_avg:53.74ms +step:2600/11000 train_loss:0.8984 train_time:139689ms step_avg:53.73ms +step:2700/11000 train_loss:1.0605 train_time:145164ms step_avg:53.76ms +step:2800/11000 train_loss:1.0038 train_time:150506ms step_avg:53.75ms +step:2900/11000 train_loss:1.0156 train_time:155848ms step_avg:53.74ms +step:3000/11000 train_loss:0.8702 train_time:161187ms step_avg:53.73ms +step:3000/11000 val_loss:0.9986 val_bpb:1.1215 train_time:161205ms step_avg:53.74ms +step:3100/11000 train_loss:1.0093 train_time:166530ms step_avg:53.72ms +step:3200/11000 train_loss:1.0002 train_time:172021ms step_avg:53.76ms +step:3300/11000 train_loss:0.9371 train_time:177362ms step_avg:53.75ms +step:3400/11000 train_loss:0.9917 train_time:182702ms step_avg:53.74ms +step:3500/11000 train_loss:0.9571 train_time:188043ms step_avg:53.73ms +step:3600/11000 train_loss:1.0061 train_time:193525ms step_avg:53.76ms +step:3700/11000 train_loss:0.8877 train_time:198872ms step_avg:53.75ms +step:3800/11000 train_loss:0.9357 train_time:204213ms step_avg:53.74ms +step:3900/11000 train_loss:1.0379 train_time:209552ms step_avg:53.73ms +step:4000/11000 train_loss:0.9968 train_time:214893ms step_avg:53.72ms +step:4000/11000 val_loss:0.9836 val_bpb:1.1046 train_time:214911ms step_avg:53.73ms +step:4100/11000 train_loss:0.9965 train_time:220374ms step_avg:53.75ms +step:4200/11000 train_loss:1.0058 train_time:225713ms step_avg:53.74ms +step:4300/11000 train_loss:0.9384 train_time:231056ms step_avg:53.73ms +step:4400/11000 train_loss:0.9732 train_time:236400ms step_avg:53.73ms +step:4500/11000 train_loss:0.9623 train_time:241889ms step_avg:53.75ms +step:4600/11000 train_loss:1.0722 train_time:247229ms step_avg:53.75ms +step:4700/11000 train_loss:0.9358 train_time:252567ms step_avg:53.74ms +step:4800/11000 train_loss:0.9835 train_time:257905ms step_avg:53.73ms +step:4900/11000 train_loss:0.9285 train_time:263246ms step_avg:53.72ms +step:5000/11000 train_loss:0.9534 train_time:268729ms step_avg:53.75ms +step:5000/11000 val_loss:0.9706 val_bpb:1.0900 train_time:268749ms step_avg:53.75ms +step:5100/11000 train_loss:0.9435 train_time:274070ms step_avg:53.74ms +step:5200/11000 train_loss:0.9865 train_time:279415ms step_avg:53.73ms +step:5300/11000 train_loss:0.9410 train_time:284753ms step_avg:53.73ms +step:5400/11000 train_loss:0.9731 train_time:290233ms step_avg:53.75ms +step:5500/11000 train_loss:0.9721 train_time:295572ms step_avg:53.74ms +step:5600/11000 train_loss:0.9702 train_time:300911ms step_avg:53.73ms +step:5700/11000 train_loss:0.9445 train_time:306252ms step_avg:53.73ms +step:5800/11000 train_loss:1.1205 train_time:311593ms step_avg:53.72ms +step:5900/11000 train_loss:0.9899 train_time:317076ms step_avg:53.74ms +step:6000/11000 train_loss:0.9668 train_time:322414ms step_avg:53.74ms +step:6000/11000 val_loss:0.9645 val_bpb:1.0831 train_time:322432ms step_avg:53.74ms +step:6100/11000 train_loss:0.9652 train_time:327753ms step_avg:53.73ms +step:6200/11000 train_loss:0.9352 train_time:333091ms step_avg:53.72ms +step:6300/11000 train_loss:1.0119 train_time:338568ms step_avg:53.74ms +step:6400/11000 train_loss:0.9357 train_time:343908ms step_avg:53.74ms +step:6500/11000 train_loss:0.8723 train_time:349250ms step_avg:53.73ms +step:6600/11000 train_loss:0.9471 train_time:354587ms step_avg:53.73ms +step:6700/11000 train_loss:0.9093 train_time:360074ms step_avg:53.74ms +step:6800/11000 train_loss:0.9716 train_time:365410ms step_avg:53.74ms +step:6900/11000 train_loss:0.9681 train_time:370748ms step_avg:53.73ms +step:7000/11000 train_loss:0.9290 train_time:376089ms step_avg:53.73ms +step:7000/11000 val_loss:0.9569 val_bpb:1.0746 train_time:376106ms step_avg:53.73ms +step:7100/11000 train_loss:0.9515 train_time:381428ms step_avg:53.72ms +step:7200/11000 train_loss:0.9765 train_time:386904ms step_avg:53.74ms +step:7300/11000 train_loss:0.9165 train_time:392248ms step_avg:53.73ms +step:7400/11000 train_loss:0.9227 train_time:397590ms step_avg:53.73ms +step:7500/11000 train_loss:0.9453 train_time:402926ms step_avg:53.72ms +step:7600/11000 train_loss:1.0107 train_time:408408ms step_avg:53.74ms +step:7700/11000 train_loss:0.9311 train_time:413746ms step_avg:53.73ms +step:7800/11000 train_loss:0.9566 train_time:419081ms step_avg:53.73ms +step:7900/11000 train_loss:0.9141 train_time:424419ms step_avg:53.72ms +step:8000/11000 train_loss:0.9966 train_time:429758ms step_avg:53.72ms +step:8000/11000 val_loss:0.9527 val_bpb:1.0699 train_time:429775ms step_avg:53.72ms +step:8100/11000 train_loss:0.9746 train_time:435243ms step_avg:53.73ms +step:8200/11000 train_loss:0.8956 train_time:440579ms step_avg:53.73ms +step:8300/11000 train_loss:0.8775 train_time:445917ms step_avg:53.72ms +step:8400/11000 train_loss:0.9745 train_time:451260ms step_avg:53.72ms +step:8500/11000 train_loss:0.9331 train_time:456736ms step_avg:53.73ms +step:8600/11000 train_loss:0.9675 train_time:462076ms step_avg:53.73ms +step:8700/11000 train_loss:0.8975 train_time:467414ms step_avg:53.73ms +step:8800/11000 train_loss:0.9560 train_time:472753ms step_avg:53.72ms +step:8900/11000 train_loss:0.9402 train_time:478090ms step_avg:53.72ms +step:9000/11000 train_loss:0.9445 train_time:483574ms step_avg:53.73ms +step:9000/11000 val_loss:0.9491 val_bpb:1.0658 train_time:483593ms step_avg:53.73ms +step:9100/11000 train_loss:0.8941 train_time:488913ms step_avg:53.73ms +step:9200/11000 train_loss:0.9722 train_time:494276ms step_avg:53.73ms +step:9300/11000 train_loss:0.9234 train_time:499695ms step_avg:53.73ms +step:9400/11000 train_loss:0.9815 train_time:505180ms step_avg:53.74ms +step:9500/11000 train_loss:0.8862 train_time:510516ms step_avg:53.74ms +step:9600/11000 train_loss:1.1241 train_time:515858ms step_avg:53.74ms +step:9700/11000 train_loss:0.9322 train_time:521206ms step_avg:53.73ms +step:9800/11000 train_loss:0.8576 train_time:526546ms step_avg:53.73ms +step:9900/11000 train_loss:0.8772 train_time:532037ms step_avg:53.74ms +step:10000/11000 train_loss:0.9450 train_time:537380ms step_avg:53.74ms +step:10000/11000 val_loss:0.9307 val_bpb:1.0452 train_time:537394ms step_avg:53.74ms +step:10100/11000 train_loss:0.9620 train_time:542717ms step_avg:53.73ms +step:10200/11000 train_loss:0.9077 train_time:548058ms step_avg:53.73ms +step:10300/11000 train_loss:0.9518 train_time:553545ms step_avg:53.74ms +step:10400/11000 train_loss:0.9364 train_time:558886ms step_avg:53.74ms +step:10500/11000 train_loss:0.8425 train_time:564228ms step_avg:53.74ms +step:10600/11000 train_loss:0.8269 train_time:569567ms step_avg:53.73ms +step:10700/11000 train_loss:1.0273 train_time:574905ms step_avg:53.73ms +step:10800/11000 train_loss:0.9178 train_time:580383ms step_avg:53.74ms +step:10900/11000 train_loss:0.8812 train_time:585728ms step_avg:53.74ms +step:11000/11000 train_loss:0.9374 train_time:591064ms step_avg:53.73ms +step:11000/11000 val_loss:0.9181 val_bpb:1.0311 train_time:591082ms step_avg:53.73ms +peak memory allocated: 11113 MiB reserved: 11382 MiB +Serialized model: 62256267 bytes +Code size: 47686 bytes +Total submission size: 62303953 bytes +Serialized model int8+zlib: 15577461 bytes (payload:17048864 raw_torch:17101673 payload_ratio:3.65x) +Total submission size int8+zlib: 15625147 bytes +final_int8_zlib_roundtrip val_loss:0.9199 val_bpb:1.0331 eval_time:3999ms +final_int8_zlib_roundtrip_exact val_loss:0.91991714 val_bpb:1.03309190 diff --git a/records/track_10min_16mb/2026-03-25_GravityTokenizer_AblationLeverage/train_seed42.log b/records/track_10min_16mb/2026-03-25_GravityTokenizer_AblationLeverage/train_seed42.log new file mode 100644 index 0000000000..a56ca651e3 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_GravityTokenizer_AblationLeverage/train_seed42.log @@ -0,0 +1,1252 @@ +W0325 16:28:49.646000 1295 torch/distributed/run.py:803] +W0325 16:28:49.646000 1295 torch/distributed/run.py:803] ***************************************** +W0325 16:28:49.646000 1295 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0325 16:28:49.646000 1295 torch/distributed/run.py:803] ***************************************** +logs/gravity_12L_seed42.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/gravity_beta_1.0.model +train_loader:dataset:fineweb_gravity_beta_1.0 train_shards:39 +val_loader:shards pattern=./data/datasets/fineweb_gravity_beta_1.0/fineweb_val_*.bin tokens:144967680 +model_params:15749448 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:6 num_kv_heads:2 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:2048 iterations:11000 warmup_steps:50 max_wallclock_seconds:600.000 +seed:42 +warmup_step:10/50 +warmup_step:20/50 +warmup_step:30/50 +warmup_step:40/50 +warmup_step:50/50 +step:0/11000 val_loss:6.9259 val_bpb:7.7779 train_time:0ms step_avg:0.02ms +step:1/11000 train_loss:6.9240 train_time:46ms step_avg:46.40ms +step:2/11000 train_loss:10.5258 train_time:96ms step_avg:48.10ms +step:3/11000 train_loss:6.0231 train_time:153ms step_avg:51.03ms +step:4/11000 train_loss:5.4000 train_time:206ms step_avg:51.60ms +step:5/11000 train_loss:4.6453 train_time:260ms step_avg:52.08ms +step:6/11000 train_loss:3.9035 train_time:314ms step_avg:52.33ms +step:7/11000 train_loss:3.7564 train_time:367ms step_avg:52.46ms +step:8/11000 train_loss:3.5644 train_time:421ms step_avg:52.63ms +step:9/11000 train_loss:3.6143 train_time:474ms step_avg:52.68ms +step:10/11000 train_loss:3.2778 train_time:528ms step_avg:52.82ms +step:100/11000 train_loss:1.6411 train_time:5348ms step_avg:53.48ms +step:200/11000 train_loss:1.4209 train_time:10697ms step_avg:53.48ms +step:300/11000 train_loss:1.2945 train_time:16045ms step_avg:53.48ms +step:400/11000 train_loss:1.1435 train_time:21385ms step_avg:53.46ms +step:500/11000 train_loss:1.1755 train_time:26862ms step_avg:53.72ms +step:600/11000 train_loss:1.1559 train_time:32206ms step_avg:53.68ms +step:700/11000 train_loss:1.0774 train_time:37552ms step_avg:53.65ms +step:800/11000 train_loss:1.0312 train_time:42901ms step_avg:53.63ms +step:900/11000 train_loss:1.1268 train_time:48381ms step_avg:53.76ms +step:1000/11000 train_loss:1.0070 train_time:53737ms step_avg:53.74ms +step:1000/11000 val_loss:1.0844 val_bpb:1.2178 train_time:53754ms step_avg:53.75ms +step:1100/11000 train_loss:0.9714 train_time:59093ms step_avg:53.72ms +step:1200/11000 train_loss:0.9298 train_time:64447ms step_avg:53.71ms +step:1300/11000 train_loss:1.0697 train_time:69806ms step_avg:53.70ms +step:1400/11000 train_loss:0.9878 train_time:75298ms step_avg:53.78ms +step:1500/11000 train_loss:1.0250 train_time:80653ms step_avg:53.77ms +step:1600/11000 train_loss:1.0495 train_time:86009ms step_avg:53.76ms +step:1700/11000 train_loss:1.0279 train_time:91360ms step_avg:53.74ms +step:1800/11000 train_loss:1.0972 train_time:96844ms step_avg:53.80ms +step:1900/11000 train_loss:1.0215 train_time:102197ms step_avg:53.79ms +step:2000/11000 train_loss:1.0722 train_time:107551ms step_avg:53.78ms +step:2000/11000 val_loss:1.0231 val_bpb:1.1490 train_time:107568ms step_avg:53.78ms +step:2100/11000 train_loss:1.0539 train_time:112902ms step_avg:53.76ms +step:2200/11000 train_loss:1.0020 train_time:118252ms step_avg:53.75ms +step:2300/11000 train_loss:0.9561 train_time:123738ms step_avg:53.80ms +step:2400/11000 train_loss:1.0105 train_time:129083ms step_avg:53.78ms +step:2500/11000 train_loss:1.0190 train_time:134427ms step_avg:53.77ms +step:2600/11000 train_loss:0.8951 train_time:139769ms step_avg:53.76ms +step:2700/11000 train_loss:1.0603 train_time:145246ms step_avg:53.79ms +step:2800/11000 train_loss:1.0028 train_time:150586ms step_avg:53.78ms +step:2900/11000 train_loss:1.0138 train_time:155927ms step_avg:53.77ms +step:3000/11000 train_loss:0.8721 train_time:161266ms step_avg:53.76ms +step:3000/11000 val_loss:0.9973 val_bpb:1.1200 train_time:161286ms step_avg:53.76ms +step:3100/11000 train_loss:1.0092 train_time:166613ms step_avg:53.75ms +step:3200/11000 train_loss:0.9997 train_time:172089ms step_avg:53.78ms +step:3300/11000 train_loss:0.9334 train_time:177430ms step_avg:53.77ms +step:3400/11000 train_loss:0.9906 train_time:182776ms step_avg:53.76ms +step:3500/11000 train_loss:0.9578 train_time:188115ms step_avg:53.75ms +step:3600/11000 train_loss:1.0021 train_time:193592ms step_avg:53.78ms +step:3700/11000 train_loss:0.8843 train_time:198932ms step_avg:53.77ms +step:3800/11000 train_loss:0.9367 train_time:204268ms step_avg:53.75ms +step:3900/11000 train_loss:1.0339 train_time:209605ms step_avg:53.74ms +step:4000/11000 train_loss:0.9930 train_time:214939ms step_avg:53.73ms +step:4000/11000 val_loss:0.9821 val_bpb:1.1029 train_time:214957ms step_avg:53.74ms +step:4100/11000 train_loss:0.9943 train_time:220420ms step_avg:53.76ms +step:4200/11000 train_loss:1.0044 train_time:225755ms step_avg:53.75ms +step:4300/11000 train_loss:0.9394 train_time:231087ms step_avg:53.74ms +step:4400/11000 train_loss:0.9729 train_time:236422ms step_avg:53.73ms +step:4500/11000 train_loss:0.9594 train_time:241896ms step_avg:53.75ms +step:4600/11000 train_loss:1.0720 train_time:247235ms step_avg:53.75ms +step:4700/11000 train_loss:0.9321 train_time:252572ms step_avg:53.74ms +step:4800/11000 train_loss:0.9814 train_time:257908ms step_avg:53.73ms +step:4900/11000 train_loss:0.9225 train_time:263243ms step_avg:53.72ms +step:5000/11000 train_loss:0.9534 train_time:268717ms step_avg:53.74ms +step:5000/11000 val_loss:0.9693 val_bpb:1.0885 train_time:268737ms step_avg:53.75ms +step:5100/11000 train_loss:0.9399 train_time:274054ms step_avg:53.74ms +step:5200/11000 train_loss:0.9841 train_time:279394ms step_avg:53.73ms +step:5300/11000 train_loss:0.9404 train_time:284729ms step_avg:53.72ms +step:5400/11000 train_loss:0.9731 train_time:290201ms step_avg:53.74ms +step:5500/11000 train_loss:0.9706 train_time:295535ms step_avg:53.73ms +step:5600/11000 train_loss:0.9722 train_time:300868ms step_avg:53.73ms +step:5700/11000 train_loss:0.9434 train_time:306201ms step_avg:53.72ms +step:5800/11000 train_loss:1.1211 train_time:311540ms step_avg:53.71ms +step:5900/11000 train_loss:0.9865 train_time:317013ms step_avg:53.73ms +step:6000/11000 train_loss:0.9644 train_time:322349ms step_avg:53.72ms +step:6000/11000 val_loss:0.9626 val_bpb:1.0810 train_time:322368ms step_avg:53.73ms +step:6100/11000 train_loss:0.9644 train_time:327689ms step_avg:53.72ms +step:6200/11000 train_loss:0.9342 train_time:333022ms step_avg:53.71ms +step:6300/11000 train_loss:1.0098 train_time:338495ms step_avg:53.73ms +step:6400/11000 train_loss:0.9337 train_time:343830ms step_avg:53.72ms +step:6500/11000 train_loss:0.8727 train_time:349165ms step_avg:53.72ms +step:6600/11000 train_loss:0.9451 train_time:354498ms step_avg:53.71ms +step:6700/11000 train_loss:0.9064 train_time:359973ms step_avg:53.73ms +step:6800/11000 train_loss:0.9712 train_time:365312ms step_avg:53.72ms +step:6900/11000 train_loss:0.9663 train_time:370646ms step_avg:53.72ms +step:7000/11000 train_loss:0.9257 train_time:375980ms step_avg:53.71ms +step:7000/11000 val_loss:0.9553 val_bpb:1.0729 train_time:375999ms step_avg:53.71ms +step:7100/11000 train_loss:0.9511 train_time:381320ms step_avg:53.71ms +step:7200/11000 train_loss:0.9738 train_time:386792ms step_avg:53.72ms +step:7300/11000 train_loss:0.9171 train_time:392127ms step_avg:53.72ms +step:7400/11000 train_loss:0.9189 train_time:397467ms step_avg:53.71ms +step:7500/11000 train_loss:0.9419 train_time:402801ms step_avg:53.71ms +step:7600/11000 train_loss:1.0085 train_time:408274ms step_avg:53.72ms +step:7700/11000 train_loss:0.9256 train_time:413609ms step_avg:53.72ms +step:7800/11000 train_loss:0.9556 train_time:418943ms step_avg:53.71ms +step:7900/11000 train_loss:0.9126 train_time:424277ms step_avg:53.71ms +step:8000/11000 train_loss:0.9950 train_time:429612ms step_avg:53.70ms +step:8000/11000 val_loss:0.9510 val_bpb:1.0680 train_time:429629ms step_avg:53.70ms +step:8100/11000 train_loss:0.9744 train_time:435090ms step_avg:53.71ms +step:8200/11000 train_loss:0.8932 train_time:440424ms step_avg:53.71ms +step:8300/11000 train_loss:0.8758 train_time:445767ms step_avg:53.71ms +step:8400/11000 train_loss:0.9758 train_time:451110ms step_avg:53.70ms +step:8500/11000 train_loss:0.9277 train_time:456589ms step_avg:53.72ms +step:8600/11000 train_loss:0.9651 train_time:461929ms step_avg:53.71ms +step:8700/11000 train_loss:0.8956 train_time:467267ms step_avg:53.71ms +step:8800/11000 train_loss:0.9560 train_time:472604ms step_avg:53.70ms +step:8900/11000 train_loss:0.9386 train_time:477948ms step_avg:53.70ms +step:9000/11000 train_loss:0.9437 train_time:483428ms step_avg:53.71ms +step:9000/11000 val_loss:0.9474 val_bpb:1.0640 train_time:483448ms step_avg:53.72ms +step:9100/11000 train_loss:0.8927 train_time:488770ms step_avg:53.71ms +step:9200/11000 train_loss:0.9721 train_time:494108ms step_avg:53.71ms +step:9300/11000 train_loss:0.9184 train_time:499443ms step_avg:53.70ms +step:9400/11000 train_loss:0.9799 train_time:504924ms step_avg:53.72ms +step:9500/11000 train_loss:0.8851 train_time:510262ms step_avg:53.71ms +step:9600/11000 train_loss:1.1251 train_time:515599ms step_avg:53.71ms +step:9700/11000 train_loss:0.9325 train_time:521044ms step_avg:53.72ms +step:9800/11000 train_loss:0.8561 train_time:526384ms step_avg:53.71ms +step:9900/11000 train_loss:0.8751 train_time:531865ms step_avg:53.72ms +step:10000/11000 train_loss:0.9433 train_time:537204ms step_avg:53.72ms +step:10000/11000 val_loss:0.9292 val_bpb:1.0436 train_time:537221ms step_avg:53.72ms +step:10100/11000 train_loss:0.9588 train_time:542541ms step_avg:53.72ms +step:10200/11000 train_loss:0.9084 train_time:547886ms step_avg:53.71ms +step:10300/11000 train_loss:0.9516 train_time:553364ms step_avg:53.72ms +step:10400/11000 train_loss:0.9323 train_time:558701ms step_avg:53.72ms +step:10500/11000 train_loss:0.8440 train_time:564036ms step_avg:53.72ms +step:10600/11000 train_loss:0.8255 train_time:569373ms step_avg:53.71ms +step:10700/11000 train_loss:1.0255 train_time:574713ms step_avg:53.71ms +step:10800/11000 train_loss:0.9161 train_time:580195ms step_avg:53.72ms +step:10900/11000 train_loss:0.8788 train_time:585542ms step_avg:53.72ms +step:11000/11000 train_loss:0.9357 train_time:590880ms step_avg:53.72ms +step:11000/11000 val_loss:0.9165 val_bpb:1.0293 train_time:590898ms step_avg:53.72ms +peak memory allocated: 11110 MiB reserved: 11384 MiB +Serialized model: 62256267 bytes +Code size: 47686 bytes +Total submission size: 62303953 bytes +Serialized model int8+zlib: 15581581 bytes (payload:17048864 raw_torch:17101673 payload_ratio:3.65x) +Total submission size int8+zlib: 15629267 bytes +final_int8_zlib_roundtrip val_loss:0.9181 val_bpb:1.0310 eval_time:3997ms +final_int8_zlib_roundtrip_exact val_loss:0.91807808 val_bpb:1.03102658 +) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Wed Mar 25 16:29:02 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 33C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 30C P0 114W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 32C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 33C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 32C P0 114W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 31C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/gravity_beta_1.0.model +train_loader:dataset:fineweb_gravity_beta_1.0 train_shards:39 +val_loader:shards pattern=./data/datasets/fineweb_gravity_beta_1.0/fineweb_val_*.bin tokens:144967680 +model_params:15749448 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:6 num_kv_heads:2 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:2048 iterations:11000 warmup_steps:50 max_wallclock_seconds:600.000 +seed:42 +warmup_step:10/50 +warmup_step:20/50 +warmup_step:30/50 +warmup_step:40/50 +warmup_step:50/50 +step:0/11000 val_loss:6.9259 val_bpb:7.7779 train_time:0ms step_avg:0.02ms +step:1/11000 train_loss:6.9240 train_time:46ms step_avg:46.40ms +step:2/11000 train_loss:10.5258 train_time:96ms step_avg:48.10ms +step:3/11000 train_loss:6.0231 train_time:153ms step_avg:51.03ms +step:4/11000 train_loss:5.4000 train_time:206ms step_avg:51.60ms +step:5/11000 train_loss:4.6453 train_time:260ms step_avg:52.08ms +step:6/11000 train_loss:3.9035 train_time:314ms step_avg:52.33ms +step:7/11000 train_loss:3.7564 train_time:367ms step_avg:52.46ms +step:8/11000 train_loss:3.5644 train_time:421ms step_avg:52.63ms +step:9/11000 train_loss:3.6143 train_time:474ms step_avg:52.68ms +step:10/11000 train_loss:3.2778 train_time:528ms step_avg:52.82ms +step:100/11000 train_loss:1.6411 train_time:5348ms step_avg:53.48ms +step:200/11000 train_loss:1.4209 train_time:10697ms step_avg:53.48ms +step:300/11000 train_loss:1.2945 train_time:16045ms step_avg:53.48ms +step:400/11000 train_loss:1.1435 train_time:21385ms step_avg:53.46ms +step:500/11000 train_loss:1.1755 train_time:26862ms step_avg:53.72ms +step:600/11000 train_loss:1.1559 train_time:32206ms step_avg:53.68ms +step:700/11000 train_loss:1.0774 train_time:37552ms step_avg:53.65ms +step:800/11000 train_loss:1.0312 train_time:42901ms step_avg:53.63ms +step:900/11000 train_loss:1.1268 train_time:48381ms step_avg:53.76ms +step:1000/11000 train_loss:1.0070 train_time:53737ms step_avg:53.74ms +step:1000/11000 val_loss:1.0844 val_bpb:1.2178 train_time:53754ms step_avg:53.75ms +step:1100/11000 train_loss:0.9714 train_time:59093ms step_avg:53.72ms +step:1200/11000 train_loss:0.9298 train_time:64447ms step_avg:53.71ms +step:1300/11000 train_loss:1.0697 train_time:69806ms step_avg:53.70ms +step:1400/11000 train_loss:0.9878 train_time:75298ms step_avg:53.78ms +step:1500/11000 train_loss:1.0250 train_time:80653ms step_avg:53.77ms +step:1600/11000 train_loss:1.0495 train_time:86009ms step_avg:53.76ms +step:1700/11000 train_loss:1.0279 train_time:91360ms step_avg:53.74ms +step:1800/11000 train_loss:1.0972 train_time:96844ms step_avg:53.80ms +step:1900/11000 train_loss:1.0215 train_time:102197ms step_avg:53.79ms +step:2000/11000 train_loss:1.0722 train_time:107551ms step_avg:53.78ms +step:2000/11000 val_loss:1.0231 val_bpb:1.1490 train_time:107568ms step_avg:53.78ms +step:2100/11000 train_loss:1.0539 train_time:112902ms step_avg:53.76ms +step:2200/11000 train_loss:1.0020 train_time:118252ms step_avg:53.75ms +step:2300/11000 train_loss:0.9561 train_time:123738ms step_avg:53.80ms +step:2400/11000 train_loss:1.0105 train_time:129083ms step_avg:53.78ms +step:2500/11000 train_loss:1.0190 train_time:134427ms step_avg:53.77ms +step:2600/11000 train_loss:0.8951 train_time:139769ms step_avg:53.76ms +step:2700/11000 train_loss:1.0603 train_time:145246ms step_avg:53.79ms +step:2800/11000 train_loss:1.0028 train_time:150586ms step_avg:53.78ms +step:2900/11000 train_loss:1.0138 train_time:155927ms step_avg:53.77ms +step:3000/11000 train_loss:0.8721 train_time:161266ms step_avg:53.76ms +step:3000/11000 val_loss:0.9973 val_bpb:1.1200 train_time:161286ms step_avg:53.76ms +step:3100/11000 train_loss:1.0092 train_time:166613ms step_avg:53.75ms +step:3200/11000 train_loss:0.9997 train_time:172089ms step_avg:53.78ms +step:3300/11000 train_loss:0.9334 train_time:177430ms step_avg:53.77ms +step:3400/11000 train_loss:0.9906 train_time:182776ms step_avg:53.76ms +step:3500/11000 train_loss:0.9578 train_time:188115ms step_avg:53.75ms +step:3600/11000 train_loss:1.0021 train_time:193592ms step_avg:53.78ms +step:3700/11000 train_loss:0.8843 train_time:198932ms step_avg:53.77ms +step:3800/11000 train_loss:0.9367 train_time:204268ms step_avg:53.75ms +step:3900/11000 train_loss:1.0339 train_time:209605ms step_avg:53.74ms +step:4000/11000 train_loss:0.9930 train_time:214939ms step_avg:53.73ms +step:4000/11000 val_loss:0.9821 val_bpb:1.1029 train_time:214957ms step_avg:53.74ms +step:4100/11000 train_loss:0.9943 train_time:220420ms step_avg:53.76ms +step:4200/11000 train_loss:1.0044 train_time:225755ms step_avg:53.75ms +step:4300/11000 train_loss:0.9394 train_time:231087ms step_avg:53.74ms +step:4400/11000 train_loss:0.9729 train_time:236422ms step_avg:53.73ms +step:4500/11000 train_loss:0.9594 train_time:241896ms step_avg:53.75ms +step:4600/11000 train_loss:1.0720 train_time:247235ms step_avg:53.75ms +step:4700/11000 train_loss:0.9321 train_time:252572ms step_avg:53.74ms +step:4800/11000 train_loss:0.9814 train_time:257908ms step_avg:53.73ms +step:4900/11000 train_loss:0.9225 train_time:263243ms step_avg:53.72ms +step:5000/11000 train_loss:0.9534 train_time:268717ms step_avg:53.74ms +step:5000/11000 val_loss:0.9693 val_bpb:1.0885 train_time:268737ms step_avg:53.75ms +step:5100/11000 train_loss:0.9399 train_time:274054ms step_avg:53.74ms +step:5200/11000 train_loss:0.9841 train_time:279394ms step_avg:53.73ms +step:5300/11000 train_loss:0.9404 train_time:284729ms step_avg:53.72ms +step:5400/11000 train_loss:0.9731 train_time:290201ms step_avg:53.74ms +step:5500/11000 train_loss:0.9706 train_time:295535ms step_avg:53.73ms +step:5600/11000 train_loss:0.9722 train_time:300868ms step_avg:53.73ms +step:5700/11000 train_loss:0.9434 train_time:306201ms step_avg:53.72ms +step:5800/11000 train_loss:1.1211 train_time:311540ms step_avg:53.71ms +step:5900/11000 train_loss:0.9865 train_time:317013ms step_avg:53.73ms +step:6000/11000 train_loss:0.9644 train_time:322349ms step_avg:53.72ms +step:6000/11000 val_loss:0.9626 val_bpb:1.0810 train_time:322368ms step_avg:53.73ms +step:6100/11000 train_loss:0.9644 train_time:327689ms step_avg:53.72ms +step:6200/11000 train_loss:0.9342 train_time:333022ms step_avg:53.71ms +step:6300/11000 train_loss:1.0098 train_time:338495ms step_avg:53.73ms +step:6400/11000 train_loss:0.9337 train_time:343830ms step_avg:53.72ms +step:6500/11000 train_loss:0.8727 train_time:349165ms step_avg:53.72ms +step:6600/11000 train_loss:0.9451 train_time:354498ms step_avg:53.71ms +step:6700/11000 train_loss:0.9064 train_time:359973ms step_avg:53.73ms +step:6800/11000 train_loss:0.9712 train_time:365312ms step_avg:53.72ms +step:6900/11000 train_loss:0.9663 train_time:370646ms step_avg:53.72ms +step:7000/11000 train_loss:0.9257 train_time:375980ms step_avg:53.71ms +step:7000/11000 val_loss:0.9553 val_bpb:1.0729 train_time:375999ms step_avg:53.71ms +step:7100/11000 train_loss:0.9511 train_time:381320ms step_avg:53.71ms +step:7200/11000 train_loss:0.9738 train_time:386792ms step_avg:53.72ms +step:7300/11000 train_loss:0.9171 train_time:392127ms step_avg:53.72ms +step:7400/11000 train_loss:0.9189 train_time:397467ms step_avg:53.71ms +step:7500/11000 train_loss:0.9419 train_time:402801ms step_avg:53.71ms +step:7600/11000 train_loss:1.0085 train_time:408274ms step_avg:53.72ms +step:7700/11000 train_loss:0.9256 train_time:413609ms step_avg:53.72ms +step:7800/11000 train_loss:0.9556 train_time:418943ms step_avg:53.71ms +step:7900/11000 train_loss:0.9126 train_time:424277ms step_avg:53.71ms +step:8000/11000 train_loss:0.9950 train_time:429612ms step_avg:53.70ms +step:8000/11000 val_loss:0.9510 val_bpb:1.0680 train_time:429629ms step_avg:53.70ms +step:8100/11000 train_loss:0.9744 train_time:435090ms step_avg:53.71ms +step:8200/11000 train_loss:0.8932 train_time:440424ms step_avg:53.71ms +step:8300/11000 train_loss:0.8758 train_time:445767ms step_avg:53.71ms +step:8400/11000 train_loss:0.9758 train_time:451110ms step_avg:53.70ms +step:8500/11000 train_loss:0.9277 train_time:456589ms step_avg:53.72ms +step:8600/11000 train_loss:0.9651 train_time:461929ms step_avg:53.71ms +step:8700/11000 train_loss:0.8956 train_time:467267ms step_avg:53.71ms +step:8800/11000 train_loss:0.9560 train_time:472604ms step_avg:53.70ms +step:8900/11000 train_loss:0.9386 train_time:477948ms step_avg:53.70ms +step:9000/11000 train_loss:0.9437 train_time:483428ms step_avg:53.71ms +step:9000/11000 val_loss:0.9474 val_bpb:1.0640 train_time:483448ms step_avg:53.72ms +step:9100/11000 train_loss:0.8927 train_time:488770ms step_avg:53.71ms +step:9200/11000 train_loss:0.9721 train_time:494108ms step_avg:53.71ms +step:9300/11000 train_loss:0.9184 train_time:499443ms step_avg:53.70ms +step:9400/11000 train_loss:0.9799 train_time:504924ms step_avg:53.72ms +step:9500/11000 train_loss:0.8851 train_time:510262ms step_avg:53.71ms +step:9600/11000 train_loss:1.1251 train_time:515599ms step_avg:53.71ms +step:9700/11000 train_loss:0.9325 train_time:521044ms step_avg:53.72ms +step:9800/11000 train_loss:0.8561 train_time:526384ms step_avg:53.71ms +step:9900/11000 train_loss:0.8751 train_time:531865ms step_avg:53.72ms +step:10000/11000 train_loss:0.9433 train_time:537204ms step_avg:53.72ms +step:10000/11000 val_loss:0.9292 val_bpb:1.0436 train_time:537221ms step_avg:53.72ms +step:10100/11000 train_loss:0.9588 train_time:542541ms step_avg:53.72ms +step:10200/11000 train_loss:0.9084 train_time:547886ms step_avg:53.71ms +step:10300/11000 train_loss:0.9516 train_time:553364ms step_avg:53.72ms +step:10400/11000 train_loss:0.9323 train_time:558701ms step_avg:53.72ms +step:10500/11000 train_loss:0.8440 train_time:564036ms step_avg:53.72ms +step:10600/11000 train_loss:0.8255 train_time:569373ms step_avg:53.71ms +step:10700/11000 train_loss:1.0255 train_time:574713ms step_avg:53.71ms +step:10800/11000 train_loss:0.9161 train_time:580195ms step_avg:53.72ms +step:10900/11000 train_loss:0.8788 train_time:585542ms step_avg:53.72ms +step:11000/11000 train_loss:0.9357 train_time:590880ms step_avg:53.72ms +step:11000/11000 val_loss:0.9165 val_bpb:1.0293 train_time:590898ms step_avg:53.72ms +peak memory allocated: 11110 MiB reserved: 11384 MiB +Serialized model: 62256267 bytes +Code size: 47686 bytes +Total submission size: 62303953 bytes +Serialized model int8+zlib: 15581581 bytes (payload:17048864 raw_torch:17101673 payload_ratio:3.65x) +Total submission size int8+zlib: 15629267 bytes +final_int8_zlib_roundtrip val_loss:0.9181 val_bpb:1.0310 eval_time:3997ms +final_int8_zlib_roundtrip_exact val_loss:0.91807808 val_bpb:1.03102658 +