diff --git a/records/track_10min_16mb/2026-04-29_MUDD_Connections/README.md b/records/track_10min_16mb/2026-04-29_MUDD_Connections/README.md new file mode 100644 index 0000000000..1a379e7180 --- /dev/null +++ b/records/track_10min_16mb/2026-04-29_MUDD_Connections/README.md @@ -0,0 +1,111 @@ +# Record: MUDD Connections + SP8192 + 3-Layer Recurrence + Parallel Residuals + QK-Gain 5.25 + Legal TTT + +**val_bpb = 1.0769** (3-seed mean, std 0.0004) | **~15.99 MB** | 8xH100 HBM3 + +## 3-Seed Results + +| Seed | Sliding BPP | **TTT BPP** | Artifact | +|------|-------------|-------------|----------| +| 42 | 1.0788 | **1.0774** | 15996370 | +| 423 | 1.0787 | **1.0767** | 15998081 | +| 424 | 1.0777 | **1.0765** | 15996509 | +| **Mean** | **1.0784** | **1.0769** | **15996987** | +| **Std** | **0.0005** | **0.0004** | | + +Based on SOTA (PR #1493): **1.0810 BPP**. Delta: **-0.0041 BPP ~= -0.0107 nats**. Meets the 0.005-nat threshold. + +## Key Techniques + +### My Contribution + +**MUDD Connections** +I introduce a lite version of **MU**ltiway **D**ynamic **D**ense (MUDD) Connections ([**MUDDFormer**](https://arxiv.org/abs/2502.12170)), and remove sigmoid-gated U-Net connections and residual mixing with x0, both of which can be seen as special cases of MUDD Connections. Although MUDD Connections is more comprehensive, the full version brings more overhead than performance gain. To reduce the overhead of MUDD Connections, I restrict connections from the following three aspects: +- **Query side** roughly a 2-stride query dilation; layer indices to make layer-aggregation: [2,4,6,8,10,12,15,16] +- **Key side** adopt interleaved local/global windows along layer-dimension and use more global windows for upper layers; corresponding window sizes: [2, None, 2, None, 2, None, None, None] (None represents global window) +- **Num of ways** keep V-stream and R-stream (the two more important streams in Q/K/V/R) for layer 12, 15, and retain only R-stream for other layers; corresponding number of ways: [1, 1, 1, 1, 1, 2, 2, 1] + +In addition, MUDD Connections prefers wide V-stream, so I switch GQA back to MHA. + +### Previous Credits +1. **SP8192 + GPTQ SDClip** — int6 matrices (k=12.85), int8 embeddings (k=20.0), zero selective pruning (PR #1394 @clarkkev) +2. **3-Layer Depth Recurrence** (layers 3,4,5, activate at frac=0.35) — 17 virtual layers from 11 physical (PR #1331 @dexhunter, PR #1437 @dexhunter) +3. **Parallel Residuals** (layers 7+) — GPT-J style, attention and MLP read from same input (PR #1412 @Robby955, PR #1204 @msisovic) +4. **QK-Gain 5.25** — learnable per-head query scaling, monotonic improvement from 4.0 to 5.25 +5. **Legal Score-First TTT** — SGD (lr=0.005, momentum=0.9), 3 epochs per 32K-token chunk, cosine LR decay. Score-before-update ordering. (PR #549 @abaybektursun, PR #1413 @dexhunter) +6. **Tuned Hyperparameters** — WD=0.095, MLR=0.022, EMA=0.9965, warmdown=0.72 (PR #1445 @X-Abhishek-X) +7. **LZMA code wrapper** — save bytes for code + +## Architecture + +11L x 512d x 8H / 8KV, MLP 3.5x, LeakyReLU(0.5)^2, Partial RoPE (16/64 dims), layerwise LN scale, tied embeddings, logit softcap=30.0. Depth recurrence: encoder [0,1,2,3,4,5,3,4] decoder [5,3,4,5,6,7,8,9,10] (loops layers 3-5, activated at step ~2016). Parallel residuals from layer 7: attention and MLP operate on same pre-residual input. Replace sigmoid-gated U-Net connections and residual mixing with x0 by **MUDD Connections**. + + +## Training + +MuonEq-R optimizer (row-normalized Muon, Newton-Schulz 5 steps), AdamW for embeddings/scalars. 4367 steps in 588s on 8xH100 HBM3. Linear warmdown to LR=0 over final 72% of training. EMA decay 0.9965. + +## Quantization + +Full-Hessian GPTQ with SDClip: `clip = k * std(row)` for principled rate-distortion. int6 for attention/MLP matrices and part of dynamic dense matrices, int8 for token embeddings. Byte-shuffle + Brotli-11 compression. Zero selective pruning needed -- model fits natively under 16MB. + +## TTT (Test-Time Training) + +Score-first, chunk-based SGD adaptation at eval time: +- Chunk val tokens into 32K-token chunks +- For each chunk: (1) score all sliding windows under `torch.no_grad()`, (2) train model on scored chunk tokens with SGD +- 3 epochs per chunk, cosine LR decay across chunks +- Gradient clipping at 1.0, distributed all-reduce for multi-GPU +- Total TTT eval time: ~371s (within 600s eval budget) + +## Compliance + +Per Issue #1017 (Track B -- legal eval-time adaptation): + +- **Condition 1 (Causality):** Sliding-window eval is strictly causal. Each position scored from prefix tokens only. +- **Condition 2 (Normalized distribution):** Standard softmax over full vocab. No n-gram cache, no logit biasing. +- **Condition 3 (Score before update):** Each chunk fully scored under `torch.no_grad()` BEFORE any SGD update. Training only on already-scored tokens. +- **Condition 4 (Single pass):** Each token scored exactly once. No rescoring, no multi-pass selection. + +Additional: +- No SLOT (standard or causal) +- No pre-quant TTT on val data (model quantized once during training, TTT adapts at eval time) +- No ETLB (eval-time logit bias) +- No n-gram cache or tilt +- All artifacts under 16,000,000 bytes on all 3 seeds +- Training under 600s on all 3 seeds (~588s actual) +- Eval (sliding + TTT) under 600s on all 3 seeds (~500s actual) + +## Reproduction + +```bash +pip install brotli sentencepiece +pip install flash_attn_3 --no-deps --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291/ +MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf python3 data/cached_challenge_fineweb.py --variant sp8192 + +NCCL_NET=Socket NCCL_DEBUG=WARN SEED=423 QK_GAIN_INIT=5.25 TTT_ENABLED=1 TTT_LR=0.005 TTT_EPOCHS=3 \ +RUN_ID=baseline0409_mudd_seed42 USE_MUDD=1 KEEP_UNET=0 MLP_MULT=3.5 NUM_KV_HEADS=80 WARMUP_STEPS=150 TENSORBOARD_DIR='' \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +## Credits + +- **@bigbigbag** - SOTA Baseline 04-09 +- **@clarkkev** — SP8192 + GPTQ Embeddings + SDClip + MuonEq-R + depth recurrence (PR #1394) +- **@dexhunter** — 3-layer depth recurrence (PR #1331, #1437), legal TTT on SP8192 (PR #1413) +- **@abaybektursun** — Score-first TTT framework (PR #549, merged precedent) +- **@Robby955** — Parallel residuals on SP8192 (PR #1412) +- **@msisovic** — Parallel residuals concept (PR #1204) +- **@X-Abhishek-X** — Hyperparameter tuning: WD=0.095, MLR=0.022, EMA=0.9965 (PR #1445, #1471) + +## Acknowledgements + +Thanks to **ColorfulClouds Tech** for providing compute. Thanks to **@Lisennlp** and **@xiaoda99** for valuable discussions on reducing the overhead of MUDD Connections. + +## Included Files + +- `README.md` (this file) +- `submission.json` +- `train_gpt.py` +- `train_seed42.log` +- `train_seed423.log` +- `train_seed424.log` diff --git a/records/track_10min_16mb/2026-04-29_MUDD_Connections/submission.json b/records/track_10min_16mb/2026-04-29_MUDD_Connections/submission.json new file mode 100644 index 0000000000..462f62550c --- /dev/null +++ b/records/track_10min_16mb/2026-04-29_MUDD_Connections/submission.json @@ -0,0 +1,37 @@ +{ + "author": "Qingye Meng", + "github_id": "hilbertmeng", + "name": "MUDD Connections", + "date": "2026-04-29", + "track": "10min_16mb", + "val_bpb": 1.0769, + "val_bpb_std": 0.00042, + "seeds": [42, 423, 424], + "seed_results": { + "42": {"val_bpb": 1.0788, "artifact_bytes": 15996370}, + "423": {"val_bpb": 1.0787, "artifact_bytes": 15998081}, + "424": {"val_bpb": 1.0777, "artifact_bytes": 15996509} + }, + "hardware": "NVIDIA H100 80GB HBM3", + "pytorch_version": "2.9.1+cu128", + "technique_summary": "MUDD Connections + SP8192 + 3-Layer Depth Recurrence (L3-5) + Parallel Residuals (L7+) + QK-Gain 5.25 + EMA 0.9965 + WD 0.095 + Score-First TTT (SGD 3ep) + GPTQ SDClip + Brotli", + "compliance": { + "train_under_600s": true, + "artifact_under_16mb": true, + "eval_under_600s": true, + "no_slot": true, + "no_pre_quant_ttt": true, + "no_etlb": true, + "no_ngram_cache": true, + "score_first_ttt": true, + "three_seeds": true + }, + "attribution": { + "baseline": "@bigbag (PR #1493)", + "sp8192_gptq_sdclip": "@clarkkev (PR #1394)", + "depth_recurrence": "@dexhunter (PR #1331, #1437)", + "parallel_residuals": "@Robby955 (PR #1412), @msisovic (PR #1204)", + "legal_ttt_framework": "@abaybektursun (PR #549), @dexhunter (PR #1413)", + "hyperparameter_tuning": "@X-Abhishek-X (PR #1445)" + } +} diff --git a/records/track_10min_16mb/2026-04-29_MUDD_Connections/train_gpt.py b/records/track_10min_16mb/2026-04-29_MUDD_Connections/train_gpt.py new file mode 100644 index 0000000000..d3d25bcf7d --- /dev/null +++ b/records/track_10min_16mb/2026-04-29_MUDD_Connections/train_gpt.py @@ -0,0 +1,658 @@ +import collections,copy,glob,io,lzma,math,os +from pathlib import Path +import random,re,subprocess,sys,time,uuid,numpy as np,sentencepiece as spm,torch,torch.distributed as dist,torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +from torch import Tensor,nn +from torch.utils.tensorboard import SummaryWriter +from flash_attn_interface import flash_attn_func as flash_attn_3_func +from einops import rearrange +from typing import Optional, Tuple + +import lzma, base64 + +class Hyperparameters:data_dir=os.environ.get('DATA_DIR','./data/');seed=int(os.environ.get('SEED',1337));run_id=os.environ.get('RUN_ID',str(uuid.uuid4()));iterations=int(os.environ.get('ITERATIONS',20000));warmdown_frac=float(os.environ.get('WARMDOWN_FRAC',.72));warmup_steps=int(os.environ.get('WARMUP_STEPS',20));train_batch_tokens=int(os.environ.get('TRAIN_BATCH_TOKENS',786432));train_seq_len=int(os.environ.get('TRAIN_SEQ_LEN',2048));train_log_every=int(os.environ.get('TRAIN_LOG_EVERY',500));max_wallclock_seconds=float(os.environ.get('MAX_WALLCLOCK_SECONDS',6e2));val_batch_tokens=int(os.environ.get('VAL_BATCH_TOKENS',524288));eval_seq_len=int(os.environ.get('EVAL_SEQ_LEN',2048));val_loss_every=int(os.environ.get('VAL_LOSS_EVERY',4000));sliding_window_enabled=bool(int(os.environ.get('SLIDING_WINDOW_ENABLED','1')));vocab_size=int(os.environ.get('VOCAB_SIZE',8192));num_layers=int(os.environ.get('NUM_LAYERS',11));xsa_last_n=int(os.environ.get('XSA_LAST_N',11));model_dim=int(os.environ.get('MODEL_DIM',512));embedding_dim=int(os.environ.get('EMBEDDING_DIM',512));num_kv_heads=int(os.environ.get('NUM_KV_HEADS',4));num_heads=int(os.environ.get('NUM_HEADS',8));mlp_mult=float(os.environ.get('MLP_MULT',4.));skip_gates_enabled=bool(int(os.environ.get('SKIP_GATES_ENABLED','1')));tie_embeddings=bool(int(os.environ.get('TIE_EMBEDDINGS','1')));logit_softcap=float(os.environ.get('LOGIT_SOFTCAP',3e1));rope_base=float(os.environ.get('ROPE_BASE',1e4));rope_dims=int(os.environ.get('ROPE_DIMS',16));rope_train_seq_len=int(os.environ.get('ROPE_TRAIN_SEQ_LEN',2048));ln_scale=bool(int(os.environ.get('LN_SCALE','1')));qk_gain_init=float(os.environ.get('QK_GAIN_INIT',5.));num_loops=int(os.environ.get('NUM_LOOPS',2));loop_start=int(os.environ.get('LOOP_START',3));loop_end=int(os.environ.get('LOOP_END',5));enable_looping_at=float(os.environ.get('ENABLE_LOOPING_AT',.35));parallel_residual_start=int(os.environ.get('PARALLEL_RESIDUAL_START',7));min_lr=float(os.environ.get('MIN_LR',.0));embed_lr=float(os.environ.get('EMBED_LR',.6));head_lr=float(os.environ.get('HEAD_LR',.008));tied_embed_lr=float(os.environ.get('TIED_EMBED_LR',.03));tied_embed_init_std=float(os.environ.get('TIED_EMBED_INIT_STD',.005));matrix_lr=float(os.environ.get('MATRIX_LR',.022));scalar_lr=float(os.environ.get('SCALAR_LR',.02));muon_momentum=float(os.environ.get('MUON_MOMENTUM',.99));muon_backend_steps=int(os.environ.get('MUON_BACKEND_STEPS',5));muon_momentum_warmup_start=float(os.environ.get('MUON_MOMENTUM_WARMUP_START',.92));muon_momentum_warmup_steps=int(os.environ.get('MUON_MOMENTUM_WARMUP_STEPS',1500));muon_row_normalize=bool(int(os.environ.get('MUON_ROW_NORMALIZE','1')));beta1=float(os.environ.get('BETA1',.9));beta2=float(os.environ.get('BETA2',.95));adam_eps=float(os.environ.get('ADAM_EPS',1e-08));grad_clip_norm=float(os.environ.get('GRAD_CLIP_NORM',.3));eval_stride=int(os.environ.get('EVAL_STRIDE',64));muon_beta2=float(os.environ.get('MUON_BETA2',.95));adam_wd=float(os.environ.get('ADAM_WD',.02));muon_wd=float(os.environ.get('MUON_WD',.095));embed_wd=float(os.environ.get('EMBED_WD',.085));ema_decay=float(os.environ.get('EMA_DECAY',.9965));ttt_enabled=bool(int(os.environ.get('TTT_ENABLED','0')));ttt_lr=float(os.environ.get('TTT_LR',.005));ttt_epochs=int(os.environ.get('TTT_EPOCHS',3));ttt_momentum=float(os.environ.get('TTT_MOMENTUM',.9));ttt_chunk_tokens=int(os.environ.get('TTT_CHUNK_TOKENS',32768));etlb_enabled=bool(int(os.environ.get('ETLB_ENABLED','0')));etlb_lr=float(os.environ.get('ETLB_LR',.05));etlb_steps=int(os.environ.get('ETLB_STEPS',5));etlb_clip=float(os.environ.get('ETLB_CLIP',3.));compressor=os.environ.get('COMPRESSOR','brotli');gptq_calibration_batches=int(os.environ.get('GPTQ_CALIBRATION_BATCHES',64));gptq_reserve_seconds=float(os.environ.get('GPTQ_RESERVE_SECONDS',12.));matrix_bits=int(os.environ.get('MATRIX_BITS',6));embed_bits=int(os.environ.get('EMBED_BITS',8));matrix_clip_sigmas=float(os.environ.get('MATRIX_CLIP_SIGMAS',12.85));embed_clip_sigmas=float(os.environ.get('EMBED_CLIP_SIGMAS',2e1));distributed='RANK'in os.environ and'WORLD_SIZE'in os.environ;rank=int(os.environ.get('RANK','0'));world_size=int(os.environ.get('WORLD_SIZE','1'));local_rank=int(os.environ.get('LOCAL_RANK','0'));is_main_process=rank==0;grad_accum_steps=8//world_size;datasets_dir=os.path.join(data_dir,'datasets',f"fineweb10B_sp{vocab_size}");train_files=os.path.join(datasets_dir,'fineweb_train_*.bin');val_files=os.path.join(datasets_dir,'fineweb_val_*.bin');tokenizer_path=os.path.join(data_dir,'tokenizers',f"fineweb_{vocab_size}_bpe.model");logfile=f"logs/{run_id}.txt";model_path='final_model.pt';quantized_model_path='final_model.int6.ptz';tensorboard_dir = os.environ.get("TENSORBOARD_DIR", "./logs/tensorboard");use_mudd=bool(int(os.environ.get('USE_MUDD','0')));mudd_q_dilation=int(os.environ.get('MUDD_Q_DILATION','1'));mudd_k_dilation=int(os.environ.get('MUDD_K_DILATION','1'));keep_unet=bool(int(os.environ.get('KEEP_UNET','1')));mudd_emb=bool(int(os.environ.get('MUDD_EMB','0'))); +_logger_hparams=None +def set_logging_hparams(h):global _logger_hparams;_logger_hparams=h +def log(msg,console=True): + if _logger_hparams is None:print(msg);return + if _logger_hparams.is_main_process: + if console:print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile,'a',encoding='utf-8')as f:print(msg,file=f) +class ValidationData: + def __init__(self,h,device): + self.sp=spm.SentencePieceProcessor(model_file=h.tokenizer_path) + if int(self.sp.vocab_size())!=h.vocab_size:raise ValueError(f"VOCAB_SIZE={h.vocab_size} does not match tokenizer vocab_size={int(self.sp.vocab_size())}") + self.val_tokens=load_validation_tokens(h.val_files,h.eval_seq_len);self.base_bytes_lut,self.has_leading_space_lut,self.is_boundary_token_lut=build_sentencepiece_luts(self.sp,h.vocab_size,device) +def build_sentencepiece_luts(sp,vocab_size,device): + sp_vocab_size=int(sp.vocab_size());assert sp.piece_to_id('▁')!=sp.unk_id(),"Tokenizer must have '▁' (space) as its own token for correct BPB byte counting";table_size=max(sp_vocab_size,vocab_size);base_bytes_np=np.zeros((table_size,),dtype=np.int16);has_leading_space_np=np.zeros((table_size,),dtype=np.bool_);is_boundary_token_np=np.ones((table_size,),dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id)or sp.is_unknown(token_id)or sp.is_unused(token_id):continue + is_boundary_token_np[token_id]=False + if sp.is_byte(token_id):base_bytes_np[token_id]=1;continue + piece=sp.id_to_piece(token_id) + if piece.startswith('▁'):has_leading_space_np[token_id]=True;piece=piece[1:] + base_bytes_np[token_id]=len(piece.encode('utf-8')) + return torch.tensor(base_bytes_np,dtype=torch.int16,device=device),torch.tensor(has_leading_space_np,dtype=torch.bool,device=device),torch.tensor(is_boundary_token_np,dtype=torch.bool,device=device) +def load_validation_tokens(pattern,seq_len): + files=[Path(p)for p in sorted(glob.glob(pattern))] + if not files:raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens=torch.cat([load_data_shard(file)for file in files]).contiguous();usable=(tokens.numel()-1)//seq_len*seq_len + if usable<=0:raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[:usable+1] +def load_data_shard(file): + header_bytes=256*np.dtype('0 else 0;num_sequences=(self.num_tokens[si]-1-phase)//self.seq_len;sequence_order=self.rng.permutation(num_sequences);self.start_inds[si]=(phase+sequence_order*self.seq_len).tolist() + def next_batch(self,global_tokens,grad_accum_steps): + device_tokens=global_tokens//(self.world_size*grad_accum_steps);device_batch_size=device_tokens//self.seq_len;remaining=np.array([len(s)for s in self.start_inds],dtype=np.float64);x=torch.empty((device_batch_size,self.seq_len),dtype=torch.int64);y=torch.empty((device_batch_size,self.seq_len),dtype=torch.int64) + for bi in range(device_batch_size): + total=remaining.sum() + if total<=0: + for si in range(len(self.files)):self._reset_shard(si) + remaining=np.array([len(s)for s in self.start_inds],dtype=np.float64);total=remaining.sum() + probs=remaining/total;si=int(self.rng.choice(len(self.files),p=probs));start_ind=self.start_inds[si].pop();remaining[si]-=1;mm=_get_shard_memmap(self.files[si]);window=torch.as_tensor(np.array(mm[start_ind:start_ind+self.seq_len+1],dtype=np.int64));x[bi]=window[:-1];y[bi]=window[1:] + return x.to(self.device,non_blocking=True),y.to(self.device,non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self,eps=None):super().__init__();self.eps=eps + def forward(self,x):return F.rms_norm(x,(x.size(-1),),eps=self.eps) +class CastedLinear(nn.Linear): + def forward(self,x):w=self.weight.to(x.dtype);bias=self.bias.to(x.dtype)if self.bias is not None else None;return F.linear(x,w,bias) +class Rotary(nn.Module): + def __init__(self,dim,base=1e4,train_seq_len=1024,rope_dims=0):super().__init__();self.dim=dim;self.base=base;self.train_seq_len=train_seq_len;self.rope_dims=rope_dims if rope_dims>0 else dim;inv_freq=1./base**(torch.arange(0,self.rope_dims,2,dtype=torch.float32)/self.rope_dims);self.register_buffer('inv_freq',inv_freq,persistent=False);self._seq_len_cached=0;self._cos_cached=None;self._sin_cached=None + def forward(self,seq_len,device,dtype): + if self._cos_cached is None or self._sin_cached is None or self._seq_len_cached!=seq_len or self._cos_cached.device!=device: + rd=self.rope_dims + if seq_len>self.train_seq_len:scale=seq_len/self.train_seq_len;new_base=self.base*scale**(rd/(rd-2));inv_freq=1./new_base**(torch.arange(0,rd,2,dtype=torch.float32,device=device)/rd) + else:inv_freq=self.inv_freq.to(device) + t=torch.arange(seq_len,device=device,dtype=inv_freq.dtype);freqs=torch.outer(t,inv_freq);self._cos_cached=freqs.cos()[None,:,None,:];self._sin_cached=freqs.sin()[None,:,None,:];self._seq_len_cached=seq_len + return self._cos_cached.to(dtype=dtype),self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x,cos,sin,rope_dims=0): + if rope_dims>0 and rope_dims C B T L', C=self.C) + return dw + def layer_mix(self, x, all_hids, dw, hidden_masks=None): + L = dw.shape[3] + if L == 1: # aggregate only embedding + hids = all_hids[:1] + elif L > 2: + hids = all_hids[:1] + all_hids[-(L-2):] + all_hids[-1:] + else: # aggregate embedding and current layer + hids = all_hids[:1] + all_hids[-1:] + scale = self.scale.to(dtype=hids[0].dtype).view(self.C, 1, 1, -1) + weighted = dw[:, :, :, 0, None] * hids[0] + for j in range(1, L):weighted = weighted + dw[:, :, :, j, None] * hids[j] + result = F.rms_norm(weighted, (weighted.size(-1),)) * scale + return tuple(result[c] for c in range(self.C)) if self.C > 1 else result[0] +class CausalSelfAttention(nn.Module): + def __init__(self,dim,num_heads,num_kv_heads,rope_base,qk_gain_init,train_seq_len): + super().__init__() + if dim%num_heads!=0:raise ValueError('model_dim must be divisible by num_heads') + if num_heads%num_kv_heads!=0:raise ValueError('num_heads must be divisible by num_kv_heads') + self.num_heads=num_heads;self.num_kv_heads=num_kv_heads;self.head_dim=dim//num_heads + if self.head_dim%2!=0:raise ValueError('head_dim must be even for RoPE') + kv_dim=self.num_kv_heads*self.head_dim;self.c_q=CastedLinear(dim,dim,bias=False);self.c_k=CastedLinear(dim,kv_dim,bias=False);self.c_v=CastedLinear(dim,kv_dim,bias=False);self.proj=CastedLinear(dim,dim,bias=False);self.proj._zero_init=True;self.q_gain=nn.Parameter(torch.full((num_heads,),qk_gain_init,dtype=torch.float32));self.rope_dims=0;self.rotary=Rotary(self.head_dim,base=rope_base,train_seq_len=train_seq_len);self.use_xsa=False + def _xsa_efficient(self,y,v):B,T,H,D=y.shape;Hkv=v.size(-2);group=H//Hkv;y_g=y.reshape(B,T,Hkv,group,D);vn=F.normalize(v,dim=-1).unsqueeze(-2);proj=(y_g*vn).sum(dim=-1,keepdim=True)*vn;return(y_g-proj).reshape(B,T,H,D) + def forward(self,x, qkvway=None): + bsz, seqlen, dim = x.shape + xq, xk, xv = x, x, x + if qkvway is not None: + if isinstance(qkvway, tuple) and len(qkvway) == 3: + q=self.c_q(xq+qkvway[0]).reshape(bsz,seqlen,self.num_heads,self.head_dim) + k=self.c_k(xk+qkvway[1]).reshape(bsz,seqlen,self.num_kv_heads,self.head_dim) + v= self.c_v(xv + qkvway[2]).reshape(bsz,seqlen,self.num_kv_heads,self.head_dim) + elif isinstance(qkvway, tuple) and len(qkvway) == 2: + q=self.c_q(xq+qkvway[0]).reshape(bsz,seqlen,self.num_heads,self.head_dim) + k=self.c_k(xk+qkvway[0]).reshape(bsz,seqlen,self.num_kv_heads,self.head_dim) + v= self.c_v(xv + qkvway[1]).reshape(bsz,seqlen,self.num_kv_heads,self.head_dim) + elif isinstance(qkvway, tuple) and len(qkvway) == 1: + q=self.c_q(xq).reshape(bsz,seqlen,self.num_heads,self.head_dim) + k=self.c_k(xk).reshape(bsz,seqlen,self.num_kv_heads,self.head_dim) + v= self.c_v(xv + qkvway[0]).reshape(bsz,seqlen,self.num_kv_heads,self.head_dim) + else: + q=self.c_q(xq).reshape(bsz,seqlen,self.num_heads,self.head_dim) + k=self.c_k(xk).reshape(bsz,seqlen,self.num_kv_heads,self.head_dim) + v= self.c_v(xv).reshape(bsz,seqlen,self.num_kv_heads,self.head_dim) + q=F.rms_norm(q,(q.size(-1),));k=F.rms_norm(k,(k.size(-1),));cos,sin=self.rotary(seqlen,xq.device,q.dtype);q=apply_rotary_emb(q,cos,sin,self.rope_dims);k=apply_rotary_emb(k,cos,sin,self.rope_dims);q=q*self.q_gain.to(dtype=q.dtype)[None,None,:,None] + y=flash_attn_3_func(q,k,v,causal=True) + if self.use_xsa:y=self._xsa_efficient(y,v) + y=y.reshape(bsz,seqlen,dim);return self.proj(y) + +class MLP(nn.Module): + def __init__(self,dim,mlp_mult, lidx=None, use_mudd=False): + super().__init__() + hidden= int(mlp_mult*dim) + if lidx == 2: # reduce mlp hidden dim to keep the total params under the limit of 16M + hidden -= 64 + self.fc=CastedLinear(dim,hidden,bias=False);self.proj=CastedLinear(hidden,dim,bias=False);self.proj._zero_init=True + def forward(self,x):return self.proj(F.leaky_relu(self.fc(x),negative_slope=.5).square()) +class Block(nn.Module): + def __init__(self,dim,num_heads,num_kv_heads,mlp_mult,rope_base,qk_gain_init,train_seq_len,layer_idx=0,ln_scale=False,use_mudd=False,keep_unet=True): + super().__init__();self.attn_norm=RMSNorm();self.mlp_norm=RMSNorm() + self.attn=CausalSelfAttention(dim,num_heads,num_kv_heads,rope_base,qk_gain_init,train_seq_len) + self.attn_scale=nn.Parameter(torch.ones(dim,dtype=torch.float32)) + self.mlp=MLP(dim,mlp_mult, lidx=layer_idx, use_mudd=use_mudd) + self.mlp_scale=nn.Parameter(torch.ones(dim,dtype=torch.float32));self.resid_mix=nn.Parameter(torch.stack((torch.ones(dim),torch.zeros(dim))).float()) if keep_unet else None; self.keep_unet=keep_unet + self.ln_scale_factor=1./math.sqrt(layer_idx+1)if ln_scale else 1.;self.parallel=False;self.use_mudd=use_mudd + def forward(self,x, x0, vm=None, is_recurrent=False, looping_active=False): + if is_recurrent and not looping_active: + return x[-1] if isinstance(x, tuple) else x + if self.resid_mix is None: + x_in = x + else: + mix=self.resid_mix.to(dtype=x.dtype) + x_in=mix[0][None,None,:]*x+mix[1][None,None,:]*x0 + qkvway = None + if vm is not None: + if isinstance(vm, tuple): # 2/3-way + qkvway = vm[:-1];normed_x = self.attn_norm(x_in)*self.ln_scale_factor + x_in = x_in + vm[-1] + else: # 1-way + x_in = x_in + vm + normed_x = self.attn_norm(x_in)*self.ln_scale_factor + else: + normed_x = self.attn_norm(x_in)*self.ln_scale_factor + attn_out=self.attn(normed_x, qkvway=qkvway) * self.attn_scale.to(dtype=x_in.dtype)[None,None,:] + if self.parallel: + mlp_out=self.mlp(self.mlp_norm(x_in)*self.ln_scale_factor) + x_out=x_in+attn_out+self.mlp_scale.to(dtype=x_in.dtype)[None,None,:]*mlp_out + else: + x_out=x_in+attn_out;x_out=x_out+self.mlp_scale.to(dtype=x_out.dtype)[None,None,:]*self.mlp(self.mlp_norm(x_out)*self.ln_scale_factor) + return x_out +class GPT(nn.Module): + def __init__(self,h): + super().__init__() + if h.logit_softcap<=.0:raise ValueError(f"logit_softcap must be positive, got {h.logit_softcap}") + self.tie_embeddings=h.tie_embeddings;self.tied_embed_init_std=h.tied_embed_init_std;self.logit_softcap=h.logit_softcap;self.tok_emb=nn.Embedding(h.vocab_size,h.embedding_dim) + self.embedding_dim=h.embedding_dim + if h.embedding_dim!=h.model_dim:self.embed_proj=CastedLinear(h.embedding_dim,h.model_dim,bias=False);self.head_proj=CastedLinear(h.model_dim,h.embedding_dim,bias=False) + else:self.embed_proj=None;self.head_proj=None + self.use_mudd=h.use_mudd;self.num_layers=h.num_layers + self.num_encoder_layers=h.num_layers//2;self.num_decoder_layers=h.num_layers-self.num_encoder_layers + self.blocks=nn.ModuleList([Block(h.model_dim,h.num_heads,h.num_kv_heads,h.mlp_mult,h.rope_base,h.qk_gain_init,h.train_seq_len,layer_idx=i,ln_scale=h.ln_scale,use_mudd=h.use_mudd, keep_unet=h.keep_unet)for i in range(h.num_layers)]) + if h.rope_dims>0: + head_dim=h.model_dim//h.num_heads + for block in self.blocks: + if block.attn is not None: + block.attn.rope_dims=h.rope_dims;block.attn.rotary=Rotary(head_dim,base=h.rope_base,train_seq_len=h.train_seq_len,rope_dims=h.rope_dims) + self.final_norm=RMSNorm();self.lm_head=None if h.tie_embeddings else CastedLinear(h.embedding_dim,h.vocab_size,bias=False) + if self.lm_head is not None:self.lm_head._zero_init=True + if h.xsa_last_n>0: + for i in range(max(0,h.num_layers-h.xsa_last_n),h.num_layers): + if self.blocks[i].attn is not None: + self.blocks[i].attn.use_xsa=True + if h.parallel_residual_start>=0: + for i in range(h.parallel_residual_start,h.num_layers):self.blocks[i].parallel=True + self.looping_active=False + self.is_recur_indices = [False] * (h.num_layers + h.num_loops* (h.loop_end- h.loop_start+1)) + if h.num_loops>0: + loop_seg=list(range(h.loop_start,h.loop_end+1));all_indices=list(range(h.loop_start)) + for _ in range(h.num_loops+1):all_indices.extend(loop_seg) + all_indices.extend(range(h.loop_end+1,h.num_layers));num_enc=len(all_indices)//2;self.encoder_indices=all_indices[:num_enc];self.decoder_indices=all_indices[num_enc:] + if h.use_mudd: + is_recur_indices = [] + for i,j in enumerate(all_indices): + if i==0: + is_recur_indices.append(False) + else: + is_recur_indices.append(j in all_indices[:i-1]) + self.is_recur_indices = is_recur_indices + else:self.encoder_indices=list(range(self.num_encoder_layers));self.decoder_indices=list(range(self.num_encoder_layers,h.num_layers)) + self.num_skip_weights=min(len(self.encoder_indices),len(self.decoder_indices)) + use_mudd = h.use_mudd + self.skip_weights=nn.Parameter(torch.ones(self.num_skip_weights,h.model_dim,dtype=torch.float32)) if h.keep_unet else None + self.skip_gates=nn.Parameter(torch.zeros(self.num_skip_weights,h.model_dim,dtype=torch.float32)) if h.skip_gates_enabled and h.keep_unet else None + if h.use_mudd: + if h.mudd_emb: + self.num_mudd_embs = 1 + self.mudd_emb = nn.Embedding(h.vocab_size,h.embedding_dim * self.num_mudd_embs) + else: + self.mudd_emb = None + looped_num_layers = len(all_indices) + self.mudd_q_dilation=h.mudd_q_dilation + self.mudd_k_dilation=h.mudd_k_dilation + num_base_layers=1 if self.mudd_emb is None else 2 + self.num_ways = [1]*12 + [2] * 5 + self.mudd_q_indices = [2,4,6,8,10,12,15,16] + local_window_sizes= [None, None, 2,None]*5 + self.dynamic_dense=nn.ModuleList([MultiwayDynamicDenseBlock(h.model_dim,i,last_layer=i==looped_num_layers-1,multiway=True,k_dilation=self.mudd_k_dilation,base_layer=num_base_layers,num_ways=self.num_ways[i],local_window_size=local_window_sizes[i]) if i in self.mudd_q_indices else None for i in range(looped_num_layers)]) + else:self.dynamic_dense=nn.ModuleList() + self._init_weights() + def _init_weights(self): + if self.tie_embeddings:nn.init.normal_(self.tok_emb.weight,mean=.0,std=self.tied_embed_init_std) + if self.mudd_emb:nn.init.normal_(self.mudd_emb.weight,mean=.0,std=self.tied_embed_init_std) + for(name,module)in self.named_modules(): + if isinstance(module,nn.Linear): + if getattr(module,'_zero_init',False):nn.init.zeros_(module.weight) + elif'dynamic_dense'in name:nn.init.normal_(module.weight,mean=0.0,std=0.006) + elif module.weight.ndim==2 and module.weight.shape[0]>=64 and module.weight.shape[1]>=64:nn.init.orthogonal_(module.weight,gain=1.) + def forward_logits(self,input_ids): + x=self.tok_emb(input_ids);x=F.rms_norm(x,(x.size(-1),)) + if self.embed_proj is not None:x=self.embed_proj(x) + x0=x;hiddens=[];skips=[];enc_iter=self.encoder_indices if self.looping_active or self.use_mudd else range(self.num_encoder_layers);dec_iter=self.decoder_indices if self.looping_active or self.use_mudd else range(self.num_encoder_layers,self.num_encoder_layers+self.num_decoder_layers) + if self.mudd_emb is not None: + mudd_embs = self.mudd_emb(input_ids) + for i in range(self.num_mudd_embs): + hiddens.append(F.rms_norm(mudd_embs[:,:,i*self.embedding_dim:(i+1)*self.embedding_dim],(self.embedding_dim,))) + if self.use_mudd:hiddens.append(x) + mudd_idx=0;looped_num_layers=len(enc_iter)+len(dec_iter) if self.use_mudd else 0 + vm = None # value and mlp way + for _idx, i in enumerate(enc_iter): + x=self.blocks[i](x,x0, vm=vm, is_recurrent=self.is_recur_indices[_idx], looping_active=self.looping_active) + if self.skip_weights is not None: + skips.append(x) + if self.use_mudd: + if mudd_idx%self.mudd_k_dilation==0:hiddens.append(x) + if mudd_idx in self.mudd_q_indices: + dw=self.dynamic_dense[mudd_idx](x);mixed=self.dynamic_dense[mudd_idx].layer_mix(x,hiddens,dw) + vm=mixed + else: + vm=None + mudd_idx+=1 + for(skip_idx,i)in enumerate(dec_iter): + if self.skip_weights is not None and skip_idxG.size(1) + if transposed:X=X.T + for _ in range(steps):A=X@X.T;B=b*A+c*A@A;X=a*X+B@X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self,params,lr,momentum,backend_steps,nesterov=True,weight_decay=.0,row_normalize=False):super().__init__(params,dict(lr=lr,momentum=momentum,backend_steps=backend_steps,nesterov=nesterov,weight_decay=weight_decay,row_normalize=row_normalize)) + @torch.no_grad() + def step(self,closure=None): + loss=None + if closure is not None: + with torch.enable_grad():loss=closure() + distributed=dist.is_available()and dist.is_initialized();world_size=dist.get_world_size()if distributed else 1;rank=dist.get_rank()if distributed else 0 + for group in self.param_groups: + params=group['params'] + if not params:continue + lr=group['lr'];momentum=group['momentum'];backend_steps=group['backend_steps'];nesterov=group['nesterov'];total_params=sum(int(p.numel())for p in params);updates_flat=torch.zeros(total_params,device=params[0].device,dtype=torch.bfloat16);curr=0 + for(i,p)in enumerate(params): + if i%world_size==rank and p.grad is not None: + g=p.grad;state=self.state[p] + if'momentum_buffer'not in state:state['momentum_buffer']=torch.zeros_like(g) + buf=state['momentum_buffer'];buf.mul_(momentum).add_(g) + if nesterov:g=g.add(buf,alpha=momentum) + if group.get('row_normalize',False):row_norms=g.float().norm(dim=-1,keepdim=True).clamp_min(1e-07);g=g/row_norms.to(g.dtype) + g=zeropower_via_newtonschulz5(g,steps=backend_steps);g*=max(1,g.size(0)/g.size(1))**.5;updates_flat[curr:curr+p.numel()]=g.reshape(-1) + curr+=p.numel() + if distributed:dist.all_reduce(updates_flat,op=dist.ReduceOp.SUM) + wd=group.get('weight_decay',.0);curr=0 + for p in params: + if wd>.0:p.data.mul_(1.-lr*wd) + g=updates_flat[curr:curr+p.numel()].view_as(p).to(dtype=p.dtype);p.add_(g,alpha=-lr);curr+=p.numel() + return loss +CONTROL_TENSOR_NAME_PATTERNS=tuple(pattern for pattern in os.environ.get('CONTROL_TENSOR_NAME_PATTERNS','attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates').split(',')if pattern) +class Optimizers: + def __init__(self,h,base_model): + block_named_params=list(base_model.blocks.named_parameters());matrix_params=[p for(name,p)in block_named_params if p.ndim==2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)];scalar_params=[p for(name,p)in block_named_params if p.ndim<2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)] + if base_model.skip_weights is not None and base_model.skip_weights.numel()>0:scalar_params.append(base_model.skip_weights) + if base_model.skip_gates is not None and base_model.skip_gates.numel()>0:scalar_params.append(base_model.skip_gates) + mudd_scalar_params=[];mudd_matrix_params=[] + if base_model.use_mudd: + for block in base_model.dynamic_dense: + if block is not None:mudd_scalar_params.extend([p for p in block.parameters() if p.ndim<2]);mudd_matrix_params.extend([p for p in block.parameters() if p.ndim==2]) + scalar_params.extend(mudd_scalar_params+mudd_matrix_params) + token_lr=h.tied_embed_lr if h.tie_embeddings else h.embed_lr;tok_params=[{'params':[base_model.tok_emb.weight]+([base_model.mudd_emb.weight] if base_model.mudd_emb is not None else []),'lr':token_lr,'base_lr':token_lr}];self.optimizer_tok=torch.optim.AdamW(tok_params,betas=(h.beta1,h.beta2),eps=h.adam_eps,weight_decay=h.embed_wd,fused=True);self.optimizer_muon=Muon(matrix_params,lr=h.matrix_lr,momentum=h.muon_momentum,backend_steps=h.muon_backend_steps,weight_decay=h.muon_wd,row_normalize=h.muon_row_normalize) + for group in self.optimizer_muon.param_groups:group['base_lr']=h.matrix_lr + self.optimizer_scalar=torch.optim.AdamW([{'params':scalar_params,'lr':h.scalar_lr,'base_lr':h.scalar_lr}],betas=(h.beta1,h.beta2),eps=h.adam_eps,weight_decay=h.adam_wd,fused=True);self.optimizers=[self.optimizer_tok,self.optimizer_muon,self.optimizer_scalar] + if base_model.lm_head is not None:self.optimizer_head=torch.optim.Adam([{'params':[base_model.lm_head.weight],'lr':h.head_lr,'base_lr':h.head_lr}],betas=(h.beta1,h.beta2),eps=h.adam_eps,fused=True);self.optimizers.insert(1,self.optimizer_head) + else:self.optimizer_head=None + def __iter__(self):return iter(self.optimizers) + def zero_grad_all(self): + for opt in self.optimizers:opt.zero_grad(set_to_none=True) + def step(self): + for opt in self.optimizers:opt.step() + self.zero_grad_all() +def restore_fp32_params(model): + for module in model.modules(): + if isinstance(module,CastedLinear):module.float() + for(name,param)in model.named_parameters(): + if(param.ndim<2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS))and param.dtype!=torch.float32:param.data=param.data.float() +def collect_hessians(model,train_loader,h,device,n_calibration_batches=64): + hessians={};hooks=[] + def make_hook(name): + def hook_fn(module,inp,out): + x=inp[0].detach().float() + if x.ndim==3:x=x.reshape(-1,x.shape[-1]) + if name not in hessians:hessians[name]=torch.zeros(x.shape[1],x.shape[1],dtype=torch.float32,device=device) + hessians[name].addmm_(x.T,x) + return hook_fn + for(name,module)in model.named_modules(): + if isinstance(module, CastedLinear) and ( + module.weight.numel() > 65536//16 or + (hasattr(module, '_parent_name') and 'mudd' in module._parent_name and '.w1' in module._parent_name) + ): + cat=classify_param(name+'.weight') + if cat in('mlp','attn'):hooks.append(module.register_forward_hook(make_hook(name+'.weight'))) + if model.tie_embeddings: + hook_module=model.head_proj if model.head_proj is not None else model.final_norm + def make_output_hook(name): + def hook_fn(module,inp,out): + x=out.detach().float() + if x.ndim==3:x=x.reshape(-1,x.shape[-1]) + if name not in hessians:hessians[name]=torch.zeros(x.shape[1],x.shape[1],dtype=torch.float32,device=device) + hessians[name].addmm_(x.T,x) + return hook_fn + hooks.append(hook_module.register_forward_hook(make_output_hook('tok_emb.weight'))) + model.eval() + with torch.no_grad(): + for _ in range(n_calibration_batches):x,_=train_loader.next_batch(h.train_batch_tokens,h.grad_accum_steps);model.forward_logits(x) + for hook in hooks:hook.remove() + for name in hessians:hessians[name]=hessians[name].cpu()/n_calibration_batches + return hessians +def gptq_quantize_weight(w,H,clip_sigmas=3.,clip_range=63,block_size=128): + W_orig=w.float().clone();rows,cols=W_orig.shape;H=H.float().clone();dead=torch.diag(H)==0;H[dead,dead]=1;damp=.01*H.diag().mean();H.diagonal().add_(damp);perm=torch.argsort(H.diag(),descending=True);invperm=torch.argsort(perm);W_perm=W_orig[:,perm].clone();W_perm[:,dead[perm]]=0;H=H[perm][:,perm];Hinv=torch.cholesky_inverse(torch.linalg.cholesky(H));Hinv=torch.linalg.cholesky(Hinv,upper=True);row_std=W_orig.std(dim=1);s=(clip_sigmas*row_std/clip_range).clamp_min(1e-10).to(torch.float16);sf=s.float();Q=torch.zeros(rows,cols,dtype=torch.int8);W_work=W_perm.clone() + for i1 in range(0,cols,block_size): + i2=min(i1+block_size,cols);W_block=W_work[:,i1:i2].clone();Hinv_block=Hinv[i1:i2,i1:i2];Err=torch.zeros(rows,i2-i1) + for j in range(i2-i1):w_col=W_block[:,j];d=Hinv_block[j,j];q_col=torch.clamp(torch.round(w_col/sf),-clip_range,clip_range);Q[:,i1+j]=q_col.to(torch.int8);err=(w_col-q_col.float()*sf)/d;Err[:,j]=err;W_block[:,j:]-=err.unsqueeze(1)*Hinv_block[j,j:].unsqueeze(0) + if i20:out[name]=(q.float()*s.float().view(q.shape[0],*[1]*(q.ndim-1))).to(orig_dtype) + else:out[name]=(q.float()*float(s.item())).to(orig_dtype) + return out +_BSHF_MAGIC=b'BSHF' +def _byte_shuffle(data,stride=2): + if stride<=1 or len(data)0: + base_model.train();chunk_seqs=(chunk_end-chunk_start)//seq_len + if chunk_seqs>0: + cos_lr=h.ttt_lr*.5*(1.+math.cos(math.pi*ci/max(num_chunks-1,1))) + for pg in optimizer.param_groups:pg['lr']=cos_lr + my_seq_s=chunk_seqs*rank//world_size;my_seq_e=chunk_seqs*(rank+1)//world_size;my_chunk_seqs=my_seq_e-my_seq_s + for _ep in range(h.ttt_epochs): + for bs in range(0,my_chunk_seqs,batch_seqs): + be=min(bs+batch_seqs,my_chunk_seqs);actual_bs=my_seq_s+bs;start_tok=chunk_start+actual_bs*seq_len;end_tok=chunk_start+(my_seq_s+be)*seq_len+1 + if end_tok>val_data.val_tokens.numel():continue + local=val_data.val_tokens[start_tok:end_tok].to(device=device,dtype=torch.int64);x=local[:-1].reshape(-1,seq_len);y=local[1:].reshape(-1,seq_len);optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type='cuda',dtype=torch.bfloat16):loss=base_model(x,y) + loss.backward() + if world_size>1: + for p in ttt_params: + if p.grad is not None:dist.all_reduce(p.grad,op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params,1.);optimizer.step() + if dist.is_available()and dist.is_initialized():dist.all_reduce(loss_sum,op=dist.ReduceOp.SUM);dist.all_reduce(token_count,op=dist.ReduceOp.SUM);dist.all_reduce(byte_count,op=dist.ReduceOp.SUM) + for p in base_model.parameters():p.requires_grad_(True) + base_model.eval();return _loss_bpb(loss_sum,token_count,byte_count) +def timed_eval(label,fn,*args,**kwargs):torch.cuda.synchronize();t0=time.perf_counter();val_loss,val_bpb=fn(*args,**kwargs);torch.cuda.synchronize();elapsed_ms=1e3*(time.perf_counter()-t0);log(f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms");return val_loss,val_bpb +def train_model(h,device,val_data): + base_model=GPT(h).to(device).bfloat16();restore_fp32_params(base_model) + compiled_model=torch.compile(base_model,dynamic=False,fullgraph=True) + if h.distributed:model=DDP(compiled_model,device_ids=[h.local_rank],broadcast_buffers=False,find_unused_parameters=False) + else:model=compiled_model + if h.is_main_process: + print('model parameters:') + for name,param in model.named_parameters(): + print(name, param.shape, param.dtype, param.mean().item(), param.std().item()) + if h.use_mudd: + for block in base_model.dynamic_dense: + if block is not None: + print('mudd lidx', block.lidx, 'num_ways', block.C, 'local_window_size', block.local_window_size, 'scale', block.mudd_scale, 'k_components', block.w2.weight.shape[0]//block.C) + log(f"model_params:{sum(p.numel()for p in base_model.parameters())}");optimizers=Optimizers(h,base_model);train_loader=ShuffledSequenceLoader(h,device);tb_writer=None + if h.tensorboard_dir and h.is_main_process: + print('use tensorboard!!!') + tensorboard_dir=os.path.join(h.tensorboard_dir,h.run_id);os.makedirs(tensorboard_dir,exist_ok=True);tb_writer=SummaryWriter(log_dir=tensorboard_dir) + max_wallclock_ms=1e3*h.max_wallclock_seconds if h.max_wallclock_seconds>0 else None + if max_wallclock_ms is not None:max_wallclock_ms-=h.gptq_reserve_seconds*1e3;log(f"gptq:reserving {h.gptq_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms") + def training_frac(step,elapsed_ms): + if max_wallclock_ms is None:return step/max(h.iterations,1) + return elapsed_ms/max(max_wallclock_ms,1e-09) + def lr_mul(frac): + if h.warmdown_frac<=0:return 1. + if frac>=1.-h.warmdown_frac:return max((1.-frac)/h.warmdown_frac,h.min_lr) + return 1. + def step_fn(step,lr_scale): + optimizers.zero_grad_all();train_loss=torch.zeros((),device=device) + for micro_step in range(h.grad_accum_steps): + if h.distributed:model.require_backward_grad_sync=micro_step==h.grad_accum_steps-1 + x,y=train_loader.next_batch(h.train_batch_tokens,h.grad_accum_steps) + with torch.autocast(device_type='cuda',dtype=torch.bfloat16,enabled=True):loss=model(x,y) + train_loss+=loss.detach();(loss/h.grad_accum_steps).backward() + train_loss/=h.grad_accum_steps;frac=min(step/h.muon_momentum_warmup_steps,1.)if h.muon_momentum_warmup_steps>0 else 1.;muon_momentum=(1-frac)*h.muon_momentum_warmup_start+frac*h.muon_momentum + for group in optimizers.optimizer_muon.param_groups:group['momentum']=muon_momentum + for opt in optimizers: + for group in opt.param_groups:group['lr']=group['base_lr']*lr_scale + if h.grad_clip_norm>0:raw_grad_norm=torch.nn.utils.clip_grad_norm_(base_model.parameters(),h.grad_clip_norm) + cur_lr=float(optimizers.optimizer_muon.param_groups[0]['lr']) + optimizers.step() + return train_loss,raw_grad_norm,cur_lr + if h.warmup_steps>0: + initial_model_state={name:tensor.detach().cpu().clone()for(name,tensor)in base_model.state_dict().items()};initial_optimizer_states=[copy.deepcopy(opt.state_dict())for opt in optimizers];model.train() + for warmup_step in range(h.warmup_steps): + _,_,_=step_fn(warmup_step,1.) + if warmup_step<=5 or(warmup_step+1)%10==0 or warmup_step+1==h.warmup_steps:log(f"warmup_step: {warmup_step+1}/{h.warmup_steps}") + if h.num_loops>0: #and not h.use_mudd: + base_model.looping_active=True;log(f"loop_warmup:enabled encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}") + for warmup_step in range(h.warmup_steps): + _,_,_=step_fn(warmup_step,1.) + if warmup_step<=5 or(warmup_step+1)%10==0 or warmup_step+1==h.warmup_steps:log(f"loop_warmup_step: {warmup_step+1}/{h.warmup_steps}") + base_model.looping_active=False + base_model.load_state_dict(initial_model_state,strict=True) + for(opt,state)in zip(optimizers,initial_optimizer_states,strict=True):opt.load_state_dict(state) + optimizers.zero_grad_all() + if h.distributed:model.require_backward_grad_sync=True + train_loader=ShuffledSequenceLoader(h,device) + ema_state={name:t.detach().float().clone()for(name,t)in base_model.state_dict().items()};ema_decay=h.ema_decay;training_time_ms=.0;stop_after_step=None;torch.cuda.synchronize();t0=time.perf_counter();step=0 + while True: + last_step=step==h.iterations or stop_after_step is not None and step>=stop_after_step;should_validate=last_step or h.val_loss_every>0 and step%h.val_loss_every==0 and step>=0 + if should_validate: + torch.cuda.synchronize();training_time_ms+=1e3*(time.perf_counter()-t0);val_loss,val_bpb=eval_val(h,device,val_data,model);log(f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}") + if tb_writer is not None:tb_writer.add_scalar('val/loss',val_loss,step);tb_writer.add_scalar('val/bpb',val_bpb,step) + torch.cuda.synchronize();t0=time.perf_counter() + if last_step: + if stop_after_step is not None and step0 and not base_model.looping_active and frac>=h.enable_looping_at:base_model.looping_active=True;log(f"layer_loop:enabled step:{step} frac:{frac:.3f} encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}") + train_loss,raw_grad_norm,cur_lr=step_fn(step,scale) + with torch.no_grad(): + for(name,t)in base_model.state_dict().items():ema_state[name].mul_(ema_decay).add_(t.detach().float(),alpha=1.-ema_decay) + step+=1;approx_training_time_ms=training_time_ms+1e3*(time.perf_counter()-t0);should_log_train=h.train_log_every>0 and(step<=5 or step%h.train_log_every==0 or stop_after_step is not None) + if should_log_train:tok_per_sec=step*h.train_batch_tokens/(approx_training_time_ms/1e3);log(f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} train_time: {approx_training_time_ms/60000:.1f}m, step_avg: {approx_training_time_ms/step:.2f}ms, raw_grad_norm: {raw_grad_norm:.4f}, tok/s: {tok_per_sec:.0f}")#;log(f'scale: {[m.scale.mean().item() for m in base_model.dynamic_dense if m is not None]}') + if tb_writer is not None and should_log_train:tb_writer.add_scalar('train/loss',train_loss.item(),step);tb_writer.add_scalar('train/raw_grad_norm',float(raw_grad_norm),step);tb_writer.add_scalar('train/learning_rate',cur_lr,step); tb_writer.add_scalar('perf/step_avg_ms',approx_training_time_ms/step,step) + reached_cap=max_wallclock_ms is not None and approx_training_time_ms>=max_wallclock_ms + if h.distributed and max_wallclock_ms is not None:reached_cap_tensor=torch.tensor(int(reached_cap),device=device);dist.all_reduce(reached_cap_tensor,op=dist.ReduceOp.MAX);reached_cap=bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap:stop_after_step=step + if tb_writer is not None:tb_writer.flush();tb_writer.close() + log(f"peak memory allocated: {torch.cuda.max_memory_allocated()//1024//1024} MiB reserved: {torch.cuda.max_memory_reserved()//1024//1024} MiB");log('ema:applying EMA weights');current_state=base_model.state_dict();avg_state={name:t.to(dtype=current_state[name].dtype)for(name,t)in ema_state.items()};base_model.load_state_dict(avg_state,strict=True);return base_model,compiled_model +def train_and_eval(h,device): + random.seed(h.seed);np.random.seed(h.seed);torch.manual_seed(h.seed);torch.cuda.manual_seed_all(h.seed);val_data=ValidationData(h,device);log(f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob("fineweb_train_*.bin")))}");log(f"val_tokens: {val_data.val_tokens.numel()-1}");base_model,compiled_model=train_model(h,device,val_data);torch._dynamo.reset();timed_eval('pre-quantization post-ema',eval_val,h,device,val_data,compiled_model);serialize(h,base_model,Path(__file__).read_text(encoding='utf-8')) + if h.distributed:dist.barrier() + eval_model=deserialize(h,device) + if h.num_loops>0:eval_model.looping_active=True + compiled_model=torch.compile(eval_model,dynamic=False,fullgraph=True);timed_eval('quantized',eval_val,h,device,val_data,compiled_model) + if h.sliding_window_enabled:timed_eval('quantized_sliding_window',eval_val_sliding,h,device,val_data,eval_model) + if h.ttt_enabled and h.sliding_window_enabled: + del eval_model,compiled_model;torch._dynamo.reset();torch.cuda.empty_cache();ttt_model=deserialize(h,device) + if h.num_loops>0:ttt_model.looping_active=True + timed_eval('quantized_ttt',eval_val_ttt,h,device,val_data,ttt_model);del ttt_model + if h.etlb_enabled and h.sliding_window_enabled: + if'eval_model'not in dir(): + eval_model=deserialize(h,device) + if h.num_loops>0:eval_model.looping_active=True + timed_eval('quantized_sliding_etlb',eval_val_sliding_etlb,h,device,val_data,eval_model) +def main(): + world_size=int(os.environ.get('WORLD_SIZE','1'));local_rank=int(os.environ.get('LOCAL_RANK','0'));distributed='RANK'in os.environ and'WORLD_SIZE'in os.environ + if not torch.cuda.is_available():raise RuntimeError('CUDA is required') + if world_size<=0:raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8%world_size!=0:raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + device=torch.device('cuda',local_rank);torch.cuda.set_device(device) + if distributed:dist.init_process_group(backend='nccl',device_id=device);dist.barrier() + torch.backends.cuda.matmul.allow_tf32=True;torch.backends.cudnn.allow_tf32=True;torch.set_float32_matmul_precision('high');from torch.backends.cuda import enable_cudnn_sdp,enable_flash_sdp,enable_math_sdp,enable_mem_efficient_sdp;enable_cudnn_sdp(False);enable_flash_sdp(True);enable_mem_efficient_sdp(False);enable_math_sdp(False);torch._dynamo.config.optimize_ddp=False;h=Hyperparameters();set_logging_hparams(h) + if h.is_main_process: + os.makedirs('logs',exist_ok=True);log(100*'=',console=False);log('Hyperparameters:',console=True) + for(k,v)in sorted(vars(type(h)).items()): + if not k.startswith('_'):log(f" {k}: {v}",console=True) + log('='*100,console=False);log(f"Running Python {sys.version}",console=False);log(f"Running PyTorch {torch.__version__}",console=False);log(subprocess.run(['nvidia-smi'],stdout=subprocess.PIPE,stderr=subprocess.PIPE,text=True,check=False).stdout,console=False);log('='*100,console=False) + train_and_eval(h,device) + if distributed:dist.destroy_process_group() +if __name__=='__main__':main() diff --git a/records/track_10min_16mb/2026-04-29_MUDD_Connections/train_seed42.txt b/records/track_10min_16mb/2026-04-29_MUDD_Connections/train_seed42.txt new file mode 100644 index 0000000000..dd358e02e6 --- /dev/null +++ b/records/track_10min_16mb/2026-04-29_MUDD_Connections/train_seed42.txt @@ -0,0 +1,236 @@ +==================================================================================================== +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.9965 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.35 + etlb_clip: 3.0 + etlb_enabled: False + etlb_lr: 0.05 + etlb_steps: 5 + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 12.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + keep_unet: False + ln_scale: True + local_rank: 0 + logfile: logs/baseline0409_mudd_seed42.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.022 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 3.5 + model_dim: 512 + model_path: final_model.pt + mudd_emb: False + mudd_k_dilation: 1 + mudd_q_dilation: 1 + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 8 + num_layers: 11 + num_loops: 2 + parallel_residual_start: 7 + qk_gain_init: 5.25 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: baseline0409_mudd_seed42 + scalar_lr: 0.02 + seed: 42 + skip_gates_enabled: True + sliding_window_enabled: True + tensorboard_dir: + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_chunk_tokens: 32768 + ttt_enabled: True + ttt_epochs: 3 + ttt_lr: 0.005 + ttt_momentum: 0.9 + use_mudd: True + val_batch_tokens: 524288 + val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.72 + warmup_steps: 150 + world_size: 8 + xsa_last_n: 11 +==================================================================================================== +Running Python 3.13.12 | packaged by Anaconda, Inc. | (main, Feb 24 2026, 16:13:31) [GCC 14.3.0] +Running PyTorch 2.9.1+cu128 +Wed Apr 29 09:26:07 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:04:00.0 Off | 0 | +| N/A 38C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | 0 | +| N/A 35C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | +| N/A 37C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:0B:00.0 Off | 0 | +| N/A 36C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:84:00.0 Off | 0 | +| N/A 37C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:85:00.0 Off | 0 | +| N/A 35C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:8A:00.0 Off | 0 | +| N/A 37C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:8B:00.0 Off | 0 | +| N/A 36C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 186629 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 1 N/A N/A 186630 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 2 N/A N/A 186631 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 3 N/A N/A 186632 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 4 N/A N/A 186633 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 5 N/A N/A 186634 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 6 N/A N/A 186635 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 7 N/A N/A 186636 C ...ing/miniconda3/bin/python3.13 1512MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +train_shards: 80 +val_tokens: 40540160 +model_params:35919402 +gptq:reserving 12s, effective=588000ms +warmup_step: 1/150 +warmup_step: 2/150 +warmup_step: 3/150 +warmup_step: 4/150 +warmup_step: 5/150 +warmup_step: 6/150 +warmup_step: 10/150 +warmup_step: 20/150 +warmup_step: 30/150 +warmup_step: 40/150 +warmup_step: 50/150 +warmup_step: 60/150 +warmup_step: 70/150 +warmup_step: 80/150 +warmup_step: 90/150 +warmup_step: 100/150 +warmup_step: 110/150 +warmup_step: 120/150 +warmup_step: 130/150 +warmup_step: 140/150 +warmup_step: 150/150 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/150 +loop_warmup_step: 2/150 +loop_warmup_step: 3/150 +loop_warmup_step: 4/150 +loop_warmup_step: 5/150 +loop_warmup_step: 6/150 +loop_warmup_step: 10/150 +loop_warmup_step: 20/150 +loop_warmup_step: 30/150 +loop_warmup_step: 40/150 +loop_warmup_step: 50/150 +loop_warmup_step: 60/150 +loop_warmup_step: 70/150 +loop_warmup_step: 80/150 +loop_warmup_step: 90/150 +loop_warmup_step: 100/150 +loop_warmup_step: 110/150 +loop_warmup_step: 120/150 +loop_warmup_step: 130/150 +loop_warmup_step: 140/150 +loop_warmup_step: 150/150 +0/20000 val_loss: 9.0082 val_bpb: 3.4873 +1/20000 train_loss: 9.0096 train_time: 0.0m, step_avg: 100.15ms, raw_grad_norm: 0.3832, tok/s: 7852685 +2/20000 train_loss: 12.6101 train_time: 0.0m, step_avg: 101.12ms, raw_grad_norm: 2.7562, tok/s: 7777098 +3/20000 train_loss: 12.1070 train_time: 0.0m, step_avg: 102.18ms, raw_grad_norm: 3.0300, tok/s: 7696504 +4/20000 train_loss: 10.5908 train_time: 0.0m, step_avg: 102.82ms, raw_grad_norm: 3.5750, tok/s: 7649000 +5/20000 train_loss: 9.1196 train_time: 0.0m, step_avg: 103.20ms, raw_grad_norm: 3.3087, tok/s: 7620276 +500/20000 train_loss: 3.3711 train_time: 0.9m, step_avg: 106.30ms, raw_grad_norm: 0.2764, tok/s: 7398353 +1000/20000 train_loss: 3.2739 train_time: 1.8m, step_avg: 106.29ms, raw_grad_norm: 0.2078, tok/s: 7398713 +1500/20000 train_loss: 3.1742 train_time: 2.7m, step_avg: 106.28ms, raw_grad_norm: 0.1309, tok/s: 7399836 +layer_loop:enabled step:1937 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +2000/20000 train_loss: 3.0623 train_time: 3.6m, step_avg: 107.89ms, raw_grad_norm: 0.1213, tok/s: 7289192 +2500/20000 train_loss: 3.0975 train_time: 4.9m, step_avg: 117.78ms, raw_grad_norm: 0.1019, tok/s: 6677403 +3000/20000 train_loss: 2.8771 train_time: 6.2m, step_avg: 124.37ms, raw_grad_norm: 0.0949, tok/s: 6323417 +3500/20000 train_loss: 2.9162 train_time: 7.5m, step_avg: 129.07ms, raw_grad_norm: 0.0902, tok/s: 6093110 +4000/20000 train_loss: 2.7856 train_time: 8.8m, step_avg: 132.59ms, raw_grad_norm: 0.0736, tok/s: 5931096 +4000/20000 val_loss: 2.8440 val_bpb: 1.1010 +4367/20000 val_loss: 2.8057 val_bpb: 1.0862 +stopping_early: wallclock_cap train_time: 588005ms step: 4367/20000 +peak memory allocated: 42260 MiB reserved: 43748 MiB +ema:applying EMA weights +pre-quantization post-ema val_loss:2.80271282 val_bpb:1.08501747 eval_time:10000ms +Serialized model: 135338079 bytes +Code size: 19388 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 71 Hessians in 14.3s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight, dynamic_dense.12.w1.weight, dynamic_dense.15.w1.weight, dynamic_dense.16.w1.weight, dynamic_dense.8.w1.weight + gptq (int8): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, dynamic_dense.10.scale, dynamic_dense.10.w1.weight, dynamic_dense.10.w2.bias, dynamic_dense.10.w2.weight, dynamic_dense.12.scale, dynamic_dense.12.w2.bias, dynamic_dense.12.w2.weight, dynamic_dense.15.scale, dynamic_dense.15.w2.bias, dynamic_dense.15.w2.weight, dynamic_dense.16.scale, dynamic_dense.16.w2.bias, dynamic_dense.16.w2.weight, dynamic_dense.2.scale, dynamic_dense.2.w1.weight, dynamic_dense.2.w2.bias, dynamic_dense.2.w2.weight, dynamic_dense.4.scale, dynamic_dense.4.w1.weight, dynamic_dense.4.w2.bias, dynamic_dense.4.w2.weight, dynamic_dense.6.scale, dynamic_dense.6.w1.weight, dynamic_dense.6.w2.bias, dynamic_dense.6.w2.weight, dynamic_dense.8.scale, dynamic_dense.8.w2.bias, dynamic_dense.8.w2.weight +Serialized model quantized+brotli: 15976982 bytes +Total submission size quantized+brotli: 15996370 bytes +quantized val_loss:2.83102779 val_bpb:1.09597908 eval_time:13282ms +quantized_sliding_window val_loss:2.78669863 val_bpb:1.07881788 eval_time:104031ms +ttt:start chunks=1238 ttt_lr=0.005 ttt_epochs=3 +quantized_ttt val_loss:2.78313705 val_bpb:1.07743908 eval_time:376290ms diff --git a/records/track_10min_16mb/2026-04-29_MUDD_Connections/train_seed423.txt b/records/track_10min_16mb/2026-04-29_MUDD_Connections/train_seed423.txt new file mode 100644 index 0000000000..ace0b37ace --- /dev/null +++ b/records/track_10min_16mb/2026-04-29_MUDD_Connections/train_seed423.txt @@ -0,0 +1,548 @@ +==================================================================================================== +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.9965 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.35 + etlb_clip: 3.0 + etlb_enabled: False + etlb_lr: 0.05 + etlb_steps: 5 + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 12.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + keep_unet: False + ln_scale: True + local_rank: 0 + logfile: logs/baseline0409_mudd_seed423.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.022 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 3.5 + model_dim: 512 + model_path: final_model.pt + mudd_emb: False + mudd_k_dilation: 1 + mudd_q_dilation: 1 + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 8 + num_layers: 11 + num_loops: 2 + parallel_residual_start: 7 + qk_gain_init: 5.25 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: baseline0409_mudd_seed423 + scalar_lr: 0.02 + seed: 423 + skip_gates_enabled: True + sliding_window_enabled: True + tensorboard_dir: + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_chunk_tokens: 32768 + ttt_enabled: True + ttt_epochs: 3 + ttt_lr: 0.005 + ttt_momentum: 0.9 + use_mudd: True + val_batch_tokens: 524288 + val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.72 + warmup_steps: 150 + world_size: 8 + xsa_last_n: 11 +==================================================================================================== +Running Python 3.13.12 | packaged by Anaconda, Inc. | (main, Feb 24 2026, 16:13:31) [GCC 14.3.0] +Running PyTorch 2.9.1+cu128 +Wed Apr 29 08:51:06 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:04:00.0 Off | 0 | +| N/A 39C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | 0 | +| N/A 36C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | +| N/A 38C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:0B:00.0 Off | 0 | +| N/A 37C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:84:00.0 Off | 0 | +| N/A 37C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:85:00.0 Off | 0 | +| N/A 36C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:8A:00.0 Off | 0 | +| N/A 37C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:8B:00.0 Off | 0 | +| N/A 36C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 104819 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 1 N/A N/A 104820 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 2 N/A N/A 104821 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 3 N/A N/A 104822 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 4 N/A N/A 104823 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 5 N/A N/A 104824 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 6 N/A N/A 104825 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 7 N/A N/A 104826 C ...ing/miniconda3/bin/python3.13 1512MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +train_shards: 80 +val_tokens: 40540160 +==================================================================================================== +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.9965 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.35 + etlb_clip: 3.0 + etlb_enabled: False + etlb_lr: 0.05 + etlb_steps: 5 + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 12.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + keep_unet: False + ln_scale: True + local_rank: 0 + logfile: logs/baseline0409_mudd_seed423.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.022 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 3.5 + model_dim: 512 + model_path: final_model.pt + mudd_emb: False + mudd_k_dilation: 1 + mudd_q_dilation: 1 + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 8 + num_layers: 11 + num_loops: 2 + parallel_residual_start: 7 + qk_gain_init: 5.25 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: baseline0409_mudd_seed423 + scalar_lr: 0.02 + seed: 423 + skip_gates_enabled: True + sliding_window_enabled: True + tensorboard_dir: + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_chunk_tokens: 32768 + ttt_enabled: True + ttt_epochs: 3 + ttt_lr: 0.005 + ttt_momentum: 0.9 + use_mudd: True + val_batch_tokens: 524288 + val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.72 + warmup_steps: 150 + world_size: 8 + xsa_last_n: 11 +==================================================================================================== +Running Python 3.13.12 | packaged by Anaconda, Inc. | (main, Feb 24 2026, 16:13:31) [GCC 14.3.0] +Running PyTorch 2.9.1+cu128 +Wed Apr 29 08:51:42 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:04:00.0 Off | 0 | +| N/A 39C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | 0 | +| N/A 36C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | +| N/A 38C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:0B:00.0 Off | 0 | +| N/A 37C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:84:00.0 Off | 0 | +| N/A 38C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:85:00.0 Off | 0 | +| N/A 36C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:8A:00.0 Off | 0 | +| N/A 38C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:8B:00.0 Off | 0 | +| N/A 36C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 105168 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 1 N/A N/A 105169 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 2 N/A N/A 105170 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 3 N/A N/A 105171 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 4 N/A N/A 105172 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 5 N/A N/A 105173 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 6 N/A N/A 105174 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 7 N/A N/A 105175 C ...ing/miniconda3/bin/python3.13 1512MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +train_shards: 80 +val_tokens: 40540160 +==================================================================================================== +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.9965 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.35 + etlb_clip: 3.0 + etlb_enabled: False + etlb_lr: 0.05 + etlb_steps: 5 + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 12.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + keep_unet: False + ln_scale: True + local_rank: 0 + logfile: logs/baseline0409_mudd_seed423.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.022 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 3.5 + model_dim: 512 + model_path: final_model.pt + mudd_emb: False + mudd_k_dilation: 1 + mudd_q_dilation: 1 + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 8 + num_layers: 11 + num_loops: 2 + parallel_residual_start: 7 + qk_gain_init: 5.25 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: baseline0409_mudd_seed423 + scalar_lr: 0.02 + seed: 423 + skip_gates_enabled: True + sliding_window_enabled: True + tensorboard_dir: + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_chunk_tokens: 32768 + ttt_enabled: True + ttt_epochs: 3 + ttt_lr: 0.005 + ttt_momentum: 0.9 + use_mudd: True + val_batch_tokens: 524288 + val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.72 + warmup_steps: 150 + world_size: 8 + xsa_last_n: 11 +==================================================================================================== +Running Python 3.13.12 | packaged by Anaconda, Inc. | (main, Feb 24 2026, 16:13:31) [GCC 14.3.0] +Running PyTorch 2.9.1+cu128 +Wed Apr 29 08:52:30 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:04:00.0 Off | 0 | +| N/A 39C P0 122W / 700W | 1521MiB / 81559MiB | 3% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | 0 | +| N/A 36C P0 119W / 700W | 1521MiB / 81559MiB | 4% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | +| N/A 38C P0 120W / 700W | 1521MiB / 81559MiB | 2% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:0B:00.0 Off | 0 | +| N/A 37C P0 120W / 700W | 1521MiB / 81559MiB | 2% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:84:00.0 Off | 0 | +| N/A 38C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:85:00.0 Off | 0 | +| N/A 36C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:8A:00.0 Off | 0 | +| N/A 38C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:8B:00.0 Off | 0 | +| N/A 36C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 105566 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 1 N/A N/A 105567 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 2 N/A N/A 105568 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 3 N/A N/A 105569 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 4 N/A N/A 105570 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 5 N/A N/A 105571 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 6 N/A N/A 105572 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 7 N/A N/A 105573 C ...ing/miniconda3/bin/python3.13 1512MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +train_shards: 80 +val_tokens: 40540160 +model_params:35919402 +gptq:reserving 12s, effective=588000ms +warmup_step: 1/150 +warmup_step: 2/150 +warmup_step: 3/150 +warmup_step: 4/150 +warmup_step: 5/150 +warmup_step: 6/150 +warmup_step: 10/150 +warmup_step: 20/150 +warmup_step: 30/150 +warmup_step: 40/150 +warmup_step: 50/150 +warmup_step: 60/150 +warmup_step: 70/150 +warmup_step: 80/150 +warmup_step: 90/150 +warmup_step: 100/150 +warmup_step: 110/150 +warmup_step: 120/150 +warmup_step: 130/150 +warmup_step: 140/150 +warmup_step: 150/150 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/150 +loop_warmup_step: 2/150 +loop_warmup_step: 3/150 +loop_warmup_step: 4/150 +loop_warmup_step: 5/150 +loop_warmup_step: 6/150 +loop_warmup_step: 10/150 +loop_warmup_step: 20/150 +loop_warmup_step: 30/150 +loop_warmup_step: 40/150 +loop_warmup_step: 50/150 +loop_warmup_step: 60/150 +loop_warmup_step: 70/150 +loop_warmup_step: 80/150 +loop_warmup_step: 90/150 +loop_warmup_step: 100/150 +loop_warmup_step: 110/150 +loop_warmup_step: 120/150 +loop_warmup_step: 130/150 +loop_warmup_step: 140/150 +loop_warmup_step: 150/150 +0/20000 val_loss: 9.0070 val_bpb: 3.4869 +1/20000 train_loss: 9.0086 train_time: 0.0m, step_avg: 102.69ms, raw_grad_norm: 0.3926, tok/s: 7658547 +2/20000 train_loss: 12.6078 train_time: 0.0m, step_avg: 102.64ms, raw_grad_norm: 2.6945, tok/s: 7662029 +3/20000 train_loss: 11.9242 train_time: 0.0m, step_avg: 103.39ms, raw_grad_norm: 3.4404, tok/s: 7606698 +4/20000 train_loss: 10.7060 train_time: 0.0m, step_avg: 103.74ms, raw_grad_norm: 3.4466, tok/s: 7581050 +5/20000 train_loss: 9.2075 train_time: 0.0m, step_avg: 103.96ms, raw_grad_norm: 3.2168, tok/s: 7564852 +500/20000 train_loss: 3.3776 train_time: 0.9m, step_avg: 106.31ms, raw_grad_norm: 0.2758, tok/s: 7397703 +1000/20000 train_loss: 3.2683 train_time: 1.8m, step_avg: 106.32ms, raw_grad_norm: 0.2050, tok/s: 7396697 +1500/20000 train_loss: 3.1803 train_time: 2.7m, step_avg: 106.33ms, raw_grad_norm: 0.1963, tok/s: 7395952 +layer_loop:enabled step:1936 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +2000/20000 train_loss: 3.0647 train_time: 3.6m, step_avg: 107.96ms, raw_grad_norm: 0.1337, tok/s: 7284477 +2500/20000 train_loss: 3.1005 train_time: 4.9m, step_avg: 117.82ms, raw_grad_norm: 0.0905, tok/s: 6674960 +3000/20000 train_loss: 2.8725 train_time: 6.2m, step_avg: 124.38ms, raw_grad_norm: 0.0909, tok/s: 6322858 +3500/20000 train_loss: 2.9117 train_time: 7.5m, step_avg: 129.06ms, raw_grad_norm: 0.0804, tok/s: 6093615 +4000/20000 train_loss: 2.7908 train_time: 8.8m, step_avg: 132.56ms, raw_grad_norm: 0.0678, tok/s: 5932428 +4000/20000 val_loss: 2.8424 val_bpb: 1.1004 +4369/20000 val_loss: 2.8040 val_bpb: 1.0855 +stopping_early: wallclock_cap train_time: 588136ms step: 4369/20000 +peak memory allocated: 42259 MiB reserved: 42352 MiB +ema:applying EMA weights +pre-quantization post-ema val_loss:2.80107152 val_bpb:1.08438207 eval_time:10500ms +Serialized model: 135338079 bytes +Code size: 19388 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 71 Hessians in 14.3s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight, dynamic_dense.12.w1.weight, dynamic_dense.15.w1.weight, dynamic_dense.16.w1.weight, dynamic_dense.8.w1.weight + gptq (int8): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, dynamic_dense.10.scale, dynamic_dense.10.w1.weight, dynamic_dense.10.w2.bias, dynamic_dense.10.w2.weight, dynamic_dense.12.scale, dynamic_dense.12.w2.bias, dynamic_dense.12.w2.weight, dynamic_dense.15.scale, dynamic_dense.15.w2.bias, dynamic_dense.15.w2.weight, dynamic_dense.16.scale, dynamic_dense.16.w2.bias, dynamic_dense.16.w2.weight, dynamic_dense.2.scale, dynamic_dense.2.w1.weight, dynamic_dense.2.w2.bias, dynamic_dense.2.w2.weight, dynamic_dense.4.scale, dynamic_dense.4.w1.weight, dynamic_dense.4.w2.bias, dynamic_dense.4.w2.weight, dynamic_dense.6.scale, dynamic_dense.6.w1.weight, dynamic_dense.6.w2.bias, dynamic_dense.6.w2.weight, dynamic_dense.8.scale, dynamic_dense.8.w2.bias, dynamic_dense.8.w2.weight +Serialized model quantized+brotli: 15978693 bytes +Total submission size quantized+brotli: 15998081 bytes +quantized val_loss:2.83134234 val_bpb:1.09610085 eval_time:33496ms +quantized_sliding_window val_loss:2.78633911 val_bpb:1.07867870 eval_time:145156ms +ttt:start chunks=1238 ttt_lr=0.005 ttt_epochs=3 +quantized_ttt val_loss:2.78113991 val_bpb:1.07666593 eval_time:422589ms diff --git a/records/track_10min_16mb/2026-04-29_MUDD_Connections/train_seed424.txt b/records/track_10min_16mb/2026-04-29_MUDD_Connections/train_seed424.txt new file mode 100644 index 0000000000..b45c27d91b --- /dev/null +++ b/records/track_10min_16mb/2026-04-29_MUDD_Connections/train_seed424.txt @@ -0,0 +1,236 @@ +==================================================================================================== +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.9965 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.35 + etlb_clip: 3.0 + etlb_enabled: False + etlb_lr: 0.05 + etlb_steps: 5 + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 12.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + keep_unet: False + ln_scale: True + local_rank: 0 + logfile: logs/baseline0409_mudd_seed424.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.022 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 3.5 + model_dim: 512 + model_path: final_model.pt + mudd_emb: False + mudd_k_dilation: 1 + mudd_q_dilation: 1 + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 8 + num_layers: 11 + num_loops: 2 + parallel_residual_start: 7 + qk_gain_init: 5.25 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: baseline0409_mudd_seed424 + scalar_lr: 0.02 + seed: 424 + skip_gates_enabled: True + sliding_window_enabled: True + tensorboard_dir: + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_chunk_tokens: 32768 + ttt_enabled: True + ttt_epochs: 3 + ttt_lr: 0.005 + ttt_momentum: 0.9 + use_mudd: True + val_batch_tokens: 524288 + val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.72 + warmup_steps: 150 + world_size: 8 + xsa_last_n: 11 +==================================================================================================== +Running Python 3.13.12 | packaged by Anaconda, Inc. | (main, Feb 24 2026, 16:13:31) [GCC 14.3.0] +Running PyTorch 2.9.1+cu128 +Wed Apr 29 09:47:48 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:04:00.0 Off | 0 | +| N/A 48C P0 130W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | 0 | +| N/A 42C P0 126W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | +| N/A 48C P0 126W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:0B:00.0 Off | 0 | +| N/A 43C P0 123W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:84:00.0 Off | 0 | +| N/A 47C P0 128W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:85:00.0 Off | 0 | +| N/A 42C P0 125W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:8A:00.0 Off | 0 | +| N/A 44C P0 123W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:8B:00.0 Off | 0 | +| N/A 41C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 189217 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 1 N/A N/A 189218 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 2 N/A N/A 189219 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 3 N/A N/A 189220 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 4 N/A N/A 189221 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 5 N/A N/A 189222 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 6 N/A N/A 189223 C ...ing/miniconda3/bin/python3.13 1512MiB | +| 7 N/A N/A 189224 C ...ing/miniconda3/bin/python3.13 1512MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +train_shards: 80 +val_tokens: 40540160 +model_params:35919402 +gptq:reserving 12s, effective=588000ms +warmup_step: 1/150 +warmup_step: 2/150 +warmup_step: 3/150 +warmup_step: 4/150 +warmup_step: 5/150 +warmup_step: 6/150 +warmup_step: 10/150 +warmup_step: 20/150 +warmup_step: 30/150 +warmup_step: 40/150 +warmup_step: 50/150 +warmup_step: 60/150 +warmup_step: 70/150 +warmup_step: 80/150 +warmup_step: 90/150 +warmup_step: 100/150 +warmup_step: 110/150 +warmup_step: 120/150 +warmup_step: 130/150 +warmup_step: 140/150 +warmup_step: 150/150 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/150 +loop_warmup_step: 2/150 +loop_warmup_step: 3/150 +loop_warmup_step: 4/150 +loop_warmup_step: 5/150 +loop_warmup_step: 6/150 +loop_warmup_step: 10/150 +loop_warmup_step: 20/150 +loop_warmup_step: 30/150 +loop_warmup_step: 40/150 +loop_warmup_step: 50/150 +loop_warmup_step: 60/150 +loop_warmup_step: 70/150 +loop_warmup_step: 80/150 +loop_warmup_step: 90/150 +loop_warmup_step: 100/150 +loop_warmup_step: 110/150 +loop_warmup_step: 120/150 +loop_warmup_step: 130/150 +loop_warmup_step: 140/150 +loop_warmup_step: 150/150 +0/20000 val_loss: 9.0081 val_bpb: 3.4873 +1/20000 train_loss: 9.0099 train_time: 0.0m, step_avg: 101.02ms, raw_grad_norm: 0.3888, tok/s: 7784578 +2/20000 train_loss: 12.6135 train_time: 0.0m, step_avg: 101.43ms, raw_grad_norm: 2.5912, tok/s: 7753673 +3/20000 train_loss: 11.9046 train_time: 0.0m, step_avg: 102.42ms, raw_grad_norm: 2.9550, tok/s: 7678511 +4/20000 train_loss: 10.5487 train_time: 0.0m, step_avg: 102.99ms, raw_grad_norm: 3.6068, tok/s: 7636280 +5/20000 train_loss: 9.1175 train_time: 0.0m, step_avg: 103.35ms, raw_grad_norm: 3.5574, tok/s: 7609512 +500/20000 train_loss: 3.3744 train_time: 0.9m, step_avg: 106.45ms, raw_grad_norm: 0.2754, tok/s: 7387651 +1000/20000 train_loss: 3.2711 train_time: 1.8m, step_avg: 106.46ms, raw_grad_norm: 0.2055, tok/s: 7386856 +1500/20000 train_loss: 3.1747 train_time: 2.7m, step_avg: 106.42ms, raw_grad_norm: 0.1461, tok/s: 7389744 +layer_loop:enabled step:1934 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +2000/20000 train_loss: 3.0641 train_time: 3.6m, step_avg: 108.09ms, raw_grad_norm: 0.1161, tok/s: 7275447 +2500/20000 train_loss: 3.1008 train_time: 4.9m, step_avg: 117.96ms, raw_grad_norm: 0.1166, tok/s: 6667041 +3000/20000 train_loss: 2.8739 train_time: 6.2m, step_avg: 124.52ms, raw_grad_norm: 0.0833, tok/s: 6315560 +3500/20000 train_loss: 2.9125 train_time: 7.5m, step_avg: 129.21ms, raw_grad_norm: 0.0819, tok/s: 6086246 +4000/20000 train_loss: 2.7870 train_time: 8.8m, step_avg: 132.73ms, raw_grad_norm: 0.0666, tok/s: 5925259 +4000/20000 val_loss: 2.8413 val_bpb: 1.1000 +4364/20000 val_loss: 2.8032 val_bpb: 1.0852 +stopping_early: wallclock_cap train_time: 588084ms step: 4364/20000 +peak memory allocated: 42260 MiB reserved: 43748 MiB +ema:applying EMA weights +pre-quantization post-ema val_loss:2.80017544 val_bpb:1.08403518 eval_time:9890ms +Serialized model: 135338079 bytes +Code size: 19388 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 71 Hessians in 14.3s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight, dynamic_dense.12.w1.weight, dynamic_dense.15.w1.weight, dynamic_dense.16.w1.weight, dynamic_dense.8.w1.weight + gptq (int8): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, dynamic_dense.10.scale, dynamic_dense.10.w1.weight, dynamic_dense.10.w2.bias, dynamic_dense.10.w2.weight, dynamic_dense.12.scale, dynamic_dense.12.w2.bias, dynamic_dense.12.w2.weight, dynamic_dense.15.scale, dynamic_dense.15.w2.bias, dynamic_dense.15.w2.weight, dynamic_dense.16.scale, dynamic_dense.16.w2.bias, dynamic_dense.16.w2.weight, dynamic_dense.2.scale, dynamic_dense.2.w1.weight, dynamic_dense.2.w2.bias, dynamic_dense.2.w2.weight, dynamic_dense.4.scale, dynamic_dense.4.w1.weight, dynamic_dense.4.w2.bias, dynamic_dense.4.w2.weight, dynamic_dense.6.scale, dynamic_dense.6.w1.weight, dynamic_dense.6.w2.bias, dynamic_dense.6.w2.weight, dynamic_dense.8.scale, dynamic_dense.8.w2.bias, dynamic_dense.8.w2.weight +Serialized model quantized+brotli: 15977121 bytes +Total submission size quantized+brotli: 15996509 bytes +quantized val_loss:2.82808086 val_bpb:1.09483823 eval_time:13253ms +quantized_sliding_window val_loss:2.78376433 val_bpb:1.07768192 eval_time:103567ms +ttt:start chunks=1238 ttt_lr=0.005 ttt_epochs=3 +quantized_ttt val_loss:2.78059298 val_bpb:1.07645419 eval_time:371222ms