Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/usr/bin/env bash
# Parameter Golf — Phase 0: env sanity + clone + deps + data + fork baseline.
# Run on the RunPod 1xH100 pod (pytorch:1.0.2-cu1281-torch280-ubuntu2404).
# Usage: bash phase0.sh 2>&1 | tee phase0.log

set -euo pipefail
cd /workspace
echo "=== PHASE 0 START ==="
date -u +"%Y-%m-%dT%H:%M:%SZ"

# 1. env sanity -------------------------------------------------------------
echo "--- env ---"
python - <<'PY'
import sys, torch, triton
print("python", sys.version.split()[0])
print("torch", torch.__version__, "cuda", torch.version.cuda, "triton", triton.__version__)
print("gpu", torch.cuda.get_device_name(0), "bf16", torch.cuda.is_bf16_supported())
print("sm", torch.cuda.get_device_capability(0))
PY
df -h /workspace | tail -1

# 2. clone repo -------------------------------------------------------------
echo "--- clone ---"
if [ ! -d parameter-golf ]; then
git clone --depth 1 https://github.com/openai/parameter-golf.git
fi
cd parameter-golf
git log -1 --oneline

# 3. install deps -----------------------------------------------------------
echo "--- pip ---"
pip install -q --no-input brotli sentencepiece huggingface-hub datasets tqdm 2>&1 | tail -3
python -c "import brotli, sentencepiece, huggingface_hub; print('brotli', brotli.__version__); print('sp', sentencepiece.__version__); print('hf_hub', huggingface_hub.__version__)"

# 4. download SP8192 smoke subset (2 train shards + full val) ---------------
echo "--- data ---"
MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf \
python3 data/cached_challenge_fineweb.py --variant sp8192 --train-shards 2 2>&1 | tail -15
ls -lh data/datasets/fineweb10B_sp8192/ | head -8
ls -lh data/tokenizers/ | head -8

# 5. fork bigbag's top record as our working baseline -----------------------
echo "--- fork baseline ---"
mkdir -p /workspace/work
cp -v records/track_10min_16mb/2026-04-09_SP8192_3LayerRecur_ParResid_QK525_LegalTTT/train_gpt.py \
/workspace/work/train_gpt_baseline.py
wc -l /workspace/work/train_gpt_baseline.py
md5sum /workspace/work/train_gpt_baseline.py

echo "=== PHASE 0 DONE ==="
date -u +"%Y-%m-%dT%H:%M:%SZ"
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#!/usr/bin/env bash
# Phase 0b: unpack bigbag's LZMA-compressed train_gpt.py into readable source.
# Expects unpack.py in the current directory (uploaded alongside this script).
set -euo pipefail

REC=records/track_10min_16mb/2026-04-09_SP8192_3LayerRecur_ParResid_QK525_LegalTTT
cd /workspace/parameter-golf

echo "--- unpack baseline ---"
python /workspace/unpack.py "$REC/train_gpt.py" /workspace/work/train_gpt_baseline.py

echo "--- readable? first 30 lines ---"
head -30 /workspace/work/train_gpt_baseline.py

echo "--- stats ---"
wc -l /workspace/work/train_gpt_baseline.py
md5sum /workspace/work/train_gpt_baseline.py

echo "--- does it at least import cleanly? ---"
# guard against accidental top-level training code
python - <<'PY'
import ast, pathlib
src = pathlib.Path("/workspace/work/train_gpt_baseline.py").read_text()
tree = ast.parse(src)
defs = [n.name for n in tree.body if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef))]
print(f"top-level defs: {len(defs)}")
print("sample:", defs[:20])
PY

echo "=== PHASE 0b DONE ==="
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#!/usr/bin/env bash
# Phase 1a: check FA3 availability + show the hot classes (Block, CausalSelfAttention, MLP, GPT).
set -euo pipefail
cd /workspace/parameter-golf

echo "--- is FA3 importable? ---"
python - <<'PY' || true
try:
import flash_attn_interface as fa3
print("FA3 OK", getattr(fa3, "__version__", "?"), "module:", fa3.__file__)
from flash_attn_interface import flash_attn_func
print("flash_attn_func sig:", flash_attn_func.__doc__[:200] if flash_attn_func.__doc__ else "(no doc)")
except Exception as e:
print("FA3 MISSING:", type(e).__name__, e)
PY

echo
echo "--- installed flash-attn-ish packages ---"
pip list 2>/dev/null | grep -i -E 'flash|attn' || echo "(none)"

echo
echo "--- extract the hot classes for inspection ---"
python - <<'PY'
import ast, pathlib, textwrap
src = pathlib.Path("/workspace/work/train_gpt_baseline.py").read_text()
tree = ast.parse(src)
wanted = {"CausalSelfAttention", "MLP", "Block", "GPT"}
lines = src.splitlines()
for node in tree.body:
if isinstance(node, ast.ClassDef) and node.name in wanted:
start, end = node.lineno - 1, node.end_lineno
body = "\n".join(lines[start:end])
print(f"\n=== class {node.name} @ lines {node.lineno}-{node.end_lineno} ({end-start} lines) ===")
print(body)
PY

echo
echo "--- what does the forward pass look like? show GPT.forward ---"
python - <<'PY'
import ast, pathlib
src = pathlib.Path("/workspace/work/train_gpt_baseline.py").read_text()
tree = ast.parse(src)
lines = src.splitlines()
for node in tree.body:
if isinstance(node, ast.ClassDef) and node.name == "GPT":
for m in node.body:
if isinstance(m, ast.FunctionDef) and m.name == "forward":
start, end = m.lineno - 1, m.end_lineno
print("\n".join(lines[start:end]))
PY

echo "=== PHASE 1a DONE ==="
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#!/usr/bin/env bash
# Phase 1b: FA3->SDPA patch + RMSNorm inspection + Block microbench (eager & compiled).
set -euo pipefail
cd /workspace

echo "=== PHASE 1b: patch FA3 + microbench Block ==="
date -u +"%Y-%m-%dT%H:%M:%SZ"

# 1. Patch the baseline: FA3 import -> SDPA shim ----------------------------
python - <<'PY'
import re, pathlib
src = pathlib.Path("/workspace/work/train_gpt_baseline.py").read_text()

shim = (
"# --- patched: FA3 -> SDPA shim ---\n"
"def flash_attn_3_func(q, k, v, causal=False):\n"
" import torch.nn.functional as _F\n"
" gqa = q.size(-2) != k.size(-2)\n"
" y = _F.scaled_dot_product_attention(\n"
" q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),\n"
" is_causal=causal, enable_gqa=gqa,\n"
" )\n"
" return y.transpose(1, 2)\n"
)

pat = r'from flash_attn_interface import flash_attn_func as flash_attn_3_func\n'
src_new, n = re.subn(pat, shim, src, count=1)
assert n == 1, f"Expected exactly 1 FA3 import, got {n}"
pathlib.Path("/workspace/work/train_gpt_patched.py").write_text(src_new)
print(f"Patched: {len(src)} -> {len(src_new)} bytes, lines: {src_new.count(chr(10))+1}")
PY

# 2. RMSNorm class (confirm parameter-free) ---------------------------------
echo
echo "--- RMSNorm definition ---"
python - <<'PY'
import ast, pathlib
src = pathlib.Path("/workspace/work/train_gpt_patched.py").read_text()
tree = ast.parse(src)
lines = src.splitlines()
for node in tree.body:
if isinstance(node, ast.ClassDef) and node.name == "RMSNorm":
print("\n".join(lines[node.lineno-1:node.end_lineno]))
# count nn.Parameter usages inside the class
n_params = sum(1 for n in ast.walk(node)
if isinstance(n, ast.Attribute) and n.attr == "Parameter")
print(f"\n[RMSNorm has {n_params} nn.Parameter usage(s) in-class]")
PY

# 3. Microbenchmark the parallel-residual block -----------------------------
echo
echo "--- Block microbench (eager & compiled) ---"
python - <<'PY'
import os, time, warnings, torch
warnings.filterwarnings("ignore")

# Exec patched source in a non-main namespace so main() doesn't auto-run
src = open("/workspace/work/train_gpt_patched.py").read()
ns = {"__name__": "pg_patched"}
try:
exec(compile(src, "train_gpt_patched.py", "exec"), ns)
except SystemExit:
pass # just in case main() has a sys.exit
Block = ns["Block"]

device = torch.device("cuda")
dtype = torch.bfloat16
torch.manual_seed(0)

# Bigbag's hyperparams for the hot-path shapes
B, T, D = 8, 2048, 512
H, KVH = 8, 4
MLP_MULT = 4.0

def build_block(parallel, use_xsa, layer_idx=7):
blk = Block(
dim=D, num_heads=H, num_kv_heads=KVH, mlp_mult=MLP_MULT,
rope_base=10000.0, qk_gain_init=5.0, train_seq_len=T,
layer_idx=layer_idx, ln_scale=True,
).to(device).to(dtype)
# Restore fp32 for the scalar/control params (baseline trick)
for p in blk.parameters():
if p.ndim < 2:
p.data = p.data.float()
blk.parallel = parallel
blk.attn.use_xsa = use_xsa
if hasattr(blk.attn, "rope_dims"):
# rope_dims=16 per hparams (partial RoPE). Rotary was built with rope_dims=0 in
# __init__; baseline re-creates it in GPT.__init__ when rope_dims>0. We redo here.
Rotary = ns["Rotary"]
blk.attn.rope_dims = 16
blk.attn.rotary = Rotary(D // H, base=10000.0, train_seq_len=T, rope_dims=16).to(device)
return blk

def bench(blk_fn, n_warmup=10, n_iter=50, label=""):
x = torch.randn(B, T, D, device=device, dtype=dtype, requires_grad=True)
x0 = torch.randn(B, T, D, device=device, dtype=dtype, requires_grad=True)
# warmup (also triggers compile if applicable)
for _ in range(n_warmup):
y = blk_fn(x, x0)
y.sum().backward()
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(n_iter):
y = blk_fn(x, x0)
y.sum().backward()
torch.cuda.synchronize()
return (time.perf_counter() - t0) * 1000.0 / n_iter

print(f"shape: B={B} T={T} D={D} H={H} KVH={KVH} MLP_MULT={MLP_MULT} dtype={dtype}")
print(f"{'variant':<48} {'ms/iter':>10}")
print("-" * 60)

for parallel in (False, True):
for xsa in (False, True):
blk = build_block(parallel, xsa)
ms_eager = bench(blk, label=f"eager p={parallel} xsa={xsa}")
print(f" eager parallel={parallel!s:<5} xsa={xsa!s:<5} {ms_eager:>10.3f}")

# Compiled variant only for the realistic target config
target = build_block(parallel=True, use_xsa=True)
target_compiled = torch.compile(target, fullgraph=True, dynamic=False, mode="max-autotune-no-cudagraphs")
ms_compiled = bench(target_compiled, n_warmup=20, n_iter=50)
print(f" compiled parallel=True xsa=True {ms_compiled:>10.3f}")

# 4. Profile the compiled target block
print()
print("--- torch.profiler: compiled parallel+xsa block, 20 fwd+bwd iters ---")
x = torch.randn(B, T, D, device=device, dtype=dtype, requires_grad=True)
x0 = torch.randn(B, T, D, device=device, dtype=dtype, requires_grad=True)
for _ in range(5):
y = target_compiled(x, x0); y.sum().backward()
torch.cuda.synchronize()
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA, torch.profiler.ProfilerActivity.CPU],
record_shapes=False,
) as prof:
for _ in range(20):
y = target_compiled(x, x0); y.sum().backward()
torch.cuda.synchronize()
print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=25))
PY

echo "=== PHASE 1b DONE ==="
date -u +"%Y-%m-%dT%H:%M:%SZ"
Loading