Skip to content

(Nonrecord) Applied Async Prefetching Potentially Boosts Performance#785

Open
SirSaltySalmon wants to merge 7 commits intoopenai:mainfrom
SirSaltySalmon:async_prefetching_on_leakyrelu
Open

(Nonrecord) Applied Async Prefetching Potentially Boosts Performance#785
SirSaltySalmon wants to merge 7 commits intoopenai:mainfrom
SirSaltySalmon:async_prefetching_on_leakyrelu

Conversation

@SirSaltySalmon
Copy link
Copy Markdown

@SirSaltySalmon SirSaltySalmon commented Mar 26, 2026

LeakyReLU^2 + Legal TTT + Parallel Muon + systems: prefetch & fusion-friendly MLP

Reference baseline: 2026-03-23_LeakyReLU_LegalTTT_ParallelMuon/README.md

Outcome

This variant improves throughput slightly, but does not improve quality versus the original 3-seed 8xH100 runs.

  • Mean steps in 600s: 7184.7 -> 7191.3 (+6.7 steps)
  • Mean step_avg: 83.53ms -> 83.44ms (faster)
  • Mean pre-TTT val_bpb (final_int6_sliding_window_exact): 1.12184 -> 1.12334 (worse by +0.00151)
  • Mean post-TTT val_bpb (legal_ttt_exact): 1.11938 -> 1.12096 (worse by +0.00158)

3-seed comparison (8xH100, 600s train budget)

Seed Baseline steps / post-TTT bpb This run steps / post-TTT bpb Delta
42 7182 / 1.12002032 7189 / 1.12119101 +7 steps, +0.00117069 bpb
1337 7179 / 1.11922988 7191 / 1.12088391 +12 steps, +0.00165403 bpb
2025 7193 / 1.11888882 7194 / 1.12081146 +1 step, +0.00192264 bpb
Mean 7184.7 / 1.11937967 7191.3 / 1.12096213 +6.7 steps, +0.00158245 bpb

1xH100 ablation (Modal sanity check, 600s train budget)

Configuration Steps / ms per step Post-TTT bpb Delta vs base
Base record train_gpt 924 / 649.71ms 1.55027402 -
+ prefetch only 942 / 637.55ms 1.53744178 +18 steps, -0.01283224 bpb
+ prefetch + MLP fusion form 943 / 636.73ms 1.53642888 +19 steps, -0.01384514 bpb

Interpretation

The data is consistent across all three seeds: the systems changes increase training throughput, but that throughput gain does not translate into better final validation quality in this setup.

So the result here is best described as a speed optimization with neutral-to-slightly-negative quality impact relative to the original record recipe. Likely just means noise impacted the training result, as training math and process is exactly the same.

On 1xH100, the same systems changes looked clearly positive (more steps and better post-TTT bpb), while on 8xH100 they remain speed-positive but quality-negative. The practical interpretation is that prefetch/fusion behavior does not transfer linearly from single-GPU to multi-GPU quality outcomes and should be treated as a throughput optimization first. Likely, I/O is no longer bottleneck at large scale, and more so communication between GPUs tend to be the target.

I will continue iterating on this as increased training speed shows promises. This attempt tries to prove that async prefetching and memory pinning can improve the throughput of most approaches, but requires more experimentation to investigate compatibility with other methods. Aiming to increase optimization's compatibility with parallel GPUs next.

What changed vs. base record

All differences are in data loading and MLP forward; model architecture, banking, Parallel Muon, FlashAttention-3, torch.compile usage, TTT protocol, and env-driven hyperparameters are otherwise aligned with base PR

1. Pinned async prefetch (PrefetchingDistributedTokenLoader)

  • Imports: queue, threading.
  • Hyperparameters (env):
    • TRAIN_PREFETCH (default 1)
    • TRAIN_PREFETCH_QUEUE (default 2)
    • TRAIN_COPY_STREAM (default 1) — when enabled with prefetch, H2D uses a dedicated torch.cuda.Stream and the default stream waits on it.
  • Helpers: _cpu_batch_from_stream, _h2d_int64_batches.
  • Loader: daemon thread builds the next (x, y) on CPU, contiguous().pin_memory(), bounded queue.Queue; next_batch dequees and copies to device.
  • Training loop: make_train_loader() factory; after optimizer state rewind (e.g. SWA branch), existing prefetch thread is shutdown() before a fresh loader is created so the token stream does not advance in the background.

2. Fusion-friendly LeakyReLU² MLP

Base:

x = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=0.5)
return F.linear(x.square(), down_w.to(x.dtype))

This submission:

x_dtype = x.dtype
up_w = up_w.to(dtype=x_dtype)
down_w = down_w.to(dtype=x_dtype)
h = F.leaky_relu(F.linear(x, up_w), negative_slope=0.5)
return F.linear(h * h, down_w)

Mathematically identical to LeakyReLU(0.5)² feeding the down projection; the change is layout / fusion hints for the compiled training graph, the Inductor fuses or simplifies more than before.

ENV

Same as the base run command, with optional prefetch toggles (defaults match optimized script):

TRAIN_PREFETCH=1 TRAIN_PREFETCH_QUEUE=2 TRAIN_COPY_STREAM=1 \
NUM_LAYERS=11 BIGRAM_VOCAB_SIZE=1536 XSA_LAST_N=4 \
EMA_ENABLED=1 EMA_DECAY=0.997 SWA_ENABLED=1 SWA_EVERY=50 \
ROPE_DIMS=16 LN_SCALE=1 LATE_QAT=1 LATE_QAT_THRESHOLD=0.15 \
VE_ENABLED=1 VE_DIM=128 VE_LAYERS=9,10 \
TTT_ENABLED=1 TTT_LR=0.002 TTT_EPOCHS=3 TTT_CHUNK_TOKENS=32768 \
TTT_FREEZE_BLOCKS=0 TTT_MOMENTUM=0.9 TTT_BATCH_SEQS=32 TTT_GRAD_CLIP=1.0 \
MUON_WD=0.04 ADAM_WD=0.04 \
MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035 \
MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 \
MUON_MOMENTUM_WARMUP_STEPS=1500 WARMDOWN_ITERS=3500 \
ITERATIONS=9000 MAX_WALLCLOCK_SECONDS=600 EVAL_STRIDE=64 \
SEED=1337 \
torchrun --standalone --nproc_per_node=8 train_gpt.py

Credits

@SirSaltySalmon SirSaltySalmon changed the title Applied Async Prefetching Boost Performance of Any Approach (Nonrecord) Applied Async Prefetching Potentially Boosts Performance Mar 26, 2026
@SirSaltySalmon SirSaltySalmon marked this pull request as ready for review March 26, 2026 19:05
@MatoTeziTanka
Copy link
Copy Markdown

Community Review — (Nonrecord) Applied Async Prefetching Potentially Boosts Performance

BPB: (not parsed — see PR title) | Compliance: LOOKS CLEAN — score-first-per-chunk TTT (legal #1416/#1423 pattern)

What I found in the code (head SHA d3fc20b82ccd, file records/track_non_record_16mb/2026-03-26_LeakyReLU_LegalTTT_ParallelMuon_PinnedPrefetch_FusionFriendlyMLP/train_gpt.py):

The TTT path at line 1207 implements the score-first-per-chunk pattern: each chunk is scored under torch.no_grad() / inference_mode() before the base_model.train() + SGD adaptation runs on that same chunk, with an is_last_chunk guard so the final chunk gets no adaptation pass. This is the structural shape the legal frontier uses (PRs #1416 erichroepke, #1423 aryanbhosale).

Per Issue #402 and Issue #677, TTT is legal when each token is scored before the adapter updates on it, and that's what the code does here — chunk ci is scored under weights adapted only on chunks 0..ci-1. No prequant_ttt_adapt_adamw(val_tokens, ...) multi-epoch fine-tune, no scored-region SLOT, no target-in-key n-gram cache.

CPU smoke test (CT2038 proteus-engine, 2026-04-11): import OK in 0.04s, dim=512, layers=11, vocab=1024, code=95277 B, SMOKE_TEST_PASS

Verdict: LOOKS CLEAN.

Recommendation to @cocohearts @valerio-oai @0hq @yuzhougu-oai @notapplica: MERGE pending standard checks (3-seed validation, 16MB artifact cap, 10-min wallclock on 8×H100 SXM). The compliance picture matches the legal reference frontier and no flags were raised by the classification pass.

Auto-classification caveat: this review was drafted by the AST-based classifier against a template derived from manually-reviewed cluster PRs (#1420, #1450, #1487, #1541, #1529, #1533, #1518). If I've misread a subtlety in your eval path — e.g., multi-epoch TTT that I mistook for single-pass, or a target-in-key lookup I missed in a helper function — please flag it and I'll re-run the audit manually.


Reviewed by @MatoTeziTankaThe Agora. CPU smoke test (CT2038 proteus-engine, 2026-04-11): import OK in 0.04s, dim=512, layers=11, vocab=1024, code=95277 B, SMOKE_TEST_PASS. Classification via deterministic AST-based classify_prs.py (pattern bank derived from ~65 manually-reviewed PRs earlier in the 2026-04-11 sweep). This review was auto-drafted from a template and spot-checked before posting — if the template misread your code, please call it out so I can iterate the classifier.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants