Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Medusa: Unstable S2 — DeltaNet Crawler, Legal Resubmission

**val_bpb: 0.8822** (3-seed mean) | **~9.9MB** | 8xH100 SXM

Legal resubmission of PR #1028 (Medusa: Unstable, mean 0.9984 BPB).

**Legality fix:** PR #1028 was flagged because `gptq_calibrate_loop_aware()` reads 256 batches from training data after the 600s wallclock cap fires. Fix: `GPTQ_RESERVE_MS=30000` stops the training loop 30s early (~570s) so GPTQ calibration (~12s) completes within the budget. The log prints elapsed time at GPTQ start for reviewer verification:
```
stopping_early: wallclock_cap train_time:570052ms step:4642/20000
gptq:loop-aware calibrated 41 layers in 11.4s
```

All hyperparameters are identical to PR #1028 / Medusa_IV.

## Results

| Seed | BPB (sliding window) | Post-EMA BPB | Int6 Roundtrip | Steps |
|------|--------------------:|-------------:|---------------:|------:|
| 300 | 1.0251 | 0.6484 | 0.8987 | 4628 |
| 444 | 0.8469 | 0.4330 | 0.7159 | 4616 |
| 4 | **0.7744** | 0.4339 | 0.6271 | 4642 |
| **Mean** | **0.8822** | | | |
| **Std dev** | **~0.105** | | | |

3-seed mean improved from 0.9984 (PR #1028) to 0.8822 with the timing fix.

## Architecture

- **Topology**: 4 flat layers + 1 crawler layer × 4 loops (Frugendorff compression)
- **INST_DIM**: 32 (flow instructions)
- **DeltaNet**: 4 heads, canonical `chunk_delta_rule` from `fla.ops.delta_rule`
- **Quantization**: int6+zstd + CRAWLER_QUANT_INT8=1, loop-aware 2-phase GPTQ (41 layers)
- **Dims**: XSA_LAST_N=11, BIGRAM_VOCAB_SIZE=2048, ROPE_DIMS=16
- **Schedule**: WARMDOWN_ITERS=2000, SWA_EVERY=50, EMA_START_STEP=4400, EMA_DECAY=0.99
- **GPTQ_RESERVE_MS**: 30000 (training stops at ~570s; GPTQ runs within budget)

## Legality

1. No n-gram eval — sliding window only
2. No val data used during training
3. GPTQ calibration reads training data and runs **inside** the 600s wallclock budget (verified via `gptq:loop-aware calibrated 41 layers in ~11.5s` at ~570s elapsed)
4. Score-first protocol not applicable (no n-gram cache)

## Known Issues

High cross-seed variance (std dev ~0.105) is caused by DeltaNet heads. Two root causes identified:
1. **State dtype bug**: `chunk_delta_rule` returns Float32 `new_state` in BF16 training — causes recompile_limit warnings during eval (does not affect final score, only eval speed). Fix exists in follow-on work.
2. **Quantization unravel**: DeltaNet weight errors compound through 4 crawler loops.

Stabilization is active research.

## Reproduce

```bash
SEED=300 bash experiments/Medusa_Legal_unstable/run.sh
SEED=444 bash experiments/Medusa_Legal_unstable/run.sh
SEED=4 bash experiments/Medusa_Legal_unstable/run.sh
```

8xH100 SXM, 600s training per seed.

## Credits

- **Gated DeltaNet (GDN) — primary catalyst**: @shalyhinpavel (PR #875) — 1.0226 BPB pure neural
- **Canonical DeltaNet kernel**: `fla.ops.delta_rule` (flash-linear-attention)
- **Loop-aware GPTQ + Frugendorff crawler architecture**: @newjordan (PR #990, PR #1028)
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#!/bin/bash
set -euo pipefail
# MEDUSA_LEGAL_UNSTABLE: Medusa_IV config + legality fix for GPTQ timing
#
# PR #1028 (Medusa_IV) was flagged by judges: GPTQ calibration read training
# data AFTER the 600s wallclock cap, which is disallowed.
#
# Fix: GPTQ_RESERVE_MS=30000 — training loop stops 30s early so GPTQ
# calibration (~12s) completes within the 600s budget. The log now prints
# elapsed time at GPTQ start so reviewers can verify.
#
# All other hyperparameters identical to Medusa_IV.

SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd -- "${SCRIPT_DIR}/../.." && pwd)"
cd "${REPO_ROOT}"
export PYTHONPATH="${REPO_ROOT}/flash-attention/hopper:${PYTHONPATH:-}"

SEED="${SEED:-1337}"
NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
NITRUST_ENABLE="${NITRUST_ENABLE:-0}"
NITRUST_STRICT="${NITRUST_STRICT:-0}"
NITRUST_SO_PATH="${NITRUST_SO_PATH:-Nitrust/rust/target/release/libnitrust_py.so}"

echo "[preflight] checking zstandard..."
python3 -c "import zstandard; print(f' zstandard {zstandard.__version__} OK')" 2>/dev/null \
|| echo " WARNING: zstandard not found"

echo "[preflight] patching torch inductor AttrsDescriptor bug (if present)..."
python3 -c "
import importlib.util, pathlib
spec = importlib.util.find_spec('torch._inductor.runtime.hints')
if spec and spec.origin:
p = pathlib.Path(spec.origin)
txt = p.read_text()
old = 'attr_desc_fields = {f.name for f in fields(AttrsDescriptor)}'
if old in txt:
import attr
new = 'import attr as _attr; attr_desc_fields = {f.name for f in _attr.fields(AttrsDescriptor)}'
p.write_text(txt.replace(old, new))
print(' patched OK')
else:
print(' no patch needed')
" 2>/dev/null || echo " WARNING: could not patch hints.py"

echo "[preflight] checking flash_attn..."
python3 -c "
try:
import flash_attn_interface; print(' FA3 (hopper) OK')
except ImportError:
import flash_attn; v=flash_attn.__version__
if v.startswith('3'): print(f' FA3 v{v} OK')
else: print(f' WARNING: FA{v[0]} detected — want FA3')
" 2>/dev/null || echo " WARNING: no flash_attn found"

echo "[preflight] checking fla.ops.delta_rule (canonical DeltaNet kernel)..."
python3 -c "
from fla.ops.delta_rule import chunk_delta_rule
print(' chunk_delta_rule OK — CANONICAL kernel active')
" 2>/dev/null || echo " WARNING: fla.ops not found — will fall back to Python DeltaNet loop (slow, non-canonical)"

echo "============================================"
echo " MEDUSA_LEGAL_UNSTABLE — Medusa_IV + GPTQ timing fix"
echo " Seed: ${SEED}"
echo " inst_dim=32 FLOW | 4 flat + 1 crawler x 4 loops"
echo " DELTA_NET_HEADS=4 | chunk_delta_rule | short_conv=True"
echo " EMA_START_STEP=4400 | EMA_DECAY=0.99 | LOOP_AWARE_GPTQ=1"
echo " NITRUST_ENABLE=${NITRUST_ENABLE} | NITRUST_STRICT=${NITRUST_STRICT}"
echo "============================================"

SEED="$SEED" \
MAX_WALLCLOCK_SECONDS=600 \
GPTQ_RESERVE_MS=30000 \
WARMDOWN_ITERS=2000 \
COMPLEMENT_ALPHA=0 \
XSA_LAST_N=11 \
BIGRAM_VOCAB_SIZE=2048 \
ROPE_DIMS=16 \
SWA_EVERY=50 \
MTP_NUM_HEADS=0 \
LATE_QAT_THRESHOLD=0 \
MATRIX_LR=0.03 \
TORCHDYNAMO_OPTIMIZE_DDP=0 \
COMPILE_FULLGRAPH=0 \
NGRAM_EVAL_ORDER=0 \
USE_CRAWLER=1 \
NUM_FLAT_LAYERS=4 \
NUM_CRAWLER_LAYERS=1 \
CRAWLER_LOOPS=4 \
INST_DIM=32 \
CRAWLER_QUANT_INT8=1 \
DELTA_NET_HEADS=4 \
EMA_START_STEP=4400 \
EMA_DECAY=0.99 \
LOOP_AWARE_GPTQ=1 \
NITRUST_ENABLE="${NITRUST_ENABLE}" \
NITRUST_STRICT="${NITRUST_STRICT}" \
NITRUST_SO_PATH="${NITRUST_SO_PATH}" \
torchrun --standalone --nproc_per_node="${NPROC_PER_NODE}" \
"${SCRIPT_DIR}/train_gpt.py" \
2>&1 | tee "logs/medusa_legal_s${SEED}_$(date +%Y%m%d_%H%M%S).log"

echo "============================================"
echo " DONE"
echo "============================================"
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"author": "Frosty40",
"github_id": "newjordan",
"name": "Medusa: Unstable S2 — DeltaNet Crawler + Loop-Aware GPTQ (Legal)",
"blurb": "Legal resubmission of Medusa: Unstable (PR #1028). Fix: GPTQ_RESERVE_MS=30000 stops training 30s early so GPTQ calibration (~12s) completes within the 600s wallclock budget. All hyperparameters identical to Medusa_IV. DELTA_NET_HEADS=4, chunk_delta_rule, loop-aware 2-phase GPTQ, late-start EMA (step 4400, decay=0.99). 4 flat + 1 crawler x 4 loops, INST_DIM=32. Known: high cross-seed variance from DeltaNet heads; stabilization is active research.",
"date": "2026-03-29",
"seed_300": {
"val_bpb": 0.6191,
"sliding_window_bpb": 1.02508673,
"post_ema_bpb": 0.6484,
"roundtrip_bpb": 0.89869376,
"steps": 4628,
"train_time_s": 570,
"eval_time_s": "~124s"
},
"seed_444": {
"val_bpb": 0.4214,
"sliding_window_bpb": 0.84693639,
"post_ema_bpb": 0.4330,
"roundtrip_bpb": 0.71594908,
"steps": 4616,
"train_time_s": 570,
"eval_time_s": "~106s"
},
"seed_4": {
"val_bpb": 0.4116,
"sliding_window_bpb": 0.77444386,
"post_ema_bpb": 0.4339,
"roundtrip_bpb": 0.62712646,
"steps": 4642,
"train_time_s": 570,
"eval_time_s": "~106s"
},
"val_bpb": 0.88215566,
"bytes_total": 9758873,
"bytes_code": 180983,
"hardware": "8xH100 SXM",
"notes": "Cross-seed variance std dev ~0.105 (vs ClownCar 0.00015). DeltaNet heads introduce seed sensitivity. GPTQ_RESERVE_MS=30000: training stops at ~570s, GPTQ completes at ~582s. Stabilization ongoing."
}
Loading