Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 157 additions & 0 deletions records/track_10min_16mb/2026-04-16_SP8192_4LayerRecurrence/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# SP8192 + 4-Layer Depth Recurrence + Parallel Residuals + QK-Gain 5.25 + Legal TTT

**val_bpb = pending** (results to be collected on 8xA100s) | **~16MB** | 8xA100

## Summary

Extends the current SOTA (PR #1509, val_bpb=1.0810) by widening the depth recurrence from 3 looped layers to 4 (`LOOP_END=6`). All other hyperparameters and techniques are carried forward unchanged.

The only code change from SOTA:

```python
# Before (SOTA, PR #1509)
loop_end=int(os.environ.get('LOOP_END', 5))
qk_gain_init=float(os.environ.get('QK_GAIN_INIT', 5.))

# After (this PR)
loop_end=int(os.environ.get('LOOP_END', 6))
qk_gain_init=float(os.environ.get('QK_GAIN_INIT', 5.25)) # absorbs the known-good default
```

## Architecture

### Virtual Layer Sequence

**SOTA (loop_end=5)** — 17 virtual layers, 8 U-Net skips:
```
Encoder [0,1,2,3,4,5,3,4] → Decoder [5,3,4,5,6,7,8,9,10]
───────────────── ──────────────────────────
pre └──loop──┘ 2nd 2nd └──loop──┘ post (par)
```

**This PR (loop_end=6)** — 19 virtual layers, 9 U-Net skips:
```
Encoder [0,1,2,3,4,5,6,3,4] → Decoder [5,6,3,4,5,6,7,8,9,10]
───────────────────── ────────────────────────────
pre └────loop────┘ 2nd 2nd └────loop────┘ post (par)
```

Layer 6 is promoted from the non-recurring post-loop section into the recurrence core. It now executes 3 times (like layers 3, 4, 5) instead of once.

### U-Net Skip Connections

With 9 encoder and 10 decoder steps, there are 9 skip connections (encoder[i] feeds decoder[8-i]):

| Skip | Encoder step | Decoder step |
|------|-------------|-------------|
| 0 | L0 (1st pass) | L10 |
| 1 | L1 (1st pass) | L9 |
| 2 | L2 (1st pass) | L8 |
| 3 | L3 (1st pass) | L7 (parallel) |
| 4 | L4 (1st pass) | L6 (3rd loop pass) |
| 5 | L5 (1st pass) | L5 (3rd loop pass) |
| 6 | **L6 (1st pass)** | **L4 (3rd loop pass)** |
| 7 | L3 (2nd pass) | L3 (3rd loop pass) |
| 8 | L4 (2nd pass) | L6 (2nd loop pass, parallel) |

Skip 6 is new in this PR. It connects the first pass through L6 (shallow context) to the third pass through L4 (deep re-processing), providing a residual shortcut that was absent in the 3-layer config.

### Parallel Residuals

Unchanged: layers 7, 8, 9, 10 use GPT-J-style parallel attention+MLP. These all appear in the decoder's post-loop section and are not part of the recurrence.

## Motivation: Compute Budget Equivalence

The 4-layer loop is slower per step (19 vs 17 virtual forward passes), but the total layer-step budget is identical:

| Config | Virtual layers | Est. steps | Layer-steps |
|--------|---------------|-----------|-------------|
| SOTA | 17 | ~4,550 | ~77,350 |
| This PR | 19 | ~4,071 | ~77,349 |

The prior depth-recurrence progression shows monotonic improvement with more virtual depth:

| Submission | Looped layers | Virtual layers | val_bpb (no TTT) |
|-----------|--------------|---------------|-----------------|
| PR #1260 (2-layer loop) | [4,5] | ~13 | 1.0979 |
| PR #1394 (3-layer loop) | [3,4,5] | 17 | 1.0856 |
| This PR (4-layer loop) | [3,4,5,6] | 19 | pending |

Each expansion has improved BPB without increasing the compute budget. The hypothesis is that depth-per-step is more sample-efficient than breadth-of-passes at the same depth.

## Local Verification

`test_architecture.py` in this directory validates the 4-layer config on CPU (no CUDA, no flash_attn needed):

```
$ python test_architecture.py
============================================================
Config: SOTA (loop_end=5)
encoder: [0, 1, 2, 3, 4, 5, 3, 4]
decoder: [5, 3, 4, 5, 6, 7, 8, 9, 10]
virtual_layers=17, skips=8
forward pass: OK shape=torch.Size([2, 32, 256])
gradient flow: OK loss=5.5465
all looped blocks have clean gradients: OK

============================================================
Config: This PR (loop_end=6)
encoder: [0, 1, 2, 3, 4, 5, 6, 3, 4]
decoder: [5, 6, 3, 4, 5, 6, 7, 8, 9, 10]
virtual_layers=19, skips=9
forward pass: OK shape=torch.Size([2, 32, 256])
gradient flow: OK loss=5.5435
all looped blocks have clean gradients: OK

All architecture tests PASSED.
```

## Full Technique Stack (carried from SOTA)

1. **SP8192** tokenizer (kevclark/parameter-golf HuggingFace dataset)
2. **4-Layer Depth Recurrence** — layers [3,4,5,6], 3 total passes, activated at `frac=0.35`
3. **Parallel Residuals** — from layer 7, GPT-J style (attention and MLP read same input)
4. **QK-Gain 5.25** — learnable per-head query scaling (now the script default)
5. **MuonEq-R** — row-normalized Muon with Newton-Schulz 5 steps
6. **Legal Score-First TTT** — SGD lr=0.005, momentum=0.9, 3 epochs/32K chunk, cosine LR decay
7. **Full-Hessian GPTQ SDClip** — int6 matrices (k=12.85), int8 embeddings (k=20.0)
8. **Byte-shuffle + Brotli-11** compression
9. **LZMA code wrapper** — ~16.6KB self-extracting code

## Reproduction

```bash
pip install brotli sentencepiece
pip install flash_attn_3 --no-deps --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291/
MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf python3 data/cached_challenge_fineweb.py --variant sp8192

SEED=42 TTT_ENABLED=1 TTT_LR=0.005 TTT_EPOCHS=3 \
torchrun --standalone --nproc_per_node=8 train_gpt.py
```

No additional env vars needed. `QK_GAIN_INIT=5.25` and `LOOP_END=6` are now script defaults.

## Expected Behavior

- Training: ~4071 steps in ~588s (vs 4550 for SOTA), 11.8% more compute per step
- Artifact: ~15.99 MB (essentially unchanged — one extra skip_weight/skip_gate row adds ~4KB uncompressed)
- Eval: identical sliding window + TTT pipeline, same timing budget

## Compliance

Identical to SOTA (PR #1509). Score-first TTT, no SLOT, no n-gram cache, no pre-quant TTT, no ETLB.

## Files

- `train_gpt.py` — LZMA-compressed production script (2-line diff from SOTA)
- `train_gpt_human.py` — human-readable version of the same code
- `test_architecture.py` — CPU smoke test, no dependencies beyond PyTorch

## Credits

- **@clarkkev** — SP8192 + GPTQ + SDClip + MuonEq-R + depth recurrence base (PR #1394)
- **@dexhunter** — 3-layer depth recurrence (PR #1331, #1437), legal TTT on SP8192 (PR #1413)
- **@Robby955** — Parallel residuals on SP8192 (PR #1412)
- **@msisovic** — Parallel residuals concept (PR #1204)
- **@X-Abhishek-X** — Hyperparameter tuning: WD=0.095, MLR=0.022, EMA=0.9965 (PR #1445, #1471)
- **@abaybektursun** — Score-first TTT framework (PR #549)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"name": "SP8192 + 4-Layer Depth Recurrence + Parallel Residuals + QK-Gain 5.25 + Legal TTT",
"github_id": "tashapais",
"val_bpb": null,
"date": "2026-04-16",
"notes": "Extends SOTA 3-layer depth recurrence to 4 layers (loop_end=6). Results pending on 8xA100s."
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
"""
Smoke-test for the 4-layer depth recurrence architecture.

Runs on CPU with standard PyTorch (no flash_attn, no CUDA, no data).
Verifies that:
1. Virtual layer sequences match the expected encoder/decoder paths
2. Forward pass produces correct output shapes
3. Gradients flow cleanly through all blocks (including looped blocks)
4. SOTA (loop_end=5) and this submission (loop_end=6) both pass

Usage:
python test_architecture.py
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class RMSNorm(nn.Module):
def forward(self, x):
return F.rms_norm(x, (x.size(-1),))


class CastedLinear(nn.Linear):
def forward(self, x):
return F.linear(x, self.weight.to(x.dtype))


class MLP(nn.Module):
def __init__(self, dim, mlp_mult):
super().__init__()
hidden = int(mlp_mult * dim)
self.fc = CastedLinear(dim, hidden, bias=False)
self.proj = CastedLinear(hidden, dim, bias=False)

def forward(self, x):
return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.5).square())


class Attention(nn.Module):
def __init__(self, dim, num_heads, num_kv_heads):
super().__init__()
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = dim // num_heads
kv_dim = num_kv_heads * self.head_dim
self.c_q = CastedLinear(dim, dim, bias=False)
self.c_k = CastedLinear(dim, kv_dim, bias=False)
self.c_v = CastedLinear(dim, kv_dim, bias=False)
self.proj = CastedLinear(dim, dim, bias=False)
self.q_gain = nn.Parameter(torch.ones(num_heads))

def forward(self, x):
B, T, D = x.shape
hd = self.head_dim
q = F.rms_norm(self.c_q(x).reshape(B, T, self.num_heads, hd), (hd,))
k = F.rms_norm(self.c_k(x).reshape(B, T, self.num_kv_heads, hd), (hd,))
v = self.c_v(x).reshape(B, T, self.num_kv_heads, hd)
q = q * self.q_gain[None, None, :, None]
grp = self.num_heads // self.num_kv_heads
k = k.repeat_interleave(grp, dim=2)
v = v.repeat_interleave(grp, dim=2)
y = F.scaled_dot_product_attention(
q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3),
is_causal=True,
).permute(0, 2, 1, 3).reshape(B, T, D)
return self.proj(y)


class Block(nn.Module):
def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, layer_idx):
super().__init__()
self.attn_norm = RMSNorm()
self.mlp_norm = RMSNorm()
self.attn = Attention(dim, num_heads, num_kv_heads)
self.mlp = MLP(dim, mlp_mult)
self.attn_scale = nn.Parameter(torch.ones(dim))
self.mlp_scale = nn.Parameter(torch.ones(dim))
self.resid_mix = nn.Parameter(torch.stack([torch.ones(dim), torch.zeros(dim)]).float())
self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1)
self.parallel = False

def forward(self, x, x0):
mix = self.resid_mix.to(dtype=x.dtype)
x_in = mix[0] * x + mix[1] * x0
attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor)
if self.parallel:
mlp_out = self.mlp(self.mlp_norm(x_in) * self.ln_scale_factor)
return x_in + self.attn_scale.to(x_in.dtype) * attn_out + self.mlp_scale.to(x_in.dtype) * mlp_out
x2 = x_in + self.attn_scale.to(x_in.dtype) * attn_out
return x2 + self.mlp_scale.to(x2.dtype) * self.mlp(self.mlp_norm(x2) * self.ln_scale_factor)


class GPTRecurrent(nn.Module):
def __init__(self, vocab_size=256, num_layers=11, dim=64,
num_heads=8, num_kv_heads=4, mlp_mult=4,
loop_start=3, loop_end=5, num_loops=2,
parallel_residual_start=7):
super().__init__()
self.blocks = nn.ModuleList([
Block(dim, num_heads, num_kv_heads, mlp_mult, i)
for i in range(num_layers)
])
for i in range(num_layers):
self.blocks[i].parallel = (i >= parallel_residual_start)

loop_seg = list(range(loop_start, loop_end + 1))
all_idx = list(range(loop_start))
for _ in range(num_loops + 1):
all_idx.extend(loop_seg)
all_idx.extend(range(loop_end + 1, num_layers))

num_enc = len(all_idx) // 2
self.encoder_indices = all_idx[:num_enc]
self.decoder_indices = all_idx[num_enc:]
self.num_skip = min(len(self.encoder_indices), len(self.decoder_indices))

self.skip_weights = nn.Parameter(torch.ones(self.num_skip, dim))
self.skip_gates = nn.Parameter(torch.zeros(self.num_skip, dim))
self.tok_emb = nn.Embedding(vocab_size, dim)
nn.init.normal_(self.tok_emb.weight, std=0.005)

def forward(self, input_ids):
x = F.rms_norm(self.tok_emb(input_ids), (self.tok_emb.embedding_dim,))
x0 = x
skips = []
for i in self.encoder_indices:
x = self.blocks[i](x, x0)
skips.append(x)
for skip_idx, i in enumerate(self.decoder_indices):
if skip_idx < self.num_skip and skips:
sk = self.skip_weights[skip_idx].to(x.dtype) * skips.pop()
g = torch.sigmoid(self.skip_gates[skip_idx].to(x.dtype))
x = torch.lerp(sk, x, g)
x = self.blocks[i](x, x0)
return F.linear(x, self.tok_emb.weight)


def run_test(name, loop_end, expected_virtual_layers, expected_skips):
print(f"\n{'='*60}")
print(f"Config: {name}")

model = GPTRecurrent(loop_end=loop_end)
vl = len(model.encoder_indices) + len(model.decoder_indices)
assert vl == expected_virtual_layers, f"Expected {expected_virtual_layers} virtual layers, got {vl}"
assert model.num_skip == expected_skips, f"Expected {expected_skips} skips, got {model.num_skip}"

print(f" encoder: {model.encoder_indices}")
print(f" decoder: {model.decoder_indices}")
print(f" virtual_layers={vl}, skips={model.num_skip}")

# Forward pass
B, T = 2, 32
ids = torch.randint(0, 256, (B, T))
logits = model(ids)
assert logits.shape == (B, T, 256), f"Bad output shape: {logits.shape}"
assert not logits.isnan().any(), "NaN in forward pass"
print(f" forward pass: OK shape={logits.shape}")

# Gradient flow
loss = F.cross_entropy(logits[:, :-1].reshape(-1, 256), ids[:, 1:].reshape(-1))
loss.backward()
no_grad = [n for n, p in model.named_parameters() if p.requires_grad and p.grad is None]
has_nan = [n for n, p in model.named_parameters() if p.grad is not None and p.grad.isnan().any()]
assert not no_grad, f"Missing gradients: {no_grad}"
assert not has_nan, f"NaN gradients: {has_nan}"
print(f" gradient flow: OK loss={loss.item():.4f}")

# Verify blocks in looped section receive gradients (key check for recurrence)
looped = set(model.encoder_indices + model.decoder_indices)
for block_idx in range(11):
if block_idx in looped:
for p in model.blocks[block_idx].parameters():
assert p.grad is not None and not p.grad.isnan().any()
print(f" all looped blocks have clean gradients: OK")


if __name__ == "__main__":
run_test("SOTA (loop_end=5)", loop_end=5, expected_virtual_layers=17, expected_skips=8)
run_test("This PR (loop_end=6)", loop_end=6, expected_virtual_layers=19, expected_skips=9)
print(f"\n{'='*60}")
print("All architecture tests PASSED.")
Loading