diff --git a/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/README.md b/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/README.md new file mode 100644 index 0000000000..4654f263d6 --- /dev/null +++ b/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/README.md @@ -0,0 +1,90 @@ +# Record: Improved Parallel Residuals + Systems Optimization + +**val_bpb = 1.0752** (3-seed mean, std 0.0006) | **2.7773 nats** | **~15.98 MB** | 8xH100 SXM, 600s | Legal TTT + +This submission applies three systems-level performance optimizations to PR #1529's dual-lane parallel residual architecture. The ML is unchanged; faster per-step throughput yields ~20 extra training steps in the same 600s budget. + +> **Submission series:** This PR is one of three related submissions applying the same systems optimizations to different base stacks: +> +> 1. On PR #1493 (current merged SOTA) +> 2. On PR #1529 (pending review) -- **this PR** +> 3. On PR #1578 (pending review) +> +> The optimizations are identical across all three -- fused Muon kernel, batched EMA, and loader prealloc. We submit against multiple bases so that a ready-to-merge option exists regardless of how the pending PRs are resolved. Judges should feel free to evaluate whichever base(s) they consider valid and disregard the rest. + +**Note on record criteria:** This submission improves speed through systems optimization without changing the ML. Per the official contest rules: *"For submissions that improve speed through systems optimization without changing the ML, this requirement [0.005 nats] is waived."* The three changes (fused Muon kernel, batched EMA, loader prealloc) are purely systems-level and do not alter model architecture, optimizer logic, loss function, or any hyperparameter. + +## 3-Seed Results + +| Seed | Steps | ms/step | Post-EMA BPB | Sliding BPB | **TTT BPB** | Artifact | +|------|-------|---------|-------------|-------------|-------------|----------| +| 1337 | 4,745 | 123.8 | 1.0823 | 1.0756 | **1.0745** | 15,983,819 | +| 2024 | 4,724 | 124.3 | 1.0833 | 1.0769 | **1.0755** | 15,982,374 | +| 42 | 4,744 | 123.8 | 1.0832 | 1.0773 | **1.0755** | 15,979,637 | +| **Mean** | **4,738** | **123.9** | **1.0829** | **1.0766** | **1.0752** | **15,981,943** | +| **Std** | | | | | **0.0006** | | + +PR #1529 original (same seeds): **1.0753 BPB mean**. Delta: **-0.0001 BPB** (from extra training steps). + +## Systems Optimizations (3 changes, training-step only) + +1. **Fused Muon transform** -- Single `@torch.compile` function combining momentum update, Nesterov extrapolation, row normalization, and Newton-Schulz orthogonalization. Eliminates kernel launch overhead between sequential operations. (+0.43% step time on 2xH100 benchmark) + +2. **EMA foreach** -- Replaces per-tensor EMA loop with `torch._foreach_mul_` / `torch._foreach_add_` for batched parameter averaging. (+0.08% step time) + +3. **Numpy prealloc loader** -- Pre-allocates a reusable numpy buffer for data loading instead of allocating a new `np.array` per sequence. (+0.11% step time) + +No eval, serialization, or model architecture changes. The three optimizations together save ~0.5% step time, translating to ~20 extra steps over 600s. + +## Architecture (from PR #1529) + +11L x 512d x 8H / 4KV, MLP 4x, LeakyReLU(0.5)^2, Partial RoPE (16/64), layerwise LN scale, tied embeddings, logit softcap=30.0. Depth recurrence: loops layers 3-5 (activated at frac=0.35). Dual-lane parallel residuals from physical layer 8: attention and MLP write to both lanes with learned post-lambdas and residual-lambdas. Final output: mean of two lanes. Skip connections: lane0 only. + +Fused Triton TMA MLP kernel + CUTLASS EVT backward for throughput. + +## Training + +Muon optimizer (sharded reduce-scatter + all-gather, Newton-Schulz 5 steps), AdamW for embeddings/scalars. ~4,738 steps in 587s. Warmdown frac=0.667, Muon momentum=0.97, EMA decay=0.9965. GPTQ reserve 13s. + +## Quantization + +Full-Hessian GPTQ with SDClip: int6 for attention/MLP matrices (k=12.85), int8 for token embeddings (k=20.0). Byte-shuffle + Brotli-11 compression. + +## TTT (Test-Time Training) + +Score-first chunk-based SGD: 32K-token chunks, 3 epochs per chunk, cosine LR decay (lr=0.01, momentum=0.9). Hash embedding (16384-dim bigram hash, zero-initialized, learned during TTT). Gradient clipping at 1.0. + +## Compliance + +- **Condition 1 (Causality):** Sliding-window eval is strictly causal. +- **Condition 2 (Normalized distribution):** Standard softmax over full vocab. +- **Condition 3 (Score before update):** Each chunk scored under `torch.no_grad()` before SGD. +- **Condition 4 (Single pass):** Each token scored exactly once. +- No SLOT, no pre-quant TTT, no ETLB, no n-gram cache. + +## Reproducibility + +```bash +pip install brotli sentencepiece flash_attn_3 huggingface_hub +# CUTLASS EVT build (required for full throughput): +git clone https://github.com/NVIDIA/cutlass.git /opt/cutlass +cd /opt/cutlass && git checkout 08185b9c3e90510ee2b656662ed0d53b06d28157 +cd /workspace && pip install --no-build-isolation ./cutlass_evt_fusion + +# Data: +MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf python3 data/cached_challenge_fineweb.py --variant sp8192 + +# Training (per seed): +for SEED in 1337 2024 42; do + SEED=$SEED TTT_ENABLED=1 HASH_EMBED_ENABLED=1 TTT_LR=0.01 \ + MUON_MOMENTUM=0.97 PARALLEL_RESIDUAL_START=8 GPTQ_RESERVE_SECONDS=13 \ + torchrun --standalone --nproc_per_node=8 train_gpt.py +done +``` + +## Attribution + +- **PR #1529** (@msisovic): Dual-lane parallel residual architecture, Triton fused MLP, CUTLASS EVT +- **PR #1394** (@clarkkev): SP8192 tokenizer, GPTQ SDClip, depth recurrence base +- **PR #1413** (@dexhunter): Legal TTT framework +- **PR #1445** (@X-Abhishek-X): Hyperparameter tuning diff --git a/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/cutlass_evt_fusion/__init__.py b/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/cutlass_evt_fusion/__init__.py new file mode 100644 index 0000000000..ce6fc83058 --- /dev/null +++ b/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/cutlass_evt_fusion/__init__.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from pathlib import Path +import torch + + +def _load_extension() -> None: + here = Path(__file__).resolve().parent + candidates = sorted(here.glob("cutlass_evt_fusion*.so")) + if not candidates: + raise ImportError(f"No compiled cutlass_evt_fusion extension found in {here}") + torch.ops.load_library(str(candidates[0])) + + +_load_extension() diff --git a/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/cutlass_evt_fusion/csrc/gemm_act_grad.cu b/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/cutlass_evt_fusion/csrc/gemm_act_grad.cu new file mode 100644 index 0000000000..e14fadac10 --- /dev/null +++ b/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/cutlass_evt_fusion/csrc/gemm_act_grad.cu @@ -0,0 +1,164 @@ +// CUTLASS 3.x EVT kernel: fused GEMM * elementwise multiply +// Computes: dpre = (go @ down_w.T) * act_grad +// Where act_grad = f'(pre) is pre-computed in the forward pass. +// +// Layout convention: +// go: (M, K) bf16 row-major +// down_w: (K, N) bf16 row-major -- CUTLASS B(N,K) with RowMajor layout +// act_grad: (M, N) bf16 row-major +// dpre: (M, N) bf16 row-major output + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" +#include "cute/tensor.hpp" +#include "cutlass/util/packed_stride.hpp" +#include + +using namespace cute; + +using ElementAcc = float; +using ElementCompute = float; +using ElementOutput = cutlass::bfloat16_t; +using ElementAux = cutlass::bfloat16_t; + +using namespace cutlass::epilogue::fusion; + +using TileShape = Shape<_128, _256, _64>; +using ClusterShape = Shape<_1, _1, _1>; +using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + +using EpiDesc = cutlass::epilogue::collective::detail::EpilogueDescriptor< + TileShape, EpilogueTile, ElementOutput, ElementOutput, EpilogueSchedule>; + +using AuxDesc = cutlass::epilogue::collective::detail::AuxLoadDescriptor< + EpiDesc, cutlass::layout::RowMajor, ElementAux>; + +using AuxLoad = Sm90AuxLoad< + AuxDesc::Stages, + typename EpiDesc::EpilogueTile, + typename AuxDesc::Element, + typename AuxDesc::Stride, + typename AuxDesc::SmemLayoutAtom, + typename AuxDesc::CopyOpS2R>; + +using Compute = Sm90Compute< + cutlass::multiplies, + ElementOutput, + ElementCompute, + cutlass::FloatRoundStyle::round_to_nearest>; + +using EVT = Sm90EVT; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + EpilogueTile, + ElementAcc, ElementCompute, + ElementOutput, cutlass::layout::RowMajor, 8, + ElementOutput, cutlass::layout::RowMajor, 8, + EpilogueSchedule, + EVT +>::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + ElementOutput, cutlass::layout::RowMajor, 8, + ElementOutput, cutlass::layout::RowMajor, 8, + ElementAcc, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + sizeof(typename CollectiveEpilogue::SharedStorage)>, + cutlass::gemm::KernelTmaWarpSpecializedCooperative +>::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue>; + +using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + +void launch_gemm_mul( + void const* ptr_go, + void const* ptr_down_w, + void const* ptr_act_grad, + void* ptr_dpre, + int M, int N, int K, + cudaStream_t stream) +{ + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = cutlass::gemm::TagToStrideB_t; + using StrideC = cutlass::gemm::TagToStrideC_t; + + int L = 1; + auto prob_shape = make_shape(M, N, K, L); + + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + auto stride_Aux = cutlass::make_cute_packed_stride( + typename AuxDesc::Stride{}, cute::make_shape(M, N, L)); + + typename EVT::Arguments evt_args{ + {}, + { + static_cast(ptr_act_grad), + ElementAux(0), + stride_Aux + }, + {} + }; + + typename GemmOp::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGemm, + prob_shape, + { + static_cast(ptr_go), + stride_A, + static_cast(ptr_down_w), + stride_B, + }, + { + evt_args, + static_cast(ptr_dpre), + stride_C, + static_cast(ptr_dpre), + stride_C, + } + }; + + GemmOp gemm_op; + size_t workspace_size = GemmOp::get_workspace_size(args); + void* workspace = nullptr; + if (workspace_size > 0) { + cudaMalloc(&workspace, workspace_size); + } + + auto status = gemm_op.initialize(args, workspace, stream); + if (status != cutlass::Status::kSuccess) { + std::cerr << "CUTLASS initialize failed: " << cutlassGetStatusString(status) << std::endl; + if (workspace) cudaFree(workspace); + exit(EXIT_FAILURE); + } + + status = gemm_op.run(stream); + if (status != cutlass::Status::kSuccess) { + cudaError_t cuda_err = cudaStreamSynchronize(stream); + std::cerr << "CUTLASS run failed: " << cutlassGetStatusString(status) + << " CUDA: " << cudaGetErrorString(cuda_err) << std::endl; + if (workspace) cudaFree(workspace); + exit(EXIT_FAILURE); + } + + if (workspace) cudaFree(workspace); +} diff --git a/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/cutlass_evt_fusion/csrc/torch_binding.cpp b/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/cutlass_evt_fusion/csrc/torch_binding.cpp new file mode 100644 index 0000000000..c128100f3e --- /dev/null +++ b/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/cutlass_evt_fusion/csrc/torch_binding.cpp @@ -0,0 +1,46 @@ +// PyTorch C++ extension: CUTLASS EVT fused GEMM * elementwise multiply +// dpre = (go @ down_w.T) * act_grad +// Pass down_w directly (K, N) -- NOT down_w.T.contiguous() + +#include +#include + +void launch_gemm_mul( + void const*, void const*, void const*, void*, int, int, int, cudaStream_t); + +at::Tensor gemm_mul(at::Tensor go, at::Tensor down_w, at::Tensor act_grad) { + TORCH_CHECK(go.is_cuda() && go.is_contiguous()); + TORCH_CHECK(down_w.is_cuda() && down_w.is_contiguous()); + TORCH_CHECK(act_grad.is_cuda() && act_grad.is_contiguous()); + TORCH_CHECK(go.scalar_type() == at::kBFloat16); + TORCH_CHECK(down_w.scalar_type() == at::kBFloat16); + TORCH_CHECK(act_grad.scalar_type() == at::kBFloat16); + + int M = go.size(0); + int K = go.size(1); + int N = down_w.size(1); + + TORCH_CHECK(down_w.size(0) == K, + "K mismatch: go has K=", K, " but down_w has size(0)=", down_w.size(0)); + TORCH_CHECK(act_grad.size(0) == M && act_grad.size(1) == N, + "act_grad shape must be (M, N), got (", act_grad.size(0), ", ", act_grad.size(1), ")"); + + at::Tensor dpre = at::empty({M, N}, go.options()); + + launch_gemm_mul( + go.data_ptr(), down_w.data_ptr(), act_grad.data_ptr(), dpre.data_ptr(), + M, N, K, + at::cuda::getCurrentCUDAStream()); + + return dpre; +} + +TORCH_LIBRARY(cutlass_evt, m) { + m.def("gemm_mul(Tensor go, Tensor down_w, Tensor act_grad) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(cutlass_evt, CUDA, m) { + m.impl("gemm_mul", &gemm_mul); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} diff --git a/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/cutlass_evt_fusion/cutlass_evt_fusion.cpython-312-x86_64-linux-gnu.so b/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/cutlass_evt_fusion/cutlass_evt_fusion.cpython-312-x86_64-linux-gnu.so new file mode 100644 index 0000000000..dc43686891 Binary files /dev/null and b/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/cutlass_evt_fusion/cutlass_evt_fusion.cpython-312-x86_64-linux-gnu.so differ diff --git a/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/cutlass_evt_fusion/setup.py b/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/cutlass_evt_fusion/setup.py new file mode 100644 index 0000000000..ec282243bd --- /dev/null +++ b/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/cutlass_evt_fusion/setup.py @@ -0,0 +1,34 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import os + +CUTLASS_PATH = os.environ.get("CUTLASS_PATH", "/opt/cutlass") + +setup( + name="cutlass_evt_fusion", + ext_modules=[ + CUDAExtension( + name="cutlass_evt_fusion", + sources=[ + "csrc/gemm_act_grad.cu", + "csrc/torch_binding.cpp", + ], + include_dirs=[ + f"{CUTLASS_PATH}/include", + f"{CUTLASS_PATH}/tools/util/include", + ], + extra_compile_args={ + "nvcc": [ + "-std=c++17", + "-arch=sm_90a", + "-O3", + "--use_fast_math", + "--expt-relaxed-constexpr", + "-DNDEBUG", + "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", + ], + }, + ), + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/requirements.txt b/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/requirements.txt new file mode 100644 index 0000000000..230a946b3e --- /dev/null +++ b/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/requirements.txt @@ -0,0 +1,10 @@ +numpy +tqdm +huggingface-hub +kernels +setuptools +typing-extensions==4.15.0 +datasets +tiktoken +sentencepiece +brotli diff --git a/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/submission.json b/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/submission.json new file mode 100644 index 0000000000..33bc1ffc50 --- /dev/null +++ b/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/submission.json @@ -0,0 +1,59 @@ +{ + "author": "Benjamin Hadad", + "github_id": "codemath3000", + "name": "Improved Parallel Residuals + Systems Optimization", + "date": "2026-04-13", + "track": "10min_16mb", + "val_loss": 2.77730717, + "val_bpb": 1.07518215, + "val_bpb_std": 0.00058614, + "seeds": [1337, 2024, 42], + "seed_results": { + "1337": { + "steps": 4745, + "step_avg_ms": 123.8, + "val_loss": 2.77555969, + "val_bpb": 1.07450565, + "post_ema_val_bpb": 1.08225270, + "sliding_val_bpb": 1.07564236, + "artifact_bytes": 15983819 + }, + "2024": { + "steps": 4724, + "step_avg_ms": 124.3, + "val_loss": 2.77822748, + "val_bpb": 1.07553843, + "post_ema_val_bpb": 1.08329084, + "sliding_val_bpb": 1.07686704, + "artifact_bytes": 15982374 + }, + "42": { + "steps": 4744, + "step_avg_ms": 123.8, + "val_loss": 2.77813433, + "val_bpb": 1.07550237, + "post_ema_val_bpb": 1.08321330, + "sliding_val_bpb": 1.07732492, + "artifact_bytes": 15979637 + } + }, + "hardware": "8xH100 80GB SXM", + "pytorch_version": "2.9.1+cu128", + "technique_summary": "PR #1529 dual-lane parallel residuals + systems optimization (fused Muon transform, EMA foreach, numpy prealloc loader). Identical ML; faster step time yields ~20 extra training steps.", + "compliance": { + "train_under_600s": true, + "artifact_under_16mb": true, + "eval_under_600s": true, + "no_slot": true, + "no_pre_quant_ttt": true, + "no_etlb": true, + "no_ngram_cache": true, + "score_first_ttt": true, + "three_seeds": true + }, + "attribution": { + "parallel_residuals_architecture": "@msisovic (PR #1529)", + "base_stack": "@clarkkev (PR #1394), @dexhunter (PR #1413), @X-Abhishek-X (PR #1445)", + "systems_optimization": "@codemath3000 (fused Muon, EMA foreach, loader prealloc)" + } +} diff --git a/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/train_gpt.py b/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/train_gpt.py new file mode 100644 index 0000000000..f3cba197ee --- /dev/null +++ b/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/train_gpt.py @@ -0,0 +1,5 @@ +import lzma as L,base64 as B,linecache as C +S=L.decompress(B.b85decode('{Wp48S^xk9=GL@E0stWa8~^|S5YJf5;ZAEuo?QSmn@VT6Qap3bt~@<3h>ok~)Km_aAcM1$ZA=RNsrI&uUw)pb_nMj0LFYCMl-ULtvz!0lTlkwZNfQb9u;zP;lKC6%NM(=|8~7kIg$g6~+qlxGpltspif;>W9Ih1cC**hctLEJ6B&YTfwKJ0OWvU@y)_x#hGZhJIRivtel&itT)Fq087gVUKXtD2j@Z3{-e*N3Em08eTqH^At=kOWE%ADhqBPB8yof3E%JBf=zKSFbu4w4u0CPt;&xAwL8q2psEB4q`T%zHUAfK76-mpdof;ejJ487{F@3b`mgvUZuw$uJooXBO%9iJpG(DEO`7Kh54ExsBd}&aEC`~|vgw;$M5%SyHOO|C5+W<_8Q!{Ol7``eh_@X7DV3i_fb2YZJlDagKbp&%;_pF5FGA>S)i+4xoE(A1{g+R(YJjVhtV2yjZUS$O)y?fJB4b`wYU1p<3VAT+rGakh&ji;kZ7;sJ%aDuUPg+~ylJISPfwfv?&MPG|nf^l{iL`qUr*SlmFhx>T@>W8;6MxfT!74#J|v2wem7N{p-lo4urwZ_o|H9!Vpv$Bs_-i*r1Se9{Qc8_i!}WH`+&3eq?uqln4|ZwYj(=ds>MO0vUWqpaCqivMfTs{#N1@=kYLA>B2F1FF{vuTk?J-N6p?RbIi+(wP(>(PEQD`O<5??4=N-3IoN2066Y3OX08=Y6P(+WoGI>Do5QrpzxMlEjr4I`TCb*xJBa+p(TO!;$y=V8P)StcVp8u_!y&ZNDrD27>K29SWWeRv6={47_kC+PF@jrr?KDg;1Eoc9(z28rZQZ65ovU(@51Z^SFQ8U_+ETEQnaNDcr+sFBvTyj)5J5RaDx5R8ht;ojo(UZH#U@a;9=I6*PRBHG$w0%|2q3*oL#F0%mR(l2)i&FF#6VlX!;%J(65%dKC(}WO1(Sry0JSU}*xV3c?DGx#66+RfkgkxUsr7I1_DLz-v=>RE_V%9cxIG%Y>yf7@Kbu&+vL9#!$z4dG$u7}b*3RwAQli_v^Ztv^wETr7qeC_b+6TkZE9TjvOdWGnINUavIp!mTe=;$z%-0LV?QDYH)HPta>CJz-bdl0dPgAyj-YO~6@?v_?m&?x{)EctU;}Q5F5-ybA>puc6%AmYm$NxEE_8w}k#Jq-7LIK5I3!=%|Gx}Cs9pY+O*FvuyR;-DgUvt}U&B$0h)Wk`U^m^Tl*;3#c~Mg#Lk)S-kMu{kcQMbBP*`HI5prQ`oFc8(Jr_!+^2ZZhC4D<}3L@OpvNy2j?9-%GQ5I_wXuGJyKkJq`N#$Xz*~x?xRA_+g}end#hB3Vs&kQ`h93^s6-eG+Kw2zbmf}Rz!udZ>#Y>47nKjLNWCdNurd!Bfye`vp~9(-mYP!ypd5Y6oEU6Nfwg`tqSa#^*!pf&#PGy$X77~dU_5G0>-2Ahea@2uKq`rJA+@Pi&)DX1-b?r7a~L!tKsxJXDml9X~Oid&e;Ag`{xk`}Jt?PK-L#qsHF2t-dFauQ`shc=csDedP+jdi(3fA?;IR=(VCx@?);N~Ib5wSH+*BdfzdC&50wydTuFw0k&jhCad2Ltq}zxm`ouqnn2C8*1E!z(rE*V55japLl?{C0y$*{rexUOOJh#P*1@T5*47`EYP7I_uVuH0vj9OooWYFLLc<2N@h_1kGZNv~GkwjOh09IJyedAJr<{NUgLy$##@8g*4EhyMBnS`RkX8dV%Th(Dqkufsn?1odN>UZ_t$durYQ>797*j^s>#g5mKPr*`ZD41W04&Zc>Sg*$l4}NvB{N7GUy?5>iS`NK&Lc-4ZgxYAtAv})oE&4?%_pp#vJ8?3xzNXQL6o{VDv(Y|O0VRaC>>Vz049ZPThUq%-A6Pc6i$7cKafMmGCL!-4|KLtDJi@X>+0#eVak3H;W50XbDl!nyK=SBDVOyiC9e$k)rAfg=+Qa_i$D>*&TCBzDRBuTR0dG87Z00psYYhYcYX4MGTnz(vBdz8Fk!Bf3`}9n9TL~qa7@OoTi8zRnd7sjyWm^e`p9vsqO+Fi&w^BYt=D>}UY@>0P>nHTtt9SW53ZbdNso~PbmIn8B}heu)Mq}1AU2U&rgKL4#WJDQHB@BE@qAVscXqCA`YR3CYTaJ(Ujy?huRAoa#mm6gBj<-OkwE?{f*8!&-ol}GGhUXehLoHbihc)mq1c^);SKbuNxBwn7l2A*kS643YJ@x>m^Gjfx8ur}ZZs^+uXy$Ru>NzB$0a}JBNKT$9L_8d-xtme;H`p#5)YWFRJB5x7<)iyI_UWlAOML7P1=y#YrNAW7UZ)Ov_DJaMnWfuQ!?xy#!9O&FJ@ah;D+ZP^*M(1xYR`;jdd8QZ>*U3^QsL{vx?le3i_Vf?5k#gKQH_;*lPoV`|Gj8BivBj+h)ADbB2`H!{n6#geMBRSDAxcwuo3`C9NX0phw+D0@ZBD_pt#cwr?sEBn^~6Vy-w-!I@Q`)Z+-u)V)2SS`%1?zz*b9i)j5HjJ^%qlo;gh~2^j2(&Ix?W%`4Q6=qG=*o-U?An!Z(h5>8HPkmn$MblMOYx3AJ`HQegfeL5r`u{Mp5r+|$vG!SvT$CU=k`k8*NW!U#YPpVbH|P9!LPR+3s|=|Zn*i%s;G=&4G=l0e>lIx+%7lBme}F}C&`PysXu&vqfTY4xiiMDPo1l`QIFMitKu;$Vj9;vs3aBdiRQyLX*ttlI_DvdY=D!F5!#c#un2oUg}w6=nxNQR2QY(~Qg1QQ)gM*Bv$nCKC^su})h`eP)@kFbA6#=2Cs=rTCQsoiFs>F;?c)yu_F1JB$|tt0mH4cbV2qlHsTW(LDhKuAkxUK-l=xCT>VMguG~_WMu3JWqKw_PDeH;X6|M2ul7QoF%@(hyL3wV)Rt_A#vj>Y*^Ud{kjHItk)05yRBQsM6mD@NvsvSKj^%C?Exvrx^>8cQ`H<1aEr^Px+Gp^3i?f~vw(#xdugzL+e=;O7ylpK7D|N(zNcg}Du}dZ6yfy(ACFau6Q+$s*0tXKk^dP6*gFq`C8`En*P0z*NFw0}k0}DtR8u;Z)Ixv5Bscr#T}^dJ%g=`fv3Y%5$w5ew2Ye8+G9o3(u#+$y^=UcW$pl^`7o8VE5x@b+}3`?|kzFNgkIT)ieqZKQ@@DSQ9v}R2X99m3M?!etY2-1)m(poEd=`jSlFO&xC?)O5$>Ii-%imnh(3Gw(F{e7Ilw4kD3>&$R_m*Fh-|W%78OBX-;&Fc5U=N#mLzj^}E+{<^?#mtd0nF}b4H`Wrf>hU_HezCfE5PttXy=9M?~TG^&D3@%<4-lHS5u2a-WG#jHQSIDr+4CXxO?;h447XgUXC0BCue#h%mjU+-(@Sz^WVu}UYRsTI56{WBcFG9b%mIiu$q*a`&uGQMUaL7z!kbBlIdJO_*B82_m)-%ZRbfluMyx725JAn*B0GXfl$VQm2Sf$=wOD!8*{}Y6e6XO0+glsJQ%h_jLm>kHhOE@?_hi!=odYHn6W`iLnP6P5+hq?1({V~9m;beay*Q}YnNRhXSS9oNS|^hpf-?uJ^mMCYQ`@8Z}3mUb<=qSpajZs&$fULv4?Vxq>dp?1KB1F3Hq9E3Z+vjip>sYlzmRZ6;{0y6;~mFnl1AG{`E}x4bJD#!mW5uADF3=g|wrRfinACfKtTLh;~aCwX_F5A7kZ?b~jcvqa37%_-$uzwIEQ?(IK{6OA56;r0<(&+67%`2PB;N>x{dCO6!lt*g6G!{h2etQczy@bRCa8=XFMYogK0O4#y;>1MfWT3SP2_T!D;7{2Dw6U6umn81A3IL1>K|$>FLk=7sa;K4WElUQ;iiJ$3l|fr2_U>d&IECv=A(F2f)YAj;aDo|PPdkP;`EA$LzPj$rIkl2#?=aO}@_;O`4g`U{*--K6_BP88@cs{V-?%u+DH`3@OT5KZ-QAZ#nVV5atp!fMocnuktrI?I9Wo%_&Kz~vziRL|WH?kdxH_iWU5Mxm4_pYp$dPDDYkv!!vG=#7fFdBdPK+=`M$!NuO@K^M)M>Ej84b1t=Dg*y`(dUlEKe&XI<+RSSw|%Tf_}@w%wvnvdjL293)R;$R%yF`e}Bz7D4cfn)9f+?A+)LeVqN0S}qXQBRfdCJ!s13%S=`vwYpf5lw=#56HA7Gnqw|(1OARNkf*Q^Y4Oxlrcpk!&**W6pw&-5Elwu8qymy1aNbCV6LuVE2qCQCo2*)tzGU_0;%gK*e$ZjLF4})&z~>M;Ja=)31H<4K|SPq@E>rYO!}aP!WgOBk=CR6lQXi#+B4YUy3u_@Q3bpd8ZzWu0gaYTpWZl;&9Htq&h&p9r0I$TV3xGOZj*}@cLsbOl#w`MGw4=GF{0sY8*|Wt)>^nF?;f^4Gw8#Pq1PoBP*!y(r)T+MYY!&?;Pm-OiHHzCZ^Pc{rq?5?7Tdn^~J?m)6Wqv(8j$)Niy1s*x>rd1V2`@>j3lm%W7HRf`m1aR&BA8%VrxH5Npk+FQM>3r=fd;oC?2{nx;zBm0I5icR}|LyFD7UNAffaB#}PcKy9*wGa*{-*H3byfu<8jCEhY}LGLq9osZ7CC^*0l^@3mVQ?~r>|DK!zHqct5GDw5RF+P5=?bhS!$t+tM@owy$9D50=Gm#pGH?UbkJjSbR&xy%G(3bL>R91D+Ph{44XF7rVdqJv>Q~HV2+=5ha*qJNzEacB)_y9%=m?4t0L1c)q^lOADH4|nA6l3$Xm=@u~D7IgxbSvM9Xd}rcaO&xj|UWFfej8cgd*_sU0Yqd-+Wx;<*9_4vx!jFJo`=O-$TD?%o73FCCG1H26lNT3*7;Pjq|Uhr#EOC8D6rZX5k6W^yJgi-?->&Rh5Q8u&KS2f_hmST~YnK_jjOsfj}mbpr@2f8uLb-J5Mb^wp;fSk_kVPbDNN@`I0ozcmr8@x_ylT&C|h$riaPa(VRxH+2{Bl8g0*cD*W;u%&}WL<||Ni6Gh!)5atx53|W#fmx8VTIT-rk4InR*1wv5r_yV81rvbmYmz%c&ozV6twIlNNP~Cd9m9Y_F!{IwUw4OiSb@EjaYJ}-G|zjHhS9S?MA+l(Ef_m1QE>P!XeSekX-%MHbinsoFj7nNqNA9;d&sDT2<41s5^6@r3!=pR!;5jYdcx<@6#+kIppiG73BfBGoB=FQpN%H&xivc!GE*Lvqz@u8exN%mzpD9@YE4C*+nv8DzHPnrS@JgT0~O@*=1qO?swB_WIwo2AV9C2@t?6*N>n<9*0?xq(fnrv|ZI{|2{K*LP*6NxxP;T~J`pXm`7zoa5wbpo&CVq0`IoHX0kef8{z=lbV-wZ>RKfIre>AlCL}Lb8@KFz!-@Vr1#=BFFtUbpnoKJpN@3u!bYOy???R!kun`54ga$q{PzqkHU5vP1fWKXp(YmCc*v&?ckYbtzLTUPZH4Seim8t}MZVv&?`_AHTeBRl7!3+W)Drf&T5pX@&9R8g+!IP=27raVrcm7)bCtJ6h?{|@P^oSYlE?bQ0Tw7#Usr}}7xp$OQ{$vxmbuY~qYyeK>?@rTo91DcR8wpcdGE0R-Hgl@HFvH3K0+I+d~8bW>9O23u#2%A&c$Kc&*2Hl$VHo!+T>GT&Ki>`0UjcsYZ}YMZ;yRIPIjYBr>{!^Fx+fFL_`FPQ>m^K=anCp4cM=7gXrzA+)_gXdetCCZ7%+|&lqerX2vl%aUrj026*m3%fQ7d%<(Qx6kysg3AO1+pwY|FC=9y!D|RF4zR&ax(E>n!R2M;V8Nn98X1~M}Rii8jXG$7Pfh*b&vr{!4Z)(Fn&gpB{2Aj@{PlMruPO1LJQCq5PwJYr&r~H&KCj6>LpIZi3v=7&S1FXGL+4hebXY-wUv-UTsF2GlN6_bDv)Z-$es|m!DR?9^~z#iWZ&fq!){Ue%@eDNp_3fx^fD_88NP!NV!`bP>R1U|wq9QE+83Uc_z*28(tpb;t`yZ%Z_oAEY^NQgFcf}{I|U@#EndHIR2a8F_CpR;{*aQD5sQk?>u;F73|wfd%MX=_ul#)daati+P2Z_GWU?ptI`Hw@69#_m6ENIeH^wUYasLTmngGFa?2t$Jfw{uqAES2;#ES#8{a8SiLWIeS?Z@tJWV6=px;DvLZ7(SWSi_t%&7M*m7lj9M4X-9G<|ul^@SPsGF12@9+Lj$H3AB&O#u`d%u41Rvd5`kB>OyeqPl-XOkFviZJ+&)$y&nh6q8uxEFdke?e(^fkBbJ1XF@5mcpXKu~oynQYXXSXh?YJwo|19XZ2xr$ZVPa)Au3Idc7@~9tQRn~Zwnr`EYrcjLm@&HYkJ7Xg@L^~z7iKEdp#i`6KcWnwd`@sy&LGd}D-uO6z#oQK1daSYg5((vefybh(x@MuykIsdm^-8+kOeVC=EQ+>RYXEx0f6E-hJRFJv*3RE4B+ZA$HU33u(s$?K;qWJx)DZdN(B*-j$!d}_CaK|eFw|%^7iW0Dxvb(8Gix2V;kMgt`n;$_3U+7JmW`Wq&Un3ue)>tzp1ExRmUEVO;d|4~%f-LCuagxMq-UAUl-C8#Y!$W~M(OH3!Q30RUu0e2kW^r2dkE^FE22d3pmHiUj;~YWGq6o~L?Iz2^O*S%V0KGU@C^@E$*78!K9fpIE2YT=a$fP=#SLz8y~@pjRe_nzg*m9~+mq0@j~l9icp`$m+aa!G24bw<=*_y8Q6mPotZTlUW45&6ZL{N;kAPDQJyI$Qk@{}q4}YTBb;UF|4CuV<_?WHfHMoI;X?s+ygoN&B{~CK|>pMFWVimq|#hL&iLYVtSc0eE6)b8I2#tVRUS+2O-IEQR37ptq|GoXA41CJjjB)@iAzjc7(XmTT?>&iY;>Y+!7`^r0uOyw=T(V(fDcQK1T-COK%SJtdLoM1Osgr6@>>#~iP6H4D&avm=;8;vYQXl_9fkm9h_{9%Hnd-6wwc&LD}^`_Xr{D7G`VUdxsBp|CJ=6wP7y_STa89NU@n1Y%$C7WT?_>tSv7WEUp&nZma5Z!Uo9?mha;yl7Ja48G|Ew3LVrC^+DZIVtUf5aeT}X~8;b5!^YgNj^~ZT>851aMNy`_Xqc(1cqbHPT2hjy$&4=*JY(c0H394+vygqL(MuL&I*Ql)B%bCN3wAQ0RT8_M_a=ZP$V+B_io?Bc?Yrf3Do^x_H6-~&~1h8B6p4lfe|XA{S@)~`V`s$94#=1>g3xyJ1k&oRfGnSfnF_yU6&*qJr>FlIGFDd{e0M`q5kS8P^q{#eAF^5TlWyEW@rzvr2o_wUZMsN^J)FroFb{%rZ{ba}lgjqw`JNUH^6a?Y+g7TXL6W!}acfowPWmOuU_n^MkpJAg@fLa4gmMnO5!q>kn}O9;q;LeqTFlEk`@+M4|E3H$*J^?fG28l@X|h+=gdSS)~u-SK4FUf_(4GA}%dBE8)1!w)o@ZLEe9;d_jXC6*{whBN+BUOLzP}#(sNXwlgkkk-rJP|dFBT4R&v#>zpw%-=D(vI{Ez~(dcOOyj5yb3jRhBR%3?4gzYJ?+Mg%{Hb|FjUq+I{G?&UL61sED|7|J5@xG|BYfHZ31PJ!7iZ%)NXMHn`|3bZ6@Q*Uez6r_{~Kv+h^eq)*ZZ^H)8P8!4qf`SfJP79_#zMk?E<@^t%CI3aF6eVLcOCW_gowgcs#0{9JxUDu`3(Q?_dM=Z}&w%#&J6E$?_0pipU!eTnxv;OC&mfCE^B5>c^&v<<>N!OqlN`fqI{&%gZAmZ*mfFzRoUUT@`P5JsMUXw&=oZ6!^U_)-64MG1SN|K1r@9yIVjuMwZv-c@X8sC*8JYTN;9e8}szv?vWQ{8}>oJto?sbm^Vg%4JFG=a?N=esT%*bH6J!)BJQPi*+b{Eb;WA2Z`N+e^rk`W8Nu2H&~*LNn$haOD0U*tN+!P*9xFgEOyp=2Ewu4E##+9?j8k)xM{JF!mMlOR-~dGL6Py}=zR!|LVL-`s$6(OWon;_(v0ivnb1v^R(&{tU3(N1lcZH&Hts<_7G_B+$>geSAx7ZDXEAUbDN-m^jg-VVrx6SD(mShQ+mw?xdV7oO-=vPY8Su<4RoL9DD5*Vno7j(|#mseg5LzPQPN=KmpPwkYS(DSIP1nFs^FOf)MU*+-0Wz3TZC_bFzi~s~RX00E4ixxYBMT%8xFw^HI-?u3>f>|wCgRTt(gm=T1*)I4>s^QNt=KFFfqh~8ZqO+7q5(x!m${Vd!6&A^UgUFjQ3xESI{XYopPe(;YLMOk77=%ig(tDLBiDUYF(g;`YW8Hzy>0uD+)&WT08-aVr9?@J8O{E4OAduDR_A*gYF5#f95mXa;rQ*wr~+_U7dDvrra$Qj1@jMIKG}ddF`!T*aabJ!*unv-h&NJEo>y$+BuTxtBWO_Jn7KgW%SDgs+I+^>m6dg^A^tI4{6KUoVK$B>q5xzlN6n&5xZXEpelH;J2lVpX4g3jP%?7q=5uD>jjT%+AeH;@68}bMx)T%`bY6NqvXq?xLyAbs502dC#Fe6Xa{OU8s8A!@wAtq+=R7dM+SXDet-4?cp$A?@b@xMLndAm)z}hiu=J_2YzH*O~Fttun?e}CP-OiPv9cN15P^VQbV?pG;KR>Uxog;%5?aOX0h7j#}J<0TmBsUIbCCw&c2frKaT54NAgzOU&&iOs_73&pN-Jj>N47LGw@&9OSp5vbr8;^5_QWtUuF*^~9sce%_nj5TR%w&Vw42JaN$E?PDny@ay6JdgrnO$Q_BqiDo%uHJ=wl^HB-QD_sUM|TLPq-j!BSB`qQuV^Y{MBsHDwl8b0;wtITVY5i1|51zK|N|r`BxP=hv~#gB@#6@iic3GYI+Q%6#e3gwvehwfhOyccN?@t_%~7}rrkUaha_>JJ(IlkVHK4_-d2y#36_R2xgus_Gs()1_EbS`1Dep&OJcd@59ncf*(2dNtACT8VX;TYYV$&zBOVtgg&q6x3~()vV-JP%wT^_rl+W$mBC+2kl-k!t2%yosd2S&SsK!aMQ{J_sKHa!X0==?X>eA>KH>*TK5V0ieNP&}tdm!0D#T9Et>GSE0F3;TgaZC@_C-lOcsU=Il@@+I~w+>lpohkJ9E28t*v7p_fXLe;5zfaW7PG(~%_gp);9ePGWKc)&B*hz6XroVz5%J-Pr6vH%kVX^O$rzg{+vKxPBa$+Se!waH#($qL{O!eccIO)93E<-*c1TF8lX}SyKNBv{xR?`fHA1M8SoqM(6lXFsRZb`@qfZ$z=e@yrEs{%q0US({M)Zy+KimKojAvAa~IC^ezw>arQb4SofAV5N4IHJ*5VY}lx!8W)-5ITkl0R$r0U}=fxTzT*kG}3F8wLT$$sEqETIh&F!{h9DMn#h5WyZ1ymP@`Q)%L+7pLs|8@dS$5_WAi2aZ2)-c%v>UcnG_u03Es?)SFpe&3=KYrj+SACbx22alc3@*Jw|EBm7ebPM^M{pqm3aNFa^O-j$SsVk9ern26ryBKOqhK{1$p@YXdkHOa(OQ(y0MO@4;hjo=l%G^tKQCX`yCy^%drf})>}>fJG~>{WHmr;F))iNcKy75)zy^E%p7o3exZlUm9zScX%<#9<@FMAG@54?C#s{in6h{JX?WLzkwudQIWQBm*=`%>Y(9Ag06Q>UZUqu+bIT>;v)i?;8m!Vle6urj*>3g89Gs(87YoJQrJyBlez$TZ7b?R+YCrU4O!WENs21KrYUtJVioSrGX(Yx+MILD90DXON2Sl!Uh-$NFE$A(E}9=#h0>QMh!?>pAv&QF$ShE*sNJGqu4N#Ip&`7a9Llx^N?n5jmO!K#+A{@tOsN;ulQM}B~ub9?%;UfF4jDT<_<`hrF+ZAvJHCV#{Ce{%Kr^#h`VW8iAw(?RB-=_uXJk;yny2}vu@_vI7?MwsT3%zx~Rfco1y*qW2#=x}xcwzG}=4tB-~P4S@u<~u5rWQ&IezhoppOe-i$hOy%016xm~T~P3h2JX{yDY$z-Dt1($v|&ptN!I+T{zeTVUI?O^qFh;j**P^h>{5@!=Jmfa)iMXpCDz+%Y2MyjnT&Rbk&^J$HABt|6XQT5-69sP)`-dTnpfL&e^97(zn=lQu+SZXb5$)ul*fKc>w(%gG8OJs+v209Pr9;P>I3FLqsMNHzW|4!?4QIT&xiB!eH^g~ZiJ7lNK_?1crwQ;ZVOdF6bBgIcH%|=lMAIwz%CgKN%vw5W~3h8l*o+%(7)$hN!T8YxEAQc!u>w=2_h%6$-*gF-AYM_a=IZG9ymFOJfbhZV4=knvV#WgM7A%eb(>cZp=$s+7fGMlDuz9i8iqKl)9M@4n%fEldh{pnPm10$1IE0xb9<)f&C`c3+UR^IaKv4P(6s`yu&zH*{!JR2Ps&ecjk!@RS20`3EKoCa=iCII1ks_xn>sVZ&idw^a$KRjG+r7NbHkg50tfjoRE^cGm8{<#>@HA-m9%P6R>F}qffugtEtAYUK<^!^bS8{#Fv%KLwBB*sOwjdmTKazz9-p|mtMsGR(!Y@uqjUs+jBUPGXrNiodukG~%BX%ZeL5}7?iHodxu;_cd>A!OIwyv@M#jGnk=bpWdEA&?WtX)K>OQ0Vxw=X>FE-4$^NjUqIA70iNX8*riB#CAf?GC4CIS&R}O=lh1wdn-qrCkTY5YXU9hVQN_+R{jb-sQplSs@ZWlLdZz%*v#v*{Hy5m>~lPh5T+-lom}BTB(lTTmYZUvh{o*3Xab|fBTu;bHeEs0IdcZ=77VQceNuzuUw3A0ggLXd05$S+L-p~F@?lw!>5zjwX@wSM#B^;_zDT@5`2hSShLR4vEaHp_6ccO?pgU44t&vj@v_nN|Jko)2y303KSgQ)H$tbhYXY_MIK3woOBJkM<$NG2^Ren32W+HkkmxL9GP=~>{>qL%8{Y&lYkZXwS!Xf3jSoAYjMX`Y?#?2^Eb;zm@HmA3ym_FKP6bcL2MTpPEsKtQXn{c7Vd7KIkp<)s3Rl~NC@4-`fwMf^-V&`tJQr)Pad#o*@zfB496;VdK{k_PwmU=IXNMQHuUyXnKpgr?b&DaFPzl#!Gc4-+f96hS_q6k>Vxd(~uH=gX@ju!8?!4`V!3rzCP3(yGbMZy!lI1;8Vqw&K~HraW2D;#-_zDra8L7rCiYZMZb^j+%i=x6npZJeb7#4|8y%@h^N3f(#W0TFsuqW^|DN5n_1BtM1CRFGNI%ebmig1NE#ETz;I&p2ik`W48k8lqp8wglhbwm)yo%77IWNa21YHZkQY2Ao+q<<5CXGw<;CO+CnN#Pp_HxR^Bat5-9&qH)@$LtMv$DXDp6xzo06=t-Lfb<|8L^pBPNXXWWOgE7&)u5+l^MZ1MUx0PtVMwl~$x@FEl{2C^r^w+14)Mm>^C;-$_I+14?LTm;tt8D86IOx|J;$m(9MPt^rcLTuk;U{mirBdGzcpB*oCC7n{qq6*#>9N8DqmLLY8=Mj=IW9YpvGFdJ3k~guWf(C!Aizf+1Qj4#%iELLEk)uuPVxpx_s1F7D%j$eSj~BUmc{irdM}JbAq)*jb4^gl#MJ%NmC}wn@qysZ4u;joF+0EhKaiK+X$i-J9LP7y*7x?ZP}jW}s|=64gTXt>@xA1b^+l&r(Fstz2Ii8T$w^$D2~3aLAXA$-w?3b_45hs%yO~pXQd?WAczR39Jwr<@KPyXTNHJAHPe!I2@c;I2o59(N*ysCss?u7qI;WMfd5+n?;b%yW0`Hdf{Sbc8Z86~`2q1Y%vi7FCP(xjlOzMu^gfWy2V+^vE%q(`E$Z_*%|Z9*8I8N*5Ju)V8V3w`K|ku-m;qWWY;B9NJa--<*k->HSamA`bv-W$%c?`w1kUh@KyciQ*JE4Bt}DeVtknmQoMsXa+S>3tCMTT#E922j>W*M1H@_5{Z~VlJewQNR}nYD>c$wijN&h4f+Izn0=F7J)WP=h6)bxZ<)?g4V;0*OqFODII{I_3y2`dR^^=o^rTxv5N=jU20P$Cfs`(krCC-%pMZQe8Zdgre)il;f(7#q3TEG@Lr8)M$Q9E}szkO@EItHq@$QJWuSTy(Z!xdjVn@cuYkvEzcB)ieE-t6wwA)@wI^qv!lBe*ZrEoN^N{!pkBUqC>9+>uV#HFWV2|_{E=3N*1f3Em^(z`o=Ul&E?p{5wO6mvkJAc$kk^2W~&SRwW1Awkn1PISb7(JUiL$DgrhB_wUA+E&4c@-Y=4KV(l*e>>(C*f_`$eC`q>7P@z6kNmLPQuiXM;_S#M1RYgeaS|lc2LQ55l0%bI`pA%78{x~pL6JR&dy*Ij>k_p_!Ig`vkiBk;al$i=0NUf;^gX<&t9i8m2@zr%Ny%DhR-)F6>o5&Q7ub0J&;iKY$R$~-y0|>ekxV~bMJwF!1jjr%w+VJIM5l8#KA7$nHxIg&qkgWZzxiPpf)p;J|M{SntDqzk)}m2Qx+0j&*1XueTN=MfF-=xuk~*;=20T=-_YL5xChBDG&=C>}xRqbouew1)pjP8hu9jZ~VS>s!r%hJsp(;ekC+6yqou;lPgX|p+%tY%jnl@k;mBQzk@=br6rrF367(j+tt1XH}@U#kPQs=9OETRK2&pLy{c^o2?;YJh}@@Z0$;$YNq;>5vZ7dfIX?ON6qGdxYO{H907^QBlfPnb;R5hn&_m&hp{-D$eWh7xi4q5j3^>^}1QgWAEq(UTxIwW0V(}P-4cdT~$YpP_tCWaoeILRaZPqi$kdsl4`my(W{B0w)NUS~1UCb?G|jIYbwp{qMg9~|>bb=m1R-~8$^CdS8>xane2KhFX&}e0^&yzDYrBbo=wex*0#j+C|6eqtm+72Fva@&90R)_28{x5o5ETdOD5m7W>M}A6N4|=jGo<2>aNi9c8A6eXXrXmO8IT`Q~LGM6dG=<6>!DMEx-#N;1u-Y-~EoR0};o#4v+;N=wxbCxXTV?7pwfYlZnXd#?d3?cCk=6LJ~~UkAhM-fr4ZJud@Hu}dzPxvW#v|_UccXQ|)Cp_)xb&~E+?>za-WcB+rp7>^1iG;_(A3Po+Rf7y(i$v;YHj&`cN&Jb{~2qlyejGtvwScAdM2FVFeIj%YM1Sji8gT>$y^$Bppu1#$ViFk@XqWAnYc+)p4r9Ke1VjZn$s80Cp1KeYZiImT<;tjB2~4wN_oeJ`pqDAD=iqBilBKuP$0k3@4}lnF3q`ul)F<8A#pL#E)-=~;NsgK$*O`t~>Ad_V_vGp`&E)lPuwvw{LdbHqDdTCY8Pg~kTwj<@urYJy@XajfCqg$ugqY4dp#N_5bFurT9cR?1uW3a-RTp&%2}S7>e^>9qwxP%Z~MvA`7L5dVv7Ij-+!f{8i@Wbd8O>$#lDi`JkuNi}?cw$Li%>-%)2tPviGTDHG!2!Xz^pODCRlFDUzEkE!nS<9%~6m?=mfHI^N&-xhKyz;DQ`>q8+BZDj3(f|Meh1X+XOJYi#00Fv>0qmRwwY+HlvBYQl0ssI200dcD')).decode() +F=__file__+'.__decompressed__.py' +C.cache[F]=(len(S),None,S.splitlines(True),F) +exec(compile(S,F,'exec')) diff --git a/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/train_gpt_human.py b/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/train_gpt_human.py new file mode 100644 index 0000000000..6454d84c5c --- /dev/null +++ b/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/train_gpt_human.py @@ -0,0 +1,1707 @@ +import collections, copy, glob, io, lzma, math, os +from pathlib import Path +import random, re, subprocess, sys, time, uuid, numpy as np, sentencepiece as spm, torch, torch.distributed as dist, torch.nn.functional as F +from torch import Tensor, nn +from flash_attn_interface import flash_attn_func as flash_attn_3_func +_HAS_TRITON_TMA = False +_HAS_CUTLASS_EVT = False +try: + import triton + import triton.language as tl + from triton.tools.tensor_descriptor import TensorDescriptor + _HAS_TRITON_TMA = True +except ImportError: + pass +try: + sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'cutlass_evt_fusion')) + import cutlass_evt_fusion + import torch.library + + @torch.library.register_fake('cutlass_evt::gemm_mul') + def _gemm_mul_fake(go, down_w, act_grad): + return go.new_empty(go.size(0), down_w.size(1)) + _HAS_CUTLASS_EVT = True +except Exception: + pass +if _HAS_TRITON_TMA: + + @triton.jit + def _fused_leaky_relu_sq_kernel(a_desc, b_desc, c_desc, aux_desc, M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, NUM_SMS: tl.constexpr): + dtype = tl.bfloat16 + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + c0 = acc0.to(dtype) + c0_ag = tl.where(c0 > 0, 2.0 * c0, 0.5 * c0) + c_desc.store([offs_am, offs_bn], c0_ag) + c0_post = 0.5 * c0_ag * c0 + aux_desc.store([offs_am, offs_bn], c0_post) + c1 = acc1.to(dtype) + c1_ag = tl.where(c1 > 0, 2.0 * c1, 0.5 * c1) + c_desc.store([offs_am, offs_bn + BLOCK_SIZE_N // 2], c1_ag) + c1_post = 0.5 * c1_ag * c1 + aux_desc.store([offs_am, offs_bn + BLOCK_SIZE_N // 2], c1_post) + + def _triton_fused_leaky_relu_sq(a, b): + M, K = a.shape + N, K2 = b.shape + assert K == K2 + act_grad = torch.empty((M, N), device=a.device, dtype=a.dtype) + post = torch.empty((M, N), device=a.device, dtype=a.dtype) + NUM_SMS = torch.cuda.get_device_properties('cuda').multi_processor_count + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = (128, 256, 64) + a_desc = TensorDescriptor.from_tensor(a, [BLOCK_SIZE_M, BLOCK_SIZE_K]) + b_desc = TensorDescriptor.from_tensor(b, [BLOCK_SIZE_N, BLOCK_SIZE_K]) + c_desc = TensorDescriptor.from_tensor(act_grad, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + aux_desc = TensorDescriptor.from_tensor(post, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + + def grid(META): + return (min(NUM_SMS, triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N)),) + _fused_leaky_relu_sq_kernel[grid](a_desc, b_desc, c_desc, aux_desc, M, N, K, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, GROUP_SIZE_M=1, NUM_SMS=NUM_SMS, num_stages=4, num_warps=8) + return (act_grad, post) + + class _FusedMLP(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, fc_w, proj_w): + x_flat = x.reshape(-1, x.shape[-1]) + act_grad, post = _triton_fused_leaky_relu_sq(x_flat, fc_w) + out = F.linear(post, proj_w) + ctx.save_for_backward(x_flat, fc_w, proj_w, act_grad, post) + return out.reshape(*x.shape[:-1], out.shape[-1]) + + @staticmethod + def backward(ctx, grad_output): + x_flat, fc_w, proj_w, act_grad, post = ctx.saved_tensors + go = grad_output.reshape(-1, grad_output.shape[-1]) + dW_proj = go.T @ post + if _HAS_CUTLASS_EVT: + dpre = torch.ops.cutlass_evt.gemm_mul(go, proj_w, act_grad) + else: + dpre = go @ proj_w * act_grad + dW_fc = dpre.T @ x_flat + dx = dpre @ fc_w + return (dx.reshape(grad_output.shape), dW_fc, dW_proj) + +class Hyperparameters: + data_dir = os.environ.get('DATA_DIR', './data/') + seed = int(os.environ.get('SEED', 1337)) + run_id = os.environ.get('RUN_ID', str(uuid.uuid4())) + iterations = int(os.environ.get('ITERATIONS', 20000)) + warmdown_frac = float(os.environ.get('WARMDOWN_FRAC', 0.667)) + warmup_steps = int(os.environ.get('WARMUP_STEPS', 20)) + train_batch_tokens = int(os.environ.get('TRAIN_BATCH_TOKENS', 786432)) + train_seq_len = int(os.environ.get('TRAIN_SEQ_LEN', 2048)) + train_log_every = int(os.environ.get('TRAIN_LOG_EVERY', 500)) + max_wallclock_seconds = float(os.environ.get('MAX_WALLCLOCK_SECONDS', 600.0)) + val_batch_tokens = int(os.environ.get('VAL_BATCH_TOKENS', 524288)) + eval_seq_len = int(os.environ.get('EVAL_SEQ_LEN', 2048)) + val_loss_every = int(os.environ.get('VAL_LOSS_EVERY', 4000)) + sliding_window_enabled = bool(int(os.environ.get('SLIDING_WINDOW_ENABLED', '1'))) + vocab_size = int(os.environ.get('VOCAB_SIZE', 8192)) + num_layers = int(os.environ.get('NUM_LAYERS', 11)) + xsa_last_n = int(os.environ.get('XSA_LAST_N', 11)) + model_dim = int(os.environ.get('MODEL_DIM', 512)) + embedding_dim = int(os.environ.get('EMBEDDING_DIM', 512)) + num_kv_heads = int(os.environ.get('NUM_KV_HEADS', 4)) + num_heads = int(os.environ.get('NUM_HEADS', 8)) + mlp_mult = float(os.environ.get('MLP_MULT', 4.0)) + skip_gates_enabled = bool(int(os.environ.get('SKIP_GATES_ENABLED', '1'))) + tie_embeddings = bool(int(os.environ.get('TIE_EMBEDDINGS', '1'))) + logit_softcap = float(os.environ.get('LOGIT_SOFTCAP', 30.0)) + rope_base = float(os.environ.get('ROPE_BASE', 10000.0)) + rope_dims = int(os.environ.get('ROPE_DIMS', 16)) + rope_train_seq_len = int(os.environ.get('ROPE_TRAIN_SEQ_LEN', 2048)) + ln_scale = bool(int(os.environ.get('LN_SCALE', '1'))) + qk_gain_init = float(os.environ.get('QK_GAIN_INIT', 5.0)) + num_loops = int(os.environ.get('NUM_LOOPS', 2)) + loop_start = int(os.environ.get('LOOP_START', 3)) + loop_end = int(os.environ.get('LOOP_END', 5)) + enable_looping_at = float(os.environ.get('ENABLE_LOOPING_AT', 0.35)) + parallel_residual_start = int(os.environ.get('PARALLEL_RESIDUAL_START', os.environ.get('PARALLEL_START_LAYER', 7))) + parallel_residual = bool(int(os.environ.get('PARALLEL_RESIDUAL', '1'))) + parallel_start_layer = parallel_residual_start + parallel_start_layer_is_physical = bool(int(os.environ.get('PARALLEL_START_LAYER_IS_PHYSICAL', '1'))) + parallel_final_lane = os.environ.get('PARALLEL_FINAL_LANE', 'mean') + parallel_freeze_lane0 = bool(int(os.environ.get('PARALLEL_FREEZE_LANE0', '0'))) + parallel_identity_init = bool(int(os.environ.get('PARALLEL_IDENTITY_INIT', '1'))) + parallel_skip_lane0_only = bool(int(os.environ.get('PARALLEL_SKIP_LANE0_ONLY', '1'))) + parallel_mlp_read_mix = bool(int(os.environ.get('PARALLEL_MLP_READ_MIX', '0'))) + min_lr = float(os.environ.get('MIN_LR', 0.0)) + embed_lr = float(os.environ.get('EMBED_LR', 0.6)) + head_lr = float(os.environ.get('HEAD_LR', 0.008)) + tied_embed_lr = float(os.environ.get('TIED_EMBED_LR', 0.03)) + tied_embed_init_std = float(os.environ.get('TIED_EMBED_INIT_STD', 0.005)) + matrix_lr = float(os.environ.get('MATRIX_LR', 0.022)) + scalar_lr = float(os.environ.get('SCALAR_LR', 0.02)) + muon_momentum = float(os.environ.get('MUON_MOMENTUM', 0.99)) + muon_backend_steps = int(os.environ.get('MUON_BACKEND_STEPS', 5)) + muon_momentum_warmup_start = float(os.environ.get('MUON_MOMENTUM_WARMUP_START', 0.92)) + muon_momentum_warmup_steps = int(os.environ.get('MUON_MOMENTUM_WARMUP_STEPS', 1500)) + muon_row_normalize = bool(int(os.environ.get('MUON_ROW_NORMALIZE', '1'))) + beta1 = float(os.environ.get('BETA1', 0.9)) + beta2 = float(os.environ.get('BETA2', 0.95)) + adam_eps = float(os.environ.get('ADAM_EPS', 1e-08)) + grad_clip_norm = float(os.environ.get('GRAD_CLIP_NORM', 0.3)) + eval_stride = int(os.environ.get('EVAL_STRIDE', 64)) + muon_beta2 = float(os.environ.get('MUON_BETA2', 0.95)) + adam_wd = float(os.environ.get('ADAM_WD', 0.02)) + muon_wd = float(os.environ.get('MUON_WD', 0.095)) + embed_wd = float(os.environ.get('EMBED_WD', 0.095)) + ema_decay = float(os.environ.get('EMA_DECAY', 0.997)) + ttt_enabled = bool(int(os.environ.get('TTT_ENABLED', '1'))) + ttt_lr = float(os.environ.get('TTT_LR', 0.005)) + ttt_epochs = int(os.environ.get('TTT_EPOCHS', 3)) + ttt_chunk_tokens = int(os.environ.get('TTT_CHUNK_TOKENS', 32768)) + ttt_freeze_blocks = int(os.environ.get('TTT_FREEZE_BLOCKS', 0)) + ttt_momentum = float(os.environ.get('TTT_MOMENTUM', 0.9)) + ttt_batch_seqs = int(os.environ.get('TTT_BATCH_SEQS', 32)) + ttt_grad_clip = float(os.environ.get('TTT_GRAD_CLIP', 1.0)) + ttt_optimizer = os.environ.get('TTT_OPTIMIZER', 'sgd') + ttt_adamw_wd = float(os.environ.get('TTT_ADAMW_WD', 0.0)) + hash_embed_enabled = bool(int(os.environ.get('HASH_EMBED_ENABLED', '1'))) + hash_embed_size = int(os.environ.get('HASH_EMBED_SIZE', 16384)) + compressor = os.environ.get('COMPRESSOR', 'brotli') + gptq_calibration_batches = int(os.environ.get('GPTQ_CALIBRATION_BATCHES', 64)) + gptq_reserve_seconds = float(os.environ.get('GPTQ_RESERVE_SECONDS', 12.0)) + matrix_bits = int(os.environ.get('MATRIX_BITS', 6)) + embed_bits = int(os.environ.get('EMBED_BITS', 8)) + matrix_clip_sigmas = float(os.environ.get('MATRIX_CLIP_SIGMAS', 12.85)) + embed_clip_sigmas = float(os.environ.get('EMBED_CLIP_SIGMAS', 20.0)) + distributed = 'RANK' in os.environ and 'WORLD_SIZE' in os.environ + rank = int(os.environ.get('RANK', '0')) + world_size = int(os.environ.get('WORLD_SIZE', '1')) + local_rank = int(os.environ.get('LOCAL_RANK', '0')) + is_main_process = rank == 0 + grad_accum_steps = 8 // world_size + datasets_dir = os.path.join(data_dir, 'datasets', f'fineweb10B_sp{vocab_size}') + train_files = os.path.join(datasets_dir, 'fineweb_train_*.bin') + val_files = os.path.join(datasets_dir, 'fineweb_val_*.bin') + tokenizer_path = os.path.join(data_dir, 'tokenizers', f'fineweb_{vocab_size}_bpe.model') + logfile = f'logs/{run_id}.txt' + model_path = 'final_model.pt' + quantized_model_path = 'final_model.int6.ptz' +_logger_hparams = None + +def set_logging_hparams(h): + global _logger_hparams + _logger_hparams = h + +def log(msg, console=True): + if _logger_hparams is None: + print(msg) + return + if _logger_hparams.is_main_process: + if console: + print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile, 'a', encoding='utf-8') as f: + print(msg, file=f) + +class ValidationData: + + def __init__(self, h, device): + self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path) + if int(self.sp.vocab_size()) != h.vocab_size: + raise ValueError(f'VOCAB_SIZE={h.vocab_size} does not match tokenizer vocab_size={int(self.sp.vocab_size())}') + self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len) + self.base_bytes_lut, self.has_leading_space_lut, self.is_boundary_token_lut = build_sentencepiece_luts(self.sp, h.vocab_size, device) + +def build_sentencepiece_luts(sp, vocab_size, device): + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith('▁'): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode('utf-8')) + return (torch.tensor(base_bytes_np, dtype=torch.int16, device=device), torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern, seq_len): + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f'No files found for pattern: {pattern}') + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = (tokens.numel() - 1) // seq_len * seq_len + if usable <= 0: + raise ValueError(f'Validation split is too short for TRAIN_SEQ_LEN={seq_len}') + return tokens[:usable + 1] + +def load_data_shard(file): + header_bytes = 256 * np.dtype(' 0 else 0 + num_sequences = (self.num_tokens[si] - 1 - phase) // self.seq_len + sequence_order = self.rng.permutation(num_sequences) + self.start_inds[si] = (phase + sequence_order * self.seq_len).tolist() + + def next_batch(self, global_tokens, grad_accum_steps): + device_tokens = global_tokens // (self.world_size * grad_accum_steps) + device_batch_size = device_tokens // self.seq_len + remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) + x = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + y = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + for bi in range(device_batch_size): + total = remaining.sum() + if total <= 0: + for si in range(len(self.files)): + self._reset_shard(si) + remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) + total = remaining.sum() + probs = remaining / total + si = int(self.rng.choice(len(self.files), p=probs)) + start_ind = self.start_inds[si].pop() + remaining[si] -= 1 + mm = _get_shard_memmap(self.files[si]) + self._window_buf[:] = mm[start_ind:start_ind + self.seq_len + 1] + window = torch.as_tensor(self._window_buf) + x[bi] = window[:-1] + y[bi] = window[1:] + return (x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)) + +class RMSNorm(nn.Module): + + def __init__(self, eps=None): + super().__init__() + self.eps = eps + + def forward(self, x): + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +class CastedLinear(nn.Linear): + + def forward(self, x): + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + +class Rotary(nn.Module): + + def __init__(self, dim, base=10000.0, train_seq_len=1024, rope_dims=0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims) + self.register_buffer('inv_freq', inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + + def forward(self, seq_len, device, dtype): + if self._cos_cached is None or self._sin_cached is None or self._seq_len_cached != seq_len or (self._cos_cached.device != device): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * scale ** (rd / (rd - 2)) + inv_freq = 1.0 / new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return (self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype)) + +def apply_rotary_emb(x, cos, sin, rope_dims=0): + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = (x[..., :rope_dims], x[..., rope_dims:]) + half = rope_dims // 2 + x1, x2 = (x_rope[..., :half], x_rope[..., half:]) + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = (x[..., :half], x[..., half:]) + return torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len): + super().__init__() + if dim % num_heads != 0: + raise ValueError('model_dim must be divisible by num_heads') + if num_heads % num_kv_heads != 0: + raise ValueError('num_heads must be divisible by num_kv_heads') + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError('head_dim must be even for RoPE') + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len) + self.use_xsa = False + + def _xsa_efficient(self, y, v): + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x, q_w, k_w, v_w, out_w): + bsz, seqlen, dim = x.shape + q = F.linear(x, q_w.to(x.dtype)).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = F.linear(x, v_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + self._last_proj_input = y.detach() if getattr(self, '_calib', False) else None + return F.linear(y, out_w.to(x.dtype)) + +class MLP(nn.Module): + + def __init__(self, dim, mlp_mult): + super().__init__() + self.use_fused = True + + def forward(self, x, up_w, down_w): + if _HAS_TRITON_TMA and x.is_cuda and self.training and self.use_fused: + return _FusedMLP.apply(x, up_w.to(x.dtype), down_w.to(x.dtype)) + hidden = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=0.5).square() + self._last_down_input = hidden.detach() if getattr(self, '_calib', False) else None + return F.linear(hidden, down_w.to(x.dtype)) + +class Block(nn.Module): + + def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, train_seq_len, layer_idx=0, ln_scale=False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x, x0, q_w, k_w, v_w, out_w, up_w, down_w): + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, q_w, k_w, v_w, out_w) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) + return x_out + + def forward_attn(self, x, x0, q_w, k_w, v_w, out_w): + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, q_w, k_w, v_w, out_w) + return x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + + def forward_mlp(self, x, up_w, down_w): + return x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * self.ln_scale_factor, up_w, down_w) + +class GPT(nn.Module): + + def __init__(self, h): + super().__init__() + if h.logit_softcap <= 0.0: + raise ValueError(f'logit_softcap must be positive, got {h.logit_softcap}') + self.tie_embeddings = h.tie_embeddings + self.tied_embed_init_std = h.tied_embed_init_std + self.logit_softcap = h.logit_softcap + self.tok_emb = nn.Embedding(h.vocab_size, h.embedding_dim) + if h.embedding_dim != h.model_dim: + self.embed_proj = CastedLinear(h.embedding_dim, h.model_dim, bias=False) + self.head_proj = CastedLinear(h.model_dim, h.embedding_dim, bias=False) + else: + self.embed_proj = None + self.head_proj = None + self.num_layers = h.num_layers + head_dim = h.model_dim // h.num_heads + kv_dim = h.num_kv_heads * head_dim + hidden_dim = int(h.mlp_mult * h.model_dim) + self.qo_bank = nn.Parameter(torch.empty(2 * h.num_layers, h.model_dim, h.model_dim)) + self.kv_bank = nn.Parameter(torch.empty(2 * h.num_layers, kv_dim, h.model_dim)) + self.mlp_up_bank = nn.Parameter(torch.empty(h.num_layers, hidden_dim, h.model_dim)) + self.mlp_down_bank = nn.Parameter(torch.empty(h.num_layers, h.model_dim, hidden_dim)) + self.num_encoder_layers = h.num_layers // 2 + self.num_decoder_layers = h.num_layers - self.num_encoder_layers + self.blocks = nn.ModuleList([Block(h.model_dim, h.num_heads, h.num_kv_heads, h.mlp_mult, h.rope_base, h.qk_gain_init, h.train_seq_len, layer_idx=i, ln_scale=h.ln_scale) for i in range(h.num_layers)]) + if h.rope_dims > 0: + head_dim = h.model_dim // h.num_heads + for block in self.blocks: + block.attn.rope_dims = h.rope_dims + block.attn.rotary = Rotary(head_dim, base=h.rope_base, train_seq_len=h.train_seq_len, rope_dims=h.rope_dims) + self.final_norm = RMSNorm() + self.lm_head = None if h.tie_embeddings else CastedLinear(h.embedding_dim, h.vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + if h.xsa_last_n > 0: + for i in range(max(0, h.num_layers - h.xsa_last_n), h.num_layers): + self.blocks[i].attn.use_xsa = True + self.looping_active = False + if h.num_loops > 0: + loop_seg = list(range(h.loop_start, h.loop_end + 1)) + all_indices = list(range(h.loop_start)) + for _ in range(h.num_loops + 1): + all_indices.extend(loop_seg) + all_indices.extend(range(h.loop_end + 1, h.num_layers)) + num_enc = len(all_indices) // 2 + self.encoder_indices = all_indices[:num_enc] + self.decoder_indices = all_indices[num_enc:] + else: + self.encoder_indices = list(range(self.num_encoder_layers)) + self.decoder_indices = list(range(self.num_encoder_layers, h.num_layers)) + self.num_skip_weights = min(len(self.encoder_indices), len(self.decoder_indices)) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, h.model_dim, dtype=torch.float32)) + self.skip_gates = nn.Parameter(torch.zeros(self.num_skip_weights, h.model_dim, dtype=torch.float32)) if h.skip_gates_enabled else None + self.parallel_residual = h.parallel_residual + self.parallel_start_layer = max(0, h.parallel_start_layer) + self.parallel_start_layer_is_physical = h.parallel_start_layer_is_physical + self.parallel_final_lane = h.parallel_final_lane.lower() + self.parallel_freeze_lane0 = h.parallel_freeze_lane0 + self.parallel_identity_init = h.parallel_identity_init + self.parallel_skip_lane0_only = h.parallel_skip_lane0_only + self.parallel_mlp_read_mix = h.parallel_mlp_read_mix + if self.parallel_final_lane not in ('mlp', 'mean', 'attn'): + raise ValueError(f"PARALLEL_FINAL_LANE must be one of 'mlp', 'mean', or 'attn', got {h.parallel_final_lane!r}") + if self.parallel_residual: + if self.parallel_identity_init: + self.parallel_post_lambdas = nn.Parameter(torch.ones(h.num_layers, 2, 2, dtype=torch.float32)) + self.parallel_resid_lambdas = nn.Parameter(torch.full((h.num_layers, 2), 1.1, dtype=torch.float32)) + else: + self.parallel_post_lambdas = nn.Parameter(torch.ones(h.num_layers, 2, 2, dtype=torch.float32)) + self.parallel_resid_lambdas = nn.Parameter(torch.full((h.num_layers, 2), 1.1 ** 0.5, dtype=torch.float32)) + else: + self.parallel_post_lambdas = None + self.parallel_resid_lambdas = None + self._init_weights() + + def _init_weights(self): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + n = self.num_layers + proj_scale = 1.0 / math.sqrt(2 * n) + for i in range(n): + nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) + nn.init.zeros_(self.qo_bank.data[n + i]) + self.qo_bank.data[n + i].mul_(proj_scale) + nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) + nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) + nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) + nn.init.zeros_(self.mlp_down_bank.data[i]) + self.mlp_down_bank.data[i].mul_(proj_scale) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, '_zero_init', False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and (module.weight.shape[1] >= 64): + nn.init.orthogonal_(module.weight, gain=1.0) + + def _bank_weights(self, i): + n = self.num_layers + return (self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i]) + + def _parallel_active_for_layer(self, physical_idx, virtual_idx): + if self.parallel_post_lambdas is None: + return False + if self.parallel_start_layer_is_physical: + return physical_idx >= self.parallel_start_layer + return virtual_idx >= self.parallel_start_layer + + def _mix_with_x0(self, lane, x0, resid_mix): + mix = resid_mix.to(dtype=lane.dtype) + return mix[0][None, None, :] * lane + mix[1][None, None, :] * x0 + + def _apply_skip_single(self, x, skip, skip_idx): + if isinstance(skip, tuple): + skip = skip[1] + scaled_skip = self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] * skip + if self.skip_gates is None: + return x + scaled_skip + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + return torch.lerp(scaled_skip, x, g) + + def _apply_skip_parallel(self, lane0, lane1, skip, skip_idx): + if isinstance(skip, tuple): + skip0, skip1 = skip + else: + skip0 = skip1 = skip + w = self.skip_weights[skip_idx].to(dtype=lane0.dtype)[None, None, :] + if self.parallel_skip_lane0_only: + if self.skip_gates is None: + next_lane0 = lane0 if self.parallel_freeze_lane0 else lane0 + w * skip0 + return (next_lane0, lane1) + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=lane0.dtype))[None, None, :] + next_lane0 = lane0 if self.parallel_freeze_lane0 else torch.lerp(w * skip0, lane0, g) + return (next_lane0, lane1) + if self.skip_gates is None: + next_lane0 = lane0 if self.parallel_freeze_lane0 else lane0 + w * skip0 + return (next_lane0, lane1 + w * skip1) + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=lane0.dtype))[None, None, :] + next_lane0 = lane0 if self.parallel_freeze_lane0 else torch.lerp(w * skip0, lane0, g) + return (next_lane0, torch.lerp(w * skip1, lane1, g)) + + def _final_parallel_hidden(self, lane0, lane1): + if self.parallel_final_lane == 'mlp': + return lane1 + if self.parallel_final_lane == 'attn': + return lane0 + return 0.5 * (lane0 + lane1) + + def _parallel_block(self, block_idx, lane0, lane1, x0, q_w, k_w, v_w, out_w, up_w, down_w): + if self.parallel_post_lambdas is None or self.parallel_resid_lambdas is None: + raise RuntimeError('parallel residual weights are not initialized') + block = self.blocks[block_idx] + attn_read = self._mix_with_x0(lane0, x0, block.resid_mix) + attn_out = block.attn(block.attn_norm(attn_read) * block.ln_scale_factor, q_w, k_w, v_w, out_w) + attn_out = block.attn_scale.to(dtype=attn_out.dtype)[None, None, :] * attn_out + mlp_read = self._mix_with_x0(lane1, x0, block.resid_mix) if self.parallel_mlp_read_mix else lane1 + mlp_out = block.mlp_scale.to(dtype=lane1.dtype)[None, None, :] * block.mlp(block.mlp_norm(mlp_read) * block.ln_scale_factor, up_w, down_w) + attn_resid = self.parallel_resid_lambdas[block_idx, 0].to(dtype=lane0.dtype) + attn_post = self.parallel_post_lambdas[block_idx, 0].to(dtype=lane0.dtype) + mlp_resid = self.parallel_resid_lambdas[block_idx, 1].to(dtype=lane0.dtype) + mlp_post = self.parallel_post_lambdas[block_idx, 1].to(dtype=lane0.dtype) + next_lane0 = lane0 + if not self.parallel_freeze_lane0: + next_lane0 = attn_resid * lane0 + attn_post[0] * attn_out + mlp_post[0] * mlp_out + next_lane1 = mlp_resid * lane1 + attn_post[1] * attn_out + mlp_post[1] * mlp_out + return (next_lane0, next_lane1) + + def _forward_logits_from_embeddings(self, x): + x0 = x + skips = [] + enc_iter = list(self.encoder_indices) if self.looping_active else list(range(self.num_encoder_layers)) + dec_iter = list(self.decoder_indices) if self.looping_active else list(range(self.num_encoder_layers, self.num_encoder_layers + self.num_decoder_layers)) + lane0 = None + lane1 = None + for virtual_idx, block_idx in enumerate(enc_iter): + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(block_idx) + if self._parallel_active_for_layer(block_idx, virtual_idx): + if lane0 is None or lane1 is None: + lane0 = x + lane1 = x + lane0, lane1 = self._parallel_block(block_idx, lane0, lane1, x0, q_w, k_w, v_w, out_w, up_w, down_w) + skips.append((lane0, lane1)) + else: + x = self.blocks[block_idx](x, x0, q_w, k_w, v_w, out_w, up_w, down_w) + skips.append(x) + dec_offset = len(enc_iter) + for skip_idx, block_idx in enumerate(dec_iter): + virtual_idx = dec_offset + skip_idx + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(block_idx) + if self._parallel_active_for_layer(block_idx, virtual_idx): + if lane0 is None or lane1 is None: + if self.parallel_skip_lane0_only and skip_idx < self.num_skip_weights and skips: + x = self._apply_skip_single(x, skips.pop(), skip_idx) + lane0 = x + lane1 = x + elif skip_idx < self.num_skip_weights and skips: + lane0, lane1 = self._apply_skip_parallel(lane0, lane1, skips.pop(), skip_idx) + lane0, lane1 = self._parallel_block(block_idx, lane0, lane1, x0, q_w, k_w, v_w, out_w, up_w, down_w) + else: + if skip_idx < self.num_skip_weights and skips: + x = self._apply_skip_single(x, skips.pop(), skip_idx) + x = self.blocks[block_idx](x, x0, q_w, k_w, v_w, out_w, up_w, down_w) + if lane1 is not None: + x = self._final_parallel_hidden(lane0, lane1) + x = self.final_norm(x) + if self.head_proj is not None: + x = self.head_proj(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward_logits(self, input_ids): + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.embed_proj is not None: + x = self.embed_proj(x) + return self._forward_logits_from_embeddings(x) + + def forward(self, input_ids, target_ids): + logits = self.forward_logits(input_ids) + return F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1), reduction='mean') + +def classify_param(name): + if 'tok_emb' in name or 'lm_head' in name: + return 'embed' + if '.mlp.' in name: + return 'mlp' + if '.attn.' in name or ('.proj.' in name and '.mlp.' not in name): + return 'attn' + return 'other' + +@torch.compile +def zeropower_via_newtonschulz5(G, steps=10, eps=1e-07): + a, b, c = (3.4445, -4.775, 2.0315) + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + for _ in range(steps): + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + +@torch.compile +def _muon_fused_transform(g, momentum_buf, momentum, nesterov, row_normalize, backend_steps): + new_buf = momentum_buf * momentum + g + if nesterov: + g = g + momentum * new_buf + else: + g = new_buf.clone() + if row_normalize: + row_norms = g.float().norm(dim=-1, keepdim=True).clamp_min(1e-07) + g = g / row_norms.to(g.dtype) + a, b, c = (3.4445, -4.775, 2.0315) + was_2d = g.ndim == 2 + if was_2d: + g = g.unsqueeze(0) + X = g.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-07) + for _ in range(backend_steps): + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X, new_buf + +class Muon(torch.optim.Optimizer): + + def __init__(self, params, lr, momentum, backend_steps, nesterov=True, weight_decay=0.0, row_normalize=False): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay, row_normalize=row_normalize)) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + self._bank_meta = [] + for group in self.param_groups: + for p in group['params']: + B = p.shape[0] + padded_B = (B + ws - 1) // ws * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({'p': p, 'B': B, 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5}) + self._bank_meta.sort(key=lambda m: -m['p'].numel()) + self._built = True + + def launch_reduce_scatters(self): + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m['p'] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m['padded_grad'] + pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0] > m['B']: + pg[m['B']:].zero_() + fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + if not self._built: + self._build() + for group in self.param_groups: + lr = group['lr'] + momentum = group['momentum'] + backend_steps = group['backend_steps'] + nesterov = group['nesterov'] + wd = group.get('weight_decay', 0.0) + row_normalize = group.get('row_normalize', False) + prev_ag_handle = None + prev_m = None + sharded = self._distributed and hasattr(self, '_rs_futures') + for idx, m in enumerate(self._bank_meta): + p = m['p'] + if p.grad is None: + continue + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + if sharded and self._rs_futures[idx] is not None: + self._rs_futures[idx].wait() + g = m['shard'] + buf = m['shard_mom'] + else: + g = p.grad.bfloat16() + state = self.state[p] + if 'momentum_buffer' not in state: + state['momentum_buffer'] = torch.zeros_like(g) + buf = state['momentum_buffer'] + update, new_buf = _muon_fused_transform(g, buf, momentum, nesterov, row_normalize, backend_steps) + buf.copy_(new_buf) + if sharded: + prev_ag_handle = dist.all_gather_into_tensor(m['full_update'], update, async_op=True) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + if hasattr(self, '_rs_futures'): + del self._rs_futures + return loss +CONTROL_TENSOR_NAME_PATTERNS = tuple((pattern for pattern in os.environ.get('CONTROL_TENSOR_NAME_PATTERNS', 'attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates,parallel_post_lambdas,parallel_resid_lambdas').split(',') if pattern)) + +class Optimizers: + + def __init__(self, h, base_model): + matrix_params = [base_model.qo_bank, base_model.kv_bank, base_model.mlp_up_bank, base_model.mlp_down_bank] + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [p for name, p in block_named_params if p.ndim < 2 or any((pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS))] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.skip_gates is not None and base_model.skip_gates.numel() > 0: + scalar_params.append(base_model.skip_gates) + if base_model.parallel_post_lambdas is not None: + scalar_params.append(base_model.parallel_post_lambdas) + if base_model.parallel_resid_lambdas is not None: + scalar_params.append(base_model.parallel_resid_lambdas) + token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr + tok_params = [{'params': [base_model.tok_emb.weight], 'lr': token_lr, 'base_lr': token_lr}] + self.optimizer_tok = torch.optim.AdamW(tok_params, betas=(h.beta1, h.beta2), eps=h.adam_eps, weight_decay=h.embed_wd, fused=True) + self.optimizer_muon = Muon(matrix_params, lr=h.matrix_lr, momentum=h.muon_momentum, backend_steps=h.muon_backend_steps, weight_decay=h.muon_wd, row_normalize=h.muon_row_normalize) + for group in self.optimizer_muon.param_groups: + group['base_lr'] = h.matrix_lr + scalar_groups = [{'params': scalar_params, 'lr': h.scalar_lr, 'base_lr': h.scalar_lr, 'weight_decay': h.adam_wd}] + self.optimizer_scalar = torch.optim.AdamW(scalar_groups, betas=(h.beta1, h.beta2), eps=h.adam_eps, fused=True) + self.optimizers = [self.optimizer_tok, self.optimizer_muon, self.optimizer_scalar] + if base_model.lm_head is not None: + self.optimizer_head = torch.optim.Adam([{'params': [base_model.lm_head.weight], 'lr': h.head_lr, 'base_lr': h.head_lr}], betas=(h.beta1, h.beta2), eps=h.adam_eps, fused=True) + self.optimizers.insert(1, self.optimizer_head) + else: + self.optimizer_head = None + self.replicated_params = list(tok_params[0]['params']) + self.replicated_params.extend(scalar_params) + if base_model.lm_head is not None: + self.replicated_params.append(base_model.lm_head.weight) + + def __iter__(self): + return iter(self.optimizers) + + def zero_grad_all(self): + for opt in self.optimizers: + opt.zero_grad(set_to_none=True) + + def step(self, distributed=False): + self.optimizer_muon.launch_reduce_scatters() + if distributed: + for p in self.replicated_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + self.optimizer_tok.step() + self.optimizer_scalar.step() + if self.optimizer_head is not None: + self.optimizer_head.step() + self.optimizer_muon.step() + self.zero_grad_all() + +def restore_fp32_params(model): + for module in model.modules(): + if isinstance(module, CastedLinear): + module.float() + for name, param in model.named_parameters(): + if (param.ndim < 2 or any((pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS))) and param.dtype != torch.float32: + param.data = param.data.float() + if hasattr(model, 'qo_bank'): + model.qo_bank.data = model.qo_bank.data.float() + model.kv_bank.data = model.kv_bank.data.float() + model.mlp_up_bank.data = model.mlp_up_bank.data.float() + model.mlp_down_bank.data = model.mlp_down_bank.data.float() + +def log_parallel_residual_converged(log0, model): + if getattr(model, 'parallel_post_lambdas', None) is None: + return + if model.looping_active: + v2p = list(model.encoder_indices) + list(model.decoder_indices) + else: + v2p = list(range(model.num_encoder_layers + model.num_decoder_layers)) + used_layers = [vi for vi, pi in enumerate(v2p) if model._parallel_active_for_layer(pi, vi)] + post = model.parallel_post_lambdas.detach().cpu() + resid = model.parallel_resid_lambdas.detach().cpu() + mode = 'physical' if model.parallel_start_layer_is_physical else 'virtual' + log0(f'parallel_residual:converged active=1 start_layer={model.parallel_start_layer} start_mode={mode} final_lane={model.parallel_final_lane} freeze_lane0={int(model.parallel_freeze_lane0)} identity_init={int(model.parallel_identity_init)} skip_lane0_only={int(model.parallel_skip_lane0_only)} mlp_read_mix={int(model.parallel_mlp_read_mix)} used_layers={len(used_layers)}') + for vi in used_layers: + pi = int(v2p[vi]) + if not (0 <= pi < post.shape[0] and 0 <= pi < resid.shape[0]): + log0(f'parallel_residual layer:{vi} physical:{pi} skipped=out_of_range') + continue + log0(f'parallel_residual layer:{vi} physical:{pi} attn_resid:{resid[pi, 0]:.4f} attn_to_attn:{post[pi, 0, 0]:.4f} attn_to_mlp:{post[pi, 0, 1]:.4f} mlp_resid:{resid[pi, 1]:.4f} mlp_to_attn:{post[pi, 1, 0]:.4f} mlp_to_mlp:{post[pi, 1, 1]:.4f}') + +def _unbank_state_dict(state_dict, num_layers): + sd = {} + n = num_layers + for k, v in state_dict.items(): + t = v.detach().cpu() + if k == 'qo_bank': + for i in range(n): + sd[f'blocks.{i}.attn.c_q.weight'] = t[i] + sd[f'blocks.{i}.attn.proj.weight'] = t[n + i] + elif k == 'kv_bank': + for i in range(n): + sd[f'blocks.{i}.attn.c_k.weight'] = t[i] + sd[f'blocks.{i}.attn.c_v.weight'] = t[n + i] + elif k == 'mlp_up_bank': + for i in range(n): + sd[f'blocks.{i}.mlp.fc.weight'] = t[i] + elif k == 'mlp_down_bank': + for i in range(n): + sd[f'blocks.{i}.mlp.proj.weight'] = t[i] + else: + sd[k] = t + return sd + +def _rebank_state_dict(flat_sd, num_layers, model_dim, kv_dim, hidden_dim): + sd = {} + n = num_layers + sd['qo_bank'] = torch.zeros(2 * n, model_dim, model_dim) + sd['kv_bank'] = torch.zeros(2 * n, kv_dim, model_dim) + sd['mlp_up_bank'] = torch.zeros(n, hidden_dim, model_dim) + sd['mlp_down_bank'] = torch.zeros(n, model_dim, hidden_dim) + for i in range(n): + sd['qo_bank'][i] = flat_sd[f'blocks.{i}.attn.c_q.weight'] + sd['qo_bank'][n + i] = flat_sd[f'blocks.{i}.attn.proj.weight'] + sd['kv_bank'][i] = flat_sd[f'blocks.{i}.attn.c_k.weight'] + sd['kv_bank'][n + i] = flat_sd[f'blocks.{i}.attn.c_v.weight'] + sd['mlp_up_bank'][i] = flat_sd[f'blocks.{i}.mlp.fc.weight'] + sd['mlp_down_bank'][i] = flat_sd[f'blocks.{i}.mlp.proj.weight'] + for k, v in flat_sd.items(): + if not (k.startswith('blocks.') and any((p in k for p in ['.attn.c_q.', '.attn.c_k.', '.attn.c_v.', '.attn.proj.', '.mlp.fc.', '.mlp.proj.']))): + sd[k] = v + return sd + +def collect_hessians(model, train_loader, h, device, n_calibration_batches=64): + hessians = {} + hooks = [] + n = model.num_layers + for i, block in enumerate(model.blocks): + block.attn._calib = True + block.mlp._calib = True + + def make_attn_hook(layer_idx): + + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + for suffix in ['c_q', 'c_k', 'c_v']: + name = f'blocks.{layer_idx}.attn.{suffix}.weight' + if name not in hessians: + hessians[name] = torch.zeros(x.shape[1], x.shape[1], dtype=torch.float32, device=device) + hessians[name].addmm_(x.T, x) + y = module._last_proj_input + if y is not None: + y = y.float() + if y.ndim == 3: + y = y.reshape(-1, y.shape[-1]) + name = f'blocks.{layer_idx}.attn.proj.weight' + if name not in hessians: + hessians[name] = torch.zeros(y.shape[1], y.shape[1], dtype=torch.float32, device=device) + hessians[name].addmm_(y.T, y) + return hook_fn + + def make_mlp_hook(layer_idx): + + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + name = f'blocks.{layer_idx}.mlp.fc.weight' + if name not in hessians: + hessians[name] = torch.zeros(x.shape[1], x.shape[1], dtype=torch.float32, device=device) + hessians[name].addmm_(x.T, x) + h_act = module._last_down_input + if h_act is not None: + h_act = h_act.float() + if h_act.ndim == 3: + h_act = h_act.reshape(-1, h_act.shape[-1]) + name = f'blocks.{layer_idx}.mlp.proj.weight' + if name not in hessians: + hessians[name] = torch.zeros(h_act.shape[1], h_act.shape[1], dtype=torch.float32, device=device) + hessians[name].addmm_(h_act.T, h_act) + return hook_fn + for i, block in enumerate(model.blocks): + hooks.append(block.attn.register_forward_hook(make_attn_hook(i))) + hooks.append(block.mlp.register_forward_hook(make_mlp_hook(i))) + if model.tie_embeddings: + hook_module = model.head_proj if model.head_proj is not None else model.final_norm + + def make_output_hook(name): + + def hook_fn(module, inp, out): + x = out.detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if name not in hessians: + hessians[name] = torch.zeros(x.shape[1], x.shape[1], dtype=torch.float32, device=device) + hessians[name].addmm_(x.T, x) + return hook_fn + hooks.append(hook_module.register_forward_hook(make_output_hook('tok_emb.weight'))) + model.eval() + with torch.no_grad(): + for _ in range(n_calibration_batches): + x, _ = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps) + model.forward_logits(x) + for hook in hooks: + hook.remove() + for i, block in enumerate(model.blocks): + block.attn._calib = False + block.mlp._calib = False + for name in hessians: + hessians[name] = hessians[name].cpu() / n_calibration_batches + return hessians + +def gptq_quantize_weight(w, H, clip_sigmas=3.0, clip_range=63, block_size=128): + W_orig = w.float().clone() + rows, cols = W_orig.shape + H = H.float().clone() + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + damp = 0.01 * H.diag().mean() + H.diagonal().add_(damp) + perm = torch.argsort(H.diag(), descending=True) + invperm = torch.argsort(perm) + W_perm = W_orig[:, perm].clone() + W_perm[:, dead[perm]] = 0 + H = H[perm][:, perm] + Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H)) + Hinv = torch.linalg.cholesky(Hinv, upper=True) + row_std = W_orig.std(dim=1) + s = (clip_sigmas * row_std / clip_range).clamp_min(1e-10).to(torch.float16) + sf = s.float() + Q = torch.zeros(rows, cols, dtype=torch.int8) + W_work = W_perm.clone() + for i1 in range(0, cols, block_size): + i2 = min(i1 + block_size, cols) + W_block = W_work[:, i1:i2].clone() + Hinv_block = Hinv[i1:i2, i1:i2] + Err = torch.zeros(rows, i2 - i1) + for j in range(i2 - i1): + w_col = W_block[:, j] + d = Hinv_block[j, j] + q_col = torch.clamp(torch.round(w_col / sf), -clip_range, clip_range) + Q[:, i1 + j] = q_col.to(torch.int8) + err = (w_col - q_col.float() * sf) / d + Err[:, j] = err + W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0) + if i2 < cols: + W_work[:, i2:] -= Err @ Hinv[i1:i2, i2:] + return (Q[:, invperm], s) + +def gptq_mixed_quantize(state_dict, hessians, h): + result = {} + meta = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = 'passthrough (float16)' + continue + cs = h.embed_clip_sigmas if 'tok_emb' in name else h.matrix_clip_sigmas + bits = h.embed_bits if 'tok_emb' in name else h.matrix_bits + q, s = gptq_quantize_weight(t, hessians[name], clip_sigmas=cs, clip_range=2 ** (bits - 1) - 1) + result[name + '.q'] = q + result[name + '.scale'] = s + meta[name] = f'gptq (int{bits})' + categories = collections.defaultdict(set) + for name, cat in meta.items(): + short = re.sub('\\.\\d+$', '', re.sub('blocks\\.\\d+', 'blocks', name)) + categories[cat].add(short) + log('Quantized weights:') + for cat in sorted(categories): + log(f" {cat}: {', '.join(sorted(categories[cat]))}") + return (result, meta) + +def dequantize_mixed(result, meta, template_sd): + out = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if 'passthrough' in info: + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = (result[name + '.q'], result[name + '.scale']) + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *[1] * (q.ndim - 1))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out +_BSHF_MAGIC = b'BSHF' + +def _byte_shuffle(data, stride=2): + if stride <= 1 or len(data) < stride: + return data + src = np.frombuffer(data, dtype=np.uint8) + n = len(src) + out = np.empty(n, dtype=np.uint8) + dest_off = 0 + for pos in range(stride): + chunk = src[pos::stride] + out[dest_off:dest_off + len(chunk)] = chunk + dest_off += len(chunk) + return _BSHF_MAGIC + bytes([stride]) + out.tobytes() + +def _byte_unshuffle(data): + if len(data) < 5 or data[:4] != _BSHF_MAGIC: + return data + stride = data[4] + if stride < 2: + return data[5:] + payload = np.frombuffer(data, dtype=np.uint8, offset=5) + n = len(payload) + out = np.empty(n, dtype=np.uint8) + src_off = 0 + for pos in range(stride): + chunk_len = n // stride + (1 if pos < n % stride else 0) + out[pos::stride][:chunk_len] = payload[src_off:src_off + chunk_len] + src_off += chunk_len + return out.tobytes() + +def _compress(data, compressor): + data = _byte_shuffle(data) + if compressor == 'lzma': + return lzma.compress(data, preset=6) + elif compressor == 'brotli': + import brotli + return brotli.compress(data, quality=11) + raise ValueError(f'Unknown compressor: {compressor!r}') + +def _decompress(data, compressor): + if compressor == 'lzma': + raw = lzma.decompress(data) + elif compressor == 'brotli': + import brotli + raw = brotli.decompress(data) + else: + raise ValueError(f'Unknown compressor: {compressor!r}') + raw = _byte_unshuffle(raw) + return raw + +def serialize(h, base_model, code): + code_bytes = len(code.encode('utf-8')) + if h.is_main_process: + torch.save(base_model.state_dict(), h.model_path) + model_bytes = os.path.getsize(h.model_path) + log(f'Serialized model: {model_bytes} bytes') + log(f'Code size: {code_bytes} bytes') + sd_cpu = _unbank_state_dict(base_model.state_dict(), h.num_layers) + device = torch.device('cuda', h.local_rank) + log('GPTQ:collecting Hessians from calibration data...') + t0 = time.perf_counter() + calib_loader = ShuffledSequenceLoader(h, device) + hessians = collect_hessians(base_model, calib_loader, h, device, n_calibration_batches=h.gptq_calibration_batches) + log(f'GPTQ:collected {len(hessians)} Hessians in {time.perf_counter() - t0:.1f}s') + quant_result, quant_meta = gptq_mixed_quantize(sd_cpu, hessians, h) + quant_buf = io.BytesIO() + torch.save({'w': quant_result, 'm': quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = _compress(quant_raw, h.compressor) + quant_file_bytes = len(quant_blob) + bytes_total = quant_file_bytes + code_bytes + if h.is_main_process: + with open(h.quantized_model_path, 'wb') as f: + f.write(quant_blob) + log(f'Serialized model quantized+{h.compressor}: {quant_file_bytes} bytes') + log(f'Total submission size quantized+{h.compressor}: {bytes_total} bytes') + return (bytes_total, quant_file_bytes) + +def deserialize(h, device): + eval_model = GPT(h).to(device).bfloat16() + restore_fp32_params(eval_model) + flat_template = _unbank_state_dict(eval_model.state_dict(), h.num_layers) + with open(h.quantized_model_path, 'rb') as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(_decompress(quant_blob_disk, h.compressor)), map_location='cpu') + deq_flat = dequantize_mixed(quant_state['w'], quant_state['m'], flat_template) + head_dim = h.model_dim // h.num_heads + kv_dim = h.num_kv_heads * head_dim + hidden_dim = int(h.mlp_mult * h.model_dim) + deq_state = _rebank_state_dict(deq_flat, h.num_layers, h.model_dim, kv_dim, hidden_dim) + eval_model.load_state_dict(deq_state, strict=True) + return eval_model + +def _loss_bpb(loss_sum, token_count, byte_count): + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + return (val_loss, val_bpb) + +def eval_val(h, device, val_data, model): + seq_len = h.eval_seq_len + local_batch_tokens = h.val_batch_tokens // (h.world_size * h.grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError(f'VAL_BATCH_SIZE must provide at least one sequence per rank; got VAL_BATCH_SIZE={h.val_batch_tokens}, WORLD_SIZE={h.world_size}, GRAD_ACCUM_STEPS={h.grad_accum_steps}, seq_len={seq_len}') + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_data.val_tokens.numel() - 1) // seq_len + seq_start = total_seqs * h.rank // h.world_size + seq_end = total_seqs * (h.rank + 1) // h.world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_data.val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = val_data.base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (val_data.has_leading_space_lut[tgt_ids] & ~val_data.is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + model.train() + return _loss_bpb(val_loss_sum, val_token_count, val_byte_count) + +def eval_val_sliding(h, device, val_data, base_model, batch_seqs=32): + base_model.eval() + logits_fn = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + seq_len = h.eval_seq_len + context_size = seq_len - h.eval_stride + total_tokens = val_data.val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, h.eval_stride) if ws + context_size < total_tokens] + total_windows = len(window_starts) + my_s = total_windows * h.rank // h.world_size + my_e = total_windows * (h.rank + 1) // h.world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + for i, ws in enumerate(batch_ws): + we = min(ws + seq_len, total_tokens) + wlen = we - ws + wlens.append(wlen) + chunk = val_data.val_tokens[ws:we + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + logits = logits_fn(x_batch) + nll = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), y_batch.reshape(-1), reduction='none').reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else context_size + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + base_model.train() + return _loss_bpb(loss_sum, token_count, byte_count) + +def eval_val_sliding_ttt(h, base_model, rank, world_size, device, val_data, stride): + for m in base_model.modules(): + if isinstance(m, MLP): + m.use_fused = False + seq_len = h.eval_seq_len + total_tokens = val_data.val_tokens.numel() - 1 + ttt_chunk = h.ttt_chunk_tokens + context_size = seq_len - stride + window_starts = [ws for ws in range(0, total_tokens, stride) if ws + context_size < total_tokens] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else context_size + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + hash_emb = None + orig_forward_logits = None + if h.hash_embed_enabled: + hash_emb = nn.Embedding(h.hash_embed_size, h.model_dim).to(device) + nn.init.zeros_(hash_emb.weight) + orig_forward_logits = base_model.forward_logits + + def forward_logits_with_hash(input_ids): + x = base_model.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + prev_ids = torch.zeros_like(input_ids) + prev_ids[:, 1:] = input_ids[:, :-1] + h_idx = (prev_ids * 2039 + input_ids) % h.hash_embed_size + x = x + hash_emb(h_idx) + if base_model.embed_proj is not None: + x = base_model.embed_proj(x) + return base_model._forward_logits_from_embeddings(x) + base_model.forward_logits = forward_logits_with_hash + log(f'ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} total_windows={len(window_starts)} stride={stride} ttt_lr={h.ttt_lr} ttt_epochs={h.ttt_epochs} freeze_blocks={h.ttt_freeze_blocks} optimizer={h.ttt_optimizer} hash_embed={h.hash_embed_enabled}') + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + frozen_block_ids = set(range(min(h.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f'blocks.{bi}.' in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + if hash_emb is not None: + ttt_params.append(hash_emb.weight) + total_unfrozen = sum((p.numel() for p in ttt_params)) + total_frozen = sum((p.numel() for p in base_model.parameters() if not p.requires_grad)) + log(f'ttt_sliding:params unfrozen={total_unfrozen} frozen={total_frozen}') + if h.ttt_optimizer == 'adamw': + optimizer = torch.optim.AdamW(ttt_params, lr=h.ttt_lr, weight_decay=h.ttt_adamw_wd, fused=False) + else: + optimizer = torch.optim.SGD(ttt_params, lr=h.ttt_lr, momentum=h.ttt_momentum) + t0 = time.perf_counter() + batch_seqs = h.ttt_batch_seqs + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + my_s = len(windows) * rank // world_size + my_e = len(windows) * (rank + 1) // world_size + my_windows = windows[my_s:my_e] + base_model.eval() + with torch.no_grad(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_data.val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), y_batch.reshape(-1), reduction='none').reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else context_size + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + is_last_chunk = ci == num_chunks - 1 + if not is_last_chunk and h.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = h.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = chunk_seqs * rank // world_size + my_seq_e = chunk_seqs * (rank + 1) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(h.ttt_epochs): + for bs in range(0, my_chunk_seqs, batch_seqs): + be = min(bs + batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_data.val_tokens.numel(): + continue + local = val_data.val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, h.ttt_grad_clip) + optimizer.step() + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) + log(f' ttt_chunk [{ci + 1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s') + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + if orig_forward_logits is not None: + base_model.forward_logits = orig_forward_logits + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + log(f'ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.8f} elapsed={time.perf_counter() - t0:.1f}s') + return (val_loss, val_bpb) + +def timed_eval(label, fn, *args, **kwargs): + torch.cuda.synchronize() + t0 = time.perf_counter() + val_loss, val_bpb = fn(*args, **kwargs) + torch.cuda.synchronize() + elapsed_ms = 1000.0 * (time.perf_counter() - t0) + log(f'{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms') + return (val_loss, val_bpb) + +def train_model(h, device, val_data): + base_model = GPT(h).to(device).bfloat16() + restore_fp32_params(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = compiled_model + log(f'model_params:{sum((p.numel() for p in base_model.parameters()))}') + log(f"parallel_residual:active={int(base_model.parallel_post_lambdas is not None)} start_layer={base_model.parallel_start_layer} start_mode={'physical' if base_model.parallel_start_layer_is_physical else 'virtual'} final_lane={base_model.parallel_final_lane} freeze_lane0={int(base_model.parallel_freeze_lane0)} identity_init={int(base_model.parallel_identity_init)} skip_lane0_only={int(base_model.parallel_skip_lane0_only)} mlp_read_mix={int(base_model.parallel_mlp_read_mix)}") + optimizers = Optimizers(h, base_model) + train_loader = ShuffledSequenceLoader(h, device) + max_wallclock_ms = 1000.0 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None + if max_wallclock_ms is not None: + max_wallclock_ms -= h.gptq_reserve_seconds * 1000.0 + log(f'gptq:reserving {h.gptq_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms') + + def training_frac(step, elapsed_ms): + if max_wallclock_ms is None: + return step / max(h.iterations, 1) + return elapsed_ms / max(max_wallclock_ms, 1e-09) + + def lr_mul(frac): + if h.warmdown_frac <= 0: + return 1.0 + if frac >= 1.0 - h.warmdown_frac: + return max((1.0 - frac) / h.warmdown_frac, h.min_lr) + return 1.0 + + def step_fn(step, lr_scale): + optimizers.zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(h.grad_accum_steps): + x, y = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps) + with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss / h.grad_accum_steps).backward() + train_loss /= h.grad_accum_steps + frac = min(step / h.muon_momentum_warmup_steps, 1.0) if h.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * h.muon_momentum_warmup_start + frac * h.muon_momentum + for group in optimizers.optimizer_muon.param_groups: + group['momentum'] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group['lr'] = group['base_lr'] * lr_scale + if h.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), h.grad_clip_norm) + optimizers.step(distributed=h.distributed) + return train_loss + if h.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if warmup_step <= 5 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == h.warmup_steps: + log(f'warmup_step: {warmup_step + 1}/{h.warmup_steps}') + if h.num_loops > 0: + base_model.looping_active = True + log(f'loop_warmup:enabled encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}') + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if warmup_step <= 5 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == h.warmup_steps: + log(f'loop_warmup_step: {warmup_step + 1}/{h.warmup_steps}') + base_model.looping_active = False + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + optimizers.zero_grad_all() + train_loader = ShuffledSequenceLoader(h, device) + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = h.ema_decay + training_time_ms = 0.0 + stop_after_step = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == h.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (h.val_loss_every > 0 and step % h.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(h, device, val_data, model) + log(f'{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}') + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < h.iterations: + log(f'stopping_early: wallclock_cap train_time: {training_time_ms:.0f}ms step: {step}/{h.iterations}') + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + frac = training_frac(step, elapsed_ms) + scale = lr_mul(frac) + if h.num_loops > 0 and (not base_model.looping_active) and (frac >= h.enable_looping_at): + base_model.looping_active = True + log(f'layer_loop:enabled step:{step} frac:{frac:.3f} encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}') + train_loss = step_fn(step, scale) + with torch.no_grad(): + sd = base_model.state_dict() + ema_vals = [ema_state[k] for k in sd] + live_vals = [sd[k].detach().float() for k in sd] + torch._foreach_mul_(ema_vals, ema_decay) + torch._foreach_add_(ema_vals, live_vals, alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = h.train_log_every > 0 and (step <= 5 or step % h.train_log_every == 0 or stop_after_step is not None) + if should_log_train: + tok_per_sec = step * h.train_batch_tokens / (approx_training_time_ms / 1000.0) + log(f'{step}/{h.iterations} train_loss: {train_loss.item():.4f} train_time: {approx_training_time_ms / 60000:.1f}m tok/s: {tok_per_sec:.0f}') + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if h.distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log(f'peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB') + log('ema:applying EMA weights') + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + log_parallel_residual_converged(log, base_model) + return (base_model, compiled_model) + +def train_and_eval(h, device): + random.seed(h.seed) + np.random.seed(h.seed) + torch.manual_seed(h.seed) + torch.cuda.manual_seed_all(h.seed) + val_data = ValidationData(h, device) + log(f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}") + log(f'val_tokens: {val_data.val_tokens.numel() - 1}') + base_model, compiled_model = train_model(h, device, val_data) + torch._dynamo.reset() + timed_eval('pre-quantization post-ema', eval_val, h, device, val_data, compiled_model) + serialize(h, base_model, Path(__file__).read_text(encoding='utf-8')) + if h.distributed: + dist.barrier() + eval_model = deserialize(h, device) + if h.num_loops > 0: + eval_model.looping_active = True + compiled_model = torch.compile(eval_model, dynamic=False, fullgraph=True) + timed_eval('quantized', eval_val, h, device, val_data, compiled_model) + if h.sliding_window_enabled: + timed_eval('quantized_sliding_window', eval_val_sliding, h, device, val_data, eval_model) + if h.ttt_enabled: + del eval_model, compiled_model + torch._dynamo.reset() + torch.cuda.empty_cache() + ttt_model = deserialize(h, device) + if h.num_loops > 0: + ttt_model.looping_active = True + timed_eval('legal_ttt_exact', eval_val_sliding_ttt, h, ttt_model, h.rank, h.world_size, device, val_data, stride=h.eval_stride) + del ttt_model + +def main(): + world_size = int(os.environ.get('WORLD_SIZE', '1')) + local_rank = int(os.environ.get('LOCAL_RANK', '0')) + distributed = 'RANK' in os.environ and 'WORLD_SIZE' in os.environ + if not torch.cuda.is_available(): + raise RuntimeError('CUDA is required') + if world_size <= 0: + raise ValueError(f'WORLD_SIZE must be positive, got {world_size}') + if 8 % world_size != 0: + raise ValueError(f'WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral') + device = torch.device('cuda', local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend='nccl', device_id=device) + dist.barrier() + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision('high') + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + torch._dynamo.config.optimize_ddp = False + h = Hyperparameters() + set_logging_hparams(h) + if h.is_main_process: + os.makedirs('logs', exist_ok=True) + log(100 * '=', console=False) + log('Hyperparameters:', console=True) + for k, v in sorted(vars(type(h)).items()): + if not k.startswith('_'): + log(f' {k}: {v}', console=True) + log('=' * 100, console=False) + log(f'Running Python {sys.version}', console=False) + log(f'Running PyTorch {torch.__version__}', console=False) + log(subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + log('=' * 100, console=False) + train_and_eval(h, device) + if distributed: + dist.destroy_process_group() +if __name__ == '__main__': + main() diff --git a/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/train_seed1337.log b/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/train_seed1337.log new file mode 100644 index 0000000000..a7974b4810 --- /dev/null +++ b/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/train_seed1337.log @@ -0,0 +1,291 @@ +W0413 00:53:43.664000 103112 torch/distributed/run.py:803] +W0413 00:53:43.664000 103112 torch/distributed/run.py:803] ***************************************** +W0413 00:53:43.664000 103112 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0413 00:53:43.664000 103112 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.997 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.095 + embedding_dim: 512 + enable_looping_at: 0.35 + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 13.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + hash_embed_enabled: True + hash_embed_size: 16384 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/8417e04a-cb72-4fc2-8a25-26bb5bf31ca4.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.022 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_final_lane: mean + parallel_freeze_lane0: False + parallel_identity_init: True + parallel_mlp_read_mix: False + parallel_residual: True + parallel_residual_start: 8 + parallel_skip_lane0_only: True + parallel_start_layer: 8 + parallel_start_layer_is_physical: True + qk_gain_init: 5.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: 8417e04a-cb72-4fc2-8a25-26bb5bf31ca4 + scalar_lr: 0.02 + seed: 1337 + skip_gates_enabled: True + sliding_window_enabled: True + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_adamw_wd: 0.0 + ttt_batch_seqs: 32 + ttt_chunk_tokens: 32768 + ttt_enabled: True + ttt_epochs: 3 + ttt_freeze_blocks: 0 + ttt_grad_clip: 1.0 + ttt_lr: 0.01 + ttt_momentum: 0.9 + ttt_optimizer: sgd + val_batch_tokens: 524288 + val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.667 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 80 +val_tokens: 40540160 +model_params:35944602 +parallel_residual:active=1 start_layer=8 start_mode=physical final_lane=mean freeze_lane0=0 identity_init=1 skip_lane0_only=1 mlp_read_mix=0 +gptq:reserving 13s, effective=587000ms +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0095 val_bpb: 3.4878 +1/20000 train_loss: 9.0103 train_time: 0.0m tok/s: 8310235 +2/20000 train_loss: 12.3889 train_time: 0.0m tok/s: 13069122 +3/20000 train_loss: 11.1292 train_time: 0.0m tok/s: 10952417 +4/20000 train_loss: 9.5969 train_time: 0.0m tok/s: 9875428 +5/20000 train_loss: 8.4219 train_time: 0.0m tok/s: 9572091 +500/20000 train_loss: 3.3681 train_time: 0.8m tok/s: 8015719 +1000/20000 train_loss: 3.2743 train_time: 1.6m tok/s: 7992840 +1500/20000 train_loss: 3.1761 train_time: 2.5m tok/s: 7991826 +2000/20000 train_loss: 3.0924 train_time: 3.3m tok/s: 7998008 +layer_loop:enabled step:2089 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +2500/20000 train_loss: 3.1424 train_time: 4.4m tok/s: 7430458 +3000/20000 train_loss: 2.9150 train_time: 5.6m tok/s: 7011663 +3500/20000 train_loss: 2.9542 train_time: 6.8m tok/s: 6741498 +4000/20000 train_loss: 2.8339 train_time: 8.0m tok/s: 6552396 +4000/20000 val_loss: 2.8894 val_bpb: 1.1186 +4500/20000 train_loss: 2.8468 train_time: 9.2m tok/s: 6412976 +4745/20000 val_loss: 2.7956 val_bpb: 1.0823 +stopping_early: wallclock_cap train_time: 587180ms step: 4745/20000 +peak memory allocated: 39972 MiB reserved: 40024 MiB +ema:applying EMA weights +parallel_residual:converged active=1 start_layer=8 start_mode=physical final_lane=mean freeze_lane0=0 identity_init=1 skip_lane0_only=1 mlp_read_mix=0 used_layers=3 +parallel_residual layer:14 physical:8 attn_resid:3.2508 attn_to_attn:-0.3672 attn_to_mlp:0.4779 mlp_resid:0.4243 mlp_to_attn:0.0263 mlp_to_mlp:0.6252 +parallel_residual layer:15 physical:9 attn_resid:0.7172 attn_to_attn:-0.0746 attn_to_mlp:0.3969 mlp_resid:0.4664 mlp_to_attn:0.2109 mlp_to_mlp:0.5664 +parallel_residual layer:16 physical:10 attn_resid:-0.0339 attn_to_attn:0.1468 attn_to_mlp:0.1468 mlp_resid:0.5235 mlp_to_attn:0.5712 mlp_to_mlp:0.5712 +pre-quantization post-ema val_loss:2.79557113 val_bpb:1.08225270 eval_time:6243ms +Serialized model: 135409136 bytes +Code size: 23191 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 12.6s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int8): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, parallel_post_lambdas, parallel_resid_lambdas, skip_gates, skip_weights +Serialized model quantized+brotli: 15960628 bytes +Total submission size quantized+brotli: 15983819 bytes +quantized val_loss:2.82196286 val_bpb:1.09246976 eval_time:29284ms +quantized_sliding_window val_loss:2.77849593 val_bpb:1.07564236 eval_time:124671ms +ttt_sliding:start chunks=1238 chunk_tokens=32768 total_windows=633409 stride=64 ttt_lr=0.01 ttt_epochs=3 freeze_blocks=0 optimizer=sgd hash_embed=True +ttt_sliding:params unfrozen=44333210 frozen=0 + ttt_chunk [1/1238] bpb=1.111309 time=46.3s + ttt_chunk [11/1238] bpb=1.067088 time=70.5s + ttt_chunk [21/1238] bpb=1.103087 time=73.1s + ttt_chunk [31/1238] bpb=1.096945 time=75.8s + ttt_chunk [41/1238] bpb=1.090075 time=78.4s + ttt_chunk [51/1238] bpb=1.083189 time=81.0s + ttt_chunk [61/1238] bpb=1.074716 time=83.7s + ttt_chunk [71/1238] bpb=1.081960 time=86.3s + ttt_chunk [81/1238] bpb=1.075413 time=88.9s + ttt_chunk [91/1238] bpb=1.072101 time=91.6s + ttt_chunk [101/1238] bpb=1.071789 time=94.3s + ttt_chunk [111/1238] bpb=1.070028 time=97.0s + ttt_chunk [121/1238] bpb=1.073103 time=99.6s + ttt_chunk [131/1238] bpb=1.076775 time=102.2s + ttt_chunk [141/1238] bpb=1.077485 time=104.8s + ttt_chunk [151/1238] bpb=1.077241 time=107.5s + ttt_chunk [161/1238] bpb=1.077876 time=110.1s + ttt_chunk [171/1238] bpb=1.077727 time=112.7s + ttt_chunk [181/1238] bpb=1.076406 time=115.4s + ttt_chunk [191/1238] bpb=1.076113 time=118.0s + ttt_chunk [201/1238] bpb=1.073676 time=120.6s + ttt_chunk [211/1238] bpb=1.078113 time=123.2s + ttt_chunk [221/1238] bpb=1.078522 time=125.8s + ttt_chunk [231/1238] bpb=1.080276 time=128.4s + ttt_chunk [241/1238] bpb=1.078525 time=131.0s + ttt_chunk [251/1238] bpb=1.078542 time=133.7s + ttt_chunk [261/1238] bpb=1.079616 time=136.3s + ttt_chunk [271/1238] bpb=1.080011 time=138.9s + ttt_chunk [281/1238] bpb=1.079360 time=141.5s + ttt_chunk [291/1238] bpb=1.080478 time=144.2s + ttt_chunk [301/1238] bpb=1.080631 time=146.8s + ttt_chunk [311/1238] bpb=1.079524 time=149.4s + ttt_chunk [321/1238] bpb=1.079376 time=152.0s + ttt_chunk [331/1238] bpb=1.079655 time=154.6s + ttt_chunk [341/1238] bpb=1.078734 time=157.3s + ttt_chunk [351/1238] bpb=1.079433 time=159.9s + ttt_chunk [361/1238] bpb=1.078418 time=162.5s + ttt_chunk [371/1238] bpb=1.076881 time=165.2s + ttt_chunk [381/1238] bpb=1.077275 time=167.8s + ttt_chunk [391/1238] bpb=1.076899 time=170.4s + ttt_chunk [401/1238] bpb=1.076994 time=173.0s + ttt_chunk [411/1238] bpb=1.077535 time=175.7s + ttt_chunk [421/1238] bpb=1.077036 time=178.3s + ttt_chunk [431/1238] bpb=1.077208 time=180.9s + ttt_chunk [441/1238] bpb=1.077265 time=183.5s + ttt_chunk [451/1238] bpb=1.078411 time=186.1s + ttt_chunk [461/1238] bpb=1.076651 time=188.8s + ttt_chunk [471/1238] bpb=1.076575 time=191.4s + ttt_chunk [481/1238] bpb=1.076742 time=194.1s + ttt_chunk [491/1238] bpb=1.077207 time=196.7s + ttt_chunk [501/1238] bpb=1.076779 time=199.3s + ttt_chunk [511/1238] bpb=1.076372 time=201.9s + ttt_chunk [521/1238] bpb=1.075884 time=204.5s + ttt_chunk [531/1238] bpb=1.075858 time=207.2s + ttt_chunk [541/1238] bpb=1.075931 time=209.8s + ttt_chunk [551/1238] bpb=1.075472 time=212.4s + ttt_chunk [561/1238] bpb=1.074787 time=215.0s + ttt_chunk [571/1238] bpb=1.074243 time=217.7s + ttt_chunk [581/1238] bpb=1.074578 time=220.3s + ttt_chunk [591/1238] bpb=1.074786 time=222.9s + ttt_chunk [601/1238] bpb=1.074672 time=225.5s + ttt_chunk [611/1238] bpb=1.075267 time=228.1s + ttt_chunk [621/1238] bpb=1.076132 time=230.8s + ttt_chunk [631/1238] bpb=1.076166 time=233.4s + ttt_chunk [641/1238] bpb=1.076596 time=236.0s + ttt_chunk [651/1238] bpb=1.076911 time=238.6s + ttt_chunk [661/1238] bpb=1.076232 time=241.2s + ttt_chunk [671/1238] bpb=1.076007 time=243.9s + ttt_chunk [681/1238] bpb=1.077332 time=246.5s + ttt_chunk [691/1238] bpb=1.077501 time=249.2s + ttt_chunk [701/1238] bpb=1.077259 time=251.8s + ttt_chunk [711/1238] bpb=1.077947 time=254.4s + ttt_chunk [721/1238] bpb=1.078238 time=257.1s + ttt_chunk [731/1238] bpb=1.077592 time=259.7s + ttt_chunk [741/1238] bpb=1.077275 time=262.3s + ttt_chunk [751/1238] bpb=1.076351 time=264.9s + ttt_chunk [761/1238] bpb=1.075745 time=267.6s + ttt_chunk [771/1238] bpb=1.074736 time=270.2s + ttt_chunk [781/1238] bpb=1.074732 time=272.8s + ttt_chunk [791/1238] bpb=1.075066 time=275.5s + ttt_chunk [801/1238] bpb=1.075362 time=278.1s + ttt_chunk [811/1238] bpb=1.074860 time=280.7s + ttt_chunk [821/1238] bpb=1.073635 time=283.4s + ttt_chunk [831/1238] bpb=1.073298 time=286.0s + ttt_chunk [841/1238] bpb=1.072810 time=288.6s + ttt_chunk [851/1238] bpb=1.072505 time=291.2s + ttt_chunk [861/1238] bpb=1.072152 time=293.9s + ttt_chunk [871/1238] bpb=1.072041 time=296.5s + ttt_chunk [881/1238] bpb=1.071559 time=299.2s + ttt_chunk [891/1238] bpb=1.071033 time=301.8s + ttt_chunk [901/1238] bpb=1.071379 time=304.4s + ttt_chunk [911/1238] bpb=1.071055 time=307.0s + ttt_chunk [921/1238] bpb=1.071296 time=309.6s + ttt_chunk [931/1238] bpb=1.071982 time=312.3s + ttt_chunk [941/1238] bpb=1.072346 time=314.9s + ttt_chunk [951/1238] bpb=1.072260 time=317.6s + ttt_chunk [961/1238] bpb=1.073072 time=320.2s + ttt_chunk [971/1238] bpb=1.073457 time=322.8s + ttt_chunk [981/1238] bpb=1.073823 time=325.4s + ttt_chunk [991/1238] bpb=1.073583 time=328.1s + ttt_chunk [1001/1238] bpb=1.073608 time=330.7s + ttt_chunk [1011/1238] bpb=1.073942 time=333.3s + ttt_chunk [1021/1238] bpb=1.074637 time=336.0s + ttt_chunk [1031/1238] bpb=1.075081 time=338.6s + ttt_chunk [1041/1238] bpb=1.075539 time=341.2s + ttt_chunk [1051/1238] bpb=1.075455 time=343.8s + ttt_chunk [1061/1238] bpb=1.075448 time=346.5s + ttt_chunk [1071/1238] bpb=1.075596 time=349.1s + ttt_chunk [1081/1238] bpb=1.075483 time=351.8s + ttt_chunk [1091/1238] bpb=1.075678 time=354.4s + ttt_chunk [1101/1238] bpb=1.076203 time=357.1s + ttt_chunk [1111/1238] bpb=1.076498 time=359.7s + ttt_chunk [1121/1238] bpb=1.076657 time=362.3s + ttt_chunk [1131/1238] bpb=1.076296 time=365.0s + ttt_chunk [1141/1238] bpb=1.075962 time=367.6s + ttt_chunk [1151/1238] bpb=1.075970 time=370.3s + ttt_chunk [1161/1238] bpb=1.076097 time=372.9s + ttt_chunk [1171/1238] bpb=1.075866 time=375.5s + ttt_chunk [1181/1238] bpb=1.075382 time=378.2s + ttt_chunk [1191/1238] bpb=1.075513 time=380.8s + ttt_chunk [1201/1238] bpb=1.075587 time=383.5s + ttt_chunk [1211/1238] bpb=1.075254 time=386.1s + ttt_chunk [1221/1238] bpb=1.074799 time=388.7s + ttt_chunk [1231/1238] bpb=1.074433 time=391.4s + ttt_chunk [1238/1238] bpb=1.074445 time=414.6s +ttt_sliding:done val_loss=2.775560 val_bpb=1.07450565 elapsed=415.0s +legal_ttt_exact val_loss:2.77555969 val_bpb:1.07450565 eval_time:415268ms diff --git a/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/train_seed2024.log b/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/train_seed2024.log new file mode 100644 index 0000000000..92d90f5ccc --- /dev/null +++ b/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/train_seed2024.log @@ -0,0 +1,291 @@ +W0413 00:53:58.704000 95934 torch/distributed/run.py:803] +W0413 00:53:58.704000 95934 torch/distributed/run.py:803] ***************************************** +W0413 00:53:58.704000 95934 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0413 00:53:58.704000 95934 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.997 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.095 + embedding_dim: 512 + enable_looping_at: 0.35 + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 13.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + hash_embed_enabled: True + hash_embed_size: 16384 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/a0be6aeb-e504-4007-a142-50278288f53a.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.022 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_final_lane: mean + parallel_freeze_lane0: False + parallel_identity_init: True + parallel_mlp_read_mix: False + parallel_residual: True + parallel_residual_start: 8 + parallel_skip_lane0_only: True + parallel_start_layer: 8 + parallel_start_layer_is_physical: True + qk_gain_init: 5.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: a0be6aeb-e504-4007-a142-50278288f53a + scalar_lr: 0.02 + seed: 2024 + skip_gates_enabled: True + sliding_window_enabled: True + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_adamw_wd: 0.0 + ttt_batch_seqs: 32 + ttt_chunk_tokens: 32768 + ttt_enabled: True + ttt_epochs: 3 + ttt_freeze_blocks: 0 + ttt_grad_clip: 1.0 + ttt_lr: 0.01 + ttt_momentum: 0.9 + ttt_optimizer: sgd + val_batch_tokens: 524288 + val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.667 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 80 +val_tokens: 40540160 +model_params:35944602 +parallel_residual:active=1 start_layer=8 start_mode=physical final_lane=mean freeze_lane0=0 identity_init=1 skip_lane0_only=1 mlp_read_mix=0 +gptq:reserving 13s, effective=587000ms +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0090 val_bpb: 3.4877 +1/20000 train_loss: 9.0109 train_time: 0.0m tok/s: 8254990 +2/20000 train_loss: 12.4898 train_time: 0.0m tok/s: 13002332 +3/20000 train_loss: 11.1637 train_time: 0.0m tok/s: 10874479 +4/20000 train_loss: 9.6277 train_time: 0.0m tok/s: 9981702 +5/20000 train_loss: 8.4097 train_time: 0.0m tok/s: 9494866 +500/20000 train_loss: 3.3815 train_time: 0.8m tok/s: 7972213 +1000/20000 train_loss: 3.2741 train_time: 1.6m tok/s: 7953599 +1500/20000 train_loss: 3.1812 train_time: 2.5m tok/s: 7952599 +2000/20000 train_loss: 3.0990 train_time: 3.3m tok/s: 7955243 +layer_loop:enabled step:2078 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +2500/20000 train_loss: 3.1395 train_time: 4.4m tok/s: 7381750 +3000/20000 train_loss: 2.9162 train_time: 5.6m tok/s: 6970116 +3500/20000 train_loss: 2.9568 train_time: 6.8m tok/s: 6703481 +4000/20000 train_loss: 2.8342 train_time: 8.0m tok/s: 6516775 +4000/20000 val_loss: 2.8898 val_bpb: 1.1187 +4500/20000 train_loss: 2.8513 train_time: 9.2m tok/s: 6379420 +4724/20000 val_loss: 2.7981 val_bpb: 1.0832 +stopping_early: wallclock_cap train_time: 587210ms step: 4724/20000 +peak memory allocated: 39972 MiB reserved: 40024 MiB +ema:applying EMA weights +parallel_residual:converged active=1 start_layer=8 start_mode=physical final_lane=mean freeze_lane0=0 identity_init=1 skip_lane0_only=1 mlp_read_mix=0 used_layers=3 +parallel_residual layer:14 physical:8 attn_resid:2.5397 attn_to_attn:0.0143 attn_to_mlp:0.4384 mlp_resid:0.4159 mlp_to_attn:-0.2457 mlp_to_mlp:0.7306 +parallel_residual layer:15 physical:9 attn_resid:0.4134 attn_to_attn:1.6293 attn_to_mlp:0.0727 mlp_resid:0.4515 mlp_to_attn:-0.0005 mlp_to_mlp:0.6379 +parallel_residual layer:16 physical:10 attn_resid:0.0049 attn_to_attn:0.3035 attn_to_mlp:0.3035 mlp_resid:0.4304 mlp_to_attn:0.6048 mlp_to_mlp:0.6048 +pre-quantization post-ema val_loss:2.79825273 val_bpb:1.08329084 eval_time:6208ms +Serialized model: 135409136 bytes +Code size: 23191 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 12.6s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int8): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, parallel_post_lambdas, parallel_resid_lambdas, skip_gates, skip_weights +Serialized model quantized+brotli: 15959183 bytes +Total submission size quantized+brotli: 15982374 bytes +quantized val_loss:2.82569916 val_bpb:1.09391620 eval_time:29285ms +quantized_sliding_window val_loss:2.78165941 val_bpb:1.07686704 eval_time:124296ms +ttt_sliding:start chunks=1238 chunk_tokens=32768 total_windows=633409 stride=64 ttt_lr=0.01 ttt_epochs=3 freeze_blocks=0 optimizer=sgd hash_embed=True +ttt_sliding:params unfrozen=44333210 frozen=0 + ttt_chunk [1/1238] bpb=1.112863 time=46.2s + ttt_chunk [11/1238] bpb=1.066513 time=70.4s + ttt_chunk [21/1238] bpb=1.103108 time=73.1s + ttt_chunk [31/1238] bpb=1.097603 time=75.7s + ttt_chunk [41/1238] bpb=1.090968 time=78.3s + ttt_chunk [51/1238] bpb=1.084285 time=80.9s + ttt_chunk [61/1238] bpb=1.076184 time=83.6s + ttt_chunk [71/1238] bpb=1.083247 time=86.2s + ttt_chunk [81/1238] bpb=1.076489 time=88.8s + ttt_chunk [91/1238] bpb=1.073023 time=91.4s + ttt_chunk [101/1238] bpb=1.072798 time=94.1s + ttt_chunk [111/1238] bpb=1.071218 time=96.7s + ttt_chunk [121/1238] bpb=1.074226 time=99.3s + ttt_chunk [131/1238] bpb=1.077939 time=101.9s + ttt_chunk [141/1238] bpb=1.078480 time=104.6s + ttt_chunk [151/1238] bpb=1.078337 time=107.2s + ttt_chunk [161/1238] bpb=1.078877 time=109.8s + ttt_chunk [171/1238] bpb=1.078735 time=112.4s + ttt_chunk [181/1238] bpb=1.077425 time=115.1s + ttt_chunk [191/1238] bpb=1.077242 time=117.7s + ttt_chunk [201/1238] bpb=1.074778 time=120.3s + ttt_chunk [211/1238] bpb=1.079251 time=122.9s + ttt_chunk [221/1238] bpb=1.079716 time=125.5s + ttt_chunk [231/1238] bpb=1.081389 time=128.2s + ttt_chunk [241/1238] bpb=1.079666 time=130.8s + ttt_chunk [251/1238] bpb=1.079637 time=133.4s + ttt_chunk [261/1238] bpb=1.080621 time=136.0s + ttt_chunk [271/1238] bpb=1.080985 time=138.7s + ttt_chunk [281/1238] bpb=1.080317 time=141.3s + ttt_chunk [291/1238] bpb=1.081438 time=143.9s + ttt_chunk [301/1238] bpb=1.081636 time=146.5s + ttt_chunk [311/1238] bpb=1.080496 time=149.2s + ttt_chunk [321/1238] bpb=1.080339 time=151.8s + ttt_chunk [331/1238] bpb=1.080641 time=154.4s + ttt_chunk [341/1238] bpb=1.079753 time=157.1s + ttt_chunk [351/1238] bpb=1.080463 time=159.7s + ttt_chunk [361/1238] bpb=1.079375 time=162.3s + ttt_chunk [371/1238] bpb=1.077809 time=164.9s + ttt_chunk [381/1238] bpb=1.078170 time=167.5s + ttt_chunk [391/1238] bpb=1.077858 time=170.2s + ttt_chunk [401/1238] bpb=1.077935 time=172.8s + ttt_chunk [411/1238] bpb=1.078466 time=175.4s + ttt_chunk [421/1238] bpb=1.077916 time=178.1s + ttt_chunk [431/1238] bpb=1.078071 time=180.7s + ttt_chunk [441/1238] bpb=1.078093 time=183.3s + ttt_chunk [451/1238] bpb=1.079237 time=185.9s + ttt_chunk [461/1238] bpb=1.077500 time=188.5s + ttt_chunk [471/1238] bpb=1.077468 time=191.1s + ttt_chunk [481/1238] bpb=1.077640 time=193.7s + ttt_chunk [491/1238] bpb=1.078069 time=196.4s + ttt_chunk [501/1238] bpb=1.077687 time=199.1s + ttt_chunk [511/1238] bpb=1.077315 time=201.7s + ttt_chunk [521/1238] bpb=1.076806 time=204.4s + ttt_chunk [531/1238] bpb=1.076769 time=207.0s + ttt_chunk [541/1238] bpb=1.076854 time=209.6s + ttt_chunk [551/1238] bpb=1.076381 time=212.2s + ttt_chunk [561/1238] bpb=1.075653 time=214.9s + ttt_chunk [571/1238] bpb=1.075095 time=217.5s + ttt_chunk [581/1238] bpb=1.075464 time=220.1s + ttt_chunk [591/1238] bpb=1.075676 time=222.7s + ttt_chunk [601/1238] bpb=1.075569 time=225.4s + ttt_chunk [611/1238] bpb=1.076171 time=228.0s + ttt_chunk [621/1238] bpb=1.076979 time=230.6s + ttt_chunk [631/1238] bpb=1.077007 time=233.2s + ttt_chunk [641/1238] bpb=1.077464 time=235.9s + ttt_chunk [651/1238] bpb=1.077778 time=238.5s + ttt_chunk [661/1238] bpb=1.077117 time=241.1s + ttt_chunk [671/1238] bpb=1.076887 time=243.7s + ttt_chunk [681/1238] bpb=1.078205 time=246.3s + ttt_chunk [691/1238] bpb=1.078398 time=249.0s + ttt_chunk [701/1238] bpb=1.078182 time=251.6s + ttt_chunk [711/1238] bpb=1.078843 time=254.2s + ttt_chunk [721/1238] bpb=1.079148 time=256.8s + ttt_chunk [731/1238] bpb=1.078484 time=259.5s + ttt_chunk [741/1238] bpb=1.078144 time=262.1s + ttt_chunk [751/1238] bpb=1.077203 time=264.7s + ttt_chunk [761/1238] bpb=1.076585 time=267.4s + ttt_chunk [771/1238] bpb=1.075556 time=270.0s + ttt_chunk [781/1238] bpb=1.075517 time=272.6s + ttt_chunk [791/1238] bpb=1.075863 time=275.2s + ttt_chunk [801/1238] bpb=1.076119 time=277.9s + ttt_chunk [811/1238] bpb=1.075627 time=280.5s + ttt_chunk [821/1238] bpb=1.074412 time=283.1s + ttt_chunk [831/1238] bpb=1.074080 time=285.7s + ttt_chunk [841/1238] bpb=1.073595 time=288.3s + ttt_chunk [851/1238] bpb=1.073293 time=291.0s + ttt_chunk [861/1238] bpb=1.072948 time=293.6s + ttt_chunk [871/1238] bpb=1.072840 time=296.3s + ttt_chunk [881/1238] bpb=1.072355 time=298.9s + ttt_chunk [891/1238] bpb=1.071798 time=301.5s + ttt_chunk [901/1238] bpb=1.072164 time=304.1s + ttt_chunk [911/1238] bpb=1.071850 time=306.7s + ttt_chunk [921/1238] bpb=1.072114 time=309.3s + ttt_chunk [931/1238] bpb=1.072794 time=311.9s + ttt_chunk [941/1238] bpb=1.073193 time=314.6s + ttt_chunk [951/1238] bpb=1.073105 time=317.2s + ttt_chunk [961/1238] bpb=1.073902 time=319.9s + ttt_chunk [971/1238] bpb=1.074282 time=322.5s + ttt_chunk [981/1238] bpb=1.074661 time=325.2s + ttt_chunk [991/1238] bpb=1.074438 time=327.8s + ttt_chunk [1001/1238] bpb=1.074443 time=330.4s + ttt_chunk [1011/1238] bpb=1.074764 time=333.0s + ttt_chunk [1021/1238] bpb=1.075463 time=335.7s + ttt_chunk [1031/1238] bpb=1.075918 time=338.3s + ttt_chunk [1041/1238] bpb=1.076376 time=340.9s + ttt_chunk [1051/1238] bpb=1.076297 time=343.5s + ttt_chunk [1061/1238] bpb=1.076302 time=346.1s + ttt_chunk [1071/1238] bpb=1.076466 time=348.7s + ttt_chunk [1081/1238] bpb=1.076346 time=351.3s + ttt_chunk [1091/1238] bpb=1.076535 time=354.0s + ttt_chunk [1101/1238] bpb=1.077067 time=356.6s + ttt_chunk [1111/1238] bpb=1.077359 time=359.2s + ttt_chunk [1121/1238] bpb=1.077539 time=361.8s + ttt_chunk [1131/1238] bpb=1.077189 time=364.5s + ttt_chunk [1141/1238] bpb=1.076830 time=367.1s + ttt_chunk [1151/1238] bpb=1.076860 time=369.7s + ttt_chunk [1161/1238] bpb=1.076982 time=372.3s + ttt_chunk [1171/1238] bpb=1.076742 time=374.9s + ttt_chunk [1181/1238] bpb=1.076257 time=377.6s + ttt_chunk [1191/1238] bpb=1.076389 time=380.2s + ttt_chunk [1201/1238] bpb=1.076475 time=382.8s + ttt_chunk [1211/1238] bpb=1.076166 time=385.5s + ttt_chunk [1221/1238] bpb=1.075701 time=388.1s + ttt_chunk [1231/1238] bpb=1.075324 time=390.7s + ttt_chunk [1238/1238] bpb=1.075326 time=413.8s +ttt_sliding:done val_loss=2.778227 val_bpb=1.07553843 elapsed=414.5s +legal_ttt_exact val_loss:2.77822748 val_bpb:1.07553843 eval_time:414740ms diff --git a/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/train_seed42.log b/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/train_seed42.log new file mode 100644 index 0000000000..27247969fb --- /dev/null +++ b/records/track_10min_16mb/2026-04-13_SystemsOpt_ImprovedParallelResiduals/train_seed42.log @@ -0,0 +1,291 @@ +W0413 01:28:52.150000 201825 torch/distributed/run.py:803] +W0413 01:28:52.150000 201825 torch/distributed/run.py:803] ***************************************** +W0413 01:28:52.150000 201825 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0413 01:28:52.150000 201825 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.997 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.095 + embedding_dim: 512 + enable_looping_at: 0.35 + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 13.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + hash_embed_enabled: True + hash_embed_size: 16384 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/acea95f7-6d5e-4924-bf21-c7f7a4a73752.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.022 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_final_lane: mean + parallel_freeze_lane0: False + parallel_identity_init: True + parallel_mlp_read_mix: False + parallel_residual: True + parallel_residual_start: 8 + parallel_skip_lane0_only: True + parallel_start_layer: 8 + parallel_start_layer_is_physical: True + qk_gain_init: 5.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: acea95f7-6d5e-4924-bf21-c7f7a4a73752 + scalar_lr: 0.02 + seed: 42 + skip_gates_enabled: True + sliding_window_enabled: True + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_adamw_wd: 0.0 + ttt_batch_seqs: 32 + ttt_chunk_tokens: 32768 + ttt_enabled: True + ttt_epochs: 3 + ttt_freeze_blocks: 0 + ttt_grad_clip: 1.0 + ttt_lr: 0.01 + ttt_momentum: 0.9 + ttt_optimizer: sgd + val_batch_tokens: 524288 + val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.667 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 80 +val_tokens: 40540160 +model_params:35944602 +parallel_residual:active=1 start_layer=8 start_mode=physical final_lane=mean freeze_lane0=0 identity_init=1 skip_lane0_only=1 mlp_read_mix=0 +gptq:reserving 13s, effective=587000ms +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0078 val_bpb: 3.4872 +1/20000 train_loss: 9.0109 train_time: 0.0m tok/s: 8348431 +2/20000 train_loss: 12.5061 train_time: 0.0m tok/s: 13111921 +3/20000 train_loss: 11.1931 train_time: 0.0m tok/s: 10857132 +4/20000 train_loss: 9.6270 train_time: 0.0m tok/s: 9993343 +5/20000 train_loss: 8.4157 train_time: 0.0m tok/s: 9494350 +500/20000 train_loss: 3.3660 train_time: 0.8m tok/s: 8007116 +1000/20000 train_loss: 3.2741 train_time: 1.6m tok/s: 7992588 +1500/20000 train_loss: 3.1788 train_time: 2.5m tok/s: 7992900 +2000/20000 train_loss: 3.0999 train_time: 3.3m tok/s: 8002163 +layer_loop:enabled step:2091 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +2500/20000 train_loss: 3.1455 train_time: 4.4m tok/s: 7425662 +3000/20000 train_loss: 2.9204 train_time: 5.6m tok/s: 7008287 +3500/20000 train_loss: 2.9606 train_time: 6.8m tok/s: 6738957 +4000/20000 train_loss: 2.8348 train_time: 8.0m tok/s: 6550404 +4000/20000 val_loss: 2.8918 val_bpb: 1.1195 +4500/20000 train_loss: 2.8505 train_time: 9.2m tok/s: 6411047 +4744/20000 val_loss: 2.7979 val_bpb: 1.0832 +stopping_early: wallclock_cap train_time: 587144ms step: 4744/20000 +peak memory allocated: 39964 MiB reserved: 40026 MiB +ema:applying EMA weights +parallel_residual:converged active=1 start_layer=8 start_mode=physical final_lane=mean freeze_lane0=0 identity_init=1 skip_lane0_only=1 mlp_read_mix=0 used_layers=3 +parallel_residual layer:14 physical:8 attn_resid:2.8245 attn_to_attn:-0.0192 attn_to_mlp:0.4629 mlp_resid:0.3730 mlp_to_attn:-0.1014 mlp_to_mlp:0.6794 +parallel_residual layer:15 physical:9 attn_resid:1.1370 attn_to_attn:0.4943 attn_to_mlp:0.3705 mlp_resid:0.4319 mlp_to_attn:0.1938 mlp_to_mlp:0.5627 +parallel_residual layer:16 physical:10 attn_resid:-0.0136 attn_to_attn:0.2324 attn_to_mlp:0.2324 mlp_resid:0.4923 mlp_to_attn:0.5796 mlp_to_mlp:0.5796 +pre-quantization post-ema val_loss:2.79805246 val_bpb:1.08321330 eval_time:6190ms +Serialized model: 135409136 bytes +Code size: 23191 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 12.6s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int8): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, parallel_post_lambdas, parallel_resid_lambdas, skip_gates, skip_weights +Serialized model quantized+brotli: 15956446 bytes +Total submission size quantized+brotli: 15979637 bytes +quantized val_loss:2.82693488 val_bpb:1.09439459 eval_time:8840ms +quantized_sliding_window val_loss:2.78284216 val_bpb:1.07732492 eval_time:92523ms +ttt_sliding:start chunks=1238 chunk_tokens=32768 total_windows=633409 stride=64 ttt_lr=0.01 ttt_epochs=3 freeze_blocks=0 optimizer=sgd hash_embed=True +ttt_sliding:params unfrozen=44333210 frozen=0 + ttt_chunk [1/1238] bpb=1.110149 time=5.1s + ttt_chunk [11/1238] bpb=1.069383 time=10.4s + ttt_chunk [21/1238] bpb=1.105840 time=13.0s + ttt_chunk [31/1238] bpb=1.099127 time=15.7s + ttt_chunk [41/1238] bpb=1.091953 time=18.3s + ttt_chunk [51/1238] bpb=1.085263 time=20.9s + ttt_chunk [61/1238] bpb=1.076903 time=23.5s + ttt_chunk [71/1238] bpb=1.083742 time=26.1s + ttt_chunk [81/1238] bpb=1.076863 time=28.8s + ttt_chunk [91/1238] bpb=1.073252 time=31.4s + ttt_chunk [101/1238] bpb=1.073247 time=34.0s + ttt_chunk [111/1238] bpb=1.071509 time=36.6s + ttt_chunk [121/1238] bpb=1.074299 time=39.2s + ttt_chunk [131/1238] bpb=1.078076 time=41.9s + ttt_chunk [141/1238] bpb=1.078484 time=44.5s + ttt_chunk [151/1238] bpb=1.078290 time=47.1s + ttt_chunk [161/1238] bpb=1.078785 time=49.7s + ttt_chunk [171/1238] bpb=1.078692 time=52.4s + ttt_chunk [181/1238] bpb=1.077160 time=55.0s + ttt_chunk [191/1238] bpb=1.076898 time=57.6s + ttt_chunk [201/1238] bpb=1.074459 time=60.2s + ttt_chunk [211/1238] bpb=1.078957 time=62.8s + ttt_chunk [221/1238] bpb=1.079355 time=65.5s + ttt_chunk [231/1238] bpb=1.080996 time=68.1s + ttt_chunk [241/1238] bpb=1.079232 time=70.7s + ttt_chunk [251/1238] bpb=1.079108 time=73.3s + ttt_chunk [261/1238] bpb=1.080093 time=75.9s + ttt_chunk [271/1238] bpb=1.080521 time=78.6s + ttt_chunk [281/1238] bpb=1.079801 time=81.2s + ttt_chunk [291/1238] bpb=1.080995 time=83.8s + ttt_chunk [301/1238] bpb=1.081134 time=86.4s + ttt_chunk [311/1238] bpb=1.079983 time=89.0s + ttt_chunk [321/1238] bpb=1.079817 time=91.6s + ttt_chunk [331/1238] bpb=1.080088 time=94.2s + ttt_chunk [341/1238] bpb=1.079175 time=96.9s + ttt_chunk [351/1238] bpb=1.079968 time=99.5s + ttt_chunk [361/1238] bpb=1.078920 time=102.1s + ttt_chunk [371/1238] bpb=1.077371 time=104.7s + ttt_chunk [381/1238] bpb=1.077789 time=107.3s + ttt_chunk [391/1238] bpb=1.077478 time=110.0s + ttt_chunk [401/1238] bpb=1.077559 time=112.6s + ttt_chunk [411/1238] bpb=1.078103 time=115.2s + ttt_chunk [421/1238] bpb=1.077626 time=117.8s + ttt_chunk [431/1238] bpb=1.077835 time=120.4s + ttt_chunk [441/1238] bpb=1.077858 time=123.0s + ttt_chunk [451/1238] bpb=1.079035 time=125.7s + ttt_chunk [461/1238] bpb=1.077260 time=128.3s + ttt_chunk [471/1238] bpb=1.077291 time=130.9s + ttt_chunk [481/1238] bpb=1.077462 time=133.5s + ttt_chunk [491/1238] bpb=1.077895 time=136.1s + ttt_chunk [501/1238] bpb=1.077516 time=138.7s + ttt_chunk [511/1238] bpb=1.077146 time=141.3s + ttt_chunk [521/1238] bpb=1.076674 time=143.9s + ttt_chunk [531/1238] bpb=1.076617 time=146.6s + ttt_chunk [541/1238] bpb=1.076695 time=149.2s + ttt_chunk [551/1238] bpb=1.076223 time=151.8s + ttt_chunk [561/1238] bpb=1.075536 time=154.4s + ttt_chunk [571/1238] bpb=1.074980 time=157.0s + ttt_chunk [581/1238] bpb=1.075340 time=159.7s + ttt_chunk [591/1238] bpb=1.075558 time=162.3s + ttt_chunk [601/1238] bpb=1.075476 time=164.9s + ttt_chunk [611/1238] bpb=1.076039 time=167.5s + ttt_chunk [621/1238] bpb=1.076876 time=170.1s + ttt_chunk [631/1238] bpb=1.076920 time=172.8s + ttt_chunk [641/1238] bpb=1.077368 time=175.4s + ttt_chunk [651/1238] bpb=1.077711 time=178.0s + ttt_chunk [661/1238] bpb=1.077061 time=180.6s + ttt_chunk [671/1238] bpb=1.076842 time=183.2s + ttt_chunk [681/1238] bpb=1.078134 time=185.8s + ttt_chunk [691/1238] bpb=1.078316 time=188.5s + ttt_chunk [701/1238] bpb=1.078105 time=191.1s + ttt_chunk [711/1238] bpb=1.078769 time=193.7s + ttt_chunk [721/1238] bpb=1.079075 time=196.3s + ttt_chunk [731/1238] bpb=1.078416 time=199.0s + ttt_chunk [741/1238] bpb=1.078093 time=201.6s + ttt_chunk [751/1238] bpb=1.077181 time=204.3s + ttt_chunk [761/1238] bpb=1.076581 time=206.9s + ttt_chunk [771/1238] bpb=1.075563 time=209.5s + ttt_chunk [781/1238] bpb=1.075526 time=212.1s + ttt_chunk [791/1238] bpb=1.075846 time=214.7s + ttt_chunk [801/1238] bpb=1.076134 time=217.3s + ttt_chunk [811/1238] bpb=1.075640 time=220.0s + ttt_chunk [821/1238] bpb=1.074439 time=222.6s + ttt_chunk [831/1238] bpb=1.074095 time=225.2s + ttt_chunk [841/1238] bpb=1.073622 time=227.8s + ttt_chunk [851/1238] bpb=1.073322 time=230.4s + ttt_chunk [861/1238] bpb=1.072973 time=233.0s + ttt_chunk [871/1238] bpb=1.072868 time=235.6s + ttt_chunk [881/1238] bpb=1.072403 time=238.2s + ttt_chunk [891/1238] bpb=1.071863 time=240.9s + ttt_chunk [901/1238] bpb=1.072193 time=243.5s + ttt_chunk [911/1238] bpb=1.071869 time=246.1s + ttt_chunk [921/1238] bpb=1.072129 time=248.7s + ttt_chunk [931/1238] bpb=1.072818 time=251.3s + ttt_chunk [941/1238] bpb=1.073180 time=253.9s + ttt_chunk [951/1238] bpb=1.073108 time=256.5s + ttt_chunk [961/1238] bpb=1.073945 time=259.2s + ttt_chunk [971/1238] bpb=1.074330 time=261.8s + ttt_chunk [981/1238] bpb=1.074675 time=264.4s + ttt_chunk [991/1238] bpb=1.074439 time=267.1s + ttt_chunk [1001/1238] bpb=1.074454 time=269.7s + ttt_chunk [1011/1238] bpb=1.074793 time=272.3s + ttt_chunk [1021/1238] bpb=1.075476 time=274.9s + ttt_chunk [1031/1238] bpb=1.075906 time=277.5s + ttt_chunk [1041/1238] bpb=1.076365 time=280.1s + ttt_chunk [1051/1238] bpb=1.076281 time=282.7s + ttt_chunk [1061/1238] bpb=1.076290 time=285.3s + ttt_chunk [1071/1238] bpb=1.076452 time=287.9s + ttt_chunk [1081/1238] bpb=1.076338 time=290.6s + ttt_chunk [1091/1238] bpb=1.076510 time=293.2s + ttt_chunk [1101/1238] bpb=1.077055 time=295.9s + ttt_chunk [1111/1238] bpb=1.077348 time=298.5s + ttt_chunk [1121/1238] bpb=1.077520 time=301.1s + ttt_chunk [1131/1238] bpb=1.077158 time=303.7s + ttt_chunk [1141/1238] bpb=1.076810 time=306.3s + ttt_chunk [1151/1238] bpb=1.076827 time=308.9s + ttt_chunk [1161/1238] bpb=1.076958 time=311.5s + ttt_chunk [1171/1238] bpb=1.076726 time=314.1s + ttt_chunk [1181/1238] bpb=1.076246 time=316.8s + ttt_chunk [1191/1238] bpb=1.076387 time=319.4s + ttt_chunk [1201/1238] bpb=1.076446 time=322.0s + ttt_chunk [1211/1238] bpb=1.076123 time=324.6s + ttt_chunk [1221/1238] bpb=1.075650 time=327.3s + ttt_chunk [1231/1238] bpb=1.075270 time=329.9s + ttt_chunk [1238/1238] bpb=1.075268 time=333.8s +ttt_sliding:done val_loss=2.778134 val_bpb=1.07550237 elapsed=334.2s +legal_ttt_exact val_loss:2.77813433 val_bpb:1.07550237 eval_time:334486ms