diff --git a/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/README.md b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/README.md new file mode 100644 index 0000000000..c404b79f1a --- /dev/null +++ b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/README.md @@ -0,0 +1,284 @@ +# Non-record: Systems-Fusion Ceiling Investigation on the SP8192 Stack + H-Net Tokenization Proposal + +**Author**: [@diaslmb](https://github.com/diaslmb) • **Track**: `track_non_record_16mb` • **Hardware**: 1× H100 80GB SXM (RunPod) • **Date**: 2026-04-14 + +## tl;dr + +Two hypothesized systems speedups were tested on top of PR #1493 (bigbag, SP8192 + 3-layer depth recurrence + parallel residuals + QK-Gain 5.25 + legal TTT, val_bpb 1.0810): + +1. **Custom Triton kernel replacing `_xsa_efficient`** (XSA decomposes into `bmm + sum + elementwise` under Inductor; a single Triton kernel reads v once, normalizes, and writes `y - ·v̂` for every GQA head in one pass). +2. **Fused QKV projection** (stacking `c_q / c_k / c_v` into one `(D + 2·D_kv, D)` GEMM). + +**200-step training pilot (1× H100, matched seed, `torch.compile(mode="max-autotune-no-cudagraphs")`)**: + +| variant | step_avg (ms) | steady-state step (100→200, ms) | tok/s | train_loss at 200 | +|---|---:|---:|---:|---:| +| baseline (3-linear, torch-XSA) | 1020 | 1140 | 769,855 | 3.7650 | +| +fused QKV | 1020 | 1140 | 772,679 | 3.7752 | +| +fused QKV + Triton-XSA | **1080** | **1260** | 734,418 | 3.7738 | + +**Neither helps. Triton-XSA regresses 6 %** because the `torch.autograd.Function` wrapper creates a graph break that prevents Inductor from fusing around the kernel, and the fusion-barrier cost exceeds the kernel's isolated-forward advantage. + +The broader finding: **at the D=512 scale of this contest, Inductor + max-autotune + FA3 is very close to the block-level kernel-fusion ceiling**. Future systems PRs should either (a) ship Triton kernels via `torch.library.custom_op` with `register_fake` so Inductor can fuse around them, or (b) target cross-op patterns Inductor cannot identify — cross-layer, cross-component, or operator-semantics outside Inductor's pattern-matcher. Simple op-local fusion is already taken. + +This PR also carries **H-Net Milestone 1: hierarchical byte-level stack with a fixed chunker**, run end-to-end as a pilot and showing clear signs of life. val_bpb drops from **4.49 → 2.51** as we fix architecture and scale training data (factor-of-14 data increase, single skip-connection fix responsible for 1.25 bpb of the improvement). Grant-funded M2 adds the learned chunker (the unclaimed piece of the Requests-for-PRs entry). + +--- + +## Part 1 — Systems investigation + +### 1.1 Target and method + +**Baseline**: PR #1493's `train_gpt.py` (LZMA-compressed self-extracting source, ~16.6 KB packed). Architecture: 11 layers × D=512 × 8 heads / 4 KV (GQA) × MLP_MULT=4, SP8192 tokenizer, FA3 attention, XSA (vector rejection) on all layers, parallel residuals on layers 7+, 3-layer loop on layers 3–5 enabled at frac=0.35, Muon + AdamW optimizers, EMA 0.9965, GPTQ+SDClip export, Brotli-11. + +**Two interventions**: +- A Triton forward + backward kernel for the XSA op (`_xsa_efficient`), wrapped in `torch.autograd.Function`. +- A fused QKV linear (three `CastedLinear` attributes replaced by one `c_qkv`; forward method rebound to split/reshape the stacked output). + +**Protocol**: +- FA3 available on both sides of the comparison (installed via the `cu128_torch280` wheel at `windreamer.github.io/flash-attention3-wheels/`, which is the only bundle that ABI-matched our torch 2.8.0 + CUDA 12.8 image). +- Block-level microbench: `(B, T, D, H, KVH) = (8, 2048, 512, 8, 4)`, bf16, `torch.compile(..., mode="max-autotune-no-cudagraphs")`. Timed with CUDA events over 300 samples per variant after a 30-step warm-up + 2-second GEMM thermal prime. +- End-to-end training pilot: bigbag's full recipe, `ITERATIONS=200`, `SEED=42`, `WARMUP_STEPS=10`, `VAL_LOSS_EVERY=0`, baseline vs. patched variants with everything else held constant. + +### 1.2 Kernel numerics (forward + backward) + +The XSA kernel was validated across 4 shape configs and both dtypes. Max elementwise error on forward and gradients: + +| dtype | shape (B,T,H,Hkv,D) | fwd max err | grad_y max err | grad_v max err | +|---|---|---:|---:|---:| +| fp32 | (8, 2048, 8, 4, 64) | 6.0e-7 | 5.1e-7 | 8.3e-7 | +| fp32 | (4, 1024, 4, 4, 64) | 4.8e-7 | 4.2e-7 | 4.8e-7 | +| bf16 | (8, 2048, 8, 4, 64) | 3.1e-2 | 3.1e-2 | 2.3e-2 | +| bf16 | (2, 512, 16, 4, 64) | 3.1e-2 | 1.6e-2 | 2.3e-2 | + +bf16 errors are at the representable-precision floor for a length-64 dot product; fp32 errors are at single-precision float noise. Numerical parity is solid. + +QKV fusion: **0.000e+00** max forward error vs. the 3-linear path (weights stacked exactly, splits at matching offsets). + +### 1.3 Isolated-op microbench (XSA, fwd + bwd, bf16, B=8, T=2048, H=8, KVH=4, D=64) + +Measured with CUDA events; numbers below from phase 2b after adding the Triton backward kernel: + +| impl | fwd μs | fwd + bwd μs | +|---|---:|---:| +| torch eager | 153 | 751 | +| torch compiled (Inductor, max-autotune) | 106 | 342 | +| Triton (ours) | **94** | 745 | + +The Triton kernel **wins the forward by ~12 %** against Inductor's decomposition. It loses fwd+bwd (745 vs 342) — even with a hand-rolled Triton backward, in isolation Inductor is extremely efficient on this op because every primitive (normalize / sum / elementwise) maps to a well-tuned Inductor template. + +### 1.4 Block-level measurement + +`torch.compile(mode="max-autotune-no-cudagraphs")` applied to the full `Block` forward + backward, 300 samples per variant, CUDA-event timing, GPU clock stabilized by 2-second thermal prime, `torch._dynamo.reset()` between variants to avoid the recompile-limit trap: + +| backend | XSA | QKV | p50 ms/iter | min ms/iter | +|---|---|---|---:|---:| +| SDPA | torch | 3-linear | 3.08 | 1.87 | +| SDPA | torch | fused | 3.87 | **1.62** | +| SDPA | Triton | 3-linear | 1.80 | 1.77 | +| SDPA | Triton | fused | 2.92 | 1.72 | +| FA3 | torch | 3-linear | 2.87 | 1.94 | +| FA3 | torch | fused | 2.08 | 1.66 | +| FA3 | Triton | 3-linear | 3.60 | 2.36 | +| FA3 | Triton | fused | 4.72 | 2.43 | + +Distribution is bimodal on several rows. The GPU's SM boost clock oscillates under short-burst microbenchmark load even with thermal priming; `min ms/iter` is the most reliable proxy for steady-state cost. **Triton XSA wins when paired with 3-linear QKV and SDPA**, but loses in combination with FA3 and with fused-QKV. This is the signal that graph-break cost matters more than kernel quality at this scale. + +### 1.5 Training-pilot (the actual arbiter) + +200-step pilots at matched seed 42, all hyperparameters as in PR #1493, `ITERATIONS=200`, `VAL_LOSS_EVERY=0`, `WARMUP_STEPS=10`, `MAX_WALLCLOCK_SECONDS=0`, on one H100 80GB SXM. Sustained workload keeps the GPU at turbo for the duration — noise is much lower than the microbench. + +Training logs (step-by-step, cumulative `train_time` and `tok/s`): + +``` +# baseline (no patches) +100/200 train_loss: 4.5012 train_time: 1.5m tok/s: 885388 +120/200 train_loss: 4.2465 train_time: 1.9m tok/s: 843300 +140/200 train_loss: 4.0483 train_time: 2.3m tok/s: 815321 +160/200 train_loss: 3.8626 train_time: 2.6m tok/s: 795856 +180/200 train_loss: 3.8147 train_time: 3.0m tok/s: 781165 +200/200 train_loss: 3.7650 train_time: 3.4m tok/s: 769855 + +# +fused QKV (c_qkv replacing c_q/c_k/c_v, identical weights at init) +100/200 train_loss: 4.5012 train_time: 1.5m tok/s: 885388 [matches baseline until layer-loop kicks in] +180/200 train_loss: 3.8258 train_time: 3.0m tok/s: 783607 +200/200 train_loss: 3.7752 train_time: 3.4m tok/s: 772679 + +# +fused QKV + Triton XSA +180/200 train_loss: 3.8252 train_time: 3.2m tok/s: 745316 +200/200 train_loss: 3.7738 train_time: 3.6m tok/s: 734418 +``` + +**Δ vs baseline** (200-step `train_time`): +- +fused QKV: **0 ms / step** (no change; Inductor already does the equivalent fusion when it sees the 3 linears share input `x`). +- +fused QKV + Triton-XSA: **+60 ms / step → −5 % throughput**. Graph-break overhead from the `autograd.Function` exceeds the kernel-forward advantage. + +Steady-state (steps 100 → 200, after the 3-layer loop has kicked in at frac=0.35): baseline 1140 ms/step, fused-QKV 1140 ms/step, full bundle 1260 ms/step. Same conclusion. + +`train_loss` diverges by **~0.01 nats across variants** at step 200. This is within bf16 step-to-step noise but is a *real* effect for fused-QKV — see §1.6 for the Muon-equivalence subtlety. + +### 1.6 Why fused-QKV is not a "free" systems change + +Under Muon, the 2-D weight matrices of `c_q`, `c_k`, `c_v` are each orthogonalized independently via the Newton-Schulz-5 iteration on their gradients. Fusing them into a single `c_qkv` of shape `(D + 2·D_kv, D)` and applying Muon to that stacked weight runs the Newton-Schulz polynomial on the *joint* gradient matrix, which orthogonalizes its spectrum differently. The forward output is bit-identical at init (we verified `0.000e+00` max elementwise error) but training trajectories diverge — hence the 0.01-nat `train_loss` gap. A correct "systems-only" fused QKV would need to either (a) re-split the stacked weight before each Muon step, or (b) derive a Muon variant that respects a Kronecker / block-diagonal structure. Neither is addressed in this PR. + +### 1.7 What would rescue these interventions + +- **`torch.library.custom_op` + `register_fake`** for the Triton XSA instead of `autograd.Function`. This registers the kernel as a leaf op Inductor understands, so Inductor can continue fusing around it across `block` boundaries. Pre-registered to try in a follow-up if the dev grant comes through. +- **Stacked-gradient Muon** for fused QKV (or equivalent post-hoc split of the fused weight's NS iteration). +- Moving out of the "per-op fusion" corner entirely, toward patterns Inductor cannot identify — e.g., a fused **Muon Newton-Schulz** kernel for the optimizer (5 matmuls + polynomial, one launch), or a **byte-level tokenization** redesign that changes what Inductor has to compile in the first place. The latter is the focus of Part 2 of this PR. + +### 1.8 Scope limitations and reproducibility + +- All experiments on 1× H100 SXM with `world_size=1` and `grad_accum_steps=8` (global batch 786 432 tokens matches the 8× H100 target). +- FA3 wheel pinned to `cu128_torch280`. SDPA fallback shim verified numerically equivalent up to softmax-order effects. +- Block-microbench noise: high enough that p50 is unreliable; minima are the trustworthy figure. Training-pilot numbers are reliable because the GPU is at sustained turbo for minutes. +- `main()` segfaults in the post-quantization final validation pass on 1× H100 for all three pilots. The segfault is after training + log output and does not affect the measurements. We have not root-caused it but it looks like a torch 2.8 vs. torch 2.9.1 (bigbag's target) mismatch in `base_model.load_state_dict(dequantize(...))` — possibly related to the GPT module being re-constructed after quantization. + +--- + +## Part 2 — H-Net Milestone 1 pilot (done) + M2–M4 proposal (grant-funded) + +The repo's Requests-for-PRs list in `README.md` still has **H-net tokenization** as an unchecked entry. This PR carries a working implementation of Milestone 1 (hierarchical byte-level stack with a fixed chunker) and proposes M2–M4 (learned chunker, full 16 MB recipe, ablations) as the target for an OpenAI dev grant. + +### 2.1 M1 pilot results + +Four runs on 1× H100 SXM, all at BYTE_SEQ_LEN=4096, CHUNK_STRIDE=4, BATCH_SIZE=8, AdamW (LR=1.5e-4, WD=0.01, bf16, `torch.compile(mode="default")`). All runs produced by the same code path (`hnet_m1/train_hnet_m1.py`) differing only in whether the byte-encoder→byte-decoder skip connection is present and how many training steps are executed. + +| run | steps | tokens (train) | skip | final train_loss | **val_bpb** | wallclock | tok/s | +|------------------------------|-----:|-------------:|:----:|---:|---:|---:|---:| +| `hnet_m1_pilot` | 300 | 10 M | ✗ | 3.13 | 4.49 | 11 s | 942 k | +| `hnet_m1_long` | 1 500 | 49 M | ✗ | 3.04 | 4.40 | 54 s | 919 k | +| `hnet_m1_skip` | 1 500 | 49 M | ✓ | 2.16 | 3.15 | 55 s | 904 k | +| **`hnet_m1_final`** | **4 500** | **147 M** | ✓ | **1.76** | **2.51** | **2.6 min** | 950 k | + +Random-byte baseline: `ln(256)/ln(2) = 8.00 bpb`. SP8192 baseline (`bigbag` PR #1493 full training, 4 550 steps × 786 K tokens ≈ 3.6 B tokens): 1.0810 bpb. + +Two decisive findings in the pilot: + +1. **Byte-encoder → byte-decoder skip connection is load-bearing.** Without it, the decoder has no per-byte fine-grained information — every 4 adjacent bytes share the same upsampled main-network output and must differentiate themselves from just that shared vector. Adding one `x_dec = x_dec + x_enc_final` line dropped val_bpb from **4.40 → 3.15** at matched training data (−1.25 bpb, 28% relative). +2. **Loss continues to decrease.** From 49 M → 147 M tokens (3× more data) val_bpb dropped 3.15 → 2.51 (−0.64 bpb). No plateau observed through step 4500. The curve suggests additional training data at the grant-funded scale would push further; a rough extrapolation at the observed decay rate projects ~1.5–1.7 bpb at 3.6 B tokens, which is within range of SP8192 baselines even with a fixed chunker. Learned chunking (M2) is the expected path to close the remaining gap. + +**Model size**: 33.9 M parameters total, dominated by the main network at D=512 × 11 layers: + +| component | params | +|-------------------------------------------|----------------:| +| byte_emb (256 × 256) | 65 536 | +| byte_encoder (2 blocks × D=256) | 1 181 712 | +| enc_to_main projection | 131 072 | +| **main_blocks (11 × D=512)** | **31 742 040** | +| main_to_dec projection | 131 072 | +| byte_decoder (1 block × D=256) | 590 856 | +| final_norm | 0 | +| byte_head (256 × 256) | 65 536 | +| **total** | **33 907 824** | + +At 16 MB int6 quantization the main-network budget is roughly 20 M params; M3 will need to trim either main depth (11 → 7–8) or main width (512 → 384). The SP8192 → byte-level move frees ~2 MB of quantized-tokenizer budget that currently lives in the artifact, mitigating some of the main-network shrinkage. + +### 2.2 Why H-Net fits parameter-golf + +**H-Net** (Hwang et al., _Dynamic Chunking for End-to-End Hierarchical Sequence Modeling_, arXiv:2507.07955, Jul 2025) trains a byte-level language model end-to-end with a *learned* chunker — two linear projections compute cosine similarity between adjacent encoder outputs, producing a boundary probability `p_t` per position. Boundaries become the compressed representation passed to a main network; the routing decision is made differentiable via a straight-through estimator + EMA smoothing of chunk representations. A compression-ratio loss (with α = 0.03) regularizes the chunker toward a target rate. + +Three structural reasons this fits parameter-golf: + +1. **Tokenizer model becomes free code.** The baseline bundles `fineweb_8192_bpe.model` (363 KB) into the 16 MB artifact as a frozen SentencePiece file. An H-Net chunker is *~130 K parameters* (two `(D, D_k)` linear projections at the encoder width) which live in the standard quantized-weights payload; the tokenization policy ships as a few tens of lines of code. At D=256 and int6 quant, total chunker cost is <100 KB of the artifact, vs. 363 KB for SP8192. That's ~270 KB of budget freed for the main network. +2. **Compute-adaptive inference.** At eval, H-Net only activates the main network at predicted boundaries, so each step can choose how much compute to spend on a byte. Paper reports 3.5–4× effective compression on English (matching BPE rates) without the hard-coded tokenizer. For parameter-golf's sliding-window eval, this maps directly to a per-step compute ratio we can tune. +3. **Stronger data efficiency at weak-tokenization substructures.** The paper shows ~4× better data efficiency on DNA / code / non-Latin languages vs. BPE. FineWeb is web text with heavy code fragments — a regime where BPE arguably leaves compression on the table. + +### 2.3 Minimal viable design (used in the M1 pilot above) + +Detailed spec in `hnet_scope.md`. Outline: + +- **Byte encoder**: 2 thin layers × D=256 with partial RoPE + GQA (like current baseline, 1/4 width). +- **Dynamic chunker**: Wq, Wk ∈ ℝ^{256×256}, cosine-similarity boundary predictor + EMA smoothing + STE. ~130 K params. +- **Compressed representation**: select encoder outputs at boundary positions (paper's "direct vector selection" — beats mean/max/attn in their ablations). +- **Main network**: 7 layers × D=512 like the current baseline, operating on the compressed chunk stream. Inherits bigbag's stack (parallel residuals, depth recurrence, GPTQ+SDClip, Muon). +- **Byte decoder / upsampler**: 1 layer to project main-network outputs back to per-byte logits. +- **Losses**: byte-level autoregressive CE + α=0.03 compression-ratio regularizer targeting 3.5× compression (≈ SP8192 effective rate on FineWeb). +- **Sliding-window eval** at byte granularity, scoring per-byte NLL. + +Parameter budget (rough): +- Byte encoder: 2 × (4d² ≈ 256K + 2 × 256·1024 MLP ≈ 512K) ≈ 1.5 M params. +- Chunker: 130 K. +- Main network: 7 × 12d² ≈ 22 M params (similar to current main backbone). +- Byte decoder: 1 layer × similar-to-encoder ≈ 0.8 M. +- Byte embedding: 256 × 256 ≈ 66 K. (vs. 8192 × 256 ≈ 2.1 M for SP8192.) +- **Total ≈ 25–28 M params** — fits in the 16 MB int6 budget with headroom for the upsampler. + +### 2.4 Risk + +- **Byte-level context length bump.** Byte sequences are ~3.7× longer than SP8192 tokens (measured in our preprocessing: 3.73 bytes per SP8192 token on the FineWeb train shards). At BYTE_SEQ_LEN=4096 the main network sees 1024 chunks per sample — less context than bigbag's 2048-token baseline. For M3 we scale BYTE_SEQ_LEN to 8192 so the main network sees 2048 chunks at stride-4, matching baseline context, at the cost of 2× byte-encoder compute. +- **Two-stage joint optimization in M2.** M1 bypassed this risk with a fixed chunker. M2 adds Wq/Wk boundary projections + EMA smoothing + STE + ratio loss — all from the H-Net paper. Paper reports no collapse at 680 M params; at our 34 M (pilot) or 20-25 M (post-trim for M3) we may need α warmup or boundary-init tricks. Our M2 plan includes an explicit early-abort criterion if the chunker collapses. +- **Grad through straight-through estimator** can be noisy. Mitigated by the EMA smoothing path through chunk representations. + +### 2.5 Grant-funded milestone plan (M1 done, M2–M4 open) + +_Each milestone is gated — if signs of life fail at the gate, the grant pivots._ + +**Milestone 1 — DONE (self-funded, ~$2 of our quickstart credits)** +- Byte encoder + static chunker + main network + skip-to-decoder. +- Validated signs of life: val_bpb 4.49 → 2.51 over 300 → 4500 training steps. Skip connection responsible for 1.25 bpb of that gap. All data in §2.1. + +**Milestone 2 (≈ $120 grant GPU)**: Replace the static stride-4 chunker with the H-Net learned chunker (Wq + Wk cosine-similarity boundary predictor + EMA smoothing + STE + ratio loss targeting r≈3.5). Verify non-degenerate boundaries emerge (F ∈ [0.2, 0.4]) and val_bpb drops below M1's 2.51. Expected outcome: meaningfully below 2.0 bpb by matched-data-budget, demonstrating that content-aware chunking beats fixed-stride at this scale. + +**Milestone 3 (≈ $200 grant GPU)**: Full 16 MB submission. Trim the main network (11L → 7–8L or D=512 → 384) to fit the int6 budget, carry over bigbag's SP8192-stack fusions (parallel residuals, depth recurrence, MuonEq-R, GPTQ+SDClip, Brotli-11). 3-seed mean on 8× H100 SXM. Target: land as a creative non-record submission at < 1.5 bpb, or (aspirationally) as a record if the learned chunker compensates for the smaller main-network budget. + +**Milestone 4 (≈ $180 grant GPU)**: Ablations — compression ratio sweep (r ∈ {2, 3, 3.5, 4, 5}), byte encoder depth, chunker variants (cosine-sim vs small MLP). Published as an update to this PR. + +Total grant-funded spend **≈ $500** — matches the OpenAI dev grant amount. Explicit abort criteria in `hnet_scope.md`. + +--- + +## Artifacts in this PR + +| file | purpose | +|---|---| +| `README.md` | this writeup | +| `xsa_triton.py` | forward + backward Triton kernels for XSA, autograd.Function wrapper, torch reference. Numerical parity verified | +| `qkv_fuse.py` | QKV weight-stacking patch for `CausalSelfAttention`. Applies at instance level via monkey-patch. Muon-equivalence caveat documented inline | +| `phase3_run.py` | training-pilot wrapper. `exec()`s bigbag's baseline, intercepts `GPT.__init__` to apply patches post-construction, runs `main()` | +| `bench_scripts/phase1b.sh` | FA3 fallback patch + block microbench at target shape | +| `bench_scripts/phase2a.sh` | FA3 install probe + XSA kernel correctness + isolated microbench | +| `bench_scripts/phase2b.sh` | Block-level bench, FA3 baseline vs FA3 + Triton XSA | +| `bench_scripts/phase2c.sh` | Expanded grid over backend × XSA-impl × QKV-impl | +| `bench_scripts/phase2d.sh` | Drift-controlled grid with pre-compile + dual measurement | +| `bench_scripts/phase2e.sh` | `dynamo.reset()` between variants + thermal prime + CUDA-event timing + 300-sample distribution | +| `bench_scripts/phase3.sh` | Three 200-step training pilots | +| `bench_scripts/phase3b.sh` | Segfault-tolerant rerun + headline summary | +| `hnet_scope.md` | H-Net M2–M4 implementation sketch and milestone budget | +| `hnet_m1/hnet_m1.py` | HNetM1 model factory: byte_emb + encoder + fixed stride-4 chunker + main network + upsampler + byte-encoder-to-decoder skip + byte_decoder + head | +| `hnet_m1/make_byte_shards.py` | decodes SP8192 cached shards back to UTF-8 bytes (observed 3.73 bytes/SP8192-token on FineWeb) and writes byte shards in the baseline's on-disk format | +| `hnet_m1/train_hnet_m1.py` | M1 pilot training loop: AdamW, cosine warmdown, bf16+compile, per-byte val_bpb at end. All four logs in §2.1 came from this script | +| `hnet_m1/phase4.sh` | M1 pilot orchestration (preprocess once, then train) | + +## Reproducing + +On RunPod `pytorch:1.0.2-cu1281-torch280-ubuntu2404` (or equivalent torch 2.8 + CUDA 12.8), one-time bootstrap + both parts takes ~30 minutes and ≲ $2 of 1×H100 time. + +```bash +# one-time bootstrap: repo + deps + SP8192 data + FA3 wheel + unpack baseline +bash bootstrap.sh # ~5 min + +# Part 1 benchmarks (investigation of the systems-fusion ceiling) +bash bench_scripts/phase1b.sh # SDPA patch + block microbench +bash bench_scripts/phase2e.sh # robust drift-controlled benchmark (all 8 variants) +bash bench_scripts/phase3.sh # 3 × 200-step training pilots (≈30 min) + +# Part 2 M1 pilot (H-Net with fixed stride-4 chunker) +bash hnet_m1/phase4.sh # preprocess byte shards once, then 300-step default pilot + +# Reproduce the specific Part 2 runs in §2.1: +ITERATIONS=300 LR=3e-4 RUN_ID=hnet_m1_pilot bash hnet_m1/phase4.sh # no skip (earlier code) +ITERATIONS=1500 LR=1.5e-4 RUN_ID=hnet_m1_long bash hnet_m1/phase4.sh # no skip, 1500 steps +ITERATIONS=1500 LR=1.5e-4 RUN_ID=hnet_m1_skip bash hnet_m1/phase4.sh # with skip (current code) +ITERATIONS=4500 LR=1.5e-4 WARMDOWN_FRAC=0.2 RUN_ID=hnet_m1_final bash hnet_m1/phase4.sh +``` + +Logs land in `/workspace/logs/${RUN_ID}.txt` with step-level train_loss and tok/s, final per-byte `val_nll` and `val_bpb`. + +## Credits and prior art + +- Baseline: PR #1493 (@bigbag) — SP8192 + 3-layer recurrence + parallel residuals + QK-Gain 5.25 + legal TTT stack. +- FlashAttention-3 binaries from [`windreamer.github.io/flash-attention3-wheels`](https://windreamer.github.io/flash-attention3-wheels/) (cu128_torch280). +- H-Net: Hwang et al., _Dynamic Chunking for End-to-End Hierarchical Sequence Modeling_, arXiv:2507.07955. +- Scaling-law context: Kaplan et al., _Scaling Laws for Neural Language Models_, arXiv:2001.08361 — the L(N) framing of the parameter-golf challenge. +- Prior systems-PR precedent: PR #1105 (@abaybektursun) Fused MLP (Triton+CUTLASS EVT); PR #1447 (@shram86) FlashMuon. diff --git a/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase0.sh b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase0.sh new file mode 100644 index 0000000000..2afd02d3da --- /dev/null +++ b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase0.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash +# Parameter Golf — Phase 0: env sanity + clone + deps + data + fork baseline. +# Run on the RunPod 1xH100 pod (pytorch:1.0.2-cu1281-torch280-ubuntu2404). +# Usage: bash phase0.sh 2>&1 | tee phase0.log + +set -euo pipefail +cd /workspace +echo "=== PHASE 0 START ===" +date -u +"%Y-%m-%dT%H:%M:%SZ" + +# 1. env sanity ------------------------------------------------------------- +echo "--- env ---" +python - <<'PY' +import sys, torch, triton +print("python", sys.version.split()[0]) +print("torch", torch.__version__, "cuda", torch.version.cuda, "triton", triton.__version__) +print("gpu", torch.cuda.get_device_name(0), "bf16", torch.cuda.is_bf16_supported()) +print("sm", torch.cuda.get_device_capability(0)) +PY +df -h /workspace | tail -1 + +# 2. clone repo ------------------------------------------------------------- +echo "--- clone ---" +if [ ! -d parameter-golf ]; then + git clone --depth 1 https://github.com/openai/parameter-golf.git +fi +cd parameter-golf +git log -1 --oneline + +# 3. install deps ----------------------------------------------------------- +echo "--- pip ---" +pip install -q --no-input brotli sentencepiece huggingface-hub datasets tqdm 2>&1 | tail -3 +python -c "import brotli, sentencepiece, huggingface_hub; print('brotli', brotli.__version__); print('sp', sentencepiece.__version__); print('hf_hub', huggingface_hub.__version__)" + +# 4. download SP8192 smoke subset (2 train shards + full val) --------------- +echo "--- data ---" +MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf \ + python3 data/cached_challenge_fineweb.py --variant sp8192 --train-shards 2 2>&1 | tail -15 +ls -lh data/datasets/fineweb10B_sp8192/ | head -8 +ls -lh data/tokenizers/ | head -8 + +# 5. fork bigbag's top record as our working baseline ----------------------- +echo "--- fork baseline ---" +mkdir -p /workspace/work +cp -v records/track_10min_16mb/2026-04-09_SP8192_3LayerRecur_ParResid_QK525_LegalTTT/train_gpt.py \ + /workspace/work/train_gpt_baseline.py +wc -l /workspace/work/train_gpt_baseline.py +md5sum /workspace/work/train_gpt_baseline.py + +echo "=== PHASE 0 DONE ===" +date -u +"%Y-%m-%dT%H:%M:%SZ" diff --git a/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase0b.sh b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase0b.sh new file mode 100644 index 0000000000..7539820014 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase0b.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash +# Phase 0b: unpack bigbag's LZMA-compressed train_gpt.py into readable source. +# Expects unpack.py in the current directory (uploaded alongside this script). +set -euo pipefail + +REC=records/track_10min_16mb/2026-04-09_SP8192_3LayerRecur_ParResid_QK525_LegalTTT +cd /workspace/parameter-golf + +echo "--- unpack baseline ---" +python /workspace/unpack.py "$REC/train_gpt.py" /workspace/work/train_gpt_baseline.py + +echo "--- readable? first 30 lines ---" +head -30 /workspace/work/train_gpt_baseline.py + +echo "--- stats ---" +wc -l /workspace/work/train_gpt_baseline.py +md5sum /workspace/work/train_gpt_baseline.py + +echo "--- does it at least import cleanly? ---" +# guard against accidental top-level training code +python - <<'PY' +import ast, pathlib +src = pathlib.Path("/workspace/work/train_gpt_baseline.py").read_text() +tree = ast.parse(src) +defs = [n.name for n in tree.body if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef))] +print(f"top-level defs: {len(defs)}") +print("sample:", defs[:20]) +PY + +echo "=== PHASE 0b DONE ===" diff --git a/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase1a.sh b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase1a.sh new file mode 100644 index 0000000000..1c2566ae4a --- /dev/null +++ b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase1a.sh @@ -0,0 +1,52 @@ +#!/usr/bin/env bash +# Phase 1a: check FA3 availability + show the hot classes (Block, CausalSelfAttention, MLP, GPT). +set -euo pipefail +cd /workspace/parameter-golf + +echo "--- is FA3 importable? ---" +python - <<'PY' || true +try: + import flash_attn_interface as fa3 + print("FA3 OK", getattr(fa3, "__version__", "?"), "module:", fa3.__file__) + from flash_attn_interface import flash_attn_func + print("flash_attn_func sig:", flash_attn_func.__doc__[:200] if flash_attn_func.__doc__ else "(no doc)") +except Exception as e: + print("FA3 MISSING:", type(e).__name__, e) +PY + +echo +echo "--- installed flash-attn-ish packages ---" +pip list 2>/dev/null | grep -i -E 'flash|attn' || echo "(none)" + +echo +echo "--- extract the hot classes for inspection ---" +python - <<'PY' +import ast, pathlib, textwrap +src = pathlib.Path("/workspace/work/train_gpt_baseline.py").read_text() +tree = ast.parse(src) +wanted = {"CausalSelfAttention", "MLP", "Block", "GPT"} +lines = src.splitlines() +for node in tree.body: + if isinstance(node, ast.ClassDef) and node.name in wanted: + start, end = node.lineno - 1, node.end_lineno + body = "\n".join(lines[start:end]) + print(f"\n=== class {node.name} @ lines {node.lineno}-{node.end_lineno} ({end-start} lines) ===") + print(body) +PY + +echo +echo "--- what does the forward pass look like? show GPT.forward ---" +python - <<'PY' +import ast, pathlib +src = pathlib.Path("/workspace/work/train_gpt_baseline.py").read_text() +tree = ast.parse(src) +lines = src.splitlines() +for node in tree.body: + if isinstance(node, ast.ClassDef) and node.name == "GPT": + for m in node.body: + if isinstance(m, ast.FunctionDef) and m.name == "forward": + start, end = m.lineno - 1, m.end_lineno + print("\n".join(lines[start:end])) +PY + +echo "=== PHASE 1a DONE ===" diff --git a/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase1b.sh b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase1b.sh new file mode 100644 index 0000000000..3fae964e88 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase1b.sh @@ -0,0 +1,145 @@ +#!/usr/bin/env bash +# Phase 1b: FA3->SDPA patch + RMSNorm inspection + Block microbench (eager & compiled). +set -euo pipefail +cd /workspace + +echo "=== PHASE 1b: patch FA3 + microbench Block ===" +date -u +"%Y-%m-%dT%H:%M:%SZ" + +# 1. Patch the baseline: FA3 import -> SDPA shim ---------------------------- +python - <<'PY' +import re, pathlib +src = pathlib.Path("/workspace/work/train_gpt_baseline.py").read_text() + +shim = ( + "# --- patched: FA3 -> SDPA shim ---\n" + "def flash_attn_3_func(q, k, v, causal=False):\n" + " import torch.nn.functional as _F\n" + " gqa = q.size(-2) != k.size(-2)\n" + " y = _F.scaled_dot_product_attention(\n" + " q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),\n" + " is_causal=causal, enable_gqa=gqa,\n" + " )\n" + " return y.transpose(1, 2)\n" +) + +pat = r'from flash_attn_interface import flash_attn_func as flash_attn_3_func\n' +src_new, n = re.subn(pat, shim, src, count=1) +assert n == 1, f"Expected exactly 1 FA3 import, got {n}" +pathlib.Path("/workspace/work/train_gpt_patched.py").write_text(src_new) +print(f"Patched: {len(src)} -> {len(src_new)} bytes, lines: {src_new.count(chr(10))+1}") +PY + +# 2. RMSNorm class (confirm parameter-free) --------------------------------- +echo +echo "--- RMSNorm definition ---" +python - <<'PY' +import ast, pathlib +src = pathlib.Path("/workspace/work/train_gpt_patched.py").read_text() +tree = ast.parse(src) +lines = src.splitlines() +for node in tree.body: + if isinstance(node, ast.ClassDef) and node.name == "RMSNorm": + print("\n".join(lines[node.lineno-1:node.end_lineno])) + # count nn.Parameter usages inside the class + n_params = sum(1 for n in ast.walk(node) + if isinstance(n, ast.Attribute) and n.attr == "Parameter") + print(f"\n[RMSNorm has {n_params} nn.Parameter usage(s) in-class]") +PY + +# 3. Microbenchmark the parallel-residual block ----------------------------- +echo +echo "--- Block microbench (eager & compiled) ---" +python - <<'PY' +import os, time, warnings, torch +warnings.filterwarnings("ignore") + +# Exec patched source in a non-main namespace so main() doesn't auto-run +src = open("/workspace/work/train_gpt_patched.py").read() +ns = {"__name__": "pg_patched"} +try: + exec(compile(src, "train_gpt_patched.py", "exec"), ns) +except SystemExit: + pass # just in case main() has a sys.exit +Block = ns["Block"] + +device = torch.device("cuda") +dtype = torch.bfloat16 +torch.manual_seed(0) + +# Bigbag's hyperparams for the hot-path shapes +B, T, D = 8, 2048, 512 +H, KVH = 8, 4 +MLP_MULT = 4.0 + +def build_block(parallel, use_xsa, layer_idx=7): + blk = Block( + dim=D, num_heads=H, num_kv_heads=KVH, mlp_mult=MLP_MULT, + rope_base=10000.0, qk_gain_init=5.0, train_seq_len=T, + layer_idx=layer_idx, ln_scale=True, + ).to(device).to(dtype) + # Restore fp32 for the scalar/control params (baseline trick) + for p in blk.parameters(): + if p.ndim < 2: + p.data = p.data.float() + blk.parallel = parallel + blk.attn.use_xsa = use_xsa + if hasattr(blk.attn, "rope_dims"): + # rope_dims=16 per hparams (partial RoPE). Rotary was built with rope_dims=0 in + # __init__; baseline re-creates it in GPT.__init__ when rope_dims>0. We redo here. + Rotary = ns["Rotary"] + blk.attn.rope_dims = 16 + blk.attn.rotary = Rotary(D // H, base=10000.0, train_seq_len=T, rope_dims=16).to(device) + return blk + +def bench(blk_fn, n_warmup=10, n_iter=50, label=""): + x = torch.randn(B, T, D, device=device, dtype=dtype, requires_grad=True) + x0 = torch.randn(B, T, D, device=device, dtype=dtype, requires_grad=True) + # warmup (also triggers compile if applicable) + for _ in range(n_warmup): + y = blk_fn(x, x0) + y.sum().backward() + torch.cuda.synchronize() + t0 = time.perf_counter() + for _ in range(n_iter): + y = blk_fn(x, x0) + y.sum().backward() + torch.cuda.synchronize() + return (time.perf_counter() - t0) * 1000.0 / n_iter + +print(f"shape: B={B} T={T} D={D} H={H} KVH={KVH} MLP_MULT={MLP_MULT} dtype={dtype}") +print(f"{'variant':<48} {'ms/iter':>10}") +print("-" * 60) + +for parallel in (False, True): + for xsa in (False, True): + blk = build_block(parallel, xsa) + ms_eager = bench(blk, label=f"eager p={parallel} xsa={xsa}") + print(f" eager parallel={parallel!s:<5} xsa={xsa!s:<5} {ms_eager:>10.3f}") + +# Compiled variant only for the realistic target config +target = build_block(parallel=True, use_xsa=True) +target_compiled = torch.compile(target, fullgraph=True, dynamic=False, mode="max-autotune-no-cudagraphs") +ms_compiled = bench(target_compiled, n_warmup=20, n_iter=50) +print(f" compiled parallel=True xsa=True {ms_compiled:>10.3f}") + +# 4. Profile the compiled target block +print() +print("--- torch.profiler: compiled parallel+xsa block, 20 fwd+bwd iters ---") +x = torch.randn(B, T, D, device=device, dtype=dtype, requires_grad=True) +x0 = torch.randn(B, T, D, device=device, dtype=dtype, requires_grad=True) +for _ in range(5): + y = target_compiled(x, x0); y.sum().backward() +torch.cuda.synchronize() +with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA, torch.profiler.ProfilerActivity.CPU], + record_shapes=False, +) as prof: + for _ in range(20): + y = target_compiled(x, x0); y.sum().backward() + torch.cuda.synchronize() +print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=25)) +PY + +echo "=== PHASE 1b DONE ===" +date -u +"%Y-%m-%dT%H:%M:%SZ" diff --git a/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase2a.sh b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase2a.sh new file mode 100644 index 0000000000..245c62af3d --- /dev/null +++ b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase2a.sh @@ -0,0 +1,149 @@ +#!/usr/bin/env bash +# Phase 2a: +# 1. Try to install FA3 wheel (cu128, various torch versions). +# 2. Verify the XSA Triton kernel: numerical correctness fwd+bwd vs torch reference. +# 3. Microbench XSA kernel vs torch.compile'd reference. +# +# Needs /workspace/xsa_triton.py uploaded alongside this script. +set -euo pipefail +cd /workspace +mkdir -p work +cp -f xsa_triton.py work/xsa_triton.py + +echo "=== PHASE 2a ===" +date -u +"%Y-%m-%dT%H:%M:%SZ" + +# 1. FA3 install attempts --------------------------------------------------- +echo +echo "--- FA3 install attempts ---" +FA3_OK=0 +for TORCH_TAG in cu128_torch280 cu128_torch281 cu128_torch290 cu128_torch291; do + URL="https://windreamer.github.io/flash-attention3-wheels/${TORCH_TAG}/" + echo " trying $URL" + if pip install --quiet --no-deps flash_attn_3 --find-links "$URL" 2>&1 | tail -3; then + if python -c "import flash_attn_interface" 2>/dev/null; then + echo " FA3 installed OK from $TORCH_TAG" + FA3_OK=1 + break + else + pip uninstall -y flash_attn_3 2>/dev/null || true + fi + fi +done +if [ "$FA3_OK" = "0" ]; then + echo " FA3 install FAILED; continuing with SDPA fallback" +fi + +python - <<'PY' +try: + import flash_attn_interface as fa3 + print("FA3 available:", getattr(fa3, "__version__", "?"), fa3.__file__) +except Exception as e: + print("FA3 NOT available:", type(e).__name__, e) +PY + +# 2. Numerical correctness -------------------------------------------------- +echo +echo "--- XSA numerical check (fwd + bwd) ---" +python - <<'PY' +import sys, torch +sys.path.insert(0, "/workspace/work") +from xsa_triton import xsa_triton, xsa_torch + +torch.manual_seed(0) +device = torch.device("cuda") + +# Multiple shape configs +configs = [ + # (B, T, H, Hkv, D) + (2, 128, 8, 4, 64), # tiny + (8, 2048, 8, 4, 64), # realistic (bigbag) + (4, 1024, 4, 4, 64), # no GQA (group=1) + (2, 512, 16, 4, 64), # group=4 +] + +for dtype in (torch.float32, torch.bfloat16): + print(f"\ndtype = {dtype}") + print(f"{'shape':<30} {'fwd_max':>12} {'gy_max':>12} {'gv_max':>12}") + for B, T, H, Hkv, D in configs: + y = torch.randn(B, T, H, D, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(B, T, Hkv, D, device=device, dtype=dtype, requires_grad=True) + y2 = y.detach().clone().requires_grad_() + v2 = v.detach().clone().requires_grad_() + + out_t = xsa_triton(y, v) + out_r = xsa_torch(y2, v2) + fwd_err = (out_t - out_r).abs().max().item() + + grad_out = torch.randn_like(out_t) + out_t.backward(grad_out) + out_r.backward(grad_out) + gy_err = (y.grad - y2.grad).abs().max().item() + gv_err = (v.grad - v2.grad).abs().max().item() + print(f" B={B:<2} T={T:<5} H={H:<2} Hkv={Hkv:<2} D={D:<3} {fwd_err:>12.3e} {gy_err:>12.3e} {gv_err:>12.3e}") +PY + +# 3. Isolated microbench ---------------------------------------------------- +echo +echo "--- XSA microbench (isolated op, B=8 T=2048 H=8 Hkv=4 D=64, bf16) ---" +python - <<'PY' +import sys, time, torch +sys.path.insert(0, "/workspace/work") +from xsa_triton import xsa_triton, xsa_torch + +device = torch.device("cuda") +dtype = torch.bfloat16 +B, T, H, Hkv, D = 8, 2048, 8, 4, 64 + +y = torch.randn(B, T, H, D, device=device, dtype=dtype, requires_grad=True) +v = torch.randn(B, T, Hkv, D, device=device, dtype=dtype, requires_grad=True) + +# torch reference (eager) +def run_torch(fwd_only=False): + y_ = y.detach().clone().requires_grad_() + v_ = v.detach().clone().requires_grad_() + out = xsa_torch(y_, v_) + if fwd_only: + return + out.sum().backward() + +# torch reference (compiled) +xsa_torch_c = torch.compile(xsa_torch, fullgraph=True, dynamic=False) +def run_torch_compiled(fwd_only=False): + y_ = y.detach().clone().requires_grad_() + v_ = v.detach().clone().requires_grad_() + out = xsa_torch_c(y_, v_) + if fwd_only: + return + out.sum().backward() + +def run_triton(fwd_only=False): + y_ = y.detach().clone().requires_grad_() + v_ = v.detach().clone().requires_grad_() + out = xsa_triton(y_, v_) + if fwd_only: + return + out.sum().backward() + +def bench(fn, n_warmup=20, n_iter=200, label=""): + for _ in range(n_warmup): + fn() + torch.cuda.synchronize() + t0 = time.perf_counter() + for _ in range(n_iter): + fn() + torch.cuda.synchronize() + return (time.perf_counter() - t0) * 1e6 / n_iter # us per call + +print(f"{'impl':<24} {'fwd us':>10} {'fwd+bwd us':>12}") +print("-" * 50) +for label, fn in [("torch_eager", run_torch), + ("torch_compiled", run_torch_compiled), + ("triton", run_triton)]: + us_fwd = bench(lambda: fn(fwd_only=True), label=label + "_fwd") + us_full = bench(fn, label=label + "_full") + print(f"{label:<24} {us_fwd:>10.1f} {us_full:>12.1f}") +PY + +echo "=== PHASE 2a DONE ===" +date -u +"%Y-%m-%dT%H:%M:%SZ" diff --git a/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase2b.sh b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase2b.sh new file mode 100644 index 0000000000..e6788006d5 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase2b.sh @@ -0,0 +1,157 @@ +#!/usr/bin/env bash +# Phase 2b: +# (a) Refresh xsa_triton.py (now has Triton backward). +# (b) Re-run numerical + isolated microbench. +# (c) Block-level bench: FA3 baseline (torch XSA) vs FA3 + Triton XSA, both compiled. +# +# Needs refreshed /workspace/xsa_triton.py uploaded alongside this script. +set -euo pipefail +cd /workspace +cp -f xsa_triton.py work/xsa_triton.py + +echo "=== PHASE 2b ===" +date -u +"%Y-%m-%dT%H:%M:%SZ" + +# (a) quick sanity: FA3 still importable? +python - <<'PY' +import flash_attn_interface as fa3 +print("FA3:", fa3.__file__) +PY + +# (b) numerics + isolated bench --------------------------------------------- +echo +echo "--- XSA numerics (fp32 / bf16) ---" +python - <<'PY' +import sys, torch +sys.path.insert(0, "/workspace/work") +from xsa_triton import xsa_triton, xsa_torch + +torch.manual_seed(0) +device = torch.device("cuda") +for dtype in (torch.float32, torch.bfloat16): + print(f"dtype={dtype}") + for B, T, H, Hkv, D in [(2,128,8,4,64),(8,2048,8,4,64),(4,1024,4,4,64),(2,512,16,4,64)]: + y = torch.randn(B,T,H,D,device=device,dtype=dtype,requires_grad=True) + v = torch.randn(B,T,Hkv,D,device=device,dtype=dtype,requires_grad=True) + y2 = y.detach().clone().requires_grad_(); v2 = v.detach().clone().requires_grad_() + out_t = xsa_triton(y, v); out_r = xsa_torch(y2, v2) + fwd_err = (out_t - out_r).abs().max().item() + go = torch.randn_like(out_t) + out_t.backward(go); out_r.backward(go) + gy_err = (y.grad - y2.grad).abs().max().item() + gv_err = (v.grad - v2.grad).abs().max().item() + print(f" B={B:<2} T={T:<5} H={H:<2} Hkv={Hkv:<2} D={D}: fwd={fwd_err:.3e} gy={gy_err:.3e} gv={gv_err:.3e}") +PY + +echo +echo "--- XSA isolated microbench (B=8 T=2048 H=8 Hkv=4 D=64 bf16) ---" +python - <<'PY' +import sys, time, torch +sys.path.insert(0, "/workspace/work") +from xsa_triton import xsa_triton, xsa_torch + +device = torch.device("cuda"); dtype = torch.bfloat16 +B,T,H,Hkv,D = 8,2048,8,4,64 +y = torch.randn(B,T,H,D,device=device,dtype=dtype,requires_grad=True) +v = torch.randn(B,T,Hkv,D,device=device,dtype=dtype,requires_grad=True) + +def make_bench(fn_fwd): + def run(fwd_only=False): + y_ = y.detach().clone().requires_grad_() + v_ = v.detach().clone().requires_grad_() + out = fn_fwd(y_, v_) + if not fwd_only: + out.sum().backward() + return run + +run_eager = make_bench(xsa_torch) +xsa_torch_c = torch.compile(xsa_torch, fullgraph=True, dynamic=False) +run_compiled = make_bench(xsa_torch_c) +run_triton = make_bench(xsa_triton) + +def bench(fn, n_warmup=30, n_iter=300): + for _ in range(n_warmup): fn() + torch.cuda.synchronize() + t0 = time.perf_counter() + for _ in range(n_iter): fn() + torch.cuda.synchronize() + return (time.perf_counter()-t0)*1e6/n_iter + +print(f"{'impl':<22} {'fwd us':>10} {'fwd+bwd us':>12}") +print("-"*48) +for lbl, fn in [("torch_eager",run_eager),("torch_compiled",run_compiled),("triton",run_triton)]: + f = bench(lambda: fn(fwd_only=True)) + t = bench(fn) + print(f"{lbl:<22} {f:>10.1f} {t:>12.1f}") +PY + +# (c) Block-level microbench (FA3 baseline vs FA3+Triton XSA) --------------- +echo +echo "--- Block microbench: FA3 baseline vs FA3 + Triton XSA ---" +python - <<'PY' +import os, sys, time, types, torch, warnings +warnings.filterwarnings("ignore") +sys.path.insert(0, "/workspace/work") +from xsa_triton import xsa_triton as xsa_triton_fn + +# Load the FA3-enabled baseline (now that FA3 is importable) +src = open("/workspace/work/train_gpt_baseline.py").read() +ns = {"__name__": "pg_baseline"} +exec(compile(src, "train_gpt_baseline.py", "exec"), ns) +Block, Rotary = ns["Block"], ns["Rotary"] + +device = torch.device("cuda"); dtype = torch.bfloat16 +torch.manual_seed(0) +B, T, D = 8, 2048, 512 +H, KVH = 8, 4 + +def build_block(use_triton_xsa, layer_idx=7): + blk = Block( + dim=D, num_heads=H, num_kv_heads=KVH, mlp_mult=4.0, + rope_base=10000.0, qk_gain_init=5.0, train_seq_len=T, + layer_idx=layer_idx, ln_scale=True, + ).to(device).to(dtype) + for p in blk.parameters(): + if p.ndim < 2: p.data = p.data.float() + blk.parallel = True + blk.attn.use_xsa = True + blk.attn.rope_dims = 16 + blk.attn.rotary = Rotary(D // H, base=10000.0, train_seq_len=T, rope_dims=16).to(device) + if use_triton_xsa: + def _xsa_triton_method(self, y, v): return xsa_triton_fn(y, v) + blk.attn._xsa_efficient = types.MethodType(_xsa_triton_method, blk.attn) + return blk + +def bench_block(blk_callable, n_warmup=20, n_iter=100): + x = torch.randn(B, T, D, device=device, dtype=dtype, requires_grad=True) + x0 = torch.randn(B, T, D, device=device, dtype=dtype, requires_grad=True) + for _ in range(n_warmup): + y = blk_callable(x, x0); y.sum().backward() + torch.cuda.synchronize() + t0 = time.perf_counter() + for _ in range(n_iter): + y = blk_callable(x, x0); y.sum().backward() + torch.cuda.synchronize() + return (time.perf_counter()-t0) * 1000.0 / n_iter + +print(f"shape: B={B} T={T} D={D} H={H} KVH={KVH} dtype={dtype}") +print(f"{'variant':<42} {'ms/iter':>10}") +print("-"*56) + +for label, use_triton, use_compile in [ + ("eager FA3 torch-XSA", False, False), + ("eager FA3 Triton-XSA", True, False), + ("compiled FA3 torch-XSA", False, True), + ("compiled FA3 Triton-XSA", True, True), +]: + blk = build_block(use_triton_xsa=use_triton) + fn = torch.compile(blk, dynamic=False, mode="max-autotune-no-cudagraphs") if use_compile else blk + try: + ms = bench_block(fn) + print(f" {label:<40} {ms:>10.3f}") + except Exception as e: + print(f" {label:<40} FAIL: {type(e).__name__}: {str(e)[:80]}") +PY + +echo "=== PHASE 2b DONE ===" +date -u +"%Y-%m-%dT%H:%M:%SZ" diff --git a/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase2c.sh b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase2c.sh new file mode 100644 index 0000000000..71d8487e33 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase2c.sh @@ -0,0 +1,225 @@ +#!/usr/bin/env bash +# Phase 2c: +# (a) Inspect FA3 module + standalone FA3-vs-SDPA attention bench. +# (b) Add QKV fusion (stack c_q/c_k/c_v weights into single GEMM) via a helper. +# (c) Block-level bench across the full cross product: +# {eager, compiled} x {FA3, SDPA} x {torch-XSA, Triton-XSA} x {3-linear, fused-QKV} +# +# Depends on /workspace/work/xsa_triton.py (already there) and FA3 installed. +set -euo pipefail +cd /workspace +echo "=== PHASE 2c ===" +date -u +"%Y-%m-%dT%H:%M:%SZ" + +# (a) FA3 module inspection + attention backend bench ---------------------- +echo +echo "--- FA3 module inspection ---" +python - <<'PY' +import flash_attn_interface as fa3, inspect, pathlib +print("file:", fa3.__file__) +print("attrs:", [a for a in dir(fa3) if not a.startswith('_')][:25]) +# show full source (it's a Python wrapper; only a few KB) +src = pathlib.Path(fa3.__file__).read_text() +print(f"source bytes: {len(src)}") +print("--- source (first 80 lines) ---") +print("\n".join(src.splitlines()[:80])) +PY + +echo +echo "--- FA3 vs SDPA attention bench (B=8 T=2048 H=8 Hkv=4 D=64 bf16) ---" +python - <<'PY' +import time, torch, torch.nn.functional as F +from flash_attn_interface import flash_attn_func as fa3 + +device = torch.device("cuda"); dtype = torch.bfloat16 +B, T, H, Hkv, D = 8, 2048, 8, 4, 64 +q = torch.randn(B, T, H, D, device=device, dtype=dtype, requires_grad=True) +k = torch.randn(B, T, Hkv, D, device=device, dtype=dtype, requires_grad=True) +v = torch.randn(B, T, Hkv, D, device=device, dtype=dtype, requires_grad=True) + +def fa3_fwd(fwd_only=False): + q.grad = k.grad = v.grad = None + out = fa3(q, k, v, causal=True) + if not fwd_only: + out.sum().backward() + +def sdpa_fwd(fwd_only=False): + q.grad = k.grad = v.grad = None + out = F.scaled_dot_product_attention( + q.transpose(1,2), k.transpose(1,2), v.transpose(1,2), + is_causal=True, enable_gqa=True, + ).transpose(1,2) + if not fwd_only: + out.sum().backward() + +def bench(fn, n_warmup=30, n_iter=300): + for _ in range(n_warmup): fn() + torch.cuda.synchronize() + t0 = time.perf_counter() + for _ in range(n_iter): fn() + torch.cuda.synchronize() + return (time.perf_counter() - t0) * 1e6 / n_iter + +print(f"{'backend':<12} {'fwd us':>10} {'fwd+bwd us':>12}") +print("-"*36) +for lbl, fn in [("FA3", fa3_fwd), ("SDPA", sdpa_fwd)]: + f = bench(lambda: fn(fwd_only=True)) + t = bench(fn) + print(f"{lbl:<12} {f:>10.1f} {t:>12.1f}") +PY + +# (b) QKV fusion helper + (c) block benchmark ------------------------------ +echo +echo "--- Block bench: all variants ---" +python - <<'PY' +import os, sys, time, types, torch, torch.nn.functional as F, warnings +warnings.filterwarnings("ignore") +sys.path.insert(0, "/workspace/work") +from xsa_triton import xsa_triton as xsa_triton_fn + +# Load FA3 baseline module +src = open("/workspace/work/train_gpt_baseline.py").read() +ns = {"__name__": "pg_baseline"} +exec(compile(src, "train_gpt_baseline.py", "exec"), ns) +Block = ns["Block"] +Rotary = ns["Rotary"] +CastedLinear = ns["CastedLinear"] +apply_rotary_emb = ns["apply_rotary_emb"] +fa3_func = ns["flash_attn_3_func"] # imported from flash_attn_interface + +device = torch.device("cuda"); dtype = torch.bfloat16 +torch.manual_seed(0) +B, T, D = 8, 2048, 512 +H, KVH = 8, 4 + +# ----- backend switches ---------------------------------------------------- +def sdpa_shim(q, k, v, causal=False): + gqa = q.size(-2) != k.size(-2) + return F.scaled_dot_product_attention( + q.transpose(1,2), k.transpose(1,2), v.transpose(1,2), + is_causal=causal, enable_gqa=gqa, + ).transpose(1,2) + +# ----- fused-QKV forward -------------------------------------------------- +def fused_qkv_forward(self, x): + bsz, seqlen, _ = x.shape + qkv = self.c_qkv(x) + q_dim = self.num_heads * self.head_dim + kv_dim = self.num_kv_heads * self.head_dim + q, k, v = qkv.split([q_dim, kv_dim, kv_dim], dim=-1) + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = self._attn_backend(q, k, v, causal=True) + if self.use_xsa: y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, y.size(-2) * y.size(-1)) + return self.proj(y) + +def fuse_qkv_into_attn(attn): + dim = attn.c_q.weight.size(1) + q_dim = attn.num_heads * attn.head_dim + kv_dim = attn.num_kv_heads * attn.head_dim + out_dim = q_dim + 2 * kv_dim + c_qkv = CastedLinear(dim, out_dim, bias=False) + c_qkv = c_qkv.to(attn.c_q.weight.device).to(attn.c_q.weight.dtype) + with torch.no_grad(): + c_qkv.weight.copy_(torch.cat([attn.c_q.weight, + attn.c_k.weight, + attn.c_v.weight], dim=0)) + attn.c_qkv = c_qkv + del attn.c_q, attn.c_k, attn.c_v + attn.forward = types.MethodType(fused_qkv_forward, attn) + +# ----- build block --------------------------------------------------------- +def build_block(use_triton_xsa, fused_qkv, backend="FA3"): + blk = Block( + dim=D, num_heads=H, num_kv_heads=KVH, mlp_mult=4.0, + rope_base=10000.0, qk_gain_init=5.0, train_seq_len=T, + layer_idx=7, ln_scale=True, + ).to(device).to(dtype) + for p in blk.parameters(): + if p.ndim < 2: p.data = p.data.float() + blk.parallel = True + blk.attn.use_xsa = True + blk.attn.rope_dims = 16 + blk.attn.rotary = Rotary(D // H, base=10000.0, train_seq_len=T, rope_dims=16).to(device) + # bind chosen attention backend on the module + if backend == "FA3": + blk.attn._attn_backend = staticmethod(fa3_func) + else: + blk.attn._attn_backend = staticmethod(sdpa_shim) + # also override the original attn.forward to route through _attn_backend for the + # 3-linear path (baseline uses flash_attn_3_func directly; we shim it in place) + if not fused_qkv: + # patch the existing forward to use _attn_backend instead of flash_attn_3_func + orig_forward = blk.attn.forward + # we need a new forward that uses self._attn_backend in place of flash_attn_3_func + def three_linear_forward(self, x): + bsz, seqlen, dim_ = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = self._attn_backend(q, k, v, causal=True) + if self.use_xsa: y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim_); return self.proj(y) + blk.attn.forward = types.MethodType(three_linear_forward, blk.attn) + else: + fuse_qkv_into_attn(blk.attn) + if use_triton_xsa: + def _xsa_tr(self, y, v): return xsa_triton_fn(y, v) + blk.attn._xsa_efficient = types.MethodType(_xsa_tr, blk.attn) + return blk + +# ----- bench wrapper ------------------------------------------------------- +def bench_block(callable_, n_warmup=20, n_iter=80): + x = torch.randn(B, T, D, device=device, dtype=dtype, requires_grad=True) + x0 = torch.randn(B, T, D, device=device, dtype=dtype, requires_grad=True) + for _ in range(n_warmup): + x.grad = None; x0.grad = None + y = callable_(x, x0); y.sum().backward() + torch.cuda.synchronize() + t0 = time.perf_counter() + for _ in range(n_iter): + x.grad = None; x0.grad = None + y = callable_(x, x0); y.sum().backward() + torch.cuda.synchronize() + return (time.perf_counter() - t0) * 1000.0 / n_iter + +# ----- grid ---------------------------------------------------------------- +grid = [] +for backend in ("FA3", "SDPA"): + for triton_xsa in (False, True): + for fused in (False, True): + for compile_ in (False, True): + grid.append((backend, triton_xsa, fused, compile_)) + +print(f"shape: B={B} T={T} D={D} H={H} KVH={KVH} dtype={dtype}") +print(f"{'backend':<5} {'xsa':<8} {'qkv':<9} {'mode':<9} {'ms/iter':>10}") +print("-" * 50) +for backend, triton_xsa, fused, compile_ in grid: + blk = build_block(use_triton_xsa=triton_xsa, fused_qkv=fused, backend=backend) + callable_ = torch.compile(blk, dynamic=False, mode="max-autotune-no-cudagraphs") if compile_ else blk + try: + ms = bench_block(callable_) + xsa_lbl = "triton" if triton_xsa else "torch" + qkv_lbl = "fused" if fused else "3-lin" + mode_lbl = "compiled" if compile_ else "eager" + print(f"{backend:<5} {xsa_lbl:<8} {qkv_lbl:<9} {mode_lbl:<9} {ms:>10.3f}") + except Exception as e: + print(f"{backend} xsa={triton_xsa} fused={fused} compile={compile_}: FAIL {type(e).__name__}: {str(e)[:80]}") +PY + +echo "=== PHASE 2c DONE ===" +date -u +"%Y-%m-%dT%H:%M:%SZ" diff --git a/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase2d.sh b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase2d.sh new file mode 100644 index 0000000000..41c36e61a4 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase2d.sh @@ -0,0 +1,240 @@ +#!/usr/bin/env bash +# Phase 2d: +# - Fix FA3 staticmethod bug (bind backend via closure, not staticmethod). +# - Verify QKV-fusion numerical equivalence. +# - Robust benchmark: compile ALL variants first, warm each, then measure +# each TWICE in a fixed order to detect drift. Focus on compiled numbers. +set -euo pipefail +cd /workspace +echo "=== PHASE 2d ===" +date -u +"%Y-%m-%dT%H:%M:%SZ" + +python - <<'PY' +import os, sys, time, types, warnings, torch, torch.nn.functional as F +warnings.filterwarnings("ignore") +sys.path.insert(0, "/workspace/work") +from xsa_triton import xsa_triton as xsa_triton_fn + +# Load FA3 baseline module +src = open("/workspace/work/train_gpt_baseline.py").read() +ns = {"__name__": "pg_baseline"} +exec(compile(src, "train_gpt_baseline.py", "exec"), ns) +Block = ns["Block"] +Rotary = ns["Rotary"] +CastedLinear = ns["CastedLinear"] +apply_rotary_emb = ns["apply_rotary_emb"] +fa3_func = ns["flash_attn_3_func"] + +device = torch.device("cuda"); dtype = torch.bfloat16 +torch.manual_seed(0) + +B, T, D = 8, 2048, 512 +H, KVH = 8, 4 + +# -------------------------------------------------------------------------- +# Backend functions (plain callables, no staticmethod) +# -------------------------------------------------------------------------- +def sdpa_backend(q, k, v, causal=False): + gqa = q.size(-2) != k.size(-2) + return F.scaled_dot_product_attention( + q.transpose(1,2), k.transpose(1,2), v.transpose(1,2), + is_causal=causal, enable_gqa=gqa, + ).transpose(1,2) + +def fa3_backend(q, k, v, causal=False): + return fa3_func(q, k, v, causal=causal) + +# -------------------------------------------------------------------------- +# Forwards: closures capture backend & xsa functions (avoids staticmethod issue) +# -------------------------------------------------------------------------- +def make_three_linear_forward(attn_backend, use_triton_xsa): + def forward(self, x): + bsz, seqlen, dim_ = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = attn_backend(q, k, v, causal=True) + if self.use_xsa: + y = xsa_triton_fn(y, v) if use_triton_xsa else self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim_); return self.proj(y) + return forward + +def make_fused_qkv_forward(attn_backend, use_triton_xsa): + def forward(self, x): + bsz, seqlen, _ = x.shape + qkv = self.c_qkv(x) + q_dim = self.num_heads * self.head_dim + kv_dim = self.num_kv_heads * self.head_dim + q, k, v = qkv.split([q_dim, kv_dim, kv_dim], dim=-1) + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = attn_backend(q, k, v, causal=True) + if self.use_xsa: + y = xsa_triton_fn(y, v) if use_triton_xsa else self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, y.size(-2) * y.size(-1)) + return self.proj(y) + return forward + +def fuse_qkv_weights(attn): + dim = attn.c_q.weight.size(1) + q_dim = attn.num_heads * attn.head_dim + kv_dim = attn.num_kv_heads * attn.head_dim + out_dim = q_dim + 2 * kv_dim + c_qkv = CastedLinear(dim, out_dim, bias=False) + c_qkv = c_qkv.to(attn.c_q.weight.device).to(attn.c_q.weight.dtype) + with torch.no_grad(): + c_qkv.weight.copy_(torch.cat([attn.c_q.weight, + attn.c_k.weight, + attn.c_v.weight], dim=0)) + attn.c_qkv = c_qkv + del attn.c_q, attn.c_k, attn.c_v + +# -------------------------------------------------------------------------- +# Numerical equivalence check: fused-QKV vs 3-linear with identical weights +# -------------------------------------------------------------------------- +print("--- QKV fusion numerical check ---") +torch.manual_seed(0) +blk3 = Block(dim=D, num_heads=H, num_kv_heads=KVH, mlp_mult=4.0, + rope_base=10000.0, qk_gain_init=5.0, train_seq_len=T, + layer_idx=7, ln_scale=True).to(device).to(dtype) +blk3.parallel = True; blk3.attn.use_xsa = True +blk3.attn.rope_dims = 16 +blk3.attn.rotary = Rotary(D // H, base=10000.0, train_seq_len=T, rope_dims=16).to(device) +blk3.attn.forward = types.MethodType( + make_three_linear_forward(sdpa_backend, use_triton_xsa=False), blk3.attn) + +# Mirror blk3 into a fused-QKV copy +import copy as _copy +blkF = _copy.deepcopy(blk3) +fuse_qkv_weights(blkF.attn) +blkF.attn.forward = types.MethodType( + make_fused_qkv_forward(sdpa_backend, use_triton_xsa=False), blkF.attn) + +x = torch.randn(2, 128, D, device=device, dtype=dtype) +x0 = torch.randn(2, 128, D, device=device, dtype=dtype) +with torch.no_grad(): + y3 = blk3(x, x0) + yF = blkF(x, x0) +fwd_err = (y3 - yF).abs().max().item() +print(f" fwd max err (3-lin vs fused): {fwd_err:.3e} (bf16 tol ~1e-2)") + +# -------------------------------------------------------------------------- +# Build + compile all variants, then bench each twice +# -------------------------------------------------------------------------- +def build_variant(backend_name, use_triton_xsa, fused_qkv): + backend = fa3_backend if backend_name == "FA3" else sdpa_backend + torch.manual_seed(0) + blk = Block(dim=D, num_heads=H, num_kv_heads=KVH, mlp_mult=4.0, + rope_base=10000.0, qk_gain_init=5.0, train_seq_len=T, + layer_idx=7, ln_scale=True).to(device).to(dtype) + for p in blk.parameters(): + if p.ndim < 2: p.data = p.data.float() + blk.parallel = True; blk.attn.use_xsa = True + blk.attn.rope_dims = 16 + blk.attn.rotary = Rotary(D // H, base=10000.0, train_seq_len=T, rope_dims=16).to(device) + if fused_qkv: + fuse_qkv_weights(blk.attn) + blk.attn.forward = types.MethodType( + make_fused_qkv_forward(backend, use_triton_xsa), blk.attn) + else: + blk.attn.forward = types.MethodType( + make_three_linear_forward(backend, use_triton_xsa), blk.attn) + return blk + +variants = [] +for backend in ("SDPA", "FA3"): + for use_triton_xsa in (False, True): + for fused in (False, True): + variants.append((backend, use_triton_xsa, fused)) + +# Build + compile + warm up each variant +print("\n--- Compiling + warming up all variants (this takes a few min) ---") +x = torch.randn(B, T, D, device=device, dtype=dtype, requires_grad=True) +x0 = torch.randn(B, T, D, device=device, dtype=dtype, requires_grad=True) +compiled_blocks = {} +for i, (backend, triton_xsa, fused) in enumerate(variants): + tag = f"{backend}/{'triton' if triton_xsa else 'torch '}XSA/{'fused' if fused else '3-lin'}" + t0 = time.perf_counter() + blk = build_variant(backend, triton_xsa, fused) + cfn = torch.compile(blk, dynamic=False, mode="max-autotune-no-cudagraphs") + try: + for _ in range(25): + x.grad = None; x0.grad = None + y = cfn(x, x0); y.sum().backward() + torch.cuda.synchronize() + compiled_blocks[(backend, triton_xsa, fused)] = cfn + print(f" [{i+1}/{len(variants)}] {tag} warmup OK in {time.perf_counter()-t0:.1f}s") + except Exception as e: + print(f" [{i+1}/{len(variants)}] {tag} FAIL: {type(e).__name__}: {str(e)[:100]}") + +# -------------------------------------------------------------------------- +# Bench each variant TWICE, report both to see drift +# -------------------------------------------------------------------------- +def bench_block(callable_, n_iter=150): + # no warmup here, we already warmed up + x.grad = None; x0.grad = None + torch.cuda.synchronize() + t0 = time.perf_counter() + for _ in range(n_iter): + x.grad = None; x0.grad = None + y = callable_(x, x0); y.sum().backward() + torch.cuda.synchronize() + return (time.perf_counter() - t0) * 1000.0 / n_iter + +print("\n--- Compiled block bench (two measurements; ms/iter) ---") +print(f"shape: B={B} T={T} D={D} H={H} KVH={KVH}") +print(f"{'variant':<35} {'run1':>9} {'run2':>9} {'drift%':>8}") +print("-" * 63) +# First pass +measurements = {} +for (backend, triton_xsa, fused), cfn in compiled_blocks.items(): + ms1 = bench_block(cfn) + measurements[(backend, triton_xsa, fused)] = [ms1] +# Second pass (different iteration order — reversed — to detect thermal/cache drift) +for (backend, triton_xsa, fused), cfn in reversed(list(compiled_blocks.items())): + ms2 = bench_block(cfn) + measurements[(backend, triton_xsa, fused)].append(ms2) + +# Print in logical order +best = (float("inf"), None) +baseline = None +for backend in ("SDPA", "FA3"): + for triton_xsa in (False, True): + for fused in (False, True): + key = (backend, triton_xsa, fused) + if key not in measurements: continue + xsa_lbl = "triton-XSA" if triton_xsa else "torch-XSA " + qkv_lbl = "fused-QKV" if fused else "3-lin-QKV" + tag = f"{backend:<4} {xsa_lbl} {qkv_lbl}" + ms1, ms2 = measurements[key] + drift = abs(ms1 - ms2) / min(ms1, ms2) * 100 + print(f"{tag:<35} {ms1:>9.3f} {ms2:>9.3f} {drift:>7.1f}%") + use_ms = min(ms1, ms2) + if backend == "SDPA" and not triton_xsa and not fused: + baseline = use_ms + if use_ms < best[0]: + best = (use_ms, key) + +print("\n--- Summary ---") +if baseline is not None and best[1] is not None: + best_ms, best_key = best + print(f"SDPA baseline (torch-XSA, 3-lin-QKV): {baseline:.3f} ms/iter") + print(f"Best variant: {best_key}: {best_ms:.3f} ms/iter") + print(f"Speedup: {(baseline/best_ms - 1)*100:.1f}%") + +print("=== PHASE 2d DONE ===") +import datetime; print(datetime.datetime.utcnow().isoformat() + "Z") +PY diff --git a/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase2e.sh b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase2e.sh new file mode 100644 index 0000000000..cf915bda22 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase2e.sh @@ -0,0 +1,179 @@ +#!/usr/bin/env bash +# Phase 2e: robust block microbench. +# - torch._dynamo.reset() + empty_cache() between variants (fixes FA3 recompile-fallback bug). +# - Thermal prime (GPU hot) before each measurement. +# - CUDA events for per-iter timing; 300 samples; report p10/p50/p90. +set -euo pipefail +cd /workspace +echo "=== PHASE 2e ===" +date -u +"%Y-%m-%dT%H:%M:%SZ" + +python - <<'PY' +import os, sys, time, types, warnings, torch, torch.nn.functional as F, torch._dynamo +warnings.filterwarnings("ignore") +sys.path.insert(0, "/workspace/work") +from xsa_triton import xsa_triton as xsa_triton_fn + +src = open("/workspace/work/train_gpt_baseline.py").read() +ns = {"__name__": "pg_baseline"} +exec(compile(src, "train_gpt_baseline.py", "exec"), ns) +Block = ns["Block"]; Rotary = ns["Rotary"]; CastedLinear = ns["CastedLinear"] +apply_rotary_emb = ns["apply_rotary_emb"]; fa3_func = ns["flash_attn_3_func"] + +device = torch.device("cuda"); dtype = torch.bfloat16 +B, T, D = 8, 2048, 512; H, KVH = 8, 4 + +def sdpa_backend(q, k, v, causal=False): + return F.scaled_dot_product_attention( + q.transpose(1,2), k.transpose(1,2), v.transpose(1,2), + is_causal=causal, enable_gqa=(q.size(-2)!=k.size(-2)), + ).transpose(1,2) + +def fa3_backend(q, k, v, causal=False): return fa3_func(q, k, v, causal=causal) + +def make_three_linear_forward(attn_backend, use_triton_xsa): + def forward(self, x): + bsz, seqlen, dim_ = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)); k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims); k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = attn_backend(q, k, v, causal=True) + if self.use_xsa: y = xsa_triton_fn(y, v) if use_triton_xsa else self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim_); return self.proj(y) + return forward + +def make_fused_qkv_forward(attn_backend, use_triton_xsa): + def forward(self, x): + bsz, seqlen, _ = x.shape + qkv = self.c_qkv(x) + q_dim = self.num_heads * self.head_dim + kv_dim = self.num_kv_heads * self.head_dim + q, k, v = qkv.split([q_dim, kv_dim, kv_dim], dim=-1) + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)); k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims); k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = attn_backend(q, k, v, causal=True) + if self.use_xsa: y = xsa_triton_fn(y, v) if use_triton_xsa else self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, y.size(-2) * y.size(-1)); return self.proj(y) + return forward + +def fuse_qkv_weights(attn): + dim = attn.c_q.weight.size(1) + q_dim = attn.num_heads * attn.head_dim; kv_dim = attn.num_kv_heads * attn.head_dim + c_qkv = CastedLinear(dim, q_dim + 2*kv_dim, bias=False) + c_qkv = c_qkv.to(attn.c_q.weight.device).to(attn.c_q.weight.dtype) + with torch.no_grad(): + c_qkv.weight.copy_(torch.cat([attn.c_q.weight, attn.c_k.weight, attn.c_v.weight], dim=0)) + attn.c_qkv = c_qkv + del attn.c_q, attn.c_k, attn.c_v + +def build_variant(backend_name, use_triton_xsa, fused_qkv): + backend = fa3_backend if backend_name == "FA3" else sdpa_backend + torch.manual_seed(0) + blk = Block(dim=D, num_heads=H, num_kv_heads=KVH, mlp_mult=4.0, + rope_base=10000.0, qk_gain_init=5.0, train_seq_len=T, + layer_idx=7, ln_scale=True).to(device).to(dtype) + for p in blk.parameters(): + if p.ndim < 2: p.data = p.data.float() + blk.parallel = True; blk.attn.use_xsa = True + blk.attn.rope_dims = 16 + blk.attn.rotary = Rotary(D // H, base=10000.0, train_seq_len=T, rope_dims=16).to(device) + if fused_qkv: + fuse_qkv_weights(blk.attn) + blk.attn.forward = types.MethodType(make_fused_qkv_forward(backend, use_triton_xsa), blk.attn) + else: + blk.attn.forward = types.MethodType(make_three_linear_forward(backend, use_triton_xsa), blk.attn) + return blk + +# --- thermal primer: long GEMM loop to get GPU to turbo --- +def thermal_prime(seconds=3.0): + a = torch.randn(4096, 4096, device=device, dtype=torch.bfloat16) + b = torch.randn(4096, 4096, device=device, dtype=torch.bfloat16) + torch.cuda.synchronize() + t0 = time.perf_counter() + while time.perf_counter() - t0 < seconds: + c = a @ b + torch.cuda.synchronize() + +# --- per-variant measurement --- +def measure_variant(cfn, n_warmup=30, n_samples=300): + x = torch.randn(B, T, D, device=device, dtype=dtype, requires_grad=True) + x0 = torch.randn(B, T, D, device=device, dtype=dtype, requires_grad=True) + for _ in range(n_warmup): + x.grad = None; x0.grad = None + y = cfn(x, x0); y.sum().backward() + torch.cuda.synchronize() + + thermal_prime(2.0) + + samples_ms = [] + starts = [torch.cuda.Event(enable_timing=True) for _ in range(n_samples)] + ends = [torch.cuda.Event(enable_timing=True) for _ in range(n_samples)] + for i in range(n_samples): + x.grad = None; x0.grad = None + starts[i].record() + y = cfn(x, x0); y.sum().backward() + ends[i].record() + torch.cuda.synchronize() + samples_ms = [s.elapsed_time(e) for s, e in zip(starts, ends)] + samples_ms.sort() + return samples_ms + +variants = [ + ("SDPA", False, False), # baseline + ("SDPA", False, True), # +fused-QKV + ("SDPA", True, False), # +triton-XSA + ("SDPA", True, True), # +both + ("FA3", False, False), + ("FA3", False, True), + ("FA3", True, False), + ("FA3", True, True), +] + +print("--- Bench (300 samples per variant, CUDA events, dynamo reset between) ---") +print(f"shape: B={B} T={T} D={D} H={H} KVH={KVH}") +print(f"{'variant':<32} {'p10':>7} {'p50':>7} {'p90':>7} {'min':>7}") +print("-" * 64) + +results = {} +for backend, triton_xsa, fused in variants: + torch._dynamo.reset() + torch.cuda.empty_cache() + blk = build_variant(backend, triton_xsa, fused) + cfn = torch.compile(blk, dynamic=False, mode="max-autotune-no-cudagraphs") + try: + s = measure_variant(cfn) + n = len(s) + p10 = s[int(n*0.10)]; p50 = s[int(n*0.50)]; p90 = s[int(n*0.90)]; smin = s[0] + xsa_lbl = "triton-XSA" if triton_xsa else "torch-XSA " + qkv_lbl = "fused-QKV" if fused else "3-lin-QKV" + label = f"{backend:<4} {xsa_lbl} {qkv_lbl}" + print(f"{label:<32} {p10:>7.3f} {p50:>7.3f} {p90:>7.3f} {smin:>7.3f}") + results[(backend, triton_xsa, fused)] = (p10, p50, p90, smin) + except Exception as e: + print(f"FAIL {backend} xsa={triton_xsa} fused={fused}: {type(e).__name__}: {str(e)[:80]}") + +print("\n--- Summary (using p50) ---") +base_key = ("SDPA", False, False) +base_p50 = results.get(base_key, (None, None, None, None))[1] +if base_p50: + print(f"Baseline (SDPA torch-XSA 3-lin-QKV) p50: {base_p50:.3f} ms") + print(f"\n{'variant':<32} {'Δ p50':>8} {'speedup':>8}") + for k, (p10, p50, p90, smin) in sorted(results.items(), key=lambda kv: kv[1][1]): + xsa_lbl = "triton-XSA" if k[1] else "torch-XSA " + qkv_lbl = "fused-QKV" if k[2] else "3-lin-QKV" + tag = f"{k[0]:<4} {xsa_lbl} {qkv_lbl}" + delta = p50 - base_p50 + spd = (base_p50 / p50 - 1) * 100 + print(f"{tag:<32} {delta:>+7.3f} {spd:>+7.1f}%") +print("\n=== PHASE 2e DONE ===") +PY +date -u +"%Y-%m-%dT%H:%M:%SZ" diff --git a/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase3.sh b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase3.sh new file mode 100644 index 0000000000..56093d00ea --- /dev/null +++ b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase3.sh @@ -0,0 +1,49 @@ +#!/usr/bin/env bash +# Phase 3: 200-step training pilots on 1xH100. Compares: +# A. baseline (no patches) +# B. +QKV fusion only +# C. +QKV fusion + Triton-XSA +# Step_avg from the last logged line is the headline number. +set -euo pipefail +cd /workspace/parameter-golf + +export DATA_DIR=/workspace/parameter-golf/data/ +export ITERATIONS=200 +export VAL_LOSS_EVERY=0 +export TRAIN_LOG_EVERY=20 +export WARMUP_STEPS=10 +export MAX_WALLCLOCK_SECONDS=0 +export SEED=42 + +mkdir -p logs + +echo "=== PHASE 3 training pilots ===" +date -u +"%Y-%m-%dT%H:%M:%SZ" + +run_pilot () { + local tag="$1" patch_qkv="$2" patch_xsa="$3" + echo + echo "--- run: $tag (PATCH_QKV=$patch_qkv PATCH_XSA=$patch_xsa) ---" + rm -f final_model.pt final_model.int6.ptz + RUN_ID="phase3_${tag}" \ + PATCH_QKV="$patch_qkv" \ + PATCH_XSA="$patch_xsa" \ + python3 /workspace/phase3_run.py 2>&1 | tail -60 + echo + echo "--- step_avg from logs/phase3_${tag}.txt ---" + grep -E '^step:[0-9]+.*step_avg' "logs/phase3_${tag}.txt" | tail -10 +} + +run_pilot baseline 0 0 +run_pilot qkvonly 1 0 +run_pilot qkv_xsa 1 1 + +echo +echo "--- headline step_avg (last logged step) per run ---" +for tag in baseline qkvonly qkv_xsa; do + last=$(grep -E '^step:[0-9]+.*step_avg' "logs/phase3_${tag}.txt" | tail -1) + echo " phase3_${tag}: $last" +done + +echo "=== PHASE 3 DONE ===" +date -u +"%Y-%m-%dT%H:%M:%SZ" diff --git a/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase3b.sh b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase3b.sh new file mode 100644 index 0000000000..4fc5183b82 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bench_scripts/phase3b.sh @@ -0,0 +1,86 @@ +#!/usr/bin/env bash +# Phase 3b: segfault-tolerant rerun of the patched pilots. +# Reuses the existing logs/phase3_baseline.txt from the last run. +set -eu # NO pipefail — segfaults at quant-time are expected +cd /workspace/parameter-golf + +export DATA_DIR=/workspace/parameter-golf/data/ +export ITERATIONS=200 +export VAL_LOSS_EVERY=0 +export TRAIN_LOG_EVERY=20 +export WARMUP_STEPS=10 +export MAX_WALLCLOCK_SECONDS=0 +export SEED=42 + +mkdir -p logs + +echo "=== PHASE 3b ===" +date -u +"%Y-%m-%dT%H:%M:%SZ" + +run_pilot () { + local tag="$1" patch_qkv="$2" patch_xsa="$3" + echo + echo "--- run: $tag (PATCH_QKV=$patch_qkv PATCH_XSA=$patch_xsa) ---" + rm -f final_model.pt final_model.int6.ptz + rm -f "logs/phase3_${tag}.txt" + local t0=$(date +%s) + set +e + RUN_ID="phase3_${tag}" PATCH_QKV="$patch_qkv" PATCH_XSA="$patch_xsa" \ + python3 /workspace/phase3_run.py > "/tmp/pilot_${tag}.stdout" 2>&1 + rc=$? + set -e + local t1=$(date +%s) + echo " wall: $((t1-t0))s exit: $rc (non-zero expected if quant path segfaults)" + echo " last training lines from logs/phase3_${tag}.txt:" + if [ -f "logs/phase3_${tag}.txt" ]; then + grep -E '^[0-9]+/[0-9]+ train_loss' "logs/phase3_${tag}.txt" | tail -3 | sed 's/^/ /' + else + echo " (log file missing)" + fi + echo " last 10 stdout lines:" + tail -10 "/tmp/pilot_${tag}.stdout" | sed 's/^/ /' +} + +run_pilot qkvonly 1 0 +run_pilot qkv_xsa 1 1 + +# --- SUMMARY --------------------------------------------------------------- +echo +echo "--- SUMMARY (step_avg from logs/phase3_*.txt) ---" +printf "%-20s %12s %14s %14s\n" "run" "step_avg(ms)" "train_loss" "tok/s" +printf "%-20s %12s %14s %14s\n" "---" "------------" "----------" "-----" +for tag in baseline qkvonly qkv_xsa; do + logf="logs/phase3_${tag}.txt" + if [ ! -f "$logf" ]; then + printf "%-20s %12s\n" "phase3_${tag}" "(no log)" + continue + fi + final=$(grep -E '^200/200 train_loss' "$logf" | tail -1 || true) + if [ -z "$final" ]; then + last_seen=$(grep -E '^[0-9]+/[0-9]+ train_loss' "$logf" | tail -1) + printf "%-20s %-40s\n" "phase3_${tag}" "(no 200/200 — last: $last_seen)" + continue + fi + time_m=$(echo "$final" | grep -oE 'train_time: [0-9.]+' | awk '{print $2}') + loss=$(echo "$final" | grep -oE 'train_loss: [0-9.]+' | awk '{print $2}') + toks=$(echo "$final" | grep -oE 'tok/s: [0-9]+' | awk '{print $2}') + step_ms=$(python3 -c "print(round($time_m * 60000 / 200, 1))") + printf "%-20s %12s %14s %14s\n" "phase3_${tag}" "$step_ms" "$loss" "$toks" +done + +# Also compute steady-state step_avg from steps 100..200 (after layer loop at step 70) +echo +echo "--- steady-state step_avg (last 100 steps after layer loop) ---" +for tag in baseline qkvonly qkv_xsa; do + logf="logs/phase3_${tag}.txt" + if [ ! -f "$logf" ]; then continue; fi + t100=$(grep -E '^100/200 train_loss' "$logf" | grep -oE 'train_time: [0-9.]+' | awk '{print $2}') + t200=$(grep -E '^200/200 train_loss' "$logf" | grep -oE 'train_time: [0-9.]+' | awk '{print $2}') + if [ -n "$t100" ] && [ -n "$t200" ]; then + step_ms=$(python3 -c "print(round(($t200 - $t100) * 60000 / 100, 1))") + echo " phase3_${tag}: ${step_ms}ms/step (from ${t100}m → ${t200}m)" + fi +done + +echo "=== PHASE 3b DONE ===" +date -u +"%Y-%m-%dT%H:%M:%SZ" diff --git a/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bootstrap.sh b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bootstrap.sh new file mode 100644 index 0000000000..ea28af2967 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/bootstrap.sh @@ -0,0 +1,61 @@ +#!/usr/bin/env bash +# Bootstrap a FRESH RunPod 1xH100 pod (pytorch:1.0.2-cu1281-torch280-ubuntu2404) +# from scratch to the state where phase4.sh (H-Net M1 pilot) can run. +# +# Upload this file + unpack.py + hnet_m1/ to /workspace/ before running. +# +# Expected cost: ~5 min wallclock on 1xH100 = ~$0.30. +# +# Usage: bash bootstrap.sh 2>&1 | tee bootstrap.log +set -euo pipefail +cd /workspace + +echo "=== BOOTSTRAP ===" +date -u +"%Y-%m-%dT%H:%M:%SZ" + +# 1. env sanity -------------------------------------------------------------- +echo "--- env ---" +python - <<'PY' +import sys, torch, triton +print("python", sys.version.split()[0]) +print("torch", torch.__version__, "cuda", torch.version.cuda, "triton", triton.__version__) +print("gpu", torch.cuda.get_device_name(0), "sm", torch.cuda.get_device_capability(0)) +PY + +# 2. clone repo -------------------------------------------------------------- +echo "--- clone parameter-golf ---" +if [ ! -d parameter-golf ]; then + git clone --depth 1 https://github.com/openai/parameter-golf.git +fi +cd parameter-golf +git log -1 --oneline + +# 3. python deps ------------------------------------------------------------- +echo "--- pip deps ---" +pip install -q --no-input brotli sentencepiece huggingface-hub datasets tqdm 2>&1 | tail -3 + +# 4. SP8192 data (2 train shards smoke subset + full val) -------------------- +echo "--- SP8192 data ---" +MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf \ + python3 data/cached_challenge_fineweb.py --variant sp8192 --train-shards 2 2>&1 | tail -5 +ls -lh data/datasets/fineweb10B_sp8192/ +ls -lh data/tokenizers/ + +# 5. unpack bigbag's baseline ----------------------------------------------- +echo "--- unpack bigbag train_gpt.py ---" +mkdir -p /workspace/work +REC=records/track_10min_16mb/2026-04-09_SP8192_3LayerRecur_ParResid_QK525_LegalTTT +python3 /workspace/unpack.py "$REC/train_gpt.py" /workspace/work/train_gpt_baseline.py +wc -l /workspace/work/train_gpt_baseline.py + +# 6. FA3 install ------------------------------------------------------------- +echo "--- FA3 install (cu128_torch280) ---" +pip install --quiet --no-deps flash_attn_3 \ + --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch280/ 2>&1 | tail -3 +python -c "import flash_attn_interface as fa3; print('FA3 OK:', fa3.__file__)" + +echo +echo "=== BOOTSTRAP DONE ===" +date -u +"%Y-%m-%dT%H:%M:%SZ" +echo +echo "Next: bash /workspace/hnet_m1/phase4.sh 2>&1 | tee /workspace/hnet_m1_pilot.log" diff --git a/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/hnet_m1/hnet_m1.py b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/hnet_m1/hnet_m1.py new file mode 100644 index 0000000000..f41f075bdb --- /dev/null +++ b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/hnet_m1/hnet_m1.py @@ -0,0 +1,156 @@ +"""H-Net Milestone 1 model: hierarchical byte-level stack with a FIXED chunker. + +Architecture: + bytes --> byte_emb (256 x D_enc) + --> byte_encoder (2 blocks at D_enc=256) + --> enc_to_main projection (D_enc -> D_main) + --> fixed chunker: x[:, ::CHUNK_STRIDE, :] + --> main_network (11 blocks at D_main=512) + --> main_to_dec projection (D_main -> D_enc) + --> upsampler: x.repeat_interleave(CHUNK_STRIDE, dim=1) (truncate/pad to T) + --> byte_decoder (1 block at D_enc=256) + --> final_norm + byte_head (D_enc -> 256) + --> per-byte logits + +Reuses bigbag's Block / RMSNorm / CastedLinear / Rotary / apply_rotary_emb via +namespace injection from an already-exec'd baseline module. + +No learned chunker. No ratio loss. No depth recurrence / parallel residuals in +M1 (can add for M2/M3). Plain AdamW on everything. +""" +from __future__ import annotations + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def build_hnet_m1(ns: dict, *, byte_seq_len: int, chunk_stride: int, + d_enc: int = 256, d_main: int = 512, + enc_num_layers: int = 2, main_num_layers: int = 11, dec_num_layers: int = 1, + enc_num_heads: int = 8, enc_num_kv_heads: int = 4, enc_mlp_mult: float = 3.0, + main_num_heads: int = 8, main_num_kv_heads: int = 4, main_mlp_mult: float = 4.0, + dec_num_heads: int = 8, dec_num_kv_heads: int = 4, dec_mlp_mult: float = 3.0, + rope_base: float = 10000.0, + enc_rope_dims: int = 16, main_rope_dims: int = 16, dec_rope_dims: int = 16, + enc_qk_gain_init: float = 1.5, main_qk_gain_init: float = 5.0, dec_qk_gain_init: float = 1.5): + """Factory that builds an HNetM1 module using classes from baseline `ns`.""" + Block = ns["Block"] + RMSNorm = ns["RMSNorm"] + CastedLinear = ns["CastedLinear"] + Rotary = ns["Rotary"] + apply_rotary_emb = ns["apply_rotary_emb"] # noqa: F841 (used inside Block) + + class HNetM1(nn.Module): + def __init__(self): + super().__init__() + self.byte_seq_len = byte_seq_len + self.chunk_stride = chunk_stride + self.d_enc = d_enc + self.d_main = d_main + self.main_seq_len = (byte_seq_len + chunk_stride - 1) // chunk_stride + + # byte embedding (256 vocab) + self.byte_emb = nn.Embedding(256, d_enc) + nn.init.normal_(self.byte_emb.weight, mean=0.0, std=0.01) + + # byte encoder + self.byte_encoder = nn.ModuleList([ + Block(dim=d_enc, num_heads=enc_num_heads, num_kv_heads=enc_num_kv_heads, + mlp_mult=enc_mlp_mult, rope_base=rope_base, qk_gain_init=enc_qk_gain_init, + train_seq_len=byte_seq_len, layer_idx=i, ln_scale=True) + for i in range(enc_num_layers) + ]) + for blk in self.byte_encoder: + blk.attn.rope_dims = enc_rope_dims + blk.attn.rotary = Rotary(d_enc // enc_num_heads, base=rope_base, + train_seq_len=byte_seq_len, rope_dims=enc_rope_dims) + + # enc -> main projection + self.enc_to_main = CastedLinear(d_enc, d_main, bias=False) + + # main network + self.main_blocks = nn.ModuleList([ + Block(dim=d_main, num_heads=main_num_heads, num_kv_heads=main_num_kv_heads, + mlp_mult=main_mlp_mult, rope_base=rope_base, qk_gain_init=main_qk_gain_init, + train_seq_len=self.main_seq_len, layer_idx=i, ln_scale=True) + for i in range(main_num_layers) + ]) + for blk in self.main_blocks: + blk.attn.rope_dims = main_rope_dims + blk.attn.rotary = Rotary(d_main // main_num_heads, base=rope_base, + train_seq_len=self.main_seq_len, rope_dims=main_rope_dims) + + # main -> dec projection + self.main_to_dec = CastedLinear(d_main, d_enc, bias=False) + + # byte decoder + self.byte_decoder = nn.ModuleList([ + Block(dim=d_enc, num_heads=dec_num_heads, num_kv_heads=dec_num_kv_heads, + mlp_mult=dec_mlp_mult, rope_base=rope_base, qk_gain_init=dec_qk_gain_init, + train_seq_len=byte_seq_len, layer_idx=i, ln_scale=True) + for i in range(dec_num_layers) + ]) + for blk in self.byte_decoder: + blk.attn.rope_dims = dec_rope_dims + blk.attn.rotary = Rotary(d_enc // dec_num_heads, base=rope_base, + train_seq_len=byte_seq_len, rope_dims=dec_rope_dims) + + # final norm + head + self.final_norm = RMSNorm() + self.byte_head = CastedLinear(d_enc, 256, bias=False) + nn.init.zeros_(self.byte_head.weight) # zero-init per baseline convention + + def forward_logits(self, input_bytes: torch.Tensor) -> torch.Tensor: + B, T = input_bytes.shape + assert T == self.byte_seq_len, f"byte_seq_len mismatch {T} vs {self.byte_seq_len}" + + x = self.byte_emb(input_bytes) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + for blk in self.byte_encoder: + x = blk(x, x0) + x_enc_final = x # (B, T, D_enc) — fine-grained byte representation; skip source + + # project + downsample (fixed stride) + x_main = self.enc_to_main(x_enc_final) + x_main = x_main[:, :: self.chunk_stride, :].contiguous() # (B, T_main, D_main) + x0_main = x_main + for blk in self.main_blocks: + x_main = blk(x_main, x0_main) + + # project + upsample (repeat each chunk by stride) + x_dec = self.main_to_dec(x_main) # (B, T_main, D_enc) + x_dec = x_dec.repeat_interleave(self.chunk_stride, dim=1) # (B, T_main*stride, D_enc) + if x_dec.size(1) < T: + x_dec = F.pad(x_dec, (0, 0, 0, T - x_dec.size(1))) + x_dec = x_dec[:, :T, :].contiguous() + + # H-Net-style byte skip: combine main-output + byte-encoder-output so the byte + # decoder has BOTH coarse (main) and fine (byte encoder) representations per byte. + x_dec = x_dec + x_enc_final + x0_dec = x_enc_final # initial residual lane = byte-level info directly + for blk in self.byte_decoder: + x_dec = blk(x_dec, x0_dec) + + x_dec = self.final_norm(x_dec) + logits = self.byte_head(x_dec) # (B, T, 256) + return logits + + def forward(self, input_bytes: torch.Tensor, target_bytes: torch.Tensor) -> torch.Tensor: + logits = self.forward_logits(input_bytes) + logits = logits.reshape(-1, logits.size(-1)) + targets = target_bytes.reshape(-1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + return HNetM1() + + +def count_params(model: nn.Module, exclude_embeddings: bool = False) -> tuple[int, int]: + total = 0; nonembed = 0 + for name, p in model.named_parameters(): + total += p.numel() + if not any(k in name for k in ("byte_emb", "byte_head")): + nonembed += p.numel() + return total, nonembed diff --git a/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/hnet_m1/make_byte_shards.py b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/hnet_m1/make_byte_shards.py new file mode 100644 index 0000000000..3f5cc76a28 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/hnet_m1/make_byte_shards.py @@ -0,0 +1,107 @@ +"""Convert SP8192 token shards into UTF-8 byte shards. + +Reads the baseline's cached SP8192 bin files, decodes tokens to text via the +SentencePiece model, re-encodes to UTF-8 bytes, and writes byte shards in the +same on-disk layout (256 int32 header ints + np.ndarray: + header = np.fromfile(path, dtype=" None: + assert bytes_arr.dtype == np.uint16 + header = np.zeros(256, dtype=" np.ndarray: + """Decode a long token stream to UTF-8 bytes, processing in chunks. + + Returns a uint16 array of byte values in [0, 255]. + """ + out_pieces: list[bytes] = [] + n = tokens.size + for i in range(0, n, chunk_size): + chunk = tokens[i : i + chunk_size] + text = sp.decode(chunk.tolist()) + b = text.encode("utf-8", errors="replace") + out_pieces.append(b) + if i % (chunk_size * 10) == 0: + so_far = sum(len(p) for p in out_pieces) + print(f" decoded {i:>10d}/{n} tokens -> {so_far:>12d} bytes", flush=True) + joined = b"".join(out_pieces) + return np.frombuffer(joined, dtype=np.uint8).astype(np.uint16) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--tokenizer", required=True, help="Path to SentencePiece .model") + ap.add_argument("--in-pattern", required=True, help="glob for input SP token shards") + ap.add_argument("--out-dir", required=True, help="output dir for byte shards") + ap.add_argument("--limit-shards", type=int, default=0, help="process only the first N shards (0 = all)") + args = ap.parse_args() + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer) + print(f"tokenizer vocab_size={sp.vocab_size()}", flush=True) + + in_paths = sorted(Path(p) for p in glob.glob(args.in_pattern)) + if args.limit_shards > 0: + in_paths = in_paths[: args.limit_shards] + print(f"input shards: {len(in_paths)}", flush=True) + + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + for idx, inp in enumerate(in_paths): + t0 = time.perf_counter() + out_name = inp.name # keep filename for dataloader compat + out_path = out_dir / out_name + if out_path.exists(): + # Assume valid and skip + print(f"[{idx+1}/{len(in_paths)}] skip existing {out_path}", flush=True) + continue + print(f"[{idx+1}/{len(in_paths)}] reading {inp}", flush=True) + toks = read_sp_shard(inp) + print(f" tokens: {toks.size}", flush=True) + bytes_arr = decode_tokens_to_bytes(sp, toks) + print(f" -> {bytes_arr.size} bytes ({bytes_arr.size / max(toks.size, 1):.2f} bytes/token)", flush=True) + write_byte_shard(out_path, bytes_arr) + print(f" wrote {out_path} in {time.perf_counter() - t0:.1f}s", flush=True) + + +if __name__ == "__main__": + main() diff --git a/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/hnet_m1/phase4.sh b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/hnet_m1/phase4.sh new file mode 100644 index 0000000000..f40b7fe75b --- /dev/null +++ b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/hnet_m1/phase4.sh @@ -0,0 +1,62 @@ +#!/usr/bin/env bash +# Phase 4: H-Net Milestone 1 pilot. +# +# Prerequisites (done once on the pod): +# - Phase 0 already ran (parameter-golf cloned, SP8192 data downloaded). +# - train_gpt_baseline.py already unpacked at /workspace/work/. +# - FA3 installed (phase 2a). +# +# This script: +# 1. Converts 2 SP8192 train shards + 1 val shard into UTF-8 byte shards. +# 2. Runs train_hnet_m1.py for ITERATIONS steps (default 300). +# 3. Prints final per-byte val_bpb. +set -euo pipefail +cd /workspace + +# Where H-Net code lives (uploaded alongside this script) +HNET_DIR="/workspace/hnet_m1" +DATA_DIR="/workspace/parameter-golf/data" +SP_SHARDS="${DATA_DIR}/datasets/fineweb10B_sp8192/fineweb_*_*.bin" +TOKENIZER="${DATA_DIR}/tokenizers/fineweb_8192_bpe.model" +BYTE_DIR="${DATA_DIR}/datasets/fineweb10B_bytes" + +echo "=== PHASE 4 (H-Net M1) ===" +date -u +"%Y-%m-%dT%H:%M:%SZ" + +# --- 1. byte shard preprocessing -------------------------------------------- +if [ -z "$(ls -A ${BYTE_DIR} 2>/dev/null)" ]; then + echo "--- converting SP8192 shards to UTF-8 byte shards ---" + python "${HNET_DIR}/make_byte_shards.py" \ + --tokenizer "${TOKENIZER}" \ + --in-pattern "${SP_SHARDS}" \ + --out-dir "${BYTE_DIR}" +else + echo "--- byte shards already present at ${BYTE_DIR}, skipping preprocess ---" +fi +ls -lh "${BYTE_DIR}/" | head + +# --- 2. pilot training run -------------------------------------------------- +echo +echo "--- H-Net M1 pilot ---" +export DATA_DIR +export BYTE_DATA_DIR="${BYTE_DIR}" +export TOKENIZER_PATH="${TOKENIZER}" +export BASELINE_PATH="/workspace/work/train_gpt_baseline.py" +export ITERATIONS="${ITERATIONS:-300}" +export BYTE_SEQ_LEN="${BYTE_SEQ_LEN:-4096}" +export CHUNK_STRIDE="${CHUNK_STRIDE:-4}" +export BATCH_SIZE="${BATCH_SIZE:-8}" +export LR="${LR:-3e-4}" +export WARMUP_STEPS="${WARMUP_STEPS:-10}" +export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-20}" +export RUN_ID="${RUN_ID:-hnet_m1_pilot}" +export SEED="${SEED:-42}" + +python "${HNET_DIR}/train_hnet_m1.py" + +echo +echo "--- tail of log ---" +tail -n 20 "/workspace/logs/${RUN_ID}.txt" + +echo "=== PHASE 4 DONE ===" +date -u +"%Y-%m-%dT%H:%M:%SZ" diff --git a/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/hnet_m1/train_hnet_m1.py b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/hnet_m1/train_hnet_m1.py new file mode 100644 index 0000000000..21be1c89d4 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/hnet_m1/train_hnet_m1.py @@ -0,0 +1,269 @@ +"""H-Net Milestone 1 training pilot. + +Trains the HNetM1 model (hierarchical byte-level with fixed chunker) on the +cached byte shards (produced by make_byte_shards.py). Uses plain AdamW on all +parameters. Keeps it minimal: no TTT, no EMA, no Muon, no GPTQ. The question +this pilot answers: does the hierarchical architecture train at ~25M params on +FineWeb bytes? + +Env knobs: + DATA_DIR default /workspace/parameter-golf/data/ + BYTE_DATA_DIR default ${DATA_DIR}/datasets/fineweb10B_bytes/ + TOKENIZER_PATH default ${DATA_DIR}/tokenizers/fineweb_8192_bpe.model + (used for val-bpb byte-count sanity only, not for input) + ITERATIONS default 300 + WARMUP_STEPS default 10 + BYTE_SEQ_LEN default 4096 bytes per sample + CHUNK_STRIDE default 4 fixed chunker stride + BATCH_SIZE default 8 sequences per step + GRAD_ACCUM default 1 + LR default 3e-4 + WD default 0.01 + MAX_WALLCLOCK_SECONDS default 0 (0 = no cap) + RUN_ID default hnet_m1_pilot + VAL_EVERY default 0 (0 = only at end) + TRAIN_LOG_EVERY default 20 + +The training loop logs lines in the same shape as bigbag's, so our phase3 +parser works: + step:N/M train_loss:X train_time:Yms step_avg:Zms tok/s:W + +At end: prints per-byte val_loss (nats) and val_bpb. +""" +from __future__ import annotations +import glob +import math +import os +import sys +import time +import uuid +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +# --------------------------------------------------------------------------- +# Config from env +# --------------------------------------------------------------------------- +DATA_DIR = os.environ.get("DATA_DIR", "/workspace/parameter-golf/data/") +BYTE_DATA_DIR = os.environ.get("BYTE_DATA_DIR", os.path.join(DATA_DIR, "datasets", "fineweb10B_bytes")) +TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", os.path.join(DATA_DIR, "tokenizers", "fineweb_8192_bpe.model")) +ITERATIONS = int(os.environ.get("ITERATIONS", 300)) +WARMUP_STEPS = int(os.environ.get("WARMUP_STEPS", 10)) +BYTE_SEQ_LEN = int(os.environ.get("BYTE_SEQ_LEN", 4096)) +CHUNK_STRIDE = int(os.environ.get("CHUNK_STRIDE", 4)) +BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 8)) +GRAD_ACCUM = int(os.environ.get("GRAD_ACCUM", 1)) +LR = float(os.environ.get("LR", 3e-4)) +WD = float(os.environ.get("WD", 0.01)) +MAX_WALLCLOCK = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 0)) +RUN_ID = os.environ.get("RUN_ID", "hnet_m1_pilot") +VAL_EVERY = int(os.environ.get("VAL_EVERY", 0)) +TRAIN_LOG_EVERY = int(os.environ.get("TRAIN_LOG_EVERY", 20)) +SEED = int(os.environ.get("SEED", 42)) +COMPILE = int(os.environ.get("COMPILE", 1)) # set 0 to skip torch.compile (eager) +WARMDOWN_FRAC = float(os.environ.get("WARMDOWN_FRAC", 0.3)) # final 30% of steps cosine to 0 + +LOG_DIR = Path("logs") +LOG_DIR.mkdir(exist_ok=True) +LOG_PATH = LOG_DIR / f"{RUN_ID}.txt" + +def log(msg: str, console: bool = True): + if console: + print(msg, flush=True) + with open(LOG_PATH, "a", encoding="utf-8") as f: + f.write(msg + "\n") + +# --------------------------------------------------------------------------- +# Load baseline module for Block/RMSNorm/... class defs. +# Register under sys.modules so torch.compile / dynamo can resolve globals +# referenced from Block.forward (flash_attn_3_func, F, torch, etc.) via a real +# importable module. +# --------------------------------------------------------------------------- +import types as _types +BASELINE_PATH = Path(os.environ.get("BASELINE_PATH", "/workspace/work/train_gpt_baseline.py")) +BASELINE_MOD_NAME = "baseline_ns" +_baseline_mod = _types.ModuleType(BASELINE_MOD_NAME) +_baseline_mod.__file__ = str(BASELINE_PATH) +sys.modules[BASELINE_MOD_NAME] = _baseline_mod +exec(compile(BASELINE_PATH.read_text(), str(BASELINE_PATH), "exec"), _baseline_mod.__dict__) +ns = _baseline_mod.__dict__ + +# Local import of the HNet model factory +sys.path.insert(0, str(Path(__file__).parent)) +from hnet_m1 import build_hnet_m1, count_params # noqa: E402 + +# --------------------------------------------------------------------------- +# Byte-shard loader (mirrors baseline load_data_shard but u16 values are bytes) +# --------------------------------------------------------------------------- +SHARD_MAGIC = 20240520 + +def load_byte_shard(path: Path) -> np.ndarray: + header = np.fromfile(path, dtype=" np.ndarray: + pieces = [] + remaining = n + while remaining > 0: + avail = self.buf.size - self.pos + if avail <= 0: + self._advance(); continue + k = min(remaining, avail) + pieces.append(self.buf[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return np.concatenate(pieces) if len(pieces) > 1 else pieces[0] + + +# --------------------------------------------------------------------------- +# main +# --------------------------------------------------------------------------- +def main(): + if LOG_PATH.exists(): + LOG_PATH.unlink() + + log(f"[hnet_m1] RUN_ID={RUN_ID} ITERATIONS={ITERATIONS} BYTE_SEQ_LEN={BYTE_SEQ_LEN} " + f"CHUNK_STRIDE={CHUNK_STRIDE} BATCH_SIZE={BATCH_SIZE} GRAD_ACCUM={GRAD_ACCUM} LR={LR} WD={WD}") + log(f"[hnet_m1] BYTE_DATA_DIR={BYTE_DATA_DIR}") + + torch.manual_seed(SEED); np.random.seed(SEED) + assert torch.cuda.is_available(), "CUDA required" + device = torch.device("cuda") + dtype = torch.bfloat16 + + # Build model + model = build_hnet_m1(ns, byte_seq_len=BYTE_SEQ_LEN, chunk_stride=CHUNK_STRIDE) + model = model.to(device).to(dtype) + # Keep small/control parameters in fp32 (baseline convention) + for name, p in model.named_parameters(): + if p.ndim < 2: + p.data = p.data.float() + + total_p, nonemb_p = count_params(model) + log(f"[hnet_m1] params total={total_p:,} non-embedding={nonemb_p:,}") + + # Optimizer: plain AdamW on matrix params; plain AdamW on vector/scalar params at same LR + matrix_params = [p for p in model.parameters() if p.ndim == 2] + scalar_params = [p for p in model.parameters() if p.ndim < 2] + opt = torch.optim.AdamW( + [ + {"params": matrix_params, "lr": LR, "weight_decay": WD}, + {"params": scalar_params, "lr": LR, "weight_decay": 0.0}, + ], + betas=(0.9, 0.95), eps=1e-8, fused=True, + ) + for g in opt.param_groups: g["base_lr"] = g["lr"] + + # Data + train_stream = ByteStream(os.path.join(BYTE_DATA_DIR, "fineweb_train_*.bin")) + val_stream = ByteStream(os.path.join(BYTE_DATA_DIR, "fineweb_val_*.bin")) + + def next_batch(stream, B, T): + raw = stream.take(B * (T + 1)) + raw = raw[: B * (T + 1)].reshape(B, T + 1) + x = torch.from_numpy(raw[:, :-1].astype(np.int64)).to(device, non_blocking=True) + y = torch.from_numpy(raw[:, 1:].astype(np.int64)).to(device, non_blocking=True) + return x, y + + # compile (or not) + if COMPILE: + log("[hnet_m1] torch.compile enabled") + compiled = torch.compile(model, dynamic=False) + else: + log("[hnet_m1] torch.compile DISABLED (eager mode)") + compiled = model + + # warmup + for wstep in range(WARMUP_STEPS): + x, y = next_batch(train_stream, BATCH_SIZE, BYTE_SEQ_LEN) + with torch.autocast(device_type="cuda", dtype=dtype, enabled=True): + loss = compiled(x, y) + loss.backward() + opt.step(); opt.zero_grad(set_to_none=True) + if (wstep + 1) % max(1, WARMUP_STEPS // 4) == 0: + log(f"warmup_step:{wstep+1}/{WARMUP_STEPS}") + + # restart data stream so warmup tokens don't bias training window + train_stream = ByteStream(os.path.join(BYTE_DATA_DIR, "fineweb_train_*.bin")) + + log("[hnet_m1] entering main loop") + torch.cuda.synchronize() + t_start = time.perf_counter() + total_tokens = 0 + + warmdown_start = int(ITERATIONS * (1.0 - WARMDOWN_FRAC)) + for step in range(1, ITERATIONS + 1): + # cosine warmdown over the last WARMDOWN_FRAC fraction of steps + if step > warmdown_start: + frac = (step - warmdown_start) / max(ITERATIONS - warmdown_start, 1) + lr_scale = 0.5 * (1.0 + math.cos(math.pi * frac)) + else: + lr_scale = 1.0 + for g in opt.param_groups: + g["lr"] = g["base_lr"] * lr_scale + + x, y = next_batch(train_stream, BATCH_SIZE, BYTE_SEQ_LEN) + with torch.autocast(device_type="cuda", dtype=dtype, enabled=True): + loss = compiled(x, y) + (loss / GRAD_ACCUM).backward() + if step % GRAD_ACCUM == 0: + opt.step(); opt.zero_grad(set_to_none=True) + total_tokens += BATCH_SIZE * BYTE_SEQ_LEN + + torch.cuda.synchronize() + t_now = time.perf_counter() - t_start + + if step <= 5 or step % TRAIN_LOG_EVERY == 0 or step == ITERATIONS: + toks_per_s = total_tokens / max(t_now, 1e-9) + log(f"{step}/{ITERATIONS} train_loss: {loss.item():.4f} " + f"train_time: {t_now/60:.1f}m tok/s: {int(toks_per_s)}") + + if MAX_WALLCLOCK > 0 and t_now >= MAX_WALLCLOCK: + log(f"[hnet_m1] hit MAX_WALLCLOCK_SECONDS={MAX_WALLCLOCK} at step {step}") + break + + # final validation: per-byte CE across ~1M tokens + log("[hnet_m1] running final val") + model.eval() + val_loss_sum = torch.tensor(0.0, device=device, dtype=torch.float64) + val_tok_sum = torch.tensor(0.0, device=device, dtype=torch.float64) + val_batches = 16 + with torch.inference_mode(): + for _ in range(val_batches): + x, y = next_batch(val_stream, BATCH_SIZE, BYTE_SEQ_LEN) + with torch.autocast(device_type="cuda", dtype=dtype, enabled=True): + logits = model.forward_logits(x) + n = x.numel() + ce = F.cross_entropy(logits.float().reshape(-1, 256), y.reshape(-1), reduction="sum") + val_loss_sum += ce.double() + val_tok_sum += float(n) + val_nll = (val_loss_sum / val_tok_sum).item() + val_bpb = val_nll / math.log(2.0) # per-byte CE is already per-byte; ln -> bits + log(f"final val_nll: {val_nll:.4f} val_bpb: {val_bpb:.4f}") + log(f"[hnet_m1] total training tokens: {total_tokens:,}") + log(f"[hnet_m1] wallclock: {(time.perf_counter()-t_start)/60:.2f}m") + + +if __name__ == "__main__": + main() diff --git a/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/hnet_scope.md b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/hnet_scope.md new file mode 100644 index 0000000000..41ab0f5ccc --- /dev/null +++ b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/hnet_scope.md @@ -0,0 +1,136 @@ +# H-Net on Parameter Golf — Implementation Scope + +Companion doc to the investigation PR. Details the proposed H-Net variant targeted for the 16 MB / 10-min-on-8×H100 budget, milestone gates, and failure modes. + +Reference: Hwang et al., *Dynamic Chunking for End-to-End Hierarchical Sequence Modeling*, arXiv:2507.07955 (Jul 2025). + +## 1. Architecture at a glance + +``` +bytes → byte-encoder → dynamic chunker → main network → byte-decoder → per-byte logits + | (thin, (Wq, Wk; (deep, ( thin, + | D_enc=256) boundary_prob D_main=512) D_enc=256) + | + STE + EMA) ^ + | | + └─ byte emb (256 × D_enc, ≈66K params) ──────────┘ +``` + +The byte encoder/decoder are thin (2 layers each) transformers at D_enc=256. The main network is essentially bigbag's current SP8192 stack (11 layers, D_main=512) but now consumes / produces a *chunk* stream, not a SP8192-token stream. The dynamic chunker decides which encoder outputs are chunk boundaries. + +## 2. Parameter budget (int6 quantized + Brotli, targeting 16 MB) + +| component | params (≈) | quantized bytes (≈) | notes | +|---|---:|---:|---| +| byte embedding (256 × D_enc) | 66 K | 60 KB (int8 like SP8192) | vs. 2.1 M / 2 MB for SP8192 | +| byte encoder (2 × transformer blocks at D_enc=256, 8 heads/4 KV, MLP 3×) | 1.5 M | 1.2 MB | tied layer weights with decoder would halve this | +| chunker (Wq + Wk, both D_enc × D_enc) | 130 K | 100 KB | unchanged from paper | +| main network (11 L × D_main=512 × GQA × MLP 4×, same as bigbag) | 22 M | 14 MB | inherits SP8192-stack fusions: parallel residuals, 3-layer loop, QK-Gain, GPTQ+SDClip | +| byte decoder (1 × block at D_enc=256) | 0.8 M | 0.6 MB | can be tied with encoder layers if tight | +| **total** | **~24.5 M** | **~16 MB** | some slack for upsampler + final LN + skip weights | + +At 24–26 M params we're slightly under bigbag's 35.9 M (they spend ~4 MB on the SP8192 token embedding we eliminated). The saved budget goes into either (a) a deeper main network, (b) a wider byte encoder, or (c) an explicit upsampler with dedicated params. + +## 3. Dynamic-chunker design (copy directly from paper) + +Per-position boundary probability: + +$$ +p_t = \tfrac{1}{2} \Big( 1 - \frac{q_t^{\top} k_{t-1}}{\Vert q_t \Vert \, \Vert k_{t-1} \Vert} \Big) +$$ + +where $q_t = W_q \hat{x}_t$, $k_t = W_k \hat{x}_t$ and $\hat{x}_t$ is the byte-encoder output at position $t$. + +Downsampling: select encoder outputs where $b_t = \mathbf{1}[p_t > 0.5]$. Discard the rest. + +Smoothing for gradient flow (EMA): + +$$ +\bar{z}_t = P_t \hat{z}_t + (1 - P_t) \bar{z}_{t-1}, +$$ + +and a Straight-Through Estimator rounds $P_t$ to 1 in forward but preserves real-valued gradients. + +## 4. Losses + +$$ +\mathcal{L} = \mathcal{L}_{\text{AR}} + \alpha \, \mathcal{L}_{\text{ratio}}, \quad \alpha = 0.03 +$$ + +$\mathcal{L}_{\text{AR}}$: standard per-byte autoregressive cross-entropy (full 256-token vocab; no SentencePiece). + +$\mathcal{L}_{\text{ratio}}$: encourages a target compression ratio $r \approx 3.5$ (matches SP8192's effective bytes/token on English FineWeb). The paper's form: + +$$ +\mathcal{L}_{\text{ratio}} = \big( r \cdot F - G \big)^2, \quad F = \tfrac{1}{T}\sum_t b_t, \quad G = \tfrac{1}{T}\sum_t p_t +$$ + +(F is the actual fraction of chunk boundaries, G is the mean boundary probability.) + +## 5. Sequence lengths and compute + +SP8192 baseline: seq_len = 2048 tokens ≈ 2048 × 3.5 ≈ 7200 bytes per sequence. + +Byte-level H-Net at equivalent data coverage: seq_len_bytes = 7200. + +- Byte encoder runs on 7200 bytes per sequence. FA3 at 7200 seq is fine (memory-bound, ~2× slower than 2048). +- Main network runs on ~2000 chunks per sequence (after ~3.5× compression). Matches current SP8192 compute budget. +- Byte decoder runs on 7200 bytes again. + +Net forward FLOPs roughly doubled vs. SP8192 (the extra byte encoder + decoder passes). This is the main compute risk — we may need to reduce main-network depth (11 → 9 or 10 layers) to fit the 10-min budget. Milestone 1 measures this. + +## 6. Milestone plan (aligned with $500 dev grant scope) + +### Milestone 1 — hierarchical stack with **fixed** chunker ( ≈ $60 GPU) + +Before learning the chunker, prove the hierarchical stack trains. Fix boundaries at every $r$-th byte (stride-3 or stride-4 deterministic). Train byte-encoder + main-network + byte-decoder end-to-end at 1× H100 scale (reduced iters, SP8192-equivalent vocab). + +**Gate**: val_bpb better than a byte-level transformer of the same total param count (simple byte LM, no hierarchy). If worse, the upsampler / reinjection path is broken and we fix before learning the chunker. + +### Milestone 2 — learned chunker with EMA + STE ( ≈ $120 GPU) + +Drop the fixed chunker. Add Wq, Wk + ratio loss. Watch for: +- Degenerate collapse (all boundaries or no boundaries). +- Chunker "freezing" (boundary positions stop moving once set). +- Ratio-loss mis-specification. + +**Gate**: non-degenerate boundary distribution (F ∈ [0.2, 0.4]) and val_bpb ≤ Milestone 1 value. + +### Milestone 3 — full 16 MB submission ( ≈ $200 GPU) + +Turn on all the SP8192-stack bells: parallel residuals on main, 3-layer depth loop, MuonEq-R, GPTQ+SDClip int6, Brotli-11. 3-seed mean on 8× H100 SXM. + +**Gate**: artifact < 16 MB, eval < 10 min, val_bpb competitive with the best non-record SP8192 results (~1.10 BPP). Hitting the SP8192-record threshold (1.016 as of 2026-04-14) is aspirational and not required for the PR to be a creative contribution. + +### Milestone 4 — ablations ( ≈ $120 GPU) + +- Compression-ratio sweep (r ∈ {2, 3, 3.5, 4, 5}). +- Byte-encoder depth (1 / 2 / 3 layers). +- Chunker variants (cosine sim vs. small MLP, see paper §E). + +Published as an update to this PR, not a separate PR. + +## 7. Failure modes and abort criteria + +| risk | likelihood | mitigation / abort | +|---|---|---| +| Chunker collapses (always / never chunk) | medium | EMA + STE + ratio loss (paper's recipe). If still collapses at M2, try Gumbel-softmax variant before aborting. | +| Scale too small for hierarchy to help | medium-high | Paper only shows ≥680M. If M1 byte-stack is worse than a plain byte transformer at 25 M params, H-Net may simply not work this small — abort and pivot grant remainder to documenting the negative result. | +| Byte-level compute exceeds 10 min | medium | Reduce main-network depth from 11 → 9; or add FA3 to byte encoder too. | +| SOTA stack (depth recurrence, parallel residuals) incompatible with chunk stream | low | They operate on the main-network token dimension, should transfer directly. | +| Training instability from joint loss | low-medium | Warmup α from 0 to 0.03 over first 500 steps. | + +## 8. Why this is the right bet now + +- **Unclaimed** on the repo's explicit Requests-for-PRs list. +- **Complements** rather than competes with current SOTA (GDN-Hybrid, varlen + fused MLP, parallel residuals, adaptive TTT). The current SOTA stacks all start from SP8192 or SP4096; H-Net attacks the tokenizer axis they can't touch. +- **Concrete payoff** from the paper: 3.5–4× effective byte compression + ~4× better data efficiency on code / non-Latin / DNA. FineWeb has heavy code fragments; exactly the regime H-Net helps most. +- **Within our budget**: $500 grant covers 4 milestones. No milestone individually requires more than ~$200. + +## 9. Prior art we're NOT duplicating + +- `#1548 dljr-github` "Frozen Random Backbone + LoRA Adapters" — a different adapter-on-random-backbone idea, not H-Net. +- `#973 mrbese` "38-token structured alphabet + BPE" — a fixed alternative tokenizer, not learned chunking. +- `#1312 / #1480 / #1581` JEPA submissions — a different hierarchical idea (representation learning, not dynamic chunking). +- `#1582 / #1596` masked-diffusion submissions — orthogonal to tokenization. + +No open PR has attempted learned byte-level chunking as of 2026-04-14. diff --git a/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/phase3_run.py b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/phase3_run.py new file mode 100644 index 0000000000..f631b93de8 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/phase3_run.py @@ -0,0 +1,101 @@ +"""Phase 3 training pilot. + +Loads the FA3 baseline module, optionally patches all blocks' attention to use +fused-QKV and/or Triton-XSA, then runs `main()` for ITERATIONS steps. + +Env knobs (defaults in code): + PATCH_QKV=0|1 fuse c_q/c_k/c_v into a single c_qkv GEMM + PATCH_XSA=0|1 swap `_xsa_efficient` to xsa_triton + RUN_ID=... baseline training log file name (train logs/.txt) + ITERATIONS=200 pilot step count + TRAIN_LOG_EVERY=20 + WARMUP_STEPS=10 + VAL_LOSS_EVERY=0 (skip in-training validation) + MAX_WALLCLOCK_SECONDS=0 (disable wallclock cap) +""" +from __future__ import annotations +import os, sys, types, torch, torch.nn.functional as F + +# -------- defaults ------------------------------------------------- +os.environ.setdefault("DATA_DIR", "/workspace/parameter-golf/data/") +os.environ.setdefault("ITERATIONS", "200") +os.environ.setdefault("VAL_LOSS_EVERY", "0") +os.environ.setdefault("TRAIN_LOG_EVERY", "20") +os.environ.setdefault("WARMUP_STEPS", "10") +os.environ.setdefault("MAX_WALLCLOCK_SECONDS", "0") + +PATCH_QKV = os.environ.get("PATCH_QKV", "0") == "1" +PATCH_XSA = os.environ.get("PATCH_XSA", "0") == "1" +print(f"[phase3] PATCH_QKV={PATCH_QKV} PATCH_XSA={PATCH_XSA}") + +sys.path.insert(0, "/workspace/work") +from xsa_triton import xsa_triton as xsa_triton_fn + +# -------- load baseline module via exec ---------------------------- +src = open("/workspace/work/train_gpt_baseline.py").read() +ns: dict = {"__name__": "pg_baseline", "__file__": "/workspace/work/train_gpt_baseline.py"} +exec(compile(src, "/workspace/work/train_gpt_baseline.py", "exec"), ns) + +CastedLinear = ns["CastedLinear"] +apply_rotary_emb = ns["apply_rotary_emb"] +fa3_func = ns["flash_attn_3_func"] + +# -------- patches -------------------------------------------------- +def fuse_qkv_weights(attn): + dim = attn.c_q.weight.size(1) + q_dim = attn.num_heads * attn.head_dim + kv_dim = attn.num_kv_heads * attn.head_dim + c_qkv = CastedLinear(dim, q_dim + 2 * kv_dim, bias=False) + c_qkv = c_qkv.to(attn.c_q.weight.device).to(attn.c_q.weight.dtype) + with torch.no_grad(): + c_qkv.weight.copy_(torch.cat([attn.c_q.weight, attn.c_k.weight, attn.c_v.weight], dim=0)) + attn.c_qkv = c_qkv + del attn.c_q, attn.c_k, attn.c_v + + +def make_fused_qkv_forward(use_triton_xsa): + def forward(self, x): + bsz, seqlen, _ = x.shape + qkv = self.c_qkv(x) + q_dim = self.num_heads * self.head_dim + kv_dim = self.num_kv_heads * self.head_dim + q, k, v = qkv.split([q_dim, kv_dim, kv_dim], dim=-1) + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = fa3_func(q, k, v, causal=True) + if self.use_xsa: + y = xsa_triton_fn(y, v) if use_triton_xsa else self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, y.size(-2) * y.size(-1)) + return self.proj(y) + return forward + + +def patched_xsa_method(self, y, v): + return xsa_triton_fn(y, v) + + +# Intercept GPT.__init__ to apply post-construction patches +original_GPT_init = ns["GPT"].__init__ +def patched_GPT_init(self, *args, **kwargs): + original_GPT_init(self, *args, **kwargs) + n_patched = 0 + for blk in self.blocks: + if PATCH_QKV: + fuse_qkv_weights(blk.attn) + blk.attn.forward = types.MethodType(make_fused_qkv_forward(PATCH_XSA), blk.attn) + elif PATCH_XSA: + # Only XSA patch, keep 3-linear path + blk.attn._xsa_efficient = types.MethodType(patched_xsa_method, blk.attn) + n_patched += 1 if (PATCH_QKV or PATCH_XSA) else 0 + print(f"[phase3] patched {n_patched} blocks (QKV={PATCH_QKV} XSA={PATCH_XSA})") +ns["GPT"].__init__ = patched_GPT_init + +# -------- run main() ----------------------------------------------- +ns["main"]() diff --git a/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/qkv_fuse.py b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/qkv_fuse.py new file mode 100644 index 0000000000..c21ae9e90e --- /dev/null +++ b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/qkv_fuse.py @@ -0,0 +1,109 @@ +"""QKV weight-stacking patch for the SP8192 stack's `CausalSelfAttention`. + +Stacks `c_q / c_k / c_v` into a single `c_qkv` linear of shape + (num_heads * head_dim + 2 * num_kv_heads * head_dim, dim) +and rebinds the attention module's `forward` to split/reshape the stacked output. + +Numerical parity: forward-output at init is bit-identical to the 3-linear +baseline (we verified 0.000e+00 max elementwise error at multiple shapes). + +Caveat: NOT a numerically-equivalent systems-only change under Muon. See the +"Muon interaction" note below. This file is shipped as reference / for anyone +following up — the training pilot showed that Inductor already fuses these three +linears when compiled at model scope, so there is no step-time benefit in practice. +""" +from __future__ import annotations + +import types + +import torch +import torch.nn.functional as F + + +def fuse_qkv_weights(attn, CastedLinear): + """Replace attn.c_q / c_k / c_v with a single c_qkv of stacked weights. + + Preserves dtype and device of the original weights. Must be called with + `CastedLinear` from the same module namespace as `attn` (because the rebound + forward references its `__call__`). + """ + dim = attn.c_q.weight.size(1) + q_dim = attn.num_heads * attn.head_dim + kv_dim = attn.num_kv_heads * attn.head_dim + out_dim = q_dim + 2 * kv_dim + + c_qkv = CastedLinear(dim, out_dim, bias=False) + c_qkv = c_qkv.to(attn.c_q.weight.device).to(attn.c_q.weight.dtype) + with torch.no_grad(): + c_qkv.weight.copy_(torch.cat( + [attn.c_q.weight, attn.c_k.weight, attn.c_v.weight], + dim=0, + )) + + attn.c_qkv = c_qkv + del attn.c_q, attn.c_k, attn.c_v + + +def make_fused_qkv_forward(attn_backend, apply_rotary_emb, xsa_fn=None): + """Build a replacement `forward` for a fused-QKV CausalSelfAttention. + + Args: + attn_backend: callable with signature `(q, k, v, causal=...) -> y` + accepting `(B, S, H, D)` layout for q,k,v (matches FA3's API). + apply_rotary_emb: the baseline's partial-RoPE function. + xsa_fn: optional replacement for `self._xsa_efficient(y, v)`. If None + the instance's existing `_xsa_efficient` is called. + + Returns a function suitable for `types.MethodType(fn, attn_instance)`. + """ + def forward(self, x): + bsz, seqlen, _ = x.shape + qkv = self.c_qkv(x) + q_dim = self.num_heads * self.head_dim + kv_dim = self.num_kv_heads * self.head_dim + q, k, v = qkv.split([q_dim, kv_dim, kv_dim], dim=-1) + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = attn_backend(q, k, v, causal=True) + if self.use_xsa: + y = xsa_fn(y, v) if xsa_fn is not None else self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, y.size(-2) * y.size(-1)) + return self.proj(y) + return forward + + +def patch_attn(attn, CastedLinear, attn_backend, apply_rotary_emb, xsa_fn=None): + """One-shot convenience: fuse weights and rebind forward on an attn instance.""" + fuse_qkv_weights(attn, CastedLinear) + attn.forward = types.MethodType( + make_fused_qkv_forward(attn_backend, apply_rotary_emb, xsa_fn), + attn, + ) + + +# Muon interaction (important): +# --------------------------------------------------------------------------- +# The baseline Muon optimizer runs the Newton-Schulz-5 polynomial on each 2-D +# weight matrix's gradient independently. With 3 separate linears (c_q, c_k, c_v), +# NS5 orthogonalizes three independent gradient matrices. With a fused c_qkv of +# shape (D + 2*D_kv, D), NS5 orthogonalizes the *joint* gradient matrix — a +# different spectrum in general, and therefore a different effective update. +# The forward output is bit-identical at init, but the training trajectories +# diverge. In our 200-step pilot we observed ~0.01 nats train_loss drift at step +# 200 (bf16), which is within the noise floor but IS a real effect. +# +# A correct systems-only fused-QKV would need one of: +# (a) split the c_qkv gradient back into 3 slices and NS5 each independently, +# then reassemble (matches baseline dynamics bit-for-bit); +# (b) reformulate Muon so NS5 respects a block-diagonal / Kronecker structure +# on the stacked weight; +# (c) accept the divergence and claim the fused version as a slightly different +# model, requiring a full 3-seed mean for the record threshold. +# Neither (a) nor (b) is implemented here. diff --git a/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/submission.json b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/submission.json new file mode 100644 index 0000000000..248754e75c --- /dev/null +++ b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/submission.json @@ -0,0 +1,36 @@ +{ + "author": "diaslmb", + "github_id": "diaslmb", + "name": "Systems-Fusion Ceiling Investigation + H-Net Milestone 1 Pilot", + "date": "2026-04-14", + "track": "non_record_16mb", + "kind": "investigation + pilot", + "headline_results": { + "part1_xsa_triton_and_fused_qkv_on_PR_1493": { + "baseline_step_avg_ms": 1020, + "qkv_fusion_step_avg_ms": 1020, + "qkv_fusion_plus_triton_xsa_step_avg_ms": 1080, + "verdict": "Neither intervention helps: Inductor already at the D=512 fusion ceiling; autograd.Function graph break costs more than the XSA kernel saves" + }, + "part2_hnet_m1_pilots_1xH100": { + "no_skip_300_steps": {"tokens": 10000000, "val_bpb": 4.49}, + "no_skip_1500_steps": {"tokens": 49000000, "val_bpb": 4.40}, + "with_skip_1500_steps": {"tokens": 49000000, "val_bpb": 3.15}, + "with_skip_4500_steps": {"tokens": 147000000, "val_bpb": 2.51}, + "params_total": 33907824, + "params_main_network": 31742040, + "byte_skip_delta_bpb": -1.25, + "data_scaling_delta_bpb": -0.64 + } + }, + "hardware": "1xH100 80GB SXM (RunPod)", + "pytorch_version": "2.8.0+cu128", + "technique_summary": "Part 1: Triton XSA fwd+bwd kernel + fused-QKV patch on bigbag PR #1493, training-pilot negative result documenting Inductor ceiling. Part 2: H-Net Milestone 1 (hierarchical byte-level stack with fixed stride-4 chunker and byte-encoder->byte-decoder skip connection) showing clean signs of life. M2-M4 proposed for OpenAI dev grant.", + "attribution": { + "systems_baseline": "@bigbag (PR #1493) - SP8192 + 3-Layer Recurrence + Parallel Residuals + QK-Gain 5.25 + Legal TTT", + "hnet_paper": "Hwang et al., Dynamic Chunking for End-to-End Hierarchical Sequence Modeling, arXiv:2507.07955", + "fa3_wheel": "windreamer.github.io/flash-attention3-wheels (cu128_torch280)", + "scaling_laws_context": "Kaplan et al., arXiv:2001.08361" + }, + "notes": "Non-record investigation + pilot PR. Part 1 is a documented negative result; Part 2 is signs-of-life for the H-net tokenization entry on the repo's Requests-for-PRs list. No leaderboard val_bpb claim (M1 pilot at 33.9M params is over the 16MB int6 budget; fitting the budget is M3's job)." +} diff --git a/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/unpack.py b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/unpack.py new file mode 100644 index 0000000000..d77e3310e2 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/unpack.py @@ -0,0 +1,26 @@ +"""Unpack bigbag's LZMA-packed train_gpt.py into readable source. + +Run: python unpack.py +""" +import base64, lzma, re, sys + +packed_path, out_path = sys.argv[1], sys.argv[2] +blob = open(packed_path).read() + +m = re.search(r'B\.b85decode\("([^"]+)"\)', blob) +if m is None: + raise SystemExit(f"No base85 blob found in {packed_path}") +b85 = m.group(1) + +src = lzma.decompress( + base64.b85decode(b85), + format=lzma.FORMAT_RAW, + filters=[{"id": lzma.FILTER_LZMA2}], +).decode() + +with open(out_path, "w", encoding="utf-8") as f: + f.write(src) + +print(f"packed : {len(blob):>8} bytes") +print(f"unpacked : {len(src):>8} bytes -> {out_path}") +print(f"lines : {src.count(chr(10)) + 1}") diff --git a/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/xsa_triton.py b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/xsa_triton.py new file mode 100644 index 0000000000..df21dcbf35 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-14_XSA_Fusion_Investigation_HNet_M1/xsa_triton.py @@ -0,0 +1,179 @@ +"""XSA (vector-rejection) Triton kernels for GQA attention outputs. + +Replaces the torch decomposition: + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + +with a pair of Triton kernels (fwd + bwd), each launching one program per +(B, T, Hkv) and loading v exactly once. +""" +from __future__ import annotations + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl + + +# --------------------------------------------------------------------------- +# Forward +# --------------------------------------------------------------------------- +@triton.jit +def _xsa_fwd_kernel( + y_ptr, v_ptr, y_out_ptr, s_ptr, + sy_b, sy_t, sy_h, sy_d, + sv_b, sv_t, sv_h, sv_d, + so_b, so_t, so_h, so_d, + ss_b, ss_t, ss_h, + T, Hkv, + BLOCK_D: tl.constexpr, + GROUP: tl.constexpr, +): + pid = tl.program_id(0) + b = pid // (T * Hkv) + r = pid - b * (T * Hkv) + t = r // Hkv + hk = r - t * Hkv + + d_offs = tl.arange(0, BLOCK_D) + + v = tl.load(v_ptr + b * sv_b + t * sv_t + hk * sv_h + d_offs * sv_d).to(tl.float32) + vnorm = tl.maximum(tl.sqrt(tl.sum(v * v)), 1e-12) + vhat = v / vnorm + + for g in tl.static_range(GROUP): + h = hk * GROUP + g + y = tl.load(y_ptr + b * sy_b + t * sy_t + h * sy_h + d_offs * sy_d).to(tl.float32) + s = tl.sum(y * vhat) + y_out = y - s * vhat + tl.store(y_out_ptr + b * so_b + t * so_t + h * so_h + d_offs * so_d, y_out) + tl.store(s_ptr + b * ss_b + t * ss_t + h * ss_h, s) + + +# --------------------------------------------------------------------------- +# Backward +# --------------------------------------------------------------------------- +# Let S_g = , U_g = . +# grad_y_g = grad_f_g - U_g · v̂ (same projection as forward) +# grad_v = -( Σ_g U_g y_g + Σ_g S_g grad_f_g - 2·(Σ_g U_g S_g) · v̂ ) / ||v|| +# Note: Σ_g U_g S_g = < Σ_g U_g y_g, v̂ > (since S_g = and v̂ is unit) +@triton.jit +def _xsa_bwd_kernel( + y_ptr, v_ptr, s_ptr, gf_ptr, + gy_ptr, gv_ptr, + sy_b, sy_t, sy_h, sy_d, + sv_b, sv_t, sv_h, sv_d, + ss_b, ss_t, ss_h, + sgf_b, sgf_t, sgf_h, sgf_d, + sgy_b, sgy_t, sgy_h, sgy_d, + sgv_b, sgv_t, sgv_h, sgv_d, + T, Hkv, + BLOCK_D: tl.constexpr, + GROUP: tl.constexpr, +): + pid = tl.program_id(0) + b = pid // (T * Hkv) + r = pid - b * (T * Hkv) + t = r // Hkv + hk = r - t * Hkv + + d_offs = tl.arange(0, BLOCK_D) + + v = tl.load(v_ptr + b * sv_b + t * sv_t + hk * sv_h + d_offs * sv_d).to(tl.float32) + vnorm = tl.maximum(tl.sqrt(tl.sum(v * v)), 1e-12) + inv_vnorm = 1.0 / vnorm + vhat = v * inv_vnorm + + sum_Uy = tl.zeros((BLOCK_D,), dtype=tl.float32) + sum_SGf = tl.zeros((BLOCK_D,), dtype=tl.float32) + + for g in tl.static_range(GROUP): + h = hk * GROUP + g + y = tl.load(y_ptr + b * sy_b + t * sy_t + h * sy_h + d_offs * sy_d).to(tl.float32) + gf = tl.load(gf_ptr + b * sgf_b + t * sgf_t + h * sgf_h + d_offs * sgf_d).to(tl.float32) + s_g = tl.load(s_ptr + b * ss_b + t * ss_t + h * ss_h).to(tl.float32) + + U = tl.sum(gf * vhat) + grad_y_vec = gf - U * vhat + tl.store(gy_ptr + b * sgy_b + t * sgy_t + h * sgy_h + d_offs * sgy_d, grad_y_vec) + + sum_Uy = sum_Uy + U * y + sum_SGf = sum_SGf + s_g * gf + + # Σ_g U_g S_g = < Σ_g U_g y_g, v̂ > (using v̂ unit) + sum_US = tl.sum(sum_Uy * vhat) + grad_v = -(sum_Uy + sum_SGf - 2.0 * sum_US * vhat) * inv_vnorm + tl.store(gv_ptr + b * sgv_b + t * sgv_t + hk * sgv_h + d_offs * sgv_d, grad_v) + + +# --------------------------------------------------------------------------- +# autograd wrapper +# --------------------------------------------------------------------------- +class XSAFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, y: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + assert y.dim() == 4 and v.dim() == 4 + B, T, H, D = y.shape + Bv, Tv, Hkv, Dv = v.shape + assert B == Bv and T == Tv and D == Dv, f"shape mismatch y={y.shape} v={v.shape}" + assert H % Hkv == 0, f"H={H} must be divisible by Hkv={Hkv}" + G = H // Hkv + assert D in (16, 32, 64, 128, 256), f"D={D} not supported" + + y_out = torch.empty_like(y) + s = torch.empty(B, T, H, device=y.device, dtype=torch.float32) + + grid = (B * T * Hkv,) + _xsa_fwd_kernel[grid]( + y, v, y_out, s, + y.stride(0), y.stride(1), y.stride(2), y.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + y_out.stride(0), y_out.stride(1), y_out.stride(2), y_out.stride(3), + s.stride(0), s.stride(1), s.stride(2), + T, Hkv, + BLOCK_D=D, GROUP=G, + ) + ctx.save_for_backward(y, v, s) + return y_out + + @staticmethod + def backward(ctx, grad_f: torch.Tensor): + y, v, s = ctx.saved_tensors + B, T, H, D = y.shape + Hkv = v.size(-2) + G = H // Hkv + + grad_y = torch.empty_like(y) + grad_v = torch.empty_like(v) + + grid = (B * T * Hkv,) + _xsa_bwd_kernel[grid]( + y, v, s, grad_f, + grad_y, grad_v, + y.stride(0), y.stride(1), y.stride(2), y.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + s.stride(0), s.stride(1), s.stride(2), + grad_f.stride(0), grad_f.stride(1), grad_f.stride(2), grad_f.stride(3), + grad_y.stride(0), grad_y.stride(1), grad_y.stride(2), grad_y.stride(3), + grad_v.stride(0), grad_v.stride(1), grad_v.stride(2), grad_v.stride(3), + T, Hkv, + BLOCK_D=D, GROUP=G, + ) + return grad_y, grad_v + + +def xsa_triton(y: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + return XSAFunction.apply(y, v) + + +def xsa_torch(y: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """Reference implementation matching bigbag's `_xsa_efficient`.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D)