diff --git a/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/README.md b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/README.md new file mode 100644 index 0000000000..38f13498bd --- /dev/null +++ b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/README.md @@ -0,0 +1,826 @@ +# XNOR-Net LLM for OpenAI Parameter Golf Challenge + +**Author:** Ciprian-Florin Ifrim -- April 2026 + +A full XNOR-Net language model that binarizes both weights and activations, trained for the [OpenAI Parameter Golf Challenge](https://openai.com/parameter-golf). The challenge requires training the best possible LLM that fits within a 16MB compressed artifact, evaluated on bits-per-byte (bpb) on the FineWeb validation set. + +This work extends the Binary-Weight-Network (BWN) and ternary submissions with a true XNOR-Net implementation -- the first known application of full activation binarization to transformer language models at this scale. + +**Best results:** + +| Track | Run | Config | Roundtrip bpb | Sliding bpb | Size | +|-------|-----|--------|---------------|-------------|------| +| 10-minute (8xH100) | R40 | 1024d 10L embed=384 BF16 scales | 1.578 | -- | 15.96MB | +| 10-minute (3 seeds) | P3/P4/P5 | Same as R40 | 1.602 +/- 0.012 | 1.567 +/- 0.012 | 15.96MB | +| Notable (100k steps) | N2 | R40 + scale QAT | 1.575 | 1.539 | 15.91MB | + +--- + +## Table of Contents + +1. [Architecture Overview](#architecture-overview) +2. [Key Technical Contributions](#key-technical-contributions) +3. [Development Timeline](#development-timeline) +4. [Activation Binarization Modes](#activation-binarization-modes) +5. [Activation Function Analysis](#activation-function-analysis) +6. [Triton XNOR Kernel](#triton-xnor-kernel) +7. [Compression Pipeline](#compression-pipeline) +8. [Optimizer Exploration](#optimizer-exploration) +9. [Learning Rate and Schedule Analysis](#learning-rate-and-schedule-analysis) +10. [Sequence Length Scheduling](#sequence-length-scheduling) +11. [Batch Size Sweep](#batch-size-sweep) +12. [Architecture Ablations](#architecture-ablations) +13. [Scale QAT and Roundtrip Gap](#scale-qat-and-roundtrip-gap) +14. [Attention Residuals](#attention-residuals) +15. [Complete Run Log](#complete-run-log) +16. [EGGROLL Exploration](#eggroll-exploration) +17. [Multi-Seed Variance](#multi-seed-variance) +18. [Final Configuration](#final-configuration) +19. [Reproduction](#reproduction) +20. [Key Insights](#key-insights) + +--- + +## Architecture Overview + +The model is a U-Net transformer with skip connections between encoder and decoder halves. All large weight matrices (QKV projections, attention output, MLP up/down) are binarized using the XNOR-Net approach from Rastegari et al. (2016). Small parameters (RMSNorm scales, skip weights, residual mixing, QK gains) remain in full precision. + +### Model Configuration (Best: R40 / N2) + +| Component | Value | +|-----------|-------| +| Model dimension | 1024 | +| Layers | 10 | +| Attention heads | 8 | +| KV heads | 4 (GQA) | +| MLP multiplier | 4x | +| Embedding dimension | 384 (R40) / 256 (R34) | +| BPE vocabulary | 1024 tokens | +| Total parameters | 117.6M | +| Binary parameters | 115.3M (98.0%) | +| FP parameters | 2.3M (2.0%) | +| Activation function | signsq (x * abs(x)) | +| Activation binarization | Mode 2 (XNOR except MLP down proj) | +| Group size | 256 | +| RoPE | YaRN (base=5000, max_len=2048) | +| Logit softcap | 10.0 (polynomial) | +| Tied embeddings | Yes | +| FP param storage | FP8 (e4m3fn) | +| Scale storage | BF16 (with FP8 STE for scale QAT) | + +### U-Net Structure + +The transformer is split into encoder (first N/2 layers) and decoder (remaining layers). Skip connections with learnable weights connect corresponding encoder-decoder pairs, initialized to ones. This provides error correction for the information loss inherent in binary quantization -- early features bypass the deepest (most lossy) layers. + +### Weight Binarization (STE) + +Each weight matrix W is binarized per group: +``` +W_binary = sign(W) * alpha, where alpha = mean(|W|) per group of 256 elements +``` +During training, the Straight-Through Estimator (STE) passes gradients through sign() as if it were the identity function. The real-valued weights are maintained in float32 and updated normally; only the forward pass uses binary weights. + +--- + +## Key Technical Contributions + +### 1. Activation Binarization Mode 2 + +Full XNOR (binarizing all activations) plateaus at ~2.0 bpb regardless of training duration. The root cause: the MLP down projection receives all-positive inputs from activation functions, so sign() always returns +1 -- carrying zero information. Mode 2 skips activation binarization on the MLP down projection only, breaking through to 1.575 bpb while keeping all other projections binary. + +### 2. signsq Activation Function + +`signsq(x) = x * |x|` replaces relu^2 for Mode 2. Unlike relu^2 (which outputs only positive values), signsq produces negative outputs, so subsequent sign() operations in the attention path carry real information. This is critical for quality when activations are binarized. + +### 3. Scale QAT (Quantization-Aware Training) + +Binary weight group scales (alpha = mean(|W|) per group) are stored in FP8 at save time. Without scale QAT, the model trains with float32 scales but encounters FP8 quantization error at roundtrip, causing catastrophic degradation at long training runs (0.87 bpb gap at 200k steps). Scale QAT simulates FP8 quantization via STE during training, so the model learns to compensate for precision loss. Result: gap drops from 0.87 to 0.006 bpb. + +### 4. Triton XNOR+POPCOUNT Kernel + +A custom Triton kernel performs true 1-bit matrix multiplication using XOR and population count instructions. The kernel operates on packed int32 words (32 binary weights per word) with per-group scaling factors. + +### 5. Cosine LR Schedule + +Binary STE training with a flat learning rate followed by warmdown wastes ~70% of training in a divergent plateau. Cosine decay from the start keeps every step productive and enables 4x higher peak LR (0.008 vs 0.002). + +### 6. Sequence Length Scheduling + +Training starts at seq_len=128 and ramps through 256->512->1024 over four equal time phases. Short sequences give 8x more gradient updates per second during the early phase where the model just needs to learn token frequencies. All torch.compile graphs are cached during warmup by running one forward pass at each sequence length. + +### 7. Low Momentum for Binary STE + +Standard momentum (0.95) amplifies noisy STE gradients, causing destructive sign oscillations. Reducing momentum to 0.80 dampens this noise, giving a 0.027 bpb improvement. + +--- + +## Development Timeline + +### Phase 1: RTX 5090 Development (T-series runs) + +Initial development on a single RTX 5090 32GB (Blackwell SM120) on Vast.ai. Established the architecture, debugged FlashAttention3 compatibility, developed and debugged the Triton XNOR kernel, and tested basic training dynamics. Key discovery: per-group alpha scaling is essential (per-row loses 0.62 bpb). + +### Phase 2: 8xH100 Scaling (S-series runs) + +Moved to 8xH100 SMX 80GB in a Docker container (driver 565.57, CUDA 12.7). Discovered that batch size dramatically affects binary training -- 65536 tokens outperforms 524288 by 0.07 bpb because binary networks benefit from frequent, small updates rather than rare, large ones. + +### Phase 3: Record Attempts (R-series runs) + +Systematic hyperparameter optimization covering 42 runs. Discovered cosine LR schedule (R25-R29), momentum reduction (R31-R35), gradient clipping (R30), sequence length scheduling (R33), and scale storage optimization (R37-R40). Cumulative improvement from R1 to R40: 2.074 -> 1.574 bpb (0.500 bpb gain). + +### Phase 4: Notable Track (N-series runs) + +Extended training at 100k-200k steps. N1 revealed the roundtrip gap problem from FP8 scale accumulation over long training. Scale QAT (N2) fixed this, achieving 1.575 roundtrip bpb. + +### Phase 5: EGGROLL Exploration (E-series runs) + +Attempted gradient-free evolution strategies using the EGGROLL algorithm (Sarkar et al. 2026). Tested full perturbation, layer-limited perturbation, and LoRA-based perturbation across 11 runs. Found that STE+Muon finds a basin too precise for zeroth-order methods to improve upon at 115M parameters. + +### Phase 6: Attention Residuals (R41-R42, N3) + +Implemented Attention Residuals from the Kimi Team (2026) paper as an alternative to U-Net skip connections. Each layer attends over all prior outputs via learned depth-wise attention. The 33% overhead from the depth-wise softmax reduced the number of training steps achievable in 10 minutes, resulting in worse final quality than the simpler U-Net skips. + +--- + +## Activation Binarization Modes + +`BINARIZE_ACTIVATIONS` controls which layers have their input activations binarized: + +| Mode | Description | Best bpb | Notes | +|------|-------------|----------|-------| +| 0 | BWN -- weights only, float activations | 1.16* | Separate BWN submission | +| 1 | Full XNOR -- all activations binarized | 2.00 | Information bottleneck | +| 2 | XNOR except MLP down projection | **1.575** | Best quality | + +*BWN result from separate Binary BitNet submission, not this XNOR codebase. + +### Why Full XNOR Plateaus at 2.0 bpb + +With relu^2 or signsq activation, MLP hidden states passed through the activation are either all-positive (relu^2) or mixed-sign (signsq). When the down projection's input goes through sign(), the quality depends entirely on whether these signs carry information: + +With relu^2: every hidden element is positive, so sign() returns all +1. The binary dot product `sign(x) * sign(w)` degenerates to just `sum(sign(w))` -- the activation signs carry no information. This bottleneck limits quality to ~2.0 bpb regardless of model size or training duration. + +With signsq in Mode 2: the down projection receives un-binarized signsq outputs (mixed signs with magnitude information), bypassing the bottleneck. All other projections (QKV, attention out, MLP up) still use full XNOR with binarized activations. + +--- + +## Activation Function Analysis + +| Activation | Formula | Pros | Cons | +|-----------|---------|------|------| +| relu^2 | relu(x)^2 | Excellent LZMA compression (structured signs) | Quality ceiling at ~2.0 bpb (all-positive) | +| signsq | x * abs(x) | Produces negative outputs, best quality | Poor compression (random signs) | +| swiglu | silu(gate) * up | Standard for LLMs | Higher param count | + +### relu^2 Compression Phenomenon + +relu^2 makes all MLP hidden activations positive. The down projection's weight signs evolve to be highly structured (correlated within groups) because the gradient signal only comes through the positive activation channel. LZMA compresses these structured signs extremely well -- a 196M param model (16L) compresses to 15.5MB. + +However, this compression comes at the cost of quality. The model trades information capacity for compressibility. With signsq, the signs are high-entropy (incompressible) but carry genuine information, yielding much better bpb. + +--- + +## Triton XNOR Kernel + +### Architecture + +The kernel performs binary matrix multiplication using XOR + population count: +``` +dot(sign(x), sign(w)) = group_size - 2 * popcount(x_packed XOR w_packed) +``` + +Each group of 256 weights is packed into 8 int32 words. The kernel accumulates per-group dot products, scales by per-group alpha, and sums across groups. + +### Per-Group Alpha Scaling + +The kernel supports per-group weight scaling factors (alpha = mean(|w|) per group), matching the STE reference path exactly. Initial versions used per-row alpha which lost 0.62 bpb of quality. + +### Bug Fix: int64 Promotion + +Triton promotes int32 to int64 during 2D broadcast operations. When `xv[:, None] ^ wv[None, :]` creates a [BLOCK_M, BLOCK_N] tensor, the result is int64. `popc()` then dispatches to `__nv_popcll` (64-bit popcount) instead of `__nv_popc` (32-bit), counting 32 extra zero-bits for every positive int32 and 32 extra one-bits for every negative int32. + +Fix: cast the XOR result back to int32 before calling popc: +```python +diff = (xv[:, None] ^ wv[None, :]).to(tl.int32) +group_acc += tl.extra.cuda.libdevice.popc(diff) +``` + +### bfloat16 Accumulation + +The kernel accumulates in bfloat16 (not float32) to match the precision of the STE reference path and the roundtrip reconstruction. This reduced the quantization gap from 0.008 to 0.003 bpb. + +### Performance + +At the current model size (1024d, 65536 batch tokens, 8 GPUs), the Triton kernel shows no speed improvement over the BF16 STE path (~38ms/step for both). The matrices are too small for the kernel launch overhead to be amortized. The kernel's value is correctness verification and future larger models. + +--- + +## Compression Pipeline + +### Storage Formats + +| Component | Format | Size (R40) | +|-----------|--------|------------| +| Binary weights | Packed bits (1 bit/param) | 14.87MB (pre-compression) | +| Group scales (g=256) | BF16 | 0.90MB | +| Embeddings, head, projections | FP8 (e4m3fn) | 0.86MB | +| Code | UTF-8 | 0.08MB | + +### Compression Comparison + +| Algorithm | Compressed Size (R34) | Compressed Size (R40) | +|-----------|-----------------------|-----------------------| +| LZMA preset 9 | 15.37MB | 15.95MB | +| **Brotli quality 11** | **15.30MB** | **15.89MB** | +| zstd level 22 | 17.08MB | 17.73MB | + +Brotli consistently wins by ~50KB. zstd is worst for binary data -- it's optimized for structured text, not near-random bit patterns. The save process tries all three and picks the smallest, with a 1-byte header indicating the method for the decompressor. + +### FP8 Scale Storage vs BF16 + +| Scale Storage | Extra Size | Roundtrip Gap | Notes | +|--------------|-----------|---------------|-------| +| FP8 | -0.45MB | 0.013 (R34) | Smaller artifact, minor precision loss | +| BF16 | baseline | 0.005 (R40) | Better roundtrip, essential for long training | + +FP8 scale storage saves ~0.45MB but introduces quantization error on per-group scales. For 10-minute runs (15k steps), the gap is tolerable. For 100k+ steps, the error compounds and BF16 scales are essential (or scale QAT is needed). + +### Sign-Sort Permutation + +Post-training, MLP hidden dimensions are permuted so same-sign weight columns are adjacent. The corresponding rows of the paired projection are permuted identically, preserving model output. Intended to create long runs of identical bits for LZMA compression. Result: did not help for signsq (signs are high-entropy), only useful for relu^2 which has structured signs. + +### Compression Regularizer + +A differentiable penalty using `tanh(10*w)` that pushes weight signs within each group toward uniformity. Controlled by `SIGN_COMPRESS_REG`. Result: hurt quality, not worth it. The regularizer fights against the STE gradient signal, reducing model capacity without sufficient compression gain. + +--- + +## Optimizer Exploration + +### Muon (Momentum + NS Orthogonalization) + +The primary optimizer for binary weight matrices. Uses Newton-Schulz orthogonalization on the gradient before applying the update. Muon was chosen because it produces well-conditioned updates that help binary STE training converge faster than Adam. + +| NS Variant | Steps | Precision | val_bpb | ms/step | +|-----------|-------|-----------|---------|---------| +| **Original ns_orth** | **3** | **bf16** | **1.671** | **38.5** | +| Our Gram NS | 5 | bf16 | 1.713 | 38.8 | +| Library Gram NS | 5 | fp16 | 1.740 | 39.1 | + +Original 3-step NS wins. Binary STE gradients are inherently noisy because sign() is a discontinuous function. More precise orthogonalization (Gram NS with 5 steps) doesn't help because the gradient itself is approximate. The library's float16 precision actively hurts because bfloat16's larger dynamic range matters more than mantissa precision for binary training. + +### NS Step Count Ablation + +| Steps | val_bpb | roundtrip | gap | ms/step | +|-------|---------|-----------|-----|---------| +| 2 | 1.684 | 1.733 | 0.049 | 37.8 | +| **3** | **1.671** | **1.672** | **0.001** | 38.5 | +| 5 | 1.713 | 1.719 | 0.006 | 38.9 | + +3 steps is the sweet spot. 2 steps under-orthogonalizes, producing updates that are poorly conditioned and create a huge roundtrip gap (0.049). 5 steps over-orthogonalizes noisy STE gradients, wasting compute on precision that doesn't exist in the signal. + +### Momentum + +Momentum controls how much of the previous gradient update is carried forward. In standard float training, high momentum (0.95) smooths out mini-batch noise. But for binary STE training, each gradient is fundamentally approximate because sign() is not differentiable. High momentum amplifies these approximation errors, causing weights to oscillate across zero (flipping their sign back and forth unproductively). + +| Momentum | val_bpb | Roundtrip | Gap | Notes | +|----------|---------|-----------|-----|-------| +| 0.95 | 1.636 | 1.639 | 0.003 | Standard, too noisy | +| 0.85 | 1.616 | 1.627 | 0.011 | Better | +| **0.80** | **1.589** | **1.602** | **0.013** | Best balance | +| 0.75 | 1.591 | 1.613 | 0.022 | Under-smoothed, worse gap | + +At 0.80, the noise from STE gradient errors is dampened enough that the model trains stably, but there is still enough momentum to escape shallow local optima. At 0.75, gradients become too noisy (not enough smoothing), and the roundtrip gap doubles -- weights jitter more and quantize poorly. + +### EMA (Exponential Moving Average) + +| EMA Config | val_bpb | roundtrip | gap | +|-----------|---------|-----------|-----| +| **Off** | **1.589** | **1.602** | **0.013** | +| Start at 60% | 1.590 | 1.612 | 0.022 | +| Start at 0% | 1.674 | 1.909 | 0.235 | + +EMA averages weights over recent training history. For float models this smooths out noise, but for binary models it's catastrophic. The averaged weights have less decisive signs -- they sit closer to zero where sign() is maximally sensitive to perturbation. During roundtrip (load from compressed artifact), these near-zero weights flip unpredictably, destroying quality. EMA is harmful for binary models. + +--- + +## Learning Rate and Schedule Analysis + +### LR Schedule: Linear Warmdown vs Cosine + +The training loss curve for binary STE networks shows a distinctive "wandering" pattern. After an initial drop (steps 0-3000), loss increases and oscillates for thousands of steps (3000-9000) before dropping again during warmdown. This happens because the LR is too high for stable binary training -- each step flips thousands of weight signs, some productive, some destructive. The productive and destructive flips roughly cancel out, so the model wanders sideways. + +With cosine decay, the LR starts decreasing immediately after warmup. There is no sustained high-LR plateau, so the wandering phase is compressed. More importantly, cosine enables a much higher peak LR (0.008 vs 0.002) because the rapid decay prevents the accumulated noise from causing divergence. + +| Schedule | Peak LR | val_bpb | +|----------|---------|---------| +| Linear warmdown 0.3 | 0.002 | 1.654 | +| **Cosine** | **0.008** | **1.629** | + +### Cosine LR Sweep + +| LR | val_bpb | roundtrip | +|----|---------|-----------| +| 0.002 | 1.653 | 1.667 | +| 0.004 | 1.636 | 1.639 | +| 0.006 | 1.632 | 1.637 | +| **0.008** | **1.629** | **1.635** | +| 0.012 | 1.635 | 1.640 | + +The peak is at 0.008. Below that, the model learns too slowly in the available training time. Above that, excessive sign flips early in training prevent the model from finding a good basin. + +### Gradient Clipping + +| Grad Clip | val_bpb | +|-----------|---------| +| 0.0 (off) | 1.629 | +| **1.0** | **1.626** | + +Small improvement. Gradient clipping prevents any single batch from causing a catastrophic cascade of sign flips. In binary networks, a large gradient can flip the sign of many weights simultaneously, and the resulting binary network can be dramatically different from what the optimizer expected. + +--- + +## Sequence Length Scheduling + +Training starts at seq_len=128 and doubles at equal time intervals: 128->256->512->1024. + +The reasoning: early in training, the model needs to learn basic token frequencies and simple bigram patterns. These require only short context. Processing short sequences is 8x faster than full 1024 (attention is quadratic), giving 8x more gradient updates per second. Once the model has learned local patterns, longer sequences allow it to learn long-range dependencies. + +Implementation details: the schedule is based on either wall-clock time (if MAX_WALLCLOCK_SECONDS > 0) or step count (if using iterations). Each torch.compile graph is cached during warmup by running one forward-backward pass at each sequence length. + +| Run | Config | val_bpb | +|-----|--------|---------| +| R31 | Cosine + momentum 0.85, no schedule | 1.616 | +| **R33** | **+ seq_len schedule** | **1.597** | +| R34 | + momentum 0.80 | 1.589 | + +The 0.019 bpb gain from scheduling is entirely free -- the same total tokens are processed, just in a more efficient order. The model gets ~4x more gradient updates during the first quarter of training. + +--- + +## Batch Size Sweep + +Binary STE training strongly prefers smaller batches with more frequent updates. Each sign() decision is discrete -- once a weight's sign flips, the effect on the network is immediate and discontinuous. More frequent updates mean the model can react to the consequences of each sign flip sooner, correcting mistakes before they propagate. + +| Batch Size | ms/step | Final bpb | +|-----------|---------|-----------| +| 524288 | 127.9 | 2.070 | +| 262144 | 69.1 | 2.072 | +| 131072 | 39.6 | 2.016 | +| **65536** | **38.6** | **1.999** | +| 32768 | 38.7 | 2.004 | + +65536 is the sweet spot. Below that, per-batch gradient noise increases without speed benefit (DDP communication overhead dominates at small batch sizes). Above that, the model makes fewer sign-flip decisions per second, losing the benefit of frequent updates. + +With 8 GPUs at 65536 total batch: 8192 tokens per GPU, well within VRAM. Step time is dominated by DDP synchronization, not compute. + +--- + +## Architecture Ablations + +### Group Size + +The group size controls how many weights share a single scaling factor (alpha). Smaller groups give finer-grained scaling but noisier per-group statistics (fewer elements to average over). Larger groups give stable statistics but coarser approximation. + +| Group Size | val_bpb | Size | +|-----------|---------|------| +| 128 | 1.671 | 15.51MB | +| **256** | **1.654** | **15.43MB** | +| 512 | 1.689 | 15.37MB | +| 1024 | 1.680 | 15.36MB | + +256 is optimal. At 128, per-group mean(|w|) over 128 elements is noisy. At 512+, a single alpha must represent weights with different magnitudes, losing precision. + +### Layers vs MLP Width + +| Layers | MLP | Params | val_bpb | +|--------|-----|--------|---------| +| **10** | **4x** | **117M** | **1.654** | +| 13 | 3x | 125M | 1.745 | +| 10 | 5x | 138M | 1.671 | + +Wider MLP is strictly better than deeper for binary networks. Each layer applies sign() to its output, which is a lossy operation that compounds across depth. More layers = more compounding information loss. A wider MLP gives more capacity per layer without the compounding. The 10L 4x config fits the 16MB budget optimally. + +### Wider Model (768d) vs Standard (1024d) + +| Config | val_bpb | Size | +|--------|---------|------| +| **1024d x 10L** | **1.589** | 15.37MB | +| 768d x 18L (embed=512) | 1.634 | 15.90MB | +| 768d x 18L (embed=256) | -- | 15.50MB | + +Even with 80% more layers, the narrower model is worse. Binary networks lose information per layer, so depth hurts more than width helps. + +### BPE Vocabulary + +| BPE Size | val_bpb | FP Params | Size | Fits? | +|----------|---------|-----------|------|-------| +| **1024** | **1.654** | 0.85MB | 15.43MB | YES | +| 8192 | 1.673 | 2.70MB | 16.83MB | NO | + +Smaller vocabulary saves 1.85MB of embedding FP params, allowing more binary parameters within the 16MB budget. The larger vocabulary doesn't compensate for the lost binary capacity. + +### Embedding Dimension + +| Embed Dim | Storage | val_bpb | Roundtrip | Size | Notes | +|-----------|---------|---------|-----------|------|-------| +| 256 | FP8 | 1.589 | 1.602 | 15.37MB | R34 best 10-min | +| **384** | **FP8+BF16 scales** | **1.574** | **1.578** | **15.96MB** | **R40 best overall** | +| 512 | FP8+FP8 scales | 1.569 | 2.435 | 15.66MB | Roundtrip catastrophic | +| 512 | FP8+BF16 scales | -- | -- | 16.26MB | Over budget | + +384 embed_dim with BF16 scales is the sweet spot -- richer embedding space within budget, and the BF16 scales avoid roundtrip degradation. 512 embed with FP8 scales destroys roundtrip at long training due to accumulated scale quantization error. + +### Logit Softcap + +| Softcap | val_bpb | +|---------|---------| +| **10** | **1.671** | +| 15 | 1.684 | + +10 is better. Lower softcap constrains logits more, regularizing the model. Uses polynomial approximation (`x * (1 - x^2/3 + x^4/15)`) instead of tanh because tanh doesn't fuse with torch.compile. + +### Smear Module + +| Smear | val_bpb | Notes | +|-------|---------|-------| +| **Off** | **1.589** | Saves ~1ms/step | +| On | 1.603 | Doesn't help with seq_len scheduling | + +Smear didn't help with sequence length scheduling enabled. The scheduling already provides the "easy then hard" curriculum that smear approximates. + +### Size Check Runs (T28-T34) + +Architecture variants tested for 16MB budget fit: + +| Run | Config | Params | Size | Fits? | +|-----|--------|--------|------|-------| +| T28 | 10L 1024d embed=512 | 118.0M | 15.87MB | YES | +| T29 | 11L 1024d embed=256 | 128.8M | 16.80MB | NO | +| T30 | 14L 768d embed=256 | 92.3M | 12.10MB | YES | +| T31 | 20L 768d embed=256 | 131.3M | 17.09MB | NO | +| T32 | 18L 768d embed=256 | 118.3M | 15.43MB | YES | +| T33 | 19L 768d embed=256 | 124.8M | 16.26MB | NO | +| T34 | 18L 768d embed=512 | 118.9M | 15.89MB | YES | + +--- + +## Scale QAT and Roundtrip Gap + +### The Problem + +During training, per-group weight scales (alpha = mean(|w|)) are computed in float32. At save time, these scales are quantized to FP8 for storage. Each step introduces a tiny error that the model never sees during training. Over 200k steps, the model becomes precisely tuned to float32 scale values that FP8 cannot represent, causing catastrophic roundtrip degradation. + +| Run | Steps | FP Storage | Scale Storage | Scale QAT | val_bpb | Roundtrip | Gap | +|-----|-------|-----------|---------------|-----------|---------|-----------|-----| +| R34 | 15k | FP8 | FP8 | No | 1.589 | 1.602 | 0.013 | +| N1 | 200k | FP8 | FP8 | No | 1.569 | 2.435 | **0.866** | +| R40 | 15k | FP8 | BF16 | No | 1.574 | 1.578 | 0.005 | +| **N2** | **100k** | **FP8** | **BF16** | **Yes** | **1.569** | **1.575** | **0.006** | + +N1's 100k checkpoint had roundtrip 1.986, 30k checkpoint had 2.121 -- the error compounds monotonically with training steps. + +### The Fix + +Scale QAT simulates FP8 quantization on scales during the forward pass via STE: +```python +alpha_q = alpha.to(torch.float8_e4m3fn).to(alpha.dtype) +alpha = alpha + (alpha_q - alpha).detach() # STE +``` + +The model sees the quantized scale values during training and learns to compensate. Combined with BF16 scale storage (which has negligible quantization error), the roundtrip gap stays below 0.006 bpb even at 100k steps. + +--- + +## Attention Residuals + +### Background + +Attention Residuals (Kimi Team, 2026) replace standard residual connections with learned depth-wise attention. Instead of `h_l = h_{l-1} + f(h_{l-1})`, each layer attends over ALL prior outputs: `h_l = softmax_weighted_sum(all previous outputs)`. This allows later layers to selectively retrieve information from any earlier layer, bypassing lossy intermediate sign() operations. + +### Implementation + +Two modes were implemented: +- **Mode 2 (pass-level):** One stored tensor per block. 10 query vectors x 1024 dim = 10K params. +- **Mode 1 (sub-layer):** One stored tensor per sub-layer (attention + MLP). 20 queries x 1024 dim = 20K params. + +Queries are zero-initialized so the model starts with uniform weights (equivalent to standard residual). Keys are RMSNorm'd stored outputs. No projection matrices needed. + +### Results + +| Run | Mode | Steps | ms/step | val_bpb | roundtrip | sliding | +|-----|------|-------|---------|---------|-----------|---------| +| R40 | U-Net (0) | 15560 | 38.6 | 1.574 | 1.578 | -- | +| R41 | AttnRes (2) | 11560 | 51.7 | 1.594 | 1.598 | -- | +| R42 | AttnRes (1) | -- | -- | crash | -- | -- | +| N2 | U-Net (0) | 100k | 39.0 | 1.569 | 1.575 | 1.539 | +| N3 | AttnRes (2) | 100k | 51.8 | 1.583 | 1.596 | 1.563 | + +### Analysis + +AttnRes adds 33% overhead (51.7ms vs 38.6ms) from the depth-wise softmax computation over stored tensors. In the 10-minute track, this overhead means ~4000 fewer training steps, which more than negates any architectural benefit. Even at 100k steps (N3 vs N2), U-Net wins by 0.021 bpb in roundtrip. + +The overhead comes from: storing 10+ tensors, computing 10 einsum operations for logits, softmax, and weighted sum each forward pass. Torch.compile partially fuses these but the softmax reduction dimension is too small (10 elements) for efficient GPU execution. + +Mode 1 (sub-layer) crashed with an Inductor OOM error -- the backward graph with 20 stored tensors exceeds Triton's register file limits for the fused RMSNorm backward kernel. + +Conclusion: for binary networks, simple weighted skip connections (U-Net) provide sufficient error correction at much lower overhead than learned depth-wise attention. + +--- + +## Complete Run Log + +### T-series: RTX 5090 Testing + +| Run | Config | Steps | val_bpb | Size | Notes | +|-----|--------|-------|---------|------|-------| +| T7 | relu2, mode 2, 15L 1024d | 200 | 1.807 | 26.08MB | Mode 2 first test | +| T9 | relu2, mode 1, 16L 1024d | 50 | 2.058 | 18.46MB | Size check | +| T10 | relu2, mode 1, 15L 1024d | 50 | 2.060 | 17.46MB | Size check | +| T11 | relu2, mode 1, 12L 1024d | 50 | 2.056 | 14.46MB | Fits | +| T12 | relu2, mode 1, 16L 1024d | 22500 | 2.569+ | -- | Diverged | +| T13 | relu2, mode 1, 16L LR=0.01 | 2000 | 1.879 | -- | Still diverged | +| T14 | signsq, mode 1, 16L | 2000 | 1.985 | 27.44MB | No compression | +| T15 | signsq, mode 1, 16L (sign-sort) | 50 | 2.079 | 27.35MB | Sign-sort no help | +| T16 | signsq, mode 1, 16L (sign-sort) | 5000 | 1.941 | 27.51MB | Sign-sort no help | +| T17 | signsq, mode 1, 11L 1024d | 500 | 1.842 | 19.65MB | Over | +| T18 | signsq, mode 1, 11L 1024 BPE | 500 | 2.001 | 17.71MB | Over | +| T19 | signsq, mode 1, 10L 1024 BPE | 500 | 2.042 | 16.15MB | Over | +| T20 | signsq, mode 1, 10L g=256 | 500 | 2.019 | 15.76MB | **Fits** | +| T21 | signsq, mode 1, 10L 262k batch | 15000 | diverged | -- | LR too high | +| T22 | signsq, mode 1, EMA, LR=0.005 | 1500 | 1.995 | 15.75MB | EMA helped | +| T23 | signsq, mode 1, no EMA, LR=0.005 | 1500 | 2.005 | 15.75MB | LR sufficient | +| T25 | signsq, mode 1, 10L LR=0.005 | 2500 | 1.984 | 15.76MB | Best mode 1 | +| T26 | Triton kernel (per-row alpha, buggy) | 1000 | 2.451 | 14.39MB | Kernel bug | +| T27 | Triton kernel (per-row, no compile) | 1000 | 2.481 | 14.39MB | Bug confirmed | +| T28 | Triton kernel (per-group, fixed) | 1000 | 2.028 | 15.68MB | **Kernel works** | + +### S-series: 8xH100 Scaling + +| Run | Config | Steps | val_bpb | Size | Notes | +|-----|--------|-------|---------|------|-------| +| S1 | BF16 scales, 524k batch | 5000 | 2.120 | 16.35MB | Over | +| S2 | FP8 scales, 524k batch | 5000 | 2.070 | 15.75MB | Fits | +| S3 | FP8 scales + compress reg 0.01 | 5000 | 2.089 | 15.75MB | Reg hurts | +| S4 | FP8 scales, 262k batch | 5000 | 2.072 | 15.76MB | Same quality | +| S5 | FP8 scales, 131k batch | 5000 | 2.016 | 15.76MB | Better | +| **S6** | **FP8 scales, 65k batch** | **5000** | **1.999** | **15.76MB** | **Best batch** | +| S7 | FP8 scales, 32k batch | 5000 | 2.004 | 15.78MB | Diminishing returns | +| S8 | LR=0.002, 65k batch | 5000 | 2.044 | 15.75MB | Too conservative | +| S9 | LR=0.003, 65k batch | 5000 | 1.995 | 15.76MB | Good | +| S10 | LR=0.004, 65k batch | 5000 | 1.994 | 15.76MB | Best LR | +| S11 | Gram NS library, 65k batch | 1000 | 2.004 | 15.51MB | Slightly slower | + +### R-series: Record Attempts + +| Run | Key Changes | val_bpb | RT bpb | Size | Fits? | +|-----|-------------|---------|--------|------|-------| +| R1 | LR=0.003, 600s | 2.074 | 2.075 | 15.68MB | YES | +| R2 | LR=0.001 | 2.121 | 2.124 | 15.67MB | YES | +| R3 | **Mode 2** (MLP down BWN) | 1.699 | -- | 15.67MB | YES | +| R4 | Mode 2 + EMA@60% | 1.668 | 1.787 | 15.64MB | YES | +| R5 | Mode 2 + EMA@0% | 1.674 | 1.909 | 15.59MB | YES | +| R6 | Triton per-row (buggy) | 2.323 | 2.363 | 14.82MB | YES | +| R7 | Triton fixed, LR=0.002 | 1.659 | 1.667 | 15.66MB | YES | +| R8 | Triton, LR=0.001 | 1.700 | 1.707 | 15.67MB | YES | +| R9 | Triton bf16, BF16 scales | 1.676 | 1.679 | 15.67MB | YES | +| R10 | Triton bf16, FP8 scales | 1.654 | 1.663 | 15.43MB | YES | +| R11 | BF16 everything | 1.665 | 1.666 | 16.35MB | NO | +| R12 | 13L MLP 3x | 1.745 | 1.752 | 16.37MB | NO | +| R13 | 10L MLP 5x | 1.671 | 1.685 | 18.09MB | NO | +| R14 | 11L MLP 4x | 1.708 | 1.707 | 16.89MB | NO | +| R15 | g=512 | 1.689 | 1.694 | 15.37MB | YES | +| R16 | g=1024 | 1.680 | 1.685 | 15.36MB | YES | +| R17 | g=128, softcap=15 | 1.684 | 1.691 | 15.51MB | YES | +| R18 | g=128, softcap=10 | 1.671 | 1.675 | 15.51MB | YES | +| R19 | 8192 BPE | 1.673 | 1.684 | 16.83MB | NO | +| R20 | Gram NS library | 1.740 | 1.743 | 15.47MB | YES | +| R21 | Our Gram NS (bf16) | 1.713 | 1.716 | 15.48MB | YES | +| R22 | Original NS, 3 steps | 1.671 | 1.672 | 15.42MB | YES | +| R23 | NS 2 steps | 1.684 | 1.733 | 15.37MB | YES | +| R24 | NS 5 steps | 1.713 | 1.719 | 15.47MB | YES | +| R25 | Cosine LR=0.004 | 1.636 | 1.639 | 15.43MB | YES | +| R26 | Cosine LR=0.002 | 1.653 | 1.667 | 15.41MB | YES | +| R27 | Cosine LR=0.006 | 1.632 | 1.637 | 15.45MB | YES | +| R28 | Cosine LR=0.008 | 1.629 | 1.635 | 15.46MB | YES | +| R29 | Cosine LR=0.012 | 1.635 | 1.640 | 15.46MB | YES | +| R30 | + Grad clip 1.0 | 1.626 | 1.631 | 15.46MB | YES | +| R31 | + Momentum 0.85 | 1.616 | 1.627 | 15.38MB | YES | +| R32 | Momentum 0.75 | 1.617 | 1.645 | 15.36MB | YES | +| R33 | + Seq len schedule | 1.597 | 1.612 | 15.39MB | YES | +| **R34** | **+ Momentum 0.80** | **1.589** | **1.602** | **15.37MB** | **YES** | +| R35 | Momentum 0.75 + schedule | 1.591 | 1.613 | 15.35MB | YES | +| R36 | 768d 18L embed=512 | 1.634 | 1.651 | 15.90MB | YES | +| R37 | 1024d 10L embed=512 | 1.602 | 1.623 | 15.82MB | YES | +| R38 | 1024d 10L embed=512 + smear | 1.603 | -- | 15.83MB | YES | +| R39 | R34 + EMA@60% | 1.590 | 1.612 | 15.36MB | YES | +| **R40** | **embed=384, BF16 scales** | **1.574** | **1.578** | **15.96MB** | **YES** | +| R41 | AttnRes mode 2 | 1.594 | 1.598 | 15.91MB | YES | +| R42 | AttnRes mode 1 | crash | -- | -- | -- | + +### P-series: Push/Submit (10-min track, 3 seeds) + +| Run | Seed | val_bpb | roundtrip | sliding (s=48) | gap | +|-----|------|---------|-----------|----------------|-----| +| P3 | 42 | 1.582 | 1.591 | 1.556 | 0.009 | +| P4 | 7 | 1.605 | 1.615 | 1.580 | 0.010 | +| P5 | 1337 | 1.598 | 1.600 | 1.565 | 0.002 | +| **Mean** | -- | **1.595** | **1.602** | **1.567** | **0.007** | + +### N-series: Notable Track + +| Run | Config | Steps | val_bpb | roundtrip | sliding | Gap | +|-----|--------|-------|---------|-----------|---------|-----| +| N1 | embed=512, FP8 scales, no QAT | 200k | 1.569 | 2.435 | -- | 0.866 | +| **N2** | **embed=384, BF16 scales, scale QAT** | **100k** | **1.569** | **1.575** | **1.539** | **0.006** | +| N3 | AttnRes mode 2 | 100k | 1.583 | 1.596 | 1.563 | 0.013 | + +--- + +## EGGROLL Exploration + +### Background + +EGGROLL (Sarkar et al. 2026) uses rank-r low-rank perturbations for efficient evolution strategies. Instead of sampling full-rank noise matrices, it samples A in R^(m x r) and B in R^(n x r) and forms E = (1/sqrt(r)) * AB^T. This enables gradient-free optimization that bypasses the STE entirely, directly optimizing the loss function over the binary weight space. + +The motivation for trying EGGROLL on our binary network: the STE is a fundamentally approximate gradient. EGGROLL evaluates the true loss function (with actual sign() and quantization), so it could potentially find better solutions than STE-based gradient descent. + +### Implementation + +Three approaches were implemented: + +1. **Full perturbation:** Perturb all 115M binary weight parameters directly. Each perturbation adds sigma * (1/sqrt(r)) * AB^T to the float weights before binarization. + +2. **Layer-limited perturbation:** Perturb only the last N layers (controlled by EGGROLL_LAYERS). Reduces dimensionality from 115M to 11.5M-34M. + +3. **LoRA perturbation:** Create LoRA adapter pairs (A, B) for each binary weight matrix. Perturb only the LoRA parameters (~614K params at rank 4). Before each forward pass, merge LoRA into base weights, evaluate, then unmerge. Final model merges LoRA permanently. + +### Results: Full Perturbation + +| Run | Start | Pop | Sigma | LR | Rank | Fitness | Result | +|-----|-------|-----|-------|-----|------|---------|--------| +| E1 | Random | 256 | 0.01 | 0.001 | 1 | -6.936 | No learning | +| E2 | Random | 256 | 0.5 | 0.1 | 1 | -6.937 | No learning | +| E3 | Pretrained | 256 | 0.01 | 0.001 | 1 | -10.48 | Diverged | +| E4 | Pretrained | 256 | 0.0001 | 0.00001 | 1 | -10.46 | Diverged slowly | +| E5 | Pretrained | 4096 | 0.0001 | 0.0001 | 1 | -2.98 | Stable, flat | +| E6 | Pretrained | 4096 | 0.0001 | 0.00001 | 8 | -3.24 | Slow divergence | +| E7 | Pretrained | 4096 | 0.0001 | 0.000001 | 8 | -3.01 | Stable, flat | + +From scratch (E1-E2), ES cannot navigate the 115M-dimensional landscape at any sigma or population size. Even 4096 population provides zero useful gradient signal. + +From pretrained weights (E3-E7), the perturbation scale is critical. Too large (E3, sigma=0.01): every perturbation destroys the trained model, so the "best" direction is just "least bad." Too small (E4, sigma=0.0001): fitness differences become noise-dominated. The sweet spot (E5, sigma=0.0001, pop=4096) is stable but shows zero improvement -- every perturbation direction is uphill from the STE-found basin. + +### Results: Layer-Limited Perturbation + +| Run | Layers Perturbed | Params | val_bpb after 10 steps | Degradation | +|-----|-----------------|--------|----------------------|-------------| +| E7 | All 40 tensors | 115M | 1.633 | +0.064 | +| E8 | Last 3 blocks (12 tensors) | 34M | 1.629 | +0.060 | +| E9 | Last 1 block (4 tensors) | 11.5M | 1.609 | +0.040 | + +Fewer parameters to perturb means less damage per step, but still no improvement. The model is at a local optimum in every direction, even when only searching a 11.5M-dimensional subspace. + +### Results: LoRA Perturbation + +| Run | LoRA Rank | Pop | Sigma | LR | LoRA Params | val_bpb after 10 steps | +|-----|-----------|-----|-------|-----|-------------|----------------------| +| E10 | 4 | 4096 | 0.01 | 0.001 | 614K | 1.573 (+0.004) | +| E11 | 4 | 16384 | 0.01 | 0.001 | 614K | -- (3.7 min/step) | + +LoRA brings the perturbation dimensionality down to 614K -- manageable for ES. E10 with pop=4096 was nearly stable (only +0.004 degradation vs +0.040 for direct perturbation of the same parameters). But still no improvement, and E11 at pop=16384 was too slow at 229s/step to be practical. + +### Why EGGROLL Cannot Improve on STE+Muon + +The fundamental issue is signal-to-noise ratio. With rank-1 perturbations in d-dimensional space, the cosine similarity between any random perturbation and the true gradient is approximately 1/sqrt(d). For d=115M, this gives ~0.00009. Population size N improves this by sqrt(N), so pop=4096 gives ~0.006 -- still 99.4% noise. + +The EGGROLL paper's successful pretraining used a 256-dim model with up to 1M population. For 115M params, the required population would be orders of magnitude larger than is practical. + +LoRA reduces d to 614K, giving ~0.04 per perturbation and ~2.5 with pop=4096. Better, but the LoRA subspace may not contain the improvement direction. The STE+Muon optimizer has access to 115M-dimensional gradient information per step, which is fundamentally more informative than 4096 scalar fitness samples. + +--- + +## Multi-Seed Variance + +Three seeds (42, 7, 1337) were run with the P1 config to estimate variance: + +| Metric | Seed 42 | Seed 7 | Seed 1337 | Mean | Std | +|--------|---------|--------|-----------|------|-----| +| val_bpb | 1.582 | 1.605 | 1.598 | 1.595 | 0.012 | +| roundtrip | 1.591 | 1.615 | 1.600 | 1.602 | 0.012 | +| sliding (s=48) | 1.556 | 1.580 | 1.565 | 1.567 | 0.012 | +| gap | 0.009 | 0.010 | 0.002 | 0.007 | 0.004 | + +Standard deviation of ~0.012 bpb across seeds. This is typical for binary networks where early sign choices cascade -- a different random initialization puts the model into a different basin, and small differences compound through the sign() operations. + +--- + +## Final Configuration + +### P1: 10-Minute Track Submission (R40 config) + +```bash +# Architecture +NUM_LAYERS=10 MODEL_DIM=1024 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=4 +EMBED_DIM=384 VOCAB_SIZE=1024 ACTIVATION=signsq ATTN_RES=0 +# XNOR +XNOR_GROUP_SIZE=256 BINARIZE_ACTIVATIONS=2 USE_TRITON_KERNEL=1 +# Storage +FP_STORAGE=FP8 SCALE_STORAGE=BF16 +# Optimizer +MATRIX_OPTIMIZER=muon MATRIX_LR=0.008 MUON_MOMENTUM=0.80 +MUON_BACKEND_STEPS=3 MUON_WD=0.04 +LR_SCHEDULE=cosine GRAD_CLIP_NORM=1.0 +# Schedule +SEQ_LEN_SCHEDULE=1 TRAIN_BATCH_TOKENS=65536 +MAX_WALLCLOCK_SECONDS=600 +``` + +**Best single run: R40 -- 1.574 val, 1.578 roundtrip, 15.96MB** +**Three-seed mean: 1.602 +/- 0.012 roundtrip, 1.567 +/- 0.012 sliding** + +### N2: Notable Track (100k steps) + +Same as P1 but: +```bash +MAX_WALLCLOCK_SECONDS=0 ITERATIONS=100000 CHECKPOINT_EVERY=25000 +SLIDING_EVAL=1 SLIDING_EVAL_STRIDE=16 TEMP_SCALING=1 +``` + +**Result: 1.569 val, 1.575 roundtrip, 1.539 sliding, 15.91MB** + +--- + +## Reproduction + +### Requirements + +- 8x NVIDIA H100 80GB SMX (or equivalent) +- PyTorch 2.10.0+cu128 +- FlashAttention 3 +- Triton 3.6.0 +- Python 3.13 + +### Setup + +```bash +bash setup.sh +conda activate golf +pip install brotli zstandard --break-system-packages +``` + +### Training + +```bash +# 10-minute track +bash run_cuda_xnor_v2.sh + +# Notable track (100k steps, ~65 minutes) +bash run_cuda_xnor_notable.sh + +# EGGROLL exploration +bash run_cuda_eggroll.sh +``` + +### Data + +FineWeb 10B dataset with 1024 BPE tokenizer. 80 training shards, 1 validation shard (~40.5M tokens). + +--- + +## Key Insights + +1. **Binary networks need frequent, small updates.** Batch size 65536 >> 524288 for quality. Each sign() is a discrete decision -- more decisions per second means faster convergence. + +2. **Full XNOR activation binarization has a quality ceiling around 2.0 bpb** due to the MLP information bottleneck. Mode 2 (skipping MLP down proj) breaks through to 1.575. + +3. **Momentum should be lower than standard (0.80 vs 0.95)** because STE gradient noise is amplified by momentum, causing destructive sign oscillations. + +4. **Cosine LR schedule is essential** for binary STE training. Flat LR with warmdown wastes 70% of training time in a divergent plateau. + +5. **Sequence length scheduling provides free improvement** -- short sequences at the start give 8x more gradient updates during the phase where the model needs to learn token frequencies. + +6. **Wider is better than deeper** for binary networks. Each sign() compounds information loss across layers, but wider MLP gives more capacity per layer. + +7. **EMA is harmful** for binary models -- the averaged weights have less decisive signs that don't survive quantization. + +8. **Scale QAT is essential for long training runs.** Without it, FP8 scale quantization error accumulates over steps and causes catastrophic roundtrip degradation (0.87 bpb gap at 200k steps). + +9. **Attention Residuals add overhead without benefit for binary networks.** The 33% slower steps reduce training progress more than depth-wise attention helps. Simple U-Net skips are sufficient. + +10. **EGGROLL cannot improve on STE+Muon at 115M parameters.** The signal-to-noise ratio of zeroth-order methods is too low for practical population sizes. Even LoRA-based EGGROLL (614K params, pop=4096) shows no improvement from the STE-found basin. + +--- + +## References + +- Rastegari, M. et al. "XNOR-Net: ImageNet Classification Using Binary Convolutional Neural Networks." ECCV 2016. +- Sarkar, B. et al. "Evolution Strategies at the Hyperscale." arXiv:2511.16652, 2026. +- Zhang, J. et al. "Gram Newton-Schulz: A Fast, Hardware-Aware Newton-Schulz Algorithm for Muon." dao-ailab, 2026. +- Kimi Team. "Attention Residuals." 2026. + +--- + +## License + +This project is part of the OpenAI Parameter Golf Challenge submission. diff --git a/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/log_100k_steps_per-pass-attention-residuals.txt b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/log_100k_steps_per-pass-attention-residuals.txt new file mode 100644 index 0000000000..3a8e297f56 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/log_100k_steps_per-pass-attention-residuals.txt @@ -0,0 +1,1818 @@ +"""XNOR-Net LLM training script for OpenAI's Parameter Golf Challenge. +Ciprian-Florin Ifrim — April 2026 + +Based on the Binary-Weight-Network script, extended to full XNOR-Net: +both weights AND activations are binarized using sign() with per-group/per-token +scaling factors derived from the XNOR-Net paper (Rastegari et al. 2016). + +This is the REFERENCE implementation using STE-simulated XNOR via F.linear. +The Triton INT8×INT8 kernel version comes later for H100 deployment. + +Architecture: U-Net transformer with skip connections — provides error correction +for the information loss inherent in binary quantization of both weights and activations. +""" + +import copy +import glob +import io +import math +import os +import random +import sys +import time +import lzma +from pathlib import Path +try: + import brotli + HAS_BROTLI = True +except ImportError: + HAS_BROTLI = False +try: + import zstandard as zstd + HAS_ZSTD = True +except ImportError: + HAS_ZSTD = False +try: + from gram_newton_schulz import GramNewtonSchulz, POLAR_EXPRESS_COEFFICIENTS + HAS_GRAM_NS = True +except ImportError: + HAS_GRAM_NS = False +import numpy as np +import warnings +warnings.filterwarnings("ignore", category=UserWarning, module="torch._inductor") +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False +def _get_flash_attn_func(): + """GPU-aware FlashAttention selection. + H100: FA3 -> FA2 -> SDPA + 5090: FA4 -> FA2 -> SDPA + Other: FA2 -> SDPA + """ + if torch.cuda.is_available(): + sm = torch.cuda.get_device_capability() + # Hopper (SM 90) — prefer FA3 + if sm[0] == 9 and sm[1] == 0: + try: + from flash_attn_interface import flash_attn_func + return flash_attn_func + except ImportError: + pass + # Blackwell (SM 100/120) — try FA4 + if sm[0] >= 10: + try: + from flash_attn.cute import flash_attn_func + return flash_attn_func + except ImportError: + pass + # FA2 — universal fallback with CUDA kernels + try: + from flash_attn import flash_attn_func + return flash_attn_func + except ImportError: + pass + # PyTorch native SDPA — always works, no install needed + def _sdpa_flash_attn_func(q, k, v, causal=False, **kwargs): + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + o = F.scaled_dot_product_attention(q, k, v, is_causal=causal, enable_gqa=True) + return o.transpose(1, 2) + return _sdpa_flash_attn_func +flash_attn_func = _get_flash_attn_func() + +# --------------------------------------------------------------------------- +# Hyperparameters +# --------------------------------------------------------------------------- +def _e(k, d, t=str): + v = os.environ.get(k, str(d)) + if t == bool: return bool(int(v)) + return t(v) + +class Hyperparameters: + # Data + data_path = _e("DATA_PATH", "./data/datasets/fineweb10B_sp8192") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = _e("TOKENIZER_PATH", "./data/tokenizers/fineweb_8192_bpe.model") + run_id = os.environ.get("RUN_ID", f"xnor_{int(time.time())}") + seed = _e("SEED", 1337, int) + # Compile + compile_mode = _e("COMPILE_MODE", "default") + # Eval + val_batch_size = _e("VAL_BATCH_SIZE", 524288, int) + val_loss_every = _e("VAL_LOSS_EVERY", 500, int) + train_log_every = _e("TRAIN_LOG_EVERY", 10, int) + sliding_eval = _e("SLIDING_EVAL", 0, bool) + sliding_eval_stride = _e("SLIDING_EVAL_STRIDE", 16, int) + sliding_batch_size = _e("SLIDING_BATCH_SIZE", 512, int) + temp_scaling = _e("TEMP_SCALING", 0, bool) + # Training schedule + iterations = _e("ITERATIONS", 2000, int) + warmdown_fraction = _e("WARMDOWN_FRACTION", 0.15, float) + lr_warmup_steps = _e("LR_WARMUP_STEPS", 5, int) + lr_schedule = _e("LR_SCHEDULE", "linear") # "linear" (current) or "cosine" + seq_len_schedule = _e("SEQ_LEN_SCHEDULE", 0, bool) # ramp seq_len from 128 to train_seq_len + train_batch_tokens = _e("TRAIN_BATCH_TOKENS", 524288, int) + train_seq_len = _e("TRAIN_SEQ_LEN", 1024, int) + max_wallclock_seconds = _e("MAX_WALLCLOCK_SECONDS", 0.0, float) + # Architecture — from XNOR-UNet spec + vocab_size = _e("VOCAB_SIZE", 8192, int) + num_layers = _e("NUM_LAYERS", 15, int) + num_kv_heads = _e("NUM_KV_HEADS", 4, int) + model_dim = _e("MODEL_DIM", 768, int) + num_heads = _e("NUM_HEADS", 8, int) + mlp_mult = _e("MLP_MULT", 4, int) + tie_embeddings = _e("TIE_EMBEDDINGS", 1, int) + embed_dim = _e("EMBED_DIM", 312, int) + activation_type = _e("ACTIVATION", "relu2") + # Attention + rope_base = _e("ROPE_BASE", 5000.0, float) + rope_type = _e("ROPE_TYPE", "yarn") + yarn_max_len = _e("YARN_MAX_LEN", 2048, int) + qk_gain_init = _e("QK_GAIN_INIT", 2.25, float) + # Logits + logit_softcap = _e("LOGIT_SOFTCAP", 10.0, float) + softcap_type = _e("SOFTCAP_TYPE", "poly") + tied_embed_init_std = _e("TIED_EMBED_INIT_STD", 0.005, float) + # Architecture extras + smear = _e("SMEAR", 1, bool) + attn_res = _e("ATTN_RES", 0, int) # 0=disabled, 1=sub-layer, 2=pass-level + # XNOR quantization + xnor_group_size = _e("XNOR_GROUP_SIZE", 128, int) + # Controls activation binarization: 0 = none (BWN), 1 = all (full XNOR), 2 = skip MLP down proj + binarize_activations = _e("BINARIZE_ACTIVATIONS", 1, int) + # Use INT8 tensor cores for binary matmul (speed test) + use_int8_kernel = _e("USE_INT8_KERNEL", 0, bool) + # Use Triton XNOR+POPCOUNT kernel (true 1-bit matmul) + use_triton_kernel = _e("USE_TRITON_KERNEL", 0, bool) + # FP8 QAT for non-binary params + _fp_raw = os.environ.get("FP_STORAGE", "FP8") + fp_storage = True if _fp_raw == "FP8" else False + # Scale storage for binary weight groups: FP8 or BF16 + _scale_raw = os.environ.get("SCALE_STORAGE", "BF16") + scale_fp8 = True if _scale_raw == "FP8" else False + # Compression regularizer: penalizes sign entropy within groups + sign_compress_reg = _e("SIGN_COMPRESS_REG", 0.0, float) + # Optimizer + matrix_lr = _e("MATRIX_LR", 0.04, float) + scalar_lr = _e("SCALAR_LR", 0.02, float) + tied_embed_lr = _e("TIED_EMBED_LR", 0.02, float) + embed_lr = _e("EMBED_LR", 0.6, float) + head_lr = _e("HEAD_LR", 0.02, float) + matrix_optimizer = _e("MATRIX_OPTIMIZER", "muon") + muon_momentum = _e("MUON_MOMENTUM", 0.95, float) + muon_backend_steps = _e("MUON_BACKEND_STEPS", 3, int) + muon_wd = _e("MUON_WD", 0.0, float) + use_gram_ns = _e("USE_GRAM_NS", 0, bool) + muon_momentum_warmup_start = _e("MUON_MOMENTUM_WARMUP_START", 0.85, float) + muon_momentum_warmup_steps = _e("MUON_MOMENTUM_WARMUP_STEPS", 500, int) + adam_lr = _e("ADAM_LR", 0.05, float) + adam_wd = _e("ADAM_WD", 0.05, float) + beta1 = _e("BETA1", 0.9, float) + beta2 = _e("BETA2", 0.95, float) + adam_eps = _e("ADAM_EPS", 1e-8, float) + grad_clip_norm = _e("GRAD_CLIP_NORM", 0.0, float) + # EMA — compatible with binary (no zero-weight collapse unlike ternary) + ema = _e("EMA", 0, bool) + ema_decay = _e("EMA_DECAY", 0.995, float) + ema_start_fraction = _e("EMA_START_FRACTION", 0.60, float) + # Checkpointing + checkpoint_every = _e("CHECKPOINT_EVERY", 0, int) + checkpoint_dir = _e("CHECKPOINT_DIR", "./checkpoints") + # Diagnostics + churn_log_every = _e("CHURN_LOG_EVERY", 500, int) + +# Scalar/low-dim parameter names (excluded from Muon, use Adam) +CTP = ("attn_scale", "mlp_scale", "resid_mix", "q_gain", "skip_weight", + "skip_weights", "vocab_bias", "smear") + +# --------------------------------------------------------------------------- +# Binary packing (1 bit per weight, lossless) +# --------------------------------------------------------------------------- +def pack_binary(q: Tensor) -> tuple[bytes, int]: + bits = ((q.reshape(-1).to(torch.int8) + 1) // 2).numpy().astype(np.uint8) + n = len(bits) + pad = (8 - n % 8) % 8 + if pad: + bits = np.concatenate([bits, np.zeros(pad, dtype=np.uint8)]) + groups = bits.reshape(-1, 8) + packed = np.zeros(len(groups), dtype=np.uint8) + for i in range(8): + packed |= groups[:, i] << i + return packed.tobytes(), n + +def unpack_binary(data: bytes, n: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8) + bits = np.zeros((len(packed), 8), dtype=np.int8) + for i in range(8): + bits[:, i] = (packed >> i) & 1 + flat = bits.reshape(-1)[:n] + return torch.from_numpy(flat.astype(np.int8) * 2 - 1) + +# --------------------------------------------------------------------------- +# State dict serialization +# --------------------------------------------------------------------------- +def q_sd(state_dict: dict, group_size: int = 128, fp_storage=False, scale_fp8=False) -> tuple[dict, dict]: + """Binary pack large 2D weight matrices, BF16/FP8 for everything else.""" + quantized = {} + stats = {"binary_params": 0, "binary_bytes": 0, "fp_params": 0, "fp_bytes": 0} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().float().contiguous() + t_orig_shape = list(t.shape) + if t.ndim == 3: + t = t.reshape(t.shape[0], -1) + is_binary = ( + t.ndim == 2 and t.numel() > 65_536 + and "tok_emb" not in name and "lm_head" not in name + and "embed_proj" not in name + ) + if is_binary: + pad = (group_size - t.shape[1] % group_size) % group_size + t_padded = F.pad(t, (0, pad)) if pad > 0 else t + t_grouped = t_padded.reshape(-1, group_size) + scale = t_grouped.abs().mean(-1, keepdim=True).clamp(min=1e-8) + q = torch.where(t_grouped >= 0, + torch.ones_like(t_grouped, dtype=torch.int8), + -torch.ones_like(t_grouped, dtype=torch.int8)) + packed_bytes, n_bits = pack_binary(q) + if scale_fp8: + scale_stored = scale.to(torch.float8_e4m3fn).squeeze(-1) + scale_bytes_per = 1 + else: + scale_stored = scale.bfloat16().squeeze(-1) + scale_bytes_per = 2 + quantized[name] = { + "type": "binary", "packed": packed_bytes, + "scale": scale_stored, + "shape": list(t.shape), "padded_cols": t_padded.shape[1], + "group_size": group_size, "n_bits": n_bits, + "orig_shape": t_orig_shape, + } + stats["binary_params"] += t.numel() + stats["binary_bytes"] += len(packed_bytes) + scale.numel() * scale_bytes_per + elif fp_storage and t.ndim == 2: + quantized[name] = {"type": "fp8", "data": t.to(torch.float8_e4m3fn)} + stats["fp_params"] += t.numel() + stats["fp_bytes"] += t.numel() + else: + quantized[name] = {"type": "bf16", "data": t.bfloat16()} + stats["fp_params"] += t.numel() + stats["fp_bytes"] += t.numel() * 2 + return quantized, stats + +def deq_sd(quantized: dict, target_dtype=torch.bfloat16): + """Reconstruct state dict from quantized representation.""" + out = {} + for name, entry in quantized.items(): + if entry["type"] == "binary": + q = unpack_binary(entry["packed"], entry["n_bits"]) + q = q.float().reshape(-1, entry["group_size"]) + scale = entry["scale"].float().unsqueeze(-1) + # No shrinkage correction: binary q_absmean = 1.0 always + t = (q * scale).reshape(-1, entry["padded_cols"]) + shape = entry["shape"] + result = t[:shape[0], :shape[1]].to(target_dtype) + orig = entry.get("orig_shape") + out[name] = result.reshape(orig).contiguous() if orig and orig != shape else result.contiguous() + elif entry["type"] == "fp8": + out[name] = entry["data"].to(torch.float32).to(target_dtype).contiguous() + else: + out[name] = entry["data"].to(target_dtype).contiguous() + return out + +# --------------------------------------------------------------------------- +# Diagnostics +# --------------------------------------------------------------------------- +_prev_committed: dict = {} +def churn_fn(model: nn.Module, group_size: int = 128): + global _prev_committed + total = flipped = 0 + with torch.no_grad(): + for name, p in model.named_parameters(): + if p.ndim == 2 and "weight" in name and p.shape[0] > 1: + w = p.detach().float().reshape(-1, group_size) + q = torch.where(w >= 0, torch.ones_like(w), -torch.ones_like(w)).cpu().numpy() + if name in _prev_committed: + flipped += int(np.sum(q != _prev_committed[name])) + total += q.size + _prev_committed[name] = q + return flipped / max(total, 1) + +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# Muon optimizer +# --------------------------------------------------------------------------- +def ns_orth(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +def gram_ns_orth(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: + """Gram Newton-Schulz with Polar Express coefficients. 5 steps, restart at step 2. + Fully unrolled for torch.compile. bfloat16 throughout.""" + X = G.bfloat16() + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + X = X / (X.norm() + eps) + n = X.size(0) + I = torch.eye(n, device=X.device, dtype=X.dtype) + + # Phase 1: steps 0-1 + R = X @ X.T + # Step 0: a=4.0848, b=-6.8946, c=2.9270 + Z = -6.8946 * R + 2.9270 * (R @ R) + Q = Z + 4.0848 * I + RZ = 4.0848 * R + R @ Z + R = 4.0848 * RZ + Z @ RZ + # Step 1: a=3.9505, b=-6.3029, c=2.6377 + Z = -6.3029 * R + 2.6377 * (R @ R) + Q = 3.9505 * Q + Q @ Z + # No R update (next is reset) + + # Reset + X = Q @ X + R = X @ X.T + + # Phase 2: steps 2-4 + # Step 2: a=3.7418, b=-5.5913, c=2.3037 + Z = -5.5913 * R + 2.3037 * (R @ R) + Q = Z + 3.7418 * I + RZ = 3.7418 * R + R @ Z + R = 3.7418 * RZ + Z @ RZ + # Step 3: a=2.8769, b=-3.1427, c=1.2046 + Z = -3.1427 * R + 1.2046 * (R @ R) + Q = 2.8769 * Q + Q @ Z + RZ = 2.8769 * R + R @ Z + R = 2.8769 * RZ + Z @ RZ + # Step 4: a=2.8366, b=-3.0525, c=1.2012 + Z = -3.0525 * R + 1.2012 * (R @ R) + Q = 2.8366 * Q + Q @ Z + + X = Q @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr, momentum, backend_steps, nesterov=True, wd=0.0): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, wd=wd)) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr, momentum = group["lr"], group["momentum"] + backend_steps, nesterov = group["backend_steps"], group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = F.rms_norm(g.float(), (g.size(-1),)).bfloat16() + g = ns_orth(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("wd", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr:curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.mul_(1 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + +# --------------------------------------------------------------------------- +# Data loading +# --------------------------------------------------------------------------- +def ld_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" Tensor: + chunks = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int): + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start:start + per_rank_span].pin_memory().to(self.device, non_blocking=True).to(torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x, y + +# --------------------------------------------------------------------------- +# Model components +# --------------------------------------------------------------------------- +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def apply_qat_ste(w: Tensor, fp_storage) -> Tensor: + if not fp_storage: + return w + w_sim = w.to(torch.float8_e4m3fn).to(w.dtype) + return (w_sim - w).detach() + w + +class QATLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=False, fp_storage=False): + super().__init__(in_features, out_features, bias=bias) + self.fp_storage = fp_storage + def forward(self, x: Tensor) -> Tensor: + w = apply_qat_ste(self.weight, self.fp_storage) + return F.linear(x, w.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) + +class QATEmbedding(nn.Embedding): + def __init__(self, num_embeddings, embedding_dim, fp_storage=False): + super().__init__(num_embeddings, embedding_dim) + self.fp_storage = fp_storage + def forward(self, input: Tensor) -> Tensor: + w = apply_qat_ste(self.weight, self.fp_storage) + return F.embedding(input, w, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + +# --------------------------------------------------------------------------- +# XNOR-Net Linear Layers +# --------------------------------------------------------------------------- +_INT8_KERNEL = False # set from Hyperparameters.use_int8_kernel in main() + +class _INT8XNORFn(torch.autograd.Function): + """INT8 tensor core forward, BF16 backward. + Uses per-row alpha (not per-group) so we get a single torch._int_mm call. + """ + @staticmethod + def forward(ctx, x, w): + # Per-row weight scale (single matmul) + w_alpha = w.detach().abs().mean(dim=-1).clamp_(min=1e-8) # [N] + x_beta = x.detach().abs().mean(dim=-1, keepdim=True).clamp_(min=1e-8) # [..., 1] + + w_int8 = w.detach().sign().to(torch.int8) # [N, K] + x_int8 = x.detach().sign().to(torch.int8) # [..., K] + + # Flatten to 2D for torch._int_mm(A[M,K], B[K,N]) -> [M,N] + orig_shape = x.shape[:-1] + M, K, N = orig_shape.numel(), x.shape[-1], w.shape[0] + x_2d = x_int8.reshape(M, K) + wt = w_int8.t().contiguous() # [K, N] + raw = torch._int_mm(x_2d, wt) # [M, N] INT32 + + y = raw.to(x.dtype) * x_beta.reshape(M, 1) * w_alpha.unsqueeze(0) + ctx.save_for_backward(x, w, w_alpha, x_beta) + return y.reshape(*orig_shape, N) + + @staticmethod + def backward(ctx, grad_output): + x, w, w_alpha, x_beta = ctx.saved_tensors + dtype = grad_output.dtype + + w_binary = (w.sign() * w_alpha.unsqueeze(-1)).to(dtype) # [N, K] + x_binary = (x.sign() * x_beta).to(dtype) # [..., K] + + # STE: gradient flows through sign() as identity + grad_x = grad_output @ w_binary + M = grad_output.shape[:-1].numel() + grad_w = grad_output.reshape(M, -1).t() @ x_binary.reshape(M, -1) + return grad_x, grad_w + +# --------------------------------------------------------------------------- +# Triton XNOR+POPCOUNT kernel — true 1-bit matmul +# --------------------------------------------------------------------------- +_TRITON_KERNEL = False # set from Hyperparameters in main() +_SCALE_FP8_STE = False # simulate FP8 scale quantization during training + +def pack_signs(x: Tensor) -> Tensor: + """Pack sign bits into int32. +1 -> bit 1, -1 -> bit 0. + x: [..., K] any dtype. K must be divisible by 32. + returns: [..., K//32] int32 + """ + K = x.shape[-1] + assert K % 32 == 0, f"K={K} must be divisible by 32" + bits = (x > 0).view(*x.shape[:-1], K // 32, 32) + shifts = torch.arange(32, device=x.device, dtype=torch.int32) + return (bits.to(torch.int32) << shifts).sum(dim=-1) + +if HAS_TRITON: + @triton.jit + def _xnor_matmul_fwd_kernel( + X_ptr, W_ptr, Y_ptr, alpha_ptr, beta_ptr, + M, N, K_packed, GROUP_PACKED: tl.constexpr, num_groups, + stride_xm, stride_wn, stride_yn, stride_alpha_n, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + ): + pid = tl.program_id(0) + num_n = tl.cdiv(N, BLOCK_N) + pid_m = pid // num_n + pid_n = pid % num_n + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_m = offs_m < M + mask_n = offs_n < N + + y_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.bfloat16) + + x_base = X_ptr + offs_m * stride_xm + w_base = W_ptr + offs_n * stride_wn + + # Loop over groups, each with GROUP_PACKED packed int32s + for g in range(num_groups): + group_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + k_start = g * GROUP_PACKED + for k in range(GROUP_PACKED): + x = tl.load(x_base + k_start + k, mask=mask_m, other=0) + w = tl.load(w_base + k_start + k, mask=mask_n, other=0) + # Cast broadcast XOR to int32 to prevent int64 promotion in popc + diff = (x[:, None] ^ w[None, :]).to(tl.int32) + group_acc += tl.extra.cuda.libdevice.popc(diff) + + # diff_bits = popc(XOR), dot = GROUP_SIZE - 2 * diff_bits + group_dot = (GROUP_PACKED * 32 - 2 * group_acc).to(tl.bfloat16) + + # Per-group alpha: alpha[n, g] + alpha_g = tl.load(alpha_ptr + offs_n * stride_alpha_n + g, mask=mask_n, other=1.0).to(tl.bfloat16) + y_acc += group_dot * alpha_g[None, :] + + # Per-token beta + beta = tl.load(beta_ptr + offs_m, mask=mask_m, other=1.0).to(tl.bfloat16) + y_acc = y_acc * beta[:, None] + + # Store + y_ptrs = Y_ptr + offs_m[:, None] * stride_yn + offs_n[None, :] + tl.store(y_ptrs, y_acc, mask=mask_m[:, None] & mask_n[None, :]) + +class _TritonXNORFn(torch.autograd.Function): + """Triton XNOR+POPCOUNT forward with per-group alpha, BF16 backward.""" + @staticmethod + def forward(ctx, x, w, group_size): + N, K = w.shape + assert K % group_size == 0, f"K={K} must be divisible by group_size={group_size}" + assert K % 32 == 0, f"K={K} must be divisible by 32" + num_groups = K // group_size + GROUP_PACKED = group_size // 32 + + # Per-group weight scale: [N, num_groups] + w_alpha = w.detach().reshape(N, num_groups, group_size).abs().mean(dim=-1).clamp_(min=1e-8) + if _SCALE_FP8_STE: + w_alpha = w_alpha.to(torch.float8_e4m3fn).to(w_alpha.dtype) + # Per-token activation scale: [..., 1] + x_beta_kd = x.detach().abs().mean(dim=-1, keepdim=True).clamp_(min=1e-8) + + # Pack signs into int32 + x_packed = pack_signs(x.detach()) # [..., K//32] + w_packed = pack_signs(w.detach()) # [N, K//32] + + orig_shape = x.shape[:-1] + M = orig_shape.numel() + K_packed = K // 32 + + x_2d = x_packed.reshape(M, K_packed).contiguous() + w_alpha_c = w_alpha.contiguous() # [N, num_groups] + beta_flat = x_beta_kd.reshape(M).contiguous() + + Y = torch.empty(M, N, device=x.device, dtype=torch.bfloat16) + + BLOCK_M, BLOCK_N = 64, 64 + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),) + + _xnor_matmul_fwd_kernel[grid]( + x_2d, w_packed, Y, w_alpha_c, beta_flat, + M, N, K_packed, GROUP_PACKED, num_groups, + x_2d.stride(0), w_packed.stride(0), Y.stride(0), w_alpha_c.stride(0), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + ) + + ctx.save_for_backward(x, w, w_alpha, x_beta_kd) + ctx.group_size = group_size + return Y.reshape(*orig_shape, N) + + @staticmethod + def backward(ctx, grad_output): + x, w, w_alpha, x_beta_kd = ctx.saved_tensors + dtype = grad_output.dtype + group_size = ctx.group_size + + # Reconstruct per-group binary weights for backward + N, K = w.shape + num_groups = K // group_size + w_binary = (w.sign().reshape(N, num_groups, group_size) * + w_alpha.unsqueeze(-1)).reshape(N, K).to(dtype) + x_binary = (x.sign() * x_beta_kd).to(dtype) + + grad_x = grad_output @ w_binary + M = grad_output.shape[:-1].numel() + grad_w = grad_output.reshape(M, -1).t() @ x_binary.reshape(M, -1) + return grad_x, grad_w, None # None for group_size + +class XNORLinear(nn.Linear): + """XNOR-Net linear layer: binarizes BOTH weights AND activations. + + Forward pass: + 1. Weight binarization: W_bin = sign(W) * alpha, alpha = mean(|W|) per group + 2. Activation binarization: X_bin = sign(X) * beta, beta = mean(|X|) per token + 3. Output = F.linear(X_bin, W_bin) + + Backward pass: Straight-Through Estimator (STE) — gradients flow through as + if the binarization was identity. This matches the XNOR-Net paper's training + algorithm (Algorithm 1). + + When binarize_act=False, this degrades to Binary-Weight-Network (BWN). + """ + def __init__(self, in_features, out_features, bias=False, group_size=128, + binarize_act=True): + super().__init__(in_features, out_features, bias=bias) + self.group_size = group_size + self.binarize_act = binarize_act + + def forward(self, x: Tensor) -> Tensor: + # Triton XNOR+POPCOUNT path (full XNOR only, per-row alpha) + if _TRITON_KERNEL and self.binarize_act: + y = _TritonXNORFn.apply(x, self.weight, self.group_size) + if self.bias is not None: + y = y + self.bias.to(y.dtype) + return y + # INT8 tensor core path (full XNOR only, per-row alpha) + if _INT8_KERNEL and self.binarize_act: + y = _INT8XNORFn.apply(x, self.weight) + if self.bias is not None: + y = y + self.bias.to(y.dtype) + return y + + # Standard STE path (per-group alpha) + w = self.weight.bfloat16() + g = self.group_size + + # --- Weight binarization (STE) --- + w_g = w.reshape(-1, g) + alpha = w_g.abs().mean(-1, keepdim=True).clamp(min=1e-8) + # Simulate FP8 scale quantization so model compensates for roundtrip error + if _SCALE_FP8_STE: + alpha_q = alpha.to(torch.float8_e4m3fn).to(alpha.dtype) + alpha = alpha + (alpha_q - alpha).detach() + w_sign = torch.where(w_g >= 0, torch.ones_like(w_g), -torch.ones_like(w_g)) + w_binary = w + ((w_sign * alpha).reshape(w.shape) - w).detach() + + # --- Activation binarization (STE) --- + if self.binarize_act: + beta = x.abs().mean(dim=-1, keepdim=True).clamp(min=1e-8) + x_sign = torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x)) + x_binary = x + ((x_sign * beta) - x).detach() + else: + x_binary = x + + return F.linear(x_binary, w_binary, + self.bias.to(x.dtype) if self.bias is not None else None) + + +class NormedXNORLinear(XNORLinear): + """XNOR linear with RMSNorm on input. + Used for output projections receiving un-normalized activations (attention out, MLP down). + The RMSNorm ensures zero-mean input before binarization — critical for sign() quality + (XNOR-Net paper: B-A-C-P block order, Table 3b: 30.3% vs 44.2%). + """ + def forward(self, x: Tensor) -> Tensor: + return super().forward(F.rms_norm(x, (x.size(-1),))) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(p in name for p in CTP)) and param.dtype != torch.float32: + param.data = param.data.float() + +# --------------------------------------------------------------------------- +# Rotary embeddings +# --------------------------------------------------------------------------- +class Rotary(nn.Module): + def __init__(self, dim, base=10000.0, rope_type="rope", yarn_max_len=4096, train_seq_len=1024): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + if rope_type == "yarn": + scale = train_seq_len / yarn_max_len + freq_idx = torch.arange(0, dim, 2, dtype=torch.float32) + ramp = torch.clamp((freq_idx / dim - 0.25) / 0.75, 0.0, 1.0) + inv_freq = inv_freq / (ramp * (1.0 / scale - 1.0) + 1.0) + # Precompute for max seq_len, slice at runtime + t = torch.arange(train_seq_len, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self.register_buffer("_cos", freqs.cos()[None, :, None, :], persistent=False) + self.register_buffer("_sin", freqs.sin()[None, :, None, :], persistent=False) + + def forward(self, seq_len, device, dtype): + return self._cos[:, :seq_len].to(dtype=dtype), self._sin[:, :seq_len].to(dtype=dtype) + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- +class CausalSelfAttention(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + group_size=128, rope_type="rope", yarn_max_len=4096, + train_seq_len=1024, binarize_act=True): + super().__init__() + self.num_heads, self.num_kv_heads = num_heads, num_kv_heads + self.head_dim = dim // num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + # Fused QKV — single XNOR matmul + self.c_qkv = XNORLinear(dim, self.q_size + 2 * self.kv_size, bias=False, + group_size=group_size, binarize_act=binarize_act) + # Output projection with RMSNorm before binarization (B-A-C-P order) + self.proj = NormedXNORLinear(dim, dim, bias=False, + group_size=group_size, binarize_act=binarize_act) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, + rope_type=rope_type, yarn_max_len=yarn_max_len, + train_seq_len=train_seq_len) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + qkv_out = self.c_qkv(x) + q_out, k_out, v_out = qkv_out.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q_out.reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = k_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = v_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + # QK normalization (BF16, stays float — non-negotiable for attention quality) + q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + # FlashAttention-3 (BF16 — attention scores/softmax/V are NEVER binarized) + y = flash_attn_func(q.contiguous(), k.contiguous(), v.contiguous(), causal=True) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + +# --------------------------------------------------------------------------- +# MLP +# --------------------------------------------------------------------------- +class MLP(nn.Module): + def __init__(self, dim, mlp_mult, group_size=128, activation="relu2", + binarize_act=True, binarize_down=True): + super().__init__() + hidden = mlp_mult * dim + self.activation = activation + if activation == "swiglu": + self.gate_up = XNORLinear(dim, hidden * 2, bias=False, + group_size=group_size, binarize_act=binarize_act) + else: + self.fc = XNORLinear(dim, hidden, bias=False, + group_size=group_size, binarize_act=binarize_act) + self.proj = NormedXNORLinear(hidden, dim, bias=False, + group_size=group_size, binarize_act=binarize_down) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + if self.activation == "swiglu": + gu = self.gate_up(x) + gate, up = gu.chunk(2, dim=-1) + return self.proj(F.silu(gate) * up) + elif self.activation == "signsq": + h = self.fc(x) + return self.proj(h * h.abs()) + else: # relu2 + return self.proj(torch.relu(self.fc(x)).square()) + +# --------------------------------------------------------------------------- +# SmearModule +# --------------------------------------------------------------------------- +class SmearModule(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + cumsum = x.cumsum(dim=1) + counts = torch.arange(1, x.size(1) + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smeared = cumsum / counts + gate = torch.tanh(self.gate.to(dtype=x.dtype)) + return x + gate * (smeared - x) + +# --------------------------------------------------------------------------- +# Transformer block +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# Attention Residuals (AttnRes) — depth-wise attention over prior outputs +# From: "Attention Residuals" (Kimi Team, 2026) +# --------------------------------------------------------------------------- +class AttnRes(nn.Module): + def __init__(self, dim, n_queries): + super().__init__() + # Zero-init -> uniform weights at start -> equivalent to standard residual + self.queries = nn.ParameterList([ + nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + for _ in range(n_queries) + ]) + self.key_norm_weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + + def attend(self, stored, query_idx): + w = self.queries[query_idx] + logits = [] + for v in stored: + k = F.rms_norm(v, (v.size(-1),)) * self.key_norm_weight.to(v.dtype) + logits.append(torch.einsum('d, b t d -> b t', w.to(v.dtype), k)) + logits = torch.stack(logits, dim=0) + weights = F.softmax(logits, dim=0) + out = torch.zeros_like(stored[0]) + for i, v in enumerate(stored): + out = out + weights[i].unsqueeze(-1) * v + return out + +# --------------------------------------------------------------------------- +# Transformer Block +# --------------------------------------------------------------------------- +class Block(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, group_size=128, activation="relu2", + rope_type="rope", yarn_max_len=4096, train_seq_len=1024, + binarize_act=True, binarize_down=True, smear=False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, + qk_gain_init, group_size, rope_type, + yarn_max_len, train_seq_len, binarize_act) + self.mlp = MLP(dim, mlp_mult, group_size, activation, binarize_act, binarize_down) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearModule(dim) if smear else None + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0] * x + mix[1] * x0 + n = self.attn_norm(x) + x = x + self.attn_scale.to(dtype=x.dtype) * self.attn(n) + x = x + self.mlp_scale.to(dtype=x.dtype) * self.mlp(self.mlp_norm(x)) + if self.smear is not None: + x = self.smear(x) + return x + +# --------------------------------------------------------------------------- +# GPT with U-Net skip connections +# --------------------------------------------------------------------------- +class GPT(nn.Module): + def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, + mlp_mult, tie_embeddings, tied_embed_init_std, logit_softcap, + rope_base, qk_gain_init, group_size=128, activation="relu2", + embed_dim=0, fp_storage=False, softcap_type="poly", + rope_type="rope", yarn_max_len=4096, train_seq_len=1024, + binarize_act=True, binarize_down=True, smear=False, attn_res=0): + super().__init__() + self.tie_embeddings = tie_embeddings + self.logit_softcap = logit_softcap + self.softcap_type = softcap_type + self.fp_storage = fp_storage + self.embed_dim = embed_dim if embed_dim > 0 else model_dim + self.attn_res_mode = attn_res + + # Embedding (FP8 QAT — NEVER binarized, confirmed 0.11 bpb RT gap) + self.tok_emb = QATEmbedding(vocab_size, self.embed_dim, fp_storage=fp_storage) + # Projection from embed_dim to model_dim (and reverse for tied logits) + self.embed_proj = QATLinear(self.embed_dim, model_dim, bias=False, + fp_storage=fp_storage) if self.embed_dim != model_dim else None + self.embed_proj_rev = QATLinear(model_dim, self.embed_dim, bias=False, + fp_storage=fp_storage) if self.embed_dim != model_dim else None + + # U-Net structure + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + # Skip weights initialized to ONES (not zeros — zeros costs 0.010 bpb) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + + # Transformer blocks + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, group_size, activation, rope_type, yarn_max_len, + train_seq_len, binarize_act, binarize_down, smear) + for _ in range(num_layers) + ]) + self.final_norm = RMSNorm() + + # Attention Residuals (replaces U-Net skips when enabled) + self.attn_res = None + if attn_res == 2: # pass-level + self.attn_res = AttnRes(model_dim, n_queries=num_layers) + elif attn_res == 1: # sub-layer + self.attn_res = AttnRes(model_dim, n_queries=num_layers * 2) + + # Logit head (tied to embedding) + self.lm_head = QATLinear(model_dim, vocab_size, bias=False, fp_storage=fp_storage) + self.lm_head._zero_init = True + if tie_embeddings: + self.lm_head.weight.requires_grad_(False) + self.vocab_bias = nn.Parameter(torch.zeros(vocab_size, dtype=torch.float32)) + + self._init_weights(tied_embed_init_std) + + def _init_weights(self, tied_embed_init_std): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=tied_embed_init_std) + for module in self.modules(): + if isinstance(module, XNORLinear) and not getattr(module, "_zero_init", False): + nn.init.normal_(module.weight, mean=0.0, std=0.02) + elif isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _compute_logits(self, x: Tensor) -> Tensor: + if self.tie_embeddings: + proj = self.embed_proj_rev(x) if self.embed_proj_rev is not None else x + return F.linear(proj, self.tok_emb.weight.to(x.dtype)) + return self.lm_head(x) + + def _softcap(self, logits: Tensor) -> Tensor: + s = self.logit_softcap + if self.softcap_type == "tanh": + return s * torch.tanh(logits / s) + # Poly softcap — fuses with torch.compile (tanh does not) + x_sc = torch.clamp(logits / s, -2.0, 2.0) + x2 = x_sc * x_sc + return s * torch.clamp(x_sc * (1.0 - x2 / 3.0 + x2 * x2 / 15.0), -1.0, 1.0) + + def forward(self, input_ids: Tensor, target_ids: Tensor, reduction="mean", temperature=1.0): + x = self.tok_emb(input_ids) + x = x.float() + if self.embed_proj is not None: + x = self.embed_proj(x) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.attn_res_mode == 2: + # Pass-level AttnRes: attend over all previous block outputs + stored = [x] + for i, block in enumerate(self.blocks): + x = self.attn_res.attend(stored, i) + x = block(x, x0) + stored.append(x) + elif self.attn_res_mode == 1: + # Sub-layer AttnRes: attend before each attn and mlp sub-layer + stored = [x] + q_idx = 0 + for i, block in enumerate(self.blocks): + # Attention sub-layer + x = self.attn_res.attend(stored, q_idx) + mix = block.resid_mix.to(dtype=x.dtype) + x = mix[0] * x + mix[1] * x0 + attn_out = block.attn_scale.to(dtype=x.dtype) * block.attn(block.attn_norm(x)) + stored.append(attn_out) + q_idx += 1 + # MLP sub-layer + x = self.attn_res.attend(stored, q_idx) + mlp_out = block.mlp_scale.to(dtype=x.dtype) * block.mlp(block.mlp_norm(x)) + stored.append(mlp_out) + q_idx += 1 + if block.smear is not None: + x = block.smear(self.attn_res.attend(stored, q_idx - 1)) + else: + # Standard U-Net (mode 0) + skips = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype) * skips.pop() + x = self.blocks[bi](x, x0) + + x_normed = self.final_norm(x) + x_flat = x_normed.reshape(-1, x_normed.size(-1)) + targets = target_ids.reshape(-1) + logits = self._softcap(self._compute_logits(x_flat)) + + if temperature != 1.0: + logits = logits / temperature + if reduction == "none": + return F.cross_entropy(logits.float(), targets, reduction="none").reshape(input_ids.shape) + + # Fused CE + Z-loss (logsumexp² regularization keeps STE gradients sharp) + logits_f = logits.float() + lse = torch.logsumexp(logits_f, dim=-1) + target_logits = logits_f.gather(1, targets.unsqueeze(1)).squeeze(1) + return (lse - target_logits).mean() + 1e-4 * (lse ** 2).mean() + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- +def build_luts(sp, vocab_size, device): + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + +def ld_val(pattern, seq_len, max_tok=int(os.environ.get("VAL_MAX_TOKENS", 500000))): + files = sorted(glob.glob(pattern)) + assert files, f"No files: {pattern}" + tok = torch.cat([ld_shard(Path(p)) for p in files]).contiguous() + if max_tok > 0: tok = tok[:max_tok + 1] + u = ((tok.numel() - 1) // seq_len) * seq_len + return tok[:u + 1] + +def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, temperature=1.0): + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_seqs = max(1, local_batch_tokens // args.train_seq_len) + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_start in range(seq_start, seq_end, local_batch_seqs): + batch_end = min(batch_start + local_batch_seqs, seq_end) + raw_start = batch_start * args.train_seq_len + raw_end = batch_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64) + x, y = local[:-1].reshape(-1, args.train_seq_len), local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + batch_loss = model(x, y, temperature=temperature).detach() + n = float(y.numel()) + loss_sum += batch_loss.to(torch.float64) * n + token_count += n + prev_ids, tgt_ids = x.reshape(-1), y.reshape(-1) + tok_bytes = base_bytes_lut[tgt_ids].to(torch.int16) + tok_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.int16) + byte_count += tok_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, token_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + val_loss = loss_sum / token_count + bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) + model.train() + return float(val_loss.item()), float(bpb) + +def eval_val_sliding(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, temperature=1.0): + seq_len = args.train_seq_len + batch_size = args.sliding_batch_size + total_tokens = val_tokens.numel() - 1 + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + all_starts = list(range(0, total_tokens - seq_len, stride)) + my_starts = all_starts[rank::world_size] + model.eval() + with torch.inference_mode(): + for i in range(0, len(my_starts), batch_size): + batch_starts = my_starts[i:i + batch_size] + starts_t = torch.tensor(batch_starts, dtype=torch.int64) + offsets = torch.arange(seq_len + 1, dtype=torch.int64) + indices = starts_t.unsqueeze(1) + offsets.unsqueeze(0) + local_batch = val_tokens[indices].to(device=device, dtype=torch.int64, non_blocking=True) + x, y = local_batch[:, :-1], local_batch[:, 1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_token_loss = model(x, y, reduction="none", temperature=temperature).detach() + for b, start in enumerate(batch_starts): + score_from = 0 if start == 0 else seq_len - stride + scored = per_token_loss[b, score_from:] + sx, sy = x[b, score_from:], y[b, score_from:] + loss_sum += scored.to(torch.float64).sum() + token_count += scored.numel() + tok_bytes = base_bytes_lut[sy].to(torch.int16) + tok_bytes += (has_leading_space_lut[sy] & ~is_boundary_token_lut[sx]).to(torch.int16) + byte_count += tok_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, token_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + val_loss = loss_sum / token_count + bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) + model.train() + return float(val_loss.item()), float(bpb) + +def find_temp(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut): + best_t, best_loss = 1.0, float("inf") + for t in [0.85, 0.90, 0.95, 1.00, 1.05, 1.10]: + loss, _ = eval_val(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, temperature=t) + if loss < best_loss: + best_loss = loss + best_t = t + return best_t + +# --------------------------------------------------------------------------- +# Checkpointing +# --------------------------------------------------------------------------- +def _ckpt_path(d, step): return os.path.join(d, f"ckpt_step{step:07d}.pt") + +def _latest_checkpoint(d): + if not os.path.isdir(d): return None + ckpts = sorted(glob.glob(os.path.join(d, "ckpt_step*.pt"))) + return ckpts[-1] if ckpts else None + +def save_checkpoint(d, step, model, optimizers, ms, _untied, ema_model, _ema_started, _ema_steps): + os.makedirs(d, exist_ok=True) + path = _ckpt_path(d, step) + tmp = path + ".tmp" + torch.save({ + "step": step, "training_time_ms": ms, "model": model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "rng_cpu": torch.get_rng_state(), "rng_cuda": torch.cuda.get_rng_state(), + "flags": {"_untied": _untied, "_ema_started": _ema_started, "_ema_steps": _ema_steps}, + "ema_model": ema_model.state_dict() if ema_model is not None else None, + }, tmp) + os.replace(tmp, path) + +def load_checkpoint(path, model, optimizers, device, ema_model=None): + p = torch.load(path, map_location=device, weights_only=False) + model.load_state_dict(p["model"], strict=True) + for opt, sd in zip(optimizers, p["optimizers"]): + opt.load_state_dict(sd) + if ema_model is not None and p.get("ema_model") is not None: + ema_model.load_state_dict(p["ema_model"], strict=True) + torch.set_rng_state(p["rng_cpu"].cpu()) + torch.cuda.set_rng_state(p["rng_cuda"].cpu()) + f = p.get("flags", {}) + return p["step"], p["training_time_ms"], f.get("_untied", False), f.get("_ema_started", False), f.get("_ema_steps", 0) + +# --------------------------------------------------------------------------- +# Main training +# --------------------------------------------------------------------------- +def main(): + args = Hyperparameters() + code = Path(__file__).read_text(encoding="utf-8") + if args.matrix_optimizer != "adamw": + global ns_orth + if args.use_gram_ns: + ns_orth = torch.compile(gram_ns_orth) + else: + ns_orth = torch.compile(ns_orth) + + global _INT8_KERNEL + _INT8_KERNEL = args.use_int8_kernel + + global _TRITON_KERNEL + _TRITON_KERNEL = args.use_triton_kernel and HAS_TRITON + + global _SCALE_FP8_STE + _SCALE_FP8_STE = args.scale_fp8 + + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = world_size > 1 + grad_accum_steps = max(1, 8 // world_size) + grad_scale = 1.0 / grad_accum_steps + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + os.makedirs("logs/cuda/xnor/", exist_ok=True) + logfile = f"logs/cuda/xnor/{args.run_id}.txt" if master_process else None + if master_process: print(logfile) + def log0(msg, console=True): + if not master_process: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + val_tokens = ld_val(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_luts(sp, args.vocab_size, device) + + # --- Model --- + # BINARIZE_ACTIVATIONS: 0=BWN, 1=full XNOR, 2=XNOR but BWN on MLP down proj + _ba = args.binarize_activations + _binarize_act = _ba >= 1 + _binarize_down = _ba == 1 + base_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + group_size=args.xnor_group_size, activation=args.activation_type, + embed_dim=args.embed_dim, fp_storage=args.fp_storage, + softcap_type=args.softcap_type, rope_type=args.rope_type, + yarn_max_len=args.yarn_max_len, train_seq_len=args.train_seq_len, + binarize_act=_binarize_act, binarize_down=_binarize_down, + smear=args.smear, attn_res=args.attn_res, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, nn.Linear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.tie_embeddings: + base_model.lm_head.weight.requires_grad_(False) + + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.recompile_limit = 16 + compiled_model = torch.compile(base_model, mode=args.compile_mode if args.compile_mode != "default" else None) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, + find_unused_parameters=not args.tie_embeddings, + static_graph=args.tie_embeddings, + gradient_as_bucket_view=True) if distributed else compiled_model + + # --- Optimizers --- + _excl = {"tok_emb.weight", "lm_head.weight"} + all_other = [(n, p) for n, p in base_model.named_parameters() if not any(eh in n for eh in _excl)] + matrix_params = [p for n, p in all_other if p.ndim == 2 and not any(pat in n for pat in CTP)] + scalar_params = [p for n, p in all_other if p.ndim < 2 or any(pat in n for pat in CTP)] + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + opt_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + + if args.matrix_optimizer == "adamw": + opt_muon = torch.optim.AdamW( + [{"params": matrix_params, "lr": args.adam_lr, "base_lr": args.adam_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.adam_wd, fused=True) + else: + opt_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, wd=args.muon_wd) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": 0.0, "base_lr": 0.0}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_tok, opt_muon, opt_scalar, opt_head] + + # --- EMA --- + ema_model = None + _ema_started = False + _ema_steps = 0 + if args.ema: + ema_model = copy.deepcopy(base_model) + for p in ema_model.parameters(): p.requires_grad_(False) + + # --- Checkpoint resume --- + _resume_step, _resume_ms, _resume_untied = 0, 0.0, False + _resume_ema_started, _resume_ema_steps = False, 0 + if args.checkpoint_every > 0: + ckpt = _latest_checkpoint(args.checkpoint_dir) + if ckpt: + log0(f"resuming from {ckpt}") + _resume_step, _resume_ms, _resume_untied, _resume_ema_started, _resume_ema_steps = \ + load_checkpoint(ckpt, base_model, optimizers, device, ema_model) + log0(f"resumed at step {_resume_step} ({_resume_ms:.0f}ms)") + else: + log0("no checkpoint, starting fresh") + + # --- Log --- + log0("--- Hyperparameters ---", console=False) + log0(" ".join(f"{a}={getattr(args,a)}" for a in sorted(dir(args)) + if not a.startswith("_") and a not in ("train_files","val_files") + and not callable(getattr(args,a))), console=False) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"XNOR-Net LLM | params:{n_params} L:{args.num_layers} d:{args.model_dim} " + f"h:{args.num_heads} kv:{args.num_kv_heads} mlp:{args.mlp_mult}x " + f"act:{args.activation_type} xnor_act:{args.binarize_activations} " + f"g:{args.xnor_group_size} ws:{world_size} ga:{grad_accum_steps}") + + # --- Data --- + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all(): + for opt in optimizers: opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step, elapsed_ms): + import math + if args.lr_schedule == "cosine": + # Cosine decay from 1.0 to 0 over the full run, with optional warmup + if max_wallclock_ms is None: + total = args.iterations + warmup = min(args.lr_warmup_steps, total) + if step < warmup: + return step / max(warmup, 1) + progress = (step - warmup) / max(total - warmup, 1) + else: + warmup_ms = args.lr_warmup_steps * (elapsed_ms / max(step, 1)) if step > 0 else 1000.0 + if elapsed_ms < warmup_ms: + return elapsed_ms / max(warmup_ms, 1e-9) + progress = (elapsed_ms - warmup_ms) / max(max_wallclock_ms - warmup_ms, 1e-9) + return 0.5 * (1.0 + math.cos(math.pi * min(progress, 1.0))) + # Linear warmdown (original) + if args.warmdown_fraction <= 0: return 1.0 + if max_wallclock_ms is None: + warmdown_start = int(args.iterations * (1.0 - args.warmdown_fraction)) + return max((args.iterations - step) / max(args.iterations * args.warmdown_fraction, 1), 0.0) if step >= warmdown_start else 1.0 + warmdown_ms = max_wallclock_ms * args.warmdown_fraction + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # --- Seq len scheduling --- + seq_len_stages = [args.train_seq_len] + if args.seq_len_schedule: + seq_len_stages = [] + sl = 128 + while sl <= args.train_seq_len: + seq_len_stages.append(sl) + sl *= 2 + if seq_len_stages[-1] != args.train_seq_len: + seq_len_stages.append(args.train_seq_len) + log0(f"seq_len_schedule: {seq_len_stages}") + + def get_seq_len(step, elapsed_ms): + if len(seq_len_stages) <= 1: + return seq_len_stages[0] + n = len(seq_len_stages) + if max_wallclock_ms is not None: + stage_idx = min(int(elapsed_ms / (max_wallclock_ms / n)), n - 1) + else: + stage_idx = min(int(step / (args.iterations / n)), n - 1) + return seq_len_stages[stage_idx] + + # --- Compiler warmup --- + warmup_seq_lens = seq_len_stages if args.seq_len_schedule else [args.train_seq_len] + if len(warmup_seq_lens) > 0: + _ms = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + _os = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws, sl in enumerate(warmup_seq_lens): + zero_grad_all() + for mi in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = mi == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, sl, grad_accum_steps) + torch.compiler.cudagraph_mark_step_begin() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): loss = model(x, y) + (loss * grad_scale).backward() + for o in optimizers: o.step() + zero_grad_all() + log0(f"warmup:{ws+1}/{len(warmup_seq_lens)} seq_len={sl}") + base_model.load_state_dict(_ms, strict=True) + for o, s in zip(optimizers, _os): o.load_state_dict(s) + zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # --- Main loop --- + training_time_ms = _resume_ms + stop_after_step = None + _untied = _resume_untied + _ema_started = _resume_ema_started + _ema_steps = _resume_ema_steps + train_loss = torch.zeros((), device=device) + torch.cuda.synchronize() + t0 = time.perf_counter() + step = _resume_step + steps_this_session = 0 + + if max_wallclock_ms is None: + if step >= args.iterations: + stop_after_step = step + elif training_time_ms >= max_wallclock_ms: + stop_after_step = step + + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms") + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms") + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + # Muon momentum warmup + if args.matrix_optimizer != "adam": + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + for g in opt_muon.param_groups: + g["momentum"] = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + + zero_grad_all() + train_loss.zero_() + current_seq_len = get_seq_len(step, elapsed_ms) + for micro in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = micro == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, current_seq_len, grad_accum_steps) + torch.compiler.cudagraph_mark_step_begin() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y) + # Compression regularizer: penalize sign entropy within groups + if args.sign_compress_reg > 0 and micro == 0: + reg = torch.tensor(0.0, device=device) + for p in base_model.parameters(): + if p.ndim == 2 and p.numel() > 65536: + g = args.xnor_group_size + pg = p.reshape(-1, g) + # soft sign: tanh(k*w) approximates sign(w) but has gradient + soft = torch.tanh(pg * 10.0) + # mean per group: +1/-1 = all same sign (compressible), 0 = balanced + reg = reg + (1.0 - soft.mean(dim=-1).abs().mean()) + loss = loss + args.sign_compress_reg * reg + train_loss.add_(loss.detach()) + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for opt in optimizers: + for g in opt.param_groups: g["lr"] = g["base_lr"] * scale + opt.step() + zero_grad_all() + + # EMA update + if ema_model is not None: + if not _ema_started: + if max_wallclock_ms is not None: + should_start = elapsed_ms >= args.ema_start_fraction * max_wallclock_ms + else: + should_start = step >= int(args.iterations * args.ema_start_fraction) + if should_start: + _ema_started = True + _ema_steps = 0 + with torch.no_grad(): + for ep, bp in zip(ema_model.parameters(), base_model.parameters()): + ep.data.copy_(bp.data) + log0(f"step:{step} ema_started") + if _ema_started: + _ema_steps += 1 + decay = min(args.ema_decay, (1.0 + _ema_steps) / (10.0 + _ema_steps)) + with torch.no_grad(): + for ep, bp in zip(ema_model.parameters(), base_model.parameters()): + ep.data.mul_(decay).add_(bp.data, alpha=1.0 - decay) + + step += 1 + steps_this_session += 1 + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # Checkpoint + if master_process and args.checkpoint_every > 0 and step % args.checkpoint_every == 0: + save_checkpoint(args.checkpoint_dir, step, base_model, optimizers, + approx_ms, _untied, ema_model, _ema_started, _ema_steps) + log0(f"step:{step} checkpoint saved") + + if args.train_log_every > 0 and step % args.train_log_every == 0: + log0(f"step:{step}/{args.iterations} loss:{train_loss.item():.4f} " + f"t:{approx_ms:.0f}ms step_avg:{approx_ms/step:.1f}ms") + + if args.churn_log_every > 0 and step % args.churn_log_every == 0: + log0(f"step:{step} churn:{churn_fn(base_model, args.xnor_group_size):.4f}") + + # Wallclock cap + if stop_after_step is None and max_wallclock_ms is not None and step % 10 == 0: + reached_cap = approx_ms >= max_wallclock_ms + if distributed: + cap_t = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(cap_t, op=dist.ReduceOp.MAX) + reached_cap = bool(cap_t.item()) + if reached_cap: + stop_after_step = step + + # --- Sign-sort permutation for compression --- + # Permute MLP hidden dimension so same-sign weights cluster together. + # Sort columns of down_proj by their sign pattern (lexicographic on sign bits). + # Apply matching row permutation to up_proj. Model output is identical. + if master_process: + sd = (ema_model if ema_model is not None and _ema_started else base_model).state_dict() + for layer_idx in range(args.num_layers): + # Find the MLP weight pairs + down_key = f"blocks.{layer_idx}.mlp.proj.weight" + if args.activation_type == "swiglu": + up_key = f"blocks.{layer_idx}.mlp.gate_up.weight" + else: + up_key = f"blocks.{layer_idx}.mlp.fc.weight" + if down_key not in sd or up_key not in sd: + continue + w_down = sd[down_key] # [dim, hidden] + w_up = sd[up_key] # [hidden, dim] (or [hidden*2, dim] for swiglu) + hidden = w_down.shape[1] + # Sort hidden dim columns of w_down by sign pattern + signs = (w_down >= 0).to(torch.int8) # [dim, hidden] + # Pack sign columns into sortable keys (use first 64 rows as sort key) + n_key_rows = min(64, signs.shape[0]) + sort_keys = torch.zeros(hidden, dtype=torch.int64, device=w_down.device) + for r in range(n_key_rows): + sort_keys += signs[r, :].long() << (n_key_rows - 1 - r) + perm = sort_keys.argsort() + # Apply permutation + sd[down_key] = w_down[:, perm] + if args.activation_type == "swiglu": + half = w_up.shape[0] // 2 + sd[up_key] = torch.cat([w_up[:half][perm], w_up[half:][perm]], dim=0) + else: + sd[up_key] = w_up[perm] + if base_model.tie_embeddings: + sd.pop("lm_head.weight", None) + q_obj, q_stats = q_sd(sd, group_size=args.xnor_group_size, + fp_storage=args.fp_storage, scale_fp8=args.scale_fp8) + buf = io.BytesIO() + torch.save(q_obj, buf) + raw_bytes = buf.getvalue() + # Compare LZMA vs Brotli, keep smallest + lzma_blob = lzma.compress(raw_bytes, preset=9) + best_blob, best_method = lzma_blob, "lzma" + comp_log = f"compression: lzma={len(lzma_blob)/1e6:.2f}MB" + if HAS_BROTLI: + brotli_blob = brotli.compress(raw_bytes, quality=11) + comp_log += f" brotli={len(brotli_blob)/1e6:.2f}MB" + if len(brotli_blob) < len(best_blob): + best_blob, best_method = brotli_blob, "brotli" + if HAS_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + zstd_blob = cctx.compress(raw_bytes) + comp_log += f" zstd={len(zstd_blob)/1e6:.2f}MB" + if len(zstd_blob) < len(best_blob): + best_blob, best_method = zstd_blob, "zstd" + log0(comp_log) + log0(f"using: {best_method}") + with open("final_model.xnor.ptz", "wb") as f: + # Method byte: 0=lzma, 1=brotli, 2=zstd + method_id = {"lzma": 0, "brotli": 1, "zstd": 2}[best_method] + f.write(bytes([method_id])) + f.write(best_blob) + artifact_bytes = 1 + len(best_blob) + code_bytes = len(code.encode("utf-8")) + total = artifact_bytes + code_bytes + log0(f"artifact:{artifact_bytes/1e6:.2f}MB binary:{q_stats['binary_params']}({q_stats['binary_bytes']}B) " + f"fp:{q_stats['fp_params']}({q_stats['fp_bytes']}B) code:{code_bytes}") + log0(f"budget:{total}/{16000000} ({total/1e6:.2f}/{16.00:.2f}MB) " + f"{'FITS' if total <= 16000000 else 'OVER'}") + + # --- Roundtrip eval --- + if distributed: dist.barrier() + with open("final_model.xnor.ptz", "rb") as f: + raw = f.read() + method_byte = raw[0] + compressed = raw[1:] + if method_byte == 2 and HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + decompressed = dctx.decompress(compressed) + elif method_byte == 1 and HAS_BROTLI: + decompressed = brotli.decompress(compressed) + else: + decompressed = lzma.decompress(compressed) + loaded = torch.load(io.BytesIO(decompressed), map_location="cpu", weights_only=False) + base_model.load_state_dict(deq_sd(loaded), strict=False) + if ema_model is not None: + ema_model.load_state_dict(deq_sd(loaded), strict=False) + torch._dynamo.reset() + q_val_loss, q_val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"final_xnor_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f}") + + opt_temp = 1.0 + if args.temp_scaling: + torch.cuda.synchronize() + t_temp = time.perf_counter() + calibration_tokens = train_loader.stream.take(65536).to(device) + opt_temp = find_temp(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + torch.cuda.synchronize() + log0(f"temp_scaling optimal_T:{opt_temp:.2f} time:{1000*(time.perf_counter()-t_temp):.0f}ms") + + if args.sliding_eval: + torch._dynamo.reset() + torch.cuda.empty_cache() + torch.cuda.synchronize() + t_sl = time.perf_counter() + sw_loss, sw_bpb = eval_val_sliding(args, base_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, stride=args.sliding_eval_stride, + temperature=opt_temp) + torch.cuda.synchronize() + log0(f"final_sliding val_loss:{sw_loss:.4f} val_bpb:{sw_bpb:.4f} " + f"(stride={args.sliding_eval_stride}, T={opt_temp:.2f}) time:{1000*(time.perf_counter()-t_sl):.0f}ms") + + if distributed: dist.destroy_process_group() + + +if __name__ == "__main__": + main() +==================================================================================================== +no checkpoint, starting fresh +--- Hyperparameters --- +activation_type=signsq adam_eps=1e-08 adam_lr=0.02 adam_wd=0.04 attn_res=2 beta1=0.9 beta2=0.95 binarize_activations=2 checkpoint_dir=./checkpoints checkpoint_every=10000 churn_log_every=0 compile_mode=default data_path=./data/datasets/fineweb10B_sp1024 ema=False ema_decay=0.995 ema_start_fraction=0.0 embed_dim=384 embed_lr=0.6 fp_storage=True grad_clip_norm=1.0 head_lr=0.02 iterations=100000 logit_softcap=10.0 lr_schedule=cosine lr_warmup_steps=5 matrix_lr=0.008 matrix_optimizer=muon max_wallclock_seconds=0.0 mlp_mult=4 model_dim=1024 muon_backend_steps=3 muon_momentum=0.8 muon_momentum_warmup_start=0.8 muon_momentum_warmup_steps=200 muon_wd=0.04 num_heads=8 num_kv_heads=4 num_layers=10 qk_gain_init=2.25 rope_base=5000.0 rope_type=yarn run_id=N3_full-xnor_8xH100_1775401853 scalar_lr=0.01 scale_fp8=False seed=42 seq_len_schedule=True sign_compress_reg=0.0 sliding_batch_size=512 sliding_eval=True sliding_eval_stride=16 smear=False softcap_type=poly temp_scaling=True tie_embeddings=1 tied_embed_init_std=0.005 tied_embed_lr=0.05 tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model train_batch_tokens=65536 train_log_every=1000 train_seq_len=1024 use_gram_ns=False use_int8_kernel=False use_triton_kernel=True val_batch_size=524288 val_loss_every=0 vocab_size=1024 warmdown_fraction=0.3 xnor_group_size=256 yarn_max_len=2048 +XNOR-Net LLM | params:117630032 L:10 d:1024 h:8 kv:4 mlp:4x act:signsq xnor_act:2 g:256 ws:8 ga:1 +seq_len_schedule: [128, 256, 512, 1024] +warmup:1/4 seq_len=128 +warmup:2/4 seq_len=256 +warmup:3/4 seq_len=512 +warmup:4/4 seq_len=1024 +step:1000/100000 loss:3.5850 t:51774ms step_avg:51.8ms +step:2000/100000 loss:3.4241 t:103414ms step_avg:51.7ms +step:3000/100000 loss:3.2302 t:154965ms step_avg:51.7ms +step:4000/100000 loss:3.3566 t:206551ms step_avg:51.6ms +step:5000/100000 loss:3.4339 t:258140ms step_avg:51.6ms +step:6000/100000 loss:3.3440 t:309693ms step_avg:51.6ms +step:7000/100000 loss:3.4384 t:361285ms step_avg:51.6ms +step:8000/100000 loss:3.6339 t:412879ms step_avg:51.6ms +step:9000/100000 loss:3.8280 t:464430ms step_avg:51.6ms +step:10000 checkpoint saved +step:10000/100000 loss:3.5758 t:516117ms step_avg:51.6ms +step:11000/100000 loss:3.5132 t:568246ms step_avg:51.7ms +step:12000/100000 loss:3.5622 t:619801ms step_avg:51.7ms +step:13000/100000 loss:3.5778 t:671388ms step_avg:51.6ms +step:14000/100000 loss:3.6370 t:722978ms step_avg:51.6ms +step:15000/100000 loss:3.5388 t:774528ms step_avg:51.6ms +step:16000/100000 loss:3.6200 t:826113ms step_avg:51.6ms +step:17000/100000 loss:3.4811 t:877721ms step_avg:51.6ms +step:18000/100000 loss:3.6142 t:929273ms step_avg:51.6ms +step:19000/100000 loss:3.4890 t:980862ms step_avg:51.6ms +step:20000 checkpoint saved +step:20000/100000 loss:3.3466 t:1032459ms step_avg:51.6ms +step:21000/100000 loss:3.8485 t:1084675ms step_avg:51.7ms +step:22000/100000 loss:3.6800 t:1136269ms step_avg:51.6ms +step:23000/100000 loss:3.5808 t:1187874ms step_avg:51.6ms +step:24000/100000 loss:3.6088 t:1239429ms step_avg:51.6ms +step:25000/100000 loss:3.5170 t:1291016ms step_avg:51.6ms +step:26000/100000 loss:3.4403 t:1342710ms step_avg:51.6ms +step:27000/100000 loss:3.3500 t:1394354ms step_avg:51.6ms +step:28000/100000 loss:3.7462 t:1446035ms step_avg:51.6ms +step:29000/100000 loss:3.6899 t:1497725ms step_avg:51.6ms +step:30000 checkpoint saved +step:30000/100000 loss:3.4281 t:1549376ms step_avg:51.6ms +step:31000/100000 loss:3.6072 t:1601549ms step_avg:51.7ms +step:32000/100000 loss:3.5159 t:1653193ms step_avg:51.7ms +step:33000/100000 loss:3.2772 t:1704989ms step_avg:51.7ms +step:34000/100000 loss:3.3412 t:1756682ms step_avg:51.7ms +step:35000/100000 loss:3.6511 t:1808340ms step_avg:51.7ms +step:36000/100000 loss:3.6248 t:1860044ms step_avg:51.7ms +step:37000/100000 loss:3.3697 t:1911986ms step_avg:51.7ms +step:38000/100000 loss:3.4508 t:1963652ms step_avg:51.7ms +step:39000/100000 loss:3.4820 t:2015352ms step_avg:51.7ms +step:40000 checkpoint saved +step:40000/100000 loss:3.5149 t:2067048ms step_avg:51.7ms +step:41000/100000 loss:3.7823 t:2119263ms step_avg:51.7ms +step:42000/100000 loss:3.5010 t:2170973ms step_avg:51.7ms +step:43000/100000 loss:3.4672 t:2222669ms step_avg:51.7ms +step:44000/100000 loss:3.3739 t:2274429ms step_avg:51.7ms +step:45000/100000 loss:3.7859 t:2326135ms step_avg:51.7ms +step:46000/100000 loss:3.2487 t:2377828ms step_avg:51.7ms +step:47000/100000 loss:3.4967 t:2429487ms step_avg:51.7ms +step:48000/100000 loss:3.1303 t:2481190ms step_avg:51.7ms +step:49000/100000 loss:3.2962 t:2532887ms step_avg:51.7ms +step:50000 checkpoint saved +step:50000/100000 loss:3.3109 t:2584546ms step_avg:51.7ms +step:51000/100000 loss:3.4626 t:2636881ms step_avg:51.7ms +step:52000/100000 loss:3.2043 t:2688715ms step_avg:51.7ms +step:53000/100000 loss:3.2786 t:2740514ms step_avg:51.7ms +step:54000/100000 loss:3.2035 t:2792352ms step_avg:51.7ms +step:55000/100000 loss:3.4879 t:2844187ms step_avg:51.7ms +step:56000/100000 loss:3.0030 t:2896086ms step_avg:51.7ms +step:57000/100000 loss:3.2256 t:2947925ms step_avg:51.7ms +step:58000/100000 loss:3.3797 t:2999761ms step_avg:51.7ms +step:59000/100000 loss:3.2903 t:3051559ms step_avg:51.7ms +step:60000 checkpoint saved +step:60000/100000 loss:3.2593 t:3103397ms step_avg:51.7ms +step:61000/100000 loss:3.2437 t:3155687ms step_avg:51.7ms +step:62000/100000 loss:3.0563 t:3207521ms step_avg:51.7ms +step:63000/100000 loss:3.2527 t:3259352ms step_avg:51.7ms +step:64000/100000 loss:3.1472 t:3311150ms step_avg:51.7ms +step:65000/100000 loss:3.2172 t:3362984ms step_avg:51.7ms +step:66000/100000 loss:2.8629 t:3414809ms step_avg:51.7ms +step:67000/100000 loss:3.0554 t:3466608ms step_avg:51.7ms +step:68000/100000 loss:3.0823 t:3518546ms step_avg:51.7ms +step:69000/100000 loss:3.2198 t:3570381ms step_avg:51.7ms +step:70000 checkpoint saved +step:70000/100000 loss:3.5317 t:3622180ms step_avg:51.7ms +step:71000/100000 loss:3.1825 t:3674518ms step_avg:51.8ms +step:72000/100000 loss:3.0766 t:3726364ms step_avg:51.8ms +step:73000/100000 loss:3.2229 t:3778161ms step_avg:51.8ms +step:74000/100000 loss:3.1886 t:3829991ms step_avg:51.8ms +step:75000/100000 loss:3.3816 t:3881828ms step_avg:51.8ms +step:76000/100000 loss:3.0165 t:3933963ms step_avg:51.8ms +step:77000/100000 loss:3.1094 t:3986130ms step_avg:51.8ms +step:78000/100000 loss:3.2248 t:4038307ms step_avg:51.8ms +step:79000/100000 loss:2.9869 t:4090543ms step_avg:51.8ms +step:80000 checkpoint saved +step:80000/100000 loss:3.0858 t:4142711ms step_avg:51.8ms +step:81000/100000 loss:2.5007 t:4195381ms step_avg:51.8ms +step:82000/100000 loss:3.0648 t:4247515ms step_avg:51.8ms +step:83000/100000 loss:2.9326 t:4299690ms step_avg:51.8ms +step:84000/100000 loss:3.2452 t:4351872ms step_avg:51.8ms +step:85000/100000 loss:3.0102 t:4404007ms step_avg:51.8ms +step:86000/100000 loss:2.9922 t:4456176ms step_avg:51.8ms +step:87000/100000 loss:2.9035 t:4508348ms step_avg:51.8ms +step:88000/100000 loss:2.9340 t:4560481ms step_avg:51.8ms +step:89000/100000 loss:3.1269 t:4612652ms step_avg:51.8ms +step:90000 checkpoint saved +step:90000/100000 loss:2.9661 t:4664785ms step_avg:51.8ms +step:91000/100000 loss:2.9156 t:4717573ms step_avg:51.8ms +step:92000/100000 loss:3.0960 t:4769733ms step_avg:51.8ms +step:93000/100000 loss:3.0488 t:4821850ms step_avg:51.8ms +step:94000/100000 loss:2.7601 t:4874014ms step_avg:51.9ms +step:95000/100000 loss:2.8182 t:4926167ms step_avg:51.9ms +step:96000/100000 loss:2.5139 t:4978286ms step_avg:51.9ms +step:97000/100000 loss:2.7799 t:5030447ms step_avg:51.9ms +step:98000/100000 loss:3.0646 t:5082598ms step_avg:51.9ms +step:99000/100000 loss:2.6594 t:5134717ms step_avg:51.9ms +step:100000 checkpoint saved +step:100000/100000 loss:2.8225 t:5186880ms step_avg:51.9ms +step:100000/100000 val_loss:2.8411 val_bpb:1.5832 train_time:5187412ms +compression: lzma=15.90MB brotli=15.82MB zstd=17.75MB +using: brotli +artifact:15.82MB binary:115343360(15319040B) fp:1238096(1270944B) code:78442 +budget:15903256/16000000 (15.90/16.00MB) FITS +final_xnor_roundtrip val_loss:2.8633 val_bpb:1.5956 +temp_scaling optimal_T:1.05 time:753ms +final_sliding val_loss:2.8040 val_bpb:1.5626 (stride=16, T=1.05) time:1762005ms diff --git a/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/log_100k_steps_u-net-skips.txt b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/log_100k_steps_u-net-skips.txt new file mode 100644 index 0000000000..fc89a74031 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/log_100k_steps_u-net-skips.txt @@ -0,0 +1,1747 @@ +"""XNOR-Net LLM training script for OpenAI's Parameter Golf Challenge. +Ciprian-Florin Ifrim - April 2026 + +Based on the Binary-Weight-Network script, extended to full XNOR-Net: +both weights AND activations are binarized using sign() with per-group/per-token +scaling factors derived from the XNOR-Net paper (Rastegari et al. 2016). + +Architecture: U-Net transformer with skip connections that provides error correction +for the information loss inherent in binary quantization of both weights and activations. +""" + +import copy +import glob +import io +import math +import os +import random +import sys +import time +import lzma +from pathlib import Path +try: + import brotli + HAS_BROTLI = True +except ImportError: + HAS_BROTLI = False +try: + import zstandard as zstd + HAS_ZSTD = True +except ImportError: + HAS_ZSTD = False +try: + from gram_newton_schulz import GramNewtonSchulz, POLAR_EXPRESS_COEFFICIENTS + HAS_GRAM_NS = True +except ImportError: + HAS_GRAM_NS = False +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False +def _get_flash_attn_func(): + """GPU-aware FlashAttention selection. + H100: FA3 -> FA2 -> SDPA + 5090: FA4 -> FA2 -> SDPA + Other: FA2 -> SDPA + """ + if torch.cuda.is_available(): + sm = torch.cuda.get_device_capability() + # Hopper (SM 90) — prefer FA3 + if sm[0] == 9 and sm[1] == 0: + try: + from flash_attn_interface import flash_attn_func + return flash_attn_func + except ImportError: + pass + # Blackwell (SM 100/120) — try FA4 + if sm[0] >= 10: + try: + from flash_attn.cute import flash_attn_func + return flash_attn_func + except ImportError: + pass + # FA2 — universal fallback with CUDA kernels + try: + from flash_attn import flash_attn_func + return flash_attn_func + except ImportError: + pass + # PyTorch native SDPA — always works, no install needed + def _sdpa_flash_attn_func(q, k, v, causal=False, **kwargs): + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + o = F.scaled_dot_product_attention(q, k, v, is_causal=causal, enable_gqa=True) + return o.transpose(1, 2) + return _sdpa_flash_attn_func +flash_attn_func = _get_flash_attn_func() + +# --------------------------------------------------------------------------- +# Hyperparameters +# --------------------------------------------------------------------------- +def _e(k, d, t=str): + v = os.environ.get(k, str(d)) + if t == bool: return bool(int(v)) + return t(v) + +class Hyperparameters: + # Data + data_path = _e("DATA_PATH", "./data/datasets/fineweb10B_sp8192") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = _e("TOKENIZER_PATH", "./data/tokenizers/fineweb_8192_bpe.model") + run_id = os.environ.get("RUN_ID", f"xnor_{int(time.time())}") + seed = _e("SEED", 1337, int) + # Compile + compile_mode = _e("COMPILE_MODE", "default") + # Eval + val_batch_size = _e("VAL_BATCH_SIZE", 524288, int) + val_loss_every = _e("VAL_LOSS_EVERY", 500, int) + train_log_every = _e("TRAIN_LOG_EVERY", 10, int) + sliding_eval = _e("SLIDING_EVAL", 0, bool) + sliding_eval_stride = _e("SLIDING_EVAL_STRIDE", 16, int) + sliding_batch_size = _e("SLIDING_BATCH_SIZE", 512, int) + temp_scaling = _e("TEMP_SCALING", 0, bool) + # Training schedule + iterations = _e("ITERATIONS", 2000, int) + warmdown_fraction = _e("WARMDOWN_FRACTION", 0.15, float) + lr_warmup_steps = _e("LR_WARMUP_STEPS", 5, int) + lr_schedule = _e("LR_SCHEDULE", "linear") # "linear" (current) or "cosine" + seq_len_schedule = _e("SEQ_LEN_SCHEDULE", 0, bool) # ramp seq_len from 128 to train_seq_len + train_batch_tokens = _e("TRAIN_BATCH_TOKENS", 524288, int) + train_seq_len = _e("TRAIN_SEQ_LEN", 1024, int) + max_wallclock_seconds = _e("MAX_WALLCLOCK_SECONDS", 0.0, float) + # Architecture — from XNOR-UNet spec + vocab_size = _e("VOCAB_SIZE", 8192, int) + num_layers = _e("NUM_LAYERS", 15, int) + num_kv_heads = _e("NUM_KV_HEADS", 4, int) + model_dim = _e("MODEL_DIM", 768, int) + num_heads = _e("NUM_HEADS", 8, int) + mlp_mult = _e("MLP_MULT", 4, int) + tie_embeddings = _e("TIE_EMBEDDINGS", 1, int) + embed_dim = _e("EMBED_DIM", 312, int) + activation_type = _e("ACTIVATION", "relu2") + # Attention + rope_base = _e("ROPE_BASE", 5000.0, float) + rope_type = _e("ROPE_TYPE", "yarn") + yarn_max_len = _e("YARN_MAX_LEN", 2048, int) + qk_gain_init = _e("QK_GAIN_INIT", 2.25, float) + # Logits + logit_softcap = _e("LOGIT_SOFTCAP", 10.0, float) + softcap_type = _e("SOFTCAP_TYPE", "poly") + tied_embed_init_std = _e("TIED_EMBED_INIT_STD", 0.005, float) + # Architecture extras + smear = _e("SMEAR", 1, bool) + # XNOR quantization + xnor_group_size = _e("XNOR_GROUP_SIZE", 128, int) + # Controls activation binarization: 0 = none (BWN), 1 = all (full XNOR), 2 = skip MLP down proj + binarize_activations = _e("BINARIZE_ACTIVATIONS", 1, int) + # Use INT8 tensor cores for binary matmul (speed test) + use_int8_kernel = _e("USE_INT8_KERNEL", 0, bool) + # Use Triton XNOR+POPCOUNT kernel (true 1-bit matmul) + use_triton_kernel = _e("USE_TRITON_KERNEL", 0, bool) + # FP8 QAT for non-binary params + _fp_raw = os.environ.get("FP_STORAGE", "FP8") + fp_storage = True if _fp_raw == "FP8" else False + # Scale storage for binary weight groups: FP8 or BF16 + _scale_raw = os.environ.get("SCALE_STORAGE", "BF16") + scale_fp8 = True if _scale_raw == "FP8" else False + # Compression regularizer: penalizes sign entropy within groups + sign_compress_reg = _e("SIGN_COMPRESS_REG", 0.0, float) + # Optimizer + matrix_lr = _e("MATRIX_LR", 0.04, float) + scalar_lr = _e("SCALAR_LR", 0.02, float) + tied_embed_lr = _e("TIED_EMBED_LR", 0.02, float) + embed_lr = _e("EMBED_LR", 0.6, float) + head_lr = _e("HEAD_LR", 0.02, float) + matrix_optimizer = _e("MATRIX_OPTIMIZER", "muon") + muon_momentum = _e("MUON_MOMENTUM", 0.95, float) + muon_backend_steps = _e("MUON_BACKEND_STEPS", 3, int) + muon_wd = _e("MUON_WD", 0.0, float) + use_gram_ns = _e("USE_GRAM_NS", 0, bool) + muon_momentum_warmup_start = _e("MUON_MOMENTUM_WARMUP_START", 0.85, float) + muon_momentum_warmup_steps = _e("MUON_MOMENTUM_WARMUP_STEPS", 500, int) + adam_lr = _e("ADAM_LR", 0.05, float) + adam_wd = _e("ADAM_WD", 0.05, float) + beta1 = _e("BETA1", 0.9, float) + beta2 = _e("BETA2", 0.95, float) + adam_eps = _e("ADAM_EPS", 1e-8, float) + grad_clip_norm = _e("GRAD_CLIP_NORM", 0.0, float) + # EMA — compatible with binary (no zero-weight collapse unlike ternary) + ema = _e("EMA", 0, bool) + ema_decay = _e("EMA_DECAY", 0.995, float) + ema_start_fraction = _e("EMA_START_FRACTION", 0.60, float) + # Checkpointing + checkpoint_every = _e("CHECKPOINT_EVERY", 0, int) + checkpoint_dir = _e("CHECKPOINT_DIR", "./checkpoints") + # Diagnostics + churn_log_every = _e("CHURN_LOG_EVERY", 500, int) + +# Scalar/low-dim parameter names (excluded from Muon, use Adam) +CTP = ("attn_scale", "mlp_scale", "resid_mix", "q_gain", "skip_weight", + "skip_weights", "vocab_bias", "smear") + +# --------------------------------------------------------------------------- +# Binary packing (1 bit per weight, lossless) +# --------------------------------------------------------------------------- +def pack_binary(q: Tensor) -> tuple[bytes, int]: + bits = ((q.reshape(-1).to(torch.int8) + 1) // 2).numpy().astype(np.uint8) + n = len(bits) + pad = (8 - n % 8) % 8 + if pad: + bits = np.concatenate([bits, np.zeros(pad, dtype=np.uint8)]) + groups = bits.reshape(-1, 8) + packed = np.zeros(len(groups), dtype=np.uint8) + for i in range(8): + packed |= groups[:, i] << i + return packed.tobytes(), n + +def unpack_binary(data: bytes, n: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8) + bits = np.zeros((len(packed), 8), dtype=np.int8) + for i in range(8): + bits[:, i] = (packed >> i) & 1 + flat = bits.reshape(-1)[:n] + return torch.from_numpy(flat.astype(np.int8) * 2 - 1) + +# --------------------------------------------------------------------------- +# State dict serialization +# --------------------------------------------------------------------------- +def q_sd(state_dict: dict, group_size: int = 128, fp_storage=False, scale_fp8=False) -> tuple[dict, dict]: + """Binary pack large 2D weight matrices, BF16/FP8 for everything else.""" + quantized = {} + stats = {"binary_params": 0, "binary_bytes": 0, "fp_params": 0, "fp_bytes": 0} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().float().contiguous() + t_orig_shape = list(t.shape) + if t.ndim == 3: + t = t.reshape(t.shape[0], -1) + is_binary = ( + t.ndim == 2 and t.numel() > 65_536 + and "tok_emb" not in name and "lm_head" not in name + and "embed_proj" not in name + ) + if is_binary: + pad = (group_size - t.shape[1] % group_size) % group_size + t_padded = F.pad(t, (0, pad)) if pad > 0 else t + t_grouped = t_padded.reshape(-1, group_size) + scale = t_grouped.abs().mean(-1, keepdim=True).clamp(min=1e-8) + q = torch.where(t_grouped >= 0, + torch.ones_like(t_grouped, dtype=torch.int8), + -torch.ones_like(t_grouped, dtype=torch.int8)) + packed_bytes, n_bits = pack_binary(q) + if scale_fp8: + scale_stored = scale.to(torch.float8_e4m3fn).squeeze(-1) + scale_bytes_per = 1 + else: + scale_stored = scale.bfloat16().squeeze(-1) + scale_bytes_per = 2 + quantized[name] = { + "type": "binary", "packed": packed_bytes, + "scale": scale_stored, + "shape": list(t.shape), "padded_cols": t_padded.shape[1], + "group_size": group_size, "n_bits": n_bits, + "orig_shape": t_orig_shape, + } + stats["binary_params"] += t.numel() + stats["binary_bytes"] += len(packed_bytes) + scale.numel() * scale_bytes_per + elif fp_storage and t.ndim == 2: + quantized[name] = {"type": "fp8", "data": t.to(torch.float8_e4m3fn)} + stats["fp_params"] += t.numel() + stats["fp_bytes"] += t.numel() + else: + quantized[name] = {"type": "bf16", "data": t.bfloat16()} + stats["fp_params"] += t.numel() + stats["fp_bytes"] += t.numel() * 2 + return quantized, stats + +def deq_sd(quantized: dict, target_dtype=torch.bfloat16): + """Reconstruct state dict from quantized representation.""" + out = {} + for name, entry in quantized.items(): + if entry["type"] == "binary": + q = unpack_binary(entry["packed"], entry["n_bits"]) + q = q.float().reshape(-1, entry["group_size"]) + scale = entry["scale"].float().unsqueeze(-1) + # No shrinkage correction: binary q_absmean = 1.0 always + t = (q * scale).reshape(-1, entry["padded_cols"]) + shape = entry["shape"] + result = t[:shape[0], :shape[1]].to(target_dtype) + orig = entry.get("orig_shape") + out[name] = result.reshape(orig).contiguous() if orig and orig != shape else result.contiguous() + elif entry["type"] == "fp8": + out[name] = entry["data"].to(torch.float32).to(target_dtype).contiguous() + else: + out[name] = entry["data"].to(target_dtype).contiguous() + return out + +# --------------------------------------------------------------------------- +# Diagnostics +# --------------------------------------------------------------------------- +_prev_committed: dict = {} +def churn_fn(model: nn.Module, group_size: int = 128): + global _prev_committed + total = flipped = 0 + with torch.no_grad(): + for name, p in model.named_parameters(): + if p.ndim == 2 and "weight" in name and p.shape[0] > 1: + w = p.detach().float().reshape(-1, group_size) + q = torch.where(w >= 0, torch.ones_like(w), -torch.ones_like(w)).cpu().numpy() + if name in _prev_committed: + flipped += int(np.sum(q != _prev_committed[name])) + total += q.size + _prev_committed[name] = q + return flipped / max(total, 1) + +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# Muon optimizer +# --------------------------------------------------------------------------- +def ns_orth(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +def gram_ns_orth(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: + """Gram Newton-Schulz with Polar Express coefficients. 5 steps, restart at step 2. + Fully unrolled for torch.compile. bfloat16 throughout.""" + X = G.bfloat16() + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + X = X / (X.norm() + eps) + n = X.size(0) + I = torch.eye(n, device=X.device, dtype=X.dtype) + + # Phase 1: steps 0-1 + R = X @ X.T + # Step 0: a=4.0848, b=-6.8946, c=2.9270 + Z = -6.8946 * R + 2.9270 * (R @ R) + Q = Z + 4.0848 * I + RZ = 4.0848 * R + R @ Z + R = 4.0848 * RZ + Z @ RZ + # Step 1: a=3.9505, b=-6.3029, c=2.6377 + Z = -6.3029 * R + 2.6377 * (R @ R) + Q = 3.9505 * Q + Q @ Z + # No R update (next is reset) + + # Reset + X = Q @ X + R = X @ X.T + + # Phase 2: steps 2-4 + # Step 2: a=3.7418, b=-5.5913, c=2.3037 + Z = -5.5913 * R + 2.3037 * (R @ R) + Q = Z + 3.7418 * I + RZ = 3.7418 * R + R @ Z + R = 3.7418 * RZ + Z @ RZ + # Step 3: a=2.8769, b=-3.1427, c=1.2046 + Z = -3.1427 * R + 1.2046 * (R @ R) + Q = 2.8769 * Q + Q @ Z + RZ = 2.8769 * R + R @ Z + R = 2.8769 * RZ + Z @ RZ + # Step 4: a=2.8366, b=-3.0525, c=1.2012 + Z = -3.0525 * R + 1.2012 * (R @ R) + Q = 2.8366 * Q + Q @ Z + + X = Q @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr, momentum, backend_steps, nesterov=True, wd=0.0): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, wd=wd)) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr, momentum = group["lr"], group["momentum"] + backend_steps, nesterov = group["backend_steps"], group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = F.rms_norm(g.float(), (g.size(-1),)).bfloat16() + g = ns_orth(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("wd", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr:curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.mul_(1 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + +# --------------------------------------------------------------------------- +# Data loading +# --------------------------------------------------------------------------- +def ld_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" Tensor: + chunks = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int): + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start:start + per_rank_span].pin_memory().to(self.device, non_blocking=True).to(torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x, y + +# --------------------------------------------------------------------------- +# Model components +# --------------------------------------------------------------------------- +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def apply_qat_ste(w: Tensor, fp_storage) -> Tensor: + if not fp_storage: + return w + w_sim = w.to(torch.float8_e4m3fn).to(w.dtype) + return (w_sim - w).detach() + w + +class QATLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=False, fp_storage=False): + super().__init__(in_features, out_features, bias=bias) + self.fp_storage = fp_storage + def forward(self, x: Tensor) -> Tensor: + w = apply_qat_ste(self.weight, self.fp_storage) + return F.linear(x, w.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) + +class QATEmbedding(nn.Embedding): + def __init__(self, num_embeddings, embedding_dim, fp_storage=False): + super().__init__(num_embeddings, embedding_dim) + self.fp_storage = fp_storage + def forward(self, input: Tensor) -> Tensor: + w = apply_qat_ste(self.weight, self.fp_storage) + return F.embedding(input, w, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + +# --------------------------------------------------------------------------- +# XNOR-Net Linear Layers +# --------------------------------------------------------------------------- +_INT8_KERNEL = False # set from Hyperparameters.use_int8_kernel in main() + +class _INT8XNORFn(torch.autograd.Function): + """INT8 tensor core forward, BF16 backward. + Uses per-row alpha (not per-group) so we get a single torch._int_mm call. + """ + @staticmethod + def forward(ctx, x, w): + # Per-row weight scale (single matmul) + w_alpha = w.detach().abs().mean(dim=-1).clamp_(min=1e-8) # [N] + x_beta = x.detach().abs().mean(dim=-1, keepdim=True).clamp_(min=1e-8) # [..., 1] + + w_int8 = w.detach().sign().to(torch.int8) # [N, K] + x_int8 = x.detach().sign().to(torch.int8) # [..., K] + + # Flatten to 2D for torch._int_mm(A[M,K], B[K,N]) -> [M,N] + orig_shape = x.shape[:-1] + M, K, N = orig_shape.numel(), x.shape[-1], w.shape[0] + x_2d = x_int8.reshape(M, K) + wt = w_int8.t().contiguous() # [K, N] + raw = torch._int_mm(x_2d, wt) # [M, N] INT32 + + y = raw.to(x.dtype) * x_beta.reshape(M, 1) * w_alpha.unsqueeze(0) + ctx.save_for_backward(x, w, w_alpha, x_beta) + return y.reshape(*orig_shape, N) + + @staticmethod + def backward(ctx, grad_output): + x, w, w_alpha, x_beta = ctx.saved_tensors + dtype = grad_output.dtype + + w_binary = (w.sign() * w_alpha.unsqueeze(-1)).to(dtype) # [N, K] + x_binary = (x.sign() * x_beta).to(dtype) # [..., K] + + # STE: gradient flows through sign() as identity + grad_x = grad_output @ w_binary + M = grad_output.shape[:-1].numel() + grad_w = grad_output.reshape(M, -1).t() @ x_binary.reshape(M, -1) + return grad_x, grad_w + +# --------------------------------------------------------------------------- +# Triton XNOR+POPCOUNT kernel — true 1-bit matmul +# --------------------------------------------------------------------------- +_TRITON_KERNEL = False # set from Hyperparameters in main() +_SCALE_FP8_STE = False # simulate FP8 scale quantization during training + +def pack_signs(x: Tensor) -> Tensor: + """Pack sign bits into int32. +1 -> bit 1, -1 -> bit 0. + x: [..., K] any dtype. K must be divisible by 32. + returns: [..., K//32] int32 + """ + K = x.shape[-1] + assert K % 32 == 0, f"K={K} must be divisible by 32" + bits = (x > 0).view(*x.shape[:-1], K // 32, 32) + shifts = torch.arange(32, device=x.device, dtype=torch.int32) + return (bits.to(torch.int32) << shifts).sum(dim=-1) + +if HAS_TRITON: + @triton.jit + def _xnor_matmul_fwd_kernel( + X_ptr, W_ptr, Y_ptr, alpha_ptr, beta_ptr, + M, N, K_packed, GROUP_PACKED: tl.constexpr, num_groups, + stride_xm, stride_wn, stride_yn, stride_alpha_n, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + ): + pid = tl.program_id(0) + num_n = tl.cdiv(N, BLOCK_N) + pid_m = pid // num_n + pid_n = pid % num_n + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_m = offs_m < M + mask_n = offs_n < N + + y_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.bfloat16) + + x_base = X_ptr + offs_m * stride_xm + w_base = W_ptr + offs_n * stride_wn + + # Loop over groups, each with GROUP_PACKED packed int32s + for g in range(num_groups): + group_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + k_start = g * GROUP_PACKED + for k in range(GROUP_PACKED): + x = tl.load(x_base + k_start + k, mask=mask_m, other=0) + w = tl.load(w_base + k_start + k, mask=mask_n, other=0) + # Cast broadcast XOR to int32 to prevent int64 promotion in popc + diff = (x[:, None] ^ w[None, :]).to(tl.int32) + group_acc += tl.extra.cuda.libdevice.popc(diff) + + # diff_bits = popc(XOR), dot = GROUP_SIZE - 2 * diff_bits + group_dot = (GROUP_PACKED * 32 - 2 * group_acc).to(tl.bfloat16) + + # Per-group alpha: alpha[n, g] + alpha_g = tl.load(alpha_ptr + offs_n * stride_alpha_n + g, mask=mask_n, other=1.0).to(tl.bfloat16) + y_acc += group_dot * alpha_g[None, :] + + # Per-token beta + beta = tl.load(beta_ptr + offs_m, mask=mask_m, other=1.0).to(tl.bfloat16) + y_acc = y_acc * beta[:, None] + + # Store + y_ptrs = Y_ptr + offs_m[:, None] * stride_yn + offs_n[None, :] + tl.store(y_ptrs, y_acc, mask=mask_m[:, None] & mask_n[None, :]) + +class _TritonXNORFn(torch.autograd.Function): + """Triton XNOR+POPCOUNT forward with per-group alpha, BF16 backward.""" + @staticmethod + def forward(ctx, x, w, group_size): + N, K = w.shape + assert K % group_size == 0, f"K={K} must be divisible by group_size={group_size}" + assert K % 32 == 0, f"K={K} must be divisible by 32" + num_groups = K // group_size + GROUP_PACKED = group_size // 32 + + # Per-group weight scale: [N, num_groups] + w_alpha = w.detach().reshape(N, num_groups, group_size).abs().mean(dim=-1).clamp_(min=1e-8) + if _SCALE_FP8_STE: + w_alpha = w_alpha.to(torch.float8_e4m3fn).to(w_alpha.dtype) + # Per-token activation scale: [..., 1] + x_beta_kd = x.detach().abs().mean(dim=-1, keepdim=True).clamp_(min=1e-8) + + # Pack signs into int32 + x_packed = pack_signs(x.detach()) # [..., K//32] + w_packed = pack_signs(w.detach()) # [N, K//32] + + orig_shape = x.shape[:-1] + M = orig_shape.numel() + K_packed = K // 32 + + x_2d = x_packed.reshape(M, K_packed).contiguous() + w_alpha_c = w_alpha.contiguous() # [N, num_groups] + beta_flat = x_beta_kd.reshape(M).contiguous() + + Y = torch.empty(M, N, device=x.device, dtype=torch.bfloat16) + + BLOCK_M, BLOCK_N = 64, 64 + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),) + + _xnor_matmul_fwd_kernel[grid]( + x_2d, w_packed, Y, w_alpha_c, beta_flat, + M, N, K_packed, GROUP_PACKED, num_groups, + x_2d.stride(0), w_packed.stride(0), Y.stride(0), w_alpha_c.stride(0), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + ) + + ctx.save_for_backward(x, w, w_alpha, x_beta_kd) + ctx.group_size = group_size + return Y.reshape(*orig_shape, N) + + @staticmethod + def backward(ctx, grad_output): + x, w, w_alpha, x_beta_kd = ctx.saved_tensors + dtype = grad_output.dtype + group_size = ctx.group_size + + # Reconstruct per-group binary weights for backward + N, K = w.shape + num_groups = K // group_size + w_binary = (w.sign().reshape(N, num_groups, group_size) * + w_alpha.unsqueeze(-1)).reshape(N, K).to(dtype) + x_binary = (x.sign() * x_beta_kd).to(dtype) + + grad_x = grad_output @ w_binary + M = grad_output.shape[:-1].numel() + grad_w = grad_output.reshape(M, -1).t() @ x_binary.reshape(M, -1) + return grad_x, grad_w, None # None for group_size + +class XNORLinear(nn.Linear): + """XNOR-Net linear layer: binarizes BOTH weights AND activations. + + Forward pass: + 1. Weight binarization: W_bin = sign(W) * alpha, alpha = mean(|W|) per group + 2. Activation binarization: X_bin = sign(X) * beta, beta = mean(|X|) per token + 3. Output = F.linear(X_bin, W_bin) + + Backward pass: Straight-Through Estimator (STE) — gradients flow through as + if the binarization was identity. This matches the XNOR-Net paper's training + algorithm (Algorithm 1). + + When binarize_act=False, this degrades to Binary-Weight-Network (BWN). + """ + def __init__(self, in_features, out_features, bias=False, group_size=128, + binarize_act=True): + super().__init__(in_features, out_features, bias=bias) + self.group_size = group_size + self.binarize_act = binarize_act + + def forward(self, x: Tensor) -> Tensor: + # Triton XNOR+POPCOUNT path (full XNOR only, per-row alpha) + if _TRITON_KERNEL and self.binarize_act: + y = _TritonXNORFn.apply(x, self.weight, self.group_size) + if self.bias is not None: + y = y + self.bias.to(y.dtype) + return y + # INT8 tensor core path (full XNOR only, per-row alpha) + if _INT8_KERNEL and self.binarize_act: + y = _INT8XNORFn.apply(x, self.weight) + if self.bias is not None: + y = y + self.bias.to(y.dtype) + return y + + # Standard STE path (per-group alpha) + w = self.weight.bfloat16() + g = self.group_size + + # --- Weight binarization (STE) --- + w_g = w.reshape(-1, g) + alpha = w_g.abs().mean(-1, keepdim=True).clamp(min=1e-8) + # Simulate FP8 scale quantization so model compensates for roundtrip error + if _SCALE_FP8_STE: + alpha_q = alpha.to(torch.float8_e4m3fn).to(alpha.dtype) + alpha = alpha + (alpha_q - alpha).detach() + w_sign = torch.where(w_g >= 0, torch.ones_like(w_g), -torch.ones_like(w_g)) + w_binary = w + ((w_sign * alpha).reshape(w.shape) - w).detach() + + # --- Activation binarization (STE) --- + if self.binarize_act: + beta = x.abs().mean(dim=-1, keepdim=True).clamp(min=1e-8) + x_sign = torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x)) + x_binary = x + ((x_sign * beta) - x).detach() + else: + x_binary = x + + return F.linear(x_binary, w_binary, + self.bias.to(x.dtype) if self.bias is not None else None) + + +class NormedXNORLinear(XNORLinear): + """XNOR linear with RMSNorm on input. + Used for output projections receiving un-normalized activations (attention out, MLP down). + The RMSNorm ensures zero-mean input before binarization — critical for sign() quality + (XNOR-Net paper: B-A-C-P block order, Table 3b: 30.3% vs 44.2%). + """ + def forward(self, x: Tensor) -> Tensor: + return super().forward(F.rms_norm(x, (x.size(-1),))) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(p in name for p in CTP)) and param.dtype != torch.float32: + param.data = param.data.float() + +# --------------------------------------------------------------------------- +# Rotary embeddings +# --------------------------------------------------------------------------- +class Rotary(nn.Module): + def __init__(self, dim, base=10000.0, rope_type="rope", yarn_max_len=4096, train_seq_len=1024): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + if rope_type == "yarn": + scale = train_seq_len / yarn_max_len + freq_idx = torch.arange(0, dim, 2, dtype=torch.float32) + ramp = torch.clamp((freq_idx / dim - 0.25) / 0.75, 0.0, 1.0) + inv_freq = inv_freq / (ramp * (1.0 / scale - 1.0) + 1.0) + # Precompute for max seq_len, slice at runtime + t = torch.arange(train_seq_len, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self.register_buffer("_cos", freqs.cos()[None, :, None, :], persistent=False) + self.register_buffer("_sin", freqs.sin()[None, :, None, :], persistent=False) + + def forward(self, seq_len, device, dtype): + return self._cos[:, :seq_len].to(dtype=dtype), self._sin[:, :seq_len].to(dtype=dtype) + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- +class CausalSelfAttention(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + group_size=128, rope_type="rope", yarn_max_len=4096, + train_seq_len=1024, binarize_act=True): + super().__init__() + self.num_heads, self.num_kv_heads = num_heads, num_kv_heads + self.head_dim = dim // num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + # Fused QKV — single XNOR matmul + self.c_qkv = XNORLinear(dim, self.q_size + 2 * self.kv_size, bias=False, + group_size=group_size, binarize_act=binarize_act) + # Output projection with RMSNorm before binarization (B-A-C-P order) + self.proj = NormedXNORLinear(dim, dim, bias=False, + group_size=group_size, binarize_act=binarize_act) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, + rope_type=rope_type, yarn_max_len=yarn_max_len, + train_seq_len=train_seq_len) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + qkv_out = self.c_qkv(x) + q_out, k_out, v_out = qkv_out.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q_out.reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = k_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = v_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + # QK normalization (BF16, stays float — non-negotiable for attention quality) + q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + # FlashAttention-3 (BF16 — attention scores/softmax/V are NEVER binarized) + y = flash_attn_func(q.contiguous(), k.contiguous(), v.contiguous(), causal=True) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + +# --------------------------------------------------------------------------- +# MLP +# --------------------------------------------------------------------------- +class MLP(nn.Module): + def __init__(self, dim, mlp_mult, group_size=128, activation="relu2", + binarize_act=True, binarize_down=True): + super().__init__() + hidden = mlp_mult * dim + self.activation = activation + if activation == "swiglu": + self.gate_up = XNORLinear(dim, hidden * 2, bias=False, + group_size=group_size, binarize_act=binarize_act) + else: + self.fc = XNORLinear(dim, hidden, bias=False, + group_size=group_size, binarize_act=binarize_act) + self.proj = NormedXNORLinear(hidden, dim, bias=False, + group_size=group_size, binarize_act=binarize_down) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + if self.activation == "swiglu": + gu = self.gate_up(x) + gate, up = gu.chunk(2, dim=-1) + return self.proj(F.silu(gate) * up) + elif self.activation == "signsq": + h = self.fc(x) + return self.proj(h * h.abs()) + else: # relu2 + return self.proj(torch.relu(self.fc(x)).square()) + +# --------------------------------------------------------------------------- +# SmearModule +# --------------------------------------------------------------------------- +class SmearModule(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + cumsum = x.cumsum(dim=1) + counts = torch.arange(1, x.size(1) + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smeared = cumsum / counts + gate = torch.tanh(self.gate.to(dtype=x.dtype)) + return x + gate * (smeared - x) + +# --------------------------------------------------------------------------- +# Transformer block +# --------------------------------------------------------------------------- +class Block(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, group_size=128, activation="relu2", + rope_type="rope", yarn_max_len=4096, train_seq_len=1024, + binarize_act=True, binarize_down=True, smear=False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, + qk_gain_init, group_size, rope_type, + yarn_max_len, train_seq_len, binarize_act) + self.mlp = MLP(dim, mlp_mult, group_size, activation, binarize_act, binarize_down) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearModule(dim) if smear else None + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0] * x + mix[1] * x0 + n = self.attn_norm(x) + x = x + self.attn_scale.to(dtype=x.dtype) * self.attn(n) + x = x + self.mlp_scale.to(dtype=x.dtype) * self.mlp(self.mlp_norm(x)) + if self.smear is not None: + x = self.smear(x) + return x + +# --------------------------------------------------------------------------- +# GPT with U-Net skip connections +# --------------------------------------------------------------------------- +class GPT(nn.Module): + def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, + mlp_mult, tie_embeddings, tied_embed_init_std, logit_softcap, + rope_base, qk_gain_init, group_size=128, activation="relu2", + embed_dim=0, fp_storage=False, softcap_type="poly", + rope_type="rope", yarn_max_len=4096, train_seq_len=1024, + binarize_act=True, binarize_down=True, smear=False): + super().__init__() + self.tie_embeddings = tie_embeddings + self.logit_softcap = logit_softcap + self.softcap_type = softcap_type + self.fp_storage = fp_storage + self.embed_dim = embed_dim if embed_dim > 0 else model_dim + + # Embedding (FP8 QAT — NEVER binarized, confirmed 0.11 bpb RT gap) + self.tok_emb = QATEmbedding(vocab_size, self.embed_dim, fp_storage=fp_storage) + # Projection from embed_dim to model_dim (and reverse for tied logits) + self.embed_proj = QATLinear(self.embed_dim, model_dim, bias=False, + fp_storage=fp_storage) if self.embed_dim != model_dim else None + self.embed_proj_rev = QATLinear(model_dim, self.embed_dim, bias=False, + fp_storage=fp_storage) if self.embed_dim != model_dim else None + + # U-Net structure + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + # Skip weights initialized to ONES (not zeros — zeros costs 0.010 bpb) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + + # Transformer blocks + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, group_size, activation, rope_type, yarn_max_len, + train_seq_len, binarize_act, binarize_down, smear) + for _ in range(num_layers) + ]) + self.final_norm = RMSNorm() + + # Logit head (tied to embedding) + self.lm_head = QATLinear(model_dim, vocab_size, bias=False, fp_storage=fp_storage) + self.lm_head._zero_init = True + if tie_embeddings: + self.lm_head.weight.requires_grad_(False) + self.vocab_bias = nn.Parameter(torch.zeros(vocab_size, dtype=torch.float32)) + + self._init_weights(tied_embed_init_std) + + def _init_weights(self, tied_embed_init_std): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=tied_embed_init_std) + for module in self.modules(): + if isinstance(module, XNORLinear) and not getattr(module, "_zero_init", False): + nn.init.normal_(module.weight, mean=0.0, std=0.02) + elif isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _compute_logits(self, x: Tensor) -> Tensor: + if self.tie_embeddings: + proj = self.embed_proj_rev(x) if self.embed_proj_rev is not None else x + return F.linear(proj, self.tok_emb.weight.to(x.dtype)) + return self.lm_head(x) + + def _softcap(self, logits: Tensor) -> Tensor: + s = self.logit_softcap + if self.softcap_type == "tanh": + return s * torch.tanh(logits / s) + # Poly softcap — fuses with torch.compile (tanh does not) + x_sc = torch.clamp(logits / s, -2.0, 2.0) + x2 = x_sc * x_sc + return s * torch.clamp(x_sc * (1.0 - x2 / 3.0 + x2 * x2 / 15.0), -1.0, 1.0) + + def forward(self, input_ids: Tensor, target_ids: Tensor, reduction="mean", temperature=1.0): + x = self.tok_emb(input_ids) + x = x.float() + if self.embed_proj is not None: + x = self.embed_proj(x) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + # U-Net encoder + skips = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + + # U-Net decoder with skip connections + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype) * skips.pop() + x = self.blocks[bi](x, x0) + + x_normed = self.final_norm(x) + x_flat = x_normed.reshape(-1, x_normed.size(-1)) + targets = target_ids.reshape(-1) + logits = self._softcap(self._compute_logits(x_flat)) + + if temperature != 1.0: + logits = logits / temperature + if reduction == "none": + return F.cross_entropy(logits.float(), targets, reduction="none").reshape(input_ids.shape) + + # Fused CE + Z-loss (logsumexp² regularization keeps STE gradients sharp) + logits_f = logits.float() + lse = torch.logsumexp(logits_f, dim=-1) + target_logits = logits_f.gather(1, targets.unsqueeze(1)).squeeze(1) + return (lse - target_logits).mean() + 1e-4 * (lse ** 2).mean() + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- +def build_luts(sp, vocab_size, device): + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + +def ld_val(pattern, seq_len, max_tok=int(os.environ.get("VAL_MAX_TOKENS", 500000))): + files = sorted(glob.glob(pattern)) + assert files, f"No files: {pattern}" + tok = torch.cat([ld_shard(Path(p)) for p in files]).contiguous() + if max_tok > 0: tok = tok[:max_tok + 1] + u = ((tok.numel() - 1) // seq_len) * seq_len + return tok[:u + 1] + +def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, temperature=1.0): + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_seqs = max(1, local_batch_tokens // args.train_seq_len) + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_start in range(seq_start, seq_end, local_batch_seqs): + batch_end = min(batch_start + local_batch_seqs, seq_end) + raw_start = batch_start * args.train_seq_len + raw_end = batch_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64) + x, y = local[:-1].reshape(-1, args.train_seq_len), local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + batch_loss = model(x, y, temperature=temperature).detach() + n = float(y.numel()) + loss_sum += batch_loss.to(torch.float64) * n + token_count += n + prev_ids, tgt_ids = x.reshape(-1), y.reshape(-1) + tok_bytes = base_bytes_lut[tgt_ids].to(torch.int16) + tok_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.int16) + byte_count += tok_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, token_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + val_loss = loss_sum / token_count + bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) + model.train() + return float(val_loss.item()), float(bpb) + +def eval_val_sliding(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, temperature=1.0): + seq_len = args.train_seq_len + batch_size = args.sliding_batch_size + total_tokens = val_tokens.numel() - 1 + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + all_starts = list(range(0, total_tokens - seq_len, stride)) + my_starts = all_starts[rank::world_size] + model.eval() + with torch.inference_mode(): + for i in range(0, len(my_starts), batch_size): + batch_starts = my_starts[i:i + batch_size] + starts_t = torch.tensor(batch_starts, dtype=torch.int64) + offsets = torch.arange(seq_len + 1, dtype=torch.int64) + indices = starts_t.unsqueeze(1) + offsets.unsqueeze(0) + local_batch = val_tokens[indices].to(device=device, dtype=torch.int64, non_blocking=True) + x, y = local_batch[:, :-1], local_batch[:, 1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_token_loss = model(x, y, reduction="none", temperature=temperature).detach() + for b, start in enumerate(batch_starts): + score_from = 0 if start == 0 else seq_len - stride + scored = per_token_loss[b, score_from:] + sx, sy = x[b, score_from:], y[b, score_from:] + loss_sum += scored.to(torch.float64).sum() + token_count += scored.numel() + tok_bytes = base_bytes_lut[sy].to(torch.int16) + tok_bytes += (has_leading_space_lut[sy] & ~is_boundary_token_lut[sx]).to(torch.int16) + byte_count += tok_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, token_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + val_loss = loss_sum / token_count + bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) + model.train() + return float(val_loss.item()), float(bpb) + +def find_temp(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut): + best_t, best_loss = 1.0, float("inf") + for t in [0.85, 0.90, 0.95, 1.00, 1.05, 1.10]: + loss, _ = eval_val(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, temperature=t) + if loss < best_loss: + best_loss = loss + best_t = t + return best_t + +# --------------------------------------------------------------------------- +# Checkpointing +# --------------------------------------------------------------------------- +def _ckpt_path(d, step): return os.path.join(d, f"ckpt_step{step:07d}.pt") + +def _latest_checkpoint(d): + if not os.path.isdir(d): return None + ckpts = sorted(glob.glob(os.path.join(d, "ckpt_step*.pt"))) + return ckpts[-1] if ckpts else None + +def save_checkpoint(d, step, model, optimizers, ms, _untied, ema_model, _ema_started, _ema_steps): + os.makedirs(d, exist_ok=True) + path = _ckpt_path(d, step) + tmp = path + ".tmp" + torch.save({ + "step": step, "training_time_ms": ms, "model": model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "rng_cpu": torch.get_rng_state(), "rng_cuda": torch.cuda.get_rng_state(), + "flags": {"_untied": _untied, "_ema_started": _ema_started, "_ema_steps": _ema_steps}, + "ema_model": ema_model.state_dict() if ema_model is not None else None, + }, tmp) + os.replace(tmp, path) + +def load_checkpoint(path, model, optimizers, device, ema_model=None): + p = torch.load(path, map_location=device, weights_only=False) + model.load_state_dict(p["model"], strict=True) + for opt, sd in zip(optimizers, p["optimizers"]): + opt.load_state_dict(sd) + if ema_model is not None and p.get("ema_model") is not None: + ema_model.load_state_dict(p["ema_model"], strict=True) + torch.set_rng_state(p["rng_cpu"].cpu()) + torch.cuda.set_rng_state(p["rng_cuda"].cpu()) + f = p.get("flags", {}) + return p["step"], p["training_time_ms"], f.get("_untied", False), f.get("_ema_started", False), f.get("_ema_steps", 0) + +# --------------------------------------------------------------------------- +# Main training +# --------------------------------------------------------------------------- +def main(): + args = Hyperparameters() + code = Path(__file__).read_text(encoding="utf-8") + if args.matrix_optimizer != "adamw": + global ns_orth + if args.use_gram_ns: + ns_orth = torch.compile(gram_ns_orth) + else: + ns_orth = torch.compile(ns_orth) + + global _INT8_KERNEL + _INT8_KERNEL = args.use_int8_kernel + + global _TRITON_KERNEL + _TRITON_KERNEL = args.use_triton_kernel and HAS_TRITON + + global _SCALE_FP8_STE + _SCALE_FP8_STE = args.scale_fp8 + + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = world_size > 1 + grad_accum_steps = max(1, 8 // world_size) + grad_scale = 1.0 / grad_accum_steps + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + os.makedirs("logs/cuda/xnor/", exist_ok=True) + logfile = f"logs/cuda/xnor/{args.run_id}.txt" if master_process else None + if master_process: print(logfile) + def log0(msg, console=True): + if not master_process: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + val_tokens = ld_val(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_luts(sp, args.vocab_size, device) + + # --- Model --- + # BINARIZE_ACTIVATIONS: 0=BWN, 1=full XNOR, 2=XNOR but BWN on MLP down proj + _ba = args.binarize_activations + _binarize_act = _ba >= 1 + _binarize_down = _ba == 1 + base_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + group_size=args.xnor_group_size, activation=args.activation_type, + embed_dim=args.embed_dim, fp_storage=args.fp_storage, + softcap_type=args.softcap_type, rope_type=args.rope_type, + yarn_max_len=args.yarn_max_len, train_seq_len=args.train_seq_len, + binarize_act=_binarize_act, binarize_down=_binarize_down, + smear=args.smear, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, nn.Linear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.tie_embeddings: + base_model.lm_head.weight.requires_grad_(False) + + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.recompile_limit = 16 + compiled_model = torch.compile(base_model, mode=args.compile_mode if args.compile_mode != "default" else None) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, + find_unused_parameters=not args.tie_embeddings, + static_graph=args.tie_embeddings, + gradient_as_bucket_view=True) if distributed else compiled_model + + # --- Optimizers --- + _excl = {"tok_emb.weight", "lm_head.weight"} + all_other = [(n, p) for n, p in base_model.named_parameters() if not any(eh in n for eh in _excl)] + matrix_params = [p for n, p in all_other if p.ndim == 2 and not any(pat in n for pat in CTP)] + scalar_params = [p for n, p in all_other if p.ndim < 2 or any(pat in n for pat in CTP)] + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + opt_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + + if args.matrix_optimizer == "adamw": + opt_muon = torch.optim.AdamW( + [{"params": matrix_params, "lr": args.adam_lr, "base_lr": args.adam_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.adam_wd, fused=True) + else: + opt_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, wd=args.muon_wd) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": 0.0, "base_lr": 0.0}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_tok, opt_muon, opt_scalar, opt_head] + + # --- EMA --- + ema_model = None + _ema_started = False + _ema_steps = 0 + if args.ema: + ema_model = copy.deepcopy(base_model) + for p in ema_model.parameters(): p.requires_grad_(False) + + # --- Checkpoint resume --- + _resume_step, _resume_ms, _resume_untied = 0, 0.0, False + _resume_ema_started, _resume_ema_steps = False, 0 + if args.checkpoint_every > 0: + ckpt = _latest_checkpoint(args.checkpoint_dir) + if ckpt: + log0(f"resuming from {ckpt}") + _resume_step, _resume_ms, _resume_untied, _resume_ema_started, _resume_ema_steps = \ + load_checkpoint(ckpt, base_model, optimizers, device, ema_model) + log0(f"resumed at step {_resume_step} ({_resume_ms:.0f}ms)") + else: + log0("no checkpoint, starting fresh") + + # --- Log --- + log0("--- Hyperparameters ---", console=False) + log0(" ".join(f"{a}={getattr(args,a)}" for a in sorted(dir(args)) + if not a.startswith("_") and a not in ("train_files","val_files") + and not callable(getattr(args,a))), console=False) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"XNOR-Net LLM | params:{n_params} L:{args.num_layers} d:{args.model_dim} " + f"h:{args.num_heads} kv:{args.num_kv_heads} mlp:{args.mlp_mult}x " + f"act:{args.activation_type} xnor_act:{args.binarize_activations} " + f"g:{args.xnor_group_size} ws:{world_size} ga:{grad_accum_steps}") + + # --- Data --- + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all(): + for opt in optimizers: opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step, elapsed_ms): + import math + if args.lr_schedule == "cosine": + # Cosine decay from 1.0 to 0 over the full run, with optional warmup + if max_wallclock_ms is None: + total = args.iterations + warmup = min(args.lr_warmup_steps, total) + if step < warmup: + return step / max(warmup, 1) + progress = (step - warmup) / max(total - warmup, 1) + else: + warmup_ms = args.lr_warmup_steps * (elapsed_ms / max(step, 1)) if step > 0 else 1000.0 + if elapsed_ms < warmup_ms: + return elapsed_ms / max(warmup_ms, 1e-9) + progress = (elapsed_ms - warmup_ms) / max(max_wallclock_ms - warmup_ms, 1e-9) + return 0.5 * (1.0 + math.cos(math.pi * min(progress, 1.0))) + # Linear warmdown (original) + if args.warmdown_fraction <= 0: return 1.0 + if max_wallclock_ms is None: + warmdown_start = int(args.iterations * (1.0 - args.warmdown_fraction)) + return max((args.iterations - step) / max(args.iterations * args.warmdown_fraction, 1), 0.0) if step >= warmdown_start else 1.0 + warmdown_ms = max_wallclock_ms * args.warmdown_fraction + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # --- Seq len scheduling --- + seq_len_stages = [args.train_seq_len] + if args.seq_len_schedule: + seq_len_stages = [] + sl = 128 + while sl <= args.train_seq_len: + seq_len_stages.append(sl) + sl *= 2 + if seq_len_stages[-1] != args.train_seq_len: + seq_len_stages.append(args.train_seq_len) + log0(f"seq_len_schedule: {seq_len_stages}") + + def get_seq_len(step, elapsed_ms): + if len(seq_len_stages) <= 1: + return seq_len_stages[0] + n = len(seq_len_stages) + if max_wallclock_ms is not None: + stage_idx = min(int(elapsed_ms / (max_wallclock_ms / n)), n - 1) + else: + stage_idx = min(int(step / (args.iterations / n)), n - 1) + return seq_len_stages[stage_idx] + + # --- Compiler warmup --- + warmup_seq_lens = seq_len_stages if args.seq_len_schedule else [args.train_seq_len] + if len(warmup_seq_lens) > 0: + _ms = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + _os = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws, sl in enumerate(warmup_seq_lens): + zero_grad_all() + for mi in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = mi == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, sl, grad_accum_steps) + torch.compiler.cudagraph_mark_step_begin() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): loss = model(x, y) + (loss * grad_scale).backward() + for o in optimizers: o.step() + zero_grad_all() + log0(f"warmup:{ws+1}/{len(warmup_seq_lens)} seq_len={sl}") + base_model.load_state_dict(_ms, strict=True) + for o, s in zip(optimizers, _os): o.load_state_dict(s) + zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # --- Main loop --- + training_time_ms = _resume_ms + stop_after_step = None + _untied = _resume_untied + _ema_started = _resume_ema_started + _ema_steps = _resume_ema_steps + train_loss = torch.zeros((), device=device) + torch.cuda.synchronize() + t0 = time.perf_counter() + step = _resume_step + steps_this_session = 0 + + if max_wallclock_ms is None: + if step >= args.iterations: + stop_after_step = step + elif training_time_ms >= max_wallclock_ms: + stop_after_step = step + + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms") + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms") + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + # Muon momentum warmup + if args.matrix_optimizer != "adam": + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + for g in opt_muon.param_groups: + g["momentum"] = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + + zero_grad_all() + train_loss.zero_() + current_seq_len = get_seq_len(step, elapsed_ms) + for micro in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = micro == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, current_seq_len, grad_accum_steps) + torch.compiler.cudagraph_mark_step_begin() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y) + # Compression regularizer: penalize sign entropy within groups + if args.sign_compress_reg > 0 and micro == 0: + reg = torch.tensor(0.0, device=device) + for p in base_model.parameters(): + if p.ndim == 2 and p.numel() > 65536: + g = args.xnor_group_size + pg = p.reshape(-1, g) + # soft sign: tanh(k*w) approximates sign(w) but has gradient + soft = torch.tanh(pg * 10.0) + # mean per group: +1/-1 = all same sign (compressible), 0 = balanced + reg = reg + (1.0 - soft.mean(dim=-1).abs().mean()) + loss = loss + args.sign_compress_reg * reg + train_loss.add_(loss.detach()) + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for opt in optimizers: + for g in opt.param_groups: g["lr"] = g["base_lr"] * scale + opt.step() + zero_grad_all() + + # EMA update + if ema_model is not None: + if not _ema_started: + if max_wallclock_ms is not None: + should_start = elapsed_ms >= args.ema_start_fraction * max_wallclock_ms + else: + should_start = step >= int(args.iterations * args.ema_start_fraction) + if should_start: + _ema_started = True + _ema_steps = 0 + with torch.no_grad(): + for ep, bp in zip(ema_model.parameters(), base_model.parameters()): + ep.data.copy_(bp.data) + log0(f"step:{step} ema_started") + if _ema_started: + _ema_steps += 1 + decay = min(args.ema_decay, (1.0 + _ema_steps) / (10.0 + _ema_steps)) + with torch.no_grad(): + for ep, bp in zip(ema_model.parameters(), base_model.parameters()): + ep.data.mul_(decay).add_(bp.data, alpha=1.0 - decay) + + step += 1 + steps_this_session += 1 + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # Checkpoint + if master_process and args.checkpoint_every > 0 and step % args.checkpoint_every == 0: + save_checkpoint(args.checkpoint_dir, step, base_model, optimizers, + approx_ms, _untied, ema_model, _ema_started, _ema_steps) + log0(f"step:{step} checkpoint saved") + + if args.train_log_every > 0 and step % args.train_log_every == 0: + log0(f"step:{step}/{args.iterations} loss:{train_loss.item():.4f} " + f"t:{approx_ms:.0f}ms step_avg:{approx_ms/step:.1f}ms") + + if args.churn_log_every > 0 and step % args.churn_log_every == 0: + log0(f"step:{step} churn:{churn_fn(base_model, args.xnor_group_size):.4f}") + + # Wallclock cap + if stop_after_step is None and max_wallclock_ms is not None and step % 10 == 0: + reached_cap = approx_ms >= max_wallclock_ms + if distributed: + cap_t = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(cap_t, op=dist.ReduceOp.MAX) + reached_cap = bool(cap_t.item()) + if reached_cap: + stop_after_step = step + + # --- Sign-sort permutation for compression --- + # Permute MLP hidden dimension so same-sign weights cluster together. + # Sort columns of down_proj by their sign pattern (lexicographic on sign bits). + # Apply matching row permutation to up_proj. Model output is identical. + if master_process: + sd = (ema_model if ema_model is not None and _ema_started else base_model).state_dict() + for layer_idx in range(args.num_layers): + # Find the MLP weight pairs + down_key = f"blocks.{layer_idx}.mlp.proj.weight" + if args.activation_type == "swiglu": + up_key = f"blocks.{layer_idx}.mlp.gate_up.weight" + else: + up_key = f"blocks.{layer_idx}.mlp.fc.weight" + if down_key not in sd or up_key not in sd: + continue + w_down = sd[down_key] # [dim, hidden] + w_up = sd[up_key] # [hidden, dim] (or [hidden*2, dim] for swiglu) + hidden = w_down.shape[1] + # Sort hidden dim columns of w_down by sign pattern + signs = (w_down >= 0).to(torch.int8) # [dim, hidden] + # Pack sign columns into sortable keys (use first 64 rows as sort key) + n_key_rows = min(64, signs.shape[0]) + sort_keys = torch.zeros(hidden, dtype=torch.int64, device=w_down.device) + for r in range(n_key_rows): + sort_keys += signs[r, :].long() << (n_key_rows - 1 - r) + perm = sort_keys.argsort() + # Apply permutation + sd[down_key] = w_down[:, perm] + if args.activation_type == "swiglu": + half = w_up.shape[0] // 2 + sd[up_key] = torch.cat([w_up[:half][perm], w_up[half:][perm]], dim=0) + else: + sd[up_key] = w_up[perm] + if base_model.tie_embeddings: + sd.pop("lm_head.weight", None) + q_obj, q_stats = q_sd(sd, group_size=args.xnor_group_size, + fp_storage=args.fp_storage, scale_fp8=args.scale_fp8) + buf = io.BytesIO() + torch.save(q_obj, buf) + raw_bytes = buf.getvalue() + # Compare LZMA vs Brotli, keep smallest + lzma_blob = lzma.compress(raw_bytes, preset=9) + best_blob, best_method = lzma_blob, "lzma" + comp_log = f"compression: lzma={len(lzma_blob)/1e6:.2f}MB" + if HAS_BROTLI: + brotli_blob = brotli.compress(raw_bytes, quality=11) + comp_log += f" brotli={len(brotli_blob)/1e6:.2f}MB" + if len(brotli_blob) < len(best_blob): + best_blob, best_method = brotli_blob, "brotli" + if HAS_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + zstd_blob = cctx.compress(raw_bytes) + comp_log += f" zstd={len(zstd_blob)/1e6:.2f}MB" + if len(zstd_blob) < len(best_blob): + best_blob, best_method = zstd_blob, "zstd" + log0(comp_log) + log0(f"using: {best_method}") + with open("final_model.xnor.ptz", "wb") as f: + # Method byte: 0=lzma, 1=brotli, 2=zstd + method_id = {"lzma": 0, "brotli": 1, "zstd": 2}[best_method] + f.write(bytes([method_id])) + f.write(best_blob) + artifact_bytes = 1 + len(best_blob) + code_bytes = len(code.encode("utf-8")) + total = artifact_bytes + code_bytes + log0(f"artifact:{artifact_bytes/1e6:.2f}MB binary:{q_stats['binary_params']}({q_stats['binary_bytes']}B) " + f"fp:{q_stats['fp_params']}({q_stats['fp_bytes']}B) code:{code_bytes}") + log0(f"budget:{total}/{16000000} ({total/1e6:.2f}/{16.00:.2f}MB) " + f"{'FITS' if total <= 16000000 else 'OVER'}") + + # --- Roundtrip eval --- + if distributed: dist.barrier() + with open("final_model.xnor.ptz", "rb") as f: + raw = f.read() + method_byte = raw[0] + compressed = raw[1:] + if method_byte == 2 and HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + decompressed = dctx.decompress(compressed) + elif method_byte == 1 and HAS_BROTLI: + decompressed = brotli.decompress(compressed) + else: + decompressed = lzma.decompress(compressed) + loaded = torch.load(io.BytesIO(decompressed), map_location="cpu", weights_only=False) + base_model.load_state_dict(deq_sd(loaded), strict=False) + if ema_model is not None: + ema_model.load_state_dict(deq_sd(loaded), strict=False) + torch._dynamo.reset() + q_val_loss, q_val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"final_xnor_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f}") + + opt_temp = 1.0 + if args.temp_scaling: + torch.cuda.synchronize() + t_temp = time.perf_counter() + calibration_tokens = train_loader.stream.take(65536).to(device) + opt_temp = find_temp(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + torch.cuda.synchronize() + log0(f"temp_scaling optimal_T:{opt_temp:.2f} time:{1000*(time.perf_counter()-t_temp):.0f}ms") + + if args.sliding_eval: + torch.cuda.synchronize() + t_sl = time.perf_counter() + sw_loss, sw_bpb = eval_val_sliding(args, base_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, stride=args.sliding_eval_stride, + temperature=opt_temp) + torch.cuda.synchronize() + log0(f"final_sliding val_loss:{sw_loss:.4f} val_bpb:{sw_bpb:.4f} " + f"(stride={args.sliding_eval_stride}, T={opt_temp:.2f}) time:{1000*(time.perf_counter()-t_sl):.0f}ms") + + if distributed: dist.destroy_process_group() + + +if __name__ == "__main__": + main() +==================================================================================================== +no checkpoint, starting fresh +--- Hyperparameters --- +activation_type=signsq adam_eps=1e-08 adam_lr=0.02 adam_wd=0.04 beta1=0.9 beta2=0.95 binarize_activations=2 checkpoint_dir=./checkpoints checkpoint_every=10000 churn_log_every=0 compile_mode=default data_path=./data/datasets/fineweb10B_sp1024 ema=False ema_decay=0.995 ema_start_fraction=0.0 embed_dim=384 embed_lr=0.6 fp_storage=True grad_clip_norm=1.0 head_lr=0.02 iterations=100000 logit_softcap=10.0 lr_schedule=cosine lr_warmup_steps=5 matrix_lr=0.008 matrix_optimizer=muon max_wallclock_seconds=0.0 mlp_mult=4 model_dim=1024 muon_backend_steps=3 muon_momentum=0.8 muon_momentum_warmup_start=0.8 muon_momentum_warmup_steps=200 muon_wd=0.04 num_heads=8 num_kv_heads=4 num_layers=10 qk_gain_init=2.25 rope_base=5000.0 rope_type=yarn run_id=N2_full-xnor_8xH100_1775379546 scalar_lr=0.01 scale_fp8=False seed=42 seq_len_schedule=True sign_compress_reg=0.0 sliding_batch_size=512 sliding_eval=True sliding_eval_stride=16 smear=False softcap_type=poly temp_scaling=True tie_embeddings=1 tied_embed_init_std=0.005 tied_embed_lr=0.05 tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model train_batch_tokens=65536 train_log_every=1000 train_seq_len=1024 use_gram_ns=False use_int8_kernel=False use_triton_kernel=True val_batch_size=524288 val_loss_every=0 vocab_size=1024 warmdown_fraction=0.3 xnor_group_size=256 yarn_max_len=2048 +XNOR-Net LLM | params:117618768 L:10 d:1024 h:8 kv:4 mlp:4x act:signsq xnor_act:2 g:256 ws:8 ga:1 +seq_len_schedule: [128, 256, 512, 1024] +warmup:1/4 seq_len=128 +warmup:2/4 seq_len=256 +warmup:3/4 seq_len=512 +warmup:4/4 seq_len=1024 +step:1000/100000 loss:3.6041 t:38617ms step_avg:38.6ms +step:2000/100000 loss:3.4874 t:76994ms step_avg:38.5ms +step:3000/100000 loss:3.2816 t:115793ms step_avg:38.6ms +step:4000/100000 loss:3.4228 t:154525ms step_avg:38.6ms +step:5000/100000 loss:3.4680 t:193214ms step_avg:38.6ms +step:6000/100000 loss:3.4292 t:232381ms step_avg:38.7ms +step:7000/100000 loss:3.4720 t:271138ms step_avg:38.7ms +step:8000/100000 loss:3.6766 t:310677ms step_avg:38.8ms +step:9000/100000 loss:3.8428 t:350041ms step_avg:38.9ms +step:10000 checkpoint saved +step:10000/100000 loss:3.6277 t:388673ms step_avg:38.9ms +step:11000/100000 loss:3.5396 t:427599ms step_avg:38.9ms +step:12000/100000 loss:3.5255 t:466355ms step_avg:38.9ms +step:13000/100000 loss:3.5803 t:505176ms step_avg:38.9ms +step:14000/100000 loss:3.5926 t:544155ms step_avg:38.9ms +step:15000/100000 loss:3.6088 t:582510ms step_avg:38.8ms +step:16000/100000 loss:3.6105 t:620776ms step_avg:38.8ms +step:17000/100000 loss:3.4647 t:659525ms step_avg:38.8ms +step:18000/100000 loss:3.6159 t:699174ms step_avg:38.8ms +step:19000/100000 loss:3.5035 t:738207ms step_avg:38.9ms +step:20000 checkpoint saved +step:20000/100000 loss:3.2823 t:777288ms step_avg:38.9ms +step:21000/100000 loss:3.9198 t:816796ms step_avg:38.9ms +step:22000/100000 loss:3.7064 t:855652ms step_avg:38.9ms +step:23000/100000 loss:3.6011 t:894490ms step_avg:38.9ms +step:24000/100000 loss:3.7519 t:933328ms step_avg:38.9ms +step:25000/100000 loss:3.5111 t:972951ms step_avg:38.9ms +step:26000/100000 loss:3.5498 t:1012608ms step_avg:38.9ms +step:27000/100000 loss:3.4294 t:1051701ms step_avg:39.0ms +step:28000/100000 loss:3.8179 t:1090872ms step_avg:39.0ms +step:29000/100000 loss:3.6463 t:1130278ms step_avg:39.0ms +step:30000 checkpoint saved +step:30000/100000 loss:3.4794 t:1169400ms step_avg:39.0ms +step:31000/100000 loss:3.6463 t:1208233ms step_avg:39.0ms +step:32000/100000 loss:3.5177 t:1247119ms step_avg:39.0ms +step:33000/100000 loss:3.2191 t:1286438ms step_avg:39.0ms +step:34000/100000 loss:3.3552 t:1325943ms step_avg:39.0ms +step:35000/100000 loss:3.6644 t:1364691ms step_avg:39.0ms +step:36000/100000 loss:3.7277 t:1403163ms step_avg:39.0ms +step:37000/100000 loss:3.4332 t:1441904ms step_avg:39.0ms +step:38000/100000 loss:3.4913 t:1481237ms step_avg:39.0ms +step:39000/100000 loss:3.4081 t:1520241ms step_avg:39.0ms +step:40000 checkpoint saved +step:40000/100000 loss:3.5163 t:1559017ms step_avg:39.0ms +step:41000/100000 loss:3.8149 t:1598127ms step_avg:39.0ms +step:42000/100000 loss:3.5277 t:1637357ms step_avg:39.0ms +step:43000/100000 loss:3.5065 t:1676018ms step_avg:39.0ms +step:44000/100000 loss:3.4695 t:1714637ms step_avg:39.0ms +step:45000/100000 loss:3.7332 t:1754371ms step_avg:39.0ms +step:46000/100000 loss:3.3099 t:1793947ms step_avg:39.0ms +step:47000/100000 loss:3.4211 t:1832801ms step_avg:39.0ms +step:48000/100000 loss:3.2141 t:1871894ms step_avg:39.0ms +step:49000/100000 loss:3.2859 t:1910638ms step_avg:39.0ms +step:50000 checkpoint saved +step:50000/100000 loss:3.3327 t:1949530ms step_avg:39.0ms +step:51000/100000 loss:3.5148 t:1988558ms step_avg:39.0ms +step:52000/100000 loss:3.2302 t:2027360ms step_avg:39.0ms +step:53000/100000 loss:3.3083 t:2067148ms step_avg:39.0ms +step:54000/100000 loss:3.2770 t:2106216ms step_avg:39.0ms +step:55000/100000 loss:3.5900 t:2145050ms step_avg:39.0ms +step:56000/100000 loss:3.0136 t:2184015ms step_avg:39.0ms +step:57000/100000 loss:3.2202 t:2223107ms step_avg:39.0ms +step:58000/100000 loss:3.3426 t:2261890ms step_avg:39.0ms +step:59000/100000 loss:3.2997 t:2301273ms step_avg:39.0ms +step:60000 checkpoint saved +step:60000/100000 loss:3.2592 t:2340534ms step_avg:39.0ms +step:61000/100000 loss:3.2692 t:2379510ms step_avg:39.0ms +step:62000/100000 loss:3.0064 t:2417685ms step_avg:39.0ms +step:63000/100000 loss:3.2605 t:2457029ms step_avg:39.0ms +step:64000/100000 loss:3.1490 t:2496026ms step_avg:39.0ms +step:65000/100000 loss:3.1985 t:2534998ms step_avg:39.0ms +step:66000/100000 loss:2.8600 t:2573775ms step_avg:39.0ms +step:67000/100000 loss:3.0304 t:2612783ms step_avg:39.0ms +step:68000/100000 loss:3.0754 t:2651488ms step_avg:39.0ms +step:69000/100000 loss:3.2654 t:2689756ms step_avg:39.0ms +step:70000 checkpoint saved +step:70000/100000 loss:3.5813 t:2728896ms step_avg:39.0ms +step:71000/100000 loss:3.1731 t:2767980ms step_avg:39.0ms +step:72000/100000 loss:3.0859 t:2806693ms step_avg:39.0ms +step:73000/100000 loss:3.2842 t:2845296ms step_avg:39.0ms +step:74000/100000 loss:3.1496 t:2884203ms step_avg:39.0ms +step:75000/100000 loss:3.3874 t:2922846ms step_avg:39.0ms +step:76000/100000 loss:3.0111 t:2961600ms step_avg:39.0ms +step:77000/100000 loss:3.1280 t:3000796ms step_avg:39.0ms +step:78000/100000 loss:3.2465 t:3039921ms step_avg:39.0ms +step:79000/100000 loss:3.0398 t:3079377ms step_avg:39.0ms +step:80000 checkpoint saved +step:80000/100000 loss:3.1211 t:3117950ms step_avg:39.0ms +step:81000/100000 loss:2.5261 t:3157257ms step_avg:39.0ms +step:82000/100000 loss:3.0721 t:3196095ms step_avg:39.0ms +step:83000/100000 loss:2.9856 t:3234842ms step_avg:39.0ms +step:84000/100000 loss:3.2483 t:3274036ms step_avg:39.0ms +step:85000/100000 loss:3.0236 t:3312447ms step_avg:39.0ms +step:86000/100000 loss:3.0041 t:3351429ms step_avg:39.0ms +step:87000/100000 loss:2.8871 t:3391245ms step_avg:39.0ms +step:88000/100000 loss:2.9250 t:3429962ms step_avg:39.0ms +step:89000/100000 loss:3.1118 t:3468230ms step_avg:39.0ms +step:90000 checkpoint saved +step:90000/100000 loss:2.9893 t:3506587ms step_avg:39.0ms +step:91000/100000 loss:2.9344 t:3545886ms step_avg:39.0ms +step:92000/100000 loss:3.1058 t:3584828ms step_avg:39.0ms +step:93000/100000 loss:3.0639 t:3623757ms step_avg:39.0ms +step:94000/100000 loss:2.7357 t:3662528ms step_avg:39.0ms +step:95000/100000 loss:2.8264 t:3701326ms step_avg:39.0ms +step:96000/100000 loss:2.4896 t:3740877ms step_avg:39.0ms +step:97000/100000 loss:2.7780 t:3780026ms step_avg:39.0ms +step:98000/100000 loss:3.0280 t:3819744ms step_avg:39.0ms +step:99000/100000 loss:2.6402 t:3858941ms step_avg:39.0ms +step:100000 checkpoint saved +step:100000/100000 loss:2.7904 t:3898337ms step_avg:39.0ms +step:100000/100000 val_loss:2.8153 val_bpb:1.5689 train_time:3898778ms +compression: lzma=15.88MB brotli=15.83MB zstd=18.28MB +using: brotli +artifact:15.83MB binary:115343360(15319040B) fp:1226832(1248416B) code:75015 +budget:15908225/16000000 (15.91/16.00MB) FITS +final_xnor_roundtrip val_loss:2.8255 val_bpb:1.5746 +temp_scaling optimal_T:1.00 time:584ms +final_sliding val_loss:2.7621 val_bpb:1.5392 (stride=16, T=1.00) time:1282427ms diff --git a/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/log_10mins_seed-1337.txt b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/log_10mins_seed-1337.txt new file mode 100644 index 0000000000..adc043be81 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/log_10mins_seed-1337.txt @@ -0,0 +1,1723 @@ +"""XNOR-Net LLM training script for OpenAI's Parameter Golf Challenge. +Ciprian-Florin Ifrim — April 2026 + +Based on the Binary-Weight-Network script, extended to full XNOR-Net: +both weights AND activations are binarized using sign() with per-group/per-token +scaling factors derived from the XNOR-Net paper (Rastegari et al. 2016). + +This is the REFERENCE implementation using STE-simulated XNOR via F.linear. +The Triton INT8×INT8 kernel version comes later for H100 deployment. + +Architecture: U-Net transformer with skip connections — provides error correction +for the information loss inherent in binary quantization of both weights and activations. +""" + +import copy +import glob +import io +import math +import os +import random +import sys +import time +import lzma +from pathlib import Path +try: + import brotli + HAS_BROTLI = True +except ImportError: + HAS_BROTLI = False +try: + import zstandard as zstd + HAS_ZSTD = True +except ImportError: + HAS_ZSTD = False +try: + from gram_newton_schulz import GramNewtonSchulz, POLAR_EXPRESS_COEFFICIENTS + HAS_GRAM_NS = True +except ImportError: + HAS_GRAM_NS = False +import numpy as np +import warnings +warnings.filterwarnings("ignore", category=UserWarning, module="torch._inductor") +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False +def _get_flash_attn_func(): + """GPU-aware FlashAttention selection. + H100: FA3 -> FA2 -> SDPA + 5090: FA4 -> FA2 -> SDPA + Other: FA2 -> SDPA + """ + if torch.cuda.is_available(): + sm = torch.cuda.get_device_capability() + # Hopper (SM 90) — prefer FA3 + if sm[0] == 9 and sm[1] == 0: + try: + from flash_attn_interface import flash_attn_func + return flash_attn_func + except ImportError: + pass + # Blackwell (SM 100/120) — try FA4 + if sm[0] >= 10: + try: + from flash_attn.cute import flash_attn_func + return flash_attn_func + except ImportError: + pass + # FA2 — universal fallback with CUDA kernels + try: + from flash_attn import flash_attn_func + return flash_attn_func + except ImportError: + pass + # PyTorch native SDPA — always works, no install needed + def _sdpa_flash_attn_func(q, k, v, causal=False, **kwargs): + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + o = F.scaled_dot_product_attention(q, k, v, is_causal=causal, enable_gqa=True) + return o.transpose(1, 2) + return _sdpa_flash_attn_func +flash_attn_func = _get_flash_attn_func() + +# --------------------------------------------------------------------------- +# Hyperparameters +# --------------------------------------------------------------------------- +def _e(k, d, t=str): + v = os.environ.get(k, str(d)) + if t == bool: return bool(int(v)) + return t(v) + +class Hyperparameters: + # Data + data_path = _e("DATA_PATH", "./data/datasets/fineweb10B_sp8192") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = _e("TOKENIZER_PATH", "./data/tokenizers/fineweb_8192_bpe.model") + run_id = os.environ.get("RUN_ID", f"xnor_{int(time.time())}") + seed = _e("SEED", 1337, int) + # Compile + compile_mode = _e("COMPILE_MODE", "default") + # Eval + val_batch_size = _e("VAL_BATCH_SIZE", 524288, int) + val_loss_every = _e("VAL_LOSS_EVERY", 500, int) + train_log_every = _e("TRAIN_LOG_EVERY", 10, int) + sliding_eval = _e("SLIDING_EVAL", 0, bool) + sliding_eval_stride = _e("SLIDING_EVAL_STRIDE", 16, int) + sliding_batch_size = _e("SLIDING_BATCH_SIZE", 512, int) + temp_scaling = _e("TEMP_SCALING", 0, bool) + # Training schedule + iterations = _e("ITERATIONS", 2000, int) + warmdown_fraction = _e("WARMDOWN_FRACTION", 0.15, float) + lr_warmup_steps = _e("LR_WARMUP_STEPS", 5, int) + lr_schedule = _e("LR_SCHEDULE", "linear") # "linear" (current) or "cosine" + seq_len_schedule = _e("SEQ_LEN_SCHEDULE", 0, bool) # ramp seq_len from 128 to train_seq_len + train_batch_tokens = _e("TRAIN_BATCH_TOKENS", 524288, int) + train_seq_len = _e("TRAIN_SEQ_LEN", 1024, int) + max_wallclock_seconds = _e("MAX_WALLCLOCK_SECONDS", 0.0, float) + # Architecture — from XNOR-UNet spec + vocab_size = _e("VOCAB_SIZE", 8192, int) + num_layers = _e("NUM_LAYERS", 15, int) + num_kv_heads = _e("NUM_KV_HEADS", 4, int) + model_dim = _e("MODEL_DIM", 768, int) + num_heads = _e("NUM_HEADS", 8, int) + mlp_mult = _e("MLP_MULT", 4, int) + tie_embeddings = _e("TIE_EMBEDDINGS", 1, int) + embed_dim = _e("EMBED_DIM", 312, int) + activation_type = _e("ACTIVATION", "relu2") + # Attention + rope_base = _e("ROPE_BASE", 5000.0, float) + rope_type = _e("ROPE_TYPE", "yarn") + yarn_max_len = _e("YARN_MAX_LEN", 2048, int) + qk_gain_init = _e("QK_GAIN_INIT", 2.25, float) + # Logits + logit_softcap = _e("LOGIT_SOFTCAP", 10.0, float) + softcap_type = _e("SOFTCAP_TYPE", "poly") + tied_embed_init_std = _e("TIED_EMBED_INIT_STD", 0.005, float) + # Architecture extras + smear = _e("SMEAR", 1, bool) + attn_res = _e("ATTN_RES", 0, int) # 0=disabled, 1=sub-layer, 2=pass-level + # XNOR quantization + xnor_group_size = _e("XNOR_GROUP_SIZE", 128, int) + # Controls activation binarization: 0 = none (BWN), 1 = all (full XNOR), 2 = skip MLP down proj + binarize_activations = _e("BINARIZE_ACTIVATIONS", 1, int) + # Use INT8 tensor cores for binary matmul (speed test) + use_int8_kernel = _e("USE_INT8_KERNEL", 0, bool) + # Use Triton XNOR+POPCOUNT kernel (true 1-bit matmul) + use_triton_kernel = _e("USE_TRITON_KERNEL", 0, bool) + # FP8 QAT for non-binary params + _fp_raw = os.environ.get("FP_STORAGE", "FP8") + fp_storage = True if _fp_raw == "FP8" else False + # Scale storage for binary weight groups: FP8 or BF16 + _scale_raw = os.environ.get("SCALE_STORAGE", "BF16") + scale_fp8 = True if _scale_raw == "FP8" else False + # Compression regularizer: penalizes sign entropy within groups + sign_compress_reg = _e("SIGN_COMPRESS_REG", 0.0, float) + # Optimizer + matrix_lr = _e("MATRIX_LR", 0.04, float) + scalar_lr = _e("SCALAR_LR", 0.02, float) + tied_embed_lr = _e("TIED_EMBED_LR", 0.02, float) + embed_lr = _e("EMBED_LR", 0.6, float) + head_lr = _e("HEAD_LR", 0.02, float) + matrix_optimizer = _e("MATRIX_OPTIMIZER", "muon") + muon_momentum = _e("MUON_MOMENTUM", 0.95, float) + muon_backend_steps = _e("MUON_BACKEND_STEPS", 3, int) + muon_wd = _e("MUON_WD", 0.0, float) + use_gram_ns = _e("USE_GRAM_NS", 0, bool) + muon_momentum_warmup_start = _e("MUON_MOMENTUM_WARMUP_START", 0.85, float) + muon_momentum_warmup_steps = _e("MUON_MOMENTUM_WARMUP_STEPS", 500, int) + adam_lr = _e("ADAM_LR", 0.05, float) + adam_wd = _e("ADAM_WD", 0.05, float) + beta1 = _e("BETA1", 0.9, float) + beta2 = _e("BETA2", 0.95, float) + adam_eps = _e("ADAM_EPS", 1e-8, float) + grad_clip_norm = _e("GRAD_CLIP_NORM", 0.0, float) + # EMA — compatible with binary (no zero-weight collapse unlike ternary) + ema = _e("EMA", 0, bool) + ema_decay = _e("EMA_DECAY", 0.995, float) + ema_start_fraction = _e("EMA_START_FRACTION", 0.60, float) + # Checkpointing + checkpoint_every = _e("CHECKPOINT_EVERY", 0, int) + checkpoint_dir = _e("CHECKPOINT_DIR", "./checkpoints") + # Diagnostics + churn_log_every = _e("CHURN_LOG_EVERY", 500, int) + +# Scalar/low-dim parameter names (excluded from Muon, use Adam) +CTP = ("attn_scale", "mlp_scale", "resid_mix", "q_gain", "skip_weight", + "skip_weights", "vocab_bias", "smear") + +# --------------------------------------------------------------------------- +# Binary packing (1 bit per weight, lossless) +# --------------------------------------------------------------------------- +def pack_binary(q: Tensor) -> tuple[bytes, int]: + bits = ((q.reshape(-1).to(torch.int8) + 1) // 2).numpy().astype(np.uint8) + n = len(bits) + pad = (8 - n % 8) % 8 + if pad: + bits = np.concatenate([bits, np.zeros(pad, dtype=np.uint8)]) + groups = bits.reshape(-1, 8) + packed = np.zeros(len(groups), dtype=np.uint8) + for i in range(8): + packed |= groups[:, i] << i + return packed.tobytes(), n + +def unpack_binary(data: bytes, n: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8) + bits = np.zeros((len(packed), 8), dtype=np.int8) + for i in range(8): + bits[:, i] = (packed >> i) & 1 + flat = bits.reshape(-1)[:n] + return torch.from_numpy(flat.astype(np.int8) * 2 - 1) + +# --------------------------------------------------------------------------- +# State dict serialization +# --------------------------------------------------------------------------- +def q_sd(state_dict: dict, group_size: int = 128, fp_storage=False, scale_fp8=False) -> tuple[dict, dict]: + """Binary pack large 2D weight matrices, BF16/FP8 for everything else.""" + quantized = {} + stats = {"binary_params": 0, "binary_bytes": 0, "fp_params": 0, "fp_bytes": 0} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().float().contiguous() + t_orig_shape = list(t.shape) + if t.ndim == 3: + t = t.reshape(t.shape[0], -1) + is_binary = ( + t.ndim == 2 and t.numel() > 65_536 + and "tok_emb" not in name and "lm_head" not in name + and "embed_proj" not in name + ) + if is_binary: + pad = (group_size - t.shape[1] % group_size) % group_size + t_padded = F.pad(t, (0, pad)) if pad > 0 else t + t_grouped = t_padded.reshape(-1, group_size) + scale = t_grouped.abs().mean(-1, keepdim=True).clamp(min=1e-8) + q = torch.where(t_grouped >= 0, + torch.ones_like(t_grouped, dtype=torch.int8), + -torch.ones_like(t_grouped, dtype=torch.int8)) + packed_bytes, n_bits = pack_binary(q) + if scale_fp8: + scale_stored = scale.to(torch.float8_e4m3fn).squeeze(-1) + scale_bytes_per = 1 + else: + scale_stored = scale.bfloat16().squeeze(-1) + scale_bytes_per = 2 + quantized[name] = { + "type": "binary", "packed": packed_bytes, + "scale": scale_stored, + "shape": list(t.shape), "padded_cols": t_padded.shape[1], + "group_size": group_size, "n_bits": n_bits, + "orig_shape": t_orig_shape, + } + stats["binary_params"] += t.numel() + stats["binary_bytes"] += len(packed_bytes) + scale.numel() * scale_bytes_per + elif fp_storage and t.ndim == 2: + quantized[name] = {"type": "fp8", "data": t.to(torch.float8_e4m3fn)} + stats["fp_params"] += t.numel() + stats["fp_bytes"] += t.numel() + else: + quantized[name] = {"type": "bf16", "data": t.bfloat16()} + stats["fp_params"] += t.numel() + stats["fp_bytes"] += t.numel() * 2 + return quantized, stats + +def deq_sd(quantized: dict, target_dtype=torch.bfloat16): + """Reconstruct state dict from quantized representation.""" + out = {} + for name, entry in quantized.items(): + if entry["type"] == "binary": + q = unpack_binary(entry["packed"], entry["n_bits"]) + q = q.float().reshape(-1, entry["group_size"]) + scale = entry["scale"].float().unsqueeze(-1) + # No shrinkage correction: binary q_absmean = 1.0 always + t = (q * scale).reshape(-1, entry["padded_cols"]) + shape = entry["shape"] + result = t[:shape[0], :shape[1]].to(target_dtype) + orig = entry.get("orig_shape") + out[name] = result.reshape(orig).contiguous() if orig and orig != shape else result.contiguous() + elif entry["type"] == "fp8": + out[name] = entry["data"].to(torch.float32).to(target_dtype).contiguous() + else: + out[name] = entry["data"].to(target_dtype).contiguous() + return out + +# --------------------------------------------------------------------------- +# Diagnostics +# --------------------------------------------------------------------------- +_prev_committed: dict = {} +def churn_fn(model: nn.Module, group_size: int = 128): + global _prev_committed + total = flipped = 0 + with torch.no_grad(): + for name, p in model.named_parameters(): + if p.ndim == 2 and "weight" in name and p.shape[0] > 1: + w = p.detach().float().reshape(-1, group_size) + q = torch.where(w >= 0, torch.ones_like(w), -torch.ones_like(w)).cpu().numpy() + if name in _prev_committed: + flipped += int(np.sum(q != _prev_committed[name])) + total += q.size + _prev_committed[name] = q + return flipped / max(total, 1) + +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# Muon optimizer +# --------------------------------------------------------------------------- +def ns_orth(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +def gram_ns_orth(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: + """Gram Newton-Schulz with Polar Express coefficients. 5 steps, restart at step 2. + Fully unrolled for torch.compile. bfloat16 throughout.""" + X = G.bfloat16() + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + X = X / (X.norm() + eps) + n = X.size(0) + I = torch.eye(n, device=X.device, dtype=X.dtype) + + # Phase 1: steps 0-1 + R = X @ X.T + # Step 0: a=4.0848, b=-6.8946, c=2.9270 + Z = -6.8946 * R + 2.9270 * (R @ R) + Q = Z + 4.0848 * I + RZ = 4.0848 * R + R @ Z + R = 4.0848 * RZ + Z @ RZ + # Step 1: a=3.9505, b=-6.3029, c=2.6377 + Z = -6.3029 * R + 2.6377 * (R @ R) + Q = 3.9505 * Q + Q @ Z + # No R update (next is reset) + + # Reset + X = Q @ X + R = X @ X.T + + # Phase 2: steps 2-4 + # Step 2: a=3.7418, b=-5.5913, c=2.3037 + Z = -5.5913 * R + 2.3037 * (R @ R) + Q = Z + 3.7418 * I + RZ = 3.7418 * R + R @ Z + R = 3.7418 * RZ + Z @ RZ + # Step 3: a=2.8769, b=-3.1427, c=1.2046 + Z = -3.1427 * R + 1.2046 * (R @ R) + Q = 2.8769 * Q + Q @ Z + RZ = 2.8769 * R + R @ Z + R = 2.8769 * RZ + Z @ RZ + # Step 4: a=2.8366, b=-3.0525, c=1.2012 + Z = -3.0525 * R + 1.2012 * (R @ R) + Q = 2.8366 * Q + Q @ Z + + X = Q @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr, momentum, backend_steps, nesterov=True, wd=0.0): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, wd=wd)) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr, momentum = group["lr"], group["momentum"] + backend_steps, nesterov = group["backend_steps"], group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = F.rms_norm(g.float(), (g.size(-1),)).bfloat16() + g = ns_orth(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("wd", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr:curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.mul_(1 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + +# --------------------------------------------------------------------------- +# Data loading +# --------------------------------------------------------------------------- +def ld_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" Tensor: + chunks = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int): + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start:start + per_rank_span].pin_memory().to(self.device, non_blocking=True).to(torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x, y + +# --------------------------------------------------------------------------- +# Model components +# --------------------------------------------------------------------------- +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def apply_qat_ste(w: Tensor, fp_storage) -> Tensor: + if not fp_storage: + return w + w_sim = w.to(torch.float8_e4m3fn).to(w.dtype) + return (w_sim - w).detach() + w + +class QATLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=False, fp_storage=False): + super().__init__(in_features, out_features, bias=bias) + self.fp_storage = fp_storage + def forward(self, x: Tensor) -> Tensor: + w = apply_qat_ste(self.weight, self.fp_storage) + return F.linear(x, w.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) + +class QATEmbedding(nn.Embedding): + def __init__(self, num_embeddings, embedding_dim, fp_storage=False): + super().__init__(num_embeddings, embedding_dim) + self.fp_storage = fp_storage + def forward(self, input: Tensor) -> Tensor: + w = apply_qat_ste(self.weight, self.fp_storage) + return F.embedding(input, w, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + +# --------------------------------------------------------------------------- +# XNOR-Net Linear Layers +# --------------------------------------------------------------------------- +_INT8_KERNEL = False # set from Hyperparameters.use_int8_kernel in main() + +class _INT8XNORFn(torch.autograd.Function): + """INT8 tensor core forward, BF16 backward. + Uses per-row alpha (not per-group) so we get a single torch._int_mm call. + """ + @staticmethod + def forward(ctx, x, w): + # Per-row weight scale (single matmul) + w_alpha = w.detach().abs().mean(dim=-1).clamp_(min=1e-8) # [N] + x_beta = x.detach().abs().mean(dim=-1, keepdim=True).clamp_(min=1e-8) # [..., 1] + + w_int8 = w.detach().sign().to(torch.int8) # [N, K] + x_int8 = x.detach().sign().to(torch.int8) # [..., K] + + # Flatten to 2D for torch._int_mm(A[M,K], B[K,N]) -> [M,N] + orig_shape = x.shape[:-1] + M, K, N = orig_shape.numel(), x.shape[-1], w.shape[0] + x_2d = x_int8.reshape(M, K) + wt = w_int8.t().contiguous() # [K, N] + raw = torch._int_mm(x_2d, wt) # [M, N] INT32 + + y = raw.to(x.dtype) * x_beta.reshape(M, 1) * w_alpha.unsqueeze(0) + ctx.save_for_backward(x, w, w_alpha, x_beta) + return y.reshape(*orig_shape, N) + + @staticmethod + def backward(ctx, grad_output): + x, w, w_alpha, x_beta = ctx.saved_tensors + dtype = grad_output.dtype + + w_binary = (w.sign() * w_alpha.unsqueeze(-1)).to(dtype) # [N, K] + x_binary = (x.sign() * x_beta).to(dtype) # [..., K] + + # STE: gradient flows through sign() as identity + grad_x = grad_output @ w_binary + M = grad_output.shape[:-1].numel() + grad_w = grad_output.reshape(M, -1).t() @ x_binary.reshape(M, -1) + return grad_x, grad_w + +# --------------------------------------------------------------------------- +# Triton XNOR+POPCOUNT kernel — true 1-bit matmul +# --------------------------------------------------------------------------- +_TRITON_KERNEL = False # set from Hyperparameters in main() +_SCALE_FP8_STE = False # simulate FP8 scale quantization during training + +def pack_signs(x: Tensor) -> Tensor: + """Pack sign bits into int32. +1 -> bit 1, -1 -> bit 0. + x: [..., K] any dtype. K must be divisible by 32. + returns: [..., K//32] int32 + """ + K = x.shape[-1] + assert K % 32 == 0, f"K={K} must be divisible by 32" + bits = (x > 0).view(*x.shape[:-1], K // 32, 32) + shifts = torch.arange(32, device=x.device, dtype=torch.int32) + return (bits.to(torch.int32) << shifts).sum(dim=-1) + +if HAS_TRITON: + @triton.jit + def _xnor_matmul_fwd_kernel( + X_ptr, W_ptr, Y_ptr, alpha_ptr, beta_ptr, + M, N, K_packed, GROUP_PACKED: tl.constexpr, num_groups, + stride_xm, stride_wn, stride_yn, stride_alpha_n, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + ): + pid = tl.program_id(0) + num_n = tl.cdiv(N, BLOCK_N) + pid_m = pid // num_n + pid_n = pid % num_n + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_m = offs_m < M + mask_n = offs_n < N + + y_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.bfloat16) + + x_base = X_ptr + offs_m * stride_xm + w_base = W_ptr + offs_n * stride_wn + + # Loop over groups, each with GROUP_PACKED packed int32s + for g in range(num_groups): + group_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + k_start = g * GROUP_PACKED + for k in range(GROUP_PACKED): + x = tl.load(x_base + k_start + k, mask=mask_m, other=0) + w = tl.load(w_base + k_start + k, mask=mask_n, other=0) + # Cast broadcast XOR to int32 to prevent int64 promotion in popc + diff = (x[:, None] ^ w[None, :]).to(tl.int32) + group_acc += tl.extra.cuda.libdevice.popc(diff) + + # diff_bits = popc(XOR), dot = GROUP_SIZE - 2 * diff_bits + group_dot = (GROUP_PACKED * 32 - 2 * group_acc).to(tl.bfloat16) + + # Per-group alpha: alpha[n, g] + alpha_g = tl.load(alpha_ptr + offs_n * stride_alpha_n + g, mask=mask_n, other=1.0).to(tl.bfloat16) + y_acc += group_dot * alpha_g[None, :] + + # Per-token beta + beta = tl.load(beta_ptr + offs_m, mask=mask_m, other=1.0).to(tl.bfloat16) + y_acc = y_acc * beta[:, None] + + # Store + y_ptrs = Y_ptr + offs_m[:, None] * stride_yn + offs_n[None, :] + tl.store(y_ptrs, y_acc, mask=mask_m[:, None] & mask_n[None, :]) + +class _TritonXNORFn(torch.autograd.Function): + """Triton XNOR+POPCOUNT forward with per-group alpha, BF16 backward.""" + @staticmethod + def forward(ctx, x, w, group_size): + N, K = w.shape + assert K % group_size == 0, f"K={K} must be divisible by group_size={group_size}" + assert K % 32 == 0, f"K={K} must be divisible by 32" + num_groups = K // group_size + GROUP_PACKED = group_size // 32 + + # Per-group weight scale: [N, num_groups] + w_alpha = w.detach().reshape(N, num_groups, group_size).abs().mean(dim=-1).clamp_(min=1e-8) + if _SCALE_FP8_STE: + w_alpha = w_alpha.to(torch.float8_e4m3fn).to(w_alpha.dtype) + # Per-token activation scale: [..., 1] + x_beta_kd = x.detach().abs().mean(dim=-1, keepdim=True).clamp_(min=1e-8) + + # Pack signs into int32 + x_packed = pack_signs(x.detach()) # [..., K//32] + w_packed = pack_signs(w.detach()) # [N, K//32] + + orig_shape = x.shape[:-1] + M = orig_shape.numel() + K_packed = K // 32 + + x_2d = x_packed.reshape(M, K_packed).contiguous() + w_alpha_c = w_alpha.contiguous() # [N, num_groups] + beta_flat = x_beta_kd.reshape(M).contiguous() + + Y = torch.empty(M, N, device=x.device, dtype=torch.bfloat16) + + BLOCK_M, BLOCK_N = 64, 64 + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),) + + _xnor_matmul_fwd_kernel[grid]( + x_2d, w_packed, Y, w_alpha_c, beta_flat, + M, N, K_packed, GROUP_PACKED, num_groups, + x_2d.stride(0), w_packed.stride(0), Y.stride(0), w_alpha_c.stride(0), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + ) + + ctx.save_for_backward(x, w, w_alpha, x_beta_kd) + ctx.group_size = group_size + return Y.reshape(*orig_shape, N) + + @staticmethod + def backward(ctx, grad_output): + x, w, w_alpha, x_beta_kd = ctx.saved_tensors + dtype = grad_output.dtype + group_size = ctx.group_size + + # Reconstruct per-group binary weights for backward + N, K = w.shape + num_groups = K // group_size + w_binary = (w.sign().reshape(N, num_groups, group_size) * + w_alpha.unsqueeze(-1)).reshape(N, K).to(dtype) + x_binary = (x.sign() * x_beta_kd).to(dtype) + + grad_x = grad_output @ w_binary + M = grad_output.shape[:-1].numel() + grad_w = grad_output.reshape(M, -1).t() @ x_binary.reshape(M, -1) + return grad_x, grad_w, None # None for group_size + +class XNORLinear(nn.Linear): + """XNOR-Net linear layer: binarizes BOTH weights AND activations. + + Forward pass: + 1. Weight binarization: W_bin = sign(W) * alpha, alpha = mean(|W|) per group + 2. Activation binarization: X_bin = sign(X) * beta, beta = mean(|X|) per token + 3. Output = F.linear(X_bin, W_bin) + + Backward pass: Straight-Through Estimator (STE) — gradients flow through as + if the binarization was identity. This matches the XNOR-Net paper's training + algorithm (Algorithm 1). + + When binarize_act=False, this degrades to Binary-Weight-Network (BWN). + """ + def __init__(self, in_features, out_features, bias=False, group_size=128, + binarize_act=True): + super().__init__(in_features, out_features, bias=bias) + self.group_size = group_size + self.binarize_act = binarize_act + + def forward(self, x: Tensor) -> Tensor: + # Triton XNOR+POPCOUNT path (full XNOR only, per-row alpha) + if _TRITON_KERNEL and self.binarize_act: + y = _TritonXNORFn.apply(x, self.weight, self.group_size) + if self.bias is not None: + y = y + self.bias.to(y.dtype) + return y + # INT8 tensor core path (full XNOR only, per-row alpha) + if _INT8_KERNEL and self.binarize_act: + y = _INT8XNORFn.apply(x, self.weight) + if self.bias is not None: + y = y + self.bias.to(y.dtype) + return y + + # Standard STE path (per-group alpha) + w = self.weight.bfloat16() + g = self.group_size + + # --- Weight binarization (STE) --- + w_g = w.reshape(-1, g) + alpha = w_g.abs().mean(-1, keepdim=True).clamp(min=1e-8) + # Simulate FP8 scale quantization so model compensates for roundtrip error + if _SCALE_FP8_STE: + alpha_q = alpha.to(torch.float8_e4m3fn).to(alpha.dtype) + alpha = alpha + (alpha_q - alpha).detach() + w_sign = torch.where(w_g >= 0, torch.ones_like(w_g), -torch.ones_like(w_g)) + w_binary = w + ((w_sign * alpha).reshape(w.shape) - w).detach() + + # --- Activation binarization (STE) --- + if self.binarize_act: + beta = x.abs().mean(dim=-1, keepdim=True).clamp(min=1e-8) + x_sign = torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x)) + x_binary = x + ((x_sign * beta) - x).detach() + else: + x_binary = x + + return F.linear(x_binary, w_binary, + self.bias.to(x.dtype) if self.bias is not None else None) + + +class NormedXNORLinear(XNORLinear): + """XNOR linear with RMSNorm on input. + Used for output projections receiving un-normalized activations (attention out, MLP down). + The RMSNorm ensures zero-mean input before binarization — critical for sign() quality + (XNOR-Net paper: B-A-C-P block order, Table 3b: 30.3% vs 44.2%). + """ + def forward(self, x: Tensor) -> Tensor: + return super().forward(F.rms_norm(x, (x.size(-1),))) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(p in name for p in CTP)) and param.dtype != torch.float32: + param.data = param.data.float() + +# --------------------------------------------------------------------------- +# Rotary embeddings +# --------------------------------------------------------------------------- +class Rotary(nn.Module): + def __init__(self, dim, base=10000.0, rope_type="rope", yarn_max_len=4096, train_seq_len=1024): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + if rope_type == "yarn": + scale = train_seq_len / yarn_max_len + freq_idx = torch.arange(0, dim, 2, dtype=torch.float32) + ramp = torch.clamp((freq_idx / dim - 0.25) / 0.75, 0.0, 1.0) + inv_freq = inv_freq / (ramp * (1.0 / scale - 1.0) + 1.0) + # Precompute for max seq_len, slice at runtime + t = torch.arange(train_seq_len, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self.register_buffer("_cos", freqs.cos()[None, :, None, :], persistent=False) + self.register_buffer("_sin", freqs.sin()[None, :, None, :], persistent=False) + + def forward(self, seq_len, device, dtype): + return self._cos[:, :seq_len].to(dtype=dtype), self._sin[:, :seq_len].to(dtype=dtype) + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- +class CausalSelfAttention(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + group_size=128, rope_type="rope", yarn_max_len=4096, + train_seq_len=1024, binarize_act=True): + super().__init__() + self.num_heads, self.num_kv_heads = num_heads, num_kv_heads + self.head_dim = dim // num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + # Fused QKV — single XNOR matmul + self.c_qkv = XNORLinear(dim, self.q_size + 2 * self.kv_size, bias=False, + group_size=group_size, binarize_act=binarize_act) + # Output projection with RMSNorm before binarization (B-A-C-P order) + self.proj = NormedXNORLinear(dim, dim, bias=False, + group_size=group_size, binarize_act=binarize_act) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, + rope_type=rope_type, yarn_max_len=yarn_max_len, + train_seq_len=train_seq_len) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + qkv_out = self.c_qkv(x) + q_out, k_out, v_out = qkv_out.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q_out.reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = k_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = v_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + # QK normalization (BF16, stays float — non-negotiable for attention quality) + q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + # FlashAttention-3 (BF16 — attention scores/softmax/V are NEVER binarized) + y = flash_attn_func(q.contiguous(), k.contiguous(), v.contiguous(), causal=True) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + +# --------------------------------------------------------------------------- +# MLP +# --------------------------------------------------------------------------- +class MLP(nn.Module): + def __init__(self, dim, mlp_mult, group_size=128, activation="relu2", + binarize_act=True, binarize_down=True): + super().__init__() + hidden = mlp_mult * dim + self.activation = activation + if activation == "swiglu": + self.gate_up = XNORLinear(dim, hidden * 2, bias=False, + group_size=group_size, binarize_act=binarize_act) + else: + self.fc = XNORLinear(dim, hidden, bias=False, + group_size=group_size, binarize_act=binarize_act) + self.proj = NormedXNORLinear(hidden, dim, bias=False, + group_size=group_size, binarize_act=binarize_down) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + if self.activation == "swiglu": + gu = self.gate_up(x) + gate, up = gu.chunk(2, dim=-1) + return self.proj(F.silu(gate) * up) + elif self.activation == "signsq": + h = self.fc(x) + return self.proj(h * h.abs()) + else: # relu2 + return self.proj(torch.relu(self.fc(x)).square()) + +# --------------------------------------------------------------------------- +# SmearModule +# --------------------------------------------------------------------------- +class SmearModule(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + cumsum = x.cumsum(dim=1) + counts = torch.arange(1, x.size(1) + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smeared = cumsum / counts + gate = torch.tanh(self.gate.to(dtype=x.dtype)) + return x + gate * (smeared - x) + +# --------------------------------------------------------------------------- +# Transformer block +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# Attention Residuals (AttnRes) — depth-wise attention over prior outputs +# From: "Attention Residuals" (Kimi Team, 2026) +# --------------------------------------------------------------------------- +class AttnRes(nn.Module): + def __init__(self, dim, n_queries): + super().__init__() + # Zero-init -> uniform weights at start -> equivalent to standard residual + self.queries = nn.ParameterList([ + nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + for _ in range(n_queries) + ]) + self.key_norm_weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + + def attend(self, stored, query_idx): + w = self.queries[query_idx] + logits = [] + for v in stored: + k = F.rms_norm(v, (v.size(-1),)) * self.key_norm_weight.to(v.dtype) + logits.append(torch.einsum('d, b t d -> b t', w.to(v.dtype), k)) + logits = torch.stack(logits, dim=0) + weights = F.softmax(logits, dim=0) + out = torch.zeros_like(stored[0]) + for i, v in enumerate(stored): + out = out + weights[i].unsqueeze(-1) * v + return out + +# --------------------------------------------------------------------------- +# Transformer Block +# --------------------------------------------------------------------------- +class Block(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, group_size=128, activation="relu2", + rope_type="rope", yarn_max_len=4096, train_seq_len=1024, + binarize_act=True, binarize_down=True, smear=False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, + qk_gain_init, group_size, rope_type, + yarn_max_len, train_seq_len, binarize_act) + self.mlp = MLP(dim, mlp_mult, group_size, activation, binarize_act, binarize_down) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearModule(dim) if smear else None + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0] * x + mix[1] * x0 + n = self.attn_norm(x) + x = x + self.attn_scale.to(dtype=x.dtype) * self.attn(n) + x = x + self.mlp_scale.to(dtype=x.dtype) * self.mlp(self.mlp_norm(x)) + if self.smear is not None: + x = self.smear(x) + return x + +# --------------------------------------------------------------------------- +# GPT with U-Net skip connections +# --------------------------------------------------------------------------- +class GPT(nn.Module): + def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, + mlp_mult, tie_embeddings, tied_embed_init_std, logit_softcap, + rope_base, qk_gain_init, group_size=128, activation="relu2", + embed_dim=0, fp_storage=False, softcap_type="poly", + rope_type="rope", yarn_max_len=4096, train_seq_len=1024, + binarize_act=True, binarize_down=True, smear=False, attn_res=0): + super().__init__() + self.tie_embeddings = tie_embeddings + self.logit_softcap = logit_softcap + self.softcap_type = softcap_type + self.fp_storage = fp_storage + self.embed_dim = embed_dim if embed_dim > 0 else model_dim + self.attn_res_mode = attn_res + + # Embedding (FP8 QAT — NEVER binarized, confirmed 0.11 bpb RT gap) + self.tok_emb = QATEmbedding(vocab_size, self.embed_dim, fp_storage=fp_storage) + # Projection from embed_dim to model_dim (and reverse for tied logits) + self.embed_proj = QATLinear(self.embed_dim, model_dim, bias=False, + fp_storage=fp_storage) if self.embed_dim != model_dim else None + self.embed_proj_rev = QATLinear(model_dim, self.embed_dim, bias=False, + fp_storage=fp_storage) if self.embed_dim != model_dim else None + + # U-Net structure + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + # Skip weights initialized to ONES (not zeros — zeros costs 0.010 bpb) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + + # Transformer blocks + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, group_size, activation, rope_type, yarn_max_len, + train_seq_len, binarize_act, binarize_down, smear) + for _ in range(num_layers) + ]) + self.final_norm = RMSNorm() + + # Attention Residuals (replaces U-Net skips when enabled) + self.attn_res = None + if attn_res == 2: # pass-level + self.attn_res = AttnRes(model_dim, n_queries=num_layers) + elif attn_res == 1: # sub-layer + self.attn_res = AttnRes(model_dim, n_queries=num_layers * 2) + + # Logit head (tied to embedding) + self.lm_head = QATLinear(model_dim, vocab_size, bias=False, fp_storage=fp_storage) + self.lm_head._zero_init = True + if tie_embeddings: + self.lm_head.weight.requires_grad_(False) + self.vocab_bias = nn.Parameter(torch.zeros(vocab_size, dtype=torch.float32)) + + self._init_weights(tied_embed_init_std) + + def _init_weights(self, tied_embed_init_std): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=tied_embed_init_std) + for module in self.modules(): + if isinstance(module, XNORLinear) and not getattr(module, "_zero_init", False): + nn.init.normal_(module.weight, mean=0.0, std=0.02) + elif isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _compute_logits(self, x: Tensor) -> Tensor: + if self.tie_embeddings: + proj = self.embed_proj_rev(x) if self.embed_proj_rev is not None else x + return F.linear(proj, self.tok_emb.weight.to(x.dtype)) + return self.lm_head(x) + + def _softcap(self, logits: Tensor) -> Tensor: + s = self.logit_softcap + if self.softcap_type == "tanh": + return s * torch.tanh(logits / s) + # Poly softcap — fuses with torch.compile (tanh does not) + x_sc = torch.clamp(logits / s, -2.0, 2.0) + x2 = x_sc * x_sc + return s * torch.clamp(x_sc * (1.0 - x2 / 3.0 + x2 * x2 / 15.0), -1.0, 1.0) + + def forward(self, input_ids: Tensor, target_ids: Tensor, reduction="mean", temperature=1.0): + x = self.tok_emb(input_ids) + x = x.float() + if self.embed_proj is not None: + x = self.embed_proj(x) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.attn_res_mode == 2: + # Pass-level AttnRes: attend over all previous block outputs + stored = [x] + for i, block in enumerate(self.blocks): + x = self.attn_res.attend(stored, i) + x = block(x, x0) + stored.append(x) + elif self.attn_res_mode == 1: + # Sub-layer AttnRes: attend before each attn and mlp sub-layer + stored = [x] + q_idx = 0 + for i, block in enumerate(self.blocks): + # Attention sub-layer + x = self.attn_res.attend(stored, q_idx) + mix = block.resid_mix.to(dtype=x.dtype) + x = mix[0] * x + mix[1] * x0 + attn_out = block.attn_scale.to(dtype=x.dtype) * block.attn(block.attn_norm(x)) + stored.append(attn_out) + q_idx += 1 + # MLP sub-layer + x = self.attn_res.attend(stored, q_idx) + mlp_out = block.mlp_scale.to(dtype=x.dtype) * block.mlp(block.mlp_norm(x)) + stored.append(mlp_out) + q_idx += 1 + if block.smear is not None: + x = block.smear(self.attn_res.attend(stored, q_idx - 1)) + else: + # Standard U-Net (mode 0) + skips = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype) * skips.pop() + x = self.blocks[bi](x, x0) + + x_normed = self.final_norm(x) + x_flat = x_normed.reshape(-1, x_normed.size(-1)) + targets = target_ids.reshape(-1) + logits = self._softcap(self._compute_logits(x_flat)) + + if temperature != 1.0: + logits = logits / temperature + if reduction == "none": + return F.cross_entropy(logits.float(), targets, reduction="none").reshape(input_ids.shape) + + # Fused CE + Z-loss (logsumexp² regularization keeps STE gradients sharp) + logits_f = logits.float() + lse = torch.logsumexp(logits_f, dim=-1) + target_logits = logits_f.gather(1, targets.unsqueeze(1)).squeeze(1) + return (lse - target_logits).mean() + 1e-4 * (lse ** 2).mean() + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- +def build_luts(sp, vocab_size, device): + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + +def ld_val(pattern, seq_len, max_tok=int(os.environ.get("VAL_MAX_TOKENS", 500000))): + files = sorted(glob.glob(pattern)) + assert files, f"No files: {pattern}" + tok = torch.cat([ld_shard(Path(p)) for p in files]).contiguous() + if max_tok > 0: tok = tok[:max_tok + 1] + u = ((tok.numel() - 1) // seq_len) * seq_len + return tok[:u + 1] + +def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, temperature=1.0): + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_seqs = max(1, local_batch_tokens // args.train_seq_len) + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_start in range(seq_start, seq_end, local_batch_seqs): + batch_end = min(batch_start + local_batch_seqs, seq_end) + raw_start = batch_start * args.train_seq_len + raw_end = batch_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64) + x, y = local[:-1].reshape(-1, args.train_seq_len), local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + batch_loss = model(x, y, temperature=temperature).detach() + n = float(y.numel()) + loss_sum += batch_loss.to(torch.float64) * n + token_count += n + prev_ids, tgt_ids = x.reshape(-1), y.reshape(-1) + tok_bytes = base_bytes_lut[tgt_ids].to(torch.int16) + tok_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.int16) + byte_count += tok_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, token_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + val_loss = loss_sum / token_count + bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) + model.train() + return float(val_loss.item()), float(bpb) + +def eval_val_sliding(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, temperature=1.0): + seq_len = args.train_seq_len + batch_size = args.sliding_batch_size + total_tokens = val_tokens.numel() - 1 + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + all_starts = list(range(0, total_tokens - seq_len, stride)) + my_starts = all_starts[rank::world_size] + model.eval() + with torch.inference_mode(): + for i in range(0, len(my_starts), batch_size): + batch_starts = my_starts[i:i + batch_size] + starts_t = torch.tensor(batch_starts, dtype=torch.int64) + offsets = torch.arange(seq_len + 1, dtype=torch.int64) + indices = starts_t.unsqueeze(1) + offsets.unsqueeze(0) + local_batch = val_tokens[indices].to(device=device, dtype=torch.int64, non_blocking=True) + x, y = local_batch[:, :-1], local_batch[:, 1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_token_loss = model(x, y, reduction="none", temperature=temperature).detach() + for b, start in enumerate(batch_starts): + score_from = 0 if start == 0 else seq_len - stride + scored = per_token_loss[b, score_from:] + sx, sy = x[b, score_from:], y[b, score_from:] + loss_sum += scored.to(torch.float64).sum() + token_count += scored.numel() + tok_bytes = base_bytes_lut[sy].to(torch.int16) + tok_bytes += (has_leading_space_lut[sy] & ~is_boundary_token_lut[sx]).to(torch.int16) + byte_count += tok_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, token_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + val_loss = loss_sum / token_count + bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) + model.train() + return float(val_loss.item()), float(bpb) + +def find_temp(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut): + best_t, best_loss = 1.0, float("inf") + for t in [0.85, 0.90, 0.95, 1.00, 1.05, 1.10]: + loss, _ = eval_val(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, temperature=t) + if loss < best_loss: + best_loss = loss + best_t = t + return best_t + +# --------------------------------------------------------------------------- +# Checkpointing +# --------------------------------------------------------------------------- +def _ckpt_path(d, step): return os.path.join(d, f"ckpt_step{step:07d}.pt") + +def _latest_checkpoint(d): + if not os.path.isdir(d): return None + ckpts = sorted(glob.glob(os.path.join(d, "ckpt_step*.pt"))) + return ckpts[-1] if ckpts else None + +def save_checkpoint(d, step, model, optimizers, ms, _untied, ema_model, _ema_started, _ema_steps): + os.makedirs(d, exist_ok=True) + path = _ckpt_path(d, step) + tmp = path + ".tmp" + torch.save({ + "step": step, "training_time_ms": ms, "model": model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "rng_cpu": torch.get_rng_state(), "rng_cuda": torch.cuda.get_rng_state(), + "flags": {"_untied": _untied, "_ema_started": _ema_started, "_ema_steps": _ema_steps}, + "ema_model": ema_model.state_dict() if ema_model is not None else None, + }, tmp) + os.replace(tmp, path) + +def load_checkpoint(path, model, optimizers, device, ema_model=None): + p = torch.load(path, map_location=device, weights_only=False) + model.load_state_dict(p["model"], strict=True) + for opt, sd in zip(optimizers, p["optimizers"]): + opt.load_state_dict(sd) + if ema_model is not None and p.get("ema_model") is not None: + ema_model.load_state_dict(p["ema_model"], strict=True) + torch.set_rng_state(p["rng_cpu"].cpu()) + torch.cuda.set_rng_state(p["rng_cuda"].cpu()) + f = p.get("flags", {}) + return p["step"], p["training_time_ms"], f.get("_untied", False), f.get("_ema_started", False), f.get("_ema_steps", 0) + +# --------------------------------------------------------------------------- +# Main training +# --------------------------------------------------------------------------- +def main(): + args = Hyperparameters() + code = Path(__file__).read_text(encoding="utf-8") + if args.matrix_optimizer != "adamw": + global ns_orth + if args.use_gram_ns: + ns_orth = torch.compile(gram_ns_orth) + else: + ns_orth = torch.compile(ns_orth) + + global _INT8_KERNEL + _INT8_KERNEL = args.use_int8_kernel + + global _TRITON_KERNEL + _TRITON_KERNEL = args.use_triton_kernel and HAS_TRITON + + global _SCALE_FP8_STE + _SCALE_FP8_STE = args.scale_fp8 + + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = world_size > 1 + grad_accum_steps = max(1, 8 // world_size) + grad_scale = 1.0 / grad_accum_steps + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + os.makedirs("logs/cuda/xnor/", exist_ok=True) + logfile = f"logs/cuda/xnor/{args.run_id}.txt" if master_process else None + if master_process: print(logfile) + def log0(msg, console=True): + if not master_process: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + val_tokens = ld_val(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_luts(sp, args.vocab_size, device) + + # --- Model --- + # BINARIZE_ACTIVATIONS: 0=BWN, 1=full XNOR, 2=XNOR but BWN on MLP down proj + _ba = args.binarize_activations + _binarize_act = _ba >= 1 + _binarize_down = _ba == 1 + base_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + group_size=args.xnor_group_size, activation=args.activation_type, + embed_dim=args.embed_dim, fp_storage=args.fp_storage, + softcap_type=args.softcap_type, rope_type=args.rope_type, + yarn_max_len=args.yarn_max_len, train_seq_len=args.train_seq_len, + binarize_act=_binarize_act, binarize_down=_binarize_down, + smear=args.smear, attn_res=args.attn_res, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, nn.Linear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.tie_embeddings: + base_model.lm_head.weight.requires_grad_(False) + + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.recompile_limit = 16 + compiled_model = torch.compile(base_model, mode=args.compile_mode if args.compile_mode != "default" else None) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, + find_unused_parameters=not args.tie_embeddings, + static_graph=args.tie_embeddings, + gradient_as_bucket_view=True) if distributed else compiled_model + + # --- Optimizers --- + _excl = {"tok_emb.weight", "lm_head.weight"} + all_other = [(n, p) for n, p in base_model.named_parameters() if not any(eh in n for eh in _excl)] + matrix_params = [p for n, p in all_other if p.ndim == 2 and not any(pat in n for pat in CTP)] + scalar_params = [p for n, p in all_other if p.ndim < 2 or any(pat in n for pat in CTP)] + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + opt_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + + if args.matrix_optimizer == "adamw": + opt_muon = torch.optim.AdamW( + [{"params": matrix_params, "lr": args.adam_lr, "base_lr": args.adam_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.adam_wd, fused=True) + else: + opt_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, wd=args.muon_wd) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": 0.0, "base_lr": 0.0}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_tok, opt_muon, opt_scalar, opt_head] + + # --- EMA --- + ema_model = None + _ema_started = False + _ema_steps = 0 + if args.ema: + ema_model = copy.deepcopy(base_model) + for p in ema_model.parameters(): p.requires_grad_(False) + + # --- Checkpoint resume --- + _resume_step, _resume_ms, _resume_untied = 0, 0.0, False + _resume_ema_started, _resume_ema_steps = False, 0 + if args.checkpoint_every > 0: + ckpt = _latest_checkpoint(args.checkpoint_dir) + if ckpt: + log0(f"resuming from {ckpt}") + _resume_step, _resume_ms, _resume_untied, _resume_ema_started, _resume_ema_steps = \ + load_checkpoint(ckpt, base_model, optimizers, device, ema_model) + log0(f"resumed at step {_resume_step} ({_resume_ms:.0f}ms)") + else: + log0("no checkpoint, starting fresh") + + # --- Log --- + log0("--- Hyperparameters ---", console=False) + log0(" ".join(f"{a}={getattr(args,a)}" for a in sorted(dir(args)) + if not a.startswith("_") and a not in ("train_files","val_files") + and not callable(getattr(args,a))), console=False) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"XNOR-Net LLM | params:{n_params} L:{args.num_layers} d:{args.model_dim} " + f"h:{args.num_heads} kv:{args.num_kv_heads} mlp:{args.mlp_mult}x " + f"act:{args.activation_type} xnor_act:{args.binarize_activations} " + f"g:{args.xnor_group_size} ws:{world_size} ga:{grad_accum_steps}") + + # --- Data --- + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all(): + for opt in optimizers: opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step, elapsed_ms): + import math + if args.lr_schedule == "cosine": + # Cosine decay from 1.0 to 0 over the full run, with optional warmup + if max_wallclock_ms is None: + total = args.iterations + warmup = min(args.lr_warmup_steps, total) + if step < warmup: + return step / max(warmup, 1) + progress = (step - warmup) / max(total - warmup, 1) + else: + warmup_ms = args.lr_warmup_steps * (elapsed_ms / max(step, 1)) if step > 0 else 1000.0 + if elapsed_ms < warmup_ms: + return elapsed_ms / max(warmup_ms, 1e-9) + progress = (elapsed_ms - warmup_ms) / max(max_wallclock_ms - warmup_ms, 1e-9) + return 0.5 * (1.0 + math.cos(math.pi * min(progress, 1.0))) + # Linear warmdown (original) + if args.warmdown_fraction <= 0: return 1.0 + if max_wallclock_ms is None: + warmdown_start = int(args.iterations * (1.0 - args.warmdown_fraction)) + return max((args.iterations - step) / max(args.iterations * args.warmdown_fraction, 1), 0.0) if step >= warmdown_start else 1.0 + warmdown_ms = max_wallclock_ms * args.warmdown_fraction + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # --- Seq len scheduling --- + seq_len_stages = [args.train_seq_len] + if args.seq_len_schedule: + seq_len_stages = [] + sl = 128 + while sl <= args.train_seq_len: + seq_len_stages.append(sl) + sl *= 2 + if seq_len_stages[-1] != args.train_seq_len: + seq_len_stages.append(args.train_seq_len) + log0(f"seq_len_schedule: {seq_len_stages}") + + def get_seq_len(step, elapsed_ms): + if len(seq_len_stages) <= 1: + return seq_len_stages[0] + n = len(seq_len_stages) + if max_wallclock_ms is not None: + stage_idx = min(int(elapsed_ms / (max_wallclock_ms / n)), n - 1) + else: + stage_idx = min(int(step / (args.iterations / n)), n - 1) + return seq_len_stages[stage_idx] + + # --- Compiler warmup --- + warmup_seq_lens = seq_len_stages if args.seq_len_schedule else [args.train_seq_len] + if len(warmup_seq_lens) > 0: + _ms = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + _os = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws, sl in enumerate(warmup_seq_lens): + zero_grad_all() + for mi in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = mi == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, sl, grad_accum_steps) + torch.compiler.cudagraph_mark_step_begin() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): loss = model(x, y) + (loss * grad_scale).backward() + for o in optimizers: o.step() + zero_grad_all() + log0(f"warmup:{ws+1}/{len(warmup_seq_lens)} seq_len={sl}") + base_model.load_state_dict(_ms, strict=True) + for o, s in zip(optimizers, _os): o.load_state_dict(s) + zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # --- Main loop --- + training_time_ms = _resume_ms + stop_after_step = None + _untied = _resume_untied + _ema_started = _resume_ema_started + _ema_steps = _resume_ema_steps + train_loss = torch.zeros((), device=device) + torch.cuda.synchronize() + t0 = time.perf_counter() + step = _resume_step + steps_this_session = 0 + + if max_wallclock_ms is None: + if step >= args.iterations: + stop_after_step = step + elif training_time_ms >= max_wallclock_ms: + stop_after_step = step + + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms") + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms") + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + # Muon momentum warmup + if args.matrix_optimizer != "adam": + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + for g in opt_muon.param_groups: + g["momentum"] = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + + zero_grad_all() + train_loss.zero_() + current_seq_len = get_seq_len(step, elapsed_ms) + for micro in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = micro == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, current_seq_len, grad_accum_steps) + torch.compiler.cudagraph_mark_step_begin() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y) + # Compression regularizer: penalize sign entropy within groups + if args.sign_compress_reg > 0 and micro == 0: + reg = torch.tensor(0.0, device=device) + for p in base_model.parameters(): + if p.ndim == 2 and p.numel() > 65536: + g = args.xnor_group_size + pg = p.reshape(-1, g) + # soft sign: tanh(k*w) approximates sign(w) but has gradient + soft = torch.tanh(pg * 10.0) + # mean per group: +1/-1 = all same sign (compressible), 0 = balanced + reg = reg + (1.0 - soft.mean(dim=-1).abs().mean()) + loss = loss + args.sign_compress_reg * reg + train_loss.add_(loss.detach()) + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for opt in optimizers: + for g in opt.param_groups: g["lr"] = g["base_lr"] * scale + opt.step() + zero_grad_all() + + # EMA update + if ema_model is not None: + if not _ema_started: + if max_wallclock_ms is not None: + should_start = elapsed_ms >= args.ema_start_fraction * max_wallclock_ms + else: + should_start = step >= int(args.iterations * args.ema_start_fraction) + if should_start: + _ema_started = True + _ema_steps = 0 + with torch.no_grad(): + for ep, bp in zip(ema_model.parameters(), base_model.parameters()): + ep.data.copy_(bp.data) + log0(f"step:{step} ema_started") + if _ema_started: + _ema_steps += 1 + decay = min(args.ema_decay, (1.0 + _ema_steps) / (10.0 + _ema_steps)) + with torch.no_grad(): + for ep, bp in zip(ema_model.parameters(), base_model.parameters()): + ep.data.mul_(decay).add_(bp.data, alpha=1.0 - decay) + + step += 1 + steps_this_session += 1 + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # Checkpoint + if master_process and args.checkpoint_every > 0 and step % args.checkpoint_every == 0: + save_checkpoint(args.checkpoint_dir, step, base_model, optimizers, + approx_ms, _untied, ema_model, _ema_started, _ema_steps) + log0(f"step:{step} checkpoint saved") + + if args.train_log_every > 0 and step % args.train_log_every == 0: + log0(f"step:{step}/{args.iterations} loss:{train_loss.item():.4f} " + f"t:{approx_ms:.0f}ms step_avg:{approx_ms/step:.1f}ms") + + if args.churn_log_every > 0 and step % args.churn_log_every == 0: + log0(f"step:{step} churn:{churn_fn(base_model, args.xnor_group_size):.4f}") + + # Wallclock cap + if stop_after_step is None and max_wallclock_ms is not None and step % 10 == 0: + reached_cap = approx_ms >= max_wallclock_ms + if distributed: + cap_t = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(cap_t, op=dist.ReduceOp.MAX) + reached_cap = bool(cap_t.item()) + if reached_cap: + stop_after_step = step + + # --- Sign-sort permutation for compression --- + # Permute MLP hidden dimension so same-sign weights cluster together. + # Sort columns of down_proj by their sign pattern (lexicographic on sign bits). + # Apply matching row permutation to up_proj. Model output is identical. + if master_process: + sd = (ema_model if ema_model is not None and _ema_started else base_model).state_dict() + for layer_idx in range(args.num_layers): + # Find the MLP weight pairs + down_key = f"blocks.{layer_idx}.mlp.proj.weight" + if args.activation_type == "swiglu": + up_key = f"blocks.{layer_idx}.mlp.gate_up.weight" + else: + up_key = f"blocks.{layer_idx}.mlp.fc.weight" + if down_key not in sd or up_key not in sd: + continue + w_down = sd[down_key] # [dim, hidden] + w_up = sd[up_key] # [hidden, dim] (or [hidden*2, dim] for swiglu) + hidden = w_down.shape[1] + # Sort hidden dim columns of w_down by sign pattern + signs = (w_down >= 0).to(torch.int8) # [dim, hidden] + # Pack sign columns into sortable keys (use first 64 rows as sort key) + n_key_rows = min(64, signs.shape[0]) + sort_keys = torch.zeros(hidden, dtype=torch.int64, device=w_down.device) + for r in range(n_key_rows): + sort_keys += signs[r, :].long() << (n_key_rows - 1 - r) + perm = sort_keys.argsort() + # Apply permutation + sd[down_key] = w_down[:, perm] + if args.activation_type == "swiglu": + half = w_up.shape[0] // 2 + sd[up_key] = torch.cat([w_up[:half][perm], w_up[half:][perm]], dim=0) + else: + sd[up_key] = w_up[perm] + if base_model.tie_embeddings: + sd.pop("lm_head.weight", None) + q_obj, q_stats = q_sd(sd, group_size=args.xnor_group_size, + fp_storage=args.fp_storage, scale_fp8=args.scale_fp8) + buf = io.BytesIO() + torch.save(q_obj, buf) + raw_bytes = buf.getvalue() + # Compare LZMA vs Brotli, keep smallest + lzma_blob = lzma.compress(raw_bytes, preset=9) + best_blob, best_method = lzma_blob, "lzma" + comp_log = f"compression: lzma={len(lzma_blob)/1e6:.2f}MB" + if HAS_BROTLI: + brotli_blob = brotli.compress(raw_bytes, quality=11) + comp_log += f" brotli={len(brotli_blob)/1e6:.2f}MB" + if len(brotli_blob) < len(best_blob): + best_blob, best_method = brotli_blob, "brotli" + if HAS_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + zstd_blob = cctx.compress(raw_bytes) + comp_log += f" zstd={len(zstd_blob)/1e6:.2f}MB" + if len(zstd_blob) < len(best_blob): + best_blob, best_method = zstd_blob, "zstd" + log0(comp_log) + log0(f"using: {best_method}") + with open("final_model.xnor.ptz", "wb") as f: + # Method byte: 0=lzma, 1=brotli, 2=zstd + method_id = {"lzma": 0, "brotli": 1, "zstd": 2}[best_method] + f.write(bytes([method_id])) + f.write(best_blob) + artifact_bytes = 1 + len(best_blob) + code_bytes = len(code.encode("utf-8")) + total = artifact_bytes + code_bytes + log0(f"artifact:{artifact_bytes/1e6:.2f}MB binary:{q_stats['binary_params']}({q_stats['binary_bytes']}B) " + f"fp:{q_stats['fp_params']}({q_stats['fp_bytes']}B) code:{code_bytes}") + log0(f"budget:{total}/{16000000} ({total/1e6:.2f}/{16.00:.2f}MB) " + f"{'FITS' if total <= 16000000 else 'OVER'}") + + # --- Roundtrip eval --- + if distributed: dist.barrier() + with open("final_model.xnor.ptz", "rb") as f: + raw = f.read() + method_byte = raw[0] + compressed = raw[1:] + if method_byte == 2 and HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + decompressed = dctx.decompress(compressed) + elif method_byte == 1 and HAS_BROTLI: + decompressed = brotli.decompress(compressed) + else: + decompressed = lzma.decompress(compressed) + loaded = torch.load(io.BytesIO(decompressed), map_location="cpu", weights_only=False) + base_model.load_state_dict(deq_sd(loaded), strict=False) + if ema_model is not None: + ema_model.load_state_dict(deq_sd(loaded), strict=False) + torch._dynamo.reset() + q_val_loss, q_val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"final_xnor_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f}") + + opt_temp = 1.0 + if args.temp_scaling: + torch.cuda.synchronize() + t_temp = time.perf_counter() + calibration_tokens = train_loader.stream.take(65536).to(device) + opt_temp = find_temp(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + torch.cuda.synchronize() + log0(f"temp_scaling optimal_T:{opt_temp:.2f} time:{1000*(time.perf_counter()-t_temp):.0f}ms") + + if args.sliding_eval: + torch._dynamo.reset() + torch.cuda.empty_cache() + torch.cuda.synchronize() + t_sl = time.perf_counter() + sw_loss, sw_bpb = eval_val_sliding(args, base_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, stride=args.sliding_eval_stride, + temperature=opt_temp) + torch.cuda.synchronize() + log0(f"final_sliding val_loss:{sw_loss:.4f} val_bpb:{sw_bpb:.4f} " + f"(stride={args.sliding_eval_stride}, T={opt_temp:.2f}) time:{1000*(time.perf_counter()-t_sl):.0f}ms") + + if distributed: dist.destroy_process_group() + + +if __name__ == "__main__": + main() +==================================================================================================== +--- Hyperparameters --- +activation_type=signsq adam_eps=1e-08 adam_lr=0.02 adam_wd=0.04 attn_res=0 beta1=0.9 beta2=0.95 binarize_activations=2 checkpoint_dir=./checkpoints checkpoint_every=0 churn_log_every=0 compile_mode=default data_path=./data/datasets/fineweb10B_sp1024 ema=False ema_decay=0.995 ema_start_fraction=0.0 embed_dim=384 embed_lr=0.6 fp_storage=True grad_clip_norm=1.0 head_lr=0.02 iterations=20000 logit_softcap=10.0 lr_schedule=cosine lr_warmup_steps=5 matrix_lr=0.008 matrix_optimizer=muon max_wallclock_seconds=599.0 mlp_mult=4 model_dim=1024 muon_backend_steps=3 muon_momentum=0.8 muon_momentum_warmup_start=0.8 muon_momentum_warmup_steps=200 muon_wd=0.04 num_heads=8 num_kv_heads=4 num_layers=10 qk_gain_init=2.25 rope_base=5000.0 rope_type=yarn run_id=P5_full-xnor_8xH100_1775400387 scalar_lr=0.01 scale_fp8=False seed=1337 seq_len_schedule=True sign_compress_reg=0.0 sliding_batch_size=512 sliding_eval=True sliding_eval_stride=48 smear=False softcap_type=poly temp_scaling=True tie_embeddings=1 tied_embed_init_std=0.005 tied_embed_lr=0.05 tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model train_batch_tokens=65536 train_log_every=1000 train_seq_len=1024 use_gram_ns=False use_int8_kernel=False use_triton_kernel=True val_batch_size=524288 val_loss_every=0 vocab_size=1024 warmdown_fraction=0.3 xnor_group_size=256 yarn_max_len=2048 +XNOR-Net LLM | params:117618768 L:10 d:1024 h:8 kv:4 mlp:4x act:signsq xnor_act:2 g:256 ws:8 ga:1 +seq_len_schedule: [128, 256, 512, 1024] +warmup:1/4 seq_len=128 +warmup:2/4 seq_len=256 +warmup:3/4 seq_len=512 +warmup:4/4 seq_len=1024 +step:1000/20000 loss:3.5790 t:37359ms step_avg:37.4ms +step:2000/20000 loss:3.4716 t:75368ms step_avg:37.7ms +step:3000/20000 loss:3.2363 t:113908ms step_avg:38.0ms +step:4000/20000 loss:3.3267 t:152524ms step_avg:38.1ms +step:5000/20000 loss:3.2797 t:189536ms step_avg:37.9ms +step:6000/20000 loss:3.2266 t:228176ms step_avg:38.0ms +step:7000/20000 loss:3.2360 t:266157ms step_avg:38.0ms +step:8000/20000 loss:3.3412 t:303942ms step_avg:38.0ms +step:9000/20000 loss:3.4709 t:341970ms step_avg:38.0ms +step:10000/20000 loss:3.1597 t:380526ms step_avg:38.1ms +step:11000/20000 loss:3.0887 t:418259ms step_avg:38.0ms +step:12000/20000 loss:3.0631 t:456502ms step_avg:38.0ms +step:13000/20000 loss:3.0616 t:494196ms step_avg:38.0ms +step:14000/20000 loss:2.9467 t:532560ms step_avg:38.0ms +step:15000/20000 loss:2.8517 t:570126ms step_avg:38.0ms +step:15760/20000 val_loss:2.8681 val_bpb:1.5983 train_time:599050ms +stopping_early: wallclock_cap train_time:599050ms +compression: lzma=15.95MB brotli=15.88MB zstd=17.86MB +using: brotli +artifact:15.88MB binary:115343360(15319040B) fp:1226832(1248416B) code:78442 +budget:15958791/16000000 (15.96/16.00MB) FITS +final_xnor_roundtrip val_loss:2.8714 val_bpb:1.6001 +temp_scaling optimal_T:1.00 time:315ms +final_sliding val_loss:2.8082 val_bpb:1.5649 (stride=48, T=1.00) time:427419ms diff --git a/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/log_10mins_seed-42.txt b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/log_10mins_seed-42.txt new file mode 100644 index 0000000000..e07f9fec31 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/log_10mins_seed-42.txt @@ -0,0 +1,1723 @@ +"""XNOR-Net LLM training script for OpenAI's Parameter Golf Challenge. +Ciprian-Florin Ifrim — April 2026 + +Based on the Binary-Weight-Network script, extended to full XNOR-Net: +both weights AND activations are binarized using sign() with per-group/per-token +scaling factors derived from the XNOR-Net paper (Rastegari et al. 2016). + +This is the REFERENCE implementation using STE-simulated XNOR via F.linear. +The Triton INT8×INT8 kernel version comes later for H100 deployment. + +Architecture: U-Net transformer with skip connections — provides error correction +for the information loss inherent in binary quantization of both weights and activations. +""" + +import copy +import glob +import io +import math +import os +import random +import sys +import time +import lzma +from pathlib import Path +try: + import brotli + HAS_BROTLI = True +except ImportError: + HAS_BROTLI = False +try: + import zstandard as zstd + HAS_ZSTD = True +except ImportError: + HAS_ZSTD = False +try: + from gram_newton_schulz import GramNewtonSchulz, POLAR_EXPRESS_COEFFICIENTS + HAS_GRAM_NS = True +except ImportError: + HAS_GRAM_NS = False +import numpy as np +import warnings +warnings.filterwarnings("ignore", category=UserWarning, module="torch._inductor") +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False +def _get_flash_attn_func(): + """GPU-aware FlashAttention selection. + H100: FA3 -> FA2 -> SDPA + 5090: FA4 -> FA2 -> SDPA + Other: FA2 -> SDPA + """ + if torch.cuda.is_available(): + sm = torch.cuda.get_device_capability() + # Hopper (SM 90) — prefer FA3 + if sm[0] == 9 and sm[1] == 0: + try: + from flash_attn_interface import flash_attn_func + return flash_attn_func + except ImportError: + pass + # Blackwell (SM 100/120) — try FA4 + if sm[0] >= 10: + try: + from flash_attn.cute import flash_attn_func + return flash_attn_func + except ImportError: + pass + # FA2 — universal fallback with CUDA kernels + try: + from flash_attn import flash_attn_func + return flash_attn_func + except ImportError: + pass + # PyTorch native SDPA — always works, no install needed + def _sdpa_flash_attn_func(q, k, v, causal=False, **kwargs): + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + o = F.scaled_dot_product_attention(q, k, v, is_causal=causal, enable_gqa=True) + return o.transpose(1, 2) + return _sdpa_flash_attn_func +flash_attn_func = _get_flash_attn_func() + +# --------------------------------------------------------------------------- +# Hyperparameters +# --------------------------------------------------------------------------- +def _e(k, d, t=str): + v = os.environ.get(k, str(d)) + if t == bool: return bool(int(v)) + return t(v) + +class Hyperparameters: + # Data + data_path = _e("DATA_PATH", "./data/datasets/fineweb10B_sp8192") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = _e("TOKENIZER_PATH", "./data/tokenizers/fineweb_8192_bpe.model") + run_id = os.environ.get("RUN_ID", f"xnor_{int(time.time())}") + seed = _e("SEED", 1337, int) + # Compile + compile_mode = _e("COMPILE_MODE", "default") + # Eval + val_batch_size = _e("VAL_BATCH_SIZE", 524288, int) + val_loss_every = _e("VAL_LOSS_EVERY", 500, int) + train_log_every = _e("TRAIN_LOG_EVERY", 10, int) + sliding_eval = _e("SLIDING_EVAL", 0, bool) + sliding_eval_stride = _e("SLIDING_EVAL_STRIDE", 16, int) + sliding_batch_size = _e("SLIDING_BATCH_SIZE", 512, int) + temp_scaling = _e("TEMP_SCALING", 0, bool) + # Training schedule + iterations = _e("ITERATIONS", 2000, int) + warmdown_fraction = _e("WARMDOWN_FRACTION", 0.15, float) + lr_warmup_steps = _e("LR_WARMUP_STEPS", 5, int) + lr_schedule = _e("LR_SCHEDULE", "linear") # "linear" (current) or "cosine" + seq_len_schedule = _e("SEQ_LEN_SCHEDULE", 0, bool) # ramp seq_len from 128 to train_seq_len + train_batch_tokens = _e("TRAIN_BATCH_TOKENS", 524288, int) + train_seq_len = _e("TRAIN_SEQ_LEN", 1024, int) + max_wallclock_seconds = _e("MAX_WALLCLOCK_SECONDS", 0.0, float) + # Architecture — from XNOR-UNet spec + vocab_size = _e("VOCAB_SIZE", 8192, int) + num_layers = _e("NUM_LAYERS", 15, int) + num_kv_heads = _e("NUM_KV_HEADS", 4, int) + model_dim = _e("MODEL_DIM", 768, int) + num_heads = _e("NUM_HEADS", 8, int) + mlp_mult = _e("MLP_MULT", 4, int) + tie_embeddings = _e("TIE_EMBEDDINGS", 1, int) + embed_dim = _e("EMBED_DIM", 312, int) + activation_type = _e("ACTIVATION", "relu2") + # Attention + rope_base = _e("ROPE_BASE", 5000.0, float) + rope_type = _e("ROPE_TYPE", "yarn") + yarn_max_len = _e("YARN_MAX_LEN", 2048, int) + qk_gain_init = _e("QK_GAIN_INIT", 2.25, float) + # Logits + logit_softcap = _e("LOGIT_SOFTCAP", 10.0, float) + softcap_type = _e("SOFTCAP_TYPE", "poly") + tied_embed_init_std = _e("TIED_EMBED_INIT_STD", 0.005, float) + # Architecture extras + smear = _e("SMEAR", 1, bool) + attn_res = _e("ATTN_RES", 0, int) # 0=disabled, 1=sub-layer, 2=pass-level + # XNOR quantization + xnor_group_size = _e("XNOR_GROUP_SIZE", 128, int) + # Controls activation binarization: 0 = none (BWN), 1 = all (full XNOR), 2 = skip MLP down proj + binarize_activations = _e("BINARIZE_ACTIVATIONS", 1, int) + # Use INT8 tensor cores for binary matmul (speed test) + use_int8_kernel = _e("USE_INT8_KERNEL", 0, bool) + # Use Triton XNOR+POPCOUNT kernel (true 1-bit matmul) + use_triton_kernel = _e("USE_TRITON_KERNEL", 0, bool) + # FP8 QAT for non-binary params + _fp_raw = os.environ.get("FP_STORAGE", "FP8") + fp_storage = True if _fp_raw == "FP8" else False + # Scale storage for binary weight groups: FP8 or BF16 + _scale_raw = os.environ.get("SCALE_STORAGE", "BF16") + scale_fp8 = True if _scale_raw == "FP8" else False + # Compression regularizer: penalizes sign entropy within groups + sign_compress_reg = _e("SIGN_COMPRESS_REG", 0.0, float) + # Optimizer + matrix_lr = _e("MATRIX_LR", 0.04, float) + scalar_lr = _e("SCALAR_LR", 0.02, float) + tied_embed_lr = _e("TIED_EMBED_LR", 0.02, float) + embed_lr = _e("EMBED_LR", 0.6, float) + head_lr = _e("HEAD_LR", 0.02, float) + matrix_optimizer = _e("MATRIX_OPTIMIZER", "muon") + muon_momentum = _e("MUON_MOMENTUM", 0.95, float) + muon_backend_steps = _e("MUON_BACKEND_STEPS", 3, int) + muon_wd = _e("MUON_WD", 0.0, float) + use_gram_ns = _e("USE_GRAM_NS", 0, bool) + muon_momentum_warmup_start = _e("MUON_MOMENTUM_WARMUP_START", 0.85, float) + muon_momentum_warmup_steps = _e("MUON_MOMENTUM_WARMUP_STEPS", 500, int) + adam_lr = _e("ADAM_LR", 0.05, float) + adam_wd = _e("ADAM_WD", 0.05, float) + beta1 = _e("BETA1", 0.9, float) + beta2 = _e("BETA2", 0.95, float) + adam_eps = _e("ADAM_EPS", 1e-8, float) + grad_clip_norm = _e("GRAD_CLIP_NORM", 0.0, float) + # EMA — compatible with binary (no zero-weight collapse unlike ternary) + ema = _e("EMA", 0, bool) + ema_decay = _e("EMA_DECAY", 0.995, float) + ema_start_fraction = _e("EMA_START_FRACTION", 0.60, float) + # Checkpointing + checkpoint_every = _e("CHECKPOINT_EVERY", 0, int) + checkpoint_dir = _e("CHECKPOINT_DIR", "./checkpoints") + # Diagnostics + churn_log_every = _e("CHURN_LOG_EVERY", 500, int) + +# Scalar/low-dim parameter names (excluded from Muon, use Adam) +CTP = ("attn_scale", "mlp_scale", "resid_mix", "q_gain", "skip_weight", + "skip_weights", "vocab_bias", "smear") + +# --------------------------------------------------------------------------- +# Binary packing (1 bit per weight, lossless) +# --------------------------------------------------------------------------- +def pack_binary(q: Tensor) -> tuple[bytes, int]: + bits = ((q.reshape(-1).to(torch.int8) + 1) // 2).numpy().astype(np.uint8) + n = len(bits) + pad = (8 - n % 8) % 8 + if pad: + bits = np.concatenate([bits, np.zeros(pad, dtype=np.uint8)]) + groups = bits.reshape(-1, 8) + packed = np.zeros(len(groups), dtype=np.uint8) + for i in range(8): + packed |= groups[:, i] << i + return packed.tobytes(), n + +def unpack_binary(data: bytes, n: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8) + bits = np.zeros((len(packed), 8), dtype=np.int8) + for i in range(8): + bits[:, i] = (packed >> i) & 1 + flat = bits.reshape(-1)[:n] + return torch.from_numpy(flat.astype(np.int8) * 2 - 1) + +# --------------------------------------------------------------------------- +# State dict serialization +# --------------------------------------------------------------------------- +def q_sd(state_dict: dict, group_size: int = 128, fp_storage=False, scale_fp8=False) -> tuple[dict, dict]: + """Binary pack large 2D weight matrices, BF16/FP8 for everything else.""" + quantized = {} + stats = {"binary_params": 0, "binary_bytes": 0, "fp_params": 0, "fp_bytes": 0} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().float().contiguous() + t_orig_shape = list(t.shape) + if t.ndim == 3: + t = t.reshape(t.shape[0], -1) + is_binary = ( + t.ndim == 2 and t.numel() > 65_536 + and "tok_emb" not in name and "lm_head" not in name + and "embed_proj" not in name + ) + if is_binary: + pad = (group_size - t.shape[1] % group_size) % group_size + t_padded = F.pad(t, (0, pad)) if pad > 0 else t + t_grouped = t_padded.reshape(-1, group_size) + scale = t_grouped.abs().mean(-1, keepdim=True).clamp(min=1e-8) + q = torch.where(t_grouped >= 0, + torch.ones_like(t_grouped, dtype=torch.int8), + -torch.ones_like(t_grouped, dtype=torch.int8)) + packed_bytes, n_bits = pack_binary(q) + if scale_fp8: + scale_stored = scale.to(torch.float8_e4m3fn).squeeze(-1) + scale_bytes_per = 1 + else: + scale_stored = scale.bfloat16().squeeze(-1) + scale_bytes_per = 2 + quantized[name] = { + "type": "binary", "packed": packed_bytes, + "scale": scale_stored, + "shape": list(t.shape), "padded_cols": t_padded.shape[1], + "group_size": group_size, "n_bits": n_bits, + "orig_shape": t_orig_shape, + } + stats["binary_params"] += t.numel() + stats["binary_bytes"] += len(packed_bytes) + scale.numel() * scale_bytes_per + elif fp_storage and t.ndim == 2: + quantized[name] = {"type": "fp8", "data": t.to(torch.float8_e4m3fn)} + stats["fp_params"] += t.numel() + stats["fp_bytes"] += t.numel() + else: + quantized[name] = {"type": "bf16", "data": t.bfloat16()} + stats["fp_params"] += t.numel() + stats["fp_bytes"] += t.numel() * 2 + return quantized, stats + +def deq_sd(quantized: dict, target_dtype=torch.bfloat16): + """Reconstruct state dict from quantized representation.""" + out = {} + for name, entry in quantized.items(): + if entry["type"] == "binary": + q = unpack_binary(entry["packed"], entry["n_bits"]) + q = q.float().reshape(-1, entry["group_size"]) + scale = entry["scale"].float().unsqueeze(-1) + # No shrinkage correction: binary q_absmean = 1.0 always + t = (q * scale).reshape(-1, entry["padded_cols"]) + shape = entry["shape"] + result = t[:shape[0], :shape[1]].to(target_dtype) + orig = entry.get("orig_shape") + out[name] = result.reshape(orig).contiguous() if orig and orig != shape else result.contiguous() + elif entry["type"] == "fp8": + out[name] = entry["data"].to(torch.float32).to(target_dtype).contiguous() + else: + out[name] = entry["data"].to(target_dtype).contiguous() + return out + +# --------------------------------------------------------------------------- +# Diagnostics +# --------------------------------------------------------------------------- +_prev_committed: dict = {} +def churn_fn(model: nn.Module, group_size: int = 128): + global _prev_committed + total = flipped = 0 + with torch.no_grad(): + for name, p in model.named_parameters(): + if p.ndim == 2 and "weight" in name and p.shape[0] > 1: + w = p.detach().float().reshape(-1, group_size) + q = torch.where(w >= 0, torch.ones_like(w), -torch.ones_like(w)).cpu().numpy() + if name in _prev_committed: + flipped += int(np.sum(q != _prev_committed[name])) + total += q.size + _prev_committed[name] = q + return flipped / max(total, 1) + +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# Muon optimizer +# --------------------------------------------------------------------------- +def ns_orth(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +def gram_ns_orth(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: + """Gram Newton-Schulz with Polar Express coefficients. 5 steps, restart at step 2. + Fully unrolled for torch.compile. bfloat16 throughout.""" + X = G.bfloat16() + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + X = X / (X.norm() + eps) + n = X.size(0) + I = torch.eye(n, device=X.device, dtype=X.dtype) + + # Phase 1: steps 0-1 + R = X @ X.T + # Step 0: a=4.0848, b=-6.8946, c=2.9270 + Z = -6.8946 * R + 2.9270 * (R @ R) + Q = Z + 4.0848 * I + RZ = 4.0848 * R + R @ Z + R = 4.0848 * RZ + Z @ RZ + # Step 1: a=3.9505, b=-6.3029, c=2.6377 + Z = -6.3029 * R + 2.6377 * (R @ R) + Q = 3.9505 * Q + Q @ Z + # No R update (next is reset) + + # Reset + X = Q @ X + R = X @ X.T + + # Phase 2: steps 2-4 + # Step 2: a=3.7418, b=-5.5913, c=2.3037 + Z = -5.5913 * R + 2.3037 * (R @ R) + Q = Z + 3.7418 * I + RZ = 3.7418 * R + R @ Z + R = 3.7418 * RZ + Z @ RZ + # Step 3: a=2.8769, b=-3.1427, c=1.2046 + Z = -3.1427 * R + 1.2046 * (R @ R) + Q = 2.8769 * Q + Q @ Z + RZ = 2.8769 * R + R @ Z + R = 2.8769 * RZ + Z @ RZ + # Step 4: a=2.8366, b=-3.0525, c=1.2012 + Z = -3.0525 * R + 1.2012 * (R @ R) + Q = 2.8366 * Q + Q @ Z + + X = Q @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr, momentum, backend_steps, nesterov=True, wd=0.0): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, wd=wd)) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr, momentum = group["lr"], group["momentum"] + backend_steps, nesterov = group["backend_steps"], group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = F.rms_norm(g.float(), (g.size(-1),)).bfloat16() + g = ns_orth(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("wd", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr:curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.mul_(1 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + +# --------------------------------------------------------------------------- +# Data loading +# --------------------------------------------------------------------------- +def ld_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" Tensor: + chunks = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int): + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start:start + per_rank_span].pin_memory().to(self.device, non_blocking=True).to(torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x, y + +# --------------------------------------------------------------------------- +# Model components +# --------------------------------------------------------------------------- +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def apply_qat_ste(w: Tensor, fp_storage) -> Tensor: + if not fp_storage: + return w + w_sim = w.to(torch.float8_e4m3fn).to(w.dtype) + return (w_sim - w).detach() + w + +class QATLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=False, fp_storage=False): + super().__init__(in_features, out_features, bias=bias) + self.fp_storage = fp_storage + def forward(self, x: Tensor) -> Tensor: + w = apply_qat_ste(self.weight, self.fp_storage) + return F.linear(x, w.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) + +class QATEmbedding(nn.Embedding): + def __init__(self, num_embeddings, embedding_dim, fp_storage=False): + super().__init__(num_embeddings, embedding_dim) + self.fp_storage = fp_storage + def forward(self, input: Tensor) -> Tensor: + w = apply_qat_ste(self.weight, self.fp_storage) + return F.embedding(input, w, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + +# --------------------------------------------------------------------------- +# XNOR-Net Linear Layers +# --------------------------------------------------------------------------- +_INT8_KERNEL = False # set from Hyperparameters.use_int8_kernel in main() + +class _INT8XNORFn(torch.autograd.Function): + """INT8 tensor core forward, BF16 backward. + Uses per-row alpha (not per-group) so we get a single torch._int_mm call. + """ + @staticmethod + def forward(ctx, x, w): + # Per-row weight scale (single matmul) + w_alpha = w.detach().abs().mean(dim=-1).clamp_(min=1e-8) # [N] + x_beta = x.detach().abs().mean(dim=-1, keepdim=True).clamp_(min=1e-8) # [..., 1] + + w_int8 = w.detach().sign().to(torch.int8) # [N, K] + x_int8 = x.detach().sign().to(torch.int8) # [..., K] + + # Flatten to 2D for torch._int_mm(A[M,K], B[K,N]) -> [M,N] + orig_shape = x.shape[:-1] + M, K, N = orig_shape.numel(), x.shape[-1], w.shape[0] + x_2d = x_int8.reshape(M, K) + wt = w_int8.t().contiguous() # [K, N] + raw = torch._int_mm(x_2d, wt) # [M, N] INT32 + + y = raw.to(x.dtype) * x_beta.reshape(M, 1) * w_alpha.unsqueeze(0) + ctx.save_for_backward(x, w, w_alpha, x_beta) + return y.reshape(*orig_shape, N) + + @staticmethod + def backward(ctx, grad_output): + x, w, w_alpha, x_beta = ctx.saved_tensors + dtype = grad_output.dtype + + w_binary = (w.sign() * w_alpha.unsqueeze(-1)).to(dtype) # [N, K] + x_binary = (x.sign() * x_beta).to(dtype) # [..., K] + + # STE: gradient flows through sign() as identity + grad_x = grad_output @ w_binary + M = grad_output.shape[:-1].numel() + grad_w = grad_output.reshape(M, -1).t() @ x_binary.reshape(M, -1) + return grad_x, grad_w + +# --------------------------------------------------------------------------- +# Triton XNOR+POPCOUNT kernel — true 1-bit matmul +# --------------------------------------------------------------------------- +_TRITON_KERNEL = False # set from Hyperparameters in main() +_SCALE_FP8_STE = False # simulate FP8 scale quantization during training + +def pack_signs(x: Tensor) -> Tensor: + """Pack sign bits into int32. +1 -> bit 1, -1 -> bit 0. + x: [..., K] any dtype. K must be divisible by 32. + returns: [..., K//32] int32 + """ + K = x.shape[-1] + assert K % 32 == 0, f"K={K} must be divisible by 32" + bits = (x > 0).view(*x.shape[:-1], K // 32, 32) + shifts = torch.arange(32, device=x.device, dtype=torch.int32) + return (bits.to(torch.int32) << shifts).sum(dim=-1) + +if HAS_TRITON: + @triton.jit + def _xnor_matmul_fwd_kernel( + X_ptr, W_ptr, Y_ptr, alpha_ptr, beta_ptr, + M, N, K_packed, GROUP_PACKED: tl.constexpr, num_groups, + stride_xm, stride_wn, stride_yn, stride_alpha_n, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + ): + pid = tl.program_id(0) + num_n = tl.cdiv(N, BLOCK_N) + pid_m = pid // num_n + pid_n = pid % num_n + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_m = offs_m < M + mask_n = offs_n < N + + y_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.bfloat16) + + x_base = X_ptr + offs_m * stride_xm + w_base = W_ptr + offs_n * stride_wn + + # Loop over groups, each with GROUP_PACKED packed int32s + for g in range(num_groups): + group_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + k_start = g * GROUP_PACKED + for k in range(GROUP_PACKED): + x = tl.load(x_base + k_start + k, mask=mask_m, other=0) + w = tl.load(w_base + k_start + k, mask=mask_n, other=0) + # Cast broadcast XOR to int32 to prevent int64 promotion in popc + diff = (x[:, None] ^ w[None, :]).to(tl.int32) + group_acc += tl.extra.cuda.libdevice.popc(diff) + + # diff_bits = popc(XOR), dot = GROUP_SIZE - 2 * diff_bits + group_dot = (GROUP_PACKED * 32 - 2 * group_acc).to(tl.bfloat16) + + # Per-group alpha: alpha[n, g] + alpha_g = tl.load(alpha_ptr + offs_n * stride_alpha_n + g, mask=mask_n, other=1.0).to(tl.bfloat16) + y_acc += group_dot * alpha_g[None, :] + + # Per-token beta + beta = tl.load(beta_ptr + offs_m, mask=mask_m, other=1.0).to(tl.bfloat16) + y_acc = y_acc * beta[:, None] + + # Store + y_ptrs = Y_ptr + offs_m[:, None] * stride_yn + offs_n[None, :] + tl.store(y_ptrs, y_acc, mask=mask_m[:, None] & mask_n[None, :]) + +class _TritonXNORFn(torch.autograd.Function): + """Triton XNOR+POPCOUNT forward with per-group alpha, BF16 backward.""" + @staticmethod + def forward(ctx, x, w, group_size): + N, K = w.shape + assert K % group_size == 0, f"K={K} must be divisible by group_size={group_size}" + assert K % 32 == 0, f"K={K} must be divisible by 32" + num_groups = K // group_size + GROUP_PACKED = group_size // 32 + + # Per-group weight scale: [N, num_groups] + w_alpha = w.detach().reshape(N, num_groups, group_size).abs().mean(dim=-1).clamp_(min=1e-8) + if _SCALE_FP8_STE: + w_alpha = w_alpha.to(torch.float8_e4m3fn).to(w_alpha.dtype) + # Per-token activation scale: [..., 1] + x_beta_kd = x.detach().abs().mean(dim=-1, keepdim=True).clamp_(min=1e-8) + + # Pack signs into int32 + x_packed = pack_signs(x.detach()) # [..., K//32] + w_packed = pack_signs(w.detach()) # [N, K//32] + + orig_shape = x.shape[:-1] + M = orig_shape.numel() + K_packed = K // 32 + + x_2d = x_packed.reshape(M, K_packed).contiguous() + w_alpha_c = w_alpha.contiguous() # [N, num_groups] + beta_flat = x_beta_kd.reshape(M).contiguous() + + Y = torch.empty(M, N, device=x.device, dtype=torch.bfloat16) + + BLOCK_M, BLOCK_N = 64, 64 + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),) + + _xnor_matmul_fwd_kernel[grid]( + x_2d, w_packed, Y, w_alpha_c, beta_flat, + M, N, K_packed, GROUP_PACKED, num_groups, + x_2d.stride(0), w_packed.stride(0), Y.stride(0), w_alpha_c.stride(0), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + ) + + ctx.save_for_backward(x, w, w_alpha, x_beta_kd) + ctx.group_size = group_size + return Y.reshape(*orig_shape, N) + + @staticmethod + def backward(ctx, grad_output): + x, w, w_alpha, x_beta_kd = ctx.saved_tensors + dtype = grad_output.dtype + group_size = ctx.group_size + + # Reconstruct per-group binary weights for backward + N, K = w.shape + num_groups = K // group_size + w_binary = (w.sign().reshape(N, num_groups, group_size) * + w_alpha.unsqueeze(-1)).reshape(N, K).to(dtype) + x_binary = (x.sign() * x_beta_kd).to(dtype) + + grad_x = grad_output @ w_binary + M = grad_output.shape[:-1].numel() + grad_w = grad_output.reshape(M, -1).t() @ x_binary.reshape(M, -1) + return grad_x, grad_w, None # None for group_size + +class XNORLinear(nn.Linear): + """XNOR-Net linear layer: binarizes BOTH weights AND activations. + + Forward pass: + 1. Weight binarization: W_bin = sign(W) * alpha, alpha = mean(|W|) per group + 2. Activation binarization: X_bin = sign(X) * beta, beta = mean(|X|) per token + 3. Output = F.linear(X_bin, W_bin) + + Backward pass: Straight-Through Estimator (STE) — gradients flow through as + if the binarization was identity. This matches the XNOR-Net paper's training + algorithm (Algorithm 1). + + When binarize_act=False, this degrades to Binary-Weight-Network (BWN). + """ + def __init__(self, in_features, out_features, bias=False, group_size=128, + binarize_act=True): + super().__init__(in_features, out_features, bias=bias) + self.group_size = group_size + self.binarize_act = binarize_act + + def forward(self, x: Tensor) -> Tensor: + # Triton XNOR+POPCOUNT path (full XNOR only, per-row alpha) + if _TRITON_KERNEL and self.binarize_act: + y = _TritonXNORFn.apply(x, self.weight, self.group_size) + if self.bias is not None: + y = y + self.bias.to(y.dtype) + return y + # INT8 tensor core path (full XNOR only, per-row alpha) + if _INT8_KERNEL and self.binarize_act: + y = _INT8XNORFn.apply(x, self.weight) + if self.bias is not None: + y = y + self.bias.to(y.dtype) + return y + + # Standard STE path (per-group alpha) + w = self.weight.bfloat16() + g = self.group_size + + # --- Weight binarization (STE) --- + w_g = w.reshape(-1, g) + alpha = w_g.abs().mean(-1, keepdim=True).clamp(min=1e-8) + # Simulate FP8 scale quantization so model compensates for roundtrip error + if _SCALE_FP8_STE: + alpha_q = alpha.to(torch.float8_e4m3fn).to(alpha.dtype) + alpha = alpha + (alpha_q - alpha).detach() + w_sign = torch.where(w_g >= 0, torch.ones_like(w_g), -torch.ones_like(w_g)) + w_binary = w + ((w_sign * alpha).reshape(w.shape) - w).detach() + + # --- Activation binarization (STE) --- + if self.binarize_act: + beta = x.abs().mean(dim=-1, keepdim=True).clamp(min=1e-8) + x_sign = torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x)) + x_binary = x + ((x_sign * beta) - x).detach() + else: + x_binary = x + + return F.linear(x_binary, w_binary, + self.bias.to(x.dtype) if self.bias is not None else None) + + +class NormedXNORLinear(XNORLinear): + """XNOR linear with RMSNorm on input. + Used for output projections receiving un-normalized activations (attention out, MLP down). + The RMSNorm ensures zero-mean input before binarization — critical for sign() quality + (XNOR-Net paper: B-A-C-P block order, Table 3b: 30.3% vs 44.2%). + """ + def forward(self, x: Tensor) -> Tensor: + return super().forward(F.rms_norm(x, (x.size(-1),))) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(p in name for p in CTP)) and param.dtype != torch.float32: + param.data = param.data.float() + +# --------------------------------------------------------------------------- +# Rotary embeddings +# --------------------------------------------------------------------------- +class Rotary(nn.Module): + def __init__(self, dim, base=10000.0, rope_type="rope", yarn_max_len=4096, train_seq_len=1024): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + if rope_type == "yarn": + scale = train_seq_len / yarn_max_len + freq_idx = torch.arange(0, dim, 2, dtype=torch.float32) + ramp = torch.clamp((freq_idx / dim - 0.25) / 0.75, 0.0, 1.0) + inv_freq = inv_freq / (ramp * (1.0 / scale - 1.0) + 1.0) + # Precompute for max seq_len, slice at runtime + t = torch.arange(train_seq_len, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self.register_buffer("_cos", freqs.cos()[None, :, None, :], persistent=False) + self.register_buffer("_sin", freqs.sin()[None, :, None, :], persistent=False) + + def forward(self, seq_len, device, dtype): + return self._cos[:, :seq_len].to(dtype=dtype), self._sin[:, :seq_len].to(dtype=dtype) + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- +class CausalSelfAttention(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + group_size=128, rope_type="rope", yarn_max_len=4096, + train_seq_len=1024, binarize_act=True): + super().__init__() + self.num_heads, self.num_kv_heads = num_heads, num_kv_heads + self.head_dim = dim // num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + # Fused QKV — single XNOR matmul + self.c_qkv = XNORLinear(dim, self.q_size + 2 * self.kv_size, bias=False, + group_size=group_size, binarize_act=binarize_act) + # Output projection with RMSNorm before binarization (B-A-C-P order) + self.proj = NormedXNORLinear(dim, dim, bias=False, + group_size=group_size, binarize_act=binarize_act) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, + rope_type=rope_type, yarn_max_len=yarn_max_len, + train_seq_len=train_seq_len) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + qkv_out = self.c_qkv(x) + q_out, k_out, v_out = qkv_out.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q_out.reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = k_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = v_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + # QK normalization (BF16, stays float — non-negotiable for attention quality) + q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + # FlashAttention-3 (BF16 — attention scores/softmax/V are NEVER binarized) + y = flash_attn_func(q.contiguous(), k.contiguous(), v.contiguous(), causal=True) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + +# --------------------------------------------------------------------------- +# MLP +# --------------------------------------------------------------------------- +class MLP(nn.Module): + def __init__(self, dim, mlp_mult, group_size=128, activation="relu2", + binarize_act=True, binarize_down=True): + super().__init__() + hidden = mlp_mult * dim + self.activation = activation + if activation == "swiglu": + self.gate_up = XNORLinear(dim, hidden * 2, bias=False, + group_size=group_size, binarize_act=binarize_act) + else: + self.fc = XNORLinear(dim, hidden, bias=False, + group_size=group_size, binarize_act=binarize_act) + self.proj = NormedXNORLinear(hidden, dim, bias=False, + group_size=group_size, binarize_act=binarize_down) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + if self.activation == "swiglu": + gu = self.gate_up(x) + gate, up = gu.chunk(2, dim=-1) + return self.proj(F.silu(gate) * up) + elif self.activation == "signsq": + h = self.fc(x) + return self.proj(h * h.abs()) + else: # relu2 + return self.proj(torch.relu(self.fc(x)).square()) + +# --------------------------------------------------------------------------- +# SmearModule +# --------------------------------------------------------------------------- +class SmearModule(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + cumsum = x.cumsum(dim=1) + counts = torch.arange(1, x.size(1) + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smeared = cumsum / counts + gate = torch.tanh(self.gate.to(dtype=x.dtype)) + return x + gate * (smeared - x) + +# --------------------------------------------------------------------------- +# Transformer block +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# Attention Residuals (AttnRes) — depth-wise attention over prior outputs +# From: "Attention Residuals" (Kimi Team, 2026) +# --------------------------------------------------------------------------- +class AttnRes(nn.Module): + def __init__(self, dim, n_queries): + super().__init__() + # Zero-init -> uniform weights at start -> equivalent to standard residual + self.queries = nn.ParameterList([ + nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + for _ in range(n_queries) + ]) + self.key_norm_weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + + def attend(self, stored, query_idx): + w = self.queries[query_idx] + logits = [] + for v in stored: + k = F.rms_norm(v, (v.size(-1),)) * self.key_norm_weight.to(v.dtype) + logits.append(torch.einsum('d, b t d -> b t', w.to(v.dtype), k)) + logits = torch.stack(logits, dim=0) + weights = F.softmax(logits, dim=0) + out = torch.zeros_like(stored[0]) + for i, v in enumerate(stored): + out = out + weights[i].unsqueeze(-1) * v + return out + +# --------------------------------------------------------------------------- +# Transformer Block +# --------------------------------------------------------------------------- +class Block(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, group_size=128, activation="relu2", + rope_type="rope", yarn_max_len=4096, train_seq_len=1024, + binarize_act=True, binarize_down=True, smear=False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, + qk_gain_init, group_size, rope_type, + yarn_max_len, train_seq_len, binarize_act) + self.mlp = MLP(dim, mlp_mult, group_size, activation, binarize_act, binarize_down) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearModule(dim) if smear else None + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0] * x + mix[1] * x0 + n = self.attn_norm(x) + x = x + self.attn_scale.to(dtype=x.dtype) * self.attn(n) + x = x + self.mlp_scale.to(dtype=x.dtype) * self.mlp(self.mlp_norm(x)) + if self.smear is not None: + x = self.smear(x) + return x + +# --------------------------------------------------------------------------- +# GPT with U-Net skip connections +# --------------------------------------------------------------------------- +class GPT(nn.Module): + def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, + mlp_mult, tie_embeddings, tied_embed_init_std, logit_softcap, + rope_base, qk_gain_init, group_size=128, activation="relu2", + embed_dim=0, fp_storage=False, softcap_type="poly", + rope_type="rope", yarn_max_len=4096, train_seq_len=1024, + binarize_act=True, binarize_down=True, smear=False, attn_res=0): + super().__init__() + self.tie_embeddings = tie_embeddings + self.logit_softcap = logit_softcap + self.softcap_type = softcap_type + self.fp_storage = fp_storage + self.embed_dim = embed_dim if embed_dim > 0 else model_dim + self.attn_res_mode = attn_res + + # Embedding (FP8 QAT — NEVER binarized, confirmed 0.11 bpb RT gap) + self.tok_emb = QATEmbedding(vocab_size, self.embed_dim, fp_storage=fp_storage) + # Projection from embed_dim to model_dim (and reverse for tied logits) + self.embed_proj = QATLinear(self.embed_dim, model_dim, bias=False, + fp_storage=fp_storage) if self.embed_dim != model_dim else None + self.embed_proj_rev = QATLinear(model_dim, self.embed_dim, bias=False, + fp_storage=fp_storage) if self.embed_dim != model_dim else None + + # U-Net structure + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + # Skip weights initialized to ONES (not zeros — zeros costs 0.010 bpb) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + + # Transformer blocks + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, group_size, activation, rope_type, yarn_max_len, + train_seq_len, binarize_act, binarize_down, smear) + for _ in range(num_layers) + ]) + self.final_norm = RMSNorm() + + # Attention Residuals (replaces U-Net skips when enabled) + self.attn_res = None + if attn_res == 2: # pass-level + self.attn_res = AttnRes(model_dim, n_queries=num_layers) + elif attn_res == 1: # sub-layer + self.attn_res = AttnRes(model_dim, n_queries=num_layers * 2) + + # Logit head (tied to embedding) + self.lm_head = QATLinear(model_dim, vocab_size, bias=False, fp_storage=fp_storage) + self.lm_head._zero_init = True + if tie_embeddings: + self.lm_head.weight.requires_grad_(False) + self.vocab_bias = nn.Parameter(torch.zeros(vocab_size, dtype=torch.float32)) + + self._init_weights(tied_embed_init_std) + + def _init_weights(self, tied_embed_init_std): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=tied_embed_init_std) + for module in self.modules(): + if isinstance(module, XNORLinear) and not getattr(module, "_zero_init", False): + nn.init.normal_(module.weight, mean=0.0, std=0.02) + elif isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _compute_logits(self, x: Tensor) -> Tensor: + if self.tie_embeddings: + proj = self.embed_proj_rev(x) if self.embed_proj_rev is not None else x + return F.linear(proj, self.tok_emb.weight.to(x.dtype)) + return self.lm_head(x) + + def _softcap(self, logits: Tensor) -> Tensor: + s = self.logit_softcap + if self.softcap_type == "tanh": + return s * torch.tanh(logits / s) + # Poly softcap — fuses with torch.compile (tanh does not) + x_sc = torch.clamp(logits / s, -2.0, 2.0) + x2 = x_sc * x_sc + return s * torch.clamp(x_sc * (1.0 - x2 / 3.0 + x2 * x2 / 15.0), -1.0, 1.0) + + def forward(self, input_ids: Tensor, target_ids: Tensor, reduction="mean", temperature=1.0): + x = self.tok_emb(input_ids) + x = x.float() + if self.embed_proj is not None: + x = self.embed_proj(x) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.attn_res_mode == 2: + # Pass-level AttnRes: attend over all previous block outputs + stored = [x] + for i, block in enumerate(self.blocks): + x = self.attn_res.attend(stored, i) + x = block(x, x0) + stored.append(x) + elif self.attn_res_mode == 1: + # Sub-layer AttnRes: attend before each attn and mlp sub-layer + stored = [x] + q_idx = 0 + for i, block in enumerate(self.blocks): + # Attention sub-layer + x = self.attn_res.attend(stored, q_idx) + mix = block.resid_mix.to(dtype=x.dtype) + x = mix[0] * x + mix[1] * x0 + attn_out = block.attn_scale.to(dtype=x.dtype) * block.attn(block.attn_norm(x)) + stored.append(attn_out) + q_idx += 1 + # MLP sub-layer + x = self.attn_res.attend(stored, q_idx) + mlp_out = block.mlp_scale.to(dtype=x.dtype) * block.mlp(block.mlp_norm(x)) + stored.append(mlp_out) + q_idx += 1 + if block.smear is not None: + x = block.smear(self.attn_res.attend(stored, q_idx - 1)) + else: + # Standard U-Net (mode 0) + skips = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype) * skips.pop() + x = self.blocks[bi](x, x0) + + x_normed = self.final_norm(x) + x_flat = x_normed.reshape(-1, x_normed.size(-1)) + targets = target_ids.reshape(-1) + logits = self._softcap(self._compute_logits(x_flat)) + + if temperature != 1.0: + logits = logits / temperature + if reduction == "none": + return F.cross_entropy(logits.float(), targets, reduction="none").reshape(input_ids.shape) + + # Fused CE + Z-loss (logsumexp² regularization keeps STE gradients sharp) + logits_f = logits.float() + lse = torch.logsumexp(logits_f, dim=-1) + target_logits = logits_f.gather(1, targets.unsqueeze(1)).squeeze(1) + return (lse - target_logits).mean() + 1e-4 * (lse ** 2).mean() + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- +def build_luts(sp, vocab_size, device): + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + +def ld_val(pattern, seq_len, max_tok=int(os.environ.get("VAL_MAX_TOKENS", 500000))): + files = sorted(glob.glob(pattern)) + assert files, f"No files: {pattern}" + tok = torch.cat([ld_shard(Path(p)) for p in files]).contiguous() + if max_tok > 0: tok = tok[:max_tok + 1] + u = ((tok.numel() - 1) // seq_len) * seq_len + return tok[:u + 1] + +def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, temperature=1.0): + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_seqs = max(1, local_batch_tokens // args.train_seq_len) + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_start in range(seq_start, seq_end, local_batch_seqs): + batch_end = min(batch_start + local_batch_seqs, seq_end) + raw_start = batch_start * args.train_seq_len + raw_end = batch_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64) + x, y = local[:-1].reshape(-1, args.train_seq_len), local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + batch_loss = model(x, y, temperature=temperature).detach() + n = float(y.numel()) + loss_sum += batch_loss.to(torch.float64) * n + token_count += n + prev_ids, tgt_ids = x.reshape(-1), y.reshape(-1) + tok_bytes = base_bytes_lut[tgt_ids].to(torch.int16) + tok_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.int16) + byte_count += tok_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, token_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + val_loss = loss_sum / token_count + bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) + model.train() + return float(val_loss.item()), float(bpb) + +def eval_val_sliding(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, temperature=1.0): + seq_len = args.train_seq_len + batch_size = args.sliding_batch_size + total_tokens = val_tokens.numel() - 1 + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + all_starts = list(range(0, total_tokens - seq_len, stride)) + my_starts = all_starts[rank::world_size] + model.eval() + with torch.inference_mode(): + for i in range(0, len(my_starts), batch_size): + batch_starts = my_starts[i:i + batch_size] + starts_t = torch.tensor(batch_starts, dtype=torch.int64) + offsets = torch.arange(seq_len + 1, dtype=torch.int64) + indices = starts_t.unsqueeze(1) + offsets.unsqueeze(0) + local_batch = val_tokens[indices].to(device=device, dtype=torch.int64, non_blocking=True) + x, y = local_batch[:, :-1], local_batch[:, 1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_token_loss = model(x, y, reduction="none", temperature=temperature).detach() + for b, start in enumerate(batch_starts): + score_from = 0 if start == 0 else seq_len - stride + scored = per_token_loss[b, score_from:] + sx, sy = x[b, score_from:], y[b, score_from:] + loss_sum += scored.to(torch.float64).sum() + token_count += scored.numel() + tok_bytes = base_bytes_lut[sy].to(torch.int16) + tok_bytes += (has_leading_space_lut[sy] & ~is_boundary_token_lut[sx]).to(torch.int16) + byte_count += tok_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, token_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + val_loss = loss_sum / token_count + bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) + model.train() + return float(val_loss.item()), float(bpb) + +def find_temp(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut): + best_t, best_loss = 1.0, float("inf") + for t in [0.85, 0.90, 0.95, 1.00, 1.05, 1.10]: + loss, _ = eval_val(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, temperature=t) + if loss < best_loss: + best_loss = loss + best_t = t + return best_t + +# --------------------------------------------------------------------------- +# Checkpointing +# --------------------------------------------------------------------------- +def _ckpt_path(d, step): return os.path.join(d, f"ckpt_step{step:07d}.pt") + +def _latest_checkpoint(d): + if not os.path.isdir(d): return None + ckpts = sorted(glob.glob(os.path.join(d, "ckpt_step*.pt"))) + return ckpts[-1] if ckpts else None + +def save_checkpoint(d, step, model, optimizers, ms, _untied, ema_model, _ema_started, _ema_steps): + os.makedirs(d, exist_ok=True) + path = _ckpt_path(d, step) + tmp = path + ".tmp" + torch.save({ + "step": step, "training_time_ms": ms, "model": model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "rng_cpu": torch.get_rng_state(), "rng_cuda": torch.cuda.get_rng_state(), + "flags": {"_untied": _untied, "_ema_started": _ema_started, "_ema_steps": _ema_steps}, + "ema_model": ema_model.state_dict() if ema_model is not None else None, + }, tmp) + os.replace(tmp, path) + +def load_checkpoint(path, model, optimizers, device, ema_model=None): + p = torch.load(path, map_location=device, weights_only=False) + model.load_state_dict(p["model"], strict=True) + for opt, sd in zip(optimizers, p["optimizers"]): + opt.load_state_dict(sd) + if ema_model is not None and p.get("ema_model") is not None: + ema_model.load_state_dict(p["ema_model"], strict=True) + torch.set_rng_state(p["rng_cpu"].cpu()) + torch.cuda.set_rng_state(p["rng_cuda"].cpu()) + f = p.get("flags", {}) + return p["step"], p["training_time_ms"], f.get("_untied", False), f.get("_ema_started", False), f.get("_ema_steps", 0) + +# --------------------------------------------------------------------------- +# Main training +# --------------------------------------------------------------------------- +def main(): + args = Hyperparameters() + code = Path(__file__).read_text(encoding="utf-8") + if args.matrix_optimizer != "adamw": + global ns_orth + if args.use_gram_ns: + ns_orth = torch.compile(gram_ns_orth) + else: + ns_orth = torch.compile(ns_orth) + + global _INT8_KERNEL + _INT8_KERNEL = args.use_int8_kernel + + global _TRITON_KERNEL + _TRITON_KERNEL = args.use_triton_kernel and HAS_TRITON + + global _SCALE_FP8_STE + _SCALE_FP8_STE = args.scale_fp8 + + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = world_size > 1 + grad_accum_steps = max(1, 8 // world_size) + grad_scale = 1.0 / grad_accum_steps + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + os.makedirs("logs/cuda/xnor/", exist_ok=True) + logfile = f"logs/cuda/xnor/{args.run_id}.txt" if master_process else None + if master_process: print(logfile) + def log0(msg, console=True): + if not master_process: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + val_tokens = ld_val(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_luts(sp, args.vocab_size, device) + + # --- Model --- + # BINARIZE_ACTIVATIONS: 0=BWN, 1=full XNOR, 2=XNOR but BWN on MLP down proj + _ba = args.binarize_activations + _binarize_act = _ba >= 1 + _binarize_down = _ba == 1 + base_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + group_size=args.xnor_group_size, activation=args.activation_type, + embed_dim=args.embed_dim, fp_storage=args.fp_storage, + softcap_type=args.softcap_type, rope_type=args.rope_type, + yarn_max_len=args.yarn_max_len, train_seq_len=args.train_seq_len, + binarize_act=_binarize_act, binarize_down=_binarize_down, + smear=args.smear, attn_res=args.attn_res, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, nn.Linear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.tie_embeddings: + base_model.lm_head.weight.requires_grad_(False) + + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.recompile_limit = 16 + compiled_model = torch.compile(base_model, mode=args.compile_mode if args.compile_mode != "default" else None) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, + find_unused_parameters=not args.tie_embeddings, + static_graph=args.tie_embeddings, + gradient_as_bucket_view=True) if distributed else compiled_model + + # --- Optimizers --- + _excl = {"tok_emb.weight", "lm_head.weight"} + all_other = [(n, p) for n, p in base_model.named_parameters() if not any(eh in n for eh in _excl)] + matrix_params = [p for n, p in all_other if p.ndim == 2 and not any(pat in n for pat in CTP)] + scalar_params = [p for n, p in all_other if p.ndim < 2 or any(pat in n for pat in CTP)] + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + opt_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + + if args.matrix_optimizer == "adamw": + opt_muon = torch.optim.AdamW( + [{"params": matrix_params, "lr": args.adam_lr, "base_lr": args.adam_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.adam_wd, fused=True) + else: + opt_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, wd=args.muon_wd) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": 0.0, "base_lr": 0.0}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_tok, opt_muon, opt_scalar, opt_head] + + # --- EMA --- + ema_model = None + _ema_started = False + _ema_steps = 0 + if args.ema: + ema_model = copy.deepcopy(base_model) + for p in ema_model.parameters(): p.requires_grad_(False) + + # --- Checkpoint resume --- + _resume_step, _resume_ms, _resume_untied = 0, 0.0, False + _resume_ema_started, _resume_ema_steps = False, 0 + if args.checkpoint_every > 0: + ckpt = _latest_checkpoint(args.checkpoint_dir) + if ckpt: + log0(f"resuming from {ckpt}") + _resume_step, _resume_ms, _resume_untied, _resume_ema_started, _resume_ema_steps = \ + load_checkpoint(ckpt, base_model, optimizers, device, ema_model) + log0(f"resumed at step {_resume_step} ({_resume_ms:.0f}ms)") + else: + log0("no checkpoint, starting fresh") + + # --- Log --- + log0("--- Hyperparameters ---", console=False) + log0(" ".join(f"{a}={getattr(args,a)}" for a in sorted(dir(args)) + if not a.startswith("_") and a not in ("train_files","val_files") + and not callable(getattr(args,a))), console=False) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"XNOR-Net LLM | params:{n_params} L:{args.num_layers} d:{args.model_dim} " + f"h:{args.num_heads} kv:{args.num_kv_heads} mlp:{args.mlp_mult}x " + f"act:{args.activation_type} xnor_act:{args.binarize_activations} " + f"g:{args.xnor_group_size} ws:{world_size} ga:{grad_accum_steps}") + + # --- Data --- + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all(): + for opt in optimizers: opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step, elapsed_ms): + import math + if args.lr_schedule == "cosine": + # Cosine decay from 1.0 to 0 over the full run, with optional warmup + if max_wallclock_ms is None: + total = args.iterations + warmup = min(args.lr_warmup_steps, total) + if step < warmup: + return step / max(warmup, 1) + progress = (step - warmup) / max(total - warmup, 1) + else: + warmup_ms = args.lr_warmup_steps * (elapsed_ms / max(step, 1)) if step > 0 else 1000.0 + if elapsed_ms < warmup_ms: + return elapsed_ms / max(warmup_ms, 1e-9) + progress = (elapsed_ms - warmup_ms) / max(max_wallclock_ms - warmup_ms, 1e-9) + return 0.5 * (1.0 + math.cos(math.pi * min(progress, 1.0))) + # Linear warmdown (original) + if args.warmdown_fraction <= 0: return 1.0 + if max_wallclock_ms is None: + warmdown_start = int(args.iterations * (1.0 - args.warmdown_fraction)) + return max((args.iterations - step) / max(args.iterations * args.warmdown_fraction, 1), 0.0) if step >= warmdown_start else 1.0 + warmdown_ms = max_wallclock_ms * args.warmdown_fraction + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # --- Seq len scheduling --- + seq_len_stages = [args.train_seq_len] + if args.seq_len_schedule: + seq_len_stages = [] + sl = 128 + while sl <= args.train_seq_len: + seq_len_stages.append(sl) + sl *= 2 + if seq_len_stages[-1] != args.train_seq_len: + seq_len_stages.append(args.train_seq_len) + log0(f"seq_len_schedule: {seq_len_stages}") + + def get_seq_len(step, elapsed_ms): + if len(seq_len_stages) <= 1: + return seq_len_stages[0] + n = len(seq_len_stages) + if max_wallclock_ms is not None: + stage_idx = min(int(elapsed_ms / (max_wallclock_ms / n)), n - 1) + else: + stage_idx = min(int(step / (args.iterations / n)), n - 1) + return seq_len_stages[stage_idx] + + # --- Compiler warmup --- + warmup_seq_lens = seq_len_stages if args.seq_len_schedule else [args.train_seq_len] + if len(warmup_seq_lens) > 0: + _ms = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + _os = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws, sl in enumerate(warmup_seq_lens): + zero_grad_all() + for mi in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = mi == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, sl, grad_accum_steps) + torch.compiler.cudagraph_mark_step_begin() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): loss = model(x, y) + (loss * grad_scale).backward() + for o in optimizers: o.step() + zero_grad_all() + log0(f"warmup:{ws+1}/{len(warmup_seq_lens)} seq_len={sl}") + base_model.load_state_dict(_ms, strict=True) + for o, s in zip(optimizers, _os): o.load_state_dict(s) + zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # --- Main loop --- + training_time_ms = _resume_ms + stop_after_step = None + _untied = _resume_untied + _ema_started = _resume_ema_started + _ema_steps = _resume_ema_steps + train_loss = torch.zeros((), device=device) + torch.cuda.synchronize() + t0 = time.perf_counter() + step = _resume_step + steps_this_session = 0 + + if max_wallclock_ms is None: + if step >= args.iterations: + stop_after_step = step + elif training_time_ms >= max_wallclock_ms: + stop_after_step = step + + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms") + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms") + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + # Muon momentum warmup + if args.matrix_optimizer != "adam": + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + for g in opt_muon.param_groups: + g["momentum"] = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + + zero_grad_all() + train_loss.zero_() + current_seq_len = get_seq_len(step, elapsed_ms) + for micro in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = micro == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, current_seq_len, grad_accum_steps) + torch.compiler.cudagraph_mark_step_begin() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y) + # Compression regularizer: penalize sign entropy within groups + if args.sign_compress_reg > 0 and micro == 0: + reg = torch.tensor(0.0, device=device) + for p in base_model.parameters(): + if p.ndim == 2 and p.numel() > 65536: + g = args.xnor_group_size + pg = p.reshape(-1, g) + # soft sign: tanh(k*w) approximates sign(w) but has gradient + soft = torch.tanh(pg * 10.0) + # mean per group: +1/-1 = all same sign (compressible), 0 = balanced + reg = reg + (1.0 - soft.mean(dim=-1).abs().mean()) + loss = loss + args.sign_compress_reg * reg + train_loss.add_(loss.detach()) + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for opt in optimizers: + for g in opt.param_groups: g["lr"] = g["base_lr"] * scale + opt.step() + zero_grad_all() + + # EMA update + if ema_model is not None: + if not _ema_started: + if max_wallclock_ms is not None: + should_start = elapsed_ms >= args.ema_start_fraction * max_wallclock_ms + else: + should_start = step >= int(args.iterations * args.ema_start_fraction) + if should_start: + _ema_started = True + _ema_steps = 0 + with torch.no_grad(): + for ep, bp in zip(ema_model.parameters(), base_model.parameters()): + ep.data.copy_(bp.data) + log0(f"step:{step} ema_started") + if _ema_started: + _ema_steps += 1 + decay = min(args.ema_decay, (1.0 + _ema_steps) / (10.0 + _ema_steps)) + with torch.no_grad(): + for ep, bp in zip(ema_model.parameters(), base_model.parameters()): + ep.data.mul_(decay).add_(bp.data, alpha=1.0 - decay) + + step += 1 + steps_this_session += 1 + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # Checkpoint + if master_process and args.checkpoint_every > 0 and step % args.checkpoint_every == 0: + save_checkpoint(args.checkpoint_dir, step, base_model, optimizers, + approx_ms, _untied, ema_model, _ema_started, _ema_steps) + log0(f"step:{step} checkpoint saved") + + if args.train_log_every > 0 and step % args.train_log_every == 0: + log0(f"step:{step}/{args.iterations} loss:{train_loss.item():.4f} " + f"t:{approx_ms:.0f}ms step_avg:{approx_ms/step:.1f}ms") + + if args.churn_log_every > 0 and step % args.churn_log_every == 0: + log0(f"step:{step} churn:{churn_fn(base_model, args.xnor_group_size):.4f}") + + # Wallclock cap + if stop_after_step is None and max_wallclock_ms is not None and step % 10 == 0: + reached_cap = approx_ms >= max_wallclock_ms + if distributed: + cap_t = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(cap_t, op=dist.ReduceOp.MAX) + reached_cap = bool(cap_t.item()) + if reached_cap: + stop_after_step = step + + # --- Sign-sort permutation for compression --- + # Permute MLP hidden dimension so same-sign weights cluster together. + # Sort columns of down_proj by their sign pattern (lexicographic on sign bits). + # Apply matching row permutation to up_proj. Model output is identical. + if master_process: + sd = (ema_model if ema_model is not None and _ema_started else base_model).state_dict() + for layer_idx in range(args.num_layers): + # Find the MLP weight pairs + down_key = f"blocks.{layer_idx}.mlp.proj.weight" + if args.activation_type == "swiglu": + up_key = f"blocks.{layer_idx}.mlp.gate_up.weight" + else: + up_key = f"blocks.{layer_idx}.mlp.fc.weight" + if down_key not in sd or up_key not in sd: + continue + w_down = sd[down_key] # [dim, hidden] + w_up = sd[up_key] # [hidden, dim] (or [hidden*2, dim] for swiglu) + hidden = w_down.shape[1] + # Sort hidden dim columns of w_down by sign pattern + signs = (w_down >= 0).to(torch.int8) # [dim, hidden] + # Pack sign columns into sortable keys (use first 64 rows as sort key) + n_key_rows = min(64, signs.shape[0]) + sort_keys = torch.zeros(hidden, dtype=torch.int64, device=w_down.device) + for r in range(n_key_rows): + sort_keys += signs[r, :].long() << (n_key_rows - 1 - r) + perm = sort_keys.argsort() + # Apply permutation + sd[down_key] = w_down[:, perm] + if args.activation_type == "swiglu": + half = w_up.shape[0] // 2 + sd[up_key] = torch.cat([w_up[:half][perm], w_up[half:][perm]], dim=0) + else: + sd[up_key] = w_up[perm] + if base_model.tie_embeddings: + sd.pop("lm_head.weight", None) + q_obj, q_stats = q_sd(sd, group_size=args.xnor_group_size, + fp_storage=args.fp_storage, scale_fp8=args.scale_fp8) + buf = io.BytesIO() + torch.save(q_obj, buf) + raw_bytes = buf.getvalue() + # Compare LZMA vs Brotli, keep smallest + lzma_blob = lzma.compress(raw_bytes, preset=9) + best_blob, best_method = lzma_blob, "lzma" + comp_log = f"compression: lzma={len(lzma_blob)/1e6:.2f}MB" + if HAS_BROTLI: + brotli_blob = brotli.compress(raw_bytes, quality=11) + comp_log += f" brotli={len(brotli_blob)/1e6:.2f}MB" + if len(brotli_blob) < len(best_blob): + best_blob, best_method = brotli_blob, "brotli" + if HAS_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + zstd_blob = cctx.compress(raw_bytes) + comp_log += f" zstd={len(zstd_blob)/1e6:.2f}MB" + if len(zstd_blob) < len(best_blob): + best_blob, best_method = zstd_blob, "zstd" + log0(comp_log) + log0(f"using: {best_method}") + with open("final_model.xnor.ptz", "wb") as f: + # Method byte: 0=lzma, 1=brotli, 2=zstd + method_id = {"lzma": 0, "brotli": 1, "zstd": 2}[best_method] + f.write(bytes([method_id])) + f.write(best_blob) + artifact_bytes = 1 + len(best_blob) + code_bytes = len(code.encode("utf-8")) + total = artifact_bytes + code_bytes + log0(f"artifact:{artifact_bytes/1e6:.2f}MB binary:{q_stats['binary_params']}({q_stats['binary_bytes']}B) " + f"fp:{q_stats['fp_params']}({q_stats['fp_bytes']}B) code:{code_bytes}") + log0(f"budget:{total}/{16000000} ({total/1e6:.2f}/{16.00:.2f}MB) " + f"{'FITS' if total <= 16000000 else 'OVER'}") + + # --- Roundtrip eval --- + if distributed: dist.barrier() + with open("final_model.xnor.ptz", "rb") as f: + raw = f.read() + method_byte = raw[0] + compressed = raw[1:] + if method_byte == 2 and HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + decompressed = dctx.decompress(compressed) + elif method_byte == 1 and HAS_BROTLI: + decompressed = brotli.decompress(compressed) + else: + decompressed = lzma.decompress(compressed) + loaded = torch.load(io.BytesIO(decompressed), map_location="cpu", weights_only=False) + base_model.load_state_dict(deq_sd(loaded), strict=False) + if ema_model is not None: + ema_model.load_state_dict(deq_sd(loaded), strict=False) + torch._dynamo.reset() + q_val_loss, q_val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"final_xnor_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f}") + + opt_temp = 1.0 + if args.temp_scaling: + torch.cuda.synchronize() + t_temp = time.perf_counter() + calibration_tokens = train_loader.stream.take(65536).to(device) + opt_temp = find_temp(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + torch.cuda.synchronize() + log0(f"temp_scaling optimal_T:{opt_temp:.2f} time:{1000*(time.perf_counter()-t_temp):.0f}ms") + + if args.sliding_eval: + torch._dynamo.reset() + torch.cuda.empty_cache() + torch.cuda.synchronize() + t_sl = time.perf_counter() + sw_loss, sw_bpb = eval_val_sliding(args, base_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, stride=args.sliding_eval_stride, + temperature=opt_temp) + torch.cuda.synchronize() + log0(f"final_sliding val_loss:{sw_loss:.4f} val_bpb:{sw_bpb:.4f} " + f"(stride={args.sliding_eval_stride}, T={opt_temp:.2f}) time:{1000*(time.perf_counter()-t_sl):.0f}ms") + + if distributed: dist.destroy_process_group() + + +if __name__ == "__main__": + main() +==================================================================================================== +--- Hyperparameters --- +activation_type=signsq adam_eps=1e-08 adam_lr=0.02 adam_wd=0.04 attn_res=0 beta1=0.9 beta2=0.95 binarize_activations=2 checkpoint_dir=./checkpoints checkpoint_every=0 churn_log_every=0 compile_mode=default data_path=./data/datasets/fineweb10B_sp1024 ema=False ema_decay=0.995 ema_start_fraction=0.0 embed_dim=384 embed_lr=0.6 fp_storage=True grad_clip_norm=1.0 head_lr=0.02 iterations=20000 logit_softcap=10.0 lr_schedule=cosine lr_warmup_steps=5 matrix_lr=0.008 matrix_optimizer=muon max_wallclock_seconds=599.0 mlp_mult=4 model_dim=1024 muon_backend_steps=3 muon_momentum=0.8 muon_momentum_warmup_start=0.8 muon_momentum_warmup_steps=200 muon_wd=0.04 num_heads=8 num_kv_heads=4 num_layers=10 qk_gain_init=2.25 rope_base=5000.0 rope_type=yarn run_id=P3_full-xnor_8xH100_1775397720 scalar_lr=0.01 scale_fp8=False seed=42 seq_len_schedule=True sign_compress_reg=0.0 sliding_batch_size=512 sliding_eval=True sliding_eval_stride=48 smear=False softcap_type=poly temp_scaling=True tie_embeddings=1 tied_embed_init_std=0.005 tied_embed_lr=0.05 tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model train_batch_tokens=65536 train_log_every=1000 train_seq_len=1024 use_gram_ns=False use_int8_kernel=False use_triton_kernel=True val_batch_size=524288 val_loss_every=0 vocab_size=1024 warmdown_fraction=0.3 xnor_group_size=256 yarn_max_len=2048 +XNOR-Net LLM | params:117618768 L:10 d:1024 h:8 kv:4 mlp:4x act:signsq xnor_act:2 g:256 ws:8 ga:1 +seq_len_schedule: [128, 256, 512, 1024] +warmup:1/4 seq_len=128 +warmup:2/4 seq_len=256 +warmup:3/4 seq_len=512 +warmup:4/4 seq_len=1024 +step:1000/20000 loss:3.5811 t:38678ms step_avg:38.7ms +step:2000/20000 loss:3.4451 t:77288ms step_avg:38.6ms +step:3000/20000 loss:3.2128 t:115032ms step_avg:38.3ms +step:4000/20000 loss:3.2484 t:153372ms step_avg:38.3ms +step:5000/20000 loss:3.2370 t:191703ms step_avg:38.3ms +step:6000/20000 loss:3.1917 t:229530ms step_avg:38.3ms +step:7000/20000 loss:3.2547 t:268055ms step_avg:38.3ms +step:8000/20000 loss:3.2913 t:306838ms step_avg:38.4ms +step:9000/20000 loss:3.4470 t:345742ms step_avg:38.4ms +step:10000/20000 loss:3.1461 t:384190ms step_avg:38.4ms +step:11000/20000 loss:3.0617 t:422950ms step_avg:38.4ms +step:12000/20000 loss:3.0255 t:461140ms step_avg:38.4ms +step:13000/20000 loss:3.0365 t:500138ms step_avg:38.5ms +step:14000/20000 loss:2.9206 t:538713ms step_avg:38.5ms +step:15000/20000 loss:2.8475 t:577650ms step_avg:38.5ms +step:15540/20000 val_loss:2.8381 val_bpb:1.5816 train_time:598686ms +stopping_early: wallclock_cap train_time:598686ms +compression: lzma=15.94MB brotli=15.88MB zstd=18.35MB +using: brotli +artifact:15.88MB binary:115343360(15319040B) fp:1226832(1248416B) code:78442 +budget:15959813/16000000 (15.96/16.00MB) FITS +final_xnor_roundtrip val_loss:2.8544 val_bpb:1.5907 +temp_scaling optimal_T:1.00 time:235ms +final_sliding val_loss:2.7923 val_bpb:1.5561 (stride=48, T=1.00) time:428012ms diff --git a/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/log_10mins_seed-7.txt b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/log_10mins_seed-7.txt new file mode 100644 index 0000000000..ff2ecaab2e --- /dev/null +++ b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/log_10mins_seed-7.txt @@ -0,0 +1,1723 @@ +"""XNOR-Net LLM training script for OpenAI's Parameter Golf Challenge. +Ciprian-Florin Ifrim — April 2026 + +Based on the Binary-Weight-Network script, extended to full XNOR-Net: +both weights AND activations are binarized using sign() with per-group/per-token +scaling factors derived from the XNOR-Net paper (Rastegari et al. 2016). + +This is the REFERENCE implementation using STE-simulated XNOR via F.linear. +The Triton INT8×INT8 kernel version comes later for H100 deployment. + +Architecture: U-Net transformer with skip connections — provides error correction +for the information loss inherent in binary quantization of both weights and activations. +""" + +import copy +import glob +import io +import math +import os +import random +import sys +import time +import lzma +from pathlib import Path +try: + import brotli + HAS_BROTLI = True +except ImportError: + HAS_BROTLI = False +try: + import zstandard as zstd + HAS_ZSTD = True +except ImportError: + HAS_ZSTD = False +try: + from gram_newton_schulz import GramNewtonSchulz, POLAR_EXPRESS_COEFFICIENTS + HAS_GRAM_NS = True +except ImportError: + HAS_GRAM_NS = False +import numpy as np +import warnings +warnings.filterwarnings("ignore", category=UserWarning, module="torch._inductor") +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False +def _get_flash_attn_func(): + """GPU-aware FlashAttention selection. + H100: FA3 -> FA2 -> SDPA + 5090: FA4 -> FA2 -> SDPA + Other: FA2 -> SDPA + """ + if torch.cuda.is_available(): + sm = torch.cuda.get_device_capability() + # Hopper (SM 90) — prefer FA3 + if sm[0] == 9 and sm[1] == 0: + try: + from flash_attn_interface import flash_attn_func + return flash_attn_func + except ImportError: + pass + # Blackwell (SM 100/120) — try FA4 + if sm[0] >= 10: + try: + from flash_attn.cute import flash_attn_func + return flash_attn_func + except ImportError: + pass + # FA2 — universal fallback with CUDA kernels + try: + from flash_attn import flash_attn_func + return flash_attn_func + except ImportError: + pass + # PyTorch native SDPA — always works, no install needed + def _sdpa_flash_attn_func(q, k, v, causal=False, **kwargs): + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + o = F.scaled_dot_product_attention(q, k, v, is_causal=causal, enable_gqa=True) + return o.transpose(1, 2) + return _sdpa_flash_attn_func +flash_attn_func = _get_flash_attn_func() + +# --------------------------------------------------------------------------- +# Hyperparameters +# --------------------------------------------------------------------------- +def _e(k, d, t=str): + v = os.environ.get(k, str(d)) + if t == bool: return bool(int(v)) + return t(v) + +class Hyperparameters: + # Data + data_path = _e("DATA_PATH", "./data/datasets/fineweb10B_sp8192") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = _e("TOKENIZER_PATH", "./data/tokenizers/fineweb_8192_bpe.model") + run_id = os.environ.get("RUN_ID", f"xnor_{int(time.time())}") + seed = _e("SEED", 1337, int) + # Compile + compile_mode = _e("COMPILE_MODE", "default") + # Eval + val_batch_size = _e("VAL_BATCH_SIZE", 524288, int) + val_loss_every = _e("VAL_LOSS_EVERY", 500, int) + train_log_every = _e("TRAIN_LOG_EVERY", 10, int) + sliding_eval = _e("SLIDING_EVAL", 0, bool) + sliding_eval_stride = _e("SLIDING_EVAL_STRIDE", 16, int) + sliding_batch_size = _e("SLIDING_BATCH_SIZE", 512, int) + temp_scaling = _e("TEMP_SCALING", 0, bool) + # Training schedule + iterations = _e("ITERATIONS", 2000, int) + warmdown_fraction = _e("WARMDOWN_FRACTION", 0.15, float) + lr_warmup_steps = _e("LR_WARMUP_STEPS", 5, int) + lr_schedule = _e("LR_SCHEDULE", "linear") # "linear" (current) or "cosine" + seq_len_schedule = _e("SEQ_LEN_SCHEDULE", 0, bool) # ramp seq_len from 128 to train_seq_len + train_batch_tokens = _e("TRAIN_BATCH_TOKENS", 524288, int) + train_seq_len = _e("TRAIN_SEQ_LEN", 1024, int) + max_wallclock_seconds = _e("MAX_WALLCLOCK_SECONDS", 0.0, float) + # Architecture — from XNOR-UNet spec + vocab_size = _e("VOCAB_SIZE", 8192, int) + num_layers = _e("NUM_LAYERS", 15, int) + num_kv_heads = _e("NUM_KV_HEADS", 4, int) + model_dim = _e("MODEL_DIM", 768, int) + num_heads = _e("NUM_HEADS", 8, int) + mlp_mult = _e("MLP_MULT", 4, int) + tie_embeddings = _e("TIE_EMBEDDINGS", 1, int) + embed_dim = _e("EMBED_DIM", 312, int) + activation_type = _e("ACTIVATION", "relu2") + # Attention + rope_base = _e("ROPE_BASE", 5000.0, float) + rope_type = _e("ROPE_TYPE", "yarn") + yarn_max_len = _e("YARN_MAX_LEN", 2048, int) + qk_gain_init = _e("QK_GAIN_INIT", 2.25, float) + # Logits + logit_softcap = _e("LOGIT_SOFTCAP", 10.0, float) + softcap_type = _e("SOFTCAP_TYPE", "poly") + tied_embed_init_std = _e("TIED_EMBED_INIT_STD", 0.005, float) + # Architecture extras + smear = _e("SMEAR", 1, bool) + attn_res = _e("ATTN_RES", 0, int) # 0=disabled, 1=sub-layer, 2=pass-level + # XNOR quantization + xnor_group_size = _e("XNOR_GROUP_SIZE", 128, int) + # Controls activation binarization: 0 = none (BWN), 1 = all (full XNOR), 2 = skip MLP down proj + binarize_activations = _e("BINARIZE_ACTIVATIONS", 1, int) + # Use INT8 tensor cores for binary matmul (speed test) + use_int8_kernel = _e("USE_INT8_KERNEL", 0, bool) + # Use Triton XNOR+POPCOUNT kernel (true 1-bit matmul) + use_triton_kernel = _e("USE_TRITON_KERNEL", 0, bool) + # FP8 QAT for non-binary params + _fp_raw = os.environ.get("FP_STORAGE", "FP8") + fp_storage = True if _fp_raw == "FP8" else False + # Scale storage for binary weight groups: FP8 or BF16 + _scale_raw = os.environ.get("SCALE_STORAGE", "BF16") + scale_fp8 = True if _scale_raw == "FP8" else False + # Compression regularizer: penalizes sign entropy within groups + sign_compress_reg = _e("SIGN_COMPRESS_REG", 0.0, float) + # Optimizer + matrix_lr = _e("MATRIX_LR", 0.04, float) + scalar_lr = _e("SCALAR_LR", 0.02, float) + tied_embed_lr = _e("TIED_EMBED_LR", 0.02, float) + embed_lr = _e("EMBED_LR", 0.6, float) + head_lr = _e("HEAD_LR", 0.02, float) + matrix_optimizer = _e("MATRIX_OPTIMIZER", "muon") + muon_momentum = _e("MUON_MOMENTUM", 0.95, float) + muon_backend_steps = _e("MUON_BACKEND_STEPS", 3, int) + muon_wd = _e("MUON_WD", 0.0, float) + use_gram_ns = _e("USE_GRAM_NS", 0, bool) + muon_momentum_warmup_start = _e("MUON_MOMENTUM_WARMUP_START", 0.85, float) + muon_momentum_warmup_steps = _e("MUON_MOMENTUM_WARMUP_STEPS", 500, int) + adam_lr = _e("ADAM_LR", 0.05, float) + adam_wd = _e("ADAM_WD", 0.05, float) + beta1 = _e("BETA1", 0.9, float) + beta2 = _e("BETA2", 0.95, float) + adam_eps = _e("ADAM_EPS", 1e-8, float) + grad_clip_norm = _e("GRAD_CLIP_NORM", 0.0, float) + # EMA — compatible with binary (no zero-weight collapse unlike ternary) + ema = _e("EMA", 0, bool) + ema_decay = _e("EMA_DECAY", 0.995, float) + ema_start_fraction = _e("EMA_START_FRACTION", 0.60, float) + # Checkpointing + checkpoint_every = _e("CHECKPOINT_EVERY", 0, int) + checkpoint_dir = _e("CHECKPOINT_DIR", "./checkpoints") + # Diagnostics + churn_log_every = _e("CHURN_LOG_EVERY", 500, int) + +# Scalar/low-dim parameter names (excluded from Muon, use Adam) +CTP = ("attn_scale", "mlp_scale", "resid_mix", "q_gain", "skip_weight", + "skip_weights", "vocab_bias", "smear") + +# --------------------------------------------------------------------------- +# Binary packing (1 bit per weight, lossless) +# --------------------------------------------------------------------------- +def pack_binary(q: Tensor) -> tuple[bytes, int]: + bits = ((q.reshape(-1).to(torch.int8) + 1) // 2).numpy().astype(np.uint8) + n = len(bits) + pad = (8 - n % 8) % 8 + if pad: + bits = np.concatenate([bits, np.zeros(pad, dtype=np.uint8)]) + groups = bits.reshape(-1, 8) + packed = np.zeros(len(groups), dtype=np.uint8) + for i in range(8): + packed |= groups[:, i] << i + return packed.tobytes(), n + +def unpack_binary(data: bytes, n: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8) + bits = np.zeros((len(packed), 8), dtype=np.int8) + for i in range(8): + bits[:, i] = (packed >> i) & 1 + flat = bits.reshape(-1)[:n] + return torch.from_numpy(flat.astype(np.int8) * 2 - 1) + +# --------------------------------------------------------------------------- +# State dict serialization +# --------------------------------------------------------------------------- +def q_sd(state_dict: dict, group_size: int = 128, fp_storage=False, scale_fp8=False) -> tuple[dict, dict]: + """Binary pack large 2D weight matrices, BF16/FP8 for everything else.""" + quantized = {} + stats = {"binary_params": 0, "binary_bytes": 0, "fp_params": 0, "fp_bytes": 0} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().float().contiguous() + t_orig_shape = list(t.shape) + if t.ndim == 3: + t = t.reshape(t.shape[0], -1) + is_binary = ( + t.ndim == 2 and t.numel() > 65_536 + and "tok_emb" not in name and "lm_head" not in name + and "embed_proj" not in name + ) + if is_binary: + pad = (group_size - t.shape[1] % group_size) % group_size + t_padded = F.pad(t, (0, pad)) if pad > 0 else t + t_grouped = t_padded.reshape(-1, group_size) + scale = t_grouped.abs().mean(-1, keepdim=True).clamp(min=1e-8) + q = torch.where(t_grouped >= 0, + torch.ones_like(t_grouped, dtype=torch.int8), + -torch.ones_like(t_grouped, dtype=torch.int8)) + packed_bytes, n_bits = pack_binary(q) + if scale_fp8: + scale_stored = scale.to(torch.float8_e4m3fn).squeeze(-1) + scale_bytes_per = 1 + else: + scale_stored = scale.bfloat16().squeeze(-1) + scale_bytes_per = 2 + quantized[name] = { + "type": "binary", "packed": packed_bytes, + "scale": scale_stored, + "shape": list(t.shape), "padded_cols": t_padded.shape[1], + "group_size": group_size, "n_bits": n_bits, + "orig_shape": t_orig_shape, + } + stats["binary_params"] += t.numel() + stats["binary_bytes"] += len(packed_bytes) + scale.numel() * scale_bytes_per + elif fp_storage and t.ndim == 2: + quantized[name] = {"type": "fp8", "data": t.to(torch.float8_e4m3fn)} + stats["fp_params"] += t.numel() + stats["fp_bytes"] += t.numel() + else: + quantized[name] = {"type": "bf16", "data": t.bfloat16()} + stats["fp_params"] += t.numel() + stats["fp_bytes"] += t.numel() * 2 + return quantized, stats + +def deq_sd(quantized: dict, target_dtype=torch.bfloat16): + """Reconstruct state dict from quantized representation.""" + out = {} + for name, entry in quantized.items(): + if entry["type"] == "binary": + q = unpack_binary(entry["packed"], entry["n_bits"]) + q = q.float().reshape(-1, entry["group_size"]) + scale = entry["scale"].float().unsqueeze(-1) + # No shrinkage correction: binary q_absmean = 1.0 always + t = (q * scale).reshape(-1, entry["padded_cols"]) + shape = entry["shape"] + result = t[:shape[0], :shape[1]].to(target_dtype) + orig = entry.get("orig_shape") + out[name] = result.reshape(orig).contiguous() if orig and orig != shape else result.contiguous() + elif entry["type"] == "fp8": + out[name] = entry["data"].to(torch.float32).to(target_dtype).contiguous() + else: + out[name] = entry["data"].to(target_dtype).contiguous() + return out + +# --------------------------------------------------------------------------- +# Diagnostics +# --------------------------------------------------------------------------- +_prev_committed: dict = {} +def churn_fn(model: nn.Module, group_size: int = 128): + global _prev_committed + total = flipped = 0 + with torch.no_grad(): + for name, p in model.named_parameters(): + if p.ndim == 2 and "weight" in name and p.shape[0] > 1: + w = p.detach().float().reshape(-1, group_size) + q = torch.where(w >= 0, torch.ones_like(w), -torch.ones_like(w)).cpu().numpy() + if name in _prev_committed: + flipped += int(np.sum(q != _prev_committed[name])) + total += q.size + _prev_committed[name] = q + return flipped / max(total, 1) + +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# Muon optimizer +# --------------------------------------------------------------------------- +def ns_orth(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +def gram_ns_orth(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: + """Gram Newton-Schulz with Polar Express coefficients. 5 steps, restart at step 2. + Fully unrolled for torch.compile. bfloat16 throughout.""" + X = G.bfloat16() + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + X = X / (X.norm() + eps) + n = X.size(0) + I = torch.eye(n, device=X.device, dtype=X.dtype) + + # Phase 1: steps 0-1 + R = X @ X.T + # Step 0: a=4.0848, b=-6.8946, c=2.9270 + Z = -6.8946 * R + 2.9270 * (R @ R) + Q = Z + 4.0848 * I + RZ = 4.0848 * R + R @ Z + R = 4.0848 * RZ + Z @ RZ + # Step 1: a=3.9505, b=-6.3029, c=2.6377 + Z = -6.3029 * R + 2.6377 * (R @ R) + Q = 3.9505 * Q + Q @ Z + # No R update (next is reset) + + # Reset + X = Q @ X + R = X @ X.T + + # Phase 2: steps 2-4 + # Step 2: a=3.7418, b=-5.5913, c=2.3037 + Z = -5.5913 * R + 2.3037 * (R @ R) + Q = Z + 3.7418 * I + RZ = 3.7418 * R + R @ Z + R = 3.7418 * RZ + Z @ RZ + # Step 3: a=2.8769, b=-3.1427, c=1.2046 + Z = -3.1427 * R + 1.2046 * (R @ R) + Q = 2.8769 * Q + Q @ Z + RZ = 2.8769 * R + R @ Z + R = 2.8769 * RZ + Z @ RZ + # Step 4: a=2.8366, b=-3.0525, c=1.2012 + Z = -3.0525 * R + 1.2012 * (R @ R) + Q = 2.8366 * Q + Q @ Z + + X = Q @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr, momentum, backend_steps, nesterov=True, wd=0.0): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, wd=wd)) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr, momentum = group["lr"], group["momentum"] + backend_steps, nesterov = group["backend_steps"], group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = F.rms_norm(g.float(), (g.size(-1),)).bfloat16() + g = ns_orth(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("wd", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr:curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.mul_(1 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + +# --------------------------------------------------------------------------- +# Data loading +# --------------------------------------------------------------------------- +def ld_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" Tensor: + chunks = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int): + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start:start + per_rank_span].pin_memory().to(self.device, non_blocking=True).to(torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x, y + +# --------------------------------------------------------------------------- +# Model components +# --------------------------------------------------------------------------- +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def apply_qat_ste(w: Tensor, fp_storage) -> Tensor: + if not fp_storage: + return w + w_sim = w.to(torch.float8_e4m3fn).to(w.dtype) + return (w_sim - w).detach() + w + +class QATLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=False, fp_storage=False): + super().__init__(in_features, out_features, bias=bias) + self.fp_storage = fp_storage + def forward(self, x: Tensor) -> Tensor: + w = apply_qat_ste(self.weight, self.fp_storage) + return F.linear(x, w.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) + +class QATEmbedding(nn.Embedding): + def __init__(self, num_embeddings, embedding_dim, fp_storage=False): + super().__init__(num_embeddings, embedding_dim) + self.fp_storage = fp_storage + def forward(self, input: Tensor) -> Tensor: + w = apply_qat_ste(self.weight, self.fp_storage) + return F.embedding(input, w, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + +# --------------------------------------------------------------------------- +# XNOR-Net Linear Layers +# --------------------------------------------------------------------------- +_INT8_KERNEL = False # set from Hyperparameters.use_int8_kernel in main() + +class _INT8XNORFn(torch.autograd.Function): + """INT8 tensor core forward, BF16 backward. + Uses per-row alpha (not per-group) so we get a single torch._int_mm call. + """ + @staticmethod + def forward(ctx, x, w): + # Per-row weight scale (single matmul) + w_alpha = w.detach().abs().mean(dim=-1).clamp_(min=1e-8) # [N] + x_beta = x.detach().abs().mean(dim=-1, keepdim=True).clamp_(min=1e-8) # [..., 1] + + w_int8 = w.detach().sign().to(torch.int8) # [N, K] + x_int8 = x.detach().sign().to(torch.int8) # [..., K] + + # Flatten to 2D for torch._int_mm(A[M,K], B[K,N]) -> [M,N] + orig_shape = x.shape[:-1] + M, K, N = orig_shape.numel(), x.shape[-1], w.shape[0] + x_2d = x_int8.reshape(M, K) + wt = w_int8.t().contiguous() # [K, N] + raw = torch._int_mm(x_2d, wt) # [M, N] INT32 + + y = raw.to(x.dtype) * x_beta.reshape(M, 1) * w_alpha.unsqueeze(0) + ctx.save_for_backward(x, w, w_alpha, x_beta) + return y.reshape(*orig_shape, N) + + @staticmethod + def backward(ctx, grad_output): + x, w, w_alpha, x_beta = ctx.saved_tensors + dtype = grad_output.dtype + + w_binary = (w.sign() * w_alpha.unsqueeze(-1)).to(dtype) # [N, K] + x_binary = (x.sign() * x_beta).to(dtype) # [..., K] + + # STE: gradient flows through sign() as identity + grad_x = grad_output @ w_binary + M = grad_output.shape[:-1].numel() + grad_w = grad_output.reshape(M, -1).t() @ x_binary.reshape(M, -1) + return grad_x, grad_w + +# --------------------------------------------------------------------------- +# Triton XNOR+POPCOUNT kernel — true 1-bit matmul +# --------------------------------------------------------------------------- +_TRITON_KERNEL = False # set from Hyperparameters in main() +_SCALE_FP8_STE = False # simulate FP8 scale quantization during training + +def pack_signs(x: Tensor) -> Tensor: + """Pack sign bits into int32. +1 -> bit 1, -1 -> bit 0. + x: [..., K] any dtype. K must be divisible by 32. + returns: [..., K//32] int32 + """ + K = x.shape[-1] + assert K % 32 == 0, f"K={K} must be divisible by 32" + bits = (x > 0).view(*x.shape[:-1], K // 32, 32) + shifts = torch.arange(32, device=x.device, dtype=torch.int32) + return (bits.to(torch.int32) << shifts).sum(dim=-1) + +if HAS_TRITON: + @triton.jit + def _xnor_matmul_fwd_kernel( + X_ptr, W_ptr, Y_ptr, alpha_ptr, beta_ptr, + M, N, K_packed, GROUP_PACKED: tl.constexpr, num_groups, + stride_xm, stride_wn, stride_yn, stride_alpha_n, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + ): + pid = tl.program_id(0) + num_n = tl.cdiv(N, BLOCK_N) + pid_m = pid // num_n + pid_n = pid % num_n + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_m = offs_m < M + mask_n = offs_n < N + + y_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.bfloat16) + + x_base = X_ptr + offs_m * stride_xm + w_base = W_ptr + offs_n * stride_wn + + # Loop over groups, each with GROUP_PACKED packed int32s + for g in range(num_groups): + group_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + k_start = g * GROUP_PACKED + for k in range(GROUP_PACKED): + x = tl.load(x_base + k_start + k, mask=mask_m, other=0) + w = tl.load(w_base + k_start + k, mask=mask_n, other=0) + # Cast broadcast XOR to int32 to prevent int64 promotion in popc + diff = (x[:, None] ^ w[None, :]).to(tl.int32) + group_acc += tl.extra.cuda.libdevice.popc(diff) + + # diff_bits = popc(XOR), dot = GROUP_SIZE - 2 * diff_bits + group_dot = (GROUP_PACKED * 32 - 2 * group_acc).to(tl.bfloat16) + + # Per-group alpha: alpha[n, g] + alpha_g = tl.load(alpha_ptr + offs_n * stride_alpha_n + g, mask=mask_n, other=1.0).to(tl.bfloat16) + y_acc += group_dot * alpha_g[None, :] + + # Per-token beta + beta = tl.load(beta_ptr + offs_m, mask=mask_m, other=1.0).to(tl.bfloat16) + y_acc = y_acc * beta[:, None] + + # Store + y_ptrs = Y_ptr + offs_m[:, None] * stride_yn + offs_n[None, :] + tl.store(y_ptrs, y_acc, mask=mask_m[:, None] & mask_n[None, :]) + +class _TritonXNORFn(torch.autograd.Function): + """Triton XNOR+POPCOUNT forward with per-group alpha, BF16 backward.""" + @staticmethod + def forward(ctx, x, w, group_size): + N, K = w.shape + assert K % group_size == 0, f"K={K} must be divisible by group_size={group_size}" + assert K % 32 == 0, f"K={K} must be divisible by 32" + num_groups = K // group_size + GROUP_PACKED = group_size // 32 + + # Per-group weight scale: [N, num_groups] + w_alpha = w.detach().reshape(N, num_groups, group_size).abs().mean(dim=-1).clamp_(min=1e-8) + if _SCALE_FP8_STE: + w_alpha = w_alpha.to(torch.float8_e4m3fn).to(w_alpha.dtype) + # Per-token activation scale: [..., 1] + x_beta_kd = x.detach().abs().mean(dim=-1, keepdim=True).clamp_(min=1e-8) + + # Pack signs into int32 + x_packed = pack_signs(x.detach()) # [..., K//32] + w_packed = pack_signs(w.detach()) # [N, K//32] + + orig_shape = x.shape[:-1] + M = orig_shape.numel() + K_packed = K // 32 + + x_2d = x_packed.reshape(M, K_packed).contiguous() + w_alpha_c = w_alpha.contiguous() # [N, num_groups] + beta_flat = x_beta_kd.reshape(M).contiguous() + + Y = torch.empty(M, N, device=x.device, dtype=torch.bfloat16) + + BLOCK_M, BLOCK_N = 64, 64 + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),) + + _xnor_matmul_fwd_kernel[grid]( + x_2d, w_packed, Y, w_alpha_c, beta_flat, + M, N, K_packed, GROUP_PACKED, num_groups, + x_2d.stride(0), w_packed.stride(0), Y.stride(0), w_alpha_c.stride(0), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + ) + + ctx.save_for_backward(x, w, w_alpha, x_beta_kd) + ctx.group_size = group_size + return Y.reshape(*orig_shape, N) + + @staticmethod + def backward(ctx, grad_output): + x, w, w_alpha, x_beta_kd = ctx.saved_tensors + dtype = grad_output.dtype + group_size = ctx.group_size + + # Reconstruct per-group binary weights for backward + N, K = w.shape + num_groups = K // group_size + w_binary = (w.sign().reshape(N, num_groups, group_size) * + w_alpha.unsqueeze(-1)).reshape(N, K).to(dtype) + x_binary = (x.sign() * x_beta_kd).to(dtype) + + grad_x = grad_output @ w_binary + M = grad_output.shape[:-1].numel() + grad_w = grad_output.reshape(M, -1).t() @ x_binary.reshape(M, -1) + return grad_x, grad_w, None # None for group_size + +class XNORLinear(nn.Linear): + """XNOR-Net linear layer: binarizes BOTH weights AND activations. + + Forward pass: + 1. Weight binarization: W_bin = sign(W) * alpha, alpha = mean(|W|) per group + 2. Activation binarization: X_bin = sign(X) * beta, beta = mean(|X|) per token + 3. Output = F.linear(X_bin, W_bin) + + Backward pass: Straight-Through Estimator (STE) — gradients flow through as + if the binarization was identity. This matches the XNOR-Net paper's training + algorithm (Algorithm 1). + + When binarize_act=False, this degrades to Binary-Weight-Network (BWN). + """ + def __init__(self, in_features, out_features, bias=False, group_size=128, + binarize_act=True): + super().__init__(in_features, out_features, bias=bias) + self.group_size = group_size + self.binarize_act = binarize_act + + def forward(self, x: Tensor) -> Tensor: + # Triton XNOR+POPCOUNT path (full XNOR only, per-row alpha) + if _TRITON_KERNEL and self.binarize_act: + y = _TritonXNORFn.apply(x, self.weight, self.group_size) + if self.bias is not None: + y = y + self.bias.to(y.dtype) + return y + # INT8 tensor core path (full XNOR only, per-row alpha) + if _INT8_KERNEL and self.binarize_act: + y = _INT8XNORFn.apply(x, self.weight) + if self.bias is not None: + y = y + self.bias.to(y.dtype) + return y + + # Standard STE path (per-group alpha) + w = self.weight.bfloat16() + g = self.group_size + + # --- Weight binarization (STE) --- + w_g = w.reshape(-1, g) + alpha = w_g.abs().mean(-1, keepdim=True).clamp(min=1e-8) + # Simulate FP8 scale quantization so model compensates for roundtrip error + if _SCALE_FP8_STE: + alpha_q = alpha.to(torch.float8_e4m3fn).to(alpha.dtype) + alpha = alpha + (alpha_q - alpha).detach() + w_sign = torch.where(w_g >= 0, torch.ones_like(w_g), -torch.ones_like(w_g)) + w_binary = w + ((w_sign * alpha).reshape(w.shape) - w).detach() + + # --- Activation binarization (STE) --- + if self.binarize_act: + beta = x.abs().mean(dim=-1, keepdim=True).clamp(min=1e-8) + x_sign = torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x)) + x_binary = x + ((x_sign * beta) - x).detach() + else: + x_binary = x + + return F.linear(x_binary, w_binary, + self.bias.to(x.dtype) if self.bias is not None else None) + + +class NormedXNORLinear(XNORLinear): + """XNOR linear with RMSNorm on input. + Used for output projections receiving un-normalized activations (attention out, MLP down). + The RMSNorm ensures zero-mean input before binarization — critical for sign() quality + (XNOR-Net paper: B-A-C-P block order, Table 3b: 30.3% vs 44.2%). + """ + def forward(self, x: Tensor) -> Tensor: + return super().forward(F.rms_norm(x, (x.size(-1),))) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(p in name for p in CTP)) and param.dtype != torch.float32: + param.data = param.data.float() + +# --------------------------------------------------------------------------- +# Rotary embeddings +# --------------------------------------------------------------------------- +class Rotary(nn.Module): + def __init__(self, dim, base=10000.0, rope_type="rope", yarn_max_len=4096, train_seq_len=1024): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + if rope_type == "yarn": + scale = train_seq_len / yarn_max_len + freq_idx = torch.arange(0, dim, 2, dtype=torch.float32) + ramp = torch.clamp((freq_idx / dim - 0.25) / 0.75, 0.0, 1.0) + inv_freq = inv_freq / (ramp * (1.0 / scale - 1.0) + 1.0) + # Precompute for max seq_len, slice at runtime + t = torch.arange(train_seq_len, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self.register_buffer("_cos", freqs.cos()[None, :, None, :], persistent=False) + self.register_buffer("_sin", freqs.sin()[None, :, None, :], persistent=False) + + def forward(self, seq_len, device, dtype): + return self._cos[:, :seq_len].to(dtype=dtype), self._sin[:, :seq_len].to(dtype=dtype) + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- +class CausalSelfAttention(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + group_size=128, rope_type="rope", yarn_max_len=4096, + train_seq_len=1024, binarize_act=True): + super().__init__() + self.num_heads, self.num_kv_heads = num_heads, num_kv_heads + self.head_dim = dim // num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + # Fused QKV — single XNOR matmul + self.c_qkv = XNORLinear(dim, self.q_size + 2 * self.kv_size, bias=False, + group_size=group_size, binarize_act=binarize_act) + # Output projection with RMSNorm before binarization (B-A-C-P order) + self.proj = NormedXNORLinear(dim, dim, bias=False, + group_size=group_size, binarize_act=binarize_act) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, + rope_type=rope_type, yarn_max_len=yarn_max_len, + train_seq_len=train_seq_len) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + qkv_out = self.c_qkv(x) + q_out, k_out, v_out = qkv_out.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q_out.reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = k_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = v_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + # QK normalization (BF16, stays float — non-negotiable for attention quality) + q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + # FlashAttention-3 (BF16 — attention scores/softmax/V are NEVER binarized) + y = flash_attn_func(q.contiguous(), k.contiguous(), v.contiguous(), causal=True) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + +# --------------------------------------------------------------------------- +# MLP +# --------------------------------------------------------------------------- +class MLP(nn.Module): + def __init__(self, dim, mlp_mult, group_size=128, activation="relu2", + binarize_act=True, binarize_down=True): + super().__init__() + hidden = mlp_mult * dim + self.activation = activation + if activation == "swiglu": + self.gate_up = XNORLinear(dim, hidden * 2, bias=False, + group_size=group_size, binarize_act=binarize_act) + else: + self.fc = XNORLinear(dim, hidden, bias=False, + group_size=group_size, binarize_act=binarize_act) + self.proj = NormedXNORLinear(hidden, dim, bias=False, + group_size=group_size, binarize_act=binarize_down) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + if self.activation == "swiglu": + gu = self.gate_up(x) + gate, up = gu.chunk(2, dim=-1) + return self.proj(F.silu(gate) * up) + elif self.activation == "signsq": + h = self.fc(x) + return self.proj(h * h.abs()) + else: # relu2 + return self.proj(torch.relu(self.fc(x)).square()) + +# --------------------------------------------------------------------------- +# SmearModule +# --------------------------------------------------------------------------- +class SmearModule(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + cumsum = x.cumsum(dim=1) + counts = torch.arange(1, x.size(1) + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smeared = cumsum / counts + gate = torch.tanh(self.gate.to(dtype=x.dtype)) + return x + gate * (smeared - x) + +# --------------------------------------------------------------------------- +# Transformer block +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# Attention Residuals (AttnRes) — depth-wise attention over prior outputs +# From: "Attention Residuals" (Kimi Team, 2026) +# --------------------------------------------------------------------------- +class AttnRes(nn.Module): + def __init__(self, dim, n_queries): + super().__init__() + # Zero-init -> uniform weights at start -> equivalent to standard residual + self.queries = nn.ParameterList([ + nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + for _ in range(n_queries) + ]) + self.key_norm_weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + + def attend(self, stored, query_idx): + w = self.queries[query_idx] + logits = [] + for v in stored: + k = F.rms_norm(v, (v.size(-1),)) * self.key_norm_weight.to(v.dtype) + logits.append(torch.einsum('d, b t d -> b t', w.to(v.dtype), k)) + logits = torch.stack(logits, dim=0) + weights = F.softmax(logits, dim=0) + out = torch.zeros_like(stored[0]) + for i, v in enumerate(stored): + out = out + weights[i].unsqueeze(-1) * v + return out + +# --------------------------------------------------------------------------- +# Transformer Block +# --------------------------------------------------------------------------- +class Block(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, group_size=128, activation="relu2", + rope_type="rope", yarn_max_len=4096, train_seq_len=1024, + binarize_act=True, binarize_down=True, smear=False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, + qk_gain_init, group_size, rope_type, + yarn_max_len, train_seq_len, binarize_act) + self.mlp = MLP(dim, mlp_mult, group_size, activation, binarize_act, binarize_down) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearModule(dim) if smear else None + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0] * x + mix[1] * x0 + n = self.attn_norm(x) + x = x + self.attn_scale.to(dtype=x.dtype) * self.attn(n) + x = x + self.mlp_scale.to(dtype=x.dtype) * self.mlp(self.mlp_norm(x)) + if self.smear is not None: + x = self.smear(x) + return x + +# --------------------------------------------------------------------------- +# GPT with U-Net skip connections +# --------------------------------------------------------------------------- +class GPT(nn.Module): + def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, + mlp_mult, tie_embeddings, tied_embed_init_std, logit_softcap, + rope_base, qk_gain_init, group_size=128, activation="relu2", + embed_dim=0, fp_storage=False, softcap_type="poly", + rope_type="rope", yarn_max_len=4096, train_seq_len=1024, + binarize_act=True, binarize_down=True, smear=False, attn_res=0): + super().__init__() + self.tie_embeddings = tie_embeddings + self.logit_softcap = logit_softcap + self.softcap_type = softcap_type + self.fp_storage = fp_storage + self.embed_dim = embed_dim if embed_dim > 0 else model_dim + self.attn_res_mode = attn_res + + # Embedding (FP8 QAT — NEVER binarized, confirmed 0.11 bpb RT gap) + self.tok_emb = QATEmbedding(vocab_size, self.embed_dim, fp_storage=fp_storage) + # Projection from embed_dim to model_dim (and reverse for tied logits) + self.embed_proj = QATLinear(self.embed_dim, model_dim, bias=False, + fp_storage=fp_storage) if self.embed_dim != model_dim else None + self.embed_proj_rev = QATLinear(model_dim, self.embed_dim, bias=False, + fp_storage=fp_storage) if self.embed_dim != model_dim else None + + # U-Net structure + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + # Skip weights initialized to ONES (not zeros — zeros costs 0.010 bpb) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + + # Transformer blocks + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, group_size, activation, rope_type, yarn_max_len, + train_seq_len, binarize_act, binarize_down, smear) + for _ in range(num_layers) + ]) + self.final_norm = RMSNorm() + + # Attention Residuals (replaces U-Net skips when enabled) + self.attn_res = None + if attn_res == 2: # pass-level + self.attn_res = AttnRes(model_dim, n_queries=num_layers) + elif attn_res == 1: # sub-layer + self.attn_res = AttnRes(model_dim, n_queries=num_layers * 2) + + # Logit head (tied to embedding) + self.lm_head = QATLinear(model_dim, vocab_size, bias=False, fp_storage=fp_storage) + self.lm_head._zero_init = True + if tie_embeddings: + self.lm_head.weight.requires_grad_(False) + self.vocab_bias = nn.Parameter(torch.zeros(vocab_size, dtype=torch.float32)) + + self._init_weights(tied_embed_init_std) + + def _init_weights(self, tied_embed_init_std): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=tied_embed_init_std) + for module in self.modules(): + if isinstance(module, XNORLinear) and not getattr(module, "_zero_init", False): + nn.init.normal_(module.weight, mean=0.0, std=0.02) + elif isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _compute_logits(self, x: Tensor) -> Tensor: + if self.tie_embeddings: + proj = self.embed_proj_rev(x) if self.embed_proj_rev is not None else x + return F.linear(proj, self.tok_emb.weight.to(x.dtype)) + return self.lm_head(x) + + def _softcap(self, logits: Tensor) -> Tensor: + s = self.logit_softcap + if self.softcap_type == "tanh": + return s * torch.tanh(logits / s) + # Poly softcap — fuses with torch.compile (tanh does not) + x_sc = torch.clamp(logits / s, -2.0, 2.0) + x2 = x_sc * x_sc + return s * torch.clamp(x_sc * (1.0 - x2 / 3.0 + x2 * x2 / 15.0), -1.0, 1.0) + + def forward(self, input_ids: Tensor, target_ids: Tensor, reduction="mean", temperature=1.0): + x = self.tok_emb(input_ids) + x = x.float() + if self.embed_proj is not None: + x = self.embed_proj(x) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.attn_res_mode == 2: + # Pass-level AttnRes: attend over all previous block outputs + stored = [x] + for i, block in enumerate(self.blocks): + x = self.attn_res.attend(stored, i) + x = block(x, x0) + stored.append(x) + elif self.attn_res_mode == 1: + # Sub-layer AttnRes: attend before each attn and mlp sub-layer + stored = [x] + q_idx = 0 + for i, block in enumerate(self.blocks): + # Attention sub-layer + x = self.attn_res.attend(stored, q_idx) + mix = block.resid_mix.to(dtype=x.dtype) + x = mix[0] * x + mix[1] * x0 + attn_out = block.attn_scale.to(dtype=x.dtype) * block.attn(block.attn_norm(x)) + stored.append(attn_out) + q_idx += 1 + # MLP sub-layer + x = self.attn_res.attend(stored, q_idx) + mlp_out = block.mlp_scale.to(dtype=x.dtype) * block.mlp(block.mlp_norm(x)) + stored.append(mlp_out) + q_idx += 1 + if block.smear is not None: + x = block.smear(self.attn_res.attend(stored, q_idx - 1)) + else: + # Standard U-Net (mode 0) + skips = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype) * skips.pop() + x = self.blocks[bi](x, x0) + + x_normed = self.final_norm(x) + x_flat = x_normed.reshape(-1, x_normed.size(-1)) + targets = target_ids.reshape(-1) + logits = self._softcap(self._compute_logits(x_flat)) + + if temperature != 1.0: + logits = logits / temperature + if reduction == "none": + return F.cross_entropy(logits.float(), targets, reduction="none").reshape(input_ids.shape) + + # Fused CE + Z-loss (logsumexp² regularization keeps STE gradients sharp) + logits_f = logits.float() + lse = torch.logsumexp(logits_f, dim=-1) + target_logits = logits_f.gather(1, targets.unsqueeze(1)).squeeze(1) + return (lse - target_logits).mean() + 1e-4 * (lse ** 2).mean() + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- +def build_luts(sp, vocab_size, device): + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + +def ld_val(pattern, seq_len, max_tok=int(os.environ.get("VAL_MAX_TOKENS", 500000))): + files = sorted(glob.glob(pattern)) + assert files, f"No files: {pattern}" + tok = torch.cat([ld_shard(Path(p)) for p in files]).contiguous() + if max_tok > 0: tok = tok[:max_tok + 1] + u = ((tok.numel() - 1) // seq_len) * seq_len + return tok[:u + 1] + +def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, temperature=1.0): + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_seqs = max(1, local_batch_tokens // args.train_seq_len) + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_start in range(seq_start, seq_end, local_batch_seqs): + batch_end = min(batch_start + local_batch_seqs, seq_end) + raw_start = batch_start * args.train_seq_len + raw_end = batch_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64) + x, y = local[:-1].reshape(-1, args.train_seq_len), local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + batch_loss = model(x, y, temperature=temperature).detach() + n = float(y.numel()) + loss_sum += batch_loss.to(torch.float64) * n + token_count += n + prev_ids, tgt_ids = x.reshape(-1), y.reshape(-1) + tok_bytes = base_bytes_lut[tgt_ids].to(torch.int16) + tok_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.int16) + byte_count += tok_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, token_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + val_loss = loss_sum / token_count + bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) + model.train() + return float(val_loss.item()), float(bpb) + +def eval_val_sliding(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, temperature=1.0): + seq_len = args.train_seq_len + batch_size = args.sliding_batch_size + total_tokens = val_tokens.numel() - 1 + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + all_starts = list(range(0, total_tokens - seq_len, stride)) + my_starts = all_starts[rank::world_size] + model.eval() + with torch.inference_mode(): + for i in range(0, len(my_starts), batch_size): + batch_starts = my_starts[i:i + batch_size] + starts_t = torch.tensor(batch_starts, dtype=torch.int64) + offsets = torch.arange(seq_len + 1, dtype=torch.int64) + indices = starts_t.unsqueeze(1) + offsets.unsqueeze(0) + local_batch = val_tokens[indices].to(device=device, dtype=torch.int64, non_blocking=True) + x, y = local_batch[:, :-1], local_batch[:, 1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_token_loss = model(x, y, reduction="none", temperature=temperature).detach() + for b, start in enumerate(batch_starts): + score_from = 0 if start == 0 else seq_len - stride + scored = per_token_loss[b, score_from:] + sx, sy = x[b, score_from:], y[b, score_from:] + loss_sum += scored.to(torch.float64).sum() + token_count += scored.numel() + tok_bytes = base_bytes_lut[sy].to(torch.int16) + tok_bytes += (has_leading_space_lut[sy] & ~is_boundary_token_lut[sx]).to(torch.int16) + byte_count += tok_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, token_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + val_loss = loss_sum / token_count + bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) + model.train() + return float(val_loss.item()), float(bpb) + +def find_temp(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut): + best_t, best_loss = 1.0, float("inf") + for t in [0.85, 0.90, 0.95, 1.00, 1.05, 1.10]: + loss, _ = eval_val(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, temperature=t) + if loss < best_loss: + best_loss = loss + best_t = t + return best_t + +# --------------------------------------------------------------------------- +# Checkpointing +# --------------------------------------------------------------------------- +def _ckpt_path(d, step): return os.path.join(d, f"ckpt_step{step:07d}.pt") + +def _latest_checkpoint(d): + if not os.path.isdir(d): return None + ckpts = sorted(glob.glob(os.path.join(d, "ckpt_step*.pt"))) + return ckpts[-1] if ckpts else None + +def save_checkpoint(d, step, model, optimizers, ms, _untied, ema_model, _ema_started, _ema_steps): + os.makedirs(d, exist_ok=True) + path = _ckpt_path(d, step) + tmp = path + ".tmp" + torch.save({ + "step": step, "training_time_ms": ms, "model": model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "rng_cpu": torch.get_rng_state(), "rng_cuda": torch.cuda.get_rng_state(), + "flags": {"_untied": _untied, "_ema_started": _ema_started, "_ema_steps": _ema_steps}, + "ema_model": ema_model.state_dict() if ema_model is not None else None, + }, tmp) + os.replace(tmp, path) + +def load_checkpoint(path, model, optimizers, device, ema_model=None): + p = torch.load(path, map_location=device, weights_only=False) + model.load_state_dict(p["model"], strict=True) + for opt, sd in zip(optimizers, p["optimizers"]): + opt.load_state_dict(sd) + if ema_model is not None and p.get("ema_model") is not None: + ema_model.load_state_dict(p["ema_model"], strict=True) + torch.set_rng_state(p["rng_cpu"].cpu()) + torch.cuda.set_rng_state(p["rng_cuda"].cpu()) + f = p.get("flags", {}) + return p["step"], p["training_time_ms"], f.get("_untied", False), f.get("_ema_started", False), f.get("_ema_steps", 0) + +# --------------------------------------------------------------------------- +# Main training +# --------------------------------------------------------------------------- +def main(): + args = Hyperparameters() + code = Path(__file__).read_text(encoding="utf-8") + if args.matrix_optimizer != "adamw": + global ns_orth + if args.use_gram_ns: + ns_orth = torch.compile(gram_ns_orth) + else: + ns_orth = torch.compile(ns_orth) + + global _INT8_KERNEL + _INT8_KERNEL = args.use_int8_kernel + + global _TRITON_KERNEL + _TRITON_KERNEL = args.use_triton_kernel and HAS_TRITON + + global _SCALE_FP8_STE + _SCALE_FP8_STE = args.scale_fp8 + + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = world_size > 1 + grad_accum_steps = max(1, 8 // world_size) + grad_scale = 1.0 / grad_accum_steps + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + os.makedirs("logs/cuda/xnor/", exist_ok=True) + logfile = f"logs/cuda/xnor/{args.run_id}.txt" if master_process else None + if master_process: print(logfile) + def log0(msg, console=True): + if not master_process: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + val_tokens = ld_val(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_luts(sp, args.vocab_size, device) + + # --- Model --- + # BINARIZE_ACTIVATIONS: 0=BWN, 1=full XNOR, 2=XNOR but BWN on MLP down proj + _ba = args.binarize_activations + _binarize_act = _ba >= 1 + _binarize_down = _ba == 1 + base_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + group_size=args.xnor_group_size, activation=args.activation_type, + embed_dim=args.embed_dim, fp_storage=args.fp_storage, + softcap_type=args.softcap_type, rope_type=args.rope_type, + yarn_max_len=args.yarn_max_len, train_seq_len=args.train_seq_len, + binarize_act=_binarize_act, binarize_down=_binarize_down, + smear=args.smear, attn_res=args.attn_res, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, nn.Linear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.tie_embeddings: + base_model.lm_head.weight.requires_grad_(False) + + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.recompile_limit = 16 + compiled_model = torch.compile(base_model, mode=args.compile_mode if args.compile_mode != "default" else None) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, + find_unused_parameters=not args.tie_embeddings, + static_graph=args.tie_embeddings, + gradient_as_bucket_view=True) if distributed else compiled_model + + # --- Optimizers --- + _excl = {"tok_emb.weight", "lm_head.weight"} + all_other = [(n, p) for n, p in base_model.named_parameters() if not any(eh in n for eh in _excl)] + matrix_params = [p for n, p in all_other if p.ndim == 2 and not any(pat in n for pat in CTP)] + scalar_params = [p for n, p in all_other if p.ndim < 2 or any(pat in n for pat in CTP)] + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + opt_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + + if args.matrix_optimizer == "adamw": + opt_muon = torch.optim.AdamW( + [{"params": matrix_params, "lr": args.adam_lr, "base_lr": args.adam_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.adam_wd, fused=True) + else: + opt_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, wd=args.muon_wd) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": 0.0, "base_lr": 0.0}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_tok, opt_muon, opt_scalar, opt_head] + + # --- EMA --- + ema_model = None + _ema_started = False + _ema_steps = 0 + if args.ema: + ema_model = copy.deepcopy(base_model) + for p in ema_model.parameters(): p.requires_grad_(False) + + # --- Checkpoint resume --- + _resume_step, _resume_ms, _resume_untied = 0, 0.0, False + _resume_ema_started, _resume_ema_steps = False, 0 + if args.checkpoint_every > 0: + ckpt = _latest_checkpoint(args.checkpoint_dir) + if ckpt: + log0(f"resuming from {ckpt}") + _resume_step, _resume_ms, _resume_untied, _resume_ema_started, _resume_ema_steps = \ + load_checkpoint(ckpt, base_model, optimizers, device, ema_model) + log0(f"resumed at step {_resume_step} ({_resume_ms:.0f}ms)") + else: + log0("no checkpoint, starting fresh") + + # --- Log --- + log0("--- Hyperparameters ---", console=False) + log0(" ".join(f"{a}={getattr(args,a)}" for a in sorted(dir(args)) + if not a.startswith("_") and a not in ("train_files","val_files") + and not callable(getattr(args,a))), console=False) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"XNOR-Net LLM | params:{n_params} L:{args.num_layers} d:{args.model_dim} " + f"h:{args.num_heads} kv:{args.num_kv_heads} mlp:{args.mlp_mult}x " + f"act:{args.activation_type} xnor_act:{args.binarize_activations} " + f"g:{args.xnor_group_size} ws:{world_size} ga:{grad_accum_steps}") + + # --- Data --- + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all(): + for opt in optimizers: opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step, elapsed_ms): + import math + if args.lr_schedule == "cosine": + # Cosine decay from 1.0 to 0 over the full run, with optional warmup + if max_wallclock_ms is None: + total = args.iterations + warmup = min(args.lr_warmup_steps, total) + if step < warmup: + return step / max(warmup, 1) + progress = (step - warmup) / max(total - warmup, 1) + else: + warmup_ms = args.lr_warmup_steps * (elapsed_ms / max(step, 1)) if step > 0 else 1000.0 + if elapsed_ms < warmup_ms: + return elapsed_ms / max(warmup_ms, 1e-9) + progress = (elapsed_ms - warmup_ms) / max(max_wallclock_ms - warmup_ms, 1e-9) + return 0.5 * (1.0 + math.cos(math.pi * min(progress, 1.0))) + # Linear warmdown (original) + if args.warmdown_fraction <= 0: return 1.0 + if max_wallclock_ms is None: + warmdown_start = int(args.iterations * (1.0 - args.warmdown_fraction)) + return max((args.iterations - step) / max(args.iterations * args.warmdown_fraction, 1), 0.0) if step >= warmdown_start else 1.0 + warmdown_ms = max_wallclock_ms * args.warmdown_fraction + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # --- Seq len scheduling --- + seq_len_stages = [args.train_seq_len] + if args.seq_len_schedule: + seq_len_stages = [] + sl = 128 + while sl <= args.train_seq_len: + seq_len_stages.append(sl) + sl *= 2 + if seq_len_stages[-1] != args.train_seq_len: + seq_len_stages.append(args.train_seq_len) + log0(f"seq_len_schedule: {seq_len_stages}") + + def get_seq_len(step, elapsed_ms): + if len(seq_len_stages) <= 1: + return seq_len_stages[0] + n = len(seq_len_stages) + if max_wallclock_ms is not None: + stage_idx = min(int(elapsed_ms / (max_wallclock_ms / n)), n - 1) + else: + stage_idx = min(int(step / (args.iterations / n)), n - 1) + return seq_len_stages[stage_idx] + + # --- Compiler warmup --- + warmup_seq_lens = seq_len_stages if args.seq_len_schedule else [args.train_seq_len] + if len(warmup_seq_lens) > 0: + _ms = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + _os = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws, sl in enumerate(warmup_seq_lens): + zero_grad_all() + for mi in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = mi == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, sl, grad_accum_steps) + torch.compiler.cudagraph_mark_step_begin() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): loss = model(x, y) + (loss * grad_scale).backward() + for o in optimizers: o.step() + zero_grad_all() + log0(f"warmup:{ws+1}/{len(warmup_seq_lens)} seq_len={sl}") + base_model.load_state_dict(_ms, strict=True) + for o, s in zip(optimizers, _os): o.load_state_dict(s) + zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # --- Main loop --- + training_time_ms = _resume_ms + stop_after_step = None + _untied = _resume_untied + _ema_started = _resume_ema_started + _ema_steps = _resume_ema_steps + train_loss = torch.zeros((), device=device) + torch.cuda.synchronize() + t0 = time.perf_counter() + step = _resume_step + steps_this_session = 0 + + if max_wallclock_ms is None: + if step >= args.iterations: + stop_after_step = step + elif training_time_ms >= max_wallclock_ms: + stop_after_step = step + + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms") + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms") + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + # Muon momentum warmup + if args.matrix_optimizer != "adam": + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + for g in opt_muon.param_groups: + g["momentum"] = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + + zero_grad_all() + train_loss.zero_() + current_seq_len = get_seq_len(step, elapsed_ms) + for micro in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = micro == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, current_seq_len, grad_accum_steps) + torch.compiler.cudagraph_mark_step_begin() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y) + # Compression regularizer: penalize sign entropy within groups + if args.sign_compress_reg > 0 and micro == 0: + reg = torch.tensor(0.0, device=device) + for p in base_model.parameters(): + if p.ndim == 2 and p.numel() > 65536: + g = args.xnor_group_size + pg = p.reshape(-1, g) + # soft sign: tanh(k*w) approximates sign(w) but has gradient + soft = torch.tanh(pg * 10.0) + # mean per group: +1/-1 = all same sign (compressible), 0 = balanced + reg = reg + (1.0 - soft.mean(dim=-1).abs().mean()) + loss = loss + args.sign_compress_reg * reg + train_loss.add_(loss.detach()) + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for opt in optimizers: + for g in opt.param_groups: g["lr"] = g["base_lr"] * scale + opt.step() + zero_grad_all() + + # EMA update + if ema_model is not None: + if not _ema_started: + if max_wallclock_ms is not None: + should_start = elapsed_ms >= args.ema_start_fraction * max_wallclock_ms + else: + should_start = step >= int(args.iterations * args.ema_start_fraction) + if should_start: + _ema_started = True + _ema_steps = 0 + with torch.no_grad(): + for ep, bp in zip(ema_model.parameters(), base_model.parameters()): + ep.data.copy_(bp.data) + log0(f"step:{step} ema_started") + if _ema_started: + _ema_steps += 1 + decay = min(args.ema_decay, (1.0 + _ema_steps) / (10.0 + _ema_steps)) + with torch.no_grad(): + for ep, bp in zip(ema_model.parameters(), base_model.parameters()): + ep.data.mul_(decay).add_(bp.data, alpha=1.0 - decay) + + step += 1 + steps_this_session += 1 + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # Checkpoint + if master_process and args.checkpoint_every > 0 and step % args.checkpoint_every == 0: + save_checkpoint(args.checkpoint_dir, step, base_model, optimizers, + approx_ms, _untied, ema_model, _ema_started, _ema_steps) + log0(f"step:{step} checkpoint saved") + + if args.train_log_every > 0 and step % args.train_log_every == 0: + log0(f"step:{step}/{args.iterations} loss:{train_loss.item():.4f} " + f"t:{approx_ms:.0f}ms step_avg:{approx_ms/step:.1f}ms") + + if args.churn_log_every > 0 and step % args.churn_log_every == 0: + log0(f"step:{step} churn:{churn_fn(base_model, args.xnor_group_size):.4f}") + + # Wallclock cap + if stop_after_step is None and max_wallclock_ms is not None and step % 10 == 0: + reached_cap = approx_ms >= max_wallclock_ms + if distributed: + cap_t = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(cap_t, op=dist.ReduceOp.MAX) + reached_cap = bool(cap_t.item()) + if reached_cap: + stop_after_step = step + + # --- Sign-sort permutation for compression --- + # Permute MLP hidden dimension so same-sign weights cluster together. + # Sort columns of down_proj by their sign pattern (lexicographic on sign bits). + # Apply matching row permutation to up_proj. Model output is identical. + if master_process: + sd = (ema_model if ema_model is not None and _ema_started else base_model).state_dict() + for layer_idx in range(args.num_layers): + # Find the MLP weight pairs + down_key = f"blocks.{layer_idx}.mlp.proj.weight" + if args.activation_type == "swiglu": + up_key = f"blocks.{layer_idx}.mlp.gate_up.weight" + else: + up_key = f"blocks.{layer_idx}.mlp.fc.weight" + if down_key not in sd or up_key not in sd: + continue + w_down = sd[down_key] # [dim, hidden] + w_up = sd[up_key] # [hidden, dim] (or [hidden*2, dim] for swiglu) + hidden = w_down.shape[1] + # Sort hidden dim columns of w_down by sign pattern + signs = (w_down >= 0).to(torch.int8) # [dim, hidden] + # Pack sign columns into sortable keys (use first 64 rows as sort key) + n_key_rows = min(64, signs.shape[0]) + sort_keys = torch.zeros(hidden, dtype=torch.int64, device=w_down.device) + for r in range(n_key_rows): + sort_keys += signs[r, :].long() << (n_key_rows - 1 - r) + perm = sort_keys.argsort() + # Apply permutation + sd[down_key] = w_down[:, perm] + if args.activation_type == "swiglu": + half = w_up.shape[0] // 2 + sd[up_key] = torch.cat([w_up[:half][perm], w_up[half:][perm]], dim=0) + else: + sd[up_key] = w_up[perm] + if base_model.tie_embeddings: + sd.pop("lm_head.weight", None) + q_obj, q_stats = q_sd(sd, group_size=args.xnor_group_size, + fp_storage=args.fp_storage, scale_fp8=args.scale_fp8) + buf = io.BytesIO() + torch.save(q_obj, buf) + raw_bytes = buf.getvalue() + # Compare LZMA vs Brotli, keep smallest + lzma_blob = lzma.compress(raw_bytes, preset=9) + best_blob, best_method = lzma_blob, "lzma" + comp_log = f"compression: lzma={len(lzma_blob)/1e6:.2f}MB" + if HAS_BROTLI: + brotli_blob = brotli.compress(raw_bytes, quality=11) + comp_log += f" brotli={len(brotli_blob)/1e6:.2f}MB" + if len(brotli_blob) < len(best_blob): + best_blob, best_method = brotli_blob, "brotli" + if HAS_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + zstd_blob = cctx.compress(raw_bytes) + comp_log += f" zstd={len(zstd_blob)/1e6:.2f}MB" + if len(zstd_blob) < len(best_blob): + best_blob, best_method = zstd_blob, "zstd" + log0(comp_log) + log0(f"using: {best_method}") + with open("final_model.xnor.ptz", "wb") as f: + # Method byte: 0=lzma, 1=brotli, 2=zstd + method_id = {"lzma": 0, "brotli": 1, "zstd": 2}[best_method] + f.write(bytes([method_id])) + f.write(best_blob) + artifact_bytes = 1 + len(best_blob) + code_bytes = len(code.encode("utf-8")) + total = artifact_bytes + code_bytes + log0(f"artifact:{artifact_bytes/1e6:.2f}MB binary:{q_stats['binary_params']}({q_stats['binary_bytes']}B) " + f"fp:{q_stats['fp_params']}({q_stats['fp_bytes']}B) code:{code_bytes}") + log0(f"budget:{total}/{16000000} ({total/1e6:.2f}/{16.00:.2f}MB) " + f"{'FITS' if total <= 16000000 else 'OVER'}") + + # --- Roundtrip eval --- + if distributed: dist.barrier() + with open("final_model.xnor.ptz", "rb") as f: + raw = f.read() + method_byte = raw[0] + compressed = raw[1:] + if method_byte == 2 and HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + decompressed = dctx.decompress(compressed) + elif method_byte == 1 and HAS_BROTLI: + decompressed = brotli.decompress(compressed) + else: + decompressed = lzma.decompress(compressed) + loaded = torch.load(io.BytesIO(decompressed), map_location="cpu", weights_only=False) + base_model.load_state_dict(deq_sd(loaded), strict=False) + if ema_model is not None: + ema_model.load_state_dict(deq_sd(loaded), strict=False) + torch._dynamo.reset() + q_val_loss, q_val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"final_xnor_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f}") + + opt_temp = 1.0 + if args.temp_scaling: + torch.cuda.synchronize() + t_temp = time.perf_counter() + calibration_tokens = train_loader.stream.take(65536).to(device) + opt_temp = find_temp(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + torch.cuda.synchronize() + log0(f"temp_scaling optimal_T:{opt_temp:.2f} time:{1000*(time.perf_counter()-t_temp):.0f}ms") + + if args.sliding_eval: + torch._dynamo.reset() + torch.cuda.empty_cache() + torch.cuda.synchronize() + t_sl = time.perf_counter() + sw_loss, sw_bpb = eval_val_sliding(args, base_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, stride=args.sliding_eval_stride, + temperature=opt_temp) + torch.cuda.synchronize() + log0(f"final_sliding val_loss:{sw_loss:.4f} val_bpb:{sw_bpb:.4f} " + f"(stride={args.sliding_eval_stride}, T={opt_temp:.2f}) time:{1000*(time.perf_counter()-t_sl):.0f}ms") + + if distributed: dist.destroy_process_group() + + +if __name__ == "__main__": + main() +==================================================================================================== +--- Hyperparameters --- +activation_type=signsq adam_eps=1e-08 adam_lr=0.02 adam_wd=0.04 attn_res=0 beta1=0.9 beta2=0.95 binarize_activations=2 checkpoint_dir=./checkpoints checkpoint_every=0 churn_log_every=0 compile_mode=default data_path=./data/datasets/fineweb10B_sp1024 ema=False ema_decay=0.995 ema_start_fraction=0.0 embed_dim=384 embed_lr=0.6 fp_storage=True grad_clip_norm=1.0 head_lr=0.02 iterations=20000 logit_softcap=10.0 lr_schedule=cosine lr_warmup_steps=5 matrix_lr=0.008 matrix_optimizer=muon max_wallclock_seconds=599.0 mlp_mult=4 model_dim=1024 muon_backend_steps=3 muon_momentum=0.8 muon_momentum_warmup_start=0.8 muon_momentum_warmup_steps=200 muon_wd=0.04 num_heads=8 num_kv_heads=4 num_layers=10 qk_gain_init=2.25 rope_base=5000.0 rope_type=yarn run_id=P4_full-xnor_8xH100_1775399123 scalar_lr=0.01 scale_fp8=False seed=7 seq_len_schedule=True sign_compress_reg=0.0 sliding_batch_size=512 sliding_eval=True sliding_eval_stride=48 smear=False softcap_type=poly temp_scaling=True tie_embeddings=1 tied_embed_init_std=0.005 tied_embed_lr=0.05 tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model train_batch_tokens=65536 train_log_every=1000 train_seq_len=1024 use_gram_ns=False use_int8_kernel=False use_triton_kernel=True val_batch_size=524288 val_loss_every=0 vocab_size=1024 warmdown_fraction=0.3 xnor_group_size=256 yarn_max_len=2048 +XNOR-Net LLM | params:117618768 L:10 d:1024 h:8 kv:4 mlp:4x act:signsq xnor_act:2 g:256 ws:8 ga:1 +seq_len_schedule: [128, 256, 512, 1024] +warmup:1/4 seq_len=128 +warmup:2/4 seq_len=256 +warmup:3/4 seq_len=512 +warmup:4/4 seq_len=1024 +step:1000/20000 loss:3.6102 t:38441ms step_avg:38.4ms +step:2000/20000 loss:3.5059 t:77150ms step_avg:38.6ms +step:3000/20000 loss:3.2432 t:115821ms step_avg:38.6ms +step:4000/20000 loss:3.3050 t:154778ms step_avg:38.7ms +step:5000/20000 loss:3.2936 t:193583ms step_avg:38.7ms +step:6000/20000 loss:3.2143 t:231722ms step_avg:38.6ms +step:7000/20000 loss:3.2960 t:271246ms step_avg:38.7ms +step:8000/20000 loss:3.3074 t:310317ms step_avg:38.8ms +step:9000/20000 loss:3.4893 t:349224ms step_avg:38.8ms +step:10000/20000 loss:3.1546 t:387800ms step_avg:38.8ms +step:11000/20000 loss:3.1024 t:426151ms step_avg:38.7ms +step:12000/20000 loss:3.0742 t:463866ms step_avg:38.7ms +step:13000/20000 loss:3.0682 t:502665ms step_avg:38.7ms +step:14000/20000 loss:2.9603 t:540769ms step_avg:38.6ms +step:15000/20000 loss:2.8670 t:579905ms step_avg:38.7ms +step:15500/20000 val_loss:2.8798 val_bpb:1.6048 train_time:599224ms +stopping_early: wallclock_cap train_time:599224ms +compression: lzma=15.95MB brotli=15.89MB zstd=17.99MB +using: brotli +artifact:15.89MB binary:115343360(15319040B) fp:1226832(1248416B) code:78442 +budget:15964152/16000000 (15.96/16.00MB) FITS +final_xnor_roundtrip val_loss:2.8986 val_bpb:1.6153 +temp_scaling optimal_T:0.95 time:315ms +final_sliding val_loss:2.8350 val_bpb:1.5799 (stride=48, T=0.95) time:427692ms diff --git a/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/log_eggroll_3-last-layers.txt b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/log_eggroll_3-last-layers.txt new file mode 100644 index 0000000000..fb035b29bd --- /dev/null +++ b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/log_eggroll_3-last-layers.txt @@ -0,0 +1,1500 @@ +"""XNOR-Net LLM training script for OpenAI's Parameter Golf Challenge. +Ciprian-Florin Ifrim - April 2026 + +Based on the Binary-Weight-Network script, extended to full XNOR-Net: +both weights AND activations are binarized using sign() with per-group/per-token +scaling factors derived from the XNOR-Net paper (Rastegari et al. 2016). + +Architecture: U-Net transformer with skip connections that provides error correction +for the information loss inherent in binary quantization of both weights and activations. +""" + +import copy +import glob +import io +import math +import os +import random +import sys +import time +import lzma +from pathlib import Path +try: + import brotli + HAS_BROTLI = True +except ImportError: + HAS_BROTLI = False +try: + import zstandard as zstd + HAS_ZSTD = True +except ImportError: + HAS_ZSTD = False +try: + from gram_newton_schulz import GramNewtonSchulz, POLAR_EXPRESS_COEFFICIENTS + HAS_GRAM_NS = True +except ImportError: + HAS_GRAM_NS = False +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False +def _get_flash_attn_func(): + """GPU-aware FlashAttention selection. + H100: FA3 -> FA2 -> SDPA + 5090: FA4 -> FA2 -> SDPA + Other: FA2 -> SDPA + """ + if torch.cuda.is_available(): + sm = torch.cuda.get_device_capability() + # Hopper (SM 90) — prefer FA3 + if sm[0] == 9 and sm[1] == 0: + try: + from flash_attn_interface import flash_attn_func + return flash_attn_func + except ImportError: + pass + # Blackwell (SM 100/120) — try FA4 + if sm[0] >= 10: + try: + from flash_attn.cute import flash_attn_func + return flash_attn_func + except ImportError: + pass + # FA2 — universal fallback with CUDA kernels + try: + from flash_attn import flash_attn_func + return flash_attn_func + except ImportError: + pass + # PyTorch native SDPA — always works, no install needed + def _sdpa_flash_attn_func(q, k, v, causal=False, **kwargs): + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + o = F.scaled_dot_product_attention(q, k, v, is_causal=causal, enable_gqa=True) + return o.transpose(1, 2) + return _sdpa_flash_attn_func +flash_attn_func = _get_flash_attn_func() + +# --------------------------------------------------------------------------- +# Hyperparameters +# --------------------------------------------------------------------------- +def _e(k, d, t=str): + v = os.environ.get(k, str(d)) + if t == bool: return bool(int(v)) + return t(v) + +class Hyperparameters: + # Data + data_path = _e("DATA_PATH", "./data/datasets/fineweb10B_sp8192") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = _e("TOKENIZER_PATH", "./data/tokenizers/fineweb_8192_bpe.model") + run_id = os.environ.get("RUN_ID", f"xnor_{int(time.time())}") + seed = _e("SEED", 1337, int) + # Compile + compile_mode = _e("COMPILE_MODE", "default") + # Eval + val_batch_size = _e("VAL_BATCH_SIZE", 524288, int) + val_loss_every = _e("VAL_LOSS_EVERY", 500, int) + train_log_every = _e("TRAIN_LOG_EVERY", 10, int) + sliding_eval = _e("SLIDING_EVAL", 0, bool) + sliding_eval_stride = _e("SLIDING_EVAL_STRIDE", 16, int) + sliding_batch_size = _e("SLIDING_BATCH_SIZE", 512, int) + temp_scaling = _e("TEMP_SCALING", 0, bool) + # Training schedule + iterations = _e("ITERATIONS", 2000, int) + warmdown_fraction = _e("WARMDOWN_FRACTION", 0.15, float) + lr_warmup_steps = _e("LR_WARMUP_STEPS", 5, int) + lr_schedule = _e("LR_SCHEDULE", "linear") # "linear" (current) or "cosine" + seq_len_schedule = _e("SEQ_LEN_SCHEDULE", 0, bool) # ramp seq_len from 128 to train_seq_len + train_batch_tokens = _e("TRAIN_BATCH_TOKENS", 524288, int) + train_seq_len = _e("TRAIN_SEQ_LEN", 1024, int) + max_wallclock_seconds = _e("MAX_WALLCLOCK_SECONDS", 0.0, float) + # Architecture — from XNOR-UNet spec + vocab_size = _e("VOCAB_SIZE", 8192, int) + num_layers = _e("NUM_LAYERS", 15, int) + num_kv_heads = _e("NUM_KV_HEADS", 4, int) + model_dim = _e("MODEL_DIM", 768, int) + num_heads = _e("NUM_HEADS", 8, int) + mlp_mult = _e("MLP_MULT", 4, int) + tie_embeddings = _e("TIE_EMBEDDINGS", 1, int) + embed_dim = _e("EMBED_DIM", 312, int) + activation_type = _e("ACTIVATION", "relu2") + # Attention + rope_base = _e("ROPE_BASE", 5000.0, float) + rope_type = _e("ROPE_TYPE", "yarn") + yarn_max_len = _e("YARN_MAX_LEN", 2048, int) + qk_gain_init = _e("QK_GAIN_INIT", 2.25, float) + # Logits + logit_softcap = _e("LOGIT_SOFTCAP", 10.0, float) + softcap_type = _e("SOFTCAP_TYPE", "poly") + tied_embed_init_std = _e("TIED_EMBED_INIT_STD", 0.005, float) + # Architecture extras + smear = _e("SMEAR", 1, bool) + # XNOR quantization + xnor_group_size = _e("XNOR_GROUP_SIZE", 128, int) + # Controls activation binarization: 0 = none (BWN), 1 = all (full XNOR), 2 = skip MLP down proj + binarize_activations = _e("BINARIZE_ACTIVATIONS", 1, int) + # Use INT8 tensor cores for binary matmul (speed test) + use_int8_kernel = _e("USE_INT8_KERNEL", 0, bool) + # Use Triton XNOR+POPCOUNT kernel (true 1-bit matmul) + use_triton_kernel = _e("USE_TRITON_KERNEL", 0, bool) + # FP8 QAT for non-binary params + _fp_raw = os.environ.get("FP_STORAGE", "FP8") + fp_storage = True if _fp_raw == "FP8" else False + # Scale storage for binary weight groups: FP8 or BF16 + _scale_raw = os.environ.get("SCALE_STORAGE", "BF16") + scale_fp8 = True if _scale_raw == "FP8" else False + # Compression regularizer: penalizes sign entropy within groups + sign_compress_reg = _e("SIGN_COMPRESS_REG", 0.0, float) + # EGGROLL + pop_size = _e("POP_SIZE", 256, int) # population size (antithetic pairs = pop_size/2) + eggroll_sigma = _e("EGGROLL_SIGMA", 0.01, float) # perturbation scale + eggroll_lr = _e("EGGROLL_LR", 0.001, float) # learning rate + eggroll_rank = _e("EGGROLL_RANK", 1, int) # rank of perturbation matrices + fitness_shaping = _e("FITNESS_SHAPING", "rank") # "rank" or "sign" + eggroll_load = _e("EGGROLL_LOAD", "") # path to pretrained .ptz to load + # Checkpointing + checkpoint_every = _e("CHECKPOINT_EVERY", 0, int) + checkpoint_dir = _e("CHECKPOINT_DIR", "./checkpoints") + # Diagnostics + churn_log_every = _e("CHURN_LOG_EVERY", 500, int) + +# Scalar/low-dim parameter names (excluded from Muon, use Adam) +CTP = ("attn_scale", "mlp_scale", "resid_mix", "q_gain", "skip_weight", + "skip_weights", "vocab_bias", "smear") + +# --------------------------------------------------------------------------- +# Binary packing (1 bit per weight, lossless) +# --------------------------------------------------------------------------- +def pack_binary(q: Tensor) -> tuple[bytes, int]: + bits = ((q.reshape(-1).to(torch.int8) + 1) // 2).numpy().astype(np.uint8) + n = len(bits) + pad = (8 - n % 8) % 8 + if pad: + bits = np.concatenate([bits, np.zeros(pad, dtype=np.uint8)]) + groups = bits.reshape(-1, 8) + packed = np.zeros(len(groups), dtype=np.uint8) + for i in range(8): + packed |= groups[:, i] << i + return packed.tobytes(), n + +def unpack_binary(data: bytes, n: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8) + bits = np.zeros((len(packed), 8), dtype=np.int8) + for i in range(8): + bits[:, i] = (packed >> i) & 1 + flat = bits.reshape(-1)[:n] + return torch.from_numpy(flat.astype(np.int8) * 2 - 1) + +# --------------------------------------------------------------------------- +# State dict serialization +# --------------------------------------------------------------------------- +def q_sd(state_dict: dict, group_size: int = 128, fp_storage=False, scale_fp8=False) -> tuple[dict, dict]: + """Binary pack large 2D weight matrices, BF16/FP8 for everything else.""" + quantized = {} + stats = {"binary_params": 0, "binary_bytes": 0, "fp_params": 0, "fp_bytes": 0} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().float().contiguous() + t_orig_shape = list(t.shape) + if t.ndim == 3: + t = t.reshape(t.shape[0], -1) + is_binary = ( + t.ndim == 2 and t.numel() > 65_536 + and "tok_emb" not in name and "lm_head" not in name + and "embed_proj" not in name + ) + if is_binary: + pad = (group_size - t.shape[1] % group_size) % group_size + t_padded = F.pad(t, (0, pad)) if pad > 0 else t + t_grouped = t_padded.reshape(-1, group_size) + scale = t_grouped.abs().mean(-1, keepdim=True).clamp(min=1e-8) + q = torch.where(t_grouped >= 0, + torch.ones_like(t_grouped, dtype=torch.int8), + -torch.ones_like(t_grouped, dtype=torch.int8)) + packed_bytes, n_bits = pack_binary(q) + if scale_fp8: + scale_stored = scale.to(torch.float8_e4m3fn).squeeze(-1) + scale_bytes_per = 1 + else: + scale_stored = scale.bfloat16().squeeze(-1) + scale_bytes_per = 2 + quantized[name] = { + "type": "binary", "packed": packed_bytes, + "scale": scale_stored, + "shape": list(t.shape), "padded_cols": t_padded.shape[1], + "group_size": group_size, "n_bits": n_bits, + "orig_shape": t_orig_shape, + } + stats["binary_params"] += t.numel() + stats["binary_bytes"] += len(packed_bytes) + scale.numel() * scale_bytes_per + elif fp_storage and t.ndim == 2: + quantized[name] = {"type": "fp8", "data": t.to(torch.float8_e4m3fn)} + stats["fp_params"] += t.numel() + stats["fp_bytes"] += t.numel() + else: + quantized[name] = {"type": "bf16", "data": t.bfloat16()} + stats["fp_params"] += t.numel() + stats["fp_bytes"] += t.numel() * 2 + return quantized, stats + +def deq_sd(quantized: dict, target_dtype=torch.bfloat16): + """Reconstruct state dict from quantized representation.""" + out = {} + for name, entry in quantized.items(): + if entry["type"] == "binary": + q = unpack_binary(entry["packed"], entry["n_bits"]) + q = q.float().reshape(-1, entry["group_size"]) + scale = entry["scale"].float().unsqueeze(-1) + # No shrinkage correction: binary q_absmean = 1.0 always + t = (q * scale).reshape(-1, entry["padded_cols"]) + shape = entry["shape"] + result = t[:shape[0], :shape[1]].to(target_dtype) + orig = entry.get("orig_shape") + out[name] = result.reshape(orig).contiguous() if orig and orig != shape else result.contiguous() + elif entry["type"] == "fp8": + out[name] = entry["data"].to(torch.float32).to(target_dtype).contiguous() + else: + out[name] = entry["data"].to(target_dtype).contiguous() + return out + +# --------------------------------------------------------------------------- +# Diagnostics +# --------------------------------------------------------------------------- +_prev_committed: dict = {} +def churn_fn(model: nn.Module, group_size: int = 128): + global _prev_committed + total = flipped = 0 + with torch.no_grad(): + for name, p in model.named_parameters(): + if p.ndim == 2 and "weight" in name and p.shape[0] > 1: + w = p.detach().float().reshape(-1, group_size) + q = torch.where(w >= 0, torch.ones_like(w), -torch.ones_like(w)).cpu().numpy() + if name in _prev_committed: + flipped += int(np.sum(q != _prev_committed[name])) + total += q.size + _prev_committed[name] = q + return flipped / max(total, 1) + +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# Muon optimizer +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# EGGROLL: Evolution-Guided optimization with low-rank perturbations +# --------------------------------------------------------------------------- +def get_binary_params(model): + """Get list of (name, param) for parameters that will be binarized.""" + params = [] + for name, p in model.named_parameters(): + if p.ndim == 2 and p.numel() > 65536 and "tok_emb" not in name and "lm_head" not in name and "embed_proj" not in name: + params.append((name, p)) + return params + +def apply_perturbation(binary_params, seed, sigma, rank, sign=1.0, device="cuda"): + """Apply rank-r perturbation to binary params using deterministic RNG.""" + rng = torch.Generator(device=device) + rng.manual_seed(seed) + for name, p in binary_params: + m, n = p.shape + for r in range(rank): + a = torch.randn(m, device=device, generator=rng) + b = torch.randn(n, device=device, generator=rng) + p.data.add_(a.outer(b), alpha=sign * sigma / math.sqrt(rank)) + +def remove_perturbation(binary_params, seed, sigma, rank, sign=1.0, device="cuda"): + """Remove previously applied perturbation (subtract it).""" + apply_perturbation(binary_params, seed, sigma, rank, sign=-sign, device=device) + +def compute_eggroll_update(binary_params, seeds, fitnesses, sigma, rank, lr, device="cuda"): + """Compute and apply the EGGROLL weight update. + Uses the efficient (diag(f) @ A).T @ B formulation for rank-1.""" + N = len(seeds) + for name, p in binary_params: + m, n = p.shape + update = torch.zeros_like(p.data) + for i, (seed, f) in enumerate(zip(seeds, fitnesses)): + rng = torch.Generator(device=device) + rng.manual_seed(seed) + for r in range(rank): + a = torch.randn(m, device=device, generator=rng) + b = torch.randn(n, device=device, generator=rng) + update.add_(a.outer(b), alpha=f / (math.sqrt(rank) * N)) + p.data.add_(update, alpha=lr / sigma) + +def shape_fitnesses(fitnesses_pos, fitnesses_neg, method="rank"): + """Shape fitnesses from antithetic pairs.""" + N = len(fitnesses_pos) + if method == "sign": + return [float(fp > fn) - float(fn > fp) for fp, fn in zip(fitnesses_pos, fitnesses_neg)] + # Rank-based shaping: combine all, rank, normalize to [-0.5, 0.5] + all_f = list(fitnesses_pos) + list(fitnesses_neg) + ranked = torch.argsort(torch.argsort(torch.tensor(all_f, dtype=torch.float32))).float() + ranked = ranked / (2 * N - 1) - 0.5 + # Return shaped fitness for positive perturbation (subtract negative's) + shaped = [(ranked[i] - ranked[N + i]).item() for i in range(N)] + return shaped + +# --------------------------------------------------------------------------- +# Data loading +# --------------------------------------------------------------------------- +def ld_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" Tensor: + chunks = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int): + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start:start + per_rank_span].pin_memory().to(self.device, non_blocking=True).to(torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x, y + +# --------------------------------------------------------------------------- +# Model components +# --------------------------------------------------------------------------- +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def apply_qat_ste(w: Tensor, fp_storage) -> Tensor: + if not fp_storage: + return w + w_sim = w.to(torch.float8_e4m3fn).to(w.dtype) + return (w_sim - w).detach() + w + +class QATLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=False, fp_storage=False): + super().__init__(in_features, out_features, bias=bias) + self.fp_storage = fp_storage + def forward(self, x: Tensor) -> Tensor: + w = apply_qat_ste(self.weight, self.fp_storage) + return F.linear(x, w.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) + +class QATEmbedding(nn.Embedding): + def __init__(self, num_embeddings, embedding_dim, fp_storage=False): + super().__init__(num_embeddings, embedding_dim) + self.fp_storage = fp_storage + def forward(self, input: Tensor) -> Tensor: + w = apply_qat_ste(self.weight, self.fp_storage) + return F.embedding(input, w, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + +# --------------------------------------------------------------------------- +# XNOR-Net Linear Layers +# --------------------------------------------------------------------------- +_INT8_KERNEL = False # set from Hyperparameters.use_int8_kernel in main() + +class _INT8XNORFn(torch.autograd.Function): + """INT8 tensor core forward, BF16 backward. + Uses per-row alpha (not per-group) so we get a single torch._int_mm call. + """ + @staticmethod + def forward(ctx, x, w): + # Per-row weight scale (single matmul) + w_alpha = w.detach().abs().mean(dim=-1).clamp_(min=1e-8) # [N] + x_beta = x.detach().abs().mean(dim=-1, keepdim=True).clamp_(min=1e-8) # [..., 1] + + w_int8 = w.detach().sign().to(torch.int8) # [N, K] + x_int8 = x.detach().sign().to(torch.int8) # [..., K] + + # Flatten to 2D for torch._int_mm(A[M,K], B[K,N]) -> [M,N] + orig_shape = x.shape[:-1] + M, K, N = orig_shape.numel(), x.shape[-1], w.shape[0] + x_2d = x_int8.reshape(M, K) + wt = w_int8.t().contiguous() # [K, N] + raw = torch._int_mm(x_2d, wt) # [M, N] INT32 + + y = raw.to(x.dtype) * x_beta.reshape(M, 1) * w_alpha.unsqueeze(0) + ctx.save_for_backward(x, w, w_alpha, x_beta) + return y.reshape(*orig_shape, N) + + @staticmethod + def backward(ctx, grad_output): + x, w, w_alpha, x_beta = ctx.saved_tensors + dtype = grad_output.dtype + + w_binary = (w.sign() * w_alpha.unsqueeze(-1)).to(dtype) # [N, K] + x_binary = (x.sign() * x_beta).to(dtype) # [..., K] + + # STE: gradient flows through sign() as identity + grad_x = grad_output @ w_binary + M = grad_output.shape[:-1].numel() + grad_w = grad_output.reshape(M, -1).t() @ x_binary.reshape(M, -1) + return grad_x, grad_w + +# --------------------------------------------------------------------------- +# Triton XNOR+POPCOUNT kernel — true 1-bit matmul +# --------------------------------------------------------------------------- +_TRITON_KERNEL = False # set from Hyperparameters in main() +_SCALE_FP8_STE = False # simulate FP8 scale quantization during training + +def pack_signs(x: Tensor) -> Tensor: + """Pack sign bits into int32. +1 -> bit 1, -1 -> bit 0. + x: [..., K] any dtype. K must be divisible by 32. + returns: [..., K//32] int32 + """ + K = x.shape[-1] + assert K % 32 == 0, f"K={K} must be divisible by 32" + bits = (x > 0).view(*x.shape[:-1], K // 32, 32) + shifts = torch.arange(32, device=x.device, dtype=torch.int32) + return (bits.to(torch.int32) << shifts).sum(dim=-1) + +if HAS_TRITON: + @triton.jit + def _xnor_matmul_fwd_kernel( + X_ptr, W_ptr, Y_ptr, alpha_ptr, beta_ptr, + M, N, K_packed, GROUP_PACKED: tl.constexpr, num_groups, + stride_xm, stride_wn, stride_yn, stride_alpha_n, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + ): + pid = tl.program_id(0) + num_n = tl.cdiv(N, BLOCK_N) + pid_m = pid // num_n + pid_n = pid % num_n + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_m = offs_m < M + mask_n = offs_n < N + + y_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.bfloat16) + + x_base = X_ptr + offs_m * stride_xm + w_base = W_ptr + offs_n * stride_wn + + # Loop over groups, each with GROUP_PACKED packed int32s + for g in range(num_groups): + group_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + k_start = g * GROUP_PACKED + for k in range(GROUP_PACKED): + x = tl.load(x_base + k_start + k, mask=mask_m, other=0) + w = tl.load(w_base + k_start + k, mask=mask_n, other=0) + # Cast broadcast XOR to int32 to prevent int64 promotion in popc + diff = (x[:, None] ^ w[None, :]).to(tl.int32) + group_acc += tl.extra.cuda.libdevice.popc(diff) + + # diff_bits = popc(XOR), dot = GROUP_SIZE - 2 * diff_bits + group_dot = (GROUP_PACKED * 32 - 2 * group_acc).to(tl.bfloat16) + + # Per-group alpha: alpha[n, g] + alpha_g = tl.load(alpha_ptr + offs_n * stride_alpha_n + g, mask=mask_n, other=1.0).to(tl.bfloat16) + y_acc += group_dot * alpha_g[None, :] + + # Per-token beta + beta = tl.load(beta_ptr + offs_m, mask=mask_m, other=1.0).to(tl.bfloat16) + y_acc = y_acc * beta[:, None] + + # Store + y_ptrs = Y_ptr + offs_m[:, None] * stride_yn + offs_n[None, :] + tl.store(y_ptrs, y_acc, mask=mask_m[:, None] & mask_n[None, :]) + +class _TritonXNORFn(torch.autograd.Function): + """Triton XNOR+POPCOUNT forward with per-group alpha, BF16 backward.""" + @staticmethod + def forward(ctx, x, w, group_size): + N, K = w.shape + assert K % group_size == 0, f"K={K} must be divisible by group_size={group_size}" + assert K % 32 == 0, f"K={K} must be divisible by 32" + num_groups = K // group_size + GROUP_PACKED = group_size // 32 + + # Per-group weight scale: [N, num_groups] + w_alpha = w.detach().reshape(N, num_groups, group_size).abs().mean(dim=-1).clamp_(min=1e-8) + if _SCALE_FP8_STE: + w_alpha = w_alpha.to(torch.float8_e4m3fn).to(w_alpha.dtype) + # Per-token activation scale: [..., 1] + x_beta_kd = x.detach().abs().mean(dim=-1, keepdim=True).clamp_(min=1e-8) + + # Pack signs into int32 + x_packed = pack_signs(x.detach()) # [..., K//32] + w_packed = pack_signs(w.detach()) # [N, K//32] + + orig_shape = x.shape[:-1] + M = orig_shape.numel() + K_packed = K // 32 + + x_2d = x_packed.reshape(M, K_packed).contiguous() + w_alpha_c = w_alpha.contiguous() # [N, num_groups] + beta_flat = x_beta_kd.reshape(M).contiguous() + + Y = torch.empty(M, N, device=x.device, dtype=torch.bfloat16) + + BLOCK_M, BLOCK_N = 64, 64 + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),) + + _xnor_matmul_fwd_kernel[grid]( + x_2d, w_packed, Y, w_alpha_c, beta_flat, + M, N, K_packed, GROUP_PACKED, num_groups, + x_2d.stride(0), w_packed.stride(0), Y.stride(0), w_alpha_c.stride(0), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + ) + + ctx.save_for_backward(x, w, w_alpha, x_beta_kd) + ctx.group_size = group_size + return Y.reshape(*orig_shape, N) + + @staticmethod + def backward(ctx, grad_output): + x, w, w_alpha, x_beta_kd = ctx.saved_tensors + dtype = grad_output.dtype + group_size = ctx.group_size + + # Reconstruct per-group binary weights for backward + N, K = w.shape + num_groups = K // group_size + w_binary = (w.sign().reshape(N, num_groups, group_size) * + w_alpha.unsqueeze(-1)).reshape(N, K).to(dtype) + x_binary = (x.sign() * x_beta_kd).to(dtype) + + grad_x = grad_output @ w_binary + M = grad_output.shape[:-1].numel() + grad_w = grad_output.reshape(M, -1).t() @ x_binary.reshape(M, -1) + return grad_x, grad_w, None # None for group_size + +class XNORLinear(nn.Linear): + """XNOR-Net linear layer: binarizes BOTH weights AND activations. + + Forward pass: + 1. Weight binarization: W_bin = sign(W) * alpha, alpha = mean(|W|) per group + 2. Activation binarization: X_bin = sign(X) * beta, beta = mean(|X|) per token + 3. Output = F.linear(X_bin, W_bin) + + Backward pass: Straight-Through Estimator (STE) — gradients flow through as + if the binarization was identity. This matches the XNOR-Net paper's training + algorithm (Algorithm 1). + + When binarize_act=False, this degrades to Binary-Weight-Network (BWN). + """ + def __init__(self, in_features, out_features, bias=False, group_size=128, + binarize_act=True): + super().__init__(in_features, out_features, bias=bias) + self.group_size = group_size + self.binarize_act = binarize_act + + def forward(self, x: Tensor) -> Tensor: + # Triton XNOR+POPCOUNT path (full XNOR only, per-row alpha) + if _TRITON_KERNEL and self.binarize_act: + y = _TritonXNORFn.apply(x, self.weight, self.group_size) + if self.bias is not None: + y = y + self.bias.to(y.dtype) + return y + # INT8 tensor core path (full XNOR only, per-row alpha) + if _INT8_KERNEL and self.binarize_act: + y = _INT8XNORFn.apply(x, self.weight) + if self.bias is not None: + y = y + self.bias.to(y.dtype) + return y + + # Standard STE path (per-group alpha) + w = self.weight.bfloat16() + g = self.group_size + + # --- Weight binarization (STE) --- + w_g = w.reshape(-1, g) + alpha = w_g.abs().mean(-1, keepdim=True).clamp(min=1e-8) + # Simulate FP8 scale quantization so model compensates for roundtrip error + if _SCALE_FP8_STE: + alpha_q = alpha.to(torch.float8_e4m3fn).to(alpha.dtype) + alpha = alpha + (alpha_q - alpha).detach() + w_sign = torch.where(w_g >= 0, torch.ones_like(w_g), -torch.ones_like(w_g)) + w_binary = w + ((w_sign * alpha).reshape(w.shape) - w).detach() + + # --- Activation binarization (STE) --- + if self.binarize_act: + beta = x.abs().mean(dim=-1, keepdim=True).clamp(min=1e-8) + x_sign = torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x)) + x_binary = x + ((x_sign * beta) - x).detach() + else: + x_binary = x + + return F.linear(x_binary, w_binary, + self.bias.to(x.dtype) if self.bias is not None else None) + + +class NormedXNORLinear(XNORLinear): + """XNOR linear with RMSNorm on input. + Used for output projections receiving un-normalized activations (attention out, MLP down). + The RMSNorm ensures zero-mean input before binarization — critical for sign() quality + (XNOR-Net paper: B-A-C-P block order, Table 3b: 30.3% vs 44.2%). + """ + def forward(self, x: Tensor) -> Tensor: + return super().forward(F.rms_norm(x, (x.size(-1),))) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(p in name for p in CTP)) and param.dtype != torch.float32: + param.data = param.data.float() + +# --------------------------------------------------------------------------- +# Rotary embeddings +# --------------------------------------------------------------------------- +class Rotary(nn.Module): + def __init__(self, dim, base=10000.0, rope_type="rope", yarn_max_len=4096, train_seq_len=1024): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + if rope_type == "yarn": + scale = train_seq_len / yarn_max_len + freq_idx = torch.arange(0, dim, 2, dtype=torch.float32) + ramp = torch.clamp((freq_idx / dim - 0.25) / 0.75, 0.0, 1.0) + inv_freq = inv_freq / (ramp * (1.0 / scale - 1.0) + 1.0) + # Precompute for max seq_len, slice at runtime + t = torch.arange(train_seq_len, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self.register_buffer("_cos", freqs.cos()[None, :, None, :], persistent=False) + self.register_buffer("_sin", freqs.sin()[None, :, None, :], persistent=False) + + def forward(self, seq_len, device, dtype): + return self._cos[:, :seq_len].to(dtype=dtype), self._sin[:, :seq_len].to(dtype=dtype) + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- +class CausalSelfAttention(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + group_size=128, rope_type="rope", yarn_max_len=4096, + train_seq_len=1024, binarize_act=True): + super().__init__() + self.num_heads, self.num_kv_heads = num_heads, num_kv_heads + self.head_dim = dim // num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + # Fused QKV — single XNOR matmul + self.c_qkv = XNORLinear(dim, self.q_size + 2 * self.kv_size, bias=False, + group_size=group_size, binarize_act=binarize_act) + # Output projection with RMSNorm before binarization (B-A-C-P order) + self.proj = NormedXNORLinear(dim, dim, bias=False, + group_size=group_size, binarize_act=binarize_act) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, + rope_type=rope_type, yarn_max_len=yarn_max_len, + train_seq_len=train_seq_len) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + qkv_out = self.c_qkv(x) + q_out, k_out, v_out = qkv_out.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q_out.reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = k_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = v_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + # QK normalization (BF16, stays float — non-negotiable for attention quality) + q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + # FlashAttention-3 (BF16 — attention scores/softmax/V are NEVER binarized) + y = flash_attn_func(q.contiguous(), k.contiguous(), v.contiguous(), causal=True) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + +# --------------------------------------------------------------------------- +# MLP +# --------------------------------------------------------------------------- +class MLP(nn.Module): + def __init__(self, dim, mlp_mult, group_size=128, activation="relu2", + binarize_act=True, binarize_down=True): + super().__init__() + hidden = mlp_mult * dim + self.activation = activation + if activation == "swiglu": + self.gate_up = XNORLinear(dim, hidden * 2, bias=False, + group_size=group_size, binarize_act=binarize_act) + else: + self.fc = XNORLinear(dim, hidden, bias=False, + group_size=group_size, binarize_act=binarize_act) + self.proj = NormedXNORLinear(hidden, dim, bias=False, + group_size=group_size, binarize_act=binarize_down) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + if self.activation == "swiglu": + gu = self.gate_up(x) + gate, up = gu.chunk(2, dim=-1) + return self.proj(F.silu(gate) * up) + elif self.activation == "signsq": + h = self.fc(x) + return self.proj(h * h.abs()) + else: # relu2 + return self.proj(torch.relu(self.fc(x)).square()) + +# --------------------------------------------------------------------------- +# SmearModule +# --------------------------------------------------------------------------- +class SmearModule(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + cumsum = x.cumsum(dim=1) + counts = torch.arange(1, x.size(1) + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smeared = cumsum / counts + gate = torch.tanh(self.gate.to(dtype=x.dtype)) + return x + gate * (smeared - x) + +# --------------------------------------------------------------------------- +# Transformer block +# --------------------------------------------------------------------------- +class Block(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, group_size=128, activation="relu2", + rope_type="rope", yarn_max_len=4096, train_seq_len=1024, + binarize_act=True, binarize_down=True, smear=False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, + qk_gain_init, group_size, rope_type, + yarn_max_len, train_seq_len, binarize_act) + self.mlp = MLP(dim, mlp_mult, group_size, activation, binarize_act, binarize_down) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearModule(dim) if smear else None + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0] * x + mix[1] * x0 + n = self.attn_norm(x) + x = x + self.attn_scale.to(dtype=x.dtype) * self.attn(n) + x = x + self.mlp_scale.to(dtype=x.dtype) * self.mlp(self.mlp_norm(x)) + if self.smear is not None: + x = self.smear(x) + return x + +# --------------------------------------------------------------------------- +# GPT with U-Net skip connections +# --------------------------------------------------------------------------- +class GPT(nn.Module): + def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, + mlp_mult, tie_embeddings, tied_embed_init_std, logit_softcap, + rope_base, qk_gain_init, group_size=128, activation="relu2", + embed_dim=0, fp_storage=False, softcap_type="poly", + rope_type="rope", yarn_max_len=4096, train_seq_len=1024, + binarize_act=True, binarize_down=True, smear=False): + super().__init__() + self.tie_embeddings = tie_embeddings + self.logit_softcap = logit_softcap + self.softcap_type = softcap_type + self.fp_storage = fp_storage + self.embed_dim = embed_dim if embed_dim > 0 else model_dim + + # Embedding (FP8 QAT — NEVER binarized, confirmed 0.11 bpb RT gap) + self.tok_emb = QATEmbedding(vocab_size, self.embed_dim, fp_storage=fp_storage) + # Projection from embed_dim to model_dim (and reverse for tied logits) + self.embed_proj = QATLinear(self.embed_dim, model_dim, bias=False, + fp_storage=fp_storage) if self.embed_dim != model_dim else None + self.embed_proj_rev = QATLinear(model_dim, self.embed_dim, bias=False, + fp_storage=fp_storage) if self.embed_dim != model_dim else None + + # U-Net structure + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + # Skip weights initialized to ONES (not zeros — zeros costs 0.010 bpb) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + + # Transformer blocks + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, group_size, activation, rope_type, yarn_max_len, + train_seq_len, binarize_act, binarize_down, smear) + for _ in range(num_layers) + ]) + self.final_norm = RMSNorm() + + # Logit head (tied to embedding) + self.lm_head = QATLinear(model_dim, vocab_size, bias=False, fp_storage=fp_storage) + self.lm_head._zero_init = True + if tie_embeddings: + self.lm_head.weight.requires_grad_(False) + self.vocab_bias = nn.Parameter(torch.zeros(vocab_size, dtype=torch.float32)) + + self._init_weights(tied_embed_init_std) + + def _init_weights(self, tied_embed_init_std): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=tied_embed_init_std) + for module in self.modules(): + if isinstance(module, XNORLinear) and not getattr(module, "_zero_init", False): + nn.init.normal_(module.weight, mean=0.0, std=0.02) + elif isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _compute_logits(self, x: Tensor) -> Tensor: + if self.tie_embeddings: + proj = self.embed_proj_rev(x) if self.embed_proj_rev is not None else x + return F.linear(proj, self.tok_emb.weight.to(x.dtype)) + return self.lm_head(x) + + def _softcap(self, logits: Tensor) -> Tensor: + s = self.logit_softcap + if self.softcap_type == "tanh": + return s * torch.tanh(logits / s) + # Poly softcap — fuses with torch.compile (tanh does not) + x_sc = torch.clamp(logits / s, -2.0, 2.0) + x2 = x_sc * x_sc + return s * torch.clamp(x_sc * (1.0 - x2 / 3.0 + x2 * x2 / 15.0), -1.0, 1.0) + + def forward(self, input_ids: Tensor, target_ids: Tensor, reduction="mean", temperature=1.0): + x = self.tok_emb(input_ids) + x = x.float() + if self.embed_proj is not None: + x = self.embed_proj(x) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + # U-Net encoder + skips = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + + # U-Net decoder with skip connections + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype) * skips.pop() + x = self.blocks[bi](x, x0) + + x_normed = self.final_norm(x) + x_flat = x_normed.reshape(-1, x_normed.size(-1)) + targets = target_ids.reshape(-1) + logits = self._softcap(self._compute_logits(x_flat)) + + if temperature != 1.0: + logits = logits / temperature + if reduction == "none": + return F.cross_entropy(logits.float(), targets, reduction="none").reshape(input_ids.shape) + + # Fused CE + Z-loss (logsumexp² regularization keeps STE gradients sharp) + logits_f = logits.float() + lse = torch.logsumexp(logits_f, dim=-1) + target_logits = logits_f.gather(1, targets.unsqueeze(1)).squeeze(1) + return (lse - target_logits).mean() + 1e-4 * (lse ** 2).mean() + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- +def build_luts(sp, vocab_size, device): + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + +def ld_val(pattern, seq_len, max_tok=int(os.environ.get("VAL_MAX_TOKENS", 500000))): + files = sorted(glob.glob(pattern)) + assert files, f"No files: {pattern}" + tok = torch.cat([ld_shard(Path(p)) for p in files]).contiguous() + if max_tok > 0: tok = tok[:max_tok + 1] + u = ((tok.numel() - 1) // seq_len) * seq_len + return tok[:u + 1] + +def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, temperature=1.0): + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_seqs = max(1, local_batch_tokens // args.train_seq_len) + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_start in range(seq_start, seq_end, local_batch_seqs): + batch_end = min(batch_start + local_batch_seqs, seq_end) + raw_start = batch_start * args.train_seq_len + raw_end = batch_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64) + x, y = local[:-1].reshape(-1, args.train_seq_len), local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + batch_loss = model(x, y, temperature=temperature).detach() + n = float(y.numel()) + loss_sum += batch_loss.to(torch.float64) * n + token_count += n + prev_ids, tgt_ids = x.reshape(-1), y.reshape(-1) + tok_bytes = base_bytes_lut[tgt_ids].to(torch.int16) + tok_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.int16) + byte_count += tok_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, token_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + val_loss = loss_sum / token_count + bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) + model.train() + return float(val_loss.item()), float(bpb) + +def eval_val_sliding(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, temperature=1.0): + seq_len = args.train_seq_len + batch_size = args.sliding_batch_size + total_tokens = val_tokens.numel() - 1 + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + all_starts = list(range(0, total_tokens - seq_len, stride)) + my_starts = all_starts[rank::world_size] + model.eval() + with torch.inference_mode(): + for i in range(0, len(my_starts), batch_size): + batch_starts = my_starts[i:i + batch_size] + starts_t = torch.tensor(batch_starts, dtype=torch.int64) + offsets = torch.arange(seq_len + 1, dtype=torch.int64) + indices = starts_t.unsqueeze(1) + offsets.unsqueeze(0) + local_batch = val_tokens[indices].to(device=device, dtype=torch.int64, non_blocking=True) + x, y = local_batch[:, :-1], local_batch[:, 1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_token_loss = model(x, y, reduction="none", temperature=temperature).detach() + for b, start in enumerate(batch_starts): + score_from = 0 if start == 0 else seq_len - stride + scored = per_token_loss[b, score_from:] + sx, sy = x[b, score_from:], y[b, score_from:] + loss_sum += scored.to(torch.float64).sum() + token_count += scored.numel() + tok_bytes = base_bytes_lut[sy].to(torch.int16) + tok_bytes += (has_leading_space_lut[sy] & ~is_boundary_token_lut[sx]).to(torch.int16) + byte_count += tok_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, token_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + val_loss = loss_sum / token_count + bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) + model.train() + return float(val_loss.item()), float(bpb) + +def find_temp(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut): + best_t, best_loss = 1.0, float("inf") + for t in [0.85, 0.90, 0.95, 1.00, 1.05, 1.10]: + loss, _ = eval_val(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, temperature=t) + if loss < best_loss: + best_loss = loss + best_t = t + return best_t + +# --------------------------------------------------------------------------- +# Checkpointing +# --------------------------------------------------------------------------- +def _ckpt_path(d, step): return os.path.join(d, f"ckpt_step{step:07d}.pt") + +def _latest_checkpoint(d): + if not os.path.isdir(d): return None + ckpts = sorted(glob.glob(os.path.join(d, "ckpt_step*.pt"))) + return ckpts[-1] if ckpts else None + +def save_checkpoint(d, step, model, optimizers, ms, _untied, ema_model, _ema_started, _ema_steps): + os.makedirs(d, exist_ok=True) + path = _ckpt_path(d, step) + tmp = path + ".tmp" + torch.save({ + "step": step, "training_time_ms": ms, "model": model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "rng_cpu": torch.get_rng_state(), "rng_cuda": torch.cuda.get_rng_state(), + "flags": {"_untied": _untied, "_ema_started": _ema_started, "_ema_steps": _ema_steps}, + "ema_model": ema_model.state_dict() if ema_model is not None else None, + }, tmp) + os.replace(tmp, path) + +def load_checkpoint(path, model, optimizers, device, ema_model=None): + p = torch.load(path, map_location=device, weights_only=False) + model.load_state_dict(p["model"], strict=True) + for opt, sd in zip(optimizers, p["optimizers"]): + opt.load_state_dict(sd) + if ema_model is not None and p.get("ema_model") is not None: + ema_model.load_state_dict(p["ema_model"], strict=True) + torch.set_rng_state(p["rng_cpu"].cpu()) + torch.cuda.set_rng_state(p["rng_cuda"].cpu()) + f = p.get("flags", {}) + return p["step"], p["training_time_ms"], f.get("_untied", False), f.get("_ema_started", False), f.get("_ema_steps", 0) + +# --------------------------------------------------------------------------- +# Main training +# --------------------------------------------------------------------------- +def main(): + args = Hyperparameters() + code = Path(__file__).read_text(encoding="utf-8") + + global _INT8_KERNEL + _INT8_KERNEL = args.use_int8_kernel + + global _TRITON_KERNEL + _TRITON_KERNEL = args.use_triton_kernel and HAS_TRITON + + global _SCALE_FP8_STE + _SCALE_FP8_STE = args.scale_fp8 + + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = world_size > 1 + grad_accum_steps = max(1, 8 // world_size) + grad_scale = 1.0 / grad_accum_steps + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + os.makedirs("logs/cuda/xnor/", exist_ok=True) + logfile = f"logs/cuda/xnor/{args.run_id}.txt" if master_process else None + if master_process: print(logfile) + def log0(msg, console=True): + if not master_process: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + val_tokens = ld_val(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_luts(sp, args.vocab_size, device) + + # --- Model --- + # BINARIZE_ACTIVATIONS: 0=BWN, 1=full XNOR, 2=XNOR but BWN on MLP down proj + _ba = args.binarize_activations + _binarize_act = _ba >= 1 + _binarize_down = _ba == 1 + base_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + group_size=args.xnor_group_size, activation=args.activation_type, + embed_dim=args.embed_dim, fp_storage=args.fp_storage, + softcap_type=args.softcap_type, rope_type=args.rope_type, + yarn_max_len=args.yarn_max_len, train_seq_len=args.train_seq_len, + binarize_act=_binarize_act, binarize_down=_binarize_down, + smear=args.smear, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, nn.Linear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.tie_embeddings: + base_model.lm_head.weight.requires_grad_(False) + + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.recompile_limit = 16 + # No DDP needed for EGGROLL - each GPU evaluates different perturbations + # No torch.compile - forward only, perturbations modify weights in-place + + # --- Load pretrained weights --- + if args.eggroll_load: + log0(f"loading pretrained weights from {args.eggroll_load}") + if args.eggroll_load.endswith(".ptz"): + # Compressed artifact + with open(args.eggroll_load, "rb") as f: + raw = f.read() + method_byte = raw[0] + compressed = raw[1:] + if method_byte == 2 and HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + decompressed = dctx.decompress(compressed) + elif method_byte == 1 and HAS_BROTLI: + decompressed = brotli.decompress(compressed) + else: + decompressed = lzma.decompress(compressed) + loaded = torch.load(io.BytesIO(decompressed), map_location="cpu", weights_only=False) + base_model.load_state_dict(deq_sd(loaded), strict=False) + else: + # Raw checkpoint + ckpt = torch.load(args.eggroll_load, map_location="cpu", weights_only=False) + sd = ckpt["model"] if "model" in ckpt else ckpt + base_model.load_state_dict(sd, strict=False) + log0("pretrained weights loaded") + + # --- EGGROLL setup --- + # No gradients needed - forward only + for p in base_model.parameters(): + p.requires_grad_(False) + binary_params = get_binary_params(base_model) + log0(f"EGGROLL binary params: {len(binary_params)} layers, " + f"{sum(p.numel() for _,p in binary_params)} params") + + # --- Log --- + log0("--- Hyperparameters ---", console=False) + log0(" ".join(f"{a}={getattr(args,a)}" for a in sorted(dir(args)) + if not a.startswith("_") and a not in ("train_files","val_files") + and not callable(getattr(args,a))), console=False) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"XNOR-EGGROLL | params:{n_params} L:{args.num_layers} d:{args.model_dim} " + f"h:{args.num_heads} kv:{args.num_kv_heads} mlp:{args.mlp_mult}x " + f"act:{args.activation_type} xnor_act:{args.binarize_activations} " + f"g:{args.xnor_group_size} ws:{world_size} pop:{args.pop_size} " + f"sigma:{args.eggroll_sigma} rank:{args.eggroll_rank}") + + # --- Data --- + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step, elapsed_ms): + import math + if args.lr_schedule == "cosine": + if max_wallclock_ms is None: + total = args.iterations + warmup = min(args.lr_warmup_steps, total) + if step < warmup: + return step / max(warmup, 1) + progress = (step - warmup) / max(total - warmup, 1) + else: + warmup_ms = args.lr_warmup_steps * (elapsed_ms / max(step, 1)) if step > 0 else 1000.0 + if elapsed_ms < warmup_ms: + return elapsed_ms / max(warmup_ms, 1e-9) + progress = (elapsed_ms - warmup_ms) / max(max_wallclock_ms - warmup_ms, 1e-9) + return 0.5 * (1.0 + math.cos(math.pi * min(progress, 1.0))) + if args.warmdown_fraction <= 0: return 1.0 + if max_wallclock_ms is None: + warmdown_start = int(args.iterations * (1.0 - args.warmdown_fraction)) + return max((args.iterations - step) / max(args.iterations * args.warmdown_fraction, 1), 0.0) if step >= warmdown_start else 1.0 + warmdown_ms = max_wallclock_ms * args.warmdown_fraction + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # --- Seq len scheduling --- + seq_len_stages = [args.train_seq_len] + if args.seq_len_schedule: + seq_len_stages = [] + sl = 128 + while sl <= args.train_seq_len: + seq_len_stages.append(sl) + sl *= 2 + if seq_len_stages[-1] != args.train_seq_len: + seq_len_stages.append(args.train_seq_len) + log0(f"seq_len_schedule: {seq_len_stages}") + + def get_seq_len(step, elapsed_ms): + if len(seq_len_stages) <= 1: + return seq_len_stages[0] + n = len(seq_len_stages) + if max_wallclock_ms is not None: + stage_idx = min(int(elapsed_ms / (max_wallclock_ms / n)), n - 1) + else: + stage_idx = min(int(step / (args.iterations / n)), n - 1) + return seq_len_stages[stage_idx] + + # --- Main EGGROLL loop --- + training_time_ms = 0.0 + stop_after_step = None + pop_per_gpu = args.pop_size // world_size + assert pop_per_gpu % 2 == 0, "pop_size / world_size must be even for antithetic pairs" + pairs_per_gpu = pop_per_gpu // 2 + + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + + base_model.eval() + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(args, base_model, rank, world_size, device, 1, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms") + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms") + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + lr_scale = lr_mul(step, elapsed_ms) + current_seq_len = get_seq_len(step, elapsed_ms) + + # Get data batch (same for all perturbations) + x, y = train_loader.next_batch(args.train_batch_tokens, current_seq_len, 1) + + # Evaluate perturbations (antithetic pairs) + fitnesses_pos = [] + fitnesses_neg = [] + local_seeds = [] + + for i in range(pairs_per_gpu): + global_idx = step * (args.pop_size // 2) + rank * pairs_per_gpu + i + seed = global_idx * 137 + 42 # deterministic seed + + # Positive perturbation + apply_perturbation(binary_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=+1.0, device=device) + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss_pos = base_model(x, y) + fitnesses_pos.append(-loss_pos.item()) + remove_perturbation(binary_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=+1.0, device=device) + + # Negative perturbation (antithetic) + apply_perturbation(binary_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=-1.0, device=device) + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss_neg = base_model(x, y) + fitnesses_neg.append(-loss_neg.item()) + remove_perturbation(binary_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=-1.0, device=device) + + local_seeds.append(seed) + + # Gather fitnesses across GPUs + if distributed: + all_pos = [None] * world_size + all_neg = [None] * world_size + all_seeds_list = [None] * world_size + dist.all_gather_object(all_pos, fitnesses_pos) + dist.all_gather_object(all_neg, fitnesses_neg) + dist.all_gather_object(all_seeds_list, local_seeds) + fitnesses_pos = [f for sublist in all_pos for f in sublist] + fitnesses_neg = [f for sublist in all_neg for f in sublist] + all_seeds = [s for sublist in all_seeds_list for s in sublist] + else: + all_seeds = local_seeds + + # Shape fitnesses + shaped = shape_fitnesses(fitnesses_pos, fitnesses_neg, method=args.fitness_shaping) + + # Update weights + lr = args.eggroll_lr * lr_scale + compute_eggroll_update(binary_params, all_seeds, shaped, args.eggroll_sigma, + args.eggroll_rank, lr, device=device) + + step += 1 + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + if args.train_log_every > 0 and step % args.train_log_every == 0: + mean_fitness = sum(fitnesses_pos) / len(fitnesses_pos) + log0(f"step:{step}/{args.iterations} fitness:{mean_fitness:.4f} " + f"t:{approx_ms:.0f}ms step_avg:{approx_ms/step:.1f}ms") + + # Wallclock cap + if stop_after_step is None and max_wallclock_ms is not None and step % 10 == 0: + reached_cap = approx_ms >= max_wallclock_ms + if distributed: + cap_t = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(cap_t, op=dist.ReduceOp.MAX) + reached_cap = bool(cap_t.item()) + if reached_cap: + stop_after_step = step + + # --- Sign-sort permutation for compression --- + # Permute MLP hidden dimension so same-sign weights cluster together. + # Sort columns of down_proj by their sign pattern (lexicographic on sign bits). + # Apply matching row permutation to up_proj. Model output is identical. + if master_process: + sd = base_model.state_dict() + for layer_idx in range(args.num_layers): + # Find the MLP weight pairs + down_key = f"blocks.{layer_idx}.mlp.proj.weight" + if args.activation_type == "swiglu": + up_key = f"blocks.{layer_idx}.mlp.gate_up.weight" + else: + up_key = f"blocks.{layer_idx}.mlp.fc.weight" + if down_key not in sd or up_key not in sd: + continue + w_down = sd[down_key] # [dim, hidden] + w_up = sd[up_key] # [hidden, dim] (or [hidden*2, dim] for swiglu) + hidden = w_down.shape[1] + # Sort hidden dim columns of w_down by sign pattern + signs = (w_down >= 0).to(torch.int8) # [dim, hidden] + # Pack sign columns into sortable keys (use first 64 rows as sort key) + n_key_rows = min(64, signs.shape[0]) + sort_keys = torch.zeros(hidden, dtype=torch.int64, device=w_down.device) + for r in range(n_key_rows): + sort_keys += signs[r, :].long() << (n_key_rows - 1 - r) + perm = sort_keys.argsort() + # Apply permutation + sd[down_key] = w_down[:, perm] + if args.activation_type == "swiglu": + half = w_up.shape[0] // 2 + sd[up_key] = torch.cat([w_up[:half][perm], w_up[half:][perm]], dim=0) + else: + sd[up_key] = w_up[perm] + if base_model.tie_embeddings: + sd.pop("lm_head.weight", None) + q_obj, q_stats = q_sd(sd, group_size=args.xnor_group_size, + fp_storage=args.fp_storage, scale_fp8=args.scale_fp8) + buf = io.BytesIO() + torch.save(q_obj, buf) + raw_bytes = buf.getvalue() + # Compare LZMA vs Brotli, keep smallest + lzma_blob = lzma.compress(raw_bytes, preset=9) + best_blob, best_method = lzma_blob, "lzma" + comp_log = f"compression: lzma={len(lzma_blob)/1e6:.2f}MB" + if HAS_BROTLI: + brotli_blob = brotli.compress(raw_bytes, quality=11) + comp_log += f" brotli={len(brotli_blob)/1e6:.2f}MB" + if len(brotli_blob) < len(best_blob): + best_blob, best_method = brotli_blob, "brotli" + if HAS_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + zstd_blob = cctx.compress(raw_bytes) + comp_log += f" zstd={len(zstd_blob)/1e6:.2f}MB" + if len(zstd_blob) < len(best_blob): + best_blob, best_method = zstd_blob, "zstd" + log0(comp_log) + log0(f"using: {best_method}") + with open("final_model.xnor.ptz", "wb") as f: + # Method byte: 0=lzma, 1=brotli, 2=zstd + method_id = {"lzma": 0, "brotli": 1, "zstd": 2}[best_method] + f.write(bytes([method_id])) + f.write(best_blob) + artifact_bytes = 1 + len(best_blob) + code_bytes = len(code.encode("utf-8")) + total = artifact_bytes + code_bytes + log0(f"artifact:{artifact_bytes/1e6:.2f}MB binary:{q_stats['binary_params']}({q_stats['binary_bytes']}B) " + f"fp:{q_stats['fp_params']}({q_stats['fp_bytes']}B) code:{code_bytes}") + log0(f"budget:{total}/{16000000} ({total/1e6:.2f}/{16.00:.2f}MB) " + f"{'FITS' if total <= 16000000 else 'OVER'}") + + # --- Roundtrip eval --- + if distributed: dist.barrier() + with open("final_model.xnor.ptz", "rb") as f: + raw = f.read() + method_byte = raw[0] + compressed = raw[1:] + if method_byte == 2 and HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + decompressed = dctx.decompress(compressed) + elif method_byte == 1 and HAS_BROTLI: + decompressed = brotli.decompress(compressed) + else: + decompressed = lzma.decompress(compressed) + loaded = torch.load(io.BytesIO(decompressed), map_location="cpu", weights_only=False) + base_model.load_state_dict(deq_sd(loaded), strict=False) + torch._dynamo.reset() + q_val_loss, q_val_bpb = eval_val(args, base_model, rank, world_size, device, 1, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"final_xnor_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f}") + + opt_temp = 1.0 + if args.temp_scaling: + torch.cuda.synchronize() + t_temp = time.perf_counter() + calibration_tokens = train_loader.stream.take(65536).to(device) + opt_temp = find_temp(args, base_model, rank, world_size, device, 1, + calibration_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + torch.cuda.synchronize() + log0(f"temp_scaling optimal_T:{opt_temp:.2f} time:{1000*(time.perf_counter()-t_temp):.0f}ms") + + if args.sliding_eval: + torch.cuda.synchronize() + t_sl = time.perf_counter() + sw_loss, sw_bpb = eval_val_sliding(args, base_model, rank, world_size, device, 1, + val_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, stride=args.sliding_eval_stride, + temperature=opt_temp) + torch.cuda.synchronize() + log0(f"final_sliding val_loss:{sw_loss:.4f} val_bpb:{sw_bpb:.4f} " + f"(stride={args.sliding_eval_stride}, T={opt_temp:.2f}) time:{1000*(time.perf_counter()-t_sl):.0f}ms") + + if distributed: dist.destroy_process_group() + + +if __name__ == "__main__": + main() +==================================================================================================== +loading pretrained weights from checkpoints_n2_100k/ckpt_step0100000.pt +pretrained weights loaded +EGGROLL binary params: 40 layers, 115343360 params +--- Hyperparameters --- +activation_type=signsq binarize_activations=2 checkpoint_dir=./checkpoints checkpoint_every=0 churn_log_every=0 compile_mode=default data_path=./data/datasets/fineweb10B_sp1024 eggroll_load=checkpoints_n2_100k/ckpt_step0100000.pt eggroll_lr=1e-06 eggroll_rank=8 eggroll_sigma=0.0001 embed_dim=384 fitness_shaping=rank fp_storage=True iterations=10 logit_softcap=10.0 lr_schedule=cosine lr_warmup_steps=0 max_wallclock_seconds=0.0 mlp_mult=4 model_dim=1024 num_heads=8 num_kv_heads=4 num_layers=10 pop_size=4096 qk_gain_init=2.25 rope_base=5000.0 rope_type=yarn run_id=E7_eggroll_8xH100_1775387185 scale_fp8=False seed=42 seq_len_schedule=False sign_compress_reg=0.0 sliding_batch_size=512 sliding_eval=True sliding_eval_stride=16 smear=False softcap_type=poly temp_scaling=True tie_embeddings=1 tied_embed_init_std=0.005 tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model train_batch_tokens=65536 train_log_every=1 train_seq_len=1024 use_int8_kernel=False use_triton_kernel=False val_batch_size=524288 val_loss_every=10 vocab_size=1024 warmdown_fraction=0.3 xnor_group_size=256 yarn_max_len=2048 +XNOR-EGGROLL | params:117618768 L:10 d:1024 h:8 kv:4 mlp:4x act:signsq xnor_act:2 g:256 ws:8 pop:4096 sigma:0.0001 rank:8 +step:0/10 val_loss:2.8153 val_bpb:1.5689 train_time:0ms +step:1/10 fitness:-3.0060 t:35171ms step_avg:35171.2ms +step:2/10 fitness:-2.9829 t:71107ms step_avg:35553.4ms +step:3/10 fitness:-3.0019 t:107162ms step_avg:35720.7ms +step:4/10 fitness:-2.9579 t:144328ms step_avg:36082.0ms +step:5/10 fitness:-3.0162 t:180393ms step_avg:36078.5ms +step:6/10 fitness:-2.9989 t:216005ms step_avg:36000.8ms +step:7/10 fitness:-3.0115 t:252089ms step_avg:36012.8ms +step:8/10 fitness:-3.0022 t:287147ms step_avg:35893.3ms +step:9/10 fitness:-3.0482 t:322782ms step_avg:35864.6ms +step:10/10 fitness:-2.9566 t:358676ms step_avg:35867.6ms +step:10/10 val_loss:2.9309 val_bpb:1.6333 train_time:358684ms diff --git a/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/log_eggroll_all-layers-perturbed.txt b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/log_eggroll_all-layers-perturbed.txt new file mode 100644 index 0000000000..2fc5ac31a5 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/log_eggroll_all-layers-perturbed.txt @@ -0,0 +1,1506 @@ +"""XNOR-Net LLM training script for OpenAI's Parameter Golf Challenge. +Ciprian-Florin Ifrim - April 2026 + +Based on the Binary-Weight-Network script, extended to full XNOR-Net: +both weights AND activations are binarized using sign() with per-group/per-token +scaling factors derived from the XNOR-Net paper (Rastegari et al. 2016). + +Architecture: U-Net transformer with skip connections that provides error correction +for the information loss inherent in binary quantization of both weights and activations. +""" + +import copy +import glob +import io +import math +import os +import random +import sys +import time +import lzma +from pathlib import Path +try: + import brotli + HAS_BROTLI = True +except ImportError: + HAS_BROTLI = False +try: + import zstandard as zstd + HAS_ZSTD = True +except ImportError: + HAS_ZSTD = False +try: + from gram_newton_schulz import GramNewtonSchulz, POLAR_EXPRESS_COEFFICIENTS + HAS_GRAM_NS = True +except ImportError: + HAS_GRAM_NS = False +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False +def _get_flash_attn_func(): + """GPU-aware FlashAttention selection. + H100: FA3 -> FA2 -> SDPA + 5090: FA4 -> FA2 -> SDPA + Other: FA2 -> SDPA + """ + if torch.cuda.is_available(): + sm = torch.cuda.get_device_capability() + # Hopper (SM 90) — prefer FA3 + if sm[0] == 9 and sm[1] == 0: + try: + from flash_attn_interface import flash_attn_func + return flash_attn_func + except ImportError: + pass + # Blackwell (SM 100/120) — try FA4 + if sm[0] >= 10: + try: + from flash_attn.cute import flash_attn_func + return flash_attn_func + except ImportError: + pass + # FA2 — universal fallback with CUDA kernels + try: + from flash_attn import flash_attn_func + return flash_attn_func + except ImportError: + pass + # PyTorch native SDPA — always works, no install needed + def _sdpa_flash_attn_func(q, k, v, causal=False, **kwargs): + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + o = F.scaled_dot_product_attention(q, k, v, is_causal=causal, enable_gqa=True) + return o.transpose(1, 2) + return _sdpa_flash_attn_func +flash_attn_func = _get_flash_attn_func() + +# --------------------------------------------------------------------------- +# Hyperparameters +# --------------------------------------------------------------------------- +def _e(k, d, t=str): + v = os.environ.get(k, str(d)) + if t == bool: return bool(int(v)) + return t(v) + +class Hyperparameters: + # Data + data_path = _e("DATA_PATH", "./data/datasets/fineweb10B_sp8192") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = _e("TOKENIZER_PATH", "./data/tokenizers/fineweb_8192_bpe.model") + run_id = os.environ.get("RUN_ID", f"xnor_{int(time.time())}") + seed = _e("SEED", 1337, int) + # Compile + compile_mode = _e("COMPILE_MODE", "default") + # Eval + val_batch_size = _e("VAL_BATCH_SIZE", 524288, int) + val_loss_every = _e("VAL_LOSS_EVERY", 500, int) + train_log_every = _e("TRAIN_LOG_EVERY", 10, int) + sliding_eval = _e("SLIDING_EVAL", 0, bool) + sliding_eval_stride = _e("SLIDING_EVAL_STRIDE", 16, int) + sliding_batch_size = _e("SLIDING_BATCH_SIZE", 512, int) + temp_scaling = _e("TEMP_SCALING", 0, bool) + # Training schedule + iterations = _e("ITERATIONS", 2000, int) + warmdown_fraction = _e("WARMDOWN_FRACTION", 0.15, float) + lr_warmup_steps = _e("LR_WARMUP_STEPS", 5, int) + lr_schedule = _e("LR_SCHEDULE", "linear") # "linear" (current) or "cosine" + seq_len_schedule = _e("SEQ_LEN_SCHEDULE", 0, bool) # ramp seq_len from 128 to train_seq_len + train_batch_tokens = _e("TRAIN_BATCH_TOKENS", 524288, int) + train_seq_len = _e("TRAIN_SEQ_LEN", 1024, int) + max_wallclock_seconds = _e("MAX_WALLCLOCK_SECONDS", 0.0, float) + # Architecture — from XNOR-UNet spec + vocab_size = _e("VOCAB_SIZE", 8192, int) + num_layers = _e("NUM_LAYERS", 15, int) + num_kv_heads = _e("NUM_KV_HEADS", 4, int) + model_dim = _e("MODEL_DIM", 768, int) + num_heads = _e("NUM_HEADS", 8, int) + mlp_mult = _e("MLP_MULT", 4, int) + tie_embeddings = _e("TIE_EMBEDDINGS", 1, int) + embed_dim = _e("EMBED_DIM", 312, int) + activation_type = _e("ACTIVATION", "relu2") + # Attention + rope_base = _e("ROPE_BASE", 5000.0, float) + rope_type = _e("ROPE_TYPE", "yarn") + yarn_max_len = _e("YARN_MAX_LEN", 2048, int) + qk_gain_init = _e("QK_GAIN_INIT", 2.25, float) + # Logits + logit_softcap = _e("LOGIT_SOFTCAP", 10.0, float) + softcap_type = _e("SOFTCAP_TYPE", "poly") + tied_embed_init_std = _e("TIED_EMBED_INIT_STD", 0.005, float) + # Architecture extras + smear = _e("SMEAR", 1, bool) + # XNOR quantization + xnor_group_size = _e("XNOR_GROUP_SIZE", 128, int) + # Controls activation binarization: 0 = none (BWN), 1 = all (full XNOR), 2 = skip MLP down proj + binarize_activations = _e("BINARIZE_ACTIVATIONS", 1, int) + # Use INT8 tensor cores for binary matmul (speed test) + use_int8_kernel = _e("USE_INT8_KERNEL", 0, bool) + # Use Triton XNOR+POPCOUNT kernel (true 1-bit matmul) + use_triton_kernel = _e("USE_TRITON_KERNEL", 0, bool) + # FP8 QAT for non-binary params + _fp_raw = os.environ.get("FP_STORAGE", "FP8") + fp_storage = True if _fp_raw == "FP8" else False + # Scale storage for binary weight groups: FP8 or BF16 + _scale_raw = os.environ.get("SCALE_STORAGE", "BF16") + scale_fp8 = True if _scale_raw == "FP8" else False + # Compression regularizer: penalizes sign entropy within groups + sign_compress_reg = _e("SIGN_COMPRESS_REG", 0.0, float) + # EGGROLL + pop_size = _e("POP_SIZE", 256, int) # population size (antithetic pairs = pop_size/2) + eggroll_sigma = _e("EGGROLL_SIGMA", 0.01, float) # perturbation scale + eggroll_lr = _e("EGGROLL_LR", 0.001, float) # learning rate + eggroll_rank = _e("EGGROLL_RANK", 1, int) # rank of perturbation matrices + fitness_shaping = _e("FITNESS_SHAPING", "rank") # "rank" or "sign" + eggroll_load = _e("EGGROLL_LOAD", "") # path to pretrained .ptz to load + # Checkpointing + checkpoint_every = _e("CHECKPOINT_EVERY", 0, int) + checkpoint_dir = _e("CHECKPOINT_DIR", "./checkpoints") + # Diagnostics + churn_log_every = _e("CHURN_LOG_EVERY", 500, int) + +# Scalar/low-dim parameter names (excluded from Muon, use Adam) +CTP = ("attn_scale", "mlp_scale", "resid_mix", "q_gain", "skip_weight", + "skip_weights", "vocab_bias", "smear") + +# --------------------------------------------------------------------------- +# Binary packing (1 bit per weight, lossless) +# --------------------------------------------------------------------------- +def pack_binary(q: Tensor) -> tuple[bytes, int]: + bits = ((q.reshape(-1).to(torch.int8) + 1) // 2).numpy().astype(np.uint8) + n = len(bits) + pad = (8 - n % 8) % 8 + if pad: + bits = np.concatenate([bits, np.zeros(pad, dtype=np.uint8)]) + groups = bits.reshape(-1, 8) + packed = np.zeros(len(groups), dtype=np.uint8) + for i in range(8): + packed |= groups[:, i] << i + return packed.tobytes(), n + +def unpack_binary(data: bytes, n: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8) + bits = np.zeros((len(packed), 8), dtype=np.int8) + for i in range(8): + bits[:, i] = (packed >> i) & 1 + flat = bits.reshape(-1)[:n] + return torch.from_numpy(flat.astype(np.int8) * 2 - 1) + +# --------------------------------------------------------------------------- +# State dict serialization +# --------------------------------------------------------------------------- +def q_sd(state_dict: dict, group_size: int = 128, fp_storage=False, scale_fp8=False) -> tuple[dict, dict]: + """Binary pack large 2D weight matrices, BF16/FP8 for everything else.""" + quantized = {} + stats = {"binary_params": 0, "binary_bytes": 0, "fp_params": 0, "fp_bytes": 0} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().float().contiguous() + t_orig_shape = list(t.shape) + if t.ndim == 3: + t = t.reshape(t.shape[0], -1) + is_binary = ( + t.ndim == 2 and t.numel() > 65_536 + and "tok_emb" not in name and "lm_head" not in name + and "embed_proj" not in name + ) + if is_binary: + pad = (group_size - t.shape[1] % group_size) % group_size + t_padded = F.pad(t, (0, pad)) if pad > 0 else t + t_grouped = t_padded.reshape(-1, group_size) + scale = t_grouped.abs().mean(-1, keepdim=True).clamp(min=1e-8) + q = torch.where(t_grouped >= 0, + torch.ones_like(t_grouped, dtype=torch.int8), + -torch.ones_like(t_grouped, dtype=torch.int8)) + packed_bytes, n_bits = pack_binary(q) + if scale_fp8: + scale_stored = scale.to(torch.float8_e4m3fn).squeeze(-1) + scale_bytes_per = 1 + else: + scale_stored = scale.bfloat16().squeeze(-1) + scale_bytes_per = 2 + quantized[name] = { + "type": "binary", "packed": packed_bytes, + "scale": scale_stored, + "shape": list(t.shape), "padded_cols": t_padded.shape[1], + "group_size": group_size, "n_bits": n_bits, + "orig_shape": t_orig_shape, + } + stats["binary_params"] += t.numel() + stats["binary_bytes"] += len(packed_bytes) + scale.numel() * scale_bytes_per + elif fp_storage and t.ndim == 2: + quantized[name] = {"type": "fp8", "data": t.to(torch.float8_e4m3fn)} + stats["fp_params"] += t.numel() + stats["fp_bytes"] += t.numel() + else: + quantized[name] = {"type": "bf16", "data": t.bfloat16()} + stats["fp_params"] += t.numel() + stats["fp_bytes"] += t.numel() * 2 + return quantized, stats + +def deq_sd(quantized: dict, target_dtype=torch.bfloat16): + """Reconstruct state dict from quantized representation.""" + out = {} + for name, entry in quantized.items(): + if entry["type"] == "binary": + q = unpack_binary(entry["packed"], entry["n_bits"]) + q = q.float().reshape(-1, entry["group_size"]) + scale = entry["scale"].float().unsqueeze(-1) + # No shrinkage correction: binary q_absmean = 1.0 always + t = (q * scale).reshape(-1, entry["padded_cols"]) + shape = entry["shape"] + result = t[:shape[0], :shape[1]].to(target_dtype) + orig = entry.get("orig_shape") + out[name] = result.reshape(orig).contiguous() if orig and orig != shape else result.contiguous() + elif entry["type"] == "fp8": + out[name] = entry["data"].to(torch.float32).to(target_dtype).contiguous() + else: + out[name] = entry["data"].to(target_dtype).contiguous() + return out + +# --------------------------------------------------------------------------- +# Diagnostics +# --------------------------------------------------------------------------- +_prev_committed: dict = {} +def churn_fn(model: nn.Module, group_size: int = 128): + global _prev_committed + total = flipped = 0 + with torch.no_grad(): + for name, p in model.named_parameters(): + if p.ndim == 2 and "weight" in name and p.shape[0] > 1: + w = p.detach().float().reshape(-1, group_size) + q = torch.where(w >= 0, torch.ones_like(w), -torch.ones_like(w)).cpu().numpy() + if name in _prev_committed: + flipped += int(np.sum(q != _prev_committed[name])) + total += q.size + _prev_committed[name] = q + return flipped / max(total, 1) + +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# Muon optimizer +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# EGGROLL: Evolution-Guided optimization with low-rank perturbations +# --------------------------------------------------------------------------- +def get_binary_params(model): + """Get list of (name, param) for parameters that will be binarized.""" + params = [] + for name, p in model.named_parameters(): + if p.ndim == 2 and p.numel() > 65536 and "tok_emb" not in name and "lm_head" not in name and "embed_proj" not in name: + params.append((name, p)) + return params + +def apply_perturbation(binary_params, seed, sigma, rank, sign=1.0, device="cuda"): + """Apply rank-r perturbation to binary params using deterministic RNG.""" + rng = torch.Generator(device=device) + rng.manual_seed(seed) + for name, p in binary_params: + m, n = p.shape + for r in range(rank): + a = torch.randn(m, device=device, generator=rng) + b = torch.randn(n, device=device, generator=rng) + p.data.add_(a.outer(b), alpha=sign * sigma / math.sqrt(rank)) + +def remove_perturbation(binary_params, seed, sigma, rank, sign=1.0, device="cuda"): + """Remove previously applied perturbation (subtract it).""" + apply_perturbation(binary_params, seed, sigma, rank, sign=-sign, device=device) + +def compute_eggroll_update(binary_params, seeds, fitnesses, sigma, rank, lr, device="cuda"): + """Compute and apply the EGGROLL weight update. + Uses the efficient (diag(f) @ A).T @ B formulation for rank-1.""" + N = len(seeds) + for name, p in binary_params: + m, n = p.shape + update = torch.zeros_like(p.data) + for i, (seed, f) in enumerate(zip(seeds, fitnesses)): + rng = torch.Generator(device=device) + rng.manual_seed(seed) + for r in range(rank): + a = torch.randn(m, device=device, generator=rng) + b = torch.randn(n, device=device, generator=rng) + update.add_(a.outer(b), alpha=f / (math.sqrt(rank) * N)) + p.data.add_(update, alpha=lr / sigma) + +def shape_fitnesses(fitnesses_pos, fitnesses_neg, method="rank"): + """Shape fitnesses from antithetic pairs.""" + N = len(fitnesses_pos) + if method == "sign": + return [float(fp > fn) - float(fn > fp) for fp, fn in zip(fitnesses_pos, fitnesses_neg)] + # Rank-based shaping: combine all, rank, normalize to [-0.5, 0.5] + all_f = list(fitnesses_pos) + list(fitnesses_neg) + ranked = torch.argsort(torch.argsort(torch.tensor(all_f, dtype=torch.float32))).float() + ranked = ranked / (2 * N - 1) - 0.5 + # Return shaped fitness for positive perturbation (subtract negative's) + shaped = [(ranked[i] - ranked[N + i]).item() for i in range(N)] + return shaped + +# --------------------------------------------------------------------------- +# Data loading +# --------------------------------------------------------------------------- +def ld_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" Tensor: + chunks = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int): + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start:start + per_rank_span].pin_memory().to(self.device, non_blocking=True).to(torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x, y + +# --------------------------------------------------------------------------- +# Model components +# --------------------------------------------------------------------------- +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def apply_qat_ste(w: Tensor, fp_storage) -> Tensor: + if not fp_storage: + return w + w_sim = w.to(torch.float8_e4m3fn).to(w.dtype) + return (w_sim - w).detach() + w + +class QATLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=False, fp_storage=False): + super().__init__(in_features, out_features, bias=bias) + self.fp_storage = fp_storage + def forward(self, x: Tensor) -> Tensor: + w = apply_qat_ste(self.weight, self.fp_storage) + return F.linear(x, w.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) + +class QATEmbedding(nn.Embedding): + def __init__(self, num_embeddings, embedding_dim, fp_storage=False): + super().__init__(num_embeddings, embedding_dim) + self.fp_storage = fp_storage + def forward(self, input: Tensor) -> Tensor: + w = apply_qat_ste(self.weight, self.fp_storage) + return F.embedding(input, w, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + +# --------------------------------------------------------------------------- +# XNOR-Net Linear Layers +# --------------------------------------------------------------------------- +_INT8_KERNEL = False # set from Hyperparameters.use_int8_kernel in main() + +class _INT8XNORFn(torch.autograd.Function): + """INT8 tensor core forward, BF16 backward. + Uses per-row alpha (not per-group) so we get a single torch._int_mm call. + """ + @staticmethod + def forward(ctx, x, w): + # Per-row weight scale (single matmul) + w_alpha = w.detach().abs().mean(dim=-1).clamp_(min=1e-8) # [N] + x_beta = x.detach().abs().mean(dim=-1, keepdim=True).clamp_(min=1e-8) # [..., 1] + + w_int8 = w.detach().sign().to(torch.int8) # [N, K] + x_int8 = x.detach().sign().to(torch.int8) # [..., K] + + # Flatten to 2D for torch._int_mm(A[M,K], B[K,N]) -> [M,N] + orig_shape = x.shape[:-1] + M, K, N = orig_shape.numel(), x.shape[-1], w.shape[0] + x_2d = x_int8.reshape(M, K) + wt = w_int8.t().contiguous() # [K, N] + raw = torch._int_mm(x_2d, wt) # [M, N] INT32 + + y = raw.to(x.dtype) * x_beta.reshape(M, 1) * w_alpha.unsqueeze(0) + ctx.save_for_backward(x, w, w_alpha, x_beta) + return y.reshape(*orig_shape, N) + + @staticmethod + def backward(ctx, grad_output): + x, w, w_alpha, x_beta = ctx.saved_tensors + dtype = grad_output.dtype + + w_binary = (w.sign() * w_alpha.unsqueeze(-1)).to(dtype) # [N, K] + x_binary = (x.sign() * x_beta).to(dtype) # [..., K] + + # STE: gradient flows through sign() as identity + grad_x = grad_output @ w_binary + M = grad_output.shape[:-1].numel() + grad_w = grad_output.reshape(M, -1).t() @ x_binary.reshape(M, -1) + return grad_x, grad_w + +# --------------------------------------------------------------------------- +# Triton XNOR+POPCOUNT kernel — true 1-bit matmul +# --------------------------------------------------------------------------- +_TRITON_KERNEL = False # set from Hyperparameters in main() +_SCALE_FP8_STE = False # simulate FP8 scale quantization during training + +def pack_signs(x: Tensor) -> Tensor: + """Pack sign bits into int32. +1 -> bit 1, -1 -> bit 0. + x: [..., K] any dtype. K must be divisible by 32. + returns: [..., K//32] int32 + """ + K = x.shape[-1] + assert K % 32 == 0, f"K={K} must be divisible by 32" + bits = (x > 0).view(*x.shape[:-1], K // 32, 32) + shifts = torch.arange(32, device=x.device, dtype=torch.int32) + return (bits.to(torch.int32) << shifts).sum(dim=-1) + +if HAS_TRITON: + @triton.jit + def _xnor_matmul_fwd_kernel( + X_ptr, W_ptr, Y_ptr, alpha_ptr, beta_ptr, + M, N, K_packed, GROUP_PACKED: tl.constexpr, num_groups, + stride_xm, stride_wn, stride_yn, stride_alpha_n, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + ): + pid = tl.program_id(0) + num_n = tl.cdiv(N, BLOCK_N) + pid_m = pid // num_n + pid_n = pid % num_n + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_m = offs_m < M + mask_n = offs_n < N + + y_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.bfloat16) + + x_base = X_ptr + offs_m * stride_xm + w_base = W_ptr + offs_n * stride_wn + + # Loop over groups, each with GROUP_PACKED packed int32s + for g in range(num_groups): + group_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + k_start = g * GROUP_PACKED + for k in range(GROUP_PACKED): + x = tl.load(x_base + k_start + k, mask=mask_m, other=0) + w = tl.load(w_base + k_start + k, mask=mask_n, other=0) + # Cast broadcast XOR to int32 to prevent int64 promotion in popc + diff = (x[:, None] ^ w[None, :]).to(tl.int32) + group_acc += tl.extra.cuda.libdevice.popc(diff) + + # diff_bits = popc(XOR), dot = GROUP_SIZE - 2 * diff_bits + group_dot = (GROUP_PACKED * 32 - 2 * group_acc).to(tl.bfloat16) + + # Per-group alpha: alpha[n, g] + alpha_g = tl.load(alpha_ptr + offs_n * stride_alpha_n + g, mask=mask_n, other=1.0).to(tl.bfloat16) + y_acc += group_dot * alpha_g[None, :] + + # Per-token beta + beta = tl.load(beta_ptr + offs_m, mask=mask_m, other=1.0).to(tl.bfloat16) + y_acc = y_acc * beta[:, None] + + # Store + y_ptrs = Y_ptr + offs_m[:, None] * stride_yn + offs_n[None, :] + tl.store(y_ptrs, y_acc, mask=mask_m[:, None] & mask_n[None, :]) + +class _TritonXNORFn(torch.autograd.Function): + """Triton XNOR+POPCOUNT forward with per-group alpha, BF16 backward.""" + @staticmethod + def forward(ctx, x, w, group_size): + N, K = w.shape + assert K % group_size == 0, f"K={K} must be divisible by group_size={group_size}" + assert K % 32 == 0, f"K={K} must be divisible by 32" + num_groups = K // group_size + GROUP_PACKED = group_size // 32 + + # Per-group weight scale: [N, num_groups] + w_alpha = w.detach().reshape(N, num_groups, group_size).abs().mean(dim=-1).clamp_(min=1e-8) + if _SCALE_FP8_STE: + w_alpha = w_alpha.to(torch.float8_e4m3fn).to(w_alpha.dtype) + # Per-token activation scale: [..., 1] + x_beta_kd = x.detach().abs().mean(dim=-1, keepdim=True).clamp_(min=1e-8) + + # Pack signs into int32 + x_packed = pack_signs(x.detach()) # [..., K//32] + w_packed = pack_signs(w.detach()) # [N, K//32] + + orig_shape = x.shape[:-1] + M = orig_shape.numel() + K_packed = K // 32 + + x_2d = x_packed.reshape(M, K_packed).contiguous() + w_alpha_c = w_alpha.contiguous() # [N, num_groups] + beta_flat = x_beta_kd.reshape(M).contiguous() + + Y = torch.empty(M, N, device=x.device, dtype=torch.bfloat16) + + BLOCK_M, BLOCK_N = 64, 64 + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),) + + _xnor_matmul_fwd_kernel[grid]( + x_2d, w_packed, Y, w_alpha_c, beta_flat, + M, N, K_packed, GROUP_PACKED, num_groups, + x_2d.stride(0), w_packed.stride(0), Y.stride(0), w_alpha_c.stride(0), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + ) + + ctx.save_for_backward(x, w, w_alpha, x_beta_kd) + ctx.group_size = group_size + return Y.reshape(*orig_shape, N) + + @staticmethod + def backward(ctx, grad_output): + x, w, w_alpha, x_beta_kd = ctx.saved_tensors + dtype = grad_output.dtype + group_size = ctx.group_size + + # Reconstruct per-group binary weights for backward + N, K = w.shape + num_groups = K // group_size + w_binary = (w.sign().reshape(N, num_groups, group_size) * + w_alpha.unsqueeze(-1)).reshape(N, K).to(dtype) + x_binary = (x.sign() * x_beta_kd).to(dtype) + + grad_x = grad_output @ w_binary + M = grad_output.shape[:-1].numel() + grad_w = grad_output.reshape(M, -1).t() @ x_binary.reshape(M, -1) + return grad_x, grad_w, None # None for group_size + +class XNORLinear(nn.Linear): + """XNOR-Net linear layer: binarizes BOTH weights AND activations. + + Forward pass: + 1. Weight binarization: W_bin = sign(W) * alpha, alpha = mean(|W|) per group + 2. Activation binarization: X_bin = sign(X) * beta, beta = mean(|X|) per token + 3. Output = F.linear(X_bin, W_bin) + + Backward pass: Straight-Through Estimator (STE) — gradients flow through as + if the binarization was identity. This matches the XNOR-Net paper's training + algorithm (Algorithm 1). + + When binarize_act=False, this degrades to Binary-Weight-Network (BWN). + """ + def __init__(self, in_features, out_features, bias=False, group_size=128, + binarize_act=True): + super().__init__(in_features, out_features, bias=bias) + self.group_size = group_size + self.binarize_act = binarize_act + + def forward(self, x: Tensor) -> Tensor: + # Triton XNOR+POPCOUNT path (full XNOR only, per-row alpha) + if _TRITON_KERNEL and self.binarize_act: + y = _TritonXNORFn.apply(x, self.weight, self.group_size) + if self.bias is not None: + y = y + self.bias.to(y.dtype) + return y + # INT8 tensor core path (full XNOR only, per-row alpha) + if _INT8_KERNEL and self.binarize_act: + y = _INT8XNORFn.apply(x, self.weight) + if self.bias is not None: + y = y + self.bias.to(y.dtype) + return y + + # Standard STE path (per-group alpha) + w = self.weight.bfloat16() + g = self.group_size + + # --- Weight binarization (STE) --- + w_g = w.reshape(-1, g) + alpha = w_g.abs().mean(-1, keepdim=True).clamp(min=1e-8) + # Simulate FP8 scale quantization so model compensates for roundtrip error + if _SCALE_FP8_STE: + alpha_q = alpha.to(torch.float8_e4m3fn).to(alpha.dtype) + alpha = alpha + (alpha_q - alpha).detach() + w_sign = torch.where(w_g >= 0, torch.ones_like(w_g), -torch.ones_like(w_g)) + w_binary = w + ((w_sign * alpha).reshape(w.shape) - w).detach() + + # --- Activation binarization (STE) --- + if self.binarize_act: + beta = x.abs().mean(dim=-1, keepdim=True).clamp(min=1e-8) + x_sign = torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x)) + x_binary = x + ((x_sign * beta) - x).detach() + else: + x_binary = x + + return F.linear(x_binary, w_binary, + self.bias.to(x.dtype) if self.bias is not None else None) + + +class NormedXNORLinear(XNORLinear): + """XNOR linear with RMSNorm on input. + Used for output projections receiving un-normalized activations (attention out, MLP down). + The RMSNorm ensures zero-mean input before binarization — critical for sign() quality + (XNOR-Net paper: B-A-C-P block order, Table 3b: 30.3% vs 44.2%). + """ + def forward(self, x: Tensor) -> Tensor: + return super().forward(F.rms_norm(x, (x.size(-1),))) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(p in name for p in CTP)) and param.dtype != torch.float32: + param.data = param.data.float() + +# --------------------------------------------------------------------------- +# Rotary embeddings +# --------------------------------------------------------------------------- +class Rotary(nn.Module): + def __init__(self, dim, base=10000.0, rope_type="rope", yarn_max_len=4096, train_seq_len=1024): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + if rope_type == "yarn": + scale = train_seq_len / yarn_max_len + freq_idx = torch.arange(0, dim, 2, dtype=torch.float32) + ramp = torch.clamp((freq_idx / dim - 0.25) / 0.75, 0.0, 1.0) + inv_freq = inv_freq / (ramp * (1.0 / scale - 1.0) + 1.0) + # Precompute for max seq_len, slice at runtime + t = torch.arange(train_seq_len, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self.register_buffer("_cos", freqs.cos()[None, :, None, :], persistent=False) + self.register_buffer("_sin", freqs.sin()[None, :, None, :], persistent=False) + + def forward(self, seq_len, device, dtype): + return self._cos[:, :seq_len].to(dtype=dtype), self._sin[:, :seq_len].to(dtype=dtype) + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- +class CausalSelfAttention(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + group_size=128, rope_type="rope", yarn_max_len=4096, + train_seq_len=1024, binarize_act=True): + super().__init__() + self.num_heads, self.num_kv_heads = num_heads, num_kv_heads + self.head_dim = dim // num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + # Fused QKV — single XNOR matmul + self.c_qkv = XNORLinear(dim, self.q_size + 2 * self.kv_size, bias=False, + group_size=group_size, binarize_act=binarize_act) + # Output projection with RMSNorm before binarization (B-A-C-P order) + self.proj = NormedXNORLinear(dim, dim, bias=False, + group_size=group_size, binarize_act=binarize_act) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, + rope_type=rope_type, yarn_max_len=yarn_max_len, + train_seq_len=train_seq_len) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + qkv_out = self.c_qkv(x) + q_out, k_out, v_out = qkv_out.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q_out.reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = k_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = v_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + # QK normalization (BF16, stays float — non-negotiable for attention quality) + q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + # FlashAttention-3 (BF16 — attention scores/softmax/V are NEVER binarized) + y = flash_attn_func(q.contiguous(), k.contiguous(), v.contiguous(), causal=True) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + +# --------------------------------------------------------------------------- +# MLP +# --------------------------------------------------------------------------- +class MLP(nn.Module): + def __init__(self, dim, mlp_mult, group_size=128, activation="relu2", + binarize_act=True, binarize_down=True): + super().__init__() + hidden = mlp_mult * dim + self.activation = activation + if activation == "swiglu": + self.gate_up = XNORLinear(dim, hidden * 2, bias=False, + group_size=group_size, binarize_act=binarize_act) + else: + self.fc = XNORLinear(dim, hidden, bias=False, + group_size=group_size, binarize_act=binarize_act) + self.proj = NormedXNORLinear(hidden, dim, bias=False, + group_size=group_size, binarize_act=binarize_down) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + if self.activation == "swiglu": + gu = self.gate_up(x) + gate, up = gu.chunk(2, dim=-1) + return self.proj(F.silu(gate) * up) + elif self.activation == "signsq": + h = self.fc(x) + return self.proj(h * h.abs()) + else: # relu2 + return self.proj(torch.relu(self.fc(x)).square()) + +# --------------------------------------------------------------------------- +# SmearModule +# --------------------------------------------------------------------------- +class SmearModule(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + cumsum = x.cumsum(dim=1) + counts = torch.arange(1, x.size(1) + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smeared = cumsum / counts + gate = torch.tanh(self.gate.to(dtype=x.dtype)) + return x + gate * (smeared - x) + +# --------------------------------------------------------------------------- +# Transformer block +# --------------------------------------------------------------------------- +class Block(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, group_size=128, activation="relu2", + rope_type="rope", yarn_max_len=4096, train_seq_len=1024, + binarize_act=True, binarize_down=True, smear=False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, + qk_gain_init, group_size, rope_type, + yarn_max_len, train_seq_len, binarize_act) + self.mlp = MLP(dim, mlp_mult, group_size, activation, binarize_act, binarize_down) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearModule(dim) if smear else None + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0] * x + mix[1] * x0 + n = self.attn_norm(x) + x = x + self.attn_scale.to(dtype=x.dtype) * self.attn(n) + x = x + self.mlp_scale.to(dtype=x.dtype) * self.mlp(self.mlp_norm(x)) + if self.smear is not None: + x = self.smear(x) + return x + +# --------------------------------------------------------------------------- +# GPT with U-Net skip connections +# --------------------------------------------------------------------------- +class GPT(nn.Module): + def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, + mlp_mult, tie_embeddings, tied_embed_init_std, logit_softcap, + rope_base, qk_gain_init, group_size=128, activation="relu2", + embed_dim=0, fp_storage=False, softcap_type="poly", + rope_type="rope", yarn_max_len=4096, train_seq_len=1024, + binarize_act=True, binarize_down=True, smear=False): + super().__init__() + self.tie_embeddings = tie_embeddings + self.logit_softcap = logit_softcap + self.softcap_type = softcap_type + self.fp_storage = fp_storage + self.embed_dim = embed_dim if embed_dim > 0 else model_dim + + # Embedding (FP8 QAT — NEVER binarized, confirmed 0.11 bpb RT gap) + self.tok_emb = QATEmbedding(vocab_size, self.embed_dim, fp_storage=fp_storage) + # Projection from embed_dim to model_dim (and reverse for tied logits) + self.embed_proj = QATLinear(self.embed_dim, model_dim, bias=False, + fp_storage=fp_storage) if self.embed_dim != model_dim else None + self.embed_proj_rev = QATLinear(model_dim, self.embed_dim, bias=False, + fp_storage=fp_storage) if self.embed_dim != model_dim else None + + # U-Net structure + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + # Skip weights initialized to ONES (not zeros — zeros costs 0.010 bpb) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + + # Transformer blocks + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, group_size, activation, rope_type, yarn_max_len, + train_seq_len, binarize_act, binarize_down, smear) + for _ in range(num_layers) + ]) + self.final_norm = RMSNorm() + + # Logit head (tied to embedding) + self.lm_head = QATLinear(model_dim, vocab_size, bias=False, fp_storage=fp_storage) + self.lm_head._zero_init = True + if tie_embeddings: + self.lm_head.weight.requires_grad_(False) + self.vocab_bias = nn.Parameter(torch.zeros(vocab_size, dtype=torch.float32)) + + self._init_weights(tied_embed_init_std) + + def _init_weights(self, tied_embed_init_std): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=tied_embed_init_std) + for module in self.modules(): + if isinstance(module, XNORLinear) and not getattr(module, "_zero_init", False): + nn.init.normal_(module.weight, mean=0.0, std=0.02) + elif isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _compute_logits(self, x: Tensor) -> Tensor: + if self.tie_embeddings: + proj = self.embed_proj_rev(x) if self.embed_proj_rev is not None else x + return F.linear(proj, self.tok_emb.weight.to(x.dtype)) + return self.lm_head(x) + + def _softcap(self, logits: Tensor) -> Tensor: + s = self.logit_softcap + if self.softcap_type == "tanh": + return s * torch.tanh(logits / s) + # Poly softcap — fuses with torch.compile (tanh does not) + x_sc = torch.clamp(logits / s, -2.0, 2.0) + x2 = x_sc * x_sc + return s * torch.clamp(x_sc * (1.0 - x2 / 3.0 + x2 * x2 / 15.0), -1.0, 1.0) + + def forward(self, input_ids: Tensor, target_ids: Tensor, reduction="mean", temperature=1.0): + x = self.tok_emb(input_ids) + x = x.float() + if self.embed_proj is not None: + x = self.embed_proj(x) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + # U-Net encoder + skips = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + + # U-Net decoder with skip connections + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype) * skips.pop() + x = self.blocks[bi](x, x0) + + x_normed = self.final_norm(x) + x_flat = x_normed.reshape(-1, x_normed.size(-1)) + targets = target_ids.reshape(-1) + logits = self._softcap(self._compute_logits(x_flat)) + + if temperature != 1.0: + logits = logits / temperature + if reduction == "none": + return F.cross_entropy(logits.float(), targets, reduction="none").reshape(input_ids.shape) + + # Fused CE + Z-loss (logsumexp² regularization keeps STE gradients sharp) + logits_f = logits.float() + lse = torch.logsumexp(logits_f, dim=-1) + target_logits = logits_f.gather(1, targets.unsqueeze(1)).squeeze(1) + return (lse - target_logits).mean() + 1e-4 * (lse ** 2).mean() + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- +def build_luts(sp, vocab_size, device): + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + +def ld_val(pattern, seq_len, max_tok=int(os.environ.get("VAL_MAX_TOKENS", 500000))): + files = sorted(glob.glob(pattern)) + assert files, f"No files: {pattern}" + tok = torch.cat([ld_shard(Path(p)) for p in files]).contiguous() + if max_tok > 0: tok = tok[:max_tok + 1] + u = ((tok.numel() - 1) // seq_len) * seq_len + return tok[:u + 1] + +def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, temperature=1.0): + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_seqs = max(1, local_batch_tokens // args.train_seq_len) + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_start in range(seq_start, seq_end, local_batch_seqs): + batch_end = min(batch_start + local_batch_seqs, seq_end) + raw_start = batch_start * args.train_seq_len + raw_end = batch_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64) + x, y = local[:-1].reshape(-1, args.train_seq_len), local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + batch_loss = model(x, y, temperature=temperature).detach() + n = float(y.numel()) + loss_sum += batch_loss.to(torch.float64) * n + token_count += n + prev_ids, tgt_ids = x.reshape(-1), y.reshape(-1) + tok_bytes = base_bytes_lut[tgt_ids].to(torch.int16) + tok_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.int16) + byte_count += tok_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, token_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + val_loss = loss_sum / token_count + bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) + model.train() + return float(val_loss.item()), float(bpb) + +def eval_val_sliding(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, temperature=1.0): + seq_len = args.train_seq_len + batch_size = args.sliding_batch_size + total_tokens = val_tokens.numel() - 1 + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + all_starts = list(range(0, total_tokens - seq_len, stride)) + my_starts = all_starts[rank::world_size] + model.eval() + with torch.inference_mode(): + for i in range(0, len(my_starts), batch_size): + batch_starts = my_starts[i:i + batch_size] + starts_t = torch.tensor(batch_starts, dtype=torch.int64) + offsets = torch.arange(seq_len + 1, dtype=torch.int64) + indices = starts_t.unsqueeze(1) + offsets.unsqueeze(0) + local_batch = val_tokens[indices].to(device=device, dtype=torch.int64, non_blocking=True) + x, y = local_batch[:, :-1], local_batch[:, 1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_token_loss = model(x, y, reduction="none", temperature=temperature).detach() + for b, start in enumerate(batch_starts): + score_from = 0 if start == 0 else seq_len - stride + scored = per_token_loss[b, score_from:] + sx, sy = x[b, score_from:], y[b, score_from:] + loss_sum += scored.to(torch.float64).sum() + token_count += scored.numel() + tok_bytes = base_bytes_lut[sy].to(torch.int16) + tok_bytes += (has_leading_space_lut[sy] & ~is_boundary_token_lut[sx]).to(torch.int16) + byte_count += tok_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, token_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + val_loss = loss_sum / token_count + bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) + model.train() + return float(val_loss.item()), float(bpb) + +def find_temp(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut): + best_t, best_loss = 1.0, float("inf") + for t in [0.85, 0.90, 0.95, 1.00, 1.05, 1.10]: + loss, _ = eval_val(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, temperature=t) + if loss < best_loss: + best_loss = loss + best_t = t + return best_t + +# --------------------------------------------------------------------------- +# Checkpointing +# --------------------------------------------------------------------------- +def _ckpt_path(d, step): return os.path.join(d, f"ckpt_step{step:07d}.pt") + +def _latest_checkpoint(d): + if not os.path.isdir(d): return None + ckpts = sorted(glob.glob(os.path.join(d, "ckpt_step*.pt"))) + return ckpts[-1] if ckpts else None + +def save_checkpoint(d, step, model, optimizers, ms, _untied, ema_model, _ema_started, _ema_steps): + os.makedirs(d, exist_ok=True) + path = _ckpt_path(d, step) + tmp = path + ".tmp" + torch.save({ + "step": step, "training_time_ms": ms, "model": model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "rng_cpu": torch.get_rng_state(), "rng_cuda": torch.cuda.get_rng_state(), + "flags": {"_untied": _untied, "_ema_started": _ema_started, "_ema_steps": _ema_steps}, + "ema_model": ema_model.state_dict() if ema_model is not None else None, + }, tmp) + os.replace(tmp, path) + +def load_checkpoint(path, model, optimizers, device, ema_model=None): + p = torch.load(path, map_location=device, weights_only=False) + model.load_state_dict(p["model"], strict=True) + for opt, sd in zip(optimizers, p["optimizers"]): + opt.load_state_dict(sd) + if ema_model is not None and p.get("ema_model") is not None: + ema_model.load_state_dict(p["ema_model"], strict=True) + torch.set_rng_state(p["rng_cpu"].cpu()) + torch.cuda.set_rng_state(p["rng_cuda"].cpu()) + f = p.get("flags", {}) + return p["step"], p["training_time_ms"], f.get("_untied", False), f.get("_ema_started", False), f.get("_ema_steps", 0) + +# --------------------------------------------------------------------------- +# Main training +# --------------------------------------------------------------------------- +def main(): + args = Hyperparameters() + code = Path(__file__).read_text(encoding="utf-8") + + global _INT8_KERNEL + _INT8_KERNEL = args.use_int8_kernel + + global _TRITON_KERNEL + _TRITON_KERNEL = args.use_triton_kernel and HAS_TRITON + + global _SCALE_FP8_STE + _SCALE_FP8_STE = args.scale_fp8 + + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = world_size > 1 + grad_accum_steps = max(1, 8 // world_size) + grad_scale = 1.0 / grad_accum_steps + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + os.makedirs("logs/cuda/xnor/", exist_ok=True) + logfile = f"logs/cuda/xnor/{args.run_id}.txt" if master_process else None + if master_process: print(logfile) + def log0(msg, console=True): + if not master_process: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + val_tokens = ld_val(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_luts(sp, args.vocab_size, device) + + # --- Model --- + # BINARIZE_ACTIVATIONS: 0=BWN, 1=full XNOR, 2=XNOR but BWN on MLP down proj + _ba = args.binarize_activations + _binarize_act = _ba >= 1 + _binarize_down = _ba == 1 + base_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + group_size=args.xnor_group_size, activation=args.activation_type, + embed_dim=args.embed_dim, fp_storage=args.fp_storage, + softcap_type=args.softcap_type, rope_type=args.rope_type, + yarn_max_len=args.yarn_max_len, train_seq_len=args.train_seq_len, + binarize_act=_binarize_act, binarize_down=_binarize_down, + smear=args.smear, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, nn.Linear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.tie_embeddings: + base_model.lm_head.weight.requires_grad_(False) + + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.recompile_limit = 16 + # No DDP needed for EGGROLL - each GPU evaluates different perturbations + # No torch.compile - forward only, perturbations modify weights in-place + + # --- Load pretrained weights --- + if args.eggroll_load: + log0(f"loading pretrained weights from {args.eggroll_load}") + if args.eggroll_load.endswith(".ptz"): + # Compressed artifact + with open(args.eggroll_load, "rb") as f: + raw = f.read() + method_byte = raw[0] + compressed = raw[1:] + if method_byte == 2 and HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + decompressed = dctx.decompress(compressed) + elif method_byte == 1 and HAS_BROTLI: + decompressed = brotli.decompress(compressed) + else: + decompressed = lzma.decompress(compressed) + loaded = torch.load(io.BytesIO(decompressed), map_location="cpu", weights_only=False) + base_model.load_state_dict(deq_sd(loaded), strict=False) + else: + # Raw checkpoint + ckpt = torch.load(args.eggroll_load, map_location="cpu", weights_only=False) + sd = ckpt["model"] if "model" in ckpt else ckpt + base_model.load_state_dict(sd, strict=False) + log0("pretrained weights loaded") + + # --- EGGROLL setup --- + # No gradients needed - forward only + for p in base_model.parameters(): + p.requires_grad_(False) + binary_params = get_binary_params(base_model) + log0(f"EGGROLL binary params: {len(binary_params)} layers, " + f"{sum(p.numel() for _,p in binary_params)} params") + + # --- Log --- + log0("--- Hyperparameters ---", console=False) + log0(" ".join(f"{a}={getattr(args,a)}" for a in sorted(dir(args)) + if not a.startswith("_") and a not in ("train_files","val_files") + and not callable(getattr(args,a))), console=False) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"XNOR-EGGROLL | params:{n_params} L:{args.num_layers} d:{args.model_dim} " + f"h:{args.num_heads} kv:{args.num_kv_heads} mlp:{args.mlp_mult}x " + f"act:{args.activation_type} xnor_act:{args.binarize_activations} " + f"g:{args.xnor_group_size} ws:{world_size} pop:{args.pop_size} " + f"sigma:{args.eggroll_sigma} rank:{args.eggroll_rank}") + + # --- Data --- + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step, elapsed_ms): + import math + if args.lr_schedule == "cosine": + if max_wallclock_ms is None: + total = args.iterations + warmup = min(args.lr_warmup_steps, total) + if step < warmup: + return step / max(warmup, 1) + progress = (step - warmup) / max(total - warmup, 1) + else: + warmup_ms = args.lr_warmup_steps * (elapsed_ms / max(step, 1)) if step > 0 else 1000.0 + if elapsed_ms < warmup_ms: + return elapsed_ms / max(warmup_ms, 1e-9) + progress = (elapsed_ms - warmup_ms) / max(max_wallclock_ms - warmup_ms, 1e-9) + return 0.5 * (1.0 + math.cos(math.pi * min(progress, 1.0))) + if args.warmdown_fraction <= 0: return 1.0 + if max_wallclock_ms is None: + warmdown_start = int(args.iterations * (1.0 - args.warmdown_fraction)) + return max((args.iterations - step) / max(args.iterations * args.warmdown_fraction, 1), 0.0) if step >= warmdown_start else 1.0 + warmdown_ms = max_wallclock_ms * args.warmdown_fraction + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # --- Seq len scheduling --- + seq_len_stages = [args.train_seq_len] + if args.seq_len_schedule: + seq_len_stages = [] + sl = 128 + while sl <= args.train_seq_len: + seq_len_stages.append(sl) + sl *= 2 + if seq_len_stages[-1] != args.train_seq_len: + seq_len_stages.append(args.train_seq_len) + log0(f"seq_len_schedule: {seq_len_stages}") + + def get_seq_len(step, elapsed_ms): + if len(seq_len_stages) <= 1: + return seq_len_stages[0] + n = len(seq_len_stages) + if max_wallclock_ms is not None: + stage_idx = min(int(elapsed_ms / (max_wallclock_ms / n)), n - 1) + else: + stage_idx = min(int(step / (args.iterations / n)), n - 1) + return seq_len_stages[stage_idx] + + # --- Main EGGROLL loop --- + training_time_ms = 0.0 + stop_after_step = None + pop_per_gpu = args.pop_size // world_size + assert pop_per_gpu % 2 == 0, "pop_size / world_size must be even for antithetic pairs" + pairs_per_gpu = pop_per_gpu // 2 + + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + + base_model.eval() + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(args, base_model, rank, world_size, device, 1, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms") + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms") + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + lr_scale = lr_mul(step, elapsed_ms) + current_seq_len = get_seq_len(step, elapsed_ms) + + # Get data batch (same for all perturbations) + x, y = train_loader.next_batch(args.train_batch_tokens, current_seq_len, 1) + + # Evaluate perturbations (antithetic pairs) + fitnesses_pos = [] + fitnesses_neg = [] + local_seeds = [] + + for i in range(pairs_per_gpu): + global_idx = step * (args.pop_size // 2) + rank * pairs_per_gpu + i + seed = global_idx * 137 + 42 # deterministic seed + + # Positive perturbation + apply_perturbation(binary_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=+1.0, device=device) + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss_pos = base_model(x, y) + fitnesses_pos.append(-loss_pos.item()) + remove_perturbation(binary_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=+1.0, device=device) + + # Negative perturbation (antithetic) + apply_perturbation(binary_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=-1.0, device=device) + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss_neg = base_model(x, y) + fitnesses_neg.append(-loss_neg.item()) + remove_perturbation(binary_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=-1.0, device=device) + + local_seeds.append(seed) + + # Gather fitnesses across GPUs + if distributed: + all_pos = [None] * world_size + all_neg = [None] * world_size + all_seeds_list = [None] * world_size + dist.all_gather_object(all_pos, fitnesses_pos) + dist.all_gather_object(all_neg, fitnesses_neg) + dist.all_gather_object(all_seeds_list, local_seeds) + fitnesses_pos = [f for sublist in all_pos for f in sublist] + fitnesses_neg = [f for sublist in all_neg for f in sublist] + all_seeds = [s for sublist in all_seeds_list for s in sublist] + else: + all_seeds = local_seeds + + # Shape fitnesses + shaped = shape_fitnesses(fitnesses_pos, fitnesses_neg, method=args.fitness_shaping) + + # Update weights + lr = args.eggroll_lr * lr_scale + compute_eggroll_update(binary_params, all_seeds, shaped, args.eggroll_sigma, + args.eggroll_rank, lr, device=device) + + step += 1 + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + if args.train_log_every > 0 and step % args.train_log_every == 0: + mean_fitness = sum(fitnesses_pos) / len(fitnesses_pos) + log0(f"step:{step}/{args.iterations} fitness:{mean_fitness:.4f} " + f"t:{approx_ms:.0f}ms step_avg:{approx_ms/step:.1f}ms") + + # Wallclock cap + if stop_after_step is None and max_wallclock_ms is not None and step % 10 == 0: + reached_cap = approx_ms >= max_wallclock_ms + if distributed: + cap_t = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(cap_t, op=dist.ReduceOp.MAX) + reached_cap = bool(cap_t.item()) + if reached_cap: + stop_after_step = step + + # --- Sign-sort permutation for compression --- + # Permute MLP hidden dimension so same-sign weights cluster together. + # Sort columns of down_proj by their sign pattern (lexicographic on sign bits). + # Apply matching row permutation to up_proj. Model output is identical. + if master_process: + sd = base_model.state_dict() + for layer_idx in range(args.num_layers): + # Find the MLP weight pairs + down_key = f"blocks.{layer_idx}.mlp.proj.weight" + if args.activation_type == "swiglu": + up_key = f"blocks.{layer_idx}.mlp.gate_up.weight" + else: + up_key = f"blocks.{layer_idx}.mlp.fc.weight" + if down_key not in sd or up_key not in sd: + continue + w_down = sd[down_key] # [dim, hidden] + w_up = sd[up_key] # [hidden, dim] (or [hidden*2, dim] for swiglu) + hidden = w_down.shape[1] + # Sort hidden dim columns of w_down by sign pattern + signs = (w_down >= 0).to(torch.int8) # [dim, hidden] + # Pack sign columns into sortable keys (use first 64 rows as sort key) + n_key_rows = min(64, signs.shape[0]) + sort_keys = torch.zeros(hidden, dtype=torch.int64, device=w_down.device) + for r in range(n_key_rows): + sort_keys += signs[r, :].long() << (n_key_rows - 1 - r) + perm = sort_keys.argsort() + # Apply permutation + sd[down_key] = w_down[:, perm] + if args.activation_type == "swiglu": + half = w_up.shape[0] // 2 + sd[up_key] = torch.cat([w_up[:half][perm], w_up[half:][perm]], dim=0) + else: + sd[up_key] = w_up[perm] + if base_model.tie_embeddings: + sd.pop("lm_head.weight", None) + q_obj, q_stats = q_sd(sd, group_size=args.xnor_group_size, + fp_storage=args.fp_storage, scale_fp8=args.scale_fp8) + buf = io.BytesIO() + torch.save(q_obj, buf) + raw_bytes = buf.getvalue() + # Compare LZMA vs Brotli, keep smallest + lzma_blob = lzma.compress(raw_bytes, preset=9) + best_blob, best_method = lzma_blob, "lzma" + comp_log = f"compression: lzma={len(lzma_blob)/1e6:.2f}MB" + if HAS_BROTLI: + brotli_blob = brotli.compress(raw_bytes, quality=11) + comp_log += f" brotli={len(brotli_blob)/1e6:.2f}MB" + if len(brotli_blob) < len(best_blob): + best_blob, best_method = brotli_blob, "brotli" + if HAS_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + zstd_blob = cctx.compress(raw_bytes) + comp_log += f" zstd={len(zstd_blob)/1e6:.2f}MB" + if len(zstd_blob) < len(best_blob): + best_blob, best_method = zstd_blob, "zstd" + log0(comp_log) + log0(f"using: {best_method}") + with open("final_model.xnor.ptz", "wb") as f: + # Method byte: 0=lzma, 1=brotli, 2=zstd + method_id = {"lzma": 0, "brotli": 1, "zstd": 2}[best_method] + f.write(bytes([method_id])) + f.write(best_blob) + artifact_bytes = 1 + len(best_blob) + code_bytes = len(code.encode("utf-8")) + total = artifact_bytes + code_bytes + log0(f"artifact:{artifact_bytes/1e6:.2f}MB binary:{q_stats['binary_params']}({q_stats['binary_bytes']}B) " + f"fp:{q_stats['fp_params']}({q_stats['fp_bytes']}B) code:{code_bytes}") + log0(f"budget:{total}/{16000000} ({total/1e6:.2f}/{16.00:.2f}MB) " + f"{'FITS' if total <= 16000000 else 'OVER'}") + + # --- Roundtrip eval --- + if distributed: dist.barrier() + with open("final_model.xnor.ptz", "rb") as f: + raw = f.read() + method_byte = raw[0] + compressed = raw[1:] + if method_byte == 2 and HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + decompressed = dctx.decompress(compressed) + elif method_byte == 1 and HAS_BROTLI: + decompressed = brotli.decompress(compressed) + else: + decompressed = lzma.decompress(compressed) + loaded = torch.load(io.BytesIO(decompressed), map_location="cpu", weights_only=False) + base_model.load_state_dict(deq_sd(loaded), strict=False) + torch._dynamo.reset() + q_val_loss, q_val_bpb = eval_val(args, base_model, rank, world_size, device, 1, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"final_xnor_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f}") + + opt_temp = 1.0 + if args.temp_scaling: + torch.cuda.synchronize() + t_temp = time.perf_counter() + calibration_tokens = train_loader.stream.take(65536).to(device) + opt_temp = find_temp(args, base_model, rank, world_size, device, 1, + calibration_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + torch.cuda.synchronize() + log0(f"temp_scaling optimal_T:{opt_temp:.2f} time:{1000*(time.perf_counter()-t_temp):.0f}ms") + + if args.sliding_eval: + torch.cuda.synchronize() + t_sl = time.perf_counter() + sw_loss, sw_bpb = eval_val_sliding(args, base_model, rank, world_size, device, 1, + val_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, stride=args.sliding_eval_stride, + temperature=opt_temp) + torch.cuda.synchronize() + log0(f"final_sliding val_loss:{sw_loss:.4f} val_bpb:{sw_bpb:.4f} " + f"(stride={args.sliding_eval_stride}, T={opt_temp:.2f}) time:{1000*(time.perf_counter()-t_sl):.0f}ms") + + if distributed: dist.destroy_process_group() + + +if __name__ == "__main__": + main() +==================================================================================================== +loading pretrained weights from checkpoints_n2_100k/ckpt_step0100000.pt +pretrained weights loaded +EGGROLL binary params: 40 layers, 115343360 params +--- Hyperparameters --- +activation_type=signsq binarize_activations=2 checkpoint_dir=./checkpoints checkpoint_every=10 churn_log_every=0 compile_mode=default data_path=./data/datasets/fineweb10B_sp1024 eggroll_load=checkpoints_n2_100k/ckpt_step0100000.pt eggroll_lr=1e-05 eggroll_rank=8 eggroll_sigma=0.0001 embed_dim=384 fitness_shaping=rank fp_storage=True iterations=100 logit_softcap=10.0 lr_schedule=cosine lr_warmup_steps=0 max_wallclock_seconds=0.0 mlp_mult=4 model_dim=1024 num_heads=8 num_kv_heads=4 num_layers=10 pop_size=4096 qk_gain_init=2.25 rope_base=5000.0 rope_type=yarn run_id=E6_eggroll_8xH100_1775386511 scale_fp8=False seed=42 seq_len_schedule=False sign_compress_reg=0.0 sliding_batch_size=512 sliding_eval=True sliding_eval_stride=16 smear=False softcap_type=poly temp_scaling=True tie_embeddings=1 tied_embed_init_std=0.005 tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model train_batch_tokens=65536 train_log_every=1 train_seq_len=1024 use_int8_kernel=False use_triton_kernel=False val_batch_size=524288 val_loss_every=10 vocab_size=1024 warmdown_fraction=0.3 xnor_group_size=256 yarn_max_len=2048 +XNOR-EGGROLL | params:117618768 L:10 d:1024 h:8 kv:4 mlp:4x act:signsq xnor_act:2 g:256 ws:8 pop:4096 sigma:0.0001 rank:8 +step:0/100 val_loss:2.8153 val_bpb:1.5689 train_time:0ms +step:1/100 fitness:-3.0060 t:35043ms step_avg:35043.1ms +step:2/100 fitness:-3.0587 t:70815ms step_avg:35407.3ms +step:3/100 fitness:-3.1098 t:106725ms step_avg:35575.2ms +step:4/100 fitness:-3.0883 t:141596ms step_avg:35399.0ms +step:5/100 fitness:-3.2052 t:177266ms step_avg:35453.1ms +step:6/100 fitness:-3.1968 t:213010ms step_avg:35501.6ms +step:7/100 fitness:-3.2209 t:248773ms step_avg:35539.0ms +step:8/100 fitness:-3.2207 t:284709ms step_avg:35588.6ms +step:9/100 fitness:-3.2825 t:320220ms step_avg:35580.0ms +step:10/100 fitness:-3.1952 t:355306ms step_avg:35530.6ms +step:10/100 val_loss:3.2174 val_bpb:1.7930 train_time:355314ms +step:11/100 fitness:-3.1612 t:389848ms step_avg:35440.7ms +step:12/100 fitness:-3.2401 t:425054ms step_avg:35421.2ms +step:13/100 fitness:-3.2494 t:460109ms step_avg:35393.0ms +step:14/100 fitness:-3.2711 t:495081ms step_avg:35362.9ms +step:15/100 fitness:-3.2671 t:530287ms step_avg:35352.5ms +step:16/100 fitness:-3.3667 t:565458ms step_avg:35341.1ms diff --git a/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/log_eggroll_last-layer-only.txt b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/log_eggroll_last-layer-only.txt new file mode 100644 index 0000000000..e42e53fee6 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/log_eggroll_last-layer-only.txt @@ -0,0 +1,1518 @@ +"""XNOR-Net LLM training script for OpenAI's Parameter Golf Challenge. +Ciprian-Florin Ifrim - April 2026 + +Based on the Binary-Weight-Network script, extended to full XNOR-Net: +both weights AND activations are binarized using sign() with per-group/per-token +scaling factors derived from the XNOR-Net paper (Rastegari et al. 2016). + +Architecture: U-Net transformer with skip connections that provides error correction +for the information loss inherent in binary quantization of both weights and activations. +""" + +import copy +import glob +import io +import math +import os +import random +import sys +import time +import lzma +from pathlib import Path +try: + import brotli + HAS_BROTLI = True +except ImportError: + HAS_BROTLI = False +try: + import zstandard as zstd + HAS_ZSTD = True +except ImportError: + HAS_ZSTD = False +try: + from gram_newton_schulz import GramNewtonSchulz, POLAR_EXPRESS_COEFFICIENTS + HAS_GRAM_NS = True +except ImportError: + HAS_GRAM_NS = False +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False +def _get_flash_attn_func(): + """GPU-aware FlashAttention selection. + H100: FA3 -> FA2 -> SDPA + 5090: FA4 -> FA2 -> SDPA + Other: FA2 -> SDPA + """ + if torch.cuda.is_available(): + sm = torch.cuda.get_device_capability() + # Hopper (SM 90) — prefer FA3 + if sm[0] == 9 and sm[1] == 0: + try: + from flash_attn_interface import flash_attn_func + return flash_attn_func + except ImportError: + pass + # Blackwell (SM 100/120) — try FA4 + if sm[0] >= 10: + try: + from flash_attn.cute import flash_attn_func + return flash_attn_func + except ImportError: + pass + # FA2 — universal fallback with CUDA kernels + try: + from flash_attn import flash_attn_func + return flash_attn_func + except ImportError: + pass + # PyTorch native SDPA — always works, no install needed + def _sdpa_flash_attn_func(q, k, v, causal=False, **kwargs): + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + o = F.scaled_dot_product_attention(q, k, v, is_causal=causal, enable_gqa=True) + return o.transpose(1, 2) + return _sdpa_flash_attn_func +flash_attn_func = _get_flash_attn_func() + +# --------------------------------------------------------------------------- +# Hyperparameters +# --------------------------------------------------------------------------- +def _e(k, d, t=str): + v = os.environ.get(k, str(d)) + if t == bool: return bool(int(v)) + return t(v) + +class Hyperparameters: + # Data + data_path = _e("DATA_PATH", "./data/datasets/fineweb10B_sp8192") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = _e("TOKENIZER_PATH", "./data/tokenizers/fineweb_8192_bpe.model") + run_id = os.environ.get("RUN_ID", f"xnor_{int(time.time())}") + seed = _e("SEED", 1337, int) + # Compile + compile_mode = _e("COMPILE_MODE", "default") + # Eval + val_batch_size = _e("VAL_BATCH_SIZE", 524288, int) + val_loss_every = _e("VAL_LOSS_EVERY", 500, int) + train_log_every = _e("TRAIN_LOG_EVERY", 10, int) + sliding_eval = _e("SLIDING_EVAL", 0, bool) + sliding_eval_stride = _e("SLIDING_EVAL_STRIDE", 16, int) + sliding_batch_size = _e("SLIDING_BATCH_SIZE", 512, int) + temp_scaling = _e("TEMP_SCALING", 0, bool) + # Training schedule + iterations = _e("ITERATIONS", 2000, int) + warmdown_fraction = _e("WARMDOWN_FRACTION", 0.15, float) + lr_warmup_steps = _e("LR_WARMUP_STEPS", 5, int) + lr_schedule = _e("LR_SCHEDULE", "linear") # "linear" (current) or "cosine" + seq_len_schedule = _e("SEQ_LEN_SCHEDULE", 0, bool) # ramp seq_len from 128 to train_seq_len + train_batch_tokens = _e("TRAIN_BATCH_TOKENS", 524288, int) + train_seq_len = _e("TRAIN_SEQ_LEN", 1024, int) + max_wallclock_seconds = _e("MAX_WALLCLOCK_SECONDS", 0.0, float) + # Architecture — from XNOR-UNet spec + vocab_size = _e("VOCAB_SIZE", 8192, int) + num_layers = _e("NUM_LAYERS", 15, int) + num_kv_heads = _e("NUM_KV_HEADS", 4, int) + model_dim = _e("MODEL_DIM", 768, int) + num_heads = _e("NUM_HEADS", 8, int) + mlp_mult = _e("MLP_MULT", 4, int) + tie_embeddings = _e("TIE_EMBEDDINGS", 1, int) + embed_dim = _e("EMBED_DIM", 312, int) + activation_type = _e("ACTIVATION", "relu2") + # Attention + rope_base = _e("ROPE_BASE", 5000.0, float) + rope_type = _e("ROPE_TYPE", "yarn") + yarn_max_len = _e("YARN_MAX_LEN", 2048, int) + qk_gain_init = _e("QK_GAIN_INIT", 2.25, float) + # Logits + logit_softcap = _e("LOGIT_SOFTCAP", 10.0, float) + softcap_type = _e("SOFTCAP_TYPE", "poly") + tied_embed_init_std = _e("TIED_EMBED_INIT_STD", 0.005, float) + # Architecture extras + smear = _e("SMEAR", 1, bool) + # XNOR quantization + xnor_group_size = _e("XNOR_GROUP_SIZE", 128, int) + # Controls activation binarization: 0 = none (BWN), 1 = all (full XNOR), 2 = skip MLP down proj + binarize_activations = _e("BINARIZE_ACTIVATIONS", 1, int) + # Use INT8 tensor cores for binary matmul (speed test) + use_int8_kernel = _e("USE_INT8_KERNEL", 0, bool) + # Use Triton XNOR+POPCOUNT kernel (true 1-bit matmul) + use_triton_kernel = _e("USE_TRITON_KERNEL", 0, bool) + # FP8 QAT for non-binary params + _fp_raw = os.environ.get("FP_STORAGE", "FP8") + fp_storage = True if _fp_raw == "FP8" else False + # Scale storage for binary weight groups: FP8 or BF16 + _scale_raw = os.environ.get("SCALE_STORAGE", "BF16") + scale_fp8 = True if _scale_raw == "FP8" else False + # Compression regularizer: penalizes sign entropy within groups + sign_compress_reg = _e("SIGN_COMPRESS_REG", 0.0, float) + # EGGROLL + pop_size = _e("POP_SIZE", 256, int) # population size (antithetic pairs = pop_size/2) + eggroll_sigma = _e("EGGROLL_SIGMA", 0.01, float) # perturbation scale + eggroll_lr = _e("EGGROLL_LR", 0.001, float) # learning rate + eggroll_rank = _e("EGGROLL_RANK", 1, int) # rank of perturbation matrices + fitness_shaping = _e("FITNESS_SHAPING", "rank") # "rank" or "sign" + eggroll_load = _e("EGGROLL_LOAD", "") # path to pretrained .ptz to load + eggroll_layers = _e("EGGROLL_LAYERS", 0, int) # 0 = all layers, N = last N layers only + # Checkpointing + checkpoint_every = _e("CHECKPOINT_EVERY", 0, int) + checkpoint_dir = _e("CHECKPOINT_DIR", "./checkpoints") + # Diagnostics + churn_log_every = _e("CHURN_LOG_EVERY", 500, int) + +# Scalar/low-dim parameter names (excluded from Muon, use Adam) +CTP = ("attn_scale", "mlp_scale", "resid_mix", "q_gain", "skip_weight", + "skip_weights", "vocab_bias", "smear") + +# --------------------------------------------------------------------------- +# Binary packing (1 bit per weight, lossless) +# --------------------------------------------------------------------------- +def pack_binary(q: Tensor) -> tuple[bytes, int]: + bits = ((q.reshape(-1).to(torch.int8) + 1) // 2).numpy().astype(np.uint8) + n = len(bits) + pad = (8 - n % 8) % 8 + if pad: + bits = np.concatenate([bits, np.zeros(pad, dtype=np.uint8)]) + groups = bits.reshape(-1, 8) + packed = np.zeros(len(groups), dtype=np.uint8) + for i in range(8): + packed |= groups[:, i] << i + return packed.tobytes(), n + +def unpack_binary(data: bytes, n: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8) + bits = np.zeros((len(packed), 8), dtype=np.int8) + for i in range(8): + bits[:, i] = (packed >> i) & 1 + flat = bits.reshape(-1)[:n] + return torch.from_numpy(flat.astype(np.int8) * 2 - 1) + +# --------------------------------------------------------------------------- +# State dict serialization +# --------------------------------------------------------------------------- +def q_sd(state_dict: dict, group_size: int = 128, fp_storage=False, scale_fp8=False) -> tuple[dict, dict]: + """Binary pack large 2D weight matrices, BF16/FP8 for everything else.""" + quantized = {} + stats = {"binary_params": 0, "binary_bytes": 0, "fp_params": 0, "fp_bytes": 0} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().float().contiguous() + t_orig_shape = list(t.shape) + if t.ndim == 3: + t = t.reshape(t.shape[0], -1) + is_binary = ( + t.ndim == 2 and t.numel() > 65_536 + and "tok_emb" not in name and "lm_head" not in name + and "embed_proj" not in name + ) + if is_binary: + pad = (group_size - t.shape[1] % group_size) % group_size + t_padded = F.pad(t, (0, pad)) if pad > 0 else t + t_grouped = t_padded.reshape(-1, group_size) + scale = t_grouped.abs().mean(-1, keepdim=True).clamp(min=1e-8) + q = torch.where(t_grouped >= 0, + torch.ones_like(t_grouped, dtype=torch.int8), + -torch.ones_like(t_grouped, dtype=torch.int8)) + packed_bytes, n_bits = pack_binary(q) + if scale_fp8: + scale_stored = scale.to(torch.float8_e4m3fn).squeeze(-1) + scale_bytes_per = 1 + else: + scale_stored = scale.bfloat16().squeeze(-1) + scale_bytes_per = 2 + quantized[name] = { + "type": "binary", "packed": packed_bytes, + "scale": scale_stored, + "shape": list(t.shape), "padded_cols": t_padded.shape[1], + "group_size": group_size, "n_bits": n_bits, + "orig_shape": t_orig_shape, + } + stats["binary_params"] += t.numel() + stats["binary_bytes"] += len(packed_bytes) + scale.numel() * scale_bytes_per + elif fp_storage and t.ndim == 2: + quantized[name] = {"type": "fp8", "data": t.to(torch.float8_e4m3fn)} + stats["fp_params"] += t.numel() + stats["fp_bytes"] += t.numel() + else: + quantized[name] = {"type": "bf16", "data": t.bfloat16()} + stats["fp_params"] += t.numel() + stats["fp_bytes"] += t.numel() * 2 + return quantized, stats + +def deq_sd(quantized: dict, target_dtype=torch.bfloat16): + """Reconstruct state dict from quantized representation.""" + out = {} + for name, entry in quantized.items(): + if entry["type"] == "binary": + q = unpack_binary(entry["packed"], entry["n_bits"]) + q = q.float().reshape(-1, entry["group_size"]) + scale = entry["scale"].float().unsqueeze(-1) + # No shrinkage correction: binary q_absmean = 1.0 always + t = (q * scale).reshape(-1, entry["padded_cols"]) + shape = entry["shape"] + result = t[:shape[0], :shape[1]].to(target_dtype) + orig = entry.get("orig_shape") + out[name] = result.reshape(orig).contiguous() if orig and orig != shape else result.contiguous() + elif entry["type"] == "fp8": + out[name] = entry["data"].to(torch.float32).to(target_dtype).contiguous() + else: + out[name] = entry["data"].to(target_dtype).contiguous() + return out + +# --------------------------------------------------------------------------- +# Diagnostics +# --------------------------------------------------------------------------- +_prev_committed: dict = {} +def churn_fn(model: nn.Module, group_size: int = 128): + global _prev_committed + total = flipped = 0 + with torch.no_grad(): + for name, p in model.named_parameters(): + if p.ndim == 2 and "weight" in name and p.shape[0] > 1: + w = p.detach().float().reshape(-1, group_size) + q = torch.where(w >= 0, torch.ones_like(w), -torch.ones_like(w)).cpu().numpy() + if name in _prev_committed: + flipped += int(np.sum(q != _prev_committed[name])) + total += q.size + _prev_committed[name] = q + return flipped / max(total, 1) + +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# Muon optimizer +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# EGGROLL: Evolution-Guided optimization with low-rank perturbations +# --------------------------------------------------------------------------- +def get_binary_params(model, num_layers=10, last_n_layers=0): + """Get list of (name, param) for parameters that will be perturbed. + last_n_layers=0 means all layers, N means only the last N blocks.""" + params = [] + for name, p in model.named_parameters(): + if p.ndim == 2 and p.numel() > 65536 and "tok_emb" not in name and "lm_head" not in name and "embed_proj" not in name: + if last_n_layers > 0: + # Only include params from last N blocks + match = False + for layer_idx in range(num_layers - last_n_layers, num_layers): + if f"blocks.{layer_idx}." in name: + match = True + break + if not match: + continue + params.append((name, p)) + return params + +def apply_perturbation(binary_params, seed, sigma, rank, sign=1.0, device="cuda"): + """Apply rank-r perturbation to binary params using deterministic RNG.""" + rng = torch.Generator(device=device) + rng.manual_seed(seed) + for name, p in binary_params: + m, n = p.shape + for r in range(rank): + a = torch.randn(m, device=device, generator=rng) + b = torch.randn(n, device=device, generator=rng) + p.data.add_(a.outer(b), alpha=sign * sigma / math.sqrt(rank)) + +def remove_perturbation(binary_params, seed, sigma, rank, sign=1.0, device="cuda"): + """Remove previously applied perturbation (subtract it).""" + apply_perturbation(binary_params, seed, sigma, rank, sign=-sign, device=device) + +def compute_eggroll_update(binary_params, seeds, fitnesses, sigma, rank, lr, device="cuda"): + """Compute and apply the EGGROLL weight update. + Uses the efficient (diag(f) @ A).T @ B formulation for rank-1.""" + N = len(seeds) + for name, p in binary_params: + m, n = p.shape + update = torch.zeros_like(p.data) + for i, (seed, f) in enumerate(zip(seeds, fitnesses)): + rng = torch.Generator(device=device) + rng.manual_seed(seed) + for r in range(rank): + a = torch.randn(m, device=device, generator=rng) + b = torch.randn(n, device=device, generator=rng) + update.add_(a.outer(b), alpha=f / (math.sqrt(rank) * N)) + p.data.add_(update, alpha=lr / sigma) + +def shape_fitnesses(fitnesses_pos, fitnesses_neg, method="rank"): + """Shape fitnesses from antithetic pairs.""" + N = len(fitnesses_pos) + if method == "sign": + return [float(fp > fn) - float(fn > fp) for fp, fn in zip(fitnesses_pos, fitnesses_neg)] + # Rank-based shaping: combine all, rank, normalize to [-0.5, 0.5] + all_f = list(fitnesses_pos) + list(fitnesses_neg) + ranked = torch.argsort(torch.argsort(torch.tensor(all_f, dtype=torch.float32))).float() + ranked = ranked / (2 * N - 1) - 0.5 + # Return shaped fitness for positive perturbation (subtract negative's) + shaped = [(ranked[i] - ranked[N + i]).item() for i in range(N)] + return shaped + +# --------------------------------------------------------------------------- +# Data loading +# --------------------------------------------------------------------------- +def ld_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" Tensor: + chunks = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int): + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start:start + per_rank_span].pin_memory().to(self.device, non_blocking=True).to(torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x, y + +# --------------------------------------------------------------------------- +# Model components +# --------------------------------------------------------------------------- +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def apply_qat_ste(w: Tensor, fp_storage) -> Tensor: + if not fp_storage: + return w + w_sim = w.to(torch.float8_e4m3fn).to(w.dtype) + return (w_sim - w).detach() + w + +class QATLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=False, fp_storage=False): + super().__init__(in_features, out_features, bias=bias) + self.fp_storage = fp_storage + def forward(self, x: Tensor) -> Tensor: + w = apply_qat_ste(self.weight, self.fp_storage) + return F.linear(x, w.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) + +class QATEmbedding(nn.Embedding): + def __init__(self, num_embeddings, embedding_dim, fp_storage=False): + super().__init__(num_embeddings, embedding_dim) + self.fp_storage = fp_storage + def forward(self, input: Tensor) -> Tensor: + w = apply_qat_ste(self.weight, self.fp_storage) + return F.embedding(input, w, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + +# --------------------------------------------------------------------------- +# XNOR-Net Linear Layers +# --------------------------------------------------------------------------- +_INT8_KERNEL = False # set from Hyperparameters.use_int8_kernel in main() + +class _INT8XNORFn(torch.autograd.Function): + """INT8 tensor core forward, BF16 backward. + Uses per-row alpha (not per-group) so we get a single torch._int_mm call. + """ + @staticmethod + def forward(ctx, x, w): + # Per-row weight scale (single matmul) + w_alpha = w.detach().abs().mean(dim=-1).clamp_(min=1e-8) # [N] + x_beta = x.detach().abs().mean(dim=-1, keepdim=True).clamp_(min=1e-8) # [..., 1] + + w_int8 = w.detach().sign().to(torch.int8) # [N, K] + x_int8 = x.detach().sign().to(torch.int8) # [..., K] + + # Flatten to 2D for torch._int_mm(A[M,K], B[K,N]) -> [M,N] + orig_shape = x.shape[:-1] + M, K, N = orig_shape.numel(), x.shape[-1], w.shape[0] + x_2d = x_int8.reshape(M, K) + wt = w_int8.t().contiguous() # [K, N] + raw = torch._int_mm(x_2d, wt) # [M, N] INT32 + + y = raw.to(x.dtype) * x_beta.reshape(M, 1) * w_alpha.unsqueeze(0) + ctx.save_for_backward(x, w, w_alpha, x_beta) + return y.reshape(*orig_shape, N) + + @staticmethod + def backward(ctx, grad_output): + x, w, w_alpha, x_beta = ctx.saved_tensors + dtype = grad_output.dtype + + w_binary = (w.sign() * w_alpha.unsqueeze(-1)).to(dtype) # [N, K] + x_binary = (x.sign() * x_beta).to(dtype) # [..., K] + + # STE: gradient flows through sign() as identity + grad_x = grad_output @ w_binary + M = grad_output.shape[:-1].numel() + grad_w = grad_output.reshape(M, -1).t() @ x_binary.reshape(M, -1) + return grad_x, grad_w + +# --------------------------------------------------------------------------- +# Triton XNOR+POPCOUNT kernel — true 1-bit matmul +# --------------------------------------------------------------------------- +_TRITON_KERNEL = False # set from Hyperparameters in main() +_SCALE_FP8_STE = False # simulate FP8 scale quantization during training + +def pack_signs(x: Tensor) -> Tensor: + """Pack sign bits into int32. +1 -> bit 1, -1 -> bit 0. + x: [..., K] any dtype. K must be divisible by 32. + returns: [..., K//32] int32 + """ + K = x.shape[-1] + assert K % 32 == 0, f"K={K} must be divisible by 32" + bits = (x > 0).view(*x.shape[:-1], K // 32, 32) + shifts = torch.arange(32, device=x.device, dtype=torch.int32) + return (bits.to(torch.int32) << shifts).sum(dim=-1) + +if HAS_TRITON: + @triton.jit + def _xnor_matmul_fwd_kernel( + X_ptr, W_ptr, Y_ptr, alpha_ptr, beta_ptr, + M, N, K_packed, GROUP_PACKED: tl.constexpr, num_groups, + stride_xm, stride_wn, stride_yn, stride_alpha_n, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + ): + pid = tl.program_id(0) + num_n = tl.cdiv(N, BLOCK_N) + pid_m = pid // num_n + pid_n = pid % num_n + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_m = offs_m < M + mask_n = offs_n < N + + y_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.bfloat16) + + x_base = X_ptr + offs_m * stride_xm + w_base = W_ptr + offs_n * stride_wn + + # Loop over groups, each with GROUP_PACKED packed int32s + for g in range(num_groups): + group_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + k_start = g * GROUP_PACKED + for k in range(GROUP_PACKED): + x = tl.load(x_base + k_start + k, mask=mask_m, other=0) + w = tl.load(w_base + k_start + k, mask=mask_n, other=0) + # Cast broadcast XOR to int32 to prevent int64 promotion in popc + diff = (x[:, None] ^ w[None, :]).to(tl.int32) + group_acc += tl.extra.cuda.libdevice.popc(diff) + + # diff_bits = popc(XOR), dot = GROUP_SIZE - 2 * diff_bits + group_dot = (GROUP_PACKED * 32 - 2 * group_acc).to(tl.bfloat16) + + # Per-group alpha: alpha[n, g] + alpha_g = tl.load(alpha_ptr + offs_n * stride_alpha_n + g, mask=mask_n, other=1.0).to(tl.bfloat16) + y_acc += group_dot * alpha_g[None, :] + + # Per-token beta + beta = tl.load(beta_ptr + offs_m, mask=mask_m, other=1.0).to(tl.bfloat16) + y_acc = y_acc * beta[:, None] + + # Store + y_ptrs = Y_ptr + offs_m[:, None] * stride_yn + offs_n[None, :] + tl.store(y_ptrs, y_acc, mask=mask_m[:, None] & mask_n[None, :]) + +class _TritonXNORFn(torch.autograd.Function): + """Triton XNOR+POPCOUNT forward with per-group alpha, BF16 backward.""" + @staticmethod + def forward(ctx, x, w, group_size): + N, K = w.shape + assert K % group_size == 0, f"K={K} must be divisible by group_size={group_size}" + assert K % 32 == 0, f"K={K} must be divisible by 32" + num_groups = K // group_size + GROUP_PACKED = group_size // 32 + + # Per-group weight scale: [N, num_groups] + w_alpha = w.detach().reshape(N, num_groups, group_size).abs().mean(dim=-1).clamp_(min=1e-8) + if _SCALE_FP8_STE: + w_alpha = w_alpha.to(torch.float8_e4m3fn).to(w_alpha.dtype) + # Per-token activation scale: [..., 1] + x_beta_kd = x.detach().abs().mean(dim=-1, keepdim=True).clamp_(min=1e-8) + + # Pack signs into int32 + x_packed = pack_signs(x.detach()) # [..., K//32] + w_packed = pack_signs(w.detach()) # [N, K//32] + + orig_shape = x.shape[:-1] + M = orig_shape.numel() + K_packed = K // 32 + + x_2d = x_packed.reshape(M, K_packed).contiguous() + w_alpha_c = w_alpha.contiguous() # [N, num_groups] + beta_flat = x_beta_kd.reshape(M).contiguous() + + Y = torch.empty(M, N, device=x.device, dtype=torch.bfloat16) + + BLOCK_M, BLOCK_N = 64, 64 + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),) + + _xnor_matmul_fwd_kernel[grid]( + x_2d, w_packed, Y, w_alpha_c, beta_flat, + M, N, K_packed, GROUP_PACKED, num_groups, + x_2d.stride(0), w_packed.stride(0), Y.stride(0), w_alpha_c.stride(0), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + ) + + ctx.save_for_backward(x, w, w_alpha, x_beta_kd) + ctx.group_size = group_size + return Y.reshape(*orig_shape, N) + + @staticmethod + def backward(ctx, grad_output): + x, w, w_alpha, x_beta_kd = ctx.saved_tensors + dtype = grad_output.dtype + group_size = ctx.group_size + + # Reconstruct per-group binary weights for backward + N, K = w.shape + num_groups = K // group_size + w_binary = (w.sign().reshape(N, num_groups, group_size) * + w_alpha.unsqueeze(-1)).reshape(N, K).to(dtype) + x_binary = (x.sign() * x_beta_kd).to(dtype) + + grad_x = grad_output @ w_binary + M = grad_output.shape[:-1].numel() + grad_w = grad_output.reshape(M, -1).t() @ x_binary.reshape(M, -1) + return grad_x, grad_w, None # None for group_size + +class XNORLinear(nn.Linear): + """XNOR-Net linear layer: binarizes BOTH weights AND activations. + + Forward pass: + 1. Weight binarization: W_bin = sign(W) * alpha, alpha = mean(|W|) per group + 2. Activation binarization: X_bin = sign(X) * beta, beta = mean(|X|) per token + 3. Output = F.linear(X_bin, W_bin) + + Backward pass: Straight-Through Estimator (STE) — gradients flow through as + if the binarization was identity. This matches the XNOR-Net paper's training + algorithm (Algorithm 1). + + When binarize_act=False, this degrades to Binary-Weight-Network (BWN). + """ + def __init__(self, in_features, out_features, bias=False, group_size=128, + binarize_act=True): + super().__init__(in_features, out_features, bias=bias) + self.group_size = group_size + self.binarize_act = binarize_act + + def forward(self, x: Tensor) -> Tensor: + # Triton XNOR+POPCOUNT path (full XNOR only, per-row alpha) + if _TRITON_KERNEL and self.binarize_act: + y = _TritonXNORFn.apply(x, self.weight, self.group_size) + if self.bias is not None: + y = y + self.bias.to(y.dtype) + return y + # INT8 tensor core path (full XNOR only, per-row alpha) + if _INT8_KERNEL and self.binarize_act: + y = _INT8XNORFn.apply(x, self.weight) + if self.bias is not None: + y = y + self.bias.to(y.dtype) + return y + + # Standard STE path (per-group alpha) + w = self.weight.bfloat16() + g = self.group_size + + # --- Weight binarization (STE) --- + w_g = w.reshape(-1, g) + alpha = w_g.abs().mean(-1, keepdim=True).clamp(min=1e-8) + # Simulate FP8 scale quantization so model compensates for roundtrip error + if _SCALE_FP8_STE: + alpha_q = alpha.to(torch.float8_e4m3fn).to(alpha.dtype) + alpha = alpha + (alpha_q - alpha).detach() + w_sign = torch.where(w_g >= 0, torch.ones_like(w_g), -torch.ones_like(w_g)) + w_binary = w + ((w_sign * alpha).reshape(w.shape) - w).detach() + + # --- Activation binarization (STE) --- + if self.binarize_act: + beta = x.abs().mean(dim=-1, keepdim=True).clamp(min=1e-8) + x_sign = torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x)) + x_binary = x + ((x_sign * beta) - x).detach() + else: + x_binary = x + + return F.linear(x_binary, w_binary, + self.bias.to(x.dtype) if self.bias is not None else None) + + +class NormedXNORLinear(XNORLinear): + """XNOR linear with RMSNorm on input. + Used for output projections receiving un-normalized activations (attention out, MLP down). + The RMSNorm ensures zero-mean input before binarization — critical for sign() quality + (XNOR-Net paper: B-A-C-P block order, Table 3b: 30.3% vs 44.2%). + """ + def forward(self, x: Tensor) -> Tensor: + return super().forward(F.rms_norm(x, (x.size(-1),))) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(p in name for p in CTP)) and param.dtype != torch.float32: + param.data = param.data.float() + +# --------------------------------------------------------------------------- +# Rotary embeddings +# --------------------------------------------------------------------------- +class Rotary(nn.Module): + def __init__(self, dim, base=10000.0, rope_type="rope", yarn_max_len=4096, train_seq_len=1024): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + if rope_type == "yarn": + scale = train_seq_len / yarn_max_len + freq_idx = torch.arange(0, dim, 2, dtype=torch.float32) + ramp = torch.clamp((freq_idx / dim - 0.25) / 0.75, 0.0, 1.0) + inv_freq = inv_freq / (ramp * (1.0 / scale - 1.0) + 1.0) + # Precompute for max seq_len, slice at runtime + t = torch.arange(train_seq_len, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self.register_buffer("_cos", freqs.cos()[None, :, None, :], persistent=False) + self.register_buffer("_sin", freqs.sin()[None, :, None, :], persistent=False) + + def forward(self, seq_len, device, dtype): + return self._cos[:, :seq_len].to(dtype=dtype), self._sin[:, :seq_len].to(dtype=dtype) + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- +class CausalSelfAttention(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + group_size=128, rope_type="rope", yarn_max_len=4096, + train_seq_len=1024, binarize_act=True): + super().__init__() + self.num_heads, self.num_kv_heads = num_heads, num_kv_heads + self.head_dim = dim // num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + # Fused QKV — single XNOR matmul + self.c_qkv = XNORLinear(dim, self.q_size + 2 * self.kv_size, bias=False, + group_size=group_size, binarize_act=binarize_act) + # Output projection with RMSNorm before binarization (B-A-C-P order) + self.proj = NormedXNORLinear(dim, dim, bias=False, + group_size=group_size, binarize_act=binarize_act) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, + rope_type=rope_type, yarn_max_len=yarn_max_len, + train_seq_len=train_seq_len) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + qkv_out = self.c_qkv(x) + q_out, k_out, v_out = qkv_out.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q_out.reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = k_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = v_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + # QK normalization (BF16, stays float — non-negotiable for attention quality) + q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + # FlashAttention-3 (BF16 — attention scores/softmax/V are NEVER binarized) + y = flash_attn_func(q.contiguous(), k.contiguous(), v.contiguous(), causal=True) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + +# --------------------------------------------------------------------------- +# MLP +# --------------------------------------------------------------------------- +class MLP(nn.Module): + def __init__(self, dim, mlp_mult, group_size=128, activation="relu2", + binarize_act=True, binarize_down=True): + super().__init__() + hidden = mlp_mult * dim + self.activation = activation + if activation == "swiglu": + self.gate_up = XNORLinear(dim, hidden * 2, bias=False, + group_size=group_size, binarize_act=binarize_act) + else: + self.fc = XNORLinear(dim, hidden, bias=False, + group_size=group_size, binarize_act=binarize_act) + self.proj = NormedXNORLinear(hidden, dim, bias=False, + group_size=group_size, binarize_act=binarize_down) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + if self.activation == "swiglu": + gu = self.gate_up(x) + gate, up = gu.chunk(2, dim=-1) + return self.proj(F.silu(gate) * up) + elif self.activation == "signsq": + h = self.fc(x) + return self.proj(h * h.abs()) + else: # relu2 + return self.proj(torch.relu(self.fc(x)).square()) + +# --------------------------------------------------------------------------- +# SmearModule +# --------------------------------------------------------------------------- +class SmearModule(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + cumsum = x.cumsum(dim=1) + counts = torch.arange(1, x.size(1) + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smeared = cumsum / counts + gate = torch.tanh(self.gate.to(dtype=x.dtype)) + return x + gate * (smeared - x) + +# --------------------------------------------------------------------------- +# Transformer block +# --------------------------------------------------------------------------- +class Block(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, group_size=128, activation="relu2", + rope_type="rope", yarn_max_len=4096, train_seq_len=1024, + binarize_act=True, binarize_down=True, smear=False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, + qk_gain_init, group_size, rope_type, + yarn_max_len, train_seq_len, binarize_act) + self.mlp = MLP(dim, mlp_mult, group_size, activation, binarize_act, binarize_down) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearModule(dim) if smear else None + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0] * x + mix[1] * x0 + n = self.attn_norm(x) + x = x + self.attn_scale.to(dtype=x.dtype) * self.attn(n) + x = x + self.mlp_scale.to(dtype=x.dtype) * self.mlp(self.mlp_norm(x)) + if self.smear is not None: + x = self.smear(x) + return x + +# --------------------------------------------------------------------------- +# GPT with U-Net skip connections +# --------------------------------------------------------------------------- +class GPT(nn.Module): + def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, + mlp_mult, tie_embeddings, tied_embed_init_std, logit_softcap, + rope_base, qk_gain_init, group_size=128, activation="relu2", + embed_dim=0, fp_storage=False, softcap_type="poly", + rope_type="rope", yarn_max_len=4096, train_seq_len=1024, + binarize_act=True, binarize_down=True, smear=False): + super().__init__() + self.tie_embeddings = tie_embeddings + self.logit_softcap = logit_softcap + self.softcap_type = softcap_type + self.fp_storage = fp_storage + self.embed_dim = embed_dim if embed_dim > 0 else model_dim + + # Embedding (FP8 QAT — NEVER binarized, confirmed 0.11 bpb RT gap) + self.tok_emb = QATEmbedding(vocab_size, self.embed_dim, fp_storage=fp_storage) + # Projection from embed_dim to model_dim (and reverse for tied logits) + self.embed_proj = QATLinear(self.embed_dim, model_dim, bias=False, + fp_storage=fp_storage) if self.embed_dim != model_dim else None + self.embed_proj_rev = QATLinear(model_dim, self.embed_dim, bias=False, + fp_storage=fp_storage) if self.embed_dim != model_dim else None + + # U-Net structure + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + # Skip weights initialized to ONES (not zeros — zeros costs 0.010 bpb) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + + # Transformer blocks + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, group_size, activation, rope_type, yarn_max_len, + train_seq_len, binarize_act, binarize_down, smear) + for _ in range(num_layers) + ]) + self.final_norm = RMSNorm() + + # Logit head (tied to embedding) + self.lm_head = QATLinear(model_dim, vocab_size, bias=False, fp_storage=fp_storage) + self.lm_head._zero_init = True + if tie_embeddings: + self.lm_head.weight.requires_grad_(False) + self.vocab_bias = nn.Parameter(torch.zeros(vocab_size, dtype=torch.float32)) + + self._init_weights(tied_embed_init_std) + + def _init_weights(self, tied_embed_init_std): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=tied_embed_init_std) + for module in self.modules(): + if isinstance(module, XNORLinear) and not getattr(module, "_zero_init", False): + nn.init.normal_(module.weight, mean=0.0, std=0.02) + elif isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _compute_logits(self, x: Tensor) -> Tensor: + if self.tie_embeddings: + proj = self.embed_proj_rev(x) if self.embed_proj_rev is not None else x + return F.linear(proj, self.tok_emb.weight.to(x.dtype)) + return self.lm_head(x) + + def _softcap(self, logits: Tensor) -> Tensor: + s = self.logit_softcap + if self.softcap_type == "tanh": + return s * torch.tanh(logits / s) + # Poly softcap — fuses with torch.compile (tanh does not) + x_sc = torch.clamp(logits / s, -2.0, 2.0) + x2 = x_sc * x_sc + return s * torch.clamp(x_sc * (1.0 - x2 / 3.0 + x2 * x2 / 15.0), -1.0, 1.0) + + def forward(self, input_ids: Tensor, target_ids: Tensor, reduction="mean", temperature=1.0): + x = self.tok_emb(input_ids) + x = x.float() + if self.embed_proj is not None: + x = self.embed_proj(x) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + # U-Net encoder + skips = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + + # U-Net decoder with skip connections + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype) * skips.pop() + x = self.blocks[bi](x, x0) + + x_normed = self.final_norm(x) + x_flat = x_normed.reshape(-1, x_normed.size(-1)) + targets = target_ids.reshape(-1) + logits = self._softcap(self._compute_logits(x_flat)) + + if temperature != 1.0: + logits = logits / temperature + if reduction == "none": + return F.cross_entropy(logits.float(), targets, reduction="none").reshape(input_ids.shape) + + # Fused CE + Z-loss (logsumexp² regularization keeps STE gradients sharp) + logits_f = logits.float() + lse = torch.logsumexp(logits_f, dim=-1) + target_logits = logits_f.gather(1, targets.unsqueeze(1)).squeeze(1) + return (lse - target_logits).mean() + 1e-4 * (lse ** 2).mean() + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- +def build_luts(sp, vocab_size, device): + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + +def ld_val(pattern, seq_len, max_tok=int(os.environ.get("VAL_MAX_TOKENS", 500000))): + files = sorted(glob.glob(pattern)) + assert files, f"No files: {pattern}" + tok = torch.cat([ld_shard(Path(p)) for p in files]).contiguous() + if max_tok > 0: tok = tok[:max_tok + 1] + u = ((tok.numel() - 1) // seq_len) * seq_len + return tok[:u + 1] + +def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, temperature=1.0): + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_seqs = max(1, local_batch_tokens // args.train_seq_len) + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_start in range(seq_start, seq_end, local_batch_seqs): + batch_end = min(batch_start + local_batch_seqs, seq_end) + raw_start = batch_start * args.train_seq_len + raw_end = batch_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64) + x, y = local[:-1].reshape(-1, args.train_seq_len), local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + batch_loss = model(x, y, temperature=temperature).detach() + n = float(y.numel()) + loss_sum += batch_loss.to(torch.float64) * n + token_count += n + prev_ids, tgt_ids = x.reshape(-1), y.reshape(-1) + tok_bytes = base_bytes_lut[tgt_ids].to(torch.int16) + tok_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.int16) + byte_count += tok_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, token_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + val_loss = loss_sum / token_count + bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) + model.train() + return float(val_loss.item()), float(bpb) + +def eval_val_sliding(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, temperature=1.0): + seq_len = args.train_seq_len + batch_size = args.sliding_batch_size + total_tokens = val_tokens.numel() - 1 + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + all_starts = list(range(0, total_tokens - seq_len, stride)) + my_starts = all_starts[rank::world_size] + model.eval() + with torch.inference_mode(): + for i in range(0, len(my_starts), batch_size): + batch_starts = my_starts[i:i + batch_size] + starts_t = torch.tensor(batch_starts, dtype=torch.int64) + offsets = torch.arange(seq_len + 1, dtype=torch.int64) + indices = starts_t.unsqueeze(1) + offsets.unsqueeze(0) + local_batch = val_tokens[indices].to(device=device, dtype=torch.int64, non_blocking=True) + x, y = local_batch[:, :-1], local_batch[:, 1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_token_loss = model(x, y, reduction="none", temperature=temperature).detach() + for b, start in enumerate(batch_starts): + score_from = 0 if start == 0 else seq_len - stride + scored = per_token_loss[b, score_from:] + sx, sy = x[b, score_from:], y[b, score_from:] + loss_sum += scored.to(torch.float64).sum() + token_count += scored.numel() + tok_bytes = base_bytes_lut[sy].to(torch.int16) + tok_bytes += (has_leading_space_lut[sy] & ~is_boundary_token_lut[sx]).to(torch.int16) + byte_count += tok_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, token_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + val_loss = loss_sum / token_count + bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) + model.train() + return float(val_loss.item()), float(bpb) + +def find_temp(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut): + best_t, best_loss = 1.0, float("inf") + for t in [0.85, 0.90, 0.95, 1.00, 1.05, 1.10]: + loss, _ = eval_val(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, temperature=t) + if loss < best_loss: + best_loss = loss + best_t = t + return best_t + +# --------------------------------------------------------------------------- +# Checkpointing +# --------------------------------------------------------------------------- +def _ckpt_path(d, step): return os.path.join(d, f"ckpt_step{step:07d}.pt") + +def _latest_checkpoint(d): + if not os.path.isdir(d): return None + ckpts = sorted(glob.glob(os.path.join(d, "ckpt_step*.pt"))) + return ckpts[-1] if ckpts else None + +def save_checkpoint(d, step, model, optimizers, ms, _untied, ema_model, _ema_started, _ema_steps): + os.makedirs(d, exist_ok=True) + path = _ckpt_path(d, step) + tmp = path + ".tmp" + torch.save({ + "step": step, "training_time_ms": ms, "model": model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "rng_cpu": torch.get_rng_state(), "rng_cuda": torch.cuda.get_rng_state(), + "flags": {"_untied": _untied, "_ema_started": _ema_started, "_ema_steps": _ema_steps}, + "ema_model": ema_model.state_dict() if ema_model is not None else None, + }, tmp) + os.replace(tmp, path) + +def load_checkpoint(path, model, optimizers, device, ema_model=None): + p = torch.load(path, map_location=device, weights_only=False) + model.load_state_dict(p["model"], strict=True) + for opt, sd in zip(optimizers, p["optimizers"]): + opt.load_state_dict(sd) + if ema_model is not None and p.get("ema_model") is not None: + ema_model.load_state_dict(p["ema_model"], strict=True) + torch.set_rng_state(p["rng_cpu"].cpu()) + torch.cuda.set_rng_state(p["rng_cuda"].cpu()) + f = p.get("flags", {}) + return p["step"], p["training_time_ms"], f.get("_untied", False), f.get("_ema_started", False), f.get("_ema_steps", 0) + +# --------------------------------------------------------------------------- +# Main training +# --------------------------------------------------------------------------- +def main(): + args = Hyperparameters() + code = Path(__file__).read_text(encoding="utf-8") + + global _INT8_KERNEL + _INT8_KERNEL = args.use_int8_kernel + + global _TRITON_KERNEL + _TRITON_KERNEL = args.use_triton_kernel and HAS_TRITON + + global _SCALE_FP8_STE + _SCALE_FP8_STE = args.scale_fp8 + + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = world_size > 1 + grad_accum_steps = max(1, 8 // world_size) + grad_scale = 1.0 / grad_accum_steps + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + os.makedirs("logs/cuda/xnor/", exist_ok=True) + logfile = f"logs/cuda/xnor/{args.run_id}.txt" if master_process else None + if master_process: print(logfile) + def log0(msg, console=True): + if not master_process: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + val_tokens = ld_val(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_luts(sp, args.vocab_size, device) + + # --- Model --- + # BINARIZE_ACTIVATIONS: 0=BWN, 1=full XNOR, 2=XNOR but BWN on MLP down proj + _ba = args.binarize_activations + _binarize_act = _ba >= 1 + _binarize_down = _ba == 1 + base_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + group_size=args.xnor_group_size, activation=args.activation_type, + embed_dim=args.embed_dim, fp_storage=args.fp_storage, + softcap_type=args.softcap_type, rope_type=args.rope_type, + yarn_max_len=args.yarn_max_len, train_seq_len=args.train_seq_len, + binarize_act=_binarize_act, binarize_down=_binarize_down, + smear=args.smear, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, nn.Linear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.tie_embeddings: + base_model.lm_head.weight.requires_grad_(False) + + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.recompile_limit = 16 + # No DDP needed for EGGROLL - each GPU evaluates different perturbations + # No torch.compile - forward only, perturbations modify weights in-place + + # --- Load pretrained weights --- + if args.eggroll_load: + log0(f"loading pretrained weights from {args.eggroll_load}") + if args.eggroll_load.endswith(".ptz"): + # Compressed artifact + with open(args.eggroll_load, "rb") as f: + raw = f.read() + method_byte = raw[0] + compressed = raw[1:] + if method_byte == 2 and HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + decompressed = dctx.decompress(compressed) + elif method_byte == 1 and HAS_BROTLI: + decompressed = brotli.decompress(compressed) + else: + decompressed = lzma.decompress(compressed) + loaded = torch.load(io.BytesIO(decompressed), map_location="cpu", weights_only=False) + base_model.load_state_dict(deq_sd(loaded), strict=False) + else: + # Raw checkpoint + ckpt = torch.load(args.eggroll_load, map_location="cpu", weights_only=False) + sd = ckpt["model"] if "model" in ckpt else ckpt + base_model.load_state_dict(sd, strict=False) + log0("pretrained weights loaded") + + # --- EGGROLL setup --- + # No gradients needed - forward only + for p in base_model.parameters(): + p.requires_grad_(False) + binary_params = get_binary_params(base_model, num_layers=args.num_layers, + last_n_layers=args.eggroll_layers) + log0(f"EGGROLL binary params: {len(binary_params)} layers, " + f"{sum(p.numel() for _,p in binary_params)} params") + + # --- Log --- + log0("--- Hyperparameters ---", console=False) + log0(" ".join(f"{a}={getattr(args,a)}" for a in sorted(dir(args)) + if not a.startswith("_") and a not in ("train_files","val_files") + and not callable(getattr(args,a))), console=False) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"XNOR-EGGROLL | params:{n_params} L:{args.num_layers} d:{args.model_dim} " + f"h:{args.num_heads} kv:{args.num_kv_heads} mlp:{args.mlp_mult}x " + f"act:{args.activation_type} xnor_act:{args.binarize_activations} " + f"g:{args.xnor_group_size} ws:{world_size} pop:{args.pop_size} " + f"sigma:{args.eggroll_sigma} rank:{args.eggroll_rank}") + + # --- Data --- + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step, elapsed_ms): + import math + if args.lr_schedule == "cosine": + if max_wallclock_ms is None: + total = args.iterations + warmup = min(args.lr_warmup_steps, total) + if step < warmup: + return step / max(warmup, 1) + progress = (step - warmup) / max(total - warmup, 1) + else: + warmup_ms = args.lr_warmup_steps * (elapsed_ms / max(step, 1)) if step > 0 else 1000.0 + if elapsed_ms < warmup_ms: + return elapsed_ms / max(warmup_ms, 1e-9) + progress = (elapsed_ms - warmup_ms) / max(max_wallclock_ms - warmup_ms, 1e-9) + return 0.5 * (1.0 + math.cos(math.pi * min(progress, 1.0))) + if args.warmdown_fraction <= 0: return 1.0 + if max_wallclock_ms is None: + warmdown_start = int(args.iterations * (1.0 - args.warmdown_fraction)) + return max((args.iterations - step) / max(args.iterations * args.warmdown_fraction, 1), 0.0) if step >= warmdown_start else 1.0 + warmdown_ms = max_wallclock_ms * args.warmdown_fraction + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # --- Seq len scheduling --- + seq_len_stages = [args.train_seq_len] + if args.seq_len_schedule: + seq_len_stages = [] + sl = 128 + while sl <= args.train_seq_len: + seq_len_stages.append(sl) + sl *= 2 + if seq_len_stages[-1] != args.train_seq_len: + seq_len_stages.append(args.train_seq_len) + log0(f"seq_len_schedule: {seq_len_stages}") + + def get_seq_len(step, elapsed_ms): + if len(seq_len_stages) <= 1: + return seq_len_stages[0] + n = len(seq_len_stages) + if max_wallclock_ms is not None: + stage_idx = min(int(elapsed_ms / (max_wallclock_ms / n)), n - 1) + else: + stage_idx = min(int(step / (args.iterations / n)), n - 1) + return seq_len_stages[stage_idx] + + # --- Main EGGROLL loop --- + training_time_ms = 0.0 + stop_after_step = None + pop_per_gpu = args.pop_size // world_size + assert pop_per_gpu % 2 == 0, "pop_size / world_size must be even for antithetic pairs" + pairs_per_gpu = pop_per_gpu // 2 + + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + + base_model.eval() + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(args, base_model, rank, world_size, device, 1, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms") + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms") + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + lr_scale = lr_mul(step, elapsed_ms) + current_seq_len = get_seq_len(step, elapsed_ms) + + # Get data batch (same for all perturbations) + x, y = train_loader.next_batch(args.train_batch_tokens, current_seq_len, 1) + + # Evaluate perturbations (antithetic pairs) + fitnesses_pos = [] + fitnesses_neg = [] + local_seeds = [] + + for i in range(pairs_per_gpu): + global_idx = step * (args.pop_size // 2) + rank * pairs_per_gpu + i + seed = global_idx * 137 + 42 # deterministic seed + + # Positive perturbation + apply_perturbation(binary_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=+1.0, device=device) + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss_pos = base_model(x, y) + fitnesses_pos.append(-loss_pos.item()) + remove_perturbation(binary_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=+1.0, device=device) + + # Negative perturbation (antithetic) + apply_perturbation(binary_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=-1.0, device=device) + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss_neg = base_model(x, y) + fitnesses_neg.append(-loss_neg.item()) + remove_perturbation(binary_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=-1.0, device=device) + + local_seeds.append(seed) + + # Gather fitnesses across GPUs + if distributed: + all_pos = [None] * world_size + all_neg = [None] * world_size + all_seeds_list = [None] * world_size + dist.all_gather_object(all_pos, fitnesses_pos) + dist.all_gather_object(all_neg, fitnesses_neg) + dist.all_gather_object(all_seeds_list, local_seeds) + fitnesses_pos = [f for sublist in all_pos for f in sublist] + fitnesses_neg = [f for sublist in all_neg for f in sublist] + all_seeds = [s for sublist in all_seeds_list for s in sublist] + else: + all_seeds = local_seeds + + # Shape fitnesses + shaped = shape_fitnesses(fitnesses_pos, fitnesses_neg, method=args.fitness_shaping) + + # Update weights + lr = args.eggroll_lr * lr_scale + compute_eggroll_update(binary_params, all_seeds, shaped, args.eggroll_sigma, + args.eggroll_rank, lr, device=device) + + step += 1 + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + if args.train_log_every > 0 and step % args.train_log_every == 0: + mean_fitness = sum(fitnesses_pos) / len(fitnesses_pos) + log0(f"step:{step}/{args.iterations} fitness:{mean_fitness:.4f} " + f"t:{approx_ms:.0f}ms step_avg:{approx_ms/step:.1f}ms") + + # Wallclock cap + if stop_after_step is None and max_wallclock_ms is not None and step % 10 == 0: + reached_cap = approx_ms >= max_wallclock_ms + if distributed: + cap_t = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(cap_t, op=dist.ReduceOp.MAX) + reached_cap = bool(cap_t.item()) + if reached_cap: + stop_after_step = step + + # --- Sign-sort permutation for compression --- + # Permute MLP hidden dimension so same-sign weights cluster together. + # Sort columns of down_proj by their sign pattern (lexicographic on sign bits). + # Apply matching row permutation to up_proj. Model output is identical. + if master_process: + sd = base_model.state_dict() + for layer_idx in range(args.num_layers): + # Find the MLP weight pairs + down_key = f"blocks.{layer_idx}.mlp.proj.weight" + if args.activation_type == "swiglu": + up_key = f"blocks.{layer_idx}.mlp.gate_up.weight" + else: + up_key = f"blocks.{layer_idx}.mlp.fc.weight" + if down_key not in sd or up_key not in sd: + continue + w_down = sd[down_key] # [dim, hidden] + w_up = sd[up_key] # [hidden, dim] (or [hidden*2, dim] for swiglu) + hidden = w_down.shape[1] + # Sort hidden dim columns of w_down by sign pattern + signs = (w_down >= 0).to(torch.int8) # [dim, hidden] + # Pack sign columns into sortable keys (use first 64 rows as sort key) + n_key_rows = min(64, signs.shape[0]) + sort_keys = torch.zeros(hidden, dtype=torch.int64, device=w_down.device) + for r in range(n_key_rows): + sort_keys += signs[r, :].long() << (n_key_rows - 1 - r) + perm = sort_keys.argsort() + # Apply permutation + sd[down_key] = w_down[:, perm] + if args.activation_type == "swiglu": + half = w_up.shape[0] // 2 + sd[up_key] = torch.cat([w_up[:half][perm], w_up[half:][perm]], dim=0) + else: + sd[up_key] = w_up[perm] + if base_model.tie_embeddings: + sd.pop("lm_head.weight", None) + q_obj, q_stats = q_sd(sd, group_size=args.xnor_group_size, + fp_storage=args.fp_storage, scale_fp8=args.scale_fp8) + buf = io.BytesIO() + torch.save(q_obj, buf) + raw_bytes = buf.getvalue() + # Compare LZMA vs Brotli, keep smallest + lzma_blob = lzma.compress(raw_bytes, preset=9) + best_blob, best_method = lzma_blob, "lzma" + comp_log = f"compression: lzma={len(lzma_blob)/1e6:.2f}MB" + if HAS_BROTLI: + brotli_blob = brotli.compress(raw_bytes, quality=11) + comp_log += f" brotli={len(brotli_blob)/1e6:.2f}MB" + if len(brotli_blob) < len(best_blob): + best_blob, best_method = brotli_blob, "brotli" + if HAS_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + zstd_blob = cctx.compress(raw_bytes) + comp_log += f" zstd={len(zstd_blob)/1e6:.2f}MB" + if len(zstd_blob) < len(best_blob): + best_blob, best_method = zstd_blob, "zstd" + log0(comp_log) + log0(f"using: {best_method}") + with open("final_model.xnor.ptz", "wb") as f: + # Method byte: 0=lzma, 1=brotli, 2=zstd + method_id = {"lzma": 0, "brotli": 1, "zstd": 2}[best_method] + f.write(bytes([method_id])) + f.write(best_blob) + artifact_bytes = 1 + len(best_blob) + code_bytes = len(code.encode("utf-8")) + total = artifact_bytes + code_bytes + log0(f"artifact:{artifact_bytes/1e6:.2f}MB binary:{q_stats['binary_params']}({q_stats['binary_bytes']}B) " + f"fp:{q_stats['fp_params']}({q_stats['fp_bytes']}B) code:{code_bytes}") + log0(f"budget:{total}/{16000000} ({total/1e6:.2f}/{16.00:.2f}MB) " + f"{'FITS' if total <= 16000000 else 'OVER'}") + + # --- Roundtrip eval --- + if distributed: dist.barrier() + with open("final_model.xnor.ptz", "rb") as f: + raw = f.read() + method_byte = raw[0] + compressed = raw[1:] + if method_byte == 2 and HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + decompressed = dctx.decompress(compressed) + elif method_byte == 1 and HAS_BROTLI: + decompressed = brotli.decompress(compressed) + else: + decompressed = lzma.decompress(compressed) + loaded = torch.load(io.BytesIO(decompressed), map_location="cpu", weights_only=False) + base_model.load_state_dict(deq_sd(loaded), strict=False) + torch._dynamo.reset() + q_val_loss, q_val_bpb = eval_val(args, base_model, rank, world_size, device, 1, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"final_xnor_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f}") + + opt_temp = 1.0 + if args.temp_scaling: + torch.cuda.synchronize() + t_temp = time.perf_counter() + calibration_tokens = train_loader.stream.take(65536).to(device) + opt_temp = find_temp(args, base_model, rank, world_size, device, 1, + calibration_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + torch.cuda.synchronize() + log0(f"temp_scaling optimal_T:{opt_temp:.2f} time:{1000*(time.perf_counter()-t_temp):.0f}ms") + + if args.sliding_eval: + torch.cuda.synchronize() + t_sl = time.perf_counter() + sw_loss, sw_bpb = eval_val_sliding(args, base_model, rank, world_size, device, 1, + val_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, stride=args.sliding_eval_stride, + temperature=opt_temp) + torch.cuda.synchronize() + log0(f"final_sliding val_loss:{sw_loss:.4f} val_bpb:{sw_bpb:.4f} " + f"(stride={args.sliding_eval_stride}, T={opt_temp:.2f}) time:{1000*(time.perf_counter()-t_sl):.0f}ms") + + if distributed: dist.destroy_process_group() + + +if __name__ == "__main__": + main() +==================================================================================================== +loading pretrained weights from checkpoints_n2_100k/ckpt_step0100000.pt +pretrained weights loaded +EGGROLL binary params: 12 layers, 34603008 params +--- Hyperparameters --- +activation_type=signsq binarize_activations=2 checkpoint_dir=./checkpoints checkpoint_every=0 churn_log_every=0 compile_mode=default data_path=./data/datasets/fineweb10B_sp1024 eggroll_layers=3 eggroll_load=checkpoints_n2_100k/ckpt_step0100000.pt eggroll_lr=1e-06 eggroll_rank=8 eggroll_sigma=0.0001 embed_dim=384 fitness_shaping=rank fp_storage=True iterations=10 logit_softcap=10.0 lr_schedule=cosine lr_warmup_steps=0 max_wallclock_seconds=0.0 mlp_mult=4 model_dim=1024 num_heads=8 num_kv_heads=4 num_layers=10 pop_size=4096 qk_gain_init=2.25 rope_base=5000.0 rope_type=yarn run_id=E8_eggroll-nlayers_8xH100_1775387953 scale_fp8=False seed=42 seq_len_schedule=False sign_compress_reg=0.0 sliding_batch_size=512 sliding_eval=False sliding_eval_stride=16 smear=False softcap_type=poly temp_scaling=True tie_embeddings=1 tied_embed_init_std=0.005 tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model train_batch_tokens=65536 train_log_every=1 train_seq_len=1024 use_int8_kernel=False use_triton_kernel=False val_batch_size=524288 val_loss_every=10 vocab_size=1024 warmdown_fraction=0.3 xnor_group_size=256 yarn_max_len=2048 +XNOR-EGGROLL | params:117618768 L:10 d:1024 h:8 kv:4 mlp:4x act:signsq xnor_act:2 g:256 ws:8 pop:4096 sigma:0.0001 rank:8 +step:0/10 val_loss:2.8153 val_bpb:1.5689 train_time:1ms +step:1/10 fitness:-2.9963 t:17251ms step_avg:17250.9ms +step:2/10 fitness:-2.9740 t:34143ms step_avg:17071.5ms +step:3/10 fitness:-2.9935 t:51657ms step_avg:17219.1ms +step:4/10 fitness:-2.9454 t:68481ms step_avg:17120.3ms +step:5/10 fitness:-3.0062 t:85298ms step_avg:17059.7ms +step:6/10 fitness:-2.9882 t:102081ms step_avg:17013.5ms +step:7/10 fitness:-3.0000 t:118871ms step_avg:16981.5ms +step:8/10 fitness:-2.9899 t:135844ms step_avg:16980.5ms +step:9/10 fitness:-3.0370 t:152662ms step_avg:16962.5ms +step:10/10 fitness:-2.9455 t:169385ms step_avg:16938.5ms +step:10/10 val_loss:2.9232 val_bpb:1.6290 train_time:169393ms +compression: lzma=15.88MB brotli=15.81MB zstd=18.26MB +using: brotli +artifact:15.81MB binary:115343360(15319040B) fp:1226832(1248416B) code:69955 +budget:15883303/16000000 (15.88/16.00MB) FITS +final_xnor_roundtrip val_loss:2.9329 val_bpb:1.6344 +temp_scaling optimal_T:1.00 time:123ms diff --git a/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/log_eggroll_lora.txt b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/log_eggroll_lora.txt new file mode 100644 index 0000000000..c992f43615 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/log_eggroll_lora.txt @@ -0,0 +1,1517 @@ +"""XNOR-Net LLM training script for OpenAI's Parameter Golf Challenge. +Ciprian-Florin Ifrim - April 2026 + +Based on the Binary-Weight-Network script, extended to full XNOR-Net: +both weights AND activations are binarized using sign() with per-group/per-token +scaling factors derived from the XNOR-Net paper (Rastegari et al. 2016). + +Architecture: U-Net transformer with skip connections that provides error correction +for the information loss inherent in binary quantization of both weights and activations. +""" + +import copy +import glob +import io +import math +import os +import random +import sys +import time +import lzma +from pathlib import Path +try: + import brotli + HAS_BROTLI = True +except ImportError: + HAS_BROTLI = False +try: + import zstandard as zstd + HAS_ZSTD = True +except ImportError: + HAS_ZSTD = False +try: + from gram_newton_schulz import GramNewtonSchulz, POLAR_EXPRESS_COEFFICIENTS + HAS_GRAM_NS = True +except ImportError: + HAS_GRAM_NS = False +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False +def _get_flash_attn_func(): + """GPU-aware FlashAttention selection. + H100: FA3 -> FA2 -> SDPA + 5090: FA4 -> FA2 -> SDPA + Other: FA2 -> SDPA + """ + if torch.cuda.is_available(): + sm = torch.cuda.get_device_capability() + # Hopper (SM 90) — prefer FA3 + if sm[0] == 9 and sm[1] == 0: + try: + from flash_attn_interface import flash_attn_func + return flash_attn_func + except ImportError: + pass + # Blackwell (SM 100/120) — try FA4 + if sm[0] >= 10: + try: + from flash_attn.cute import flash_attn_func + return flash_attn_func + except ImportError: + pass + # FA2 — universal fallback with CUDA kernels + try: + from flash_attn import flash_attn_func + return flash_attn_func + except ImportError: + pass + # PyTorch native SDPA — always works, no install needed + def _sdpa_flash_attn_func(q, k, v, causal=False, **kwargs): + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + o = F.scaled_dot_product_attention(q, k, v, is_causal=causal, enable_gqa=True) + return o.transpose(1, 2) + return _sdpa_flash_attn_func +flash_attn_func = _get_flash_attn_func() + +# --------------------------------------------------------------------------- +# Hyperparameters +# --------------------------------------------------------------------------- +def _e(k, d, t=str): + v = os.environ.get(k, str(d)) + if t == bool: return bool(int(v)) + return t(v) + +class Hyperparameters: + # Data + data_path = _e("DATA_PATH", "./data/datasets/fineweb10B_sp8192") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = _e("TOKENIZER_PATH", "./data/tokenizers/fineweb_8192_bpe.model") + run_id = os.environ.get("RUN_ID", f"xnor_{int(time.time())}") + seed = _e("SEED", 1337, int) + # Compile + compile_mode = _e("COMPILE_MODE", "default") + # Eval + val_batch_size = _e("VAL_BATCH_SIZE", 524288, int) + val_loss_every = _e("VAL_LOSS_EVERY", 500, int) + train_log_every = _e("TRAIN_LOG_EVERY", 10, int) + sliding_eval = _e("SLIDING_EVAL", 0, bool) + sliding_eval_stride = _e("SLIDING_EVAL_STRIDE", 16, int) + sliding_batch_size = _e("SLIDING_BATCH_SIZE", 512, int) + temp_scaling = _e("TEMP_SCALING", 0, bool) + # Training schedule + iterations = _e("ITERATIONS", 2000, int) + warmdown_fraction = _e("WARMDOWN_FRACTION", 0.15, float) + lr_warmup_steps = _e("LR_WARMUP_STEPS", 5, int) + lr_schedule = _e("LR_SCHEDULE", "linear") # "linear" (current) or "cosine" + seq_len_schedule = _e("SEQ_LEN_SCHEDULE", 0, bool) # ramp seq_len from 128 to train_seq_len + train_batch_tokens = _e("TRAIN_BATCH_TOKENS", 524288, int) + train_seq_len = _e("TRAIN_SEQ_LEN", 1024, int) + max_wallclock_seconds = _e("MAX_WALLCLOCK_SECONDS", 0.0, float) + # Architecture — from XNOR-UNet spec + vocab_size = _e("VOCAB_SIZE", 8192, int) + num_layers = _e("NUM_LAYERS", 15, int) + num_kv_heads = _e("NUM_KV_HEADS", 4, int) + model_dim = _e("MODEL_DIM", 768, int) + num_heads = _e("NUM_HEADS", 8, int) + mlp_mult = _e("MLP_MULT", 4, int) + tie_embeddings = _e("TIE_EMBEDDINGS", 1, int) + embed_dim = _e("EMBED_DIM", 312, int) + activation_type = _e("ACTIVATION", "relu2") + # Attention + rope_base = _e("ROPE_BASE", 5000.0, float) + rope_type = _e("ROPE_TYPE", "yarn") + yarn_max_len = _e("YARN_MAX_LEN", 2048, int) + qk_gain_init = _e("QK_GAIN_INIT", 2.25, float) + # Logits + logit_softcap = _e("LOGIT_SOFTCAP", 10.0, float) + softcap_type = _e("SOFTCAP_TYPE", "poly") + tied_embed_init_std = _e("TIED_EMBED_INIT_STD", 0.005, float) + # Architecture extras + smear = _e("SMEAR", 1, bool) + # XNOR quantization + xnor_group_size = _e("XNOR_GROUP_SIZE", 128, int) + # Controls activation binarization: 0 = none (BWN), 1 = all (full XNOR), 2 = skip MLP down proj + binarize_activations = _e("BINARIZE_ACTIVATIONS", 1, int) + # Use INT8 tensor cores for binary matmul (speed test) + use_int8_kernel = _e("USE_INT8_KERNEL", 0, bool) + # Use Triton XNOR+POPCOUNT kernel (true 1-bit matmul) + use_triton_kernel = _e("USE_TRITON_KERNEL", 0, bool) + # FP8 QAT for non-binary params + _fp_raw = os.environ.get("FP_STORAGE", "FP8") + fp_storage = True if _fp_raw == "FP8" else False + # Scale storage for binary weight groups: FP8 or BF16 + _scale_raw = os.environ.get("SCALE_STORAGE", "BF16") + scale_fp8 = True if _scale_raw == "FP8" else False + # Compression regularizer: penalizes sign entropy within groups + sign_compress_reg = _e("SIGN_COMPRESS_REG", 0.0, float) + # EGGROLL + pop_size = _e("POP_SIZE", 256, int) # population size (antithetic pairs = pop_size/2) + eggroll_sigma = _e("EGGROLL_SIGMA", 0.01, float) # perturbation scale + eggroll_lr = _e("EGGROLL_LR", 0.001, float) # learning rate + eggroll_rank = _e("EGGROLL_RANK", 1, int) # rank of perturbation matrices + fitness_shaping = _e("FITNESS_SHAPING", "rank") # "rank" or "sign" + eggroll_load = _e("EGGROLL_LOAD", "") # path to pretrained .ptz to load + eggroll_layers = _e("EGGROLL_LAYERS", 0, int) # 0 = all layers, N = last N layers only + # Checkpointing + checkpoint_every = _e("CHECKPOINT_EVERY", 0, int) + checkpoint_dir = _e("CHECKPOINT_DIR", "./checkpoints") + # Diagnostics + churn_log_every = _e("CHURN_LOG_EVERY", 500, int) + +# Scalar/low-dim parameter names (excluded from Muon, use Adam) +CTP = ("attn_scale", "mlp_scale", "resid_mix", "q_gain", "skip_weight", + "skip_weights", "vocab_bias", "smear") + +# --------------------------------------------------------------------------- +# Binary packing (1 bit per weight, lossless) +# --------------------------------------------------------------------------- +def pack_binary(q: Tensor) -> tuple[bytes, int]: + bits = ((q.reshape(-1).to(torch.int8) + 1) // 2).numpy().astype(np.uint8) + n = len(bits) + pad = (8 - n % 8) % 8 + if pad: + bits = np.concatenate([bits, np.zeros(pad, dtype=np.uint8)]) + groups = bits.reshape(-1, 8) + packed = np.zeros(len(groups), dtype=np.uint8) + for i in range(8): + packed |= groups[:, i] << i + return packed.tobytes(), n + +def unpack_binary(data: bytes, n: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8) + bits = np.zeros((len(packed), 8), dtype=np.int8) + for i in range(8): + bits[:, i] = (packed >> i) & 1 + flat = bits.reshape(-1)[:n] + return torch.from_numpy(flat.astype(np.int8) * 2 - 1) + +# --------------------------------------------------------------------------- +# State dict serialization +# --------------------------------------------------------------------------- +def q_sd(state_dict: dict, group_size: int = 128, fp_storage=False, scale_fp8=False) -> tuple[dict, dict]: + """Binary pack large 2D weight matrices, BF16/FP8 for everything else.""" + quantized = {} + stats = {"binary_params": 0, "binary_bytes": 0, "fp_params": 0, "fp_bytes": 0} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().float().contiguous() + t_orig_shape = list(t.shape) + if t.ndim == 3: + t = t.reshape(t.shape[0], -1) + is_binary = ( + t.ndim == 2 and t.numel() > 65_536 + and "tok_emb" not in name and "lm_head" not in name + and "embed_proj" not in name + ) + if is_binary: + pad = (group_size - t.shape[1] % group_size) % group_size + t_padded = F.pad(t, (0, pad)) if pad > 0 else t + t_grouped = t_padded.reshape(-1, group_size) + scale = t_grouped.abs().mean(-1, keepdim=True).clamp(min=1e-8) + q = torch.where(t_grouped >= 0, + torch.ones_like(t_grouped, dtype=torch.int8), + -torch.ones_like(t_grouped, dtype=torch.int8)) + packed_bytes, n_bits = pack_binary(q) + if scale_fp8: + scale_stored = scale.to(torch.float8_e4m3fn).squeeze(-1) + scale_bytes_per = 1 + else: + scale_stored = scale.bfloat16().squeeze(-1) + scale_bytes_per = 2 + quantized[name] = { + "type": "binary", "packed": packed_bytes, + "scale": scale_stored, + "shape": list(t.shape), "padded_cols": t_padded.shape[1], + "group_size": group_size, "n_bits": n_bits, + "orig_shape": t_orig_shape, + } + stats["binary_params"] += t.numel() + stats["binary_bytes"] += len(packed_bytes) + scale.numel() * scale_bytes_per + elif fp_storage and t.ndim == 2: + quantized[name] = {"type": "fp8", "data": t.to(torch.float8_e4m3fn)} + stats["fp_params"] += t.numel() + stats["fp_bytes"] += t.numel() + else: + quantized[name] = {"type": "bf16", "data": t.bfloat16()} + stats["fp_params"] += t.numel() + stats["fp_bytes"] += t.numel() * 2 + return quantized, stats + +def deq_sd(quantized: dict, target_dtype=torch.bfloat16): + """Reconstruct state dict from quantized representation.""" + out = {} + for name, entry in quantized.items(): + if entry["type"] == "binary": + q = unpack_binary(entry["packed"], entry["n_bits"]) + q = q.float().reshape(-1, entry["group_size"]) + scale = entry["scale"].float().unsqueeze(-1) + # No shrinkage correction: binary q_absmean = 1.0 always + t = (q * scale).reshape(-1, entry["padded_cols"]) + shape = entry["shape"] + result = t[:shape[0], :shape[1]].to(target_dtype) + orig = entry.get("orig_shape") + out[name] = result.reshape(orig).contiguous() if orig and orig != shape else result.contiguous() + elif entry["type"] == "fp8": + out[name] = entry["data"].to(torch.float32).to(target_dtype).contiguous() + else: + out[name] = entry["data"].to(target_dtype).contiguous() + return out + +# --------------------------------------------------------------------------- +# Diagnostics +# --------------------------------------------------------------------------- +_prev_committed: dict = {} +def churn_fn(model: nn.Module, group_size: int = 128): + global _prev_committed + total = flipped = 0 + with torch.no_grad(): + for name, p in model.named_parameters(): + if p.ndim == 2 and "weight" in name and p.shape[0] > 1: + w = p.detach().float().reshape(-1, group_size) + q = torch.where(w >= 0, torch.ones_like(w), -torch.ones_like(w)).cpu().numpy() + if name in _prev_committed: + flipped += int(np.sum(q != _prev_committed[name])) + total += q.size + _prev_committed[name] = q + return flipped / max(total, 1) + +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# Muon optimizer +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# EGGROLL: Evolution-Guided optimization with low-rank perturbations +# --------------------------------------------------------------------------- +def get_binary_params(model, num_layers=10, last_n_layers=0): + """Get list of (name, param) for parameters that will be perturbed. + last_n_layers=0 means all layers, N means only the last N blocks.""" + params = [] + for name, p in model.named_parameters(): + if p.ndim == 2 and p.numel() > 65536 and "tok_emb" not in name and "lm_head" not in name and "embed_proj" not in name: + if last_n_layers > 0: + # Only include params from last N blocks + match = False + for layer_idx in range(num_layers - last_n_layers, num_layers): + if f"blocks.{layer_idx}." in name: + match = True + break + if not match: + continue + params.append((name, p)) + return params + +def apply_perturbation(binary_params, seed, sigma, rank, sign=1.0, device="cuda"): + """Apply rank-r perturbation to binary params using deterministic RNG.""" + rng = torch.Generator(device=device) + rng.manual_seed(seed) + for name, p in binary_params: + m, n = p.shape + for r in range(rank): + a = torch.randn(m, device=device, generator=rng) + b = torch.randn(n, device=device, generator=rng) + p.data.add_(a.outer(b), alpha=sign * sigma / math.sqrt(rank)) + +def remove_perturbation(binary_params, seed, sigma, rank, sign=1.0, device="cuda"): + """Remove previously applied perturbation (subtract it).""" + apply_perturbation(binary_params, seed, sigma, rank, sign=-sign, device=device) + +def compute_eggroll_update(binary_params, seeds, fitnesses, sigma, rank, lr, device="cuda"): + """Compute and apply the EGGROLL weight update. + Uses the efficient (diag(f) @ A).T @ B formulation for rank-1.""" + N = len(seeds) + for name, p in binary_params: + m, n = p.shape + update = torch.zeros_like(p.data) + for i, (seed, f) in enumerate(zip(seeds, fitnesses)): + rng = torch.Generator(device=device) + rng.manual_seed(seed) + for r in range(rank): + a = torch.randn(m, device=device, generator=rng) + b = torch.randn(n, device=device, generator=rng) + update.add_(a.outer(b), alpha=f / (math.sqrt(rank) * N)) + p.data.add_(update, alpha=lr / sigma) + +def shape_fitnesses(fitnesses_pos, fitnesses_neg, method="rank"): + """Shape fitnesses from antithetic pairs.""" + N = len(fitnesses_pos) + if method == "sign": + return [float(fp > fn) - float(fn > fp) for fp, fn in zip(fitnesses_pos, fitnesses_neg)] + # Rank-based shaping: combine all, rank, normalize to [-0.5, 0.5] + all_f = list(fitnesses_pos) + list(fitnesses_neg) + ranked = torch.argsort(torch.argsort(torch.tensor(all_f, dtype=torch.float32))).float() + ranked = ranked / (2 * N - 1) - 0.5 + # Return shaped fitness for positive perturbation (subtract negative's) + shaped = [(ranked[i] - ranked[N + i]).item() for i in range(N)] + return shaped + +# --------------------------------------------------------------------------- +# Data loading +# --------------------------------------------------------------------------- +def ld_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" Tensor: + chunks = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int): + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start:start + per_rank_span].pin_memory().to(self.device, non_blocking=True).to(torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x, y + +# --------------------------------------------------------------------------- +# Model components +# --------------------------------------------------------------------------- +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def apply_qat_ste(w: Tensor, fp_storage) -> Tensor: + if not fp_storage: + return w + w_sim = w.to(torch.float8_e4m3fn).to(w.dtype) + return (w_sim - w).detach() + w + +class QATLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=False, fp_storage=False): + super().__init__(in_features, out_features, bias=bias) + self.fp_storage = fp_storage + def forward(self, x: Tensor) -> Tensor: + w = apply_qat_ste(self.weight, self.fp_storage) + return F.linear(x, w.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) + +class QATEmbedding(nn.Embedding): + def __init__(self, num_embeddings, embedding_dim, fp_storage=False): + super().__init__(num_embeddings, embedding_dim) + self.fp_storage = fp_storage + def forward(self, input: Tensor) -> Tensor: + w = apply_qat_ste(self.weight, self.fp_storage) + return F.embedding(input, w, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + +# --------------------------------------------------------------------------- +# XNOR-Net Linear Layers +# --------------------------------------------------------------------------- +_INT8_KERNEL = False # set from Hyperparameters.use_int8_kernel in main() + +class _INT8XNORFn(torch.autograd.Function): + """INT8 tensor core forward, BF16 backward. + Uses per-row alpha (not per-group) so we get a single torch._int_mm call. + """ + @staticmethod + def forward(ctx, x, w): + # Per-row weight scale (single matmul) + w_alpha = w.detach().abs().mean(dim=-1).clamp_(min=1e-8) # [N] + x_beta = x.detach().abs().mean(dim=-1, keepdim=True).clamp_(min=1e-8) # [..., 1] + + w_int8 = w.detach().sign().to(torch.int8) # [N, K] + x_int8 = x.detach().sign().to(torch.int8) # [..., K] + + # Flatten to 2D for torch._int_mm(A[M,K], B[K,N]) -> [M,N] + orig_shape = x.shape[:-1] + M, K, N = orig_shape.numel(), x.shape[-1], w.shape[0] + x_2d = x_int8.reshape(M, K) + wt = w_int8.t().contiguous() # [K, N] + raw = torch._int_mm(x_2d, wt) # [M, N] INT32 + + y = raw.to(x.dtype) * x_beta.reshape(M, 1) * w_alpha.unsqueeze(0) + ctx.save_for_backward(x, w, w_alpha, x_beta) + return y.reshape(*orig_shape, N) + + @staticmethod + def backward(ctx, grad_output): + x, w, w_alpha, x_beta = ctx.saved_tensors + dtype = grad_output.dtype + + w_binary = (w.sign() * w_alpha.unsqueeze(-1)).to(dtype) # [N, K] + x_binary = (x.sign() * x_beta).to(dtype) # [..., K] + + # STE: gradient flows through sign() as identity + grad_x = grad_output @ w_binary + M = grad_output.shape[:-1].numel() + grad_w = grad_output.reshape(M, -1).t() @ x_binary.reshape(M, -1) + return grad_x, grad_w + +# --------------------------------------------------------------------------- +# Triton XNOR+POPCOUNT kernel — true 1-bit matmul +# --------------------------------------------------------------------------- +_TRITON_KERNEL = False # set from Hyperparameters in main() +_SCALE_FP8_STE = False # simulate FP8 scale quantization during training + +def pack_signs(x: Tensor) -> Tensor: + """Pack sign bits into int32. +1 -> bit 1, -1 -> bit 0. + x: [..., K] any dtype. K must be divisible by 32. + returns: [..., K//32] int32 + """ + K = x.shape[-1] + assert K % 32 == 0, f"K={K} must be divisible by 32" + bits = (x > 0).view(*x.shape[:-1], K // 32, 32) + shifts = torch.arange(32, device=x.device, dtype=torch.int32) + return (bits.to(torch.int32) << shifts).sum(dim=-1) + +if HAS_TRITON: + @triton.jit + def _xnor_matmul_fwd_kernel( + X_ptr, W_ptr, Y_ptr, alpha_ptr, beta_ptr, + M, N, K_packed, GROUP_PACKED: tl.constexpr, num_groups, + stride_xm, stride_wn, stride_yn, stride_alpha_n, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + ): + pid = tl.program_id(0) + num_n = tl.cdiv(N, BLOCK_N) + pid_m = pid // num_n + pid_n = pid % num_n + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_m = offs_m < M + mask_n = offs_n < N + + y_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.bfloat16) + + x_base = X_ptr + offs_m * stride_xm + w_base = W_ptr + offs_n * stride_wn + + # Loop over groups, each with GROUP_PACKED packed int32s + for g in range(num_groups): + group_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + k_start = g * GROUP_PACKED + for k in range(GROUP_PACKED): + x = tl.load(x_base + k_start + k, mask=mask_m, other=0) + w = tl.load(w_base + k_start + k, mask=mask_n, other=0) + # Cast broadcast XOR to int32 to prevent int64 promotion in popc + diff = (x[:, None] ^ w[None, :]).to(tl.int32) + group_acc += tl.extra.cuda.libdevice.popc(diff) + + # diff_bits = popc(XOR), dot = GROUP_SIZE - 2 * diff_bits + group_dot = (GROUP_PACKED * 32 - 2 * group_acc).to(tl.bfloat16) + + # Per-group alpha: alpha[n, g] + alpha_g = tl.load(alpha_ptr + offs_n * stride_alpha_n + g, mask=mask_n, other=1.0).to(tl.bfloat16) + y_acc += group_dot * alpha_g[None, :] + + # Per-token beta + beta = tl.load(beta_ptr + offs_m, mask=mask_m, other=1.0).to(tl.bfloat16) + y_acc = y_acc * beta[:, None] + + # Store + y_ptrs = Y_ptr + offs_m[:, None] * stride_yn + offs_n[None, :] + tl.store(y_ptrs, y_acc, mask=mask_m[:, None] & mask_n[None, :]) + +class _TritonXNORFn(torch.autograd.Function): + """Triton XNOR+POPCOUNT forward with per-group alpha, BF16 backward.""" + @staticmethod + def forward(ctx, x, w, group_size): + N, K = w.shape + assert K % group_size == 0, f"K={K} must be divisible by group_size={group_size}" + assert K % 32 == 0, f"K={K} must be divisible by 32" + num_groups = K // group_size + GROUP_PACKED = group_size // 32 + + # Per-group weight scale: [N, num_groups] + w_alpha = w.detach().reshape(N, num_groups, group_size).abs().mean(dim=-1).clamp_(min=1e-8) + if _SCALE_FP8_STE: + w_alpha = w_alpha.to(torch.float8_e4m3fn).to(w_alpha.dtype) + # Per-token activation scale: [..., 1] + x_beta_kd = x.detach().abs().mean(dim=-1, keepdim=True).clamp_(min=1e-8) + + # Pack signs into int32 + x_packed = pack_signs(x.detach()) # [..., K//32] + w_packed = pack_signs(w.detach()) # [N, K//32] + + orig_shape = x.shape[:-1] + M = orig_shape.numel() + K_packed = K // 32 + + x_2d = x_packed.reshape(M, K_packed).contiguous() + w_alpha_c = w_alpha.contiguous() # [N, num_groups] + beta_flat = x_beta_kd.reshape(M).contiguous() + + Y = torch.empty(M, N, device=x.device, dtype=torch.bfloat16) + + BLOCK_M, BLOCK_N = 64, 64 + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),) + + _xnor_matmul_fwd_kernel[grid]( + x_2d, w_packed, Y, w_alpha_c, beta_flat, + M, N, K_packed, GROUP_PACKED, num_groups, + x_2d.stride(0), w_packed.stride(0), Y.stride(0), w_alpha_c.stride(0), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + ) + + ctx.save_for_backward(x, w, w_alpha, x_beta_kd) + ctx.group_size = group_size + return Y.reshape(*orig_shape, N) + + @staticmethod + def backward(ctx, grad_output): + x, w, w_alpha, x_beta_kd = ctx.saved_tensors + dtype = grad_output.dtype + group_size = ctx.group_size + + # Reconstruct per-group binary weights for backward + N, K = w.shape + num_groups = K // group_size + w_binary = (w.sign().reshape(N, num_groups, group_size) * + w_alpha.unsqueeze(-1)).reshape(N, K).to(dtype) + x_binary = (x.sign() * x_beta_kd).to(dtype) + + grad_x = grad_output @ w_binary + M = grad_output.shape[:-1].numel() + grad_w = grad_output.reshape(M, -1).t() @ x_binary.reshape(M, -1) + return grad_x, grad_w, None # None for group_size + +class XNORLinear(nn.Linear): + """XNOR-Net linear layer: binarizes BOTH weights AND activations. + + Forward pass: + 1. Weight binarization: W_bin = sign(W) * alpha, alpha = mean(|W|) per group + 2. Activation binarization: X_bin = sign(X) * beta, beta = mean(|X|) per token + 3. Output = F.linear(X_bin, W_bin) + + Backward pass: Straight-Through Estimator (STE) — gradients flow through as + if the binarization was identity. This matches the XNOR-Net paper's training + algorithm (Algorithm 1). + + When binarize_act=False, this degrades to Binary-Weight-Network (BWN). + """ + def __init__(self, in_features, out_features, bias=False, group_size=128, + binarize_act=True): + super().__init__(in_features, out_features, bias=bias) + self.group_size = group_size + self.binarize_act = binarize_act + + def forward(self, x: Tensor) -> Tensor: + # Triton XNOR+POPCOUNT path (full XNOR only, per-row alpha) + if _TRITON_KERNEL and self.binarize_act: + y = _TritonXNORFn.apply(x, self.weight, self.group_size) + if self.bias is not None: + y = y + self.bias.to(y.dtype) + return y + # INT8 tensor core path (full XNOR only, per-row alpha) + if _INT8_KERNEL and self.binarize_act: + y = _INT8XNORFn.apply(x, self.weight) + if self.bias is not None: + y = y + self.bias.to(y.dtype) + return y + + # Standard STE path (per-group alpha) + w = self.weight.bfloat16() + g = self.group_size + + # --- Weight binarization (STE) --- + w_g = w.reshape(-1, g) + alpha = w_g.abs().mean(-1, keepdim=True).clamp(min=1e-8) + # Simulate FP8 scale quantization so model compensates for roundtrip error + if _SCALE_FP8_STE: + alpha_q = alpha.to(torch.float8_e4m3fn).to(alpha.dtype) + alpha = alpha + (alpha_q - alpha).detach() + w_sign = torch.where(w_g >= 0, torch.ones_like(w_g), -torch.ones_like(w_g)) + w_binary = w + ((w_sign * alpha).reshape(w.shape) - w).detach() + + # --- Activation binarization (STE) --- + if self.binarize_act: + beta = x.abs().mean(dim=-1, keepdim=True).clamp(min=1e-8) + x_sign = torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x)) + x_binary = x + ((x_sign * beta) - x).detach() + else: + x_binary = x + + return F.linear(x_binary, w_binary, + self.bias.to(x.dtype) if self.bias is not None else None) + + +class NormedXNORLinear(XNORLinear): + """XNOR linear with RMSNorm on input. + Used for output projections receiving un-normalized activations (attention out, MLP down). + The RMSNorm ensures zero-mean input before binarization — critical for sign() quality + (XNOR-Net paper: B-A-C-P block order, Table 3b: 30.3% vs 44.2%). + """ + def forward(self, x: Tensor) -> Tensor: + return super().forward(F.rms_norm(x, (x.size(-1),))) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(p in name for p in CTP)) and param.dtype != torch.float32: + param.data = param.data.float() + +# --------------------------------------------------------------------------- +# Rotary embeddings +# --------------------------------------------------------------------------- +class Rotary(nn.Module): + def __init__(self, dim, base=10000.0, rope_type="rope", yarn_max_len=4096, train_seq_len=1024): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + if rope_type == "yarn": + scale = train_seq_len / yarn_max_len + freq_idx = torch.arange(0, dim, 2, dtype=torch.float32) + ramp = torch.clamp((freq_idx / dim - 0.25) / 0.75, 0.0, 1.0) + inv_freq = inv_freq / (ramp * (1.0 / scale - 1.0) + 1.0) + # Precompute for max seq_len, slice at runtime + t = torch.arange(train_seq_len, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self.register_buffer("_cos", freqs.cos()[None, :, None, :], persistent=False) + self.register_buffer("_sin", freqs.sin()[None, :, None, :], persistent=False) + + def forward(self, seq_len, device, dtype): + return self._cos[:, :seq_len].to(dtype=dtype), self._sin[:, :seq_len].to(dtype=dtype) + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- +class CausalSelfAttention(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + group_size=128, rope_type="rope", yarn_max_len=4096, + train_seq_len=1024, binarize_act=True): + super().__init__() + self.num_heads, self.num_kv_heads = num_heads, num_kv_heads + self.head_dim = dim // num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + # Fused QKV — single XNOR matmul + self.c_qkv = XNORLinear(dim, self.q_size + 2 * self.kv_size, bias=False, + group_size=group_size, binarize_act=binarize_act) + # Output projection with RMSNorm before binarization (B-A-C-P order) + self.proj = NormedXNORLinear(dim, dim, bias=False, + group_size=group_size, binarize_act=binarize_act) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, + rope_type=rope_type, yarn_max_len=yarn_max_len, + train_seq_len=train_seq_len) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + qkv_out = self.c_qkv(x) + q_out, k_out, v_out = qkv_out.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q_out.reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = k_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = v_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + # QK normalization (BF16, stays float — non-negotiable for attention quality) + q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + # FlashAttention-3 (BF16 — attention scores/softmax/V are NEVER binarized) + y = flash_attn_func(q.contiguous(), k.contiguous(), v.contiguous(), causal=True) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + +# --------------------------------------------------------------------------- +# MLP +# --------------------------------------------------------------------------- +class MLP(nn.Module): + def __init__(self, dim, mlp_mult, group_size=128, activation="relu2", + binarize_act=True, binarize_down=True): + super().__init__() + hidden = mlp_mult * dim + self.activation = activation + if activation == "swiglu": + self.gate_up = XNORLinear(dim, hidden * 2, bias=False, + group_size=group_size, binarize_act=binarize_act) + else: + self.fc = XNORLinear(dim, hidden, bias=False, + group_size=group_size, binarize_act=binarize_act) + self.proj = NormedXNORLinear(hidden, dim, bias=False, + group_size=group_size, binarize_act=binarize_down) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + if self.activation == "swiglu": + gu = self.gate_up(x) + gate, up = gu.chunk(2, dim=-1) + return self.proj(F.silu(gate) * up) + elif self.activation == "signsq": + h = self.fc(x) + return self.proj(h * h.abs()) + else: # relu2 + return self.proj(torch.relu(self.fc(x)).square()) + +# --------------------------------------------------------------------------- +# SmearModule +# --------------------------------------------------------------------------- +class SmearModule(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + cumsum = x.cumsum(dim=1) + counts = torch.arange(1, x.size(1) + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smeared = cumsum / counts + gate = torch.tanh(self.gate.to(dtype=x.dtype)) + return x + gate * (smeared - x) + +# --------------------------------------------------------------------------- +# Transformer block +# --------------------------------------------------------------------------- +class Block(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, group_size=128, activation="relu2", + rope_type="rope", yarn_max_len=4096, train_seq_len=1024, + binarize_act=True, binarize_down=True, smear=False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, + qk_gain_init, group_size, rope_type, + yarn_max_len, train_seq_len, binarize_act) + self.mlp = MLP(dim, mlp_mult, group_size, activation, binarize_act, binarize_down) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearModule(dim) if smear else None + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0] * x + mix[1] * x0 + n = self.attn_norm(x) + x = x + self.attn_scale.to(dtype=x.dtype) * self.attn(n) + x = x + self.mlp_scale.to(dtype=x.dtype) * self.mlp(self.mlp_norm(x)) + if self.smear is not None: + x = self.smear(x) + return x + +# --------------------------------------------------------------------------- +# GPT with U-Net skip connections +# --------------------------------------------------------------------------- +class GPT(nn.Module): + def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, + mlp_mult, tie_embeddings, tied_embed_init_std, logit_softcap, + rope_base, qk_gain_init, group_size=128, activation="relu2", + embed_dim=0, fp_storage=False, softcap_type="poly", + rope_type="rope", yarn_max_len=4096, train_seq_len=1024, + binarize_act=True, binarize_down=True, smear=False): + super().__init__() + self.tie_embeddings = tie_embeddings + self.logit_softcap = logit_softcap + self.softcap_type = softcap_type + self.fp_storage = fp_storage + self.embed_dim = embed_dim if embed_dim > 0 else model_dim + + # Embedding (FP8 QAT — NEVER binarized, confirmed 0.11 bpb RT gap) + self.tok_emb = QATEmbedding(vocab_size, self.embed_dim, fp_storage=fp_storage) + # Projection from embed_dim to model_dim (and reverse for tied logits) + self.embed_proj = QATLinear(self.embed_dim, model_dim, bias=False, + fp_storage=fp_storage) if self.embed_dim != model_dim else None + self.embed_proj_rev = QATLinear(model_dim, self.embed_dim, bias=False, + fp_storage=fp_storage) if self.embed_dim != model_dim else None + + # U-Net structure + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + # Skip weights initialized to ONES (not zeros — zeros costs 0.010 bpb) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + + # Transformer blocks + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, group_size, activation, rope_type, yarn_max_len, + train_seq_len, binarize_act, binarize_down, smear) + for _ in range(num_layers) + ]) + self.final_norm = RMSNorm() + + # Logit head (tied to embedding) + self.lm_head = QATLinear(model_dim, vocab_size, bias=False, fp_storage=fp_storage) + self.lm_head._zero_init = True + if tie_embeddings: + self.lm_head.weight.requires_grad_(False) + self.vocab_bias = nn.Parameter(torch.zeros(vocab_size, dtype=torch.float32)) + + self._init_weights(tied_embed_init_std) + + def _init_weights(self, tied_embed_init_std): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=tied_embed_init_std) + for module in self.modules(): + if isinstance(module, XNORLinear) and not getattr(module, "_zero_init", False): + nn.init.normal_(module.weight, mean=0.0, std=0.02) + elif isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _compute_logits(self, x: Tensor) -> Tensor: + if self.tie_embeddings: + proj = self.embed_proj_rev(x) if self.embed_proj_rev is not None else x + return F.linear(proj, self.tok_emb.weight.to(x.dtype)) + return self.lm_head(x) + + def _softcap(self, logits: Tensor) -> Tensor: + s = self.logit_softcap + if self.softcap_type == "tanh": + return s * torch.tanh(logits / s) + # Poly softcap — fuses with torch.compile (tanh does not) + x_sc = torch.clamp(logits / s, -2.0, 2.0) + x2 = x_sc * x_sc + return s * torch.clamp(x_sc * (1.0 - x2 / 3.0 + x2 * x2 / 15.0), -1.0, 1.0) + + def forward(self, input_ids: Tensor, target_ids: Tensor, reduction="mean", temperature=1.0): + x = self.tok_emb(input_ids) + x = x.float() + if self.embed_proj is not None: + x = self.embed_proj(x) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + # U-Net encoder + skips = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + + # U-Net decoder with skip connections + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype) * skips.pop() + x = self.blocks[bi](x, x0) + + x_normed = self.final_norm(x) + x_flat = x_normed.reshape(-1, x_normed.size(-1)) + targets = target_ids.reshape(-1) + logits = self._softcap(self._compute_logits(x_flat)) + + if temperature != 1.0: + logits = logits / temperature + if reduction == "none": + return F.cross_entropy(logits.float(), targets, reduction="none").reshape(input_ids.shape) + + # Fused CE + Z-loss (logsumexp² regularization keeps STE gradients sharp) + logits_f = logits.float() + lse = torch.logsumexp(logits_f, dim=-1) + target_logits = logits_f.gather(1, targets.unsqueeze(1)).squeeze(1) + return (lse - target_logits).mean() + 1e-4 * (lse ** 2).mean() + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- +def build_luts(sp, vocab_size, device): + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + +def ld_val(pattern, seq_len, max_tok=int(os.environ.get("VAL_MAX_TOKENS", 500000))): + files = sorted(glob.glob(pattern)) + assert files, f"No files: {pattern}" + tok = torch.cat([ld_shard(Path(p)) for p in files]).contiguous() + if max_tok > 0: tok = tok[:max_tok + 1] + u = ((tok.numel() - 1) // seq_len) * seq_len + return tok[:u + 1] + +def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, temperature=1.0): + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_seqs = max(1, local_batch_tokens // args.train_seq_len) + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_start in range(seq_start, seq_end, local_batch_seqs): + batch_end = min(batch_start + local_batch_seqs, seq_end) + raw_start = batch_start * args.train_seq_len + raw_end = batch_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64) + x, y = local[:-1].reshape(-1, args.train_seq_len), local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + batch_loss = model(x, y, temperature=temperature).detach() + n = float(y.numel()) + loss_sum += batch_loss.to(torch.float64) * n + token_count += n + prev_ids, tgt_ids = x.reshape(-1), y.reshape(-1) + tok_bytes = base_bytes_lut[tgt_ids].to(torch.int16) + tok_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.int16) + byte_count += tok_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, token_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + val_loss = loss_sum / token_count + bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) + model.train() + return float(val_loss.item()), float(bpb) + +def eval_val_sliding(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, temperature=1.0): + seq_len = args.train_seq_len + batch_size = args.sliding_batch_size + total_tokens = val_tokens.numel() - 1 + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + all_starts = list(range(0, total_tokens - seq_len, stride)) + my_starts = all_starts[rank::world_size] + model.eval() + with torch.inference_mode(): + for i in range(0, len(my_starts), batch_size): + batch_starts = my_starts[i:i + batch_size] + starts_t = torch.tensor(batch_starts, dtype=torch.int64) + offsets = torch.arange(seq_len + 1, dtype=torch.int64) + indices = starts_t.unsqueeze(1) + offsets.unsqueeze(0) + local_batch = val_tokens[indices].to(device=device, dtype=torch.int64, non_blocking=True) + x, y = local_batch[:, :-1], local_batch[:, 1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_token_loss = model(x, y, reduction="none", temperature=temperature).detach() + for b, start in enumerate(batch_starts): + score_from = 0 if start == 0 else seq_len - stride + scored = per_token_loss[b, score_from:] + sx, sy = x[b, score_from:], y[b, score_from:] + loss_sum += scored.to(torch.float64).sum() + token_count += scored.numel() + tok_bytes = base_bytes_lut[sy].to(torch.int16) + tok_bytes += (has_leading_space_lut[sy] & ~is_boundary_token_lut[sx]).to(torch.int16) + byte_count += tok_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, token_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + val_loss = loss_sum / token_count + bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) + model.train() + return float(val_loss.item()), float(bpb) + +def find_temp(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut): + best_t, best_loss = 1.0, float("inf") + for t in [0.85, 0.90, 0.95, 1.00, 1.05, 1.10]: + loss, _ = eval_val(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, temperature=t) + if loss < best_loss: + best_loss = loss + best_t = t + return best_t + +# --------------------------------------------------------------------------- +# Checkpointing +# --------------------------------------------------------------------------- +def _ckpt_path(d, step): return os.path.join(d, f"ckpt_step{step:07d}.pt") + +def _latest_checkpoint(d): + if not os.path.isdir(d): return None + ckpts = sorted(glob.glob(os.path.join(d, "ckpt_step*.pt"))) + return ckpts[-1] if ckpts else None + +def save_checkpoint(d, step, model, optimizers, ms, _untied, ema_model, _ema_started, _ema_steps): + os.makedirs(d, exist_ok=True) + path = _ckpt_path(d, step) + tmp = path + ".tmp" + torch.save({ + "step": step, "training_time_ms": ms, "model": model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "rng_cpu": torch.get_rng_state(), "rng_cuda": torch.cuda.get_rng_state(), + "flags": {"_untied": _untied, "_ema_started": _ema_started, "_ema_steps": _ema_steps}, + "ema_model": ema_model.state_dict() if ema_model is not None else None, + }, tmp) + os.replace(tmp, path) + +def load_checkpoint(path, model, optimizers, device, ema_model=None): + p = torch.load(path, map_location=device, weights_only=False) + model.load_state_dict(p["model"], strict=True) + for opt, sd in zip(optimizers, p["optimizers"]): + opt.load_state_dict(sd) + if ema_model is not None and p.get("ema_model") is not None: + ema_model.load_state_dict(p["ema_model"], strict=True) + torch.set_rng_state(p["rng_cpu"].cpu()) + torch.cuda.set_rng_state(p["rng_cuda"].cpu()) + f = p.get("flags", {}) + return p["step"], p["training_time_ms"], f.get("_untied", False), f.get("_ema_started", False), f.get("_ema_steps", 0) + +# --------------------------------------------------------------------------- +# Main training +# --------------------------------------------------------------------------- +def main(): + args = Hyperparameters() + code = Path(__file__).read_text(encoding="utf-8") + + global _INT8_KERNEL + _INT8_KERNEL = args.use_int8_kernel + + global _TRITON_KERNEL + _TRITON_KERNEL = args.use_triton_kernel and HAS_TRITON + + global _SCALE_FP8_STE + _SCALE_FP8_STE = args.scale_fp8 + + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = world_size > 1 + grad_accum_steps = max(1, 8 // world_size) + grad_scale = 1.0 / grad_accum_steps + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + os.makedirs("logs/cuda/xnor/", exist_ok=True) + logfile = f"logs/cuda/xnor/{args.run_id}.txt" if master_process else None + if master_process: print(logfile) + def log0(msg, console=True): + if not master_process: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + val_tokens = ld_val(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_luts(sp, args.vocab_size, device) + + # --- Model --- + # BINARIZE_ACTIVATIONS: 0=BWN, 1=full XNOR, 2=XNOR but BWN on MLP down proj + _ba = args.binarize_activations + _binarize_act = _ba >= 1 + _binarize_down = _ba == 1 + base_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + group_size=args.xnor_group_size, activation=args.activation_type, + embed_dim=args.embed_dim, fp_storage=args.fp_storage, + softcap_type=args.softcap_type, rope_type=args.rope_type, + yarn_max_len=args.yarn_max_len, train_seq_len=args.train_seq_len, + binarize_act=_binarize_act, binarize_down=_binarize_down, + smear=args.smear, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, nn.Linear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.tie_embeddings: + base_model.lm_head.weight.requires_grad_(False) + + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.recompile_limit = 16 + # No DDP needed for EGGROLL - each GPU evaluates different perturbations + # No torch.compile - forward only, perturbations modify weights in-place + + # --- Load pretrained weights --- + if args.eggroll_load: + log0(f"loading pretrained weights from {args.eggroll_load}") + if args.eggroll_load.endswith(".ptz"): + # Compressed artifact + with open(args.eggroll_load, "rb") as f: + raw = f.read() + method_byte = raw[0] + compressed = raw[1:] + if method_byte == 2 and HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + decompressed = dctx.decompress(compressed) + elif method_byte == 1 and HAS_BROTLI: + decompressed = brotli.decompress(compressed) + else: + decompressed = lzma.decompress(compressed) + loaded = torch.load(io.BytesIO(decompressed), map_location="cpu", weights_only=False) + base_model.load_state_dict(deq_sd(loaded), strict=False) + else: + # Raw checkpoint + ckpt = torch.load(args.eggroll_load, map_location="cpu", weights_only=False) + sd = ckpt["model"] if "model" in ckpt else ckpt + base_model.load_state_dict(sd, strict=False) + log0("pretrained weights loaded") + + # --- EGGROLL setup --- + # No gradients needed - forward only + for p in base_model.parameters(): + p.requires_grad_(False) + binary_params = get_binary_params(base_model, num_layers=args.num_layers, + last_n_layers=args.eggroll_layers) + log0(f"EGGROLL binary params: {len(binary_params)} layers, " + f"{sum(p.numel() for _,p in binary_params)} params") + + # --- Log --- + log0("--- Hyperparameters ---", console=False) + log0(" ".join(f"{a}={getattr(args,a)}" for a in sorted(dir(args)) + if not a.startswith("_") and a not in ("train_files","val_files") + and not callable(getattr(args,a))), console=False) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"XNOR-EGGROLL | params:{n_params} L:{args.num_layers} d:{args.model_dim} " + f"h:{args.num_heads} kv:{args.num_kv_heads} mlp:{args.mlp_mult}x " + f"act:{args.activation_type} xnor_act:{args.binarize_activations} " + f"g:{args.xnor_group_size} ws:{world_size} pop:{args.pop_size} " + f"sigma:{args.eggroll_sigma} rank:{args.eggroll_rank}") + + # --- Data --- + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step, elapsed_ms): + import math + if args.lr_schedule == "cosine": + if max_wallclock_ms is None: + total = args.iterations + warmup = min(args.lr_warmup_steps, total) + if step < warmup: + return step / max(warmup, 1) + progress = (step - warmup) / max(total - warmup, 1) + else: + warmup_ms = args.lr_warmup_steps * (elapsed_ms / max(step, 1)) if step > 0 else 1000.0 + if elapsed_ms < warmup_ms: + return elapsed_ms / max(warmup_ms, 1e-9) + progress = (elapsed_ms - warmup_ms) / max(max_wallclock_ms - warmup_ms, 1e-9) + return 0.5 * (1.0 + math.cos(math.pi * min(progress, 1.0))) + if args.warmdown_fraction <= 0: return 1.0 + if max_wallclock_ms is None: + warmdown_start = int(args.iterations * (1.0 - args.warmdown_fraction)) + return max((args.iterations - step) / max(args.iterations * args.warmdown_fraction, 1), 0.0) if step >= warmdown_start else 1.0 + warmdown_ms = max_wallclock_ms * args.warmdown_fraction + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # --- Seq len scheduling --- + seq_len_stages = [args.train_seq_len] + if args.seq_len_schedule: + seq_len_stages = [] + sl = 128 + while sl <= args.train_seq_len: + seq_len_stages.append(sl) + sl *= 2 + if seq_len_stages[-1] != args.train_seq_len: + seq_len_stages.append(args.train_seq_len) + log0(f"seq_len_schedule: {seq_len_stages}") + + def get_seq_len(step, elapsed_ms): + if len(seq_len_stages) <= 1: + return seq_len_stages[0] + n = len(seq_len_stages) + if max_wallclock_ms is not None: + stage_idx = min(int(elapsed_ms / (max_wallclock_ms / n)), n - 1) + else: + stage_idx = min(int(step / (args.iterations / n)), n - 1) + return seq_len_stages[stage_idx] + + # --- Main EGGROLL loop --- + training_time_ms = 0.0 + stop_after_step = None + pop_per_gpu = args.pop_size // world_size + assert pop_per_gpu % 2 == 0, "pop_size / world_size must be even for antithetic pairs" + pairs_per_gpu = pop_per_gpu // 2 + + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + + base_model.eval() + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(args, base_model, rank, world_size, device, 1, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms") + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms") + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + lr_scale = lr_mul(step, elapsed_ms) + current_seq_len = get_seq_len(step, elapsed_ms) + + # Get data batch (same for all perturbations) + x, y = train_loader.next_batch(args.train_batch_tokens, current_seq_len, 1) + + # Evaluate perturbations (antithetic pairs) + fitnesses_pos = [] + fitnesses_neg = [] + local_seeds = [] + + for i in range(pairs_per_gpu): + global_idx = step * (args.pop_size // 2) + rank * pairs_per_gpu + i + seed = global_idx * 137 + 42 # deterministic seed + + # Positive perturbation + apply_perturbation(binary_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=+1.0, device=device) + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss_pos = base_model(x, y) + fitnesses_pos.append(-loss_pos.item()) + remove_perturbation(binary_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=+1.0, device=device) + + # Negative perturbation (antithetic) + apply_perturbation(binary_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=-1.0, device=device) + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss_neg = base_model(x, y) + fitnesses_neg.append(-loss_neg.item()) + remove_perturbation(binary_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=-1.0, device=device) + + local_seeds.append(seed) + + # Gather fitnesses across GPUs + if distributed: + all_pos = [None] * world_size + all_neg = [None] * world_size + all_seeds_list = [None] * world_size + dist.all_gather_object(all_pos, fitnesses_pos) + dist.all_gather_object(all_neg, fitnesses_neg) + dist.all_gather_object(all_seeds_list, local_seeds) + fitnesses_pos = [f for sublist in all_pos for f in sublist] + fitnesses_neg = [f for sublist in all_neg for f in sublist] + all_seeds = [s for sublist in all_seeds_list for s in sublist] + else: + all_seeds = local_seeds + + # Shape fitnesses + shaped = shape_fitnesses(fitnesses_pos, fitnesses_neg, method=args.fitness_shaping) + + # Update weights + lr = args.eggroll_lr * lr_scale + compute_eggroll_update(binary_params, all_seeds, shaped, args.eggroll_sigma, + args.eggroll_rank, lr, device=device) + + step += 1 + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + if args.train_log_every > 0 and step % args.train_log_every == 0: + mean_fitness = sum(fitnesses_pos) / len(fitnesses_pos) + log0(f"step:{step}/{args.iterations} fitness:{mean_fitness:.4f} " + f"t:{approx_ms:.0f}ms step_avg:{approx_ms/step:.1f}ms") + + # Wallclock cap + if stop_after_step is None and max_wallclock_ms is not None and step % 10 == 0: + reached_cap = approx_ms >= max_wallclock_ms + if distributed: + cap_t = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(cap_t, op=dist.ReduceOp.MAX) + reached_cap = bool(cap_t.item()) + if reached_cap: + stop_after_step = step + + # --- Sign-sort permutation for compression --- + # Permute MLP hidden dimension so same-sign weights cluster together. + # Sort columns of down_proj by their sign pattern (lexicographic on sign bits). + # Apply matching row permutation to up_proj. Model output is identical. + if master_process: + sd = base_model.state_dict() + for layer_idx in range(args.num_layers): + # Find the MLP weight pairs + down_key = f"blocks.{layer_idx}.mlp.proj.weight" + if args.activation_type == "swiglu": + up_key = f"blocks.{layer_idx}.mlp.gate_up.weight" + else: + up_key = f"blocks.{layer_idx}.mlp.fc.weight" + if down_key not in sd or up_key not in sd: + continue + w_down = sd[down_key] # [dim, hidden] + w_up = sd[up_key] # [hidden, dim] (or [hidden*2, dim] for swiglu) + hidden = w_down.shape[1] + # Sort hidden dim columns of w_down by sign pattern + signs = (w_down >= 0).to(torch.int8) # [dim, hidden] + # Pack sign columns into sortable keys (use first 64 rows as sort key) + n_key_rows = min(64, signs.shape[0]) + sort_keys = torch.zeros(hidden, dtype=torch.int64, device=w_down.device) + for r in range(n_key_rows): + sort_keys += signs[r, :].long() << (n_key_rows - 1 - r) + perm = sort_keys.argsort() + # Apply permutation + sd[down_key] = w_down[:, perm] + if args.activation_type == "swiglu": + half = w_up.shape[0] // 2 + sd[up_key] = torch.cat([w_up[:half][perm], w_up[half:][perm]], dim=0) + else: + sd[up_key] = w_up[perm] + if base_model.tie_embeddings: + sd.pop("lm_head.weight", None) + q_obj, q_stats = q_sd(sd, group_size=args.xnor_group_size, + fp_storage=args.fp_storage, scale_fp8=args.scale_fp8) + buf = io.BytesIO() + torch.save(q_obj, buf) + raw_bytes = buf.getvalue() + # Compare LZMA vs Brotli, keep smallest + lzma_blob = lzma.compress(raw_bytes, preset=9) + best_blob, best_method = lzma_blob, "lzma" + comp_log = f"compression: lzma={len(lzma_blob)/1e6:.2f}MB" + if HAS_BROTLI: + brotli_blob = brotli.compress(raw_bytes, quality=11) + comp_log += f" brotli={len(brotli_blob)/1e6:.2f}MB" + if len(brotli_blob) < len(best_blob): + best_blob, best_method = brotli_blob, "brotli" + if HAS_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + zstd_blob = cctx.compress(raw_bytes) + comp_log += f" zstd={len(zstd_blob)/1e6:.2f}MB" + if len(zstd_blob) < len(best_blob): + best_blob, best_method = zstd_blob, "zstd" + log0(comp_log) + log0(f"using: {best_method}") + with open("final_model.xnor.ptz", "wb") as f: + # Method byte: 0=lzma, 1=brotli, 2=zstd + method_id = {"lzma": 0, "brotli": 1, "zstd": 2}[best_method] + f.write(bytes([method_id])) + f.write(best_blob) + artifact_bytes = 1 + len(best_blob) + code_bytes = len(code.encode("utf-8")) + total = artifact_bytes + code_bytes + log0(f"artifact:{artifact_bytes/1e6:.2f}MB binary:{q_stats['binary_params']}({q_stats['binary_bytes']}B) " + f"fp:{q_stats['fp_params']}({q_stats['fp_bytes']}B) code:{code_bytes}") + log0(f"budget:{total}/{16000000} ({total/1e6:.2f}/{16.00:.2f}MB) " + f"{'FITS' if total <= 16000000 else 'OVER'}") + + # --- Roundtrip eval --- + if distributed: dist.barrier() + with open("final_model.xnor.ptz", "rb") as f: + raw = f.read() + method_byte = raw[0] + compressed = raw[1:] + if method_byte == 2 and HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + decompressed = dctx.decompress(compressed) + elif method_byte == 1 and HAS_BROTLI: + decompressed = brotli.decompress(compressed) + else: + decompressed = lzma.decompress(compressed) + loaded = torch.load(io.BytesIO(decompressed), map_location="cpu", weights_only=False) + base_model.load_state_dict(deq_sd(loaded), strict=False) + torch._dynamo.reset() + q_val_loss, q_val_bpb = eval_val(args, base_model, rank, world_size, device, 1, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"final_xnor_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f}") + + opt_temp = 1.0 + if args.temp_scaling: + torch.cuda.synchronize() + t_temp = time.perf_counter() + calibration_tokens = train_loader.stream.take(65536).to(device) + opt_temp = find_temp(args, base_model, rank, world_size, device, 1, + calibration_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + torch.cuda.synchronize() + log0(f"temp_scaling optimal_T:{opt_temp:.2f} time:{1000*(time.perf_counter()-t_temp):.0f}ms") + + if args.sliding_eval: + torch.cuda.synchronize() + t_sl = time.perf_counter() + sw_loss, sw_bpb = eval_val_sliding(args, base_model, rank, world_size, device, 1, + val_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, stride=args.sliding_eval_stride, + temperature=opt_temp) + torch.cuda.synchronize() + log0(f"final_sliding val_loss:{sw_loss:.4f} val_bpb:{sw_bpb:.4f} " + f"(stride={args.sliding_eval_stride}, T={opt_temp:.2f}) time:{1000*(time.perf_counter()-t_sl):.0f}ms") + + if distributed: dist.destroy_process_group() + + +if __name__ == "__main__": + main() +==================================================================================================== +loading pretrained weights from checkpoints_n2_100k/ckpt_step0100000.pt +pretrained weights loaded +EGGROLL binary params: 4 layers, 11534336 params +--- Hyperparameters --- +activation_type=signsq binarize_activations=2 checkpoint_dir=./checkpoints checkpoint_every=0 churn_log_every=0 compile_mode=default data_path=./data/datasets/fineweb10B_sp1024 eggroll_layers=1 eggroll_load=checkpoints_n2_100k/ckpt_step0100000.pt eggroll_lr=1e-06 eggroll_rank=8 eggroll_sigma=0.0001 embed_dim=384 fitness_shaping=rank fp_storage=True iterations=10 logit_softcap=10.0 lr_schedule=cosine lr_warmup_steps=0 max_wallclock_seconds=0.0 mlp_mult=4 model_dim=1024 num_heads=8 num_kv_heads=4 num_layers=10 pop_size=4096 qk_gain_init=2.25 rope_base=5000.0 rope_type=yarn run_id=E9_eggroll-nlayers_8xH100_1775388646 scale_fp8=False seed=42 seq_len_schedule=False sign_compress_reg=0.0 sliding_batch_size=512 sliding_eval=False sliding_eval_stride=16 smear=False softcap_type=poly temp_scaling=False tie_embeddings=1 tied_embed_init_std=0.005 tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model train_batch_tokens=65536 train_log_every=1 train_seq_len=1024 use_int8_kernel=False use_triton_kernel=False val_batch_size=524288 val_loss_every=10 vocab_size=1024 warmdown_fraction=0.3 xnor_group_size=256 yarn_max_len=2048 +XNOR-EGGROLL | params:117618768 L:10 d:1024 h:8 kv:4 mlp:4x act:signsq xnor_act:2 g:256 ws:8 pop:4096 sigma:0.0001 rank:8 +step:0/10 val_loss:2.8153 val_bpb:1.5689 train_time:0ms +step:1/10 fitness:-2.9516 t:12325ms step_avg:12325.1ms +step:2/10 fitness:-2.9244 t:24164ms step_avg:12082.1ms +step:3/10 fitness:-2.9409 t:36016ms step_avg:12005.5ms +step:4/10 fitness:-2.8975 t:48305ms step_avg:12076.2ms +step:5/10 fitness:-2.9571 t:60244ms step_avg:12048.8ms +step:6/10 fitness:-2.9360 t:72090ms step_avg:12014.9ms +step:7/10 fitness:-2.9518 t:84000ms step_avg:12000.0ms +step:8/10 fitness:-2.9419 t:95917ms step_avg:11989.7ms +step:9/10 fitness:-2.9835 t:107843ms step_avg:11982.5ms +step:10/10 fitness:-2.8931 t:119758ms step_avg:11975.8ms +step:10/10 val_loss:2.8866 val_bpb:1.6086 train_time:119766ms +compression: lzma=15.88MB brotli=15.81MB zstd=18.26MB +using: brotli +artifact:15.81MB binary:115343360(15319040B) fp:1226832(1248416B) code:69955 +budget:15882843/16000000 (15.88/16.00MB) FITS +final_xnor_roundtrip val_loss:2.8958 val_bpb:1.6137 diff --git a/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/run_cuda_xnor.sh b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/run_cuda_xnor.sh new file mode 100644 index 0000000000..ef203b46f8 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/run_cuda_xnor.sh @@ -0,0 +1,89 @@ +#!/bin/bash +# XNOR-Net LLM + +export OMP_NUM_THREADS=1 +export PYTHONWARNINGS="ignore::UserWarning:torch._inductor" + +# --- Data --- +export DATA_PATH=./data/datasets/fineweb10B_sp1024 +export TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model +export VOCAB_SIZE=1024 + +# --- Architecture --- +export NUM_LAYERS=10 +export MODEL_DIM=1024 +export NUM_HEADS=8 +export NUM_KV_HEADS=4 +export MLP_MULT=4 +export EMBED_DIM=384 +export ACTIVATION=signsq +export SMEAR=0 +export ATTN_RES=2 # 0 = Standard U-Net with skip connections, 1 = AttnRes sub-layer, 2 = AttnRes pass-level + +# --- XNOR --- +export XNOR_GROUP_SIZE=256 +export BINARIZE_ACTIVATIONS=2 # 1 = full XNOR, 0 = weight-only binary (BWN), 2 = XNOR besides MLP +export USE_INT8_KERNEL=0 +export USE_TRITON_KERNEL=1 + +# --- Attention --- +export ROPE_TYPE=yarn +export YARN_MAX_LEN=2048 +export ROPE_BASE=5000 +export QK_GAIN_INIT=2.25 + +# --- Logits --- +export LOGIT_SOFTCAP=10 +export SOFTCAP_TYPE=poly +export TIE_EMBEDDINGS=1 +export FP_STORAGE=FP8 +export SCALE_STORAGE=BF16 + +# --- Optimizer --- +export MATRIX_OPTIMIZER=muon +export USE_GRAM_NS=0 +export LR_SCHEDULE=cosine +export MATRIX_LR=0.008 +export SCALAR_LR=0.01 +export TIED_EMBED_LR=0.05 +export HEAD_LR=0.02 +export ADAM_LR=0.02 +export ADAM_WD=0.04 +export MUON_WD=0.04 +export MUON_BACKEND_STEPS=3 +export MUON_MOMENTUM=0.80 +export MUON_MOMENTUM_WARMUP_START=0.80 +export MUON_MOMENTUM_WARMUP_STEPS=200 +export WARMDOWN_FRACTION=0.3 +export SIGN_COMPRESS_REG=0.0 +export GRAD_CLIP_NORM=1.0 + +# --- Schedule --- +export SEQ_LEN_SCHEDULE=1 +export TRAIN_BATCH_TOKENS=65536 +export TRAIN_SEQ_LEN=1024 +export LR_WARMUP_STEPS=5 +export MAX_WALLCLOCK_SECONDS=0 +export ITERATIONS=100000 + +# --- Eval --- +export VAL_LOSS_EVERY=0 +export TRAIN_LOG_EVERY=1000 +export CHURN_LOG_EVERY=0 +export VAL_MAX_TOKENS=0 +export TEMP_SCALING=1 +export SLIDING_EVAL=1 +export SLIDING_EVAL_STRIDE=16 +export SLIDING_BATCH_SIZE=512 + +# --- EMA / Checkpointing --- +export EMA=0 +export EMA_START_FRACTION=0.0 +export EMA_DECAY=0.995 +export SEED=42 +export COMPILE_MODE=default +export CHECKPOINT_EVERY=10000 + +# --- Run --- +export RUN_ID=N3_full-xnor_8xH100_$(date +%s) +torchrun --standalone --nproc_per_node=8 train_gpt_cuda_xnor.py diff --git a/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/run_cuda_xnor_eggroll.sh b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/run_cuda_xnor_eggroll.sh new file mode 100644 index 0000000000..d42febca62 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/run_cuda_xnor_eggroll.sh @@ -0,0 +1,76 @@ +#!/bin/bash +# XNOR-Net EGGROLL — gradient-free fine-tuning from pretrained checkpoint +export OMP_NUM_THREADS=1 + +# --- Data --- +export DATA_PATH=./data/datasets/fineweb10B_sp1024 +export TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model +export VOCAB_SIZE=1024 + +# --- Architecture (must match checkpoint) --- +export NUM_LAYERS=10 +export MODEL_DIM=1024 +export NUM_HEADS=8 +export NUM_KV_HEADS=4 +export MLP_MULT=4 +export EMBED_DIM=384 +export ACTIVATION=signsq +export SMEAR=0 + +# --- XNOR --- +export XNOR_GROUP_SIZE=256 +export BINARIZE_ACTIVATIONS=2 +export USE_INT8_KERNEL=0 +export USE_TRITON_KERNEL=0 + +# --- Attention --- +export ROPE_TYPE=yarn +export YARN_MAX_LEN=2048 +export ROPE_BASE=5000 +export QK_GAIN_INIT=2.25 + +# --- Logits --- +export LOGIT_SOFTCAP=10 +export SOFTCAP_TYPE=poly +export TIE_EMBEDDINGS=1 +export FP_STORAGE=FP8 +export SCALE_STORAGE=BF16 + +# --- EGGROLL --- +export EGGROLL_LOAD=checkpoints_n2_100k/ckpt_step0100000.pt +export EGGROLL_LORA_RANK=4 +export EGGROLL_LAYERS=0 +export POP_SIZE=16384 +export EGGROLL_SIGMA=0.01 +export EGGROLL_LR=0.001 +export EGGROLL_RANK=8 +export FITNESS_SHAPING=rank +export LR_SCHEDULE=cosine +export WARMDOWN_FRACTION=0.3 + +# --- Schedule --- +export SEQ_LEN_SCHEDULE=0 +export TRAIN_BATCH_TOKENS=65536 +export TRAIN_SEQ_LEN=1024 +export LR_WARMUP_STEPS=0 +export MAX_WALLCLOCK_SECONDS=0 +export ITERATIONS=50 + +# --- Eval --- +export VAL_LOSS_EVERY=10 +export TRAIN_LOG_EVERY=1 +export CHURN_LOG_EVERY=0 +export VAL_MAX_TOKENS=0 +export TEMP_SCALING=0 +export SLIDING_EVAL=0 +export SLIDING_EVAL_STRIDE=16 +export SLIDING_BATCH_SIZE=512 + +# --- Checkpointing --- +export SEED=42 +export COMPILE_MODE=default +export CHECKPOINT_EVERY=0 + +# --- Run --- +export RUN_ID=E11_eggroll-lora_8xH100_$(date +%s) +torchrun --standalone --nproc_per_node=8 train_gpt_cuda_xnor_eggroll_lora.py \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/train_gpt_cuda_xnor.py b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/train_gpt_cuda_xnor.py new file mode 100644 index 0000000000..60ad2be960 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/train_gpt_cuda_xnor.py @@ -0,0 +1,1690 @@ +"""XNOR-Net LLM training script for OpenAI's Parameter Golf Challenge. +Ciprian-Florin Ifrim — April 2026 + +Based on the Binary-Weight-Network script, extended to full XNOR-Net: +both weights AND activations are binarized using sign() with per-group/per-token +scaling factors derived from the XNOR-Net paper (Rastegari et al. 2016). + +This is the REFERENCE implementation using STE-simulated XNOR via F.linear. +The Triton INT8×INT8 kernel version comes later for H100 deployment. + +Architecture: U-Net transformer with skip connections — provides error correction +for the information loss inherent in binary quantization of both weights and activations. +""" + +import copy +import glob +import io +import math +import os +import random +import sys +import time +import lzma +from pathlib import Path +try: + import brotli + HAS_BROTLI = True +except ImportError: + HAS_BROTLI = False +try: + import zstandard as zstd + HAS_ZSTD = True +except ImportError: + HAS_ZSTD = False +try: + from gram_newton_schulz import GramNewtonSchulz, POLAR_EXPRESS_COEFFICIENTS + HAS_GRAM_NS = True +except ImportError: + HAS_GRAM_NS = False +import numpy as np +import warnings +warnings.filterwarnings("ignore", category=UserWarning, module="torch._inductor") +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False +def _get_flash_attn_func(): + """GPU-aware FlashAttention selection. + H100: FA3 -> FA2 -> SDPA + 5090: FA4 -> FA2 -> SDPA + Other: FA2 -> SDPA + """ + if torch.cuda.is_available(): + sm = torch.cuda.get_device_capability() + # Hopper (SM 90) — prefer FA3 + if sm[0] == 9 and sm[1] == 0: + try: + from flash_attn_interface import flash_attn_func + return flash_attn_func + except ImportError: + pass + # Blackwell (SM 100/120) — try FA4 + if sm[0] >= 10: + try: + from flash_attn.cute import flash_attn_func + return flash_attn_func + except ImportError: + pass + # FA2 — universal fallback with CUDA kernels + try: + from flash_attn import flash_attn_func + return flash_attn_func + except ImportError: + pass + # PyTorch native SDPA — always works, no install needed + def _sdpa_flash_attn_func(q, k, v, causal=False, **kwargs): + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + o = F.scaled_dot_product_attention(q, k, v, is_causal=causal, enable_gqa=True) + return o.transpose(1, 2) + return _sdpa_flash_attn_func +flash_attn_func = _get_flash_attn_func() + +# --------------------------------------------------------------------------- +# Hyperparameters +# --------------------------------------------------------------------------- +def _e(k, d, t=str): + v = os.environ.get(k, str(d)) + if t == bool: return bool(int(v)) + return t(v) + +class Hyperparameters: + # Data + data_path = _e("DATA_PATH", "./data/datasets/fineweb10B_sp8192") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = _e("TOKENIZER_PATH", "./data/tokenizers/fineweb_8192_bpe.model") + run_id = os.environ.get("RUN_ID", f"xnor_{int(time.time())}") + seed = _e("SEED", 1337, int) + # Compile + compile_mode = _e("COMPILE_MODE", "default") + # Eval + val_batch_size = _e("VAL_BATCH_SIZE", 524288, int) + val_loss_every = _e("VAL_LOSS_EVERY", 500, int) + train_log_every = _e("TRAIN_LOG_EVERY", 10, int) + sliding_eval = _e("SLIDING_EVAL", 0, bool) + sliding_eval_stride = _e("SLIDING_EVAL_STRIDE", 16, int) + sliding_batch_size = _e("SLIDING_BATCH_SIZE", 512, int) + temp_scaling = _e("TEMP_SCALING", 0, bool) + # Training schedule + iterations = _e("ITERATIONS", 2000, int) + warmdown_fraction = _e("WARMDOWN_FRACTION", 0.15, float) + lr_warmup_steps = _e("LR_WARMUP_STEPS", 5, int) + lr_schedule = _e("LR_SCHEDULE", "linear") # "linear" (current) or "cosine" + seq_len_schedule = _e("SEQ_LEN_SCHEDULE", 0, bool) # ramp seq_len from 128 to train_seq_len + train_batch_tokens = _e("TRAIN_BATCH_TOKENS", 524288, int) + train_seq_len = _e("TRAIN_SEQ_LEN", 1024, int) + max_wallclock_seconds = _e("MAX_WALLCLOCK_SECONDS", 0.0, float) + # Architecture — from XNOR-UNet spec + vocab_size = _e("VOCAB_SIZE", 8192, int) + num_layers = _e("NUM_LAYERS", 15, int) + num_kv_heads = _e("NUM_KV_HEADS", 4, int) + model_dim = _e("MODEL_DIM", 768, int) + num_heads = _e("NUM_HEADS", 8, int) + mlp_mult = _e("MLP_MULT", 4, int) + tie_embeddings = _e("TIE_EMBEDDINGS", 1, int) + embed_dim = _e("EMBED_DIM", 312, int) + activation_type = _e("ACTIVATION", "relu2") + # Attention + rope_base = _e("ROPE_BASE", 5000.0, float) + rope_type = _e("ROPE_TYPE", "yarn") + yarn_max_len = _e("YARN_MAX_LEN", 2048, int) + qk_gain_init = _e("QK_GAIN_INIT", 2.25, float) + # Logits + logit_softcap = _e("LOGIT_SOFTCAP", 10.0, float) + softcap_type = _e("SOFTCAP_TYPE", "poly") + tied_embed_init_std = _e("TIED_EMBED_INIT_STD", 0.005, float) + # Architecture extras + smear = _e("SMEAR", 1, bool) + attn_res = _e("ATTN_RES", 0, int) # 0=disabled, 1=sub-layer, 2=pass-level + # XNOR quantization + xnor_group_size = _e("XNOR_GROUP_SIZE", 128, int) + # Controls activation binarization: 0 = none (BWN), 1 = all (full XNOR), 2 = skip MLP down proj + binarize_activations = _e("BINARIZE_ACTIVATIONS", 1, int) + # Use INT8 tensor cores for binary matmul (speed test) + use_int8_kernel = _e("USE_INT8_KERNEL", 0, bool) + # Use Triton XNOR+POPCOUNT kernel (true 1-bit matmul) + use_triton_kernel = _e("USE_TRITON_KERNEL", 0, bool) + # FP8 QAT for non-binary params + _fp_raw = os.environ.get("FP_STORAGE", "FP8") + fp_storage = True if _fp_raw == "FP8" else False + # Scale storage for binary weight groups: FP8 or BF16 + _scale_raw = os.environ.get("SCALE_STORAGE", "BF16") + scale_fp8 = True if _scale_raw == "FP8" else False + # Compression regularizer: penalizes sign entropy within groups + sign_compress_reg = _e("SIGN_COMPRESS_REG", 0.0, float) + # Optimizer + matrix_lr = _e("MATRIX_LR", 0.04, float) + scalar_lr = _e("SCALAR_LR", 0.02, float) + tied_embed_lr = _e("TIED_EMBED_LR", 0.02, float) + embed_lr = _e("EMBED_LR", 0.6, float) + head_lr = _e("HEAD_LR", 0.02, float) + matrix_optimizer = _e("MATRIX_OPTIMIZER", "muon") + muon_momentum = _e("MUON_MOMENTUM", 0.95, float) + muon_backend_steps = _e("MUON_BACKEND_STEPS", 3, int) + muon_wd = _e("MUON_WD", 0.0, float) + use_gram_ns = _e("USE_GRAM_NS", 0, bool) + muon_momentum_warmup_start = _e("MUON_MOMENTUM_WARMUP_START", 0.85, float) + muon_momentum_warmup_steps = _e("MUON_MOMENTUM_WARMUP_STEPS", 500, int) + adam_lr = _e("ADAM_LR", 0.05, float) + adam_wd = _e("ADAM_WD", 0.05, float) + beta1 = _e("BETA1", 0.9, float) + beta2 = _e("BETA2", 0.95, float) + adam_eps = _e("ADAM_EPS", 1e-8, float) + grad_clip_norm = _e("GRAD_CLIP_NORM", 0.0, float) + # EMA — compatible with binary (no zero-weight collapse unlike ternary) + ema = _e("EMA", 0, bool) + ema_decay = _e("EMA_DECAY", 0.995, float) + ema_start_fraction = _e("EMA_START_FRACTION", 0.60, float) + # Checkpointing + checkpoint_every = _e("CHECKPOINT_EVERY", 0, int) + checkpoint_dir = _e("CHECKPOINT_DIR", "./checkpoints") + # Diagnostics + churn_log_every = _e("CHURN_LOG_EVERY", 500, int) + +# Scalar/low-dim parameter names (excluded from Muon, use Adam) +CTP = ("attn_scale", "mlp_scale", "resid_mix", "q_gain", "skip_weight", + "skip_weights", "vocab_bias", "smear") + +# --------------------------------------------------------------------------- +# Binary packing (1 bit per weight, lossless) +# --------------------------------------------------------------------------- +def pack_binary(q: Tensor) -> tuple[bytes, int]: + bits = ((q.reshape(-1).to(torch.int8) + 1) // 2).numpy().astype(np.uint8) + n = len(bits) + pad = (8 - n % 8) % 8 + if pad: + bits = np.concatenate([bits, np.zeros(pad, dtype=np.uint8)]) + groups = bits.reshape(-1, 8) + packed = np.zeros(len(groups), dtype=np.uint8) + for i in range(8): + packed |= groups[:, i] << i + return packed.tobytes(), n + +def unpack_binary(data: bytes, n: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8) + bits = np.zeros((len(packed), 8), dtype=np.int8) + for i in range(8): + bits[:, i] = (packed >> i) & 1 + flat = bits.reshape(-1)[:n] + return torch.from_numpy(flat.astype(np.int8) * 2 - 1) + +# --------------------------------------------------------------------------- +# State dict serialization +# --------------------------------------------------------------------------- +def q_sd(state_dict: dict, group_size: int = 128, fp_storage=False, scale_fp8=False) -> tuple[dict, dict]: + """Binary pack large 2D weight matrices, BF16/FP8 for everything else.""" + quantized = {} + stats = {"binary_params": 0, "binary_bytes": 0, "fp_params": 0, "fp_bytes": 0} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().float().contiguous() + t_orig_shape = list(t.shape) + if t.ndim == 3: + t = t.reshape(t.shape[0], -1) + is_binary = ( + t.ndim == 2 and t.numel() > 65_536 + and "tok_emb" not in name and "lm_head" not in name + and "embed_proj" not in name + ) + if is_binary: + pad = (group_size - t.shape[1] % group_size) % group_size + t_padded = F.pad(t, (0, pad)) if pad > 0 else t + t_grouped = t_padded.reshape(-1, group_size) + scale = t_grouped.abs().mean(-1, keepdim=True).clamp(min=1e-8) + q = torch.where(t_grouped >= 0, + torch.ones_like(t_grouped, dtype=torch.int8), + -torch.ones_like(t_grouped, dtype=torch.int8)) + packed_bytes, n_bits = pack_binary(q) + if scale_fp8: + scale_stored = scale.to(torch.float8_e4m3fn).squeeze(-1) + scale_bytes_per = 1 + else: + scale_stored = scale.bfloat16().squeeze(-1) + scale_bytes_per = 2 + quantized[name] = { + "type": "binary", "packed": packed_bytes, + "scale": scale_stored, + "shape": list(t.shape), "padded_cols": t_padded.shape[1], + "group_size": group_size, "n_bits": n_bits, + "orig_shape": t_orig_shape, + } + stats["binary_params"] += t.numel() + stats["binary_bytes"] += len(packed_bytes) + scale.numel() * scale_bytes_per + elif fp_storage and t.ndim == 2: + quantized[name] = {"type": "fp8", "data": t.to(torch.float8_e4m3fn)} + stats["fp_params"] += t.numel() + stats["fp_bytes"] += t.numel() + else: + quantized[name] = {"type": "bf16", "data": t.bfloat16()} + stats["fp_params"] += t.numel() + stats["fp_bytes"] += t.numel() * 2 + return quantized, stats + +def deq_sd(quantized: dict, target_dtype=torch.bfloat16): + """Reconstruct state dict from quantized representation.""" + out = {} + for name, entry in quantized.items(): + if entry["type"] == "binary": + q = unpack_binary(entry["packed"], entry["n_bits"]) + q = q.float().reshape(-1, entry["group_size"]) + scale = entry["scale"].float().unsqueeze(-1) + # No shrinkage correction: binary q_absmean = 1.0 always + t = (q * scale).reshape(-1, entry["padded_cols"]) + shape = entry["shape"] + result = t[:shape[0], :shape[1]].to(target_dtype) + orig = entry.get("orig_shape") + out[name] = result.reshape(orig).contiguous() if orig and orig != shape else result.contiguous() + elif entry["type"] == "fp8": + out[name] = entry["data"].to(torch.float32).to(target_dtype).contiguous() + else: + out[name] = entry["data"].to(target_dtype).contiguous() + return out + +# --------------------------------------------------------------------------- +# Diagnostics +# --------------------------------------------------------------------------- +_prev_committed: dict = {} +def churn_fn(model: nn.Module, group_size: int = 128): + global _prev_committed + total = flipped = 0 + with torch.no_grad(): + for name, p in model.named_parameters(): + if p.ndim == 2 and "weight" in name and p.shape[0] > 1: + w = p.detach().float().reshape(-1, group_size) + q = torch.where(w >= 0, torch.ones_like(w), -torch.ones_like(w)).cpu().numpy() + if name in _prev_committed: + flipped += int(np.sum(q != _prev_committed[name])) + total += q.size + _prev_committed[name] = q + return flipped / max(total, 1) + +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# Muon optimizer +# --------------------------------------------------------------------------- +def ns_orth(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +def gram_ns_orth(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: + """Gram Newton-Schulz with Polar Express coefficients. 5 steps, restart at step 2. + Fully unrolled for torch.compile. bfloat16 throughout.""" + X = G.bfloat16() + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + X = X / (X.norm() + eps) + n = X.size(0) + I = torch.eye(n, device=X.device, dtype=X.dtype) + + # Phase 1: steps 0-1 + R = X @ X.T + # Step 0: a=4.0848, b=-6.8946, c=2.9270 + Z = -6.8946 * R + 2.9270 * (R @ R) + Q = Z + 4.0848 * I + RZ = 4.0848 * R + R @ Z + R = 4.0848 * RZ + Z @ RZ + # Step 1: a=3.9505, b=-6.3029, c=2.6377 + Z = -6.3029 * R + 2.6377 * (R @ R) + Q = 3.9505 * Q + Q @ Z + # No R update (next is reset) + + # Reset + X = Q @ X + R = X @ X.T + + # Phase 2: steps 2-4 + # Step 2: a=3.7418, b=-5.5913, c=2.3037 + Z = -5.5913 * R + 2.3037 * (R @ R) + Q = Z + 3.7418 * I + RZ = 3.7418 * R + R @ Z + R = 3.7418 * RZ + Z @ RZ + # Step 3: a=2.8769, b=-3.1427, c=1.2046 + Z = -3.1427 * R + 1.2046 * (R @ R) + Q = 2.8769 * Q + Q @ Z + RZ = 2.8769 * R + R @ Z + R = 2.8769 * RZ + Z @ RZ + # Step 4: a=2.8366, b=-3.0525, c=1.2012 + Z = -3.0525 * R + 1.2012 * (R @ R) + Q = 2.8366 * Q + Q @ Z + + X = Q @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr, momentum, backend_steps, nesterov=True, wd=0.0): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, wd=wd)) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr, momentum = group["lr"], group["momentum"] + backend_steps, nesterov = group["backend_steps"], group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = F.rms_norm(g.float(), (g.size(-1),)).bfloat16() + g = ns_orth(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("wd", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr:curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.mul_(1 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + +# --------------------------------------------------------------------------- +# Data loading +# --------------------------------------------------------------------------- +def ld_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" Tensor: + chunks = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int): + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start:start + per_rank_span].pin_memory().to(self.device, non_blocking=True).to(torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x, y + +# --------------------------------------------------------------------------- +# Model components +# --------------------------------------------------------------------------- +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def apply_qat_ste(w: Tensor, fp_storage) -> Tensor: + if not fp_storage: + return w + w_sim = w.to(torch.float8_e4m3fn).to(w.dtype) + return (w_sim - w).detach() + w + +class QATLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=False, fp_storage=False): + super().__init__(in_features, out_features, bias=bias) + self.fp_storage = fp_storage + def forward(self, x: Tensor) -> Tensor: + w = apply_qat_ste(self.weight, self.fp_storage) + return F.linear(x, w.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) + +class QATEmbedding(nn.Embedding): + def __init__(self, num_embeddings, embedding_dim, fp_storage=False): + super().__init__(num_embeddings, embedding_dim) + self.fp_storage = fp_storage + def forward(self, input: Tensor) -> Tensor: + w = apply_qat_ste(self.weight, self.fp_storage) + return F.embedding(input, w, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + +# --------------------------------------------------------------------------- +# XNOR-Net Linear Layers +# --------------------------------------------------------------------------- +_INT8_KERNEL = False # set from Hyperparameters.use_int8_kernel in main() + +class _INT8XNORFn(torch.autograd.Function): + """INT8 tensor core forward, BF16 backward. + Uses per-row alpha (not per-group) so we get a single torch._int_mm call. + """ + @staticmethod + def forward(ctx, x, w): + # Per-row weight scale (single matmul) + w_alpha = w.detach().abs().mean(dim=-1).clamp_(min=1e-8) # [N] + x_beta = x.detach().abs().mean(dim=-1, keepdim=True).clamp_(min=1e-8) # [..., 1] + + w_int8 = w.detach().sign().to(torch.int8) # [N, K] + x_int8 = x.detach().sign().to(torch.int8) # [..., K] + + # Flatten to 2D for torch._int_mm(A[M,K], B[K,N]) -> [M,N] + orig_shape = x.shape[:-1] + M, K, N = orig_shape.numel(), x.shape[-1], w.shape[0] + x_2d = x_int8.reshape(M, K) + wt = w_int8.t().contiguous() # [K, N] + raw = torch._int_mm(x_2d, wt) # [M, N] INT32 + + y = raw.to(x.dtype) * x_beta.reshape(M, 1) * w_alpha.unsqueeze(0) + ctx.save_for_backward(x, w, w_alpha, x_beta) + return y.reshape(*orig_shape, N) + + @staticmethod + def backward(ctx, grad_output): + x, w, w_alpha, x_beta = ctx.saved_tensors + dtype = grad_output.dtype + + w_binary = (w.sign() * w_alpha.unsqueeze(-1)).to(dtype) # [N, K] + x_binary = (x.sign() * x_beta).to(dtype) # [..., K] + + # STE: gradient flows through sign() as identity + grad_x = grad_output @ w_binary + M = grad_output.shape[:-1].numel() + grad_w = grad_output.reshape(M, -1).t() @ x_binary.reshape(M, -1) + return grad_x, grad_w + +# --------------------------------------------------------------------------- +# Triton XNOR+POPCOUNT kernel — true 1-bit matmul +# --------------------------------------------------------------------------- +_TRITON_KERNEL = False # set from Hyperparameters in main() +_SCALE_FP8_STE = False # simulate FP8 scale quantization during training + +def pack_signs(x: Tensor) -> Tensor: + """Pack sign bits into int32. +1 -> bit 1, -1 -> bit 0. + x: [..., K] any dtype. K must be divisible by 32. + returns: [..., K//32] int32 + """ + K = x.shape[-1] + assert K % 32 == 0, f"K={K} must be divisible by 32" + bits = (x > 0).view(*x.shape[:-1], K // 32, 32) + shifts = torch.arange(32, device=x.device, dtype=torch.int32) + return (bits.to(torch.int32) << shifts).sum(dim=-1) + +if HAS_TRITON: + @triton.jit + def _xnor_matmul_fwd_kernel( + X_ptr, W_ptr, Y_ptr, alpha_ptr, beta_ptr, + M, N, K_packed, GROUP_PACKED: tl.constexpr, num_groups, + stride_xm, stride_wn, stride_yn, stride_alpha_n, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + ): + pid = tl.program_id(0) + num_n = tl.cdiv(N, BLOCK_N) + pid_m = pid // num_n + pid_n = pid % num_n + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_m = offs_m < M + mask_n = offs_n < N + + y_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.bfloat16) + + x_base = X_ptr + offs_m * stride_xm + w_base = W_ptr + offs_n * stride_wn + + # Loop over groups, each with GROUP_PACKED packed int32s + for g in range(num_groups): + group_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + k_start = g * GROUP_PACKED + for k in range(GROUP_PACKED): + x = tl.load(x_base + k_start + k, mask=mask_m, other=0) + w = tl.load(w_base + k_start + k, mask=mask_n, other=0) + # Cast broadcast XOR to int32 to prevent int64 promotion in popc + diff = (x[:, None] ^ w[None, :]).to(tl.int32) + group_acc += tl.extra.cuda.libdevice.popc(diff) + + # diff_bits = popc(XOR), dot = GROUP_SIZE - 2 * diff_bits + group_dot = (GROUP_PACKED * 32 - 2 * group_acc).to(tl.bfloat16) + + # Per-group alpha: alpha[n, g] + alpha_g = tl.load(alpha_ptr + offs_n * stride_alpha_n + g, mask=mask_n, other=1.0).to(tl.bfloat16) + y_acc += group_dot * alpha_g[None, :] + + # Per-token beta + beta = tl.load(beta_ptr + offs_m, mask=mask_m, other=1.0).to(tl.bfloat16) + y_acc = y_acc * beta[:, None] + + # Store + y_ptrs = Y_ptr + offs_m[:, None] * stride_yn + offs_n[None, :] + tl.store(y_ptrs, y_acc, mask=mask_m[:, None] & mask_n[None, :]) + +class _TritonXNORFn(torch.autograd.Function): + """Triton XNOR+POPCOUNT forward with per-group alpha, BF16 backward.""" + @staticmethod + def forward(ctx, x, w, group_size): + N, K = w.shape + assert K % group_size == 0, f"K={K} must be divisible by group_size={group_size}" + assert K % 32 == 0, f"K={K} must be divisible by 32" + num_groups = K // group_size + GROUP_PACKED = group_size // 32 + + # Per-group weight scale: [N, num_groups] + w_alpha = w.detach().reshape(N, num_groups, group_size).abs().mean(dim=-1).clamp_(min=1e-8) + if _SCALE_FP8_STE: + w_alpha = w_alpha.to(torch.float8_e4m3fn).to(w_alpha.dtype) + # Per-token activation scale: [..., 1] + x_beta_kd = x.detach().abs().mean(dim=-1, keepdim=True).clamp_(min=1e-8) + + # Pack signs into int32 + x_packed = pack_signs(x.detach()) # [..., K//32] + w_packed = pack_signs(w.detach()) # [N, K//32] + + orig_shape = x.shape[:-1] + M = orig_shape.numel() + K_packed = K // 32 + + x_2d = x_packed.reshape(M, K_packed).contiguous() + w_alpha_c = w_alpha.contiguous() # [N, num_groups] + beta_flat = x_beta_kd.reshape(M).contiguous() + + Y = torch.empty(M, N, device=x.device, dtype=torch.bfloat16) + + BLOCK_M, BLOCK_N = 64, 64 + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),) + + _xnor_matmul_fwd_kernel[grid]( + x_2d, w_packed, Y, w_alpha_c, beta_flat, + M, N, K_packed, GROUP_PACKED, num_groups, + x_2d.stride(0), w_packed.stride(0), Y.stride(0), w_alpha_c.stride(0), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + ) + + ctx.save_for_backward(x, w, w_alpha, x_beta_kd) + ctx.group_size = group_size + return Y.reshape(*orig_shape, N) + + @staticmethod + def backward(ctx, grad_output): + x, w, w_alpha, x_beta_kd = ctx.saved_tensors + dtype = grad_output.dtype + group_size = ctx.group_size + + # Reconstruct per-group binary weights for backward + N, K = w.shape + num_groups = K // group_size + w_binary = (w.sign().reshape(N, num_groups, group_size) * + w_alpha.unsqueeze(-1)).reshape(N, K).to(dtype) + x_binary = (x.sign() * x_beta_kd).to(dtype) + + grad_x = grad_output @ w_binary + M = grad_output.shape[:-1].numel() + grad_w = grad_output.reshape(M, -1).t() @ x_binary.reshape(M, -1) + return grad_x, grad_w, None # None for group_size + +class XNORLinear(nn.Linear): + """XNOR-Net linear layer: binarizes BOTH weights AND activations. + + Forward pass: + 1. Weight binarization: W_bin = sign(W) * alpha, alpha = mean(|W|) per group + 2. Activation binarization: X_bin = sign(X) * beta, beta = mean(|X|) per token + 3. Output = F.linear(X_bin, W_bin) + + Backward pass: Straight-Through Estimator (STE) — gradients flow through as + if the binarization was identity. This matches the XNOR-Net paper's training + algorithm (Algorithm 1). + + When binarize_act=False, this degrades to Binary-Weight-Network (BWN). + """ + def __init__(self, in_features, out_features, bias=False, group_size=128, + binarize_act=True): + super().__init__(in_features, out_features, bias=bias) + self.group_size = group_size + self.binarize_act = binarize_act + + def forward(self, x: Tensor) -> Tensor: + # Triton XNOR+POPCOUNT path (full XNOR only, per-row alpha) + if _TRITON_KERNEL and self.binarize_act: + y = _TritonXNORFn.apply(x, self.weight, self.group_size) + if self.bias is not None: + y = y + self.bias.to(y.dtype) + return y + # INT8 tensor core path (full XNOR only, per-row alpha) + if _INT8_KERNEL and self.binarize_act: + y = _INT8XNORFn.apply(x, self.weight) + if self.bias is not None: + y = y + self.bias.to(y.dtype) + return y + + # Standard STE path (per-group alpha) + w = self.weight.bfloat16() + g = self.group_size + + # --- Weight binarization (STE) --- + w_g = w.reshape(-1, g) + alpha = w_g.abs().mean(-1, keepdim=True).clamp(min=1e-8) + # Simulate FP8 scale quantization so model compensates for roundtrip error + if _SCALE_FP8_STE: + alpha_q = alpha.to(torch.float8_e4m3fn).to(alpha.dtype) + alpha = alpha + (alpha_q - alpha).detach() + w_sign = torch.where(w_g >= 0, torch.ones_like(w_g), -torch.ones_like(w_g)) + w_binary = w + ((w_sign * alpha).reshape(w.shape) - w).detach() + + # --- Activation binarization (STE) --- + if self.binarize_act: + beta = x.abs().mean(dim=-1, keepdim=True).clamp(min=1e-8) + x_sign = torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x)) + x_binary = x + ((x_sign * beta) - x).detach() + else: + x_binary = x + + return F.linear(x_binary, w_binary, + self.bias.to(x.dtype) if self.bias is not None else None) + + +class NormedXNORLinear(XNORLinear): + """XNOR linear with RMSNorm on input. + Used for output projections receiving un-normalized activations (attention out, MLP down). + The RMSNorm ensures zero-mean input before binarization — critical for sign() quality + (XNOR-Net paper: B-A-C-P block order, Table 3b: 30.3% vs 44.2%). + """ + def forward(self, x: Tensor) -> Tensor: + return super().forward(F.rms_norm(x, (x.size(-1),))) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(p in name for p in CTP)) and param.dtype != torch.float32: + param.data = param.data.float() + +# --------------------------------------------------------------------------- +# Rotary embeddings +# --------------------------------------------------------------------------- +class Rotary(nn.Module): + def __init__(self, dim, base=10000.0, rope_type="rope", yarn_max_len=4096, train_seq_len=1024): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + if rope_type == "yarn": + scale = train_seq_len / yarn_max_len + freq_idx = torch.arange(0, dim, 2, dtype=torch.float32) + ramp = torch.clamp((freq_idx / dim - 0.25) / 0.75, 0.0, 1.0) + inv_freq = inv_freq / (ramp * (1.0 / scale - 1.0) + 1.0) + # Precompute for max seq_len, slice at runtime + t = torch.arange(train_seq_len, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self.register_buffer("_cos", freqs.cos()[None, :, None, :], persistent=False) + self.register_buffer("_sin", freqs.sin()[None, :, None, :], persistent=False) + + def forward(self, seq_len, device, dtype): + return self._cos[:, :seq_len].to(dtype=dtype), self._sin[:, :seq_len].to(dtype=dtype) + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- +class CausalSelfAttention(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + group_size=128, rope_type="rope", yarn_max_len=4096, + train_seq_len=1024, binarize_act=True): + super().__init__() + self.num_heads, self.num_kv_heads = num_heads, num_kv_heads + self.head_dim = dim // num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + # Fused QKV — single XNOR matmul + self.c_qkv = XNORLinear(dim, self.q_size + 2 * self.kv_size, bias=False, + group_size=group_size, binarize_act=binarize_act) + # Output projection with RMSNorm before binarization (B-A-C-P order) + self.proj = NormedXNORLinear(dim, dim, bias=False, + group_size=group_size, binarize_act=binarize_act) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, + rope_type=rope_type, yarn_max_len=yarn_max_len, + train_seq_len=train_seq_len) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + qkv_out = self.c_qkv(x) + q_out, k_out, v_out = qkv_out.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q_out.reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = k_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = v_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + # QK normalization (BF16, stays float — non-negotiable for attention quality) + q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + # FlashAttention-3 (BF16 — attention scores/softmax/V are NEVER binarized) + y = flash_attn_func(q.contiguous(), k.contiguous(), v.contiguous(), causal=True) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + +# --------------------------------------------------------------------------- +# MLP +# --------------------------------------------------------------------------- +class MLP(nn.Module): + def __init__(self, dim, mlp_mult, group_size=128, activation="relu2", + binarize_act=True, binarize_down=True): + super().__init__() + hidden = mlp_mult * dim + self.activation = activation + if activation == "swiglu": + self.gate_up = XNORLinear(dim, hidden * 2, bias=False, + group_size=group_size, binarize_act=binarize_act) + else: + self.fc = XNORLinear(dim, hidden, bias=False, + group_size=group_size, binarize_act=binarize_act) + self.proj = NormedXNORLinear(hidden, dim, bias=False, + group_size=group_size, binarize_act=binarize_down) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + if self.activation == "swiglu": + gu = self.gate_up(x) + gate, up = gu.chunk(2, dim=-1) + return self.proj(F.silu(gate) * up) + elif self.activation == "signsq": + h = self.fc(x) + return self.proj(h * h.abs()) + else: # relu2 + return self.proj(torch.relu(self.fc(x)).square()) + +# --------------------------------------------------------------------------- +# SmearModule +# --------------------------------------------------------------------------- +class SmearModule(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + cumsum = x.cumsum(dim=1) + counts = torch.arange(1, x.size(1) + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smeared = cumsum / counts + gate = torch.tanh(self.gate.to(dtype=x.dtype)) + return x + gate * (smeared - x) + +# --------------------------------------------------------------------------- +# Transformer block +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# Attention Residuals (AttnRes) — depth-wise attention over prior outputs +# From: "Attention Residuals" (Kimi Team, 2026) +# --------------------------------------------------------------------------- +class AttnRes(nn.Module): + def __init__(self, dim, n_queries): + super().__init__() + # Zero-init -> uniform weights at start -> equivalent to standard residual + self.queries = nn.ParameterList([ + nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + for _ in range(n_queries) + ]) + self.key_norm_weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + + def attend(self, stored, query_idx): + w = self.queries[query_idx] + logits = [] + for v in stored: + k = F.rms_norm(v, (v.size(-1),)) * self.key_norm_weight.to(v.dtype) + logits.append(torch.einsum('d, b t d -> b t', w.to(v.dtype), k)) + logits = torch.stack(logits, dim=0) + weights = F.softmax(logits, dim=0) + out = torch.zeros_like(stored[0]) + for i, v in enumerate(stored): + out = out + weights[i].unsqueeze(-1) * v + return out + +# --------------------------------------------------------------------------- +# Transformer Block +# --------------------------------------------------------------------------- +class Block(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, group_size=128, activation="relu2", + rope_type="rope", yarn_max_len=4096, train_seq_len=1024, + binarize_act=True, binarize_down=True, smear=False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, + qk_gain_init, group_size, rope_type, + yarn_max_len, train_seq_len, binarize_act) + self.mlp = MLP(dim, mlp_mult, group_size, activation, binarize_act, binarize_down) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearModule(dim) if smear else None + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0] * x + mix[1] * x0 + n = self.attn_norm(x) + x = x + self.attn_scale.to(dtype=x.dtype) * self.attn(n) + x = x + self.mlp_scale.to(dtype=x.dtype) * self.mlp(self.mlp_norm(x)) + if self.smear is not None: + x = self.smear(x) + return x + +# --------------------------------------------------------------------------- +# GPT with U-Net skip connections +# --------------------------------------------------------------------------- +class GPT(nn.Module): + def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, + mlp_mult, tie_embeddings, tied_embed_init_std, logit_softcap, + rope_base, qk_gain_init, group_size=128, activation="relu2", + embed_dim=0, fp_storage=False, softcap_type="poly", + rope_type="rope", yarn_max_len=4096, train_seq_len=1024, + binarize_act=True, binarize_down=True, smear=False, attn_res=0): + super().__init__() + self.tie_embeddings = tie_embeddings + self.logit_softcap = logit_softcap + self.softcap_type = softcap_type + self.fp_storage = fp_storage + self.embed_dim = embed_dim if embed_dim > 0 else model_dim + self.attn_res_mode = attn_res + + # Embedding (FP8 QAT — NEVER binarized, confirmed 0.11 bpb RT gap) + self.tok_emb = QATEmbedding(vocab_size, self.embed_dim, fp_storage=fp_storage) + # Projection from embed_dim to model_dim (and reverse for tied logits) + self.embed_proj = QATLinear(self.embed_dim, model_dim, bias=False, + fp_storage=fp_storage) if self.embed_dim != model_dim else None + self.embed_proj_rev = QATLinear(model_dim, self.embed_dim, bias=False, + fp_storage=fp_storage) if self.embed_dim != model_dim else None + + # U-Net structure + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + # Skip weights initialized to ONES (not zeros — zeros costs 0.010 bpb) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + + # Transformer blocks + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, group_size, activation, rope_type, yarn_max_len, + train_seq_len, binarize_act, binarize_down, smear) + for _ in range(num_layers) + ]) + self.final_norm = RMSNorm() + + # Attention Residuals (replaces U-Net skips when enabled) + self.attn_res = None + if attn_res == 2: # pass-level + self.attn_res = AttnRes(model_dim, n_queries=num_layers) + elif attn_res == 1: # sub-layer + self.attn_res = AttnRes(model_dim, n_queries=num_layers * 2) + + # Logit head (tied to embedding) + self.lm_head = QATLinear(model_dim, vocab_size, bias=False, fp_storage=fp_storage) + self.lm_head._zero_init = True + if tie_embeddings: + self.lm_head.weight.requires_grad_(False) + self.vocab_bias = nn.Parameter(torch.zeros(vocab_size, dtype=torch.float32)) + + self._init_weights(tied_embed_init_std) + + def _init_weights(self, tied_embed_init_std): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=tied_embed_init_std) + for module in self.modules(): + if isinstance(module, XNORLinear) and not getattr(module, "_zero_init", False): + nn.init.normal_(module.weight, mean=0.0, std=0.02) + elif isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _compute_logits(self, x: Tensor) -> Tensor: + if self.tie_embeddings: + proj = self.embed_proj_rev(x) if self.embed_proj_rev is not None else x + return F.linear(proj, self.tok_emb.weight.to(x.dtype)) + return self.lm_head(x) + + def _softcap(self, logits: Tensor) -> Tensor: + s = self.logit_softcap + if self.softcap_type == "tanh": + return s * torch.tanh(logits / s) + # Poly softcap — fuses with torch.compile (tanh does not) + x_sc = torch.clamp(logits / s, -2.0, 2.0) + x2 = x_sc * x_sc + return s * torch.clamp(x_sc * (1.0 - x2 / 3.0 + x2 * x2 / 15.0), -1.0, 1.0) + + def forward(self, input_ids: Tensor, target_ids: Tensor, reduction="mean", temperature=1.0): + x = self.tok_emb(input_ids) + x = x.float() + if self.embed_proj is not None: + x = self.embed_proj(x) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.attn_res_mode == 2: + # Pass-level AttnRes: attend over all previous block outputs + stored = [x] + for i, block in enumerate(self.blocks): + x = self.attn_res.attend(stored, i) + x = block(x, x0) + stored.append(x) + elif self.attn_res_mode == 1: + # Sub-layer AttnRes: attend before each attn and mlp sub-layer + stored = [x] + q_idx = 0 + for i, block in enumerate(self.blocks): + # Attention sub-layer + x = self.attn_res.attend(stored, q_idx) + mix = block.resid_mix.to(dtype=x.dtype) + x = mix[0] * x + mix[1] * x0 + attn_out = block.attn_scale.to(dtype=x.dtype) * block.attn(block.attn_norm(x)) + stored.append(attn_out) + q_idx += 1 + # MLP sub-layer + x = self.attn_res.attend(stored, q_idx) + mlp_out = block.mlp_scale.to(dtype=x.dtype) * block.mlp(block.mlp_norm(x)) + stored.append(mlp_out) + q_idx += 1 + if block.smear is not None: + x = block.smear(self.attn_res.attend(stored, q_idx - 1)) + else: + # Standard U-Net (mode 0) + skips = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype) * skips.pop() + x = self.blocks[bi](x, x0) + + x_normed = self.final_norm(x) + x_flat = x_normed.reshape(-1, x_normed.size(-1)) + targets = target_ids.reshape(-1) + logits = self._softcap(self._compute_logits(x_flat)) + + if temperature != 1.0: + logits = logits / temperature + if reduction == "none": + return F.cross_entropy(logits.float(), targets, reduction="none").reshape(input_ids.shape) + + # Fused CE + Z-loss (logsumexp² regularization keeps STE gradients sharp) + logits_f = logits.float() + lse = torch.logsumexp(logits_f, dim=-1) + target_logits = logits_f.gather(1, targets.unsqueeze(1)).squeeze(1) + return (lse - target_logits).mean() + 1e-4 * (lse ** 2).mean() + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- +def build_luts(sp, vocab_size, device): + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + +def ld_val(pattern, seq_len, max_tok=int(os.environ.get("VAL_MAX_TOKENS", 500000))): + files = sorted(glob.glob(pattern)) + assert files, f"No files: {pattern}" + tok = torch.cat([ld_shard(Path(p)) for p in files]).contiguous() + if max_tok > 0: tok = tok[:max_tok + 1] + u = ((tok.numel() - 1) // seq_len) * seq_len + return tok[:u + 1] + +def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, temperature=1.0): + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_seqs = max(1, local_batch_tokens // args.train_seq_len) + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_start in range(seq_start, seq_end, local_batch_seqs): + batch_end = min(batch_start + local_batch_seqs, seq_end) + raw_start = batch_start * args.train_seq_len + raw_end = batch_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64) + x, y = local[:-1].reshape(-1, args.train_seq_len), local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + batch_loss = model(x, y, temperature=temperature).detach() + n = float(y.numel()) + loss_sum += batch_loss.to(torch.float64) * n + token_count += n + prev_ids, tgt_ids = x.reshape(-1), y.reshape(-1) + tok_bytes = base_bytes_lut[tgt_ids].to(torch.int16) + tok_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.int16) + byte_count += tok_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, token_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + val_loss = loss_sum / token_count + bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) + model.train() + return float(val_loss.item()), float(bpb) + +def eval_val_sliding(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, temperature=1.0): + seq_len = args.train_seq_len + batch_size = args.sliding_batch_size + total_tokens = val_tokens.numel() - 1 + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + all_starts = list(range(0, total_tokens - seq_len, stride)) + my_starts = all_starts[rank::world_size] + model.eval() + with torch.inference_mode(): + for i in range(0, len(my_starts), batch_size): + batch_starts = my_starts[i:i + batch_size] + starts_t = torch.tensor(batch_starts, dtype=torch.int64) + offsets = torch.arange(seq_len + 1, dtype=torch.int64) + indices = starts_t.unsqueeze(1) + offsets.unsqueeze(0) + local_batch = val_tokens[indices].to(device=device, dtype=torch.int64, non_blocking=True) + x, y = local_batch[:, :-1], local_batch[:, 1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_token_loss = model(x, y, reduction="none", temperature=temperature).detach() + for b, start in enumerate(batch_starts): + score_from = 0 if start == 0 else seq_len - stride + scored = per_token_loss[b, score_from:] + sx, sy = x[b, score_from:], y[b, score_from:] + loss_sum += scored.to(torch.float64).sum() + token_count += scored.numel() + tok_bytes = base_bytes_lut[sy].to(torch.int16) + tok_bytes += (has_leading_space_lut[sy] & ~is_boundary_token_lut[sx]).to(torch.int16) + byte_count += tok_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, token_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + val_loss = loss_sum / token_count + bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) + model.train() + return float(val_loss.item()), float(bpb) + +def find_temp(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut): + best_t, best_loss = 1.0, float("inf") + for t in [0.85, 0.90, 0.95, 1.00, 1.05, 1.10]: + loss, _ = eval_val(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, temperature=t) + if loss < best_loss: + best_loss = loss + best_t = t + return best_t + +# --------------------------------------------------------------------------- +# Checkpointing +# --------------------------------------------------------------------------- +def _ckpt_path(d, step): return os.path.join(d, f"ckpt_step{step:07d}.pt") + +def _latest_checkpoint(d): + if not os.path.isdir(d): return None + ckpts = sorted(glob.glob(os.path.join(d, "ckpt_step*.pt"))) + return ckpts[-1] if ckpts else None + +def save_checkpoint(d, step, model, optimizers, ms, _untied, ema_model, _ema_started, _ema_steps): + os.makedirs(d, exist_ok=True) + path = _ckpt_path(d, step) + tmp = path + ".tmp" + torch.save({ + "step": step, "training_time_ms": ms, "model": model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "rng_cpu": torch.get_rng_state(), "rng_cuda": torch.cuda.get_rng_state(), + "flags": {"_untied": _untied, "_ema_started": _ema_started, "_ema_steps": _ema_steps}, + "ema_model": ema_model.state_dict() if ema_model is not None else None, + }, tmp) + os.replace(tmp, path) + +def load_checkpoint(path, model, optimizers, device, ema_model=None): + p = torch.load(path, map_location=device, weights_only=False) + model.load_state_dict(p["model"], strict=True) + for opt, sd in zip(optimizers, p["optimizers"]): + opt.load_state_dict(sd) + if ema_model is not None and p.get("ema_model") is not None: + ema_model.load_state_dict(p["ema_model"], strict=True) + torch.set_rng_state(p["rng_cpu"].cpu()) + torch.cuda.set_rng_state(p["rng_cuda"].cpu()) + f = p.get("flags", {}) + return p["step"], p["training_time_ms"], f.get("_untied", False), f.get("_ema_started", False), f.get("_ema_steps", 0) + +# --------------------------------------------------------------------------- +# Main training +# --------------------------------------------------------------------------- +def main(): + args = Hyperparameters() + code = Path(__file__).read_text(encoding="utf-8") + if args.matrix_optimizer != "adamw": + global ns_orth + if args.use_gram_ns: + ns_orth = torch.compile(gram_ns_orth) + else: + ns_orth = torch.compile(ns_orth) + + global _INT8_KERNEL + _INT8_KERNEL = args.use_int8_kernel + + global _TRITON_KERNEL + _TRITON_KERNEL = args.use_triton_kernel and HAS_TRITON + + global _SCALE_FP8_STE + _SCALE_FP8_STE = args.scale_fp8 + + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = world_size > 1 + grad_accum_steps = max(1, 8 // world_size) + grad_scale = 1.0 / grad_accum_steps + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + os.makedirs("logs/cuda/xnor/", exist_ok=True) + logfile = f"logs/cuda/xnor/{args.run_id}.txt" if master_process else None + if master_process: print(logfile) + def log0(msg, console=True): + if not master_process: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + val_tokens = ld_val(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_luts(sp, args.vocab_size, device) + + # --- Model --- + # BINARIZE_ACTIVATIONS: 0=BWN, 1=full XNOR, 2=XNOR but BWN on MLP down proj + _ba = args.binarize_activations + _binarize_act = _ba >= 1 + _binarize_down = _ba == 1 + base_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + group_size=args.xnor_group_size, activation=args.activation_type, + embed_dim=args.embed_dim, fp_storage=args.fp_storage, + softcap_type=args.softcap_type, rope_type=args.rope_type, + yarn_max_len=args.yarn_max_len, train_seq_len=args.train_seq_len, + binarize_act=_binarize_act, binarize_down=_binarize_down, + smear=args.smear, attn_res=args.attn_res, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, nn.Linear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.tie_embeddings: + base_model.lm_head.weight.requires_grad_(False) + + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.recompile_limit = 16 + compiled_model = torch.compile(base_model, mode=args.compile_mode if args.compile_mode != "default" else None) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, + find_unused_parameters=not args.tie_embeddings, + static_graph=args.tie_embeddings, + gradient_as_bucket_view=True) if distributed else compiled_model + + # --- Optimizers --- + _excl = {"tok_emb.weight", "lm_head.weight"} + all_other = [(n, p) for n, p in base_model.named_parameters() if not any(eh in n for eh in _excl)] + matrix_params = [p for n, p in all_other if p.ndim == 2 and not any(pat in n for pat in CTP)] + scalar_params = [p for n, p in all_other if p.ndim < 2 or any(pat in n for pat in CTP)] + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + opt_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + + if args.matrix_optimizer == "adamw": + opt_muon = torch.optim.AdamW( + [{"params": matrix_params, "lr": args.adam_lr, "base_lr": args.adam_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.adam_wd, fused=True) + else: + opt_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, wd=args.muon_wd) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": 0.0, "base_lr": 0.0}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_tok, opt_muon, opt_scalar, opt_head] + + # --- EMA --- + ema_model = None + _ema_started = False + _ema_steps = 0 + if args.ema: + ema_model = copy.deepcopy(base_model) + for p in ema_model.parameters(): p.requires_grad_(False) + + # --- Checkpoint resume --- + _resume_step, _resume_ms, _resume_untied = 0, 0.0, False + _resume_ema_started, _resume_ema_steps = False, 0 + if args.checkpoint_every > 0: + ckpt = _latest_checkpoint(args.checkpoint_dir) + if ckpt: + log0(f"resuming from {ckpt}") + _resume_step, _resume_ms, _resume_untied, _resume_ema_started, _resume_ema_steps = \ + load_checkpoint(ckpt, base_model, optimizers, device, ema_model) + log0(f"resumed at step {_resume_step} ({_resume_ms:.0f}ms)") + else: + log0("no checkpoint, starting fresh") + + # --- Log --- + log0("--- Hyperparameters ---", console=False) + log0(" ".join(f"{a}={getattr(args,a)}" for a in sorted(dir(args)) + if not a.startswith("_") and a not in ("train_files","val_files") + and not callable(getattr(args,a))), console=False) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"XNOR-Net LLM | params:{n_params} L:{args.num_layers} d:{args.model_dim} " + f"h:{args.num_heads} kv:{args.num_kv_heads} mlp:{args.mlp_mult}x " + f"act:{args.activation_type} xnor_act:{args.binarize_activations} " + f"g:{args.xnor_group_size} ws:{world_size} ga:{grad_accum_steps}") + + # --- Data --- + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all(): + for opt in optimizers: opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step, elapsed_ms): + import math + if args.lr_schedule == "cosine": + # Cosine decay from 1.0 to 0 over the full run, with optional warmup + if max_wallclock_ms is None: + total = args.iterations + warmup = min(args.lr_warmup_steps, total) + if step < warmup: + return step / max(warmup, 1) + progress = (step - warmup) / max(total - warmup, 1) + else: + warmup_ms = args.lr_warmup_steps * (elapsed_ms / max(step, 1)) if step > 0 else 1000.0 + if elapsed_ms < warmup_ms: + return elapsed_ms / max(warmup_ms, 1e-9) + progress = (elapsed_ms - warmup_ms) / max(max_wallclock_ms - warmup_ms, 1e-9) + return 0.5 * (1.0 + math.cos(math.pi * min(progress, 1.0))) + # Linear warmdown (original) + if args.warmdown_fraction <= 0: return 1.0 + if max_wallclock_ms is None: + warmdown_start = int(args.iterations * (1.0 - args.warmdown_fraction)) + return max((args.iterations - step) / max(args.iterations * args.warmdown_fraction, 1), 0.0) if step >= warmdown_start else 1.0 + warmdown_ms = max_wallclock_ms * args.warmdown_fraction + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # --- Seq len scheduling --- + seq_len_stages = [args.train_seq_len] + if args.seq_len_schedule: + seq_len_stages = [] + sl = 128 + while sl <= args.train_seq_len: + seq_len_stages.append(sl) + sl *= 2 + if seq_len_stages[-1] != args.train_seq_len: + seq_len_stages.append(args.train_seq_len) + log0(f"seq_len_schedule: {seq_len_stages}") + + def get_seq_len(step, elapsed_ms): + if len(seq_len_stages) <= 1: + return seq_len_stages[0] + n = len(seq_len_stages) + if max_wallclock_ms is not None: + stage_idx = min(int(elapsed_ms / (max_wallclock_ms / n)), n - 1) + else: + stage_idx = min(int(step / (args.iterations / n)), n - 1) + return seq_len_stages[stage_idx] + + # --- Compiler warmup --- + warmup_seq_lens = seq_len_stages if args.seq_len_schedule else [args.train_seq_len] + if len(warmup_seq_lens) > 0: + _ms = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + _os = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws, sl in enumerate(warmup_seq_lens): + zero_grad_all() + for mi in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = mi == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, sl, grad_accum_steps) + torch.compiler.cudagraph_mark_step_begin() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): loss = model(x, y) + (loss * grad_scale).backward() + for o in optimizers: o.step() + zero_grad_all() + log0(f"warmup:{ws+1}/{len(warmup_seq_lens)} seq_len={sl}") + base_model.load_state_dict(_ms, strict=True) + for o, s in zip(optimizers, _os): o.load_state_dict(s) + zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # --- Main loop --- + training_time_ms = _resume_ms + stop_after_step = None + _untied = _resume_untied + _ema_started = _resume_ema_started + _ema_steps = _resume_ema_steps + train_loss = torch.zeros((), device=device) + torch.cuda.synchronize() + t0 = time.perf_counter() + step = _resume_step + steps_this_session = 0 + + if max_wallclock_ms is None: + if step >= args.iterations: + stop_after_step = step + elif training_time_ms >= max_wallclock_ms: + stop_after_step = step + + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms") + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms") + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + # Muon momentum warmup + if args.matrix_optimizer != "adam": + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + for g in opt_muon.param_groups: + g["momentum"] = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + + zero_grad_all() + train_loss.zero_() + current_seq_len = get_seq_len(step, elapsed_ms) + for micro in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = micro == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, current_seq_len, grad_accum_steps) + torch.compiler.cudagraph_mark_step_begin() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y) + # Compression regularizer: penalize sign entropy within groups + if args.sign_compress_reg > 0 and micro == 0: + reg = torch.tensor(0.0, device=device) + for p in base_model.parameters(): + if p.ndim == 2 and p.numel() > 65536: + g = args.xnor_group_size + pg = p.reshape(-1, g) + # soft sign: tanh(k*w) approximates sign(w) but has gradient + soft = torch.tanh(pg * 10.0) + # mean per group: +1/-1 = all same sign (compressible), 0 = balanced + reg = reg + (1.0 - soft.mean(dim=-1).abs().mean()) + loss = loss + args.sign_compress_reg * reg + train_loss.add_(loss.detach()) + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for opt in optimizers: + for g in opt.param_groups: g["lr"] = g["base_lr"] * scale + opt.step() + zero_grad_all() + + # EMA update + if ema_model is not None: + if not _ema_started: + if max_wallclock_ms is not None: + should_start = elapsed_ms >= args.ema_start_fraction * max_wallclock_ms + else: + should_start = step >= int(args.iterations * args.ema_start_fraction) + if should_start: + _ema_started = True + _ema_steps = 0 + with torch.no_grad(): + for ep, bp in zip(ema_model.parameters(), base_model.parameters()): + ep.data.copy_(bp.data) + log0(f"step:{step} ema_started") + if _ema_started: + _ema_steps += 1 + decay = min(args.ema_decay, (1.0 + _ema_steps) / (10.0 + _ema_steps)) + with torch.no_grad(): + for ep, bp in zip(ema_model.parameters(), base_model.parameters()): + ep.data.mul_(decay).add_(bp.data, alpha=1.0 - decay) + + step += 1 + steps_this_session += 1 + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # Checkpoint + if master_process and args.checkpoint_every > 0 and step % args.checkpoint_every == 0: + save_checkpoint(args.checkpoint_dir, step, base_model, optimizers, + approx_ms, _untied, ema_model, _ema_started, _ema_steps) + log0(f"step:{step} checkpoint saved") + + if args.train_log_every > 0 and step % args.train_log_every == 0: + log0(f"step:{step}/{args.iterations} loss:{train_loss.item():.4f} " + f"t:{approx_ms:.0f}ms step_avg:{approx_ms/step:.1f}ms") + + if args.churn_log_every > 0 and step % args.churn_log_every == 0: + log0(f"step:{step} churn:{churn_fn(base_model, args.xnor_group_size):.4f}") + + # Wallclock cap + if stop_after_step is None and max_wallclock_ms is not None and step % 10 == 0: + reached_cap = approx_ms >= max_wallclock_ms + if distributed: + cap_t = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(cap_t, op=dist.ReduceOp.MAX) + reached_cap = bool(cap_t.item()) + if reached_cap: + stop_after_step = step + + # --- Sign-sort permutation for compression --- + # Permute MLP hidden dimension so same-sign weights cluster together. + # Sort columns of down_proj by their sign pattern (lexicographic on sign bits). + # Apply matching row permutation to up_proj. Model output is identical. + if master_process: + sd = (ema_model if ema_model is not None and _ema_started else base_model).state_dict() + for layer_idx in range(args.num_layers): + # Find the MLP weight pairs + down_key = f"blocks.{layer_idx}.mlp.proj.weight" + if args.activation_type == "swiglu": + up_key = f"blocks.{layer_idx}.mlp.gate_up.weight" + else: + up_key = f"blocks.{layer_idx}.mlp.fc.weight" + if down_key not in sd or up_key not in sd: + continue + w_down = sd[down_key] # [dim, hidden] + w_up = sd[up_key] # [hidden, dim] (or [hidden*2, dim] for swiglu) + hidden = w_down.shape[1] + # Sort hidden dim columns of w_down by sign pattern + signs = (w_down >= 0).to(torch.int8) # [dim, hidden] + # Pack sign columns into sortable keys (use first 64 rows as sort key) + n_key_rows = min(64, signs.shape[0]) + sort_keys = torch.zeros(hidden, dtype=torch.int64, device=w_down.device) + for r in range(n_key_rows): + sort_keys += signs[r, :].long() << (n_key_rows - 1 - r) + perm = sort_keys.argsort() + # Apply permutation + sd[down_key] = w_down[:, perm] + if args.activation_type == "swiglu": + half = w_up.shape[0] // 2 + sd[up_key] = torch.cat([w_up[:half][perm], w_up[half:][perm]], dim=0) + else: + sd[up_key] = w_up[perm] + if base_model.tie_embeddings: + sd.pop("lm_head.weight", None) + q_obj, q_stats = q_sd(sd, group_size=args.xnor_group_size, + fp_storage=args.fp_storage, scale_fp8=args.scale_fp8) + buf = io.BytesIO() + torch.save(q_obj, buf) + raw_bytes = buf.getvalue() + # Compare LZMA vs Brotli, keep smallest + lzma_blob = lzma.compress(raw_bytes, preset=9) + best_blob, best_method = lzma_blob, "lzma" + comp_log = f"compression: lzma={len(lzma_blob)/1e6:.2f}MB" + if HAS_BROTLI: + brotli_blob = brotli.compress(raw_bytes, quality=11) + comp_log += f" brotli={len(brotli_blob)/1e6:.2f}MB" + if len(brotli_blob) < len(best_blob): + best_blob, best_method = brotli_blob, "brotli" + if HAS_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + zstd_blob = cctx.compress(raw_bytes) + comp_log += f" zstd={len(zstd_blob)/1e6:.2f}MB" + if len(zstd_blob) < len(best_blob): + best_blob, best_method = zstd_blob, "zstd" + log0(comp_log) + log0(f"using: {best_method}") + with open("final_model.xnor.ptz", "wb") as f: + # Method byte: 0=lzma, 1=brotli, 2=zstd + method_id = {"lzma": 0, "brotli": 1, "zstd": 2}[best_method] + f.write(bytes([method_id])) + f.write(best_blob) + artifact_bytes = 1 + len(best_blob) + code_bytes = len(code.encode("utf-8")) + total = artifact_bytes + code_bytes + log0(f"artifact:{artifact_bytes/1e6:.2f}MB binary:{q_stats['binary_params']}({q_stats['binary_bytes']}B) " + f"fp:{q_stats['fp_params']}({q_stats['fp_bytes']}B) code:{code_bytes}") + log0(f"budget:{total}/{16000000} ({total/1e6:.2f}/{16.00:.2f}MB) " + f"{'FITS' if total <= 16000000 else 'OVER'}") + + # --- Roundtrip eval --- + if distributed: dist.barrier() + with open("final_model.xnor.ptz", "rb") as f: + raw = f.read() + method_byte = raw[0] + compressed = raw[1:] + if method_byte == 2 and HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + decompressed = dctx.decompress(compressed) + elif method_byte == 1 and HAS_BROTLI: + decompressed = brotli.decompress(compressed) + else: + decompressed = lzma.decompress(compressed) + loaded = torch.load(io.BytesIO(decompressed), map_location="cpu", weights_only=False) + base_model.load_state_dict(deq_sd(loaded), strict=False) + if ema_model is not None: + ema_model.load_state_dict(deq_sd(loaded), strict=False) + torch._dynamo.reset() + q_val_loss, q_val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"final_xnor_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f}") + + opt_temp = 1.0 + if args.temp_scaling: + torch.cuda.synchronize() + t_temp = time.perf_counter() + calibration_tokens = train_loader.stream.take(65536).to(device) + opt_temp = find_temp(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + torch.cuda.synchronize() + log0(f"temp_scaling optimal_T:{opt_temp:.2f} time:{1000*(time.perf_counter()-t_temp):.0f}ms") + + if args.sliding_eval: + torch._dynamo.reset() + torch.cuda.empty_cache() + torch.cuda.synchronize() + t_sl = time.perf_counter() + sw_loss, sw_bpb = eval_val_sliding(args, base_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, stride=args.sliding_eval_stride, + temperature=opt_temp) + torch.cuda.synchronize() + log0(f"final_sliding val_loss:{sw_loss:.4f} val_bpb:{sw_bpb:.4f} " + f"(stride={args.sliding_eval_stride}, T={opt_temp:.2f}) time:{1000*(time.perf_counter()-t_sl):.0f}ms") + + if distributed: dist.destroy_process_group() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/train_gpt_cuda_xnor_eggroll_full.py b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/train_gpt_cuda_xnor_eggroll_full.py new file mode 100644 index 0000000000..fee4b71bbd --- /dev/null +++ b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/train_gpt_cuda_xnor_eggroll_full.py @@ -0,0 +1,1481 @@ +"""XNOR-Net LLM training script for OpenAI's Parameter Golf Challenge. +Ciprian-Florin Ifrim - April 2026 + +Based on the Binary-Weight-Network script, extended to full XNOR-Net: +both weights AND activations are binarized using sign() with per-group/per-token +scaling factors derived from the XNOR-Net paper (Rastegari et al. 2016). + +Architecture: U-Net transformer with skip connections that provides error correction +for the information loss inherent in binary quantization of both weights and activations. +""" + +import copy +import glob +import io +import math +import os +import random +import sys +import time +import lzma +from pathlib import Path +try: + import brotli + HAS_BROTLI = True +except ImportError: + HAS_BROTLI = False +try: + import zstandard as zstd + HAS_ZSTD = True +except ImportError: + HAS_ZSTD = False +try: + from gram_newton_schulz import GramNewtonSchulz, POLAR_EXPRESS_COEFFICIENTS + HAS_GRAM_NS = True +except ImportError: + HAS_GRAM_NS = False +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False +def _get_flash_attn_func(): + """GPU-aware FlashAttention selection. + H100: FA3 -> FA2 -> SDPA + 5090: FA4 -> FA2 -> SDPA + Other: FA2 -> SDPA + """ + if torch.cuda.is_available(): + sm = torch.cuda.get_device_capability() + # Hopper (SM 90) — prefer FA3 + if sm[0] == 9 and sm[1] == 0: + try: + from flash_attn_interface import flash_attn_func + return flash_attn_func + except ImportError: + pass + # Blackwell (SM 100/120) — try FA4 + if sm[0] >= 10: + try: + from flash_attn.cute import flash_attn_func + return flash_attn_func + except ImportError: + pass + # FA2 — universal fallback with CUDA kernels + try: + from flash_attn import flash_attn_func + return flash_attn_func + except ImportError: + pass + # PyTorch native SDPA — always works, no install needed + def _sdpa_flash_attn_func(q, k, v, causal=False, **kwargs): + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + o = F.scaled_dot_product_attention(q, k, v, is_causal=causal, enable_gqa=True) + return o.transpose(1, 2) + return _sdpa_flash_attn_func +flash_attn_func = _get_flash_attn_func() + +# --------------------------------------------------------------------------- +# Hyperparameters +# --------------------------------------------------------------------------- +def _e(k, d, t=str): + v = os.environ.get(k, str(d)) + if t == bool: return bool(int(v)) + return t(v) + +class Hyperparameters: + # Data + data_path = _e("DATA_PATH", "./data/datasets/fineweb10B_sp8192") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = _e("TOKENIZER_PATH", "./data/tokenizers/fineweb_8192_bpe.model") + run_id = os.environ.get("RUN_ID", f"xnor_{int(time.time())}") + seed = _e("SEED", 1337, int) + # Compile + compile_mode = _e("COMPILE_MODE", "default") + # Eval + val_batch_size = _e("VAL_BATCH_SIZE", 524288, int) + val_loss_every = _e("VAL_LOSS_EVERY", 500, int) + train_log_every = _e("TRAIN_LOG_EVERY", 10, int) + sliding_eval = _e("SLIDING_EVAL", 0, bool) + sliding_eval_stride = _e("SLIDING_EVAL_STRIDE", 16, int) + sliding_batch_size = _e("SLIDING_BATCH_SIZE", 512, int) + temp_scaling = _e("TEMP_SCALING", 0, bool) + # Training schedule + iterations = _e("ITERATIONS", 2000, int) + warmdown_fraction = _e("WARMDOWN_FRACTION", 0.15, float) + lr_warmup_steps = _e("LR_WARMUP_STEPS", 5, int) + lr_schedule = _e("LR_SCHEDULE", "linear") # "linear" (current) or "cosine" + seq_len_schedule = _e("SEQ_LEN_SCHEDULE", 0, bool) # ramp seq_len from 128 to train_seq_len + train_batch_tokens = _e("TRAIN_BATCH_TOKENS", 524288, int) + train_seq_len = _e("TRAIN_SEQ_LEN", 1024, int) + max_wallclock_seconds = _e("MAX_WALLCLOCK_SECONDS", 0.0, float) + # Architecture — from XNOR-UNet spec + vocab_size = _e("VOCAB_SIZE", 8192, int) + num_layers = _e("NUM_LAYERS", 15, int) + num_kv_heads = _e("NUM_KV_HEADS", 4, int) + model_dim = _e("MODEL_DIM", 768, int) + num_heads = _e("NUM_HEADS", 8, int) + mlp_mult = _e("MLP_MULT", 4, int) + tie_embeddings = _e("TIE_EMBEDDINGS", 1, int) + embed_dim = _e("EMBED_DIM", 312, int) + activation_type = _e("ACTIVATION", "relu2") + # Attention + rope_base = _e("ROPE_BASE", 5000.0, float) + rope_type = _e("ROPE_TYPE", "yarn") + yarn_max_len = _e("YARN_MAX_LEN", 2048, int) + qk_gain_init = _e("QK_GAIN_INIT", 2.25, float) + # Logits + logit_softcap = _e("LOGIT_SOFTCAP", 10.0, float) + softcap_type = _e("SOFTCAP_TYPE", "poly") + tied_embed_init_std = _e("TIED_EMBED_INIT_STD", 0.005, float) + # Architecture extras + smear = _e("SMEAR", 1, bool) + # XNOR quantization + xnor_group_size = _e("XNOR_GROUP_SIZE", 128, int) + # Controls activation binarization: 0 = none (BWN), 1 = all (full XNOR), 2 = skip MLP down proj + binarize_activations = _e("BINARIZE_ACTIVATIONS", 1, int) + # Use INT8 tensor cores for binary matmul (speed test) + use_int8_kernel = _e("USE_INT8_KERNEL", 0, bool) + # Use Triton XNOR+POPCOUNT kernel (true 1-bit matmul) + use_triton_kernel = _e("USE_TRITON_KERNEL", 0, bool) + # FP8 QAT for non-binary params + _fp_raw = os.environ.get("FP_STORAGE", "FP8") + fp_storage = True if _fp_raw == "FP8" else False + # Scale storage for binary weight groups: FP8 or BF16 + _scale_raw = os.environ.get("SCALE_STORAGE", "BF16") + scale_fp8 = True if _scale_raw == "FP8" else False + # Compression regularizer: penalizes sign entropy within groups + sign_compress_reg = _e("SIGN_COMPRESS_REG", 0.0, float) + # EGGROLL + pop_size = _e("POP_SIZE", 256, int) # population size (antithetic pairs = pop_size/2) + eggroll_sigma = _e("EGGROLL_SIGMA", 0.01, float) # perturbation scale + eggroll_lr = _e("EGGROLL_LR", 0.001, float) # learning rate + eggroll_rank = _e("EGGROLL_RANK", 1, int) # rank of perturbation matrices + fitness_shaping = _e("FITNESS_SHAPING", "rank") # "rank" or "sign" + eggroll_load = _e("EGGROLL_LOAD", "") # path to pretrained .ptz to load + # Checkpointing + checkpoint_every = _e("CHECKPOINT_EVERY", 0, int) + checkpoint_dir = _e("CHECKPOINT_DIR", "./checkpoints") + # Diagnostics + churn_log_every = _e("CHURN_LOG_EVERY", 500, int) + +# Scalar/low-dim parameter names (excluded from Muon, use Adam) +CTP = ("attn_scale", "mlp_scale", "resid_mix", "q_gain", "skip_weight", + "skip_weights", "vocab_bias", "smear") + +# --------------------------------------------------------------------------- +# Binary packing (1 bit per weight, lossless) +# --------------------------------------------------------------------------- +def pack_binary(q: Tensor) -> tuple[bytes, int]: + bits = ((q.reshape(-1).to(torch.int8) + 1) // 2).numpy().astype(np.uint8) + n = len(bits) + pad = (8 - n % 8) % 8 + if pad: + bits = np.concatenate([bits, np.zeros(pad, dtype=np.uint8)]) + groups = bits.reshape(-1, 8) + packed = np.zeros(len(groups), dtype=np.uint8) + for i in range(8): + packed |= groups[:, i] << i + return packed.tobytes(), n + +def unpack_binary(data: bytes, n: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8) + bits = np.zeros((len(packed), 8), dtype=np.int8) + for i in range(8): + bits[:, i] = (packed >> i) & 1 + flat = bits.reshape(-1)[:n] + return torch.from_numpy(flat.astype(np.int8) * 2 - 1) + +# --------------------------------------------------------------------------- +# State dict serialization +# --------------------------------------------------------------------------- +def q_sd(state_dict: dict, group_size: int = 128, fp_storage=False, scale_fp8=False) -> tuple[dict, dict]: + """Binary pack large 2D weight matrices, BF16/FP8 for everything else.""" + quantized = {} + stats = {"binary_params": 0, "binary_bytes": 0, "fp_params": 0, "fp_bytes": 0} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().float().contiguous() + t_orig_shape = list(t.shape) + if t.ndim == 3: + t = t.reshape(t.shape[0], -1) + is_binary = ( + t.ndim == 2 and t.numel() > 65_536 + and "tok_emb" not in name and "lm_head" not in name + and "embed_proj" not in name + ) + if is_binary: + pad = (group_size - t.shape[1] % group_size) % group_size + t_padded = F.pad(t, (0, pad)) if pad > 0 else t + t_grouped = t_padded.reshape(-1, group_size) + scale = t_grouped.abs().mean(-1, keepdim=True).clamp(min=1e-8) + q = torch.where(t_grouped >= 0, + torch.ones_like(t_grouped, dtype=torch.int8), + -torch.ones_like(t_grouped, dtype=torch.int8)) + packed_bytes, n_bits = pack_binary(q) + if scale_fp8: + scale_stored = scale.to(torch.float8_e4m3fn).squeeze(-1) + scale_bytes_per = 1 + else: + scale_stored = scale.bfloat16().squeeze(-1) + scale_bytes_per = 2 + quantized[name] = { + "type": "binary", "packed": packed_bytes, + "scale": scale_stored, + "shape": list(t.shape), "padded_cols": t_padded.shape[1], + "group_size": group_size, "n_bits": n_bits, + "orig_shape": t_orig_shape, + } + stats["binary_params"] += t.numel() + stats["binary_bytes"] += len(packed_bytes) + scale.numel() * scale_bytes_per + elif fp_storage and t.ndim == 2: + quantized[name] = {"type": "fp8", "data": t.to(torch.float8_e4m3fn)} + stats["fp_params"] += t.numel() + stats["fp_bytes"] += t.numel() + else: + quantized[name] = {"type": "bf16", "data": t.bfloat16()} + stats["fp_params"] += t.numel() + stats["fp_bytes"] += t.numel() * 2 + return quantized, stats + +def deq_sd(quantized: dict, target_dtype=torch.bfloat16): + """Reconstruct state dict from quantized representation.""" + out = {} + for name, entry in quantized.items(): + if entry["type"] == "binary": + q = unpack_binary(entry["packed"], entry["n_bits"]) + q = q.float().reshape(-1, entry["group_size"]) + scale = entry["scale"].float().unsqueeze(-1) + # No shrinkage correction: binary q_absmean = 1.0 always + t = (q * scale).reshape(-1, entry["padded_cols"]) + shape = entry["shape"] + result = t[:shape[0], :shape[1]].to(target_dtype) + orig = entry.get("orig_shape") + out[name] = result.reshape(orig).contiguous() if orig and orig != shape else result.contiguous() + elif entry["type"] == "fp8": + out[name] = entry["data"].to(torch.float32).to(target_dtype).contiguous() + else: + out[name] = entry["data"].to(target_dtype).contiguous() + return out + +# --------------------------------------------------------------------------- +# Diagnostics +# --------------------------------------------------------------------------- +_prev_committed: dict = {} +def churn_fn(model: nn.Module, group_size: int = 128): + global _prev_committed + total = flipped = 0 + with torch.no_grad(): + for name, p in model.named_parameters(): + if p.ndim == 2 and "weight" in name and p.shape[0] > 1: + w = p.detach().float().reshape(-1, group_size) + q = torch.where(w >= 0, torch.ones_like(w), -torch.ones_like(w)).cpu().numpy() + if name in _prev_committed: + flipped += int(np.sum(q != _prev_committed[name])) + total += q.size + _prev_committed[name] = q + return flipped / max(total, 1) + +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# Muon optimizer +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# EGGROLL: Evolution-Guided optimization with low-rank perturbations +# --------------------------------------------------------------------------- +def get_binary_params(model): + """Get list of (name, param) for parameters that will be binarized.""" + params = [] + for name, p in model.named_parameters(): + if p.ndim == 2 and p.numel() > 65536 and "tok_emb" not in name and "lm_head" not in name and "embed_proj" not in name: + params.append((name, p)) + return params + +def apply_perturbation(binary_params, seed, sigma, rank, sign=1.0, device="cuda"): + """Apply rank-r perturbation to binary params using deterministic RNG.""" + rng = torch.Generator(device=device) + rng.manual_seed(seed) + for name, p in binary_params: + m, n = p.shape + for r in range(rank): + a = torch.randn(m, device=device, generator=rng) + b = torch.randn(n, device=device, generator=rng) + p.data.add_(a.outer(b), alpha=sign * sigma / math.sqrt(rank)) + +def remove_perturbation(binary_params, seed, sigma, rank, sign=1.0, device="cuda"): + """Remove previously applied perturbation (subtract it).""" + apply_perturbation(binary_params, seed, sigma, rank, sign=-sign, device=device) + +def compute_eggroll_update(binary_params, seeds, fitnesses, sigma, rank, lr, device="cuda"): + """Compute and apply the EGGROLL weight update. + Uses the efficient (diag(f) @ A).T @ B formulation for rank-1.""" + N = len(seeds) + for name, p in binary_params: + m, n = p.shape + update = torch.zeros_like(p.data) + for i, (seed, f) in enumerate(zip(seeds, fitnesses)): + rng = torch.Generator(device=device) + rng.manual_seed(seed) + for r in range(rank): + a = torch.randn(m, device=device, generator=rng) + b = torch.randn(n, device=device, generator=rng) + update.add_(a.outer(b), alpha=f / (math.sqrt(rank) * N)) + p.data.add_(update, alpha=lr / sigma) + +def shape_fitnesses(fitnesses_pos, fitnesses_neg, method="rank"): + """Shape fitnesses from antithetic pairs.""" + N = len(fitnesses_pos) + if method == "sign": + return [float(fp > fn) - float(fn > fp) for fp, fn in zip(fitnesses_pos, fitnesses_neg)] + # Rank-based shaping: combine all, rank, normalize to [-0.5, 0.5] + all_f = list(fitnesses_pos) + list(fitnesses_neg) + ranked = torch.argsort(torch.argsort(torch.tensor(all_f, dtype=torch.float32))).float() + ranked = ranked / (2 * N - 1) - 0.5 + # Return shaped fitness for positive perturbation (subtract negative's) + shaped = [(ranked[i] - ranked[N + i]).item() for i in range(N)] + return shaped + +# --------------------------------------------------------------------------- +# Data loading +# --------------------------------------------------------------------------- +def ld_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" Tensor: + chunks = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int): + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start:start + per_rank_span].pin_memory().to(self.device, non_blocking=True).to(torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x, y + +# --------------------------------------------------------------------------- +# Model components +# --------------------------------------------------------------------------- +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def apply_qat_ste(w: Tensor, fp_storage) -> Tensor: + if not fp_storage: + return w + w_sim = w.to(torch.float8_e4m3fn).to(w.dtype) + return (w_sim - w).detach() + w + +class QATLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=False, fp_storage=False): + super().__init__(in_features, out_features, bias=bias) + self.fp_storage = fp_storage + def forward(self, x: Tensor) -> Tensor: + w = apply_qat_ste(self.weight, self.fp_storage) + return F.linear(x, w.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) + +class QATEmbedding(nn.Embedding): + def __init__(self, num_embeddings, embedding_dim, fp_storage=False): + super().__init__(num_embeddings, embedding_dim) + self.fp_storage = fp_storage + def forward(self, input: Tensor) -> Tensor: + w = apply_qat_ste(self.weight, self.fp_storage) + return F.embedding(input, w, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + +# --------------------------------------------------------------------------- +# XNOR-Net Linear Layers +# --------------------------------------------------------------------------- +_INT8_KERNEL = False # set from Hyperparameters.use_int8_kernel in main() + +class _INT8XNORFn(torch.autograd.Function): + """INT8 tensor core forward, BF16 backward. + Uses per-row alpha (not per-group) so we get a single torch._int_mm call. + """ + @staticmethod + def forward(ctx, x, w): + # Per-row weight scale (single matmul) + w_alpha = w.detach().abs().mean(dim=-1).clamp_(min=1e-8) # [N] + x_beta = x.detach().abs().mean(dim=-1, keepdim=True).clamp_(min=1e-8) # [..., 1] + + w_int8 = w.detach().sign().to(torch.int8) # [N, K] + x_int8 = x.detach().sign().to(torch.int8) # [..., K] + + # Flatten to 2D for torch._int_mm(A[M,K], B[K,N]) -> [M,N] + orig_shape = x.shape[:-1] + M, K, N = orig_shape.numel(), x.shape[-1], w.shape[0] + x_2d = x_int8.reshape(M, K) + wt = w_int8.t().contiguous() # [K, N] + raw = torch._int_mm(x_2d, wt) # [M, N] INT32 + + y = raw.to(x.dtype) * x_beta.reshape(M, 1) * w_alpha.unsqueeze(0) + ctx.save_for_backward(x, w, w_alpha, x_beta) + return y.reshape(*orig_shape, N) + + @staticmethod + def backward(ctx, grad_output): + x, w, w_alpha, x_beta = ctx.saved_tensors + dtype = grad_output.dtype + + w_binary = (w.sign() * w_alpha.unsqueeze(-1)).to(dtype) # [N, K] + x_binary = (x.sign() * x_beta).to(dtype) # [..., K] + + # STE: gradient flows through sign() as identity + grad_x = grad_output @ w_binary + M = grad_output.shape[:-1].numel() + grad_w = grad_output.reshape(M, -1).t() @ x_binary.reshape(M, -1) + return grad_x, grad_w + +# --------------------------------------------------------------------------- +# Triton XNOR+POPCOUNT kernel — true 1-bit matmul +# --------------------------------------------------------------------------- +_TRITON_KERNEL = False # set from Hyperparameters in main() +_SCALE_FP8_STE = False # simulate FP8 scale quantization during training + +def pack_signs(x: Tensor) -> Tensor: + """Pack sign bits into int32. +1 -> bit 1, -1 -> bit 0. + x: [..., K] any dtype. K must be divisible by 32. + returns: [..., K//32] int32 + """ + K = x.shape[-1] + assert K % 32 == 0, f"K={K} must be divisible by 32" + bits = (x > 0).view(*x.shape[:-1], K // 32, 32) + shifts = torch.arange(32, device=x.device, dtype=torch.int32) + return (bits.to(torch.int32) << shifts).sum(dim=-1) + +if HAS_TRITON: + @triton.jit + def _xnor_matmul_fwd_kernel( + X_ptr, W_ptr, Y_ptr, alpha_ptr, beta_ptr, + M, N, K_packed, GROUP_PACKED: tl.constexpr, num_groups, + stride_xm, stride_wn, stride_yn, stride_alpha_n, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + ): + pid = tl.program_id(0) + num_n = tl.cdiv(N, BLOCK_N) + pid_m = pid // num_n + pid_n = pid % num_n + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_m = offs_m < M + mask_n = offs_n < N + + y_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.bfloat16) + + x_base = X_ptr + offs_m * stride_xm + w_base = W_ptr + offs_n * stride_wn + + # Loop over groups, each with GROUP_PACKED packed int32s + for g in range(num_groups): + group_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + k_start = g * GROUP_PACKED + for k in range(GROUP_PACKED): + x = tl.load(x_base + k_start + k, mask=mask_m, other=0) + w = tl.load(w_base + k_start + k, mask=mask_n, other=0) + # Cast broadcast XOR to int32 to prevent int64 promotion in popc + diff = (x[:, None] ^ w[None, :]).to(tl.int32) + group_acc += tl.extra.cuda.libdevice.popc(diff) + + # diff_bits = popc(XOR), dot = GROUP_SIZE - 2 * diff_bits + group_dot = (GROUP_PACKED * 32 - 2 * group_acc).to(tl.bfloat16) + + # Per-group alpha: alpha[n, g] + alpha_g = tl.load(alpha_ptr + offs_n * stride_alpha_n + g, mask=mask_n, other=1.0).to(tl.bfloat16) + y_acc += group_dot * alpha_g[None, :] + + # Per-token beta + beta = tl.load(beta_ptr + offs_m, mask=mask_m, other=1.0).to(tl.bfloat16) + y_acc = y_acc * beta[:, None] + + # Store + y_ptrs = Y_ptr + offs_m[:, None] * stride_yn + offs_n[None, :] + tl.store(y_ptrs, y_acc, mask=mask_m[:, None] & mask_n[None, :]) + +class _TritonXNORFn(torch.autograd.Function): + """Triton XNOR+POPCOUNT forward with per-group alpha, BF16 backward.""" + @staticmethod + def forward(ctx, x, w, group_size): + N, K = w.shape + assert K % group_size == 0, f"K={K} must be divisible by group_size={group_size}" + assert K % 32 == 0, f"K={K} must be divisible by 32" + num_groups = K // group_size + GROUP_PACKED = group_size // 32 + + # Per-group weight scale: [N, num_groups] + w_alpha = w.detach().reshape(N, num_groups, group_size).abs().mean(dim=-1).clamp_(min=1e-8) + if _SCALE_FP8_STE: + w_alpha = w_alpha.to(torch.float8_e4m3fn).to(w_alpha.dtype) + # Per-token activation scale: [..., 1] + x_beta_kd = x.detach().abs().mean(dim=-1, keepdim=True).clamp_(min=1e-8) + + # Pack signs into int32 + x_packed = pack_signs(x.detach()) # [..., K//32] + w_packed = pack_signs(w.detach()) # [N, K//32] + + orig_shape = x.shape[:-1] + M = orig_shape.numel() + K_packed = K // 32 + + x_2d = x_packed.reshape(M, K_packed).contiguous() + w_alpha_c = w_alpha.contiguous() # [N, num_groups] + beta_flat = x_beta_kd.reshape(M).contiguous() + + Y = torch.empty(M, N, device=x.device, dtype=torch.bfloat16) + + BLOCK_M, BLOCK_N = 64, 64 + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),) + + _xnor_matmul_fwd_kernel[grid]( + x_2d, w_packed, Y, w_alpha_c, beta_flat, + M, N, K_packed, GROUP_PACKED, num_groups, + x_2d.stride(0), w_packed.stride(0), Y.stride(0), w_alpha_c.stride(0), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + ) + + ctx.save_for_backward(x, w, w_alpha, x_beta_kd) + ctx.group_size = group_size + return Y.reshape(*orig_shape, N) + + @staticmethod + def backward(ctx, grad_output): + x, w, w_alpha, x_beta_kd = ctx.saved_tensors + dtype = grad_output.dtype + group_size = ctx.group_size + + # Reconstruct per-group binary weights for backward + N, K = w.shape + num_groups = K // group_size + w_binary = (w.sign().reshape(N, num_groups, group_size) * + w_alpha.unsqueeze(-1)).reshape(N, K).to(dtype) + x_binary = (x.sign() * x_beta_kd).to(dtype) + + grad_x = grad_output @ w_binary + M = grad_output.shape[:-1].numel() + grad_w = grad_output.reshape(M, -1).t() @ x_binary.reshape(M, -1) + return grad_x, grad_w, None # None for group_size + +class XNORLinear(nn.Linear): + """XNOR-Net linear layer: binarizes BOTH weights AND activations. + + Forward pass: + 1. Weight binarization: W_bin = sign(W) * alpha, alpha = mean(|W|) per group + 2. Activation binarization: X_bin = sign(X) * beta, beta = mean(|X|) per token + 3. Output = F.linear(X_bin, W_bin) + + Backward pass: Straight-Through Estimator (STE) — gradients flow through as + if the binarization was identity. This matches the XNOR-Net paper's training + algorithm (Algorithm 1). + + When binarize_act=False, this degrades to Binary-Weight-Network (BWN). + """ + def __init__(self, in_features, out_features, bias=False, group_size=128, + binarize_act=True): + super().__init__(in_features, out_features, bias=bias) + self.group_size = group_size + self.binarize_act = binarize_act + + def forward(self, x: Tensor) -> Tensor: + # Triton XNOR+POPCOUNT path (full XNOR only, per-row alpha) + if _TRITON_KERNEL and self.binarize_act: + y = _TritonXNORFn.apply(x, self.weight, self.group_size) + if self.bias is not None: + y = y + self.bias.to(y.dtype) + return y + # INT8 tensor core path (full XNOR only, per-row alpha) + if _INT8_KERNEL and self.binarize_act: + y = _INT8XNORFn.apply(x, self.weight) + if self.bias is not None: + y = y + self.bias.to(y.dtype) + return y + + # Standard STE path (per-group alpha) + w = self.weight.bfloat16() + g = self.group_size + + # --- Weight binarization (STE) --- + w_g = w.reshape(-1, g) + alpha = w_g.abs().mean(-1, keepdim=True).clamp(min=1e-8) + # Simulate FP8 scale quantization so model compensates for roundtrip error + if _SCALE_FP8_STE: + alpha_q = alpha.to(torch.float8_e4m3fn).to(alpha.dtype) + alpha = alpha + (alpha_q - alpha).detach() + w_sign = torch.where(w_g >= 0, torch.ones_like(w_g), -torch.ones_like(w_g)) + w_binary = w + ((w_sign * alpha).reshape(w.shape) - w).detach() + + # --- Activation binarization (STE) --- + if self.binarize_act: + beta = x.abs().mean(dim=-1, keepdim=True).clamp(min=1e-8) + x_sign = torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x)) + x_binary = x + ((x_sign * beta) - x).detach() + else: + x_binary = x + + return F.linear(x_binary, w_binary, + self.bias.to(x.dtype) if self.bias is not None else None) + + +class NormedXNORLinear(XNORLinear): + """XNOR linear with RMSNorm on input. + Used for output projections receiving un-normalized activations (attention out, MLP down). + The RMSNorm ensures zero-mean input before binarization — critical for sign() quality + (XNOR-Net paper: B-A-C-P block order, Table 3b: 30.3% vs 44.2%). + """ + def forward(self, x: Tensor) -> Tensor: + return super().forward(F.rms_norm(x, (x.size(-1),))) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(p in name for p in CTP)) and param.dtype != torch.float32: + param.data = param.data.float() + +# --------------------------------------------------------------------------- +# Rotary embeddings +# --------------------------------------------------------------------------- +class Rotary(nn.Module): + def __init__(self, dim, base=10000.0, rope_type="rope", yarn_max_len=4096, train_seq_len=1024): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + if rope_type == "yarn": + scale = train_seq_len / yarn_max_len + freq_idx = torch.arange(0, dim, 2, dtype=torch.float32) + ramp = torch.clamp((freq_idx / dim - 0.25) / 0.75, 0.0, 1.0) + inv_freq = inv_freq / (ramp * (1.0 / scale - 1.0) + 1.0) + # Precompute for max seq_len, slice at runtime + t = torch.arange(train_seq_len, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self.register_buffer("_cos", freqs.cos()[None, :, None, :], persistent=False) + self.register_buffer("_sin", freqs.sin()[None, :, None, :], persistent=False) + + def forward(self, seq_len, device, dtype): + return self._cos[:, :seq_len].to(dtype=dtype), self._sin[:, :seq_len].to(dtype=dtype) + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- +class CausalSelfAttention(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + group_size=128, rope_type="rope", yarn_max_len=4096, + train_seq_len=1024, binarize_act=True): + super().__init__() + self.num_heads, self.num_kv_heads = num_heads, num_kv_heads + self.head_dim = dim // num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + # Fused QKV — single XNOR matmul + self.c_qkv = XNORLinear(dim, self.q_size + 2 * self.kv_size, bias=False, + group_size=group_size, binarize_act=binarize_act) + # Output projection with RMSNorm before binarization (B-A-C-P order) + self.proj = NormedXNORLinear(dim, dim, bias=False, + group_size=group_size, binarize_act=binarize_act) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, + rope_type=rope_type, yarn_max_len=yarn_max_len, + train_seq_len=train_seq_len) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + qkv_out = self.c_qkv(x) + q_out, k_out, v_out = qkv_out.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q_out.reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = k_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = v_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + # QK normalization (BF16, stays float — non-negotiable for attention quality) + q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + # FlashAttention-3 (BF16 — attention scores/softmax/V are NEVER binarized) + y = flash_attn_func(q.contiguous(), k.contiguous(), v.contiguous(), causal=True) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + +# --------------------------------------------------------------------------- +# MLP +# --------------------------------------------------------------------------- +class MLP(nn.Module): + def __init__(self, dim, mlp_mult, group_size=128, activation="relu2", + binarize_act=True, binarize_down=True): + super().__init__() + hidden = mlp_mult * dim + self.activation = activation + if activation == "swiglu": + self.gate_up = XNORLinear(dim, hidden * 2, bias=False, + group_size=group_size, binarize_act=binarize_act) + else: + self.fc = XNORLinear(dim, hidden, bias=False, + group_size=group_size, binarize_act=binarize_act) + self.proj = NormedXNORLinear(hidden, dim, bias=False, + group_size=group_size, binarize_act=binarize_down) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + if self.activation == "swiglu": + gu = self.gate_up(x) + gate, up = gu.chunk(2, dim=-1) + return self.proj(F.silu(gate) * up) + elif self.activation == "signsq": + h = self.fc(x) + return self.proj(h * h.abs()) + else: # relu2 + return self.proj(torch.relu(self.fc(x)).square()) + +# --------------------------------------------------------------------------- +# SmearModule +# --------------------------------------------------------------------------- +class SmearModule(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + cumsum = x.cumsum(dim=1) + counts = torch.arange(1, x.size(1) + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smeared = cumsum / counts + gate = torch.tanh(self.gate.to(dtype=x.dtype)) + return x + gate * (smeared - x) + +# --------------------------------------------------------------------------- +# Transformer block +# --------------------------------------------------------------------------- +class Block(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, group_size=128, activation="relu2", + rope_type="rope", yarn_max_len=4096, train_seq_len=1024, + binarize_act=True, binarize_down=True, smear=False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, + qk_gain_init, group_size, rope_type, + yarn_max_len, train_seq_len, binarize_act) + self.mlp = MLP(dim, mlp_mult, group_size, activation, binarize_act, binarize_down) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearModule(dim) if smear else None + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0] * x + mix[1] * x0 + n = self.attn_norm(x) + x = x + self.attn_scale.to(dtype=x.dtype) * self.attn(n) + x = x + self.mlp_scale.to(dtype=x.dtype) * self.mlp(self.mlp_norm(x)) + if self.smear is not None: + x = self.smear(x) + return x + +# --------------------------------------------------------------------------- +# GPT with U-Net skip connections +# --------------------------------------------------------------------------- +class GPT(nn.Module): + def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, + mlp_mult, tie_embeddings, tied_embed_init_std, logit_softcap, + rope_base, qk_gain_init, group_size=128, activation="relu2", + embed_dim=0, fp_storage=False, softcap_type="poly", + rope_type="rope", yarn_max_len=4096, train_seq_len=1024, + binarize_act=True, binarize_down=True, smear=False): + super().__init__() + self.tie_embeddings = tie_embeddings + self.logit_softcap = logit_softcap + self.softcap_type = softcap_type + self.fp_storage = fp_storage + self.embed_dim = embed_dim if embed_dim > 0 else model_dim + + # Embedding (FP8 QAT — NEVER binarized, confirmed 0.11 bpb RT gap) + self.tok_emb = QATEmbedding(vocab_size, self.embed_dim, fp_storage=fp_storage) + # Projection from embed_dim to model_dim (and reverse for tied logits) + self.embed_proj = QATLinear(self.embed_dim, model_dim, bias=False, + fp_storage=fp_storage) if self.embed_dim != model_dim else None + self.embed_proj_rev = QATLinear(model_dim, self.embed_dim, bias=False, + fp_storage=fp_storage) if self.embed_dim != model_dim else None + + # U-Net structure + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + # Skip weights initialized to ONES (not zeros — zeros costs 0.010 bpb) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + + # Transformer blocks + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, group_size, activation, rope_type, yarn_max_len, + train_seq_len, binarize_act, binarize_down, smear) + for _ in range(num_layers) + ]) + self.final_norm = RMSNorm() + + # Logit head (tied to embedding) + self.lm_head = QATLinear(model_dim, vocab_size, bias=False, fp_storage=fp_storage) + self.lm_head._zero_init = True + if tie_embeddings: + self.lm_head.weight.requires_grad_(False) + self.vocab_bias = nn.Parameter(torch.zeros(vocab_size, dtype=torch.float32)) + + self._init_weights(tied_embed_init_std) + + def _init_weights(self, tied_embed_init_std): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=tied_embed_init_std) + for module in self.modules(): + if isinstance(module, XNORLinear) and not getattr(module, "_zero_init", False): + nn.init.normal_(module.weight, mean=0.0, std=0.02) + elif isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _compute_logits(self, x: Tensor) -> Tensor: + if self.tie_embeddings: + proj = self.embed_proj_rev(x) if self.embed_proj_rev is not None else x + return F.linear(proj, self.tok_emb.weight.to(x.dtype)) + return self.lm_head(x) + + def _softcap(self, logits: Tensor) -> Tensor: + s = self.logit_softcap + if self.softcap_type == "tanh": + return s * torch.tanh(logits / s) + # Poly softcap — fuses with torch.compile (tanh does not) + x_sc = torch.clamp(logits / s, -2.0, 2.0) + x2 = x_sc * x_sc + return s * torch.clamp(x_sc * (1.0 - x2 / 3.0 + x2 * x2 / 15.0), -1.0, 1.0) + + def forward(self, input_ids: Tensor, target_ids: Tensor, reduction="mean", temperature=1.0): + x = self.tok_emb(input_ids) + x = x.float() + if self.embed_proj is not None: + x = self.embed_proj(x) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + # U-Net encoder + skips = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + + # U-Net decoder with skip connections + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype) * skips.pop() + x = self.blocks[bi](x, x0) + + x_normed = self.final_norm(x) + x_flat = x_normed.reshape(-1, x_normed.size(-1)) + targets = target_ids.reshape(-1) + logits = self._softcap(self._compute_logits(x_flat)) + + if temperature != 1.0: + logits = logits / temperature + if reduction == "none": + return F.cross_entropy(logits.float(), targets, reduction="none").reshape(input_ids.shape) + + # Fused CE + Z-loss (logsumexp² regularization keeps STE gradients sharp) + logits_f = logits.float() + lse = torch.logsumexp(logits_f, dim=-1) + target_logits = logits_f.gather(1, targets.unsqueeze(1)).squeeze(1) + return (lse - target_logits).mean() + 1e-4 * (lse ** 2).mean() + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- +def build_luts(sp, vocab_size, device): + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + +def ld_val(pattern, seq_len, max_tok=int(os.environ.get("VAL_MAX_TOKENS", 500000))): + files = sorted(glob.glob(pattern)) + assert files, f"No files: {pattern}" + tok = torch.cat([ld_shard(Path(p)) for p in files]).contiguous() + if max_tok > 0: tok = tok[:max_tok + 1] + u = ((tok.numel() - 1) // seq_len) * seq_len + return tok[:u + 1] + +def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, temperature=1.0): + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_seqs = max(1, local_batch_tokens // args.train_seq_len) + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_start in range(seq_start, seq_end, local_batch_seqs): + batch_end = min(batch_start + local_batch_seqs, seq_end) + raw_start = batch_start * args.train_seq_len + raw_end = batch_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64) + x, y = local[:-1].reshape(-1, args.train_seq_len), local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + batch_loss = model(x, y, temperature=temperature).detach() + n = float(y.numel()) + loss_sum += batch_loss.to(torch.float64) * n + token_count += n + prev_ids, tgt_ids = x.reshape(-1), y.reshape(-1) + tok_bytes = base_bytes_lut[tgt_ids].to(torch.int16) + tok_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.int16) + byte_count += tok_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, token_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + val_loss = loss_sum / token_count + bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) + model.train() + return float(val_loss.item()), float(bpb) + +def eval_val_sliding(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, temperature=1.0): + seq_len = args.train_seq_len + batch_size = args.sliding_batch_size + total_tokens = val_tokens.numel() - 1 + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + all_starts = list(range(0, total_tokens - seq_len, stride)) + my_starts = all_starts[rank::world_size] + model.eval() + with torch.inference_mode(): + for i in range(0, len(my_starts), batch_size): + batch_starts = my_starts[i:i + batch_size] + starts_t = torch.tensor(batch_starts, dtype=torch.int64) + offsets = torch.arange(seq_len + 1, dtype=torch.int64) + indices = starts_t.unsqueeze(1) + offsets.unsqueeze(0) + local_batch = val_tokens[indices].to(device=device, dtype=torch.int64, non_blocking=True) + x, y = local_batch[:, :-1], local_batch[:, 1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_token_loss = model(x, y, reduction="none", temperature=temperature).detach() + for b, start in enumerate(batch_starts): + score_from = 0 if start == 0 else seq_len - stride + scored = per_token_loss[b, score_from:] + sx, sy = x[b, score_from:], y[b, score_from:] + loss_sum += scored.to(torch.float64).sum() + token_count += scored.numel() + tok_bytes = base_bytes_lut[sy].to(torch.int16) + tok_bytes += (has_leading_space_lut[sy] & ~is_boundary_token_lut[sx]).to(torch.int16) + byte_count += tok_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, token_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + val_loss = loss_sum / token_count + bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) + model.train() + return float(val_loss.item()), float(bpb) + +def find_temp(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut): + best_t, best_loss = 1.0, float("inf") + for t in [0.85, 0.90, 0.95, 1.00, 1.05, 1.10]: + loss, _ = eval_val(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, temperature=t) + if loss < best_loss: + best_loss = loss + best_t = t + return best_t + +# --------------------------------------------------------------------------- +# Checkpointing +# --------------------------------------------------------------------------- +def _ckpt_path(d, step): return os.path.join(d, f"ckpt_step{step:07d}.pt") + +def _latest_checkpoint(d): + if not os.path.isdir(d): return None + ckpts = sorted(glob.glob(os.path.join(d, "ckpt_step*.pt"))) + return ckpts[-1] if ckpts else None + +def save_checkpoint(d, step, model, optimizers, ms, _untied, ema_model, _ema_started, _ema_steps): + os.makedirs(d, exist_ok=True) + path = _ckpt_path(d, step) + tmp = path + ".tmp" + torch.save({ + "step": step, "training_time_ms": ms, "model": model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "rng_cpu": torch.get_rng_state(), "rng_cuda": torch.cuda.get_rng_state(), + "flags": {"_untied": _untied, "_ema_started": _ema_started, "_ema_steps": _ema_steps}, + "ema_model": ema_model.state_dict() if ema_model is not None else None, + }, tmp) + os.replace(tmp, path) + +def load_checkpoint(path, model, optimizers, device, ema_model=None): + p = torch.load(path, map_location=device, weights_only=False) + model.load_state_dict(p["model"], strict=True) + for opt, sd in zip(optimizers, p["optimizers"]): + opt.load_state_dict(sd) + if ema_model is not None and p.get("ema_model") is not None: + ema_model.load_state_dict(p["ema_model"], strict=True) + torch.set_rng_state(p["rng_cpu"].cpu()) + torch.cuda.set_rng_state(p["rng_cuda"].cpu()) + f = p.get("flags", {}) + return p["step"], p["training_time_ms"], f.get("_untied", False), f.get("_ema_started", False), f.get("_ema_steps", 0) + +# --------------------------------------------------------------------------- +# Main training +# --------------------------------------------------------------------------- +def main(): + args = Hyperparameters() + code = Path(__file__).read_text(encoding="utf-8") + + global _INT8_KERNEL + _INT8_KERNEL = args.use_int8_kernel + + global _TRITON_KERNEL + _TRITON_KERNEL = args.use_triton_kernel and HAS_TRITON + + global _SCALE_FP8_STE + _SCALE_FP8_STE = args.scale_fp8 + + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = world_size > 1 + grad_accum_steps = max(1, 8 // world_size) + grad_scale = 1.0 / grad_accum_steps + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + os.makedirs("logs/cuda/xnor/", exist_ok=True) + logfile = f"logs/cuda/xnor/{args.run_id}.txt" if master_process else None + if master_process: print(logfile) + def log0(msg, console=True): + if not master_process: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + val_tokens = ld_val(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_luts(sp, args.vocab_size, device) + + # --- Model --- + # BINARIZE_ACTIVATIONS: 0=BWN, 1=full XNOR, 2=XNOR but BWN on MLP down proj + _ba = args.binarize_activations + _binarize_act = _ba >= 1 + _binarize_down = _ba == 1 + base_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + group_size=args.xnor_group_size, activation=args.activation_type, + embed_dim=args.embed_dim, fp_storage=args.fp_storage, + softcap_type=args.softcap_type, rope_type=args.rope_type, + yarn_max_len=args.yarn_max_len, train_seq_len=args.train_seq_len, + binarize_act=_binarize_act, binarize_down=_binarize_down, + smear=args.smear, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, nn.Linear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.tie_embeddings: + base_model.lm_head.weight.requires_grad_(False) + + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.recompile_limit = 16 + # No DDP needed for EGGROLL - each GPU evaluates different perturbations + # No torch.compile - forward only, perturbations modify weights in-place + + # --- Load pretrained weights --- + if args.eggroll_load: + log0(f"loading pretrained weights from {args.eggroll_load}") + if args.eggroll_load.endswith(".ptz"): + # Compressed artifact + with open(args.eggroll_load, "rb") as f: + raw = f.read() + method_byte = raw[0] + compressed = raw[1:] + if method_byte == 2 and HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + decompressed = dctx.decompress(compressed) + elif method_byte == 1 and HAS_BROTLI: + decompressed = brotli.decompress(compressed) + else: + decompressed = lzma.decompress(compressed) + loaded = torch.load(io.BytesIO(decompressed), map_location="cpu", weights_only=False) + base_model.load_state_dict(deq_sd(loaded), strict=False) + else: + # Raw checkpoint + ckpt = torch.load(args.eggroll_load, map_location="cpu", weights_only=False) + sd = ckpt["model"] if "model" in ckpt else ckpt + base_model.load_state_dict(sd, strict=False) + log0("pretrained weights loaded") + + # --- EGGROLL setup --- + # No gradients needed - forward only + for p in base_model.parameters(): + p.requires_grad_(False) + binary_params = get_binary_params(base_model) + log0(f"EGGROLL binary params: {len(binary_params)} layers, " + f"{sum(p.numel() for _,p in binary_params)} params") + + # --- Log --- + log0("--- Hyperparameters ---", console=False) + log0(" ".join(f"{a}={getattr(args,a)}" for a in sorted(dir(args)) + if not a.startswith("_") and a not in ("train_files","val_files") + and not callable(getattr(args,a))), console=False) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"XNOR-EGGROLL | params:{n_params} L:{args.num_layers} d:{args.model_dim} " + f"h:{args.num_heads} kv:{args.num_kv_heads} mlp:{args.mlp_mult}x " + f"act:{args.activation_type} xnor_act:{args.binarize_activations} " + f"g:{args.xnor_group_size} ws:{world_size} pop:{args.pop_size} " + f"sigma:{args.eggroll_sigma} rank:{args.eggroll_rank}") + + # --- Data --- + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step, elapsed_ms): + import math + if args.lr_schedule == "cosine": + if max_wallclock_ms is None: + total = args.iterations + warmup = min(args.lr_warmup_steps, total) + if step < warmup: + return step / max(warmup, 1) + progress = (step - warmup) / max(total - warmup, 1) + else: + warmup_ms = args.lr_warmup_steps * (elapsed_ms / max(step, 1)) if step > 0 else 1000.0 + if elapsed_ms < warmup_ms: + return elapsed_ms / max(warmup_ms, 1e-9) + progress = (elapsed_ms - warmup_ms) / max(max_wallclock_ms - warmup_ms, 1e-9) + return 0.5 * (1.0 + math.cos(math.pi * min(progress, 1.0))) + if args.warmdown_fraction <= 0: return 1.0 + if max_wallclock_ms is None: + warmdown_start = int(args.iterations * (1.0 - args.warmdown_fraction)) + return max((args.iterations - step) / max(args.iterations * args.warmdown_fraction, 1), 0.0) if step >= warmdown_start else 1.0 + warmdown_ms = max_wallclock_ms * args.warmdown_fraction + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # --- Seq len scheduling --- + seq_len_stages = [args.train_seq_len] + if args.seq_len_schedule: + seq_len_stages = [] + sl = 128 + while sl <= args.train_seq_len: + seq_len_stages.append(sl) + sl *= 2 + if seq_len_stages[-1] != args.train_seq_len: + seq_len_stages.append(args.train_seq_len) + log0(f"seq_len_schedule: {seq_len_stages}") + + def get_seq_len(step, elapsed_ms): + if len(seq_len_stages) <= 1: + return seq_len_stages[0] + n = len(seq_len_stages) + if max_wallclock_ms is not None: + stage_idx = min(int(elapsed_ms / (max_wallclock_ms / n)), n - 1) + else: + stage_idx = min(int(step / (args.iterations / n)), n - 1) + return seq_len_stages[stage_idx] + + # --- Main EGGROLL loop --- + training_time_ms = 0.0 + stop_after_step = None + pop_per_gpu = args.pop_size // world_size + assert pop_per_gpu % 2 == 0, "pop_size / world_size must be even for antithetic pairs" + pairs_per_gpu = pop_per_gpu // 2 + + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + + base_model.eval() + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(args, base_model, rank, world_size, device, 1, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms") + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms") + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + lr_scale = lr_mul(step, elapsed_ms) + current_seq_len = get_seq_len(step, elapsed_ms) + + # Get data batch (same for all perturbations) + x, y = train_loader.next_batch(args.train_batch_tokens, current_seq_len, 1) + + # Evaluate perturbations (antithetic pairs) + fitnesses_pos = [] + fitnesses_neg = [] + local_seeds = [] + + for i in range(pairs_per_gpu): + global_idx = step * (args.pop_size // 2) + rank * pairs_per_gpu + i + seed = global_idx * 137 + 42 # deterministic seed + + # Positive perturbation + apply_perturbation(binary_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=+1.0, device=device) + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss_pos = base_model(x, y) + fitnesses_pos.append(-loss_pos.item()) + remove_perturbation(binary_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=+1.0, device=device) + + # Negative perturbation (antithetic) + apply_perturbation(binary_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=-1.0, device=device) + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss_neg = base_model(x, y) + fitnesses_neg.append(-loss_neg.item()) + remove_perturbation(binary_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=-1.0, device=device) + + local_seeds.append(seed) + + # Gather fitnesses across GPUs + if distributed: + all_pos = [None] * world_size + all_neg = [None] * world_size + all_seeds_list = [None] * world_size + dist.all_gather_object(all_pos, fitnesses_pos) + dist.all_gather_object(all_neg, fitnesses_neg) + dist.all_gather_object(all_seeds_list, local_seeds) + fitnesses_pos = [f for sublist in all_pos for f in sublist] + fitnesses_neg = [f for sublist in all_neg for f in sublist] + all_seeds = [s for sublist in all_seeds_list for s in sublist] + else: + all_seeds = local_seeds + + # Shape fitnesses + shaped = shape_fitnesses(fitnesses_pos, fitnesses_neg, method=args.fitness_shaping) + + # Update weights + lr = args.eggroll_lr * lr_scale + compute_eggroll_update(binary_params, all_seeds, shaped, args.eggroll_sigma, + args.eggroll_rank, lr, device=device) + + step += 1 + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + if args.train_log_every > 0 and step % args.train_log_every == 0: + mean_fitness = sum(fitnesses_pos) / len(fitnesses_pos) + log0(f"step:{step}/{args.iterations} fitness:{mean_fitness:.4f} " + f"t:{approx_ms:.0f}ms step_avg:{approx_ms/step:.1f}ms") + + # Wallclock cap + if stop_after_step is None and max_wallclock_ms is not None and step % 10 == 0: + reached_cap = approx_ms >= max_wallclock_ms + if distributed: + cap_t = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(cap_t, op=dist.ReduceOp.MAX) + reached_cap = bool(cap_t.item()) + if reached_cap: + stop_after_step = step + + # --- Sign-sort permutation for compression --- + # Permute MLP hidden dimension so same-sign weights cluster together. + # Sort columns of down_proj by their sign pattern (lexicographic on sign bits). + # Apply matching row permutation to up_proj. Model output is identical. + if master_process: + sd = base_model.state_dict() + for layer_idx in range(args.num_layers): + # Find the MLP weight pairs + down_key = f"blocks.{layer_idx}.mlp.proj.weight" + if args.activation_type == "swiglu": + up_key = f"blocks.{layer_idx}.mlp.gate_up.weight" + else: + up_key = f"blocks.{layer_idx}.mlp.fc.weight" + if down_key not in sd or up_key not in sd: + continue + w_down = sd[down_key] # [dim, hidden] + w_up = sd[up_key] # [hidden, dim] (or [hidden*2, dim] for swiglu) + hidden = w_down.shape[1] + # Sort hidden dim columns of w_down by sign pattern + signs = (w_down >= 0).to(torch.int8) # [dim, hidden] + # Pack sign columns into sortable keys (use first 64 rows as sort key) + n_key_rows = min(64, signs.shape[0]) + sort_keys = torch.zeros(hidden, dtype=torch.int64, device=w_down.device) + for r in range(n_key_rows): + sort_keys += signs[r, :].long() << (n_key_rows - 1 - r) + perm = sort_keys.argsort() + # Apply permutation + sd[down_key] = w_down[:, perm] + if args.activation_type == "swiglu": + half = w_up.shape[0] // 2 + sd[up_key] = torch.cat([w_up[:half][perm], w_up[half:][perm]], dim=0) + else: + sd[up_key] = w_up[perm] + if base_model.tie_embeddings: + sd.pop("lm_head.weight", None) + q_obj, q_stats = q_sd(sd, group_size=args.xnor_group_size, + fp_storage=args.fp_storage, scale_fp8=args.scale_fp8) + buf = io.BytesIO() + torch.save(q_obj, buf) + raw_bytes = buf.getvalue() + # Compare LZMA vs Brotli, keep smallest + lzma_blob = lzma.compress(raw_bytes, preset=9) + best_blob, best_method = lzma_blob, "lzma" + comp_log = f"compression: lzma={len(lzma_blob)/1e6:.2f}MB" + if HAS_BROTLI: + brotli_blob = brotli.compress(raw_bytes, quality=11) + comp_log += f" brotli={len(brotli_blob)/1e6:.2f}MB" + if len(brotli_blob) < len(best_blob): + best_blob, best_method = brotli_blob, "brotli" + if HAS_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + zstd_blob = cctx.compress(raw_bytes) + comp_log += f" zstd={len(zstd_blob)/1e6:.2f}MB" + if len(zstd_blob) < len(best_blob): + best_blob, best_method = zstd_blob, "zstd" + log0(comp_log) + log0(f"using: {best_method}") + with open("final_model.xnor.ptz", "wb") as f: + # Method byte: 0=lzma, 1=brotli, 2=zstd + method_id = {"lzma": 0, "brotli": 1, "zstd": 2}[best_method] + f.write(bytes([method_id])) + f.write(best_blob) + artifact_bytes = 1 + len(best_blob) + code_bytes = len(code.encode("utf-8")) + total = artifact_bytes + code_bytes + log0(f"artifact:{artifact_bytes/1e6:.2f}MB binary:{q_stats['binary_params']}({q_stats['binary_bytes']}B) " + f"fp:{q_stats['fp_params']}({q_stats['fp_bytes']}B) code:{code_bytes}") + log0(f"budget:{total}/{16000000} ({total/1e6:.2f}/{16.00:.2f}MB) " + f"{'FITS' if total <= 16000000 else 'OVER'}") + + # --- Roundtrip eval --- + if distributed: dist.barrier() + with open("final_model.xnor.ptz", "rb") as f: + raw = f.read() + method_byte = raw[0] + compressed = raw[1:] + if method_byte == 2 and HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + decompressed = dctx.decompress(compressed) + elif method_byte == 1 and HAS_BROTLI: + decompressed = brotli.decompress(compressed) + else: + decompressed = lzma.decompress(compressed) + loaded = torch.load(io.BytesIO(decompressed), map_location="cpu", weights_only=False) + base_model.load_state_dict(deq_sd(loaded), strict=False) + torch._dynamo.reset() + q_val_loss, q_val_bpb = eval_val(args, base_model, rank, world_size, device, 1, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"final_xnor_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f}") + + opt_temp = 1.0 + if args.temp_scaling: + torch.cuda.synchronize() + t_temp = time.perf_counter() + calibration_tokens = train_loader.stream.take(65536).to(device) + opt_temp = find_temp(args, base_model, rank, world_size, device, 1, + calibration_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + torch.cuda.synchronize() + log0(f"temp_scaling optimal_T:{opt_temp:.2f} time:{1000*(time.perf_counter()-t_temp):.0f}ms") + + if args.sliding_eval: + torch.cuda.synchronize() + t_sl = time.perf_counter() + sw_loss, sw_bpb = eval_val_sliding(args, base_model, rank, world_size, device, 1, + val_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, stride=args.sliding_eval_stride, + temperature=opt_temp) + torch.cuda.synchronize() + log0(f"final_sliding val_loss:{sw_loss:.4f} val_bpb:{sw_bpb:.4f} " + f"(stride={args.sliding_eval_stride}, T={opt_temp:.2f}) time:{1000*(time.perf_counter()-t_sl):.0f}ms") + + if distributed: dist.destroy_process_group() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/train_gpt_cuda_xnor_eggroll_lora.py b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/train_gpt_cuda_xnor_eggroll_lora.py new file mode 100644 index 0000000000..b0610ad3e1 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/train_gpt_cuda_xnor_eggroll_lora.py @@ -0,0 +1,1546 @@ +"""XNOR-Net LLM training script for OpenAI's Parameter Golf Challenge. +Ciprian-Florin Ifrim - April 2026 + +Based on the Binary-Weight-Network script, extended to full XNOR-Net: +both weights AND activations are binarized using sign() with per-group/per-token +scaling factors derived from the XNOR-Net paper (Rastegari et al. 2016). + +Architecture: U-Net transformer with skip connections that provides error correction +for the information loss inherent in binary quantization of both weights and activations. +""" + +import copy +import glob +import io +import math +import os +import random +import sys +import time +import lzma +from pathlib import Path +try: + import brotli + HAS_BROTLI = True +except ImportError: + HAS_BROTLI = False +try: + import zstandard as zstd + HAS_ZSTD = True +except ImportError: + HAS_ZSTD = False +try: + from gram_newton_schulz import GramNewtonSchulz, POLAR_EXPRESS_COEFFICIENTS + HAS_GRAM_NS = True +except ImportError: + HAS_GRAM_NS = False +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False +def _get_flash_attn_func(): + """GPU-aware FlashAttention selection. + H100: FA3 -> FA2 -> SDPA + 5090: FA4 -> FA2 -> SDPA + Other: FA2 -> SDPA + """ + if torch.cuda.is_available(): + sm = torch.cuda.get_device_capability() + # Hopper (SM 90) — prefer FA3 + if sm[0] == 9 and sm[1] == 0: + try: + from flash_attn_interface import flash_attn_func + return flash_attn_func + except ImportError: + pass + # Blackwell (SM 100/120) — try FA4 + if sm[0] >= 10: + try: + from flash_attn.cute import flash_attn_func + return flash_attn_func + except ImportError: + pass + # FA2 — universal fallback with CUDA kernels + try: + from flash_attn import flash_attn_func + return flash_attn_func + except ImportError: + pass + # PyTorch native SDPA — always works, no install needed + def _sdpa_flash_attn_func(q, k, v, causal=False, **kwargs): + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + o = F.scaled_dot_product_attention(q, k, v, is_causal=causal, enable_gqa=True) + return o.transpose(1, 2) + return _sdpa_flash_attn_func +flash_attn_func = _get_flash_attn_func() + +# --------------------------------------------------------------------------- +# Hyperparameters +# --------------------------------------------------------------------------- +def _e(k, d, t=str): + v = os.environ.get(k, str(d)) + if t == bool: return bool(int(v)) + return t(v) + +class Hyperparameters: + # Data + data_path = _e("DATA_PATH", "./data/datasets/fineweb10B_sp8192") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = _e("TOKENIZER_PATH", "./data/tokenizers/fineweb_8192_bpe.model") + run_id = os.environ.get("RUN_ID", f"xnor_{int(time.time())}") + seed = _e("SEED", 1337, int) + # Compile + compile_mode = _e("COMPILE_MODE", "default") + # Eval + val_batch_size = _e("VAL_BATCH_SIZE", 524288, int) + val_loss_every = _e("VAL_LOSS_EVERY", 500, int) + train_log_every = _e("TRAIN_LOG_EVERY", 10, int) + sliding_eval = _e("SLIDING_EVAL", 0, bool) + sliding_eval_stride = _e("SLIDING_EVAL_STRIDE", 16, int) + sliding_batch_size = _e("SLIDING_BATCH_SIZE", 512, int) + temp_scaling = _e("TEMP_SCALING", 0, bool) + # Training schedule + iterations = _e("ITERATIONS", 2000, int) + warmdown_fraction = _e("WARMDOWN_FRACTION", 0.15, float) + lr_warmup_steps = _e("LR_WARMUP_STEPS", 5, int) + lr_schedule = _e("LR_SCHEDULE", "linear") # "linear" (current) or "cosine" + seq_len_schedule = _e("SEQ_LEN_SCHEDULE", 0, bool) # ramp seq_len from 128 to train_seq_len + train_batch_tokens = _e("TRAIN_BATCH_TOKENS", 524288, int) + train_seq_len = _e("TRAIN_SEQ_LEN", 1024, int) + max_wallclock_seconds = _e("MAX_WALLCLOCK_SECONDS", 0.0, float) + # Architecture — from XNOR-UNet spec + vocab_size = _e("VOCAB_SIZE", 8192, int) + num_layers = _e("NUM_LAYERS", 15, int) + num_kv_heads = _e("NUM_KV_HEADS", 4, int) + model_dim = _e("MODEL_DIM", 768, int) + num_heads = _e("NUM_HEADS", 8, int) + mlp_mult = _e("MLP_MULT", 4, int) + tie_embeddings = _e("TIE_EMBEDDINGS", 1, int) + embed_dim = _e("EMBED_DIM", 312, int) + activation_type = _e("ACTIVATION", "relu2") + # Attention + rope_base = _e("ROPE_BASE", 5000.0, float) + rope_type = _e("ROPE_TYPE", "yarn") + yarn_max_len = _e("YARN_MAX_LEN", 2048, int) + qk_gain_init = _e("QK_GAIN_INIT", 2.25, float) + # Logits + logit_softcap = _e("LOGIT_SOFTCAP", 10.0, float) + softcap_type = _e("SOFTCAP_TYPE", "poly") + tied_embed_init_std = _e("TIED_EMBED_INIT_STD", 0.005, float) + # Architecture extras + smear = _e("SMEAR", 1, bool) + # XNOR quantization + xnor_group_size = _e("XNOR_GROUP_SIZE", 128, int) + # Controls activation binarization: 0 = none (BWN), 1 = all (full XNOR), 2 = skip MLP down proj + binarize_activations = _e("BINARIZE_ACTIVATIONS", 1, int) + # Use INT8 tensor cores for binary matmul (speed test) + use_int8_kernel = _e("USE_INT8_KERNEL", 0, bool) + # Use Triton XNOR+POPCOUNT kernel (true 1-bit matmul) + use_triton_kernel = _e("USE_TRITON_KERNEL", 0, bool) + # FP8 QAT for non-binary params + _fp_raw = os.environ.get("FP_STORAGE", "FP8") + fp_storage = True if _fp_raw == "FP8" else False + # Scale storage for binary weight groups: FP8 or BF16 + _scale_raw = os.environ.get("SCALE_STORAGE", "BF16") + scale_fp8 = True if _scale_raw == "FP8" else False + # Compression regularizer: penalizes sign entropy within groups + sign_compress_reg = _e("SIGN_COMPRESS_REG", 0.0, float) + # EGGROLL + pop_size = _e("POP_SIZE", 256, int) # population size (antithetic pairs = pop_size/2) + eggroll_sigma = _e("EGGROLL_SIGMA", 0.01, float) # perturbation scale + eggroll_lr = _e("EGGROLL_LR", 0.001, float) # learning rate + eggroll_rank = _e("EGGROLL_RANK", 1, int) # rank of perturbation matrices + fitness_shaping = _e("FITNESS_SHAPING", "rank") # "rank" or "sign" + eggroll_load = _e("EGGROLL_LOAD", "") # path to pretrained .ptz to load + eggroll_layers = _e("EGGROLL_LAYERS", 0, int) # 0 = all layers, N = last N layers only + eggroll_lora_rank = _e("EGGROLL_LORA_RANK", 0, int) # 0 = direct perturbation, >0 = LoRA rank + # Checkpointing + checkpoint_every = _e("CHECKPOINT_EVERY", 0, int) + checkpoint_dir = _e("CHECKPOINT_DIR", "./checkpoints") + # Diagnostics + churn_log_every = _e("CHURN_LOG_EVERY", 500, int) + +# Scalar/low-dim parameter names (excluded from Muon, use Adam) +CTP = ("attn_scale", "mlp_scale", "resid_mix", "q_gain", "skip_weight", + "skip_weights", "vocab_bias", "smear") + +# --------------------------------------------------------------------------- +# Binary packing (1 bit per weight, lossless) +# --------------------------------------------------------------------------- +def pack_binary(q: Tensor) -> tuple[bytes, int]: + bits = ((q.reshape(-1).to(torch.int8) + 1) // 2).numpy().astype(np.uint8) + n = len(bits) + pad = (8 - n % 8) % 8 + if pad: + bits = np.concatenate([bits, np.zeros(pad, dtype=np.uint8)]) + groups = bits.reshape(-1, 8) + packed = np.zeros(len(groups), dtype=np.uint8) + for i in range(8): + packed |= groups[:, i] << i + return packed.tobytes(), n + +def unpack_binary(data: bytes, n: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8) + bits = np.zeros((len(packed), 8), dtype=np.int8) + for i in range(8): + bits[:, i] = (packed >> i) & 1 + flat = bits.reshape(-1)[:n] + return torch.from_numpy(flat.astype(np.int8) * 2 - 1) + +# --------------------------------------------------------------------------- +# State dict serialization +# --------------------------------------------------------------------------- +def q_sd(state_dict: dict, group_size: int = 128, fp_storage=False, scale_fp8=False) -> tuple[dict, dict]: + """Binary pack large 2D weight matrices, BF16/FP8 for everything else.""" + quantized = {} + stats = {"binary_params": 0, "binary_bytes": 0, "fp_params": 0, "fp_bytes": 0} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().float().contiguous() + t_orig_shape = list(t.shape) + if t.ndim == 3: + t = t.reshape(t.shape[0], -1) + is_binary = ( + t.ndim == 2 and t.numel() > 65_536 + and "tok_emb" not in name and "lm_head" not in name + and "embed_proj" not in name + ) + if is_binary: + pad = (group_size - t.shape[1] % group_size) % group_size + t_padded = F.pad(t, (0, pad)) if pad > 0 else t + t_grouped = t_padded.reshape(-1, group_size) + scale = t_grouped.abs().mean(-1, keepdim=True).clamp(min=1e-8) + q = torch.where(t_grouped >= 0, + torch.ones_like(t_grouped, dtype=torch.int8), + -torch.ones_like(t_grouped, dtype=torch.int8)) + packed_bytes, n_bits = pack_binary(q) + if scale_fp8: + scale_stored = scale.to(torch.float8_e4m3fn).squeeze(-1) + scale_bytes_per = 1 + else: + scale_stored = scale.bfloat16().squeeze(-1) + scale_bytes_per = 2 + quantized[name] = { + "type": "binary", "packed": packed_bytes, + "scale": scale_stored, + "shape": list(t.shape), "padded_cols": t_padded.shape[1], + "group_size": group_size, "n_bits": n_bits, + "orig_shape": t_orig_shape, + } + stats["binary_params"] += t.numel() + stats["binary_bytes"] += len(packed_bytes) + scale.numel() * scale_bytes_per + elif fp_storage and t.ndim == 2: + quantized[name] = {"type": "fp8", "data": t.to(torch.float8_e4m3fn)} + stats["fp_params"] += t.numel() + stats["fp_bytes"] += t.numel() + else: + quantized[name] = {"type": "bf16", "data": t.bfloat16()} + stats["fp_params"] += t.numel() + stats["fp_bytes"] += t.numel() * 2 + return quantized, stats + +def deq_sd(quantized: dict, target_dtype=torch.bfloat16): + """Reconstruct state dict from quantized representation.""" + out = {} + for name, entry in quantized.items(): + if entry["type"] == "binary": + q = unpack_binary(entry["packed"], entry["n_bits"]) + q = q.float().reshape(-1, entry["group_size"]) + scale = entry["scale"].float().unsqueeze(-1) + # No shrinkage correction: binary q_absmean = 1.0 always + t = (q * scale).reshape(-1, entry["padded_cols"]) + shape = entry["shape"] + result = t[:shape[0], :shape[1]].to(target_dtype) + orig = entry.get("orig_shape") + out[name] = result.reshape(orig).contiguous() if orig and orig != shape else result.contiguous() + elif entry["type"] == "fp8": + out[name] = entry["data"].to(torch.float32).to(target_dtype).contiguous() + else: + out[name] = entry["data"].to(target_dtype).contiguous() + return out + +# --------------------------------------------------------------------------- +# Diagnostics +# --------------------------------------------------------------------------- +_prev_committed: dict = {} +def churn_fn(model: nn.Module, group_size: int = 128): + global _prev_committed + total = flipped = 0 + with torch.no_grad(): + for name, p in model.named_parameters(): + if p.ndim == 2 and "weight" in name and p.shape[0] > 1: + w = p.detach().float().reshape(-1, group_size) + q = torch.where(w >= 0, torch.ones_like(w), -torch.ones_like(w)).cpu().numpy() + if name in _prev_committed: + flipped += int(np.sum(q != _prev_committed[name])) + total += q.size + _prev_committed[name] = q + return flipped / max(total, 1) + +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# Muon optimizer +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# EGGROLL: Evolution-Guided optimization with low-rank perturbations +# --------------------------------------------------------------------------- +def get_binary_params(model, num_layers=10, last_n_layers=0): + """Get list of (name, param) for parameters that will be perturbed. + last_n_layers=0 means all layers, N means only the last N blocks.""" + params = [] + for name, p in model.named_parameters(): + if p.ndim == 2 and p.numel() > 65536 and "tok_emb" not in name and "lm_head" not in name and "embed_proj" not in name: + if last_n_layers > 0: + # Only include params from last N blocks + match = False + for layer_idx in range(num_layers - last_n_layers, num_layers): + if f"blocks.{layer_idx}." in name: + match = True + break + if not match: + continue + params.append((name, p)) + return params + +def apply_perturbation(binary_params, seed, sigma, rank, sign=1.0, device="cuda"): + """Apply rank-r perturbation to binary params using deterministic RNG.""" + rng = torch.Generator(device=device) + rng.manual_seed(seed) + for name, p in binary_params: + m, n = p.shape + for r in range(rank): + a = torch.randn(m, device=device, generator=rng) + b = torch.randn(n, device=device, generator=rng) + p.data.add_(a.outer(b), alpha=sign * sigma / math.sqrt(rank)) + +def remove_perturbation(binary_params, seed, sigma, rank, sign=1.0, device="cuda"): + """Remove previously applied perturbation (subtract it).""" + apply_perturbation(binary_params, seed, sigma, rank, sign=-sign, device=device) + +def compute_eggroll_update(binary_params, seeds, fitnesses, sigma, rank, lr, device="cuda"): + """Compute and apply the EGGROLL weight update. + Uses the efficient (diag(f) @ A).T @ B formulation for rank-1.""" + N = len(seeds) + for name, p in binary_params: + m, n = p.shape + update = torch.zeros_like(p.data) + for i, (seed, f) in enumerate(zip(seeds, fitnesses)): + rng = torch.Generator(device=device) + rng.manual_seed(seed) + for r in range(rank): + a = torch.randn(m, device=device, generator=rng) + b = torch.randn(n, device=device, generator=rng) + update.add_(a.outer(b), alpha=f / (math.sqrt(rank) * N)) + p.data.add_(update, alpha=lr / sigma) + +def shape_fitnesses(fitnesses_pos, fitnesses_neg, method="rank"): + """Shape fitnesses from antithetic pairs.""" + N = len(fitnesses_pos) + if method == "sign": + return [float(fp > fn) - float(fn > fp) for fp, fn in zip(fitnesses_pos, fitnesses_neg)] + # Rank-based shaping: combine all, rank, normalize to [-0.5, 0.5] + all_f = list(fitnesses_pos) + list(fitnesses_neg) + ranked = torch.argsort(torch.argsort(torch.tensor(all_f, dtype=torch.float32))).float() + ranked = ranked / (2 * N - 1) - 0.5 + # Return shaped fitness for positive perturbation (subtract negative's) + shaped = [(ranked[i] - ranked[N + i]).item() for i in range(N)] + return shaped + +# --- LoRA-based EGGROLL --- +def create_lora_params(binary_params, lora_rank, device="cuda"): + """Create LoRA A and B matrices for each binary param, initialized to zero.""" + lora = {} + lora_list = [] # flat list of (name, param) for perturbation + for name, p in binary_params: + m, n = p.shape + A = torch.zeros(m, lora_rank, device=device, dtype=torch.float32) + B = torch.zeros(lora_rank, n, device=device, dtype=torch.float32) + lora[name] = (A, B) + lora_list.append((name + ".lora_A", A)) + lora_list.append((name + ".lora_B", B)) + return lora, lora_list + +def merge_lora(binary_params, lora, scale=1.0): + """Merge LoRA into base weights: W += A @ B * scale.""" + for name, p in binary_params: + A, B = lora[name] + p.data.add_(A @ B, alpha=scale) + +def unmerge_lora(binary_params, lora, scale=1.0): + """Remove LoRA from base weights: W -= A @ B * scale.""" + merge_lora(binary_params, lora, scale=-scale) + +# --------------------------------------------------------------------------- +# Data loading +# --------------------------------------------------------------------------- +def ld_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" Tensor: + chunks = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int): + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start:start + per_rank_span].pin_memory().to(self.device, non_blocking=True).to(torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x, y + +# --------------------------------------------------------------------------- +# Model components +# --------------------------------------------------------------------------- +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def apply_qat_ste(w: Tensor, fp_storage) -> Tensor: + if not fp_storage: + return w + w_sim = w.to(torch.float8_e4m3fn).to(w.dtype) + return (w_sim - w).detach() + w + +class QATLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=False, fp_storage=False): + super().__init__(in_features, out_features, bias=bias) + self.fp_storage = fp_storage + def forward(self, x: Tensor) -> Tensor: + w = apply_qat_ste(self.weight, self.fp_storage) + return F.linear(x, w.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) + +class QATEmbedding(nn.Embedding): + def __init__(self, num_embeddings, embedding_dim, fp_storage=False): + super().__init__(num_embeddings, embedding_dim) + self.fp_storage = fp_storage + def forward(self, input: Tensor) -> Tensor: + w = apply_qat_ste(self.weight, self.fp_storage) + return F.embedding(input, w, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + +# --------------------------------------------------------------------------- +# XNOR-Net Linear Layers +# --------------------------------------------------------------------------- +_INT8_KERNEL = False # set from Hyperparameters.use_int8_kernel in main() + +class _INT8XNORFn(torch.autograd.Function): + """INT8 tensor core forward, BF16 backward. + Uses per-row alpha (not per-group) so we get a single torch._int_mm call. + """ + @staticmethod + def forward(ctx, x, w): + # Per-row weight scale (single matmul) + w_alpha = w.detach().abs().mean(dim=-1).clamp_(min=1e-8) # [N] + x_beta = x.detach().abs().mean(dim=-1, keepdim=True).clamp_(min=1e-8) # [..., 1] + + w_int8 = w.detach().sign().to(torch.int8) # [N, K] + x_int8 = x.detach().sign().to(torch.int8) # [..., K] + + # Flatten to 2D for torch._int_mm(A[M,K], B[K,N]) -> [M,N] + orig_shape = x.shape[:-1] + M, K, N = orig_shape.numel(), x.shape[-1], w.shape[0] + x_2d = x_int8.reshape(M, K) + wt = w_int8.t().contiguous() # [K, N] + raw = torch._int_mm(x_2d, wt) # [M, N] INT32 + + y = raw.to(x.dtype) * x_beta.reshape(M, 1) * w_alpha.unsqueeze(0) + ctx.save_for_backward(x, w, w_alpha, x_beta) + return y.reshape(*orig_shape, N) + + @staticmethod + def backward(ctx, grad_output): + x, w, w_alpha, x_beta = ctx.saved_tensors + dtype = grad_output.dtype + + w_binary = (w.sign() * w_alpha.unsqueeze(-1)).to(dtype) # [N, K] + x_binary = (x.sign() * x_beta).to(dtype) # [..., K] + + # STE: gradient flows through sign() as identity + grad_x = grad_output @ w_binary + M = grad_output.shape[:-1].numel() + grad_w = grad_output.reshape(M, -1).t() @ x_binary.reshape(M, -1) + return grad_x, grad_w + +# --------------------------------------------------------------------------- +# Triton XNOR+POPCOUNT kernel — true 1-bit matmul +# --------------------------------------------------------------------------- +_TRITON_KERNEL = False # set from Hyperparameters in main() +_SCALE_FP8_STE = False # simulate FP8 scale quantization during training + +def pack_signs(x: Tensor) -> Tensor: + """Pack sign bits into int32. +1 -> bit 1, -1 -> bit 0. + x: [..., K] any dtype. K must be divisible by 32. + returns: [..., K//32] int32 + """ + K = x.shape[-1] + assert K % 32 == 0, f"K={K} must be divisible by 32" + bits = (x > 0).view(*x.shape[:-1], K // 32, 32) + shifts = torch.arange(32, device=x.device, dtype=torch.int32) + return (bits.to(torch.int32) << shifts).sum(dim=-1) + +if HAS_TRITON: + @triton.jit + def _xnor_matmul_fwd_kernel( + X_ptr, W_ptr, Y_ptr, alpha_ptr, beta_ptr, + M, N, K_packed, GROUP_PACKED: tl.constexpr, num_groups, + stride_xm, stride_wn, stride_yn, stride_alpha_n, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + ): + pid = tl.program_id(0) + num_n = tl.cdiv(N, BLOCK_N) + pid_m = pid // num_n + pid_n = pid % num_n + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_m = offs_m < M + mask_n = offs_n < N + + y_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.bfloat16) + + x_base = X_ptr + offs_m * stride_xm + w_base = W_ptr + offs_n * stride_wn + + # Loop over groups, each with GROUP_PACKED packed int32s + for g in range(num_groups): + group_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + k_start = g * GROUP_PACKED + for k in range(GROUP_PACKED): + x = tl.load(x_base + k_start + k, mask=mask_m, other=0) + w = tl.load(w_base + k_start + k, mask=mask_n, other=0) + # Cast broadcast XOR to int32 to prevent int64 promotion in popc + diff = (x[:, None] ^ w[None, :]).to(tl.int32) + group_acc += tl.extra.cuda.libdevice.popc(diff) + + # diff_bits = popc(XOR), dot = GROUP_SIZE - 2 * diff_bits + group_dot = (GROUP_PACKED * 32 - 2 * group_acc).to(tl.bfloat16) + + # Per-group alpha: alpha[n, g] + alpha_g = tl.load(alpha_ptr + offs_n * stride_alpha_n + g, mask=mask_n, other=1.0).to(tl.bfloat16) + y_acc += group_dot * alpha_g[None, :] + + # Per-token beta + beta = tl.load(beta_ptr + offs_m, mask=mask_m, other=1.0).to(tl.bfloat16) + y_acc = y_acc * beta[:, None] + + # Store + y_ptrs = Y_ptr + offs_m[:, None] * stride_yn + offs_n[None, :] + tl.store(y_ptrs, y_acc, mask=mask_m[:, None] & mask_n[None, :]) + +class _TritonXNORFn(torch.autograd.Function): + """Triton XNOR+POPCOUNT forward with per-group alpha, BF16 backward.""" + @staticmethod + def forward(ctx, x, w, group_size): + N, K = w.shape + assert K % group_size == 0, f"K={K} must be divisible by group_size={group_size}" + assert K % 32 == 0, f"K={K} must be divisible by 32" + num_groups = K // group_size + GROUP_PACKED = group_size // 32 + + # Per-group weight scale: [N, num_groups] + w_alpha = w.detach().reshape(N, num_groups, group_size).abs().mean(dim=-1).clamp_(min=1e-8) + if _SCALE_FP8_STE: + w_alpha = w_alpha.to(torch.float8_e4m3fn).to(w_alpha.dtype) + # Per-token activation scale: [..., 1] + x_beta_kd = x.detach().abs().mean(dim=-1, keepdim=True).clamp_(min=1e-8) + + # Pack signs into int32 + x_packed = pack_signs(x.detach()) # [..., K//32] + w_packed = pack_signs(w.detach()) # [N, K//32] + + orig_shape = x.shape[:-1] + M = orig_shape.numel() + K_packed = K // 32 + + x_2d = x_packed.reshape(M, K_packed).contiguous() + w_alpha_c = w_alpha.contiguous() # [N, num_groups] + beta_flat = x_beta_kd.reshape(M).contiguous() + + Y = torch.empty(M, N, device=x.device, dtype=torch.bfloat16) + + BLOCK_M, BLOCK_N = 64, 64 + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),) + + _xnor_matmul_fwd_kernel[grid]( + x_2d, w_packed, Y, w_alpha_c, beta_flat, + M, N, K_packed, GROUP_PACKED, num_groups, + x_2d.stride(0), w_packed.stride(0), Y.stride(0), w_alpha_c.stride(0), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + ) + + ctx.save_for_backward(x, w, w_alpha, x_beta_kd) + ctx.group_size = group_size + return Y.reshape(*orig_shape, N) + + @staticmethod + def backward(ctx, grad_output): + x, w, w_alpha, x_beta_kd = ctx.saved_tensors + dtype = grad_output.dtype + group_size = ctx.group_size + + # Reconstruct per-group binary weights for backward + N, K = w.shape + num_groups = K // group_size + w_binary = (w.sign().reshape(N, num_groups, group_size) * + w_alpha.unsqueeze(-1)).reshape(N, K).to(dtype) + x_binary = (x.sign() * x_beta_kd).to(dtype) + + grad_x = grad_output @ w_binary + M = grad_output.shape[:-1].numel() + grad_w = grad_output.reshape(M, -1).t() @ x_binary.reshape(M, -1) + return grad_x, grad_w, None # None for group_size + +class XNORLinear(nn.Linear): + """XNOR-Net linear layer: binarizes BOTH weights AND activations. + + Forward pass: + 1. Weight binarization: W_bin = sign(W) * alpha, alpha = mean(|W|) per group + 2. Activation binarization: X_bin = sign(X) * beta, beta = mean(|X|) per token + 3. Output = F.linear(X_bin, W_bin) + + Backward pass: Straight-Through Estimator (STE) — gradients flow through as + if the binarization was identity. This matches the XNOR-Net paper's training + algorithm (Algorithm 1). + + When binarize_act=False, this degrades to Binary-Weight-Network (BWN). + """ + def __init__(self, in_features, out_features, bias=False, group_size=128, + binarize_act=True): + super().__init__(in_features, out_features, bias=bias) + self.group_size = group_size + self.binarize_act = binarize_act + + def forward(self, x: Tensor) -> Tensor: + # Triton XNOR+POPCOUNT path (full XNOR only, per-row alpha) + if _TRITON_KERNEL and self.binarize_act: + y = _TritonXNORFn.apply(x, self.weight, self.group_size) + if self.bias is not None: + y = y + self.bias.to(y.dtype) + return y + # INT8 tensor core path (full XNOR only, per-row alpha) + if _INT8_KERNEL and self.binarize_act: + y = _INT8XNORFn.apply(x, self.weight) + if self.bias is not None: + y = y + self.bias.to(y.dtype) + return y + + # Standard STE path (per-group alpha) + w = self.weight.bfloat16() + g = self.group_size + + # --- Weight binarization (STE) --- + w_g = w.reshape(-1, g) + alpha = w_g.abs().mean(-1, keepdim=True).clamp(min=1e-8) + # Simulate FP8 scale quantization so model compensates for roundtrip error + if _SCALE_FP8_STE: + alpha_q = alpha.to(torch.float8_e4m3fn).to(alpha.dtype) + alpha = alpha + (alpha_q - alpha).detach() + w_sign = torch.where(w_g >= 0, torch.ones_like(w_g), -torch.ones_like(w_g)) + w_binary = w + ((w_sign * alpha).reshape(w.shape) - w).detach() + + # --- Activation binarization (STE) --- + if self.binarize_act: + beta = x.abs().mean(dim=-1, keepdim=True).clamp(min=1e-8) + x_sign = torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x)) + x_binary = x + ((x_sign * beta) - x).detach() + else: + x_binary = x + + return F.linear(x_binary, w_binary, + self.bias.to(x.dtype) if self.bias is not None else None) + + +class NormedXNORLinear(XNORLinear): + """XNOR linear with RMSNorm on input. + Used for output projections receiving un-normalized activations (attention out, MLP down). + The RMSNorm ensures zero-mean input before binarization — critical for sign() quality + (XNOR-Net paper: B-A-C-P block order, Table 3b: 30.3% vs 44.2%). + """ + def forward(self, x: Tensor) -> Tensor: + return super().forward(F.rms_norm(x, (x.size(-1),))) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(p in name for p in CTP)) and param.dtype != torch.float32: + param.data = param.data.float() + +# --------------------------------------------------------------------------- +# Rotary embeddings +# --------------------------------------------------------------------------- +class Rotary(nn.Module): + def __init__(self, dim, base=10000.0, rope_type="rope", yarn_max_len=4096, train_seq_len=1024): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + if rope_type == "yarn": + scale = train_seq_len / yarn_max_len + freq_idx = torch.arange(0, dim, 2, dtype=torch.float32) + ramp = torch.clamp((freq_idx / dim - 0.25) / 0.75, 0.0, 1.0) + inv_freq = inv_freq / (ramp * (1.0 / scale - 1.0) + 1.0) + # Precompute for max seq_len, slice at runtime + t = torch.arange(train_seq_len, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self.register_buffer("_cos", freqs.cos()[None, :, None, :], persistent=False) + self.register_buffer("_sin", freqs.sin()[None, :, None, :], persistent=False) + + def forward(self, seq_len, device, dtype): + return self._cos[:, :seq_len].to(dtype=dtype), self._sin[:, :seq_len].to(dtype=dtype) + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- +class CausalSelfAttention(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + group_size=128, rope_type="rope", yarn_max_len=4096, + train_seq_len=1024, binarize_act=True): + super().__init__() + self.num_heads, self.num_kv_heads = num_heads, num_kv_heads + self.head_dim = dim // num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + # Fused QKV — single XNOR matmul + self.c_qkv = XNORLinear(dim, self.q_size + 2 * self.kv_size, bias=False, + group_size=group_size, binarize_act=binarize_act) + # Output projection with RMSNorm before binarization (B-A-C-P order) + self.proj = NormedXNORLinear(dim, dim, bias=False, + group_size=group_size, binarize_act=binarize_act) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, + rope_type=rope_type, yarn_max_len=yarn_max_len, + train_seq_len=train_seq_len) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + qkv_out = self.c_qkv(x) + q_out, k_out, v_out = qkv_out.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q_out.reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = k_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = v_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + # QK normalization (BF16, stays float — non-negotiable for attention quality) + q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + # FlashAttention-3 (BF16 — attention scores/softmax/V are NEVER binarized) + y = flash_attn_func(q.contiguous(), k.contiguous(), v.contiguous(), causal=True) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + +# --------------------------------------------------------------------------- +# MLP +# --------------------------------------------------------------------------- +class MLP(nn.Module): + def __init__(self, dim, mlp_mult, group_size=128, activation="relu2", + binarize_act=True, binarize_down=True): + super().__init__() + hidden = mlp_mult * dim + self.activation = activation + if activation == "swiglu": + self.gate_up = XNORLinear(dim, hidden * 2, bias=False, + group_size=group_size, binarize_act=binarize_act) + else: + self.fc = XNORLinear(dim, hidden, bias=False, + group_size=group_size, binarize_act=binarize_act) + self.proj = NormedXNORLinear(hidden, dim, bias=False, + group_size=group_size, binarize_act=binarize_down) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + if self.activation == "swiglu": + gu = self.gate_up(x) + gate, up = gu.chunk(2, dim=-1) + return self.proj(F.silu(gate) * up) + elif self.activation == "signsq": + h = self.fc(x) + return self.proj(h * h.abs()) + else: # relu2 + return self.proj(torch.relu(self.fc(x)).square()) + +# --------------------------------------------------------------------------- +# SmearModule +# --------------------------------------------------------------------------- +class SmearModule(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + cumsum = x.cumsum(dim=1) + counts = torch.arange(1, x.size(1) + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smeared = cumsum / counts + gate = torch.tanh(self.gate.to(dtype=x.dtype)) + return x + gate * (smeared - x) + +# --------------------------------------------------------------------------- +# Transformer block +# --------------------------------------------------------------------------- +class Block(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, group_size=128, activation="relu2", + rope_type="rope", yarn_max_len=4096, train_seq_len=1024, + binarize_act=True, binarize_down=True, smear=False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, + qk_gain_init, group_size, rope_type, + yarn_max_len, train_seq_len, binarize_act) + self.mlp = MLP(dim, mlp_mult, group_size, activation, binarize_act, binarize_down) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearModule(dim) if smear else None + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0] * x + mix[1] * x0 + n = self.attn_norm(x) + x = x + self.attn_scale.to(dtype=x.dtype) * self.attn(n) + x = x + self.mlp_scale.to(dtype=x.dtype) * self.mlp(self.mlp_norm(x)) + if self.smear is not None: + x = self.smear(x) + return x + +# --------------------------------------------------------------------------- +# GPT with U-Net skip connections +# --------------------------------------------------------------------------- +class GPT(nn.Module): + def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, + mlp_mult, tie_embeddings, tied_embed_init_std, logit_softcap, + rope_base, qk_gain_init, group_size=128, activation="relu2", + embed_dim=0, fp_storage=False, softcap_type="poly", + rope_type="rope", yarn_max_len=4096, train_seq_len=1024, + binarize_act=True, binarize_down=True, smear=False): + super().__init__() + self.tie_embeddings = tie_embeddings + self.logit_softcap = logit_softcap + self.softcap_type = softcap_type + self.fp_storage = fp_storage + self.embed_dim = embed_dim if embed_dim > 0 else model_dim + + # Embedding (FP8 QAT — NEVER binarized, confirmed 0.11 bpb RT gap) + self.tok_emb = QATEmbedding(vocab_size, self.embed_dim, fp_storage=fp_storage) + # Projection from embed_dim to model_dim (and reverse for tied logits) + self.embed_proj = QATLinear(self.embed_dim, model_dim, bias=False, + fp_storage=fp_storage) if self.embed_dim != model_dim else None + self.embed_proj_rev = QATLinear(model_dim, self.embed_dim, bias=False, + fp_storage=fp_storage) if self.embed_dim != model_dim else None + + # U-Net structure + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + # Skip weights initialized to ONES (not zeros — zeros costs 0.010 bpb) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + + # Transformer blocks + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, group_size, activation, rope_type, yarn_max_len, + train_seq_len, binarize_act, binarize_down, smear) + for _ in range(num_layers) + ]) + self.final_norm = RMSNorm() + + # Logit head (tied to embedding) + self.lm_head = QATLinear(model_dim, vocab_size, bias=False, fp_storage=fp_storage) + self.lm_head._zero_init = True + if tie_embeddings: + self.lm_head.weight.requires_grad_(False) + self.vocab_bias = nn.Parameter(torch.zeros(vocab_size, dtype=torch.float32)) + + self._init_weights(tied_embed_init_std) + + def _init_weights(self, tied_embed_init_std): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=tied_embed_init_std) + for module in self.modules(): + if isinstance(module, XNORLinear) and not getattr(module, "_zero_init", False): + nn.init.normal_(module.weight, mean=0.0, std=0.02) + elif isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _compute_logits(self, x: Tensor) -> Tensor: + if self.tie_embeddings: + proj = self.embed_proj_rev(x) if self.embed_proj_rev is not None else x + return F.linear(proj, self.tok_emb.weight.to(x.dtype)) + return self.lm_head(x) + + def _softcap(self, logits: Tensor) -> Tensor: + s = self.logit_softcap + if self.softcap_type == "tanh": + return s * torch.tanh(logits / s) + # Poly softcap — fuses with torch.compile (tanh does not) + x_sc = torch.clamp(logits / s, -2.0, 2.0) + x2 = x_sc * x_sc + return s * torch.clamp(x_sc * (1.0 - x2 / 3.0 + x2 * x2 / 15.0), -1.0, 1.0) + + def forward(self, input_ids: Tensor, target_ids: Tensor, reduction="mean", temperature=1.0): + x = self.tok_emb(input_ids) + x = x.float() + if self.embed_proj is not None: + x = self.embed_proj(x) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + # U-Net encoder + skips = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + + # U-Net decoder with skip connections + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype) * skips.pop() + x = self.blocks[bi](x, x0) + + x_normed = self.final_norm(x) + x_flat = x_normed.reshape(-1, x_normed.size(-1)) + targets = target_ids.reshape(-1) + logits = self._softcap(self._compute_logits(x_flat)) + + if temperature != 1.0: + logits = logits / temperature + if reduction == "none": + return F.cross_entropy(logits.float(), targets, reduction="none").reshape(input_ids.shape) + + # Fused CE + Z-loss (logsumexp² regularization keeps STE gradients sharp) + logits_f = logits.float() + lse = torch.logsumexp(logits_f, dim=-1) + target_logits = logits_f.gather(1, targets.unsqueeze(1)).squeeze(1) + return (lse - target_logits).mean() + 1e-4 * (lse ** 2).mean() + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- +def build_luts(sp, vocab_size, device): + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + +def ld_val(pattern, seq_len, max_tok=int(os.environ.get("VAL_MAX_TOKENS", 500000))): + files = sorted(glob.glob(pattern)) + assert files, f"No files: {pattern}" + tok = torch.cat([ld_shard(Path(p)) for p in files]).contiguous() + if max_tok > 0: tok = tok[:max_tok + 1] + u = ((tok.numel() - 1) // seq_len) * seq_len + return tok[:u + 1] + +def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, temperature=1.0): + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_seqs = max(1, local_batch_tokens // args.train_seq_len) + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_start in range(seq_start, seq_end, local_batch_seqs): + batch_end = min(batch_start + local_batch_seqs, seq_end) + raw_start = batch_start * args.train_seq_len + raw_end = batch_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64) + x, y = local[:-1].reshape(-1, args.train_seq_len), local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + batch_loss = model(x, y, temperature=temperature).detach() + n = float(y.numel()) + loss_sum += batch_loss.to(torch.float64) * n + token_count += n + prev_ids, tgt_ids = x.reshape(-1), y.reshape(-1) + tok_bytes = base_bytes_lut[tgt_ids].to(torch.int16) + tok_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.int16) + byte_count += tok_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, token_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + val_loss = loss_sum / token_count + bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) + model.train() + return float(val_loss.item()), float(bpb) + +def eval_val_sliding(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, temperature=1.0): + seq_len = args.train_seq_len + batch_size = args.sliding_batch_size + total_tokens = val_tokens.numel() - 1 + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + all_starts = list(range(0, total_tokens - seq_len, stride)) + my_starts = all_starts[rank::world_size] + model.eval() + with torch.inference_mode(): + for i in range(0, len(my_starts), batch_size): + batch_starts = my_starts[i:i + batch_size] + starts_t = torch.tensor(batch_starts, dtype=torch.int64) + offsets = torch.arange(seq_len + 1, dtype=torch.int64) + indices = starts_t.unsqueeze(1) + offsets.unsqueeze(0) + local_batch = val_tokens[indices].to(device=device, dtype=torch.int64, non_blocking=True) + x, y = local_batch[:, :-1], local_batch[:, 1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_token_loss = model(x, y, reduction="none", temperature=temperature).detach() + for b, start in enumerate(batch_starts): + score_from = 0 if start == 0 else seq_len - stride + scored = per_token_loss[b, score_from:] + sx, sy = x[b, score_from:], y[b, score_from:] + loss_sum += scored.to(torch.float64).sum() + token_count += scored.numel() + tok_bytes = base_bytes_lut[sy].to(torch.int16) + tok_bytes += (has_leading_space_lut[sy] & ~is_boundary_token_lut[sx]).to(torch.int16) + byte_count += tok_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, token_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + val_loss = loss_sum / token_count + bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) + model.train() + return float(val_loss.item()), float(bpb) + +def find_temp(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut): + best_t, best_loss = 1.0, float("inf") + for t in [0.85, 0.90, 0.95, 1.00, 1.05, 1.10]: + loss, _ = eval_val(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, temperature=t) + if loss < best_loss: + best_loss = loss + best_t = t + return best_t + +# --------------------------------------------------------------------------- +# Checkpointing +# --------------------------------------------------------------------------- +def _ckpt_path(d, step): return os.path.join(d, f"ckpt_step{step:07d}.pt") + +def _latest_checkpoint(d): + if not os.path.isdir(d): return None + ckpts = sorted(glob.glob(os.path.join(d, "ckpt_step*.pt"))) + return ckpts[-1] if ckpts else None + +def save_checkpoint(d, step, model, optimizers, ms, _untied, ema_model, _ema_started, _ema_steps): + os.makedirs(d, exist_ok=True) + path = _ckpt_path(d, step) + tmp = path + ".tmp" + torch.save({ + "step": step, "training_time_ms": ms, "model": model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "rng_cpu": torch.get_rng_state(), "rng_cuda": torch.cuda.get_rng_state(), + "flags": {"_untied": _untied, "_ema_started": _ema_started, "_ema_steps": _ema_steps}, + "ema_model": ema_model.state_dict() if ema_model is not None else None, + }, tmp) + os.replace(tmp, path) + +def load_checkpoint(path, model, optimizers, device, ema_model=None): + p = torch.load(path, map_location=device, weights_only=False) + model.load_state_dict(p["model"], strict=True) + for opt, sd in zip(optimizers, p["optimizers"]): + opt.load_state_dict(sd) + if ema_model is not None and p.get("ema_model") is not None: + ema_model.load_state_dict(p["ema_model"], strict=True) + torch.set_rng_state(p["rng_cpu"].cpu()) + torch.cuda.set_rng_state(p["rng_cuda"].cpu()) + f = p.get("flags", {}) + return p["step"], p["training_time_ms"], f.get("_untied", False), f.get("_ema_started", False), f.get("_ema_steps", 0) + +# --------------------------------------------------------------------------- +# Main training +# --------------------------------------------------------------------------- +def main(): + args = Hyperparameters() + code = Path(__file__).read_text(encoding="utf-8") + + global _INT8_KERNEL + _INT8_KERNEL = args.use_int8_kernel + + global _TRITON_KERNEL + _TRITON_KERNEL = args.use_triton_kernel and HAS_TRITON + + global _SCALE_FP8_STE + _SCALE_FP8_STE = args.scale_fp8 + + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = world_size > 1 + grad_accum_steps = max(1, 8 // world_size) + grad_scale = 1.0 / grad_accum_steps + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + os.makedirs("logs/cuda/xnor/", exist_ok=True) + logfile = f"logs/cuda/xnor/{args.run_id}.txt" if master_process else None + if master_process: print(logfile) + def log0(msg, console=True): + if not master_process: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + val_tokens = ld_val(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_luts(sp, args.vocab_size, device) + + # --- Model --- + # BINARIZE_ACTIVATIONS: 0=BWN, 1=full XNOR, 2=XNOR but BWN on MLP down proj + _ba = args.binarize_activations + _binarize_act = _ba >= 1 + _binarize_down = _ba == 1 + base_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + group_size=args.xnor_group_size, activation=args.activation_type, + embed_dim=args.embed_dim, fp_storage=args.fp_storage, + softcap_type=args.softcap_type, rope_type=args.rope_type, + yarn_max_len=args.yarn_max_len, train_seq_len=args.train_seq_len, + binarize_act=_binarize_act, binarize_down=_binarize_down, + smear=args.smear, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, nn.Linear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.tie_embeddings: + base_model.lm_head.weight.requires_grad_(False) + + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.recompile_limit = 16 + # No DDP needed for EGGROLL - each GPU evaluates different perturbations + # No torch.compile - forward only, perturbations modify weights in-place + + # --- Load pretrained weights --- + if args.eggroll_load: + log0(f"loading pretrained weights from {args.eggroll_load}") + if args.eggroll_load.endswith(".ptz"): + # Compressed artifact + with open(args.eggroll_load, "rb") as f: + raw = f.read() + method_byte = raw[0] + compressed = raw[1:] + if method_byte == 2 and HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + decompressed = dctx.decompress(compressed) + elif method_byte == 1 and HAS_BROTLI: + decompressed = brotli.decompress(compressed) + else: + decompressed = lzma.decompress(compressed) + loaded = torch.load(io.BytesIO(decompressed), map_location="cpu", weights_only=False) + base_model.load_state_dict(deq_sd(loaded), strict=False) + else: + # Raw checkpoint + ckpt = torch.load(args.eggroll_load, map_location="cpu", weights_only=False) + sd = ckpt["model"] if "model" in ckpt else ckpt + base_model.load_state_dict(sd, strict=False) + log0("pretrained weights loaded") + + # --- EGGROLL setup --- + # No gradients needed - forward only + for p in base_model.parameters(): + p.requires_grad_(False) + binary_params = get_binary_params(base_model, num_layers=args.num_layers, + last_n_layers=args.eggroll_layers) + log0(f"EGGROLL binary params: {len(binary_params)} tensors, " + f"{sum(p.numel() for _,p in binary_params)} params") + + # LoRA mode: perturb small LoRA matrices instead of full weights + lora = None + if args.eggroll_lora_rank > 0: + lora, lora_list = create_lora_params(binary_params, args.eggroll_lora_rank, device) + perturb_params = lora_list + lora_total = sum(p.numel() for _, p in lora_list) + log0(f"EGGROLL LoRA: rank={args.eggroll_lora_rank}, " + f"{len(lora_list)} tensors, {lora_total} params") + else: + perturb_params = binary_params + + # --- Log --- + log0("--- Hyperparameters ---", console=False) + log0(" ".join(f"{a}={getattr(args,a)}" for a in sorted(dir(args)) + if not a.startswith("_") and a not in ("train_files","val_files") + and not callable(getattr(args,a))), console=False) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"XNOR-EGGROLL | params:{n_params} L:{args.num_layers} d:{args.model_dim} " + f"h:{args.num_heads} kv:{args.num_kv_heads} mlp:{args.mlp_mult}x " + f"act:{args.activation_type} xnor_act:{args.binarize_activations} " + f"g:{args.xnor_group_size} ws:{world_size} pop:{args.pop_size} " + f"sigma:{args.eggroll_sigma} rank:{args.eggroll_rank}") + + # --- Data --- + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step, elapsed_ms): + import math + if args.lr_schedule == "cosine": + if max_wallclock_ms is None: + total = args.iterations + warmup = min(args.lr_warmup_steps, total) + if step < warmup: + return step / max(warmup, 1) + progress = (step - warmup) / max(total - warmup, 1) + else: + warmup_ms = args.lr_warmup_steps * (elapsed_ms / max(step, 1)) if step > 0 else 1000.0 + if elapsed_ms < warmup_ms: + return elapsed_ms / max(warmup_ms, 1e-9) + progress = (elapsed_ms - warmup_ms) / max(max_wallclock_ms - warmup_ms, 1e-9) + return 0.5 * (1.0 + math.cos(math.pi * min(progress, 1.0))) + if args.warmdown_fraction <= 0: return 1.0 + if max_wallclock_ms is None: + warmdown_start = int(args.iterations * (1.0 - args.warmdown_fraction)) + return max((args.iterations - step) / max(args.iterations * args.warmdown_fraction, 1), 0.0) if step >= warmdown_start else 1.0 + warmdown_ms = max_wallclock_ms * args.warmdown_fraction + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # --- Seq len scheduling --- + seq_len_stages = [args.train_seq_len] + if args.seq_len_schedule: + seq_len_stages = [] + sl = 128 + while sl <= args.train_seq_len: + seq_len_stages.append(sl) + sl *= 2 + if seq_len_stages[-1] != args.train_seq_len: + seq_len_stages.append(args.train_seq_len) + log0(f"seq_len_schedule: {seq_len_stages}") + + def get_seq_len(step, elapsed_ms): + if len(seq_len_stages) <= 1: + return seq_len_stages[0] + n = len(seq_len_stages) + if max_wallclock_ms is not None: + stage_idx = min(int(elapsed_ms / (max_wallclock_ms / n)), n - 1) + else: + stage_idx = min(int(step / (args.iterations / n)), n - 1) + return seq_len_stages[stage_idx] + + # --- Main EGGROLL loop --- + training_time_ms = 0.0 + stop_after_step = None + pop_per_gpu = args.pop_size // world_size + assert pop_per_gpu % 2 == 0, "pop_size / world_size must be even for antithetic pairs" + pairs_per_gpu = pop_per_gpu // 2 + + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + + base_model.eval() + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + if lora is not None: + merge_lora(binary_params, lora) + val_loss, val_bpb = eval_val(args, base_model, rank, world_size, device, 1, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + if lora is not None: + unmerge_lora(binary_params, lora) + log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms") + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms") + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + lr_scale = lr_mul(step, elapsed_ms) + current_seq_len = get_seq_len(step, elapsed_ms) + + # Get data batch (same for all perturbations) + x, y = train_loader.next_batch(args.train_batch_tokens, current_seq_len, 1) + + # Evaluate perturbations (antithetic pairs) + fitnesses_pos = [] + fitnesses_neg = [] + local_seeds = [] + + for i in range(pairs_per_gpu): + global_idx = step * (args.pop_size // 2) + rank * pairs_per_gpu + i + seed = global_idx * 137 + 42 # deterministic seed + + # Positive perturbation + apply_perturbation(perturb_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=+1.0, device=device) + if lora is not None: + merge_lora(binary_params, lora) + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss_pos = base_model(x, y) + fitnesses_pos.append(-loss_pos.item()) + if lora is not None: + unmerge_lora(binary_params, lora) + remove_perturbation(perturb_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=+1.0, device=device) + + # Negative perturbation (antithetic) + apply_perturbation(perturb_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=-1.0, device=device) + if lora is not None: + merge_lora(binary_params, lora) + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss_neg = base_model(x, y) + fitnesses_neg.append(-loss_neg.item()) + if lora is not None: + unmerge_lora(binary_params, lora) + remove_perturbation(perturb_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=-1.0, device=device) + + local_seeds.append(seed) + + # Gather fitnesses across GPUs + if distributed: + all_pos = [None] * world_size + all_neg = [None] * world_size + all_seeds_list = [None] * world_size + dist.all_gather_object(all_pos, fitnesses_pos) + dist.all_gather_object(all_neg, fitnesses_neg) + dist.all_gather_object(all_seeds_list, local_seeds) + fitnesses_pos = [f for sublist in all_pos for f in sublist] + fitnesses_neg = [f for sublist in all_neg for f in sublist] + all_seeds = [s for sublist in all_seeds_list for s in sublist] + else: + all_seeds = local_seeds + + # Shape fitnesses + shaped = shape_fitnesses(fitnesses_pos, fitnesses_neg, method=args.fitness_shaping) + + # Update LoRA or direct weights + lr = args.eggroll_lr * lr_scale + compute_eggroll_update(perturb_params, all_seeds, shaped, args.eggroll_sigma, + args.eggroll_rank, lr, device=device) + + step += 1 + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + if args.train_log_every > 0 and step % args.train_log_every == 0: + mean_fitness = sum(fitnesses_pos) / len(fitnesses_pos) + log0(f"step:{step}/{args.iterations} fitness:{mean_fitness:.4f} " + f"t:{approx_ms:.0f}ms step_avg:{approx_ms/step:.1f}ms") + + # Wallclock cap + if stop_after_step is None and max_wallclock_ms is not None and step % 10 == 0: + reached_cap = approx_ms >= max_wallclock_ms + if distributed: + cap_t = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(cap_t, op=dist.ReduceOp.MAX) + reached_cap = bool(cap_t.item()) + if reached_cap: + stop_after_step = step + + # --- Merge LoRA into base weights for serialization --- + if lora is not None: + merge_lora(binary_params, lora) + log0("LoRA merged into base weights for serialization") + + # --- Sign-sort permutation for compression --- + # Permute MLP hidden dimension so same-sign weights cluster together. + # Sort columns of down_proj by their sign pattern (lexicographic on sign bits). + # Apply matching row permutation to up_proj. Model output is identical. + if master_process: + sd = base_model.state_dict() + for layer_idx in range(args.num_layers): + # Find the MLP weight pairs + down_key = f"blocks.{layer_idx}.mlp.proj.weight" + if args.activation_type == "swiglu": + up_key = f"blocks.{layer_idx}.mlp.gate_up.weight" + else: + up_key = f"blocks.{layer_idx}.mlp.fc.weight" + if down_key not in sd or up_key not in sd: + continue + w_down = sd[down_key] # [dim, hidden] + w_up = sd[up_key] # [hidden, dim] (or [hidden*2, dim] for swiglu) + hidden = w_down.shape[1] + # Sort hidden dim columns of w_down by sign pattern + signs = (w_down >= 0).to(torch.int8) # [dim, hidden] + # Pack sign columns into sortable keys (use first 64 rows as sort key) + n_key_rows = min(64, signs.shape[0]) + sort_keys = torch.zeros(hidden, dtype=torch.int64, device=w_down.device) + for r in range(n_key_rows): + sort_keys += signs[r, :].long() << (n_key_rows - 1 - r) + perm = sort_keys.argsort() + # Apply permutation + sd[down_key] = w_down[:, perm] + if args.activation_type == "swiglu": + half = w_up.shape[0] // 2 + sd[up_key] = torch.cat([w_up[:half][perm], w_up[half:][perm]], dim=0) + else: + sd[up_key] = w_up[perm] + if base_model.tie_embeddings: + sd.pop("lm_head.weight", None) + q_obj, q_stats = q_sd(sd, group_size=args.xnor_group_size, + fp_storage=args.fp_storage, scale_fp8=args.scale_fp8) + buf = io.BytesIO() + torch.save(q_obj, buf) + raw_bytes = buf.getvalue() + # Compare LZMA vs Brotli, keep smallest + lzma_blob = lzma.compress(raw_bytes, preset=9) + best_blob, best_method = lzma_blob, "lzma" + comp_log = f"compression: lzma={len(lzma_blob)/1e6:.2f}MB" + if HAS_BROTLI: + brotli_blob = brotli.compress(raw_bytes, quality=11) + comp_log += f" brotli={len(brotli_blob)/1e6:.2f}MB" + if len(brotli_blob) < len(best_blob): + best_blob, best_method = brotli_blob, "brotli" + if HAS_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + zstd_blob = cctx.compress(raw_bytes) + comp_log += f" zstd={len(zstd_blob)/1e6:.2f}MB" + if len(zstd_blob) < len(best_blob): + best_blob, best_method = zstd_blob, "zstd" + log0(comp_log) + log0(f"using: {best_method}") + with open("final_model.xnor.ptz", "wb") as f: + # Method byte: 0=lzma, 1=brotli, 2=zstd + method_id = {"lzma": 0, "brotli": 1, "zstd": 2}[best_method] + f.write(bytes([method_id])) + f.write(best_blob) + artifact_bytes = 1 + len(best_blob) + code_bytes = len(code.encode("utf-8")) + total = artifact_bytes + code_bytes + log0(f"artifact:{artifact_bytes/1e6:.2f}MB binary:{q_stats['binary_params']}({q_stats['binary_bytes']}B) " + f"fp:{q_stats['fp_params']}({q_stats['fp_bytes']}B) code:{code_bytes}") + log0(f"budget:{total}/{16000000} ({total/1e6:.2f}/{16.00:.2f}MB) " + f"{'FITS' if total <= 16000000 else 'OVER'}") + + # --- Roundtrip eval --- + if distributed: dist.barrier() + with open("final_model.xnor.ptz", "rb") as f: + raw = f.read() + method_byte = raw[0] + compressed = raw[1:] + if method_byte == 2 and HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + decompressed = dctx.decompress(compressed) + elif method_byte == 1 and HAS_BROTLI: + decompressed = brotli.decompress(compressed) + else: + decompressed = lzma.decompress(compressed) + loaded = torch.load(io.BytesIO(decompressed), map_location="cpu", weights_only=False) + base_model.load_state_dict(deq_sd(loaded), strict=False) + torch._dynamo.reset() + q_val_loss, q_val_bpb = eval_val(args, base_model, rank, world_size, device, 1, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"final_xnor_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f}") + + opt_temp = 1.0 + if args.temp_scaling: + torch.cuda.synchronize() + t_temp = time.perf_counter() + calibration_tokens = train_loader.stream.take(65536).to(device) + opt_temp = find_temp(args, base_model, rank, world_size, device, 1, + calibration_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + torch.cuda.synchronize() + log0(f"temp_scaling optimal_T:{opt_temp:.2f} time:{1000*(time.perf_counter()-t_temp):.0f}ms") + + if args.sliding_eval: + torch.cuda.synchronize() + t_sl = time.perf_counter() + sw_loss, sw_bpb = eval_val_sliding(args, base_model, rank, world_size, device, 1, + val_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, stride=args.sliding_eval_stride, + temperature=opt_temp) + torch.cuda.synchronize() + log0(f"final_sliding val_loss:{sw_loss:.4f} val_bpb:{sw_bpb:.4f} " + f"(stride={args.sliding_eval_stride}, T={opt_temp:.2f}) time:{1000*(time.perf_counter()-t_sl):.0f}ms") + + if distributed: dist.destroy_process_group() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/train_gpt_cuda_xnor_eggroll_nlayers.py b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/train_gpt_cuda_xnor_eggroll_nlayers.py new file mode 100644 index 0000000000..2de3a3cf2c --- /dev/null +++ b/records/track_non_record_16mb/2026-04-05_118M_XNOR-Net_FP8_1024D_10L/train_gpt_cuda_xnor_eggroll_nlayers.py @@ -0,0 +1,1493 @@ +"""XNOR-Net LLM training script for OpenAI's Parameter Golf Challenge. +Ciprian-Florin Ifrim - April 2026 + +Based on the Binary-Weight-Network script, extended to full XNOR-Net: +both weights AND activations are binarized using sign() with per-group/per-token +scaling factors derived from the XNOR-Net paper (Rastegari et al. 2016). + +Architecture: U-Net transformer with skip connections that provides error correction +for the information loss inherent in binary quantization of both weights and activations. +""" + +import copy +import glob +import io +import math +import os +import random +import sys +import time +import lzma +from pathlib import Path +try: + import brotli + HAS_BROTLI = True +except ImportError: + HAS_BROTLI = False +try: + import zstandard as zstd + HAS_ZSTD = True +except ImportError: + HAS_ZSTD = False +try: + from gram_newton_schulz import GramNewtonSchulz, POLAR_EXPRESS_COEFFICIENTS + HAS_GRAM_NS = True +except ImportError: + HAS_GRAM_NS = False +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False +def _get_flash_attn_func(): + """GPU-aware FlashAttention selection. + H100: FA3 -> FA2 -> SDPA + 5090: FA4 -> FA2 -> SDPA + Other: FA2 -> SDPA + """ + if torch.cuda.is_available(): + sm = torch.cuda.get_device_capability() + # Hopper (SM 90) — prefer FA3 + if sm[0] == 9 and sm[1] == 0: + try: + from flash_attn_interface import flash_attn_func + return flash_attn_func + except ImportError: + pass + # Blackwell (SM 100/120) — try FA4 + if sm[0] >= 10: + try: + from flash_attn.cute import flash_attn_func + return flash_attn_func + except ImportError: + pass + # FA2 — universal fallback with CUDA kernels + try: + from flash_attn import flash_attn_func + return flash_attn_func + except ImportError: + pass + # PyTorch native SDPA — always works, no install needed + def _sdpa_flash_attn_func(q, k, v, causal=False, **kwargs): + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + o = F.scaled_dot_product_attention(q, k, v, is_causal=causal, enable_gqa=True) + return o.transpose(1, 2) + return _sdpa_flash_attn_func +flash_attn_func = _get_flash_attn_func() + +# --------------------------------------------------------------------------- +# Hyperparameters +# --------------------------------------------------------------------------- +def _e(k, d, t=str): + v = os.environ.get(k, str(d)) + if t == bool: return bool(int(v)) + return t(v) + +class Hyperparameters: + # Data + data_path = _e("DATA_PATH", "./data/datasets/fineweb10B_sp8192") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = _e("TOKENIZER_PATH", "./data/tokenizers/fineweb_8192_bpe.model") + run_id = os.environ.get("RUN_ID", f"xnor_{int(time.time())}") + seed = _e("SEED", 1337, int) + # Compile + compile_mode = _e("COMPILE_MODE", "default") + # Eval + val_batch_size = _e("VAL_BATCH_SIZE", 524288, int) + val_loss_every = _e("VAL_LOSS_EVERY", 500, int) + train_log_every = _e("TRAIN_LOG_EVERY", 10, int) + sliding_eval = _e("SLIDING_EVAL", 0, bool) + sliding_eval_stride = _e("SLIDING_EVAL_STRIDE", 16, int) + sliding_batch_size = _e("SLIDING_BATCH_SIZE", 512, int) + temp_scaling = _e("TEMP_SCALING", 0, bool) + # Training schedule + iterations = _e("ITERATIONS", 2000, int) + warmdown_fraction = _e("WARMDOWN_FRACTION", 0.15, float) + lr_warmup_steps = _e("LR_WARMUP_STEPS", 5, int) + lr_schedule = _e("LR_SCHEDULE", "linear") # "linear" (current) or "cosine" + seq_len_schedule = _e("SEQ_LEN_SCHEDULE", 0, bool) # ramp seq_len from 128 to train_seq_len + train_batch_tokens = _e("TRAIN_BATCH_TOKENS", 524288, int) + train_seq_len = _e("TRAIN_SEQ_LEN", 1024, int) + max_wallclock_seconds = _e("MAX_WALLCLOCK_SECONDS", 0.0, float) + # Architecture — from XNOR-UNet spec + vocab_size = _e("VOCAB_SIZE", 8192, int) + num_layers = _e("NUM_LAYERS", 15, int) + num_kv_heads = _e("NUM_KV_HEADS", 4, int) + model_dim = _e("MODEL_DIM", 768, int) + num_heads = _e("NUM_HEADS", 8, int) + mlp_mult = _e("MLP_MULT", 4, int) + tie_embeddings = _e("TIE_EMBEDDINGS", 1, int) + embed_dim = _e("EMBED_DIM", 312, int) + activation_type = _e("ACTIVATION", "relu2") + # Attention + rope_base = _e("ROPE_BASE", 5000.0, float) + rope_type = _e("ROPE_TYPE", "yarn") + yarn_max_len = _e("YARN_MAX_LEN", 2048, int) + qk_gain_init = _e("QK_GAIN_INIT", 2.25, float) + # Logits + logit_softcap = _e("LOGIT_SOFTCAP", 10.0, float) + softcap_type = _e("SOFTCAP_TYPE", "poly") + tied_embed_init_std = _e("TIED_EMBED_INIT_STD", 0.005, float) + # Architecture extras + smear = _e("SMEAR", 1, bool) + # XNOR quantization + xnor_group_size = _e("XNOR_GROUP_SIZE", 128, int) + # Controls activation binarization: 0 = none (BWN), 1 = all (full XNOR), 2 = skip MLP down proj + binarize_activations = _e("BINARIZE_ACTIVATIONS", 1, int) + # Use INT8 tensor cores for binary matmul (speed test) + use_int8_kernel = _e("USE_INT8_KERNEL", 0, bool) + # Use Triton XNOR+POPCOUNT kernel (true 1-bit matmul) + use_triton_kernel = _e("USE_TRITON_KERNEL", 0, bool) + # FP8 QAT for non-binary params + _fp_raw = os.environ.get("FP_STORAGE", "FP8") + fp_storage = True if _fp_raw == "FP8" else False + # Scale storage for binary weight groups: FP8 or BF16 + _scale_raw = os.environ.get("SCALE_STORAGE", "BF16") + scale_fp8 = True if _scale_raw == "FP8" else False + # Compression regularizer: penalizes sign entropy within groups + sign_compress_reg = _e("SIGN_COMPRESS_REG", 0.0, float) + # EGGROLL + pop_size = _e("POP_SIZE", 256, int) # population size (antithetic pairs = pop_size/2) + eggroll_sigma = _e("EGGROLL_SIGMA", 0.01, float) # perturbation scale + eggroll_lr = _e("EGGROLL_LR", 0.001, float) # learning rate + eggroll_rank = _e("EGGROLL_RANK", 1, int) # rank of perturbation matrices + fitness_shaping = _e("FITNESS_SHAPING", "rank") # "rank" or "sign" + eggroll_load = _e("EGGROLL_LOAD", "") # path to pretrained .ptz to load + eggroll_layers = _e("EGGROLL_LAYERS", 0, int) # 0 = all layers, N = last N layers only + # Checkpointing + checkpoint_every = _e("CHECKPOINT_EVERY", 0, int) + checkpoint_dir = _e("CHECKPOINT_DIR", "./checkpoints") + # Diagnostics + churn_log_every = _e("CHURN_LOG_EVERY", 500, int) + +# Scalar/low-dim parameter names (excluded from Muon, use Adam) +CTP = ("attn_scale", "mlp_scale", "resid_mix", "q_gain", "skip_weight", + "skip_weights", "vocab_bias", "smear") + +# --------------------------------------------------------------------------- +# Binary packing (1 bit per weight, lossless) +# --------------------------------------------------------------------------- +def pack_binary(q: Tensor) -> tuple[bytes, int]: + bits = ((q.reshape(-1).to(torch.int8) + 1) // 2).numpy().astype(np.uint8) + n = len(bits) + pad = (8 - n % 8) % 8 + if pad: + bits = np.concatenate([bits, np.zeros(pad, dtype=np.uint8)]) + groups = bits.reshape(-1, 8) + packed = np.zeros(len(groups), dtype=np.uint8) + for i in range(8): + packed |= groups[:, i] << i + return packed.tobytes(), n + +def unpack_binary(data: bytes, n: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8) + bits = np.zeros((len(packed), 8), dtype=np.int8) + for i in range(8): + bits[:, i] = (packed >> i) & 1 + flat = bits.reshape(-1)[:n] + return torch.from_numpy(flat.astype(np.int8) * 2 - 1) + +# --------------------------------------------------------------------------- +# State dict serialization +# --------------------------------------------------------------------------- +def q_sd(state_dict: dict, group_size: int = 128, fp_storage=False, scale_fp8=False) -> tuple[dict, dict]: + """Binary pack large 2D weight matrices, BF16/FP8 for everything else.""" + quantized = {} + stats = {"binary_params": 0, "binary_bytes": 0, "fp_params": 0, "fp_bytes": 0} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().float().contiguous() + t_orig_shape = list(t.shape) + if t.ndim == 3: + t = t.reshape(t.shape[0], -1) + is_binary = ( + t.ndim == 2 and t.numel() > 65_536 + and "tok_emb" not in name and "lm_head" not in name + and "embed_proj" not in name + ) + if is_binary: + pad = (group_size - t.shape[1] % group_size) % group_size + t_padded = F.pad(t, (0, pad)) if pad > 0 else t + t_grouped = t_padded.reshape(-1, group_size) + scale = t_grouped.abs().mean(-1, keepdim=True).clamp(min=1e-8) + q = torch.where(t_grouped >= 0, + torch.ones_like(t_grouped, dtype=torch.int8), + -torch.ones_like(t_grouped, dtype=torch.int8)) + packed_bytes, n_bits = pack_binary(q) + if scale_fp8: + scale_stored = scale.to(torch.float8_e4m3fn).squeeze(-1) + scale_bytes_per = 1 + else: + scale_stored = scale.bfloat16().squeeze(-1) + scale_bytes_per = 2 + quantized[name] = { + "type": "binary", "packed": packed_bytes, + "scale": scale_stored, + "shape": list(t.shape), "padded_cols": t_padded.shape[1], + "group_size": group_size, "n_bits": n_bits, + "orig_shape": t_orig_shape, + } + stats["binary_params"] += t.numel() + stats["binary_bytes"] += len(packed_bytes) + scale.numel() * scale_bytes_per + elif fp_storage and t.ndim == 2: + quantized[name] = {"type": "fp8", "data": t.to(torch.float8_e4m3fn)} + stats["fp_params"] += t.numel() + stats["fp_bytes"] += t.numel() + else: + quantized[name] = {"type": "bf16", "data": t.bfloat16()} + stats["fp_params"] += t.numel() + stats["fp_bytes"] += t.numel() * 2 + return quantized, stats + +def deq_sd(quantized: dict, target_dtype=torch.bfloat16): + """Reconstruct state dict from quantized representation.""" + out = {} + for name, entry in quantized.items(): + if entry["type"] == "binary": + q = unpack_binary(entry["packed"], entry["n_bits"]) + q = q.float().reshape(-1, entry["group_size"]) + scale = entry["scale"].float().unsqueeze(-1) + # No shrinkage correction: binary q_absmean = 1.0 always + t = (q * scale).reshape(-1, entry["padded_cols"]) + shape = entry["shape"] + result = t[:shape[0], :shape[1]].to(target_dtype) + orig = entry.get("orig_shape") + out[name] = result.reshape(orig).contiguous() if orig and orig != shape else result.contiguous() + elif entry["type"] == "fp8": + out[name] = entry["data"].to(torch.float32).to(target_dtype).contiguous() + else: + out[name] = entry["data"].to(target_dtype).contiguous() + return out + +# --------------------------------------------------------------------------- +# Diagnostics +# --------------------------------------------------------------------------- +_prev_committed: dict = {} +def churn_fn(model: nn.Module, group_size: int = 128): + global _prev_committed + total = flipped = 0 + with torch.no_grad(): + for name, p in model.named_parameters(): + if p.ndim == 2 and "weight" in name and p.shape[0] > 1: + w = p.detach().float().reshape(-1, group_size) + q = torch.where(w >= 0, torch.ones_like(w), -torch.ones_like(w)).cpu().numpy() + if name in _prev_committed: + flipped += int(np.sum(q != _prev_committed[name])) + total += q.size + _prev_committed[name] = q + return flipped / max(total, 1) + +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# Muon optimizer +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# EGGROLL: Evolution-Guided optimization with low-rank perturbations +# --------------------------------------------------------------------------- +def get_binary_params(model, num_layers=10, last_n_layers=0): + """Get list of (name, param) for parameters that will be perturbed. + last_n_layers=0 means all layers, N means only the last N blocks.""" + params = [] + for name, p in model.named_parameters(): + if p.ndim == 2 and p.numel() > 65536 and "tok_emb" not in name and "lm_head" not in name and "embed_proj" not in name: + if last_n_layers > 0: + # Only include params from last N blocks + match = False + for layer_idx in range(num_layers - last_n_layers, num_layers): + if f"blocks.{layer_idx}." in name: + match = True + break + if not match: + continue + params.append((name, p)) + return params + +def apply_perturbation(binary_params, seed, sigma, rank, sign=1.0, device="cuda"): + """Apply rank-r perturbation to binary params using deterministic RNG.""" + rng = torch.Generator(device=device) + rng.manual_seed(seed) + for name, p in binary_params: + m, n = p.shape + for r in range(rank): + a = torch.randn(m, device=device, generator=rng) + b = torch.randn(n, device=device, generator=rng) + p.data.add_(a.outer(b), alpha=sign * sigma / math.sqrt(rank)) + +def remove_perturbation(binary_params, seed, sigma, rank, sign=1.0, device="cuda"): + """Remove previously applied perturbation (subtract it).""" + apply_perturbation(binary_params, seed, sigma, rank, sign=-sign, device=device) + +def compute_eggroll_update(binary_params, seeds, fitnesses, sigma, rank, lr, device="cuda"): + """Compute and apply the EGGROLL weight update. + Uses the efficient (diag(f) @ A).T @ B formulation for rank-1.""" + N = len(seeds) + for name, p in binary_params: + m, n = p.shape + update = torch.zeros_like(p.data) + for i, (seed, f) in enumerate(zip(seeds, fitnesses)): + rng = torch.Generator(device=device) + rng.manual_seed(seed) + for r in range(rank): + a = torch.randn(m, device=device, generator=rng) + b = torch.randn(n, device=device, generator=rng) + update.add_(a.outer(b), alpha=f / (math.sqrt(rank) * N)) + p.data.add_(update, alpha=lr / sigma) + +def shape_fitnesses(fitnesses_pos, fitnesses_neg, method="rank"): + """Shape fitnesses from antithetic pairs.""" + N = len(fitnesses_pos) + if method == "sign": + return [float(fp > fn) - float(fn > fp) for fp, fn in zip(fitnesses_pos, fitnesses_neg)] + # Rank-based shaping: combine all, rank, normalize to [-0.5, 0.5] + all_f = list(fitnesses_pos) + list(fitnesses_neg) + ranked = torch.argsort(torch.argsort(torch.tensor(all_f, dtype=torch.float32))).float() + ranked = ranked / (2 * N - 1) - 0.5 + # Return shaped fitness for positive perturbation (subtract negative's) + shaped = [(ranked[i] - ranked[N + i]).item() for i in range(N)] + return shaped + +# --------------------------------------------------------------------------- +# Data loading +# --------------------------------------------------------------------------- +def ld_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" Tensor: + chunks = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int): + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start:start + per_rank_span].pin_memory().to(self.device, non_blocking=True).to(torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x, y + +# --------------------------------------------------------------------------- +# Model components +# --------------------------------------------------------------------------- +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def apply_qat_ste(w: Tensor, fp_storage) -> Tensor: + if not fp_storage: + return w + w_sim = w.to(torch.float8_e4m3fn).to(w.dtype) + return (w_sim - w).detach() + w + +class QATLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=False, fp_storage=False): + super().__init__(in_features, out_features, bias=bias) + self.fp_storage = fp_storage + def forward(self, x: Tensor) -> Tensor: + w = apply_qat_ste(self.weight, self.fp_storage) + return F.linear(x, w.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) + +class QATEmbedding(nn.Embedding): + def __init__(self, num_embeddings, embedding_dim, fp_storage=False): + super().__init__(num_embeddings, embedding_dim) + self.fp_storage = fp_storage + def forward(self, input: Tensor) -> Tensor: + w = apply_qat_ste(self.weight, self.fp_storage) + return F.embedding(input, w, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + +# --------------------------------------------------------------------------- +# XNOR-Net Linear Layers +# --------------------------------------------------------------------------- +_INT8_KERNEL = False # set from Hyperparameters.use_int8_kernel in main() + +class _INT8XNORFn(torch.autograd.Function): + """INT8 tensor core forward, BF16 backward. + Uses per-row alpha (not per-group) so we get a single torch._int_mm call. + """ + @staticmethod + def forward(ctx, x, w): + # Per-row weight scale (single matmul) + w_alpha = w.detach().abs().mean(dim=-1).clamp_(min=1e-8) # [N] + x_beta = x.detach().abs().mean(dim=-1, keepdim=True).clamp_(min=1e-8) # [..., 1] + + w_int8 = w.detach().sign().to(torch.int8) # [N, K] + x_int8 = x.detach().sign().to(torch.int8) # [..., K] + + # Flatten to 2D for torch._int_mm(A[M,K], B[K,N]) -> [M,N] + orig_shape = x.shape[:-1] + M, K, N = orig_shape.numel(), x.shape[-1], w.shape[0] + x_2d = x_int8.reshape(M, K) + wt = w_int8.t().contiguous() # [K, N] + raw = torch._int_mm(x_2d, wt) # [M, N] INT32 + + y = raw.to(x.dtype) * x_beta.reshape(M, 1) * w_alpha.unsqueeze(0) + ctx.save_for_backward(x, w, w_alpha, x_beta) + return y.reshape(*orig_shape, N) + + @staticmethod + def backward(ctx, grad_output): + x, w, w_alpha, x_beta = ctx.saved_tensors + dtype = grad_output.dtype + + w_binary = (w.sign() * w_alpha.unsqueeze(-1)).to(dtype) # [N, K] + x_binary = (x.sign() * x_beta).to(dtype) # [..., K] + + # STE: gradient flows through sign() as identity + grad_x = grad_output @ w_binary + M = grad_output.shape[:-1].numel() + grad_w = grad_output.reshape(M, -1).t() @ x_binary.reshape(M, -1) + return grad_x, grad_w + +# --------------------------------------------------------------------------- +# Triton XNOR+POPCOUNT kernel — true 1-bit matmul +# --------------------------------------------------------------------------- +_TRITON_KERNEL = False # set from Hyperparameters in main() +_SCALE_FP8_STE = False # simulate FP8 scale quantization during training + +def pack_signs(x: Tensor) -> Tensor: + """Pack sign bits into int32. +1 -> bit 1, -1 -> bit 0. + x: [..., K] any dtype. K must be divisible by 32. + returns: [..., K//32] int32 + """ + K = x.shape[-1] + assert K % 32 == 0, f"K={K} must be divisible by 32" + bits = (x > 0).view(*x.shape[:-1], K // 32, 32) + shifts = torch.arange(32, device=x.device, dtype=torch.int32) + return (bits.to(torch.int32) << shifts).sum(dim=-1) + +if HAS_TRITON: + @triton.jit + def _xnor_matmul_fwd_kernel( + X_ptr, W_ptr, Y_ptr, alpha_ptr, beta_ptr, + M, N, K_packed, GROUP_PACKED: tl.constexpr, num_groups, + stride_xm, stride_wn, stride_yn, stride_alpha_n, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + ): + pid = tl.program_id(0) + num_n = tl.cdiv(N, BLOCK_N) + pid_m = pid // num_n + pid_n = pid % num_n + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_m = offs_m < M + mask_n = offs_n < N + + y_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.bfloat16) + + x_base = X_ptr + offs_m * stride_xm + w_base = W_ptr + offs_n * stride_wn + + # Loop over groups, each with GROUP_PACKED packed int32s + for g in range(num_groups): + group_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + k_start = g * GROUP_PACKED + for k in range(GROUP_PACKED): + x = tl.load(x_base + k_start + k, mask=mask_m, other=0) + w = tl.load(w_base + k_start + k, mask=mask_n, other=0) + # Cast broadcast XOR to int32 to prevent int64 promotion in popc + diff = (x[:, None] ^ w[None, :]).to(tl.int32) + group_acc += tl.extra.cuda.libdevice.popc(diff) + + # diff_bits = popc(XOR), dot = GROUP_SIZE - 2 * diff_bits + group_dot = (GROUP_PACKED * 32 - 2 * group_acc).to(tl.bfloat16) + + # Per-group alpha: alpha[n, g] + alpha_g = tl.load(alpha_ptr + offs_n * stride_alpha_n + g, mask=mask_n, other=1.0).to(tl.bfloat16) + y_acc += group_dot * alpha_g[None, :] + + # Per-token beta + beta = tl.load(beta_ptr + offs_m, mask=mask_m, other=1.0).to(tl.bfloat16) + y_acc = y_acc * beta[:, None] + + # Store + y_ptrs = Y_ptr + offs_m[:, None] * stride_yn + offs_n[None, :] + tl.store(y_ptrs, y_acc, mask=mask_m[:, None] & mask_n[None, :]) + +class _TritonXNORFn(torch.autograd.Function): + """Triton XNOR+POPCOUNT forward with per-group alpha, BF16 backward.""" + @staticmethod + def forward(ctx, x, w, group_size): + N, K = w.shape + assert K % group_size == 0, f"K={K} must be divisible by group_size={group_size}" + assert K % 32 == 0, f"K={K} must be divisible by 32" + num_groups = K // group_size + GROUP_PACKED = group_size // 32 + + # Per-group weight scale: [N, num_groups] + w_alpha = w.detach().reshape(N, num_groups, group_size).abs().mean(dim=-1).clamp_(min=1e-8) + if _SCALE_FP8_STE: + w_alpha = w_alpha.to(torch.float8_e4m3fn).to(w_alpha.dtype) + # Per-token activation scale: [..., 1] + x_beta_kd = x.detach().abs().mean(dim=-1, keepdim=True).clamp_(min=1e-8) + + # Pack signs into int32 + x_packed = pack_signs(x.detach()) # [..., K//32] + w_packed = pack_signs(w.detach()) # [N, K//32] + + orig_shape = x.shape[:-1] + M = orig_shape.numel() + K_packed = K // 32 + + x_2d = x_packed.reshape(M, K_packed).contiguous() + w_alpha_c = w_alpha.contiguous() # [N, num_groups] + beta_flat = x_beta_kd.reshape(M).contiguous() + + Y = torch.empty(M, N, device=x.device, dtype=torch.bfloat16) + + BLOCK_M, BLOCK_N = 64, 64 + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),) + + _xnor_matmul_fwd_kernel[grid]( + x_2d, w_packed, Y, w_alpha_c, beta_flat, + M, N, K_packed, GROUP_PACKED, num_groups, + x_2d.stride(0), w_packed.stride(0), Y.stride(0), w_alpha_c.stride(0), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + ) + + ctx.save_for_backward(x, w, w_alpha, x_beta_kd) + ctx.group_size = group_size + return Y.reshape(*orig_shape, N) + + @staticmethod + def backward(ctx, grad_output): + x, w, w_alpha, x_beta_kd = ctx.saved_tensors + dtype = grad_output.dtype + group_size = ctx.group_size + + # Reconstruct per-group binary weights for backward + N, K = w.shape + num_groups = K // group_size + w_binary = (w.sign().reshape(N, num_groups, group_size) * + w_alpha.unsqueeze(-1)).reshape(N, K).to(dtype) + x_binary = (x.sign() * x_beta_kd).to(dtype) + + grad_x = grad_output @ w_binary + M = grad_output.shape[:-1].numel() + grad_w = grad_output.reshape(M, -1).t() @ x_binary.reshape(M, -1) + return grad_x, grad_w, None # None for group_size + +class XNORLinear(nn.Linear): + """XNOR-Net linear layer: binarizes BOTH weights AND activations. + + Forward pass: + 1. Weight binarization: W_bin = sign(W) * alpha, alpha = mean(|W|) per group + 2. Activation binarization: X_bin = sign(X) * beta, beta = mean(|X|) per token + 3. Output = F.linear(X_bin, W_bin) + + Backward pass: Straight-Through Estimator (STE) — gradients flow through as + if the binarization was identity. This matches the XNOR-Net paper's training + algorithm (Algorithm 1). + + When binarize_act=False, this degrades to Binary-Weight-Network (BWN). + """ + def __init__(self, in_features, out_features, bias=False, group_size=128, + binarize_act=True): + super().__init__(in_features, out_features, bias=bias) + self.group_size = group_size + self.binarize_act = binarize_act + + def forward(self, x: Tensor) -> Tensor: + # Triton XNOR+POPCOUNT path (full XNOR only, per-row alpha) + if _TRITON_KERNEL and self.binarize_act: + y = _TritonXNORFn.apply(x, self.weight, self.group_size) + if self.bias is not None: + y = y + self.bias.to(y.dtype) + return y + # INT8 tensor core path (full XNOR only, per-row alpha) + if _INT8_KERNEL and self.binarize_act: + y = _INT8XNORFn.apply(x, self.weight) + if self.bias is not None: + y = y + self.bias.to(y.dtype) + return y + + # Standard STE path (per-group alpha) + w = self.weight.bfloat16() + g = self.group_size + + # --- Weight binarization (STE) --- + w_g = w.reshape(-1, g) + alpha = w_g.abs().mean(-1, keepdim=True).clamp(min=1e-8) + # Simulate FP8 scale quantization so model compensates for roundtrip error + if _SCALE_FP8_STE: + alpha_q = alpha.to(torch.float8_e4m3fn).to(alpha.dtype) + alpha = alpha + (alpha_q - alpha).detach() + w_sign = torch.where(w_g >= 0, torch.ones_like(w_g), -torch.ones_like(w_g)) + w_binary = w + ((w_sign * alpha).reshape(w.shape) - w).detach() + + # --- Activation binarization (STE) --- + if self.binarize_act: + beta = x.abs().mean(dim=-1, keepdim=True).clamp(min=1e-8) + x_sign = torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x)) + x_binary = x + ((x_sign * beta) - x).detach() + else: + x_binary = x + + return F.linear(x_binary, w_binary, + self.bias.to(x.dtype) if self.bias is not None else None) + + +class NormedXNORLinear(XNORLinear): + """XNOR linear with RMSNorm on input. + Used for output projections receiving un-normalized activations (attention out, MLP down). + The RMSNorm ensures zero-mean input before binarization — critical for sign() quality + (XNOR-Net paper: B-A-C-P block order, Table 3b: 30.3% vs 44.2%). + """ + def forward(self, x: Tensor) -> Tensor: + return super().forward(F.rms_norm(x, (x.size(-1),))) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(p in name for p in CTP)) and param.dtype != torch.float32: + param.data = param.data.float() + +# --------------------------------------------------------------------------- +# Rotary embeddings +# --------------------------------------------------------------------------- +class Rotary(nn.Module): + def __init__(self, dim, base=10000.0, rope_type="rope", yarn_max_len=4096, train_seq_len=1024): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + if rope_type == "yarn": + scale = train_seq_len / yarn_max_len + freq_idx = torch.arange(0, dim, 2, dtype=torch.float32) + ramp = torch.clamp((freq_idx / dim - 0.25) / 0.75, 0.0, 1.0) + inv_freq = inv_freq / (ramp * (1.0 / scale - 1.0) + 1.0) + # Precompute for max seq_len, slice at runtime + t = torch.arange(train_seq_len, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self.register_buffer("_cos", freqs.cos()[None, :, None, :], persistent=False) + self.register_buffer("_sin", freqs.sin()[None, :, None, :], persistent=False) + + def forward(self, seq_len, device, dtype): + return self._cos[:, :seq_len].to(dtype=dtype), self._sin[:, :seq_len].to(dtype=dtype) + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- +class CausalSelfAttention(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + group_size=128, rope_type="rope", yarn_max_len=4096, + train_seq_len=1024, binarize_act=True): + super().__init__() + self.num_heads, self.num_kv_heads = num_heads, num_kv_heads + self.head_dim = dim // num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + # Fused QKV — single XNOR matmul + self.c_qkv = XNORLinear(dim, self.q_size + 2 * self.kv_size, bias=False, + group_size=group_size, binarize_act=binarize_act) + # Output projection with RMSNorm before binarization (B-A-C-P order) + self.proj = NormedXNORLinear(dim, dim, bias=False, + group_size=group_size, binarize_act=binarize_act) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, + rope_type=rope_type, yarn_max_len=yarn_max_len, + train_seq_len=train_seq_len) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + qkv_out = self.c_qkv(x) + q_out, k_out, v_out = qkv_out.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q_out.reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = k_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = v_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + # QK normalization (BF16, stays float — non-negotiable for attention quality) + q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + # FlashAttention-3 (BF16 — attention scores/softmax/V are NEVER binarized) + y = flash_attn_func(q.contiguous(), k.contiguous(), v.contiguous(), causal=True) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + +# --------------------------------------------------------------------------- +# MLP +# --------------------------------------------------------------------------- +class MLP(nn.Module): + def __init__(self, dim, mlp_mult, group_size=128, activation="relu2", + binarize_act=True, binarize_down=True): + super().__init__() + hidden = mlp_mult * dim + self.activation = activation + if activation == "swiglu": + self.gate_up = XNORLinear(dim, hidden * 2, bias=False, + group_size=group_size, binarize_act=binarize_act) + else: + self.fc = XNORLinear(dim, hidden, bias=False, + group_size=group_size, binarize_act=binarize_act) + self.proj = NormedXNORLinear(hidden, dim, bias=False, + group_size=group_size, binarize_act=binarize_down) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + if self.activation == "swiglu": + gu = self.gate_up(x) + gate, up = gu.chunk(2, dim=-1) + return self.proj(F.silu(gate) * up) + elif self.activation == "signsq": + h = self.fc(x) + return self.proj(h * h.abs()) + else: # relu2 + return self.proj(torch.relu(self.fc(x)).square()) + +# --------------------------------------------------------------------------- +# SmearModule +# --------------------------------------------------------------------------- +class SmearModule(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + cumsum = x.cumsum(dim=1) + counts = torch.arange(1, x.size(1) + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) + smeared = cumsum / counts + gate = torch.tanh(self.gate.to(dtype=x.dtype)) + return x + gate * (smeared - x) + +# --------------------------------------------------------------------------- +# Transformer block +# --------------------------------------------------------------------------- +class Block(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, group_size=128, activation="relu2", + rope_type="rope", yarn_max_len=4096, train_seq_len=1024, + binarize_act=True, binarize_down=True, smear=False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, + qk_gain_init, group_size, rope_type, + yarn_max_len, train_seq_len, binarize_act) + self.mlp = MLP(dim, mlp_mult, group_size, activation, binarize_act, binarize_down) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearModule(dim) if smear else None + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0] * x + mix[1] * x0 + n = self.attn_norm(x) + x = x + self.attn_scale.to(dtype=x.dtype) * self.attn(n) + x = x + self.mlp_scale.to(dtype=x.dtype) * self.mlp(self.mlp_norm(x)) + if self.smear is not None: + x = self.smear(x) + return x + +# --------------------------------------------------------------------------- +# GPT with U-Net skip connections +# --------------------------------------------------------------------------- +class GPT(nn.Module): + def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, + mlp_mult, tie_embeddings, tied_embed_init_std, logit_softcap, + rope_base, qk_gain_init, group_size=128, activation="relu2", + embed_dim=0, fp_storage=False, softcap_type="poly", + rope_type="rope", yarn_max_len=4096, train_seq_len=1024, + binarize_act=True, binarize_down=True, smear=False): + super().__init__() + self.tie_embeddings = tie_embeddings + self.logit_softcap = logit_softcap + self.softcap_type = softcap_type + self.fp_storage = fp_storage + self.embed_dim = embed_dim if embed_dim > 0 else model_dim + + # Embedding (FP8 QAT — NEVER binarized, confirmed 0.11 bpb RT gap) + self.tok_emb = QATEmbedding(vocab_size, self.embed_dim, fp_storage=fp_storage) + # Projection from embed_dim to model_dim (and reverse for tied logits) + self.embed_proj = QATLinear(self.embed_dim, model_dim, bias=False, + fp_storage=fp_storage) if self.embed_dim != model_dim else None + self.embed_proj_rev = QATLinear(model_dim, self.embed_dim, bias=False, + fp_storage=fp_storage) if self.embed_dim != model_dim else None + + # U-Net structure + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + # Skip weights initialized to ONES (not zeros — zeros costs 0.010 bpb) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + + # Transformer blocks + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, group_size, activation, rope_type, yarn_max_len, + train_seq_len, binarize_act, binarize_down, smear) + for _ in range(num_layers) + ]) + self.final_norm = RMSNorm() + + # Logit head (tied to embedding) + self.lm_head = QATLinear(model_dim, vocab_size, bias=False, fp_storage=fp_storage) + self.lm_head._zero_init = True + if tie_embeddings: + self.lm_head.weight.requires_grad_(False) + self.vocab_bias = nn.Parameter(torch.zeros(vocab_size, dtype=torch.float32)) + + self._init_weights(tied_embed_init_std) + + def _init_weights(self, tied_embed_init_std): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=tied_embed_init_std) + for module in self.modules(): + if isinstance(module, XNORLinear) and not getattr(module, "_zero_init", False): + nn.init.normal_(module.weight, mean=0.0, std=0.02) + elif isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _compute_logits(self, x: Tensor) -> Tensor: + if self.tie_embeddings: + proj = self.embed_proj_rev(x) if self.embed_proj_rev is not None else x + return F.linear(proj, self.tok_emb.weight.to(x.dtype)) + return self.lm_head(x) + + def _softcap(self, logits: Tensor) -> Tensor: + s = self.logit_softcap + if self.softcap_type == "tanh": + return s * torch.tanh(logits / s) + # Poly softcap — fuses with torch.compile (tanh does not) + x_sc = torch.clamp(logits / s, -2.0, 2.0) + x2 = x_sc * x_sc + return s * torch.clamp(x_sc * (1.0 - x2 / 3.0 + x2 * x2 / 15.0), -1.0, 1.0) + + def forward(self, input_ids: Tensor, target_ids: Tensor, reduction="mean", temperature=1.0): + x = self.tok_emb(input_ids) + x = x.float() + if self.embed_proj is not None: + x = self.embed_proj(x) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + # U-Net encoder + skips = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + + # U-Net decoder with skip connections + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype) * skips.pop() + x = self.blocks[bi](x, x0) + + x_normed = self.final_norm(x) + x_flat = x_normed.reshape(-1, x_normed.size(-1)) + targets = target_ids.reshape(-1) + logits = self._softcap(self._compute_logits(x_flat)) + + if temperature != 1.0: + logits = logits / temperature + if reduction == "none": + return F.cross_entropy(logits.float(), targets, reduction="none").reshape(input_ids.shape) + + # Fused CE + Z-loss (logsumexp² regularization keeps STE gradients sharp) + logits_f = logits.float() + lse = torch.logsumexp(logits_f, dim=-1) + target_logits = logits_f.gather(1, targets.unsqueeze(1)).squeeze(1) + return (lse - target_logits).mean() + 1e-4 * (lse ** 2).mean() + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- +def build_luts(sp, vocab_size, device): + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + +def ld_val(pattern, seq_len, max_tok=int(os.environ.get("VAL_MAX_TOKENS", 500000))): + files = sorted(glob.glob(pattern)) + assert files, f"No files: {pattern}" + tok = torch.cat([ld_shard(Path(p)) for p in files]).contiguous() + if max_tok > 0: tok = tok[:max_tok + 1] + u = ((tok.numel() - 1) // seq_len) * seq_len + return tok[:u + 1] + +def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, temperature=1.0): + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_seqs = max(1, local_batch_tokens // args.train_seq_len) + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_start in range(seq_start, seq_end, local_batch_seqs): + batch_end = min(batch_start + local_batch_seqs, seq_end) + raw_start = batch_start * args.train_seq_len + raw_end = batch_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64) + x, y = local[:-1].reshape(-1, args.train_seq_len), local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + batch_loss = model(x, y, temperature=temperature).detach() + n = float(y.numel()) + loss_sum += batch_loss.to(torch.float64) * n + token_count += n + prev_ids, tgt_ids = x.reshape(-1), y.reshape(-1) + tok_bytes = base_bytes_lut[tgt_ids].to(torch.int16) + tok_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.int16) + byte_count += tok_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, token_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + val_loss = loss_sum / token_count + bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) + model.train() + return float(val_loss.item()), float(bpb) + +def eval_val_sliding(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, temperature=1.0): + seq_len = args.train_seq_len + batch_size = args.sliding_batch_size + total_tokens = val_tokens.numel() - 1 + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + all_starts = list(range(0, total_tokens - seq_len, stride)) + my_starts = all_starts[rank::world_size] + model.eval() + with torch.inference_mode(): + for i in range(0, len(my_starts), batch_size): + batch_starts = my_starts[i:i + batch_size] + starts_t = torch.tensor(batch_starts, dtype=torch.int64) + offsets = torch.arange(seq_len + 1, dtype=torch.int64) + indices = starts_t.unsqueeze(1) + offsets.unsqueeze(0) + local_batch = val_tokens[indices].to(device=device, dtype=torch.int64, non_blocking=True) + x, y = local_batch[:, :-1], local_batch[:, 1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_token_loss = model(x, y, reduction="none", temperature=temperature).detach() + for b, start in enumerate(batch_starts): + score_from = 0 if start == 0 else seq_len - stride + scored = per_token_loss[b, score_from:] + sx, sy = x[b, score_from:], y[b, score_from:] + loss_sum += scored.to(torch.float64).sum() + token_count += scored.numel() + tok_bytes = base_bytes_lut[sy].to(torch.int16) + tok_bytes += (has_leading_space_lut[sy] & ~is_boundary_token_lut[sx]).to(torch.int16) + byte_count += tok_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, token_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + val_loss = loss_sum / token_count + bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) + model.train() + return float(val_loss.item()), float(bpb) + +def find_temp(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut): + best_t, best_loss = 1.0, float("inf") + for t in [0.85, 0.90, 0.95, 1.00, 1.05, 1.10]: + loss, _ = eval_val(args, base_model, rank, world_size, device, grad_accum_steps, + calibration_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, temperature=t) + if loss < best_loss: + best_loss = loss + best_t = t + return best_t + +# --------------------------------------------------------------------------- +# Checkpointing +# --------------------------------------------------------------------------- +def _ckpt_path(d, step): return os.path.join(d, f"ckpt_step{step:07d}.pt") + +def _latest_checkpoint(d): + if not os.path.isdir(d): return None + ckpts = sorted(glob.glob(os.path.join(d, "ckpt_step*.pt"))) + return ckpts[-1] if ckpts else None + +def save_checkpoint(d, step, model, optimizers, ms, _untied, ema_model, _ema_started, _ema_steps): + os.makedirs(d, exist_ok=True) + path = _ckpt_path(d, step) + tmp = path + ".tmp" + torch.save({ + "step": step, "training_time_ms": ms, "model": model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "rng_cpu": torch.get_rng_state(), "rng_cuda": torch.cuda.get_rng_state(), + "flags": {"_untied": _untied, "_ema_started": _ema_started, "_ema_steps": _ema_steps}, + "ema_model": ema_model.state_dict() if ema_model is not None else None, + }, tmp) + os.replace(tmp, path) + +def load_checkpoint(path, model, optimizers, device, ema_model=None): + p = torch.load(path, map_location=device, weights_only=False) + model.load_state_dict(p["model"], strict=True) + for opt, sd in zip(optimizers, p["optimizers"]): + opt.load_state_dict(sd) + if ema_model is not None and p.get("ema_model") is not None: + ema_model.load_state_dict(p["ema_model"], strict=True) + torch.set_rng_state(p["rng_cpu"].cpu()) + torch.cuda.set_rng_state(p["rng_cuda"].cpu()) + f = p.get("flags", {}) + return p["step"], p["training_time_ms"], f.get("_untied", False), f.get("_ema_started", False), f.get("_ema_steps", 0) + +# --------------------------------------------------------------------------- +# Main training +# --------------------------------------------------------------------------- +def main(): + args = Hyperparameters() + code = Path(__file__).read_text(encoding="utf-8") + + global _INT8_KERNEL + _INT8_KERNEL = args.use_int8_kernel + + global _TRITON_KERNEL + _TRITON_KERNEL = args.use_triton_kernel and HAS_TRITON + + global _SCALE_FP8_STE + _SCALE_FP8_STE = args.scale_fp8 + + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = world_size > 1 + grad_accum_steps = max(1, 8 // world_size) + grad_scale = 1.0 / grad_accum_steps + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + os.makedirs("logs/cuda/xnor/", exist_ok=True) + logfile = f"logs/cuda/xnor/{args.run_id}.txt" if master_process else None + if master_process: print(logfile) + def log0(msg, console=True): + if not master_process: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + val_tokens = ld_val(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_luts(sp, args.vocab_size, device) + + # --- Model --- + # BINARIZE_ACTIVATIONS: 0=BWN, 1=full XNOR, 2=XNOR but BWN on MLP down proj + _ba = args.binarize_activations + _binarize_act = _ba >= 1 + _binarize_down = _ba == 1 + base_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + group_size=args.xnor_group_size, activation=args.activation_type, + embed_dim=args.embed_dim, fp_storage=args.fp_storage, + softcap_type=args.softcap_type, rope_type=args.rope_type, + yarn_max_len=args.yarn_max_len, train_seq_len=args.train_seq_len, + binarize_act=_binarize_act, binarize_down=_binarize_down, + smear=args.smear, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, nn.Linear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.tie_embeddings: + base_model.lm_head.weight.requires_grad_(False) + + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.recompile_limit = 16 + # No DDP needed for EGGROLL - each GPU evaluates different perturbations + # No torch.compile - forward only, perturbations modify weights in-place + + # --- Load pretrained weights --- + if args.eggroll_load: + log0(f"loading pretrained weights from {args.eggroll_load}") + if args.eggroll_load.endswith(".ptz"): + # Compressed artifact + with open(args.eggroll_load, "rb") as f: + raw = f.read() + method_byte = raw[0] + compressed = raw[1:] + if method_byte == 2 and HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + decompressed = dctx.decompress(compressed) + elif method_byte == 1 and HAS_BROTLI: + decompressed = brotli.decompress(compressed) + else: + decompressed = lzma.decompress(compressed) + loaded = torch.load(io.BytesIO(decompressed), map_location="cpu", weights_only=False) + base_model.load_state_dict(deq_sd(loaded), strict=False) + else: + # Raw checkpoint + ckpt = torch.load(args.eggroll_load, map_location="cpu", weights_only=False) + sd = ckpt["model"] if "model" in ckpt else ckpt + base_model.load_state_dict(sd, strict=False) + log0("pretrained weights loaded") + + # --- EGGROLL setup --- + # No gradients needed - forward only + for p in base_model.parameters(): + p.requires_grad_(False) + binary_params = get_binary_params(base_model, num_layers=args.num_layers, + last_n_layers=args.eggroll_layers) + log0(f"EGGROLL binary params: {len(binary_params)} layers, " + f"{sum(p.numel() for _,p in binary_params)} params") + + # --- Log --- + log0("--- Hyperparameters ---", console=False) + log0(" ".join(f"{a}={getattr(args,a)}" for a in sorted(dir(args)) + if not a.startswith("_") and a not in ("train_files","val_files") + and not callable(getattr(args,a))), console=False) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"XNOR-EGGROLL | params:{n_params} L:{args.num_layers} d:{args.model_dim} " + f"h:{args.num_heads} kv:{args.num_kv_heads} mlp:{args.mlp_mult}x " + f"act:{args.activation_type} xnor_act:{args.binarize_activations} " + f"g:{args.xnor_group_size} ws:{world_size} pop:{args.pop_size} " + f"sigma:{args.eggroll_sigma} rank:{args.eggroll_rank}") + + # --- Data --- + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step, elapsed_ms): + import math + if args.lr_schedule == "cosine": + if max_wallclock_ms is None: + total = args.iterations + warmup = min(args.lr_warmup_steps, total) + if step < warmup: + return step / max(warmup, 1) + progress = (step - warmup) / max(total - warmup, 1) + else: + warmup_ms = args.lr_warmup_steps * (elapsed_ms / max(step, 1)) if step > 0 else 1000.0 + if elapsed_ms < warmup_ms: + return elapsed_ms / max(warmup_ms, 1e-9) + progress = (elapsed_ms - warmup_ms) / max(max_wallclock_ms - warmup_ms, 1e-9) + return 0.5 * (1.0 + math.cos(math.pi * min(progress, 1.0))) + if args.warmdown_fraction <= 0: return 1.0 + if max_wallclock_ms is None: + warmdown_start = int(args.iterations * (1.0 - args.warmdown_fraction)) + return max((args.iterations - step) / max(args.iterations * args.warmdown_fraction, 1), 0.0) if step >= warmdown_start else 1.0 + warmdown_ms = max_wallclock_ms * args.warmdown_fraction + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # --- Seq len scheduling --- + seq_len_stages = [args.train_seq_len] + if args.seq_len_schedule: + seq_len_stages = [] + sl = 128 + while sl <= args.train_seq_len: + seq_len_stages.append(sl) + sl *= 2 + if seq_len_stages[-1] != args.train_seq_len: + seq_len_stages.append(args.train_seq_len) + log0(f"seq_len_schedule: {seq_len_stages}") + + def get_seq_len(step, elapsed_ms): + if len(seq_len_stages) <= 1: + return seq_len_stages[0] + n = len(seq_len_stages) + if max_wallclock_ms is not None: + stage_idx = min(int(elapsed_ms / (max_wallclock_ms / n)), n - 1) + else: + stage_idx = min(int(step / (args.iterations / n)), n - 1) + return seq_len_stages[stage_idx] + + # --- Main EGGROLL loop --- + training_time_ms = 0.0 + stop_after_step = None + pop_per_gpu = args.pop_size // world_size + assert pop_per_gpu % 2 == 0, "pop_size / world_size must be even for antithetic pairs" + pairs_per_gpu = pop_per_gpu // 2 + + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + + base_model.eval() + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(args, base_model, rank, world_size, device, 1, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms") + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms") + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + lr_scale = lr_mul(step, elapsed_ms) + current_seq_len = get_seq_len(step, elapsed_ms) + + # Get data batch (same for all perturbations) + x, y = train_loader.next_batch(args.train_batch_tokens, current_seq_len, 1) + + # Evaluate perturbations (antithetic pairs) + fitnesses_pos = [] + fitnesses_neg = [] + local_seeds = [] + + for i in range(pairs_per_gpu): + global_idx = step * (args.pop_size // 2) + rank * pairs_per_gpu + i + seed = global_idx * 137 + 42 # deterministic seed + + # Positive perturbation + apply_perturbation(binary_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=+1.0, device=device) + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss_pos = base_model(x, y) + fitnesses_pos.append(-loss_pos.item()) + remove_perturbation(binary_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=+1.0, device=device) + + # Negative perturbation (antithetic) + apply_perturbation(binary_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=-1.0, device=device) + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss_neg = base_model(x, y) + fitnesses_neg.append(-loss_neg.item()) + remove_perturbation(binary_params, seed, args.eggroll_sigma, args.eggroll_rank, sign=-1.0, device=device) + + local_seeds.append(seed) + + # Gather fitnesses across GPUs + if distributed: + all_pos = [None] * world_size + all_neg = [None] * world_size + all_seeds_list = [None] * world_size + dist.all_gather_object(all_pos, fitnesses_pos) + dist.all_gather_object(all_neg, fitnesses_neg) + dist.all_gather_object(all_seeds_list, local_seeds) + fitnesses_pos = [f for sublist in all_pos for f in sublist] + fitnesses_neg = [f for sublist in all_neg for f in sublist] + all_seeds = [s for sublist in all_seeds_list for s in sublist] + else: + all_seeds = local_seeds + + # Shape fitnesses + shaped = shape_fitnesses(fitnesses_pos, fitnesses_neg, method=args.fitness_shaping) + + # Update weights + lr = args.eggroll_lr * lr_scale + compute_eggroll_update(binary_params, all_seeds, shaped, args.eggroll_sigma, + args.eggroll_rank, lr, device=device) + + step += 1 + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + if args.train_log_every > 0 and step % args.train_log_every == 0: + mean_fitness = sum(fitnesses_pos) / len(fitnesses_pos) + log0(f"step:{step}/{args.iterations} fitness:{mean_fitness:.4f} " + f"t:{approx_ms:.0f}ms step_avg:{approx_ms/step:.1f}ms") + + # Wallclock cap + if stop_after_step is None and max_wallclock_ms is not None and step % 10 == 0: + reached_cap = approx_ms >= max_wallclock_ms + if distributed: + cap_t = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(cap_t, op=dist.ReduceOp.MAX) + reached_cap = bool(cap_t.item()) + if reached_cap: + stop_after_step = step + + # --- Sign-sort permutation for compression --- + # Permute MLP hidden dimension so same-sign weights cluster together. + # Sort columns of down_proj by their sign pattern (lexicographic on sign bits). + # Apply matching row permutation to up_proj. Model output is identical. + if master_process: + sd = base_model.state_dict() + for layer_idx in range(args.num_layers): + # Find the MLP weight pairs + down_key = f"blocks.{layer_idx}.mlp.proj.weight" + if args.activation_type == "swiglu": + up_key = f"blocks.{layer_idx}.mlp.gate_up.weight" + else: + up_key = f"blocks.{layer_idx}.mlp.fc.weight" + if down_key not in sd or up_key not in sd: + continue + w_down = sd[down_key] # [dim, hidden] + w_up = sd[up_key] # [hidden, dim] (or [hidden*2, dim] for swiglu) + hidden = w_down.shape[1] + # Sort hidden dim columns of w_down by sign pattern + signs = (w_down >= 0).to(torch.int8) # [dim, hidden] + # Pack sign columns into sortable keys (use first 64 rows as sort key) + n_key_rows = min(64, signs.shape[0]) + sort_keys = torch.zeros(hidden, dtype=torch.int64, device=w_down.device) + for r in range(n_key_rows): + sort_keys += signs[r, :].long() << (n_key_rows - 1 - r) + perm = sort_keys.argsort() + # Apply permutation + sd[down_key] = w_down[:, perm] + if args.activation_type == "swiglu": + half = w_up.shape[0] // 2 + sd[up_key] = torch.cat([w_up[:half][perm], w_up[half:][perm]], dim=0) + else: + sd[up_key] = w_up[perm] + if base_model.tie_embeddings: + sd.pop("lm_head.weight", None) + q_obj, q_stats = q_sd(sd, group_size=args.xnor_group_size, + fp_storage=args.fp_storage, scale_fp8=args.scale_fp8) + buf = io.BytesIO() + torch.save(q_obj, buf) + raw_bytes = buf.getvalue() + # Compare LZMA vs Brotli, keep smallest + lzma_blob = lzma.compress(raw_bytes, preset=9) + best_blob, best_method = lzma_blob, "lzma" + comp_log = f"compression: lzma={len(lzma_blob)/1e6:.2f}MB" + if HAS_BROTLI: + brotli_blob = brotli.compress(raw_bytes, quality=11) + comp_log += f" brotli={len(brotli_blob)/1e6:.2f}MB" + if len(brotli_blob) < len(best_blob): + best_blob, best_method = brotli_blob, "brotli" + if HAS_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + zstd_blob = cctx.compress(raw_bytes) + comp_log += f" zstd={len(zstd_blob)/1e6:.2f}MB" + if len(zstd_blob) < len(best_blob): + best_blob, best_method = zstd_blob, "zstd" + log0(comp_log) + log0(f"using: {best_method}") + with open("final_model.xnor.ptz", "wb") as f: + # Method byte: 0=lzma, 1=brotli, 2=zstd + method_id = {"lzma": 0, "brotli": 1, "zstd": 2}[best_method] + f.write(bytes([method_id])) + f.write(best_blob) + artifact_bytes = 1 + len(best_blob) + code_bytes = len(code.encode("utf-8")) + total = artifact_bytes + code_bytes + log0(f"artifact:{artifact_bytes/1e6:.2f}MB binary:{q_stats['binary_params']}({q_stats['binary_bytes']}B) " + f"fp:{q_stats['fp_params']}({q_stats['fp_bytes']}B) code:{code_bytes}") + log0(f"budget:{total}/{16000000} ({total/1e6:.2f}/{16.00:.2f}MB) " + f"{'FITS' if total <= 16000000 else 'OVER'}") + + # --- Roundtrip eval --- + if distributed: dist.barrier() + with open("final_model.xnor.ptz", "rb") as f: + raw = f.read() + method_byte = raw[0] + compressed = raw[1:] + if method_byte == 2 and HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + decompressed = dctx.decompress(compressed) + elif method_byte == 1 and HAS_BROTLI: + decompressed = brotli.decompress(compressed) + else: + decompressed = lzma.decompress(compressed) + loaded = torch.load(io.BytesIO(decompressed), map_location="cpu", weights_only=False) + base_model.load_state_dict(deq_sd(loaded), strict=False) + torch._dynamo.reset() + q_val_loss, q_val_bpb = eval_val(args, base_model, rank, world_size, device, 1, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"final_xnor_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f}") + + opt_temp = 1.0 + if args.temp_scaling: + torch.cuda.synchronize() + t_temp = time.perf_counter() + calibration_tokens = train_loader.stream.take(65536).to(device) + opt_temp = find_temp(args, base_model, rank, world_size, device, 1, + calibration_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + torch.cuda.synchronize() + log0(f"temp_scaling optimal_T:{opt_temp:.2f} time:{1000*(time.perf_counter()-t_temp):.0f}ms") + + if args.sliding_eval: + torch.cuda.synchronize() + t_sl = time.perf_counter() + sw_loss, sw_bpb = eval_val_sliding(args, base_model, rank, world_size, device, 1, + val_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, stride=args.sliding_eval_stride, + temperature=opt_temp) + torch.cuda.synchronize() + log0(f"final_sliding val_loss:{sw_loss:.4f} val_bpb:{sw_bpb:.4f} " + f"(stride={args.sliding_eval_stride}, T={opt_temp:.2f}) time:{1000*(time.perf_counter()-t_sl):.0f}ms") + + if distributed: dist.destroy_process_group() + + +if __name__ == "__main__": + main() \ No newline at end of file