diff --git a/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/DATASET_AUDIT.md b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/DATASET_AUDIT.md new file mode 100644 index 0000000000..377b7708e3 --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/DATASET_AUDIT.md @@ -0,0 +1,128 @@ +# Dataset Audit: CaseOps Train Shards and Full Validation + +This note records how the dataset used for this submission was constructed and how it was checked against the merged CaseOps leaderboard lineage. + +## Verdict + +The submitted runs use the CaseOps SP8192 lossless-caps tokenizer and byte-sidecar BPB accounting. The 80 training shards were verified byte-for-byte against the output of the merged CaseOps leader's `prepare_caseops_data.py` default path. Evaluation uses the full CaseOps validation shard/sidecar reported by the leaderboard logs (`val_tokens: 47851520`). + +This is the same structural setup used by the CaseOps leaderboard lineage: 80 train shards, SP8192 lossless-caps tokenization, BOS-delimited documents, and byte sidecars for validation BPB accounting. + +## Sources + +- Dataset stream: the canonical FineWeb document stream used by the CaseOps records, `docs_selected.jsonl`. +- Tokenizer: `tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model`. +- Text transform: `lossless_caps.py`. +- Dataset script in this submission: `prepare_caseops_data.py`. +- Reference merged-leader script: `records/track_10min_16mb/2026-04-27_SP8192_LQER_SparseGate_BOSSmearFix_9HpStack_1.0611/prepare_caseops_data.py` from commit `1e439663209730edeac34e659039d7de62d85908` in `codemath3000/parameter-golf` (`https://github.com/codemath3000/parameter-golf/blob/1e439663209730edeac34e659039d7de62d85908/records/track_10min_16mb/2026-04-27_SP8192_LQER_SparseGate_BOSSmearFix_9HpStack_1.0611/prepare_caseops_data.py`). + +The relevant reference-script behavior is: + +- `SHARD_TOKENS = 10_000_000` +- `BOS_ID = 1` +- `--val-docs` default is `10_000` +- documents before `val_docs` are written to `fineweb_val_*.bin` and `fineweb_val_bytes_*.bin` +- documents after that boundary are written to `fineweb_train_*.bin` + +## Exact Train-Shard Replication + +On the AP RunPod, we rebuilt the dataset using the exact merged-leader `prepare_caseops_data.py` behavior with its default `--val-docs=10000`. Because the full document stream is large and the reference script has no "stop after 80 train shards" option, the monitor stopped the producer after the regenerated dataset had passed the first 80 train shards. The monitor observed 82 train shards at its next polling interval, but the audit comparison intentionally uses only the first 80 shards, matching the record runs. + +The regenerated output was compared to the compact CaseOps archive used for staging. The comparison result is stored in `dataset_verification/manifest_compare.json`: + +```json +{ + "exact_train_shards_seen": 82, + "exact_train_first80_tokens": 800000000, + "archive_train_first80_tokens": 800000000, + "train_first80_hash_mismatches": 0, + "train_mismatches_first5": [], + "exact_val_shards": 1, + "exact_val_byte_shards": 1, + "exact_val_tokens": 9662502, + "exact_val_byte_entries": 9662502, + "archive_val_tokens": 9662502, + "archive_val_byte_entries": 9662502, + "val_hash_match": true, + "val_bytes_hash_match": true +} +``` + +Interpretation: + +- The compact archive's first 80 train shards contain exactly `800000000` tokens. +- Those first 80 train shards have zero hash mismatches against the exact-script rebuild. +- The compact archive's default 10k validation token shard and byte sidecar also match the exact-script rebuild. + +This proves that the compact archive used for training is a faithful byte-for-byte staging of the merged CaseOps script's first 80 training shards. + +## Why Full 50k Validation Is Used + +The exact merged script's default `--val-docs=10000` produces a small validation set with `9662502` raw validation entries. That default validation output is useful for proving the archive's provenance, but it is not the leaderboard-comparable validation set. + +The CaseOps leaderboard logs report: + +```text +val_tokens: 47851520 +``` + +The submitted logs also report the same validation length: + +- `train_seed42.log`: `val_tokens: 47851520` +- `train_seed1337.log`: `val_tokens: 47851520` +- `train_seed2026.log`: `val_tokens: 47851520` +- `ap_pod_seed0/run.log`: `val_tokens: 47851520` + +So the final scoring dataset keeps the verified 80 train shards and replaces only the validation token/byte sidecar shards with the full 50,000-document CaseOps validation set. + +## Full-Validation Repair + +The full-validation repair leaves all `fineweb_train_*.bin` shards unchanged. It removes/replaces only: + +- `fineweb_val_*.bin` +- `fineweb_val_bytes_*.bin` + +Those validation files were regenerated from the first 50,000 documents of the same canonical document stream, using the same SP8192 tokenizer and `lossless_caps.py` transform. + +The AP pod repair log is stored at `dataset_verification/repair_full50k_val_ap.log` and ends with: + +```text +done docs=50000 val_shards=5 val_tokens=47853344 +``` + +`47853344` is the raw number of validation token/byte entries. `train_gpt.py` rounds the scored validation stream to the eval sequence length (`EVAL_SEQ_LEN=2560`), yielding: + +```text +val_tokens: 47851520 +``` + +This matches the leaderboard lineage and the submitted logs. + +## Base-Training/Eval Data Separation + +Base training reads only: + +```text +fineweb_train_*.bin +``` + +Validation/eval, including score-first TTT, reads: + +```text +fineweb_val_*.bin +fineweb_val_bytes_*.bin +``` + +The byte sidecar is not a training target. It is used for BPB accounting and document-aware validation/eval processing. Eval-time TTT uses validation tokens only in the score-first order documented in `README.md`: tokens are scored before they are used for any global or LoRA update. The full-validation repair does not change base-training data or model code; it only restores the validation stream length to the leaderboard-comparable `47851520` scored tokens. + +## Evidence Files + +- `dataset_verification/manifest_compare.json` - hash/token comparison between exact-script rebuild and compact archive. +- `dataset_verification/monitor.log` - timestamped exact-script rebuild monitor and embedded manifest result. +- `dataset_verification/repair_full50k_val_ap.log` - full 50k validation regeneration log. +- `train_seed42.log`, `train_seed1337.log`, `train_seed2026.log` - final 3-seed submission logs. +- `ap_pod_seed0/run.log` - independent AP pod check using the same train/full-validation construction. + +## Important Non-Comparable Check + +We also ran the candidate against the exact script's default 10k validation output. That run produced a much worse BPB while having a similar validation loss, because BPB depends on the validation byte sidecar and the 10k validation slice has a different token/byte ratio. It is therefore a useful debugging check, but it is not comparable to leaderboard logs reporting `val_tokens: 47851520`. diff --git a/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/README.md b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/README.md new file mode 100644 index 0000000000..607221697c --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/README.md @@ -0,0 +1,113 @@ +# Record: Gated XSA + LQER top-1 + strict token-only n-gram TTT (val_bpb: 1.047) + +**val_bpb: 1.04722074** (3-seed mean, std 0.00104816) | **max artifact: 15,996,490 bytes** | 8xH100 SXM | strict in-timer TTT eval + +**Improvement vs merged PR #1855 SOTA (1.06107587 BPB):** **-0.01385513 BPB / -0.00960 nats per byte**, clearing the README's 0.005-nats record threshold by about 1.92x. + +| Metric | Seed 42 | Seed 1337 | Seed 2026 | 3-seed | +|---|---:|---:|---:|---:| +| Stop step | 4,914 | 4,926 | 4,916 | 4,918.7 mean | +| Train time | 596.127 s | 596.167 s | 596.080 s | 596.125 s mean | +| Pre-quant BPB | 1.04930686 | 1.05124428 | 1.05029930 | 1.05028348 mean | +| Quantized BPB | 1.05773513 | 1.05990331 | 1.05886641 | 1.05883495 mean | +| **Post-TTT BPB** | **1.04616727** | **1.04826351** | **1.04723144** | **1.04722074 mean** | +| Eval time | 471.457 s | 465.480 s | 463.281 s | 466.739 s mean | +| Artifact bytes | 15,995,574 | 15,992,746 | 15,996,490 | 15,996,490 max | + +All reported eval time above includes the n-gram hint precompute inside the measured TTT eval timer (`NGRAM_HINT_PRECOMPUTE_OUTSIDE=0`). + +## Summary + +This submission picks up on the PR #1967 / CaseOps lineage and then applies a training-time attention change plus a conservative eval-time n-gram path: + +1. **Gated XSA.** Each attention layer gets a learned per-head scalar `xsa_alpha`; the existing XSA subtraction coefficient is multiplied by `tanh(xsa_alpha)`. The gate is zero-initialized, so the model starts as a strict superset of the base stack. +2. **LQER top-1.** `LQER_TOP_K=1` keeps the best LQER correction tensor. This saves artifact bytes versus the top-3 setting and was a favorable knob in the PR #1948 lineage. +3. **Strict token-only n-gram tilt.** In response to the current-token class-routing concern, this update adopts the conservative PR #1514 workaround: disable the within-word and word-level experts and retain only the token-16 expert. The token hint is emitted from `token_context_hash(st)` over prefix state before the current token is pushed into the online state. +4. **In-timer hint precompute.** The n-gram hint pass is included in the final eval timer (`NGRAM_HINT_PRECOMPUTE_OUTSIDE=0`). A token-only native fast path keeps the full eval under the 10-minute cap. +5. **Cheaper phased TTT.** The final eval uses one score-first global TTT phase over a 1,000-document prefix, then scores the remaining stream with the adapted global model plus per-document LoRA TTT. + +What did not work: Skylight/NorMuon was tested but is disabled in this submission (`SKYLIGHT_MUON=0`) because it destabilized this stack. + +## Compliance notes + +- **Artifact size:** max artifact is 15,996,490 bytes, under the decimal 16,000,000-byte cap. +- **Training budget:** all three seeds stop on the 600-second wallclock cap at about 596.1 s. +- **Eval budget:** all three token-only final TTT evals are under 600 s. The n-gram hint precompute is included in that timer. +- **Score-first TTT:** the phased TTT path scores validation tokens before using them for global or LoRA updates. The global phase only trains on already-scored prefix documents. +- **Token-only n-gram tilt:** the tilt applies a closed-form renormalized one-token boost, `p'(a) = exp(beta * 1[a=h]) p(a) / Z`, where `Z = 1 + p(h)(exp(beta)-1)`. Hints are generated left-to-right from prefix token state. +- **No within-word or word-level experts:** the final logs show `token_gate=628130 within_gate=0 word_gate=0 agree2plus=0` for every seed. +- **Gate population diagnostic:** `token_only_fast_evals/token_only_gate_population.json` reproduces the production hint pass and reports the same `token_gate=628130`, with `within_gate=0` and `word_gate=0`. +- **Dataset/tokenizer:** uses the CaseOps SP8192 lossless-caps tokenizer and byte-sidecar BPB accounting from the CaseOps lineage. The 80 training shards match the merged CaseOps leader's `prepare_caseops_data.py` default `val_docs=10000` output byte-for-byte. Evaluation uses the full CaseOps validation shard/sidecar reported by the leaderboard lineage (`val_tokens: 47851520`). See `DATASET_AUDIT.md`. + +## Key settings + +| Setting | Value | +|---|---| +| Base stack | PR #1967 V21 + LeakyReLU 0.3 + n-gram tilt lineage | +| Model | 11 layers, 512 dim, 8 heads / 4 KV heads | +| Tokenizer | SP8192 lossless-caps CaseOps v1 reserved | +| Eval sequence length | 2560 | +| TTT mask | `no_qv` | +| TTT LoRA rank | 80 | +| TTT local LR mult | 0.75 | +| QK gain init | 5.25 | +| Matrix LR | 0.026 | +| Min LR | 0.1 | +| LQER | rank 4, asymmetric, top-1 | +| N-gram precompute | inside timer (`NGRAM_HINT_PRECOMPUTE_OUTSIDE=0`) | +| N-gram expert | token-16 only | +| Within/word experts | disabled (`WITHIN_BOOST=0`, `WORD_BOOST=0`) | +| Phased TTT | 1 phase, 1,000 prefix docs | +| Gated XSA | enabled | +| Skylight Muon | disabled | + +## Reproducing + +Install Python dependencies from `requirements.txt`, install FlashAttention 3 as described there, and install the `lrzip` system package before launching the run. The script itself does not install packages or make network calls during training/evaluation. + +```bash +SEED=42 \ +NGRAM_HINT_PRECOMPUTE_OUTSIDE=0 \ +DATA_PATH=./data/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved \ +TOKENIZER_PATH=./data/tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model \ +CASEOPS_ENABLED=1 VOCAB_SIZE=8192 ITERATIONS=20000 MAX_WALLCLOCK_SECONDS=600 \ +TTT_ENABLED=1 PHASED_TTT_ENABLED=1 PHASED_TTT_NUM_PHASES=1 PHASED_TTT_PREFIX_DOCS=1000 \ +TTT_LORA_RANK=80 TTT_MASK=no_qv TTT_Q_LORA=0 TTT_V_LORA=0 \ +TTT_LOCAL_LR_MULT=0.75 EVAL_SEQ_LEN=2560 TTT_EVAL_SEQ_LEN=2560 \ +QK_GAIN_INIT=5.25 \ +MATRIX_LR=0.026 MIN_LR=0.1 EMBED_BITS=7 GRAD_CLIP_NORM=0.3 \ +MATRIX_CLIP_SIGMAS=12.85 ATTN_CLIP_SIGMAS=13.0 MLP_CLIP_SIGMAS=11.5 EMBED_CLIP_SIGMAS=14.0 \ +FUSED_CE_ENABLED=1 SMEAR_GATE_ENABLED=1 GATE_WINDOW=12 \ +SPARSE_ATTN_GATE_ENABLED=1 LQER_ENABLED=1 LQER_RANK=4 LQER_TOP_K=1 \ +LQER_GROUP_SIZE=64 LQER_ASYM_ENABLED=1 LQER_ASYM_GROUP=64 \ +AWQ_LITE_ENABLED=1 ASYM_LOGIT_RESCALE=1 NGRAM_TILT_ENABLED=1 \ +TOKEN_ORDER=16 TOKEN_THRESHOLD=0.800 TOKEN_BOOST=2.625 \ +WITHIN_TAU=999 WITHIN_BOOST=0 WORD_TAU=999 WORD_BOOST=0 AGREE_ADD_BOOST=0 \ +GATED_XSA=1 SKYLIGHT_MUON=0 \ +GPTQ_RESERVE_SECONDS=4.0 GPTQ_CALIBRATION_BATCHES=16 \ +COMPRESSOR=pergroup \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +## Files + +- `train_gpt.py` - complete training/eval script. +- `online_ngram_tilt.py`, `online_ngram_state.c` - token-only n-gram hint/tilt helper from the PR #1967 lineage with the conservative fast path. +- `prepare_caseops_data.py`, `lossless_caps.py` - CaseOps dataset preparation helpers. +- `DATASET_AUDIT.md`, `dataset_verification/` - dataset construction audit and verification logs. +- `tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model` - tokenizer model. +- `train_seed42.log`, `train_seed1337.log`, `train_seed2026.log` - original full per-seed training logs for the saved artifacts. +- `token_only_fast_evals/` - eval-only replay logs from the saved artifacts using the conservative token-only n-gram path. +- `submission.json` - structured metadata for the token-only 3-seed result. + +## Credits + +This work is a small stack on top of a long public lineage: + +- PR #1967 by `ndokutovich` for the V21 + LeakyReLU 0.3 + closed-form n-gram tilt stack. +- PR #1953 by `andrewbaggio1` for the long-context/no-QV TTT and QK-gain settings. +- PR #1945 by `alertcat` for the V21/AWQ-lite/asymmetric-logit-rescale base. +- PR #1948 by `TimS-ml` and `lijuncheng16` for the LQER-top-k sweep and LeakyReLU work. +- PR #1514 by `codemath3000` for the conservative token-only n-gram workaround precedent. +- PR #1145 by `AnirudhRahul` for the online n-gram augmentation lineage. +- The CaseOps lineage from `romeerp`, `dexhunter`, `aquariouseworkman`, `codemath3000`, and others for the SP8192 lossless-caps tokenizer, byte-sidecar BPB accounting, and score-first phased TTT. diff --git a/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/ap_pod_seed0/run.log b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/ap_pod_seed0/run.log new file mode 100644 index 0000000000..69208f321a --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/ap_pod_seed0/run.log @@ -0,0 +1,488 @@ +W0430 20:55:55.562000 943867 torch/distributed/run.py:803] +W0430 20:55:55.562000 943867 torch/distributed/run.py:803] ***************************************** +W0430 20:55:55.562000 943867 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0430 20:55:55.562000 943867 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + agree_add_boost: 0.5 + artifact_dir: /workspace/parameter-golf/records/miracle_gatedxsa_full50kval_ap/seed0 + attn_clip_sigmas: 13.0 + attn_out_gate_enabled: False + attn_out_gate_src: proj + awq_lite_bits: 8 + awq_lite_enabled: True + awq_lite_group_size: 64 + awq_lite_group_top_k: 1 + beta1: 0.9 + beta2: 0.95 + caseops_enabled: True + compressor: pergroup + data_dir: ./data/ + datasets_dir: /tmp/pr1855_compact_train_full50k_val/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved + distributed: True + ema_decay: 0.9965 + embed_bits: 7 + embed_clip_sigmas: 14.0 + embed_lr: 0.6 + embed_wd: 0.085 + enable_looping_at: 0.35 + eval_seq_len: 2560 + eval_stride: 64 + fused_ce_enabled: True + gate_window: 12 + gated_attn_enabled: False + gated_attn_init_std: 0.01 + gated_attn_quant_gate: False + gated_xsa_enabled: True + global_ttt_batch_seqs: 32 + global_ttt_chunk_tokens: 32768 + global_ttt_epochs: 1 + global_ttt_grad_clip: 1.0 + global_ttt_lr: 0.001 + global_ttt_momentum: 0.9 + global_ttt_respect_doc_boundaries: True + global_ttt_warmup_chunks: 0 + global_ttt_warmup_start_lr: 0.0 + gptq_calibration_batches: 16 + gptq_reserve_seconds: 4.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: /workspace/parameter-golf/records/miracle_gatedxsa_full50kval_ap/seed0/miracle_gatedxsa_full50kval_ap_s0.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + lqer_asym_enabled: True + lqer_asym_group: 64 + lqer_enabled: True + lqer_factor_bits: 4 + lqer_gain_select: False + lqer_rank: 4 + lqer_scope: all + lqer_top_k: 1 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.026 + max_wallclock_seconds: 600.0 + min_lr: 0.1 + mlp_clip_sigmas: 11.5 + mlp_mult: 4.0 + model_dim: 512 + model_path: /workspace/parameter-golf/records/miracle_gatedxsa_full50kval_ap/seed0/final_model.pt + muon_backend_steps: 5 + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + ngram_hint_precompute_outside: False + ngram_tilt_enabled: True + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_final_lane: mean + parallel_start_layer: 8 + phased_ttt_num_phases: 1 + phased_ttt_prefix_docs: 1000 + qk_gain_init: 5.25 + quantized_model_path: /workspace/parameter-golf/records/miracle_gatedxsa_full50kval_ap/seed0/final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + rope_yarn: False + run_id: miracle_gatedxsa_full50kval_ap_s0 + scalar_lr: 0.02 + seed: 0 + skip_gates_enabled: True + skylight_beta2: 0.95 + skylight_muon_enabled: False + skylight_uw_floor: 0.35 + smear_gate_enabled: True + sparse_attn_gate_enabled: True + sparse_attn_gate_init_std: 0.0 + sparse_attn_gate_scale: 1.0 + temperature_scale: 1.0 + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + token_boost: 2.625 + token_order: 16 + token_threshold: 0.8 + tokenizer_path: /tmp/pr1855_compact_train_full50k_val/tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model + train_batch_tokens: 786432 + train_files: /tmp/pr1855_compact_train_full50k_val/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.999 + ttt_chunk_size: 48 + ttt_enabled: True + ttt_eval_batches: + ttt_eval_seq_len: 2560 + ttt_grad_steps: 1 + ttt_k_lora: True + ttt_lora_lr: 0.0001 + ttt_lora_rank: 80 + ttt_mlp_lora: True + ttt_o_lora: True + ttt_optimizer: adam + ttt_weight_decay: 2.0 + val_batch_tokens: 524288 + val_bytes_files: /tmp/pr1855_compact_train_full50k_val/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_bytes_*.bin + val_doc_fraction: 1.0 + val_files: /tmp/pr1855_compact_train_full50k_val/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.75 + warmup_steps: 20 + within_boost: 0.75 + within_tau: 0.45 + word_boost: 0.75 + word_normalize: strip_punct_lower + word_order: 4 + word_tau: 0.65 + world_size: 8 + xsa_last_n: 11 +train_shards: 80 +val_tokens: 47851520 +model_params:35945761 +gptq:reserving 4s, effective=596000ms +warmup_cu_buckets:64,128,192,256 iters_each:3 +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0094 val_bpb: 4.1166 +1/20000 train_loss: 9.0099 train_time: 0.0m tok/s: 15194170 +2/20000 train_loss: 12.9098 train_time: 0.0m tok/s: 10973614 +3/20000 train_loss: 10.2004 train_time: 0.0m tok/s: 10009477 +4/20000 train_loss: 8.6800 train_time: 0.0m tok/s: 9529757 +5/20000 train_loss: 7.8744 train_time: 0.0m tok/s: 9255583 +500/20000 train_loss: 2.7223 train_time: 0.8m tok/s: 8266988 +1000/20000 train_loss: 2.7769 train_time: 1.6m tok/s: 8235087 +1500/20000 train_loss: 2.5838 train_time: 2.4m tok/s: 8231573 +2000/20000 train_loss: 2.5776 train_time: 3.2m tok/s: 8232989 +layer_loop:enabled step:2184 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +2500/20000 train_loss: 2.4728 train_time: 4.2m tok/s: 7768854 +3000/20000 train_loss: 2.6396 train_time: 5.4m tok/s: 7298755 +3500/20000 train_loss: 2.4162 train_time: 6.6m tok/s: 6977483 +4000/20000 train_loss: 2.5719 train_time: 7.7m tok/s: 6770128 +4000/20000 val_loss: 2.3850 val_bpb: 1.0898 +4500/20000 train_loss: 2.3871 train_time: 8.9m tok/s: 6617525 +4936/20000 val_loss: 2.2990 val_bpb: 1.0505 +stopping_early: wallclock_cap train_time: 596057ms step: 4936/20000 +peak memory allocated: 41724 MiB reserved: 46960 MiB +ema:applying EMA weights +diagnostic pre-quantization post-ema val_loss:2.29988038 val_bpb:1.05086399 eval_time:9646ms +Serialized model: 135421514 bytes +Code size (uncompressed): 187853 bytes +Code size (compressed): 47669 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 4.1s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int7)+awqgrpint8+lqer_asym: tok_emb.weight + passthrough (float16): blocks.attn.attn_gate_w, blocks.attn.q_gain, blocks.attn.xsa_alpha, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, parallel_post_lambdas, parallel_resid_lambdas, skip_gates, skip_weights, smear_gate.weight, smear_lambda, softcap_neg, softcap_pos +Serialize: per-group lrzip compression... +Serialize: per-group compression done in 120.1s +Serialized model quantized+pergroup: 15951038 bytes +Total submission size quantized+pergroup: 15998707 bytes +Deserialize: per-group lrzip decompression... +Deserialize: decompression done in 21.0s +diagnostic quantized val_loss:2.31861082 val_bpb:1.05942233 eval_time:28979ms +Deserialize: per-group lrzip decompression... +Deserialize: decompression done in 21.1s +ttt_lora:warming up compile (random tokens, no val data) +ttt_lora:compile warmup done (104.6s) + +beginning TTT eval timer +ngram_tilt:hints total=47851520 gated=13023303 token_gate=628130 within_gate=9866847 word_gate=2891588 agree2plus=303177 +ngram_tilt:precompute_done elapsed=168.13s total_targets=47851520 +ttt_phased: total_docs:50000 prefix_docs:1000 suffix_docs:49000 num_phases:1 boundaries:[1000] +ttp: b776/782 bl:2.2247 bb:1.0548 rl:2.2247 rb:1.0548 dl:7534-8350 gd:0 +ttp: b773/782 bl:2.1754 bb:1.0245 rl:2.2029 rb:1.0413 dl:6104-6447 gd:0 +ttp: b769/782 bl:2.2979 bb:1.0697 rl:2.2283 rb:1.0490 dl:5097-5309 gd:0 +ttp: b765/782 bl:2.2872 bb:1.0697 rl:2.2393 rb:1.0529 dl:4393-4510 gd:0 +ttpp: phase:1/1 pd:1424 gd:1000 t:253.3s +tttg: c1/154 lr:0.001000 t:0.3s +tttg: c2/154 lr:0.001000 t:0.4s +tttg: c3/154 lr:0.001000 t:0.5s +tttg: c4/154 lr:0.000999 t:0.6s +tttg: c5/154 lr:0.000998 t:0.7s +tttg: c6/154 lr:0.000997 t:0.8s +tttg: c7/154 lr:0.000996 t:0.9s +tttg: c8/154 lr:0.000995 t:0.9s +tttg: c9/154 lr:0.000993 t:1.0s +tttg: c10/154 lr:0.000991 t:1.1s +tttg: c11/154 lr:0.000989 t:1.2s +tttg: c12/154 lr:0.000987 t:1.3s +tttg: c13/154 lr:0.000985 t:1.3s +tttg: c14/154 lr:0.000982 t:1.4s +tttg: c15/154 lr:0.000979 t:1.5s +tttg: c16/154 lr:0.000976 t:1.6s +tttg: c17/154 lr:0.000973 t:1.7s +tttg: c18/154 lr:0.000970 t:1.7s +tttg: c19/154 lr:0.000966 t:1.8s +tttg: c20/154 lr:0.000962 t:1.9s +tttg: c21/154 lr:0.000958 t:2.0s +tttg: c22/154 lr:0.000954 t:2.1s +tttg: c23/154 lr:0.000950 t:2.2s +tttg: c24/154 lr:0.000945 t:2.2s +tttg: c25/154 lr:0.000941 t:2.3s +tttg: c26/154 lr:0.000936 t:2.4s +tttg: c27/154 lr:0.000930 t:2.5s +tttg: c28/154 lr:0.000925 t:2.6s +tttg: c29/154 lr:0.000920 t:2.6s +tttg: c30/154 lr:0.000914 t:2.7s +tttg: c31/154 lr:0.000908 t:2.8s +tttg: c32/154 lr:0.000902 t:2.9s +tttg: c33/154 lr:0.000896 t:3.0s +tttg: c34/154 lr:0.000890 t:3.1s +tttg: c35/154 lr:0.000883 t:3.1s +tttg: c36/154 lr:0.000876 t:3.2s +tttg: c37/154 lr:0.000870 t:3.3s +tttg: c38/154 lr:0.000863 t:3.4s +tttg: c39/154 lr:0.000855 t:3.5s +tttg: c40/154 lr:0.000848 t:3.5s +tttg: c41/154 lr:0.000841 t:3.6s +tttg: c42/154 lr:0.000833 t:3.7s +tttg: c43/154 lr:0.000825 t:3.8s +tttg: c44/154 lr:0.000817 t:3.9s +tttg: c45/154 lr:0.000809 t:4.0s +tttg: c46/154 lr:0.000801 t:4.0s +tttg: c47/154 lr:0.000793 t:4.1s +tttg: c48/154 lr:0.000785 t:4.2s +tttg: c49/154 lr:0.000776 t:4.3s +tttg: c50/154 lr:0.000768 t:4.4s +tttg: c51/154 lr:0.000759 t:4.5s +tttg: c52/154 lr:0.000750 t:4.5s +tttg: c53/154 lr:0.000741 t:4.6s +tttg: c54/154 lr:0.000732 t:4.7s +tttg: c55/154 lr:0.000723 t:4.8s +tttg: c56/154 lr:0.000714 t:4.9s +tttg: c57/154 lr:0.000704 t:4.9s +tttg: c58/154 lr:0.000695 t:5.0s +tttg: c59/154 lr:0.000685 t:5.1s +tttg: c60/154 lr:0.000676 t:5.2s +tttg: c61/154 lr:0.000666 t:5.3s +tttg: c62/154 lr:0.000656 t:5.3s +tttg: c63/154 lr:0.000647 t:5.4s +tttg: c64/154 lr:0.000637 t:5.5s +tttg: c65/154 lr:0.000627 t:5.6s +tttg: c66/154 lr:0.000617 t:5.7s +tttg: c67/154 lr:0.000607 t:5.8s +tttg: c68/154 lr:0.000597 t:5.8s +tttg: c69/154 lr:0.000587 t:5.9s +tttg: c70/154 lr:0.000577 t:6.0s +tttg: c71/154 lr:0.000567 t:6.1s +tttg: c72/154 lr:0.000556 t:6.2s +tttg: c73/154 lr:0.000546 t:6.3s +tttg: c74/154 lr:0.000536 t:6.3s +tttg: c75/154 lr:0.000526 t:6.4s +tttg: c76/154 lr:0.000515 t:6.5s +tttg: c77/154 lr:0.000505 t:6.6s +tttg: c78/154 lr:0.000495 t:6.7s +tttg: c79/154 lr:0.000485 t:6.7s +tttg: c80/154 lr:0.000474 t:6.8s +tttg: c81/154 lr:0.000464 t:6.9s +tttg: c82/154 lr:0.000454 t:7.0s +tttg: c83/154 lr:0.000444 t:7.1s +tttg: c84/154 lr:0.000433 t:7.1s +tttg: c85/154 lr:0.000423 t:7.2s +tttg: c86/154 lr:0.000413 t:7.3s +tttg: c87/154 lr:0.000403 t:7.4s +tttg: c88/154 lr:0.000393 t:7.5s +tttg: c89/154 lr:0.000383 t:7.5s +tttg: c90/154 lr:0.000373 t:7.6s +tttg: c91/154 lr:0.000363 t:7.7s +tttg: c92/154 lr:0.000353 t:7.8s +tttg: c93/154 lr:0.000344 t:7.9s +tttg: c94/154 lr:0.000334 t:7.9s +tttg: c95/154 lr:0.000324 t:8.0s +tttg: c96/154 lr:0.000315 t:8.1s +tttg: c97/154 lr:0.000305 t:8.2s +tttg: c98/154 lr:0.000296 t:8.3s +tttg: c99/154 lr:0.000286 t:8.4s +tttg: c100/154 lr:0.000277 t:8.4s +tttg: c101/154 lr:0.000268 t:8.5s +tttg: c102/154 lr:0.000259 t:8.6s +tttg: c103/154 lr:0.000250 t:8.7s +tttg: c104/154 lr:0.000241 t:8.8s +tttg: c105/154 lr:0.000232 t:8.8s +tttg: c106/154 lr:0.000224 t:8.9s +tttg: c107/154 lr:0.000215 t:9.0s +tttg: c108/154 lr:0.000207 t:9.1s +tttg: c109/154 lr:0.000199 t:9.2s +tttg: c110/154 lr:0.000191 t:9.3s +tttg: c111/154 lr:0.000183 t:9.3s +tttg: c112/154 lr:0.000175 t:9.4s +tttg: c113/154 lr:0.000167 t:9.5s +tttg: c114/154 lr:0.000159 t:9.6s +tttg: c115/154 lr:0.000152 t:9.7s +tttg: c116/154 lr:0.000145 t:9.7s +tttg: c117/154 lr:0.000137 t:9.8s +tttg: c118/154 lr:0.000130 t:9.9s +tttg: c119/154 lr:0.000124 t:10.0s +tttg: c120/154 lr:0.000117 t:10.1s +tttg: c121/154 lr:0.000110 t:10.2s +tttg: c122/154 lr:0.000104 t:10.2s +tttg: c123/154 lr:0.000098 t:10.3s +tttg: c124/154 lr:0.000092 t:10.4s +tttg: c125/154 lr:0.000086 t:10.5s +tttg: c126/154 lr:0.000080 t:10.6s +tttg: c127/154 lr:0.000075 t:10.6s +tttg: c128/154 lr:0.000070 t:10.7s +tttg: c129/154 lr:0.000064 t:10.8s +tttg: c130/154 lr:0.000059 t:10.9s +tttg: c131/154 lr:0.000055 t:11.0s +tttg: c132/154 lr:0.000050 t:11.0s +tttg: c133/154 lr:0.000046 t:11.1s +tttg: c134/154 lr:0.000042 t:11.2s +tttg: c135/154 lr:0.000038 t:11.3s +tttg: c136/154 lr:0.000034 t:11.4s +tttg: c137/154 lr:0.000030 t:11.5s +tttg: c138/154 lr:0.000027 t:11.6s +tttg: c139/154 lr:0.000024 t:11.6s +tttg: c140/154 lr:0.000021 t:11.7s +tttg: c141/154 lr:0.000018 t:11.8s +tttg: c142/154 lr:0.000015 t:11.9s +tttg: c143/154 lr:0.000013 t:12.0s +tttg: c144/154 lr:0.000011 t:12.0s +tttg: c145/154 lr:0.000009 t:12.1s +tttg: c146/154 lr:0.000007 t:12.2s +tttg: c147/154 lr:0.000005 t:12.3s +tttg: c148/154 lr:0.000004 t:12.4s +tttg: c149/154 lr:0.000003 t:12.4s +tttg: c150/154 lr:0.000002 t:12.5s +tttg: c151/154 lr:0.000001 t:12.6s +tttg: c152/154 lr:0.000000 t:12.7s +tttg: c153/154 lr:0.000000 t:12.8s +ttpr: phase:1/1 t:267.8s +ttp: b753/782 bl:2.1866 bb:0.9871 rl:2.2329 rb:1.0445 dl:3284-3344 gd:1 +ttp: b750/782 bl:2.3600 bb:1.0604 rl:2.2460 rb:1.0462 dl:3090-3149 gd:1 +ttp: b748/782 bl:2.2883 bb:1.0679 rl:2.2498 rb:1.0482 dl:2992-3039 gd:1 +ttp: b745/782 bl:2.2018 bb:1.0080 rl:2.2460 rb:1.0450 dl:2842-2883 gd:1 +ttp: b741/782 bl:2.2840 bb:1.0243 rl:2.2486 rb:1.0435 dl:2686-2730 gd:1 +ttp: b738/782 bl:2.2853 bb:1.0348 rl:2.2509 rb:1.0429 dl:2583-2618 gd:1 +ttp: b735/782 bl:2.3578 bb:1.0847 rl:2.2570 rb:1.0453 dl:2495-2526 gd:1 +ttp: b733/782 bl:2.3483 bb:1.0514 rl:2.2619 rb:1.0456 dl:2441-2468 gd:1 +ttp: b728/782 bl:2.3306 bb:1.0670 rl:2.2651 rb:1.0467 dl:2306-2324 gd:1 +ttp: b725/782 bl:2.2872 bb:1.0289 rl:2.2661 rb:1.0459 dl:2232-2254 gd:1 +ttp: b724/782 bl:2.2946 bb:1.0477 rl:2.2673 rb:1.0459 dl:2203-2231 gd:1 +ttp: b719/782 bl:2.2959 bb:1.0339 rl:2.2684 rb:1.0455 dl:2106-2125 gd:1 +ttp: b717/782 bl:2.2269 bb:1.0196 rl:2.2669 rb:1.0445 dl:2070-2088 gd:1 +ttp: b714/782 bl:2.2812 bb:1.0104 rl:2.2674 rb:1.0433 dl:2018-2035 gd:1 +ttp: b707/782 bl:2.3297 bb:1.0353 rl:2.2693 rb:1.0431 dl:1910-1923 gd:1 +ttp: b701/782 bl:2.2866 bb:1.0252 rl:2.2698 rb:1.0425 dl:1835-1847 gd:1 +ttp: b695/782 bl:2.3146 bb:1.0675 rl:2.2710 rb:1.0432 dl:1769-1779 gd:1 +ttp: b687/782 bl:2.2828 bb:1.0424 rl:2.2713 rb:1.0432 dl:1685-1696 gd:1 +ttp: b680/782 bl:2.2567 bb:1.0163 rl:2.2710 rb:1.0425 dl:1618-1628 gd:1 +ttp: b674/782 bl:2.3772 bb:1.0767 rl:2.2734 rb:1.0433 dl:1571-1578 gd:1 +ttp: b668/782 bl:2.3028 bb:1.0528 rl:2.2740 rb:1.0435 dl:1521-1530 gd:1 +ttp: b661/782 bl:2.3719 bb:1.0723 rl:2.2760 rb:1.0441 dl:1474-1480 gd:1 +ttp: b654/782 bl:2.2649 bb:1.0244 rl:2.2758 rb:1.0437 dl:1425-1432 gd:1 +ttp: b647/782 bl:2.2454 bb:1.0191 rl:2.2752 rb:1.0433 dl:1382-1387 gd:1 +ttp: b640/782 bl:2.2741 bb:1.0360 rl:2.2752 rb:1.0431 dl:1337-1343 gd:1 +ttp: b633/782 bl:2.2485 bb:1.0103 rl:2.2748 rb:1.0426 dl:1297-1302 gd:1 +ttp: b625/782 bl:2.3754 bb:1.0364 rl:2.2764 rb:1.0425 dl:1255-1260 gd:1 +ttp: b618/782 bl:2.3758 bb:1.0574 rl:2.2779 rb:1.0427 dl:1216-1221 gd:1 +ttp: b611/782 bl:2.2638 bb:1.0109 rl:2.2777 rb:1.0423 dl:1182-1186 gd:1 +ttp: b605/782 bl:2.2151 bb:1.0102 rl:2.2768 rb:1.0418 dl:1154-1159 gd:1 +ttp: b600/782 bl:2.2297 bb:0.9991 rl:2.2761 rb:1.0412 dl:1133-1137 gd:1 +ttp: b589/782 bl:2.2438 bb:0.9965 rl:2.2757 rb:1.0406 dl:1086-1089 gd:1 +ttp: b582/782 bl:2.3178 bb:1.0181 rl:2.2763 rb:1.0403 dl:1056-1060 gd:1 +ttp: b573/782 bl:2.3381 bb:1.0540 rl:2.2770 rb:1.0405 dl:1021-1025 gd:1 +ttp: b566/782 bl:2.2704 bb:1.0141 rl:2.2769 rb:1.0402 dl:997-1001 gd:1 +ttp: b558/782 bl:2.3386 bb:1.0459 rl:2.2776 rb:1.0403 dl:968-972 gd:1 +ttp: b551/782 bl:2.3008 bb:1.0399 rl:2.2778 rb:1.0403 dl:946-949 gd:1 +ttp: b543/782 bl:2.3083 bb:1.0451 rl:2.2781 rb:1.0403 dl:921-924 gd:1 +ttp: b536/782 bl:2.2804 bb:1.0269 rl:2.2781 rb:1.0402 dl:899-902 gd:1 +ttp: b528/782 bl:2.2988 bb:1.0275 rl:2.2783 rb:1.0401 dl:875-878 gd:1 +ttp: b520/782 bl:2.2988 bb:0.9912 rl:2.2785 rb:1.0396 dl:852-854 gd:1 +ttp: b512/782 bl:2.2699 bb:1.0483 rl:2.2784 rb:1.0397 dl:829-832 gd:1 +ttp: b504/782 bl:2.2909 bb:1.0220 rl:2.2786 rb:1.0395 dl:807-809 gd:1 +ttp: b496/782 bl:2.3869 bb:1.0333 rl:2.2794 rb:1.0395 dl:785-788 gd:1 +ttp: b488/782 bl:2.2675 bb:0.9978 rl:2.2793 rb:1.0391 dl:766-769 gd:1 +ttp: b484/782 bl:2.3360 bb:1.0351 rl:2.2798 rb:1.0391 dl:756-759 gd:1 +ttp: b476/782 bl:2.2451 bb:1.0175 rl:2.2795 rb:1.0389 dl:738-740 gd:1 +ttp: b468/782 bl:2.3222 bb:1.0453 rl:2.2798 rb:1.0390 dl:719-721 gd:1 +ttp: b461/782 bl:2.3470 bb:1.0268 rl:2.2803 rb:1.0389 dl:703-706 gd:1 +ttp: b453/782 bl:2.3031 bb:1.0406 rl:2.2805 rb:1.0389 dl:687-689 gd:1 +ttp: b445/782 bl:2.3232 bb:1.0325 rl:2.2807 rb:1.0388 dl:670-672 gd:1 +ttp: b437/782 bl:2.2665 bb:1.0428 rl:2.2807 rb:1.0389 dl:653-655 gd:1 +ttp: b429/782 bl:2.2172 bb:1.0112 rl:2.2803 rb:1.0387 dl:638-640 gd:1 +ttp: b421/782 bl:2.2611 bb:0.9899 rl:2.2801 rb:1.0384 dl:622-624 gd:1 +ttp: b413/782 bl:2.3455 bb:1.0512 rl:2.2805 rb:1.0385 dl:607-609 gd:1 +ttp: b405/782 bl:2.3242 bb:1.0430 rl:2.2808 rb:1.0385 dl:592-593 gd:1 +ttp: b397/782 bl:2.3307 bb:1.0336 rl:2.2810 rb:1.0385 dl:577-579 gd:1 +ttp: b386/782 bl:2.3017 bb:1.0809 rl:2.2812 rb:1.0387 dl:557-559 gd:1 +ttp: b378/782 bl:2.3962 bb:1.0397 rl:2.2817 rb:1.0387 dl:544-545 gd:1 +ttp: b370/782 bl:2.3333 bb:1.0681 rl:2.2820 rb:1.0388 dl:530-532 gd:1 +ttp: b361/782 bl:2.3244 bb:1.0852 rl:2.2822 rb:1.0391 dl:515-517 gd:1 +ttp: b354/782 bl:2.2703 bb:1.0503 rl:2.2821 rb:1.0391 dl:503-504 gd:1 +ttp: b348/782 bl:2.3322 bb:1.0460 rl:2.2824 rb:1.0391 dl:494-495 gd:1 +ttp: b339/782 bl:2.3115 bb:1.0673 rl:2.2825 rb:1.0393 dl:480-482 gd:1 +ttp: b331/782 bl:2.3107 bb:1.0678 rl:2.2826 rb:1.0394 dl:468-469 gd:1 +ttp: b324/782 bl:2.2869 bb:1.0692 rl:2.2826 rb:1.0395 dl:458-459 gd:1 +ttp: b316/782 bl:2.3298 bb:1.0628 rl:2.2828 rb:1.0396 dl:445-446 gd:1 +ttp: b307/782 bl:2.3032 bb:1.1134 rl:2.2829 rb:1.0399 dl:432-433 gd:1 +ttp: b303/782 bl:2.3530 bb:1.0733 rl:2.2832 rb:1.0400 dl:426-427 gd:1 +ttp: b295/782 bl:2.2384 bb:1.0501 rl:2.2830 rb:1.0400 dl:414-415 gd:1 +ttp: b287/782 bl:2.3679 bb:1.0788 rl:2.2833 rb:1.0402 dl:402-403 gd:1 +ttp: b279/782 bl:2.2794 bb:1.0770 rl:2.2833 rb:1.0403 dl:391-392 gd:1 +ttp: b270/782 bl:2.2842 bb:1.0451 rl:2.2833 rb:1.0403 dl:379-380 gd:1 +ttp: b262/782 bl:2.4131 bb:1.1289 rl:2.2837 rb:1.0406 dl:369-370 gd:1 +ttp: b255/782 bl:2.3285 bb:1.0739 rl:2.2839 rb:1.0407 dl:360-361 gd:1 +ttp: b247/782 bl:2.3192 bb:1.0795 rl:2.2840 rb:1.0408 dl:350-351 gd:1 +ttp: b239/782 bl:2.3477 bb:1.0901 rl:2.2842 rb:1.0410 dl:340-341 gd:1 +ttp: b230/782 bl:2.4168 bb:1.1341 rl:2.2846 rb:1.0412 dl:329-330 gd:1 +ttp: b222/782 bl:2.3411 bb:1.0943 rl:2.2847 rb:1.0414 dl:320-321 gd:1 +ttp: b214/782 bl:2.3151 bb:1.1078 rl:2.2848 rb:1.0415 dl:310-312 gd:1 +ttp: b206/782 bl:2.3670 bb:1.0889 rl:2.2850 rb:1.0417 dl:302-303 gd:1 +ttp: b198/782 bl:2.3649 bb:1.0462 rl:2.2852 rb:1.0417 dl:294-295 gd:1 +ttp: b190/782 bl:2.3176 bb:1.0655 rl:2.2853 rb:1.0417 dl:284-285 gd:1 +ttp: b182/782 bl:2.3221 bb:1.1040 rl:2.2854 rb:1.0419 dl:276-277 gd:1 +ttp: b174/782 bl:2.4134 bb:1.1383 rl:2.2857 rb:1.0421 dl:268-269 gd:1 +ttp: b166/782 bl:2.4478 bb:1.0941 rl:2.2860 rb:1.0422 dl:260-262 gd:1 +ttp: b158/782 bl:2.3040 bb:1.0893 rl:2.2861 rb:1.0423 dl:253-254 gd:1 +ttp: b150/782 bl:2.3126 bb:1.0981 rl:2.2861 rb:1.0424 dl:245-246 gd:1 +ttp: b142/782 bl:2.3448 bb:1.0914 rl:2.2863 rb:1.0425 dl:237-238 gd:1 +ttp: b134/782 bl:2.3953 bb:1.1233 rl:2.2865 rb:1.0427 dl:230-231 gd:1 +ttp: b124/782 bl:2.3429 bb:1.1444 rl:2.2866 rb:1.0428 dl:220-222 gd:1 +ttp: b117/782 bl:2.4403 bb:1.1858 rl:2.2868 rb:1.0431 dl:214-215 gd:1 +ttp: b110/782 bl:2.3186 bb:1.1003 rl:2.2869 rb:1.0432 dl:208-208 gd:1 +ttp: b101/782 bl:2.4798 bb:1.1398 rl:2.2872 rb:1.0433 dl:200-201 gd:1 +ttp: b93/782 bl:2.4427 bb:1.1716 rl:2.2875 rb:1.0435 dl:192-193 gd:1 +ttp: b86/782 bl:2.4373 bb:1.1245 rl:2.2877 rb:1.0437 dl:186-187 gd:1 +ttp: b79/782 bl:2.3617 bb:1.1290 rl:2.2878 rb:1.0438 dl:180-181 gd:1 +ttp: b70/782 bl:2.4865 bb:1.2117 rl:2.2881 rb:1.0440 dl:172-173 gd:1 +ttp: b62/782 bl:2.4104 bb:1.1614 rl:2.2883 rb:1.0442 dl:165-166 gd:1 +ttp: b55/782 bl:2.5796 bb:1.2142 rl:2.2887 rb:1.0444 dl:158-159 gd:1 +ttp: b47/782 bl:2.4064 bb:1.1234 rl:2.2888 rb:1.0445 dl:150-151 gd:1 +ttp: b39/782 bl:2.4054 bb:1.1644 rl:2.2890 rb:1.0446 dl:142-143 gd:1 +ttp: b31/782 bl:2.3914 bb:1.1344 rl:2.2891 rb:1.0447 dl:134-135 gd:1 +ttp: b23/782 bl:2.5520 bb:1.1986 rl:2.2893 rb:1.0449 dl:126-127 gd:1 +ttp: b15/782 bl:2.6189 bb:1.2162 rl:2.2897 rb:1.0450 dl:115-117 gd:1 +ttp: b6/782 bl:2.6751 bb:1.1927 rl:2.2900 rb:1.0452 dl:99-101 gd:1 +quantized_ttt_phased val_loss:2.29051850 val_bpb:1.04667757 eval_time:639926ms +total_eval_time:639.9s diff --git a/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/dataset_verification/manifest_compare.json b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/dataset_verification/manifest_compare.json new file mode 100644 index 0000000000..2a4983eb8e --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/dataset_verification/manifest_compare.json @@ -0,0 +1,15 @@ +{ + "exact_train_shards_seen": 82, + "exact_train_first80_tokens": 800000000, + "archive_train_first80_tokens": 800000000, + "train_first80_hash_mismatches": 0, + "train_mismatches_first5": [], + "exact_val_shards": 1, + "exact_val_byte_shards": 1, + "exact_val_tokens": 9662502, + "exact_val_byte_entries": 9662502, + "archive_val_tokens": 9662502, + "archive_val_byte_entries": 9662502, + "val_hash_match": true, + "val_bytes_hash_match": true +} diff --git a/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/dataset_verification/monitor.log b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/dataset_verification/monitor.log new file mode 100644 index 0000000000..29c9a31970 --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/dataset_verification/monitor.log @@ -0,0 +1,43 @@ +monitor_start=2026-04-30T19:33:09+00:00 +2026-04-30T19:33:09+00:00 train_shards=10 +2026-04-30T19:34:09+00:00 train_shards=13 +2026-04-30T19:35:09+00:00 train_shards=16 +2026-04-30T19:36:09+00:00 train_shards=19 +2026-04-30T19:37:09+00:00 train_shards=22 +2026-04-30T19:38:09+00:00 train_shards=25 +2026-04-30T19:39:09+00:00 train_shards=28 +2026-04-30T19:40:09+00:00 train_shards=31 +2026-04-30T19:41:09+00:00 train_shards=34 +2026-04-30T19:42:09+00:00 train_shards=37 +2026-04-30T19:43:09+00:00 train_shards=40 +2026-04-30T19:44:09+00:00 train_shards=43 +2026-04-30T19:45:09+00:00 train_shards=46 +2026-04-30T19:46:09+00:00 train_shards=49 +2026-04-30T19:47:09+00:00 train_shards=52 +2026-04-30T19:48:09+00:00 train_shards=55 +2026-04-30T19:49:09+00:00 train_shards=58 +2026-04-30T19:50:09+00:00 train_shards=61 +2026-04-30T19:51:09+00:00 train_shards=64 +2026-04-30T19:52:09+00:00 train_shards=67 +2026-04-30T19:53:09+00:00 train_shards=70 +2026-04-30T19:54:09+00:00 train_shards=73 +2026-04-30T19:55:09+00:00 train_shards=76 +2026-04-30T19:56:10+00:00 train_shards=79 +2026-04-30T19:57:10+00:00 train_shards=82 +reached train_shards=82; stopping exact long producer/consumer +{ + "exact_train_shards_seen": 82, + "exact_train_first80_tokens": 800000000, + "archive_train_first80_tokens": 800000000, + "train_first80_hash_mismatches": 0, + "train_mismatches_first5": [], + "exact_val_shards": 1, + "exact_val_byte_shards": 1, + "exact_val_tokens": 9662502, + "exact_val_byte_entries": 9662502, + "archive_val_tokens": 9662502, + "archive_val_byte_entries": 9662502, + "val_hash_match": true, + "val_bytes_hash_match": true +} +monitor_done=2026-04-30T19:58:06+00:00 diff --git a/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/dataset_verification/repair_full50k_val_ap.log b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/dataset_verification/repair_full50k_val_ap.log new file mode 100644 index 0000000000..fcc699f138 --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/dataset_verification/repair_full50k_val_ap.log @@ -0,0 +1,11 @@ +processed 5000 docs val_shards=0 buffered=4868711 +processed 10000 docs val_shards=0 buffered=9662502 +processed 15000 docs val_shards=1 buffered=4359790 +processed 20000 docs val_shards=1 buffered=9231754 +processed 25000 docs val_shards=2 buffered=4164835 +processed 30000 docs val_shards=2 buffered=8727017 +processed 35000 docs val_shards=3 buffered=3527019 +processed 40000 docs val_shards=3 buffered=8448288 +processed 45000 docs val_shards=4 buffered=3163420 +processed 50000 docs val_shards=4 buffered=7853344 +done docs=50000 val_shards=5 val_tokens=47853344 diff --git a/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/lossless_caps.py b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/lossless_caps.py new file mode 100644 index 0000000000..98e472f824 --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/lossless_caps.py @@ -0,0 +1,833 @@ +"""Lossless capitalization pre-encoding helpers. + +This module provides a narrow, reversible transform that only touches +ASCII capital letters `A-Z`. Each uppercase ASCII letter is rewritten as +``, where `sentinel` is a private-use Unicode +character that is escaped by doubling if it appears literally in the +input text. + +Example with the default sentinel `\\uE000`: + + "The NASA Launch" -> "\\uE000the \\uE000n\\uE000a\\uE000s\\uE000a \\uE000launch" + +The transform is intentionally simple for v1: + +- lowercase ASCII letters are unchanged +- uppercase ASCII letters become sentinel + lowercase letter +- non-ASCII characters are left untouched +- literal sentinel characters are escaped as sentinel + sentinel + +This makes the transform exactly invertible while allowing a downstream +tokenizer to reuse lowercase subwords across case variants. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Callable, Iterable + +LOSSLESS_CAPS_V1 = "lossless_caps_v1" +LOSSLESS_CAPS_V2 = "lossless_caps_v2" +LOSSLESS_CAPS_V3 = "lossless_caps_v3" +LOSSLESS_CAPS_V4 = "lossless_caps_v4" +LOSSLESS_CAPS_V5 = "lossless_caps_v5" +LOSSLESS_CAPS_V6 = "lossless_caps_v6" +LOSSLESS_CAPS_V7 = "lossless_caps_v7" +LOSSLESS_CAPS_CASEOPS_V1 = "lossless_caps_caseops_v1" +IDENTITY = "identity" +DEFAULT_SENTINEL = "\uE000" +DEFAULT_V2_TITLE = "\uE001" +DEFAULT_V2_ALLCAPS = "\uE002" +DEFAULT_V2_CAPNEXT = "\uE003" +DEFAULT_V2_ESC = "\uE004" +DEFAULT_V5_TITLE_MIN_LEN = 7 +DEFAULT_V6_ALLCAPS_MIN_LEN = 3 +DEFAULT_V7_ALLCAPS_MIN_LEN = 4 + + +class LosslessCapsError(ValueError): + """Raised when a transformed string is malformed.""" + + +def _is_ascii_upper(ch: str) -> bool: + return "A" <= ch <= "Z" + + +def _is_ascii_lower(ch: str) -> bool: + return "a" <= ch <= "z" + + +def _is_ascii_alpha(ch: str) -> bool: + return _is_ascii_lower(ch) or _is_ascii_upper(ch) + + +def _validate_distinct_single_chars(*chars: str) -> None: + if any(len(ch) != 1 for ch in chars): + raise ValueError("all control characters must be exactly one character") + if len(set(chars)) != len(chars): + raise ValueError("control characters must be distinct") + + +def encode_lossless_caps_v1(text: str, *, sentinel: str = DEFAULT_SENTINEL) -> str: + """Encode ASCII capitals reversibly using a one-character sentinel.""" + if len(sentinel) != 1: + raise ValueError("sentinel must be exactly one character") + out: list[str] = [] + for ch in text: + if ch == sentinel: + out.append(sentinel) + out.append(sentinel) + elif _is_ascii_upper(ch): + out.append(sentinel) + out.append(ch.lower()) + else: + out.append(ch) + return "".join(out) + + +def decode_lossless_caps_v1(text: str, *, sentinel: str = DEFAULT_SENTINEL) -> str: + """Decode the `lossless_caps_v1` transform back to the original text.""" + if len(sentinel) != 1: + raise ValueError("sentinel must be exactly one character") + out: list[str] = [] + i = 0 + n = len(text) + while i < n: + ch = text[i] + if ch != sentinel: + out.append(ch) + i += 1 + continue + if i + 1 >= n: + raise LosslessCapsError("dangling capitalization sentinel at end of string") + nxt = text[i + 1] + if nxt == sentinel: + out.append(sentinel) + elif _is_ascii_lower(nxt): + out.append(nxt.upper()) + else: + raise LosslessCapsError( + f"invalid sentinel escape sequence {sentinel + nxt!r}; " + "expected doubled sentinel or sentinel + lowercase ASCII letter" + ) + i += 2 + return "".join(out) + + +def encode_lossless_caps_v2( + text: str, + *, + title: str = DEFAULT_V2_TITLE, + allcaps: str = DEFAULT_V2_ALLCAPS, + capnext: str = DEFAULT_V2_CAPNEXT, + esc: str = DEFAULT_V2_ESC, +) -> str: + """Encode ASCII word capitalization with cheap word-level markers. + + Rules over maximal ASCII alphabetic runs: + - lowercase words stay unchanged + - TitleCase words become `title + lowercase(word)` + - ALLCAPS words become `allcaps + lowercase(word)` + - mixed-case words use: + - optional `title` when the first letter is uppercase + - `capnext + lowercase(letter)` for subsequent uppercase letters + - literal control characters are escaped as `esc + literal` + """ + _validate_distinct_single_chars(title, allcaps, capnext, esc) + controls = {title, allcaps, capnext, esc} + out: list[str] = [] + i = 0 + n = len(text) + while i < n: + ch = text[i] + if ch in controls: + out.append(esc) + out.append(ch) + i += 1 + continue + if not _is_ascii_alpha(ch): + out.append(ch) + i += 1 + continue + + j = i + 1 + while j < n and _is_ascii_alpha(text[j]): + j += 1 + word = text[i:j] + lower_word = word.lower() + + if word.islower(): + out.append(word) + elif len(word) >= 2 and word.isupper(): + out.append(allcaps) + out.append(lower_word) + elif _is_ascii_upper(word[0]) and word[1:].islower(): + out.append(title) + out.append(lower_word) + else: + if _is_ascii_upper(word[0]): + out.append(title) + out.append(lower_word[0]) + for orig_ch, lower_ch in zip(word[1:], lower_word[1:], strict=True): + if _is_ascii_upper(orig_ch): + out.append(capnext) + out.append(lower_ch) + i = j + return "".join(out) + + +def decode_lossless_caps_v2( + text: str, + *, + title: str = DEFAULT_V2_TITLE, + allcaps: str = DEFAULT_V2_ALLCAPS, + capnext: str = DEFAULT_V2_CAPNEXT, + esc: str = DEFAULT_V2_ESC, +) -> str: + """Decode the `lossless_caps_v2` transform back to the original text.""" + _validate_distinct_single_chars(title, allcaps, capnext, esc) + out: list[str] = [] + pending_escape = False + pending_word_mode: str | None = None + active_allcaps = False + pending_capnext = False + in_ascii_word = False + + for ch in text: + if pending_escape: + if pending_word_mode is not None and not _is_ascii_alpha(ch): + raise LosslessCapsError("escaped control char cannot satisfy pending word capitalization mode") + out.append(ch) + pending_escape = False + if _is_ascii_alpha(ch): + in_ascii_word = True + else: + in_ascii_word = False + active_allcaps = False + continue + + if ch == esc: + pending_escape = True + continue + if ch == title: + if pending_word_mode is not None or in_ascii_word or pending_capnext: + raise LosslessCapsError("invalid title marker placement") + pending_word_mode = "title" + continue + if ch == allcaps: + if pending_word_mode is not None or in_ascii_word or pending_capnext: + raise LosslessCapsError("invalid allcaps marker placement") + pending_word_mode = "allcaps" + continue + if ch == capnext: + if pending_capnext: + raise LosslessCapsError("duplicate capnext marker") + pending_capnext = True + continue + + if _is_ascii_alpha(ch): + at_word_start = not in_ascii_word + if at_word_start: + if pending_word_mode == "allcaps": + out.append(ch.upper()) + active_allcaps = True + elif pending_word_mode == "title": + out.append(ch.upper()) + elif pending_capnext: + out.append(ch.upper()) + else: + out.append(ch) + pending_word_mode = None + pending_capnext = False + in_ascii_word = True + continue + + if pending_word_mode is not None: + raise LosslessCapsError("word capitalization marker leaked into the middle of a word") + if active_allcaps: + out.append(ch.upper()) + elif pending_capnext: + out.append(ch.upper()) + else: + out.append(ch) + pending_capnext = False + continue + + if pending_word_mode is not None or pending_capnext: + raise LosslessCapsError("capitalization marker not followed by an ASCII letter") + out.append(ch) + in_ascii_word = False + active_allcaps = False + + if pending_escape: + raise LosslessCapsError("dangling escape marker at end of string") + if pending_word_mode is not None or pending_capnext: + raise LosslessCapsError("dangling capitalization marker at end of string") + return "".join(out) + + +def encode_lossless_caps_v3( + text: str, + *, + title: str = DEFAULT_V2_TITLE, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, +) -> str: + """Encode only common word-level capitalization patterns. + + Rules over maximal ASCII alphabetic runs: + - lowercase words stay unchanged + - TitleCase words become `title + lowercase(word)` + - ALLCAPS words become `allcaps + lowercase(word)` + - all other mixed-case words are left unchanged + - literal control characters are escaped as `esc + literal` + """ + _validate_distinct_single_chars(title, allcaps, esc) + controls = {title, allcaps, esc} + out: list[str] = [] + i = 0 + n = len(text) + while i < n: + ch = text[i] + if ch in controls: + out.append(esc) + out.append(ch) + i += 1 + continue + if not _is_ascii_alpha(ch): + out.append(ch) + i += 1 + continue + + j = i + 1 + while j < n and _is_ascii_alpha(text[j]): + j += 1 + word = text[i:j] + + if word.islower(): + out.append(word) + elif len(word) >= 2 and word.isupper(): + out.append(allcaps) + out.append(word.lower()) + elif _is_ascii_upper(word[0]) and word[1:].islower(): + out.append(title) + out.append(word.lower()) + else: + out.append(word) + i = j + return "".join(out) + + +def decode_lossless_caps_v3( + text: str, + *, + title: str = DEFAULT_V2_TITLE, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, +) -> str: + """Decode the `lossless_caps_v3` transform back to the original text.""" + _validate_distinct_single_chars(title, allcaps, esc) + out: list[str] = [] + pending_escape = False + pending_word_mode: str | None = None + active_allcaps = False + in_ascii_word = False + + for ch in text: + if pending_escape: + if pending_word_mode is not None and not _is_ascii_alpha(ch): + raise LosslessCapsError("escaped control char cannot satisfy pending word capitalization mode") + out.append(ch) + pending_escape = False + if _is_ascii_alpha(ch): + in_ascii_word = True + else: + in_ascii_word = False + active_allcaps = False + continue + + if ch == esc: + pending_escape = True + continue + if ch == title: + if pending_word_mode is not None or in_ascii_word: + raise LosslessCapsError("invalid title marker placement") + pending_word_mode = "title" + continue + if ch == allcaps: + if pending_word_mode is not None or in_ascii_word: + raise LosslessCapsError("invalid allcaps marker placement") + pending_word_mode = "allcaps" + continue + + if _is_ascii_alpha(ch): + at_word_start = not in_ascii_word + if at_word_start: + if pending_word_mode == "allcaps": + out.append(ch.upper()) + active_allcaps = True + elif pending_word_mode == "title": + out.append(ch.upper()) + else: + out.append(ch) + pending_word_mode = None + in_ascii_word = True + continue + + if pending_word_mode is not None: + raise LosslessCapsError("word capitalization marker leaked into the middle of a word") + out.append(ch.upper() if active_allcaps else ch) + continue + + if pending_word_mode is not None: + raise LosslessCapsError("capitalization marker not followed by an ASCII letter") + out.append(ch) + in_ascii_word = False + active_allcaps = False + + if pending_escape: + raise LosslessCapsError("dangling escape marker at end of string") + if pending_word_mode is not None: + raise LosslessCapsError("dangling capitalization marker at end of string") + return "".join(out) + + +def encode_lossless_caps_v4( + text: str, + *, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, +) -> str: + """Encode only ALLCAPS ASCII words, leaving all other case untouched.""" + _validate_distinct_single_chars(allcaps, esc) + controls = {allcaps, esc} + out: list[str] = [] + i = 0 + n = len(text) + while i < n: + ch = text[i] + if ch in controls: + out.append(esc) + out.append(ch) + i += 1 + continue + if not _is_ascii_alpha(ch): + out.append(ch) + i += 1 + continue + j = i + 1 + while j < n and _is_ascii_alpha(text[j]): + j += 1 + word = text[i:j] + if len(word) >= 2 and word.isupper(): + out.append(allcaps) + out.append(word.lower()) + else: + out.append(word) + i = j + return "".join(out) + + +def decode_lossless_caps_v4( + text: str, + *, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, +) -> str: + """Decode the `lossless_caps_v4` transform back to the original text.""" + _validate_distinct_single_chars(allcaps, esc) + out: list[str] = [] + pending_escape = False + pending_allcaps = False + in_ascii_word = False + active_allcaps = False + + for ch in text: + if pending_escape: + if pending_allcaps and not _is_ascii_alpha(ch): + raise LosslessCapsError("escaped control char cannot satisfy pending allcaps mode") + out.append(ch) + pending_escape = False + if _is_ascii_alpha(ch): + in_ascii_word = True + else: + in_ascii_word = False + active_allcaps = False + continue + + if ch == esc: + pending_escape = True + continue + if ch == allcaps: + if pending_allcaps or in_ascii_word: + raise LosslessCapsError("invalid allcaps marker placement") + pending_allcaps = True + continue + + if _is_ascii_alpha(ch): + if not in_ascii_word: + active_allcaps = pending_allcaps + pending_allcaps = False + in_ascii_word = True + out.append(ch.upper() if active_allcaps else ch) + continue + + if pending_allcaps: + raise LosslessCapsError("allcaps marker not followed by an ASCII letter") + out.append(ch) + in_ascii_word = False + active_allcaps = False + + if pending_escape: + raise LosslessCapsError("dangling escape marker at end of string") + if pending_allcaps: + raise LosslessCapsError("dangling allcaps marker at end of string") + return "".join(out) + + +def encode_lossless_caps_v5( + text: str, + *, + title: str = DEFAULT_V2_TITLE, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, + title_min_len: int = DEFAULT_V5_TITLE_MIN_LEN, +) -> str: + """Encode ALLCAPS words and only sufficiently long TitleCase words.""" + _validate_distinct_single_chars(title, allcaps, esc) + controls = {title, allcaps, esc} + out: list[str] = [] + i = 0 + n = len(text) + while i < n: + ch = text[i] + if ch in controls: + out.append(esc) + out.append(ch) + i += 1 + continue + if not _is_ascii_alpha(ch): + out.append(ch) + i += 1 + continue + j = i + 1 + while j < n and _is_ascii_alpha(text[j]): + j += 1 + word = text[i:j] + if len(word) >= 2 and word.isupper(): + out.append(allcaps) + out.append(word.lower()) + elif len(word) >= title_min_len and _is_ascii_upper(word[0]) and word[1:].islower(): + out.append(title) + out.append(word.lower()) + else: + out.append(word) + i = j + return "".join(out) + + +def decode_lossless_caps_v5( + text: str, + *, + title: str = DEFAULT_V2_TITLE, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, +) -> str: + """Decode the `lossless_caps_v5` transform back to the original text.""" + return decode_lossless_caps_v3(text, title=title, allcaps=allcaps, esc=esc) + + +def encode_lossless_caps_v6( + text: str, + *, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, + allcaps_min_len: int = DEFAULT_V6_ALLCAPS_MIN_LEN, +) -> str: + """Encode only ALLCAPS words with length >= allcaps_min_len.""" + _validate_distinct_single_chars(allcaps, esc) + controls = {allcaps, esc} + out: list[str] = [] + i = 0 + n = len(text) + while i < n: + ch = text[i] + if ch in controls: + out.append(esc) + out.append(ch) + i += 1 + continue + if not _is_ascii_alpha(ch): + out.append(ch) + i += 1 + continue + j = i + 1 + while j < n and _is_ascii_alpha(text[j]): + j += 1 + word = text[i:j] + if len(word) >= allcaps_min_len and word.isupper(): + out.append(allcaps) + out.append(word.lower()) + else: + out.append(word) + i = j + return "".join(out) + + +def decode_lossless_caps_v6( + text: str, + *, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, +) -> str: + """Decode the `lossless_caps_v6` transform back to the original text.""" + return decode_lossless_caps_v4(text, allcaps=allcaps, esc=esc) + + +def encode_lossless_caps_v7( + text: str, + *, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, + allcaps_min_len: int = DEFAULT_V7_ALLCAPS_MIN_LEN, +) -> str: + """Encode only ALLCAPS words with length >= 4.""" + return encode_lossless_caps_v6( + text, + allcaps=allcaps, + esc=esc, + allcaps_min_len=allcaps_min_len, + ) + + +def decode_lossless_caps_v7( + text: str, + *, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, +) -> str: + """Decode the `lossless_caps_v7` transform back to the original text.""" + return decode_lossless_caps_v6(text, allcaps=allcaps, esc=esc) + + +def get_text_transform(name: str | None) -> Callable[[str], str]: + """Return the forward text transform for the given config name.""" + normalized = IDENTITY if name in {None, "", IDENTITY} else str(name) + if normalized == IDENTITY: + return lambda text: text + if normalized == LOSSLESS_CAPS_V1: + return encode_lossless_caps_v1 + if normalized == LOSSLESS_CAPS_V2: + return encode_lossless_caps_v2 + if normalized == LOSSLESS_CAPS_V3: + return encode_lossless_caps_v3 + if normalized == LOSSLESS_CAPS_V4: + return encode_lossless_caps_v4 + if normalized == LOSSLESS_CAPS_V5: + return encode_lossless_caps_v5 + if normalized == LOSSLESS_CAPS_V6: + return encode_lossless_caps_v6 + if normalized == LOSSLESS_CAPS_V7: + return encode_lossless_caps_v7 + if normalized == LOSSLESS_CAPS_CASEOPS_V1: + return encode_lossless_caps_v2 + raise ValueError(f"unsupported text_transform={name!r}") + + +def get_text_inverse_transform(name: str | None) -> Callable[[str], str]: + """Return the inverse transform for the given config name.""" + normalized = IDENTITY if name in {None, "", IDENTITY} else str(name) + if normalized == IDENTITY: + return lambda text: text + if normalized == LOSSLESS_CAPS_V1: + return decode_lossless_caps_v1 + if normalized == LOSSLESS_CAPS_V2: + return decode_lossless_caps_v2 + if normalized == LOSSLESS_CAPS_V3: + return decode_lossless_caps_v3 + if normalized == LOSSLESS_CAPS_V4: + return decode_lossless_caps_v4 + if normalized == LOSSLESS_CAPS_V5: + return decode_lossless_caps_v5 + if normalized == LOSSLESS_CAPS_V6: + return decode_lossless_caps_v6 + if normalized == LOSSLESS_CAPS_V7: + return decode_lossless_caps_v7 + if normalized == LOSSLESS_CAPS_CASEOPS_V1: + return decode_lossless_caps_v2 + raise ValueError(f"unsupported text_transform={name!r}") + + +def normalize_text_transform_name(name: str | None) -> str: + """Normalize empty/None transform names to the identity transform.""" + return IDENTITY if name in {None, "", IDENTITY} else str(name) + + +def get_text_transform_control_symbols(name: str | None) -> list[str]: + """Return reserved control symbols used by a transform, if any.""" + normalized = normalize_text_transform_name(name) + if normalized == IDENTITY: + return [] + if normalized == LOSSLESS_CAPS_V1: + return [DEFAULT_SENTINEL] + if normalized == LOSSLESS_CAPS_V2: + return [DEFAULT_V2_TITLE, DEFAULT_V2_ALLCAPS, DEFAULT_V2_CAPNEXT, DEFAULT_V2_ESC] + if normalized == LOSSLESS_CAPS_CASEOPS_V1: + return [DEFAULT_V2_TITLE, DEFAULT_V2_ALLCAPS, DEFAULT_V2_CAPNEXT, DEFAULT_V2_ESC] + if normalized in {LOSSLESS_CAPS_V3, LOSSLESS_CAPS_V5}: + return [DEFAULT_V2_TITLE, DEFAULT_V2_ALLCAPS, DEFAULT_V2_ESC] + if normalized in {LOSSLESS_CAPS_V4, LOSSLESS_CAPS_V6, LOSSLESS_CAPS_V7}: + return [DEFAULT_V2_ALLCAPS, DEFAULT_V2_ESC] + raise ValueError(f"unsupported text_transform={name!r}") + + +def infer_text_transform_from_manifest(tokenizer_path: str | Path) -> str: + """Best-effort lookup of a tokenizer's text transform from a local manifest.""" + tokenizer_path = Path(tokenizer_path).expanduser().resolve() + manifest_candidates = [ + tokenizer_path.parent.parent / "manifest.json", + tokenizer_path.parent / "manifest.json", + ] + for manifest_path in manifest_candidates: + if not manifest_path.is_file(): + continue + try: + payload = json.loads(manifest_path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError): + continue + tokenizers = payload.get("tokenizers") + if not isinstance(tokenizers, list): + continue + for tokenizer_meta in tokenizers: + if not isinstance(tokenizer_meta, dict): + continue + model_path = tokenizer_meta.get("model_path") or tokenizer_meta.get("path") + if not model_path: + continue + candidate = (manifest_path.parent / str(model_path)).resolve() + if candidate == tokenizer_path: + return normalize_text_transform_name(tokenizer_meta.get("text_transform")) + return IDENTITY + + +def surface_piece_original_byte_counts( + surfaces: Iterable[str], + *, + text_transform_name: str | None = None, + sentinel: str = DEFAULT_SENTINEL, +) -> list[int]: + """Return exact original UTF-8 byte counts contributed by each surface piece. + + `surfaces` must be the exact decoded text fragments emitted by SentencePiece + in order, e.g. `piece.surface` from `encode_as_immutable_proto`. + """ + normalized = normalize_text_transform_name(text_transform_name) + if normalized == IDENTITY: + return [len(surface.encode("utf-8")) for surface in surfaces] + if normalized == LOSSLESS_CAPS_V1: + if len(sentinel) != 1: + raise ValueError("sentinel must be exactly one character") + sentinel_bytes = len(sentinel.encode("utf-8")) + pending_sentinel = False + counts: list[int] = [] + for surface in surfaces: + piece_bytes = 0 + for ch in surface: + if pending_sentinel: + if ch == sentinel: + piece_bytes += sentinel_bytes + elif _is_ascii_lower(ch): + piece_bytes += 1 + else: + raise LosslessCapsError( + f"invalid continuation {ch!r} after capitalization sentinel" + ) + pending_sentinel = False + continue + if ch == sentinel: + pending_sentinel = True + else: + piece_bytes += len(ch.encode("utf-8")) + counts.append(piece_bytes) + if pending_sentinel: + raise LosslessCapsError("dangling capitalization sentinel across piece boundary") + return counts + if normalized not in {LOSSLESS_CAPS_V2, LOSSLESS_CAPS_V3, LOSSLESS_CAPS_V4, LOSSLESS_CAPS_V5, LOSSLESS_CAPS_V6, LOSSLESS_CAPS_V7, LOSSLESS_CAPS_CASEOPS_V1}: + raise ValueError(f"unsupported text_transform={text_transform_name!r}") + + title = DEFAULT_V2_TITLE + allcaps = DEFAULT_V2_ALLCAPS + capnext = DEFAULT_V2_CAPNEXT + esc = DEFAULT_V2_ESC + if normalized in {LOSSLESS_CAPS_V2, LOSSLESS_CAPS_CASEOPS_V1}: + _validate_distinct_single_chars(title, allcaps, capnext, esc) + elif normalized in {LOSSLESS_CAPS_V4, LOSSLESS_CAPS_V6, LOSSLESS_CAPS_V7}: + _validate_distinct_single_chars(allcaps, esc) + else: + _validate_distinct_single_chars(title, allcaps, esc) + pending_escape = False + pending_word_mode: str | None = None + active_allcaps = False + pending_capnext = False + in_ascii_word = False + counts: list[int] = [] + for surface in surfaces: + piece_bytes = 0 + for ch in surface: + if pending_escape: + if pending_word_mode is not None and not _is_ascii_alpha(ch): + raise LosslessCapsError("escaped control char cannot satisfy pending word capitalization mode") + piece_bytes += len(ch.encode("utf-8")) + pending_escape = False + if _is_ascii_alpha(ch): + in_ascii_word = True + else: + in_ascii_word = False + active_allcaps = False + continue + if ch == esc: + pending_escape = True + continue + if normalized in {LOSSLESS_CAPS_V2, LOSSLESS_CAPS_V3, LOSSLESS_CAPS_V5, LOSSLESS_CAPS_CASEOPS_V1} and ch == title: + if pending_word_mode is not None or in_ascii_word or pending_capnext: + raise LosslessCapsError("invalid title marker placement") + pending_word_mode = "title" + continue + if ch == allcaps: + if pending_word_mode is not None or in_ascii_word or pending_capnext: + raise LosslessCapsError("invalid allcaps marker placement") + pending_word_mode = "allcaps" + continue + if normalized in {LOSSLESS_CAPS_V2, LOSSLESS_CAPS_CASEOPS_V1} and ch == capnext: + if pending_capnext: + raise LosslessCapsError("duplicate capnext marker") + pending_capnext = True + continue + + if _is_ascii_alpha(ch): + at_word_start = not in_ascii_word + if at_word_start: + piece_bytes += 1 + active_allcaps = pending_word_mode == "allcaps" + pending_word_mode = None + pending_capnext = False + in_ascii_word = True + continue + if pending_word_mode is not None: + raise LosslessCapsError("word capitalization marker leaked into the middle of a word") + piece_bytes += 1 + pending_capnext = False + continue + + if pending_word_mode is not None or pending_capnext: + raise LosslessCapsError("capitalization marker not followed by an ASCII letter") + piece_bytes += len(ch.encode("utf-8")) + in_ascii_word = False + active_allcaps = False + counts.append(piece_bytes) + if pending_escape: + raise LosslessCapsError("dangling escape marker across piece boundary") + if pending_word_mode is not None or pending_capnext: + raise LosslessCapsError("dangling capitalization marker across piece boundary") + return counts diff --git a/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/online_ngram_state.c b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/online_ngram_state.c new file mode 100644 index 0000000000..c8d1c1a118 --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/online_ngram_state.c @@ -0,0 +1,496 @@ +#include +#include +#include + +#define COEFF_COUNT 32 + +static const uint64_t ROLLING_COEFFS[COEFF_COUNT] = { + 36313ULL, 27191ULL, 51647ULL, 81929ULL, 131071ULL, 196613ULL, + 262147ULL, 393241ULL, 524309ULL, 655373ULL, 786433ULL, 917521ULL, + 1048583ULL, 1179653ULL, 1310729ULL, 1441801ULL, 1572869ULL, 1703941ULL, + 1835017ULL, 1966087ULL, 2097169ULL, 2228243ULL, 2359319ULL, 2490389ULL, + 2621471ULL, 2752549ULL, 2883617ULL, 3014687ULL, 3145757ULL, 3276833ULL, + 3407903ULL, 3538973ULL, +}; + +static const uint64_t PAIR_MIX = 1000003ULL; +static const uint64_t PREFIX_BASE = 1099511628211ULL; +static const uint64_t LEN_MIX = 0x9E3779B185EBCA87ULL; +static const uint64_t TABLE_MIX = 0x9e3779b97f4a7c15ULL; + +typedef struct { + uint64_t key; + uint32_t total; + uint32_t top_count; + uint16_t top_tok; + uint16_t _pad; +} CtxBucket; + +typedef struct { + uint64_t key; + uint32_t count; + uint32_t _pad; +} PairBucket; + +typedef struct { + int token_ctx_len; + int token_prefix_len; + int token_head; + uint16_t *token_ring; + + CtxBucket *token_ctx_tbl; + uint8_t *token_ctx_used; + size_t token_ctx_mask; + + PairBucket *token_pair_tbl; + uint8_t *token_pair_used; + size_t token_pair_mask; + + uint64_t within_hash; + uint32_t within_len; + + CtxBucket *within_ctx_tbl; + uint8_t *within_ctx_used; + size_t within_ctx_mask; + + PairBucket *within_pair_tbl; + uint8_t *within_pair_used; + size_t within_pair_mask; +} OnlineNgramState; + +static inline size_t mix_index(uint64_t key, size_t mask) { + return (size_t)((key * TABLE_MIX) & mask); +} + +static inline size_t find_ctx_slot( + CtxBucket *tbl, + uint8_t *used, + size_t mask, + uint64_t key, + int *found +) { + size_t idx = mix_index(key, mask); + for (size_t probe = 0; probe <= mask; ++probe) { + if (!used[idx]) { + *found = 0; + return idx; + } + if (tbl[idx].key == key) { + *found = 1; + return idx; + } + idx = (idx + 1U) & mask; + } + *found = -1; + return 0; +} + +static inline size_t find_pair_slot( + PairBucket *tbl, + uint8_t *used, + size_t mask, + uint64_t key, + int *found +) { + size_t idx = mix_index(key, mask); + for (size_t probe = 0; probe <= mask; ++probe) { + if (!used[idx]) { + *found = 0; + return idx; + } + if (tbl[idx].key == key) { + *found = 1; + return idx; + } + idx = (idx + 1U) & mask; + } + *found = -1; + return 0; +} + +static inline uint64_t token_pair_key(uint64_t ctx_key, uint16_t tok, int ctx_len) { + return (ctx_key * PAIR_MIX) ^ (((uint64_t)tok) * ROLLING_COEFFS[(size_t)ctx_len % COEFF_COUNT]); +} + +static inline uint64_t within_pair_key(uint64_t ctx_key, uint16_t tok) { + return (ctx_key * PAIR_MIX) ^ (((uint64_t)tok) * ROLLING_COEFFS[0]); +} + +static inline uint64_t extend_prefix_hash(uint64_t current_hash, uint16_t tok, uint32_t pos) { + return (current_hash * PREFIX_BASE) ^ (((uint64_t)tok + 1ULL) * ROLLING_COEFFS[(size_t)pos % COEFF_COUNT]); +} + +static inline uint32_t pair_increment( + PairBucket *tbl, + uint8_t *used, + size_t mask, + uint64_t key +) { + int found = 0; + size_t idx = find_pair_slot(tbl, used, mask, key, &found); + if (found < 0) { + return 0U; + } + if (!found) { + used[idx] = 1U; + tbl[idx].key = key; + tbl[idx].count = 1U; + return 1U; + } + tbl[idx].count += 1U; + return tbl[idx].count; +} + +static inline int ctx_increment( + CtxBucket *tbl, + uint8_t *used, + size_t mask, + uint64_t key, + uint16_t tok, + uint32_t pair_count +) { + int found = 0; + size_t idx = find_ctx_slot(tbl, used, mask, key, &found); + if (found < 0) { + return -1; + } + if (!found) { + used[idx] = 1U; + tbl[idx].key = key; + tbl[idx].total = 1U; + tbl[idx].top_count = pair_count; + tbl[idx].top_tok = tok; + return 0; + } + tbl[idx].total += 1U; + if (pair_count > tbl[idx].top_count) { + tbl[idx].top_count = pair_count; + tbl[idx].top_tok = tok; + } + return 0; +} + +static inline uint64_t token_context_hash(const OnlineNgramState *st) { + uint64_t h = 0ULL; + if (st->token_ctx_len <= 0) { + return h; + } + for (int j = 0; j < st->token_ctx_len; ++j) { + const int ring_idx = (st->token_head + j) % st->token_ctx_len; + h ^= ((uint64_t)st->token_ring[ring_idx]) * ROLLING_COEFFS[(size_t)j]; + } + return h; +} + +static inline void token_push(OnlineNgramState *st, uint16_t tok) { + if (st->token_ctx_len <= 0) { + return; + } + if (st->token_prefix_len < st->token_ctx_len) { + st->token_ring[st->token_prefix_len] = tok; + st->token_prefix_len += 1; + return; + } + st->token_ring[st->token_head] = tok; + st->token_head = (st->token_head + 1) % st->token_ctx_len; +} + +static void *xcalloc(size_t count, size_t size) { + if (count == 0 || size == 0) { + return NULL; + } + return calloc(count, size); +} + +static int alloc_tables( + size_t table_bits, + CtxBucket **ctx_tbl, + uint8_t **ctx_used, + size_t *ctx_mask, + PairBucket **pair_tbl, + uint8_t **pair_used, + size_t *pair_mask +) { + const size_t size = 1ULL << table_bits; + *ctx_tbl = (CtxBucket *)xcalloc(size, sizeof(CtxBucket)); + *ctx_used = (uint8_t *)xcalloc(size, sizeof(uint8_t)); + *pair_tbl = (PairBucket *)xcalloc(size, sizeof(PairBucket)); + *pair_used = (uint8_t *)xcalloc(size, sizeof(uint8_t)); + if (!*ctx_tbl || !*ctx_used || !*pair_tbl || !*pair_used) { + return -1; + } + *ctx_mask = size - 1U; + *pair_mask = size - 1U; + return 0; +} + +void *online_ngram_state_create( + int token_ctx_len, + int token_table_bits, + int within_table_bits +) { + if (token_ctx_len < 0 || token_table_bits <= 0 || within_table_bits <= 0) { + return NULL; + } + OnlineNgramState *st = (OnlineNgramState *)calloc(1, sizeof(OnlineNgramState)); + if (!st) { + return NULL; + } + st->token_ctx_len = token_ctx_len; + if (token_ctx_len > 0) { + st->token_ring = (uint16_t *)xcalloc((size_t)token_ctx_len, sizeof(uint16_t)); + if (!st->token_ring) { + free(st); + return NULL; + } + } + if (alloc_tables( + (size_t)token_table_bits, + &st->token_ctx_tbl, + &st->token_ctx_used, + &st->token_ctx_mask, + &st->token_pair_tbl, + &st->token_pair_used, + &st->token_pair_mask + ) != 0) { + free(st->token_ring); + free(st); + return NULL; + } + if (alloc_tables( + (size_t)within_table_bits, + &st->within_ctx_tbl, + &st->within_ctx_used, + &st->within_ctx_mask, + &st->within_pair_tbl, + &st->within_pair_used, + &st->within_pair_mask + ) != 0) { + free(st->token_pair_used); + free(st->token_pair_tbl); + free(st->token_ctx_used); + free(st->token_ctx_tbl); + free(st->token_ring); + free(st); + return NULL; + } + return (void *)st; +} + +void online_ngram_state_destroy(void *ptr) { + OnlineNgramState *st = (OnlineNgramState *)ptr; + if (!st) { + return; + } + free(st->within_pair_used); + free(st->within_pair_tbl); + free(st->within_ctx_used); + free(st->within_ctx_tbl); + free(st->token_pair_used); + free(st->token_pair_tbl); + free(st->token_ctx_used); + free(st->token_ctx_tbl); + free(st->token_ring); + free(st); +} + +void online_ngram_state_seed_prefix_token(void *ptr, uint16_t tok) { + OnlineNgramState *st = (OnlineNgramState *)ptr; + if (!st) { + return; + } + token_push(st, tok); +} + +int online_ngram_state_process_chunk( + void *ptr, + const uint16_t *tokens, + int64_t n_tokens, + const uint8_t *starts_new_word_lut, + const uint8_t *boundary_lut, + uint16_t *token_top_token, + float *token_top_prob, + uint16_t *within_top_token, + float *within_top_prob, + uint8_t *within_valid +) { + OnlineNgramState *st = (OnlineNgramState *)ptr; + if (!st || !tokens || n_tokens < 0) { + return -1; + } + for (int64_t i = 0; i < n_tokens; ++i) { + const uint16_t tok = tokens[i]; + const uint8_t is_boundary = boundary_lut[tok]; + const uint8_t is_new_word = starts_new_word_lut[tok]; + + uint64_t token_ctx_key = 0ULL; + if (st->token_ctx_len == 0 || st->token_prefix_len >= st->token_ctx_len) { + token_ctx_key = token_context_hash(st); + int found = 0; + size_t idx = find_ctx_slot( + st->token_ctx_tbl, + st->token_ctx_used, + st->token_ctx_mask, + token_ctx_key, + &found + ); + if (found > 0) { + token_top_token[i] = st->token_ctx_tbl[idx].top_tok; + token_top_prob[i] = + (float)st->token_ctx_tbl[idx].top_count / (float)st->token_ctx_tbl[idx].total; + } else { + token_top_token[i] = 0U; + token_top_prob[i] = 0.0f; + } + } else { + token_top_token[i] = 0U; + token_top_prob[i] = 0.0f; + } + + uint64_t within_ctx_key = 0ULL; + if (!is_boundary && !is_new_word && st->within_len > 0U) { + within_ctx_key = st->within_hash ^ ((uint64_t)st->within_len * LEN_MIX); + int found = 0; + size_t idx = find_ctx_slot( + st->within_ctx_tbl, + st->within_ctx_used, + st->within_ctx_mask, + within_ctx_key, + &found + ); + within_valid[i] = 1U; + if (found > 0) { + within_top_token[i] = st->within_ctx_tbl[idx].top_tok; + within_top_prob[i] = + (float)st->within_ctx_tbl[idx].top_count / (float)st->within_ctx_tbl[idx].total; + } else { + within_top_token[i] = 0U; + within_top_prob[i] = 0.0f; + } + } else { + within_valid[i] = 0U; + within_top_token[i] = 0U; + within_top_prob[i] = 0.0f; + } + + if (st->token_ctx_len == 0 || st->token_prefix_len >= st->token_ctx_len) { + const uint64_t pair_key = token_pair_key(token_ctx_key, tok, st->token_ctx_len); + const uint32_t pair_count = pair_increment( + st->token_pair_tbl, + st->token_pair_used, + st->token_pair_mask, + pair_key + ); + if (pair_count == 0U) { + return -2; + } + if (ctx_increment( + st->token_ctx_tbl, + st->token_ctx_used, + st->token_ctx_mask, + token_ctx_key, + tok, + pair_count + ) != 0) { + return -3; + } + } + token_push(st, tok); + + if (is_boundary) { + st->within_hash = 0ULL; + st->within_len = 0U; + continue; + } + if (is_new_word || st->within_len == 0U) { + st->within_hash = extend_prefix_hash(0ULL, tok, 0U); + st->within_len = 1U; + continue; + } + const uint32_t within_pair_count = pair_increment( + st->within_pair_tbl, + st->within_pair_used, + st->within_pair_mask, + within_pair_key(within_ctx_key, tok) + ); + if (within_pair_count == 0U) { + return -4; + } + if (ctx_increment( + st->within_ctx_tbl, + st->within_ctx_used, + st->within_ctx_mask, + within_ctx_key, + tok, + within_pair_count + ) != 0) { + return -5; + } + st->within_hash = extend_prefix_hash(st->within_hash, tok, st->within_len); + st->within_len += 1U; + } + return 0; +} + +int online_ngram_state_process_chunk_token_only( + void *ptr, + const uint16_t *tokens, + int64_t n_tokens, + uint16_t *token_top_token, + float *token_top_prob +) { + OnlineNgramState *st = (OnlineNgramState *)ptr; + if (!st || !tokens || n_tokens < 0) { + return -1; + } + for (int64_t i = 0; i < n_tokens; ++i) { + const uint16_t tok = tokens[i]; + + uint64_t token_ctx_key = 0ULL; + if (st->token_ctx_len == 0 || st->token_prefix_len >= st->token_ctx_len) { + token_ctx_key = token_context_hash(st); + int found = 0; + size_t idx = find_ctx_slot( + st->token_ctx_tbl, + st->token_ctx_used, + st->token_ctx_mask, + token_ctx_key, + &found + ); + if (found > 0) { + token_top_token[i] = st->token_ctx_tbl[idx].top_tok; + token_top_prob[i] = + (float)st->token_ctx_tbl[idx].top_count / (float)st->token_ctx_tbl[idx].total; + } else { + token_top_token[i] = 0U; + token_top_prob[i] = 0.0f; + } + + const uint64_t pair_key = token_pair_key(token_ctx_key, tok, st->token_ctx_len); + const uint32_t pair_count = pair_increment( + st->token_pair_tbl, + st->token_pair_used, + st->token_pair_mask, + pair_key + ); + if (pair_count == 0U) { + return -2; + } + if (ctx_increment( + st->token_ctx_tbl, + st->token_ctx_used, + st->token_ctx_mask, + token_ctx_key, + tok, + pair_count + ) != 0) { + return -3; + } + } else { + token_top_token[i] = 0U; + token_top_prob[i] = 0.0f; + } + token_push(st, tok); + } + return 0; +} diff --git a/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/online_ngram_tilt.py b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/online_ngram_tilt.py new file mode 100644 index 0000000000..9c188b053c --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/online_ngram_tilt.py @@ -0,0 +1,431 @@ +""" +Vendored online n-gram tilt helpers from PR #1145 (AnirudhRahul, valerio-endorsed). + +Provides causal, normalized, prefix-only n-gram experts that propose at most one +hinted token per scored position. Caller obtains q_t = p(h_t | x) from the model +(post-TTT-adapt logits) and applies multiplicative-boost-with-renorm: + + p'(a) = exp(beta * 1[a == h_t]) * p(a) / Z_t + Z_t = 1 - q_t + exp(beta) * q_t = 1 + q_t * (exp(beta) - 1) + -log p'(y_realized) = -log p(y) - beta * 1[y == h_t] + log Z_t + = ptl - beta * is_hit + log1p(q_t * (exp(beta) - 1)) + +Compliance: +- C1 causal: hint h_t computed from strict prefix (tokens 0..t-1 only) +- C2 normalized over Sigma: closed-form Z_t over full vocab softmax +- C3 score-before-update: hints precomputed in single L->R pass; loss uses prefix-only +- C4 single pass: process_chunk advances state monotonically + +Compatible with both #1934/#1855 base architectures via Hyperparameter env-var gates. +""" + +from __future__ import annotations + +import ctypes +import math +import os +import subprocess +from collections import deque +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch + + +SCRIPT_DIR = Path(__file__).resolve().parent +ONLINE_NGRAM_SRC = SCRIPT_DIR / "online_ngram_state.c" +ONLINE_NGRAM_LIB = SCRIPT_DIR / "libonline_ngram_state.so" + +WHITESPACE_BYTE_IDS = {9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 36} +EDGE_PUNCT = ".,:;!?()[]{}<>\"'`" + + +def normalize_word(text: str, mode: str) -> str: + text = text.strip() + if mode == "lower": + return text.lower() + if mode == "identity": + return text + if mode == "strip_punct_lower": + return text.strip(EDGE_PUNCT).lower() + raise ValueError(f"Unknown word normalization mode: {mode}") + + +def suggest_table_bits(expected_entries: int, load_factor: float) -> int: + if expected_entries <= 0: + return 16 + target = max(int(expected_entries / max(load_factor, 1e-6)), 1) + bits = max(int(math.ceil(math.log2(target))), 12) + return min(bits, 28) + + +def ensure_online_ngram_lib(log0=print) -> ctypes.CDLL: + needs_build = (not ONLINE_NGRAM_LIB.exists()) or ( + ONLINE_NGRAM_SRC.stat().st_mtime_ns > ONLINE_NGRAM_LIB.stat().st_mtime_ns + ) + if needs_build: + log0(f"ngram_tilt:building_native_helper src={ONLINE_NGRAM_SRC.name}") + subprocess.run( + [ + "gcc", "-O3", "-march=native", "-shared", "-fPIC", + "-o", str(ONLINE_NGRAM_LIB), + str(ONLINE_NGRAM_SRC), + ], + check=True, + ) + lib = ctypes.CDLL(str(ONLINE_NGRAM_LIB)) + lib.online_ngram_state_create.restype = ctypes.c_void_p + lib.online_ngram_state_create.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int] + lib.online_ngram_state_destroy.restype = None + lib.online_ngram_state_destroy.argtypes = [ctypes.c_void_p] + lib.online_ngram_state_seed_prefix_token.restype = None + lib.online_ngram_state_seed_prefix_token.argtypes = [ctypes.c_void_p, ctypes.c_uint16] + lib.online_ngram_state_process_chunk.restype = ctypes.c_int + lib.online_ngram_state_process_chunk.argtypes = [ + ctypes.c_void_p, + ctypes.POINTER(ctypes.c_uint16), + ctypes.c_int64, + ctypes.POINTER(ctypes.c_uint8), + ctypes.POINTER(ctypes.c_uint8), + ctypes.POINTER(ctypes.c_uint16), + ctypes.POINTER(ctypes.c_float), + ctypes.POINTER(ctypes.c_uint16), + ctypes.POINTER(ctypes.c_float), + ctypes.POINTER(ctypes.c_uint8), + ] + lib.online_ngram_state_process_chunk_token_only.restype = ctypes.c_int + lib.online_ngram_state_process_chunk_token_only.argtypes = [ + ctypes.c_void_p, + ctypes.POINTER(ctypes.c_uint16), + ctypes.c_int64, + ctypes.POINTER(ctypes.c_uint16), + ctypes.POINTER(ctypes.c_float), + ] + return lib + + +class OnlineNgramState: + def __init__( + self, *, lib, token_ctx_len, token_table_bits, within_table_bits, + starts_new_word_lut, boundary_lut, seed_prefix_token, + ): + self.lib = lib + self.state = lib.online_ngram_state_create(token_ctx_len, token_table_bits, within_table_bits) + if not self.state: + raise RuntimeError( + f"Native ngram state alloc failed token_table_bits={token_table_bits} within_table_bits={within_table_bits}" + ) + self.starts_new_word_lut = np.ascontiguousarray(starts_new_word_lut.astype(np.uint8, copy=False)) + self.boundary_lut = np.ascontiguousarray(boundary_lut.astype(np.uint8, copy=False)) + self.lib.online_ngram_state_seed_prefix_token(self.state, ctypes.c_uint16(int(seed_prefix_token))) + + def close(self): + if self.state: + self.lib.online_ngram_state_destroy(self.state) + self.state = None + + def __del__(self): + self.close() + + def process_chunk(self, chunk_tokens): + chunk_tokens = np.ascontiguousarray(chunk_tokens.astype(np.uint16, copy=False)) + n = int(chunk_tokens.size) + token_top_token = np.zeros(n, dtype=np.uint16) + token_top_prob = np.zeros(n, dtype=np.float32) + within_top_token = np.zeros(n, dtype=np.uint16) + within_top_prob = np.zeros(n, dtype=np.float32) + within_valid = np.zeros(n, dtype=np.uint8) + rc = self.lib.online_ngram_state_process_chunk( + self.state, + chunk_tokens.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), + ctypes.c_int64(n), + self.starts_new_word_lut.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)), + self.boundary_lut.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)), + token_top_token.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), + token_top_prob.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), + within_top_token.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), + within_top_prob.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), + within_valid.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)), + ) + if rc != 0: + raise RuntimeError(f"Native ngram process_chunk failed rc={rc}") + return token_top_token, token_top_prob, within_top_token, within_top_prob, within_valid.astype(bool) + + def process_chunk_token_only(self, chunk_tokens): + chunk_tokens = np.ascontiguousarray(chunk_tokens.astype(np.uint16, copy=False)) + n = int(chunk_tokens.size) + token_top_token = np.zeros(n, dtype=np.uint16) + token_top_prob = np.zeros(n, dtype=np.float32) + rc = self.lib.online_ngram_state_process_chunk_token_only( + self.state, + chunk_tokens.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), + ctypes.c_int64(n), + token_top_token.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), + token_top_prob.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), + ) + if rc != 0: + raise RuntimeError(f"Native ngram token-only process_chunk failed rc={rc}") + return token_top_token, token_top_prob + + +class WordStartState: + def __init__(self, *, sp, order, normalize_mode): + self.sp = sp + self.ctx_w = max(order - 1, 0) + self.normalize_mode = normalize_mode + self.prev_word_ids: deque = deque(maxlen=self.ctx_w) + self.current_word_tokens: list = [] + self.word_to_id: dict = {} + self.next_word_id = 1 + self.ctx_total: dict = {} + self.pair_count: dict = {} + self.ctx_best_token: dict = {} + self.ctx_best_count: dict = {} + + def _flush_current_word(self): + if not self.current_word_tokens: + return + text = normalize_word(self.sp.decode(self.current_word_tokens), self.normalize_mode) + if text: + wid = self.word_to_id.get(text) + if wid is None: + wid = self.next_word_id + self.word_to_id[text] = wid + self.next_word_id += 1 + if self.ctx_w > 0: + self.prev_word_ids.append(wid) + self.current_word_tokens = [] + + def process_chunk(self, chunk_tokens, *, starts_new_word_lut, boundary_lut): + chunk_tokens = np.ascontiguousarray(chunk_tokens.astype(np.uint16, copy=False)) + top_token = np.zeros(chunk_tokens.size, dtype=np.uint16) + top_prob = np.zeros(chunk_tokens.size, dtype=np.float32) + for i, tok_u16 in enumerate(chunk_tokens): + tok = int(tok_u16) + is_boundary = bool(boundary_lut[tok]) + is_word_start = bool(starts_new_word_lut[tok]) or not self.current_word_tokens + if is_boundary: + self._flush_current_word() + continue + if bool(starts_new_word_lut[tok]): + self._flush_current_word() + ctx_key = None + if is_word_start and len(self.prev_word_ids) >= self.ctx_w: + ctx_key = tuple(self.prev_word_ids) if self.ctx_w > 0 else () + total = self.ctx_total.get(ctx_key, 0) + if total > 0: + top_token[i] = np.uint16(self.ctx_best_token[ctx_key]) + top_prob[i] = np.float32(self.ctx_best_count[ctx_key] / total) + if is_word_start: + if ctx_key is not None: + pair_key = (ctx_key, tok) + pair = self.pair_count.get(pair_key, 0) + 1 + self.pair_count[pair_key] = pair + total = self.ctx_total.get(ctx_key, 0) + 1 + self.ctx_total[ctx_key] = total + best_count = self.ctx_best_count.get(ctx_key, 0) + if pair > best_count: + self.ctx_best_count[ctx_key] = pair + self.ctx_best_token[ctx_key] = tok + self.current_word_tokens = [tok] + else: + self.current_word_tokens.append(tok) + return top_token, top_prob + + +def build_piece_luts(*, tokenizer_path, vocab_size): + sp = spm.SentencePieceProcessor(model_file=tokenizer_path) + pieces = [sp.id_to_piece(i) for i in range(sp.vocab_size())] + starts_new_word_lut = np.zeros(vocab_size, dtype=np.uint8) + for i, piece in enumerate(pieces): + starts_new_word_lut[i] = 1 if piece.startswith("▁") else 0 + boundary_lut = np.zeros(vocab_size, dtype=np.uint8) + bos_id = sp.bos_id() + if bos_id >= 0 and bos_id < vocab_size: + boundary_lut[bos_id] = 1 + for tok in range(min(sp.vocab_size(), vocab_size)): + if sp.is_byte(tok) and tok in WHITESPACE_BYTE_IDS: + boundary_lut[tok] = 1 + return sp, starts_new_word_lut, boundary_lut + + +def build_hints_for_targets( + *, target_token_ids_np, tokenizer_path, vocab_size, log0=print, + token_order=16, token_threshold=0.800, token_boost=2.625, + within_tau=999.0, within_boost=0.0, + word_order=4, word_normalize="strip_punct_lower", + word_tau=999.0, word_boost=0.0, + agree_add_boost=0.0, +): + """Single L->R pass. Returns dict with hint_ids, gate_mask, boost_per_pos. + + target_token_ids_np: np.uint16 array of realized targets (length = total_targets). + Output arrays are aligned to target_token_ids_np indexing. + + For each scored position t we pick at most one hint h_t: + - prefer the expert with highest expected gain = p_top * boost - log1p(p_top * (exp(boost)-1)) + - if multiple experts agree on the same h_t, additive boost agree_add_boost + - gate (don't tilt) when no expert clears its threshold + + The realized loss formula used by the caller: + ptl' = ptl - beta * 1[y == h_t] + log1p(q_t * (exp(beta) - 1)) when gate_mask == True + ptl' = ptl when gate_mask == False + """ + sp, starts_new_word_lut, boundary_lut = build_piece_luts( + tokenizer_path=tokenizer_path, vocab_size=vocab_size + ) + total = int(target_token_ids_np.size) + if total == 0: + return { + "hint_ids": np.zeros(0, dtype=np.int64), + "gate_mask": np.zeros(0, dtype=bool), + "boost": np.zeros(0, dtype=np.float32), + "sp": sp, + "starts_new_word_lut": starts_new_word_lut, + "boundary_lut": boundary_lut, + } + + token_table_bits = suggest_table_bits(total, load_factor=0.55) + within_table_bits = suggest_table_bits(max(total // 2, 1), load_factor=0.60) + token_only = ( + float(within_boost) == 0.0 + and float(word_boost) == 0.0 + ) + online_lib = ensure_online_ngram_lib(log0) + ngram_state = OnlineNgramState( + lib=online_lib, + token_ctx_len=max(token_order - 1, 0), + token_table_bits=token_table_bits, + within_table_bits=within_table_bits, + starts_new_word_lut=starts_new_word_lut, + boundary_lut=boundary_lut, + seed_prefix_token=int(target_token_ids_np[0]), + ) + if token_only: + token_top_tok, token_top_prob = ngram_state.process_chunk_token_only(target_token_ids_np) + token_gate = token_top_prob >= np.float32(token_threshold) + hint_ids = np.where(token_gate, token_top_tok.astype(np.int64), 0).astype(np.int64) + boost = np.where(token_gate, np.float32(token_boost), np.float32(0.0)).astype(np.float32) + log0( + f"ngram_tilt:hints total={total} gated={int(token_gate.sum())} " + f"token_gate={int(token_gate.sum())} within_gate=0 word_gate=0 agree2plus=0" + ) + return { + "hint_ids": hint_ids, + "gate_mask": token_gate, + "boost": boost, + "sp": sp, + "starts_new_word_lut": starts_new_word_lut, + "boundary_lut": boundary_lut, + } + word_state = WordStartState(sp=sp, order=word_order, normalize_mode=word_normalize) + + token_top_tok, token_top_prob, within_top_tok, within_top_prob, within_valid = ( + ngram_state.process_chunk(target_token_ids_np) + ) + word_top_tok, word_top_prob = word_state.process_chunk( + target_token_ids_np, + starts_new_word_lut=starts_new_word_lut, + boundary_lut=boundary_lut, + ) + + def _expected_gain(p_top, boost): + # E[ -log p'(y) under -log p(y)] when y ~ p + # = p_top * boost - log1p(p_top * (exp(boost) - 1)) + # Maximizing this over experts => pick the most informative hint. + log_norm = np.log1p(p_top * (math.exp(boost) - 1.0)) + return p_top * boost - log_norm + + token_gate = token_top_prob >= np.float32(token_threshold) + within_gate = within_valid & (within_top_prob >= np.float32(within_tau)) + word_gate = word_top_prob >= np.float32(word_tau) + + token_gain = np.where(token_gate, _expected_gain(token_top_prob.astype(np.float64), token_boost), -np.inf) + within_gain = np.where(within_gate, _expected_gain(within_top_prob.astype(np.float64), within_boost), -np.inf) + word_gain = np.where(word_gate, _expected_gain(word_top_prob.astype(np.float64), word_boost), -np.inf) + + stack = np.stack([token_gain, within_gain, word_gain], axis=1) + best_idx = np.argmax(stack, axis=1) + best_gain = np.max(stack, axis=1) + any_gate = best_gain > -np.inf + + hint_ids = np.zeros(total, dtype=np.int64) + boost = np.zeros(total, dtype=np.float32) + base_boost_per_expert = np.array([token_boost, within_boost, word_boost], dtype=np.float32) + hint_per_expert = np.stack([ + token_top_tok.astype(np.int64), + within_top_tok.astype(np.int64), + word_top_tok.astype(np.int64), + ], axis=1) + + rows = np.arange(total) + hint_ids[any_gate] = hint_per_expert[rows[any_gate], best_idx[any_gate]] + boost[any_gate] = base_boost_per_expert[best_idx[any_gate]] + + # Agreement bonus: if 2+ experts agree on the same hint as best, add agree_add_boost + gate_mask_each = np.stack([token_gate, within_gate, word_gate], axis=1) + expert_hints = hint_per_expert.copy() + expert_hints[~gate_mask_each] = -1 + agreements = (expert_hints == hint_ids[:, None]).sum(axis=1) + agreement_extra = np.where(agreements >= 2, np.float32(agree_add_boost), np.float32(0.0)) + boost = (boost + agreement_extra).astype(np.float32) + + log0( + f"ngram_tilt:hints total={total} gated={int(any_gate.sum())} " + f"token_gate={int(token_gate.sum())} within_gate={int(within_gate.sum())} word_gate={int(word_gate.sum())} " + f"agree2plus={int((agreements >= 2).sum())}" + ) + + return { + "hint_ids": hint_ids, + "gate_mask": any_gate, + "boost": boost, + "sp": sp, + "starts_new_word_lut": starts_new_word_lut, + "boundary_lut": boundary_lut, + } + + +def apply_tilt_to_ptl_torch( + ptl: torch.Tensor, + log_q_hint: torch.Tensor, + target_ids: torch.Tensor, + hint_ids: torch.Tensor, + gate_mask: torch.Tensor, + boost: torch.Tensor, +): + """Closed-form tilt applied to per-token NLL. + + All tensors same shape [..., L]. + ptl_tilted = ptl - beta * 1[y == h] + log1p(q * (exp(beta) - 1)) if gate else ptl + """ + boost64 = boost.to(torch.float64) + q = log_q_hint.to(torch.float64).clamp_(max=0.0).exp() + is_hit = (target_ids == hint_ids).to(torch.float64) + log_Z = torch.log1p(q * (torch.expm1(boost64))) + ptl_tilted = ptl.to(torch.float64) - boost64 * is_hit + log_Z + return torch.where(gate_mask, ptl_tilted, ptl.to(torch.float64)).to(ptl.dtype) + + +def apply_tilt_to_ptl_torch_fast( + ptl: torch.Tensor, + log_q_hint: torch.Tensor, + target_ids: torch.Tensor, + hint_ids: torch.Tensor, + gate_mask: torch.Tensor, + boost: torch.Tensor, +): + """fp32 variant of apply_tilt — cast removed where safe. + + BPB downstream accumulator is fp64, so per-token tilt computation in + fp32 has no impact on final precision. Saves ~10-15s per eval pass on + H100 (avoids fp64 ALU + double memory traffic). + """ + boost32 = boost.to(torch.float32) + q = log_q_hint.to(torch.float32).clamp_(max=0.0).exp() + is_hit = (target_ids == hint_ids).to(torch.float32) + log_Z = torch.log1p(q * (torch.expm1(boost32))) + ptl_f32 = ptl.to(torch.float32) + ptl_tilted = ptl_f32 - boost32 * is_hit + log_Z + return torch.where(gate_mask, ptl_tilted, ptl_f32).to(ptl.dtype) diff --git a/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/prepare_caseops_data.py b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/prepare_caseops_data.py new file mode 100644 index 0000000000..5c3f13e69c --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/prepare_caseops_data.py @@ -0,0 +1,177 @@ +"""Prepare CaseOps-tokenized FineWeb shards + per-token byte sidecar. + +CaseOps (``lossless_caps_caseops_v1``) is a bijective, character-level text +transform that introduces four operator tokens in place of explicit +capitalization: TITLE, ALLCAPS, CAPNEXT, ESC. The transform is fully +reversible — no information is lost relative to the untransformed UTF-8 +text, so BPB stays computable on TRUE byte counts. + +Forward pipeline: + 1. Read the canonical FineWeb-10B doc stream (``docs_selected.jsonl`` + produced by ``data/download_hf_docs_and_tokenize.py`` in the root repo). + 2. Apply ``encode_lossless_caps_v2`` (the caseops_v1 alias) to each doc. + 3. Tokenize with the shipped SP model + ``tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model`` + (reserves TITLE/ALLCAPS/CAPNEXT/ESC + sentinel as user_defined_symbols). + 4. Write uint16 train/val shards (``fineweb_{train,val}_XXXXXX.bin``). + 5. For the VAL stream only, emit per-token byte sidecar shards + (``fineweb_val_bytes_XXXXXX.bin``, uint16 parallel arrays) that record + each token's ORIGINAL pre-transform UTF-8 byte count. BPB is computed + from these canonical bytes so the score is on the untransformed text + (not the transformed representation). + +Output layout — matches what ``train_gpt.py`` expects under +``DATA_DIR=./data`` with ``CASEOPS_ENABLED=1``: + + data/datasets/fineweb10B_sp8192_caseops/datasets/ + tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model + datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/ + fineweb_train_000000.bin + fineweb_train_000001.bin + ... + fineweb_val_000000.bin + fineweb_val_bytes_000000.bin + +Usage: + + python3 prepare_caseops_data.py \\ + --docs ./fineweb10B_raw/docs_selected.jsonl \\ + --out ./data/datasets/fineweb10B_sp8192_caseops/datasets \\ + --sp ./tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model + +Requirements: sentencepiece, numpy. CPU-only. Runs once; reused across seeds. +""" +from __future__ import annotations + +import argparse +import json +import pathlib +import struct +import sys + +import numpy as np +import sentencepiece as spm + +# Local import — lossless_caps.py ships next to this script. +sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent)) +from lossless_caps import ( # noqa: E402 + LOSSLESS_CAPS_CASEOPS_V1, + encode_lossless_caps_v2, + surface_piece_original_byte_counts, +) + + +SHARD_MAGIC = 20240520 +SHARD_VERSION = 1 +SHARD_TOKENS = 10_000_000 # tokens per shard — matches the main pipeline +BOS_ID = 1 # SP model's control token; train_gpt.py:_find_docs requires BOS per doc + + +def _write_shard(out_path: pathlib.Path, arr: np.ndarray) -> None: + """Write a uint16 shard in the standard header-prefixed format.""" + assert arr.dtype == np.uint16 + header = np.zeros(256, dtype=np.int32) + header[0] = SHARD_MAGIC + header[1] = SHARD_VERSION + header[2] = int(arr.size) + with out_path.open("wb") as fh: + fh.write(header.tobytes()) + fh.write(arr.tobytes()) + + +def _iter_docs(docs_path: pathlib.Path): + """Yield doc strings from a jsonl file (one json object per line).""" + with docs_path.open("r", encoding="utf-8") as fh: + for line in fh: + line = line.strip() + if not line: + continue + obj = json.loads(line) + # Support both {"text": ...} and raw strings. + yield obj["text"] if isinstance(obj, dict) else obj + + +def _token_original_byte_counts( + sp: spm.SentencePieceProcessor, + original_text: str, + transformed_text: str, +) -> np.ndarray: + """Per-token canonical (pre-transform) UTF-8 byte counts. + + Delegates to ``surface_piece_original_byte_counts`` in ``lossless_caps.py`` + — the canonical exporter used by the PR #1729 / HF-hosted CaseOps dataset. + Operator pieces (U+E001..U+E004) contribute 0 original bytes; letter pieces + contribute their pre-transform UTF-8 byte count. + """ + proto = sp.encode_as_immutable_proto(transformed_text) + byte_counts = surface_piece_original_byte_counts( + (piece.surface for piece in proto.pieces), + text_transform_name=LOSSLESS_CAPS_CASEOPS_V1, + ) + return np.asarray(list(byte_counts), dtype=np.uint16) + + +def main() -> None: + ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + ap.add_argument("--docs", required=True, type=pathlib.Path, help="Path to docs_selected.jsonl") + ap.add_argument("--out", required=True, type=pathlib.Path, help="Output datasets dir") + ap.add_argument("--sp", required=True, type=pathlib.Path, help="Path to CaseOps SP model") + ap.add_argument("--val-docs", type=int, default=10_000, help="Validation docs count") + args = ap.parse_args() + + sp = spm.SentencePieceProcessor(model_file=str(args.sp)) + print(f"loaded sp: vocab={sp.vocab_size()}", flush=True) + + train_out = args.out / "datasets" / "fineweb10B_sp8192_lossless_caps_caseops_v1_reserved" + train_out.mkdir(parents=True, exist_ok=True) + + val_buf_tokens: list[int] = [] + val_buf_bytes: list[int] = [] + train_buf: list[int] = [] + val_written = 0 + train_written = 0 + n_docs = 0 + + for text in _iter_docs(args.docs): + transformed = encode_lossless_caps_v2(text) + token_ids = [BOS_ID] + sp.encode(transformed, out_type=int) + if n_docs < args.val_docs: + # Validation doc — also compute byte sidecar + byte_counts = _token_original_byte_counts(sp, text, transformed) + val_buf_tokens.extend(token_ids) + val_buf_bytes.append(0) # BOS contributes 0 original bytes + val_buf_bytes.extend(int(b) for b in byte_counts) + if len(val_buf_tokens) >= SHARD_TOKENS: + _write_shard(train_out / f"fineweb_val_{val_written:06d}.bin", + np.array(val_buf_tokens[:SHARD_TOKENS], dtype=np.uint16)) + _write_shard(train_out / f"fineweb_val_bytes_{val_written:06d}.bin", + np.array(val_buf_bytes[:SHARD_TOKENS], dtype=np.uint16)) + val_buf_tokens = val_buf_tokens[SHARD_TOKENS:] + val_buf_bytes = val_buf_bytes[SHARD_TOKENS:] + val_written += 1 + else: + train_buf.extend(token_ids) + if len(train_buf) >= SHARD_TOKENS: + _write_shard(train_out / f"fineweb_train_{train_written:06d}.bin", + np.array(train_buf[:SHARD_TOKENS], dtype=np.uint16)) + train_buf = train_buf[SHARD_TOKENS:] + train_written += 1 + n_docs += 1 + if n_docs % 10_000 == 0: + print(f" processed {n_docs} docs train_shards={train_written} val_shards={val_written}", flush=True) + + # Flush tail buffers into final (possibly short) shards. + if val_buf_tokens: + _write_shard(train_out / f"fineweb_val_{val_written:06d}.bin", + np.array(val_buf_tokens, dtype=np.uint16)) + _write_shard(train_out / f"fineweb_val_bytes_{val_written:06d}.bin", + np.array(val_buf_bytes, dtype=np.uint16)) + if train_buf: + _write_shard(train_out / f"fineweb_train_{train_written:06d}.bin", + np.array(train_buf, dtype=np.uint16)) + + print(f"done. docs={n_docs} train_shards={train_written + (1 if train_buf else 0)} val_shards={val_written + (1 if val_buf_tokens else 0)}") + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/requirements.txt b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/requirements.txt new file mode 100644 index 0000000000..b6c55e13aa --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/requirements.txt @@ -0,0 +1,13 @@ +# Python deps. Install with: pip install -r requirements.txt +torch==2.9.1+cu128 +sentencepiece +brotli +huggingface_hub +numpy +python-minifier + +# FlashAttention 3 must be installed separately (not on PyPI): +# pip install --no-deps flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291/ + +# System dep (apt): lrzip (used by per-group compressor) +# apt-get install -y lrzip diff --git a/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/run.sh b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/run.sh new file mode 100755 index 0000000000..42bd8e9f1d --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/run.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash +set -euo pipefail + +SEED="${SEED:-42}" + +NGRAM_HINT_PRECOMPUTE_OUTSIDE="${NGRAM_HINT_PRECOMPUTE_OUTSIDE:-0}" \ +SEED="$SEED" \ +DATA_PATH="${DATA_PATH:-./data/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved}" \ +TOKENIZER_PATH="${TOKENIZER_PATH:-./data/tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model}" \ +CASEOPS_ENABLED=1 VOCAB_SIZE=8192 ITERATIONS=20000 MAX_WALLCLOCK_SECONDS=600 \ +TTT_ENABLED=1 PHASED_TTT_ENABLED=1 PHASED_TTT_NUM_PHASES=1 PHASED_TTT_PREFIX_DOCS=1000 \ +TTT_LORA_RANK=80 TTT_MASK=no_qv TTT_Q_LORA=0 TTT_V_LORA=0 \ +TTT_LOCAL_LR_MULT=0.75 EVAL_SEQ_LEN=2560 TTT_EVAL_SEQ_LEN=2560 \ +QK_GAIN_INIT=5.25 \ +MATRIX_LR=0.026 MIN_LR=0.1 EMBED_BITS=7 GRAD_CLIP_NORM=0.3 \ +MATRIX_CLIP_SIGMAS=12.85 ATTN_CLIP_SIGMAS=13.0 MLP_CLIP_SIGMAS=11.5 EMBED_CLIP_SIGMAS=14.0 \ +FUSED_CE_ENABLED=1 SMEAR_GATE_ENABLED=1 GATE_WINDOW=12 \ +SPARSE_ATTN_GATE_ENABLED=1 LQER_ENABLED=1 LQER_RANK=4 LQER_TOP_K=1 \ +LQER_GROUP_SIZE=64 LQER_ASYM_ENABLED=1 LQER_ASYM_GROUP=64 \ +AWQ_LITE_ENABLED=1 ASYM_LOGIT_RESCALE=1 NGRAM_TILT_ENABLED=1 \ +TOKEN_ORDER=16 TOKEN_THRESHOLD=0.800 TOKEN_BOOST=2.625 \ +WITHIN_TAU=999 WITHIN_BOOST=0 WORD_TAU=999 WORD_BOOST=0 AGREE_ADD_BOOST=0 \ +GATED_XSA=1 SKYLIGHT_MUON=0 \ +GPTQ_RESERVE_SECONDS=4.0 GPTQ_CALIBRATION_BATCHES=16 \ +COMPRESSOR=pergroup \ +torchrun --standalone --nproc_per_node="${NPROC_PER_NODE:-8}" train_gpt.py diff --git a/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/submission.json b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/submission.json new file mode 100644 index 0000000000..cde406d9e0 --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/submission.json @@ -0,0 +1,72 @@ +{ + "author": "Simon Marcus", + "github_id": "simon-marcus", + "name": "Gated XSA + LQER top-1 + strict token-only n-gram TTT", + "blurb": "PR #1967 lineage with a zero-initialized per-head Gated XSA training-time attention superset, LQER_TOP_K=1 for artifact headroom, and strict in-timer token-only closed-form n-gram tilt plus one-phase score-first TTT over a 1000-document prefix. 3-seed mean: 1.04722074 BPB.", + "date": "2026-04-30", + "track": "10min_16mb", + "val_loss": 2.29170716, + "val_bpb": 1.04722074, + "val_loss_std": 0.00229377, + "val_bpb_std": 0.00104816, + "seeds": [ + 42, + 1337, + 2026 + ], + "seed_results": { + "42": { + "val_loss": 2.28940177, + "val_bpb": 1.04616727, + "prequant_val_bpb": 1.04930686, + "quantized_val_bpb": 1.05773513, + "artifact_bytes": 15995574, + "steps": 4914, + "train_time_ms": 596127, + "eval_time_ms": 471457 + }, + "1337": { + "val_loss": 2.29398913, + "val_bpb": 1.04826351, + "prequant_val_bpb": 1.05124428, + "quantized_val_bpb": 1.05990331, + "artifact_bytes": 15992746, + "steps": 4926, + "train_time_ms": 596167, + "eval_time_ms": 465480 + }, + "2026": { + "val_loss": 2.29173058, + "val_bpb": 1.04723144, + "prequant_val_bpb": 1.0502993, + "quantized_val_bpb": 1.05886641, + "artifact_bytes": 15996490, + "steps": 4916, + "train_time_ms": 596080, + "eval_time_ms": 463281 + } + }, + "artifact_bytes_mean": 15994936.67, + "artifact_bytes_max": 15996490, + "bytes_total": 15996490, + "train_steps_mean": 4918.67, + "train_time_ms_mean": 596124.67, + "eval_time_ms_mean": 466739.33, + "hardware": "8xH100 80GB SXM", + "pytorch_version": "2.9.1+cu128", + "cuda_version": "12.8", + "flash_attn_version": "FA3 (cu128_torch291 wheel)", + "technique_summary": "Gated XSA per-head tanh gate, LQER_TOP_K=1, PR #1514-style token-only PR #1967 closed-form n-gram tilt with NGRAM_HINT_PRECOMPUTE_OUTSIDE=0, and one-phase score-first phased TTT with PHASED_TTT_PREFIX_DOCS=1000.", + "ngram_gate_counts": { + "token_gate": 628130, + "within_gate": 0, + "word_gate": 0, + "agree2plus": 0 + }, + "comparison_baseline_pr": "#1855", + "comparison_baseline_bpb": 1.06107587, + "delta_vs_merged_leader_bpb": -0.01385513, + "delta_vs_merged_leader_nats": -0.00960364, + "record_threshold_nats": 0.005, + "record_threshold_multiple": 1.92 +} diff --git a/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/token_only_fast_evals/controller.log b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/token_only_fast_evals/controller.log new file mode 100644 index 0000000000..c30b0934e3 --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/token_only_fast_evals/controller.log @@ -0,0 +1,4 @@ +=== fast token-only seed 1337 start 2026-05-01T00:48:12+00:00 === +=== fast token-only seed 1337 end 2026-05-01T00:58:50+00:00 === +=== fast token-only seed 2026 start 2026-05-01T00:58:50+00:00 === +=== fast token-only seed 2026 end 2026-05-01T01:09:28+00:00 === diff --git a/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/token_only_fast_evals/seed1337/run_tokenonly_fast.log b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/token_only_fast_evals/seed1337/run_tokenonly_fast.log new file mode 100644 index 0000000000..b9718ebf74 --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/token_only_fast_evals/seed1337/run_tokenonly_fast.log @@ -0,0 +1,432 @@ +W0501 00:48:13.334000 48378 torch/distributed/run.py:803] +W0501 00:48:13.334000 48378 torch/distributed/run.py:803] ***************************************** +W0501 00:48:13.334000 48378 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0501 00:48:13.334000 48378 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + agree_add_boost: 0.5 + artifact_dir: /workspace/parameter-golf-fallback/records/fallback_token_only_fast/seed1337 + attn_clip_sigmas: 13.0 + attn_out_gate_enabled: False + attn_out_gate_src: proj + awq_lite_bits: 8 + awq_lite_enabled: True + awq_lite_group_size: 64 + awq_lite_group_top_k: 1 + beta1: 0.9 + beta2: 0.95 + caseops_enabled: True + compressor: pergroup + data_dir: ./data/ + datasets_dir: /tmp/parameter-golf-data-caseops/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved + distributed: True + ema_decay: 0.9965 + embed_bits: 7 + embed_clip_sigmas: 14.0 + embed_lr: 0.6 + embed_wd: 0.085 + enable_looping_at: 0.35 + eval_seq_len: 2560 + eval_stride: 64 + fused_ce_enabled: True + gate_window: 12 + gated_attn_enabled: False + gated_attn_init_std: 0.01 + gated_attn_quant_gate: False + gated_xsa_enabled: True + global_ttt_batch_seqs: 32 + global_ttt_chunk_tokens: 32768 + global_ttt_epochs: 1 + global_ttt_grad_clip: 1.0 + global_ttt_lr: 0.001 + global_ttt_momentum: 0.9 + global_ttt_respect_doc_boundaries: True + global_ttt_warmup_chunks: 0 + global_ttt_warmup_start_lr: 0.0 + gptq_calibration_batches: 16 + gptq_reserve_seconds: 4.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: /workspace/parameter-golf-fallback/records/fallback_token_only_fast/seed1337/tokenonly_fast_p1000_n1_s1337.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + lqer_asym_enabled: True + lqer_asym_group: 64 + lqer_enabled: True + lqer_factor_bits: 4 + lqer_gain_select: False + lqer_rank: 4 + lqer_scope: all + lqer_top_k: 1 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.026 + max_wallclock_seconds: 600.0 + min_lr: 0.1 + mlp_clip_sigmas: 11.5 + mlp_mult: 4.0 + model_dim: 512 + model_path: /workspace/parameter-golf-fallback/records/fallback_token_only_fast/seed1337/final_model.pt + muon_backend_steps: 5 + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + ngram_hint_precompute_outside: False + ngram_tilt_enabled: True + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_final_lane: mean + parallel_start_layer: 8 + phased_ttt_num_phases: 1 + phased_ttt_prefix_docs: 1000 + qk_gain_init: 5.25 + quantized_model_path: /workspace/parameter-golf-fallback/records/fallback_token_only_fast/seed1337/final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + rope_yarn: False + run_id: tokenonly_fast_p1000_n1_s1337 + scalar_lr: 0.02 + seed: 1337 + skip_gates_enabled: True + skylight_beta2: 0.95 + skylight_muon_enabled: False + skylight_uw_floor: 0.35 + smear_gate_enabled: True + sparse_attn_gate_enabled: True + sparse_attn_gate_init_std: 0.0 + sparse_attn_gate_scale: 1.0 + temperature_scale: 1.0 + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + token_boost: 2.625 + token_order: 16 + token_threshold: 0.8 + tokenizer_path: /tmp/parameter-golf-data-caseops/tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model + train_batch_tokens: 786432 + train_files: /tmp/parameter-golf-data-caseops/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.999 + ttt_chunk_size: 48 + ttt_enabled: True + ttt_eval_batches: + ttt_eval_seq_len: 2560 + ttt_grad_steps: 1 + ttt_k_lora: True + ttt_lora_lr: 0.0001 + ttt_lora_rank: 80 + ttt_mlp_lora: True + ttt_o_lora: True + ttt_optimizer: adam + ttt_weight_decay: 2.0 + val_batch_tokens: 524288 + val_bytes_files: /tmp/parameter-golf-data-caseops/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_bytes_*.bin + val_doc_fraction: 1.0 + val_files: /tmp/parameter-golf-data-caseops/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.75 + warmup_steps: 20 + within_boost: 0.0 + within_tau: 999.0 + word_boost: 0.0 + word_normalize: strip_punct_lower + word_order: 4 + word_tau: 999.0 + world_size: 8 + xsa_last_n: 11 +train_shards: 0 +val_tokens: 47851520 +TTT_EVAL_ONLY=1 — skipping training + GPTQ, loading saved artifact for TTT eval +ttt_lora_alpha: 144.0 +ttt_warm_start_a: True +ttt_weight_decay: 2.0 +Deserialize: per-group lrzip decompression... +Deserialize: decompression done in 21.1s +Deserialize: per-group lrzip decompression... +Deserialize: decompression done in 20.9s +ttt_lora:warming up compile (random tokens, no val data) +ttt_lora:compile warmup done (108.2s) + +beginning TTT eval timer +ngram_tilt:hints total=47851520 gated=628130 token_gate=628130 within_gate=0 word_gate=0 agree2plus=0 +ngram_tilt:precompute_done elapsed=11.83s total_targets=47851520 +ttt_phased: total_docs:50000 prefix_docs:1000 suffix_docs:49000 num_phases:1 boundaries:[1000] +ttp: b781/782 bl:2.1212 bb:1.0379 rl:2.1212 rb:1.0379 dl:17258-30330 gd:0 +ttpp: phase:1/1 pd:1424 gd:1000 t:254.4s +tttg: c1/154 lr:0.001000 t:0.5s +tttg: c2/154 lr:0.001000 t:0.5s +tttg: c3/154 lr:0.001000 t:0.6s +tttg: c4/154 lr:0.000999 t:0.7s +tttg: c5/154 lr:0.000998 t:0.8s +tttg: c6/154 lr:0.000997 t:0.8s +tttg: c7/154 lr:0.000996 t:0.9s +tttg: c8/154 lr:0.000995 t:1.0s +tttg: c9/154 lr:0.000993 t:1.1s +tttg: c10/154 lr:0.000991 t:1.2s +tttg: c11/154 lr:0.000989 t:1.2s +tttg: c12/154 lr:0.000987 t:1.3s +tttg: c13/154 lr:0.000985 t:1.4s +tttg: c14/154 lr:0.000982 t:1.5s +tttg: c15/154 lr:0.000979 t:1.5s +tttg: c16/154 lr:0.000976 t:1.6s +tttg: c17/154 lr:0.000973 t:1.7s +tttg: c18/154 lr:0.000970 t:1.8s +tttg: c19/154 lr:0.000966 t:1.8s +tttg: c20/154 lr:0.000962 t:1.9s +tttg: c21/154 lr:0.000958 t:2.0s +tttg: c22/154 lr:0.000954 t:2.1s +tttg: c23/154 lr:0.000950 t:2.2s +tttg: c24/154 lr:0.000945 t:2.2s +tttg: c25/154 lr:0.000941 t:2.3s +tttg: c26/154 lr:0.000936 t:2.4s +tttg: c27/154 lr:0.000930 t:2.5s +tttg: c28/154 lr:0.000925 t:2.5s +tttg: c29/154 lr:0.000920 t:2.6s +tttg: c30/154 lr:0.000914 t:2.7s +tttg: c31/154 lr:0.000908 t:2.8s +tttg: c32/154 lr:0.000902 t:2.9s +tttg: c33/154 lr:0.000896 t:2.9s +tttg: c34/154 lr:0.000890 t:3.0s +tttg: c35/154 lr:0.000883 t:3.1s +tttg: c36/154 lr:0.000876 t:3.2s +tttg: c37/154 lr:0.000870 t:3.2s +tttg: c38/154 lr:0.000863 t:3.3s +tttg: c39/154 lr:0.000855 t:3.4s +tttg: c40/154 lr:0.000848 t:3.5s +tttg: c41/154 lr:0.000841 t:3.5s +tttg: c42/154 lr:0.000833 t:3.6s +tttg: c43/154 lr:0.000825 t:3.7s +tttg: c44/154 lr:0.000817 t:3.8s +tttg: c45/154 lr:0.000809 t:3.9s +tttg: c46/154 lr:0.000801 t:3.9s +tttg: c47/154 lr:0.000793 t:4.0s +tttg: c48/154 lr:0.000785 t:4.1s +tttg: c49/154 lr:0.000776 t:4.2s +tttg: c50/154 lr:0.000768 t:4.3s +tttg: c51/154 lr:0.000759 t:4.3s +tttg: c52/154 lr:0.000750 t:4.4s +tttg: c53/154 lr:0.000741 t:4.5s +tttg: c54/154 lr:0.000732 t:4.6s +tttg: c55/154 lr:0.000723 t:4.6s +tttg: c56/154 lr:0.000714 t:4.7s +tttg: c57/154 lr:0.000704 t:4.8s +tttg: c58/154 lr:0.000695 t:4.9s +tttg: c59/154 lr:0.000685 t:4.9s +tttg: c60/154 lr:0.000676 t:5.0s +tttg: c61/154 lr:0.000666 t:5.1s +tttg: c62/154 lr:0.000656 t:5.2s +tttg: c63/154 lr:0.000647 t:5.3s +tttg: c64/154 lr:0.000637 t:5.3s +tttg: c65/154 lr:0.000627 t:5.4s +tttg: c66/154 lr:0.000617 t:5.5s +tttg: c67/154 lr:0.000607 t:5.6s +tttg: c68/154 lr:0.000597 t:5.6s +tttg: c69/154 lr:0.000587 t:5.7s +tttg: c70/154 lr:0.000577 t:5.8s +tttg: c71/154 lr:0.000567 t:5.9s +tttg: c72/154 lr:0.000556 t:5.9s +tttg: c73/154 lr:0.000546 t:6.0s +tttg: c74/154 lr:0.000536 t:6.1s +tttg: c75/154 lr:0.000526 t:6.2s +tttg: c76/154 lr:0.000515 t:6.3s +tttg: c77/154 lr:0.000505 t:6.3s +tttg: c78/154 lr:0.000495 t:6.4s +tttg: c79/154 lr:0.000485 t:6.5s +tttg: c80/154 lr:0.000474 t:6.6s +tttg: c81/154 lr:0.000464 t:6.6s +tttg: c82/154 lr:0.000454 t:6.7s +tttg: c83/154 lr:0.000444 t:6.8s +tttg: c84/154 lr:0.000433 t:6.9s +tttg: c85/154 lr:0.000423 t:7.0s +tttg: c86/154 lr:0.000413 t:7.0s +tttg: c87/154 lr:0.000403 t:7.1s +tttg: c88/154 lr:0.000393 t:7.2s +tttg: c89/154 lr:0.000383 t:7.3s +tttg: c90/154 lr:0.000373 t:7.4s +tttg: c91/154 lr:0.000363 t:7.4s +tttg: c92/154 lr:0.000353 t:7.5s +tttg: c93/154 lr:0.000344 t:7.6s +tttg: c94/154 lr:0.000334 t:7.7s +tttg: c95/154 lr:0.000324 t:7.7s +tttg: c96/154 lr:0.000315 t:7.8s +tttg: c97/154 lr:0.000305 t:7.9s +tttg: c98/154 lr:0.000296 t:8.0s +tttg: c99/154 lr:0.000286 t:8.0s +tttg: c100/154 lr:0.000277 t:8.1s +tttg: c101/154 lr:0.000268 t:8.2s +tttg: c102/154 lr:0.000259 t:8.3s +tttg: c103/154 lr:0.000250 t:8.3s +tttg: c104/154 lr:0.000241 t:8.4s +tttg: c105/154 lr:0.000232 t:8.5s +tttg: c106/154 lr:0.000224 t:8.6s +tttg: c107/154 lr:0.000215 t:8.7s +tttg: c108/154 lr:0.000207 t:8.7s +tttg: c109/154 lr:0.000199 t:8.8s +tttg: c110/154 lr:0.000191 t:8.9s +tttg: c111/154 lr:0.000183 t:9.0s +tttg: c112/154 lr:0.000175 t:9.0s +tttg: c113/154 lr:0.000167 t:9.1s +tttg: c114/154 lr:0.000159 t:9.2s +tttg: c115/154 lr:0.000152 t:9.3s +tttg: c116/154 lr:0.000145 t:9.4s +tttg: c117/154 lr:0.000137 t:9.4s +tttg: c118/154 lr:0.000130 t:9.5s +tttg: c119/154 lr:0.000124 t:9.6s +tttg: c120/154 lr:0.000117 t:9.7s +tttg: c121/154 lr:0.000110 t:9.7s +tttg: c122/154 lr:0.000104 t:9.8s +tttg: c123/154 lr:0.000098 t:9.9s +tttg: c124/154 lr:0.000092 t:10.0s +tttg: c125/154 lr:0.000086 t:10.1s +tttg: c126/154 lr:0.000080 t:10.1s +tttg: c127/154 lr:0.000075 t:10.2s +tttg: c128/154 lr:0.000070 t:10.3s +tttg: c129/154 lr:0.000064 t:10.4s +tttg: c130/154 lr:0.000059 t:10.5s +tttg: c131/154 lr:0.000055 t:10.5s +tttg: c132/154 lr:0.000050 t:10.6s +tttg: c133/154 lr:0.000046 t:10.7s +tttg: c134/154 lr:0.000042 t:10.8s +tttg: c135/154 lr:0.000038 t:10.9s +tttg: c136/154 lr:0.000034 t:10.9s +tttg: c137/154 lr:0.000030 t:11.0s +tttg: c138/154 lr:0.000027 t:11.1s +tttg: c139/154 lr:0.000024 t:11.2s +tttg: c140/154 lr:0.000021 t:11.3s +tttg: c141/154 lr:0.000018 t:11.3s +tttg: c142/154 lr:0.000015 t:11.4s +tttg: c143/154 lr:0.000013 t:11.5s +tttg: c144/154 lr:0.000011 t:11.6s +tttg: c145/154 lr:0.000009 t:11.6s +tttg: c146/154 lr:0.000007 t:11.7s +tttg: c147/154 lr:0.000005 t:11.8s +tttg: c148/154 lr:0.000004 t:11.9s +tttg: c149/154 lr:0.000003 t:11.9s +tttg: c150/154 lr:0.000002 t:12.0s +tttg: c151/154 lr:0.000001 t:12.1s +tttg: c152/154 lr:0.000000 t:12.2s +tttg: c153/154 lr:0.000000 t:12.3s +ttpr: phase:1/1 t:268.5s +ttp: b756/782 bl:2.3070 bb:1.0268 rl:2.1467 rb:1.0362 dl:3466-3549 gd:1 +ttp: b750/782 bl:2.3619 bb:1.0612 rl:2.1701 rb:1.0391 dl:3090-3149 gd:1 +ttp: b746/782 bl:2.3803 bb:1.0488 rl:2.1895 rb:1.0401 dl:2884-2943 gd:1 +ttp: b741/782 bl:2.2866 bb:1.0255 rl:2.1972 rb:1.0389 dl:2686-2730 gd:1 +ttp: b738/782 bl:2.2883 bb:1.0362 rl:2.2036 rb:1.0387 dl:2583-2618 gd:1 +ttp: b734/782 bl:2.2335 bb:1.0161 rl:2.2055 rb:1.0372 dl:2469-2495 gd:1 +ttp: b731/782 bl:2.3125 bb:1.0314 rl:2.2116 rb:1.0368 dl:2377-2414 gd:1 +ttp: b724/782 bl:2.2978 bb:1.0492 rl:2.2160 rb:1.0375 dl:2203-2231 gd:1 +ttp: b720/782 bl:2.3255 bb:1.0518 rl:2.2210 rb:1.0382 dl:2125-2144 gd:1 +ttp: b719/782 bl:2.2971 bb:1.0344 rl:2.2244 rb:1.0380 dl:2106-2125 gd:1 +ttp: b712/782 bl:2.3092 bb:1.0473 rl:2.2277 rb:1.0384 dl:1984-2002 gd:1 +ttp: b710/782 bl:2.1997 bb:1.0298 rl:2.2267 rb:1.0381 dl:1952-1966 gd:1 +ttp: b704/782 bl:2.2564 bb:1.0252 rl:2.2277 rb:1.0376 dl:1872-1885 gd:1 +ttp: b700/782 bl:2.2500 bb:1.0047 rl:2.2284 rb:1.0365 dl:1824-1834 gd:1 +ttp: b696/782 bl:2.2832 bb:1.0398 rl:2.2301 rb:1.0366 dl:1779-1790 gd:1 +ttp: b693/782 bl:2.3133 bb:1.0392 rl:2.2326 rb:1.0367 dl:1746-1757 gd:1 +ttp: b687/782 bl:2.2864 bb:1.0441 rl:2.2341 rb:1.0369 dl:1685-1696 gd:1 +ttp: b680/782 bl:2.2590 bb:1.0173 rl:2.2347 rb:1.0364 dl:1618-1628 gd:1 +ttp: b674/782 bl:2.3839 bb:1.0797 rl:2.2384 rb:1.0374 dl:1571-1578 gd:1 +ttp: b667/782 bl:2.3391 bb:1.0574 rl:2.2407 rb:1.0379 dl:1514-1521 gd:1 +ttp: b660/782 bl:2.3453 bb:1.0368 rl:2.2430 rb:1.0379 dl:1466-1474 gd:1 +ttp: b653/782 bl:2.2665 bb:1.0275 rl:2.2434 rb:1.0377 dl:1419-1425 gd:1 +ttp: b646/782 bl:2.2444 bb:1.0378 rl:2.2435 rb:1.0377 dl:1375-1382 gd:1 +ttp: b639/782 bl:2.2815 bb:1.0188 rl:2.2442 rb:1.0373 dl:1331-1337 gd:1 +ttp: b632/782 bl:2.3216 bb:1.0214 rl:2.2456 rb:1.0370 dl:1290-1297 gd:1 +ttp: b626/782 bl:2.2863 bb:1.0159 rl:2.2462 rb:1.0366 dl:1260-1265 gd:1 +ttp: b619/782 bl:2.3004 bb:1.0490 rl:2.2471 rb:1.0368 dl:1221-1226 gd:1 +ttp: b612/782 bl:2.2053 bb:0.9991 rl:2.2465 rb:1.0362 dl:1186-1190 gd:1 +ttp: b605/782 bl:2.2218 bb:1.0132 rl:2.2461 rb:1.0359 dl:1154-1159 gd:1 +ttp: b598/782 bl:2.3301 bb:1.0539 rl:2.2473 rb:1.0362 dl:1124-1129 gd:1 +ttp: b591/782 bl:2.2756 bb:1.0184 rl:2.2477 rb:1.0359 dl:1093-1098 gd:1 +ttp: b584/782 bl:2.2636 bb:1.0234 rl:2.2479 rb:1.0357 dl:1064-1069 gd:1 +ttp: b576/782 bl:2.3540 bb:1.0828 rl:2.2492 rb:1.0363 dl:1033-1037 gd:1 +ttp: b569/782 bl:2.2825 bb:1.0321 rl:2.2497 rb:1.0363 dl:1007-1010 gd:1 +ttp: b560/782 bl:2.2343 bb:0.9943 rl:2.2495 rb:1.0358 dl:975-979 gd:1 +ttp: b554/782 bl:2.4063 bb:1.0833 rl:2.2512 rb:1.0363 dl:955-959 gd:1 +ttp: b547/782 bl:2.3010 bb:1.0342 rl:2.2518 rb:1.0363 dl:934-937 gd:1 +ttp: b540/782 bl:2.3248 bb:1.0619 rl:2.2526 rb:1.0366 dl:912-915 gd:1 +ttp: b532/782 bl:2.3642 bb:1.0559 rl:2.2537 rb:1.0368 dl:887-889 gd:1 +ttp: b526/782 bl:2.2991 bb:1.0133 rl:2.2541 rb:1.0365 dl:869-872 gd:1 +ttp: b519/782 bl:2.2687 bb:1.0292 rl:2.2543 rb:1.0365 dl:850-852 gd:1 +ttp: b512/782 bl:2.2723 bb:1.0494 rl:2.2544 rb:1.0366 dl:829-832 gd:1 +ttp: b505/782 bl:2.3032 bb:1.0533 rl:2.2549 rb:1.0367 dl:809-812 gd:1 +ttp: b499/782 bl:2.3068 bb:1.0416 rl:2.2553 rb:1.0368 dl:794-796 gd:1 +ttp: b491/782 bl:2.2537 bb:1.0166 rl:2.2553 rb:1.0366 dl:773-776 gd:1 +ttp: b483/782 bl:2.2268 bb:1.0160 rl:2.2551 rb:1.0364 dl:754-756 gd:1 +ttp: b477/782 bl:2.3740 bb:1.0224 rl:2.2560 rb:1.0363 dl:740-742 gd:1 +ttp: b469/782 bl:2.3010 bb:1.0119 rl:2.2564 rb:1.0361 dl:721-724 gd:1 +ttp: b458/782 bl:2.1756 bb:1.0090 rl:2.2558 rb:1.0359 dl:697-700 gd:1 +ttp: b451/782 bl:2.3795 bb:1.0767 rl:2.2566 rb:1.0362 dl:682-685 gd:1 +ttp: b443/782 bl:2.2087 bb:1.0392 rl:2.2563 rb:1.0362 dl:666-668 gd:1 +ttp: b437/782 bl:2.2644 bb:1.0419 rl:2.2564 rb:1.0363 dl:653-655 gd:1 +ttp: b430/782 bl:2.3561 bb:1.0303 rl:2.2570 rb:1.0362 dl:640-642 gd:1 +ttp: b422/782 bl:2.2796 bb:1.0757 rl:2.2572 rb:1.0365 dl:624-626 gd:1 +ttp: b410/782 bl:2.2966 bb:1.0084 rl:2.2574 rb:1.0363 dl:601-603 gd:1 +ttp: b402/782 bl:2.2171 bb:0.9866 rl:2.2572 rb:1.0360 dl:586-588 gd:1 +ttp: b394/782 bl:2.2214 bb:0.9780 rl:2.2570 rb:1.0356 dl:571-573 gd:1 +ttp: b389/782 bl:2.2601 bb:1.0704 rl:2.2570 rb:1.0358 dl:563-564 gd:1 +ttp: b382/782 bl:2.2663 bb:1.0708 rl:2.2570 rb:1.0360 dl:550-552 gd:1 +ttp: b373/782 bl:2.3840 bb:1.0878 rl:2.2577 rb:1.0363 dl:535-537 gd:1 +ttp: b365/782 bl:2.3063 bb:1.0248 rl:2.2579 rb:1.0362 dl:522-524 gd:1 +ttp: b357/782 bl:2.3044 bb:1.0565 rl:2.2582 rb:1.0363 dl:508-510 gd:1 +ttp: b349/782 bl:2.3191 bb:1.0114 rl:2.2585 rb:1.0362 dl:495-496 gd:1 +ttp: b341/782 bl:2.2770 bb:1.0665 rl:2.2585 rb:1.0363 dl:483-485 gd:1 +ttp: b333/782 bl:2.3915 bb:1.0644 rl:2.2591 rb:1.0365 dl:471-472 gd:1 +ttp: b324/782 bl:2.2898 bb:1.0705 rl:2.2593 rb:1.0366 dl:458-459 gd:1 +ttp: b316/782 bl:2.3316 bb:1.0637 rl:2.2596 rb:1.0367 dl:445-446 gd:1 +ttp: b309/782 bl:2.3844 bb:1.0941 rl:2.2601 rb:1.0370 dl:435-437 gd:1 +ttp: b299/782 bl:2.2902 bb:1.0876 rl:2.2602 rb:1.0372 dl:420-421 gd:1 +ttp: b290/782 bl:2.3090 bb:1.0577 rl:2.2604 rb:1.0372 dl:406-407 gd:1 +ttp: b283/782 bl:2.3401 bb:1.1127 rl:2.2607 rb:1.0375 dl:396-398 gd:1 +ttp: b274/782 bl:2.2777 bb:1.0588 rl:2.2608 rb:1.0376 dl:384-385 gd:1 +ttp: b272/782 bl:2.3512 bb:1.0860 rl:2.2611 rb:1.0378 dl:382-383 gd:1 +ttp: b264/782 bl:2.3941 bb:1.0910 rl:2.2615 rb:1.0379 dl:371-372 gd:1 +ttp: b256/782 bl:2.5066 bb:1.1064 rl:2.2623 rb:1.0382 dl:361-362 gd:1 +ttp: b247/782 bl:2.3258 bb:1.0826 rl:2.2626 rb:1.0383 dl:350-351 gd:1 +ttp: b240/782 bl:2.2686 bb:1.0413 rl:2.2626 rb:1.0383 dl:341-342 gd:1 +ttp: b231/782 bl:2.2724 bb:1.0675 rl:2.2626 rb:1.0384 dl:330-331 gd:1 +ttp: b225/782 bl:2.4085 bb:1.1027 rl:2.2630 rb:1.0386 dl:323-324 gd:1 +ttp: b217/782 bl:2.3448 bb:1.1195 rl:2.2633 rb:1.0388 dl:314-315 gd:1 +ttp: b209/782 bl:2.3903 bb:1.1180 rl:2.2636 rb:1.0390 dl:305-306 gd:1 +ttp: b192/782 bl:2.3382 bb:1.1355 rl:2.2638 rb:1.0393 dl:286-288 gd:1 +ttp: b184/782 bl:2.3618 bb:1.1134 rl:2.2640 rb:1.0395 dl:278-279 gd:1 +ttp: b176/782 bl:2.2846 bb:1.1096 rl:2.2641 rb:1.0396 dl:270-271 gd:1 +ttp: b167/782 bl:2.4852 bb:1.1088 rl:2.2646 rb:1.0398 dl:262-263 gd:1 +ttp: b159/782 bl:2.4558 bb:1.1393 rl:2.2650 rb:1.0400 dl:254-255 gd:1 +ttp: b152/782 bl:2.3551 bb:1.1280 rl:2.2652 rb:1.0402 dl:247-248 gd:1 +ttp: b144/782 bl:2.3299 bb:1.0951 rl:2.2654 rb:1.0403 dl:239-240 gd:1 +ttp: b137/782 bl:2.3795 bb:1.1368 rl:2.2656 rb:1.0405 dl:233-233 gd:1 +ttp: b128/782 bl:2.3536 bb:1.1375 rl:2.2658 rb:1.0407 dl:224-225 gd:1 +ttp: b120/782 bl:2.3622 bb:1.0976 rl:2.2660 rb:1.0408 dl:217-218 gd:1 +ttp: b111/782 bl:2.3764 bb:1.1587 rl:2.2662 rb:1.0410 dl:208-210 gd:1 +ttp: b104/782 bl:2.4678 bb:1.1650 rl:2.2665 rb:1.0412 dl:202-203 gd:1 +ttp: b96/782 bl:2.4445 bb:1.1867 rl:2.2668 rb:1.0414 dl:195-196 gd:1 +ttp: b88/782 bl:2.4452 bb:1.1669 rl:2.2671 rb:1.0416 dl:188-189 gd:1 +ttp: b80/782 bl:2.4225 bb:1.1291 rl:2.2674 rb:1.0418 dl:181-182 gd:1 +ttp: b73/782 bl:2.5159 bb:1.2349 rl:2.2678 rb:1.0420 dl:174-175 gd:1 +ttp: b66/782 bl:2.6078 bb:1.2204 rl:2.2683 rb:1.0423 dl:169-169 gd:1 +ttp: b57/782 bl:2.4350 bb:1.1466 rl:2.2685 rb:1.0424 dl:160-161 gd:1 +ttp: b49/782 bl:2.4249 bb:1.1531 rl:2.2687 rb:1.0426 dl:152-153 gd:1 +ttp: b40/782 bl:2.4572 bb:1.1383 rl:2.2689 rb:1.0427 dl:143-144 gd:1 +ttp: b32/782 bl:2.5693 bb:1.1980 rl:2.2693 rb:1.0429 dl:135-136 gd:1 +ttp: b24/782 bl:2.4248 bb:1.1436 rl:2.2695 rb:1.0430 dl:127-128 gd:1 +ttp: b15/782 bl:2.6236 bb:1.2184 rl:2.2698 rb:1.0432 dl:115-117 gd:1 +ttp: b7/782 bl:2.7243 bb:1.2261 rl:2.2702 rb:1.0433 dl:101-103 gd:1 +quantized_ttt_phased val_loss:2.29398913 val_bpb:1.04826351 eval_time:465480ms +total_eval_time:465.5s diff --git a/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/token_only_fast_evals/seed2026/run_tokenonly_fast.log b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/token_only_fast_evals/seed2026/run_tokenonly_fast.log new file mode 100644 index 0000000000..3b2d01ee9e --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/token_only_fast_evals/seed2026/run_tokenonly_fast.log @@ -0,0 +1,432 @@ +W0501 00:58:51.286000 54573 torch/distributed/run.py:803] +W0501 00:58:51.286000 54573 torch/distributed/run.py:803] ***************************************** +W0501 00:58:51.286000 54573 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0501 00:58:51.286000 54573 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + agree_add_boost: 0.5 + artifact_dir: /workspace/parameter-golf-fallback/records/fallback_token_only_fast/seed2026 + attn_clip_sigmas: 13.0 + attn_out_gate_enabled: False + attn_out_gate_src: proj + awq_lite_bits: 8 + awq_lite_enabled: True + awq_lite_group_size: 64 + awq_lite_group_top_k: 1 + beta1: 0.9 + beta2: 0.95 + caseops_enabled: True + compressor: pergroup + data_dir: ./data/ + datasets_dir: /tmp/parameter-golf-data-caseops/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved + distributed: True + ema_decay: 0.9965 + embed_bits: 7 + embed_clip_sigmas: 14.0 + embed_lr: 0.6 + embed_wd: 0.085 + enable_looping_at: 0.35 + eval_seq_len: 2560 + eval_stride: 64 + fused_ce_enabled: True + gate_window: 12 + gated_attn_enabled: False + gated_attn_init_std: 0.01 + gated_attn_quant_gate: False + gated_xsa_enabled: True + global_ttt_batch_seqs: 32 + global_ttt_chunk_tokens: 32768 + global_ttt_epochs: 1 + global_ttt_grad_clip: 1.0 + global_ttt_lr: 0.001 + global_ttt_momentum: 0.9 + global_ttt_respect_doc_boundaries: True + global_ttt_warmup_chunks: 0 + global_ttt_warmup_start_lr: 0.0 + gptq_calibration_batches: 16 + gptq_reserve_seconds: 4.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: /workspace/parameter-golf-fallback/records/fallback_token_only_fast/seed2026/tokenonly_fast_p1000_n1_s2026.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + lqer_asym_enabled: True + lqer_asym_group: 64 + lqer_enabled: True + lqer_factor_bits: 4 + lqer_gain_select: False + lqer_rank: 4 + lqer_scope: all + lqer_top_k: 1 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.026 + max_wallclock_seconds: 600.0 + min_lr: 0.1 + mlp_clip_sigmas: 11.5 + mlp_mult: 4.0 + model_dim: 512 + model_path: /workspace/parameter-golf-fallback/records/fallback_token_only_fast/seed2026/final_model.pt + muon_backend_steps: 5 + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + ngram_hint_precompute_outside: False + ngram_tilt_enabled: True + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_final_lane: mean + parallel_start_layer: 8 + phased_ttt_num_phases: 1 + phased_ttt_prefix_docs: 1000 + qk_gain_init: 5.25 + quantized_model_path: /workspace/parameter-golf-fallback/records/fallback_token_only_fast/seed2026/final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + rope_yarn: False + run_id: tokenonly_fast_p1000_n1_s2026 + scalar_lr: 0.02 + seed: 2026 + skip_gates_enabled: True + skylight_beta2: 0.95 + skylight_muon_enabled: False + skylight_uw_floor: 0.35 + smear_gate_enabled: True + sparse_attn_gate_enabled: True + sparse_attn_gate_init_std: 0.0 + sparse_attn_gate_scale: 1.0 + temperature_scale: 1.0 + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + token_boost: 2.625 + token_order: 16 + token_threshold: 0.8 + tokenizer_path: /tmp/parameter-golf-data-caseops/tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model + train_batch_tokens: 786432 + train_files: /tmp/parameter-golf-data-caseops/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.999 + ttt_chunk_size: 48 + ttt_enabled: True + ttt_eval_batches: + ttt_eval_seq_len: 2560 + ttt_grad_steps: 1 + ttt_k_lora: True + ttt_lora_lr: 0.0001 + ttt_lora_rank: 80 + ttt_mlp_lora: True + ttt_o_lora: True + ttt_optimizer: adam + ttt_weight_decay: 2.0 + val_batch_tokens: 524288 + val_bytes_files: /tmp/parameter-golf-data-caseops/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_bytes_*.bin + val_doc_fraction: 1.0 + val_files: /tmp/parameter-golf-data-caseops/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.75 + warmup_steps: 20 + within_boost: 0.0 + within_tau: 999.0 + word_boost: 0.0 + word_normalize: strip_punct_lower + word_order: 4 + word_tau: 999.0 + world_size: 8 + xsa_last_n: 11 +train_shards: 0 +val_tokens: 47851520 +TTT_EVAL_ONLY=1 — skipping training + GPTQ, loading saved artifact for TTT eval +ttt_lora_alpha: 144.0 +ttt_warm_start_a: True +ttt_weight_decay: 2.0 +Deserialize: per-group lrzip decompression... +Deserialize: decompression done in 21.1s +Deserialize: per-group lrzip decompression... +Deserialize: decompression done in 21.3s +ttt_lora:warming up compile (random tokens, no val data) +ttt_lora:compile warmup done (110.1s) + +beginning TTT eval timer +ngram_tilt:hints total=47851520 gated=628130 token_gate=628130 within_gate=0 word_gate=0 agree2plus=0 +ngram_tilt:precompute_done elapsed=11.88s total_targets=47851520 +ttt_phased: total_docs:50000 prefix_docs:1000 suffix_docs:49000 num_phases:1 boundaries:[1000] +ttp: b781/782 bl:2.1203 bb:1.0374 rl:2.1203 rb:1.0374 dl:17258-30330 gd:0 +ttpp: phase:1/1 pd:1424 gd:1000 t:254.8s +tttg: c1/154 lr:0.001000 t:0.4s +tttg: c2/154 lr:0.001000 t:0.5s +tttg: c3/154 lr:0.001000 t:0.6s +tttg: c4/154 lr:0.000999 t:0.7s +tttg: c5/154 lr:0.000998 t:0.7s +tttg: c6/154 lr:0.000997 t:0.8s +tttg: c7/154 lr:0.000996 t:0.9s +tttg: c8/154 lr:0.000995 t:1.0s +tttg: c9/154 lr:0.000993 t:1.1s +tttg: c10/154 lr:0.000991 t:1.1s +tttg: c11/154 lr:0.000989 t:1.2s +tttg: c12/154 lr:0.000987 t:1.3s +tttg: c13/154 lr:0.000985 t:1.4s +tttg: c14/154 lr:0.000982 t:1.4s +tttg: c15/154 lr:0.000979 t:1.5s +tttg: c16/154 lr:0.000976 t:1.6s +tttg: c17/154 lr:0.000973 t:1.7s +tttg: c18/154 lr:0.000970 t:1.7s +tttg: c19/154 lr:0.000966 t:1.8s +tttg: c20/154 lr:0.000962 t:1.9s +tttg: c21/154 lr:0.000958 t:2.0s +tttg: c22/154 lr:0.000954 t:2.0s +tttg: c23/154 lr:0.000950 t:2.1s +tttg: c24/154 lr:0.000945 t:2.2s +tttg: c25/154 lr:0.000941 t:2.3s +tttg: c26/154 lr:0.000936 t:2.4s +tttg: c27/154 lr:0.000930 t:2.4s +tttg: c28/154 lr:0.000925 t:2.5s +tttg: c29/154 lr:0.000920 t:2.6s +tttg: c30/154 lr:0.000914 t:2.7s +tttg: c31/154 lr:0.000908 t:2.7s +tttg: c32/154 lr:0.000902 t:2.8s +tttg: c33/154 lr:0.000896 t:2.9s +tttg: c34/154 lr:0.000890 t:3.0s +tttg: c35/154 lr:0.000883 t:3.1s +tttg: c36/154 lr:0.000876 t:3.1s +tttg: c37/154 lr:0.000870 t:3.2s +tttg: c38/154 lr:0.000863 t:3.3s +tttg: c39/154 lr:0.000855 t:3.4s +tttg: c40/154 lr:0.000848 t:3.4s +tttg: c41/154 lr:0.000841 t:3.5s +tttg: c42/154 lr:0.000833 t:3.6s +tttg: c43/154 lr:0.000825 t:3.7s +tttg: c44/154 lr:0.000817 t:3.7s +tttg: c45/154 lr:0.000809 t:3.8s +tttg: c46/154 lr:0.000801 t:3.9s +tttg: c47/154 lr:0.000793 t:4.0s +tttg: c48/154 lr:0.000785 t:4.1s +tttg: c49/154 lr:0.000776 t:4.1s +tttg: c50/154 lr:0.000768 t:4.2s +tttg: c51/154 lr:0.000759 t:4.3s +tttg: c52/154 lr:0.000750 t:4.4s +tttg: c53/154 lr:0.000741 t:4.4s +tttg: c54/154 lr:0.000732 t:4.5s +tttg: c55/154 lr:0.000723 t:4.6s +tttg: c56/154 lr:0.000714 t:4.7s +tttg: c57/154 lr:0.000704 t:4.8s +tttg: c58/154 lr:0.000695 t:4.8s +tttg: c59/154 lr:0.000685 t:4.9s +tttg: c60/154 lr:0.000676 t:5.0s +tttg: c61/154 lr:0.000666 t:5.1s +tttg: c62/154 lr:0.000656 t:5.2s +tttg: c63/154 lr:0.000647 t:5.2s +tttg: c64/154 lr:0.000637 t:5.3s +tttg: c65/154 lr:0.000627 t:5.4s +tttg: c66/154 lr:0.000617 t:5.5s +tttg: c67/154 lr:0.000607 t:5.5s +tttg: c68/154 lr:0.000597 t:5.6s +tttg: c69/154 lr:0.000587 t:5.7s +tttg: c70/154 lr:0.000577 t:5.8s +tttg: c71/154 lr:0.000567 t:5.8s +tttg: c72/154 lr:0.000556 t:5.9s +tttg: c73/154 lr:0.000546 t:6.0s +tttg: c74/154 lr:0.000536 t:6.1s +tttg: c75/154 lr:0.000526 t:6.2s +tttg: c76/154 lr:0.000515 t:6.2s +tttg: c77/154 lr:0.000505 t:6.3s +tttg: c78/154 lr:0.000495 t:6.4s +tttg: c79/154 lr:0.000485 t:6.5s +tttg: c80/154 lr:0.000474 t:6.6s +tttg: c81/154 lr:0.000464 t:6.6s +tttg: c82/154 lr:0.000454 t:6.7s +tttg: c83/154 lr:0.000444 t:6.8s +tttg: c84/154 lr:0.000433 t:6.9s +tttg: c85/154 lr:0.000423 t:7.0s +tttg: c86/154 lr:0.000413 t:7.0s +tttg: c87/154 lr:0.000403 t:7.1s +tttg: c88/154 lr:0.000393 t:7.2s +tttg: c89/154 lr:0.000383 t:7.3s +tttg: c90/154 lr:0.000373 t:7.3s +tttg: c91/154 lr:0.000363 t:7.4s +tttg: c92/154 lr:0.000353 t:7.5s +tttg: c93/154 lr:0.000344 t:7.6s +tttg: c94/154 lr:0.000334 t:7.6s +tttg: c95/154 lr:0.000324 t:7.7s +tttg: c96/154 lr:0.000315 t:7.8s +tttg: c97/154 lr:0.000305 t:7.9s +tttg: c98/154 lr:0.000296 t:7.9s +tttg: c99/154 lr:0.000286 t:8.0s +tttg: c100/154 lr:0.000277 t:8.1s +tttg: c101/154 lr:0.000268 t:8.2s +tttg: c102/154 lr:0.000259 t:8.2s +tttg: c103/154 lr:0.000250 t:8.3s +tttg: c104/154 lr:0.000241 t:8.4s +tttg: c105/154 lr:0.000232 t:8.5s +tttg: c106/154 lr:0.000224 t:8.6s +tttg: c107/154 lr:0.000215 t:8.6s +tttg: c108/154 lr:0.000207 t:8.7s +tttg: c109/154 lr:0.000199 t:8.8s +tttg: c110/154 lr:0.000191 t:8.9s +tttg: c111/154 lr:0.000183 t:8.9s +tttg: c112/154 lr:0.000175 t:9.0s +tttg: c113/154 lr:0.000167 t:9.1s +tttg: c114/154 lr:0.000159 t:9.2s +tttg: c115/154 lr:0.000152 t:9.2s +tttg: c116/154 lr:0.000145 t:9.3s +tttg: c117/154 lr:0.000137 t:9.4s +tttg: c118/154 lr:0.000130 t:9.5s +tttg: c119/154 lr:0.000124 t:9.6s +tttg: c120/154 lr:0.000117 t:9.6s +tttg: c121/154 lr:0.000110 t:9.7s +tttg: c122/154 lr:0.000104 t:9.8s +tttg: c123/154 lr:0.000098 t:9.9s +tttg: c124/154 lr:0.000092 t:9.9s +tttg: c125/154 lr:0.000086 t:10.0s +tttg: c126/154 lr:0.000080 t:10.1s +tttg: c127/154 lr:0.000075 t:10.2s +tttg: c128/154 lr:0.000070 t:10.2s +tttg: c129/154 lr:0.000064 t:10.3s +tttg: c130/154 lr:0.000059 t:10.4s +tttg: c131/154 lr:0.000055 t:10.5s +tttg: c132/154 lr:0.000050 t:10.5s +tttg: c133/154 lr:0.000046 t:10.6s +tttg: c134/154 lr:0.000042 t:10.7s +tttg: c135/154 lr:0.000038 t:10.8s +tttg: c136/154 lr:0.000034 t:10.9s +tttg: c137/154 lr:0.000030 t:10.9s +tttg: c138/154 lr:0.000027 t:11.0s +tttg: c139/154 lr:0.000024 t:11.1s +tttg: c140/154 lr:0.000021 t:11.2s +tttg: c141/154 lr:0.000018 t:11.2s +tttg: c142/154 lr:0.000015 t:11.3s +tttg: c143/154 lr:0.000013 t:11.4s +tttg: c144/154 lr:0.000011 t:11.5s +tttg: c145/154 lr:0.000009 t:11.5s +tttg: c146/154 lr:0.000007 t:11.6s +tttg: c147/154 lr:0.000005 t:11.7s +tttg: c148/154 lr:0.000004 t:11.8s +tttg: c149/154 lr:0.000003 t:11.9s +tttg: c150/154 lr:0.000002 t:11.9s +tttg: c151/154 lr:0.000001 t:12.0s +tttg: c152/154 lr:0.000000 t:12.1s +tttg: c153/154 lr:0.000000 t:12.2s +ttpr: phase:1/1 t:268.8s +ttp: b758/782 bl:2.2792 bb:1.0623 rl:2.1430 rb:1.0411 dl:3634-3740 gd:1 +ttp: b748/782 bl:2.2891 bb:1.0683 rl:2.1584 rb:1.0441 dl:2992-3039 gd:1 +ttp: b745/782 bl:2.2039 bb:1.0090 rl:2.1625 rb:1.0407 dl:2842-2883 gd:1 +ttp: b741/782 bl:2.2834 bb:1.0240 rl:2.1720 rb:1.0393 dl:2686-2730 gd:1 +ttp: b737/782 bl:2.2860 bb:1.0277 rl:2.1800 rb:1.0385 dl:2550-2583 gd:1 +ttp: b733/782 bl:2.3475 bb:1.0510 rl:2.1904 rb:1.0393 dl:2441-2468 gd:1 +ttp: b731/782 bl:2.3111 bb:1.0307 rl:2.1973 rb:1.0388 dl:2377-2414 gd:1 +ttp: b724/782 bl:2.2923 bb:1.0467 rl:2.2021 rb:1.0392 dl:2203-2231 gd:1 +ttp: b720/782 bl:2.3259 bb:1.0520 rl:2.2079 rb:1.0398 dl:2125-2144 gd:1 +ttp: b718/782 bl:2.2609 bb:1.0147 rl:2.2102 rb:1.0387 dl:2089-2106 gd:1 +ttp: b715/782 bl:2.3243 bb:1.0133 rl:2.2148 rb:1.0376 dl:2036-2053 gd:1 +ttp: b708/782 bl:2.2822 bb:1.0208 rl:2.2173 rb:1.0369 dl:1924-1937 gd:1 +ttp: b704/782 bl:2.2550 bb:1.0246 rl:2.2186 rb:1.0365 dl:1872-1885 gd:1 +ttp: b700/782 bl:2.2478 bb:1.0038 rl:2.2196 rb:1.0354 dl:1824-1834 gd:1 +ttp: b696/782 bl:2.2797 bb:1.0382 rl:2.2214 rb:1.0354 dl:1779-1790 gd:1 +ttp: b692/782 bl:2.2657 bb:1.0171 rl:2.2227 rb:1.0349 dl:1737-1746 gd:1 +ttp: b687/782 bl:2.2855 bb:1.0437 rl:2.2245 rb:1.0351 dl:1685-1696 gd:1 +ttp: b681/782 bl:2.3086 bb:1.0321 rl:2.2267 rb:1.0351 dl:1628-1637 gd:1 +ttp: b673/782 bl:2.3365 bb:1.0488 rl:2.2293 rb:1.0354 dl:1562-1571 gd:1 +ttp: b666/782 bl:2.3851 bb:1.0528 rl:2.2329 rb:1.0358 dl:1507-1514 gd:1 +ttp: b659/782 bl:2.2811 bb:1.0294 rl:2.2340 rb:1.0357 dl:1459-1466 gd:1 +ttp: b652/782 bl:2.2207 bb:1.0095 rl:2.2337 rb:1.0351 dl:1411-1419 gd:1 +ttp: b645/782 bl:2.2689 bb:1.0152 rl:2.2344 rb:1.0347 dl:1367-1375 gd:1 +ttp: b638/782 bl:2.3133 bb:1.0540 rl:2.2358 rb:1.0351 dl:1325-1331 gd:1 +ttp: b631/782 bl:2.2780 bb:0.9920 rl:2.2366 rb:1.0343 dl:1285-1290 gd:1 +ttp: b624/782 bl:2.3253 bb:1.0526 rl:2.2381 rb:1.0346 dl:1249-1255 gd:1 +ttp: b616/782 bl:2.3806 bb:1.0326 rl:2.2404 rb:1.0346 dl:1205-1211 gd:1 +ttp: b609/782 bl:2.2421 bb:1.0045 rl:2.2404 rb:1.0341 dl:1172-1177 gd:1 +ttp: b602/782 bl:2.3501 bb:1.0366 rl:2.2420 rb:1.0341 dl:1141-1146 gd:1 +ttp: b597/782 bl:2.3377 bb:1.0395 rl:2.2434 rb:1.0342 dl:1119-1124 gd:1 +ttp: b590/782 bl:2.2830 bb:1.0461 rl:2.2439 rb:1.0344 dl:1089-1093 gd:1 +ttp: b583/782 bl:2.2964 bb:1.0204 rl:2.2446 rb:1.0342 dl:1060-1064 gd:1 +ttp: b574/782 bl:2.3310 bb:1.0460 rl:2.2457 rb:1.0343 dl:1025-1029 gd:1 +ttp: b567/782 bl:2.2345 bb:1.0031 rl:2.2456 rb:1.0339 dl:1001-1004 gd:1 +ttp: b561/782 bl:2.2133 bb:0.9984 rl:2.2452 rb:1.0335 dl:979-983 gd:1 +ttp: b554/782 bl:2.3993 bb:1.0801 rl:2.2469 rb:1.0341 dl:955-959 gd:1 +ttp: b548/782 bl:2.2189 bb:1.0366 rl:2.2466 rb:1.0341 dl:937-939 gd:1 +ttp: b541/782 bl:2.2974 bb:1.0194 rl:2.2472 rb:1.0339 dl:915-918 gd:1 +ttp: b533/782 bl:2.3440 bb:1.0544 rl:2.2481 rb:1.0341 dl:890-892 gd:1 +ttp: b527/782 bl:2.3210 bb:1.0187 rl:2.2489 rb:1.0340 dl:872-875 gd:1 +ttp: b520/782 bl:2.2979 bb:0.9908 rl:2.2493 rb:1.0335 dl:852-854 gd:1 +ttp: b512/782 bl:2.2707 bb:1.0486 rl:2.2495 rb:1.0337 dl:829-832 gd:1 +ttp: b505/782 bl:2.3010 bb:1.0523 rl:2.2500 rb:1.0338 dl:809-812 gd:1 +ttp: b499/782 bl:2.3050 bb:1.0408 rl:2.2505 rb:1.0339 dl:794-796 gd:1 +ttp: b492/782 bl:2.2530 bb:1.0233 rl:2.2505 rb:1.0338 dl:776-778 gd:1 +ttp: b485/782 bl:2.2587 bb:1.0175 rl:2.2506 rb:1.0337 dl:759-761 gd:1 +ttp: b477/782 bl:2.3736 bb:1.0222 rl:2.2515 rb:1.0336 dl:740-742 gd:1 +ttp: b469/782 bl:2.3006 bb:1.0117 rl:2.2519 rb:1.0334 dl:721-724 gd:1 +ttp: b459/782 bl:2.2474 bb:1.0290 rl:2.2519 rb:1.0334 dl:700-701 gd:1 +ttp: b451/782 bl:2.3742 bb:1.0743 rl:2.2527 rb:1.0337 dl:682-685 gd:1 +ttp: b443/782 bl:2.2059 bb:1.0379 rl:2.2524 rb:1.0337 dl:666-668 gd:1 +ttp: b437/782 bl:2.2671 bb:1.0431 rl:2.2525 rb:1.0338 dl:653-655 gd:1 +ttp: b429/782 bl:2.2174 bb:1.0113 rl:2.2523 rb:1.0336 dl:638-640 gd:1 +ttp: b422/782 bl:2.2739 bb:1.0730 rl:2.2524 rb:1.0339 dl:624-626 gd:1 +ttp: b411/782 bl:2.3336 bb:1.0474 rl:2.2529 rb:1.0339 dl:603-605 gd:1 +ttp: b404/782 bl:2.3388 bb:1.0474 rl:2.2534 rb:1.0340 dl:590-592 gd:1 +ttp: b396/782 bl:2.2513 bb:1.0590 rl:2.2534 rb:1.0342 dl:575-577 gd:1 +ttp: b394/782 bl:2.2165 bb:0.9758 rl:2.2532 rb:1.0338 dl:571-573 gd:1 +ttp: b386/782 bl:2.3090 bb:1.0844 rl:2.2535 rb:1.0341 dl:557-559 gd:1 +ttp: b378/782 bl:2.3935 bb:1.0386 rl:2.2543 rb:1.0341 dl:544-545 gd:1 +ttp: b370/782 bl:2.3402 bb:1.0713 rl:2.2547 rb:1.0343 dl:530-532 gd:1 +ttp: b358/782 bl:2.3735 bb:1.0652 rl:2.2553 rb:1.0345 dl:510-512 gd:1 +ttp: b350/782 bl:2.3052 bb:1.0476 rl:2.2555 rb:1.0345 dl:497-498 gd:1 +ttp: b342/782 bl:2.3484 bb:1.1110 rl:2.2560 rb:1.0349 dl:485-486 gd:1 +ttp: b335/782 bl:2.3240 bb:1.0528 rl:2.2563 rb:1.0349 dl:474-476 gd:1 +ttp: b327/782 bl:2.2987 bb:1.0688 rl:2.2565 rb:1.0351 dl:462-463 gd:1 +ttp: b319/782 bl:2.3598 bb:1.0641 rl:2.2569 rb:1.0352 dl:450-451 gd:1 +ttp: b311/782 bl:2.3173 bb:1.0681 rl:2.2572 rb:1.0354 dl:438-439 gd:1 +ttp: b303/782 bl:2.3538 bb:1.0736 rl:2.2575 rb:1.0355 dl:426-427 gd:1 +ttp: b293/782 bl:2.4114 bb:1.0873 rl:2.2581 rb:1.0357 dl:410-412 gd:1 +ttp: b288/782 bl:2.2083 bb:1.0051 rl:2.2579 rb:1.0356 dl:403-405 gd:1 +ttp: b279/782 bl:2.2825 bb:1.0785 rl:2.2580 rb:1.0358 dl:391-392 gd:1 +ttp: b272/782 bl:2.3445 bb:1.0829 rl:2.2583 rb:1.0359 dl:382-383 gd:1 +ttp: b264/782 bl:2.3872 bb:1.0878 rl:2.2588 rb:1.0361 dl:371-372 gd:1 +ttp: b256/782 bl:2.5046 bb:1.1055 rl:2.2596 rb:1.0363 dl:361-362 gd:1 +ttp: b248/782 bl:2.4251 bb:1.1704 rl:2.2601 rb:1.0367 dl:351-352 gd:1 +ttp: b237/782 bl:2.3028 bb:1.0817 rl:2.2603 rb:1.0369 dl:337-338 gd:1 +ttp: b233/782 bl:2.3416 bb:1.1188 rl:2.2605 rb:1.0371 dl:333-334 gd:1 +ttp: b223/782 bl:2.2901 bb:1.1056 rl:2.2606 rb:1.0373 dl:321-322 gd:1 +ttp: b218/782 bl:2.4278 bb:1.0950 rl:2.2611 rb:1.0375 dl:315-316 gd:1 +ttp: b212/782 bl:2.3351 bb:1.0659 rl:2.2613 rb:1.0376 dl:308-309 gd:1 +ttp: b205/782 bl:2.2993 bb:1.1009 rl:2.2614 rb:1.0377 dl:301-302 gd:1 +ttp: b196/782 bl:2.4005 bb:1.0954 rl:2.2617 rb:1.0379 dl:291-292 gd:1 +ttp: b171/782 bl:2.4486 bb:1.1291 rl:2.2622 rb:1.0381 dl:266-266 gd:1 +ttp: b163/782 bl:2.3470 bb:1.1057 rl:2.2624 rb:1.0382 dl:257-259 gd:1 +ttp: b156/782 bl:2.2691 bb:1.1329 rl:2.2624 rb:1.0384 dl:251-252 gd:1 +ttp: b148/782 bl:2.2981 bb:1.0873 rl:2.2625 rb:1.0385 dl:243-244 gd:1 +ttp: b140/782 bl:2.4027 bb:1.1218 rl:2.2628 rb:1.0387 dl:235-236 gd:1 +ttp: b132/782 bl:2.4107 bb:1.1449 rl:2.2631 rb:1.0389 dl:228-229 gd:1 +ttp: b125/782 bl:2.4407 bb:1.1245 rl:2.2634 rb:1.0391 dl:222-222 gd:1 +ttp: b116/782 bl:2.4532 bb:1.1139 rl:2.2638 rb:1.0392 dl:213-214 gd:1 +ttp: b108/782 bl:2.3758 bb:1.1453 rl:2.2640 rb:1.0394 dl:206-207 gd:1 +ttp: b100/782 bl:2.3932 bb:1.1449 rl:2.2642 rb:1.0396 dl:199-200 gd:1 +ttp: b92/782 bl:2.4047 bb:1.1442 rl:2.2644 rb:1.0398 dl:191-192 gd:1 +ttp: b85/782 bl:2.4756 bb:1.1856 rl:2.2648 rb:1.0400 dl:185-186 gd:1 +ttp: b77/782 bl:2.4777 bb:1.2170 rl:2.2651 rb:1.0402 dl:178-179 gd:1 +ttp: b69/782 bl:2.4374 bb:1.1898 rl:2.2654 rb:1.0405 dl:171-172 gd:1 +ttp: b61/782 bl:2.4238 bb:1.1997 rl:2.2656 rb:1.0407 dl:164-165 gd:1 +ttp: b53/782 bl:2.4717 bb:1.1778 rl:2.2659 rb:1.0408 dl:156-157 gd:1 +ttp: b44/782 bl:2.5344 bb:1.1826 rl:2.2662 rb:1.0410 dl:147-148 gd:1 +ttp: b36/782 bl:2.5008 bb:1.2067 rl:2.2665 rb:1.0412 dl:139-140 gd:1 +ttp: b29/782 bl:2.5957 bb:1.2008 rl:2.2669 rb:1.0414 dl:132-133 gd:1 +ttp: b22/782 bl:2.5282 bb:1.1835 rl:2.2672 rb:1.0415 dl:124-126 gd:1 +ttp: b15/782 bl:2.6158 bb:1.2148 rl:2.2675 rb:1.0417 dl:115-117 gd:1 +ttp: b9/782 bl:2.7217 bb:1.2419 rl:2.2679 rb:1.0419 dl:105-107 gd:1 +quantized_ttt_phased val_loss:2.29173058 val_bpb:1.04723144 eval_time:463281ms +total_eval_time:463.3s diff --git a/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/token_only_fast_evals/seed42/run_tokenonly_fast.log b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/token_only_fast_evals/seed42/run_tokenonly_fast.log new file mode 100644 index 0000000000..cf64b897df --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/token_only_fast_evals/seed42/run_tokenonly_fast.log @@ -0,0 +1,413 @@ +W0501 00:35:47.233000 42063 torch/distributed/run.py:803] +W0501 00:35:47.233000 42063 torch/distributed/run.py:803] ***************************************** +W0501 00:35:47.233000 42063 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0501 00:35:47.233000 42063 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + agree_add_boost: 0.5 + artifact_dir: /workspace/parameter-golf-fallback/records/fallback_token_only_fast/seed42 + attn_clip_sigmas: 13.0 + attn_out_gate_enabled: False + attn_out_gate_src: proj + awq_lite_bits: 8 + awq_lite_enabled: True + awq_lite_group_size: 64 + awq_lite_group_top_k: 1 + beta1: 0.9 + beta2: 0.95 + caseops_enabled: True + compressor: pergroup + data_dir: ./data/ + datasets_dir: /tmp/parameter-golf-data-caseops/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved + distributed: True + ema_decay: 0.9965 + embed_bits: 7 + embed_clip_sigmas: 14.0 + embed_lr: 0.6 + embed_wd: 0.085 + enable_looping_at: 0.35 + eval_seq_len: 2560 + eval_stride: 64 + fused_ce_enabled: True + gate_window: 12 + gated_attn_enabled: False + gated_attn_init_std: 0.01 + gated_attn_quant_gate: False + gated_xsa_enabled: True + global_ttt_batch_seqs: 32 + global_ttt_chunk_tokens: 32768 + global_ttt_epochs: 1 + global_ttt_grad_clip: 1.0 + global_ttt_lr: 0.001 + global_ttt_momentum: 0.9 + global_ttt_respect_doc_boundaries: True + global_ttt_warmup_chunks: 0 + global_ttt_warmup_start_lr: 0.0 + gptq_calibration_batches: 16 + gptq_reserve_seconds: 4.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: /workspace/parameter-golf-fallback/records/fallback_token_only_fast/seed42/tokenonly_fast_p1000_n1_s42.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + lqer_asym_enabled: True + lqer_asym_group: 64 + lqer_enabled: True + lqer_factor_bits: 4 + lqer_gain_select: False + lqer_rank: 4 + lqer_scope: all + lqer_top_k: 1 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.026 + max_wallclock_seconds: 600.0 + min_lr: 0.1 + mlp_clip_sigmas: 11.5 + mlp_mult: 4.0 + model_dim: 512 + model_path: /workspace/parameter-golf-fallback/records/fallback_token_only_fast/seed42/final_model.pt + muon_backend_steps: 5 + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + ngram_hint_precompute_outside: False + ngram_tilt_enabled: True + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_final_lane: mean + parallel_start_layer: 8 + phased_ttt_num_phases: 1 + phased_ttt_prefix_docs: 1000 + qk_gain_init: 5.25 + quantized_model_path: /workspace/parameter-golf-fallback/records/fallback_token_only_fast/seed42/final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + rope_yarn: False + run_id: tokenonly_fast_p1000_n1_s42 + scalar_lr: 0.02 + seed: 42 + skip_gates_enabled: True + skylight_beta2: 0.95 + skylight_muon_enabled: False + skylight_uw_floor: 0.35 + smear_gate_enabled: True + sparse_attn_gate_enabled: True + sparse_attn_gate_init_std: 0.0 + sparse_attn_gate_scale: 1.0 + temperature_scale: 1.0 + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + token_boost: 2.625 + token_order: 16 + token_threshold: 0.8 + tokenizer_path: /tmp/parameter-golf-data-caseops/tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model + train_batch_tokens: 786432 + train_files: /tmp/parameter-golf-data-caseops/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.999 + ttt_chunk_size: 48 + ttt_enabled: True + ttt_eval_batches: + ttt_eval_seq_len: 2560 + ttt_grad_steps: 1 + ttt_k_lora: True + ttt_lora_lr: 0.0001 + ttt_lora_rank: 80 + ttt_mlp_lora: True + ttt_o_lora: True + ttt_optimizer: adam + ttt_weight_decay: 2.0 + val_batch_tokens: 524288 + val_bytes_files: /tmp/parameter-golf-data-caseops/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_bytes_*.bin + val_doc_fraction: 1.0 + val_files: /tmp/parameter-golf-data-caseops/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.75 + warmup_steps: 20 + within_boost: 0.0 + within_tau: 999.0 + word_boost: 0.0 + word_normalize: strip_punct_lower + word_order: 4 + word_tau: 999.0 + world_size: 8 + xsa_last_n: 11 +train_shards: 0 +val_tokens: 47851520 +TTT_EVAL_ONLY=1 — skipping training + GPTQ, loading saved artifact for TTT eval +ttt_lora_alpha: 144.0 +ttt_warm_start_a: True +ttt_weight_decay: 2.0 +Deserialize: per-group lrzip decompression... +Deserialize: decompression done in 20.9s +Deserialize: per-group lrzip decompression... +Deserialize: decompression done in 20.8s +ttt_lora:warming up compile (random tokens, no val data) +ttt_lora:compile warmup done (108.8s) + +beginning TTT eval timer +ngram_tilt:hints total=47851520 gated=628130 token_gate=628130 within_gate=0 word_gate=0 agree2plus=0 +ngram_tilt:precompute_done elapsed=11.78s total_targets=47851520 +ttt_phased: total_docs:50000 prefix_docs:1000 suffix_docs:49000 num_phases:1 boundaries:[1000] +ttp: b782/782 bl:2.1027 bb:0.9957 rl:2.1027 rb:0.9957 dl:30339-97114 gd:0 +ttpp: phase:1/1 pd:1424 gd:1000 t:254.6s +tttg: c1/154 lr:0.001000 t:0.4s +tttg: c2/154 lr:0.001000 t:0.5s +tttg: c3/154 lr:0.001000 t:0.6s +tttg: c4/154 lr:0.000999 t:0.7s +tttg: c5/154 lr:0.000998 t:0.7s +tttg: c6/154 lr:0.000997 t:0.8s +tttg: c7/154 lr:0.000996 t:0.9s +tttg: c8/154 lr:0.000995 t:1.0s +tttg: c9/154 lr:0.000993 t:1.1s +tttg: c10/154 lr:0.000991 t:1.1s +tttg: c11/154 lr:0.000989 t:1.2s +tttg: c12/154 lr:0.000987 t:1.3s +tttg: c13/154 lr:0.000985 t:1.4s +tttg: c14/154 lr:0.000982 t:1.5s +tttg: c15/154 lr:0.000979 t:1.5s +tttg: c16/154 lr:0.000976 t:1.6s +tttg: c17/154 lr:0.000973 t:1.7s +tttg: c18/154 lr:0.000970 t:1.8s +tttg: c19/154 lr:0.000966 t:1.8s +tttg: c20/154 lr:0.000962 t:1.9s +tttg: c21/154 lr:0.000958 t:2.0s +tttg: c22/154 lr:0.000954 t:2.1s +tttg: c23/154 lr:0.000950 t:2.1s +tttg: c24/154 lr:0.000945 t:2.2s +tttg: c25/154 lr:0.000941 t:2.3s +tttg: c26/154 lr:0.000936 t:2.4s +tttg: c27/154 lr:0.000930 t:2.5s +tttg: c28/154 lr:0.000925 t:2.5s +tttg: c29/154 lr:0.000920 t:2.6s +tttg: c30/154 lr:0.000914 t:2.7s +tttg: c31/154 lr:0.000908 t:2.8s +tttg: c32/154 lr:0.000902 t:2.8s +tttg: c33/154 lr:0.000896 t:2.9s +tttg: c34/154 lr:0.000890 t:3.0s +tttg: c35/154 lr:0.000883 t:3.1s +tttg: c36/154 lr:0.000876 t:3.1s +tttg: c37/154 lr:0.000870 t:3.2s +tttg: c38/154 lr:0.000863 t:3.3s +tttg: c39/154 lr:0.000855 t:3.4s +tttg: c40/154 lr:0.000848 t:3.5s +tttg: c41/154 lr:0.000841 t:3.5s +tttg: c42/154 lr:0.000833 t:3.6s +tttg: c43/154 lr:0.000825 t:3.7s +tttg: c44/154 lr:0.000817 t:3.8s +tttg: c45/154 lr:0.000809 t:3.8s +tttg: c46/154 lr:0.000801 t:3.9s +tttg: c47/154 lr:0.000793 t:4.0s +tttg: c48/154 lr:0.000785 t:4.1s +tttg: c49/154 lr:0.000776 t:4.2s +tttg: c50/154 lr:0.000768 t:4.2s +tttg: c51/154 lr:0.000759 t:4.3s +tttg: c52/154 lr:0.000750 t:4.4s +tttg: c53/154 lr:0.000741 t:4.5s +tttg: c54/154 lr:0.000732 t:4.6s +tttg: c55/154 lr:0.000723 t:4.6s +tttg: c56/154 lr:0.000714 t:4.7s +tttg: c57/154 lr:0.000704 t:4.8s +tttg: c58/154 lr:0.000695 t:4.9s +tttg: c59/154 lr:0.000685 t:5.0s +tttg: c60/154 lr:0.000676 t:5.0s +tttg: c61/154 lr:0.000666 t:5.1s +tttg: c62/154 lr:0.000656 t:5.2s +tttg: c63/154 lr:0.000647 t:5.3s +tttg: c64/154 lr:0.000637 t:5.3s +tttg: c65/154 lr:0.000627 t:5.4s +tttg: c66/154 lr:0.000617 t:5.5s +tttg: c67/154 lr:0.000607 t:5.6s +tttg: c68/154 lr:0.000597 t:5.6s +tttg: c69/154 lr:0.000587 t:5.7s +tttg: c70/154 lr:0.000577 t:5.8s +tttg: c71/154 lr:0.000567 t:5.9s +tttg: c72/154 lr:0.000556 t:5.9s +tttg: c73/154 lr:0.000546 t:6.0s +tttg: c74/154 lr:0.000536 t:6.1s +tttg: c75/154 lr:0.000526 t:6.2s +tttg: c76/154 lr:0.000515 t:6.3s +tttg: c77/154 lr:0.000505 t:6.3s +tttg: c78/154 lr:0.000495 t:6.4s +tttg: c79/154 lr:0.000485 t:6.5s +tttg: c80/154 lr:0.000474 t:6.6s +tttg: c81/154 lr:0.000464 t:6.6s +tttg: c82/154 lr:0.000454 t:6.7s +tttg: c83/154 lr:0.000444 t:6.8s +tttg: c84/154 lr:0.000433 t:6.9s +tttg: c85/154 lr:0.000423 t:7.0s +tttg: c86/154 lr:0.000413 t:7.1s +tttg: c87/154 lr:0.000403 t:7.1s +tttg: c88/154 lr:0.000393 t:7.2s +tttg: c89/154 lr:0.000383 t:7.3s +tttg: c90/154 lr:0.000373 t:7.4s +tttg: c91/154 lr:0.000363 t:7.4s +tttg: c92/154 lr:0.000353 t:7.5s +tttg: c93/154 lr:0.000344 t:7.6s +tttg: c94/154 lr:0.000334 t:7.7s +tttg: c95/154 lr:0.000324 t:7.8s +tttg: c96/154 lr:0.000315 t:7.8s +tttg: c97/154 lr:0.000305 t:7.9s +tttg: c98/154 lr:0.000296 t:8.0s +tttg: c99/154 lr:0.000286 t:8.1s +tttg: c100/154 lr:0.000277 t:8.1s +tttg: c101/154 lr:0.000268 t:8.2s +tttg: c102/154 lr:0.000259 t:8.3s +tttg: c103/154 lr:0.000250 t:8.4s +tttg: c104/154 lr:0.000241 t:8.5s +tttg: c105/154 lr:0.000232 t:8.5s +tttg: c106/154 lr:0.000224 t:8.6s +tttg: c107/154 lr:0.000215 t:8.7s +tttg: c108/154 lr:0.000207 t:8.8s +tttg: c109/154 lr:0.000199 t:8.8s +tttg: c110/154 lr:0.000191 t:8.9s +tttg: c111/154 lr:0.000183 t:9.0s +tttg: c112/154 lr:0.000175 t:9.1s +tttg: c113/154 lr:0.000167 t:9.1s +tttg: c114/154 lr:0.000159 t:9.2s +tttg: c115/154 lr:0.000152 t:9.3s +tttg: c116/154 lr:0.000145 t:9.4s +tttg: c117/154 lr:0.000137 t:9.4s +tttg: c118/154 lr:0.000130 t:9.5s +tttg: c119/154 lr:0.000124 t:9.6s +tttg: c120/154 lr:0.000117 t:9.7s +tttg: c121/154 lr:0.000110 t:9.8s +tttg: c122/154 lr:0.000104 t:9.8s +tttg: c123/154 lr:0.000098 t:9.9s +tttg: c124/154 lr:0.000092 t:10.0s +tttg: c125/154 lr:0.000086 t:10.1s +tttg: c126/154 lr:0.000080 t:10.1s +tttg: c127/154 lr:0.000075 t:10.2s +tttg: c128/154 lr:0.000070 t:10.3s +tttg: c129/154 lr:0.000064 t:10.4s +tttg: c130/154 lr:0.000059 t:10.4s +tttg: c131/154 lr:0.000055 t:10.5s +tttg: c132/154 lr:0.000050 t:10.6s +tttg: c133/154 lr:0.000046 t:10.7s +tttg: c134/154 lr:0.000042 t:10.8s +tttg: c135/154 lr:0.000038 t:10.8s +tttg: c136/154 lr:0.000034 t:10.9s +tttg: c137/154 lr:0.000030 t:11.0s +tttg: c138/154 lr:0.000027 t:11.1s +tttg: c139/154 lr:0.000024 t:11.1s +tttg: c140/154 lr:0.000021 t:11.2s +tttg: c141/154 lr:0.000018 t:11.3s +tttg: c142/154 lr:0.000015 t:11.4s +tttg: c143/154 lr:0.000013 t:11.5s +tttg: c144/154 lr:0.000011 t:11.5s +tttg: c145/154 lr:0.000009 t:11.6s +tttg: c146/154 lr:0.000007 t:11.7s +tttg: c147/154 lr:0.000005 t:11.8s +tttg: c148/154 lr:0.000004 t:11.9s +tttg: c149/154 lr:0.000003 t:11.9s +tttg: c150/154 lr:0.000002 t:12.0s +tttg: c151/154 lr:0.000001 t:12.1s +tttg: c152/154 lr:0.000000 t:12.2s +tttg: c153/154 lr:0.000000 t:12.2s +ttpr: phase:1/1 t:268.6s +ttp: b758/782 bl:2.2769 bb:1.0612 rl:2.1428 rb:1.0110 dl:3634-3740 gd:1 +ttp: b690/782 bl:2.2658 bb:1.0518 rl:2.1548 rb:1.0150 dl:1715-1725 gd:1 +ttp: b684/782 bl:2.3366 bb:1.0294 rl:2.1704 rb:1.0163 dl:1658-1665 gd:1 +ttp: b677/782 bl:2.2813 bb:1.0221 rl:2.1788 rb:1.0168 dl:1595-1601 gd:1 +ttp: b668/782 bl:2.2972 bb:1.0502 rl:2.1868 rb:1.0191 dl:1521-1530 gd:1 +ttp: b661/782 bl:2.3710 bb:1.0719 rl:2.1982 rb:1.0224 dl:1474-1480 gd:1 +ttp: b653/782 bl:2.2657 bb:1.0272 rl:2.2020 rb:1.0227 dl:1419-1425 gd:1 +ttp: b645/782 bl:2.2680 bb:1.0148 rl:2.2053 rb:1.0223 dl:1367-1375 gd:1 +ttp: b637/782 bl:2.3302 bb:1.0627 rl:2.2112 rb:1.0242 dl:1320-1325 gd:1 +ttp: b629/782 bl:2.3176 bb:0.9974 rl:2.2159 rb:1.0230 dl:1276-1280 gd:1 +ttp: b621/782 bl:2.2605 bb:1.0323 rl:2.2177 rb:1.0233 dl:1231-1237 gd:1 +ttp: b613/782 bl:2.3082 bb:1.0277 rl:2.2211 rb:1.0235 dl:1190-1195 gd:1 +ttp: b606/782 bl:2.3331 bb:1.0542 rl:2.2250 rb:1.0246 dl:1159-1164 gd:1 +ttp: b597/782 bl:2.3349 bb:1.0383 rl:2.2286 rb:1.0251 dl:1119-1124 gd:1 +ttp: b589/782 bl:2.2402 bb:0.9949 rl:2.2290 rb:1.0241 dl:1086-1089 gd:1 +ttp: b582/782 bl:2.3154 bb:1.0170 rl:2.2315 rb:1.0239 dl:1056-1060 gd:1 +ttp: b571/782 bl:2.2688 bb:0.9925 rl:2.2325 rb:1.0230 dl:1014-1017 gd:1 +ttp: b563/782 bl:2.2289 bb:1.0016 rl:2.2324 rb:1.0224 dl:987-990 gd:1 +ttp: b556/782 bl:2.3456 bb:1.0545 rl:2.2352 rb:1.0232 dl:961-965 gd:1 +ttp: b548/782 bl:2.2110 bb:1.0329 rl:2.2346 rb:1.0235 dl:937-939 gd:1 +ttp: b541/782 bl:2.2935 bb:1.0177 rl:2.2360 rb:1.0233 dl:915-918 gd:1 +ttp: b534/782 bl:2.2968 bb:1.0288 rl:2.2372 rb:1.0234 dl:893-896 gd:1 +ttp: b526/782 bl:2.2997 bb:1.0136 rl:2.2385 rb:1.0232 dl:869-872 gd:1 +ttp: b517/782 bl:2.3268 bb:1.0153 rl:2.2402 rb:1.0231 dl:843-846 gd:1 +ttp: b511/782 bl:2.3641 bb:1.0400 rl:2.2425 rb:1.0234 dl:826-829 gd:1 +ttp: b502/782 bl:2.2876 bb:1.0137 rl:2.2433 rb:1.0232 dl:802-804 gd:1 +ttp: b494/782 bl:2.2912 bb:1.0443 rl:2.2441 rb:1.0236 dl:780-783 gd:1 +ttp: b486/782 bl:2.3787 bb:1.0687 rl:2.2463 rb:1.0243 dl:761-764 gd:1 +ttp: b479/782 bl:2.3770 bb:1.0680 rl:2.2484 rb:1.0250 dl:744-747 gd:1 +ttp: b471/782 bl:2.3675 bb:1.0689 rl:2.2502 rb:1.0257 dl:726-728 gd:1 +ttp: b463/782 bl:2.2832 bb:1.0274 rl:2.2506 rb:1.0257 dl:708-710 gd:1 +ttp: b455/782 bl:2.2623 bb:1.0196 rl:2.2508 rb:1.0256 dl:691-693 gd:1 +ttp: b447/782 bl:2.2981 bb:1.0557 rl:2.2514 rb:1.0260 dl:674-676 gd:1 +ttp: b439/782 bl:2.2979 bb:1.0253 rl:2.2520 rb:1.0260 dl:657-659 gd:1 +ttp: b432/782 bl:2.3081 bb:1.0259 rl:2.2527 rb:1.0260 dl:643-645 gd:1 +ttp: b424/782 bl:2.3159 bb:1.0501 rl:2.2535 rb:1.0263 dl:629-630 gd:1 +ttp: b416/782 bl:2.3358 bb:1.0270 rl:2.2544 rb:1.0263 dl:613-615 gd:1 +ttp: b409/782 bl:2.2869 bb:1.0496 rl:2.2548 rb:1.0266 dl:598-601 gd:1 +ttp: b400/782 bl:2.2838 bb:1.0277 rl:2.2551 rb:1.0266 dl:582-584 gd:1 +ttp: b392/782 bl:2.2117 bb:1.0173 rl:2.2547 rb:1.0265 dl:568-570 gd:1 +ttp: b383/782 bl:2.2435 bb:1.0287 rl:2.2546 rb:1.0265 dl:552-554 gd:1 +ttp: b376/782 bl:2.2858 bb:1.0252 rl:2.2549 rb:1.0265 dl:540-542 gd:1 +ttp: b368/782 bl:2.3367 bb:1.0883 rl:2.2556 rb:1.0271 dl:527-528 gd:1 +ttp: b361/782 bl:2.3234 bb:1.0847 rl:2.2562 rb:1.0276 dl:515-517 gd:1 +ttp: b353/782 bl:2.1754 bb:0.9948 rl:2.2555 rb:1.0273 dl:501-503 gd:1 +ttp: b344/782 bl:2.3504 bb:1.0475 rl:2.2563 rb:1.0275 dl:488-489 gd:1 +ttp: b336/782 bl:2.3717 bb:1.0688 rl:2.2573 rb:1.0278 dl:476-477 gd:1 +ttp: b328/782 bl:2.2534 bb:1.0016 rl:2.2573 rb:1.0276 dl:463-465 gd:1 +ttp: b300/782 bl:2.3062 bb:1.0418 rl:2.2576 rb:1.0277 dl:421-422 gd:1 +ttp: b292/782 bl:2.2978 bb:1.0879 rl:2.2579 rb:1.0281 dl:409-410 gd:1 +ttp: b284/782 bl:2.4138 bb:1.1239 rl:2.2589 rb:1.0287 dl:398-399 gd:1 +ttp: b276/782 bl:2.3540 bb:1.0880 rl:2.2595 rb:1.0291 dl:387-388 gd:1 +ttp: b268/782 bl:2.3169 bb:1.0585 rl:2.2599 rb:1.0293 dl:376-378 gd:1 +ttp: b260/782 bl:2.3366 bb:1.0646 rl:2.2603 rb:1.0295 dl:366-367 gd:1 +ttp: b252/782 bl:2.3512 bb:1.0538 rl:2.2609 rb:1.0296 dl:356-357 gd:1 +ttp: b244/782 bl:2.2924 bb:1.0908 rl:2.2610 rb:1.0300 dl:346-347 gd:1 +ttp: b236/782 bl:2.2984 bb:1.0578 rl:2.2612 rb:1.0301 dl:336-337 gd:1 +ttp: b229/782 bl:2.3389 bb:1.0542 rl:2.2616 rb:1.0302 dl:328-329 gd:1 +ttp: b221/782 bl:2.3741 bb:1.1063 rl:2.2622 rb:1.0306 dl:318-320 gd:1 +ttp: b213/782 bl:2.2370 bb:1.0627 rl:2.2621 rb:1.0308 dl:309-310 gd:1 +ttp: b205/782 bl:2.2957 bb:1.0991 rl:2.2623 rb:1.0311 dl:301-302 gd:1 +ttp: b197/782 bl:2.3172 bb:1.0954 rl:2.2625 rb:1.0314 dl:292-294 gd:1 +ttp: b189/782 bl:2.3828 bb:1.1242 rl:2.2630 rb:1.0317 dl:283-284 gd:1 +ttp: b182/782 bl:2.3204 bb:1.1032 rl:2.2633 rb:1.0320 dl:276-277 gd:1 +ttp: b174/782 bl:2.4136 bb:1.1384 rl:2.2639 rb:1.0325 dl:268-269 gd:1 +ttp: b164/782 bl:2.3970 bb:1.1339 rl:2.2644 rb:1.0329 dl:259-260 gd:1 +ttp: b155/782 bl:2.3663 bb:1.0940 rl:2.2648 rb:1.0331 dl:250-251 gd:1 +ttp: b147/782 bl:2.4325 bb:1.1063 rl:2.2655 rb:1.0334 dl:242-243 gd:1 +ttp: b136/782 bl:2.3880 bb:1.1228 rl:2.2659 rb:1.0337 dl:232-233 gd:1 +ttp: b128/782 bl:2.3531 bb:1.1373 rl:2.2662 rb:1.0340 dl:224-225 gd:1 +ttp: b120/782 bl:2.3593 bb:1.0962 rl:2.2665 rb:1.0342 dl:217-218 gd:1 +ttp: b111/782 bl:2.3784 bb:1.1597 rl:2.2668 rb:1.0346 dl:208-210 gd:1 +ttp: b105/782 bl:2.3815 bb:1.1327 rl:2.2672 rb:1.0349 dl:203-204 gd:1 +ttp: b99/782 bl:2.4588 bb:1.1580 rl:2.2678 rb:1.0352 dl:198-199 gd:1 +ttp: b89/782 bl:2.4593 bb:1.1364 rl:2.2683 rb:1.0355 dl:189-190 gd:1 +ttp: b81/782 bl:2.4314 bb:1.1034 rl:2.2687 rb:1.0357 dl:182-183 gd:1 +ttp: b74/782 bl:2.4379 bb:1.1313 rl:2.2692 rb:1.0359 dl:175-176 gd:1 +ttp: b67/782 bl:2.5135 bb:1.1899 rl:2.2698 rb:1.0363 dl:169-170 gd:1 +ttp: b59/782 bl:2.4637 bb:1.1737 rl:2.2703 rb:1.0366 dl:162-163 gd:1 +ttp: b52/782 bl:2.6405 bb:1.2325 rl:2.2711 rb:1.0371 dl:155-156 gd:1 +ttp: b44/782 bl:2.5357 bb:1.1832 rl:2.2717 rb:1.0374 dl:147-148 gd:1 +ttp: b36/782 bl:2.4997 bb:1.2062 rl:2.2721 rb:1.0377 dl:139-140 gd:1 +ttp: b27/782 bl:2.5415 bb:1.2015 rl:2.2727 rb:1.0380 dl:130-131 gd:1 +ttp: b19/782 bl:2.6021 bb:1.1948 rl:2.2732 rb:1.0383 dl:121-122 gd:1 +ttp: b11/782 bl:2.5707 bb:1.1887 rl:2.2737 rb:1.0385 dl:109-110 gd:1 +ttp: b3/782 bl:2.6135 bb:1.1643 rl:2.2742 rb:1.0387 dl:89-93 gd:1 +quantized_ttt_phased val_loss:2.28940177 val_bpb:1.04616727 eval_time:471457ms +total_eval_time:471.5s diff --git a/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/token_only_fast_evals/token_only_gate_population.json b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/token_only_fast_evals/token_only_gate_population.json new file mode 100644 index 0000000000..4bd3cbe8e8 --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/token_only_fast_evals/token_only_gate_population.json @@ -0,0 +1,27 @@ +{ + "n_targets": 47851520, + "token_gate_count": 628130, + "within_gate_count": 0, + "word_gate_count": 0, + "agree2plus_count": 0, + "all_val_targets": { + "total": 47851520, + "boundary": 49999, + "new_word": 25250866, + "continuation": 22550655, + "boundary_pct": 0.1045, + "new_word_pct": 52.7692, + "continuation_pct": 47.1263 + }, + "token_gate_positions": { + "total": 628130, + "boundary": 1718, + "new_word": 238227, + "continuation": 388185, + "boundary_pct": 0.2735, + "new_word_pct": 37.9264, + "continuation_pct": 61.8001 + }, + "source_alignment": "load_data_shard header-aware tokens[:usable+1], then targets=val_tokens[1:], matching train_gpt.py", + "production_path": "build_hints_for_targets(... token_order=16, within_boost=0, word_boost=0), dispatching process_chunk_token_only" +} diff --git a/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model new file mode 100644 index 0000000000..fffc8bb306 Binary files /dev/null and b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model differ diff --git a/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/train_gpt.py b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/train_gpt.py new file mode 100644 index 0000000000..fbe71d4102 --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/train_gpt.py @@ -0,0 +1,4383 @@ +import base64, collections, copy, fcntl, glob, io, lzma, math, os +from pathlib import Path +import random, re, subprocess, sys, time, uuid, numpy as np, sentencepiece as spm, torch, torch.distributed as dist, torch.nn.functional as F +from torch import Tensor, nn +from flash_attn_interface import ( + flash_attn_func as flash_attn_3_func, + flash_attn_varlen_func, +) +from concurrent.futures import ThreadPoolExecutor +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + + +# ===== Fused softcapped cross-entropy (Triton) — training-only path ===== +# Replaces the eager +# logits_softcap = softcap * tanh(logits / softcap) +# F.cross_entropy(logits_softcap.float(), targets, reduction="mean") +# sequence with a single fused kernel that reads logits_proj once, applies +# softcap in-register, and computes (LSE, loss) in one streaming pass. The +# backward kernel mirrors the forward so there's no stored softcapped logits. +# Numerically identical to the eager path up to fp32 accumulation differences. +_FUSED_CE_LIBRARY = "pgsubmission1draft7fusedce" +_FUSED_CE_BLOCK_SIZE = 1024 +_FUSED_CE_NUM_WARPS = 4 + + +@triton.jit +def _softcapped_ce_fwd_kernel( + logits_ptr, losses_ptr, lse_ptr, targets_ptr, + stride_logits_n, stride_logits_v, + n_rows, n_cols, softcap, + block_size: tl.constexpr, +): + row_idx = tl.program_id(0).to(tl.int64) + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + max_val = -float("inf") + sum_exp = 0.0 + A = 2.0 * softcap + inv_C = 2.0 / softcap + for off in range(0, n_cols, block_size): + cols = off + tl.arange(0, block_size) + mask = cols < n_cols + val = tl.load( + logits_row_ptr + cols * stride_logits_v, + mask=mask, other=-float("inf"), + ).to(tl.float32) + z = A * tl.sigmoid(val * inv_C) + z = tl.where(mask, z, -float("inf")) + curr_max = tl.max(z, axis=0) + new_max = tl.maximum(max_val, curr_max) + sum_exp = sum_exp * tl.exp(max_val - new_max) + tl.sum(tl.exp(z - new_max), axis=0) + max_val = new_max + lse = max_val + tl.log(sum_exp) + tl.store(lse_ptr + row_idx, lse) + target = tl.load(targets_ptr + row_idx).to(tl.int32) + target_val = tl.load(logits_row_ptr + target * stride_logits_v).to(tl.float32) + target_z = A * tl.sigmoid(target_val * inv_C) + tl.store(losses_ptr + row_idx, lse - target_z) + + +@triton.jit +def _softcapped_ce_bwd_kernel( + grad_logits_ptr, grad_losses_ptr, lse_ptr, logits_ptr, targets_ptr, + stride_logits_n, stride_logits_v, + stride_grad_n, stride_grad_v, + n_rows, n_cols, softcap, + block_size: tl.constexpr, +): + row_idx = tl.program_id(0).to(tl.int64) + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + grad_row_ptr = grad_logits_ptr + row_idx * stride_grad_n + lse = tl.load(lse_ptr + row_idx) + grad_loss = tl.load(grad_losses_ptr + row_idx).to(tl.float32) + target = tl.load(targets_ptr + row_idx).to(tl.int32) + A = 2.0 * softcap + inv_C = 2.0 / softcap + dz_dx_scale = A * inv_C + for off in range(0, n_cols, block_size): + cols = off + tl.arange(0, block_size) + mask = cols < n_cols + val = tl.load( + logits_row_ptr + cols * stride_logits_v, + mask=mask, other=0.0, + ).to(tl.float32) + sigmoid_u = tl.sigmoid(val * inv_C) + z = A * sigmoid_u + probs = tl.exp(z - lse) + grad_z = grad_loss * (probs - tl.where(cols == target, 1.0, 0.0)) + grad_x = grad_z * (dz_dx_scale * sigmoid_u * (1.0 - sigmoid_u)) + tl.store(grad_row_ptr + cols * stride_grad_v, grad_x, mask=mask) + + +def _validate_softcapped_ce_inputs( + logits: Tensor, targets: Tensor, softcap: float, +) -> tuple[Tensor, Tensor]: + if logits.ndim != 2: + raise ValueError(f"Expected logits.ndim=2, got {logits.ndim}") + if targets.ndim != 1: + raise ValueError(f"Expected targets.ndim=1, got {targets.ndim}") + if logits.shape[0] != targets.shape[0]: + raise ValueError( + f"Expected matching rows, got logits={tuple(logits.shape)} targets={tuple(targets.shape)}" + ) + if not logits.is_cuda or not targets.is_cuda: + raise ValueError("softcapped_cross_entropy requires CUDA tensors") + if softcap <= 0.0: + raise ValueError(f"softcap must be positive, got {softcap}") + if logits.dtype not in (torch.float16, torch.bfloat16, torch.float32): + raise ValueError(f"Unsupported logits dtype: {logits.dtype}") + logits = logits.contiguous() + targets = targets.contiguous() + if targets.dtype != torch.int64: + targets = targets.to(dtype=torch.int64) + return logits, targets + + +@torch.library.custom_op(f"{_FUSED_CE_LIBRARY}::softcapped_ce", mutates_args=()) +def softcapped_ce_op(logits: Tensor, targets: Tensor, softcap: float) -> tuple[Tensor, Tensor]: + logits, targets = _validate_softcapped_ce_inputs(logits, targets, float(softcap)) + n_rows, n_cols = logits.shape + losses = torch.empty((n_rows,), device=logits.device, dtype=torch.float32) + lse = torch.empty((n_rows,), device=logits.device, dtype=torch.float32) + _softcapped_ce_fwd_kernel[(n_rows,)]( + logits, losses, lse, targets, + logits.stride(0), logits.stride(1), + n_rows, n_cols, float(softcap), + block_size=_FUSED_CE_BLOCK_SIZE, num_warps=_FUSED_CE_NUM_WARPS, + ) + return losses, lse + + +@softcapped_ce_op.register_fake +def _(logits: Tensor, targets: Tensor, softcap: float): + if logits.ndim != 2 or targets.ndim != 1: + raise ValueError("softcapped_ce fake impl expects 2D logits and 1D targets") + if logits.shape[0] != targets.shape[0]: + raise ValueError( + f"Expected matching rows, got logits={tuple(logits.shape)} targets={tuple(targets.shape)}" + ) + n_rows = logits.shape[0] + return ( + logits.new_empty((n_rows,), dtype=torch.float32), + logits.new_empty((n_rows,), dtype=torch.float32), + ) + + +@torch.library.custom_op(f"{_FUSED_CE_LIBRARY}::softcapped_ce_backward", mutates_args=()) +def softcapped_ce_backward_op( + logits: Tensor, targets: Tensor, lse: Tensor, grad_losses: Tensor, softcap: float, +) -> Tensor: + logits, targets = _validate_softcapped_ce_inputs(logits, targets, float(softcap)) + lse = lse.contiguous() + grad_losses = grad_losses.contiguous().to(dtype=torch.float32) + if lse.ndim != 1 or grad_losses.ndim != 1: + raise ValueError("Expected 1D lse and grad_losses") + if lse.shape[0] != logits.shape[0] or grad_losses.shape[0] != logits.shape[0]: + raise ValueError( + f"Expected row-aligned lse/grad_losses, got logits={tuple(logits.shape)} " + f"lse={tuple(lse.shape)} grad_losses={tuple(grad_losses.shape)}" + ) + grad_logits = torch.empty_like(logits) + n_rows, n_cols = logits.shape + _softcapped_ce_bwd_kernel[(n_rows,)]( + grad_logits, grad_losses, lse, logits, targets, + logits.stride(0), logits.stride(1), + grad_logits.stride(0), grad_logits.stride(1), + n_rows, n_cols, float(softcap), + block_size=_FUSED_CE_BLOCK_SIZE, num_warps=_FUSED_CE_NUM_WARPS, + ) + return grad_logits + + +@softcapped_ce_backward_op.register_fake +def _(logits: Tensor, targets: Tensor, lse: Tensor, grad_losses: Tensor, softcap: float): + if logits.ndim != 2 or targets.ndim != 1 or lse.ndim != 1 or grad_losses.ndim != 1: + raise ValueError("softcapped_ce_backward fake impl expects 2D logits and 1D row tensors") + if ( + logits.shape[0] != targets.shape[0] + or logits.shape[0] != lse.shape[0] + or logits.shape[0] != grad_losses.shape[0] + ): + raise ValueError("softcapped_ce_backward fake impl expects row-aligned tensors") + return logits.new_empty(logits.shape) + + +def _softcapped_ce_setup_context( + ctx: torch.autograd.function.FunctionCtx, inputs, output, +) -> None: + logits, targets, softcap = inputs + _losses, lse = output + ctx.save_for_backward(logits, targets, lse) + ctx.softcap = float(softcap) + + +def _softcapped_ce_backward( + ctx: torch.autograd.function.FunctionCtx, grad_losses: Tensor, grad_lse: "Tensor | None", +): + del grad_lse + logits, targets, lse = ctx.saved_tensors + grad_logits = torch.ops.pgsubmission1draft7fusedce.softcapped_ce_backward( + logits, targets, lse, grad_losses, ctx.softcap + ) + return grad_logits, None, None + + +softcapped_ce_op.register_autograd( + _softcapped_ce_backward, setup_context=_softcapped_ce_setup_context, +) + + +def softcapped_cross_entropy( + logits: Tensor, targets: Tensor, softcap: float, reduction: str = "mean", +) -> Tensor: + losses, _lse = torch.ops.pgsubmission1draft7fusedce.softcapped_ce( + logits, targets, float(softcap) + ) + if reduction == "none": + return losses + if reduction == "sum": + return losses.sum() + if reduction == "mean": + return losses.mean() + raise ValueError(f"Unsupported reduction={reduction!r}") + + +class Hyperparameters: + data_dir = os.environ.get("DATA_DIR", "./data/") + seed = int(os.environ.get("SEED", 1337)) + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_frac = float(os.environ.get("WARMDOWN_FRAC", 0.75)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786432)) + # Fused softcapped CE (Triton). Training-only — forward_logits eval path still uses + # eager softcap+F.cross_entropy. Default ON since validated as at-worst neutral. + fused_ce_enabled = bool(int(os.environ.get("FUSED_CE_ENABLED", "1"))) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 6e2)) + val_batch_tokens = int(os.environ.get("VAL_BATCH_TOKENS", 524288)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 8192)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 11)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 4.0)) + skip_gates_enabled = bool(int(os.environ.get("SKIP_GATES_ENABLED", "1"))) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 3e1)) + rope_base = float(os.environ.get("ROPE_BASE", 1e4)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + rope_train_seq_len = int(os.environ.get("ROPE_TRAIN_SEQ_LEN", 2048)) + rope_yarn = bool(int(os.environ.get("ROPE_YARN", "0"))) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 5.0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + loop_start = int(os.environ.get("LOOP_START", 3)) + loop_end = int(os.environ.get("LOOP_END", 5)) + enable_looping_at = float(os.environ.get("ENABLE_LOOPING_AT", 0.35)) + parallel_start_layer = int(os.environ.get("PARALLEL_START_LAYER", 8)) + parallel_final_lane = os.environ.get("PARALLEL_FINAL_LANE", "mean") + min_lr = float(os.environ.get("MIN_LR", 0.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.026)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.97)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float( + os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92) + ) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + muon_row_normalize = bool(int(os.environ.get("MUON_ROW_NORMALIZE", "1"))) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-08)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + adam_wd = float(os.environ.get("ADAM_WD", 0.02)) + muon_wd = float(os.environ.get("MUON_WD", 0.095)) + embed_wd = float(os.environ.get("EMBED_WD", 0.085)) + ema_decay = float(os.environ.get("EMA_DECAY", 0.9965)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1"))) + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 96)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.0001)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 48)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 2048)) + ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + ttt_grad_steps = int(os.environ.get("TTT_GRAD_STEPS", 1)) + # V19: PR #1886 (renqianluo) + sunnypatneedi research log 2026-04-28 found that + # the Triton fused-CE kernel's fp32-accumulation interacts with warm-start LoRA-A + # to destabilize seeds 314/1337 at TTT_WEIGHT_DECAY=1.0. Raising the default to + # 2.0 prevents seed collapse without measurably moving stable seeds. + ttt_weight_decay = float(os.environ.get("TTT_WEIGHT_DECAY", 2.0)) + ttt_beta1 = float(os.environ.get("TTT_BETA1", 0)) + ttt_beta2 = float(os.environ.get("TTT_BETA2", 0.999)) + ttt_k_lora = bool(int(os.environ.get("TTT_K_LORA", "1"))) + ttt_mlp_lora = bool(int(os.environ.get("TTT_MLP_LORA", "1"))) + ttt_o_lora = bool(int(os.environ.get("TTT_O_LORA", "1"))) + ttt_optimizer = os.environ.get("TTT_OPTIMIZER", "adam") + ttt_eval_batches = os.environ.get("TTT_EVAL_BATCHES", "") + val_doc_fraction = float(os.environ.get("VAL_DOC_FRACTION", 1.0)) + compressor = os.environ.get("COMPRESSOR", "brotli") + gptq_calibration_batches = int(os.environ.get("GPTQ_CALIBRATION_BATCHES", 16)) + gptq_reserve_seconds = float(os.environ.get("GPTQ_RESERVE_SECONDS", 4.0)) + phased_ttt_prefix_docs = int(os.environ.get("PHASED_TTT_PREFIX_DOCS", 2000)) + phased_ttt_num_phases = int(os.environ.get("PHASED_TTT_NUM_PHASES", 1)) + global_ttt_lr = float(os.environ.get("GLOBAL_TTT_LR", 0.001)) + global_ttt_momentum = float(os.environ.get("GLOBAL_TTT_MOMENTUM", 0.9)) + global_ttt_epochs = int(os.environ.get("GLOBAL_TTT_EPOCHS", 1)) + global_ttt_chunk_tokens = int(os.environ.get("GLOBAL_TTT_CHUNK_TOKENS", 32768)) + global_ttt_batch_seqs = int(os.environ.get("GLOBAL_TTT_BATCH_SEQS", 32)) + global_ttt_warmup_start_lr = float(os.environ.get("GLOBAL_TTT_WARMUP_START_LR", 0.0)) + global_ttt_warmup_chunks = int(os.environ.get("GLOBAL_TTT_WARMUP_CHUNKS", 0)) + global_ttt_grad_clip = float(os.environ.get("GLOBAL_TTT_GRAD_CLIP", 1.0)) + global_ttt_respect_doc_boundaries = bool(int(os.environ.get("GLOBAL_TTT_RESPECT_DOC_BOUNDARIES", "1"))) + matrix_bits = int(os.environ.get("MATRIX_BITS", 6)) + embed_bits = int(os.environ.get("EMBED_BITS", 8)) + matrix_clip_sigmas = float(os.environ.get("MATRIX_CLIP_SIGMAS", 12.85)) + embed_clip_sigmas = float(os.environ.get("EMBED_CLIP_SIGMAS", 2e1)) + mlp_clip_sigmas = float(os.environ.get("MLP_CLIP_SIGMAS", 10.0)) + attn_clip_sigmas = float(os.environ.get("ATTN_CLIP_SIGMAS", 13.0)) + # AttnOutGate (per-head multiplicative output gate, PR #1667 MarioPaerle). + # Zero-init weight: 2*sigmoid(0)=1 -> transparent at start. Source defaults to + # block input x ('proj'); 'q' uses raw Q projection output. + attn_out_gate_enabled = bool(int(os.environ.get("ATTN_OUT_GATE_ENABLED", "0"))) + attn_out_gate_src = os.environ.get("ATTN_OUT_GATE_SRC", "proj") + # SmearGate (input-dependent forward-1 token smear, modded-nanogpt @classiclarryd + # via PR #1667). x_t <- x_t + lam * sigmoid(W*x_t[:gate_window]) * x_{t-1}. + # lam=0 + W=0 -> transparent at init. + smear_gate_enabled = bool(int(os.environ.get("SMEAR_GATE_ENABLED", "0"))) + # Window: first GATE_WINDOW dims of the source feed the gate projection. + gate_window = int(os.environ.get("GATE_WINDOW", 12)) + # Gated Attention (Qwen, NeurIPS 2025 Best Paper, arXiv:2505.06708; + # qiuzh20/gated_attention). Per-head sigmoid gate on SDPA output, BEFORE + # out_proj. Gate input = full block input x (paper's headwise G1 variant + # driven from hidden_states). W_g shape (num_heads, dim), plain sigmoid. + # Near-zero init gives g~0.5 at step 0 (half attention output); per-block + # attn_scale (init 1.0) compensates during training. Name contains + # "attn_gate" so CONTROL_TENSOR_NAME_PATTERNS routes it to scalar AdamW. + gated_attn_enabled = bool(int(os.environ.get("GATED_ATTN_ENABLED", "0"))) + gated_attn_init_std = float(os.environ.get("GATED_ATTN_INIT_STD", 0.01)) + # Dedicated int8-per-row quantization for `attn_gate_w` tensors. These are + # small ((num_heads, dim) = (8, 512) = 4096 params) and bypass GPTQ via the + # numel<=65536 passthrough branch -> stored as fp16 (8 KB/layer, ~65 KB total + # compressed). int8-per-row cuts the raw tensor in half with negligible BPB + # impact: scales per head (8 values), symmetric quant over [-127, 127]. + # No Hessian needed (gate weights not in collect_hessians()). + gated_attn_quant_gate = bool(int(os.environ.get("GATED_ATTN_QUANT_GATE", "0"))) + # Sparse Attention Gate (modded-nanogpt-style). Keeps dense SDPA and only + # swaps the output-gate input to the first GATE_WINDOW residual dims. + # W_g: (num_heads, gate_window) = (8, 12) = 96 params/layer (~44K total), + # vs dense GatedAttn's (8, 512) = 4K/layer (~44K diff). Name "attn_gate_w" + # is shared so quant routing and int8 gate passthrough Just Work. Gate + # passthrough int8 still applies via GATED_ATTN_QUANT_GATE=1. + # Mutually exclusive with ATTN_OUT_GATE_ENABLED and GATED_ATTN_ENABLED. + sparse_attn_gate_enabled = bool(int(os.environ.get("SPARSE_ATTN_GATE_ENABLED", "0"))) + sparse_attn_gate_init_std = float(os.environ.get("SPARSE_ATTN_GATE_INIT_STD", 0.0)) + sparse_attn_gate_scale = float(os.environ.get("SPARSE_ATTN_GATE_SCALE", 1.0)) + # LQER asymmetric rank-k correction on top-K quant-error tensors (PR #1530 v2 port). + # Computes SVD of E = W_fp - W_quant, packs top-r A,B as INT2/INT4 (asym) or INTk (sym). + lqer_enabled = bool(int(os.environ.get("LQER_ENABLED", "1"))) + lqer_rank = int(os.environ.get("LQER_RANK", 4)) + # Default 1 (was 3): PR #1948 sweep showed top_k=1 is -0.00044 BPB + # and saves ~50 KB which we need to keep the seed-max artifact under + # 16,000,000 with the new XSA gate weights and Skylight optimizer state + # (state is not exported, but Skylight slightly increases pre-quant BPB + # variance across seeds, so we want artifact headroom). + lqer_top_k = int(os.environ.get("LQER_TOP_K", 1)) + lqer_factor_bits = int(os.environ.get("LQER_FACTOR_BITS", 4)) + lqer_asym_enabled = bool(int(os.environ.get("LQER_ASYM_ENABLED", "1"))) + lqer_asym_group = int(os.environ.get("LQER_ASYM_GROUP", "64")) + lqer_scope = os.environ.get("LQER_SCOPE", "all") + lqer_gain_select = bool(int(os.environ.get("LQER_GAIN_SELECT", "0"))) + awq_lite_enabled = bool(int(os.environ.get("AWQ_LITE_ENABLED", "0"))) + awq_lite_bits = int(os.environ.get("AWQ_LITE_BITS", "8")) + awq_lite_group_top_k = int(os.environ.get("AWQ_LITE_GROUP_TOP_K", "1")) + awq_lite_group_size = int(os.environ.get("AWQ_LITE_GROUP_SIZE", "64")) + # PR #1145 online n-gram tilt (AnirudhRahul, valerio-endorsed). Causal, + # normalized, prefix-only token expert; closed-form multiplicative-boost-with-renorm + # applied to per-token NLL. See online_ngram_tilt.py for math + compliance. + ngram_tilt_enabled = bool(int(os.environ.get("NGRAM_TILT_ENABLED", "0"))) + token_order = int(os.environ.get("TOKEN_ORDER", "16")) + token_threshold = float(os.environ.get("TOKEN_THRESHOLD", "0.800")) + token_boost = float(os.environ.get("TOKEN_BOOST", "2.625")) + within_tau = float(os.environ.get("WITHIN_TAU", "999.0")) + within_boost = float(os.environ.get("WITHIN_BOOST", "0.0")) + word_order = int(os.environ.get("WORD_ORDER", "4")) + word_normalize = os.environ.get("WORD_NORMALIZE", "strip_punct_lower") + word_tau = float(os.environ.get("WORD_TAU", "999.0")) + word_boost = float(os.environ.get("WORD_BOOST", "0.0")) + agree_add_boost = float(os.environ.get("AGREE_ADD_BOOST", "0.0")) + # Optional legacy switch: move ngram hint precompute outside the measured eval + # timer. The submitted configuration keeps this disabled so precompute is timed. + ngram_hint_precompute_outside = bool(int(os.environ.get("NGRAM_HINT_PRECOMPUTE_OUTSIDE", "0"))) + # === Modded-NanoGPT transplants (PR-#1967 base + miracle stack) === + # GATED_XSA: per-(layer, head) tanh(alpha) gate around the existing XSA + # subtraction. zero-init alpha -> tanh(0)=0 -> step-0 model is bit-identical + # to the un-gated PR-#1967 baseline. Modded-nanogpt PR #264, p=0.0014. + # alpha is 1-D (num_heads,) so the existing optimizer routing puts it in the + # scalar-AdamW group automatically; per-tensor numel is small enough to land + # in the fp16 passthrough quant bucket -> ~16 bytes/layer artifact cost. + gated_xsa_enabled = bool(int(os.environ.get("GATED_XSA", "1"))) + # SKYLIGHT_MUON: per-row variance EMA + Frobenius re-normalization + u/w + # floor on the post-NS Muon update. Modded-nanogpt PR #269 (Skylight-001), + # 6/6 seeds, -250 steps to target. Pure optimizer change -> zero artifact + # bytes. State buffer (row_var_ema, skylight_step) lives only in optimizer + # state, never serialized. + skylight_muon_enabled = bool(int(os.environ.get("SKYLIGHT_MUON", "1"))) + skylight_beta2 = float(os.environ.get("BETA2_NORMUON", "0.95")) + skylight_uw_floor = float(os.environ.get("UW_FLOOR", "0.35")) + # 2C: Temperature scaling on logits before softcap. Σ P=1 preserved. + # Default 1.0 = no-op. Tune on train holdout, apply at eval. + temperature_scale = float(os.environ.get("TEMPERATURE_SCALE", "1.0")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + is_main_process = rank == 0 + grad_accum_steps = 8 // world_size + # CaseOps integration: optional override of dataset root + tokenizer path. + # When CASEOPS_ENABLED=1, the wrapper loads a per-token byte sidecar + # (fineweb_val_bytes_*.bin, identical shard layout to val_*.bin) and uses + # it as the canonical raw-byte budget for BPB accounting. The sidecar + # REPLACES the build_sentencepiece_luts byte-counting path entirely. + caseops_enabled = bool(int(os.environ.get("CASEOPS_ENABLED", "0"))) + _default_caseops_data = os.path.join( + data_dir, + "datasets", + "fineweb10B_sp8192_caseops", + "datasets", + "datasets", + "fineweb10B_sp8192_lossless_caps_caseops_v1_reserved", + ) + _default_caseops_tok = os.path.join( + data_dir, + "datasets", + "fineweb10B_sp8192_caseops", + "datasets", + "tokenizers", + "fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model", + ) + if caseops_enabled: + datasets_dir = os.environ.get("DATA_PATH", _default_caseops_data) + tokenizer_path = os.environ.get("TOKENIZER_PATH", _default_caseops_tok) + else: + datasets_dir = os.environ.get( + "DATA_PATH", + os.path.join(data_dir, "datasets", f"fineweb10B_sp{vocab_size}"), + ) + tokenizer_path = os.environ.get( + "TOKENIZER_PATH", + os.path.join(data_dir, "tokenizers", f"fineweb_{vocab_size}_bpe.model"), + ) + train_files = os.path.join(datasets_dir, "fineweb_train_*.bin") + val_files = os.path.join(datasets_dir, "fineweb_val_*.bin") + val_bytes_files = os.path.join(datasets_dir, "fineweb_val_bytes_*.bin") + artifact_dir = os.environ.get("ARTIFACT_DIR", "") + logfile = ( + os.path.join(artifact_dir, f"{run_id}.txt") + if artifact_dir + else f"logs/{run_id}.txt" + ) + model_path = ( + os.path.join(artifact_dir, "final_model.pt") + if artifact_dir + else "final_model.pt" + ) + quantized_model_path = ( + os.path.join(artifact_dir, "final_model.int6.ptz") + if artifact_dir + else "final_model.int6.ptz" + ) + + +_logger_hparams = None + + +def set_logging_hparams(h): + global _logger_hparams + _logger_hparams = h + + +def log(msg, console=True): + if _logger_hparams is None: + print(msg) + return + if _logger_hparams.is_main_process: + if console: + print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + +class ValidationData: + def __init__(self, h, device): + self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path) + if int(self.sp.vocab_size()) != h.vocab_size: + raise ValueError( + f"VOCAB_SIZE={h.vocab_size} does not match tokenizer vocab_size={int(self.sp.vocab_size())}" + ) + self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len) + self.caseops_enabled = bool(getattr(h, "caseops_enabled", False)) + if self.caseops_enabled: + self.base_bytes_lut = None + self.has_leading_space_lut = None + self.is_boundary_token_lut = None + else: + ( + self.base_bytes_lut, + self.has_leading_space_lut, + self.is_boundary_token_lut, + ) = build_sentencepiece_luts(self.sp, h.vocab_size, device) + self.val_bytes = None + if self.caseops_enabled: + self.val_bytes = load_validation_byte_sidecar( + h.val_bytes_files, h.eval_seq_len, self.val_tokens.numel() + ) + + +def build_sentencepiece_luts(sp, vocab_size, device): + sp_vocab_size = int(sp.vocab_size()) + assert ( + sp.piece_to_id("▁") != sp.unk_id() + ), "Tokenizer must have '▁' (space) as its own token for correct BPB byte counting" + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern, seq_len): + # Filter out CaseOps byte sidecar shards which share the val_*.bin glob. + files = [ + Path(p) + for p in sorted(glob.glob(pattern)) + if "_bytes_" not in Path(p).name + ] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = (tokens.numel() - 1) // seq_len * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def load_validation_byte_sidecar(pattern, seq_len, expected_len): + """Load CaseOps per-token byte sidecar(s). Same shard layout as token shards + (256 int32 header + uint16 array). Each entry = canonical raw-text byte + budget for that token in the corresponding val shard. Returns a CPU + int16 tensor sliced to match expected_len (i.e. val_tokens length).""" + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No byte sidecar files for pattern: {pattern}") + shards = [load_data_shard(file) for file in files] + # load_data_shard returns uint16 — that's exactly what the sidecar stores. + bytes_full = torch.cat(shards).contiguous() + if bytes_full.numel() < expected_len: + raise ValueError( + f"Byte sidecar too short: {bytes_full.numel()} < val_tokens {expected_len}" + ) + return bytes_full[:expected_len].to(torch.int32) + + +def load_data_shard(file): + header_bytes = 256 * np.dtype(" 0: + pos = start + while pos < end: + seg_starts.append(pos) + pos += max_doc_len + else: + seg_starts.append(start) + boundaries = seg_starts + [total_len] + padded_len = get_next_multiple_of_n(len(boundaries), bucket_size) + cu = torch.full((padded_len,), total_len, dtype=torch.int32, device=device) + cu[: len(boundaries)] = torch.tensor(boundaries, dtype=torch.int32, device=device) + seg_ends = seg_starts[1:] + [total_len] + max_seqlen = max(end - start for start, end in zip(seg_starts, seg_ends)) + return cu, max_seqlen + +class DocumentPackingLoader: + _shard_pool = ThreadPoolExecutor(1) + + def __init__(self, h, device, cu_bucket_size=64): + self.rank = h.rank + self.world_size = h.world_size + self.device = device + self.cu_bucket_size = cu_bucket_size + self.max_seq_len = h.train_seq_len + all_files = [Path(p) for p in sorted(glob.glob(h.train_files))] + if not all_files: + raise FileNotFoundError(f"No files found for pattern: {h.train_files}") + self.files = all_files + self.file_iter = iter(self.files) + self._init_shard(load_data_shard(next(self.file_iter))) + self._next_shard = self._submit_next_shard() + self._batch_pool = ThreadPoolExecutor(1) + self._prefetch_queue = [] + + def _init_shard(self, tokens): + global BOS_ID + self.tokens = tokens + self.shard_size = tokens.numel() + if BOS_ID is None: + BOS_ID = 1 + self.bos_idx = ( + (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + ) + self.cursor = int(self.bos_idx[0]) + + def _submit_next_shard(self): + try: + path = next(self.file_iter) + return self._shard_pool.submit(load_data_shard, path) + except StopIteration: + return None + + def _advance_shard(self): + if self._next_shard is None: + self.file_iter = iter(self.files) + self._next_shard = self._shard_pool.submit( + load_data_shard, next(self.file_iter) + ) + self._init_shard(self._next_shard.result()) + self._next_shard = self._submit_next_shard() + + def _local_doc_starts(self, local_start, total_len): + lo = np.searchsorted(self.bos_idx, local_start, side="left") + hi = np.searchsorted(self.bos_idx, local_start + total_len, side="left") + return (self.bos_idx[lo:hi] - local_start).tolist() + + def _prepare_batch(self, num_tokens_local, max_seq_len): + per_rank_span = num_tokens_local + 1 + global_span = per_rank_span * self.world_size + while self.cursor + global_span > self.shard_size: + self._advance_shard() + local_start = self.cursor + self.rank * per_rank_span + buf = self.tokens[local_start : local_start + per_rank_span] + inputs = torch.empty(per_rank_span - 1, dtype=torch.int64, pin_memory=True) + targets = torch.empty(per_rank_span - 1, dtype=torch.int64, pin_memory=True) + inputs.copy_(buf[:-1]) + targets.copy_(buf[1:]) + starts = self._local_doc_starts(local_start, inputs.numel()) + cu_seqlens, max_seqlen = _build_cu_seqlens( + starts, inputs.numel(), inputs.device, max_seq_len, self.cu_bucket_size + ) + cu_seqlens = cu_seqlens.pin_memory() + self.cursor += global_span + return inputs, targets, cu_seqlens, max_seqlen + + def next_batch(self, global_tokens, grad_accum_steps): + num_tokens_local = global_tokens // (self.world_size * grad_accum_steps) + while len(self._prefetch_queue) < 2: + self._prefetch_queue.append( + self._batch_pool.submit(self._prepare_batch, num_tokens_local, self.max_seq_len)) + inputs, targets, cu_seqlens, max_seqlen = self._prefetch_queue.pop(0).result() + self._prefetch_queue.append( + self._batch_pool.submit(self._prepare_batch, num_tokens_local, self.max_seq_len)) + return ( + inputs[None].to(self.device, non_blocking=True), + targets[None].to(self.device, non_blocking=True), + cu_seqlens.to(self.device, non_blocking=True), + max_seqlen, + ) + + +class ShuffledSequenceLoader: + def __init__(self, h, device): + self.world_size = h.world_size + self.seq_len = h.train_seq_len + self.device = device + all_files = [Path(p) for p in sorted(glob.glob(h.train_files))] + if not all_files: + raise FileNotFoundError(f"No files found for pattern: {h.train_files}") + self.files = all_files[h.rank :: h.world_size] + self.rng = np.random.Generator(np.random.PCG64(h.rank)) + self.num_tokens = [_read_num_tokens(f) for f in self.files] + self.start_inds = [[] for _ in self.files] + for si in range(len(self.files)): + self._reset_shard(si) + + def _reset_shard(self, si): + max_phase = min( + self.seq_len - 1, max(0, self.num_tokens[si] - self.seq_len - 1) + ) + phase = int(self.rng.integers(max_phase + 1)) if max_phase > 0 else 0 + num_sequences = (self.num_tokens[si] - 1 - phase) // self.seq_len + sequence_order = self.rng.permutation(num_sequences) + self.start_inds[si] = (phase + sequence_order * self.seq_len).tolist() + + def next_batch(self, global_tokens, grad_accum_steps): + device_tokens = global_tokens // (self.world_size * grad_accum_steps) + device_batch_size = device_tokens // self.seq_len + remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) + x = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + y = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + for bi in range(device_batch_size): + total = remaining.sum() + if total <= 0: + for si in range(len(self.files)): + self._reset_shard(si) + remaining = np.array( + [len(s) for s in self.start_inds], dtype=np.float64 + ) + total = remaining.sum() + probs = remaining / total + si = int(self.rng.choice(len(self.files), p=probs)) + start_ind = self.start_inds[si].pop() + remaining[si] -= 1 + mm = _get_shard_memmap(self.files[si]) + window = torch.as_tensor( + np.array(mm[start_ind : start_ind + self.seq_len + 1], dtype=np.int64) + ) + x[bi] = window[:-1] + y[bi] = window[1:] + return x.to(self.device, non_blocking=True), y.to( + self.device, non_blocking=True + ) + + +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + + def forward(self, x): + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x): + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +@triton.jit +def fused_log_softmax_dual_gather_kernel( + logits_ptr, + target_ids_ptr, + hint_ids_ptr, + log_p_y_out_ptr, + log_q_h_out_ptr, + BT, + V, + BLOCK_V: tl.constexpr, +): + """Fused log_softmax + dual gather. Single pass over [BT, V] logits per row, + extracts log p(target_id) and log p(hint_id) via online logsumexp. + Replaces F.log_softmax (which materializes [BT, V] fp32) + 2 gather ops. + """ + pid = tl.program_id(0) + if pid >= BT: + return + + target = tl.load(target_ids_ptr + pid) + hint = tl.load(hint_ids_ptr + pid) + row_offset = pid * V + + target_logit = tl.load(logits_ptr + row_offset + target).to(tl.float32) + hint_logit = tl.load(logits_ptr + row_offset + hint).to(tl.float32) + + NEG_INF = float("-inf") + max_val = NEG_INF + for v_start in tl.range(0, V, BLOCK_V): + v_offsets = v_start + tl.arange(0, BLOCK_V) + mask = v_offsets < V + chunk = tl.load( + logits_ptr + row_offset + v_offsets, mask=mask, other=NEG_INF + ).to(tl.float32) + block_max = tl.max(chunk, axis=0) + max_val = tl.maximum(max_val, block_max) + + sum_exp = tl.zeros((), dtype=tl.float32) + for v_start in tl.range(0, V, BLOCK_V): + v_offsets = v_start + tl.arange(0, BLOCK_V) + mask = v_offsets < V + chunk = tl.load( + logits_ptr + row_offset + v_offsets, mask=mask, other=0.0 + ).to(tl.float32) + chunk_centered = chunk - max_val + exp_chunk = tl.where(mask, tl.exp(chunk_centered), 0.0) + sum_exp += tl.sum(exp_chunk, axis=0) + + log_sum_exp = max_val + tl.log(sum_exp) + log_p_y = target_logit - log_sum_exp + log_p_h = hint_logit - log_sum_exp + + tl.store(log_p_y_out_ptr + pid, log_p_y) + tl.store(log_q_h_out_ptr + pid, log_p_h) + + +def fused_log_softmax_dual_gather(logits, target_ids, hint_ids): + """Triton wrapper — replaces F.log_softmax + 2 gather pattern. + Returns (log_p_y, log_q_h) where p = softmax(logits). + """ + bsz, sl, V = logits.shape + BT = bsz * sl + logits_flat = logits.reshape(BT, V).contiguous() + target_flat = target_ids.reshape(BT).contiguous() + hint_flat = hint_ids.reshape(BT).contiguous() + + log_p_y_out = torch.empty(BT, dtype=torch.float32, device=logits.device) + log_q_h_out = torch.empty(BT, dtype=torch.float32, device=logits.device) + + BLOCK_V = 1024 + grid = (BT,) + fused_log_softmax_dual_gather_kernel[grid]( + logits_flat, + target_flat, + hint_flat, + log_p_y_out, + log_q_h_out, + BT, + V, + BLOCK_V=BLOCK_V, + num_warps=8, + ) + return log_p_y_out.reshape(bsz, sl), log_q_h_out.reshape(bsz, sl) + + +@triton.jit +def linear_leaky_relu_square_kernel( + a_desc, + b_desc, + c_desc, + aux_desc, + M, + N, + K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + NUM_SMS: tl.constexpr, + FORWARD: tl.constexpr, +): + dtype = tl.bfloat16 + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + tile_id_c = start_pid - NUM_SMS + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + tile_id_c += NUM_SMS + offs_am_c = offs_am + offs_bn_c = offs_bn + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + c0 = acc0.to(dtype) + c1 = acc1.to(dtype) + if not FORWARD: + pre0 = aux_desc.load([offs_am_c, offs_bn_c]) + pre1 = aux_desc.load([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2]) + c0 = c0 * tl.where(pre0 > 0, 2.0 * pre0, 0.18 * pre0) + c1 = c1 * tl.where(pre1 > 0, 2.0 * pre1, 0.18 * pre1) + c_desc.store([offs_am_c, offs_bn_c], c0) + c_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1) + if FORWARD: + aux0 = tl.where(c0 > 0, c0, 0.3 * c0) + aux1 = tl.where(c1 > 0, c1, 0.3 * c1) + aux_desc.store([offs_am_c, offs_bn_c], aux0 * aux0) + aux_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], aux1 * aux1) + + +def linear_leaky_relu_square(a, b, aux=None): + M, K = a.shape + N, K2 = b.shape + assert K == K2 + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + forward = aux is None + if aux is None: + aux = torch.empty((M, N), device=a.device, dtype=a.dtype) + num_sms = torch.cuda.get_device_properties(a.device).multi_processor_count + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 256, 128, 64 + num_stages = 4 if forward else 3 + a_desc = TensorDescriptor.from_tensor(a, [BLOCK_SIZE_M, BLOCK_SIZE_K]) + b_desc = TensorDescriptor.from_tensor(b, [BLOCK_SIZE_N, BLOCK_SIZE_K]) + c_desc = TensorDescriptor.from_tensor(c, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + aux_desc = TensorDescriptor.from_tensor(aux, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + grid = lambda _meta: ( + min(num_sms, triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N)), + ) + linear_leaky_relu_square_kernel[grid]( + a_desc, + b_desc, + c_desc, + aux_desc, + M, + N, + K, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + NUM_SMS=num_sms, + FORWARD=forward, + num_stages=num_stages, + num_warps=8, + ) + if forward: + return c, aux + return c + + +class FusedLinearLeakyReLUSquareFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, w1, w2): + x_flat = x.reshape(-1, x.shape[-1]) + pre, post = linear_leaky_relu_square(x_flat, w1) + out = F.linear(post, w2) + ctx.save_for_backward(x, w1, w2, pre, post) + return out.view(*x.shape[:-1], out.shape[-1]) + + @staticmethod + def backward(ctx, grad_output): + x, w1, w2, pre, post = ctx.saved_tensors + x_flat = x.reshape(-1, x.shape[-1]) + grad_output_flat = grad_output.reshape(-1, grad_output.shape[-1]) + dw2 = grad_output_flat.T @ post + dpre = linear_leaky_relu_square(grad_output_flat, w2.T.contiguous(), aux=pre) + dw1 = dpre.T @ x_flat + dx = dpre @ w1 + return dx.view_as(x), dw1, dw2 + + +FusedLeakyReLUSquareMLP = FusedLinearLeakyReLUSquareFunction.apply + + +class Rotary(nn.Module): + def __init__(self, dim, base=1e4, train_seq_len=1024, rope_dims=0, yarn=True): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.yarn = yarn + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / base ** ( + torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + + def forward(self, seq_len, device, dtype): + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached < seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if self.yarn and seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * scale ** (rd / (rd - 2)) + inv_freq = 1.0 / new_base ** ( + torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd + ) + else: + inv_freq = self.inv_freq.float().to(device) + t = torch.arange(seq_len, device=device, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached[:, :seq_len].to(dtype=dtype), self._sin_cached[:, :seq_len].to(dtype=dtype) + + +def apply_rotary_emb(x, cos, sin, rope_dims=0): + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len, yarn=True, + attn_out_gate=False, attn_out_gate_src="proj", gate_window=12, + gated_attn=False, gated_attn_init_std=0.01, + sparse_attn_gate=False, sparse_attn_gate_init_std=0.0, sparse_attn_gate_scale=1.0, + gated_xsa=False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + if int(attn_out_gate) + int(gated_attn) + int(sparse_attn_gate) > 1: + raise ValueError( + "attn_out_gate, gated_attn, and sparse_attn_gate are mutually exclusive" + ) + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + self.q_gain = nn.Parameter( + torch.full((num_heads,), qk_gain_init, dtype=torch.float32) + ) + self.rope_dims = 0 + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len, yarn=yarn) + self.use_xsa = False + # AttnOutGate (PR #1667 MarioPaerle): per-head multiplicative gate on attention + # output. CastedLinear so restore_fp32_params casts back to fp32 for GPTQ. + # _zero_init -> 2*sigmoid(0)=1 -> transparent at init. + self.attn_out_gate = attn_out_gate + self.attn_out_gate_src = attn_out_gate_src + self.gate_window = gate_window + if attn_out_gate: + self.attn_gate_proj = CastedLinear(gate_window, num_heads, bias=False) + self.attn_gate_proj._zero_init = True + # Gated Attention (arXiv:2505.06708, Qwen, NeurIPS 2025). Per-head sigmoid + # gate on SDPA output, BEFORE out_proj. Gate projection W_g: (num_heads, dim). + # Name "attn_gate_w" contains "attn_gate" substring so it matches + # CONTROL_TENSOR_NAME_PATTERNS and routes to the scalar AdamW group. + # fp32 Parameter -> restore_fp32_params path covers it via the ndim<2 OR + # name-pattern check (name matches "attn_gate"). Cast to x.dtype on use. + self.gated_attn = gated_attn + if gated_attn: + W = torch.empty(num_heads, dim, dtype=torch.float32) + nn.init.normal_(W, mean=0.0, std=gated_attn_init_std) + self.attn_gate_w = nn.Parameter(W) + # Sparse attention head-output gate (modded-nanogpt style). Keeps dense SDPA + # and only narrows the gate input to the first gate_window residual dims. + # W_g: (num_heads, gate_window). y_{t,h} <- sigmoid(scale * W_g_h @ x_t[:gate_window]) * y_{t,h}. + # Shares attn_gate_w name with dense GatedAttn so the quant routing + # (CONTROL_TENSOR_NAME_PATTERNS / attn_gate_w int8 passthrough) is unchanged. + self.sparse_attn_gate = sparse_attn_gate + self.sparse_attn_gate_scale = sparse_attn_gate_scale + if sparse_attn_gate: + W = torch.empty(num_heads, gate_window, dtype=torch.float32) + if sparse_attn_gate_init_std > 0: + nn.init.normal_(W, mean=0.0, std=sparse_attn_gate_init_std) + else: + nn.init.zeros_(W) + self.attn_gate_w = nn.Parameter(W) + # Gated XSA (modded-nanogpt PR #264). zero-init -> tanh(0)=0 -> the + # original `_xsa_efficient` is unchanged at step 0. The per-head alpha + # can grow toward +1 (full XSA), shrink toward 0 (disable XSA per head), + # or go slightly negative (mild self-amplification). Routed to the + # scalar AdamW group via the ndim<2 check in Optimizers.__init__. + self.gated_xsa_enabled = gated_xsa + if gated_xsa: + self.xsa_alpha = nn.Parameter(torch.zeros(num_heads, dtype=torch.float32)) + + def _xsa_efficient(self, y, v): + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + coef = (y_g * vn).sum(dim=-1, keepdim=True) + if getattr(self, "gated_xsa_enabled", False): + # alpha shape (H,) -> reshape to broadcast over (B, T, Hkv, group, 1) + a = torch.tanh(self.xsa_alpha).view(1, 1, Hkv, group, 1).to(y.dtype) + coef = coef * a + return (y_g - coef * vn).reshape(B, T, H, D) + + def forward(self, x, q_w, k_w, v_w, out_w, cu_seqlens=None, max_seqlen=0): + bsz, seqlen, dim = x.shape + # q_raw kept around as a tap point for attn_out_gate_src='q' (post-projection, + # pre-reshape, pre-RoPE). + q_raw = F.linear(x, q_w.to(x.dtype)) + q = q_raw.reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = F.linear(x, v_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if cu_seqlens is not None: + y = flash_attn_varlen_func( + q[0], + k[0], + v[0], + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + causal=True, + window_size=(-1, -1), + )[None] + else: + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + # AttnOutGate inlined (PR #1667). Inline + .contiguous() barrier so torch.compile + # fullgraph=True is happy (this avoids the @torch.compiler.disable trap that + # crashed gates v3). Per-head gate on (B,T,H,D) tensor: g shape [B,T,H], broadcast + # over D via [..., None]. zero-init weight -> 2*sigmoid(0)=1 -> transparent. + if self.attn_out_gate: + gate_src = q_raw if self.attn_out_gate_src == "q" else x + gate_in = gate_src[..., : self.gate_window].contiguous() + g = 2.0 * torch.sigmoid(self.attn_gate_proj(gate_in)) + y = y * g[..., None] + # Gated Attention (arXiv:2505.06708 G1). Inline + .contiguous() barrier so + # torch.compile fullgraph=True is happy. Per-head gate on (B,T,H,D): g shape + # [B,T,H], broadcast over D via [..., None]. Paper: g = sigmoid(x @ W_g.T) + # where W_g: (H, dim). .to(x.dtype) on fp32 param before broadcast with bf16. + if self.gated_attn: + x_c = x.contiguous() + g = torch.sigmoid(F.linear(x_c, self.attn_gate_w.to(x.dtype))) + y = y * g[..., None] + # Sparse head-output gate: narrower (gate_window) input, same shape g as GatedAttn. + if self.sparse_attn_gate: + gate_in = x[..., : self.gate_window].contiguous() + g = torch.sigmoid( + self.sparse_attn_gate_scale + * F.linear(gate_in, self.attn_gate_w.to(x.dtype)) + ) + y = y * g[..., None] + y = y.reshape(bsz, seqlen, dim) + self._last_proj_input = y.detach() if getattr(self, "_calib", False) else None + return F.linear(y, out_w.to(x.dtype)) + + +class MLP(nn.Module): + def __init__(self, dim, mlp_mult): + super().__init__() + self.use_fused = True + + def forward(self, x, up_w, down_w): + if self.training and self.use_fused: + return FusedLeakyReLUSquareMLP(x, up_w.to(x.dtype), down_w.to(x.dtype)) + hidden = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=0.3).square() + self._last_down_input = hidden.detach() if getattr(self, "_calib", False) else None + return F.linear(hidden, down_w.to(x.dtype)) + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + train_seq_len, + layer_idx=0, + ln_scale=False, + yarn=True, + attn_out_gate=False, + attn_out_gate_src="proj", + gate_window=12, + gated_attn=False, + gated_attn_init_std=0.01, + sparse_attn_gate=False, + sparse_attn_gate_init_std=0.0, + sparse_attn_gate_scale=1.0, + gated_xsa=False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len, yarn=yarn, + attn_out_gate=attn_out_gate, attn_out_gate_src=attn_out_gate_src, gate_window=gate_window, + gated_attn=gated_attn, gated_attn_init_std=gated_attn_init_std, + sparse_attn_gate=sparse_attn_gate, + sparse_attn_gate_init_std=sparse_attn_gate_init_std, + sparse_attn_gate_scale=sparse_attn_gate_scale, + gated_xsa=gated_xsa, + ) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter( + torch.stack((torch.ones(dim), torch.zeros(dim))).float() + ) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=None, max_seqlen=0): + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn( + self.attn_norm(x_in) * self.ln_scale_factor, + q_w, k_w, v_w, out_w, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[ + None, None, : + ] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) + return x_out + +class GPT(nn.Module): + def __init__(self, h): + super().__init__() + if h.logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {h.logit_softcap}") + self.tie_embeddings = h.tie_embeddings + self.tied_embed_init_std = h.tied_embed_init_std + self.logit_softcap = h.logit_softcap + self.fused_ce_enabled = bool(h.fused_ce_enabled) + self.tok_emb = nn.Embedding(h.vocab_size, h.model_dim) + self.num_layers = h.num_layers + head_dim = h.model_dim // h.num_heads + kv_dim = h.num_kv_heads * head_dim + hidden_dim = int(h.mlp_mult * h.model_dim) + self.qo_bank = nn.Parameter(torch.empty(2 * h.num_layers, h.model_dim, h.model_dim)) + self.kv_bank = nn.Parameter(torch.empty(2 * h.num_layers, kv_dim, h.model_dim)) + self.mlp_up_bank = nn.Parameter(torch.empty(h.num_layers, hidden_dim, h.model_dim)) + self.mlp_down_bank = nn.Parameter(torch.empty(h.num_layers, h.model_dim, hidden_dim)) + self.num_encoder_layers = h.num_layers // 2 + self.num_decoder_layers = h.num_layers - self.num_encoder_layers + self.blocks = nn.ModuleList( + [ + Block( + h.model_dim, + h.num_heads, + h.num_kv_heads, + h.mlp_mult, + h.rope_base, + h.qk_gain_init, + h.train_seq_len, + layer_idx=i, + ln_scale=h.ln_scale, + yarn=h.rope_yarn, + attn_out_gate=h.attn_out_gate_enabled, + attn_out_gate_src=h.attn_out_gate_src, + gate_window=h.gate_window, + gated_attn=h.gated_attn_enabled, + gated_attn_init_std=h.gated_attn_init_std, + sparse_attn_gate=h.sparse_attn_gate_enabled, + sparse_attn_gate_init_std=h.sparse_attn_gate_init_std, + sparse_attn_gate_scale=h.sparse_attn_gate_scale, + gated_xsa=h.gated_xsa_enabled, + ) + for i in range(h.num_layers) + ] + ) + if h.rope_dims > 0: + head_dim = h.model_dim // h.num_heads + for block in self.blocks: + block.attn.rope_dims = h.rope_dims + block.attn.rotary = Rotary( + head_dim, + base=h.rope_base, + train_seq_len=h.train_seq_len, + rope_dims=h.rope_dims, + yarn=h.rope_yarn, + ) + self.final_norm = RMSNorm() + self.lm_head = ( + None + if h.tie_embeddings + else CastedLinear(h.model_dim, h.vocab_size, bias=False) + ) + if self.lm_head is not None: + self.lm_head._zero_init = True + if h.xsa_last_n > 0: + for i in range(max(0, h.num_layers - h.xsa_last_n), h.num_layers): + self.blocks[i].attn.use_xsa = True + self.looping_active = False + if h.num_loops > 0: + loop_seg = list(range(h.loop_start, h.loop_end + 1)) + all_indices = list(range(h.loop_start)) + for _ in range(h.num_loops + 1): + all_indices.extend(loop_seg) + all_indices.extend(range(h.loop_end + 1, h.num_layers)) + num_enc = len(all_indices) // 2 + self.encoder_indices = all_indices[:num_enc] + self.decoder_indices = all_indices[num_enc:] + else: + self.encoder_indices = list(range(self.num_encoder_layers)) + self.decoder_indices = list(range(self.num_encoder_layers, h.num_layers)) + self.num_skip_weights = min( + len(self.encoder_indices), len(self.decoder_indices) + ) + self.skip_weights = nn.Parameter( + torch.ones(self.num_skip_weights, h.model_dim, dtype=torch.float32) + ) + self.skip_gates = ( + nn.Parameter( + torch.zeros(self.num_skip_weights, h.model_dim, dtype=torch.float32) + ) + if h.skip_gates_enabled + else None + ) + self.parallel_start_layer = h.parallel_start_layer + self.parallel_final_lane = h.parallel_final_lane.lower() + self.parallel_post_lambdas = nn.Parameter( + torch.ones(h.num_layers, 2, 2, dtype=torch.float32) + ) + self.parallel_resid_lambdas = nn.Parameter( + torch.full((h.num_layers, 2), 1.1, dtype=torch.float32) + ) + # SmearGate (PR #1667 / modded-nanogpt @classiclarryd): + # x_t <- x_t + lam * sigmoid(W * x_t[:gate_window]) * x_{t-1}. + # Per-token forward-1 smear of the embedding lane. W zero-init + lam=0 -> + # transparent at init. Uses CastedLinear so restore_fp32_params handles dtype. + self.smear_gate_enabled = h.smear_gate_enabled + if self.smear_gate_enabled: + self.smear_window = h.gate_window + self.smear_gate = CastedLinear(self.smear_window, 1, bias=False) + self.smear_gate._zero_init = True + self.smear_lambda = nn.Parameter(torch.zeros(1, dtype=torch.float32)) + # V19: Asymmetric Logit Rescale (PR #1923 jorge-asenjo). + # Two learnable softcap scales applied on the EVAL path (forward_logits + + # forward_ttt). Init to logit_softcap so the layer is identity at step 0. + # Train path keeps the single fused softcap to preserve PR #1855 numerics. + self.asym_logit_enabled = bool(int(os.environ.get("ASYM_LOGIT_RESCALE", "0"))) + if self.asym_logit_enabled: + self.softcap_pos = nn.Parameter(torch.tensor(float(h.logit_softcap), dtype=torch.float32)) + self.softcap_neg = nn.Parameter(torch.tensor(float(h.logit_softcap), dtype=torch.float32)) + # v5 Stage 2C: temperature scaling on logits before softcap (eval-only TTT path). + self.temperature_scale = float(getattr(h, "temperature_scale", 1.0)) + self._init_weights() + + def _init_weights(self): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + n = self.num_layers + proj_scale = 1.0 / math.sqrt(2 * n) + for i in range(n): + nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) + nn.init.zeros_(self.qo_bank.data[n + i]) + self.qo_bank.data[n + i].mul_(proj_scale) + nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) + nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) + for i in range(n): + nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) + nn.init.zeros_(self.mlp_down_bank.data[i]) + self.mlp_down_bank.data[i].mul_(proj_scale) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif ( + module.weight.ndim == 2 + and module.weight.shape[0] >= 64 + and module.weight.shape[1] >= 64 + ): + nn.init.orthogonal_(module.weight, gain=1.0) + + def _bank_weights(self, i): + n = self.num_layers + return ( + self.qo_bank[i], + self.kv_bank[i], + self.kv_bank[n + i], + self.qo_bank[n + i], + self.mlp_up_bank[i], + self.mlp_down_bank[i], + ) + + def _parallel_block( + self, block_idx, lane0, lane1, x0, + q_w, k_w, v_w, out_w, up_w, down_w, + cu_seqlens=None, max_seqlen=0, + ): + block = self.blocks[block_idx] + mix = block.resid_mix.to(dtype=lane0.dtype) + attn_read = mix[0][None, None, :] * lane0 + mix[1][None, None, :] * x0 + attn_out = block.attn( + block.attn_norm(attn_read) * block.ln_scale_factor, + q_w, k_w, v_w, out_w, + cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + ) + attn_out = block.attn_scale.to(dtype=attn_out.dtype)[None, None, :] * attn_out + mlp_read = lane1 + mlp_out = block.mlp_scale.to(dtype=lane1.dtype)[None, None, :] * block.mlp( + block.mlp_norm(mlp_read) * block.ln_scale_factor, up_w, down_w + ) + attn_resid = self.parallel_resid_lambdas[block_idx, 0].to(dtype=lane0.dtype) + attn_post = self.parallel_post_lambdas[block_idx, 0].to(dtype=lane0.dtype) + mlp_resid = self.parallel_resid_lambdas[block_idx, 1].to(dtype=lane0.dtype) + mlp_post = self.parallel_post_lambdas[block_idx, 1].to(dtype=lane0.dtype) + lane0 = attn_resid * lane0 + attn_post[0] * attn_out + mlp_post[0] * mlp_out + lane1 = mlp_resid * lane1 + attn_post[1] * attn_out + mlp_post[1] * mlp_out + return lane0, lane1 + + def _final_parallel_hidden(self, lane0, lane1): + if self.parallel_final_lane == "mlp": + return lane1 + if self.parallel_final_lane == "attn": + return lane0 + return 0.5 * (lane0 + lane1) + + def _forward_hidden(self, input_ids, cu_seqlens=None, max_seqlen=0): + """Run the encoder/decoder stack to the final RMSNorm; returns pre-projection hidden. + Shared by eval (softcap+projection via forward_logits) and train (fused CE path).""" + x = self.tok_emb(input_ids) + # SmearGate (PR #1667). lam=0 + W=0 -> identity at init. + # Cross-doc leak fix: zero the prev-token smear at any position whose current token + # is BOS, so the BOS embedding starting doc N+1 in a packed stream is not + # contaminated by doc N's last token (audited issue on PR#1797 base). + if self.smear_gate_enabled: + sl = self.smear_lambda.to(dtype=x.dtype) + gate_in = x[:, 1:, : self.smear_window].contiguous() + g = sl * torch.sigmoid(self.smear_gate(gate_in)) + not_bos = (input_ids[:, 1:] != BOS_ID).to(x.dtype).unsqueeze(-1) + x = torch.cat([x[:, :1], x[:, 1:] + g * x[:, :-1] * not_bos], dim=1) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips = [] + enc_iter = ( + self.encoder_indices + if self.looping_active + else range(self.num_encoder_layers) + ) + dec_iter = ( + self.decoder_indices + if self.looping_active + else range( + self.num_encoder_layers, + self.num_encoder_layers + self.num_decoder_layers, + ) + ) + for i in enc_iter: + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + x = self.blocks[i](x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + skips.append(x) + psl = self.parallel_start_layer + lane0 = None + lane1 = None + for skip_idx, i in enumerate(dec_iter): + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + if i >= psl and psl > 0: + if lane0 is None: + lane0 = x + lane1 = x + if skip_idx < self.num_skip_weights and skips: + skip = skips.pop() + w = self.skip_weights[skip_idx].to(dtype=lane0.dtype)[None, None, :] + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=lane0.dtype))[None, None, :] + lane0 = torch.lerp(w * skip, lane0, g) + else: + lane0 = lane0 + w * skip + lane0, lane1 = self._parallel_block( + i, lane0, lane1, x0, q_w, k_w, v_w, out_w, up_w, down_w, + cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + ) + else: + if skip_idx < self.num_skip_weights and skips: + scaled_skip = ( + self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] + * skips.pop() + ) + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + x = self.blocks[i](x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + if lane0 is not None: + x = self._final_parallel_hidden(lane0, lane1) + x = self.final_norm(x) + return x + + def _project_logits(self, hidden): + if self.tie_embeddings: + return F.linear(hidden, self.tok_emb.weight) + return self.lm_head(hidden) + + def _apply_asym_softcap(self, logits): + # V19: Asymmetric softcap (PR #1923). Splits the logit_softcap scalar into + # learnable positive/negative branches. Score-first preserved: still a + # bounded, normalized post-projection nonlinearity feeding a standard + # softmax over the full vocab. + sp = self.softcap_pos.to(logits.dtype) + sn = self.softcap_neg.to(logits.dtype) + return torch.where(logits > 0, sp * torch.tanh(logits / sp), sn * torch.tanh(logits / sn)) + + def forward_logits(self, input_ids, cu_seqlens=None, max_seqlen=0): + hidden = self._forward_hidden(input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + logits_proj = self._project_logits(hidden) + if self.asym_logit_enabled: + return self._apply_asym_softcap(logits_proj) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward(self, input_ids, target_ids, cu_seqlens=None, max_seqlen=0): + hidden = self._forward_hidden(input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + logits_proj = self._project_logits(hidden) + flat_targets = target_ids.reshape(-1) + # Fused softcapped-CE kernel (training path only). Applies softcap inside the + # Triton kernel; takes pre-softcap logits_proj. Non-fused path matches stock + # PR-1736 numerics exactly (softcap in fp32, then F.cross_entropy on fp32). + if self.fused_ce_enabled: + return softcapped_cross_entropy( + logits_proj.reshape(-1, logits_proj.size(-1)), + flat_targets, + self.logit_softcap, + reduction="mean", + ) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + flat_targets, + reduction="mean", + ) + + def forward_ttt(self, input_ids, target_ids, lora, hint_ids=None): + x = self.tok_emb(input_ids) + # SmearGate on the TTT path — same inline compute as forward_logits. + # Cross-doc leak fix: see _forward_hidden comment. + if self.smear_gate_enabled: + sl = self.smear_lambda.to(dtype=x.dtype) + gate_in = x[:, 1:, : self.smear_window].contiguous() + g = sl * torch.sigmoid(self.smear_gate(gate_in)) + not_bos = (input_ids[:, 1:] != BOS_ID).to(x.dtype).unsqueeze(-1) + x = torch.cat([x[:, :1], x[:, 1:] + g * x[:, :-1] * not_bos], dim=1) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips = [] + enc_iter = ( + self.encoder_indices + if self.looping_active + else list(range(self.num_encoder_layers)) + ) + dec_iter = ( + self.decoder_indices + if self.looping_active + else list( + range( + self.num_encoder_layers, + self.num_encoder_layers + self.num_decoder_layers, + ) + ) + ) + slot = 0 + for i in enc_iter: + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + x = self._block_with_lora(self.blocks[i], x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w) + slot += 1 + skips.append(x) + psl = self.parallel_start_layer + lane0 = None + lane1 = None + for skip_idx, i in enumerate(dec_iter): + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + if i >= psl and psl > 0: + if lane0 is None: + lane0 = x + lane1 = x + if skip_idx < self.num_skip_weights and skips: + skip = skips.pop() + w = self.skip_weights[skip_idx].to(dtype=lane0.dtype)[None, None, :] + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=lane0.dtype))[None, None, :] + lane0 = torch.lerp(w * skip, lane0, g) + else: + lane0 = lane0 + w * skip + lane0, lane1 = self._parallel_block_with_lora( + i, lane0, lane1, x0, lora, slot, + q_w, k_w, v_w, out_w, up_w, down_w, + ) + else: + if skip_idx < self.num_skip_weights and skips: + scaled_skip = ( + self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] + * skips.pop() + ) + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + x = self._block_with_lora(self.blocks[i], x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w) + slot += 1 + if lane0 is not None: + x = self._final_parallel_hidden(lane0, lane1) + x = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = logits + lora.lm_head_lora(x) + # v5 Stage 2C: temperature scaling. T=1.0 (default) -> no-op. + # Applied BEFORE softcap so cap acts on calibrated logits. + if getattr(self, "temperature_scale", 1.0) != 1.0: + logits = logits / self.temperature_scale + # V19: same asymmetric softcap on the TTT eval path. + if self.asym_logit_enabled: + logits = self._apply_asym_softcap(logits) + else: + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + bsz, sl, V = logits.shape + if hint_ids is None: + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none" + ).reshape(bsz, sl) + # PR #1145 tilt branch (v4): Triton fused kernel for eval scoring (no_grad). + # TTT learning path needs autograd, so fall back to vanilla F.log_softmax + # when logits require grad. Triton kernel is forward-only (no backward). + if logits.requires_grad: + ls = F.log_softmax(logits.float(), dim=-1) + log_p_y = ls.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1) + log_q_h = ls.gather(-1, hint_ids.clamp(min=0).unsqueeze(-1)).squeeze(-1) + return -log_p_y, log_q_h + log_p_y, log_q_h = fused_log_softmax_dual_gather( + logits, target_ids, hint_ids.clamp(min=0) + ) + return -log_p_y, log_q_h + + def _block_with_lora(self, block, x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w): + mix = block.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = block.attn_norm(x_in) * block.ln_scale_factor + attn = block.attn + bsz, seqlen, dim = n.shape + # Keep raw Q for AttnOutGate src='q' (matches forward path semantics). + q_raw = F.linear(n, q_w.to(n.dtype)) + lora.q_loras[slot](n) + q = q_raw.reshape(bsz, seqlen, attn.num_heads, attn.head_dim) + k = F.linear(n, k_w.to(n.dtype)) + if lora.k_loras is not None: + k = k + lora.k_loras[slot](n) + k = k.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + v = (F.linear(n, v_w.to(n.dtype)) + lora.v_loras[slot](n)).reshape( + bsz, seqlen, attn.num_kv_heads, attn.head_dim + ) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = attn.rotary(seqlen, n.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, attn.rope_dims) + k = apply_rotary_emb(k, cos, sin, attn.rope_dims) + q = q * attn.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if attn.use_xsa: + y = attn._xsa_efficient(y, v) + # AttnOutGate (TTT path) — inline + .contiguous() barrier, same as the eval path. + if attn.attn_out_gate: + gate_src = q_raw if attn.attn_out_gate_src == "q" else n + gate_in = gate_src[..., : attn.gate_window].contiguous() + g = 2.0 * torch.sigmoid(attn.attn_gate_proj(gate_in)) + y = y * g[..., None] + # Gated Attention (TTT path). Gate input is n (post-norm block input), same + # as eval path. .to(n.dtype) on fp32 param before bf16 broadcast. + if attn.gated_attn: + n_c = n.contiguous() + g = torch.sigmoid(F.linear(n_c, attn.attn_gate_w.to(n.dtype))) + y = y * g[..., None] + # Sparse attention head-output gate (TTT path) — must match the eval path in + # forward() exactly, else training (which applied the gate) and TTT eval (which + # skipped it) produce mismatched representations and catastrophic BPB regression. + if attn.sparse_attn_gate: + gate_in = n[..., : attn.gate_window].contiguous() + g = torch.sigmoid( + attn.sparse_attn_gate_scale + * F.linear(gate_in, attn.attn_gate_w.to(n.dtype)) + ) + y = y * g[..., None] + y = y.reshape(bsz, seqlen, dim) + attn_out = F.linear(y, out_w.to(n.dtype)) + if lora.o_loras is not None: + attn_out = attn_out + lora.o_loras[slot](n) + x_out = x_in + block.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + mlp_n = block.mlp_norm(x_out) * block.ln_scale_factor + mlp_out = block.mlp(mlp_n, up_w, down_w) + if lora.mlp_loras is not None: + mlp_out = mlp_out + lora.mlp_loras[slot](mlp_n) + x_out = x_out + block.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * mlp_out + return x_out + + def _parallel_block_with_lora( + self, block_idx, lane0, lane1, x0, lora, slot, + q_w, k_w, v_w, out_w, up_w, down_w, + ): + block = self.blocks[block_idx] + mix = block.resid_mix.to(dtype=lane0.dtype) + attn_read = mix[0][None, None, :] * lane0 + mix[1][None, None, :] * x0 + n = block.attn_norm(attn_read) * block.ln_scale_factor + attn = block.attn + bsz, seqlen, dim = n.shape + q_raw = F.linear(n, q_w.to(n.dtype)) + lora.q_loras[slot](n) + q = q_raw.reshape(bsz, seqlen, attn.num_heads, attn.head_dim) + k = F.linear(n, k_w.to(n.dtype)) + if lora.k_loras is not None: + k = k + lora.k_loras[slot](n) + k = k.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + v = (F.linear(n, v_w.to(n.dtype)) + lora.v_loras[slot](n)).reshape( + bsz, seqlen, attn.num_kv_heads, attn.head_dim + ) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = attn.rotary(seqlen, n.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, attn.rope_dims) + k = apply_rotary_emb(k, cos, sin, attn.rope_dims) + q = q * attn.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if attn.use_xsa: + y = attn._xsa_efficient(y, v) + # AttnOutGate (TTT parallel path) — inline + .contiguous() barrier. + if attn.attn_out_gate: + gate_src = q_raw if attn.attn_out_gate_src == "q" else n + gate_in = gate_src[..., : attn.gate_window].contiguous() + g = 2.0 * torch.sigmoid(attn.attn_gate_proj(gate_in)) + y = y * g[..., None] + # Gated Attention (TTT parallel path). Gate input is n (post-norm block input). + if attn.gated_attn: + n_c = n.contiguous() + g = torch.sigmoid(F.linear(n_c, attn.attn_gate_w.to(n.dtype))) + y = y * g[..., None] + # Sparse attention head-output gate (TTT parallel path) — must match the + # eval path in forward() to keep train/eval semantics in sync. + if attn.sparse_attn_gate: + gate_in = n[..., : attn.gate_window].contiguous() + g = torch.sigmoid( + attn.sparse_attn_gate_scale + * F.linear(gate_in, attn.attn_gate_w.to(n.dtype)) + ) + y = y * g[..., None] + y = y.reshape(bsz, seqlen, dim) + attn_out = F.linear(y, out_w.to(n.dtype)) + if lora.o_loras is not None: + attn_out = attn_out + lora.o_loras[slot](n) + attn_out = block.attn_scale.to(dtype=attn_out.dtype)[None, None, :] * attn_out + mlp_read = lane1 + mlp_n = block.mlp_norm(mlp_read) * block.ln_scale_factor + mlp_out = block.mlp(mlp_n, up_w, down_w) + if lora.mlp_loras is not None: + mlp_out = mlp_out + lora.mlp_loras[slot](mlp_n) + mlp_out = block.mlp_scale.to(dtype=lane1.dtype)[None, None, :] * mlp_out + attn_resid = self.parallel_resid_lambdas[block_idx, 0].to(dtype=lane0.dtype) + attn_post = self.parallel_post_lambdas[block_idx, 0].to(dtype=lane0.dtype) + mlp_resid = self.parallel_resid_lambdas[block_idx, 1].to(dtype=lane0.dtype) + mlp_post = self.parallel_post_lambdas[block_idx, 1].to(dtype=lane0.dtype) + lane0 = attn_resid * lane0 + attn_post[0] * attn_out + mlp_post[0] * mlp_out + lane1 = mlp_resid * lane1 + attn_post[1] * attn_out + mlp_post[1] * mlp_out + return lane0, lane1 + + +class BatchedLinearLoRA(nn.Module): + # PR-1767: rank-scaled output (alpha/rank), like standard LoRA. Decouples + # effective magnitude from rank so changing rank does not change LR scale. + _ALPHA = float(os.environ.get("TTT_LORA_ALPHA", "144")) + # PR-1767: optionally keep A warm across per-doc resets (only B is zeroed). + # Accumulates useful feature directions across documents within a TTT phase. + _WARM_START_A = bool(int(os.environ.get("TTT_WARM_START_A", "1"))) + + def __init__(self, bsz, in_features, out_features, rank): + super().__init__() + self._bound = 1.0 / math.sqrt(in_features) + self._scale = self._ALPHA / rank + self.A = nn.Parameter( + torch.empty(bsz, rank, in_features).uniform_(-self._bound, self._bound) + ) + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) + + def reset(self): + with torch.no_grad(): + if not self._WARM_START_A: + self.A.uniform_(-self._bound, self._bound) + self.B.zero_() + + def forward(self, x): + return ((x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2)) * self._scale + + +class BatchedTTTLoRA(nn.Module): + def __init__(self, bsz, model, rank, k_lora=True, mlp_lora=True, o_lora=True): + super().__init__() + self.bsz = bsz + dim = model.qo_bank.shape[-1] + vocab = model.tok_emb.num_embeddings + if getattr(model, "looping_active", False): + num_slots = len(model.encoder_indices) + len(model.decoder_indices) + else: + num_slots = len(model.blocks) + kv_dim = model.blocks[0].attn.num_kv_heads * ( + dim // model.blocks[0].attn.num_heads + ) + embed_dim = model.tok_emb.embedding_dim + self.lm_head_lora = BatchedLinearLoRA(bsz, embed_dim, vocab, rank) + self.q_loras = nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + self.v_loras = nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)] + ) + self.k_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)] + ) + if k_lora + else None + ) + self.mlp_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + if mlp_lora + else None + ) + self.o_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + if o_lora + else None + ) + + def reset(self): + with torch.no_grad(): + self.lm_head_lora.reset() + for loras in [self.q_loras, self.v_loras, self.k_loras, + self.mlp_loras, self.o_loras]: + if loras is not None: + for lora in loras: + lora.reset() + + +# Polar Express per-iteration minimax Newton-Schulz coefficients (PR #1344). +# Replaces the fixed (3.4445, -4.775, 2.0315) coefficients of stock Muon. +# Applied at backend_steps=5 — taking more than 5 iterations from this list +# falls back to the final (converged) tuple via the slice guard below. +_PE_COEFFS = ( + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), +) + + +@torch.compile +def zeropower_via_newtonschulz5(G, steps=10, eps=1e-07): + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + coeffs = _PE_COEFFS[:steps] if steps <= len(_PE_COEFFS) else _PE_COEFFS + for a, b, c in coeffs: + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + + +class Muon(torch.optim.Optimizer): + def __init__( + self, + params, + lr, + momentum, + backend_steps, + nesterov=True, + weight_decay=0.0, + row_normalize=False, + skylight_enabled=False, + skylight_beta2=0.95, + skylight_uw_floor=0.35, + ): + super().__init__( + params, + dict( + lr=lr, + momentum=momentum, + backend_steps=backend_steps, + nesterov=nesterov, + weight_decay=weight_decay, + row_normalize=row_normalize, + ), + ) + self._built = False + # Skylight (modded-nanogpt PR #269 = NorMuon row-var EMA + u/w floor). + # Applied to the full-rank update AFTER all_gather completes, so every + # rank computes the same scale on identical inputs and keeps p replicas + # bit-identical without extra collectives. State (row_var_ema, step) + # lives only in self.state[p] and is never serialized. + self._skylight_enabled = bool(skylight_enabled) + self._skylight_beta2 = float(skylight_beta2) + self._skylight_uw_floor = float(skylight_uw_floor) + + def _apply_skylight(self, p, update): + """Skylight: per-row variance EMA -> rsqrt scaling -> Frobenius restore -> u/w floor. + + `update` shape: (rows..., in_features). We treat the last dim as the + per-row feature axis (matches the existing row_normalize convention at + the same call site, which also uses dim=-1). + + Returns a new tensor in the same dtype as `update`. + """ + if not self._skylight_enabled: + return update + state = self.state[p] + upd_f = update.float() + # Per-row variance: mean over the last dim. + row_var = upd_f.pow(2).mean(dim=-1, keepdim=True) + ema = state.get("row_var_ema") + if ema is None or ema.shape != row_var.shape or ema.device != row_var.device: + ema = torch.zeros_like(row_var) + state["row_var_ema"] = ema + state["skylight_step"] = 0 + state["skylight_step"] = int(state.get("skylight_step", 0)) + 1 + step = state["skylight_step"] + beta2 = self._skylight_beta2 + ema.mul_(beta2).add_(row_var, alpha=1.0 - beta2) + ema_hat = ema / (1.0 - beta2 ** step) + scale = torch.rsqrt(ema_hat + 1e-12) + pre_fro = upd_f.norm() + upd_f = upd_f * scale + post_fro = upd_f.norm() + 1e-12 + upd_f = upd_f * (pre_fro / post_fro) + # u/w floor: enforce ||U||_F / ||W||_F >= UW_FLOOR. + w_fro = p.data.float().norm() + 1e-12 + u_fro = upd_f.norm() + 1e-12 + ratio = u_fro / w_fro + if ratio < self._skylight_uw_floor: + upd_f = upd_f * (self._skylight_uw_floor / ratio) + return upd_f.to(update.dtype) + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + "p": p, + "B": B, + "padded_grad": torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + "shard": torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + "shard_mom": torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + "full_update": torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + "scale": max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + self._bank_meta.sort(key=lambda m: -m["p"].numel()) + self._built = True + + def launch_reduce_scatters(self): + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m["p"] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m["padded_grad"] + pg[: m["B"]].copy_(p.grad) + fut = dist.reduce_scatter_tensor( + m["shard"], pg, op=dist.ReduceOp.AVG, async_op=True + ) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + if not self._built: + self._build() + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + row_normalize = group.get("row_normalize", False) + prev_ag_handle = None + prev_m = None + sharded = self._distributed and hasattr(self, "_rs_futures") + for idx, m in enumerate(self._bank_meta): + p = m["p"] + if p.grad is None: + continue + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m["p"] + upd = prev_m["full_update"][: prev_m["B"]] + upd = self._apply_skylight(pp, upd) + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd, alpha=-lr * prev_m["scale"]) + if sharded and self._rs_futures[idx] is not None: + self._rs_futures[idx].wait() + g = m["shard"] + buf = m["shard_mom"] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + if row_normalize: + rn = update.float().norm(dim=-1, keepdim=True).clamp_min(1e-07) + update = update / rn.to(update.dtype) + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m["full_update"], update, async_op=True + ) + prev_m = m + else: + update = self._apply_skylight(p, update) + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update, alpha=-lr * m["scale"]) + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m["p"] + upd = prev_m["full_update"][: prev_m["B"]] + upd = self._apply_skylight(pp, upd) + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd, alpha=-lr * prev_m["scale"]) + if hasattr(self, "_rs_futures"): + del self._rs_futures + return loss + + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates,parallel_post_lambdas,parallel_resid_lambdas,attn_gate_proj,attn_gate_w,smear_gate,smear_lambda", + ).split(",") + if pattern +) + + +PACKED_REPLICATED_GRAD_MAX_NUMEL = 1 << 15 + + +class Optimizers: + def __init__(self, h, base_model): + matrix_params = [ + base_model.qo_bank, + base_model.kv_bank, + base_model.mlp_up_bank, + base_model.mlp_down_bank, + ] + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [ + p + for (name, p) in block_named_params + if p.ndim < 2 + or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.skip_gates is not None and base_model.skip_gates.numel() > 0: + scalar_params.append(base_model.skip_gates) + if base_model.parallel_post_lambdas is not None: + scalar_params.append(base_model.parallel_post_lambdas) + if base_model.parallel_resid_lambdas is not None: + scalar_params.append(base_model.parallel_resid_lambdas) + # SmearGate params live on GPT root (not in .blocks), so add them by hand. + # Both are tiny (gate_window scalars + 1 lambda). Optimized via scalar Adam. + if getattr(base_model, "smear_gate_enabled", False): + scalar_params.append(base_model.smear_gate.weight) + scalar_params.append(base_model.smear_lambda) + token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr + tok_params = [ + {"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr} + ] + self.optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.embed_wd, + fused=True, + ) + self.optimizer_muon = Muon( + matrix_params, + lr=h.matrix_lr, + momentum=h.muon_momentum, + backend_steps=h.muon_backend_steps, + weight_decay=h.muon_wd, + row_normalize=h.muon_row_normalize, + skylight_enabled=h.skylight_muon_enabled, + skylight_beta2=h.skylight_beta2, + skylight_uw_floor=h.skylight_uw_floor, + ) + for group in self.optimizer_muon.param_groups: + group["base_lr"] = h.matrix_lr + self.optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": h.scalar_lr, "base_lr": h.scalar_lr}], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.adam_wd, + fused=True, + ) + self.optimizers = [ + self.optimizer_tok, + self.optimizer_muon, + self.optimizer_scalar, + ] + self.replicated_params = list(tok_params[0]["params"]) + self.replicated_params.extend(scalar_params) + self.replicated_large_params = [] + self.replicated_packed_params = [] + for p in self.replicated_params: + if p.numel() <= PACKED_REPLICATED_GRAD_MAX_NUMEL: + self.replicated_packed_params.append(p) + else: + self.replicated_large_params.append(p) + self._aux_stream = torch.cuda.Stream() + + def __iter__(self): + return iter(self.optimizers) + + def zero_grad_all(self): + for opt in self.optimizers: + opt.zero_grad(set_to_none=True) + + def _all_reduce_packed_grads(self): + grads_by_key = collections.defaultdict(list) + for p in self.replicated_packed_params: + if p.grad is not None: + grads_by_key[(p.grad.device, p.grad.dtype)].append(p.grad) + for grads in grads_by_key.values(): + flat = torch.empty( + sum(g.numel() for g in grads), + device=grads[0].device, + dtype=grads[0].dtype, + ) + offset = 0 + for g in grads: + n = g.numel() + flat[offset : offset + n].copy_(g.contiguous().view(-1)) + offset += n + dist.all_reduce(flat, op=dist.ReduceOp.AVG) + offset = 0 + for g in grads: + n = g.numel() + g.copy_(flat[offset : offset + n].view_as(g)) + offset += n + + def step(self, distributed=False): + self.optimizer_muon.launch_reduce_scatters() + if distributed: + reduce_handles = [ + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG, async_op=True) + for p in self.replicated_large_params + if p.grad is not None + ] + self._all_reduce_packed_grads() + for handle in reduce_handles: + handle.wait() + self._aux_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self._aux_stream): + self.optimizer_tok.step() + self.optimizer_scalar.step() + self.optimizer_muon.step() + torch.cuda.current_stream().wait_stream(self._aux_stream) + self.zero_grad_all() + + +def restore_fp32_params(model): + for module in model.modules(): + if isinstance(module, CastedLinear): + module.float() + for name, param in model.named_parameters(): + if ( + param.ndim < 2 + or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ) and param.dtype != torch.float32: + param.data = param.data.float() + if hasattr(model, "qo_bank") and model.qo_bank is not None: + model.qo_bank.data = model.qo_bank.data.float() + model.kv_bank.data = model.kv_bank.data.float() + model.mlp_up_bank.data = model.mlp_up_bank.data.float() + model.mlp_down_bank.data = model.mlp_down_bank.data.float() + + +def collect_hessians(model, train_loader, h, device, n_calibration_batches=64): + hessians = {} + act_sumsq = {} + act_counts = {} + hooks = [] + for i, block in enumerate(model.blocks): + block.attn._calib = True + block.mlp._calib = True + block.mlp.use_fused = False + + def make_attn_hook(layer_idx): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + x_sq = x.square().sum(dim=0) + x_count = x.shape[0] + for suffix in ["c_q", "c_k", "c_v"]: + name = f"blocks.{layer_idx}.attn.{suffix}.weight" + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + if name not in act_sumsq: + act_sumsq[name] = torch.zeros( + x.shape[1], dtype=torch.float32, device=device + ) + act_counts[name] = 0 + act_sumsq[name] += x_sq + act_counts[name] += x_count + y = module._last_proj_input + if y is not None: + y = y.float() + if y.ndim == 3: + y = y.reshape(-1, y.shape[-1]) + name = f"blocks.{layer_idx}.attn.proj.weight" + if name not in hessians: + hessians[name] = torch.zeros( + y.shape[1], y.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(y.T, y) + if name not in act_sumsq: + act_sumsq[name] = torch.zeros( + y.shape[1], dtype=torch.float32, device=device + ) + act_counts[name] = 0 + act_sumsq[name] += y.square().sum(dim=0) + act_counts[name] += y.shape[0] + return hook_fn + + def make_mlp_hook(layer_idx): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + name = f"blocks.{layer_idx}.mlp.fc.weight" + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + if name not in act_sumsq: + act_sumsq[name] = torch.zeros( + x.shape[1], dtype=torch.float32, device=device + ) + act_counts[name] = 0 + act_sumsq[name] += x.square().sum(dim=0) + act_counts[name] += x.shape[0] + h_act = module._last_down_input + if h_act is not None: + h_act = h_act.float() + if h_act.ndim == 3: + h_act = h_act.reshape(-1, h_act.shape[-1]) + name = f"blocks.{layer_idx}.mlp.proj.weight" + if name not in hessians: + hessians[name] = torch.zeros( + h_act.shape[1], h_act.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(h_act.T, h_act) + if name not in act_sumsq: + act_sumsq[name] = torch.zeros( + h_act.shape[1], dtype=torch.float32, device=device + ) + act_counts[name] = 0 + act_sumsq[name] += h_act.square().sum(dim=0) + act_counts[name] += h_act.shape[0] + return hook_fn + + for i, block in enumerate(model.blocks): + hooks.append(block.attn.register_forward_hook(make_attn_hook(i))) + hooks.append(block.mlp.register_forward_hook(make_mlp_hook(i))) + + # Hessian hooks for embedding factorization projection layers + def make_linear_input_hook(weight_name): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if weight_name not in hessians: + hessians[weight_name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[weight_name].addmm_(x.T, x) + return hook_fn + + if model.tie_embeddings: + hook_module = model.final_norm + + def make_output_hook(name): + def hook_fn(module, inp, out): + x = out.detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + if name not in act_sumsq: + act_sumsq[name] = torch.zeros( + x.shape[1], dtype=torch.float32, device=device + ) + act_counts[name] = 0 + act_sumsq[name] += x.square().sum(dim=0) + act_counts[name] += x.shape[0] + return hook_fn + + hooks.append( + hook_module.register_forward_hook(make_output_hook("tok_emb.weight")) + ) + model.eval() + with torch.no_grad(): + for _ in range(n_calibration_batches): + x, _ = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps) + model.forward_logits(x) + for hook in hooks: + hook.remove() + for i, block in enumerate(model.blocks): + block.attn._calib = False + block.mlp._calib = False + block.mlp.use_fused = True + for name in hessians: + hessians[name] = hessians[name].cpu() / n_calibration_batches + act_stats = {} + for name, sumsq in act_sumsq.items(): + count = max(act_counts.get(name, 0), 1) + act_stats[name] = (sumsq / count).sqrt().cpu() + return hessians, act_stats + + +def gptq_quantize_weight( + w, + H, + clip_sigmas=3.0, + clip_range=63, + block_size=128, + protect_groups=None, + group_size=None, + protect_clip_range=None, +): + W_orig = w.float().clone() + rows, cols = W_orig.shape + H = H.float().clone() + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + damp = 0.01 * H.diag().mean() + H.diagonal().add_(damp) + perm = torch.argsort(H.diag(), descending=True) + invperm = torch.argsort(perm) + W_perm = W_orig[:, perm].clone() + W_perm[:, dead[perm]] = 0 + H = H[perm][:, perm] + Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H)) + Hinv = torch.linalg.cholesky(Hinv, upper=True) + row_std = W_orig.std(dim=1) + s = (clip_sigmas * row_std / clip_range).clamp_min(1e-10).to(torch.float16) + sf = s.float() + protect_meta = None + protect_mask_perm = None + s_hi = None + sf_hi = None + if ( + protect_groups + and group_size is not None + and protect_clip_range is not None + and protect_clip_range > clip_range + ): + protect_mask = torch.zeros(cols, dtype=torch.bool) + starts = [] + for (start, end) in protect_groups: + if start < 0 or end > cols or end <= start: + continue + protect_mask[start:end] = True + starts.append(start) + if starts: + protect_mask_perm = protect_mask[perm] + s_hi = (clip_sigmas * row_std / protect_clip_range).clamp_min(1e-10).to( + torch.float16 + ) + sf_hi = s_hi.float() + protect_meta = { + "starts": torch.tensor(starts, dtype=torch.int16), + "size": int(group_size), + "s_hi": s_hi, + } + Q = torch.zeros(rows, cols, dtype=torch.int8) + W_work = W_perm.clone() + for i1 in range(0, cols, block_size): + i2 = min(i1 + block_size, cols) + W_block = W_work[:, i1:i2].clone() + Hinv_block = Hinv[i1:i2, i1:i2] + Err = torch.zeros(rows, i2 - i1) + for j in range(i2 - i1): + w_col = W_block[:, j] + d = Hinv_block[j, j] + if protect_mask_perm is not None and bool(protect_mask_perm[i1 + j]): + q_col = torch.clamp( + torch.round(w_col / sf_hi), + -protect_clip_range, + protect_clip_range, + ) + w_recon = q_col.float() * sf_hi + else: + q_col = torch.clamp(torch.round(w_col / sf), -clip_range, clip_range) + w_recon = q_col.float() * sf + Q[:, i1 + j] = q_col.to(torch.int8) + err = (w_col - w_recon) / d + Err[:, j] = err + W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0) + if i2 < cols: + W_work[:, i2:] -= Err @ Hinv[i1:i2, i2:] + return Q[:, invperm], s, protect_meta + + +def _quantize_gate_int8_row(w): + # Symmetric int8-per-row quantization for small gate tensors. w shape + # (R, C) -> (R,) scales in fp16, int8 values in [-127, 127]. Single scale + # per row keeps accuracy high while halving storage vs fp16. + W = w.float().contiguous() + row_max = W.abs().amax(dim=1).clamp_min(1e-10) + s = (row_max / 127.0).to(torch.float16) + sf = s.float().view(-1, 1) + q = torch.clamp(torch.round(W / sf), -127, 127).to(torch.int8) + return q, s + + +def _lqer_pack(A, B, bits): + rng = 2 ** (bits - 1) - 1 + sA = (A.abs().amax(dim=1).clamp_min(1e-10) / rng).to(torch.float16) + sB = (B.abs().amax(dim=1).clamp_min(1e-10) / rng).to(torch.float16) + qA = torch.clamp(torch.round(A / sA.float().view(-1, 1)), -rng, rng).to(torch.int8) + qB = torch.clamp(torch.round(B / sB.float().view(-1, 1)), -rng, rng).to(torch.int8) + return qA, sA, qB, sB + + +def _lqer_pack_asym(A, B, g=64): + # A: INT2 per-matrix scalar (signed [-2,1], scale = |A|max/1.5). + sA = (A.abs().amax().clamp_min(1e-10) / 1.5).to(torch.float16) + qA = torch.clamp(torch.round(A / sA.float()), -2, 1).to(torch.int8) + # B: INT4 groupwise g over flattened B (signed [-8,7], per-group scale). + Bf = B.reshape(-1, g) + Bmax = Bf.abs().amax(dim=-1, keepdim=True).clamp_min(1e-10) + sB = (Bmax / 7.5).to(torch.float16).reshape(-1) + qB = torch.clamp(torch.round(Bf / sB.float().reshape(-1, 1)), -8, 7).to( + torch.int8 + ).reshape(B.shape) + return qA, sA, qB, sB + + +def _lqer_fit_quantized(E, h): + U, S, Vh = torch.linalg.svd(E, full_matrices=False) + r = min(h.lqer_rank, S.numel()) + if r <= 0: + return None + A = (U[:, :r] * S[:r]).contiguous() + B = Vh[:r, :].contiguous() + asym_on = bool(getattr(h, "lqer_asym_enabled", False)) + asym_g = int(getattr(h, "lqer_asym_group", 64)) + if asym_on and B.numel() % asym_g == 0: + qA, sA, qB, sB = _lqer_pack_asym(A, B, asym_g) + A_hat = qA.float() * float(sA) + g_sz = qB.numel() // sB.numel() + B_hat = (qB.reshape(-1, g_sz).float() * sB.float().view(-1, 1)).reshape( + qB.shape + ) + return { + "kind": "asym", + "qA": qA, + "sA": sA, + "qB": qB, + "sB": sB, + "delta": A_hat @ B_hat, + } + qA, sA, qB, sB = _lqer_pack(A, B, h.lqer_factor_bits) + A_hat = qA.float() * sA.float().view(-1, 1) + B_hat = qB.float() * sB.float().view(-1, 1) + return { + "kind": "sym", + "qA": qA, + "sA": sA, + "qB": qB, + "sB": sB, + "delta": A_hat @ B_hat, + } + + +def _awq_lite_group_candidates(w, act_rms, group_size): + cols = w.shape[1] + n_groups = cols // group_size + if n_groups <= 0: + return [] + weight_score = w.float().abs().mean(dim=0) + saliency = act_rms.float() * weight_score + cands = [] + for gi in range(n_groups): + start = gi * group_size + end = start + group_size + score = float(saliency[start:end].sum()) + cands.append((score, start, end)) + return cands + + +def gptq_mixed_quantize(state_dict, hessians, act_stats, h): + result = {} + meta = {} + quant_gate = bool(getattr(h, "gated_attn_quant_gate", False)) + lqer_on = bool(getattr(h, "lqer_enabled", False)) + awq_on = bool(getattr(h, "awq_lite_enabled", False)) + lqer_cands = {} + awq_selected = collections.defaultdict(list) + if awq_on: + awq_cands = [] + for (name, tensor) in state_dict.items(): + t = tensor.detach().cpu().contiguous() + if t.is_floating_point() and t.numel() > 65536 and name in act_stats: + bits = h.embed_bits if "tok_emb" in name else h.matrix_bits + if bits < h.awq_lite_bits: + for score, start, end in _awq_lite_group_candidates( + t, act_stats[name], h.awq_lite_group_size + ): + awq_cands.append((score, name, start, end)) + awq_cands.sort(key=lambda x: -x[0]) + for (_score, name, start, end) in awq_cands[: h.awq_lite_group_top_k]: + awq_selected[name].append((start, end)) + for (name, tensor) in state_dict.items(): + t = tensor.detach().cpu().contiguous() + # Dedicated int8-per-row path for attn_gate_w (bypasses both GPTQ and + # fp16 passthrough). Applied BEFORE the numel<=65536 passthrough check + # so the gate tensor is routed here instead of to fp16. + if ( + quant_gate + and t.is_floating_point() + and t.ndim == 2 + and name.endswith(".attn_gate_w") + # Dense GatedAttn: (num_heads, dim) = (8, 512) = 4096. + # Sparse gate: (num_heads, gate_window) = (8, 12) = 96. + # Both need int8-per-row routing; the 1024 lower bound in stock + # PR-1736 presumed dense-only. Widen to catch both. + and 32 <= t.numel() <= 8192 + ): + gq, gs = _quantize_gate_int8_row(t) + result[name + ".gq"] = gq + result[name + ".gs"] = gs + meta[name] = "gate_int8_row" + continue + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough (float16)" + continue + if "tok_emb" in name: + cs = h.embed_clip_sigmas + elif ".mlp." in name: + cs = h.mlp_clip_sigmas + elif ".attn." in name: + cs = h.attn_clip_sigmas + else: + cs = h.matrix_clip_sigmas + bits = h.embed_bits if "tok_emb" in name else h.matrix_bits + clip_range = 2 ** (bits - 1) - 1 + q, s, protect_meta = gptq_quantize_weight( + t, + hessians[name], + clip_sigmas=cs, + clip_range=clip_range, + protect_groups=awq_selected.get(name), + group_size=h.awq_lite_group_size if name in awq_selected else None, + protect_clip_range=(2 ** (h.awq_lite_bits - 1) - 1) + if name in awq_selected + else None, + ) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = f"gptq (int{bits})" + W_q = q.float() * s.float().view(-1, 1) + if protect_meta is not None: + result[name + ".awqg_start"] = protect_meta["starts"] + result[name + ".awqg_s_hi"] = protect_meta["s_hi"] + result[name + ".awqg_size"] = torch.tensor( + protect_meta["size"], dtype=torch.int16 + ) + meta[name] = meta[name] + f"+awqgrpint{h.awq_lite_bits}" + gsz = protect_meta["size"] + for start in protect_meta["starts"].tolist(): + W_q[:, start : start + gsz] = ( + q[:, start : start + gsz].float() + * protect_meta["s_hi"].float().view(-1, 1) + ) + if lqer_on: + # LQER is fit on top of the fully realized GPTQ base, which already + # includes any higher-precision AWQ-protected groups. + scope = str(getattr(h, "lqer_scope", "all")).lower() + scope_ok = ( + scope == "all" + or (scope == "mlp" and ".mlp." in name) + or (scope == "attn" and ".attn." in name) + or (scope == "embed" and "tok_emb" in name) + ) + if scope_ok: + E = t.float() - W_q + err_norm = float(E.norm()) + if err_norm > 0: + lqer_cands[name] = (E, err_norm) + if lqer_on and lqer_cands: + if bool(getattr(h, "lqer_gain_select", False)): + scored = [] + for (name, (E, base_err)) in lqer_cands.items(): + fit = _lqer_fit_quantized(E, h) + if fit is None: + continue + new_err = float((E - fit["delta"]).norm()) + gain = base_err - new_err + if gain > 0: + scored.append((gain, name, fit)) + scored.sort(key=lambda x: -x[0]) + for (_gain, name, fit) in scored[: h.lqer_top_k]: + if fit["kind"] == "asym": + result[name + ".lqA_a"] = fit["qA"] + result[name + ".lqAs_a"] = fit["sA"] + result[name + ".lqB_a"] = fit["qB"] + result[name + ".lqBs_a"] = fit["sB"] + meta[name] = meta[name] + "+lqer_asym" + else: + result[name + ".lqA"] = fit["qA"] + result[name + ".lqAs"] = fit["sA"] + result[name + ".lqB"] = fit["qB"] + result[name + ".lqBs"] = fit["sB"] + meta[name] = meta[name] + "+lqer" + else: + top = sorted(lqer_cands.items(), key=lambda kv: -kv[1][1])[: h.lqer_top_k] + asym_on = bool(getattr(h, "lqer_asym_enabled", False)) + asym_g = int(getattr(h, "lqer_asym_group", 64)) + for (name, (E, _)) in top: + U, S, Vh = torch.linalg.svd(E, full_matrices=False) + r = min(h.lqer_rank, S.numel()) + A = (U[:, :r] * S[:r]).contiguous() + B = Vh[:r, :].contiguous() + if asym_on and B.numel() % asym_g == 0: + qA, sA, qB, sB = _lqer_pack_asym(A, B, asym_g) + result[name + ".lqA_a"] = qA + result[name + ".lqAs_a"] = sA + result[name + ".lqB_a"] = qB + result[name + ".lqBs_a"] = sB + meta[name] = meta[name] + "+lqer_asym" + else: + qA, sA, qB, sB = _lqer_pack(A, B, h.lqer_factor_bits) + result[name + ".lqA"] = qA + result[name + ".lqAs"] = sA + result[name + ".lqB"] = qB + result[name + ".lqBs"] = sB + meta[name] = meta[name] + "+lqer" + categories = collections.defaultdict(set) + for (name, cat) in meta.items(): + short = re.sub("\\.\\d+$", "", re.sub("blocks\\.\\d+", "blocks", name)) + categories[cat].add(short) + log("Quantized weights:") + for cat in sorted(categories): + log(f" {cat}: {', '.join(sorted(categories[cat]))}") + return result, meta + +def dequantize_mixed(result, meta, template_sd): + out = {} + for (name, orig) in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if "passthrough" in info: + t = result[name] + if t.dtype == torch.float16 and orig_dtype in ( + torch.float32, + torch.bfloat16, + ): + t = t.to(orig_dtype) + out[name] = t + continue + if info == "gate_int8_row": + gq = result[name + ".gq"] + gs = result[name + ".gs"] + out[name] = (gq.float() * gs.float().view(-1, 1)).to(orig_dtype) + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + W = q.float() * s.float().view(q.shape[0], *[1] * (q.ndim - 1)) + else: + W = q.float() * float(s.item()) + if "awqgrpint" in info: + starts = result[name + ".awqg_start"].tolist() + s_hi = result[name + ".awqg_s_hi"].float() + gsz = int(result[name + ".awqg_size"].item()) + for start in starts: + W[:, start : start + gsz] = ( + q[:, start : start + gsz].float() * s_hi.view(-1, 1) + ) + if "lqer_asym" in info: + qA_t = result[name + ".lqA_a"] + sA_t = result[name + ".lqAs_a"] + qB_t = result[name + ".lqB_a"] + sB_t = result[name + ".lqBs_a"] + qA = qA_t.float() * float(sA_t) + g_sz = qB_t.numel() // sB_t.numel() + qB = (qB_t.reshape(-1, g_sz).float() * sB_t.float().view(-1, 1)).reshape( + qB_t.shape + ) + W = W + qA @ qB + elif "lqer" in info: + qA = result[name + ".lqA"].float() * result[name + ".lqAs"].float().view(-1, 1) + qB = result[name + ".lqB"].float() * result[name + ".lqBs"].float().view(-1, 1) + W = W + qA @ qB + out[name] = W.to(orig_dtype) + return out + + +_BSHF_MAGIC = b"BSHF" + + +# ── Per-group lrzip compression (ported from PR#1586 via PR#1667/1729) ──────── + +_GROUP_ORDER = [ + "_tok_emb.weight.q", + "attn.c_k.weight.q", "attn.c_q.weight.q", + "attn.c_v.weight.q", "attn.proj.weight.q", + "mlp.fc.weight.q", "mlp.proj.weight.q", +] +_SIMSORT_KEYS = {"_tok_emb.weight.q", "attn.c_q.weight.q", "mlp.fc.weight.q"} +_PACK_MAGIC = b"PGRP" + + +def _similarity_sort_l1(matrix): + import numpy as _np + n = matrix.shape[0] + used = _np.zeros(n, dtype=bool) + order = [0] + used[0] = True + cur = matrix[0].astype(_np.float32) + for _ in range(n - 1): + dists = _np.sum(_np.abs(matrix[~used].astype(_np.float32) - cur), axis=1) + unused = _np.where(~used)[0] + best = unused[_np.argmin(dists)] + order.append(best) + used[best] = True + cur = matrix[best].astype(_np.float32) + return _np.array(order, dtype=_np.uint16) + + +def _lrzip_compress(data, tmpdir, label): + inp = os.path.join(tmpdir, f"{label}.bin") + out = f"{inp}.lrz" + with open(inp, "wb") as f: + f.write(data) + subprocess.run(["lrzip", "-z", "-L", "9", "-o", out, inp], capture_output=True, check=True) + with open(out, "rb") as f: + result = f.read() + os.remove(inp); os.remove(out) + return result + + +def _lrzip_decompress(data, tmpdir, label): + inp = os.path.join(tmpdir, f"{label}.lrz") + out = os.path.join(tmpdir, f"{label}.bin") + with open(inp, "wb") as f: + f.write(data) + subprocess.run(["lrzip", "-d", "-f", "-o", out, inp], capture_output=True, check=True) + with open(out, "rb") as f: + result = f.read() + os.remove(inp); os.remove(out) + return result + + +def _pack_streams(streams): + import struct + n = len(streams) + hdr = _PACK_MAGIC + struct.pack("= 2 + docs.append((start, end - start)) + return docs + + +def _build_ttt_global_batches(doc_entries, h, ascending=False): + batch_size = h.ttt_batch_size + global_doc_entries = sorted(doc_entries, key=lambda x: x[1][1]) + global_batches = [ + global_doc_entries[i : i + batch_size] + for i in range(0, len(global_doc_entries), batch_size) + ] + indexed = list(enumerate(global_batches)) + if not ascending: + indexed.sort(key=lambda ib: -max(dl for _, (_, dl) in ib[1])) + return indexed + + +def _init_batch_counter(path): + with open(path, "wb") as f: + f.write((0).to_bytes(4, "little")) + + +def _claim_next_batch(counter_path, queue_len): + try: + with open(counter_path, "r+b") as f: + fcntl.flock(f, fcntl.LOCK_EX) + idx = int.from_bytes(f.read(4), "little") + f.seek(0) + f.write((idx + 1).to_bytes(4, "little")) + f.flush() + except FileNotFoundError: + return queue_len + return idx + + +def _compute_chunk_window(ci, pred_len, num_chunks, chunk_size, eval_seq_len): + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_start = ci * chunk_size + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + + +def _accumulate_bpb( + ptl, + x, + y, + chunk_offsets, + chunk_lens, + pos_idx, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + loss_sum, + byte_sum, + token_count, + y_bytes=None, +): + pos = pos_idx[: x.size(1)].unsqueeze(0) + mask = ( + (chunk_lens.unsqueeze(1) > 0) + & (pos >= chunk_offsets.unsqueeze(1)) + & (pos < (chunk_offsets + chunk_lens).unsqueeze(1)) + ) + mask_f64 = mask.to(torch.float64) + if y_bytes is not None: + tok_bytes = y_bytes.to(torch.float64) + else: + tok_bytes = base_bytes_lut[y].to(torch.float64) + tok_bytes += (has_leading_space_lut[y] & ~is_boundary_token_lut[x]).to( + torch.float64 + ) + loss_sum += (ptl.to(torch.float64) * mask_f64).sum() + byte_sum += (tok_bytes * mask_f64).sum() + token_count += chunk_lens.to(torch.float64).sum() + + +def _loss_bpb_from_sums(loss_sum, token_count, byte_sum): + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_sum.item()) + return val_loss, val_bpb + + +def _add_to_counter(path, delta): + try: + with open(path, "r+b") as f: + fcntl.flock(f, fcntl.LOCK_EX) + cur = int.from_bytes(f.read(8), "little", signed=True) + cur += int(delta) + f.seek(0) + f.write(int(cur).to_bytes(8, "little", signed=True)) + f.flush() + return cur + except FileNotFoundError: + return int(delta) + + +def _init_int64_counter(path): + with open(path, "wb") as f: + f.write((0).to_bytes(8, "little", signed=True)) + + +def _select_ttt_doc_entries(docs, h): + doc_entries = list(enumerate(docs)) + if h.val_doc_fraction < 1.0: + sample_n = max(1, int(round(len(docs) * h.val_doc_fraction))) + sampled_indices = sorted( + random.Random(h.seed).sample(range(len(docs)), sample_n) + ) + return [(i, docs[i]) for i in sampled_indices] + return doc_entries + + +def train_val_ttt_global_sgd_distributed(h, device, val_data, base_model, val_tokens, batch_seqs=None): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + base_model.eval() + seq_len = h.eval_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = h.global_ttt_chunk_tokens + batch_seqs = h.global_ttt_batch_seqs if batch_seqs is None else batch_seqs + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + ttt_params = [p for p in base_model.parameters()] + for p in ttt_params: + p.requires_grad_(True) + optimizer = torch.optim.SGD( + ttt_params, lr=h.global_ttt_lr, momentum=h.global_ttt_momentum + ) + t_start = time.perf_counter() + for ci in range(num_chunks): + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + is_last_chunk = ci == num_chunks - 1 + if is_last_chunk or h.global_ttt_epochs <= 0: + continue + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs <= 0: + continue + warmup_chunks = max(0, min(h.global_ttt_warmup_chunks, num_chunks - 1)) + if warmup_chunks > 0 and ci < warmup_chunks: + warmup_denom = max(warmup_chunks - 1, 1) + warmup_t = ci / warmup_denom + lr_now = ( + h.global_ttt_warmup_start_lr + + (h.global_ttt_lr - h.global_ttt_warmup_start_lr) * warmup_t + ) + else: + decay_steps = max(num_chunks - 1 - warmup_chunks, 1) + decay_ci = max(ci - warmup_chunks, 0) + lr_now = h.global_ttt_lr * 0.5 * ( + 1.0 + math.cos(math.pi * decay_ci / decay_steps) + ) + for pg in optimizer.param_groups: + pg["lr"] = lr_now + my_seq_s = chunk_seqs * h.rank // h.world_size + my_seq_e = chunk_seqs * (h.rank + 1) // h.world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ in range(h.global_ttt_epochs): + for bs in range(0, my_chunk_seqs, batch_seqs): + be = min(bs + batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x_flat = local[:-1] + y_flat = local[1:] + optimizer.zero_grad(set_to_none=True) + with torch.enable_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + if h.global_ttt_respect_doc_boundaries: + bos_pos = (x_flat == BOS_ID).nonzero(as_tuple=True)[0].tolist() + cu_seqlens, max_seqlen = _build_cu_seqlens( + bos_pos, x_flat.numel(), x_flat.device, h.eval_seq_len, 64 + ) + loss = base_model( + x_flat[None], + y_flat[None], + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + else: + x = x_flat.reshape(-1, seq_len) + y = y_flat.reshape(-1, seq_len) + loss = base_model(x, y) + loss.backward() + if dist.is_available() and dist.is_initialized(): + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.SUM) + p.grad.mul_(1.0 / h.world_size) + if h.global_ttt_grad_clip > 0: + torch.nn.utils.clip_grad_norm_(ttt_params, h.global_ttt_grad_clip) + optimizer.step() + base_model.eval() + if h.rank == 0: + elapsed = time.perf_counter() - t_start + log( + f"tttg: c{ci+1}/{num_chunks} lr:{lr_now:.6f} t:{elapsed:.1f}s" + ) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + +def _compute_ngram_hints_for_val(h, val_data, log0=print): + """Stage 1A: precompute ngram hints over full val token sequence. + Returns (hint_global, gate_global, boost_global) tensors on CPU, or None if tilt disabled. + + Compliance: single L->R pass over val tokens; uses val data only; produces hint + aligned to target positions [t] for predicting all_tokens[t+1] from prefix [:t+1]. + Same compute as inline precompute, just relocated to run BEFORE eval timer. + """ + if not getattr(h, "ngram_tilt_enabled", False): + return None + from online_ngram_tilt import build_hints_for_targets + all_tokens = val_data.val_tokens + targets_np_all = all_tokens.cpu().numpy().astype("uint16", copy=False)[1:] + t_h0 = time.perf_counter() + hints_pkg = build_hints_for_targets( + target_token_ids_np=targets_np_all, + tokenizer_path=h.tokenizer_path, + vocab_size=h.vocab_size, + log0=log0, + token_order=h.token_order, + token_threshold=h.token_threshold, + token_boost=h.token_boost, + within_tau=h.within_tau, + within_boost=h.within_boost, + word_order=h.word_order, + word_normalize=h.word_normalize, + word_tau=h.word_tau, + word_boost=h.word_boost, + agree_add_boost=h.agree_add_boost, + ) + hint_global = torch.from_numpy(hints_pkg["hint_ids"].astype("int64")) + gate_global = torch.from_numpy(hints_pkg["gate_mask"]) + boost_global = torch.from_numpy(hints_pkg["boost"].astype("float32")) + log0( + f"ngram_tilt:precompute_outside_timer_done elapsed={time.perf_counter()-t_h0:.2f}s " + f"total_targets={hint_global.numel()}" + ) + return (hint_global, gate_global, boost_global) + + +def eval_val_ttt_phased(h, base_model, device, val_data, forward_ttt_train, precomputed_hints=None): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + all_tokens = val_data.val_tokens + all_tokens_idx = all_tokens.to(torch.int32) + # === PR #1145 n-gram tilt: precompute prefix-only hints over val targets === + # Hints are aligned to target positions: hint_global[i] is the hint for + # predicting token all_tokens[i+1] given prefix all_tokens[:i+1]. + # Stored on CPU as int64; gathered per-chunk to GPU alongside y indices. + ngram_hint_global = None + ngram_gate_global = None + ngram_boost_global = None + if precomputed_hints is not None: + # v5 Stage 1A: hints were precomputed BEFORE eval timer started. + # Save measured eval time = the precompute elapsed (~168s for full tilt). + ngram_hint_global, ngram_gate_global, ngram_boost_global = precomputed_hints + log( + f"ngram_tilt:using_precomputed_hints " + f"total_targets={ngram_hint_global.numel()} (precompute time excluded from eval)" + ) + elif getattr(h, "ngram_tilt_enabled", False): + from online_ngram_tilt import build_hints_for_targets + targets_np_all = all_tokens.cpu().numpy().astype("uint16", copy=False)[1:] + t_h0 = time.perf_counter() + hints_pkg = build_hints_for_targets( + target_token_ids_np=targets_np_all, + tokenizer_path=h.tokenizer_path, + vocab_size=h.vocab_size, + log0=log, + token_order=h.token_order, + token_threshold=h.token_threshold, + token_boost=h.token_boost, + within_tau=h.within_tau, + within_boost=h.within_boost, + word_order=h.word_order, + word_normalize=h.word_normalize, + word_tau=h.word_tau, + word_boost=h.word_boost, + agree_add_boost=h.agree_add_boost, + ) + ngram_hint_global = torch.from_numpy(hints_pkg["hint_ids"].astype("int64")) + ngram_gate_global = torch.from_numpy(hints_pkg["gate_mask"]) + ngram_boost_global = torch.from_numpy(hints_pkg["boost"].astype("float32")) + log( + f"ngram_tilt:precompute_done elapsed={time.perf_counter()-t_h0:.2f}s " + f"total_targets={ngram_hint_global.numel()}" + ) + docs = _find_docs(all_tokens) + doc_entries = _select_ttt_doc_entries(docs, h) + prefix_doc_limit = max(0, min(len(doc_entries), int(h.phased_ttt_prefix_docs))) + num_phases = max(1, int(h.phased_ttt_num_phases)) + phase_boundaries = [] + for pi in range(num_phases): + boundary = prefix_doc_limit * (pi + 1) // num_phases + phase_boundaries.append(boundary) + current_phase = 0 + current_phase_boundary = phase_boundaries[0] + log( + "ttt_phased:" + f" total_docs:{len(doc_entries)} prefix_docs:{prefix_doc_limit} " + f"suffix_docs:{len(doc_entries) - prefix_doc_limit}" + f" num_phases:{num_phases} boundaries:{phase_boundaries}" + ) + chunk_size, eval_seq_len = h.ttt_chunk_size, h.ttt_eval_seq_len + eval_batch_set = None + if h.ttt_eval_batches: + eval_batch_set = set(int(x) for x in h.ttt_eval_batches.split(",") if x.strip()) + use_ascending = eval_batch_set is not None + global_batches_sorted = _build_ttt_global_batches( + doc_entries, h, ascending=use_ascending + ) + queue_len = len(global_batches_sorted) + counter_path = f"/tmp/ttt_counter_{h.run_id}" + prefix_counter_path = f"/tmp/ttt_prefix_counter_{h.run_id}" + pause_flag_path = f"/tmp/ttt_pause_flag_{h.run_id}" + if h.rank == 0: + _init_batch_counter(counter_path) + _init_int64_counter(prefix_counter_path) + try: + os.remove(pause_flag_path) + except FileNotFoundError: + pass + if dist.is_available() and dist.is_initialized(): + path_list = [counter_path, prefix_counter_path, pause_flag_path] + dist.broadcast_object_list(path_list, src=0) + counter_path, prefix_counter_path, pause_flag_path = path_list + dist.barrier() + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + t_start = time.perf_counter() + reusable_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + + def _build_opt(lora): + if h.ttt_optimizer == "sgd": + return torch.optim.SGD( + lora.parameters(), lr=h.ttt_lora_lr, + momentum=h.ttt_beta1, weight_decay=h.ttt_weight_decay, + ) + return torch.optim.AdamW( + lora.parameters(), lr=h.ttt_lora_lr, + betas=(h.ttt_beta1, h.ttt_beta2), + eps=1e-10, weight_decay=h.ttt_weight_decay, fused=True, + ) + + reusable_opt = _build_opt(reusable_lora) + local_scored_docs = [] + global_ttt_done = prefix_doc_limit == 0 + try: + while True: + queue_idx = _claim_next_batch(counter_path, queue_len) + if queue_idx >= queue_len: + break + orig_batch_idx, batch_entries = global_batches_sorted[queue_idx] + batch = [doc for _, doc in batch_entries] + bsz = len(batch) + prev_loss = loss_sum.item() + prev_bytes = byte_sum.item() + prev_tokens = token_count.item() + if bsz == reusable_lora.bsz: + reusable_lora.reset() + for s in reusable_opt.state.values(): + for k, v in s.items(): + if isinstance(v, torch.Tensor): + v.zero_() + elif k == "step": + s[k] = 0 + cur_lora = reusable_lora + cur_opt = reusable_opt + else: + cur_lora = BatchedTTTLoRA( + bsz, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + cur_opt = _build_opt(cur_lora) + pred_lens = [doc_len - 1 for _, doc_len in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + num_chunks_t = torch.tensor(num_chunks, dtype=torch.int64, device=device) + for ci in range(max_nc): + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + tok_starts = torch.zeros(bsz, dtype=torch.int64) + tok_wls = torch.zeros(bsz, dtype=torch.int64) + chunk_offsets_cpu = torch.zeros(bsz, dtype=torch.int64) + chunk_lens_cpu = torch.zeros(bsz, dtype=torch.int64) + for b in range(bsz): + if not active[b]: + continue + doc_start, doc_len = batch[b] + win_start, win_len, chunk_offset, chunk_len = _compute_chunk_window( + ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len + ) + tok_starts[b] = doc_start + win_start + tok_wls[b] = win_len + chunk_offsets_cpu[b] = chunk_offset + chunk_lens_cpu[b] = chunk_len + _, context_size, chunk_offset, _ = _compute_chunk_window( + ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len + ) + col_idx = torch.arange(context_size + 1) + idx = tok_starts.unsqueeze(1) + col_idx.unsqueeze(0) + idx.clamp_(max=all_tokens.numel() - 1) + gathered_gpu = all_tokens_idx[idx].to( + device=device, dtype=torch.int64, non_blocking=True + ) + valid = (col_idx[:context_size].unsqueeze(0) < tok_wls.unsqueeze(1)).to( + device, non_blocking=True + ) + chunk_offsets = chunk_offsets_cpu.to(device, non_blocking=True) + chunk_lens = chunk_lens_cpu.to(device, non_blocking=True) + x = torch.where(valid, gathered_gpu[:, :context_size], 0) + y = torch.where(valid, gathered_gpu[:, 1 : context_size + 1], 0) + ctx_pos = torch.arange(context_size, device=device, dtype=torch.int64) + # n-gram tilt path: gather hints aligned to y, pass into forward_ttt + hint_ids_gpu = None + gate_mask_gpu = None + boost_gpu = None + if ngram_hint_global is not None: + hint_idx_cpu = ( + tok_starts.unsqueeze(1) + col_idx[:context_size].unsqueeze(0) + ).clamp_(min=0, max=ngram_hint_global.numel() - 1) + hint_ids_gpu = ngram_hint_global[hint_idx_cpu].to( + device=device, dtype=torch.int64, non_blocking=True + ) + gate_mask_gpu = ngram_gate_global[hint_idx_cpu].to( + device=device, non_blocking=True + ) + boost_gpu = ngram_boost_global[hint_idx_cpu].to( + device=device, dtype=torch.float32, non_blocking=True + ) + hint_ids_gpu = torch.where(valid, hint_ids_gpu, torch.zeros_like(hint_ids_gpu)) + gate_mask_gpu = gate_mask_gpu & valid + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + if hint_ids_gpu is not None: + per_tok_loss, log_q_hint = forward_ttt_train( + x, y, lora=cur_lora, hint_ids=hint_ids_gpu + ) + else: + per_tok_loss = forward_ttt_train(x, y, lora=cur_lora) + log_q_hint = None + # CaseOps sidecar-driven byte budget. Mirror the index pattern + # used to build y from all_tokens: y[b, j] corresponds to the + # token at global position tok_starts[b] + 1 + j (when valid). + y_bytes_arg = None + if val_data.caseops_enabled and val_data.val_bytes is not None: + y_idx = ( + tok_starts.unsqueeze(1) + + 1 + + col_idx[:context_size].unsqueeze(0) + ) + y_idx = y_idx.clamp_(max=val_data.val_bytes.numel() - 1) + y_bytes_arg = val_data.val_bytes[y_idx].to( + device=device, dtype=torch.int32, non_blocking=True + ) + # Mirror the `valid` masking used for y so out-of-range tokens + # contribute zero bytes (matches y=0 substitution above). + y_bytes_arg = torch.where( + valid, y_bytes_arg, torch.zeros_like(y_bytes_arg) + ) + # n-gram tilt application: use tilted ptl for BPB accumulation, + # but keep original per_tok_loss for TTT-LoRA backward (training + # objective is base NLL — tilt is a scoring-time overlay). + if hint_ids_gpu is not None and log_q_hint is not None: + from online_ngram_tilt import apply_tilt_to_ptl_torch_fast as apply_tilt_to_ptl_torch + tilted_loss = apply_tilt_to_ptl_torch( + ptl=per_tok_loss, + log_q_hint=log_q_hint, + target_ids=y, + hint_ids=hint_ids_gpu, + gate_mask=gate_mask_gpu, + boost=boost_gpu, + ) + else: + tilted_loss = per_tok_loss + with torch.no_grad(): + _accumulate_bpb( + tilted_loss, + x, + y, + chunk_offsets, + chunk_lens, + ctx_pos, + val_data.base_bytes_lut, + val_data.has_leading_space_lut, + val_data.is_boundary_token_lut, + loss_sum, + byte_sum, + token_count, + y_bytes=y_bytes_arg, + ) + if needs_train: + activate_chunk_mask = (num_chunks_t - 1 > ci).float() + for gi in range(h.ttt_grad_steps): + if gi > 0: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_tok_loss = forward_ttt_train(x, y, lora=cur_lora) + per_doc = per_tok_loss[ + :, chunk_offset : chunk_offset + chunk_size + ].mean(dim=-1) + cur_opt.zero_grad(set_to_none=True) + (per_doc * activate_chunk_mask).sum().backward() + cur_opt.step() + else: + del per_tok_loss + batch_num = orig_batch_idx + 1 + doc_lens = [dl for _, dl in batch] + should_report = batch_num in eval_batch_set if eval_batch_set is not None else True + if should_report: + cur_tokens = token_count.item() + cur_loss_val = loss_sum.item() + cur_bytes_val = byte_sum.item() + dt = cur_tokens - prev_tokens + db = cur_bytes_val - prev_bytes + if dt > 0 and db > 0: + b_loss = (cur_loss_val - prev_loss) / dt + b_bpb = b_loss / math.log(2.0) * (dt / db) + else: + b_loss = b_bpb = 0.0 + r_loss = cur_loss_val / max(cur_tokens, 1) + r_bpb = r_loss / math.log(2.0) * (cur_tokens / max(cur_bytes_val, 1)) + elapsed = time.perf_counter() - t_start + log( + f"ttp: b{batch_num}/{queue_len} bl:{b_loss:.4f} bb:{b_bpb:.4f} " + f"rl:{r_loss:.4f} rb:{r_bpb:.4f} dl:{min(doc_lens)}-{max(doc_lens)} " + f"gd:{int(global_ttt_done)}" + ) + if not global_ttt_done: + local_scored_docs.extend( + (orig_batch_idx, pos, doc_start, doc_len) + for pos, (doc_start, doc_len) in enumerate(batch) + ) + prefix_done = _add_to_counter(prefix_counter_path, len(batch_entries)) + if prefix_done >= current_phase_boundary: + try: + with open(pause_flag_path, "x"): + pass + except FileExistsError: + pass + should_pause = os.path.exists(pause_flag_path) + if should_pause: + if dist.is_available() and dist.is_initialized(): + dist.barrier() + gathered_scored_docs = [None] * h.world_size + if dist.is_available() and dist.is_initialized(): + dist.all_gather_object(gathered_scored_docs, local_scored_docs) + else: + gathered_scored_docs = [local_scored_docs] + scored_docs_for_global = [] + for rank_docs in gathered_scored_docs: + if rank_docs: + scored_docs_for_global.extend(rank_docs) + scored_docs_for_global.sort(key=lambda x: (x[0], x[1])) + scored_docs_for_global = scored_docs_for_global[:current_phase_boundary] + scored_token_chunks = [ + val_data.val_tokens[doc_start : doc_start + doc_len] + for _, _, doc_start, doc_len in scored_docs_for_global + ] + if scored_token_chunks: + global_ttt_tokens = torch.cat(scored_token_chunks) + else: + global_ttt_tokens = val_data.val_tokens[:0] + if h.rank == 0: + prefix_done = 0 + try: + with open(prefix_counter_path, "rb") as f: + prefix_done = int.from_bytes( + f.read(8), "little", signed=True + ) + except FileNotFoundError: + pass + log( + f"ttpp: phase:{current_phase + 1}/{num_phases} pd:{prefix_done} " + f"gd:{len(scored_docs_for_global)} " + f"t:{time.perf_counter() - t_start:.1f}s" + ) + train_val_ttt_global_sgd_distributed( + h, device, val_data, base_model, global_ttt_tokens + ) + for p in base_model.parameters(): + p.requires_grad_(False) + reusable_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + reusable_opt = _build_opt(reusable_lora) + current_phase += 1 + if current_phase >= num_phases: + global_ttt_done = True + else: + current_phase_boundary = phase_boundaries[current_phase] + if h.rank == 0: + try: + os.remove(pause_flag_path) + except FileNotFoundError: + pass + if dist.is_available() and dist.is_initialized(): + dist.barrier() + if h.rank == 0: + log(f"ttpr: phase:{current_phase}/{num_phases} t:{time.perf_counter() - t_start:.1f}s") + del cur_lora, cur_opt + finally: + pass + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.train() + return _loss_bpb_from_sums(loss_sum, token_count, byte_sum) + + +def timed_eval(label, fn, *args, **kwargs): + torch.cuda.synchronize() + t0 = time.perf_counter() + val_loss, val_bpb = fn(*args, **kwargs) + torch.cuda.synchronize() + elapsed_ms = 1e3 * (time.perf_counter() - t0) + log( + f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms" + ) + return val_loss, val_bpb + + +def train_model(h, device, val_data): + base_model = GPT(h).to(device).bfloat16() + restore_fp32_params(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + compiled_forward_logits = torch.compile( + base_model.forward_logits, dynamic=False, fullgraph=True + ) + model = compiled_model + log(f"model_params:{sum(p.numel()for p in base_model.parameters())}") + optimizers = Optimizers(h, base_model) + train_loader = DocumentPackingLoader(h, device) + max_wallclock_ms = ( + 1e3 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None + ) + if max_wallclock_ms is not None: + max_wallclock_ms -= h.gptq_reserve_seconds * 1e3 + log( + f"gptq:reserving {h.gptq_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms" + ) + + def training_frac(step, elapsed_ms): + if max_wallclock_ms is None: + return step / max(h.iterations, 1) + return elapsed_ms / max(max_wallclock_ms, 1e-09) + + def lr_mul(frac): + if h.warmdown_frac <= 0: + return 1.0 + if frac >= 1.0 - h.warmdown_frac: + return max((1.0 - frac) / h.warmdown_frac, h.min_lr) + return 1.0 + + _clip_params = [p for p in base_model.parameters() if p.requires_grad] + def step_fn(step, lr_scale): + train_loss = torch.zeros((), device=device) + for micro_step in range(h.grad_accum_steps): + x, y, cu_seqlens, _max_seqlen = train_loader.next_batch( + h.train_batch_tokens, h.grad_accum_steps + ) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y, cu_seqlens=cu_seqlens, max_seqlen=h.train_seq_len) + train_loss += loss.detach() + (loss / h.grad_accum_steps).backward() + train_loss /= h.grad_accum_steps + if step <= h.muon_momentum_warmup_steps: + + frac = ( + + min(step / h.muon_momentum_warmup_steps, 1.0) + + if h.muon_momentum_warmup_steps > 0 + + else 1.0 + + ) + + muon_momentum = ( + + 1 - frac + + ) * h.muon_momentum_warmup_start + frac * h.muon_momentum + + for group in optimizers.optimizer_muon.param_groups: + + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * lr_scale + if h.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(_clip_params, h.grad_clip_norm) + optimizers.step(distributed=h.distributed) + return train_loss + + if h.warmup_steps > 0: + initial_model_state = { + name: tensor.detach().cpu().clone() + for (name, tensor) in base_model.state_dict().items() + } + initial_optimizer_states = [ + copy.deepcopy(opt.state_dict()) for opt in optimizers + ] + model.train() + num_tokens_local = h.train_batch_tokens // h.world_size + for blk in base_model.blocks: + blk.attn.rotary(num_tokens_local, device, torch.bfloat16) + cu_bucket_size = train_loader.cu_bucket_size + warmup_cu_buckets = tuple(cu_bucket_size * i for i in range(1, 5)) + warmup_cu_iters = 3 + x, y, cu_seqlens, _ = train_loader.next_batch( + h.train_batch_tokens, h.grad_accum_steps + ) + log(f"warmup_cu_buckets:{','.join(str(b) for b in warmup_cu_buckets)} iters_each:{warmup_cu_iters}") + def _run_cu_bucket_warmup(): + for bucket_len in warmup_cu_buckets: + boundaries = list(range(0, x.size(1), max(h.train_seq_len, 1))) + if boundaries[-1] != x.size(1): + boundaries.append(x.size(1)) + cu = torch.full((bucket_len,), x.size(1), dtype=torch.int32, device=device) + cu[: len(boundaries)] = torch.tensor(boundaries, dtype=torch.int32, device=device) + for _ in range(warmup_cu_iters): + optimizers.zero_grad_all() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + wloss = model(x, y, cu_seqlens=cu, max_seqlen=h.train_seq_len) + (wloss / h.grad_accum_steps).backward() + optimizers.zero_grad_all() + _run_cu_bucket_warmup() + if h.num_loops > 0: + base_model.looping_active = True + _run_cu_bucket_warmup() + base_model.looping_active = False + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if ( + warmup_step <= 5 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == h.warmup_steps + ): + log(f"warmup_step: {warmup_step+1}/{h.warmup_steps}") + if h.num_loops > 0: + base_model.looping_active = True + log( + f"loop_warmup:enabled encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}" + ) + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if ( + warmup_step <= 5 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == h.warmup_steps + ): + log(f"loop_warmup_step: {warmup_step+1}/{h.warmup_steps}") + base_model.looping_active = False + base_model.load_state_dict(initial_model_state, strict=True) + for (opt, state) in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + optimizers.zero_grad_all() + train_loader = DocumentPackingLoader(h, device) + _live_state = base_model.state_dict(keep_vars=True) + ema_state = { + name: t.detach().float().clone() + for (name, t) in _live_state.items() + } + _ema_pairs = [(ema_state[name], t) for (name, t) in _live_state.items()] + ema_decay = h.ema_decay + training_time_ms = 0.0 + forced_stop_step = int(os.environ.get("FORCE_STOP_STEP", "0")) + stop_after_step = forced_stop_step if forced_stop_step > 0 else None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = ( + step == h.iterations + or stop_after_step is not None + and step >= stop_after_step + ) + should_validate = ( + last_step or h.val_loss_every > 0 and step % h.val_loss_every == 0 + ) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1e3 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + h, device, val_data, model, compiled_forward_logits + ) + log( + f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < h.iterations: + log( + f"stopping_early: wallclock_cap train_time: {training_time_ms:.0f}ms step: {step}/{h.iterations}" + ) + break + elapsed_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) + frac = training_frac(step, elapsed_ms) + scale = lr_mul(frac) + if ( + h.num_loops > 0 + and not base_model.looping_active + and frac >= h.enable_looping_at + ): + base_model.looping_active = True + log( + f"layer_loop:enabled step:{step} frac:{frac:.3f} encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}" + ) + train_loss = step_fn(step, scale) + with torch.no_grad(): + for ema_t, t in _ema_pairs: + ema_t.mul_(ema_decay).add_(t.detach(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) + should_log_train = h.train_log_every > 0 and ( + step <= 5 or step % h.train_log_every == 0 or stop_after_step is not None + ) + if should_log_train: + tok_per_sec = step * h.train_batch_tokens / (approx_training_time_ms / 1e3) + log( + f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} train_time: {approx_training_time_ms/60000:.1f}m tok/s: {tok_per_sec:.0f}" + ) + reached_cap = ( + forced_stop_step <= 0 + and max_wallclock_ms is not None + and approx_training_time_ms >= max_wallclock_ms + ) + if h.distributed and forced_stop_step <= 0 and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log( + f"peak memory allocated: {torch.cuda.max_memory_allocated()//1024//1024} MiB reserved: {torch.cuda.max_memory_reserved()//1024//1024} MiB" + ) + log("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = { + name: t.to(dtype=current_state[name].dtype) for (name, t) in ema_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + return base_model, compiled_model, compiled_forward_logits + + +def train_and_eval(h, device): + global BOS_ID + random.seed(h.seed) + np.random.seed(h.seed) + torch.manual_seed(h.seed) + torch.cuda.manual_seed_all(h.seed) + if h.artifact_dir and h.is_main_process: + os.makedirs(h.artifact_dir, exist_ok=True) + val_data = ValidationData(h, device) + log( + f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}" + ) + log(f"val_tokens: {val_data.val_tokens.numel()-1}") + # TTT_EVAL_ONLY: skip training + GPTQ, jump straight to TTT eval on a + # pre-existing quantized artifact. Used to test TTT-only improvements + # (e.g., PR-1767's alpha/warm-start/WD) without retraining. + ttt_eval_only = os.environ.get("TTT_EVAL_ONLY", "0") == "1" + quantize_only = os.environ.get("QUANTIZE_ONLY", "0") == "1" + if ttt_eval_only: + log("TTT_EVAL_ONLY=1 — skipping training + GPTQ, loading saved artifact for TTT eval") + log(f"ttt_lora_alpha: {BatchedLinearLoRA._ALPHA}") + log(f"ttt_warm_start_a: {BatchedLinearLoRA._WARM_START_A}") + log(f"ttt_weight_decay: {h.ttt_weight_decay}") + elif quantize_only: + log("QUANTIZE_ONLY=1 — skipping training, loading saved full-precision checkpoint") + log(f"quantize_only checkpoint: {h.model_path}") + if BOS_ID is None: + BOS_ID = 1 + base_model = GPT(h).to(device).bfloat16() + state = torch.load(h.model_path, map_location="cpu") + base_model.load_state_dict(state, strict=True) + del state + serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) + if h.distributed: + dist.barrier() + else: + base_model, compiled_model, compiled_forward_logits = train_model( + h, device, val_data + ) + torch._dynamo.reset() + timed_eval( + "diagnostic pre-quantization post-ema", + eval_val, + h, + device, + val_data, + compiled_model, + compiled_forward_logits, + ) + if os.environ.get("PREQUANT_ONLY", "0") == "1": + log("PREQUANT_ONLY=1 — skipping serialize/GPTQ/post-quant eval/TTT") + return + serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) + if h.distributed: + dist.barrier() + eval_model = deserialize(h, device) + if h.num_loops > 0: + eval_model.looping_active = True + if not ttt_eval_only: + compiled_model = torch.compile(eval_model, dynamic=False, fullgraph=True) + compiled_forward_logits = torch.compile( + eval_model.forward_logits, dynamic=False, fullgraph=True + ) + timed_eval( + "diagnostic quantized", + eval_val, + h, + device, + val_data, + compiled_model, + compiled_forward_logits, + ) + del eval_model + if h.ttt_enabled: + if not ttt_eval_only: + del compiled_model + if ttt_eval_only: + del eval_model + torch._dynamo.reset() + torch.cuda.empty_cache() + ttt_model = deserialize(h, device) + if h.num_loops > 0: + ttt_model.looping_active = True + for p in ttt_model.parameters(): + p.requires_grad_(False) + + if h.rope_yarn: + _yarn_seqlen = h.train_batch_tokens // h.grad_accum_steps + for block in ttt_model.blocks: + block.attn.rotary(_yarn_seqlen, device, torch.bfloat16) + else: + for block in ttt_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + block.attn.rotary._seq_len_cached = 0 + block.attn.rotary(h.ttt_eval_seq_len, device, torch.bfloat16) + + def _fwd_ttt_inner(input_ids, target_ids, lora): + return ttt_model.forward_ttt(input_ids, target_ids, lora=lora) + + def _fwd_ttt_inner_with_hints(input_ids, target_ids, lora, hint_ids): + return ttt_model.forward_ttt(input_ids, target_ids, lora=lora, hint_ids=hint_ids) + + _fwd_ttt_compiled_inner = None + _fwd_ttt_compiled_inner_hints = None + + def _fwd_ttt(input_ids, target_ids, lora, hint_ids=None): + nonlocal _fwd_ttt_compiled_inner, _fwd_ttt_compiled_inner_hints + if hint_ids is None: + if _fwd_ttt_compiled_inner is None: + _fwd_ttt_compiled_inner = torch.compile(_fwd_ttt_inner, dynamic=True) + return _fwd_ttt_compiled_inner(input_ids, target_ids, lora=lora) + if _fwd_ttt_compiled_inner_hints is None: + _fwd_ttt_compiled_inner_hints = torch.compile( + _fwd_ttt_inner_with_hints, dynamic=True + ) + return _fwd_ttt_compiled_inner_hints( + input_ids, target_ids, lora=lora, hint_ids=hint_ids + ) + + fwd_ttt_compiled = _fwd_ttt + log(f"ttt_lora:warming up compile (random tokens, no val data)") + if BOS_ID is None: + BOS_ID = 1 + t_warmup = time.perf_counter() + warmup_bszes = [h.ttt_batch_size] + for bsz in warmup_bszes: + wl = BatchedTTTLoRA( + bsz, ttt_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + wo = torch.optim.AdamW( + wl.parameters(), + lr=h.ttt_lora_lr, + betas=(h.ttt_beta1, h.ttt_beta2), + eps=1e-10, + weight_decay=h.ttt_weight_decay, + fused=True, + ) + for ctx_len in (h.ttt_chunk_size, h.ttt_eval_seq_len): + xw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + yw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = fwd_ttt_compiled(xw, yw, lora=wl) + ptl[:, : min(h.ttt_chunk_size, ctx_len)].mean(dim=-1).sum().backward() + wo.step() + wo.zero_grad(set_to_none=True) + del wl, wo + torch.cuda.empty_cache() + compile_elapsed = time.perf_counter() - t_warmup + log(f"ttt_lora:compile warmup done ({compile_elapsed:.1f}s)") + # v5 Stage 1A: precompute ngram hints BEFORE eval timer (single pass causal, + # uses val tokens only — same compliance as inline). For full tilt this saves + # ~168s of measured eval time without losing any tilt benefit. + precomputed_hints = None + if h.ngram_tilt_enabled and getattr(h, "ngram_hint_precompute_outside", True): + log("v5:precomputing ngram hints OUTSIDE eval timer") + precomputed_hints = _compute_ngram_hints_for_val(h, val_data, log0=log) + log("\nbeginning TTT eval timer") + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt_phased( + h, ttt_model, device, val_data, forward_ttt_train=fwd_ttt_compiled, + precomputed_hints=precomputed_hints, + ) + torch.cuda.synchronize() + ttt_eval_elapsed = time.perf_counter() - t_ttt + log( + "quantized_ttt_phased " + f"val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f} " + f"eval_time:{1e3*ttt_eval_elapsed:.0f}ms" + ) + log(f"total_eval_time:{ttt_eval_elapsed:.1f}s") + del ttt_model + + +def main(): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError( + f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral" + ) + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("high") + from torch.backends.cuda import ( + enable_cudnn_sdp, + enable_flash_sdp, + enable_math_sdp, + enable_mem_efficient_sdp, + ) + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.cache_size_limit = 64 + h = Hyperparameters() + set_logging_hparams(h) + if h.is_main_process: + os.makedirs(h.artifact_dir if h.artifact_dir else "logs", exist_ok=True) + log(100 * "=", console=False) + log("Hyperparameters:", console=True) + for (k, v) in sorted(vars(type(h)).items()): + if not k.startswith("_"): + log(f" {k}: {v}", console=True) + log("=" * 100, console=False) + log("Source code:", console=False) + log("=" * 100, console=False) + with open(__file__, "r", encoding="utf-8") as _src: + log(_src.read(), console=False) + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running PyTorch {torch.__version__}", console=False) + log("=" * 100, console=False) + train_and_eval(h, device) + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/train_seed1337.log b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/train_seed1337.log new file mode 100644 index 0000000000..86d9442b32 --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/train_seed1337.log @@ -0,0 +1,475 @@ +W0430 20:23:07.227000 412695 torch/distributed/run.py:803] +W0430 20:23:07.227000 412695 torch/distributed/run.py:803] ***************************************** +W0430 20:23:07.227000 412695 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0430 20:23:07.227000 412695 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + agree_add_boost: 0.5 + artifact_dir: /workspace/parameter-golf/records/gatedxsa_lqertop1_intimer_p1000_n1_competition_3seed/seed1337 + attn_clip_sigmas: 13.0 + attn_out_gate_enabled: False + attn_out_gate_src: proj + awq_lite_bits: 8 + awq_lite_enabled: True + awq_lite_group_size: 64 + awq_lite_group_top_k: 1 + beta1: 0.9 + beta2: 0.95 + caseops_enabled: True + compressor: pergroup + data_dir: ./data/ + datasets_dir: /tmp/parameter-golf-data-caseops/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved + distributed: True + ema_decay: 0.9965 + embed_bits: 7 + embed_clip_sigmas: 14.0 + embed_lr: 0.6 + embed_wd: 0.085 + enable_looping_at: 0.35 + eval_seq_len: 2560 + eval_stride: 64 + fused_ce_enabled: True + gate_window: 12 + gated_attn_enabled: False + gated_attn_init_std: 0.01 + gated_attn_quant_gate: False + gated_xsa_enabled: True + global_ttt_batch_seqs: 32 + global_ttt_chunk_tokens: 32768 + global_ttt_epochs: 1 + global_ttt_grad_clip: 1.0 + global_ttt_lr: 0.001 + global_ttt_momentum: 0.9 + global_ttt_respect_doc_boundaries: True + global_ttt_warmup_chunks: 0 + global_ttt_warmup_start_lr: 0.0 + gptq_calibration_batches: 16 + gptq_reserve_seconds: 4.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: /workspace/parameter-golf/records/gatedxsa_lqertop1_intimer_p1000_n1_competition_3seed/seed1337/gatedxsa_lqertop1_intimer_p1000_n1_s1337.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + lqer_asym_enabled: True + lqer_asym_group: 64 + lqer_enabled: True + lqer_factor_bits: 4 + lqer_gain_select: False + lqer_rank: 4 + lqer_scope: all + lqer_top_k: 1 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.026 + max_wallclock_seconds: 600.0 + min_lr: 0.1 + mlp_clip_sigmas: 11.5 + mlp_mult: 4.0 + model_dim: 512 + model_path: /workspace/parameter-golf/records/gatedxsa_lqertop1_intimer_p1000_n1_competition_3seed/seed1337/final_model.pt + muon_backend_steps: 5 + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + ngram_hint_precompute_outside: False + ngram_tilt_enabled: True + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_final_lane: mean + parallel_start_layer: 8 + phased_ttt_num_phases: 1 + phased_ttt_prefix_docs: 1000 + qk_gain_init: 5.25 + quantized_model_path: /workspace/parameter-golf/records/gatedxsa_lqertop1_intimer_p1000_n1_competition_3seed/seed1337/final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + rope_yarn: False + run_id: gatedxsa_lqertop1_intimer_p1000_n1_s1337 + scalar_lr: 0.02 + seed: 1337 + skip_gates_enabled: True + skylight_beta2: 0.95 + skylight_muon_enabled: False + skylight_uw_floor: 0.35 + smear_gate_enabled: True + sparse_attn_gate_enabled: True + sparse_attn_gate_init_std: 0.0 + sparse_attn_gate_scale: 1.0 + temperature_scale: 1.0 + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + token_boost: 2.625 + token_order: 16 + token_threshold: 0.8 + tokenizer_path: /tmp/parameter-golf-data-caseops/tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model + train_batch_tokens: 786432 + train_files: /tmp/parameter-golf-data-caseops/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.999 + ttt_chunk_size: 48 + ttt_enabled: True + ttt_eval_batches: + ttt_eval_seq_len: 2560 + ttt_grad_steps: 1 + ttt_k_lora: True + ttt_lora_lr: 0.0001 + ttt_lora_rank: 80 + ttt_mlp_lora: True + ttt_o_lora: True + ttt_optimizer: adam + ttt_weight_decay: 2.0 + val_batch_tokens: 524288 + val_bytes_files: /tmp/parameter-golf-data-caseops/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_bytes_*.bin + val_doc_fraction: 1.0 + val_files: /tmp/parameter-golf-data-caseops/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.75 + warmup_steps: 20 + within_boost: 0.75 + within_tau: 0.45 + word_boost: 0.75 + word_normalize: strip_punct_lower + word_order: 4 + word_tau: 0.65 + world_size: 8 + xsa_last_n: 11 +train_shards: 80 +val_tokens: 47851520 +model_params:35945761 +gptq:reserving 4s, effective=596000ms +warmup_cu_buckets:64,128,192,256 iters_each:3 +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0070 val_bpb: 4.1155 +1/20000 train_loss: 9.0078 train_time: 0.0m tok/s: 15801594 +2/20000 train_loss: 12.9685 train_time: 0.0m tok/s: 10938488 +3/20000 train_loss: 10.1966 train_time: 0.0m tok/s: 9828674 +4/20000 train_loss: 8.6730 train_time: 0.0m tok/s: 9312651 +5/20000 train_loss: 7.8727 train_time: 0.0m tok/s: 9048237 +500/20000 train_loss: 2.7176 train_time: 0.8m tok/s: 8251095 +1000/20000 train_loss: 2.7754 train_time: 1.6m tok/s: 8244907 +1500/20000 train_loss: 2.5861 train_time: 2.4m tok/s: 8243112 +2000/20000 train_loss: 2.5758 train_time: 3.2m tok/s: 8239167 +layer_loop:enabled step:2185 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +2500/20000 train_loss: 2.4740 train_time: 4.2m tok/s: 7766305 +3000/20000 train_loss: 2.6364 train_time: 5.4m tok/s: 7283348 +3500/20000 train_loss: 2.4117 train_time: 6.6m tok/s: 6975131 +4000/20000 train_loss: 2.5692 train_time: 7.8m tok/s: 6760652 +4000/20000 val_loss: 2.3834 val_bpb: 1.0890 +4500/20000 train_loss: 2.3899 train_time: 8.9m tok/s: 6603370 +4926/20000 val_loss: 2.2973 val_bpb: 1.0497 +stopping_early: wallclock_cap train_time: 596167ms step: 4926/20000 +peak memory allocated: 41724 MiB reserved: 46960 MiB +ema:applying EMA weights +diagnostic pre-quantization post-ema val_loss:2.30071266 val_bpb:1.05124428 eval_time:8532ms +Serialized model: 135421514 bytes +Code size (uncompressed): 187853 bytes +Code size (compressed): 47669 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 4.0s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int7)+awqgrpint8+lqer_asym: tok_emb.weight + passthrough (float16): blocks.attn.attn_gate_w, blocks.attn.q_gain, blocks.attn.xsa_alpha, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, parallel_post_lambdas, parallel_resid_lambdas, skip_gates, skip_weights, smear_gate.weight, smear_lambda, softcap_neg, softcap_pos +Serialize: per-group lrzip compression... +Serialize: per-group compression done in 117.0s +Serialized model quantized+pergroup: 15945077 bytes +Total submission size quantized+pergroup: 15992746 bytes +Deserialize: per-group lrzip decompression... +Deserialize: decompression done in 17.9s +diagnostic quantized val_loss:2.31966347 val_bpb:1.05990331 eval_time:12444ms +Deserialize: per-group lrzip decompression... +Deserialize: decompression done in 18.3s +ttt_lora:warming up compile (random tokens, no val data) +ttt_lora:compile warmup done (95.7s) + +beginning TTT eval timer +ngram_tilt:hints total=47851520 gated=13023303 token_gate=628130 within_gate=9866847 word_gate=2891588 agree2plus=303177 +ngram_tilt:precompute_done elapsed=146.63s total_targets=47851520 +ttt_phased: total_docs:50000 prefix_docs:1000 suffix_docs:49000 num_phases:1 boundaries:[1000] +ttp: b782/782 bl:2.1040 bb:0.9963 rl:2.1040 rb:0.9963 dl:30339-97114 gd:0 +ttpp: phase:1/1 pd:1424 gd:1000 t:211.1s +tttg: c1/154 lr:0.001000 t:0.3s +tttg: c2/154 lr:0.001000 t:0.4s +tttg: c3/154 lr:0.001000 t:0.5s +tttg: c4/154 lr:0.000999 t:0.6s +tttg: c5/154 lr:0.000998 t:0.7s +tttg: c6/154 lr:0.000997 t:0.8s +tttg: c7/154 lr:0.000996 t:0.9s +tttg: c8/154 lr:0.000995 t:1.0s +tttg: c9/154 lr:0.000993 t:1.2s +tttg: c10/154 lr:0.000991 t:1.3s +tttg: c11/154 lr:0.000989 t:1.4s +tttg: c12/154 lr:0.000987 t:1.5s +tttg: c13/154 lr:0.000985 t:1.6s +tttg: c14/154 lr:0.000982 t:1.7s +tttg: c15/154 lr:0.000979 t:1.8s +tttg: c16/154 lr:0.000976 t:1.9s +tttg: c17/154 lr:0.000973 t:2.0s +tttg: c18/154 lr:0.000970 t:2.2s +tttg: c19/154 lr:0.000966 t:2.3s +tttg: c20/154 lr:0.000962 t:2.4s +tttg: c21/154 lr:0.000958 t:2.5s +tttg: c22/154 lr:0.000954 t:2.6s +tttg: c23/154 lr:0.000950 t:2.7s +tttg: c24/154 lr:0.000945 t:2.8s +tttg: c25/154 lr:0.000941 t:2.9s +tttg: c26/154 lr:0.000936 t:3.1s +tttg: c27/154 lr:0.000930 t:3.2s +tttg: c28/154 lr:0.000925 t:3.3s +tttg: c29/154 lr:0.000920 t:3.4s +tttg: c30/154 lr:0.000914 t:3.5s +tttg: c31/154 lr:0.000908 t:3.6s +tttg: c32/154 lr:0.000902 t:3.7s +tttg: c33/154 lr:0.000896 t:3.9s +tttg: c34/154 lr:0.000890 t:4.0s +tttg: c35/154 lr:0.000883 t:4.1s +tttg: c36/154 lr:0.000876 t:4.2s +tttg: c37/154 lr:0.000870 t:4.3s +tttg: c38/154 lr:0.000863 t:4.4s +tttg: c39/154 lr:0.000855 t:4.5s +tttg: c40/154 lr:0.000848 t:4.6s +tttg: c41/154 lr:0.000841 t:4.7s +tttg: c42/154 lr:0.000833 t:4.9s +tttg: c43/154 lr:0.000825 t:5.0s +tttg: c44/154 lr:0.000817 t:5.1s +tttg: c45/154 lr:0.000809 t:5.2s +tttg: c46/154 lr:0.000801 t:5.3s +tttg: c47/154 lr:0.000793 t:5.4s +tttg: c48/154 lr:0.000785 t:5.5s +tttg: c49/154 lr:0.000776 t:5.6s +tttg: c50/154 lr:0.000768 t:5.8s +tttg: c51/154 lr:0.000759 t:5.9s +tttg: c52/154 lr:0.000750 t:6.0s +tttg: c53/154 lr:0.000741 t:6.1s +tttg: c54/154 lr:0.000732 t:6.2s +tttg: c55/154 lr:0.000723 t:6.3s +tttg: c56/154 lr:0.000714 t:6.4s +tttg: c57/154 lr:0.000704 t:6.4s +tttg: c58/154 lr:0.000695 t:6.5s +tttg: c59/154 lr:0.000685 t:6.6s +tttg: c60/154 lr:0.000676 t:6.7s +tttg: c61/154 lr:0.000666 t:6.8s +tttg: c62/154 lr:0.000656 t:6.9s +tttg: c63/154 lr:0.000647 t:7.0s +tttg: c64/154 lr:0.000637 t:7.1s +tttg: c65/154 lr:0.000627 t:7.2s +tttg: c66/154 lr:0.000617 t:7.3s +tttg: c67/154 lr:0.000607 t:7.4s +tttg: c68/154 lr:0.000597 t:7.5s +tttg: c69/154 lr:0.000587 t:7.6s +tttg: c70/154 lr:0.000577 t:7.7s +tttg: c71/154 lr:0.000567 t:7.8s +tttg: c72/154 lr:0.000556 t:7.9s +tttg: c73/154 lr:0.000546 t:7.9s +tttg: c74/154 lr:0.000536 t:8.0s +tttg: c75/154 lr:0.000526 t:8.1s +tttg: c76/154 lr:0.000515 t:8.2s +tttg: c77/154 lr:0.000505 t:8.3s +tttg: c78/154 lr:0.000495 t:8.4s +tttg: c79/154 lr:0.000485 t:8.5s +tttg: c80/154 lr:0.000474 t:8.6s +tttg: c81/154 lr:0.000464 t:8.7s +tttg: c82/154 lr:0.000454 t:8.8s +tttg: c83/154 lr:0.000444 t:8.9s +tttg: c84/154 lr:0.000433 t:9.0s +tttg: c85/154 lr:0.000423 t:9.1s +tttg: c86/154 lr:0.000413 t:9.2s +tttg: c87/154 lr:0.000403 t:9.3s +tttg: c88/154 lr:0.000393 t:9.4s +tttg: c89/154 lr:0.000383 t:9.5s +tttg: c90/154 lr:0.000373 t:9.6s +tttg: c91/154 lr:0.000363 t:9.7s +tttg: c92/154 lr:0.000353 t:9.8s +tttg: c93/154 lr:0.000344 t:9.9s +tttg: c94/154 lr:0.000334 t:10.0s +tttg: c95/154 lr:0.000324 t:10.0s +tttg: c96/154 lr:0.000315 t:10.1s +tttg: c97/154 lr:0.000305 t:10.2s +tttg: c98/154 lr:0.000296 t:10.3s +tttg: c99/154 lr:0.000286 t:10.4s +tttg: c100/154 lr:0.000277 t:10.5s +tttg: c101/154 lr:0.000268 t:10.6s +tttg: c102/154 lr:0.000259 t:10.7s +tttg: c103/154 lr:0.000250 t:10.8s +tttg: c104/154 lr:0.000241 t:10.9s +tttg: c105/154 lr:0.000232 t:11.0s +tttg: c106/154 lr:0.000224 t:11.1s +tttg: c107/154 lr:0.000215 t:11.2s +tttg: c108/154 lr:0.000207 t:11.3s +tttg: c109/154 lr:0.000199 t:11.4s +tttg: c110/154 lr:0.000191 t:11.5s +tttg: c111/154 lr:0.000183 t:11.6s +tttg: c112/154 lr:0.000175 t:11.6s +tttg: c113/154 lr:0.000167 t:11.7s +tttg: c114/154 lr:0.000159 t:11.8s +tttg: c115/154 lr:0.000152 t:11.9s +tttg: c116/154 lr:0.000145 t:12.0s +tttg: c117/154 lr:0.000137 t:12.1s +tttg: c118/154 lr:0.000130 t:12.2s +tttg: c119/154 lr:0.000124 t:12.3s +tttg: c120/154 lr:0.000117 t:12.4s +tttg: c121/154 lr:0.000110 t:12.5s +tttg: c122/154 lr:0.000104 t:12.6s +tttg: c123/154 lr:0.000098 t:12.7s +tttg: c124/154 lr:0.000092 t:12.8s +tttg: c125/154 lr:0.000086 t:12.9s +tttg: c126/154 lr:0.000080 t:13.0s +tttg: c127/154 lr:0.000075 t:13.1s +tttg: c128/154 lr:0.000070 t:13.1s +tttg: c129/154 lr:0.000064 t:13.2s +tttg: c130/154 lr:0.000059 t:13.3s +tttg: c131/154 lr:0.000055 t:13.4s +tttg: c132/154 lr:0.000050 t:13.5s +tttg: c133/154 lr:0.000046 t:13.6s +tttg: c134/154 lr:0.000042 t:13.7s +tttg: c135/154 lr:0.000038 t:13.8s +tttg: c136/154 lr:0.000034 t:13.9s +tttg: c137/154 lr:0.000030 t:14.0s +tttg: c138/154 lr:0.000027 t:14.1s +tttg: c139/154 lr:0.000024 t:14.2s +tttg: c140/154 lr:0.000021 t:14.3s +tttg: c141/154 lr:0.000018 t:14.4s +tttg: c142/154 lr:0.000015 t:14.5s +tttg: c143/154 lr:0.000013 t:14.5s +tttg: c144/154 lr:0.000011 t:14.6s +tttg: c145/154 lr:0.000009 t:14.7s +tttg: c146/154 lr:0.000007 t:14.8s +tttg: c147/154 lr:0.000005 t:14.9s +tttg: c148/154 lr:0.000004 t:15.0s +tttg: c149/154 lr:0.000003 t:15.1s +tttg: c150/154 lr:0.000002 t:15.2s +tttg: c151/154 lr:0.000001 t:15.3s +tttg: c152/154 lr:0.000000 t:15.4s +tttg: c153/154 lr:0.000000 t:15.5s +ttpr: phase:1/1 t:228.9s +ttp: b752/782 bl:2.2943 bb:1.0547 rl:2.1438 rb:1.0088 dl:3222-3283 gd:1 +ttp: b736/782 bl:2.2160 bb:1.0440 rl:2.1539 rb:1.0137 dl:2526-2550 gd:1 +ttp: b732/782 bl:2.3417 bb:1.0784 rl:2.1761 rb:1.0215 dl:2416-2441 gd:1 +ttp: b728/782 bl:2.3330 bb:1.0681 rl:2.1920 rb:1.0264 dl:2306-2324 gd:1 +ttp: b724/782 bl:2.2958 bb:1.0483 rl:2.2012 rb:1.0283 dl:2203-2231 gd:1 +ttp: b720/782 bl:2.3234 bb:1.0508 rl:2.2108 rb:1.0302 dl:2125-2144 gd:1 +ttp: b716/782 bl:2.2240 bb:1.0277 rl:2.2117 rb:1.0300 dl:2054-2069 gd:1 +ttp: b708/782 bl:2.2826 bb:1.0210 rl:2.2161 rb:1.0294 dl:1924-1937 gd:1 +ttp: b700/782 bl:2.2459 bb:1.0029 rl:2.2178 rb:1.0279 dl:1824-1834 gd:1 +ttp: b692/782 bl:2.2659 bb:1.0172 rl:2.2202 rb:1.0273 dl:1737-1746 gd:1 +ttp: b684/782 bl:2.3417 bb:1.0317 rl:2.2257 rb:1.0275 dl:1658-1665 gd:1 +ttp: b676/782 bl:2.3055 bb:1.0370 rl:2.2291 rb:1.0280 dl:1586-1595 gd:1 +ttp: b668/782 bl:2.3000 bb:1.0515 rl:2.2318 rb:1.0289 dl:1521-1530 gd:1 +ttp: b660/782 bl:2.3425 bb:1.0355 rl:2.2358 rb:1.0291 dl:1466-1474 gd:1 +ttp: b652/782 bl:2.2214 bb:1.0098 rl:2.2353 rb:1.0285 dl:1411-1419 gd:1 +ttp: b644/782 bl:2.3353 bb:1.0368 rl:2.2384 rb:1.0287 dl:1362-1367 gd:1 +ttp: b637/782 bl:2.3333 bb:1.0640 rl:2.2412 rb:1.0298 dl:1320-1325 gd:1 +ttp: b629/782 bl:2.3233 bb:0.9998 rl:2.2435 rb:1.0289 dl:1276-1280 gd:1 +ttp: b621/782 bl:2.2637 bb:1.0337 rl:2.2440 rb:1.0290 dl:1231-1237 gd:1 +ttp: b613/782 bl:2.3084 bb:1.0278 rl:2.2456 rb:1.0290 dl:1190-1195 gd:1 +ttp: b605/782 bl:2.2198 bb:1.0123 rl:2.2450 rb:1.0286 dl:1154-1159 gd:1 +ttp: b597/782 bl:2.3377 bb:1.0395 rl:2.2470 rb:1.0288 dl:1119-1124 gd:1 +ttp: b589/782 bl:2.2474 bb:0.9981 rl:2.2470 rb:1.0282 dl:1086-1089 gd:1 +ttp: b581/782 bl:2.2927 bb:1.0231 rl:2.2479 rb:1.0281 dl:1052-1056 gd:1 +ttp: b573/782 bl:2.3382 bb:1.0540 rl:2.2496 rb:1.0286 dl:1021-1025 gd:1 +ttp: b565/782 bl:2.3521 bb:1.0190 rl:2.2515 rb:1.0284 dl:993-997 gd:1 +ttp: b557/782 bl:2.3103 bb:1.0378 rl:2.2525 rb:1.0286 dl:965-968 gd:1 +ttp: b549/782 bl:2.2319 bb:1.0091 rl:2.2521 rb:1.0282 dl:939-943 gd:1 +ttp: b541/782 bl:2.2963 bb:1.0189 rl:2.2528 rb:1.0281 dl:915-918 gd:1 +ttp: b534/782 bl:2.2954 bb:1.0281 rl:2.2535 rb:1.0281 dl:893-896 gd:1 +ttp: b526/782 bl:2.2976 bb:1.0127 rl:2.2541 rb:1.0279 dl:869-872 gd:1 +ttp: b519/782 bl:2.2666 bb:1.0283 rl:2.2543 rb:1.0279 dl:850-852 gd:1 +ttp: b511/782 bl:2.3658 bb:1.0408 rl:2.2558 rb:1.0280 dl:826-829 gd:1 +ttp: b503/782 bl:2.3134 bb:1.0481 rl:2.2565 rb:1.0283 dl:804-807 gd:1 +ttp: b495/782 bl:2.2772 bb:1.0172 rl:2.2568 rb:1.0282 dl:783-785 gd:1 +ttp: b487/782 bl:2.2486 bb:1.0528 rl:2.2567 rb:1.0284 dl:764-766 gd:1 +ttp: b479/782 bl:2.3812 bb:1.0699 rl:2.2581 rb:1.0289 dl:744-747 gd:1 +ttp: b471/782 bl:2.3704 bb:1.0703 rl:2.2594 rb:1.0294 dl:726-728 gd:1 +ttp: b463/782 bl:2.2747 bb:1.0236 rl:2.2595 rb:1.0293 dl:708-710 gd:1 +ttp: b455/782 bl:2.2639 bb:1.0203 rl:2.2596 rb:1.0292 dl:691-693 gd:1 +ttp: b447/782 bl:2.2979 bb:1.0556 rl:2.2600 rb:1.0295 dl:674-676 gd:1 +ttp: b439/782 bl:2.2969 bb:1.0249 rl:2.2603 rb:1.0295 dl:657-659 gd:1 +ttp: b431/782 bl:2.3449 bb:1.0403 rl:2.2611 rb:1.0296 dl:642-643 gd:1 +ttp: b423/782 bl:2.2809 bb:1.0407 rl:2.2613 rb:1.0297 dl:626-629 gd:1 +ttp: b415/782 bl:2.2588 bb:1.0463 rl:2.2613 rb:1.0298 dl:611-613 gd:1 +ttp: b407/782 bl:2.2414 bb:1.0262 rl:2.2611 rb:1.0298 dl:595-597 gd:1 +ttp: b399/782 bl:2.2505 bb:1.0157 rl:2.2610 rb:1.0297 dl:581-582 gd:1 +ttp: b391/782 bl:2.2808 bb:1.0506 rl:2.2612 rb:1.0298 dl:566-568 gd:1 +ttp: b383/782 bl:2.2440 bb:1.0289 rl:2.2610 rb:1.0298 dl:552-554 gd:1 +ttp: b375/782 bl:2.3796 bb:1.0613 rl:2.2619 rb:1.0301 dl:538-540 gd:1 +ttp: b367/782 bl:2.2617 bb:1.0673 rl:2.2619 rb:1.0303 dl:525-527 gd:1 +ttp: b359/782 bl:2.2246 bb:1.0215 rl:2.2617 rb:1.0302 dl:512-513 gd:1 +ttp: b351/782 bl:2.3269 bb:1.0652 rl:2.2621 rb:1.0305 dl:498-499 gd:1 +ttp: b343/782 bl:2.1939 bb:1.0325 rl:2.2617 rb:1.0305 dl:486-488 gd:1 +ttp: b335/782 bl:2.3294 bb:1.0553 rl:2.2621 rb:1.0307 dl:474-476 gd:1 +ttp: b327/782 bl:2.3067 bb:1.0726 rl:2.2624 rb:1.0309 dl:462-463 gd:1 +ttp: b319/782 bl:2.3642 bb:1.0661 rl:2.2630 rb:1.0311 dl:450-451 gd:1 +ttp: b311/782 bl:2.3215 bb:1.0700 rl:2.2633 rb:1.0313 dl:438-439 gd:1 +ttp: b303/782 bl:2.3574 bb:1.0753 rl:2.2638 rb:1.0316 dl:426-427 gd:1 +ttp: b295/782 bl:2.2365 bb:1.0493 rl:2.2637 rb:1.0317 dl:414-415 gd:1 +ttp: b288/782 bl:2.2055 bb:1.0038 rl:2.2634 rb:1.0315 dl:403-405 gd:1 +ttp: b280/782 bl:2.3207 bb:1.0820 rl:2.2636 rb:1.0318 dl:392-394 gd:1 +ttp: b272/782 bl:2.3496 bb:1.0853 rl:2.2641 rb:1.0320 dl:382-383 gd:1 +ttp: b264/782 bl:2.3927 bb:1.0903 rl:2.2647 rb:1.0323 dl:371-372 gd:1 +ttp: b256/782 bl:2.5042 bb:1.1054 rl:2.2658 rb:1.0326 dl:361-362 gd:1 +ttp: b248/782 bl:2.4290 bb:1.1723 rl:2.2665 rb:1.0332 dl:351-352 gd:1 +ttp: b240/782 bl:2.2655 bb:1.0399 rl:2.2665 rb:1.0333 dl:341-342 gd:1 +ttp: b230/782 bl:2.4217 bb:1.1364 rl:2.2671 rb:1.0337 dl:329-330 gd:1 +ttp: b222/782 bl:2.3418 bb:1.0946 rl:2.2674 rb:1.0339 dl:320-321 gd:1 +ttp: b214/782 bl:2.3192 bb:1.1098 rl:2.2676 rb:1.0342 dl:310-312 gd:1 +ttp: b206/782 bl:2.3727 bb:1.0915 rl:2.2680 rb:1.0344 dl:302-303 gd:1 +ttp: b196/782 bl:2.4070 bb:1.0984 rl:2.2685 rb:1.0346 dl:291-292 gd:1 +ttp: b188/782 bl:2.3107 bb:1.0850 rl:2.2686 rb:1.0348 dl:282-283 gd:1 +ttp: b180/782 bl:2.3850 bb:1.0927 rl:2.2690 rb:1.0350 dl:274-275 gd:1 +ttp: b172/782 bl:2.4909 bb:1.1420 rl:2.2697 rb:1.0353 dl:266-267 gd:1 +ttp: b165/782 bl:2.3105 bb:1.0972 rl:2.2699 rb:1.0355 dl:260-260 gd:1 +ttp: b156/782 bl:2.2663 bb:1.1315 rl:2.2699 rb:1.0358 dl:251-252 gd:1 +ttp: b147/782 bl:2.4416 bb:1.1105 rl:2.2704 rb:1.0360 dl:242-243 gd:1 +ttp: b136/782 bl:2.3979 bb:1.1275 rl:2.2707 rb:1.0362 dl:232-233 gd:1 +ttp: b129/782 bl:2.3495 bb:1.1256 rl:2.2709 rb:1.0365 dl:225-226 gd:1 +ttp: b122/782 bl:2.3797 bb:1.1266 rl:2.2712 rb:1.0367 dl:219-219 gd:1 +ttp: b111/782 bl:2.3757 bb:1.1584 rl:2.2715 rb:1.0370 dl:208-210 gd:1 +ttp: b105/782 bl:2.3855 bb:1.1346 rl:2.2717 rb:1.0372 dl:203-204 gd:1 +ttp: b98/782 bl:2.5674 bb:1.2047 rl:2.2724 rb:1.0376 dl:197-198 gd:1 +ttp: b91/782 bl:2.4127 bb:1.1309 rl:2.2727 rb:1.0378 dl:190-191 gd:1 +ttp: b85/782 bl:2.4854 bb:1.1903 rl:2.2732 rb:1.0381 dl:185-186 gd:1 +ttp: b78/782 bl:2.5162 bb:1.1780 rl:2.2737 rb:1.0384 dl:179-180 gd:1 +ttp: b68/782 bl:2.4743 bb:1.1547 rl:2.2741 rb:1.0386 dl:170-171 gd:1 +ttp: b61/782 bl:2.4150 bb:1.1954 rl:2.2744 rb:1.0389 dl:164-165 gd:1 +ttp: b53/782 bl:2.4794 bb:1.1814 rl:2.2747 rb:1.0391 dl:156-157 gd:1 +ttp: b45/782 bl:2.4241 bb:1.1599 rl:2.2750 rb:1.0393 dl:148-149 gd:1 +ttp: b36/782 bl:2.5044 bb:1.2084 rl:2.2754 rb:1.0396 dl:139-140 gd:1 +ttp: b28/782 bl:2.5758 bb:1.1947 rl:2.2758 rb:1.0398 dl:131-132 gd:1 +ttp: b19/782 bl:2.5956 bb:1.1918 rl:2.2763 rb:1.0400 dl:121-122 gd:1 +ttp: b11/782 bl:2.5788 bb:1.1924 rl:2.2767 rb:1.0402 dl:109-110 gd:1 +ttp: b3/782 bl:2.6071 bb:1.1614 rl:2.2770 rb:1.0404 dl:89-93 gd:1 +quantized_ttt_phased val_loss:2.29170541 val_bpb:1.04721994 eval_time:548364ms +total_eval_time:548.4s diff --git a/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/train_seed2026.log b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/train_seed2026.log new file mode 100644 index 0000000000..295630511e --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/train_seed2026.log @@ -0,0 +1,476 @@ +W0430 20:49:44.689000 503047 torch/distributed/run.py:803] +W0430 20:49:44.689000 503047 torch/distributed/run.py:803] ***************************************** +W0430 20:49:44.689000 503047 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0430 20:49:44.689000 503047 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + agree_add_boost: 0.5 + artifact_dir: /workspace/parameter-golf/records/gatedxsa_lqertop1_intimer_p1000_n1_competition_3seed/seed2026 + attn_clip_sigmas: 13.0 + attn_out_gate_enabled: False + attn_out_gate_src: proj + awq_lite_bits: 8 + awq_lite_enabled: True + awq_lite_group_size: 64 + awq_lite_group_top_k: 1 + beta1: 0.9 + beta2: 0.95 + caseops_enabled: True + compressor: pergroup + data_dir: ./data/ + datasets_dir: /tmp/parameter-golf-data-caseops/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved + distributed: True + ema_decay: 0.9965 + embed_bits: 7 + embed_clip_sigmas: 14.0 + embed_lr: 0.6 + embed_wd: 0.085 + enable_looping_at: 0.35 + eval_seq_len: 2560 + eval_stride: 64 + fused_ce_enabled: True + gate_window: 12 + gated_attn_enabled: False + gated_attn_init_std: 0.01 + gated_attn_quant_gate: False + gated_xsa_enabled: True + global_ttt_batch_seqs: 32 + global_ttt_chunk_tokens: 32768 + global_ttt_epochs: 1 + global_ttt_grad_clip: 1.0 + global_ttt_lr: 0.001 + global_ttt_momentum: 0.9 + global_ttt_respect_doc_boundaries: True + global_ttt_warmup_chunks: 0 + global_ttt_warmup_start_lr: 0.0 + gptq_calibration_batches: 16 + gptq_reserve_seconds: 4.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: /workspace/parameter-golf/records/gatedxsa_lqertop1_intimer_p1000_n1_competition_3seed/seed2026/gatedxsa_lqertop1_intimer_p1000_n1_s2026.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + lqer_asym_enabled: True + lqer_asym_group: 64 + lqer_enabled: True + lqer_factor_bits: 4 + lqer_gain_select: False + lqer_rank: 4 + lqer_scope: all + lqer_top_k: 1 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.026 + max_wallclock_seconds: 600.0 + min_lr: 0.1 + mlp_clip_sigmas: 11.5 + mlp_mult: 4.0 + model_dim: 512 + model_path: /workspace/parameter-golf/records/gatedxsa_lqertop1_intimer_p1000_n1_competition_3seed/seed2026/final_model.pt + muon_backend_steps: 5 + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + ngram_hint_precompute_outside: False + ngram_tilt_enabled: True + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_final_lane: mean + parallel_start_layer: 8 + phased_ttt_num_phases: 1 + phased_ttt_prefix_docs: 1000 + qk_gain_init: 5.25 + quantized_model_path: /workspace/parameter-golf/records/gatedxsa_lqertop1_intimer_p1000_n1_competition_3seed/seed2026/final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + rope_yarn: False + run_id: gatedxsa_lqertop1_intimer_p1000_n1_s2026 + scalar_lr: 0.02 + seed: 2026 + skip_gates_enabled: True + skylight_beta2: 0.95 + skylight_muon_enabled: False + skylight_uw_floor: 0.35 + smear_gate_enabled: True + sparse_attn_gate_enabled: True + sparse_attn_gate_init_std: 0.0 + sparse_attn_gate_scale: 1.0 + temperature_scale: 1.0 + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + token_boost: 2.625 + token_order: 16 + token_threshold: 0.8 + tokenizer_path: /tmp/parameter-golf-data-caseops/tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model + train_batch_tokens: 786432 + train_files: /tmp/parameter-golf-data-caseops/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.999 + ttt_chunk_size: 48 + ttt_enabled: True + ttt_eval_batches: + ttt_eval_seq_len: 2560 + ttt_grad_steps: 1 + ttt_k_lora: True + ttt_lora_lr: 0.0001 + ttt_lora_rank: 80 + ttt_mlp_lora: True + ttt_o_lora: True + ttt_optimizer: adam + ttt_weight_decay: 2.0 + val_batch_tokens: 524288 + val_bytes_files: /tmp/parameter-golf-data-caseops/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_bytes_*.bin + val_doc_fraction: 1.0 + val_files: /tmp/parameter-golf-data-caseops/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.75 + warmup_steps: 20 + within_boost: 0.75 + within_tau: 0.45 + word_boost: 0.75 + word_normalize: strip_punct_lower + word_order: 4 + word_tau: 0.65 + world_size: 8 + xsa_last_n: 11 +train_shards: 80 +val_tokens: 47851520 +model_params:35945761 +gptq:reserving 4s, effective=596000ms +warmup_cu_buckets:64,128,192,256 iters_each:3 +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0147 val_bpb: 4.1190 +1/20000 train_loss: 9.0155 train_time: 0.0m tok/s: 15863795 +2/20000 train_loss: 12.9435 train_time: 0.0m tok/s: 10758202 +3/20000 train_loss: 10.1906 train_time: 0.0m tok/s: 9567868 +4/20000 train_loss: 8.6621 train_time: 0.0m tok/s: 9063841 +5/20000 train_loss: 7.8528 train_time: 0.0m tok/s: 8874674 +500/20000 train_loss: 2.7171 train_time: 0.8m tok/s: 8249917 +1000/20000 train_loss: 2.7709 train_time: 1.6m tok/s: 8241513 +1500/20000 train_loss: 2.5756 train_time: 2.4m tok/s: 8239847 +2000/20000 train_loss: 2.5726 train_time: 3.2m tok/s: 8236669 +layer_loop:enabled step:2185 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +2500/20000 train_loss: 2.4651 train_time: 4.2m tok/s: 7762315 +3000/20000 train_loss: 2.6340 train_time: 5.4m tok/s: 7280848 +3500/20000 train_loss: 2.4089 train_time: 6.6m tok/s: 6971672 +4000/20000 train_loss: 2.5637 train_time: 7.8m tok/s: 6757258 +4000/20000 val_loss: 2.3800 val_bpb: 1.0875 +4500/20000 train_loss: 2.3852 train_time: 9.0m tok/s: 6587900 +4916/20000 val_loss: 2.2913 val_bpb: 1.0470 +stopping_early: wallclock_cap train_time: 596080ms step: 4916/20000 +peak memory allocated: 41724 MiB reserved: 46960 MiB +ema:applying EMA weights +diagnostic pre-quantization post-ema val_loss:2.29864451 val_bpb:1.05029930 eval_time:10344ms +Serialized model: 135421514 bytes +Code size (uncompressed): 187853 bytes +Code size (compressed): 47669 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 4.0s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int7)+awqgrpint8+lqer_asym: tok_emb.weight + passthrough (float16): blocks.attn.attn_gate_w, blocks.attn.q_gain, blocks.attn.xsa_alpha, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, parallel_post_lambdas, parallel_resid_lambdas, skip_gates, skip_weights, smear_gate.weight, smear_lambda, softcap_neg, softcap_pos +Serialize: per-group lrzip compression... +Serialize: per-group compression done in 124.9s +Serialized model quantized+pergroup: 15948821 bytes +Total submission size quantized+pergroup: 15996490 bytes +Deserialize: per-group lrzip decompression... +Deserialize: decompression done in 18.0s +diagnostic quantized val_loss:2.31739416 val_bpb:1.05886641 eval_time:12281ms +Deserialize: per-group lrzip decompression... +Deserialize: decompression done in 18.2s +ttt_lora:warming up compile (random tokens, no val data) +ttt_lora:compile warmup done (102.4s) + +beginning TTT eval timer +ngram_tilt:hints total=47851520 gated=13023303 token_gate=628130 within_gate=9866847 word_gate=2891588 agree2plus=303177 +ngram_tilt:precompute_done elapsed=151.79s total_targets=47851520 +ttt_phased: total_docs:50000 prefix_docs:1000 suffix_docs:49000 num_phases:1 boundaries:[1000] +ttp: b780/782 bl:2.2059 bb:1.0627 rl:2.2059 rb:1.0627 dl:13091-17244 gd:0 +ttp: b762/782 bl:2.3218 bb:1.0752 rl:2.2311 rb:1.0655 dl:4032-4142 gd:0 +ttpp: phase:1/1 pd:1424 gd:1000 t:209.9s +tttg: c1/154 lr:0.001000 t:0.2s +tttg: c2/154 lr:0.001000 t:0.3s +tttg: c3/154 lr:0.001000 t:0.4s +tttg: c4/154 lr:0.000999 t:0.5s +tttg: c5/154 lr:0.000998 t:0.8s +tttg: c6/154 lr:0.000997 t:0.9s +tttg: c7/154 lr:0.000996 t:1.0s +tttg: c8/154 lr:0.000995 t:1.1s +tttg: c9/154 lr:0.000993 t:1.2s +tttg: c10/154 lr:0.000991 t:1.3s +tttg: c11/154 lr:0.000989 t:1.4s +tttg: c12/154 lr:0.000987 t:1.4s +tttg: c13/154 lr:0.000985 t:1.5s +tttg: c14/154 lr:0.000982 t:1.6s +tttg: c15/154 lr:0.000979 t:1.7s +tttg: c16/154 lr:0.000976 t:1.8s +tttg: c17/154 lr:0.000973 t:1.9s +tttg: c18/154 lr:0.000970 t:2.0s +tttg: c19/154 lr:0.000966 t:2.1s +tttg: c20/154 lr:0.000962 t:2.2s +tttg: c21/154 lr:0.000958 t:2.3s +tttg: c22/154 lr:0.000954 t:2.4s +tttg: c23/154 lr:0.000950 t:2.5s +tttg: c24/154 lr:0.000945 t:2.5s +tttg: c25/154 lr:0.000941 t:2.6s +tttg: c26/154 lr:0.000936 t:2.7s +tttg: c27/154 lr:0.000930 t:2.8s +tttg: c28/154 lr:0.000925 t:2.9s +tttg: c29/154 lr:0.000920 t:3.0s +tttg: c30/154 lr:0.000914 t:3.1s +tttg: c31/154 lr:0.000908 t:3.2s +tttg: c32/154 lr:0.000902 t:3.3s +tttg: c33/154 lr:0.000896 t:3.4s +tttg: c34/154 lr:0.000890 t:3.5s +tttg: c35/154 lr:0.000883 t:3.6s +tttg: c36/154 lr:0.000876 t:3.7s +tttg: c37/154 lr:0.000870 t:3.8s +tttg: c38/154 lr:0.000863 t:3.9s +tttg: c39/154 lr:0.000855 t:3.9s +tttg: c40/154 lr:0.000848 t:4.0s +tttg: c41/154 lr:0.000841 t:4.1s +tttg: c42/154 lr:0.000833 t:4.2s +tttg: c43/154 lr:0.000825 t:4.3s +tttg: c44/154 lr:0.000817 t:4.4s +tttg: c45/154 lr:0.000809 t:4.5s +tttg: c46/154 lr:0.000801 t:4.6s +tttg: c47/154 lr:0.000793 t:4.7s +tttg: c48/154 lr:0.000785 t:4.8s +tttg: c49/154 lr:0.000776 t:4.9s +tttg: c50/154 lr:0.000768 t:5.0s +tttg: c51/154 lr:0.000759 t:5.0s +tttg: c52/154 lr:0.000750 t:5.1s +tttg: c53/154 lr:0.000741 t:5.2s +tttg: c54/154 lr:0.000732 t:5.3s +tttg: c55/154 lr:0.000723 t:5.4s +tttg: c56/154 lr:0.000714 t:5.5s +tttg: c57/154 lr:0.000704 t:5.6s +tttg: c58/154 lr:0.000695 t:5.7s +tttg: c59/154 lr:0.000685 t:5.8s +tttg: c60/154 lr:0.000676 t:5.9s +tttg: c61/154 lr:0.000666 t:6.0s +tttg: c62/154 lr:0.000656 t:6.1s +tttg: c63/154 lr:0.000647 t:6.1s +tttg: c64/154 lr:0.000637 t:6.2s +tttg: c65/154 lr:0.000627 t:6.3s +tttg: c66/154 lr:0.000617 t:6.4s +tttg: c67/154 lr:0.000607 t:6.5s +tttg: c68/154 lr:0.000597 t:6.6s +tttg: c69/154 lr:0.000587 t:6.7s +tttg: c70/154 lr:0.000577 t:6.8s +tttg: c71/154 lr:0.000567 t:6.9s +tttg: c72/154 lr:0.000556 t:7.0s +tttg: c73/154 lr:0.000546 t:7.1s +tttg: c74/154 lr:0.000536 t:7.2s +tttg: c75/154 lr:0.000526 t:7.3s +tttg: c76/154 lr:0.000515 t:7.4s +tttg: c77/154 lr:0.000505 t:7.4s +tttg: c78/154 lr:0.000495 t:7.5s +tttg: c79/154 lr:0.000485 t:7.6s +tttg: c80/154 lr:0.000474 t:7.8s +tttg: c81/154 lr:0.000464 t:7.8s +tttg: c82/154 lr:0.000454 t:7.9s +tttg: c83/154 lr:0.000444 t:8.0s +tttg: c84/154 lr:0.000433 t:8.1s +tttg: c85/154 lr:0.000423 t:8.2s +tttg: c86/154 lr:0.000413 t:8.3s +tttg: c87/154 lr:0.000403 t:8.4s +tttg: c88/154 lr:0.000393 t:8.5s +tttg: c89/154 lr:0.000383 t:8.6s +tttg: c90/154 lr:0.000373 t:8.7s +tttg: c91/154 lr:0.000363 t:8.8s +tttg: c92/154 lr:0.000353 t:8.8s +tttg: c93/154 lr:0.000344 t:8.9s +tttg: c94/154 lr:0.000334 t:9.0s +tttg: c95/154 lr:0.000324 t:9.1s +tttg: c96/154 lr:0.000315 t:9.2s +tttg: c97/154 lr:0.000305 t:9.3s +tttg: c98/154 lr:0.000296 t:9.4s +tttg: c99/154 lr:0.000286 t:9.5s +tttg: c100/154 lr:0.000277 t:9.6s +tttg: c101/154 lr:0.000268 t:9.7s +tttg: c102/154 lr:0.000259 t:9.8s +tttg: c103/154 lr:0.000250 t:9.9s +tttg: c104/154 lr:0.000241 t:10.0s +tttg: c105/154 lr:0.000232 t:10.1s +tttg: c106/154 lr:0.000224 t:10.2s +tttg: c107/154 lr:0.000215 t:10.3s +tttg: c108/154 lr:0.000207 t:10.4s +tttg: c109/154 lr:0.000199 t:10.5s +tttg: c110/154 lr:0.000191 t:10.6s +tttg: c111/154 lr:0.000183 t:10.7s +tttg: c112/154 lr:0.000175 t:10.8s +tttg: c113/154 lr:0.000167 t:10.9s +tttg: c114/154 lr:0.000159 t:11.0s +tttg: c115/154 lr:0.000152 t:11.1s +tttg: c116/154 lr:0.000145 t:11.2s +tttg: c117/154 lr:0.000137 t:11.2s +tttg: c118/154 lr:0.000130 t:11.3s +tttg: c119/154 lr:0.000124 t:11.4s +tttg: c120/154 lr:0.000117 t:11.5s +tttg: c121/154 lr:0.000110 t:11.6s +tttg: c122/154 lr:0.000104 t:11.7s +tttg: c123/154 lr:0.000098 t:11.8s +tttg: c124/154 lr:0.000092 t:11.9s +tttg: c125/154 lr:0.000086 t:12.0s +tttg: c126/154 lr:0.000080 t:12.1s +tttg: c127/154 lr:0.000075 t:12.2s +tttg: c128/154 lr:0.000070 t:12.3s +tttg: c129/154 lr:0.000064 t:12.4s +tttg: c130/154 lr:0.000059 t:12.5s +tttg: c131/154 lr:0.000055 t:12.6s +tttg: c132/154 lr:0.000050 t:12.6s +tttg: c133/154 lr:0.000046 t:12.7s +tttg: c134/154 lr:0.000042 t:12.8s +tttg: c135/154 lr:0.000038 t:12.9s +tttg: c136/154 lr:0.000034 t:13.0s +tttg: c137/154 lr:0.000030 t:13.1s +tttg: c138/154 lr:0.000027 t:13.2s +tttg: c139/154 lr:0.000024 t:13.3s +tttg: c140/154 lr:0.000021 t:13.4s +tttg: c141/154 lr:0.000018 t:13.5s +tttg: c142/154 lr:0.000015 t:13.6s +tttg: c143/154 lr:0.000013 t:13.7s +tttg: c144/154 lr:0.000011 t:13.8s +tttg: c145/154 lr:0.000009 t:13.9s +tttg: c146/154 lr:0.000007 t:13.9s +tttg: c147/154 lr:0.000005 t:14.0s +tttg: c148/154 lr:0.000004 t:14.1s +tttg: c149/154 lr:0.000003 t:14.2s +tttg: c150/154 lr:0.000002 t:14.3s +tttg: c151/154 lr:0.000001 t:14.4s +tttg: c152/154 lr:0.000000 t:14.5s +tttg: c153/154 lr:0.000000 t:14.6s +ttpr: phase:1/1 t:226.7s +ttp: b752/782 bl:2.2919 bb:1.0536 rl:2.2401 rb:1.0636 dl:3222-3283 gd:1 +ttp: b751/782 bl:2.2746 bb:1.0182 rl:2.2444 rb:1.0576 dl:3150-3221 gd:1 +ttp: b747/782 bl:2.2712 bb:1.0380 rl:2.2473 rb:1.0555 dl:2944-2991 gd:1 +ttp: b742/782 bl:2.2932 bb:1.0324 rl:2.2513 rb:1.0534 dl:2730-2762 gd:1 +ttp: b738/782 bl:2.2829 bb:1.0338 rl:2.2538 rb:1.0518 dl:2583-2618 gd:1 +ttp: b734/782 bl:2.2276 bb:1.0134 rl:2.2520 rb:1.0491 dl:2469-2495 gd:1 +ttp: b731/782 bl:2.3095 bb:1.0300 rl:2.2556 rb:1.0478 dl:2377-2414 gd:1 +ttp: b725/782 bl:2.2866 bb:1.0286 rl:2.2573 rb:1.0468 dl:2232-2254 gd:1 +ttp: b722/782 bl:2.3168 bb:1.0382 rl:2.2603 rb:1.0463 dl:2163-2185 gd:1 +ttp: b718/782 bl:2.2590 bb:1.0139 rl:2.2602 rb:1.0447 dl:2089-2106 gd:1 +ttp: b714/782 bl:2.2818 bb:1.0107 rl:2.2612 rb:1.0432 dl:2018-2035 gd:1 +ttp: b709/782 bl:2.4169 bb:1.0811 rl:2.2674 rb:1.0448 dl:1937-1952 gd:1 +ttp: b703/782 bl:2.3048 bb:1.0139 rl:2.2687 rb:1.0436 dl:1859-1872 gd:1 +ttp: b696/782 bl:2.2779 bb:1.0374 rl:2.2691 rb:1.0434 dl:1779-1790 gd:1 +ttp: b688/782 bl:2.3693 bb:1.0607 rl:2.2722 rb:1.0439 dl:1696-1706 gd:1 +ttp: b680/782 bl:2.2557 bb:1.0159 rl:2.2717 rb:1.0431 dl:1618-1628 gd:1 +ttp: b673/782 bl:2.3334 bb:1.0475 rl:2.2734 rb:1.0432 dl:1562-1571 gd:1 +ttp: b664/782 bl:2.3107 bb:1.0141 rl:2.2743 rb:1.0424 dl:1493-1499 gd:1 +ttp: b654/782 bl:2.2611 bb:1.0227 rl:2.2740 rb:1.0420 dl:1425-1432 gd:1 +ttp: b647/782 bl:2.2489 bb:1.0206 rl:2.2735 rb:1.0415 dl:1382-1387 gd:1 +ttp: b641/782 bl:2.2535 bb:1.0086 rl:2.2730 rb:1.0408 dl:1343-1349 gd:1 +ttp: b629/782 bl:2.3193 bb:0.9981 rl:2.2740 rb:1.0399 dl:1276-1280 gd:1 +ttp: b624/782 bl:2.3230 bb:1.0515 rl:2.2749 rb:1.0401 dl:1249-1255 gd:1 +ttp: b614/782 bl:2.2796 bb:1.0357 rl:2.2750 rb:1.0400 dl:1195-1200 gd:1 +ttp: b611/782 bl:2.2687 bb:1.0131 rl:2.2749 rb:1.0395 dl:1182-1186 gd:1 +ttp: b602/782 bl:2.3472 bb:1.0354 rl:2.2761 rb:1.0395 dl:1141-1146 gd:1 +ttp: b594/782 bl:2.3027 bb:1.0513 rl:2.2765 rb:1.0396 dl:1107-1110 gd:1 +ttp: b582/782 bl:2.3184 bb:1.0184 rl:2.2771 rb:1.0393 dl:1056-1060 gd:1 +ttp: b573/782 bl:2.3391 bb:1.0544 rl:2.2780 rb:1.0395 dl:1021-1025 gd:1 +ttp: b566/782 bl:2.2639 bb:1.0112 rl:2.2778 rb:1.0391 dl:997-1001 gd:1 +ttp: b561/782 bl:2.2116 bb:0.9976 rl:2.2769 rb:1.0386 dl:979-983 gd:1 +ttp: b553/782 bl:2.2585 bb:1.0183 rl:2.2767 rb:1.0383 dl:952-955 gd:1 +ttp: b546/782 bl:2.2896 bb:1.0180 rl:2.2769 rb:1.0381 dl:930-934 gd:1 +ttp: b538/782 bl:2.3083 bb:1.0334 rl:2.2772 rb:1.0380 dl:905-909 gd:1 +ttp: b531/782 bl:2.2565 bb:1.0244 rl:2.2770 rb:1.0379 dl:884-887 gd:1 +ttp: b524/782 bl:2.3408 bb:1.0517 rl:2.2777 rb:1.0380 dl:863-866 gd:1 +ttp: b515/782 bl:2.3088 bb:1.0281 rl:2.2780 rb:1.0379 dl:838-841 gd:1 +ttp: b507/782 bl:2.2616 bb:1.0126 rl:2.2778 rb:1.0376 dl:814-817 gd:1 +ttp: b502/782 bl:2.2897 bb:1.0146 rl:2.2780 rb:1.0374 dl:802-804 gd:1 +ttp: b494/782 bl:2.2919 bb:1.0446 rl:2.2781 rb:1.0375 dl:780-783 gd:1 +ttp: b486/782 bl:2.3733 bb:1.0663 rl:2.2790 rb:1.0378 dl:761-764 gd:1 +ttp: b478/782 bl:2.3128 bb:1.0650 rl:2.2793 rb:1.0380 dl:742-744 gd:1 +ttp: b470/782 bl:2.3151 bb:1.0419 rl:2.2796 rb:1.0380 dl:724-726 gd:1 +ttp: b460/782 bl:2.2137 bb:1.0356 rl:2.2790 rb:1.0380 dl:701-703 gd:1 +ttp: b451/782 bl:2.3722 bb:1.0734 rl:2.2798 rb:1.0383 dl:682-685 gd:1 +ttp: b443/782 bl:2.2044 bb:1.0372 rl:2.2792 rb:1.0383 dl:666-668 gd:1 +ttp: b436/782 bl:2.2368 bb:1.0332 rl:2.2789 rb:1.0382 dl:651-653 gd:1 +ttp: b428/782 bl:2.2714 bb:1.0350 rl:2.2788 rb:1.0382 dl:636-638 gd:1 +ttp: b420/782 bl:2.3168 bb:1.0342 rl:2.2791 rb:1.0382 dl:620-622 gd:1 +ttp: b413/782 bl:2.3392 bb:1.0484 rl:2.2795 rb:1.0383 dl:607-609 gd:1 +ttp: b405/782 bl:2.3252 bb:1.0434 rl:2.2798 rb:1.0383 dl:592-593 gd:1 +ttp: b397/782 bl:2.3302 bb:1.0334 rl:2.2801 rb:1.0383 dl:577-579 gd:1 +ttp: b387/782 bl:2.3369 bb:1.0719 rl:2.2805 rb:1.0385 dl:559-561 gd:1 +ttp: b380/782 bl:2.3214 bb:1.0705 rl:2.2807 rb:1.0387 dl:547-549 gd:1 +ttp: b374/782 bl:2.2635 bb:1.0204 rl:2.2806 rb:1.0386 dl:537-538 gd:1 +ttp: b365/782 bl:2.3055 bb:1.0244 rl:2.2808 rb:1.0385 dl:522-524 gd:1 +ttp: b357/782 bl:2.2981 bb:1.0536 rl:2.2809 rb:1.0386 dl:508-510 gd:1 +ttp: b349/782 bl:2.3124 bb:1.0085 rl:2.2810 rb:1.0384 dl:495-496 gd:1 +ttp: b341/782 bl:2.2732 bb:1.0648 rl:2.2810 rb:1.0385 dl:483-485 gd:1 +ttp: b333/782 bl:2.3912 bb:1.0643 rl:2.2815 rb:1.0386 dl:471-472 gd:1 +ttp: b325/782 bl:2.3216 bb:1.0679 rl:2.2817 rb:1.0388 dl:459-461 gd:1 +ttp: b303/782 bl:2.3516 bb:1.0726 rl:2.2820 rb:1.0389 dl:426-427 gd:1 +ttp: b295/782 bl:2.2388 bb:1.0503 rl:2.2819 rb:1.0390 dl:414-415 gd:1 +ttp: b287/782 bl:2.3642 bb:1.0771 rl:2.2822 rb:1.0391 dl:402-403 gd:1 +ttp: b279/782 bl:2.2804 bb:1.0775 rl:2.2822 rb:1.0393 dl:391-392 gd:1 +ttp: b271/782 bl:2.3328 bb:1.1050 rl:2.2824 rb:1.0395 dl:380-382 gd:1 +ttp: b263/782 bl:2.3580 bb:1.0667 rl:2.2827 rb:1.0397 dl:370-371 gd:1 +ttp: b255/782 bl:2.3343 bb:1.0766 rl:2.2829 rb:1.0398 dl:360-361 gd:1 +ttp: b247/782 bl:2.3161 bb:1.0780 rl:2.2830 rb:1.0399 dl:350-351 gd:1 +ttp: b239/782 bl:2.3418 bb:1.0874 rl:2.2832 rb:1.0401 dl:340-341 gd:1 +ttp: b231/782 bl:2.2666 bb:1.0647 rl:2.2831 rb:1.0402 dl:330-331 gd:1 +ttp: b224/782 bl:2.3374 bb:1.0711 rl:2.2833 rb:1.0403 dl:322-323 gd:1 +ttp: b217/782 bl:2.3305 bb:1.1127 rl:2.2835 rb:1.0405 dl:314-315 gd:1 +ttp: b210/782 bl:2.2164 bb:1.0627 rl:2.2833 rb:1.0405 dl:306-307 gd:1 +ttp: b199/782 bl:2.4008 bb:1.1296 rl:2.2836 rb:1.0408 dl:295-296 gd:1 +ttp: b191/782 bl:2.3826 bb:1.0839 rl:2.2839 rb:1.0409 dl:285-286 gd:1 +ttp: b182/782 bl:2.3181 bb:1.1021 rl:2.2840 rb:1.0411 dl:276-277 gd:1 +ttp: b167/782 bl:2.4849 bb:1.1086 rl:2.2845 rb:1.0413 dl:262-263 gd:1 +ttp: b156/782 bl:2.2663 bb:1.1315 rl:2.2845 rb:1.0415 dl:251-252 gd:1 +ttp: b148/782 bl:2.2962 bb:1.0864 rl:2.2845 rb:1.0416 dl:243-244 gd:1 +ttp: b139/782 bl:2.4047 bb:1.1202 rl:2.2848 rb:1.0418 dl:234-235 gd:1 +ttp: b128/782 bl:2.3649 bb:1.1430 rl:2.2849 rb:1.0420 dl:224-225 gd:1 +ttp: b120/782 bl:2.3542 bb:1.0939 rl:2.2851 rb:1.0421 dl:217-218 gd:1 +ttp: b112/782 bl:2.4530 bb:1.1708 rl:2.2854 rb:1.0423 dl:210-210 gd:1 +ttp: b103/782 bl:2.3991 bb:1.1548 rl:2.2857 rb:1.0425 dl:202-202 gd:1 +ttp: b94/782 bl:2.5333 bb:1.1970 rl:2.2861 rb:1.0428 dl:193-194 gd:1 +ttp: b85/782 bl:2.4758 bb:1.1857 rl:2.2865 rb:1.0431 dl:185-186 gd:1 +ttp: b77/782 bl:2.4760 bb:1.2162 rl:2.2868 rb:1.0433 dl:178-179 gd:1 +ttp: b69/782 bl:2.4343 bb:1.1883 rl:2.2870 rb:1.0436 dl:171-172 gd:1 +ttp: b60/782 bl:2.4180 bb:1.1622 rl:2.2872 rb:1.0437 dl:163-164 gd:1 +ttp: b51/782 bl:2.4493 bb:1.1718 rl:2.2875 rb:1.0439 dl:154-155 gd:1 +ttp: b43/782 bl:2.4581 bb:1.2001 rl:2.2877 rb:1.0441 dl:146-147 gd:1 +ttp: b30/782 bl:2.5385 bb:1.2377 rl:2.2881 rb:1.0444 dl:133-134 gd:1 +ttp: b22/782 bl:2.5240 bb:1.1815 rl:2.2883 rb:1.0445 dl:124-126 gd:1 +ttp: b14/782 bl:2.5455 bb:1.1620 rl:2.2886 rb:1.0447 dl:114-115 gd:1 +ttp: b6/782 bl:2.6764 bb:1.1933 rl:2.2890 rb:1.0448 dl:99-101 gd:1 +quantized_ttt_phased val_loss:2.28941372 val_bpb:1.04617273 eval_time:546342ms +total_eval_time:546.3s diff --git a/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/train_seed42.log b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/train_seed42.log new file mode 100644 index 0000000000..152d6292d4 --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_GatedXSA_LQERTop1_IntimerNgramTTT/train_seed42.log @@ -0,0 +1,479 @@ +W0430 19:56:32.585000 322480 torch/distributed/run.py:803] +W0430 19:56:32.585000 322480 torch/distributed/run.py:803] ***************************************** +W0430 19:56:32.585000 322480 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0430 19:56:32.585000 322480 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + agree_add_boost: 0.5 + artifact_dir: /workspace/parameter-golf/records/gatedxsa_lqertop1_intimer_p1000_n1_competition_3seed/seed42 + attn_clip_sigmas: 13.0 + attn_out_gate_enabled: False + attn_out_gate_src: proj + awq_lite_bits: 8 + awq_lite_enabled: True + awq_lite_group_size: 64 + awq_lite_group_top_k: 1 + beta1: 0.9 + beta2: 0.95 + caseops_enabled: True + compressor: pergroup + data_dir: ./data/ + datasets_dir: /tmp/parameter-golf-data-caseops/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved + distributed: True + ema_decay: 0.9965 + embed_bits: 7 + embed_clip_sigmas: 14.0 + embed_lr: 0.6 + embed_wd: 0.085 + enable_looping_at: 0.35 + eval_seq_len: 2560 + eval_stride: 64 + fused_ce_enabled: True + gate_window: 12 + gated_attn_enabled: False + gated_attn_init_std: 0.01 + gated_attn_quant_gate: False + gated_xsa_enabled: True + global_ttt_batch_seqs: 32 + global_ttt_chunk_tokens: 32768 + global_ttt_epochs: 1 + global_ttt_grad_clip: 1.0 + global_ttt_lr: 0.001 + global_ttt_momentum: 0.9 + global_ttt_respect_doc_boundaries: True + global_ttt_warmup_chunks: 0 + global_ttt_warmup_start_lr: 0.0 + gptq_calibration_batches: 16 + gptq_reserve_seconds: 4.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: /workspace/parameter-golf/records/gatedxsa_lqertop1_intimer_p1000_n1_competition_3seed/seed42/gatedxsa_lqertop1_intimer_p1000_n1_s42.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + lqer_asym_enabled: True + lqer_asym_group: 64 + lqer_enabled: True + lqer_factor_bits: 4 + lqer_gain_select: False + lqer_rank: 4 + lqer_scope: all + lqer_top_k: 1 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.026 + max_wallclock_seconds: 600.0 + min_lr: 0.1 + mlp_clip_sigmas: 11.5 + mlp_mult: 4.0 + model_dim: 512 + model_path: /workspace/parameter-golf/records/gatedxsa_lqertop1_intimer_p1000_n1_competition_3seed/seed42/final_model.pt + muon_backend_steps: 5 + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + ngram_hint_precompute_outside: False + ngram_tilt_enabled: True + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_final_lane: mean + parallel_start_layer: 8 + phased_ttt_num_phases: 1 + phased_ttt_prefix_docs: 1000 + qk_gain_init: 5.25 + quantized_model_path: /workspace/parameter-golf/records/gatedxsa_lqertop1_intimer_p1000_n1_competition_3seed/seed42/final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + rope_yarn: False + run_id: gatedxsa_lqertop1_intimer_p1000_n1_s42 + scalar_lr: 0.02 + seed: 42 + skip_gates_enabled: True + skylight_beta2: 0.95 + skylight_muon_enabled: False + skylight_uw_floor: 0.35 + smear_gate_enabled: True + sparse_attn_gate_enabled: True + sparse_attn_gate_init_std: 0.0 + sparse_attn_gate_scale: 1.0 + temperature_scale: 1.0 + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + token_boost: 2.625 + token_order: 16 + token_threshold: 0.8 + tokenizer_path: /tmp/parameter-golf-data-caseops/tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model + train_batch_tokens: 786432 + train_files: /tmp/parameter-golf-data-caseops/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.999 + ttt_chunk_size: 48 + ttt_enabled: True + ttt_eval_batches: + ttt_eval_seq_len: 2560 + ttt_grad_steps: 1 + ttt_k_lora: True + ttt_lora_lr: 0.0001 + ttt_lora_rank: 80 + ttt_mlp_lora: True + ttt_o_lora: True + ttt_optimizer: adam + ttt_weight_decay: 2.0 + val_batch_tokens: 524288 + val_bytes_files: /tmp/parameter-golf-data-caseops/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_bytes_*.bin + val_doc_fraction: 1.0 + val_files: /tmp/parameter-golf-data-caseops/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.75 + warmup_steps: 20 + within_boost: 0.75 + within_tau: 0.45 + word_boost: 0.75 + word_normalize: strip_punct_lower + word_order: 4 + word_tau: 0.65 + world_size: 8 + xsa_last_n: 11 +train_shards: 80 +val_tokens: 47851520 +model_params:35945761 +gptq:reserving 4s, effective=596000ms +warmup_cu_buckets:64,128,192,256 iters_each:3 +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0076 val_bpb: 4.1158 +1/20000 train_loss: 9.0073 train_time: 0.0m tok/s: 15736525 +2/20000 train_loss: 12.7944 train_time: 0.0m tok/s: 10943073 +3/20000 train_loss: 10.1201 train_time: 0.0m tok/s: 9859656 +4/20000 train_loss: 8.6365 train_time: 0.0m tok/s: 9347564 +5/20000 train_loss: 7.9063 train_time: 0.0m tok/s: 9065215 +500/20000 train_loss: 2.7106 train_time: 0.8m tok/s: 8252492 +1000/20000 train_loss: 2.7695 train_time: 1.6m tok/s: 8237077 +1500/20000 train_loss: 2.5719 train_time: 2.4m tok/s: 8235246 +2000/20000 train_loss: 2.5716 train_time: 3.2m tok/s: 8232168 +layer_loop:enabled step:2183 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +2500/20000 train_loss: 2.4645 train_time: 4.2m tok/s: 7756294 +3000/20000 train_loss: 2.6298 train_time: 5.4m tok/s: 7274311 +3500/20000 train_loss: 2.4053 train_time: 6.6m tok/s: 6966343 +4000/20000 train_loss: 2.5656 train_time: 7.8m tok/s: 6751889 +4000/20000 val_loss: 2.3773 val_bpb: 1.0862 +4500/20000 train_loss: 2.3826 train_time: 9.0m tok/s: 6584068 +4914/20000 val_loss: 2.2883 val_bpb: 1.0456 +stopping_early: wallclock_cap train_time: 596127ms step: 4914/20000 +peak memory allocated: 41724 MiB reserved: 46960 MiB +ema:applying EMA weights +diagnostic pre-quantization post-ema val_loss:2.29647250 val_bpb:1.04930686 eval_time:8935ms +Serialized model: 135421514 bytes +Code size (uncompressed): 187853 bytes +Code size (compressed): 47669 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 4.1s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int7)+awqgrpint8+lqer_asym: tok_emb.weight + passthrough (float16): blocks.attn.attn_gate_w, blocks.attn.q_gain, blocks.attn.xsa_alpha, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, parallel_post_lambdas, parallel_resid_lambdas, skip_gates, skip_weights, smear_gate.weight, smear_lambda, softcap_neg, softcap_pos +Serialize: per-group lrzip compression... +Serialize: per-group compression done in 114.4s +Serialized model quantized+pergroup: 15947905 bytes +Total submission size quantized+pergroup: 15995574 bytes +Deserialize: per-group lrzip decompression... +Deserialize: decompression done in 17.9s +diagnostic quantized val_loss:2.31491829 val_bpb:1.05773513 eval_time:14035ms +Deserialize: per-group lrzip decompression... +Deserialize: decompression done in 18.6s +ttt_lora:warming up compile (random tokens, no val data) +ttt_lora:compile warmup done (99.2s) + +beginning TTT eval timer +ngram_tilt:hints total=47851520 gated=13023303 token_gate=628130 within_gate=9866847 word_gate=2891588 agree2plus=303177 +ngram_tilt:precompute_done elapsed=147.40s total_targets=47851520 +ttt_phased: total_docs:50000 prefix_docs:1000 suffix_docs:49000 num_phases:1 boundaries:[1000] +ttp: b779/782 bl:2.1935 bb:1.0377 rl:2.1935 rb:1.0377 dl:10442-13079 gd:0 +ttp: b769/782 bl:2.2927 bb:1.0673 rl:2.2240 rb:1.0469 dl:5097-5309 gd:0 +ttp: b764/782 bl:2.2597 bb:1.0585 rl:2.2313 rb:1.0493 dl:4284-4392 gd:0 +ttpp: phase:1/1 pd:1424 gd:1000 t:210.5s +tttg: c1/154 lr:0.001000 t:0.2s +tttg: c2/154 lr:0.001000 t:0.3s +tttg: c3/154 lr:0.001000 t:0.4s +tttg: c4/154 lr:0.000999 t:0.5s +tttg: c5/154 lr:0.000998 t:0.6s +tttg: c6/154 lr:0.000997 t:0.6s +tttg: c7/154 lr:0.000996 t:0.7s +tttg: c8/154 lr:0.000995 t:0.8s +tttg: c9/154 lr:0.000993 t:0.9s +tttg: c10/154 lr:0.000991 t:1.0s +tttg: c11/154 lr:0.000989 t:1.1s +tttg: c12/154 lr:0.000987 t:1.2s +tttg: c13/154 lr:0.000985 t:1.3s +tttg: c14/154 lr:0.000982 t:1.4s +tttg: c15/154 lr:0.000979 t:1.5s +tttg: c16/154 lr:0.000976 t:1.6s +tttg: c17/154 lr:0.000973 t:1.6s +tttg: c18/154 lr:0.000970 t:1.7s +tttg: c19/154 lr:0.000966 t:1.8s +tttg: c20/154 lr:0.000962 t:1.9s +tttg: c21/154 lr:0.000958 t:2.0s +tttg: c22/154 lr:0.000954 t:2.1s +tttg: c23/154 lr:0.000950 t:2.2s +tttg: c24/154 lr:0.000945 t:2.3s +tttg: c25/154 lr:0.000941 t:2.4s +tttg: c26/154 lr:0.000936 t:2.5s +tttg: c27/154 lr:0.000930 t:2.6s +tttg: c28/154 lr:0.000925 t:2.7s +tttg: c29/154 lr:0.000920 t:2.8s +tttg: c30/154 lr:0.000914 t:2.8s +tttg: c31/154 lr:0.000908 t:2.9s +tttg: c32/154 lr:0.000902 t:3.0s +tttg: c33/154 lr:0.000896 t:3.1s +tttg: c34/154 lr:0.000890 t:3.2s +tttg: c35/154 lr:0.000883 t:3.3s +tttg: c36/154 lr:0.000876 t:3.4s +tttg: c37/154 lr:0.000870 t:3.5s +tttg: c38/154 lr:0.000863 t:3.6s +tttg: c39/154 lr:0.000855 t:3.7s +tttg: c40/154 lr:0.000848 t:3.8s +tttg: c41/154 lr:0.000841 t:3.9s +tttg: c42/154 lr:0.000833 t:4.0s +tttg: c43/154 lr:0.000825 t:4.0s +tttg: c44/154 lr:0.000817 t:4.1s +tttg: c45/154 lr:0.000809 t:4.2s +tttg: c46/154 lr:0.000801 t:4.3s +tttg: c47/154 lr:0.000793 t:4.4s +tttg: c48/154 lr:0.000785 t:4.5s +tttg: c49/154 lr:0.000776 t:4.6s +tttg: c50/154 lr:0.000768 t:4.7s +tttg: c51/154 lr:0.000759 t:4.8s +tttg: c52/154 lr:0.000750 t:4.9s +tttg: c53/154 lr:0.000741 t:5.0s +tttg: c54/154 lr:0.000732 t:5.1s +tttg: c55/154 lr:0.000723 t:5.2s +tttg: c56/154 lr:0.000714 t:5.3s +tttg: c57/154 lr:0.000704 t:5.4s +tttg: c58/154 lr:0.000695 t:5.5s +tttg: c59/154 lr:0.000685 t:5.5s +tttg: c60/154 lr:0.000676 t:5.6s +tttg: c61/154 lr:0.000666 t:5.7s +tttg: c62/154 lr:0.000656 t:5.8s +tttg: c63/154 lr:0.000647 t:5.9s +tttg: c64/154 lr:0.000637 t:6.0s +tttg: c65/154 lr:0.000627 t:6.1s +tttg: c66/154 lr:0.000617 t:6.2s +tttg: c67/154 lr:0.000607 t:6.3s +tttg: c68/154 lr:0.000597 t:6.4s +tttg: c69/154 lr:0.000587 t:6.5s +tttg: c70/154 lr:0.000577 t:6.6s +tttg: c71/154 lr:0.000567 t:6.7s +tttg: c72/154 lr:0.000556 t:6.8s +tttg: c73/154 lr:0.000546 t:6.9s +tttg: c74/154 lr:0.000536 t:7.0s +tttg: c75/154 lr:0.000526 t:7.0s +tttg: c76/154 lr:0.000515 t:7.1s +tttg: c77/154 lr:0.000505 t:7.2s +tttg: c78/154 lr:0.000495 t:7.3s +tttg: c79/154 lr:0.000485 t:7.4s +tttg: c80/154 lr:0.000474 t:7.5s +tttg: c81/154 lr:0.000464 t:7.6s +tttg: c82/154 lr:0.000454 t:7.7s +tttg: c83/154 lr:0.000444 t:7.8s +tttg: c84/154 lr:0.000433 t:7.9s +tttg: c85/154 lr:0.000423 t:8.0s +tttg: c86/154 lr:0.000413 t:8.1s +tttg: c87/154 lr:0.000403 t:8.2s +tttg: c88/154 lr:0.000393 t:8.3s +tttg: c89/154 lr:0.000383 t:8.4s +tttg: c90/154 lr:0.000373 t:8.5s +tttg: c91/154 lr:0.000363 t:8.6s +tttg: c92/154 lr:0.000353 t:8.6s +tttg: c93/154 lr:0.000344 t:8.7s +tttg: c94/154 lr:0.000334 t:8.8s +tttg: c95/154 lr:0.000324 t:8.9s +tttg: c96/154 lr:0.000315 t:9.0s +tttg: c97/154 lr:0.000305 t:9.1s +tttg: c98/154 lr:0.000296 t:9.2s +tttg: c99/154 lr:0.000286 t:9.3s +tttg: c100/154 lr:0.000277 t:9.4s +tttg: c101/154 lr:0.000268 t:9.5s +tttg: c102/154 lr:0.000259 t:9.6s +tttg: c103/154 lr:0.000250 t:9.7s +tttg: c104/154 lr:0.000241 t:9.8s +tttg: c105/154 lr:0.000232 t:9.8s +tttg: c106/154 lr:0.000224 t:9.9s +tttg: c107/154 lr:0.000215 t:10.0s +tttg: c108/154 lr:0.000207 t:10.1s +tttg: c109/154 lr:0.000199 t:10.2s +tttg: c110/154 lr:0.000191 t:10.3s +tttg: c111/154 lr:0.000183 t:10.4s +tttg: c112/154 lr:0.000175 t:10.5s +tttg: c113/154 lr:0.000167 t:10.6s +tttg: c114/154 lr:0.000159 t:10.7s +tttg: c115/154 lr:0.000152 t:10.8s +tttg: c116/154 lr:0.000145 t:10.8s +tttg: c117/154 lr:0.000137 t:10.9s +tttg: c118/154 lr:0.000130 t:11.0s +tttg: c119/154 lr:0.000124 t:11.1s +tttg: c120/154 lr:0.000117 t:11.2s +tttg: c121/154 lr:0.000110 t:11.3s +tttg: c122/154 lr:0.000104 t:11.4s +tttg: c123/154 lr:0.000098 t:11.5s +tttg: c124/154 lr:0.000092 t:11.6s +tttg: c125/154 lr:0.000086 t:11.7s +tttg: c126/154 lr:0.000080 t:11.8s +tttg: c127/154 lr:0.000075 t:11.9s +tttg: c128/154 lr:0.000070 t:11.9s +tttg: c129/154 lr:0.000064 t:12.0s +tttg: c130/154 lr:0.000059 t:12.1s +tttg: c131/154 lr:0.000055 t:12.2s +tttg: c132/154 lr:0.000050 t:12.3s +tttg: c133/154 lr:0.000046 t:12.4s +tttg: c134/154 lr:0.000042 t:12.5s +tttg: c135/154 lr:0.000038 t:12.6s +tttg: c136/154 lr:0.000034 t:12.7s +tttg: c137/154 lr:0.000030 t:12.7s +tttg: c138/154 lr:0.000027 t:12.8s +tttg: c139/154 lr:0.000024 t:12.9s +tttg: c140/154 lr:0.000021 t:13.0s +tttg: c141/154 lr:0.000018 t:13.1s +tttg: c142/154 lr:0.000015 t:13.2s +tttg: c143/154 lr:0.000013 t:13.3s +tttg: c144/154 lr:0.000011 t:13.4s +tttg: c145/154 lr:0.000009 t:13.5s +tttg: c146/154 lr:0.000007 t:13.6s +tttg: c147/154 lr:0.000005 t:13.7s +tttg: c148/154 lr:0.000004 t:13.8s +tttg: c149/154 lr:0.000003 t:13.9s +tttg: c150/154 lr:0.000002 t:14.0s +tttg: c151/154 lr:0.000001 t:14.0s +tttg: c152/154 lr:0.000000 t:14.1s +tttg: c153/154 lr:0.000000 t:14.2s +ttpr: phase:1/1 t:227.1s +ttp: b752/782 bl:2.2884 bb:1.0520 rl:2.2389 rb:1.0496 dl:3222-3283 gd:1 +ttp: b751/782 bl:2.2719 bb:1.0170 rl:2.2427 rb:1.0457 dl:3150-3221 gd:1 +ttp: b747/782 bl:2.2691 bb:1.0371 rl:2.2452 rb:1.0449 dl:2944-2991 gd:1 +ttp: b741/782 bl:2.2801 bb:1.0225 rl:2.2481 rb:1.0430 dl:2686-2730 gd:1 +ttp: b738/782 bl:2.2848 bb:1.0346 rl:2.2507 rb:1.0424 dl:2583-2618 gd:1 +ttp: b734/782 bl:2.2253 bb:1.0124 rl:2.2491 rb:1.0404 dl:2469-2495 gd:1 +ttp: b730/782 bl:2.2386 bb:0.9838 rl:2.2485 rb:1.0370 dl:2352-2376 gd:1 +ttp: b725/782 bl:2.2832 bb:1.0271 rl:2.2503 rb:1.0364 dl:2232-2254 gd:1 +ttp: b723/782 bl:2.2586 bb:1.0140 rl:2.2507 rb:1.0353 dl:2185-2203 gd:1 +ttp: b718/782 bl:2.2537 bb:1.0115 rl:2.2508 rb:1.0342 dl:2089-2106 gd:1 +ttp: b715/782 bl:2.3189 bb:1.0109 rl:2.2537 rb:1.0332 dl:2036-2053 gd:1 +ttp: b709/782 bl:2.4123 bb:1.0791 rl:2.2597 rb:1.0350 dl:1937-1952 gd:1 +ttp: b705/782 bl:2.3318 bb:1.0481 rl:2.2622 rb:1.0355 dl:1885-1898 gd:1 +ttp: b696/782 bl:2.2755 bb:1.0363 rl:2.2627 rb:1.0355 dl:1779-1790 gd:1 +ttp: b688/782 bl:2.3665 bb:1.0594 rl:2.2658 rb:1.0362 dl:1696-1706 gd:1 +ttp: b680/782 bl:2.2513 bb:1.0139 rl:2.2654 rb:1.0356 dl:1618-1628 gd:1 +ttp: b675/782 bl:2.3361 bb:1.0447 rl:2.2672 rb:1.0358 dl:1578-1586 gd:1 +ttp: b663/782 bl:2.2937 bb:1.0260 rl:2.2679 rb:1.0356 dl:1486-1493 gd:1 +ttp: b655/782 bl:2.3485 bb:1.0300 rl:2.2697 rb:1.0355 dl:1432-1439 gd:1 +ttp: b649/782 bl:2.2514 bb:1.0010 rl:2.2693 rb:1.0347 dl:1392-1398 gd:1 +ttp: b639/782 bl:2.2764 bb:1.0165 rl:2.2695 rb:1.0343 dl:1331-1337 gd:1 +ttp: b633/782 bl:2.2441 bb:1.0083 rl:2.2690 rb:1.0338 dl:1297-1302 gd:1 +ttp: b623/782 bl:2.3023 bb:1.0047 rl:2.2696 rb:1.0333 dl:1243-1249 gd:1 +ttp: b619/782 bl:2.2950 bb:1.0466 rl:2.2700 rb:1.0335 dl:1221-1226 gd:1 +ttp: b610/782 bl:2.2140 bb:0.9900 rl:2.2691 rb:1.0327 dl:1177-1182 gd:1 +ttp: b602/782 bl:2.3421 bb:1.0331 rl:2.2703 rb:1.0328 dl:1141-1146 gd:1 +ttp: b594/782 bl:2.3017 bb:1.0508 rl:2.2707 rb:1.0330 dl:1107-1110 gd:1 +ttp: b584/782 bl:2.2592 bb:1.0214 rl:2.2706 rb:1.0329 dl:1064-1069 gd:1 +ttp: b576/782 bl:2.3541 bb:1.0828 rl:2.2717 rb:1.0335 dl:1033-1037 gd:1 +ttp: b568/782 bl:2.3230 bb:1.0665 rl:2.2724 rb:1.0340 dl:1004-1007 gd:1 +ttp: b564/782 bl:2.2509 bb:1.0015 rl:2.2721 rb:1.0335 dl:990-993 gd:1 +ttp: b556/782 bl:2.3433 bb:1.0535 rl:2.2730 rb:1.0338 dl:961-965 gd:1 +ttp: b543/782 bl:2.3010 bb:1.0417 rl:2.2733 rb:1.0339 dl:921-924 gd:1 +ttp: b535/782 bl:2.3514 bb:1.0198 rl:2.2742 rb:1.0337 dl:896-899 gd:1 +ttp: b527/782 bl:2.3113 bb:1.0144 rl:2.2746 rb:1.0335 dl:872-875 gd:1 +ttp: b519/782 bl:2.2643 bb:1.0272 rl:2.2745 rb:1.0334 dl:850-852 gd:1 +ttp: b513/782 bl:2.3324 bb:1.0240 rl:2.2751 rb:1.0333 dl:832-835 gd:1 +ttp: b506/782 bl:2.3094 bb:0.9972 rl:2.2754 rb:1.0330 dl:812-814 gd:1 +ttp: b501/782 bl:2.3487 bb:1.0376 rl:2.2761 rb:1.0330 dl:799-802 gd:1 +ttp: b494/782 bl:2.2892 bb:1.0434 rl:2.2763 rb:1.0331 dl:780-783 gd:1 +ttp: b486/782 bl:2.3773 bb:1.0681 rl:2.2772 rb:1.0334 dl:761-764 gd:1 +ttp: b478/782 bl:2.3081 bb:1.0628 rl:2.2774 rb:1.0337 dl:742-744 gd:1 +ttp: b470/782 bl:2.3112 bb:1.0402 rl:2.2777 rb:1.0337 dl:724-726 gd:1 +ttp: b457/782 bl:2.2144 bb:1.0138 rl:2.2772 rb:1.0336 dl:695-697 gd:1 +ttp: b448/782 bl:2.2801 bb:0.9940 rl:2.2772 rb:1.0333 dl:677-678 gd:1 +ttp: b440/782 bl:2.2109 bb:0.9732 rl:2.2767 rb:1.0328 dl:659-662 gd:1 +ttp: b432/782 bl:2.3062 bb:1.0251 rl:2.2769 rb:1.0327 dl:643-645 gd:1 +ttp: b424/782 bl:2.3147 bb:1.0495 rl:2.2772 rb:1.0329 dl:629-630 gd:1 +ttp: b417/782 bl:2.2233 bb:1.0271 rl:2.2768 rb:1.0328 dl:615-617 gd:1 +ttp: b402/782 bl:2.2131 bb:0.9849 rl:2.2764 rb:1.0325 dl:586-588 gd:1 +ttp: b394/782 bl:2.2107 bb:0.9733 rl:2.2760 rb:1.0321 dl:571-573 gd:1 +ttp: b386/782 bl:2.3028 bb:1.0814 rl:2.2762 rb:1.0324 dl:557-559 gd:1 +ttp: b378/782 bl:2.3920 bb:1.0379 rl:2.2769 rb:1.0324 dl:544-545 gd:1 +ttp: b371/782 bl:2.2174 bb:1.0827 rl:2.2765 rb:1.0327 dl:532-533 gd:1 +ttp: b364/782 bl:2.3139 bb:1.0463 rl:2.2767 rb:1.0328 dl:521-522 gd:1 +ttp: b352/782 bl:2.3864 bb:1.0799 rl:2.2773 rb:1.0330 dl:499-501 gd:1 +ttp: b342/782 bl:2.3379 bb:1.1060 rl:2.2776 rb:1.0334 dl:485-486 gd:1 +ttp: b334/782 bl:2.3499 bb:1.0563 rl:2.2780 rb:1.0335 dl:472-474 gd:1 +ttp: b326/782 bl:2.2675 bb:1.0384 rl:2.2779 rb:1.0335 dl:461-462 gd:1 +ttp: b318/782 bl:2.3083 bb:1.0549 rl:2.2781 rb:1.0336 dl:448-450 gd:1 +ttp: b311/782 bl:2.3164 bb:1.0677 rl:2.2782 rb:1.0338 dl:438-439 gd:1 +ttp: b304/782 bl:2.3076 bb:1.0584 rl:2.2784 rb:1.0339 dl:427-429 gd:1 +ttp: b296/782 bl:2.3533 bb:1.0835 rl:2.2787 rb:1.0341 dl:415-417 gd:1 +ttp: b288/782 bl:2.1997 bb:1.0012 rl:2.2784 rb:1.0340 dl:403-405 gd:1 +ttp: b280/782 bl:2.3188 bb:1.0811 rl:2.2785 rb:1.0341 dl:392-394 gd:1 +ttp: b272/782 bl:2.3409 bb:1.0813 rl:2.2788 rb:1.0343 dl:382-383 gd:1 +ttp: b264/782 bl:2.3889 bb:1.0886 rl:2.2792 rb:1.0345 dl:371-372 gd:1 +ttp: b257/782 bl:2.4069 bb:1.0948 rl:2.2796 rb:1.0347 dl:362-364 gd:1 +ttp: b250/782 bl:2.2825 bb:1.0583 rl:2.2797 rb:1.0348 dl:354-355 gd:1 +ttp: b243/782 bl:2.3235 bb:1.0662 rl:2.2798 rb:1.0349 dl:345-346 gd:1 +ttp: b235/782 bl:2.2487 bb:1.0826 rl:2.2797 rb:1.0351 dl:335-336 gd:1 +ttp: b227/782 bl:2.4621 bb:1.1431 rl:2.2803 rb:1.0354 dl:325-327 gd:1 +ttp: b219/782 bl:2.2872 bb:1.0944 rl:2.2803 rb:1.0356 dl:316-317 gd:1 +ttp: b210/782 bl:2.2223 bb:1.0655 rl:2.2801 rb:1.0357 dl:306-307 gd:1 +ttp: b201/782 bl:2.2599 bb:1.0780 rl:2.2801 rb:1.0358 dl:297-298 gd:1 +ttp: b193/782 bl:2.3276 bb:1.1162 rl:2.2802 rb:1.0360 dl:288-289 gd:1 +ttp: b185/782 bl:2.3856 bb:1.0938 rl:2.2805 rb:1.0362 dl:279-280 gd:1 +ttp: b176/782 bl:2.2817 bb:1.1082 rl:2.2805 rb:1.0363 dl:270-271 gd:1 +ttp: b167/782 bl:2.4785 bb:1.1058 rl:2.2810 rb:1.0365 dl:262-263 gd:1 +ttp: b158/782 bl:2.3009 bb:1.0878 rl:2.2810 rb:1.0366 dl:253-254 gd:1 +ttp: b149/782 bl:2.3232 bb:1.1317 rl:2.2811 rb:1.0369 dl:244-245 gd:1 +ttp: b141/782 bl:2.4356 bb:1.1115 rl:2.2815 rb:1.0370 dl:236-237 gd:1 +ttp: b133/782 bl:2.3331 bb:1.1191 rl:2.2816 rb:1.0372 dl:229-230 gd:1 +ttp: b124/782 bl:2.3331 bb:1.1396 rl:2.2817 rb:1.0374 dl:220-222 gd:1 +ttp: b117/782 bl:2.4476 bb:1.1893 rl:2.2821 rb:1.0377 dl:214-215 gd:1 +ttp: b107/782 bl:2.4107 bb:1.1545 rl:2.2823 rb:1.0379 dl:205-206 gd:1 +ttp: b99/782 bl:2.4558 bb:1.1566 rl:2.2826 rb:1.0381 dl:198-199 gd:1 +ttp: b89/782 bl:2.4576 bb:1.1356 rl:2.2830 rb:1.0383 dl:189-190 gd:1 +ttp: b81/782 bl:2.4277 bb:1.1017 rl:2.2832 rb:1.0384 dl:182-183 gd:1 +ttp: b73/782 bl:2.5147 bb:1.2343 rl:2.2836 rb:1.0387 dl:174-175 gd:1 +ttp: b62/782 bl:2.4040 bb:1.1583 rl:2.2838 rb:1.0389 dl:165-166 gd:1 +ttp: b55/782 bl:2.5772 bb:1.2130 rl:2.2842 rb:1.0391 dl:158-159 gd:1 +ttp: b46/782 bl:2.5024 bb:1.1948 rl:2.2845 rb:1.0393 dl:149-150 gd:1 +ttp: b38/782 bl:2.5644 bb:1.1762 rl:2.2849 rb:1.0395 dl:141-142 gd:1 +ttp: b30/782 bl:2.5345 bb:1.2358 rl:2.2852 rb:1.0397 dl:133-134 gd:1 +ttp: b21/782 bl:2.5535 bb:1.2046 rl:2.2855 rb:1.0399 dl:123-124 gd:1 +ttp: b13/782 bl:2.6471 bb:1.1992 rl:2.2859 rb:1.0401 dl:112-114 gd:1 +ttp: b4/782 bl:2.7165 bb:1.2172 rl:2.2863 rb:1.0403 dl:93-96 gd:1 +quantized_ttt_phased val_loss:2.28708329 val_bpb:1.04510781 eval_time:542399ms +total_eval_time:542.4s