Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Record: Custom Casefold Tokenizer + Parallel Residuals + Systems Optimization

**val_bpb = 1.0639** (3-seed mean, std 0.0006) | **3.0705 nats** | **~15.98 MB** | 8xH100 SXM, 600s | Legal TTT

This submission applies systems-level performance optimizations to PR #1578's casefold tokenizer + PR #1529's parallel residual architecture. The ML is unchanged; faster per-step throughput yields extra training steps in the same 600s budget.

> **Submission series:** This PR is one of three related submissions applying the same systems optimizations to different base stacks:
>
> 1. On PR #1493 (current merged SOTA)
> 2. On PR #1529 (pending review)
> 3. On PR #1578 (pending review) -- **this PR**
>
> The optimizations are identical across all three -- fused Muon kernel, batched EMA, and loader prealloc. We submit against multiple bases so that a ready-to-merge option exists regardless of how the pending PRs are resolved. Judges should feel free to evaluate whichever base(s) they consider valid and disregard the rest.

**Note on record criteria:** Per the official contest rules: *"For submissions that improve speed through systems optimization without changing the ML, this requirement [0.005 nats] is waived."* The three changes are purely systems-level. That said, this submission also clears the 0.005-nat threshold outright (0.0083 nats vs PR #1578).

## 3-Seed Results

| Seed | Steps | ms/step | Post-EMA BPB | **TTT BPB** | Artifact |
|------|-------|---------|-------------|-------------|----------|
| 1337 | 4,716 | 124.5 | 1.0709 | **1.0646** | 15,985,530 |
| 2024 | 4,731 | 124.1 | 1.0697 | **1.0634** | 15,980,244 |
| 42 | 4,726 | 124.3 | 1.0701 | **1.0639** | 15,982,918 |
| **Mean** | **4,724** | **124.3** | **1.0702** | **1.0639** | **15,982,897** |
| **Std** | | | | **0.0006** | |

PR #1578 original (same seeds): **1.0668 BPB mean**. Delta: **-0.0029 BPB** / **-0.0083 nats**.

## Systems Optimizations (3 changes, training-step only)

1. **Fused Muon transform** -- Single `@torch.compile` function combining momentum update, Nesterov extrapolation, row normalization, and Newton-Schulz orthogonalization. Eliminates kernel launch overhead between sequential operations.

2. **EMA foreach** -- Replaces per-tensor EMA loop with `torch._foreach_mul_` / `torch._foreach_add_` for batched parameter averaging.

3. **Numpy prealloc loader** -- Pre-allocates a reusable numpy buffer for data loading instead of allocating a new `np.array` per sequence.

No eval, serialization, or model architecture changes.

## What Changed vs PR #1578 (Only Systems Optimization)

The **only difference** from PR #1578 is the three systems optimizations listed above. Architecture, optimizer logic, hyperparameters, tokenizer, dataset, TTT, and quantization are all identical. The casefold v2 vocabulary and retokenized dataset are unchanged from PR #1578.

## Architecture (from PR #1529)

11L x 512d x 8H / 4KV, MLP 4x, LeakyReLU(0.5)^2, Partial RoPE (16/64), layerwise LN scale, tied embeddings, logit softcap=30.0. Depth recurrence: loops layers 3-5 (activated at frac=0.35). Dual-lane parallel residuals from physical layer 8. Fused Triton TMA MLP kernel + CUTLASS EVT backward.

## Tokenizer (from PR #1578)

Casefold v2 vocabulary: SP8192 retrained on NFKC + lowercased text. 374 freed case-duplicate slots refilled with BPP-optimized subwords. ~10.4% better compression. Byte counting verified correct on 15.4M FineWeb docs (0 mismatches). See `CASEFOLD_TOKENIZER.md` and `verify_bytes.py`.

## Training

Muon optimizer (sharded reduce-scatter + all-gather, Newton-Schulz 5 steps), AdamW for embeddings/scalars. ~4,724 steps in 587s. Warmdown frac=0.72, Muon momentum=0.97, EMA decay=0.997. GPTQ reserve 13s.

## TTT (Test-Time Training)

Score-first chunk-based SGD: 32K-token chunks, 3 epochs per chunk, cosine LR decay (lr=0.005, momentum=0.9). Hash embedding (16384-dim bigram hash, zero-initialized, learned during TTT). Gradient clipping at 1.0.

## Compliance

- **Condition 1 (Causality):** Sliding-window eval is strictly causal.
- **Condition 2 (Normalized distribution):** Standard softmax over full vocab.
- **Condition 3 (Score before update):** Each chunk scored under `torch.no_grad()` before SGD.
- **Condition 4 (Single pass):** Each token scored exactly once.
- No SLOT, no pre-quant TTT, no ETLB, no n-gram cache.

## Reproducibility

```bash
pip install brotli sentencepiece flash_attn_3 huggingface_hub

# CUTLASS EVT build:
git clone https://github.com/NVIDIA/cutlass.git /opt/cutlass
cd /opt/cutlass && git checkout 08185b9c3e90510ee2b656662ed0d53b06d28157
cd /workspace && pip install --no-build-isolation ./cutlass_evt_fusion

# Casefold data (from HuggingFace):
python3 -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='Mikeapedia/fineweb10B-sp8192-casefold-v2', repo_type='dataset', local_dir='data/datasets/fineweb10B_sp8192_casefold_v2', allow_patterns='*.bin')"

# Training (per seed):
for SEED in 1337 2024 42; do
SEED=$SEED TTT_ENABLED=1 HASH_EMBED_ENABLED=1 TTT_LR=0.005 \
MUON_MOMENTUM=0.97 PARALLEL_RESIDUAL_START=8 GPTQ_RESERVE_SECONDS=13 \
EMA_DECAY=0.997 WARMDOWN_FRAC=0.72 \
DATASETS_DIR=./data/datasets/fineweb10B_sp8192_casefold_v2 \
TOKENIZER_PATH=./data/tokenizers/fineweb_8192_bpe_casefold_refined_v2.model \
torchrun --standalone --nproc_per_node=8 train_gpt.py
done
```

## Attribution

- **PR #1578** (@mikeapedia): Casefold v2 vocabulary and retokenized dataset
- **PR #1529** (@msisovic): Dual-lane parallel residual architecture, Triton fused MLP, CUTLASS EVT
- **PR #1394** (@clarkkev): SP8192 tokenizer, GPTQ SDClip
- **PR #1413** (@dexhunter): Legal TTT framework
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Byte Verification: Why This Proves Our BPB Is Accurate

## The BPB Formula

```
BPB = (val_loss / ln2) × (tokens / bytes)
```

`val_loss` and `tokens` come directly from the model's forward pass — they
can't be wrong. The only variable a custom tokenizer can affect is **bytes**:
the denominator that converts token-level metrics into byte-level metrics.

## How Bytes Are Counted During Evaluation

The training script (`train_gpt_human.py`, lines 982-984) counts bytes using
a **lookup table (LUT)** built from the tokenizer vocabulary:

```python
token_bytes = base_bytes_lut[tgt_ids]
token_bytes += has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]
val_byte_count += token_bytes.sum()
```

For each predicted token, the LUT returns:
- **base_bytes**: UTF-8 byte width of the piece text (with interior `▁`
replaced by ASCII space)
- **+1** if the piece has a leading `▁` and the previous token isn't a
boundary (BOS, control, etc.) — this counts the space byte

This accumulates `val_byte_count`, which becomes the `bytes` in the BPB
formula.

## What `verify_bytes.py` Proves

The script verifies that the LUT byte count **exactly equals** the true
byte count of the text the model sees.

For each document:

1. Apply our normalization: `NFKC(text).lower()`
2. Apply SentencePiece's internal normalization (`sp.normalize()`) — this
is the exact text SP tokenizes (whitespace collapsing, etc.)
3. Count the UTF-8 bytes of that normalized text → **ground truth**
4. Tokenize with SP, then accumulate bytes using the same LUT logic as
`eval_val()` → **LUT bytes**
5. Assert: `LUT bytes == ground truth`

If this holds on every document, the byte denominator in BPB is correct.

## Why This Is Sufficient

The LUT and the model operate on the same token stream. When `eval_val()`
computes loss on token `t`, it also looks up `base_bytes_lut[t]` for the
byte count. If the sum of all LUT lookups equals the true byte count of
the text, then:

- No bytes are double-counted
- No bytes are missed
- The BPB formula's denominator is accurate

## What SentencePiece Normalization Means

SentencePiece applies `nmt_nfkc` normalization internally before tokenizing.
This goes beyond Python's `unicodedata.normalize("NFKC")`:

- Newlines → spaces
- Consecutive whitespace → single space
- Various Unicode normalizations

This means the model never sees the raw text — it sees SP's normalized
version. The LUT correctly counts bytes of this normalized text, not the
raw text. This is the right thing to measure: BPB should reflect the
information content the model actually predicts.

## Results

**200-doc spot check (bundled):**
```
Documents verified: 200
Ground-truth bytes: 1,489,674
LUT bytes: 1,489,674
Mismatched docs: 0 / 200
RESULT: ALL CHECKS PASSED
```

**Full 15.4M-document FineWeb corpus** (results in `verify_results.txt`):
```
Documents verified: 15,368,808
Tokens: 11,423,532,518
Ground-truth bytes: 47,707,155,846
LUT bytes: 47,707,155,846
Mismatched docs: 0 / 15,368,808
```
Zero mismatches across the entire dataset.

## Running It Yourself

Spot check (bundled sample, ~30 seconds, no GPU):
```bash
pip install sentencepiece
python verify_bytes.py --docs verify_docs.jsonl
```

Full verification (requires FineWeb corpus):
```bash
python verify_bytes.py --docs data/docs_selected.jsonl --max-docs 0
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from __future__ import annotations

from pathlib import Path
import torch


def _load_extension() -> None:
here = Path(__file__).resolve().parent
candidates = sorted(here.glob("cutlass_evt_fusion*.so"))
if not candidates:
raise ImportError(f"No compiled cutlass_evt_fusion extension found in {here}")
torch.ops.load_library(str(candidates[0]))


_load_extension()
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
// CUTLASS 3.x EVT kernel: fused GEMM * elementwise multiply
// Computes: dpre = (go @ down_w.T) * act_grad
// Where act_grad = f'(pre) is pre-computed in the forward pass.
//
// Layout convention:
// go: (M, K) bf16 row-major
// down_w: (K, N) bf16 row-major -- CUTLASS B(N,K) with RowMajor layout
// act_grad: (M, N) bf16 row-major
// dpre: (M, N) bf16 row-major output

#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp"
#include "cute/tensor.hpp"
#include "cutlass/util/packed_stride.hpp"
#include <iostream>

using namespace cute;

using ElementAcc = float;
using ElementCompute = float;
using ElementOutput = cutlass::bfloat16_t;
using ElementAux = cutlass::bfloat16_t;

using namespace cutlass::epilogue::fusion;

using TileShape = Shape<_128, _256, _64>;
using ClusterShape = Shape<_1, _1, _1>;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;

using EpiDesc = cutlass::epilogue::collective::detail::EpilogueDescriptor<
TileShape, EpilogueTile, ElementOutput, ElementOutput, EpilogueSchedule>;

using AuxDesc = cutlass::epilogue::collective::detail::AuxLoadDescriptor<
EpiDesc, cutlass::layout::RowMajor, ElementAux>;

using AuxLoad = Sm90AuxLoad<
AuxDesc::Stages,
typename EpiDesc::EpilogueTile,
typename AuxDesc::Element,
typename AuxDesc::Stride,
typename AuxDesc::SmemLayoutAtom,
typename AuxDesc::CopyOpS2R>;

using Compute = Sm90Compute<
cutlass::multiplies,
ElementOutput,
ElementCompute,
cutlass::FloatRoundStyle::round_to_nearest>;

using EVT = Sm90EVT<Compute, Sm90AccFetch, AuxLoad>;

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90,
cutlass::arch::OpClassTensorOp,
TileShape,
ClusterShape,
EpilogueTile,
ElementAcc, ElementCompute,
ElementOutput, cutlass::layout::RowMajor, 8,
ElementOutput, cutlass::layout::RowMajor, 8,
EpilogueSchedule,
EVT
>::CollectiveOp;

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90,
cutlass::arch::OpClassTensorOp,
ElementOutput, cutlass::layout::RowMajor, 8,
ElementOutput, cutlass::layout::RowMajor, 8,
ElementAcc,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
sizeof(typename CollectiveEpilogue::SharedStorage)>,
cutlass::gemm::KernelTmaWarpSpecializedCooperative
>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>,
CollectiveMainloop,
CollectiveEpilogue>;

using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

void launch_gemm_mul(
void const* ptr_go,
void const* ptr_down_w,
void const* ptr_act_grad,
void* ptr_dpre,
int M, int N, int K,
cudaStream_t stream)
{
using StrideA = cutlass::gemm::TagToStrideA_t<cutlass::layout::RowMajor>;
using StrideB = cutlass::gemm::TagToStrideB_t<cutlass::layout::RowMajor>;
using StrideC = cutlass::gemm::TagToStrideC_t<cutlass::layout::RowMajor>;

int L = 1;
auto prob_shape = make_shape(M, N, K, L);

auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
auto stride_Aux = cutlass::make_cute_packed_stride(
typename AuxDesc::Stride{}, cute::make_shape(M, N, L));

typename EVT::Arguments evt_args{
{},
{
static_cast<ElementAux const*>(ptr_act_grad),
ElementAux(0),
stride_Aux
},
{}
};

typename GemmOp::Arguments args{
cutlass::gemm::GemmUniversalMode::kGemm,
prob_shape,
{
static_cast<ElementOutput const*>(ptr_go),
stride_A,
static_cast<ElementOutput const*>(ptr_down_w),
stride_B,
},
{
evt_args,
static_cast<ElementOutput const*>(ptr_dpre),
stride_C,
static_cast<ElementOutput*>(ptr_dpre),
stride_C,
}
};

GemmOp gemm_op;
size_t workspace_size = GemmOp::get_workspace_size(args);
void* workspace = nullptr;
if (workspace_size > 0) {
cudaMalloc(&workspace, workspace_size);
}

auto status = gemm_op.initialize(args, workspace, stream);
if (status != cutlass::Status::kSuccess) {
std::cerr << "CUTLASS initialize failed: " << cutlassGetStatusString(status) << std::endl;
if (workspace) cudaFree(workspace);
exit(EXIT_FAILURE);
}

status = gemm_op.run(stream);
if (status != cutlass::Status::kSuccess) {
cudaError_t cuda_err = cudaStreamSynchronize(stream);
std::cerr << "CUTLASS run failed: " << cutlassGetStatusString(status)
<< " CUDA: " << cudaGetErrorString(cuda_err) << std::endl;
if (workspace) cudaFree(workspace);
exit(EXIT_FAILURE);
}

if (workspace) cudaFree(workspace);
}
Loading