Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions records/track_10min_16mb/2026-03-25_QAT_Int4MLP_12L/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# 12L QAT Int4-MLP + Int6-Attn on PR #549 Stack

**val_bpb: TBD** (post-quant int4/int6+lzma, sliding window stride=32, legal TTT)

## Summary

Built on the #1 SOTA (1.1194, PR #549 by @abaybektursun). Adds full QAT with mixed-precision fake-quantization (int4 for MLP, int6 for attention) and uses the byte savings to fund a 12th transformer layer.

## Key Changes from SOTA (1.1194)

### 1. Quantization-Aware Training (QAT) with Mixed Precision
Full QAT via Straight-Through Estimator applied to banked weights during every forward pass:
- **MLP weights: int4** (clip_range=7) — MLPs are less precision-sensitive
- **Attention weights: int6** (clip_range=31) — attention is more sensitive
- Applied directly in MLP.forward and CausalSelfAttention.forward (not via CastedLinear)

### 2. Int4 MLP Post-Training Quantization
Post-training quantization uses GPTQ-lite clip search with int4 for MLP (clip=7) and int6 for attention (clip=31). Int4 compresses ~3x with LZMA vs ~1.6x for int6, saving ~3MB.

### 3. 12th Transformer Layer
The byte savings from int4 MLP fund an extra layer (12 vs 11). Architecture: 6 encoder + 6 decoder with U-Net skip connections.

### 4. Eval Stride 32 (was 64)
Halved sliding window eval stride for more context overlap per scored token.

## Inherited Techniques (from PR #549 stack)
- LeakyReLU(0.5)^2 activation
- Legal Score-First TTT (3 epochs SGD per 32K chunk)
- Parallel Muon optimizer with Parameter Banking
- XSA on last 4 layers
- Partial RoPE (16/64 dims)
- LN Scale 1/sqrt(layer+1)
- EMA (decay=0.997)
- SmearGate + BigramHash(2048)
- Value Embedding on layers 9,10
- LZMA compression

## Architecture
- 12 layers, 512 dim, 8 heads, 4 KV heads (GQA)
- MLP 3x expansion (1536), LeakyReLU(0.5)^2
- ~29M params, estimated artifact ~15MB

## Run Command

```bash
DATA_PATH=./data/datasets/fineweb10B_sp1024 \
TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \
torchrun --standalone --nproc_per_node=8 train_gpt.py
```

## Expected Improvement
- 12th layer: est. -0.002 to -0.003 bpb (based on 10L->11L gains)
- QAT reducing quant penalty: est. -0.001 to -0.002 bpb
- Stride 32 eval: est. -0.001 bpb
- Net target: ~1.114-1.117 bpb
92 changes: 92 additions & 0 deletions records/track_10min_16mb/2026-03-25_QAT_Int4MLP_12L/setup.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#!/bin/bash
# ---------------------------------------------------------------
# Setup script for 12L QAT Int4-MLP submission
# Run from the repo root: bash records/track_10min_16mb/2026-03-25_QAT_Int4MLP_12L/setup.sh
# ---------------------------------------------------------------
set -e

REPO_ROOT="$(cd "$(dirname "$0")/../../.." && pwd)"
cd "$REPO_ROOT"

echo "============================================"
echo " 12L QAT Int4-MLP — Environment Setup"
echo "============================================"

# ---------------------------------------------------------------
# 1. Python dependencies
# ---------------------------------------------------------------
echo ""
echo "[1/3] Installing Python dependencies..."
pip install --upgrade pip -q
pip install numpy tqdm sentencepiece huggingface-hub -q

TORCH_VER=$(python3 -c "import torch; print(torch.__version__)" 2>/dev/null || echo "none")
if [[ "$TORCH_VER" != *"2.9"*"+cu128"* ]]; then
echo " Upgrading torch to 2.9.1+cu128 (current: $TORCH_VER)..."
pip install torch --index-url https://download.pytorch.org/whl/cu128 --no-cache-dir --force-reinstall -q
else
echo " torch $TORCH_VER already OK."
fi
echo " Done."

# ---------------------------------------------------------------
# 2. Flash Attention 3 (Hopper)
# ---------------------------------------------------------------
echo ""
echo "[2/3] Installing Flash Attention 3..."

if python3 -c "from flash_attn_interface import flash_attn_func; print('ok')" 2>/dev/null | grep -q ok; then
echo " Already installed and working — skipping."
else
pip install flash_attn_3 --no-deps --force-reinstall --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291/
echo " Installed."
fi

# ---------------------------------------------------------------
# 3. Dataset + Tokenizer (sp1024)
# ---------------------------------------------------------------
echo ""
echo "[3/3] Downloading dataset (sp1024)..."
python3 data/cached_challenge_fineweb.py --variant sp1024
echo " Done."

# ---------------------------------------------------------------
# Verification
# ---------------------------------------------------------------
echo ""
echo "============================================"
echo " Verification"
echo "============================================"

python3 - << 'PYEOF'
import sys, torch, glob, numpy as np

print(f"Python : {sys.version.split()[0]}")
print(f"PyTorch : {torch.__version__}")
print(f"CUDA : {torch.cuda.is_available()}")
print(f"GPUs : {torch.cuda.device_count()}")

for i in range(torch.cuda.device_count()):
p = torch.cuda.get_device_properties(i)
print(f" GPU {i} : {p.name} ({p.total_memory // 1024**3}GB)")

try:
from flash_attn_interface import flash_attn_func
print("FlashAttn3 : OK")
except ImportError:
print("FlashAttn3 : MISSING!")

train = sorted(glob.glob("./data/datasets/fineweb10B_sp1024/fineweb_train_*.bin"))
val = sorted(glob.glob("./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin"))
print(f"Train shards : {len(train)}")
print(f"Val shards : {len(val)}")
PYEOF

echo ""
echo "============================================"
echo " Setup complete. Run training with:"
echo ""
echo " tmux"
echo " torchrun --nproc_per_node=8 records/track_10min_16mb/2026-03-25_QAT_Int4MLP_12L/train_gpt.py"
echo ""
echo "============================================"
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"author": "Meirzhan05",
"github_id": "Meirzhan05",
"val_bpb": null,
"date": "2026-03-25",
"hardware": "8xH100 SXM",
"train_time_seconds": 600,
"artifact_bytes": null,
"summary": "12L QAT Int4-MLP + Int6-Attn on PR #549 stack: LeakyReLU² + TTT + Parallel Muon + XSA + Partial RoPE"
}
Loading