Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Record: Improved Parallel Residuals + Systems Optimization

**val_bpb = 1.0752** (3-seed mean, std 0.0006) | **2.7773 nats** | **~15.98 MB** | 8xH100 SXM, 600s | Legal TTT

This submission applies three systems-level performance optimizations to PR #1529's dual-lane parallel residual architecture. The ML is unchanged; faster per-step throughput yields ~20 extra training steps in the same 600s budget.

> **Submission series:** This PR is one of three related submissions applying the same systems optimizations to different base stacks:
>
> 1. On PR #1493 (current merged SOTA)
> 2. On PR #1529 (pending review) -- **this PR**
> 3. On PR #1578 (pending review)
>
> The optimizations are identical across all three -- fused Muon kernel, batched EMA, and loader prealloc. We submit against multiple bases so that a ready-to-merge option exists regardless of how the pending PRs are resolved. Judges should feel free to evaluate whichever base(s) they consider valid and disregard the rest.

**Note on record criteria:** This submission improves speed through systems optimization without changing the ML. Per the official contest rules: *"For submissions that improve speed through systems optimization without changing the ML, this requirement [0.005 nats] is waived."* The three changes (fused Muon kernel, batched EMA, loader prealloc) are purely systems-level and do not alter model architecture, optimizer logic, loss function, or any hyperparameter.

## 3-Seed Results

| Seed | Steps | ms/step | Post-EMA BPB | Sliding BPB | **TTT BPB** | Artifact |
|------|-------|---------|-------------|-------------|-------------|----------|
| 1337 | 4,745 | 123.8 | 1.0823 | 1.0756 | **1.0745** | 15,983,819 |
| 2024 | 4,724 | 124.3 | 1.0833 | 1.0769 | **1.0755** | 15,982,374 |
| 42 | 4,744 | 123.8 | 1.0832 | 1.0773 | **1.0755** | 15,979,637 |
| **Mean** | **4,738** | **123.9** | **1.0829** | **1.0766** | **1.0752** | **15,981,943** |
| **Std** | | | | | **0.0006** | |

PR #1529 original (same seeds): **1.0753 BPB mean**. Delta: **-0.0001 BPB** (from extra training steps).

## Systems Optimizations (3 changes, training-step only)

1. **Fused Muon transform** -- Single `@torch.compile` function combining momentum update, Nesterov extrapolation, row normalization, and Newton-Schulz orthogonalization. Eliminates kernel launch overhead between sequential operations. (+0.43% step time on 2xH100 benchmark)

2. **EMA foreach** -- Replaces per-tensor EMA loop with `torch._foreach_mul_` / `torch._foreach_add_` for batched parameter averaging. (+0.08% step time)

3. **Numpy prealloc loader** -- Pre-allocates a reusable numpy buffer for data loading instead of allocating a new `np.array` per sequence. (+0.11% step time)

No eval, serialization, or model architecture changes. The three optimizations together save ~0.5% step time, translating to ~20 extra steps over 600s.

## Architecture (from PR #1529)

11L x 512d x 8H / 4KV, MLP 4x, LeakyReLU(0.5)^2, Partial RoPE (16/64), layerwise LN scale, tied embeddings, logit softcap=30.0. Depth recurrence: loops layers 3-5 (activated at frac=0.35). Dual-lane parallel residuals from physical layer 8: attention and MLP write to both lanes with learned post-lambdas and residual-lambdas. Final output: mean of two lanes. Skip connections: lane0 only.

Fused Triton TMA MLP kernel + CUTLASS EVT backward for throughput.

## Training

Muon optimizer (sharded reduce-scatter + all-gather, Newton-Schulz 5 steps), AdamW for embeddings/scalars. ~4,738 steps in 587s. Warmdown frac=0.667, Muon momentum=0.97, EMA decay=0.9965. GPTQ reserve 13s.

## Quantization

Full-Hessian GPTQ with SDClip: int6 for attention/MLP matrices (k=12.85), int8 for token embeddings (k=20.0). Byte-shuffle + Brotli-11 compression.

## TTT (Test-Time Training)

Score-first chunk-based SGD: 32K-token chunks, 3 epochs per chunk, cosine LR decay (lr=0.01, momentum=0.9). Hash embedding (16384-dim bigram hash, zero-initialized, learned during TTT). Gradient clipping at 1.0.

## Compliance

- **Condition 1 (Causality):** Sliding-window eval is strictly causal.
- **Condition 2 (Normalized distribution):** Standard softmax over full vocab.
- **Condition 3 (Score before update):** Each chunk scored under `torch.no_grad()` before SGD.
- **Condition 4 (Single pass):** Each token scored exactly once.
- No SLOT, no pre-quant TTT, no ETLB, no n-gram cache.

## Reproducibility

```bash
pip install brotli sentencepiece flash_attn_3 huggingface_hub
# CUTLASS EVT build (required for full throughput):
git clone https://github.com/NVIDIA/cutlass.git /opt/cutlass
cd /opt/cutlass && git checkout 08185b9c3e90510ee2b656662ed0d53b06d28157
cd /workspace && pip install --no-build-isolation ./cutlass_evt_fusion

# Data:
MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf python3 data/cached_challenge_fineweb.py --variant sp8192

# Training (per seed):
for SEED in 1337 2024 42; do
SEED=$SEED TTT_ENABLED=1 HASH_EMBED_ENABLED=1 TTT_LR=0.01 \
MUON_MOMENTUM=0.97 PARALLEL_RESIDUAL_START=8 GPTQ_RESERVE_SECONDS=13 \
torchrun --standalone --nproc_per_node=8 train_gpt.py
done
```

## Attribution

- **PR #1529** (@msisovic): Dual-lane parallel residual architecture, Triton fused MLP, CUTLASS EVT
- **PR #1394** (@clarkkev): SP8192 tokenizer, GPTQ SDClip, depth recurrence base
- **PR #1413** (@dexhunter): Legal TTT framework
- **PR #1445** (@X-Abhishek-X): Hyperparameter tuning
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from __future__ import annotations

from pathlib import Path
import torch


def _load_extension() -> None:
here = Path(__file__).resolve().parent
candidates = sorted(here.glob("cutlass_evt_fusion*.so"))
if not candidates:
raise ImportError(f"No compiled cutlass_evt_fusion extension found in {here}")
torch.ops.load_library(str(candidates[0]))


_load_extension()
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
// CUTLASS 3.x EVT kernel: fused GEMM * elementwise multiply
// Computes: dpre = (go @ down_w.T) * act_grad
// Where act_grad = f'(pre) is pre-computed in the forward pass.
//
// Layout convention:
// go: (M, K) bf16 row-major
// down_w: (K, N) bf16 row-major -- CUTLASS B(N,K) with RowMajor layout
// act_grad: (M, N) bf16 row-major
// dpre: (M, N) bf16 row-major output

#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp"
#include "cute/tensor.hpp"
#include "cutlass/util/packed_stride.hpp"
#include <iostream>

using namespace cute;

using ElementAcc = float;
using ElementCompute = float;
using ElementOutput = cutlass::bfloat16_t;
using ElementAux = cutlass::bfloat16_t;

using namespace cutlass::epilogue::fusion;

using TileShape = Shape<_128, _256, _64>;
using ClusterShape = Shape<_1, _1, _1>;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;

using EpiDesc = cutlass::epilogue::collective::detail::EpilogueDescriptor<
TileShape, EpilogueTile, ElementOutput, ElementOutput, EpilogueSchedule>;

using AuxDesc = cutlass::epilogue::collective::detail::AuxLoadDescriptor<
EpiDesc, cutlass::layout::RowMajor, ElementAux>;

using AuxLoad = Sm90AuxLoad<
AuxDesc::Stages,
typename EpiDesc::EpilogueTile,
typename AuxDesc::Element,
typename AuxDesc::Stride,
typename AuxDesc::SmemLayoutAtom,
typename AuxDesc::CopyOpS2R>;

using Compute = Sm90Compute<
cutlass::multiplies,
ElementOutput,
ElementCompute,
cutlass::FloatRoundStyle::round_to_nearest>;

using EVT = Sm90EVT<Compute, Sm90AccFetch, AuxLoad>;

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90,
cutlass::arch::OpClassTensorOp,
TileShape,
ClusterShape,
EpilogueTile,
ElementAcc, ElementCompute,
ElementOutput, cutlass::layout::RowMajor, 8,
ElementOutput, cutlass::layout::RowMajor, 8,
EpilogueSchedule,
EVT
>::CollectiveOp;

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90,
cutlass::arch::OpClassTensorOp,
ElementOutput, cutlass::layout::RowMajor, 8,
ElementOutput, cutlass::layout::RowMajor, 8,
ElementAcc,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
sizeof(typename CollectiveEpilogue::SharedStorage)>,
cutlass::gemm::KernelTmaWarpSpecializedCooperative
>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>,
CollectiveMainloop,
CollectiveEpilogue>;

using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

void launch_gemm_mul(
void const* ptr_go,
void const* ptr_down_w,
void const* ptr_act_grad,
void* ptr_dpre,
int M, int N, int K,
cudaStream_t stream)
{
using StrideA = cutlass::gemm::TagToStrideA_t<cutlass::layout::RowMajor>;
using StrideB = cutlass::gemm::TagToStrideB_t<cutlass::layout::RowMajor>;
using StrideC = cutlass::gemm::TagToStrideC_t<cutlass::layout::RowMajor>;

int L = 1;
auto prob_shape = make_shape(M, N, K, L);

auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
auto stride_Aux = cutlass::make_cute_packed_stride(
typename AuxDesc::Stride{}, cute::make_shape(M, N, L));

typename EVT::Arguments evt_args{
{},
{
static_cast<ElementAux const*>(ptr_act_grad),
ElementAux(0),
stride_Aux
},
{}
};

typename GemmOp::Arguments args{
cutlass::gemm::GemmUniversalMode::kGemm,
prob_shape,
{
static_cast<ElementOutput const*>(ptr_go),
stride_A,
static_cast<ElementOutput const*>(ptr_down_w),
stride_B,
},
{
evt_args,
static_cast<ElementOutput const*>(ptr_dpre),
stride_C,
static_cast<ElementOutput*>(ptr_dpre),
stride_C,
}
};

GemmOp gemm_op;
size_t workspace_size = GemmOp::get_workspace_size(args);
void* workspace = nullptr;
if (workspace_size > 0) {
cudaMalloc(&workspace, workspace_size);
}

auto status = gemm_op.initialize(args, workspace, stream);
if (status != cutlass::Status::kSuccess) {
std::cerr << "CUTLASS initialize failed: " << cutlassGetStatusString(status) << std::endl;
if (workspace) cudaFree(workspace);
exit(EXIT_FAILURE);
}

status = gemm_op.run(stream);
if (status != cutlass::Status::kSuccess) {
cudaError_t cuda_err = cudaStreamSynchronize(stream);
std::cerr << "CUTLASS run failed: " << cutlassGetStatusString(status)
<< " CUDA: " << cudaGetErrorString(cuda_err) << std::endl;
if (workspace) cudaFree(workspace);
exit(EXIT_FAILURE);
}

if (workspace) cudaFree(workspace);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// PyTorch C++ extension: CUTLASS EVT fused GEMM * elementwise multiply
// dpre = (go @ down_w.T) * act_grad
// Pass down_w directly (K, N) -- NOT down_w.T.contiguous()

#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h>

void launch_gemm_mul(
void const*, void const*, void const*, void*, int, int, int, cudaStream_t);

at::Tensor gemm_mul(at::Tensor go, at::Tensor down_w, at::Tensor act_grad) {
TORCH_CHECK(go.is_cuda() && go.is_contiguous());
TORCH_CHECK(down_w.is_cuda() && down_w.is_contiguous());
TORCH_CHECK(act_grad.is_cuda() && act_grad.is_contiguous());
TORCH_CHECK(go.scalar_type() == at::kBFloat16);
TORCH_CHECK(down_w.scalar_type() == at::kBFloat16);
TORCH_CHECK(act_grad.scalar_type() == at::kBFloat16);

int M = go.size(0);
int K = go.size(1);
int N = down_w.size(1);

TORCH_CHECK(down_w.size(0) == K,
"K mismatch: go has K=", K, " but down_w has size(0)=", down_w.size(0));
TORCH_CHECK(act_grad.size(0) == M && act_grad.size(1) == N,
"act_grad shape must be (M, N), got (", act_grad.size(0), ", ", act_grad.size(1), ")");

at::Tensor dpre = at::empty({M, N}, go.options());

launch_gemm_mul(
go.data_ptr(), down_w.data_ptr(), act_grad.data_ptr(), dpre.data_ptr(),
M, N, K,
at::cuda::getCurrentCUDAStream());

return dpre;
}

TORCH_LIBRARY(cutlass_evt, m) {
m.def("gemm_mul(Tensor go, Tensor down_w, Tensor act_grad) -> Tensor");
}

TORCH_LIBRARY_IMPL(cutlass_evt, CUDA, m) {
m.impl("gemm_mul", &gemm_mul);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import os

CUTLASS_PATH = os.environ.get("CUTLASS_PATH", "/opt/cutlass")

setup(
name="cutlass_evt_fusion",
ext_modules=[
CUDAExtension(
name="cutlass_evt_fusion",
sources=[
"csrc/gemm_act_grad.cu",
"csrc/torch_binding.cpp",
],
include_dirs=[
f"{CUTLASS_PATH}/include",
f"{CUTLASS_PATH}/tools/util/include",
],
extra_compile_args={
"nvcc": [
"-std=c++17",
"-arch=sm_90a",
"-O3",
"--use_fast_math",
"--expt-relaxed-constexpr",
"-DNDEBUG",
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
],
},
),
],
cmdclass={"build_ext": BuildExtension},
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
numpy
tqdm
huggingface-hub
kernels
setuptools
typing-extensions==4.15.0
datasets
tiktoken
sentencepiece
brotli
Loading