Skip to content

Newton-Muon × PR #1874's document-packed loader: a controlled negative result#1907

Open
GodlyDonuts wants to merge 2 commits intoopenai:mainfrom
GodlyDonuts:nm-doc-packing-negative-result
Open

Newton-Muon × PR #1874's document-packed loader: a controlled negative result#1907
GodlyDonuts wants to merge 2 commits intoopenai:mainfrom
GodlyDonuts:nm-doc-packing-negative-result

Conversation

@GodlyDonuts
Copy link
Copy Markdown

Newton-Muon × PR #1874's Document-Packed Loader: A Controlled Negative Result

When PR #1874 landed and started showing val_bpb in the 1.067 band, follow-up leaderboard discussions (including threads under PR #1900) pointed at Newton–Schulz-style residual orthogonalization in the optimizer — "Newton-Muon" — as a plausible next bump on top of the Polar Express NS already used in #1874. We spent a weekend trying it. This PR documents what we found, with full logs and trained artifacts, in case it saves others the same compute.

Headline

In a strict same-seed (42), same-data, same-batch-size, same-step-count A/B against the PR #1874 stack, enabling Newton-Muon regressed quantized + TTT-phased val_bpb from 1.06928 → 1.10705 (a +0.0378 nat hit). This is not noise or a lack of hyperparameter sweeping; it is a flat and reproducible regression. Both runs used the same train_gpt.py included here; the only delta was the environment variable NEWTON_MUON_ENABLED.

Why it Regresses (The Root Cause)

The Newton-Muon implementation we evaluated installs forward-pre-hooks on every Linear module to track a per-module integer counter _nm_K_count, triggering a Newton–Schulz preconditioning step every $K$-th forward pass. While this works on a vanilla static-shape loop, it fails on the PR #1874 stack for the following reasons:

  1. Recompile Limit: Within roughly ten training steps, every transformer block hits torch._dynamo.config.recompile_limit=16.
  2. Static Specialization: torch._dynamo treats integer attributes on nn.Module as static values. Every step, the hook increments module._nm_K_count. Dynamo sees this as a new value and emits a new specialization for every single step.
  3. Shape Conflict: Because the loader is document-packed, each step also brings a fresh cu_seqlens shape. These two specialization axes together saturate the recompile budget almost immediately.
  4. Performance Collapse: Once the limit is hit, Dynamo silently falls back to eager mode. fullgraph=True is violated, and FlashAttention 3's fused paths stop emitting cleanly. Step time inflates roughly 2.4×, meaning fewer training steps fit inside the 600s budget.

The PyTorch hint (allow_unspec_int_on_nn_module = True) suppresses the log warnings but does not fix the underlying fragmentation. Dynamo still struggles to reconcile the dynamic integer against variable cu_seqlens, and FA3 kernels do not survive the transition.

Structural Fixes

To resolve this, the $K$-counter and preconditioning trigger must be moved out of the compiled region:

  • Drive Newton–Schulz from optimizer.step() rather than a forward hook.
  • Convert _nm_K_count to a torch.Tensor scalar held outside nn.Module.__dict__.

What is in the Folder

Conclusion

This is not a bug in any individual component, but rather a predictable interaction between hook-based integer state, document-packed variable shapes, and fullgraph compilation. We are filing this as a non-record submission to ensure the community doesn't "burn" compute on this specific implementation path.

Compute spent: ~$12 on 8×H100 SXM (RunPod).
Submitted by: [Saicharan Ramineni (GodlyDonuts)](https://github.com/GodlyDonuts)

… nat)

Grafting hook-based Newton-Schulz residual orthogonalization onto PR openai#1874's
full stack regresses val_bpb by +0.0378 nat in a controlled same-seed A/B.

Root cause is dynamo recompile fragmentation: the per-module integer counter
`_nm_K_count` is mutated inside a forward-pre-hook, dynamo treats integer
attributes on nn.Module as static, every transformer block hits
config.recompile_limit=16 within ~10 steps, fullgraph compilation silently
breaks, FA3 fused kernels stop emitting cleanly, and step time inflates ~2.4x.

Filed with full diagnostic logs (train_nm_default.log, train_nm_smoke.log)
and a same-seed PR openai#1874 baseline (train_baseline_seed42.log) for direct
comparison, so other competitors don't burn the same compute.

Author: Saicharan Ramineni <[email protected]>
…he README

- models/nm_default.int6.ptz             — Newton-Muon enabled, full 600s, seed=42 (val_bpb 1.10705)
- models/nm_smoke.int6.ptz               — Newton-Muon enabled, 180s smoke run
- models/baseline_pr1874_seed42.int6.ptz — PR openai#1874 baseline, NM disabled, seed=42 (val_bpb 1.06928)

This lets a reviewer inspect the actual trained quantized weights without
having to retrain. Total ~46 MB of binary artifacts. README updated with a
verified CPU-only inspection snippet (brotli decompress -> _byte_unshuffle ->
torch.load) demonstrating the artifacts are well-formed int6 GPTQ state dicts.

Author: Saicharan Ramineni <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant