diff --git a/curriculum/css/components.css b/curriculum/css/components.css new file mode 100644 index 0000000000..4a21768b13 --- /dev/null +++ b/curriculum/css/components.css @@ -0,0 +1,697 @@ +/* ============================================================ + Shared Components + ============================================================ */ + +/* --- Layout Shell --- */ + +.app-layout { + display: flex; + height: 100vh; + overflow: hidden; +} + +/* --- Sidebar --- */ + +.sidebar { + width: var(--sidebar-width); + min-width: var(--sidebar-width); + height: 100vh; + background: var(--pg-bg-elevated); + border-right: 1px solid var(--pg-border); + display: flex; + flex-direction: column; + overflow: hidden; + transition: width var(--transition-base), min-width var(--transition-base); + z-index: 100; +} + +.sidebar.collapsed { + width: 56px; + min-width: 56px; +} + +.sidebar-header { + padding: var(--space-md) var(--space-lg); + border-bottom: 1px solid var(--pg-border); + display: flex; + align-items: center; + gap: var(--space-sm); + min-height: var(--topbar-height); +} + +.sidebar-logo { + font-size: var(--font-size-md); + font-weight: 800; + color: var(--pg-primary); + white-space: nowrap; + letter-spacing: -0.03em; +} + +.sidebar-logo span { + color: var(--pg-accent); +} + +.sidebar-nav { + flex: 1; + overflow-y: auto; + padding: var(--space-sm) 0; +} + +.sidebar-footer { + padding: var(--space-sm) var(--space-md); + border-top: 1px solid var(--pg-border); + display: flex; + align-items: center; + justify-content: space-between; +} + +/* --- Nav Items --- */ + +.nav-item { + display: flex; + align-items: center; + gap: var(--space-sm); + padding: var(--space-sm) var(--space-lg); + cursor: pointer; + transition: background var(--transition-fast), color var(--transition-fast); + color: var(--pg-text-muted); + font-size: var(--font-size-sm); + user-select: none; + border-left: 3px solid transparent; +} + +.nav-item:hover { + background: var(--pg-bg-hover); + color: var(--pg-text); +} + +.nav-item.active { + background: var(--pg-bg-hover); + color: var(--pg-primary); + border-left-color: var(--pg-primary); +} + +.nav-item .nav-icon { + width: 20px; + text-align: center; + flex-shrink: 0; + font-size: var(--font-size-sm); +} + +.nav-item .nav-label { + flex: 1; + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; +} + +.nav-item .nav-badge { + font-size: var(--font-size-xs); + padding: 1px 6px; + border-radius: var(--radius-full); + background: var(--pg-bg-card); + color: var(--pg-text-dim); + font-weight: 600; +} + +.nav-item.complete .nav-badge { + background: var(--pg-lime); + color: #0b0b11; +} + +.nav-item.in-progress .nav-badge { + background: var(--pg-amber); + color: #0b0b11; +} + +/* Nested week items */ +.nav-weeks { + display: none; + padding-left: var(--space-lg); +} + +.nav-weeks.expanded { + display: block; +} + +.nav-week { + padding: var(--space-xs) var(--space-md); + font-size: var(--font-size-xs); + color: var(--pg-text-dim); + cursor: pointer; + display: flex; + align-items: center; + gap: var(--space-xs); + border-left: 2px solid transparent; +} + +.nav-week:hover { + color: var(--pg-text); + background: var(--pg-bg-hover); +} + +.nav-week.active { + color: var(--pg-accent); + border-left-color: var(--pg-accent); +} + +/* --- Main Content Area --- */ + +.main-area { + flex: 1; + display: flex; + flex-direction: column; + overflow: hidden; + min-width: 0; +} + +.topbar { + height: var(--topbar-height); + min-height: var(--topbar-height); + padding: 0 var(--space-lg); + display: flex; + align-items: center; + justify-content: space-between; + border-bottom: 1px solid var(--pg-border); + background: var(--pg-bg-elevated); + gap: var(--space-md); +} + +.breadcrumb { + display: flex; + align-items: center; + gap: var(--space-xs); + font-size: var(--font-size-sm); + color: var(--pg-text-muted); +} + +.breadcrumb a { + color: var(--pg-text-muted); +} + +.breadcrumb a:hover { + color: var(--pg-text); +} + +.breadcrumb .sep { + color: var(--pg-text-dim); + margin: 0 2px; +} + +.topbar-actions { + display: flex; + align-items: center; + gap: var(--space-sm); +} + +.content-scroll { + flex: 1; + overflow-y: auto; + overflow-x: hidden; + padding: var(--space-xl); +} + +.content-inner { + max-width: var(--content-max-width); + margin: 0 auto; +} + +/* --- Theme Toggle --- */ + +.theme-toggle { + width: 40px; + height: 22px; + border-radius: var(--radius-full); + background: var(--pg-bg-card); + border: 1px solid var(--pg-border); + cursor: pointer; + position: relative; + transition: background var(--transition-fast); +} + +.theme-toggle::after { + content: ''; + position: absolute; + top: 2px; + left: 2px; + width: 16px; + height: 16px; + border-radius: 50%; + background: var(--pg-violet); + transition: transform var(--transition-fast); +} + +[data-theme="light"] .theme-toggle::after { + transform: translateX(18px); + background: var(--pg-amber); +} + +/* --- Cards --- */ + +.card { + background: var(--pg-bg-card); + border: 1px solid var(--pg-border); + border-radius: var(--radius-lg); + padding: var(--space-lg); + transition: border-color var(--transition-fast), box-shadow var(--transition-fast); +} + +.card:hover { + border-color: var(--pg-primary); + box-shadow: var(--shadow-glow); +} + +.card-header { + display: flex; + align-items: center; + justify-content: space-between; + margin-bottom: var(--space-md); +} + +.card-title { + font-size: var(--font-size-md); + font-weight: 700; + color: var(--pg-text); +} + +/* --- Progress Ring --- */ + +.progress-ring { + position: relative; + display: inline-flex; + align-items: center; + justify-content: center; +} + +.progress-ring svg { + transform: rotate(-90deg); +} + +.progress-ring .ring-bg { + stroke: var(--pg-border); + fill: none; +} + +.progress-ring .ring-fill { + fill: none; + stroke-linecap: round; + transition: stroke-dashoffset var(--transition-slow); +} + +.progress-ring .ring-text { + position: absolute; + font-size: var(--font-size-xs); + font-weight: 700; + color: var(--pg-text); +} + +/* --- Badges & Tags --- */ + +.badge { + display: inline-flex; + align-items: center; + padding: 2px 10px; + border-radius: var(--radius-full); + font-size: var(--font-size-xs); + font-weight: 700; + letter-spacing: 0.03em; + text-transform: uppercase; +} + +.badge-violet { background: rgba(108, 92, 231, 0.15); color: var(--pg-violet-soft); } +.badge-cyan { background: rgba(0, 206, 201, 0.15); color: var(--pg-cyan-soft); } +.badge-magenta { background: rgba(232, 67, 147, 0.15); color: var(--pg-magenta-soft); } +.badge-amber { background: rgba(253, 203, 110, 0.15);color: var(--pg-amber); } +.badge-lime { background: rgba(0, 184, 148, 0.15); color: var(--pg-lime-soft); } +.badge-red { background: rgba(214, 48, 49, 0.15); color: var(--pg-red-soft); } +.badge-blue { background: rgba(9, 132, 227, 0.15); color: var(--pg-blue-soft); } + +/* --- Buttons --- */ + +.btn { + display: inline-flex; + align-items: center; + gap: var(--space-xs); + padding: var(--space-sm) var(--space-md); + border-radius: var(--radius-md); + font-family: var(--font-mono); + font-size: var(--font-size-sm); + font-weight: 600; + cursor: pointer; + border: 1px solid transparent; + transition: all var(--transition-fast); + user-select: none; +} + +.btn-primary { + background: var(--pg-primary); + color: #fff; + border-color: var(--pg-primary); +} +.btn-primary:hover { + background: var(--pg-primary-soft); + border-color: var(--pg-primary-soft); +} + +.btn-ghost { + background: transparent; + color: var(--pg-text-muted); + border-color: var(--pg-border); +} +.btn-ghost:hover { + background: var(--pg-bg-hover); + color: var(--pg-text); + border-color: var(--pg-text-dim); +} + +.btn-icon { + width: 32px; + height: 32px; + padding: 0; + justify-content: center; + border-radius: var(--radius-md); + background: transparent; + color: var(--pg-text-muted); + border: 1px solid transparent; + cursor: pointer; + font-size: var(--font-size-base); + transition: all var(--transition-fast); +} +.btn-icon:hover { + background: var(--pg-bg-hover); + color: var(--pg-text); +} + +/* --- Checkboxes (custom) --- */ + +.check-item { + display: flex; + align-items: flex-start; + gap: var(--space-sm); + padding: var(--space-sm) var(--space-md); + border-radius: var(--radius-md); + transition: background var(--transition-fast); + cursor: pointer; +} + +.check-item:hover { + background: var(--pg-bg-hover); +} + +.check-item input[type="checkbox"] { + appearance: none; + -webkit-appearance: none; + width: 18px; + height: 18px; + min-width: 18px; + border: 2px solid var(--pg-border); + border-radius: var(--radius-sm); + background: var(--pg-bg-input); + cursor: pointer; + transition: all var(--transition-fast); + margin-top: 2px; + position: relative; +} + +.check-item input[type="checkbox"]:checked { + background: var(--pg-lime); + border-color: var(--pg-lime); +} + +.check-item input[type="checkbox"]:checked::after { + content: '\2713'; + position: absolute; + top: -1px; + left: 2px; + font-size: 12px; + color: #0b0b11; + font-weight: 900; +} + +.check-item.completed .check-label { + color: var(--pg-text-dim); + text-decoration: line-through; +} + +.check-label { + font-size: var(--font-size-sm); + color: var(--pg-text); + line-height: 1.4; +} + +/* --- Collapsible Sections --- */ + +.collapsible { + border: 1px solid var(--pg-border); + border-radius: var(--radius-md); + margin-bottom: var(--space-sm); + overflow: hidden; +} + +.collapsible-header { + display: flex; + align-items: center; + gap: var(--space-sm); + padding: var(--space-md) var(--space-lg); + cursor: pointer; + user-select: none; + transition: background var(--transition-fast); + background: var(--pg-bg-elevated); +} + +.collapsible-header:hover { + background: var(--pg-bg-hover); +} + +.collapsible-arrow { + transition: transform var(--transition-fast); + color: var(--pg-text-dim); + font-size: var(--font-size-sm); +} + +.collapsible.open .collapsible-arrow { + transform: rotate(90deg); +} + +.collapsible-title { + flex: 1; + font-size: var(--font-size-sm); + font-weight: 700; +} + +.collapsible-body { + display: none; + padding: var(--space-md) var(--space-lg); + border-top: 1px solid var(--pg-border); +} + +.collapsible.open .collapsible-body { + display: block; +} + +/* --- Section Headers --- */ + +.section-header { + display: flex; + align-items: center; + gap: var(--space-sm); + margin-bottom: var(--space-md); + margin-top: var(--space-xl); + padding-bottom: var(--space-sm); + border-bottom: 2px solid var(--pg-border); +} + +.section-header h3 { + font-size: var(--font-size-md); + color: var(--pg-text); +} + +.section-icon { + font-size: var(--font-size-md); +} + +/* --- Notes (contenteditable) --- */ + +.note-area { + background: var(--pg-bg-input); + border: 1px solid var(--pg-border); + border-radius: var(--radius-md); + padding: var(--space-md); + font-family: var(--font-mono); + font-size: var(--font-size-sm); + color: var(--pg-text); + min-height: 60px; + outline: none; + transition: border-color var(--transition-fast); + line-height: 1.5; +} + +.note-area:focus { + border-color: var(--pg-border-focus); +} + +.note-area:empty::before { + content: attr(data-placeholder); + color: var(--pg-text-dim); +} + +/* --- Stats Grid --- */ + +.stats-grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(160px, 1fr)); + gap: var(--space-md); +} + +.stat-card { + background: var(--pg-bg-card); + border: 1px solid var(--pg-border); + border-radius: var(--radius-lg); + padding: var(--space-lg); + text-align: center; +} + +.stat-value { + font-size: var(--font-size-2xl); + font-weight: 800; + color: var(--pg-primary); + line-height: 1; + margin-bottom: var(--space-xs); +} + +.stat-label { + font-size: var(--font-size-xs); + color: var(--pg-text-muted); + text-transform: uppercase; + letter-spacing: 0.05em; +} + +/* --- Unit Header --- */ + +.unit-header { + margin-bottom: var(--space-xl); + padding-bottom: var(--space-lg); + border-bottom: 2px solid var(--pg-border); +} + +.unit-header h1 { + margin-bottom: var(--space-sm); +} + +.unit-meta { + display: flex; + align-items: center; + gap: var(--space-md); + flex-wrap: wrap; +} + +/* --- Week Section --- */ + +.week-section { + margin-bottom: var(--space-2xl); +} + +.week-title { + font-size: var(--font-size-lg); + color: var(--pg-accent); + margin-bottom: var(--space-md); + display: flex; + align-items: center; + gap: var(--space-sm); +} + +/* --- Empty / Coming Soon --- */ + +.coming-soon { + text-align: center; + padding: var(--space-2xl); + color: var(--pg-text-dim); +} + +.coming-soon h2 { + color: var(--pg-text-muted); + margin-bottom: var(--space-md); +} + +.coming-soon .icon-lock { + font-size: 3rem; + margin-bottom: var(--space-md); + opacity: 0.5; +} + +/* --- Hex Grid (Dashboard) --- */ + +.hex-grid { + display: flex; + flex-wrap: wrap; + gap: var(--space-md); + justify-content: center; + margin: var(--space-xl) 0; +} + +.hex-card { + width: 140px; + height: 140px; + background: var(--pg-bg-card); + border: 2px solid var(--pg-border); + border-radius: var(--radius-xl); + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + cursor: pointer; + transition: all var(--transition-fast); + text-align: center; + padding: var(--space-sm); +} + +.hex-card:hover { + border-color: var(--pg-primary); + box-shadow: var(--shadow-glow); + transform: translateY(-2px); +} + +.hex-card.complete { + border-color: var(--pg-lime); +} + +.hex-card.in-progress { + border-color: var(--pg-amber); +} + +.hex-unit-num { + font-size: var(--font-size-2xl); + font-weight: 800; + color: var(--pg-primary); + line-height: 1; +} + +.hex-unit-name { + font-size: var(--font-size-xs); + color: var(--pg-text-muted); + margin-top: var(--space-xs); + line-height: 1.2; +} + +.hex-progress { + font-size: var(--font-size-xs); + color: var(--pg-text-dim); + margin-top: var(--space-xs); +} + +/* --- Responsive --- */ + +@media (max-width: 768px) { + .sidebar { + position: fixed; + left: -280px; + transition: left var(--transition-base); + z-index: 1000; + } + .sidebar.mobile-open { + left: 0; + } + .content-scroll { + padding: var(--space-md); + } +} diff --git a/curriculum/css/theme.css b/curriculum/css/theme.css new file mode 100644 index 0000000000..25e4125dfa --- /dev/null +++ b/curriculum/css/theme.css @@ -0,0 +1,234 @@ +/* ============================================================ + Parameter Golf Curriculum — Design System + Monospace-forward, iconic colors, light/dark mode + ============================================================ */ + +/* --- Mode: Dark (default) --- */ +:root { + --pg-bg: #0b0b11; + --pg-bg-elevated: #131320; + --pg-bg-card: #1a1a2e; + --pg-bg-hover: #22223a; + --pg-bg-input: #16162a; + --pg-text: #e8e8f0; + --pg-text-muted: #8888a8; + --pg-text-dim: #5c5c7a; + --pg-border: #2c2c48; + --pg-border-focus: #6c5ce7; + + /* Iconic color palette */ + --pg-violet: #6c5ce7; + --pg-violet-soft: #a29bfe; + --pg-cyan: #00cec9; + --pg-cyan-soft: #81ecec; + --pg-magenta: #e84393; + --pg-magenta-soft: #fd79a8; + --pg-amber: #fdcb6e; + --pg-amber-soft: #ffeaa7; + --pg-lime: #00b894; + --pg-lime-soft: #55efc4; + --pg-red: #d63031; + --pg-red-soft: #ff7675; + --pg-blue: #0984e3; + --pg-blue-soft: #74b9ff; + --pg-orange: #e17055; + --pg-orange-soft: #fab1a0; + + /* Semantic tokens */ + --pg-primary: var(--pg-violet); + --pg-primary-soft: var(--pg-violet-soft); + --pg-accent: var(--pg-cyan); + --pg-accent-soft: var(--pg-cyan-soft); + --pg-success: var(--pg-lime); + --pg-success-soft: var(--pg-lime-soft); + --pg-warning: var(--pg-amber); + --pg-warning-soft: var(--pg-amber-soft); + --pg-danger: var(--pg-red); + --pg-danger-soft: var(--pg-red-soft); + + /* Progress states */ + --state-locked: #2a2a40; + --state-not-started: var(--pg-text-dim); + --state-in-progress: var(--pg-amber); + --state-complete: var(--pg-lime); + + /* Typography */ + --font-mono: 'JetBrains Mono', 'Fira Code', 'SF Mono', 'Cascadia Code', 'Consolas', monospace; + --font-sans: 'Inter', system-ui, -apple-system, sans-serif; + --font-size-xs: 0.7rem; + --font-size-sm: 0.8rem; + --font-size-base: 0.9rem; + --font-size-md: 1rem; + --font-size-lg: 1.2rem; + --font-size-xl: 1.5rem; + --font-size-2xl: 2rem; + --line-height: 1.65; + + /* Spacing */ + --space-xs: 4px; + --space-sm: 8px; + --space-md: 16px; + --space-lg: 24px; + --space-xl: 32px; + --space-2xl: 48px; + + /* Radius */ + --radius-sm: 4px; + --radius-md: 8px; + --radius-lg: 12px; + --radius-xl: 16px; + --radius-full: 999px; + + /* Shadows */ + --shadow-sm: 0 1px 3px rgba(0, 0, 0, 0.4); + --shadow-md: 0 4px 12px rgba(0, 0, 0, 0.5); + --shadow-lg: 0 8px 30px rgba(0, 0, 0, 0.6); + --shadow-glow: 0 0 20px rgba(108, 92, 231, 0.15); + + /* Transitions */ + --transition-fast: 150ms ease; + --transition-base: 250ms ease; + --transition-slow: 400ms ease; + + /* Layout */ + --sidebar-width: 280px; + --topbar-height: 56px; + --content-max-width: 900px; +} + +/* --- Mode: Light --- */ +[data-theme="light"] { + --pg-bg: #f5f5fa; + --pg-bg-elevated: #ffffff; + --pg-bg-card: #ffffff; + --pg-bg-hover: #ededf5; + --pg-bg-input: #f0f0f8; + --pg-text: #1a1a2e; + --pg-text-muted: #6b6b88; + --pg-text-dim: #9999b0; + --pg-border: #d8d8e8; + --pg-border-focus: #6c5ce7; + + --pg-violet: #5b4cdb; + --pg-cyan: #00a8a8; + --pg-magenta: #d63384; + --pg-amber: #e6a800; + --pg-lime: #009b77; + --pg-red: #c0392b; + --pg-blue: #0770c2; + --pg-orange: #d35400; + + --state-locked: #d0d0dd; + --state-not-started: #aaaabc; + + --shadow-sm: 0 1px 3px rgba(0, 0, 0, 0.08); + --shadow-md: 0 4px 12px rgba(0, 0, 0, 0.1); + --shadow-lg: 0 8px 30px rgba(0, 0, 0, 0.12); + --shadow-glow: 0 0 20px rgba(108, 92, 231, 0.08); +} + +/* ============================================================ + Base Reset & Globals + ============================================================ */ + +*, *::before, *::after { + box-sizing: border-box; + margin: 0; + padding: 0; +} + +html { + font-size: 16px; + -webkit-font-smoothing: antialiased; + -moz-osx-font-smoothing: grayscale; +} + +body { + font-family: var(--font-mono); + font-size: var(--font-size-base); + line-height: var(--line-height); + color: var(--pg-text); + background: var(--pg-bg); + overflow: hidden; + height: 100vh; + transition: background var(--transition-base), color var(--transition-base); +} + +a { + color: var(--pg-accent); + text-decoration: none; + transition: color var(--transition-fast); +} +a:hover { + color: var(--pg-accent-soft); + text-decoration: underline; +} + +h1, h2, h3, h4, h5, h6 { + font-family: var(--font-mono); + font-weight: 700; + letter-spacing: -0.02em; + line-height: 1.3; +} + +h1 { font-size: var(--font-size-2xl); } +h2 { font-size: var(--font-size-xl); } +h3 { font-size: var(--font-size-lg); } +h4 { font-size: var(--font-size-md); } + +code, pre, kbd { + font-family: var(--font-mono); +} + +code { + background: var(--pg-bg-input); + padding: 2px 6px; + border-radius: var(--radius-sm); + font-size: 0.85em; + color: var(--pg-magenta); +} + +pre { + background: var(--pg-bg-input); + padding: var(--space-md); + border-radius: var(--radius-md); + overflow-x: auto; + border: 1px solid var(--pg-border); +} + +pre code { + background: none; + padding: 0; + color: var(--pg-text); +} + +kbd { + background: var(--pg-bg-elevated); + border: 1px solid var(--pg-border); + border-bottom-width: 2px; + border-radius: var(--radius-sm); + padding: 1px 6px; + font-size: 0.8em; + color: var(--pg-text-muted); +} + +::selection { + background: var(--pg-violet); + color: #fff; +} + +/* Scrollbar styling */ +::-webkit-scrollbar { + width: 8px; + height: 8px; +} +::-webkit-scrollbar-track { + background: transparent; +} +::-webkit-scrollbar-thumb { + background: var(--pg-border); + border-radius: var(--radius-full); +} +::-webkit-scrollbar-thumb:hover { + background: var(--pg-text-dim); +} diff --git a/curriculum/curriculum.md b/curriculum/curriculum.md new file mode 100644 index 0000000000..cea1fd4d07 --- /dev/null +++ b/curriculum/curriculum.md @@ -0,0 +1,507 @@ +# Parameter Golf Mastery Curriculum + +A semester-long (16-week) curriculum covering every discipline required to dominate the OpenAI Parameter Golf challenge. Synthesized from the pedagogical traditions of CMU (systems + optimization), Stanford (deep learning theory + scaling), MIT (information theory + compression), and UC Berkeley (architecture + distributed systems). + +--- + +## Prerequisites + +- Linear algebra (eigenvalues, SVD, matrix norms, positive definiteness) +- Probability and statistics (MLE, Bayesian inference, hypothesis testing) +- Calculus and optimization (gradients, Hessians, convexity, Lagrange multipliers) +- Python fluency, PyTorch basics +- Comfort with the Unix command line and Git + +--- + +## Unit 1: Foundations of Language Modeling (Weeks 1-2) + +### Week 1: Statistical Language Models and Information Theory + +**Topics** +- Language modeling as next-token prediction +- Cross-entropy loss, perplexity, and bits-per-byte (BPB) +- Shannon entropy, source coding theorem, and the connection between compression and prediction +- KL divergence, mutual information +- Why BPB is the right metric for tokenizer-agnostic evaluation + +**Readings** +- Shannon, "A Mathematical Theory of Communication" (1948), Sections I-III +- Jurafsky & Martin, *Speech and Language Processing*, Ch. 3 (N-grams and Language Models) +- MacKay, *Information Theory, Inference and Learning Algorithms*, Ch. 1-6 + +**Exercises** +- Implement a character-level n-gram model and compute BPB on a text corpus +- Derive the relationship between cross-entropy loss, perplexity, and bits-per-byte +- Prove that cross-entropy is minimized when the model distribution equals the true distribution + +### Week 2: The Transformer Architecture + +**Topics** +- Self-attention: queries, keys, values, scaled dot-product attention +- Multi-head attention and why it works (subspace decomposition) +- Position encodings: sinusoidal, learned, RoPE +- Layer normalization variants: LayerNorm, RMSNorm, pre-norm vs post-norm +- Feed-forward networks: expansion factor, activation functions (ReLU, GELU, SwiGLU, ReLU^2) +- Residual connections and signal propagation in deep networks +- Autoregressive generation and causal masking + +**Readings** +- Vaswani et al., "Attention Is All You Need" (2017) +- Su et al., "RoFormer: Enhanced Transformer with Rotary Position Embedding" (2021) +- Zhang & Sennrich, "Root Mean Square Layer Normalization" (2019) +- Shazeer, "GLU Variants Improve Transformer" (2020) + +**Exercises** +- Implement a transformer decoder from scratch in PyTorch (no nn.Transformer) +- Implement RoPE and verify it produces correct relative position attention patterns +- Train a small (1-layer, 64-dim) language model on a toy corpus and verify convergence + +--- + +## Unit 2: Scaling Laws and the Parameter Golf Objective (Week 3) + +### Week 3: Neural Scaling Laws and L(N) Optimization + +**Topics** +- Kaplan et al. scaling laws: L(N), L(D), L(C) and their power-law relationships +- Chinchilla optimal: compute-optimal allocation between parameters and data +- The Parameter Golf objective as L(N) optimization: minimize loss given fixed N, unconstrained D and C +- Implications: at fixed N, how do depth, width, and architecture affect loss? +- Depth vs width tradeoffs: why deeper is more parameter-efficient (to a point) +- The role of the 10-minute training constraint as a soft compute bound + +**Readings** +- Kaplan et al., "Scaling Laws for Neural Language Models" (2020) +- Hoffmann et al., "Training Compute-Optimal Large Language Models" (Chinchilla, 2022) +- Tay et al., "Scale Efficiently: Insights from Pre-training and Fine-tuning Transformers" (2022) + +**Exercises** +- Fit power-law curves to training runs at 3-4 different model sizes and predict loss at a target size +- Experimentally determine: at 16MB, is it better to have 6 layers at 768d or 12 layers at 512d? +- Read the Parameter Golf leaderboard and categorize each submission's primary contribution axis (quantization, architecture, training, evaluation) + +--- + +## Unit 3: Efficient Architectures (Weeks 4-6) + +### Week 4: Grouped Query Attention, Multi-Query Attention, and KV-Cache Efficiency + +**Topics** +- Multi-query attention (Shazeer 2019): shared K/V heads +- Grouped query attention (GQA): interpolating between MHA and MQA +- Parameter cost analysis: how GQA reduces attention parameters +- Flash Attention: tiling, memory-efficient backward, IO complexity +- Flash Attention 2 and 3: Hopper-specific optimizations + +**Readings** +- Shazeer, "Fast Transformer Decoding: One Write-Head is All You Need" (2019) +- Ainslie et al., "GQA: Training Generalized Multi-Query Transformer Models" (2023) +- Dao et al., "FlashAttention: Fast and Memory-Efficient Exact Attention" (2022) +- Dao, "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (2023) +- Shah et al., "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision" (2024) + +**Exercises** +- Implement GQA from scratch and verify it matches standard MHA output when num_kv_heads = num_heads +- Profile memory usage of standard attention vs Flash Attention at sequence length 2048 +- Calculate exact parameter savings from GQA at various head ratios for the 512d/8-head baseline + +### Week 5: Parameter-Efficient Architecture Variants + +**Topics** +- Depth recurrence and the Universal Transformer (Dehghani et al.) + - Weight sharing across layers + - Per-layer conditioning: layer embeddings, adaptive computation time (ACT) + - Training stability with deep recurrence (gradient flow, normalization) +- Mixture of Experts (MoE) + - Sparse routing: top-k gating, expert capacity + - Parameter count vs active parameter count + - Load balancing losses +- Low-rank factorization + - LoRA and its variants + - Kronecker-factored layers +- U-Net / encoder-decoder skip connections in transformers + - The skip connection pattern used in the Parameter Golf baseline + - Learned skip weights + +**Readings** +- Dehghani et al., "Universal Transformers" (2019) +- Fedus et al., "Switch Transformers: Scaling to Trillion Parameter Models" (2022) +- Hu et al., "LoRA: Low-Rank Adaptation of Large Language Models" (2021) +- Bao et al., "All Are Worth Words: A ViT Backbone for Diffusion Models" (U-Net skips in transformers, 2023) + +**Exercises** +- Implement a weight-shared transformer (3 blocks x 4 iterations) with layer index conditioning +- Compare parameter count and loss: 12-layer standard vs 3-block x 4-iteration recurrent at same width +- Implement and ablate different per-layer conditioning strategies: additive embedding, multiplicative gate, FiLM conditioning + +### Week 6: Specialized Modules for Small Models + +**Topics** +- BigramHash embeddings: hash-based n-gram features for small vocabularies + - Hash function design, collision analysis + - Learned scaling and projection +- SmearGate: temporal smoothing via per-dimension gating +- Value Embeddings (VE): re-injecting token identity at deep layers +- Cross-Sequence Attention (XSA): removing self-value projection + - Why this forces meaningful contextual attention + - GQA-aware implementation +- Partial RoPE: applying position encoding to a subset of dimensions +- Logit soft-capping: tanh-based logit range control + +**Readings** +- Parameter Golf PRs: #162 (BigramHash), #65 (SmearGate), #374 (VE128), #478 (XSA), #315 (Partial RoPE) +- Bai et al., "Transformers as Algorithms: Generalization and Stability in In-context Learning" (2023) - theoretical grounding for why removing self-attention improves generalization + +**Exercises** +- Implement BigramHash with configurable vocabulary size and embedding dimension +- Ablation study: add/remove each module (SmearGate, VE, XSA, Partial RoPE) individually and measure BPB delta +- Analyze the collision rate of BigramHash at different vocabulary sizes (1024, 2048, 3072, 4096) + +--- + +## Unit 4: Tokenization (Week 7) + +### Week 7: Tokenizers and Their Impact on BPB + +**Topics** +- Byte Pair Encoding (BPE): algorithm, vocabulary construction, merge rules +- SentencePiece: unigram model vs BPE mode +- The tokenizer-agnostic BPB metric: how token-level loss converts to byte-level compression +- Why vocabulary size matters in parameter-constrained settings + - Embedding table cost: 2 x vocab_size x model_dim (with tied embeddings: 1x) + - Small vocab (1024): cheaper embeddings, more tokens per document, longer sequences needed + - Large vocab (8192+): expensive embeddings, fewer tokens, each token carries more information +- Tied vs untied embeddings: parameter cost analysis +- The BigramHash approach as a middle ground: small vocab + learned n-gram features +- Byte-level tokenization and its tradeoffs + +**Readings** +- Sennrich et al., "Neural Machine Translation of Rare Words with Subword Units" (BPE, 2016) +- Kudo, "Subword Regularization: Improving Neural Network Translation Models with Multiple Subword Candidates" (2018) +- Kudo & Richardson, "SentencePiece: A simple and language independent subword tokenizer" (2018) + +**Exercises** +- Train BPE tokenizers at vocab sizes 512, 1024, 2048, 4096, 8192 on FineWeb +- For each: compute tokens-per-byte ratio and estimate the embedding parameter cost at 512d +- Find the vocab size that minimizes total model BPB under a 16MB budget constraint + +--- + +## Unit 5: Optimization (Weeks 8-10) + +### Week 8: Optimizers for Small Model Training + +**Topics** +- Adam and AdamW: momentum, adaptive learning rates, weight decay +- The Muon optimizer + - Newton-Schulz orthogonalization: what it does and why + - The zeropower_via_newtonschulz5 iteration: deriving the (a, b, c) coefficients + - Muon as "spectral steepest descent" for matrix parameters + - Momentum in Muon: Nesterov momentum, warmup +- Optimizer partitioning: different optimizers for different parameter types + - Matrix params (Muon), scalar params (Adam), embeddings (Adam with higher LR) +- Learning rate schedules + - Linear warmup + - Cosine decay, linear decay + - Wallclock-aware warmdown: adapting to variable step times +- Gradient clipping: global norm clipping, when and why + +**Readings** +- Kingma & Ba, "Adam: A Method for Stochastic Optimization" (2015) +- Loshchilov & Hutter, "Decoupled Weight Decay Regularization" (AdamW, 2019) +- Jordan, "Muon: An optimizer for hidden layers in neural networks" (2024), blog post + code +- Bernstein et al., "Old Optimizer, New Norm: An Anthology" (2024) + +**Exercises** +- Implement Muon from scratch, including the Newton-Schulz iteration +- Compare Adam vs Muon on the Parameter Golf baseline: plot train loss curves over 1000 steps +- Implement wallclock-aware warmdown and verify it adapts correctly when step times vary + +### Week 9: Distributed Training and Parallel Optimization + +**Topics** +- Data parallelism: gradient averaging across GPUs +- DistributedDataParallel (DDP) in PyTorch +- Gradient accumulation: simulating larger batch sizes +- All-reduce, reduce-scatter, all-gather: collective communication primitives +- NVLink and the communication topology of 8xH100 SXM +- The Parallel Muon strategy + - Parameter Banking: storing weights as 3D tensors for batched operations + - Overlapping communication with computation + - Async reduce-scatter during Adam steps, then Newton-Schulz on shards +- Scaling batch size: critical batch size, gradient noise scale + +**Readings** +- Li et al., "PyTorch Distributed: Experiences on Accelerating Data Parallel Training" (2020) +- McCandlish et al., "An Empirical Model of Large-Batch Training" (2018) +- NVIDIA, "NCCL Developer Guide" (collective operations reference) + +**Exercises** +- Profile a training step on 1 GPU vs 8 GPUs: measure communication overhead +- Implement the Parameter Banking pattern: reshape layer weights into 3D bank tensors +- Implement async reduce-scatter + all-gather with overlap and measure throughput improvement + +### Week 10: Weight Averaging and Ensemble Methods + +**Topics** +- Exponential Moving Average (EMA) + - Decay rate selection, Polyak averaging + - Why EMA helps: smoothing over loss surface noise +- Stochastic Weight Averaging (SWA) + - Collecting snapshots during late training + - SWA vs EMA: when each works better +- LAWA (Latest-k Weight Average) +- Combining EMA and SWA +- Connection to flat minima and generalization +- When to start averaging: the lr_scale threshold approach + +**Readings** +- Polyak & Juditsky, "Acceleration of Stochastic Approximation by Averaging" (1992) +- Izmailov et al., "Averaging Weights Leads to Wider Optima and Better Generalization" (SWA, 2018) +- Kaddour et al., "Stop Wasting My Time! Saving Days of ImageNet and BERT Training with Latest Weight Averaging" (LAWA, 2022) + +**Exercises** +- Implement EMA with configurable decay and compare final BPB: EMA vs no EMA +- Implement SWA triggered by lr_scale threshold and measure improvement +- Experiment: what EMA decay rate is optimal for the Parameter Golf training duration? + +--- + +## Unit 6: Quantization and Compression (Weeks 11-13) + +### Week 11: Post-Training Quantization Fundamentals + +**Topics** +- Fixed-point number representation: int8, int6, int5, int4 +- Per-tensor vs per-row vs per-channel quantization +- Symmetric vs asymmetric quantization +- Calibration: choosing scale and zero-point + - Min-max, percentile clipping, MSE-optimal +- The quantization error budget: how rounding errors accumulate through layers +- Round-to-nearest vs more sophisticated rounding + +**Readings** +- Jacob et al., "Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference" (2018) +- Nagel et al., "A White Paper on Neural Network Quantization" (Qualcomm, 2021) +- Gholami et al., "A Survey of Quantization Methods for Efficient Neural Network Inference" (2021) + +**Exercises** +- Implement per-row int8 quantization with percentile clipping (reproduce the baseline's approach) +- Measure BPB degradation from int8, int6, int5, int4 uniform quantization on the trained baseline +- Plot reconstruction MSE vs bits-per-weight for different clipping percentiles + +### Week 12: Advanced Quantization (GPTQ, QAT) + +**Topics** +- GPTQ: Hessian-informed quantization + - The optimal brain quantization (OBQ) framework + - Hessian collection: H = X^T X from calibration data + - Cholesky decomposition and error compensation + - Column reordering for better error propagation + - Block-wise quantization for efficiency + - GPTQ-lite: diagonal Hessian approximation + - Full Hessian GPTQ: the complete algorithm +- Calibration data strategies + - Training data calibration (ruled out by Parameter Golf rules during eval) + - Autoregressive self-generated calibration: the model generates its own data +- Quantization-Aware Training (QAT) + - Straight-Through Estimator (STE): gradients through discontinuous rounding + - Fake quantization during training + - Late QAT: enabling STE only in the final training phase (when lr is low) + - Why late QAT works: the model learns to place weights near quantization grid points +- Mixed precision: different bit-widths for different layers or tensor types + +**Readings** +- Frantar et al., "GPTQ: Accurate Post-Training Quantization for Generative Pre-Trained Transformers" (2023) +- Nagel et al., "Up or Down? Adaptive Rounding for Post-Training Quantization" (AdaRound, 2020) +- Bengio et al., "Estimating or Propagating Gradients Through Stochastic Neurons" (STE, 2013) + +**Exercises** +- Implement GPTQ with diagonal Hessian (GPTQ-lite) and compare vs uniform quantization +- Implement full Hessian GPTQ with Cholesky error compensation +- Implement late QAT with STE: add fake quantization to CastedLinear when lr_scale < threshold +- Implement AR self-generated calibration: generate sequences from the trained model for GPTQ + +### Week 13: Compression and Artifact Size Optimization + +**Topics** +- Entropy coding: Huffman, arithmetic coding, ANS +- General-purpose compression: zlib, zstd, lzma + - Why lzma compresses quantized weights better than zlib + - Compression level tradeoffs (speed vs ratio) +- The 16MB artifact budget: code bytes + compressed model bytes + - Strategies for minimizing code size + - Strategies for maximizing compressibility of quantized weights +- Selective pruning: removing low-impact quantized values + - Reconstruction error as pruning criterion + - Binary search for size target +- Ternary quantization: {-1, 0, +1} weights + - Extreme compression, extreme loss of precision + - When it can work (with enough parameters and proper scaling) +- 1-bit quantization: binary weights with learned scales + +**Readings** +- Zhu et al., "Trained Ternary Quantization" (2017) +- Rastegari et al., "XNOR-Net: ImageNet Classification Using Binary Convolutional Neural Networks" (2016) +- The LZMA SDK documentation and compression algorithm description + +**Exercises** +- Compare zlib-9, zstd-22, and lzma-9 compression ratios on the same quantized model checkpoint +- Implement selective pruning: sort quantized values by reconstruction error, prune to hit size target +- Compute: at int6 with lzma-9, how many raw parameters can fit in 16MB? What about int4? + +--- + +## Unit 7: Evaluation Methods (Week 14) + +### Week 14: Evaluation Strategies and Test-Time Compute + +**Topics** +- Standard autoregressive evaluation: fixed context, sequential BPB +- Sliding window evaluation + - Stride selection: trading compute for better context + - Why it improves BPB: each token sees maximum available context + - Implementation: scoring only "new" tokens to avoid double-counting +- Test-Time Training (TTT) + - Adapting model parameters on previously-evaluated tokens + - LoRA TTT: lightweight adaptation during evaluation + - Legal TTT in Parameter Golf: only train on tokens you've already scored + - Score-first TTT: evaluate, then adapt, ensuring no data leakage +- Long-context evaluation + - Extending eval sequence length beyond training length + - Position extrapolation: NTK-aware RoPE scaling, YaRN + - Memory and compute costs of long-context eval +- Evaluation time budget: fitting within 10 minutes on 8xH100 + +**Readings** +- Press et al., "Train Short, Test Long: Attention with Linear Biases Enables Input Length Generalization" (ALiBi, 2022) +- Sun et al., "Learning to (Learn at Test Time): RNNs with Expressive Hidden States" (TTT, 2024) +- Bloc97, "NTK-Aware Scaled RoPE" (2023, online post) + +**Exercises** +- Implement sliding window evaluation with configurable stride +- Measure BPB improvement from sliding window eval at strides 32, 64, 128, 256 +- Implement a minimal LoRA TTT loop: fine-tune rank-4 LoRA on evaluated tokens, measure BPB delta +- Profile eval time: how many sliding window passes fit in 10 minutes on 1xH100? + +--- + +## Unit 8: Systems and Performance (Week 15) + +### Week 15: GPU Programming, Kernels, and Training Throughput + +**Topics** +- GPU architecture: SMs, warps, memory hierarchy (registers, shared memory, L2, HBM) +- The H100 Hopper architecture: TMA, warp specialization, FP8 tensor cores +- torch.compile: how it works, tracing, graph breaks, fullgraph mode +- Profiling training runs: torch.profiler, nsight systems, identifying bottlenecks +- Memory optimization + - Activation checkpointing / gradient checkpointing + - Mixed precision training: bf16 forward, fp32 optimizer state + - Memory-efficient attention (Flash Attention) vs standard attention memory cost +- Maximizing tokens per second + - Batch size tuning for GPU utilization + - Reducing Python overhead: compiled models, fused kernels + - Data loading: async prefetch, pinned memory, non-blocking transfers +- Custom CUDA kernels and Triton: when and why + - Fused operations (e.g., fused RMSNorm + linear) + - Megakernels: combining multiple operations into a single GPU launch + +**Readings** +- NVIDIA, "H100 Tensor Core GPU Architecture" whitepaper (2022) +- Ansel et al., "PyTorch 2: Faster Machine Learning Through Dynamic Python Bytecode Transformation and Graph Compilation" (2024) +- Tillet et al., "Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations" (2019) + +**Exercises** +- Profile a training step with torch.profiler: identify the top 5 time-consuming operations +- Implement activation checkpointing for the transformer blocks and measure memory savings +- Write a Triton kernel for fused RMSNorm and benchmark against PyTorch's F.rms_norm +- Experiment: what is the maximum batch size (tokens/step) that fits in 80GB H100 HBM for the baseline model? + +--- + +## Unit 9: Integration and Competition Strategy (Week 16) + +### Week 16: Putting It All Together + +**Topics** +- The SOTA stack walkthrough: reading and understanding the current #1 submission line by line +- Contribution axes and diminishing returns analysis + - Where is the most remaining headroom? + - Quantization frontier: int6 to int5 to int4, marginal gains + - Architectural innovations: what hasn't been tried? + - Training efficiency: can we fit more steps in 600 seconds? + - Evaluation tricks: how much more can sliding window + TTT give? +- Experiment design for competition + - Ablation methodology: single-variable changes, controlled comparisons + - Statistical significance: Welch's t-test, p < 0.01, 3-seed validation + - The 0.005-nat improvement threshold +- Submission process and reproducibility + - Records folder structure + - Logging and artifact generation + - README documentation standards +- Risk management + - Compute cost estimation before launching 8xH100 runs + - Iterating cheaply on 1xH100 or local GPU before scaling + - Version control for experiments: branches, tags, logs + +**Readings** +- The full PR history of Parameter Golf: #1019, #549, #414, #287, #198, #162, #65 +- All README files in the records/track_10min_16mb/ directory +- The Parameter Golf FAQ and rules (challenge page + README) + +**Capstone Project** +- Starting from the current SOTA codebase, implement one novel improvement +- Run 3-seed validation on 8xH100 +- Prepare a complete submission: README, submission.json, train_gpt.py, train logs +- Present results with statistical analysis (mean, std, Welch's t-test) + +--- + +## Appendix A: Key Papers Reference List + +| Topic | Paper | Year | +|-------|-------|------| +| Transformer | Vaswani et al., "Attention Is All You Need" | 2017 | +| Scaling Laws | Kaplan et al., "Scaling Laws for Neural Language Models" | 2020 | +| Chinchilla | Hoffmann et al., "Training Compute-Optimal Large Language Models" | 2022 | +| RoPE | Su et al., "RoFormer" | 2021 | +| Flash Attention | Dao et al., "FlashAttention" | 2022 | +| Universal Transformer | Dehghani et al., "Universal Transformers" | 2019 | +| BPE | Sennrich et al., "Neural Machine Translation of Rare Words" | 2016 | +| Adam | Kingma & Ba, "Adam" | 2015 | +| SWA | Izmailov et al., "Averaging Weights Leads to Wider Optima" | 2018 | +| GPTQ | Frantar et al., "GPTQ" | 2023 | +| STE | Bengio et al., "Estimating or Propagating Gradients Through Stochastic Neurons" | 2013 | +| MoE | Fedus et al., "Switch Transformers" | 2022 | +| LoRA | Hu et al., "LoRA" | 2021 | +| TTT | Sun et al., "Learning to (Learn at Test Time)" | 2024 | + +## Appendix B: Recommended Tool Proficiency + +| Tool | Purpose | Priority | +|------|---------|----------| +| PyTorch | Model implementation, training loops | Critical | +| torch.compile | Kernel fusion, graph optimization | Critical | +| Flash Attention | Efficient attention computation | Critical | +| torchrun | Distributed training launcher | Critical | +| NCCL | GPU-to-GPU communication | High | +| SentencePiece | Tokenizer training and inference | High | +| torch.profiler | Performance analysis | High | +| nsight systems | GPU-level profiling | Medium | +| Triton | Custom GPU kernels | Medium | +| CUDA | Low-level GPU programming | Medium | +| lzma/zstd/zlib | Model compression | Medium | + +## Appendix C: Compute Planning + +| Phase | Hardware | Estimated Hours | Cost Estimate | +|-------|----------|-----------------|---------------| +| Weeks 1-6 exercises | Local GPU or 1xH100 | 20 hrs | $50 | +| Weeks 7-10 exercises | 1xH100 | 30 hrs | $75 | +| Weeks 11-13 exercises | 1xH100 | 20 hrs | $50 | +| Week 14 eval experiments | 1-8xH100 | 15 hrs | $75 | +| Week 15 profiling | 1xH100 | 10 hrs | $25 | +| Week 16 capstone | 8xH100 | 20 hrs | $400 | +| **Total** | | **~115 hrs** | **~$675** | diff --git a/curriculum/data/flashcards.json b/curriculum/data/flashcards.json new file mode 100644 index 0000000000..d024eba074 --- /dev/null +++ b/curriculum/data/flashcards.json @@ -0,0 +1,140 @@ +[ + { + "cardId": "u1-shannon-entropy", + "unitId": 1, + "front": "What is Shannon entropy H(X)?", + "back": "H(X) = -\u03a3 p(x) log\u2082 p(x) \u2014 the minimum average bits per symbol to encode source X" + }, + { + "cardId": "u1-cross-entropy", + "unitId": 1, + "front": "What is cross-entropy H(p, q)?", + "back": "H(p, q) = -\u03a3 p(x) log q(x) \u2014 average bits when using model q to encode source with true distribution p" + }, + { + "cardId": "u1-kl-divergence", + "unitId": 1, + "front": "What is KL divergence D_KL(p || q)?", + "back": "D_KL(p || q) = \u03a3 p(x) log(p(x)/q(x)) = H(p,q) - H(p) \u2265 0. Extra bits from using q instead of p." + }, + { + "cardId": "u1-perplexity", + "unitId": 1, + "front": "How is perplexity defined in terms of cross-entropy?", + "back": "PPL = e^(H(p,q)) in nats, or 2^(H) in bits. Lower perplexity = better model." + }, + { + "cardId": "u1-bpb-formula", + "unitId": 1, + "front": "How is BPB (bits-per-byte) computed from val_loss?", + "back": "BPB = (val_loss / ln(2)) \u00d7 (tokens / bytes). Tokenizer-agnostic compression metric." + }, + { + "cardId": "u1-compression-prediction", + "unitId": 1, + "front": "Why does compression = prediction?", + "back": "Source coding theorem: optimal compression needs H(X) bits/symbol. A model achieving H(p,q) = H(p) is a perfect predictor, yielding optimal compression." + }, + { + "cardId": "u1-attention-formula", + "unitId": 1, + "front": "What is the scaled dot-product attention formula?", + "back": "Attn(Q,K,V) = softmax(QK\u1d40 / \u221ad_k) \u00b7 V" + }, + { + "cardId": "u1-rope-property", + "unitId": 1, + "front": "What is the key property of RoPE (Rotary Position Embedding)?", + "back": "q_m\u1d40 k_n depends only on relative position (m-n), achieved by rotating Q/K vectors by position-dependent angles." + }, + { + "cardId": "u1-rmsnorm", + "unitId": 1, + "front": "How does RMSNorm differ from LayerNorm?", + "back": "RMSNorm: x / RMS(x), no mean centering or learnable bias. Cheaper and equally effective. RMS(x) = \u221a(mean(x\u00b2))." + }, + { + "cardId": "u1-relu-squared", + "unitId": 1, + "front": "What is the ReLU\u00b2 activation used in the baseline MLP?", + "back": "FFN(x) = W\u2082 \u00b7 (ReLU(W\u2081 x))\u00b2. Squaring after ReLU creates smoother, sparser activations." + }, + { + "cardId": "u1-residual-why", + "unitId": 1, + "front": "Why are residual connections critical in deep transformers?", + "back": "They create a direct gradient path from output to input, preventing vanishing gradients. Output = x + F(x) means gradients always have a +1 component." + }, + { + "cardId": "u1-causal-mask", + "unitId": 1, + "front": "What does the causal mask do in autoregressive attention?", + "back": "Sets attention weights to 0 for all positions j > i (future tokens). Token i can only attend to tokens 1..i. Enforced by masking QK\u1d40 with -\u221e before softmax." + }, + { + "cardId": "u2-scaling-law-ln", + "unitId": 2, + "front": "What is the Kaplan scaling law L(N)?", + "back": "L(N) = (N_c / N)^alpha_N where alpha_N \u2248 0.076. Loss decreases as a power law with parameter count." + }, + { + "cardId": "u2-chinchilla", + "unitId": 2, + "front": "What is the Chinchilla-optimal ratio of parameters to data?", + "back": "For fixed compute C: scale N \u221d C^0.5 and D \u221d C^0.5 (roughly equal scaling). Previous practice over-allocated to parameters." + }, + { + "cardId": "u2-pgolf-as-ln", + "unitId": 2, + "front": "Why is Parameter Golf an L(N) optimization problem?", + "back": "Fixed N (16MB artifact), unconstrained D and C (8B+ tokens, 10 min 8xH100). Goal: minimize loss at fixed parameter count." + }, + { + "cardId": "u2-depth-vs-width", + "unitId": 2, + "front": "At fixed param count, is deeper or wider generally better?", + "back": "Deeper (more layers, smaller dim) generally wins. Leaderboard evolved 9\u219211 layers. Depth recurrence takes this to its logical extreme." + }, + { + "cardId": "u3-gqa-ratio", + "unitId": 3, + "front": "In the baseline GQA-4 (8 heads, 4 KV heads), what are the KV parameter savings?", + "back": "KV dim = 4\u00d764 = 256 instead of 512. Saves 50% of K+V projection parameters per layer (262K vs 524K)." + }, + { + "cardId": "u3-flash-attn", + "unitId": 3, + "front": "How does Flash Attention reduce memory from O(T\u00b2) to O(T)?", + "back": "Tiles Q,K,V computation into SRAM-sized blocks. Computes partial softmax in tiles, accumulates without materializing the full T\u00d7T attention matrix in HBM." + }, + { + "cardId": "u3-universal-transformer", + "unitId": 3, + "front": "What is the Universal Transformer's key idea?", + "back": "Share weights across layers, run K blocks N times each = K\u00d7N effective depth for cost of K layers. Requires per-iteration conditioning." + }, + { + "cardId": "u3-bigramhash", + "unitId": 3, + "front": "What is BigramHash and why is it useful?", + "back": "Hash adjacent token pairs into an embedding table: hash(t[i-1], t[i]) = (36313*t[i] ^ 27191*t[i-1]) % vocab. Adds n-gram context cheaply for small vocabularies." + }, + { + "cardId": "u3-xsa", + "unitId": 3, + "front": "What does XSA (Cross-Sequence Attention) do?", + "back": "Removes each token's self-value projection from attention output: y = y - proj(y onto v). Forces meaningful cross-token attention. Zero parameter cost." + }, + { + "cardId": "u3-smeargate", + "unitId": 3, + "front": "How does SmearGate work?", + "back": "Per-dimension gate: out = (1-\u03c3(g))*x + \u03c3(g)*x_prev. Learnable temporal smoothing between positions. Init near zero (pass-through)." + }, + { + "cardId": "u3-partial-rope", + "unitId": 3, + "front": "What is Partial RoPE and what ratio does the SOTA use?", + "back": "Apply RoPE to only a subset of head dimensions: 16 out of 64 (25%). Remaining dims attend purely by content. Saves compute, surprisingly effective." + } +] diff --git a/curriculum/data/leaderboard.json b/curriculum/data/leaderboard.json new file mode 100644 index 0000000000..acb0f4ab93 --- /dev/null +++ b/curriculum/data/leaderboard.json @@ -0,0 +1,23 @@ +[ + { "rank": 1, "run": "11L AR Self-Gen GPTQ + XSA", "score": 1.1147, "author": "abaybektursun", "date": "2026-03-25", "pr": 1019 }, + { "rank": 2, "run": "LeakyReLU\u00b2 + Legal TTT + Parallel Muon", "score": 1.1194, "author": "abaybektursun", "date": "2026-03-23", "pr": 549 }, + { "rank": 3, "run": "11L EMA + GPTQ-lite + warmdown3500", "score": 1.1228, "author": "signalrush", "date": "2026-03-22", "pr": 374 }, + { "rank": 4, "run": "11L Partial RoPE + LN Scale + EMA + XSA4", "score": 1.1248, "author": "jfprincz", "date": "2026-03-21", "pr": 287 }, + { "rank": 5, "run": "11L XSA4 + EMA + Int6 MLP3x", "score": 1.1271, "author": "jfprincz", "date": "2026-03-20", "pr": 198 }, + { "rank": 6, "run": "11L Efficient Partial XSA", "score": 1.1307, "author": "unnir", "date": "2026-03-20", "pr": 198 }, + { "rank": 7, "run": "10L Int5-MLP + BigramHash(10240)", "score": 1.1428, "author": "thwu1", "date": "2026-03-20", "pr": null }, + { "rank": 8, "run": "Int6 MLP3x + SmearGate + BigramHash", "score": 1.1458, "author": "Raahil Shah", "date": "2026-03-20", "pr": null }, + { "rank": 9, "run": "11L MLP3x + Int6 QAT", "score": 1.1502, "author": "aruniyer", "date": "2026-03-20", "pr": null }, + { "rank": 10, "run": "SmearGate + OrthoInit + Muon WD", "score": 1.1556, "author": "aquariouseworkman", "date": "2026-03-19", "pr": null }, + { "rank": 11, "run": "Ternary Quantization", "score": 1.1570, "author": "Ciprian-Florin Ifrim", "date": "2026-03-24", "pr": null }, + { "rank": 12, "run": "10L Int6 QAT + Zstd MLP2.6x", "score": 1.1586, "author": "yahya010", "date": "2026-03-19", "pr": null }, + { "rank": 13, "run": "Mixed Quant + Sliding Window Eval", "score": 1.1630, "author": "aquariouseworkman", "date": "2026-03-19", "pr": null }, + { "rank": 14, "run": "Muon WD + 10 layer", "score": 1.1748, "author": "notapplica", "date": "2026-03-19", "pr": null }, + { "rank": 15, "run": "Sliding Window Eval", "score": 1.1925, "author": "Matthew Li", "date": "2026-03-19", "pr": null }, + { "rank": 16, "run": "LoRA TTT", "score": 1.1928, "author": "samacqua", "date": "2026-03-19", "pr": null }, + { "rank": 17, "run": "4k seq length", "score": 1.2014, "author": "Spokane Way", "date": "2026-03-19", "pr": null }, + { "rank": 18, "run": "2048 seq length", "score": 1.206, "author": "Spokane Way", "date": "2026-03-18", "pr": null }, + { "rank": 19, "run": "int6 mixed precision", "score": 1.2147, "author": "Nan Liu", "date": "2026-03-18", "pr": null }, + { "rank": 20, "run": "fp16 Embed", "score": 1.2197, "author": "Renier Velazco", "date": "2026-03-18", "pr": null }, + { "rank": 21, "run": "Naive Baseline", "score": 1.2244, "author": "Baseline", "date": "2026-03-18", "pr": null } +] diff --git a/curriculum/data/quizzes.json b/curriculum/data/quizzes.json new file mode 100644 index 0000000000..fbf09beb7d --- /dev/null +++ b/curriculum/data/quizzes.json @@ -0,0 +1,301 @@ +{ + "w1-info-theory": { + "weekId": 1, + "quizId": "w1-info-theory", + "title": "Week 1: Information Theory Foundations", + "questions": [ + { + "id": "q1", + "question": "What does Shannon entropy H(X) represent?", + "options": [ + "The maximum possible compression ratio", + "The minimum average bits per symbol to encode a source", + "The expected prediction accuracy of a model", + "The total information content of a dataset" + ], + "answer": 1, + "explanation": "Shannon entropy H(X) = -\u03a3 p(x) log\u2082 p(x) gives the theoretical minimum average bits per symbol needed to encode the source. This is the source coding theorem." + }, + { + "id": "q2", + "question": "If H(p, q) is cross-entropy and H(p) is entropy, what is D_KL(p || q)?", + "options": [ + "H(p) - H(p, q)", + "H(p, q) - H(p)", + "H(p) + H(p, q)", + "H(p, q) / H(p)" + ], + "answer": 1, + "explanation": "D_KL(p || q) = H(p, q) - H(p). It represents the extra bits needed when using model q instead of the true distribution p." + }, + { + "id": "q3", + "question": "Why is BPB (bits-per-byte) used instead of cross-entropy loss for the Parameter Golf leaderboard?", + "options": [ + "BPB is easier to compute", + "BPB is always lower than cross-entropy", + "BPB is tokenizer-agnostic, enabling fair comparison across different vocabularies", + "BPB accounts for model size in its calculation" + ], + "answer": 2, + "explanation": "Different tokenizers produce different numbers of tokens per document. BPB normalizes by bytes instead of tokens, so a 1024-vocab model and a 32K-vocab model are compared on the same scale." + }, + { + "id": "q4", + "question": "What is perplexity in terms of cross-entropy H(p,q) (in nats)?", + "options": [ + "2^H(p,q)", + "e^H(p,q)", + "log(H(p,q))", + "1 / H(p,q)" + ], + "answer": 1, + "explanation": "When cross-entropy is measured in nats (natural log), perplexity = e^H(p,q). When in bits (log base 2), perplexity = 2^H." + }, + { + "id": "q5", + "question": "The statement 'compression = prediction' means:", + "options": [ + "Compressed models are better predictors", + "A perfect next-token predictor achieves the optimal compression rate", + "You must compress the model before evaluating prediction quality", + "Prediction accuracy improves with smaller file sizes" + ], + "answer": 1, + "explanation": "The source coding theorem shows that a model achieving H(p,q) = H(p) is both a perfect predictor and an optimal compressor. Minimizing cross-entropy simultaneously optimizes both." + } + ] + }, + "w2-transformer": { + "weekId": 2, + "quizId": "w2-transformer", + "title": "Week 2: Transformer Architecture", + "questions": [ + { + "id": "q1", + "question": "In scaled dot-product attention, why do we divide by sqrt(d_k)?", + "options": [ + "To normalize the output to unit variance", + "To prevent dot products from growing large, keeping softmax in a well-behaved regime", + "To make the computation faster on GPUs", + "To ensure the attention weights sum to 1" + ], + "answer": 1, + "explanation": "When d_k is large, the dot products QK^T can have large magnitude, pushing softmax into regions with vanishingly small gradients. Scaling by 1/sqrt(d_k) keeps the variance of the dot products around 1." + }, + { + "id": "q2", + "question": "What is the key property of RoPE (Rotary Position Embedding)?", + "options": [ + "It uses learned position embeddings that are added to the input", + "It makes the attention score q_m^T k_n depend only on relative position (m-n)", + "It applies different frequencies to different layers", + "It eliminates the need for position information entirely" + ], + "answer": 1, + "explanation": "RoPE rotates query and key vectors by position-dependent angles. The inner product q_m^T k_n depends on the angle difference, which is proportional to (m-n), making it a relative position encoding." + }, + { + "id": "q3", + "question": "In the Parameter Golf baseline, what is the MLP activation function?", + "options": [ + "GELU", + "SwiGLU", + "ReLU^2 (squared ReLU)", + "Sigmoid" + ], + "answer": 2, + "explanation": "The baseline uses ReLU^2: FFN(x) = W_2 * (ReLU(W_1 x))^2. The SOTA uses LeakyReLU(0.5)^2." + }, + { + "id": "q4", + "question": "What does the causal mask enforce in autoregressive attention?", + "options": [ + "Each token can only attend to itself", + "Each token can attend to all other tokens", + "Token at position i can only attend to positions 1 through i", + "Attention weights are uniformly distributed" + ], + "answer": 2, + "explanation": "The causal mask sets attention weights to 0 (via -infinity before softmax) for all future positions j > i. This ensures the model can only use past context for prediction." + }, + { + "id": "q5", + "question": "Why are residual connections important in deep transformers?", + "options": [ + "They reduce the number of parameters", + "They speed up inference", + "They create direct gradient paths preventing vanishing gradients", + "They improve the tokenizer's vocabulary coverage" + ], + "answer": 2, + "explanation": "With y = x + F(x), the gradient dy/dx = 1 + dF/dx always has a +1 component. This prevents gradients from vanishing through many layers, enabling training of deep networks." + } + ] + }, + "w3-scaling-laws": { + "weekId": 3, + "quizId": "w3-scaling-laws", + "title": "Week 3: Scaling Laws & L(N)", + "questions": [ + { + "id": "q1", + "question": "In Kaplan et al. scaling laws, L(N) represents:", + "options": [ + "Loss as a function of training time", + "Loss as a function of parameter count, following a power law", + "Loss as a function of learning rate", + "Loss as a function of batch size" + ], + "answer": 1, + "explanation": "L(N) = (N_c / N)^alpha_N describes how loss decreases as a power law with increasing parameter count N, when data and compute are unconstrained." + }, + { + "id": "q2", + "question": "What did the Chinchilla paper (Hoffmann et al.) demonstrate?", + "options": [ + "Bigger models are always better", + "For a fixed compute budget, parameters and data should be scaled roughly equally", + "Small models cannot achieve good perplexity", + "Quantization is necessary for efficient training" + ], + "answer": 1, + "explanation": "Chinchilla showed that previous practice (GPT-3) over-allocated to parameters. For compute-optimal training, N and D should scale proportionally: N ~ C^0.5, D ~ C^0.5." + }, + { + "id": "q3", + "question": "In Parameter Golf, why is the L(N) framing useful but incomplete?", + "options": [ + "Because the dataset is too small", + "Because the architecture is fixed", + "Because N is measured in compressed bytes, not raw parameters, and architecture/quantization are free", + "Because there is no compute constraint" + ], + "answer": 2, + "explanation": "Standard scaling laws assume standard transformers. In Parameter Golf, N is measured as compressed artifact size, and techniques like quantization, weight sharing, and better compression shift the effective L(N) curve." + }, + { + "id": "q4", + "question": "At fixed parameter count, which generally produces lower loss?", + "options": [ + "Wider model (fewer layers, larger hidden dim)", + "Deeper model (more layers, smaller hidden dim)", + "They are always equivalent", + "It depends entirely on the dataset" + ], + "answer": 1, + "explanation": "Research and the Parameter Golf leaderboard show that deeper models are generally more parameter-efficient. The leaderboard evolved from 9 layers (baseline) to 11 layers, trading width for depth." + }, + { + "id": "q5", + "question": "The 10-minute training constraint on 8xH100 means:", + "options": [ + "You can only train for 10 steps", + "Faster architectures can fit more training steps, creating a training efficiency tradeoff", + "All submissions train for the same number of steps", + "The constraint only applies to evaluation" + ], + "answer": 1, + "explanation": "At ~87ms/step (SOTA), you get ~6900 steps in 600 seconds. A technique that improves loss-per-step but doubles step time may be net negative because you only get half as many steps." + } + ] + }, + "w4-attention": { + "weekId": 4, + "quizId": "w4-attention", + "title": "Week 4: Attention Efficiency", + "questions": [ + { + "id": "q1", + "question": "In GQA with 8 query heads and 4 KV heads, how many query heads share each KV head?", + "options": [ + "1", + "2", + "4", + "8" + ], + "answer": 1, + "explanation": "With 8 query heads and 4 KV groups, each KV head is shared by 8/4 = 2 query heads. This halves the KV parameter cost compared to full multi-head attention." + }, + { + "id": "q2", + "question": "What is the key advantage of Flash Attention over standard attention?", + "options": [ + "It approximates attention for faster computation", + "It computes exact attention with O(T) memory instead of O(T^2) by tiling in SRAM", + "It reduces the number of attention heads needed", + "It eliminates the need for the causal mask" + ], + "answer": 1, + "explanation": "Flash Attention tiles the Q, K, V computation into blocks that fit in SRAM, computing partial softmax results without materializing the full T x T attention matrix in HBM. It is exact, not approximate." + }, + { + "id": "q3", + "question": "For the 512d / 8-head baseline with GQA-4, how many parameters are in the K and V projections per layer?", + "options": [ + "512 x 512 x 2 = 524,288", + "256 x 512 x 2 = 262,144", + "64 x 512 x 2 = 65,536", + "512 x 512 = 262,144" + ], + "answer": 1, + "explanation": "With 4 KV heads at head_dim=64, kv_dim = 4 x 64 = 256. So K and V each need 256 x 512 = 131,072 params. Total K+V = 262,144 params per layer." + }, + { + "id": "q4", + "question": "Flash Attention 3 specifically targets which GPU architecture?", + "options": [ + "NVIDIA Ampere (A100)", + "NVIDIA Hopper (H100) with warp specialization and TMA", + "NVIDIA Blackwell (B200)", + "AMD MI300X" + ], + "answer": 1, + "explanation": "Flash Attention 3 uses Hopper-specific features: warp specialization (producer-consumer warps), Tensor Memory Accelerator (TMA), and FP8 accumulation for maximum throughput on H100 GPUs." + } + ] + }, + "w7-tokenization": { + "weekId": 7, + "quizId": "w7-tokenization", + "title": "Week 7: Tokenization", + "questions": [ + { + "id": "q1", + "question": "With tied embeddings at d=512, how many parameters does a 1024-token vocabulary cost?", + "options": [ + "524,288 (512 x 1024)", + "1,048,576 (512 x 1024 x 2)", + "262,144 (256 x 1024)", + "2,097,152 (512 x 4096)" + ], + "answer": 0, + "explanation": "With tied embeddings, the embedding table is shared with the output projection: V x d = 1024 x 512 = 524,288 params. Without tying, it would be 2x." + }, + { + "id": "q2", + "question": "Why does the SOTA use a 1024-token vocabulary + BigramHash instead of a larger vocabulary?", + "options": [ + "Larger vocabularies are not supported by SentencePiece", + "Small vocab saves embedding parameters, BigramHash adds n-gram context cheaply via hashing", + "Larger vocabularies always produce worse BPB", + "The rules prohibit vocabularies larger than 1024" + ], + "answer": 1, + "explanation": "A 1024-token vocab costs only 524K params (3% of budget). BigramHash adds 344K params for 3072 bigram entries at 112d. Together they approximate a much larger vocab at a fraction of the parameter cost." + }, + { + "id": "q3", + "question": "What makes BPB tokenizer-agnostic?", + "options": [ + "It only counts ASCII bytes", + "It normalizes by bytes instead of tokens, so different vocab sizes produce comparable scores", + "It uses a fixed reference tokenizer for all submissions", + "It ignores the tokenizer entirely" + ], + "answer": 1, + "explanation": "BPB = (bits_per_token) x (tokens/bytes). A large vocab has fewer tokens but more bits per token. Normalizing by bytes cancels out the tokenizer's effect on token count." + } + ] + } +} diff --git a/curriculum/index.html b/curriculum/index.html new file mode 100644 index 0000000000..d87b3e7c24 --- /dev/null +++ b/curriculum/index.html @@ -0,0 +1,278 @@ + + + + + + Parameter Golf Curriculum + + + + + + + + + + + + + + + + + + + + +
+
+ Loading curriculum... +
+ + + + + + + + diff --git a/curriculum/js/dashboard.js b/curriculum/js/dashboard.js new file mode 100644 index 0000000000..7f704e55f5 --- /dev/null +++ b/curriculum/js/dashboard.js @@ -0,0 +1,195 @@ +// ============================================================ +// Dashboard View +// ============================================================ + +import * as DB from './db.js'; +import { UNITS, navigate } from './router.js'; + +// Unit color mapping +const UNIT_COLORS = { + violet: '#6c5ce7', + cyan: '#00cec9', + magenta: '#e84393', + amber: '#fdcb6e', + lime: '#00b894', + red: '#d63031', + blue: '#0984e3', + orange: '#e17055', +}; + +function progressRingSVG(percent, size = 48, stroke = 4, color = '#6c5ce7') { + const r = (size - stroke) / 2; + const circ = 2 * Math.PI * r; + const offset = circ - (percent / 100) * circ; + return ` + + + + + `; +} + +export async function renderDashboard(container) { + const overall = await DB.getOverallProgress(); + const today = new Date().toISOString().slice(0, 10); + const dueCards = await DB.getCardsForReview(today); + const allCards = await DB.getAllFlashcards(); + + // Compute per-unit progress + const unitProgress = []; + for (const unit of UNITS) { + const prog = await DB.getUnitProgress(unit.id); + unitProgress.push({ ...unit, ...prog }); + } + + const completedUnits = unitProgress.filter(u => u.percent === 100).length; + const inProgressUnits = unitProgress.filter(u => u.percent > 0 && u.percent < 100).length; + + container.innerHTML = ` +
+
+

Parameter Golf Curriculum

+
+ 9 Units + 16 Weeks + ${overall.percent}% Complete +
+
+ + +
+
+
${overall.completed}
+
Items Done
+
+
+
${completedUnits}
+
Units Complete
+
+
+
${inProgressUnits}
+
In Progress
+
+
+
${dueCards.length}
+
Cards Due
+
+
+ + +
+ \u{1F5FA} +

Unit Map

+
+
+ ${unitProgress.map(u => { + const color = UNIT_COLORS[u.color] || UNIT_COLORS.violet; + const statusClass = u.percent === 100 ? 'complete' : u.percent > 0 ? 'in-progress' : ''; + return ` +
+
${u.id}
+
${u.name}
+
${u.total > 0 ? `${u.percent}%` : 'Not started'}
+
+ `; + }).join('')} +
+ + +
+ \u{1F4CA} +

Overall Progress

+
+
+
+
+ ${progressRingSVG(overall.percent, 80, 6, '#6c5ce7')} + ${overall.percent}% +
+
+
+
+
+

+ ${overall.completed} of ${overall.total} items completed +

+
+
+
+ + +
+ \u{1F4CB} +

Unit Breakdown

+
+
+ + + + + + + + + + + ${unitProgress.map(u => { + const color = UNIT_COLORS[u.color] || UNIT_COLORS.violet; + return ` + + + + + + + `; + }).join('')} + +
UnitWeeksProgressDone
+ ${u.icon} + ${u.name} + + ${u.weeks.length > 1 ? `${u.weeks[0]}-${u.weeks[u.weeks.length - 1]}` : u.weeks[0]} + +
+
+
+
+ ${u.total > 0 ? `${u.completed}/${u.total}` : '--'} +
+
+ + ${allCards.length > 0 ? ` + +
+ \u{1F0CF} +

Flashcard Review

+
+
+
+
+

${dueCards.length} cards due today

+

${allCards.length} total cards in your deck

+
+ +
+
+ ` : ''} +
+ `; + + // Attach click handlers for hex cards and table rows + container.querySelectorAll('[data-unit]').forEach(el => { + el.addEventListener('click', () => { + const unitId = el.dataset.unit; + navigate(`#/unit/${unitId}`); + }); + }); + + // Flashcard card click + document.getElementById('flashcard-dash-card')?.addEventListener('click', () => { + navigate('#/flashcards'); + }); +} diff --git a/curriculum/js/db.js b/curriculum/js/db.js new file mode 100644 index 0000000000..e2eae43aae --- /dev/null +++ b/curriculum/js/db.js @@ -0,0 +1,329 @@ +// ============================================================ +// IndexedDB Abstraction Layer — Parameter Golf Curriculum +// ============================================================ + +const DB_NAME = 'PGolfCurriculum'; +const DB_VERSION = 1; + +let _db = null; + +function openDB() { + return new Promise((resolve, reject) => { + if (_db) { resolve(_db); return; } + const req = indexedDB.open(DB_NAME, DB_VERSION); + req.onupgradeneeded = (e) => { + const db = e.target.result; + + if (!db.objectStoreNames.contains('progress')) { + const s = db.createObjectStore('progress', { keyPath: 'id' }); + s.createIndex('byUnit', 'unitId', { unique: false }); + s.createIndex('byWeek', 'weekId', { unique: false }); + s.createIndex('byCompleted', 'completed', { unique: false }); + } + + if (!db.objectStoreNames.contains('exercises')) { + db.createObjectStore('exercises', { keyPath: 'exerciseId' }); + } + + if (!db.objectStoreNames.contains('quizScores')) { + const qs = db.createObjectStore('quizScores', { keyPath: 'id', autoIncrement: true }); + qs.createIndex('byQuiz', ['weekId', 'quizId'], { unique: false }); + } + + if (!db.objectStoreNames.contains('flashcards')) { + const fc = db.createObjectStore('flashcards', { keyPath: 'cardId' }); + fc.createIndex('byNextReview', 'nextReview', { unique: false }); + fc.createIndex('byUnit', 'unitId', { unique: false }); + } + + if (!db.objectStoreNames.contains('sessions')) { + const ss = db.createObjectStore('sessions', { keyPath: 'id', autoIncrement: true }); + ss.createIndex('byDate', 'startedAt', { unique: false }); + } + + if (!db.objectStoreNames.contains('settings')) { + db.createObjectStore('settings', { keyPath: 'key' }); + } + }; + req.onsuccess = () => { _db = req.result; resolve(_db); }; + req.onerror = () => reject(req.error); + }); +} + +// Generic helpers +function tx(storeName, mode = 'readonly') { + return _db.transaction(storeName, mode).objectStore(storeName); +} + +function reqP(request) { + return new Promise((resolve, reject) => { + request.onsuccess = () => resolve(request.result); + request.onerror = () => reject(request.error); + }); +} + +function txP(storeName, mode, fn) { + return new Promise((resolve, reject) => { + const transaction = _db.transaction(storeName, mode); + const store = transaction.objectStore(storeName); + fn(store); + transaction.oncomplete = () => resolve(); + transaction.onerror = () => reject(transaction.error); + }); +} + +// ============================================================ +// Progress +// ============================================================ + +function progressId(unitId, weekId, itemType, itemId) { + return `${unitId}-${weekId}-${itemType}-${itemId}`; +} + +export async function setItemComplete(unitId, weekId, itemType, itemId, completed) { + await openDB(); + const id = progressId(unitId, weekId, itemType, itemId); + const store = tx('progress', 'readwrite'); + const existing = await reqP(store.get(id)); + const record = existing || { id, unitId, weekId, itemType, itemId, notes: '', timeSpentMs: 0 }; + record.completed = completed; + record.completedAt = completed ? new Date().toISOString() : null; + await reqP(tx('progress', 'readwrite').put(record)); +} + +export async function getItemProgress(unitId, weekId, itemType, itemId) { + await openDB(); + return reqP(tx('progress').get(progressId(unitId, weekId, itemType, itemId))); +} + +export async function saveNote(unitId, weekId, itemType, itemId, notes) { + await openDB(); + const id = progressId(unitId, weekId, itemType, itemId); + const store = tx('progress', 'readwrite'); + const existing = await reqP(store.get(id)); + const record = existing || { id, unitId, weekId, itemType, itemId, completed: false, completedAt: null, timeSpentMs: 0 }; + record.notes = notes; + await reqP(tx('progress', 'readwrite').put(record)); +} + +export async function getWeekProgress(unitId, weekId) { + await openDB(); + const all = await getAllFromIndex('progress', 'byUnit', unitId); + const weekItems = all.filter(r => r.weekId === weekId); + const completed = weekItems.filter(r => r.completed).length; + return { completed, total: weekItems.length, percent: weekItems.length ? Math.round((completed / weekItems.length) * 100) : 0 }; +} + +export async function getUnitProgress(unitId) { + await openDB(); + const all = await getAllFromIndex('progress', 'byUnit', unitId); + const completed = all.filter(r => r.completed).length; + return { completed, total: all.length, percent: all.length ? Math.round((completed / all.length) * 100) : 0 }; +} + +export async function getOverallProgress() { + await openDB(); + const all = await getAllFromStore('progress'); + const completed = all.filter(r => r.completed).length; + return { completed, total: all.length, percent: all.length ? Math.round((completed / all.length) * 100) : 0 }; +} + +// ============================================================ +// Exercises +// ============================================================ + +export async function saveExerciseCode(exerciseId, code, language = 'python') { + await openDB(); + const store = tx('exercises', 'readwrite'); + const existing = await reqP(store.get(exerciseId)); + const record = existing || { exerciseId, unitId: 0, weekId: 0, completed: false, attempts: 0 }; + record.code = code; + record.language = language; + record.lastModified = new Date().toISOString(); + await reqP(tx('exercises', 'readwrite').put(record)); +} + +export async function getExerciseCode(exerciseId) { + await openDB(); + return reqP(tx('exercises').get(exerciseId)); +} + +export async function setExerciseComplete(exerciseId, completed) { + await openDB(); + const store = tx('exercises', 'readwrite'); + const existing = await reqP(store.get(exerciseId)); + if (existing) { + existing.completed = completed; + await reqP(tx('exercises', 'readwrite').put(existing)); + } +} + +// ============================================================ +// Quiz Scores +// ============================================================ + +export async function saveQuizAttempt(weekId, quizId, score, maxScore, answers) { + await openDB(); + const store = tx('quizScores', 'readwrite'); + await reqP(store.add({ + weekId, quizId, score, maxScore, answers, + takenAt: new Date().toISOString(), + timeSpentMs: 0 + })); +} + +export async function getQuizHistory(weekId, quizId) { + await openDB(); + const all = await getAllFromIndex('quizScores', 'byQuiz', [weekId, quizId]); + return all.sort((a, b) => b.takenAt.localeCompare(a.takenAt)); +} + +// ============================================================ +// Flashcards (SM-2) +// ============================================================ + +export async function initFlashcards(cards) { + await openDB(); + const store = tx('flashcards', 'readwrite'); + for (const card of cards) { + const existing = await reqP(store.get(card.cardId)); + if (!existing) { + await reqP(tx('flashcards', 'readwrite').put({ + ...card, + easeFactor: 2.5, + interval: 0, + repetitions: 0, + nextReview: new Date().toISOString().slice(0, 10), + lastReviewed: null + })); + } + } +} + +export async function getCardsForReview(today) { + await openDB(); + const all = await getAllFromStore('flashcards'); + return all.filter(c => c.nextReview <= today); +} + +export async function updateCardReview(cardId, quality) { + await openDB(); + const store = tx('flashcards', 'readwrite'); + const card = await reqP(store.get(cardId)); + if (!card) return; + + // SM-2 algorithm + if (quality < 3) { + card.repetitions = 0; + card.interval = 0; + } else { + if (card.repetitions === 0) card.interval = 1; + else if (card.repetitions === 1) card.interval = 6; + else card.interval = Math.round(card.interval * card.easeFactor); + card.repetitions++; + } + + card.easeFactor = Math.max(1.3, card.easeFactor + (0.1 - (5 - quality) * (0.08 + (5 - quality) * 0.02))); + const next = new Date(); + next.setDate(next.getDate() + card.interval); + card.nextReview = next.toISOString().slice(0, 10); + card.lastReviewed = new Date().toISOString(); + + await reqP(tx('flashcards', 'readwrite').put(card)); +} + +export async function getAllFlashcards() { + await openDB(); + return getAllFromStore('flashcards'); +} + +// ============================================================ +// Sessions +// ============================================================ + +let _currentSession = null; + +export async function startSession(unitId, weekId) { + await openDB(); + _currentSession = { unitId, weekId, startedAt: new Date().toISOString(), activeTimeMs: 0 }; +} + +export async function endSession() { + if (!_currentSession) return; + await openDB(); + _currentSession.endedAt = new Date().toISOString(); + const store = tx('sessions', 'readwrite'); + await reqP(store.add(_currentSession)); + _currentSession = null; +} + +export async function getTotalStudyTime() { + await openDB(); + const all = await getAllFromStore('sessions'); + return all.reduce((sum, s) => sum + (s.activeTimeMs || 0), 0); +} + +// ============================================================ +// Settings +// ============================================================ + +export async function setSetting(key, value) { + await openDB(); + await reqP(tx('settings', 'readwrite').put({ key, value })); +} + +export async function getSetting(key, defaultValue = null) { + await openDB(); + const result = await reqP(tx('settings').get(key)); + return result ? result.value : defaultValue; +} + +// ============================================================ +// Export / Import / Reset +// ============================================================ + +export async function exportAll() { + await openDB(); + const data = {}; + const storeNames = ['progress', 'exercises', 'quizScores', 'flashcards', 'sessions', 'settings']; + for (const name of storeNames) { + data[name] = await getAllFromStore(name); + } + return data; +} + +export async function importAll(data) { + await openDB(); + for (const [name, records] of Object.entries(data)) { + if (!_db.objectStoreNames.contains(name)) continue; + await txP(name, 'readwrite', (store) => { + store.clear(); + for (const record of records) { + store.put(record); + } + }); + } +} + +export async function clearAll() { + await openDB(); + const storeNames = ['progress', 'exercises', 'quizScores', 'flashcards', 'sessions', 'settings']; + for (const name of storeNames) { + await txP(name, 'readwrite', (store) => store.clear()); + } +} + +// ============================================================ +// Internal helpers +// ============================================================ + +function getAllFromStore(storeName) { + return reqP(tx(storeName).getAll()); +} + +function getAllFromIndex(storeName, indexName, key) { + return reqP(tx(storeName).index(indexName).getAll(key)); +} + +// Initialize DB on import +export const ready = openDB(); diff --git a/curriculum/js/progress.js b/curriculum/js/progress.js new file mode 100644 index 0000000000..8a9f0a3fd5 --- /dev/null +++ b/curriculum/js/progress.js @@ -0,0 +1,119 @@ +// ============================================================ +// Progress Computation & Unit Hydration +// ============================================================ + +import * as DB from './db.js'; +import { renderQuiz } from './quiz-engine.js'; + +// Hydrate a loaded unit: attach checkbox handlers, load saved state, notes +export async function hydrateUnit(unitId) { + // Hydrate all check-items (topics, readings, exercises) + const checkItems = document.querySelectorAll('.check-item'); + for (const item of checkItems) { + const checkbox = item.querySelector('input[type="checkbox"]'); + if (!checkbox) continue; + + const weekId = parseInt(item.closest('[data-week]')?.dataset.week || '0', 10); + const itemType = item.dataset.type || 'topic'; + const itemId = item.dataset.id || ''; + + // Load saved state + const progress = await DB.getItemProgress(unitId, weekId, itemType, itemId); + if (progress?.completed) { + checkbox.checked = true; + item.classList.add('completed'); + } + + // Attach change handler + checkbox.addEventListener('change', async () => { + const checked = checkbox.checked; + await DB.setItemComplete(unitId, weekId, itemType, itemId, checked); + item.classList.toggle('completed', checked); + updateSidebarProgress(); + }); + } + + // Hydrate all note areas + const noteAreas = document.querySelectorAll('.note-area'); + for (const area of noteAreas) { + const weekId = parseInt(area.closest('[data-week]')?.dataset.week || '0', 10); + const itemId = area.dataset.noteKey || ''; + const itemType = area.dataset.noteType || 'note'; + + // Load saved note + const progress = await DB.getItemProgress(unitId, weekId, itemType, itemId); + if (progress?.notes) { + area.textContent = progress.notes; + } + + // Auto-save on blur + let saveTimer; + area.addEventListener('input', () => { + clearTimeout(saveTimer); + saveTimer = setTimeout(async () => { + await DB.saveNote(unitId, weekId, itemType, itemId, area.textContent); + }, 500); + }); + } + + // Hydrate collapsible sections + const collapsibles = document.querySelectorAll('.collapsible'); + for (const c of collapsibles) { + const header = c.querySelector('.collapsible-header'); + if (header) { + header.addEventListener('click', () => { + c.classList.toggle('open'); + }); + } + } + + // Hydrate quiz containers + const quizContainers = document.querySelectorAll('[data-quiz]'); + for (const container of quizContainers) { + const quizId = container.dataset.quiz; + if (quizId) await renderQuiz(container, quizId); + } + + // Initialize flashcards for this unit + try { + const resp = await fetch('data/flashcards.json'); + const cards = await resp.json(); + const unitCards = cards.filter(c => c.unitId === unitId); + if (unitCards.length > 0) { + await DB.initFlashcards(unitCards); + } + } catch (e) { + // Flashcard data not available yet + } + + // Render KaTeX math + if (window.renderMathInElement) { + renderMathInElement(document.getElementById('content'), { + delimiters: [ + { left: '$$', right: '$$', display: true }, + { left: '$', right: '$', display: false }, + ], + throwOnError: false + }); + } +} + +// Update sidebar progress badges +export async function updateSidebarProgress() { + for (let unitId = 1; unitId <= 9; unitId++) { + const badge = document.querySelector(`[data-nav="unit-${unitId}"] .nav-badge`); + if (!badge) continue; + + const progress = await DB.getUnitProgress(unitId); + if (progress.total === 0) { + badge.textContent = '--'; + badge.closest('.nav-item')?.classList.remove('complete', 'in-progress'); + } else { + badge.textContent = `${progress.percent}%`; + const navItem = badge.closest('.nav-item'); + navItem?.classList.remove('complete', 'in-progress'); + if (progress.percent === 100) navItem?.classList.add('complete'); + else if (progress.percent > 0) navItem?.classList.add('in-progress'); + } + } +} diff --git a/curriculum/js/quiz-engine.js b/curriculum/js/quiz-engine.js new file mode 100644 index 0000000000..9d7d9e8d4b --- /dev/null +++ b/curriculum/js/quiz-engine.js @@ -0,0 +1,122 @@ +// ============================================================ +// Quiz Engine — Rendering, scoring, and persistence +// ============================================================ + +import * as DB from './db.js'; + +let _quizData = null; + +async function loadQuizData() { + if (_quizData) return _quizData; + try { + const resp = await fetch('data/quizzes.json'); + _quizData = await resp.json(); + } catch (e) { + _quizData = {}; + } + return _quizData; +} + +export async function renderQuiz(container, quizId) { + const data = await loadQuizData(); + const quiz = data[quizId]; + if (!quiz) { + container.innerHTML = '

Quiz not available yet.

'; + return; + } + + // Check for previous attempts + const history = await DB.getQuizHistory(quiz.weekId, quiz.quizId); + const bestScore = history.length > 0 ? Math.max(...history.map(h => h.score)) : null; + + container.innerHTML = ` +
+
+

${quiz.title}

+ ${bestScore !== null ? `Best: ${bestScore}/${quiz.questions.length}` : ''} +
+
+ ${quiz.questions.map((q, i) => ` +
+

+ ${i + 1}. ${q.question} +

+
+ ${q.options.map((opt, j) => ` + + `).join('')} +
+ +
+ `).join('')} +
+ + +
+ `; + + // Submit handler + document.getElementById(`quiz-submit-${quizId}`)?.addEventListener('click', async () => { + const answers = {}; + let score = 0; + + quiz.questions.forEach((q, i) => { + const selected = document.querySelector(`input[name="quiz-${quizId}-q${i}"]:checked`); + const selectedIdx = selected ? parseInt(selected.value, 10) : -1; + answers[q.id] = selectedIdx; + + const feedback = document.querySelector(`.quiz-feedback[data-qidx="${i}"]`); + const options = document.querySelectorAll(`label[data-qidx="${i}"]`); + + if (selectedIdx === q.answer) { + score++; + if (feedback) { + feedback.style.display = 'block'; + feedback.style.background = 'rgba(0, 184, 148, 0.1)'; + feedback.style.color = 'var(--pg-lime)'; + feedback.textContent = 'Correct! ' + q.explanation; + } + options.forEach((opt, j) => { + if (j === q.answer) opt.style.borderColor = 'var(--pg-lime)'; + }); + } else { + if (feedback) { + feedback.style.display = 'block'; + feedback.style.background = 'rgba(214, 48, 49, 0.1)'; + feedback.style.color = 'var(--pg-red-soft)'; + feedback.textContent = (selectedIdx === -1 ? 'Not answered. ' : 'Incorrect. ') + q.explanation; + } + options.forEach((opt, j) => { + if (j === q.answer) opt.style.borderColor = 'var(--pg-lime)'; + if (j === selectedIdx) opt.style.borderColor = 'var(--pg-red)'; + }); + } + }); + + // Save attempt + await DB.saveQuizAttempt(quiz.weekId, quiz.quizId, score, quiz.questions.length, answers); + + // Show result + const resultEl = document.getElementById(`quiz-result-${quizId}`); + if (resultEl) { + resultEl.style.display = 'block'; + const pct = Math.round((score / quiz.questions.length) * 100); + const color = pct >= 80 ? 'var(--pg-lime)' : pct >= 60 ? 'var(--pg-amber)' : 'var(--pg-red)'; + resultEl.innerHTML = ` +
+ ${score} / ${quiz.questions.length} (${pct}%) +
+ `; + } + + // Disable submit + const btn = document.getElementById(`quiz-submit-${quizId}`); + if (btn) { btn.disabled = true; btn.textContent = 'Submitted'; btn.style.opacity = '0.5'; } + + // Disable radio buttons + document.querySelectorAll(`input[name^="quiz-${quizId}"]`).forEach(r => r.disabled = true); + }); +} diff --git a/curriculum/js/router.js b/curriculum/js/router.js new file mode 100644 index 0000000000..368821f8bd --- /dev/null +++ b/curriculum/js/router.js @@ -0,0 +1,332 @@ +// ============================================================ +// Hash-based Client-side Router +// ============================================================ + +import * as DB from './db.js'; +import { renderDashboard } from './dashboard.js'; +import { hydrateUnit } from './progress.js'; + +const contentEl = () => document.getElementById('content'); +const breadcrumbEl = () => document.getElementById('breadcrumb'); + +// Unit metadata +export const UNITS = [ + { id: 1, name: 'Foundations of Language Modeling', weeks: [1, 2], color: 'violet', icon: '\u{1F4D6}' }, + { id: 2, name: 'Scaling Laws & L(N)', weeks: [3], color: 'cyan', icon: '\u{1F4CA}' }, + { id: 3, name: 'Efficient Architectures', weeks: [4, 5, 6], color: 'magenta', icon: '\u{1F9E9}' }, + { id: 4, name: 'Tokenization', weeks: [7], color: 'amber', icon: '\u{1F524}' }, + { id: 5, name: 'Optimization', weeks: [8, 9, 10], color: 'lime', icon: '\u{26A1}' }, + { id: 6, name: 'Quantization & Compression', weeks: [11, 12, 13], color: 'red', icon: '\u{1F5DC}' }, + { id: 7, name: 'Evaluation Methods', weeks: [14], color: 'blue', icon: '\u{1F3AF}' }, + { id: 8, name: 'Systems & Performance', weeks: [15], color: 'orange', icon: '\u{1F680}' }, + { id: 9, name: 'Integration & Strategy', weeks: [16], color: 'violet', icon: '\u{1F3C6}' }, +]; + +// Unit content cache +const _cache = {}; + +// Parse current hash into route object +export function parseRoute() { + const hash = window.location.hash.slice(1) || '/dashboard'; + const parts = hash.split('/').filter(Boolean); + + if (parts[0] === 'unit' && parts[1]) { + const unitId = parseInt(parts[1], 10); + const weekId = parts[2] === 'week' && parts[3] ? parseInt(parts[3], 10) : null; + return { view: 'unit', unitId, weekId }; + } + if (parts[0] === 'flashcards') return { view: 'flashcards' }; + if (parts[0] === 'settings') return { view: 'settings' }; + return { view: 'dashboard' }; +} + +// Navigate to a hash route +export function navigate(hash) { + window.location.hash = hash; +} + +// Set breadcrumb +function setBreadcrumb(parts) { + const el = breadcrumbEl(); + if (!el) return; + el.innerHTML = parts.map((p, i) => { + if (i < parts.length - 1) { + return `${p.label}/`; + } + return `${p.label}`; + }).join(''); +} + +// Update sidebar active state +function updateSidebarActive(route) { + document.querySelectorAll('.nav-item').forEach(el => el.classList.remove('active')); + document.querySelectorAll('.nav-week').forEach(el => el.classList.remove('active')); + document.querySelectorAll('.nav-weeks').forEach(el => el.classList.remove('expanded')); + + if (route.view === 'dashboard') { + document.querySelector('[data-nav="dashboard"]')?.classList.add('active'); + } else if (route.view === 'flashcards') { + document.querySelector('[data-nav="flashcards"]')?.classList.add('active'); + } else if (route.view === 'settings') { + document.querySelector('[data-nav="settings"]')?.classList.add('active'); + } else if (route.view === 'unit') { + const unitNav = document.querySelector(`[data-nav="unit-${route.unitId}"]`); + if (unitNav) { + unitNav.classList.add('active'); + const weeks = unitNav.nextElementSibling; + if (weeks) weeks.classList.add('expanded'); + } + if (route.weekId) { + document.querySelector(`[data-nav="week-${route.weekId}"]`)?.classList.add('active'); + } + } +} + +// Load unit content +async function loadUnit(unitId, weekId) { + const unit = UNITS.find(u => u.id === unitId); + if (!unit) { + contentEl().innerHTML = '
?

Unit not found

'; + return; + } + + const crumbs = [{ label: 'Dashboard', href: '#/dashboard' }]; + crumbs.push({ label: `Unit ${unitId}`, href: `#/unit/${unitId}` }); + if (weekId) crumbs.push({ label: `Week ${weekId}`, href: `#/unit/${unitId}/week/${weekId}` }); + setBreadcrumb(crumbs); + + // Try loading the unit HTML fragment + if (!_cache[unitId]) { + try { + const resp = await fetch(`units/unit-${unitId}.html`); + if (resp.ok) { + _cache[unitId] = await resp.text(); + } + } catch (e) { + // File not available + } + } + + if (_cache[unitId]) { + contentEl().innerHTML = _cache[unitId]; + await hydrateUnit(unitId); + // Scroll to week if specified + if (weekId) { + const weekEl = document.querySelector(`[data-week="${weekId}"]`); + if (weekEl) { + weekEl.scrollIntoView({ behavior: 'smooth', block: 'start' }); + } + } + } else { + // Coming soon placeholder + const weekList = unit.weeks.map(w => `
  • Week ${w}
  • `).join(''); + contentEl().innerHTML = ` +
    +
    ${unit.icon}
    +

    Unit ${unitId}: ${unit.name}

    +

    This unit is coming soon.

    +
    +

    Covers:

    + +
    +
    + `; + } +} + +// Load flashcards view +async function loadFlashcards() { + setBreadcrumb([ + { label: 'Dashboard', href: '#/dashboard' }, + { label: 'Flashcards', href: '#/flashcards' } + ]); + + const today = new Date().toISOString().slice(0, 10); + const due = await DB.getCardsForReview(today); + const all = await DB.getAllFlashcards(); + + if (all.length === 0) { + contentEl().innerHTML = ` +
    +
    \u{1F0CF}
    +

    Flashcards

    +

    No flashcards yet. Complete Unit 1 to unlock your first deck.

    +
    + `; + return; + } + + contentEl().innerHTML = ` +
    +
    +

    Flashcard Review

    +
    + ${due.length} due today + ${all.length} total cards +
    +
    +
    + ${due.length > 0 ? renderFlashcard(due, 0) : '

    No cards due for review today. Check back tomorrow.

    '} +
    +
    + `; + + if (due.length > 0) { + initFlashcardReview(due); + } +} + +function renderFlashcard(cards, idx) { + const card = cards[idx]; + return ` +
    +

    Card ${idx + 1} / ${cards.length}

    +
    ${card.front}
    + +

    Click to reveal

    + +
    + `; +} + +function initFlashcardReview(cards) { + let idx = 0; + let revealed = false; + + const area = document.getElementById('flashcard-area'); + if (!area) return; + + area.addEventListener('click', async (e) => { + const card = cards[idx]; + const quality = e.target.dataset?.quality; + + if (quality) { + await DB.updateCardReview(card.cardId, parseInt(quality, 10)); + idx++; + if (idx >= cards.length) { + area.innerHTML = '

    Review complete!

    All cards reviewed for today.

    '; + return; + } + revealed = false; + area.innerHTML = renderFlashcard(cards, idx); + return; + } + + if (!revealed) { + const back = document.getElementById('fc-back'); + const hint = document.getElementById('fc-hint'); + const rating = document.getElementById('fc-rating'); + if (back) back.style.display = 'block'; + if (hint) hint.style.display = 'none'; + if (rating) { rating.style.display = 'flex'; rating.style.justifyContent = 'center'; } + revealed = true; + } + }); +} + +// Load settings view +async function loadSettings() { + setBreadcrumb([ + { label: 'Dashboard', href: '#/dashboard' }, + { label: 'Settings', href: '#/settings' } + ]); + + const overall = await DB.getOverallProgress(); + contentEl().innerHTML = ` +
    +
    +

    Settings

    +
    +
    +
    +

    Data

    +

    ${overall.completed} items completed, ${overall.total} total tracked

    +
    + + + +
    + +
    +
    +
    + `; + + document.getElementById('btn-export')?.addEventListener('click', async () => { + const data = await DB.exportAll(); + const blob = new Blob([JSON.stringify(data, null, 2)], { type: 'application/json' }); + const a = document.createElement('a'); + a.href = URL.createObjectURL(blob); + a.download = `pgolf-curriculum-backup-${new Date().toISOString().slice(0, 10)}.json`; + a.click(); + }); + + document.getElementById('btn-import')?.addEventListener('click', () => { + document.getElementById('import-file')?.click(); + }); + + document.getElementById('import-file')?.addEventListener('change', async (e) => { + const file = e.target.files[0]; + if (!file) return; + const text = await file.text(); + const data = JSON.parse(text); + await DB.importAll(data); + alert('Data imported successfully. Reloading...'); + window.location.reload(); + }); + + document.getElementById('btn-reset')?.addEventListener('click', async () => { + if (confirm('This will permanently delete all your progress, notes, and saved code. Continue?')) { + await DB.clearAll(); + alert('All data has been reset. Reloading...'); + window.location.reload(); + } + }); +} + +// Main route handler +async function handleRoute() { + const route = parseRoute(); + updateSidebarActive(route); + await DB.setSetting('lastVisited', window.location.hash); + + const scroll = document.querySelector('.content-scroll'); + if (scroll) scroll.scrollTop = 0; + + switch (route.view) { + case 'dashboard': + setBreadcrumb([{ label: 'Dashboard', href: '#/dashboard' }]); + await renderDashboard(contentEl()); + break; + case 'unit': + await loadUnit(route.unitId, route.weekId); + break; + case 'flashcards': + await loadFlashcards(); + break; + case 'settings': + await loadSettings(); + break; + default: + setBreadcrumb([{ label: 'Dashboard', href: '#/dashboard' }]); + await renderDashboard(contentEl()); + } +} + +// Initialize router +export async function initRouter() { + await DB.ready; + + // Restore last visited route if no hash + if (!window.location.hash) { + const last = await DB.getSetting('lastVisited', '#/dashboard'); + window.location.hash = last; + } + + window.addEventListener('hashchange', handleRoute); + await handleRoute(); +} diff --git a/curriculum/units/unit-1.html b/curriculum/units/unit-1.html new file mode 100644 index 0000000000..b4608c61ac --- /dev/null +++ b/curriculum/units/unit-1.html @@ -0,0 +1,480 @@ +
    + + +
    +

    Unit 1: Foundations of Language Modeling

    +
    + Weeks 1-2 + ~20 hours + Prerequisite: None +
    +

    + The mathematical and architectural foundations that every technique in this competition builds on. + Information theory tells us why compression equals prediction. The transformer is the machine that does it. +

    +
    + + +
    +
    + \u{25B6} + Week 1: Statistical Language Models and Information Theory +
    + + +
    + \u{25CF} +

    Topics

    +
    + +
    + +
    +
    Language modeling as next-token prediction
    +
    +
    + \u{25B8} + Details +
    +
    +

    + A language model assigns probabilities to sequences of tokens. Given a context + $x_1, x_2, \ldots, x_{t-1}$, the model predicts a distribution over the next token $x_t$. + The chain rule of probability decomposes any sequence probability as: +

    +

    + $$P(x_1, \ldots, x_T) = \prod_{t=1}^{T} P(x_t \mid x_1, \ldots, x_{t-1})$$ +

    +

    + This is the autoregressive factorization. The entire Parameter Golf challenge reduces to: + build the best next-token predictor that fits in 16MB. +

    +
    +
    +
    +
    + +
    + +
    +
    Cross-entropy loss, perplexity, and bits-per-byte (BPB)
    +
    +
    + \u{25B8} + Details +
    +
    +

    + Cross-entropy measures how well our model $q$ approximates the true distribution $p$: +

    +

    + $$H(p, q) = -\sum_{x} p(x) \log q(x)$$ +

    +

    + In practice, we compute the average negative log-likelihood over tokens (in nats when using $\ln$). +

    +

    + Perplexity = $e^{H(p,q)}$ (or $2^{H}$ in bits). Lower is better. +

    +

    + Bits-per-byte (BPB) is the competition metric. It converts token-level loss to byte-level compression: +

    +

    + $$\text{BPB} = \frac{\text{val\_loss}}{\ln 2} \times \frac{\text{tokens}}{\text{bytes}}$$ +

    +

    + This makes the metric tokenizer-agnostic: a model with a 1024-token vocabulary and one with 32K tokens + are compared on the same byte-compression scale. +

    +
    +
    +
    +
    + +
    + +
    +
    Shannon entropy, source coding theorem, and compression = prediction
    +
    +
    + \u{25B8} + Details +
    +
    +

    + Shannon entropy $H(X) = -\sum p(x) \log_2 p(x)$ is the theoretical minimum bits needed to encode a source. + The source coding theorem says you cannot compress below $H(X)$ bits per symbol on average. +

    +

    + The deep insight: optimal compression and optimal prediction are the same problem. + A model that perfectly predicts the next token achieves the best possible compression rate. + This is why the Parameter Golf metric (BPB) directly measures compression quality. +

    +
    +
    +
    +
    + +
    + +
    +
    KL divergence and mutual information
    +
    +
    + \u{25B8} + Details +
    +
    +

    + KL divergence measures the extra bits needed when using $q$ instead of $p$: +

    +

    + $$D_{KL}(p \| q) = \sum_x p(x) \log \frac{p(x)}{q(x)} = H(p, q) - H(p) \geq 0$$ +

    +

    + Since $H(p)$ is fixed, minimizing cross-entropy $H(p,q)$ is equivalent to minimizing $D_{KL}(p \| q)$. + Our training loss is literally measuring how far our model is from the true distribution. +

    +
    +
    +
    +
    + +
    + +
    Why BPB is the right metric for tokenizer-agnostic evaluation
    +
    + + +
    + \u{25A0} +

    Readings

    +
    + +
    + +
    Shannon, "A Mathematical Theory of Communication" (1948), Sections I-III
    +
    + +
    + +
    Jurafsky & Martin, Speech and Language Processing, Ch. 3
    +
    + +
    + +
    MacKay, Information Theory, Inference and Learning Algorithms, Ch. 1-6
    +
    + + +
    + \u{25B2} +

    Exercises

    +
    + +
    + +
    +
    Implement a character-level n-gram model and compute BPB on a text corpus
    +
    +
    + \u{25B8} + Hints & workspace +
    +
    +

    + Start with bigrams (n=2). Count all character pairs in a training text, normalize to get probabilities, + then compute $\text{BPB} = -\frac{1}{N} \sum \log_2 P(c_t | c_{t-1})$ on held-out text. + Use add-1 (Laplace) smoothing for unseen pairs. +

    +
    +
    +
    +
    +
    + +
    + +
    +
    Derive the relationship between cross-entropy loss, perplexity, and bits-per-byte
    +
    +
    + \u{25B8} + Hints & workspace +
    +
    +

    + Start from the definition of cross-entropy in nats. Convert to bits (divide by $\ln 2$). + Perplexity is $e^{H}$ in nats or $2^{H}$ in bits. For BPB, you need to account for the + tokens-to-bytes ratio of the tokenizer. +

    +
    +
    +
    +
    +
    + +
    + +
    +
    Prove that cross-entropy is minimized when the model distribution equals the true distribution
    +
    +
    + \u{25B8} + Hints & workspace +
    +
    +

    + Use the Gibbs' inequality: $H(p, q) = H(p) + D_{KL}(p \| q) \geq H(p)$, with equality + iff $q = p$ everywhere. Since $D_{KL} \geq 0$ (prove via Jensen's inequality on $\log$). +

    +
    +
    +
    +
    +
    + + +
    + \u{2606} +

    Week 1 Quiz

    +
    +
    + + +
    + \u{270E} +

    Week 1 Notes

    +
    +
    +
    + + +
    +
    + \u{25B6} + Week 2: The Transformer Architecture +
    + + +
    + \u{25CF} +

    Topics

    +
    + +
    + +
    +
    Self-attention: queries, keys, values, scaled dot-product attention
    +
    +
    + \u{25B8} + Details +
    +
    +

    + The core operation of the transformer. Input $X \in \mathbb{R}^{T \times d}$ is projected into + queries, keys, and values: +

    +

    + $$Q = XW_Q, \quad K = XW_K, \quad V = XW_V$$ +

    +

    + $$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V$$ +

    +

    + The $\sqrt{d_k}$ scaling prevents the dot products from growing too large, keeping softmax + in a well-behaved regime. In the Parameter Golf baseline, $d_k = d_\text{model} / n_\text{heads} = 64$. +

    +
    +
    +
    +
    + +
    + +
    Multi-head attention and why it works (subspace decomposition)
    +
    + +
    + +
    +
    Position encodings: sinusoidal, learned, RoPE
    +
    +
    + \u{25B8} + Details +
    +
    +

    + RoPE (Rotary Position Embedding) is used in the Parameter Golf baseline. + It encodes position by rotating query and key vectors in 2D subspaces: +

    +

    + $$q_m = R_\Theta(m) \cdot W_q x_m, \quad k_n = R_\Theta(n) \cdot W_k x_n$$ +

    +

    + where $R_\Theta(m)$ applies rotation by angle $m \cdot \theta_i$ to the $i$-th dimension pair. + The attention score $q_m^\top k_n$ depends only on the relative position $m - n$. +

    +

    + The SOTA uses Partial RoPE: only 16 of 64 head dimensions get position encoding. + The other 48 dimensions attend purely by content similarity. +

    +
    +
    +
    +
    + +
    + +
    Layer normalization variants: LayerNorm, RMSNorm, pre-norm vs post-norm
    +
    + +
    + +
    +
    Feed-forward networks: expansion factor, activation functions (ReLU, GELU, SwiGLU, ReLU^2)
    +
    +
    + \u{25B8} + Details +
    +
    +

    + The baseline uses ReLU^2: $\text{FFN}(x) = W_2 \cdot (\text{ReLU}(W_1 x))^2$. + Squaring after ReLU creates a smoother, sparser activation pattern. + The SOTA uses LeakyReLU(0.5)^2 which allows small negative gradients to flow. +

    +

    + The MLP expansion factor in the baseline is 2x (hidden = 2 * model_dim). + The SOTA uses 3x (hidden = 3 * 512 = 1536) for more capacity per layer. +

    +
    +
    +
    +
    + +
    + +
    Residual connections and signal propagation in deep networks
    +
    + +
    + +
    Autoregressive generation and causal masking
    +
    + + +
    + \u{25A0} +

    Readings

    +
    + +
    + +
    Vaswani et al., "Attention Is All You Need" (2017)
    +
    + +
    + +
    Su et al., "RoFormer: Enhanced Transformer with Rotary Position Embedding" (2021)
    +
    + +
    + +
    Zhang & Sennrich, "Root Mean Square Layer Normalization" (2019)
    +
    + +
    + +
    Shazeer, "GLU Variants Improve Transformer" (2020)
    +
    + + +
    + \u{25B2} +

    Exercises

    +
    + +
    + +
    +
    Implement a transformer decoder from scratch in PyTorch (no nn.Transformer)
    +
    +
    + \u{25B8} + Hints & workspace +
    +
    +

    + Build it layer by layer: Embedding -> PositionalEncoding -> + N x (MultiHeadAttention + FeedForward + LayerNorm + residuals) -> + Linear head. Use the causal mask in scaled_dot_product_attention. + Reference the baseline train_gpt.py for a clean implementation to compare against. +

    +
    +
    +
    +
    +
    + +
    + +
    +
    Implement RoPE and verify it produces correct relative position attention patterns
    +
    +
    + \u{25B8} + Hints & workspace +
    +
    +

    + Compare your implementation against the baseline's Rotary class and + apply_rotary_emb function. Verify that $q_m^\top k_n$ only depends on $m-n$ + by computing attention scores and checking the Toeplitz structure. +

    +
    +
    +
    +
    +
    + +
    + +
    +
    Train a small (1-layer, 64-dim) language model on a toy corpus and verify convergence
    +
    +
    + \u{25B8} + Hints & workspace +
    +
    +

    + Use a small text file (Shakespeare, Wikipedia excerpt). Tokenize with a character-level or + tiny BPE vocabulary. Train for ~1000 steps with Adam. Plot the loss curve. + The model should overfit the training set if it's small enough. That's fine, it proves your + architecture works. +

    +
    +
    +
    +
    +
    + + +
    + \u{2606} +

    Week 2 Quiz

    +
    +
    + + +
    + \u{270E} +

    Week 2 Notes

    +
    +
    +
    + +
    diff --git a/curriculum/units/unit-2.html b/curriculum/units/unit-2.html new file mode 100644 index 0000000000..dacfb62bfa --- /dev/null +++ b/curriculum/units/unit-2.html @@ -0,0 +1,274 @@ +
    + + +
    +

    Unit 2: Scaling Laws & L(N) Optimization

    +
    + Week 3 + ~10 hours + Prereq: Unit 1 +
    +

    + Scaling laws quantify how model performance improves with more parameters, data, and compute. + Parameter Golf is an L(N) optimization problem: minimize loss at a fixed parameter count. + Understanding these laws tells you where to spend your 16MB budget. +

    +
    + + +
    +
    + \u{25B6} + Week 3: Neural Scaling Laws and L(N) Optimization +
    + + +
    + \u{25CF} +

    Topics

    +
    + +
    + +
    +
    Kaplan et al. scaling laws: L(N), L(D), L(C) and their power-law relationships
    +
    +
    + \u{25B8} + Details +
    +
    +

    + Kaplan et al. (2020) found that language model loss follows power laws: +

    +

    + $$L(N) = \left(\frac{N_c}{N}\right)^{\alpha_N}, \quad L(D) = \left(\frac{D_c}{D}\right)^{\alpha_D}, \quad L(C) = \left(\frac{C_c}{C}\right)^{\alpha_C}$$ +

    +

    + where $N$ = parameters, $D$ = dataset tokens, $C$ = compute (FLOPs). + The exponents are approximately $\alpha_N \approx 0.076$, $\alpha_D \approx 0.095$, $\alpha_C \approx 0.050$. + These are remarkably smooth and predictable across many orders of magnitude. +

    +

    + The key insight: loss is predictable from scale. You can extrapolate from small runs to predict large run performance. +

    +
    +
    +
    +
    + +
    + +
    +
    Chinchilla optimal: compute-optimal allocation between parameters and data
    +
    +
    + \u{25B8} + Details +
    +
    +

    + Hoffmann et al. (2022) showed that for a fixed compute budget $C$, you should scale + parameters $N$ and data $D$ roughly equally: $N \propto C^{0.5}$ and $D \propto C^{0.5}$. + Previous practice (GPT-3) over-allocated to parameters and under-allocated to data. +

    +

    + For Parameter Golf, Chinchilla is relevant but not directly applicable: we're constrained on $N$ (16MB artifact), + not on $C$ or $D$. We want to overfit on compute to maximize what we extract from a fixed $N$. +

    +
    +
    +
    +
    + +
    + +
    +
    The Parameter Golf objective as L(N) optimization: minimize loss given fixed N
    +
    +
    + \u{25B8} + Details +
    +
    +

    + Standard scaling laws assume you're training a standard transformer with optimal hyperparameters. + Parameter Golf breaks this assumption: the architecture is free, quantization is free, evaluation tricks are free. + You're optimizing a modified L(N) where N is measured in compressed bytes, not raw parameters. +

    +

    + This means techniques that improve the effective parameter count (weight sharing, quantization-aware training, better compression) + shift the L(N) curve leftward: getting the same loss with fewer bytes. +

    +
    +
    +
    +
    + +
    + +
    +
    Depth vs width tradeoffs: why deeper is more parameter-efficient (to a point)
    +
    +
    + \u{25B8} + Details +
    +
    +

    + At fixed parameter count, deeper models (more layers) generally outperform wider models (larger hidden dim). + This is because depth allows iterative refinement of representations, while width provides + more capacity per layer but with diminishing returns. +

    +

    + The Parameter Golf leaderboard confirms this: submissions evolved from 9 layers (baseline) to 11 layers, + even though adding layers costs parameters. The efficiency gain from deeper processing outweighs the parameter cost. +

    +

    + The extreme case: depth recurrence (weight sharing) takes this to its logical conclusion. + If depth is efficient, run the same layers multiple times and spend the freed parameters elsewhere. +

    +
    +
    +
    +
    + +
    + +
    Implications: at fixed N, how do depth, width, and architecture affect loss?
    +
    + +
    + +
    +
    The role of the 10-minute training constraint as a soft compute bound
    +
    +
    + \u{25B8} + Details +
    +
    +

    + While Parameter Golf is framed as L(N), the 10-minute 8xH100 cap introduces a + practical compute constraint. At ~87ms/step (SOTA), you get ~6900 steps. This means: +

    +
      +
    • Faster architectures can train more steps = better convergence
    • +
    • Training efficiency matters (Parallel Muon, Flash Attention 3)
    • +
    • The tradeoff: a technique that improves loss-per-step but slows step time may net negative
    • +
    • Evaluation is separate (10 min cap), so eval-time tricks have their own budget
    • +
    +
    +
    +
    +
    + + +
    + \u{25A0} +

    Readings

    +
    + +
    + +
    Kaplan et al., "Scaling Laws for Neural Language Models" (2020)
    +
    + +
    + +
    Hoffmann et al., "Training Compute-Optimal Large Language Models" (Chinchilla, 2022)
    +
    + +
    + +
    Tay et al., "Scale Efficiently: Insights from Pre-training and Fine-tuning Transformers" (2022)
    +
    + + +
    + \u{25B2} +

    Exercises

    +
    + +
    + +
    +
    Fit power-law curves to training runs at 3-4 different model sizes and predict loss at a target size
    +
    +
    + \u{25B8} + Hints & workspace +
    +
    +

    + Train the baseline at different configs (e.g., 3L/256d, 5L/384d, 7L/448d, 9L/512d). + Record final val_loss for each. Plot on a log-log scale (log N vs log L). + Fit $L = a \cdot N^{-b}$ using least squares on the log-log data. + Use scipy.optimize.curve_fit or manual linear regression on log values. +

    +
    +
    +
    +
    +
    + +
    + +
    +
    Experimentally determine: at 16MB, is it better to have 6 layers at 768d or 12 layers at 512d?
    +
    +
    + \u{25B8} + Hints & workspace +
    +
    +

    + Calculate exact parameter counts for both configs (including attention, MLP, embeddings, norms). + Verify both fit under 16MB after int8+zlib quantization. + Train each for the same wall-clock time and compare final BPB. + The result should show that depth wins at this scale, confirming the leaderboard trend. +

    +
    +
    +
    +
    +
    + +
    + +
    +
    Categorize each leaderboard submission's primary contribution axis (quantization, architecture, training, evaluation)
    +
    +
    + \u{25B8} + Hints & workspace +
    +
    +

    + Read each submission's README. For each, identify: what was the primary source of BPB improvement? + Common axes: quantization (int6, GPTQ, QAT), architecture (XSA, SmearGate, BigramHash, MLP width), + training (Muon, EMA/SWA, longer warmdown), evaluation (sliding window, TTT). + Look for patterns: which axis has diminishing returns? Which is underexplored? +

    +
    +
    +
    +
    +
    + + +
    + \u{2606} +

    Week 3 Quiz

    +
    +
    + + +
    + \u{270E} +

    Week 3 Notes

    +
    +
    +
    + +
    diff --git a/curriculum/units/unit-3.html b/curriculum/units/unit-3.html new file mode 100644 index 0000000000..420c169e66 --- /dev/null +++ b/curriculum/units/unit-3.html @@ -0,0 +1,506 @@ +
    + + +
    +

    Unit 3: Efficient Architectures

    +
    + Weeks 4-6 + ~30 hours + Prereq: Units 1-2 +
    +

    + The heart of the challenge: architectural techniques that extract maximum performance from a fixed parameter budget. + From attention efficiency (GQA, Flash Attention) to parameter sharing (depth recurrence, MoE) + to specialized modules invented for this competition (BigramHash, XSA, SmearGate). +

    +
    + + +
    +
    + \u{25B6} + Week 4: Grouped Query Attention, Multi-Query Attention, and KV-Cache Efficiency +
    + +
    + \u{25CF} +

    Topics

    +
    + +
    + +
    +
    Multi-query attention (Shazeer 2019): shared K/V heads
    +
    +
    + \u{25B8} + Details +
    +
    +

    + Standard multi-head attention: each head has its own Q, K, V projections. + Multi-query attention (MQA): all heads share a single K and V projection. + This reduces KV parameters by a factor of $h$ (number of heads), saving significant parameters + in the attention block. +

    +

    + Parameter savings at 512d, 8 heads: standard attention K/V = $2 \times 512 \times 512 = 524K$ params. + MQA K/V = $2 \times 64 \times 512 = 65K$ params. That's an 8x reduction in KV parameters. +

    +
    +
    +
    +
    + +
    + +
    +
    Grouped query attention (GQA): interpolating between MHA and MQA
    +
    +
    + \u{25B8} + Details +
    +
    +

    + GQA uses $g$ KV head groups where $1 \leq g \leq h$. Each group of $h/g$ query heads + shares a single K/V head. The baseline uses $h=8, g=4$ (GQA-4), so every 2 query heads + share one KV head. This is the sweet spot: nearly the quality of full MHA at half the KV parameter cost. +

    +
    +
    +
    +
    + +
    + +
    Parameter cost analysis: how GQA reduces attention parameters
    +
    + +
    + +
    +
    Flash Attention: tiling, memory-efficient backward, IO complexity
    +
    +
    + \u{25B8} + Details +
    +
    +

    + Standard attention materializes the $T \times T$ attention matrix in HBM, costing $O(T^2)$ memory. + Flash Attention tiles the computation: it loads blocks of Q, K, V into SRAM, computes partial + softmax outputs, and accumulates results without ever writing the full attention matrix. +

    +

    + Result: exact attention (no approximation), $O(T)$ memory, and ~2-4x faster due to reduced HBM IO. + Flash Attention 3 adds Hopper-specific optimizations (warp specialization, TMA, FP8 accumulation). +

    +
    +
    +
    +
    + +
    + +
    Flash Attention 2 and 3: Hopper-specific optimizations
    +
    + + +
    + \u{25A0} +

    Readings

    +
    + +
    + +
    Shazeer, "Fast Transformer Decoding: One Write-Head is All You Need" (2019)
    +
    +
    + +
    Ainslie et al., "GQA: Training Generalized Multi-Query Transformer Models" (2023)
    +
    +
    + +
    Dao et al., "FlashAttention: Fast and Memory-Efficient Exact Attention" (2022)
    +
    +
    + +
    Dao, "FlashAttention-2: Faster Attention with Better Parallelism" (2023)
    +
    +
    + +
    Shah et al., "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision" (2024)
    +
    + + +
    + \u{25B2} +

    Exercises

    +
    + +
    + +
    +
    Implement GQA from scratch and verify it matches standard MHA when num_kv_heads = num_heads
    +
    +
    + \u{25B8} + Hints +
    +
    +

    + Key: when num_kv_heads < num_heads, repeat each KV head to match query heads using + repeat_interleave or reshape tricks. Pass enable_gqa=True to + F.scaled_dot_product_attention for the efficient path. +

    +
    +
    +
    +
    +
    + +
    + +
    Profile memory usage of standard attention vs Flash Attention at sequence length 2048
    +
    + +
    + +
    Calculate exact parameter savings from GQA at various head ratios for the 512d/8-head baseline
    +
    + + +
    + \u{2606} +

    Week 4 Quiz

    +
    +
    + +
    +
    + + +
    +
    + \u{25B6} + Week 5: Parameter-Efficient Architecture Variants +
    + +
    + \u{25CF} +

    Topics

    +
    + +
    + +
    +
    Depth recurrence and the Universal Transformer (Dehghani et al.)
    +
    +
    + \u{25B8} + Details +
    +
    +

    + The Universal Transformer applies the same transformer block repeatedly instead + of having unique weights per layer. With $K$ shared blocks run $N$ times each, + you get $K \times N$ effective depth for the parameter cost of $K$ layers. +

    +

    + Key challenges: +

    +
      +
    • Per-iteration conditioning: each pass needs to be distinct. Solutions: layer index embeddings, FiLM conditioning, learned gates
    • +
    • Training stability: gradients through many iterations can explode. Solutions: per-iteration RMSNorm, gradient clipping, careful LR
    • +
    • Wall-clock cost: more forward passes = slower steps = fewer total steps in 600s
    • +
    • Adaptive Computation Time (ACT): different tokens may need different iteration counts
    • +
    +

    + This is our primary research direction for the competition. +

    +
    +
    +
    +
    + +
    + +
    Weight sharing across layers: per-layer conditioning strategies
    +
    + +
    + +
    Adaptive Computation Time (ACT): variable iteration counts per token
    +
    + +
    + +
    +
    Mixture of Experts: sparse routing, top-k gating, load balancing
    +
    +
    + \u{25B8} + Details +
    +
    +

    + MoE replaces the dense MLP with $E$ expert MLPs and a gating network that routes each token + to its top-$k$ experts. Total parameters = $E \times$ MLP params, but active parameters per token = $k \times$ MLP params. + In Parameter Golf, MoE is tricky: all expert weights count toward the 16MB limit, but only $k$ are active. + The benefit is more total capacity per token, but the compression tax is steep. +

    +
    +
    +
    +
    + +
    + +
    Low-rank factorization: LoRA and its variants
    +
    + +
    + +
    +
    U-Net / encoder-decoder skip connections in transformers
    +
    +
    + \u{25B8} + Details +
    +
    +

    + The baseline splits layers into encoder (first half) and decoder (second half). + Encoder outputs are stored and added to corresponding decoder layers via learned skip weights: +

    +
    # In GPT.forward():
    +for i in range(num_encoder_layers):
    +    x = blocks[i](x, x0)
    +    skips.append(x)
    +for i in range(num_decoder_layers):
    +    x = x + skip_weights[i] * skips.pop()
    +    x = blocks[encoder_layers + i](x, x0)
    +

    + This gives the decoder access to both high-level (late encoder) and low-level (early encoder) features. +

    +
    +
    +
    +
    + + +
    + \u{25A0} +

    Readings

    +
    + +
    + +
    Dehghani et al., "Universal Transformers" (2019)
    +
    +
    + +
    Fedus et al., "Switch Transformers: Scaling to Trillion Parameter Models" (2022)
    +
    +
    + +
    Hu et al., "LoRA: Low-Rank Adaptation of Large Language Models" (2021)
    +
    +
    + +
    Bao et al., "All Are Worth Words: A ViT Backbone for Diffusion Models" (2023)
    +
    + + +
    + \u{25B2} +

    Exercises

    +
    + +
    + +
    +
    Implement a weight-shared transformer (3 blocks x 4 iterations) with layer index conditioning
    +
    +
    + \u{25B8} + Hints +
    +
    +

    + Start from the baseline GPT class. Replace nn.ModuleList([Block() for _ in range(N)]) + with nn.ModuleList([Block() for _ in range(K)]) and loop each block $N/K$ times. + Add a learned nn.Embedding(total_iterations, model_dim) that's added to the input + at each iteration. Compare loss against a standard 12-layer model with the same total parameter count. +

    +
    +
    +
    +
    +
    + +
    + +
    Compare parameter count and loss: 12-layer standard vs 3-block x 4-iteration recurrent at same width
    +
    + +
    + +
    Implement and ablate different per-layer conditioning: additive embedding, multiplicative gate, FiLM
    +
    + +
    +
    + + +
    +
    + \u{25B6} + Week 6: Specialized Modules for Small Models +
    + +
    + \u{25CF} +

    Topics

    +
    + +
    + +
    +
    BigramHash embeddings: hash-based n-gram features for small vocabularies
    +
    +
    + \u{25B8} + Details +
    +
    +

    + With a 1024-token vocabulary, the embedding table is small (1024 x 512 = 524K params). + BigramHash adds n-gram context by hashing adjacent token pairs into a separate embedding table: +

    +
    hash(t[i-1], t[i]) = (36313 * t[i] ^ 27191 * t[i-1]) % bigram_vocab_size
    +

    + The SOTA uses 3072 bigram entries at 112 dimensions, adding ~344K params. + The embeddings are zero-initialized so they don't disrupt early training. + A learned scale parameter controls their contribution. +

    +
    +
    +
    +
    + +
    + +
    +
    SmearGate: temporal smoothing via per-dimension gating
    +
    +
    + \u{25B8} + Details +
    +
    +
    gate = sigmoid(learned_gate)  # per-dimension, init to ~0
    +output = (1 - gate) * x + gate * x_prev
    +

    + Creates a learnable temporal smoothing between adjacent positions. + Initialized near zero (pass-through) so it doesn't disrupt initial training. + The model learns which dimensions benefit from position mixing. +

    +
    +
    +
    +
    + +
    + +
    Value Embeddings (VE): re-injecting token identity at deep layers
    +
    + +
    + +
    +
    Cross-Sequence Attention (XSA): removing self-value projection
    +
    +
    + \u{25B8} + Details +
    +
    +

    + XSA removes each token's self-value projection from its attention output. + This prevents the model from simply reading its own embedding and forces it to + attend meaningfully to other tokens in the context: +

    +
    # Subtract self-value projection
    +v_norm = F.normalize(v, dim=-1)
    +self_proj = (y * v_norm).sum(-1, keepdim=True) * v_norm
    +y = y - self_proj
    +

    + Zero parameter cost. Applied to all 11 layers in the SOTA. +

    +
    +
    +
    +
    + +
    + +
    Partial RoPE: applying position encoding to a subset of dimensions (16/64)
    +
    + +
    + +
    Logit soft-capping: tanh-based logit range control
    +
    + + +
    + \u{25A0} +

    Readings

    +
    + +
    + +
    Parameter Golf PR #162: BigramHash (Raahil Shah)
    +
    +
    + +
    Parameter Golf PR #65: SmearGate (aquariouseworkman)
    +
    +
    + +
    Parameter Golf PR #478: XSA / Cross-Sequence Attention (gowtham0992)
    +
    +
    + +
    Parameter Golf PR #315: Partial RoPE + LN Scale (jfprincz)
    +
    +
    + +
    Parameter Golf PR #374: VE128 / Value Embeddings (unnir)
    +
    + + +
    + \u{25B2} +

    Exercises

    +
    + +
    + +
    Implement BigramHash with configurable vocabulary size and embedding dimension
    +
    +
    + +
    Ablation study: add/remove each module individually and measure BPB delta
    +
    +
    + +
    Analyze BigramHash collision rates at vocabulary sizes 1024, 2048, 3072, 4096
    +
    + +
    +
    + +
    diff --git a/curriculum/units/unit-4.html b/curriculum/units/unit-4.html new file mode 100644 index 0000000000..42753e60f8 --- /dev/null +++ b/curriculum/units/unit-4.html @@ -0,0 +1,156 @@ +
    + +
    +

    Unit 4: Tokenization

    +
    + Week 7 + ~10 hours + Prereq: Units 1-2 +
    +

    + The tokenizer determines the fundamental unit of prediction. In a parameter-constrained setting, + vocabulary size directly trades off against model capacity. Understanding this tradeoff is critical. +

    +
    + +
    +
    + \u{25B6} + Week 7: Tokenizers and Their Impact on BPB +
    + +
    + \u{25CF} +

    Topics

    +
    + +
    + +
    +
    Byte Pair Encoding (BPE): algorithm, vocabulary construction, merge rules
    +
    +
    + \u{25B8} + Details +
    +
    +

    + BPE starts with individual bytes/characters and iteratively merges the most frequent pair + into a new token. After $V$ merges, you have a vocabulary of size $V + |\text{base}|$. + Each merge rule is deterministic: given the same text, the same tokenization results. + The baseline uses a 1024-token SentencePiece BPE vocabulary trained on FineWeb. +

    +
    +
    +
    +
    + +
    + +
    SentencePiece: unigram model vs BPE mode
    +
    + +
    + +
    +
    The tokenizer-agnostic BPB metric: how token-level loss converts to byte-level compression
    +
    +
    + \u{25B8} + Details +
    +
    +

    + $\text{BPB} = \frac{\text{bits\_per\_token}}{\text{bytes\_per\_token}} = \frac{\text{val\_loss} / \ln 2}{\text{bytes} / \text{tokens}}$ +

    +

    + A 1024-token vocab has ~2.5 bytes/token. A 32K vocab might have ~4.5 bytes/token. + The larger vocab model has lower bits/token (more info per token) but also higher bytes/token. + BPB normalizes this, making comparison fair. +

    +
    +
    +
    +
    + +
    + +
    +
    Why vocabulary size matters in parameter-constrained settings
    +
    +
    + \u{25B8} + Details +
    +
    +

    + Embedding table cost with tied embeddings: $V \times d$ params. + At $d = 512$: 1024 vocab = 524K params (3% of budget), 8192 vocab = 4.2M params (25% of budget). + Larger vocab captures more per token but leaves less room for the transformer body. + The SOTA uses 1024 + BigramHash(3072) as a compromise: tiny base vocab with learned bigram features. +

    +
    +
    +
    +
    + +
    + +
    Tied vs untied embeddings: parameter cost analysis
    +
    + +
    + +
    The BigramHash approach as a middle ground: small vocab + learned n-gram features
    +
    + +
    + +
    Byte-level tokenization and its tradeoffs
    +
    + +
    + \u{25A0} +

    Readings

    +
    + +
    + +
    Sennrich et al., "Neural Machine Translation of Rare Words with Subword Units" (BPE, 2016)
    +
    +
    + +
    Kudo, "Subword Regularization" (2018)
    +
    +
    + +
    Kudo & Richardson, "SentencePiece: A simple and language independent subword tokenizer" (2018)
    +
    + +
    + \u{25B2} +

    Exercises

    +
    + +
    + +
    Train BPE tokenizers at vocab sizes 512, 1024, 2048, 4096, 8192 on FineWeb
    +
    +
    + +
    For each vocab size: compute tokens-per-byte ratio and estimate embedding parameter cost at 512d
    +
    +
    + +
    Find the vocab size that minimizes total model BPB under a 16MB budget constraint
    +
    + +
    + \u{2606} +

    Week 7 Quiz

    +
    +
    + +
    +
    +
    diff --git a/curriculum/units/unit-5.html b/curriculum/units/unit-5.html new file mode 100644 index 0000000000..ec093533f3 --- /dev/null +++ b/curriculum/units/unit-5.html @@ -0,0 +1,132 @@ +
    + +
    +

    Unit 5: Optimization

    +
    + Weeks 8-10 + ~30 hours + Prereq: Units 1-3 +
    +

    + How you train matters as much as what you train. The Muon optimizer, distributed training with + Parallel Muon, and weight averaging (EMA/SWA) are critical components of the SOTA stack. +

    +
    + + +
    +
    \u{25B6} Week 8: Optimizers for Small Model Training
    + +
    \u{25CF}

    Topics

    + +
    Adam and AdamW: momentum, adaptive learning rates, weight decay
    +
    + +
    +
    The Muon optimizer: Newton-Schulz orthogonalization
    +
    +
    \u{25B8}Details
    +
    +

    + Muon applies Newton-Schulz iteration to orthogonalize gradient matrices before applying them. + The iteration uses fixed coefficients $(a, b, c) = (3.4445, -4.7750, 2.0315)$: +

    +
    X = G / ||G||
    +for _ in range(5):
    +    A = X @ X.T
    +    B = b*A + c*(A@A)
    +    X = a*X + B@X
    +

    + This converges to the orthogonal factor of $G$, effectively normalizing the gradient matrix. + Muon is used for all matrix-shaped parameters (attention, MLP weights). + Scalar/vector params and embeddings use Adam. +

    +
    +
    +
    +
    +
    Optimizer partitioning: Muon for matrices, Adam for scalars, separate LRs for embeddings
    +
    Learning rate schedules: warmup, cosine decay, wallclock-aware warmdown
    +
    Gradient clipping: global norm clipping, when and why
    + +
    \u{25A0}

    Readings

    +
    Kingma & Ba, "Adam: A Method for Stochastic Optimization" (2015)
    +
    Loshchilov & Hutter, "Decoupled Weight Decay Regularization" (AdamW, 2019)
    +
    Jordan, "Muon: An optimizer for hidden layers in neural networks" (2024)
    + +
    \u{25B2}

    Exercises

    +
    Implement Muon from scratch, including the Newton-Schulz iteration
    +
    Compare Adam vs Muon on the baseline: plot train loss curves over 1000 steps
    +
    Implement wallclock-aware warmdown and verify it adapts to variable step times
    + +
    +
    + + +
    +
    \u{25B6} Week 9: Distributed Training and Parallel Optimization
    + +
    \u{25CF}

    Topics

    +
    Data parallelism and DistributedDataParallel (DDP) in PyTorch
    +
    Gradient accumulation: simulating larger batch sizes
    +
    All-reduce, reduce-scatter, all-gather: collective communication primitives
    +
    NVLink and the communication topology of 8xH100 SXM
    +
    + +
    +
    The Parallel Muon strategy: Parameter Banking + overlapped communication
    +
    +
    \u{25B8}Details
    +
    +

    + Parameter Banking stores all attention/MLP weights as contiguous 3D tensors (e.g., qo_bank[2*L, d, d]). + This enables batched Newton-Schulz and overlapped communication: +

    +
      +
    1. Launch async reduce-scatter on bank gradients (largest first)
    2. +
    3. While waiting: run Adam steps on embeddings/scalars
    4. +
    5. Wait for reduce-scatter, apply Newton-Schulz on shards
    6. +
    7. Launch async all-gather to broadcast updated weights
    8. +
    +
    +
    +
    +
    +
    Scaling batch size: critical batch size, gradient noise scale
    + +
    \u{25A0}

    Readings

    +
    Li et al., "PyTorch Distributed: Experiences on Accelerating Data Parallel Training" (2020)
    +
    McCandlish et al., "An Empirical Model of Large-Batch Training" (2018)
    + +
    \u{25B2}

    Exercises

    +
    Profile a training step on 1 GPU vs 8 GPUs: measure communication overhead
    +
    Implement Parameter Banking: reshape layer weights into 3D bank tensors
    +
    Implement async reduce-scatter + all-gather with overlap and measure throughput
    + +
    +
    + + +
    +
    \u{25B6} Week 10: Weight Averaging and Ensemble Methods
    + +
    \u{25CF}

    Topics

    +
    Exponential Moving Average (EMA): decay rate selection, Polyak averaging
    +
    Stochastic Weight Averaging (SWA): snapshot collection in late training
    +
    LAWA (Latest-k Weight Average): averaging last K snapshots
    +
    Combining EMA and SWA: the SOTA uses EMA(0.997) + SWA(every 50 steps)
    +
    Connection to flat minima and generalization
    + +
    \u{25A0}

    Readings

    +
    Polyak & Juditsky, "Acceleration of Stochastic Approximation by Averaging" (1992)
    +
    Izmailov et al., "Averaging Weights Leads to Wider Optima and Better Generalization" (2018)
    +
    Kaddour et al., "Stop Wasting My Time! Saving Days of Training with LAWA" (2022)
    + +
    \u{25B2}

    Exercises

    +
    Implement EMA with configurable decay and compare final BPB: EMA vs no EMA
    +
    Implement SWA triggered by lr_scale threshold and measure improvement
    +
    Experiment: what EMA decay rate is optimal for the Parameter Golf training duration?
    + +
    +
    +
    diff --git a/curriculum/units/unit-6.html b/curriculum/units/unit-6.html new file mode 100644 index 0000000000..2eeaafc216 --- /dev/null +++ b/curriculum/units/unit-6.html @@ -0,0 +1,143 @@ +
    + +
    +

    Unit 6: Quantization & Compression

    +
    + Weeks 11-13 + ~30 hours + Prereq: Units 1-3, 5 +
    +

    + The bridge between training and the 16MB artifact limit. Post-training quantization (int8 to int6), + GPTQ with Hessian-informed rounding, quantization-aware training, and compression algorithms. + This unit is where bytes are won and lost. +

    +
    + + +
    +
    \u{25B6} Week 11: Post-Training Quantization Fundamentals
    + +
    \u{25CF}

    Topics

    + +
    Fixed-point number representation: int8, int6, int5, int4
    +
    + +
    +
    Per-tensor vs per-row vs per-channel quantization
    +
    +
    \u{25B8}Details
    +
    +

    + Per-tensor: one scale for the entire weight matrix. Cheapest metadata but worst accuracy. + Per-row: one scale per row (output channel). The baseline uses this for 2D tensors. + Per-channel: one scale per column (input channel). More metadata but better for asymmetric distributions. +

    +

    + The SOTA uses per-row int6 for large matrices and per-tensor int8 for embeddings. + Scale metadata: fp16 per row = 2 bytes/row. For a 512x512 matrix: 1024 bytes of scale data. +

    +
    +
    +
    +
    +
    Symmetric vs asymmetric quantization
    +
    Calibration: min-max, percentile clipping, MSE-optimal
    +
    The quantization error budget: how rounding errors accumulate through layers
    +
    Round-to-nearest vs more sophisticated rounding strategies
    + +
    \u{25A0}

    Readings

    +
    Jacob et al., "Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference" (2018)
    +
    Nagel et al., "A White Paper on Neural Network Quantization" (Qualcomm, 2021)
    +
    Gholami et al., "A Survey of Quantization Methods for Efficient Neural Network Inference" (2021)
    + +
    \u{25B2}

    Exercises

    +
    Implement per-row int8 quantization with percentile clipping (reproduce the baseline)
    +
    Measure BPB degradation from int8, int6, int5, int4 uniform quantization
    +
    Plot reconstruction MSE vs bits-per-weight for different clipping percentiles
    + +
    +
    + + +
    +
    \u{25B6} Week 12: Advanced Quantization (GPTQ, QAT)
    + +
    \u{25CF}

    Topics

    + +
    + +
    +
    GPTQ: Hessian-informed quantization with Cholesky error compensation
    +
    +
    \u{25B8}Details
    +
    +

    + GPTQ collects Hessians $H = X^TX$ from calibration data, then quantizes column by column. + For each column, the quantization error is redistributed to unquantized columns using the inverse Hessian. + Columns with large $H_{ii}^{-1}$ (important columns) get better precision. + Block-wise processing (128 columns at a time) makes it efficient. +

    +
    +
    +
    +
    +
    GPTQ-lite: diagonal Hessian approximation vs full Hessian
    +
    + +
    +
    Autoregressive self-generated calibration: the model generates its own GPTQ data
    +
    +
    \u{25B8}Details
    +
    +

    + The SOTA generates 64 sequences of 2048 tokens at temperature 0.8 as GPTQ calibration data. + No validation or training data is accessed during quantization. This is legal under the rules + and provides representative activations for Hessian collection. +

    +
    +
    +
    +
    +
    Quantization-Aware Training (QAT): Straight-Through Estimator (STE)
    +
    Late QAT: enabling STE only when lr_scale < 0.15
    +
    Mixed precision: different bit-widths for different layers/tensor types
    + +
    \u{25A0}

    Readings

    +
    Frantar et al., "GPTQ: Accurate Post-Training Quantization for Generative Pre-Trained Transformers" (2023)
    +
    Nagel et al., "Up or Down? Adaptive Rounding for Post-Training Quantization" (2020)
    +
    Bengio et al., "Estimating or Propagating Gradients Through Stochastic Neurons" (STE, 2013)
    + +
    \u{25B2}

    Exercises

    +
    Implement GPTQ with diagonal Hessian (GPTQ-lite) and compare vs uniform quantization
    +
    Implement full Hessian GPTQ with Cholesky error compensation
    +
    Implement late QAT with STE in CastedLinear when lr_scale < threshold
    +
    Implement AR self-generated calibration: generate sequences from trained model for GPTQ
    + +
    +
    + + +
    +
    \u{25B6} Week 13: Compression and Artifact Size Optimization
    + +
    \u{25CF}

    Topics

    +
    Entropy coding: Huffman, arithmetic coding, ANS
    +
    General-purpose compression: zlib, zstd, lzma and why lzma wins for quantized weights
    +
    The 16MB artifact budget: code bytes + compressed model bytes strategies
    +
    Selective pruning: removing low-impact quantized values by reconstruction error
    +
    Ternary quantization: {-1, 0, +1} weights with learned scales
    +
    1-bit quantization: binary weights
    + +
    \u{25A0}

    Readings

    +
    Zhu et al., "Trained Ternary Quantization" (2017)
    +
    Rastegari et al., "XNOR-Net: ImageNet Classification Using Binary Convolutional Neural Networks" (2016)
    + +
    \u{25B2}

    Exercises

    +
    Compare zlib-9, zstd-22, and lzma-9 compression ratios on a quantized checkpoint
    +
    Implement selective pruning: sort values by reconstruction error, prune to hit size target
    +
    Compute: at int6 with lzma-9, how many raw params fit in 16MB? What about int4?
    + +
    +
    +
    diff --git a/curriculum/units/unit-7.html b/curriculum/units/unit-7.html new file mode 100644 index 0000000000..30521a6d2c --- /dev/null +++ b/curriculum/units/unit-7.html @@ -0,0 +1,72 @@ +
    + +
    +

    Unit 7: Evaluation Methods

    +
    + Week 14 + ~10 hours + Prereq: Units 1-3 +
    +

    + How you evaluate matters. Sliding window evaluation, test-time training, and long-context extrapolation + can improve BPB without changing the model at all. The rules allow 10 minutes for evaluation, a separate budget from training. +

    +
    + +
    +
    \u{25B6} Week 14: Evaluation Strategies and Test-Time Compute
    + +
    \u{25CF}

    Topics

    + +
    Standard autoregressive evaluation: fixed context, sequential BPB
    +
    + +
    +
    Sliding window evaluation: stride selection, scoring "new" tokens only
    +
    +
    \u{25B8}Details
    +
    +

    + Instead of evaluating in non-overlapping seq_len chunks, slide the window by stride < seq_len. + Each token gets scored with the maximum available context (up to seq_len preceding tokens). + Only score the "new" tokens in each window to avoid double-counting. Stride=64 is common. + Tradeoff: smaller stride = more passes = better BPB but more eval compute. +

    +
    +
    +
    +
    +
    + +
    +
    Test-Time Training (TTT): adapting model on previously-evaluated tokens
    +
    +
    \u{25B8}Details
    +
    +

    + Legal TTT in Parameter Golf: you can only train on tokens you've already scored. + Score-first approach: evaluate a chunk, record losses, then fine-tune (e.g., LoRA rank-4) + on those same tokens. The next chunk benefits from the adapted model. + The SOTA dropped TTT because it was neutral on their stack, but it showed promise in earlier submissions. +

    +
    +
    +
    +
    +
    Long-context evaluation: NTK-aware RoPE scaling, YaRN, position extrapolation
    +
    Evaluation time budget: fitting within 10 minutes on 8xH100
    + +
    \u{25A0}

    Readings

    +
    Press et al., "Train Short, Test Long: ALiBi" (2022)
    +
    Sun et al., "Learning to (Learn at Test Time): RNNs with Expressive Hidden States" (2024)
    +
    Bloc97, "NTK-Aware Scaled RoPE" (2023)
    + +
    \u{25B2}

    Exercises

    +
    Implement sliding window evaluation with configurable stride
    +
    Measure BPB improvement from sliding window at strides 32, 64, 128, 256
    +
    Implement minimal LoRA TTT loop: fine-tune rank-4 LoRA on evaluated tokens
    +
    Profile eval time: how many sliding window passes fit in 10 minutes on 1xH100?
    + +
    +
    +
    diff --git a/curriculum/units/unit-8.html b/curriculum/units/unit-8.html new file mode 100644 index 0000000000..69072e6d3c --- /dev/null +++ b/curriculum/units/unit-8.html @@ -0,0 +1,63 @@ +
    + +
    +

    Unit 8: Systems & Performance

    +
    + Week 15 + ~10 hours + Prereq: Units 5, 9 +
    +

    + Every millisecond per step matters: faster steps = more training in 600 seconds = lower loss. + GPU architecture, kernel fusion, memory optimization, and profiling are the systems-level skills + that turn a good algorithm into a competitive submission. +

    +
    + +
    +
    \u{25B6} Week 15: GPU Programming, Kernels, and Training Throughput
    + +
    \u{25CF}

    Topics

    + +
    GPU architecture: SMs, warps, memory hierarchy (registers, shared memory, L2, HBM)
    +
    + +
    +
    The H100 Hopper architecture: TMA, warp specialization, FP8 tensor cores
    +
    +
    \u{25B8}Details
    +
    +

    + H100 SXM: 80GB HBM3 at 3.35 TB/s, 132 SMs, 989 TFLOPS BF16 tensor core. + Key Hopper features for Parameter Golf: +

    +
      +
    • TMA: Tensor Memory Accelerator handles async data movement, freeing SMs for compute
    • +
    • Warp specialization: producer warps load data while consumer warps compute (Flash Attention 3 uses this)
    • +
    • FP8 tensor cores: 2x throughput vs BF16 for matmuls
    • +
    • NVLink 4.0: 900 GB/s bidirectional between GPUs in 8xH100 SXM
    • +
    +
    +
    +
    +
    +
    torch.compile: tracing, graph breaks, fullgraph mode
    +
    Profiling training runs: torch.profiler, nsight systems, identifying bottlenecks
    +
    Memory optimization: activation checkpointing, mixed precision, Flash Attention memory savings
    +
    Maximizing tokens/second: batch size tuning, compiled models, fused kernels, async data loading
    +
    Custom CUDA kernels and Triton: fused operations, megakernels
    + +
    \u{25A0}

    Readings

    +
    NVIDIA, "H100 Tensor Core GPU Architecture" whitepaper (2022)
    +
    Ansel et al., "PyTorch 2: Faster Machine Learning Through Dynamic Python Bytecode Transformation" (2024)
    +
    Tillet et al., "Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations" (2019)
    + +
    \u{25B2}

    Exercises

    +
    Profile a training step with torch.profiler: identify the top 5 time-consuming operations
    +
    Implement activation checkpointing and measure memory savings
    +
    Write a Triton kernel for fused RMSNorm and benchmark against F.rms_norm
    +
    Find the maximum batch size that fits in 80GB H100 HBM for the baseline model
    + +
    +
    +
    diff --git a/curriculum/units/unit-9.html b/curriculum/units/unit-9.html new file mode 100644 index 0000000000..8c38c74c5c --- /dev/null +++ b/curriculum/units/unit-9.html @@ -0,0 +1,89 @@ +
    + +
    +

    Unit 9: Integration & Competition Strategy

    +
    + Week 16 + ~15 hours + Capstone +
    +

    + Everything comes together. Read the SOTA line by line, identify remaining headroom, + design experiments, validate with statistical rigor, and prepare a submission. +

    +
    + +
    +
    \u{25B6} Week 16: Putting It All Together
    + +
    \u{25CF}

    Topics

    + +
    The SOTA stack walkthrough: reading and understanding the #1 submission line by line
    +
    + +
    +
    Contribution axes and diminishing returns analysis
    +
    +
    \u{25B8}Details
    +
    +

    + Where is the most remaining headroom? +

    +
      +
    • Quantization: int6 GPTQ is near optimal. int5/int4 have diminishing returns.
    • +
    • Architecture: Depth recurrence, state-space models are unexplored.
    • +
    • Training efficiency: Parallel Muon is already highly optimized.
    • +
    • Evaluation: Sliding window + TTT have room, but SOTA dropped TTT.
    • +
    +
    +
    +
    +
    +
    + +
    +
    Experiment design: ablation methodology, statistical significance
    +
    +
    \u{25B8}Details
    +
    +

    + The submission process requires: +

    +
      +
    • Beat SOTA by at least 0.005 nats
    • +
    • Provide enough runs to show this at p < 0.01
    • +
    • Typically: 3-seed validation with Welch's t-test
    • +
    • Must reproduce on 8xH100 SXM in under 10 minutes
    • +
    +
    +
    +
    +
    +
    Submission process: records folder structure, logging, artifact generation
    +
    Risk management: compute cost estimation, iterating cheaply before scaling
    + +
    \u{25A0}

    Readings

    +
    PR #1019: Current SOTA (AR Self-Gen GPTQ + XSA + BigramHash)
    +
    PR #549: LeakyReLU^2 + Legal TTT + Parallel Muon
    +
    PR #414: Earlier SOTA stack
    +
    All README files in records/track_10min_16mb/
    +
    Parameter Golf FAQ and rules (challenge page + repo README)
    + +
    \u{25B2}

    Capstone Project

    + +
    +

    Final Submission

    +

    + Starting from the current SOTA codebase, implement one novel improvement. This is the culmination + of everything in the curriculum. +

    + +
    Implement one novel improvement on the SOTA stack
    +
    Run 3-seed validation on 8xH100
    +
    Prepare complete submission: README.md, submission.json, train_gpt.py, train logs
    +
    Present results with statistical analysis (mean, std, Welch's t-test)
    +
    + +
    +
    +
    diff --git a/records/track_10min_16mb/2026-04-02_DepthRecurrence_WeightShared/README.md b/records/track_10min_16mb/2026-04-02_DepthRecurrence_WeightShared/README.md new file mode 100644 index 0000000000..660978744e --- /dev/null +++ b/records/track_10min_16mb/2026-04-02_DepthRecurrence_WeightShared/README.md @@ -0,0 +1,109 @@ +# Record: SLOT-32 + Partial Depth Recurrence (Layers 4,5) — val_bpb 0.7736 + +**val_bpb: 0.7736** (3-seed mean, std 0.0026) | **~15.71 MB** | 8xH100 SXM (Vast.ai), 600s + +**Beats current SOTA (PR #1313, 0.8637 BPB) by 0.0901 BPB.** + +## 3-Seed Results + +| Seed | Steps | ms/step | Sliding BPB | **SLOT-32 BPB** | Artifact | +|------|-------|---------|-------------|----------------|----------| +| 42 | 4,929 | 121.7 | 1.1259 | **0.7732** | 15,656,490 | +| 1337 | 4,935 | 121.6 | 1.1257 | **0.7764** | 15,725,938 | +| 314 | 4,938 | 121.5 | 1.1255 | **0.7713** | 15,733,118 | +| **Mean** | | | **1.1257** | **0.7736** | | + +## Key Techniques + +| Technique | BPB Impact | Source | +|-----------|-----------|--------| +| **SLOT-32** (per-sample delta + logit bias, 32 AdamW steps) | **-0.352** | arXiv:2505.12392v2, PR #1229 | +| Partial depth recurrence (layers 4,5 repeated) | -0.005 | This work + PR #1204, PR #1260 | +| Per-iteration conditioning (iter_embed + iter_gate) | Novel | This work | +| XSA all 11 layers | -0.002 | PR #1176 | +| QK-Gain 4.0 | -0.003 | PR #1125 | +| VRL (Value Residual Learning) | -0.002 | arXiv:2410.17897, PR #175 | +| BigramHash 1024x128 | Input-level | PR #162 | +| EMA(0.997) + SWA(every 50) | Weight averaging | PR #401 | +| Late QAT (STE at scale < 0.15) | Quant-aware training | PR #286 | +| int6 + LZMA | Compression | PR #160, PR #535 | + +## SLOT-32 Configuration + +The primary contribution over prior SLOT submissions is tuning to 32 steps with higher learning rate: + +| Parameter | PR #1303 (0.9462) | PR #1313 (0.8637) | **This work (0.7736)** | +|-----------|-------------------|-------------------|----------------------| +| SLOT_STEPS | 16 | 24 | **32** | +| SLOT_LR | 0.008 | 0.012 | **0.015** | +| SLOT_LR_MIN | 0.0008 | 0.001 | **0.001** | +| EVAL_STRIDE | 64 | 96 | **96** | + +- Hidden delta: [bsz, 1, 512] + logit bias: [bsz, 1, 1024] +- 32 AdamW steps, cosine LR 0.015 -> 0.001, weight_decay=1e-8 +- Scored-position masking: last stride=96 tokens per non-first window +- Model weights frozen, delta optimized through detached hidden states +- Eval time: ~304s (within 10-min eval budget) + +## Partial Depth Recurrence + +Virtual 13-layer network from 11 unique blocks by repeating layers 4 and 5: + +``` +virtual_layers = [0, 1, 2, 3, 4, 4, 5, 5, 6, 7, 8, 9, 10] + ^ ^ (repeated) +``` + +Per-iteration conditioning ensures repeated passes are distinct: +```python +gate = sigmoid(iter_gate[i]) # learned, starts near 0 +x = x + gate * iter_embed[i] # additive conditioning +x = blocks[layer_idx](x, x0) # same weights, different input +``` + +Active from step 0 (static graph for torch.compile fullgraph=True). + +## Compliance + +- Score-first SLOT (frozen model, torch.no_grad() hidden states) +- No external data access during eval +- No n-gram cache, no two-pass rescoring, no warmstart between windows +- Self-contained (no network calls) +- All seeds: training 600s, eval ~304s, total ~904s (within combined budget) +- All artifacts under 16MB + +## Reproduction + +```bash +pip install sentencepiece huggingface_hub +python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards 80 + +RECUR_LAYERS=4,5 SLOT_STEPS=32 SLOT_LR=0.015 SLOT_LR_MIN=0.001 \ + EVAL_STRIDE=96 SEED=42 \ + torchrun --standalone --nproc_per_node=8 train_gpt_slot_recurrence.py +``` + +## Lineage + +``` +PR #1019 (Merged SOTA, 1.1147 BPB) + +-- PR #1303 (SLOT-16 + VRL + XSA-11, 0.9462 BPB) + +-- PR #1313 (SLOT-24 tuning, 0.8637 BPB) + +-- This work: + +-- SLOT-32 tuning (32 steps, LR=0.015) + +-- Partial depth recurrence (layers 4,5) + +-- Per-iteration conditioning (iter_embed + iter_gate) +``` + +## Credits + +- SLOT: Hu et al. arXiv:2505.12392v2, PR #1176 (@bigbag), PR #1229 (@resouer) +- SLOT-24 tuning: PR #1313 +- Depth recurrence: PR #1204 (@msisovic), PR #1260 (@dexhunter) +- Base architecture: PR #1303 (@resouer), PR #1019 (@abaybektursun) +- QK-Gain: PR #1125 (@bigbag) +- VRL: arXiv:2410.17897, PR #175 (@anthony-maio) + +## Author + +Arnell Milhouse (@GitGeeks) diff --git a/records/track_10min_16mb/2026-04-02_DepthRecurrence_WeightShared/results_summary.md b/records/track_10min_16mb/2026-04-02_DepthRecurrence_WeightShared/results_summary.md new file mode 100644 index 0000000000..fb403a7978 --- /dev/null +++ b/records/track_10min_16mb/2026-04-02_DepthRecurrence_WeightShared/results_summary.md @@ -0,0 +1,109 @@ +# Depth Recurrence — Results Summary + +**Project:** Parameter Golf Challenge (OpenAI) +**PR:** openai/parameter-golf#1278 +**Date:** April 3, 2026 +**Author:** GitGeeks (milhouse) + +--- + +## What We're Doing + +Weight-shared depth recurrence: instead of N unique transformer blocks, share K blocks and iterate them multiple times. This gives more effective depth per parameter, directly optimizing L(N). + +This is listed as an OpenAI "Request for PR" technique. No prior submission has made it work. + +--- + +## Phase 1 Results: GPU Throughput (H100 SXM) + +**Verdict: GREEN LIGHT — recurrence is essentially free on H100.** + +| Config | Unique Layers | Repeats | Eff Depth | Params | step_avg | Overhead | +|--------|--------------|---------|-----------|--------|----------|----------| +| Baseline | 9 | 1 | 9 | 17.1M | 518ms | 1.00x | +| 3x3 | 3 | 3 | 9 | 6.0M | 508ms | **0.98x** | +| 3x5 | 3 | 5 | 15 | 6.1M | 550ms | **1.06x** | +| 4x5 | 4 | 5 | 20 | 7.9M | 693ms | **1.34x** | + +Key finding: 3x3 is actually **faster** than baseline because fewer parameters = less memory bandwidth. Even 4x5 (20 effective layers) is only 1.34x overhead. + +--- + +## MLX Validation Results (Local) + +| Config | Params | Eff Depth | val_bpb | Compressed Size | +|--------|--------|-----------|---------|-----------------| +| Baseline (9 unique layers) | 17.1M | 9 | 3.2273 | 5.10 MB | +| Recurrence 3L x 3R | 6.0M | 9 | 3.2264 | 1.87 MB | +| **Recurrence 3L x 7R** | **6.1M** | **21** | **3.2134** | **1.89 MB** | + +- 3x3 matches baseline at 2.8x fewer params +- 3x7 **beats baseline by 0.014 BPB** at same param count, 2.3x deeper +- Deeper recurrence converges faster at every step count + +--- + +## Parameter Budget Analysis (16MB artifact limit) + +| Config | Params | Estimated Artifact | Eff Depth | Notes | +|--------|--------|-------------------|-----------|-------| +| Current SOTA (11L, 512d) | 27M | ~15.9 MB | 11 | Near ceiling | +| 4 unique x 5 reps, 768d | 22.6M | ~13.3 MB | 20 | 1.5x wider, 1.8x deeper | +| 3 unique x 7 reps, 768d | 17.3M | ~10.2 MB | 21 | Most depth-efficient | + +Both recurrence configs fit within 16MB while delivering ~2x the effective depth at greater width. + +--- + +## Wallclock Projections (10 min on 8xH100) + +| Config | step_avg (1xH100) | Est. steps in 600s | Effective depth | +|--------|-------------------|-------------------|-----------------| +| Baseline 9L | 518ms | ~1158 | 9 | +| Rec 3x3 | 508ms | ~1181 | 9 | +| Rec 3x5 | 550ms | ~1091 | 15 | +| Rec 4x5 | 693ms | ~866 | 20 | + +Even at 1.34x overhead, the 4x5 config gets 866 steps with 20 effective layers. The deeper model converges faster per step, so fewer total steps are needed. + +--- + +## Implementation + +### Per-iteration conditioning +```python +# Before each block call at effective layer i: +gate = sigmoid(iter_gate[i]) # starts near 0 (init: -2.0) +x = x + gate * iter_embed[i] # additive conditioning +x = blocks[i % num_unique](x, x0) # shared block with cycling +``` + +### U-Net skip connections +Adapted for effective depth: encoder = first half of effective layers, decoder = second half with reversed skips. + +### Files +- `train_gpt_mlx_recurrence.py` — MLX prototype (validated) +- `train_gpt_recurrence.py` — CUDA port (Phase 1 tested on H100) + +--- + +## What's Next + +| Phase | Description | Cost | Status | +|-------|-------------|------|--------| +| 1 | GPU throughput gate | ~$1.50 | **DONE** | +| 2 | Architecture search (wider/deeper configs) | ~$20 | Next | +| 3 | SOTA stack integration (GPTQ+XSA+BigramHash) | ~$80 | Planned | +| 4 | Final submission (3-seed validation) | ~$60 | Planned | + +**Budget:** ~$163 estimated of $500 grant + +--- + +## Competition Context + +- **Current SOTA:** 1.1147 BPB (PR #1019, abaybektursun) +- **Baseline:** 1.2244 BPB +- **Challenge:** Lowest val_bpb within 16MB artifact, 10 min on 8xH100 +- **Deadline:** April 30, 2026 diff --git a/records/track_10min_16mb/2026-04-02_DepthRecurrence_WeightShared/submission.json b/records/track_10min_16mb/2026-04-02_DepthRecurrence_WeightShared/submission.json new file mode 100644 index 0000000000..edb1d5d5f7 --- /dev/null +++ b/records/track_10min_16mb/2026-04-02_DepthRecurrence_WeightShared/submission.json @@ -0,0 +1,17 @@ +{ + "name": "SLOT-32 + Partial Depth Recurrence (Layers 4,5) + XSA-11", + "author": "Arnell Milhouse", + "github_id": "GitGeeks", + "val_loss_mean": 1.3063, + "val_bpb_mean": 0.7736, + "val_bpb_std": 0.0026, + "seeds": { + "42": {"val_loss": 1.3055, "val_bpb": 0.7732, "artifact_bytes": 15656490}, + "1337": {"val_loss": 1.3109, "val_bpb": 0.7764, "artifact_bytes": 15725938}, + "314": {"val_loss": 1.3024, "val_bpb": 0.7713, "artifact_bytes": 15733118} + }, + "technique_summary": "SLOT-32 (32 AdamW steps, LR=0.015, cosine->0.001) + partial depth recurrence (layers 4,5) + per-iteration conditioning + XSA-11 + QK-Gain 4.0 + VRL + BigramHash + EMA/SWA + Late QAT + int6+LZMA", + "hardware": "8xH100 SXM (Vast.ai)", + "training_time_s": 600, + "eval_time_s": 304 +} diff --git a/records/track_10min_16mb/2026-04-02_DepthRecurrence_WeightShared/train_gpt.py b/records/track_10min_16mb/2026-04-02_DepthRecurrence_WeightShared/train_gpt.py new file mode 100644 index 0000000000..5b25eb4ae6 --- /dev/null +++ b/records/track_10min_16mb/2026-04-02_DepthRecurrence_WeightShared/train_gpt.py @@ -0,0 +1,1510 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +from pathlib import Path +import lzma +_COMPRESSOR = "lzma" +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True +except ImportError: + try: + from flash_attn.flash_attn_interface import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True + except ImportError: + try: + from flash_attn import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True + except ImportError: + _HAS_FA3 = False + flash_attn_3_func = None +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 4.0)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) # tighter: collect more recent checkpoints + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 1024)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 11)) # XSA on last 11 layers (0 = disabled) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + vrl_enabled = bool(int(os.environ.get("VRL_ENABLED", "1"))) + slot_enabled = bool(int(os.environ.get("SLOT_ENABLED", "1"))) + slot_steps = int(os.environ.get("SLOT_STEPS", 16)) + slot_lr = float(os.environ.get("SLOT_LR", 0.008)) + slot_lr_min = float(os.environ.get("SLOT_LR_MIN", 0.0008)) + slot_batch_seqs = int(os.environ.get("SLOT_BATCH_SEQS", 32)) + recur_layers = os.environ.get("RECUR_LAYERS", "") # e.g., "4,5" + recur_start_step = int(os.environ.get("RECUR_START_STEP", 3000)) +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale", + ).split(",") + if pattern +) +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + self.vrl_gate = None # set by GPT.__init__ when VRL is enabled + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] — broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, v_embed: Tensor | None = None, v_first: Tensor | None = None) -> tuple[Tensor, Tensor]: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + v_raw = v # save raw V before any modifications for VRL + if v_embed is not None: + v = v + v_embed + if v_first is not None and self.vrl_gate is not None: + gate = torch.sigmoid(self.vrl_gate.to(dtype=v.dtype)) + v = (1 - gate) * v + gate * v_first + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if _HAS_FA3: + y = flash_attn_3_func(q, k, v, causal=True) + else: + q2 = q.transpose(1, 2) + k2 = k.transpose(1, 2) + v2 = v.transpose(1, 2) + k2 = k2.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + v2 = v2.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + y = F.scaled_dot_product_attention(q2, k2, v2, is_causal=True) + y = y.transpose(1, 2).contiguous() + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y), v_raw +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class ValueEmbedding(nn.Module): + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None, v_first: Tensor | None = None) -> tuple[Tensor, Tensor]: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out, v_raw = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, v_first=v_first) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out, v_raw +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + vrl_enabled: bool = False, + recur_layers: str = "", + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + ) + for i in range(num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self.vrl_enabled = vrl_enabled + if vrl_enabled: + for i in range(1, num_layers): # all layers except layer 0 + self.blocks[i].attn.vrl_gate = nn.Parameter(torch.tensor(-1.5, dtype=torch.float32)) + # Partial depth recurrence: repeat specified layers once to increase effective depth + self.recur_layer_set = {int(x) for x in recur_layers.split(",") if x.strip()} if recur_layers else set() + # Build virtual layer schedule: e.g., [0,1,2,3,4,5,4,5,6,7,8,9,10] for recur_layers="4,5" + self.virtual_layers = [] + for i in range(num_layers): + self.virtual_layers.append(i) + if i in self.recur_layer_set: + self.virtual_layers.append(i) # repeat + self.recurrence_active = False # activated by training loop after recur_start_step + # Per-iteration conditioning for repeated layers only + num_recur_positions = sum(1 for idx, v in enumerate(self.virtual_layers) if self.virtual_layers[:idx].count(v) > 0) + if num_recur_positions > 0: + self.iter_embed = nn.Parameter(torch.randn(num_recur_positions, model_dim) * 0.02) + self.iter_gate = nn.Parameter(torch.full((num_recur_positions, model_dim), -2.0)) + else: + self.iter_embed = None + self.iter_gate = None + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def _run_layers(self, x: Tensor, x0: Tensor, input_ids: Tensor) -> Tensor: + ve_cache: dict = {} + v_first: Tensor | None = None + # Determine schedule: virtual (with repeats) or standard + if self.recurrence_active and self.recur_layer_set: + schedule = self.virtual_layers + else: + schedule = list(range(len(self.blocks))) + eff_depth = len(schedule) + n_enc = eff_depth // 2 + n_dec = eff_depth - n_enc + skips: list[Tensor] = [] + recur_idx = 0 # index into iter_embed/iter_gate + seen_counts: dict[int, int] = {} + # Encoder half + for eff_i in range(n_enc): + bi = schedule[eff_i] + seen_counts[bi] = seen_counts.get(bi, 0) + 1 + is_repeat = seen_counts[bi] > 1 + # Per-iteration conditioning on repeated layers + if is_repeat and self.iter_embed is not None: + gate = torch.sigmoid(self.iter_gate[recur_idx].to(dtype=x.dtype))[None, None, :] + x = x + gate * self.iter_embed[recur_idx].to(dtype=x.dtype)[None, None, :] + recur_idx += 1 + ve = self._get_ve(bi, input_ids, ve_cache) + x, v_raw = self.blocks[bi](x, x0, v_embed=ve, v_first=v_first if self.vrl_enabled else None) + if eff_i == 0 and self.vrl_enabled: + v_first = v_raw + skips.append(x) + # Decoder half + for i in range(n_dec): + eff_i = n_enc + i + bi = schedule[eff_i] + seen_counts[bi] = seen_counts.get(bi, 0) + 1 + is_repeat = seen_counts[bi] > 1 + if i < len(skips) and skips: + skip_idx = min(i, self.skip_weights.size(0) - 1) if self.skip_weights.size(0) > 0 else -1 + if skip_idx >= 0: + x = x + self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] * skips.pop() + else: + skips.pop() + if is_repeat and self.iter_embed is not None: + gate = torch.sigmoid(self.iter_gate[recur_idx].to(dtype=x.dtype))[None, None, :] + x = x + gate * self.iter_embed[recur_idx].to(dtype=x.dtype)[None, None, :] + recur_idx += 1 + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, v_embed=ve, v_first=v_first if self.vrl_enabled else None) + return x + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + x = self._run_layers(x, x0, input_ids) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def forward_hidden(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + x = self._run_layers(x, x0, input_ids) + return self.final_norm(x) + def compute_logits(self, hidden: Tensor) -> Tensor: + if self.tie_embeddings: + logits_proj = F.linear(hidden, self.tok_emb.weight) + else: + logits_proj = self.lm_head(hidden) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + def forward_logits(self, input_ids: Tensor) -> Tensor: + return self.compute_logits(self.forward_hidden(input_ids)) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +def eval_val_slot( + args: Hyperparameters, + base_model: nn.Module, + rank: int, world_size: int, device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + stride = args.eval_stride if args.eval_stride > 0 else 64 + seq_s = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + total_tok = val_tokens.numel() - 1 + ws_list = list(range(0, total_tok, stride)) + ws_list = [ws for ws in ws_list if min(ws + seq_s, total_tok) - ws >= 1] + my_ws = ws_list[rank::world_size] + if args.tie_embeddings: + proj_w = base_model.tok_emb.weight.detach().float() + else: + proj_w = base_model.lm_head.weight.detach().float() + softcap = base_model.logit_softcap + compiled_hidden = torch.compile(base_model.forward_hidden, dynamic=False, fullgraph=False) + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + for bi in range(0, len(my_ws), args.slot_batch_seqs): + bws = my_ws[bi:bi + args.slot_batch_seqs] + bsz = len(bws) + xb_cpu = torch.zeros(bsz, seq_s, dtype=torch.int64) + yb_cpu = torch.zeros(bsz, seq_s, dtype=torch.int64) + wlens = [] + for i, ws in enumerate(bws): + wend = min(ws + seq_s, total_tok) + wlen = wend - ws + wlens.append(wlen) + xb_cpu[i, :wlen] = val_tokens[ws:wend] + yb_cpu[i, :wlen] = val_tokens[ws + 1:wend + 1] + xb = xb_cpu.to(device=device, non_blocking=True) + yb = yb_cpu.to(device=device, non_blocking=True) + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + hidden = compiled_hidden(xb) + hidden_f = hidden.detach().float() + mask = torch.zeros(bsz, seq_s, device=device) + for i, ws in enumerate(bws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + mask[i, s:wlen] = 1.0 + valid_count = mask.sum() + if valid_count == 0: + continue + delta = torch.zeros(bsz, 1, hidden_f.size(-1), device=device, dtype=torch.float32, requires_grad=True) + logit_bias = torch.zeros(bsz, 1, proj_w.size(0), device=device, dtype=torch.float32, requires_grad=True) + slot_opt = torch.optim.AdamW([delta, logit_bias], lr=args.slot_lr, weight_decay=1e-8, eps=1e-5) + targets_flat = yb.reshape(-1) + for step_i in range(args.slot_steps): + lr_t = args.slot_lr_min + 0.5 * (args.slot_lr - args.slot_lr_min) * (1 + math.cos(math.pi * step_i / args.slot_steps)) + for pg in slot_opt.param_groups: + pg['lr'] = lr_t + slot_opt.zero_grad() + h = hidden_f + delta + lp = F.linear(h, proj_w) + logit_bias + lg = softcap * torch.tanh(lp / softcap) + nll = F.cross_entropy(lg.reshape(-1, lg.size(-1)), targets_flat, reduction="none").reshape(bsz, seq_s) + slot_loss = (nll * mask).sum() / valid_count + slot_loss.backward() + slot_opt.step() + with torch.no_grad(): + h = hidden_f + delta.detach() + lp = F.linear(h, proj_w) + logit_bias.detach() + lg = softcap * torch.tanh(lp / softcap) + nll = F.cross_entropy(lg.reshape(-1, lg.size(-1)), targets_flat, reduction="none").reshape(bsz, seq_s) + for i, ws in enumerate(bws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + chunk_nll = nll[i, s:wlen] + loss_sum += chunk_nll.sum().to(torch.float64) + token_count += float(wlen - s) + prev_ids = xb[i, s:wlen] + tgt_ids = yb[i, s:wlen] + tb = base_bytes_lut[tgt_ids].to(torch.float64) + tb += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.float64) + byte_sum += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_sum.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.to(torch.float16) + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._qat_enabled = args.qat_enabled + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + vrl_enabled=args.vrl_enabled, + recur_layers=args.recur_layers, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.mtp_num_heads > 0: + matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + if base_model.iter_embed is not None: + scalar_params.append(base_model.iter_embed) + scalar_params.append(base_model.iter_gate) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + if args.recur_layers: + log0(f"recurrence:layers={args.recur_layers} start_step={args.recur_start_step} virtual_depth={len(base_model.virtual_layers)} schedule={base_model.virtual_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + if args.recur_layers and step >= args.recur_start_step and not base_model.recurrence_active: + base_model.recurrence_active = True + log0(f"recurrence:activated step:{step} layers:{args.recur_layers} virtual_depth:{len(base_model.virtual_layers)}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = lzma.compress(quant_raw, preset=6) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + log0(f"Total submission size: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(lzma.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, # must match training model + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + vrl_enabled=args.vrl_enabled, + recur_layers=args.recur_layers, + ).to(device).bfloat16() + if args.recur_layers: + eval_model.recurrence_active = True # always active at eval time + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int6_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int6_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + if args.slot_enabled: + torch._dynamo.reset() + torch.cuda.synchronize() + t_slot = time.perf_counter() + slot_val_loss, slot_val_bpb = eval_val_slot( + args, eval_model, rank, world_size, device, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_slot val_loss:{slot_val_loss:.4f} val_bpb:{slot_val_bpb:.4f} " + f"steps:{args.slot_steps} lr:{args.slot_lr} eval_time:{1000.0 * (time.perf_counter() - t_slot):.0f}ms" + ) + log0(f"final_slot_exact val_loss:{slot_val_loss:.8f} val_bpb:{slot_val_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/train_gpt_mlx_recurrence.py b/train_gpt_mlx_recurrence.py new file mode 100644 index 0000000000..ed3fab5ed6 --- /dev/null +++ b/train_gpt_mlx_recurrence.py @@ -0,0 +1,1143 @@ +#!/usr/bin/env python3 +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" +from __future__ import annotations + +import glob +import json +import math +import os +import pickle +import sys +import time +import uuid +import zlib +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import sentencepiece as spm + +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim +from mlx.utils import tree_flatten, tree_unflatten + +# ============================================================================== +# SHARD FORMAT + COMPUTE DTYPE +# ============================================================================== + +COMPUTE_DTYPE = mx.bfloat16 + +# ============================================================================== +# HYPERPARAMETERS +# ============================================================================== +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap +class Hyperparameters: + # Data / tokenizer. + data_path: str = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + tokenizer_path: str = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id: str = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed: int = int(os.environ.get("SEED", 1337)) + + # Training loop. These defaults now mirror train_gpt.py on a single process. + iterations: int = int(os.environ.get("ITERATIONS", 20_000)) + val_loss_every: int = int(os.environ.get("VAL_LOSS_EVERY", 0)) + # Validation always uses the full fineweb_val split. + val_batch_size: int = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + train_log_every: int = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + train_batch_tokens: int = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + grad_accum_steps: int = int(os.environ.get("GRAD_ACCUM_STEPS", 8)) + train_seq_len: int = int(os.environ.get("TRAIN_SEQ_LEN", os.environ.get("TRAIN_MAX_SEQ_LEN", 1024))) + # Chunk each logical MLX microbatch into smaller sub-batches to reduce peak + # memory pressure without changing the effective optimizer batch. + mlx_max_microbatch_tokens: int = int(os.environ.get("MLX_MAX_MICROBATCH_TOKENS", 8_192)) + # Force MLX to materialize the graph after every sub-batch, preventing lazy + # graph buildup across accumulation steps. Keeps peak memory low on 16GB machines. + # Disable on 32GB+ unified memory for better throughput (MLX_EAGER_EVAL=0). + mlx_eager_eval: bool = bool(int(os.environ.get("MLX_EAGER_EVAL", "1"))) + warmup_steps: int = int(os.environ.get("WARMUP_STEPS", 20)) + warmdown_iters: int = int(os.environ.get("WARMDOWN_ITERS", 1200)) + max_wallclock_seconds: float = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + + # Model (defaults match the current baseline setup). + vocab_size: int = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers: int = int(os.environ.get("NUM_LAYERS", 9)) + model_dim: int = int(os.environ.get("MODEL_DIM", 512)) + num_heads: int = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads: int = int(os.environ.get("NUM_KV_HEADS", 4)) + mlp_mult: int = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings: bool = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + tied_embed_init_std: float = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + logit_chunk_tokens: int = int(os.environ.get("LOGIT_CHUNK_TOKENS", 0)) + logit_softcap: float = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + rope_base: float = float(os.environ.get("ROPE_BASE", 10000.0)) + qk_gain_init: float = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Depth recurrence: K unique blocks, each repeated N times = K*N effective depth. + # num_layers is now the number of *unique* blocks (K). + # num_repeats is how many times each block is applied (N). + # Effective depth = num_layers * num_repeats. + num_repeats: int = int(os.environ.get("NUM_REPEATS", 1)) + + # Optimizer. We keep the same per-group defaults as train_gpt.py. + beta1: float = float(os.environ.get("BETA1", 0.9)) + beta2: float = float(os.environ.get("BETA2", 0.95)) + adam_eps: float = float(os.environ.get("ADAM_EPS", 1e-8)) + tied_embed_lr: float = float(os.environ.get("TIED_EMBED_LR", 0.05)) + matrix_lr: float = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr: float = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum: float = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps: int = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start: float = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps: int = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + grad_clip_norm: float = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + out_dir: str = os.environ.get("OUT_DIR", "logs") + + @property + def train_files(self) -> str: + return f"{self.data_path}/fineweb_train_*.bin" + + @property + def val_files(self) -> str: + return f"{self.data_path}/fineweb_val_*.bin" + + @property + def microbatch_tokens(self) -> int: + return self.train_batch_tokens // self.grad_accum_steps + + def lr_mul(self, step: int, elapsed_ms: float) -> float: + if self.warmdown_iters <= 0: + return 1.0 + if self.max_wallclock_seconds <= 0: + warmdown_start = max(self.iterations - self.warmdown_iters, 0) + return max((self.iterations - step) / max(self.warmdown_iters, 1), 0.0) if warmdown_start <= step < self.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = self.warmdown_iters * step_ms + remaining_ms = max(1000.0 * self.max_wallclock_seconds - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) + + +def token_chunks(total_tokens: int, seq_len: int, max_chunk_tokens: int) -> list[int]: + usable_total = (total_tokens // seq_len) * seq_len + if usable_total <= 0: + raise ValueError(f"token budget too small for seq_len={seq_len}") + usable_chunk = max((max_chunk_tokens // seq_len) * seq_len, seq_len) + chunks: list[int] = [] + remaining = usable_total + while remaining > 0: + chunk = min(remaining, usable_chunk) + chunks.append(chunk) + remaining -= chunk + return chunks + + +def accumulate_flat_grads( + accum: dict[str, mx.array] | None, + grads_tree: dict, + scale: float, +) -> dict[str, mx.array]: + flat = dict(tree_flatten(grads_tree)) + if accum is None: + return {k: g * scale for k, g in flat.items()} + for k, g in flat.items(): + accum[k] = accum[k] + g * scale + return accum + + +# ============================================================================== +# MATH HELPERS +# ============================================================================== + +def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: + return (x * mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + eps)).astype(x.dtype) + + +def zeropower_newtonschulz5(g: mx.array, steps: int, eps: float = 1e-7) -> mx.array: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + # Background on Muon: https://kellerjordan.github.io/posts/muon/ + a, b, c = 3.4445, -4.7750, 2.0315 + x = g.astype(mx.float32) + x = x / (mx.sqrt(mx.sum(x * x)) + eps) + transposed = x.shape[0] > x.shape[1] + if transposed: + x = x.T + for _ in range(steps): + a_mat = x @ x.T + b_mat = b * a_mat + c * (a_mat @ a_mat) + x = a * x + b_mat @ x + if transposed: + x = x.T + return x.astype(g.dtype) + + +def load_data_shard(path: Path) -> np.ndarray: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + if self.file_idx == 0: + self.epoch += 1 + if self.log_fn is not None: + self.log_fn( + f"WARNING: starting epoch:{self.epoch} " + f"dataset:{self.dataset_name} train_shards:{len(self.files)}" + ) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> np.ndarray: + chunks: list[np.ndarray] = [] + left = n + while left > 0: + if self.pos >= self.tokens.size: + self.next_file() + k = min(left, int(self.tokens.size - self.pos)) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + left -= k + return chunks[0] if len(chunks) == 1 else np.concatenate(chunks, axis=0) + + +class TokenLoader: + def __init__( + self, + pattern: str, + log_fn: Callable[[str], None] | None = None, + dataset_name: str = "", + ): + self.stream = TokenStream(pattern, log_fn=log_fn, dataset_name=dataset_name) + + def next_batch(self, batch_tokens: int, seq_len: int) -> tuple[mx.array, mx.array]: + usable = (batch_tokens // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"token budget too small for seq_len={seq_len}") + chunk = self.stream.take(usable + 1) + x = chunk[:-1].reshape(-1, seq_len) + y = chunk[1:].reshape(-1, seq_len) + return mx.array(x, dtype=mx.int32), mx.array(y, dtype=mx.int32) + + +# ============================================================================== +# MODEL BLOCKS +# ============================================================================== + +class CastedLinear(nn.Module): + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + self.weight = nn.Linear(in_dim, out_dim, bias=False).weight.astype(mx.float32) + + def __call__(self, x: mx.array) -> mx.array: + return x @ self.weight.astype(x.dtype).T + + +class RMSNormNoWeight(nn.Module): + # MLX module wrapper around the functional RMSNorm helper so it composes nicely in blocks. + def __call__(self, x: mx.array) -> mx.array: + return rms_norm(x) + + +class CausalSelfAttention(nn.Module): + # - separate q/k/v projections + # - RMSNorm on q and k before attention + # - RoPE on q and k + # - causal masked SDPA + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim) + self.c_k = CastedLinear(dim, kv_dim) + self.c_v = CastedLinear(dim, kv_dim) + self.proj = CastedLinear(dim, dim) + self.q_gain = mx.ones((num_heads,), dtype=mx.float32) * qk_gain_init + self.rope = nn.RoPE(self.head_dim, traditional=False, base=rope_base) + self.scale = self.head_dim ** -0.5 + + def __call__(self, x: mx.array) -> mx.array: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + + q = self.rope(rms_norm(q).astype(COMPUTE_DTYPE)) + k = self.rope(rms_norm(k).astype(COMPUTE_DTYPE)) + q = q * self.q_gain.astype(q.dtype)[None, :, None, None] + y = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask="causal") + y = y.transpose(0, 2, 1, 3).reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # Baseline MLP uses relu^2 instead of GELU/SiLU. It is cheap and works well in this setup. + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = dim * mlp_mult + self.fc = CastedLinear(dim, hidden) + self.proj = CastedLinear(hidden, dim) + + def __call__(self, x: mx.array) -> mx.array: + x = nn.relu(self.fc(x)) + return self.proj(x * x) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNormNoWeight() + self.mlp_norm = RMSNormNoWeight() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = mx.ones((dim,), dtype=mx.float32) + self.mlp_scale = mx.ones((dim,), dtype=mx.float32) + self.resid_mix = mx.array(np.stack((np.ones((dim,), dtype=np.float32), np.zeros((dim,), dtype=np.float32)))) + + def __call__(self, x: mx.array, x0: mx.array) -> mx.array: + mix = self.resid_mix.astype(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.astype(x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.astype(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + # Depth-recurrent GPT: K unique blocks repeated N times each = K*N effective layers. + # Per-iteration conditioning via learned iteration embeddings (added to x before each block call). + # U-Net skip connections adapted for the recurrent structure. + def __init__(self, vocab_size: int, num_layers: int, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, + logit_chunk_tokens: int, logit_softcap: float, rope_base: float, tied_embed_init_std: float, + qk_gain_init: float, num_repeats: int = 1): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.logit_chunk_tokens = logit_chunk_tokens + self.logit_softcap = logit_softcap + self.num_repeats = num_repeats + + self.tok_emb = nn.Embedding(vocab_size, dim) + + # K unique blocks + self.blocks = [ + Block(dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + for _ in range(num_layers) + ] + + # Effective depth = num_layers * num_repeats + effective_depth = num_layers * num_repeats + + # U-Net skip connections over effective depth + self.num_encoder_layers = effective_depth // 2 + self.num_decoder_layers = effective_depth - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = mx.ones((self.num_skip_weights, dim), dtype=mx.float32) + + # Per-iteration conditioning: learned embedding for each effective layer + # Small init so conditioning starts near zero and grows during training + self.iter_embed = mx.random.normal((effective_depth, dim), dtype=mx.float32) * 0.02 + + # Per-iteration gate: controls how much of the iteration embedding is added + # Initialized to small negative value so sigmoid(gate) starts near 0 + self.iter_gate = mx.full((effective_depth, dim), -2.0, dtype=mx.float32) + + self.final_norm = RMSNormNoWeight() + + for b in self.blocks: + b.attn.proj.weight = mx.zeros_like(b.attn.proj.weight) + b.mlp.proj.weight = mx.zeros_like(b.mlp.proj.weight) + self.tok_emb.weight = ( + mx.random.normal(self.tok_emb.weight.shape, dtype=mx.float32) * tied_embed_init_std + ).astype(COMPUTE_DTYPE) + + def softcap(self, logits: mx.array) -> mx.array: + c = self.logit_softcap + return c * mx.tanh(logits / c) + + def __call__(self, input_ids: mx.array) -> mx.array: + x = rms_norm(self.tok_emb(input_ids).astype(COMPUTE_DTYPE)) + x0 = x + skips: list[mx.array] = [] + num_blocks = len(self.blocks) + + # Flatten the iteration schedule: block 0 rep 0, block 1 rep 0, ..., block K-1 rep 0, block 0 rep 1, ... + effective_depth = num_blocks * self.num_repeats + + # Encoder half: first effective_depth//2 iterations + for eff_i in range(self.num_encoder_layers): + block_idx = eff_i % num_blocks + # Per-iteration conditioning: gated additive embedding + gate = mx.sigmoid(self.iter_gate[eff_i].astype(x.dtype))[None, None, :] + cond = self.iter_embed[eff_i].astype(x.dtype)[None, None, :] + x = x + gate * cond + x = self.blocks[block_idx](x, x0) + skips.append(x) + + # Decoder half: remaining iterations, with skip connections + for i in range(self.num_decoder_layers): + eff_i = self.num_encoder_layers + i + block_idx = eff_i % num_blocks + if skips: + x = x + self.skip_weights[i].astype(x.dtype)[None, None, :] * skips.pop() + gate = mx.sigmoid(self.iter_gate[eff_i].astype(x.dtype))[None, None, :] + cond = self.iter_embed[eff_i].astype(x.dtype)[None, None, :] + x = x + gate * cond + x = self.blocks[block_idx](x, x0) + + return self.final_norm(x) + + def loss(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: + # Cross-entropy over flattened tokens. We keep optional logit chunking because it is a useful + # memory knob on Macs, but the common path is chunk_tokens=0 (single matmul + CE). + x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1]) + y = target_ids.reshape(-1) + if self.logit_chunk_tokens <= 0 or x.shape[0] <= self.logit_chunk_tokens: + logits_proj = x @ self.tok_emb.weight.astype(x.dtype).T + logits = self.softcap(logits_proj) + return nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="mean") + + loss_sum = mx.array(0.0, dtype=mx.float32) + n = int(x.shape[0]) + for s in range(0, n, self.logit_chunk_tokens): + e = min(s + self.logit_chunk_tokens, n) + logits_proj = x[s:e] @ self.tok_emb.weight.astype(x.dtype).T + logits = self.softcap(logits_proj) + loss_sum = loss_sum + nn.losses.cross_entropy(logits.astype(mx.float32), y[s:e], reduction="sum") + return loss_sum / float(n) + +# ============================================================================== +# OPTIMIZERS (MUON + ADAM SPLIT) +# ============================================================================== +class Muon: + # Muon applies SGD-momentum to matrix gradients, then orthogonalizes the result before the + # parameter update. + def __init__(self, keys: list[str], params: dict[str, mx.array], args: Hyperparameters): + self.keys = keys + self.args = args + self.buffers = {k: mx.zeros_like(params[k]) for k in keys} + + def step(self, params: dict[str, mx.array], grads: dict[str, mx.array], step: int, lr_mul: float) -> dict[str, mx.array]: + if self.args.muon_momentum_warmup_steps: + t = min(step / self.args.muon_momentum_warmup_steps, 1.0) + momentum = (1.0 - t) * self.args.muon_momentum_warmup_start + t * self.args.muon_momentum + else: + momentum = self.args.muon_momentum + lr = self.args.matrix_lr * lr_mul + out: dict[str, mx.array] = {} + for k in self.keys: + p = params[k] + g = grads[k] + buf = momentum * self.buffers[k] + g + self.buffers[k] = buf + g_eff = g + momentum * buf + g_ortho = zeropower_newtonschulz5(g_eff, self.args.muon_backend_steps) + scale = math.sqrt(max(1.0, float(p.shape[0]) / float(p.shape[1]))) + out[k] = p - lr * (g_ortho * scale).astype(p.dtype) + return out + + +class SplitOptimizers: + # - embeddings: Adam with the tied-embedding LR + # - block matrices (2D): Muon + # - block scalars + skip weights: Adam + # This preserves the high-level optimization behavior even though MLX internals differ. + def __init__(self, model: GPT, args: Hyperparameters): + self.args = args + params = dict(tree_flatten(model.parameters())) + self.embed_key = "tok_emb.weight" + self.matrix_keys = [ + k + for k, p in params.items() + if k.startswith("blocks.") and p.ndim == 2 and not any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + self.scalar_keys = [ + k + for k, p in params.items() + if k in ("skip_weights", "iter_embed", "iter_gate") or (k.startswith("blocks.") and (p.ndim < 2 or any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS))) + ] + + self.muon = Muon(self.matrix_keys, params, args) + self.adam_embed = optim.Adam( + learning_rate=args.tied_embed_lr, + betas=[args.beta1, args.beta2], + eps=args.adam_eps, + bias_correction=True, + ) + self.adam_scalar = optim.Adam( + learning_rate=args.scalar_lr, + betas=[args.beta1, args.beta2], + eps=args.adam_eps, + bias_correction=True, + ) + + def step(self, model: GPT, grads_tree: dict, step: int, lr_mul: float) -> None: + params = dict(tree_flatten(model.parameters())) + grads = dict(tree_flatten(grads_tree)) + updated = dict(params) + + updated.update(self.muon.step(params, grads, step=step, lr_mul=lr_mul)) + + self.adam_embed.learning_rate = self.args.tied_embed_lr * lr_mul + updated.update( + self.adam_embed.apply_gradients( + {self.embed_key: grads[self.embed_key]}, + {self.embed_key: params[self.embed_key]}, + ) + ) + + self.adam_scalar.learning_rate = self.args.scalar_lr * lr_mul + scalar_grads = {k: grads[k] for k in self.scalar_keys} + scalar_params = {k: params[k] for k in self.scalar_keys} + updated.update(self.adam_scalar.apply_gradients(scalar_grads, scalar_params)) + + model.update(tree_unflatten(list(updated.items()))) + +# ============================================================================== +# QUANTIZATION (INT8 + ZLIB) +# ============================================================================== +# - per-row int8 for 2D float tensors +# - per-tensor int8 for other float tensors +# - fp16 passthrough for small float tensors +# - exact passthrough for non-floats + +MX_DTYPE_FROM_NAME = { + "float32": mx.float32, + "float16": mx.float16, + "bfloat16": mx.bfloat16, +} + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = np.float16 +INT8_PER_ROW_SCALE_DTYPE = np.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + + +def _np_float32(arr: mx.array) -> np.ndarray: + return np.array(arr.astype(mx.float32), dtype=np.float32, copy=False) + + +def keep_float_array(name: str, arr: mx.array, passthrough_orig_dtypes: dict[str, str]) -> np.ndarray: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return np.ascontiguousarray(_np_float32(arr)) + if arr.dtype in {mx.float32, mx.bfloat16}: + passthrough_orig_dtypes[name] = str(arr.dtype).split(".")[-1] + return np.ascontiguousarray(np.array(arr.astype(mx.float16), dtype=INT8_KEEP_FLOAT_STORE_DTYPE, copy=False)) + return np.ascontiguousarray(np.array(arr, copy=True)) + + +def quantize_float_array(arr: mx.array) -> tuple[np.ndarray, np.ndarray]: + f32 = _np_float32(arr) + if f32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = np.quantile(np.abs(f32), INT8_CLIP_Q, axis=1) if f32.size else np.empty((f32.shape[0],), dtype=np.float32) + clipped = np.clip(f32, -clip_abs[:, None], clip_abs[:, None]) + scale = np.maximum(clip_abs / 127.0, 1.0 / 127.0).astype(np.float32, copy=False) + q = np.clip(np.round(clipped / scale[:, None]), -127, 127).astype(np.int8, copy=False) + return np.ascontiguousarray(q), np.ascontiguousarray(scale.astype(INT8_PER_ROW_SCALE_DTYPE, copy=False)) + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(np.quantile(np.abs(f32).reshape(-1), INT8_CLIP_Q)) if f32.size else 0.0 + scale = np.array(clip_abs / 127.0 if clip_abs > 0.0 else 1.0, dtype=np.float32) + q = np.clip(np.round(np.clip(f32, -clip_abs, clip_abs) / scale), -127, 127).astype(np.int8, copy=False) + return np.ascontiguousarray(q), scale + + +def quantize_state_dict_int8(flat_state: dict[str, mx.array]) -> tuple[dict[str, object], dict[str, int]]: + quantized: dict[str, np.ndarray] = {} + scales: dict[str, np.ndarray] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, np.ndarray] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, arr in flat_state.items(): + stats["param_count"] += int(arr.size) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += int(arr.nbytes) + if not mx.issubdtype(arr.dtype, mx.floating): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = np.ascontiguousarray(np.array(arr)) + stats["int8_payload_bytes"] += int(passthrough[name].nbytes) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if int(arr.size) <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_array(name, arr, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += int(kept.nbytes) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_array(arr) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(arr.dtype).split(".")[-1] + stats["int8_payload_bytes"] += int(q.nbytes + s.nbytes) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def dequantize_state_dict_int8(quant_obj: dict[str, object]) -> dict[str, mx.array]: + out: dict[str, mx.array] = {} + qmeta = quant_obj.get("qmeta", {}) + passthrough_orig_dtypes = quant_obj.get("passthrough_orig_dtypes", {}) + for name, q in quant_obj["quantized"].items(): + q_np = np.asarray(q, dtype=np.int8) + dtype_name = quant_obj["dtypes"][name] + scale = np.asarray(quant_obj["scales"][name], dtype=np.float32) + if qmeta.get(name, {}).get("scheme") == "per_row" or scale.ndim > 0: + # Broadcast the saved row scale back across trailing dimensions. + out_arr = q_np.astype(np.float32) * scale.reshape((q_np.shape[0],) + (1,) * (q_np.ndim - 1)) + else: + out_arr = q_np.astype(np.float32) * float(scale) + out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[dtype_name]) + for name, arr in quant_obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_arr = np.array(arr, copy=True) + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[orig_dtype]) + else: + out[name] = mx.array(out_arr) + return out + + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_lut = np.zeros((table_size,), dtype=np.int16) + has_leading_space_lut = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_lut = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_lut[token_id] = False + if sp.is_byte(token_id): + base_bytes_lut[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_lut[token_id] = True + piece = piece[1:] + base_bytes_lut[token_id] = len(piece.encode("utf-8")) + return base_bytes_lut, has_leading_space_lut, is_boundary_token_lut + + +def validate_dataset_tokenizer_pair(data_path: str, tokenizer_path: str) -> tuple[str, int, int | None]: + # The shard directory and tokenizer are coupled: val_bpb is only meaningful if we + # decode bytes with the exact tokenizer that produced the shards. The manifest + # lets the training script fail fast on accidental dataset/tokenizer mismatches. + dataset_dir = Path(data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + if len(dataset_dir.parents) < 2: + return dataset_dir.name, actual_train_files, None + manifest_path = dataset_dir.parents[1] / "manifest.json" + if not manifest_path.is_file(): + return dataset_dir.name, actual_train_files, None + + manifest = json.loads(manifest_path.read_text(encoding="utf-8")) + dataset_entry = next((x for x in manifest.get("datasets", []) if x.get("name") == dataset_dir.name), None) + if dataset_entry is None: + return dataset_dir.name, actual_train_files, None + + tokenizer_name = dataset_entry.get("tokenizer_name") + tokenizer_entry = ( + next((x for x in manifest.get("tokenizers", []) if x.get("name") == tokenizer_name), None) + if tokenizer_name + else None + ) + expected_name = Path((tokenizer_entry or {}).get("model_path") or (tokenizer_entry or {}).get("path") or "").name + if expected_name and Path(tokenizer_path).name != expected_name: + raise ValueError(f"{dataset_dir.name} expects tokenizer {expected_name}, got {Path(tokenizer_path).name}") + expected_train_files = (dataset_entry.get("stats") or {}).get("files_train") + if expected_train_files is not None: + expected_train_files = int(expected_train_files) + if actual_train_files > expected_train_files: + raise ValueError( + f"{dataset_dir.name} has more train shards than expected: found {actual_train_files}, " + f"manifest says {expected_train_files}" + ) + return dataset_dir.name, actual_train_files, expected_train_files + + +def load_validation_tokens(pattern: str, seq_len: int) -> np.ndarray: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = np.ascontiguousarray(np.concatenate([load_data_shard(file) for file in files], axis=0)) + usable = ((tokens.size - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def loss_and_grad_chunked( + args: Hyperparameters, + train_loader: TokenLoader, + compiled_loss_and_grad, +) -> tuple[mx.array, dict]: + chunk_sizes = token_chunks(args.microbatch_tokens, args.train_seq_len, args.mlx_max_microbatch_tokens) + total_tokens = float(sum(chunk_sizes)) + loss_value = mx.array(0.0, dtype=mx.float32) + grad_accum: dict[str, mx.array] | None = None + for chunk_tokens in chunk_sizes: + x, y = train_loader.next_batch(chunk_tokens, args.train_seq_len) + loss, grads = compiled_loss_and_grad(x, y) + scale = float(y.size) / total_tokens + loss_value = loss_value + loss.astype(mx.float32) * scale + grad_accum = accumulate_flat_grads(grad_accum, grads, scale) + if args.mlx_eager_eval: + mx.eval(loss_value, grad_accum) # materialize each chunk to cap peak memory + return loss_value, tree_unflatten(list(grad_accum.items())) + + +def eval_val( + args: Hyperparameters, + compiled_loss, + val_tokens: np.ndarray, + base_bytes_lut: np.ndarray, + has_leading_space_lut: np.ndarray, + is_boundary_token_lut: np.ndarray, + log_fn: Callable[[str], None] | None = None, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + val_batch_tokens = args.val_batch_size // args.grad_accum_steps + if val_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + val_batch_seqs = val_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.size - 1) // args.train_seq_len + total_batches = max((total_seqs + val_batch_seqs - 1) // val_batch_seqs, 1) + total_loss_sum = 0.0 + total_tokens = 0.0 + total_bytes = 0.0 + for batch_idx, batch_seq_start in enumerate(range(0, total_seqs, val_batch_seqs), start=1): + batch_seq_end = min(batch_seq_start + val_batch_seqs, total_seqs) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + chunk = val_tokens[raw_start:raw_end] + x_np = chunk[:-1].reshape(-1, args.train_seq_len) + y_np = chunk[1:].reshape(-1, args.train_seq_len) + x = mx.array(x_np, dtype=mx.int32) + y = mx.array(y_np, dtype=mx.int32) + chunk_token_count = float(y.size) + batch_loss = compiled_loss(x, y).astype(mx.float32) + mx.eval(batch_loss) + total_loss_sum += float(batch_loss.item()) * chunk_token_count + prev_ids = x_np.reshape(-1) + tgt_ids = y_np.reshape(-1) + bytes_np = base_bytes_lut[tgt_ids].astype(np.int16, copy=True) + bytes_np += ( + has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids] + ).astype(np.int16, copy=False) + total_tokens += chunk_token_count + total_bytes += float(bytes_np.astype(np.float64).sum()) + if log_fn is not None and total_batches > 1 and ( + batch_idx == 1 or batch_idx == total_batches or batch_idx % 25 == 0 + ): + log_fn(f"val_progress:{batch_idx}/{total_batches}") + val_loss = total_loss_sum / total_tokens + bits_per_token = val_loss / math.log(2.0) + val_bpb = bits_per_token * (total_tokens / total_bytes) + return val_loss, val_bpb + +# ----------------------------- +# TRAINING +# ----------------------------- + +def clip_grad_tree(grads_tree: dict, max_norm: float) -> dict: + if max_norm <= 0: + return grads_tree + flat = dict(tree_flatten(grads_tree)) + total_sq = 0.0 + for grad in flat.values(): + total_sq += float(np.sum(np.square(_np_float32(grad)), dtype=np.float64)) + if total_sq <= 0.0: + return grads_tree + total_norm = math.sqrt(total_sq) + if total_norm <= max_norm: + return grads_tree + scale = max_norm / (total_norm + 1e-12) + return tree_unflatten([(k, g * scale) for k, g in flat.items()]) + + +def main() -> None: + # ============================================================================== + # TOKENIZER + VALIDATION METRIC SETUP + # ============================================================================== + args = Hyperparameters() + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + logfile = out_dir / f"{args.run_id}.txt" + print(logfile) + + def log(msg: str, console: bool = True) -> None: + if console: + print(msg) + with logfile.open("a", encoding="utf-8") as f: + print(msg, file=f) + + code = Path(__file__).read_text(encoding="utf-8") + log(code, console=False) + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running MLX {mx.__version__}", console=False) + log("=" * 100, console=False) + + if not args.tie_embeddings: + raise NotImplementedError("train_gpt_mlx.py only supports tied embeddings") + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"TOKENIZER_PATH must point to a SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_name, actual_train_files, expected_train_files = validate_dataset_tokenizer_pair( + args.data_path, + args.tokenizer_path, + ) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size + ) + + # ============================================================================== + # TRAINING SETUP + # ============================================================================== + mx.random.seed(args.seed) + + train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) + + # ============================================================================== + # MODEL + OPTIMIZER SETUP + # ============================================================================== + model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + logit_chunk_tokens=args.logit_chunk_tokens, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + tied_embed_init_std=args.tied_embed_init_std, + qk_gain_init=args.qk_gain_init, + num_repeats=args.num_repeats, + ) + opt = SplitOptimizers(model, args) + + # ============================================================================== + # COMPILED TRAIN / EVAL FUNCTIONS (MLX) + # ============================================================================== + # The crucial MLX detail is capture scope: this model contains non-trainable arrays too (for example + # inside RoPE modules), so compiling only against trainable parameters throws "uncaptured inputs". + # Compiling the model-bound functions and capturing the full model state fixes that while still + # returning gradients only for trainable parameters via nn.value_and_grad(...). + compiled_loss = mx.compile(lambda x, y: model.loss(x, y), inputs=model.state, outputs=model.state) + compiled_loss_and_grad = mx.compile( + nn.value_and_grad(model, lambda x, y: model.loss(x, y)), + inputs=model.state, + outputs=model.state, + ) + + # Print config once so logs are self-describing. + n_params = sum(int(np.prod(p.shape)) for _, p in tree_flatten(model.parameters())) + log(f"run_id:{args.run_id}") + log(f"mlx_version:{mx.__version__}") + log(f"train_loader:shards pattern={args.train_files}") + log(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.size - 1}") + if expected_train_files is None: + log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}") + elif actual_train_files < expected_train_files: + log( + f"WARNING: train_loader:subset dataset:{dataset_name} " + f"train_shards:{actual_train_files}/{expected_train_files} " + f"new epochs will arrive sooner than the full dataset" + ) + else: + log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}/{expected_train_files}") + log(f"tokenizer_path:{args.tokenizer_path}") + log( + f"model_params:{n_params} vocab_size:{args.vocab_size} layers:{args.num_layers} " + f"repeats:{args.num_repeats} effective_depth:{args.num_layers * args.num_repeats} " + f"dim:{args.model_dim} heads:{args.num_heads} kv_heads:{args.num_kv_heads} " + f"seq_len:{args.train_seq_len} tie_embeddings:{args.tie_embeddings}" + ) + log( + f"iterations:{args.iterations} train_batch_tokens:{args.train_batch_tokens} grad_accum_steps:{args.grad_accum_steps} " + f"microbatch_tokens:{args.microbatch_tokens} microbatch_batch_size:{args.microbatch_tokens // args.train_seq_len} " + f"val_batch_size:{args.val_batch_size} " + f"warmup_steps:{args.warmup_steps} max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log(f"mlx_max_microbatch_tokens:{args.mlx_max_microbatch_tokens}") + log( + f"optimizer:muon+adam muon_matrix_params:{len(opt.matrix_keys)} scalar_params:{len(opt.scalar_keys)} " + f"embed_lr:{args.tied_embed_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr} " + f"muon_momentum:{args.muon_momentum} muon_steps:{args.muon_backend_steps}" + ) + log(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log(f"compute_dtype:{COMPUTE_DTYPE} compile:True") + log( + f"dtypes tok_emb:{model.tok_emb.weight.dtype} " + f"linear_weight:{model.blocks[0].attn.c_q.weight.dtype} " + f"skip_weights:{model.skip_weights.dtype}" + ) + + # ============================================================================== + # TRAINING LOOP + # ============================================================================== + if args.warmup_steps > 0: + # Warmup should only prime MLX compile/allocation paths. Updating parameters here forces us + # to snapshot and restore model/optimizer state, which is expensive on unified-memory Macs. + # Instead we run the real train shapes, force the loss/grads to materialize, and then reset + # the loader so measured training still starts from the true init and token window. + for warmup_step in range(args.warmup_steps): + accum: dict[str, mx.array] | None = None + warmup_loss = mx.array(0.0, dtype=mx.float32) + grad_scale = 1.0 / args.grad_accum_steps + for _ in range(args.grad_accum_steps): + warmup_loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) + accum = accumulate_flat_grads(accum, grads, grad_scale) + mx.eval(warmup_loss, accum) + mx.synchronize() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + + # Prime the standalone eval graph once too. It is compiled separately from value_and_grad. + val_batch_tokens = args.val_batch_size // args.grad_accum_steps + if val_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + warm_val_seqs = min(val_batch_tokens // args.train_seq_len, (val_tokens.size - 1) // args.train_seq_len) + warm_chunk = val_tokens[: warm_val_seqs * args.train_seq_len + 1] + x_val = mx.array(warm_chunk[:-1].reshape(-1, args.train_seq_len), dtype=mx.int32) + y_val = mx.array(warm_chunk[1:].reshape(-1, args.train_seq_len), dtype=mx.int32) + warm_val_loss = compiled_loss(x_val, y_val) + mx.eval(warm_val_loss) + mx.synchronize() + + train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) + + train_time_ms = 0.0 + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + stop_after_step: int | None = None + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + train_time_ms += 1000.0 * (time.perf_counter() - t0) + # Validation always scans the same fixed full validation split. + val_loss, val_bpb = eval_val( + args, + compiled_loss, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + log_fn=log, + ) + if step % 25 == 0 or last_step: + log( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{train_time_ms:.0f}ms step_avg:{train_time_ms / max(step, 1):.2f}ms" + ) + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log(f"stopping_early: wallclock_cap train_time:{train_time_ms:.0f}ms step:{step}/{args.iterations}") + break + + lr_mul = args.lr_mul(step, train_time_ms + 1000.0 * (time.perf_counter() - t0)) + step_t0 = time.perf_counter() + + accum: dict[str, mx.array] | None = None + train_loss = mx.array(0.0, dtype=mx.float32) + grad_scale = 1.0 / args.grad_accum_steps + for _ in range(args.grad_accum_steps): + loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) + accum = accumulate_flat_grads(accum, grads, grad_scale) + train_loss = train_loss + loss.astype(mx.float32) * grad_scale + if args.mlx_eager_eval: + mx.eval(train_loss, accum) # materialize each microbatch to cap peak memory + + grads = tree_unflatten(list(accum.items())) + grads = clip_grad_tree(grads, args.grad_clip_norm) + train_loss_value = float(train_loss.item()) + opt.step(model, grads, step=step, lr_mul=lr_mul) + mx.synchronize() + + step_ms = 1000.0 * (time.perf_counter() - step_t0) + approx_train_time_ms = train_time_ms + 1000.0 * (time.perf_counter() - t0) + tok_s = args.train_batch_tokens / (step_ms / 1000.0) + step += 1 + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None): + log( + f"step:{step}/{args.iterations} train_loss:{train_loss_value:.4f} " + f"train_time:{approx_train_time_ms:.0f}ms step_avg:{approx_train_time_ms / step:.2f}ms tok_s:{tok_s:.0f}" + ) + if max_wallclock_ms is not None and stop_after_step is None and approx_train_time_ms >= max_wallclock_ms: + stop_after_step = step + + # ============================================================================== + # FINAL SERIALIZATION + QUANTIZED ROUNDTRIP EVAL + # ============================================================================== + # We always write a raw artifact and a quantized artifact, then validate the + # quantized roundtrip directly by loading the dequantized tensors back into the + # model and running one final validation pass. + out_path = out_dir / f"{args.run_id}_mlx_model.npz" + flat_state = {k: v for k, v in tree_flatten(model.state)} + mx.savez(str(out_path), **flat_state) + log(f"saved_model:{out_path} bytes:{out_path.stat().st_size}") + + quant_obj, quant_stats = quantize_state_dict_int8(flat_state) + quant_raw = pickle.dumps(quant_obj, protocol=pickle.HIGHEST_PROTOCOL) + quant_blob = zlib.compress(quant_raw, level=9) + quant_serialized_bytes = len(quant_raw) + quant_path = out_dir / f"{args.run_id}_mlx_model.int8.ptz" + with quant_path.open("wb") as f: + f.write(quant_blob) + quant_file_bytes = quant_path.stat().st_size + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log( + f"serialized_model_int8_zlib:{quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_pickle:{quant_serialized_bytes} payload_ratio:{ratio:.2f}x)" + ) + + with quant_path.open("rb") as f: + quant_blob_disk = f.read() + quant_flat = dequantize_state_dict_int8(pickle.loads(zlib.decompress(quant_blob_disk))) + model.update(tree_unflatten(list(quant_flat.items()))) + q_t0 = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + compiled_loss, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + log_fn=log, + ) + q_eval_ms = 1000.0 * (time.perf_counter() - q_t0) + log(f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{q_eval_ms:.0f}ms") + log(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + +if __name__ == "__main__": + main() diff --git a/train_gpt_recurrence.py b/train_gpt_recurrence.py new file mode 100644 index 0000000000..e1eb309ec5 --- /dev/null +++ b/train_gpt_recurrence.py @@ -0,0 +1,1173 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + # Depth recurrence: each unique layer is repeated num_repeats times. + # effective_depth = num_layers * num_repeats. + num_repeats = int(os.environ.get("NUM_REPEATS", 1)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,iter_embed,iter_gate", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_repeats: int = 1, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.num_repeats = num_repeats + self.num_unique_layers = num_layers + self.tok_emb = nn.Embedding(vocab_size, model_dim) + + # Effective depth = num_layers (unique) * num_repeats + effective_depth = num_layers * num_repeats + self.effective_depth = effective_depth + self.num_encoder_layers = effective_depth // 2 + self.num_decoder_layers = effective_depth - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + + # K unique blocks (shared across repeats) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for _ in range(num_layers) + ] + ) + + # Per-iteration conditioning (only when recurrence is active) + if num_repeats > 1: + self.iter_embed = nn.Parameter(torch.randn(effective_depth, model_dim) * 0.02) + self.iter_gate = nn.Parameter(torch.full((effective_depth, model_dim), -2.0)) + else: + self.iter_embed = None + self.iter_gate = None + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _apply_iter_conditioning(self, x: Tensor, eff_i: int) -> Tensor: + if self.iter_embed is not None and self.iter_gate is not None: + gate = torch.sigmoid(self.iter_gate[eff_i].to(dtype=x.dtype))[None, None, :] + cond = self.iter_embed[eff_i].to(dtype=x.dtype)[None, None, :] + x = x + gate * cond + return x + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + num_blocks = self.num_unique_layers + + # Encoder half: first effective_depth//2 iterations + for eff_i in range(self.num_encoder_layers): + block_idx = eff_i % num_blocks + x = self._apply_iter_conditioning(x, eff_i) + x = self.blocks[block_idx](x, x0) + skips.append(x) + + # Decoder half: remaining iterations, with skip connections + for i in range(self.num_decoder_layers): + eff_i = self.num_encoder_layers + i + block_idx = eff_i % num_blocks + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self._apply_iter_conditioning(x, eff_i) + x = self.blocks[block_idx](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_repeats=args.num_repeats, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + # Depth recurrence conditioning params go into Adam (scalar) group + if base_model.iter_embed is not None: + scalar_params.append(base_model.iter_embed) + if base_model.iter_gate is not None: + scalar_params.append(base_model.iter_gate) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0( + f"model_params:{n_params} vocab_size:{args.vocab_size} " + f"layers:{args.num_layers} repeats:{args.num_repeats} " + f"effective_depth:{base_model.effective_depth} dim:{args.model_dim} " + f"heads:{args.num_heads} kv_heads:{args.num_kv_heads}" + ) + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/train_gpt_slot_recurrence.py b/train_gpt_slot_recurrence.py new file mode 100644 index 0000000000..b2099c4be7 --- /dev/null +++ b/train_gpt_slot_recurrence.py @@ -0,0 +1,1516 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +from pathlib import Path +import lzma +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "lzma" +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True +except ImportError: + try: + from flash_attn.flash_attn_interface import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True + except ImportError: + try: + from flash_attn import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True + except ImportError: + _HAS_FA3 = False + flash_attn_3_func = None +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 4.0)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) # tighter: collect more recent checkpoints + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 1024)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 11)) # XSA on last 11 layers (0 = disabled) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + vrl_enabled = bool(int(os.environ.get("VRL_ENABLED", "1"))) + slot_enabled = bool(int(os.environ.get("SLOT_ENABLED", "1"))) + slot_steps = int(os.environ.get("SLOT_STEPS", 16)) + slot_lr = float(os.environ.get("SLOT_LR", 0.008)) + slot_lr_min = float(os.environ.get("SLOT_LR_MIN", 0.0008)) + slot_batch_seqs = int(os.environ.get("SLOT_BATCH_SEQS", 32)) + slot_warmstart = float(os.environ.get("SLOT_WARMSTART", 0.0)) # 0=disabled; 0.85=warmstart from prev batch + # Partial depth recurrence: repeat specified layers once more in the forward pass. + # virtual_layers = [0,1,2,3,4,5,4,5,6,7,8,9,10] when recur_layers="4,5" + recur_layers = os.environ.get("RECUR_LAYERS", "") # e.g., "4,5" + recur_start_step = int(os.environ.get("RECUR_START_STEP", 3000)) +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale", + ).split(",") + if pattern +) +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + self.vrl_gate = None # set by GPT.__init__ when VRL is enabled + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] — broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, v_embed: Tensor | None = None, v_first: Tensor | None = None) -> tuple[Tensor, Tensor]: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + v_raw = v # save raw V before any modifications for VRL + if v_embed is not None: + v = v + v_embed + if v_first is not None and self.vrl_gate is not None: + gate = torch.sigmoid(self.vrl_gate.to(dtype=v.dtype)) + v = (1 - gate) * v + gate * v_first + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if _HAS_FA3: + y = flash_attn_3_func(q, k, v, causal=True) + else: + q2 = q.transpose(1, 2) + k2 = k.transpose(1, 2) + v2 = v.transpose(1, 2) + k2 = k2.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + v2 = v2.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + y = F.scaled_dot_product_attention(q2, k2, v2, is_causal=True) + y = y.transpose(1, 2).contiguous() + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y), v_raw +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class ValueEmbedding(nn.Module): + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None, v_first: Tensor | None = None) -> tuple[Tensor, Tensor]: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out, v_raw = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, v_first=v_first) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out, v_raw +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + vrl_enabled: bool = False, + recur_layers: str = "", + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + # Build virtual layer schedule for partial recurrence + self._recur_layer_set = {int(x) for x in recur_layers.split(",") if x.strip()} if recur_layers else set() + self._flat_schedule = list(range(num_layers)) + self._recur_schedule = [] + for i in range(num_layers): + self._recur_schedule.append(i) + if i in self._recur_layer_set: + self._recur_schedule.append(i) # repeat + # If recur_layers specified, always use recurrence (no delayed start — keeps graph static for torch.compile) + self.recurrence_active = bool(self._recur_layer_set) + self._active_schedule = self._recur_schedule if self.recurrence_active else self._flat_schedule + eff_depth = len(self._active_schedule) + self.num_encoder_layers = eff_depth // 2 + self.num_decoder_layers = eff_depth - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + # Per-iteration conditioning for recurrent positions + num_recur_pos = len(self._recur_schedule) - num_layers # number of repeated positions + if num_recur_pos > 0: + self.iter_embed = nn.Parameter(torch.randn(num_recur_pos, model_dim) * 0.02) + self.iter_gate = nn.Parameter(torch.full((num_recur_pos, model_dim), -2.0)) + else: + self.iter_embed = None + self.iter_gate = None + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + ) + for i in range(num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self.vrl_enabled = vrl_enabled + if vrl_enabled: + for i in range(1, num_layers): # all layers except layer 0 + self.blocks[i].attn.vrl_gate = nn.Parameter(torch.tensor(-1.5, dtype=torch.float32)) + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def _run_layers(self, x: Tensor, x0: Tensor, input_ids: Tensor) -> Tensor: + schedule = self._active_schedule + eff_depth = len(schedule) + enc_layers = eff_depth // 2 + dec_layers = eff_depth - enc_layers + skips: list[Tensor] = [] + ve_cache: dict = {} + v_first: Tensor | None = None + recur_idx = 0 # tracks which recurrent position we're at + seen_counts: dict[int, int] = {} + for eff_i in range(enc_layers): + vi = schedule[eff_i] + seen_counts[vi] = seen_counts.get(vi, 0) + 1 + is_repeat = seen_counts[vi] > 1 + if is_repeat and self.iter_embed is not None: + gate = torch.sigmoid(self.iter_gate[recur_idx].to(dtype=x.dtype))[None, None, :] + x = x + gate * self.iter_embed[recur_idx].to(dtype=x.dtype)[None, None, :] + recur_idx += 1 + ve = self._get_ve(vi, input_ids, ve_cache) + x, v_raw = self.blocks[vi](x, x0, v_embed=ve, v_first=v_first if self.vrl_enabled else None) + if eff_i == 0 and self.vrl_enabled: + v_first = v_raw + skips.append(x) + skip_idx = 0 + for i in range(dec_layers): + eff_i = enc_layers + i + vi = schedule[eff_i] + seen_counts[vi] = seen_counts.get(vi, 0) + 1 + is_repeat = seen_counts[vi] > 1 + if skips and skip_idx < self.num_skip_weights: + x = x + self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] * skips.pop() + skip_idx += 1 + elif skips: + skips.pop() + if is_repeat and self.iter_embed is not None: + gate = torch.sigmoid(self.iter_gate[recur_idx].to(dtype=x.dtype))[None, None, :] + x = x + gate * self.iter_embed[recur_idx].to(dtype=x.dtype)[None, None, :] + recur_idx += 1 + ve = self._get_ve(vi, input_ids, ve_cache) + x, _ = self.blocks[vi](x, x0, v_embed=ve, v_first=v_first if self.vrl_enabled else None) + return self.final_norm(x) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + x = self._run_layers(x, x0, input_ids) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def forward_hidden(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + return self._run_layers(x, x0, input_ids) + def compute_logits(self, hidden: Tensor) -> Tensor: + if self.tie_embeddings: + logits_proj = F.linear(hidden, self.tok_emb.weight) + else: + logits_proj = self.lm_head(hidden) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + def forward_logits(self, input_ids: Tensor) -> Tensor: + return self.compute_logits(self.forward_hidden(input_ids)) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +def eval_val_slot( + args: Hyperparameters, + base_model: nn.Module, + rank: int, world_size: int, device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + stride = args.eval_stride if args.eval_stride > 0 else 64 + seq_s = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + total_tok = val_tokens.numel() - 1 + ws_list = list(range(0, total_tok, stride)) + ws_list = [ws for ws in ws_list if min(ws + seq_s, total_tok) - ws >= 1] + my_ws = ws_list[rank::world_size] + if args.tie_embeddings: + proj_w = base_model.tok_emb.weight.detach().float() + else: + proj_w = base_model.lm_head.weight.detach().float() + softcap = base_model.logit_softcap + compiled_hidden = torch.compile(base_model.forward_hidden, dynamic=False, fullgraph=False) + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + warmstart_alpha = args.slot_warmstart + prev_delta = None + prev_bias = None + base_model.eval() + for bi in range(0, len(my_ws), args.slot_batch_seqs): + bws = my_ws[bi:bi + args.slot_batch_seqs] + bsz = len(bws) + xb_cpu = torch.zeros(bsz, seq_s, dtype=torch.int64) + yb_cpu = torch.zeros(bsz, seq_s, dtype=torch.int64) + wlens = [] + for i, ws in enumerate(bws): + wend = min(ws + seq_s, total_tok) + wlen = wend - ws + wlens.append(wlen) + xb_cpu[i, :wlen] = val_tokens[ws:wend] + yb_cpu[i, :wlen] = val_tokens[ws + 1:wend + 1] + xb = xb_cpu.to(device=device, non_blocking=True) + yb = yb_cpu.to(device=device, non_blocking=True) + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + hidden = compiled_hidden(xb) + hidden_f = hidden.detach().float() + mask = torch.zeros(bsz, seq_s, device=device) + for i, ws in enumerate(bws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + mask[i, s:wlen] = 1.0 + valid_count = mask.sum() + if valid_count == 0: + continue + if warmstart_alpha > 0 and prev_delta is not None and prev_delta.size(0) == bsz: + delta = (warmstart_alpha * prev_delta.detach().clone()).requires_grad_(True) + logit_bias = (warmstart_alpha * prev_bias.detach().clone()).requires_grad_(True) + else: + delta = torch.zeros(bsz, 1, hidden_f.size(-1), device=device, dtype=torch.float32, requires_grad=True) + logit_bias = torch.zeros(bsz, 1, proj_w.size(0), device=device, dtype=torch.float32, requires_grad=True) + slot_opt = torch.optim.AdamW([delta, logit_bias], lr=args.slot_lr, weight_decay=1e-8, eps=1e-5) + targets_flat = yb.reshape(-1) + for step_i in range(args.slot_steps): + lr_t = args.slot_lr_min + 0.5 * (args.slot_lr - args.slot_lr_min) * (1 + math.cos(math.pi * step_i / args.slot_steps)) + for pg in slot_opt.param_groups: + pg['lr'] = lr_t + slot_opt.zero_grad() + h = hidden_f + delta + lp = F.linear(h, proj_w) + logit_bias + lg = softcap * torch.tanh(lp / softcap) + nll = F.cross_entropy(lg.reshape(-1, lg.size(-1)), targets_flat, reduction="none").reshape(bsz, seq_s) + slot_loss = (nll * mask).sum() / valid_count + slot_loss.backward() + slot_opt.step() + if warmstart_alpha > 0: + prev_delta = delta.detach() + prev_bias = logit_bias.detach() + with torch.no_grad(): + h = hidden_f + delta.detach() + lp = F.linear(h, proj_w) + logit_bias.detach() + lg = softcap * torch.tanh(lp / softcap) + nll = F.cross_entropy(lg.reshape(-1, lg.size(-1)), targets_flat, reduction="none").reshape(bsz, seq_s) + for i, ws in enumerate(bws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + chunk_nll = nll[i, s:wlen] + loss_sum += chunk_nll.sum().to(torch.float64) + token_count += float(wlen - s) + prev_ids = xb[i, s:wlen] + tgt_ids = yb[i, s:wlen] + tb = base_bytes_lut[tgt_ids].to(torch.float64) + tb += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.float64) + byte_sum += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_sum.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.to(torch.float16) + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._qat_enabled = args.qat_enabled + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + vrl_enabled=args.vrl_enabled, + recur_layers=args.recur_layers, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.mtp_num_heads > 0: + matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.iter_embed is not None: + scalar_params.append(base_model.iter_embed) + if base_model.iter_gate is not None: + scalar_params.append(base_model.iter_gate) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = lzma.compress(quant_raw, preset=6) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + log0(f"Total submission size: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(lzma.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, # must match training model + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + vrl_enabled=args.vrl_enabled, + recur_layers=args.recur_layers, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int6_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int6_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + if args.slot_enabled: + torch._dynamo.reset() + torch.cuda.synchronize() + t_slot = time.perf_counter() + slot_val_loss, slot_val_bpb = eval_val_slot( + args, eval_model, rank, world_size, device, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_slot val_loss:{slot_val_loss:.4f} val_bpb:{slot_val_bpb:.4f} " + f"steps:{args.slot_steps} lr:{args.slot_lr} warmstart:{args.slot_warmstart} eval_time:{1000.0 * (time.perf_counter() - t_slot):.0f}ms" + ) + log0(f"final_slot_exact val_loss:{slot_val_loss:.8f} val_bpb:{slot_val_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/train_gpt_sota_slot.py b/train_gpt_sota_slot.py new file mode 100644 index 0000000000..8f94a1f8aa --- /dev/null +++ b/train_gpt_sota_slot.py @@ -0,0 +1,1457 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +from pathlib import Path +import lzma +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True +except ImportError: + try: + from flash_attn.flash_attn_interface import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True + except ImportError: + try: + from flash_attn import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True + except ImportError: + _HAS_FA3 = False + flash_attn_3_func = None +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 4.0)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) # tighter: collect more recent checkpoints + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 1024)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 11)) # XSA on last 11 layers (0 = disabled) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + vrl_enabled = bool(int(os.environ.get("VRL_ENABLED", "1"))) + slot_enabled = bool(int(os.environ.get("SLOT_ENABLED", "1"))) + slot_steps = int(os.environ.get("SLOT_STEPS", 16)) + slot_lr = float(os.environ.get("SLOT_LR", 0.008)) + slot_lr_min = float(os.environ.get("SLOT_LR_MIN", 0.0008)) + slot_batch_seqs = int(os.environ.get("SLOT_BATCH_SEQS", 32)) +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale", + ).split(",") + if pattern +) +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + self.vrl_gate = None # set by GPT.__init__ when VRL is enabled + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] — broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, v_embed: Tensor | None = None, v_first: Tensor | None = None) -> tuple[Tensor, Tensor]: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + v_raw = v # save raw V before any modifications for VRL + if v_embed is not None: + v = v + v_embed + if v_first is not None and self.vrl_gate is not None: + gate = torch.sigmoid(self.vrl_gate.to(dtype=v.dtype)) + v = (1 - gate) * v + gate * v_first + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if _HAS_FA3: + y = flash_attn_3_func(q, k, v, causal=True) + else: + q2 = q.transpose(1, 2) + k2 = k.transpose(1, 2) + v2 = v.transpose(1, 2) + k2 = k2.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + v2 = v2.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + y = F.scaled_dot_product_attention(q2, k2, v2, is_causal=True) + y = y.transpose(1, 2).contiguous() + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y), v_raw +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class ValueEmbedding(nn.Module): + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None, v_first: Tensor | None = None) -> tuple[Tensor, Tensor]: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out, v_raw = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, v_first=v_first) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out, v_raw +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + vrl_enabled: bool = False, + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + ) + for i in range(num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self.vrl_enabled = vrl_enabled + if vrl_enabled: + for i in range(1, num_layers): # all layers except layer 0 + self.blocks[i].attn.vrl_gate = nn.Parameter(torch.tensor(-1.5, dtype=torch.float32)) + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + v_first: Tensor | None = None + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, v_raw = self.blocks[i](x, x0, v_embed=ve, v_first=v_first if self.vrl_enabled else None) + if i == 0 and self.vrl_enabled: + v_first = v_raw + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, v_embed=ve, v_first=v_first if self.vrl_enabled else None) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def forward_hidden(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + v_first: Tensor | None = None + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, v_raw = self.blocks[i](x, x0, v_embed=ve, v_first=v_first if self.vrl_enabled else None) + if i == 0 and self.vrl_enabled: + v_first = v_raw + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, v_embed=ve, v_first=v_first if self.vrl_enabled else None) + return self.final_norm(x) + def compute_logits(self, hidden: Tensor) -> Tensor: + if self.tie_embeddings: + logits_proj = F.linear(hidden, self.tok_emb.weight) + else: + logits_proj = self.lm_head(hidden) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + def forward_logits(self, input_ids: Tensor) -> Tensor: + return self.compute_logits(self.forward_hidden(input_ids)) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +def eval_val_slot( + args: Hyperparameters, + base_model: nn.Module, + rank: int, world_size: int, device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + stride = args.eval_stride if args.eval_stride > 0 else 64 + seq_s = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + total_tok = val_tokens.numel() - 1 + ws_list = list(range(0, total_tok, stride)) + ws_list = [ws for ws in ws_list if min(ws + seq_s, total_tok) - ws >= 1] + my_ws = ws_list[rank::world_size] + if args.tie_embeddings: + proj_w = base_model.tok_emb.weight.detach().float() + else: + proj_w = base_model.lm_head.weight.detach().float() + softcap = base_model.logit_softcap + compiled_hidden = torch.compile(base_model.forward_hidden, dynamic=False, fullgraph=False) + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + for bi in range(0, len(my_ws), args.slot_batch_seqs): + bws = my_ws[bi:bi + args.slot_batch_seqs] + bsz = len(bws) + xb_cpu = torch.zeros(bsz, seq_s, dtype=torch.int64) + yb_cpu = torch.zeros(bsz, seq_s, dtype=torch.int64) + wlens = [] + for i, ws in enumerate(bws): + wend = min(ws + seq_s, total_tok) + wlen = wend - ws + wlens.append(wlen) + xb_cpu[i, :wlen] = val_tokens[ws:wend] + yb_cpu[i, :wlen] = val_tokens[ws + 1:wend + 1] + xb = xb_cpu.to(device=device, non_blocking=True) + yb = yb_cpu.to(device=device, non_blocking=True) + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + hidden = compiled_hidden(xb) + hidden_f = hidden.detach().float() + mask = torch.zeros(bsz, seq_s, device=device) + for i, ws in enumerate(bws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + mask[i, s:wlen] = 1.0 + valid_count = mask.sum() + if valid_count == 0: + continue + delta = torch.zeros(bsz, 1, hidden_f.size(-1), device=device, dtype=torch.float32, requires_grad=True) + logit_bias = torch.zeros(bsz, 1, proj_w.size(0), device=device, dtype=torch.float32, requires_grad=True) + slot_opt = torch.optim.AdamW([delta, logit_bias], lr=args.slot_lr, weight_decay=1e-8, eps=1e-5) + targets_flat = yb.reshape(-1) + for step_i in range(args.slot_steps): + lr_t = args.slot_lr_min + 0.5 * (args.slot_lr - args.slot_lr_min) * (1 + math.cos(math.pi * step_i / args.slot_steps)) + for pg in slot_opt.param_groups: + pg['lr'] = lr_t + slot_opt.zero_grad() + h = hidden_f + delta + lp = F.linear(h, proj_w) + logit_bias + lg = softcap * torch.tanh(lp / softcap) + nll = F.cross_entropy(lg.reshape(-1, lg.size(-1)), targets_flat, reduction="none").reshape(bsz, seq_s) + slot_loss = (nll * mask).sum() / valid_count + slot_loss.backward() + slot_opt.step() + with torch.no_grad(): + h = hidden_f + delta.detach() + lp = F.linear(h, proj_w) + logit_bias.detach() + lg = softcap * torch.tanh(lp / softcap) + nll = F.cross_entropy(lg.reshape(-1, lg.size(-1)), targets_flat, reduction="none").reshape(bsz, seq_s) + for i, ws in enumerate(bws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + chunk_nll = nll[i, s:wlen] + loss_sum += chunk_nll.sum().to(torch.float64) + token_count += float(wlen - s) + prev_ids = xb[i, s:wlen] + tgt_ids = yb[i, s:wlen] + tb = base_bytes_lut[tgt_ids].to(torch.float64) + tb += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.float64) + byte_sum += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_sum.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.to(torch.float16) + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._qat_enabled = args.qat_enabled + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + vrl_enabled=args.vrl_enabled, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.mtp_num_heads > 0: + matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = lzma.compress(quant_raw, preset=6) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + log0(f"Total submission size: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(lzma.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, # must match training model + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + vrl_enabled=args.vrl_enabled, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int6_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int6_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + if args.slot_enabled: + torch._dynamo.reset() + torch.cuda.synchronize() + t_slot = time.perf_counter() + slot_val_loss, slot_val_bpb = eval_val_slot( + args, eval_model, rank, world_size, device, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_slot val_loss:{slot_val_loss:.4f} val_bpb:{slot_val_bpb:.4f} " + f"steps:{args.slot_steps} lr:{args.slot_lr} eval_time:{1000.0 * (time.perf_counter() - t_slot):.0f}ms" + ) + log0(f"final_slot_exact val_loss:{slot_val_loss:.8f} val_bpb:{slot_val_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main()