diff --git a/docs/adr/001-enterprise-production-readiness.md b/docs/adr/001-enterprise-production-readiness.md new file mode 100644 index 00000000..a0143f93 --- /dev/null +++ b/docs/adr/001-enterprise-production-readiness.md @@ -0,0 +1,92 @@ +# ADR-001: Enterprise Production Readiness + +**Status:** Accepted +**Phase:** 4 + 7 +**Date:** 2026-03-01 + +## Context + +Zerfoo had strong foundations (clean interfaces, modular architecture, type-safe +generics, 95%+ test coverage) but lacked operational hardening for enterprise +production deployment. Gaps existed in observability, security, reliability, +configuration management, and CI/CD enforcement. + +Additionally, a Phase 7 architecture review identified structural issues: dead +code (pkg/prelude, tests/helpers nil stubs), inverted layer registration +dependency (layers/core -> model), and thread-unsafe graph memo map. + +## Decision + +### Observability (E21-E22) + +- **Structured logging** via `log/` package: Logger interface with Debug/Info/ + Warn/Error levels, JSON output mode, NopLogger for zero-overhead disabling. + All packages instrumented (compute, distributed, training, model, cmd). +- **Metrics** via `metrics/runtime/` package: Collector interface with Counter, + Gauge, Histogram. InMemoryCollector for testing; NopCollector for production + disabling. CPUEngine/GPUEngine and distributed ops instrumented. + +### Security (E23) + +- TLS/mTLS for all gRPC via `distributed.TLSConfig`. Plaintext fallback for + local dev (nil config). Input validation on all RPC handlers (completed in + Phase 5 via E32). + +### Configuration (E24) + +- `config/` package: `Load[T](path)` and `LoadWithEnv[T](path, prefix)` for + JSON config with env var overrides via struct tags. Validation via + `validate:"required"` tag. Standard structs: EngineConfig, TrainingConfig, + DistributedConfig. + +### Reliability (E25-E26, E28) + +- **Graceful shutdown** via `shutdown/` package: Coordinator with reverse-order + Closer execution, per-closer timeout, signal handling (SIGINT/SIGTERM) in CLI. +- **Health checks** via `health/` package: HTTP /healthz (liveness), /readyz + (readiness with configurable checks), /debug/pprof/. Engine health check + verifies compute is operational. +- **Resource limits**: MemoryTracker with CAS-based enforcement at Engine level. + Per-operation timeout via context.Context deadline checks. + +### CI/CD Hardening (E27) + +- Parity and numerics tests blocking in CI (removed `|| true`). +- Coverage gate via `cmd/coverage-gate/`: fails if any testable package drops + below 93%. +- Benchmark regression via `cmd/bench-compare/`: fails on >10% regression. +- Race detector on all unit tests. Go 1.25 on Ubuntu + macOS runners. + +### Architecture Cleanup (Phase 7: E44-E46) + +- **Dead code removal**: Deleted pkg/prelude (empty), tests/helpers/wire.go + (4 nil stubs), 7 dead test files (17 always-skipping tests). +- **Registration consolidation**: Removed init() from layers/core/registry.go. + Single entry point: `layers/registry.RegisterAll()`. Exported BuildFFN. +- **Graph thread safety**: Added sync.Mutex to graph.Graph protecting memo map + in Forward/Backward. Coarse-grained lock (correct for graphs < 1000 nodes). + +## Consequences + +- All packages use leveled structured logging; no raw fmt.Printf in production. +- Runtime metrics available for Prometheus scraping or in-memory snapshot. +- gRPC is TLS-capable; plaintext is opt-in (nil TLSConfig). +- CI enforces coverage >= 93%, benchmark regression < 10%, zero races. +- Graph.Forward is safe for concurrent use from multiple goroutines. +- No init()-based registration; single wiring point reduces coupling. + +### Blocked Item + +- **E29 GPU hardware validation**: Blocked on GCP GPU quota = 0. Quota request + pending (preference ID: zerfoo-gpu-test, project: numerai-488804). Unblock by + checking quota status or trying a different cloud provider. + +### Key Files + +- `log/logger.go` -- Logger interface, StdLogger, NopLogger +- `metrics/runtime/metrics.go` -- Collector, InMemoryCollector, NopCollector +- `config/loader.go` -- Load[T], LoadWithEnv[T] +- `shutdown/coordinator.go` -- Closer, Coordinator +- `health/server.go` -- Health HTTP server +- `cmd/coverage-gate/main.go` -- CI coverage enforcement +- `cmd/bench-compare/main.go` -- CI benchmark regression detection diff --git a/docs/adr/002-distributed-training-protocol.md b/docs/adr/002-distributed-training-protocol.md new file mode 100644 index 00000000..5057df89 --- /dev/null +++ b/docs/adr/002-distributed-training-protocol.md @@ -0,0 +1,74 @@ +# ADR-002: Distributed Training Protocol + +**Status:** Accepted +**Phase:** 5 +**Date:** 2026-03-01 + +## Context + +The distributed package had auto-generated protobuf stubs, a coordinator server, +InternalStrategy[T] interface, AllReduceStrategy[T], NetworkManager, and +ServerManager. Missing: concrete DistributedServiceServer, GrpcStrategy[T] +connecting strategy to transport, WorkerNode lifecycle management, and +multi-worker integration tests. + +## Decision + +### AllReduce Protocol (Star Topology) + +Root (rank 0) runs the server collecting gradients from all peers. Each non-root +worker opens a bidi stream to root, sends gradients as AllReduceRequest messages +(one per named tensor), then waits for AllReduceResponse with averaged results. +Root accumulates peer gradients plus its own, computes element-wise average +(sum / world_size), and streams results back. + +A `reduceSession` struct coordinates across concurrent AllReduce stream handlers: +collects tensors by name from each peer, waits via sync barrier, computes +reduction, distributes result. + +Ring all-reduce optimization was explicitly deferred (correctness first). + +### Barrier Protocol + +Counter-based. Each worker calls Barrier RPC on root. Root counts arrivals and +blocks via sync.Cond until all workers arrive. Epoch numbers prevent stale +barrier responses. + +### Broadcast Protocol + +Root stores broadcast tensor in thread-safe map keyed by name. Non-root workers +call Broadcast RPC on root. If tensor not yet available, handler waits with +context deadline. + +### Tensor Serialization + +pb.Tensor uses repeated float (float32 only). GrpcStrategy[T] converts +tensor.TensorNumeric[T] to/from pb.Tensor. For T=float64, values narrow to +float32 for transport (acceptable for gradient averaging). + +### Worker Lifecycle + +WorkerNode struct encapsulates: GrpcStrategy, coordinator connection, health +check registration, shutdown.Closer implementation. Start(ctx, cfg) initializes +strategy, registers with coordinator, starts gRPC server, connects to peers. +Close(ctx) triggers orderly shutdown. + +CLI `worker` subcommand: --coordinator-address, --worker-address, --worker-id +flags. Signal handling via cli.SignalContext. + +## Consequences + +- Star topology is simple and correct but O(N) at root. Ring optimization + deferred to future phase. +- All distributed operations tested end-to-end over real gRPC (bufconn). +- TLS integration tested with self-signed certificates. +- Worker lifecycle integrated with health checks and shutdown coordinator. +- distributed/ package at 96% coverage. + +### Key Files + +- `distributed/worker_service.go` -- DistributedServiceServer implementation +- `distributed/grpc_strategy.go` -- GrpcStrategy[T] (InternalStrategy over gRPC) +- `distributed/worker_node.go` -- WorkerNode lifecycle +- `distributed/integration_test.go` -- Multi-worker tests (bufconn) +- `cmd/cli/worker.go` -- Worker CLI subcommand diff --git a/docs/adr/003-open-weights-model-import.md b/docs/adr/003-open-weights-model-import.md new file mode 100644 index 00000000..af5b8bb3 --- /dev/null +++ b/docs/adr/003-open-weights-model-import.md @@ -0,0 +1,84 @@ +# ADR-003: Open Weights Model Import + +**Status:** Accepted +**Phase:** 6 +**Date:** 2026-03-02 + +## Context + +Zerfoo could train and run inference on models built with its layer API, but +importing pre-trained open-weights models (Gemma 3, Kimi-VL) required closing +gaps in the ONNX import pipeline (zonnx repo) and layer registry. + +Gap analysis identified blockers: zonnx converter missing TENSOR attribute and +UINT8 dtype; MatMulNBits and Constant not registered; vision encoder operators +(Conv2d, Pad, Slice, Resize, BatchNorm, GlobalAveragePool) not implemented; +MoE not implemented. + +## Decision + +### 4-Bit Weight Packing + +MatMulNBits stores 4-bit weights packed two-per-byte in UINT8 tensors. ZMF uses +DataType=UINT8. Dequantization happens in MatMulNBits.Forward() using +numeric.Unpack4BitSlice. Supports symmetric and asymmetric quantization with +per-block scales and optional zero-points. + +### Conv2d Strategy + +Direct nested-loop convolution (not im2col + MatMul). Simpler, correct for +inference workloads, avoids allocating large intermediate matrices. Deviation +from original plan noted. + +### MoE Design + +MoEGate routes tokens to top-k experts via softmax + topK selection. Gate weight +passed as runtime Forward input (not from params) to match the ONNX/ZMF pattern. +MixtureOfExperts dispatches to selected experts and aggregates weighted outputs. +Expert loading from ZMF sub-graphs deferred (tech debt). + +### Operator Inventory + +New operators implemented and registered: + +| Operator | File | Category | +|----------|------|----------| +| Softmax | layers/activations/softmax.go | Activation | +| Sigmoid builder | layers/activations/registry.go | Activation | +| Erf | layers/activations/erf.go | Activation | +| LayerNormalization | layers/normalization/registry.go | Normalization | +| BatchNormalization | layers/normalization/batch_norm.go | Normalization | +| Slice | layers/core/slice.go | Core | +| Pad | layers/core/pad.go | Core | +| TopK | layers/core/topk.go | Core | +| Conv2d | layers/core/conv2d.go | Core | +| GlobalAveragePool | layers/core/global_avg_pool.go | Core | +| Resize | layers/core/resize.go | Core | +| MoEGate | layers/core/moe.go | Core | +| MixtureOfExperts | layers/core/moe.go | Core | +| Constant | layers/core/constant.go | Core | + +### Multi-Repo Discipline + +zonnx and zerfoo are separate repos. Pre-commit hooks reject multi-directory +commits. zonnx converter fixes committed in zonnx; layer/model changes in zerfoo. +ONNX-to-ZMF conversion in zonnx handles special cases (Slice/Pad/TopK input +promotion, Resize scales/sizes, MatMulNBits dequantization). + +## Consequences + +- Gemma 3 end-to-end import validated (forward pass + greedy decode). +- SigLIP vision encoder and Kimi-VL connector validated. +- 13 new operators added to registry; total 56+ layers. +- MatMulNBits dequantization at converter level (not runtime) for standard + MatMul path. +- Expert loading from sub-graphs is documented tech debt. +- All parity tests are env-var gated (skip gracefully without model files). + +### Key Files + +- `layers/core/moe.go` -- MoEGate, MixtureOfExperts +- `layers/core/conv2d.go` -- Conv2d (nested-loop) +- `layers/core/constant.go` -- Constant node +- `tests/parity/gemma3_test.go` -- Gemma 3 parity test +- `tests/parity/siglip_test.go` -- SigLIP/Kimi-VL parity tests diff --git a/docs/adr/004-embeddable-inference-library.md b/docs/adr/004-embeddable-inference-library.md new file mode 100644 index 00000000..5ed3fe56 --- /dev/null +++ b/docs/adr/004-embeddable-inference-library.md @@ -0,0 +1,99 @@ +# ADR-004: Embeddable Inference Library + +**Status:** Accepted +**Phase:** 8 +**Date:** 2026-03-02 + +## Context + +Running inference on an imported model required extensive manual wiring: download +ONNX files, convert with zonnx CLI, write Go code to create Engine/Graph, use +whitespace-only tokenizer, call Forward in a manual loop (no KV cache, O(n^2)), +no sampling beyond argmax, no streaming. + +Phase 8 transforms zerfoo into an embeddable inference library: + +```go +m, _ := inference.Load("google/gemma-3-4b-it") +resp, _ := m.Generate(ctx, "Explain quantum computing") +``` + +## Decision + +### BPE Tokenizer + +Pure Go, no CGo. Loads HuggingFace tokenizer.json format (vocab, merge rules, +pre-tokenizer config, normalizer, special tokens). BPE merge loop: split into +bytes, iteratively merge highest-priority adjacent pair. Byte-level BPE +pre-tokenization (GPT-2 style). SentencePiece .model files not supported +directly; most HuggingFace models ship tokenizer.json. + +### KV Cache + +GenerationContext embeds context.Context and carries *KVCache. KVCache stores +per-layer key/value tensors (appended on each step). Attention layers +(GroupQueryAttention, GlobalAttention) check for KVCache in context: if present, +append current K/V, use full cached K/V for computation. Graph.Forward() +signature unchanged (opt-in via context). Callers without KVCache get existing +full-recompute behavior. + +### Generation Loop + +Generator holds graph, tokenizer, engine, model config. Autoregressive loop: +1. Encode prompt to token IDs +2. Forward pass for logits [1, seqLen, vocabSize] +3. Extract last-position logits +4. Apply: temperature scaling -> top-k filtering -> top-p filtering -> + repetition penalty -> sample (or argmax at temperature=0) +5. Check stop conditions (EOS, max tokens, stop strings) +6. Repeat with new token as input (KV cache handles prefix) + +### Streaming + +TokenStream interface with OnToken(token string, done bool) error. GenerateStream +delivers each decoded token as generated. Sentinel-based stop-string detection +with delta emission. + +### Model Registry + +Local cache under ~/.zerfoo/models/. Layout: // containing +model.zmf, tokenizer.json, config.json. Pull: download from HuggingFace Hub API, +convert ONNX to ZMF (zonnx as Go library), cache locally. HF_TOKEN env var for +gated models. + +### HTTP Serve + +net/http server. OpenAI-compatible endpoints: +- POST /v1/chat/completions (non-streaming + SSE) +- POST /v1/completions (non-streaming + SSE) +- GET /v1/models + +### Constraints + +- Pure Go, no CGo, no external C libraries for tokenization. +- KV cache is opt-in; does not break existing callers. +- Model registry works offline after initial pull. +- No training through high-level API. +- No multi-model serving. + +## Consequences + +- 3-line model loading and generation for end users. +- O(n) per generation step with KV cache (vs O(n^2) without). +- CLI commands (pull/run/serve) for interactive and server use. +- OpenAI API compatibility enables tool interoperability. +- Coverage: generate 95%, inference 96.4%, serve 96.4%. +- Embeddings not yet supported (Embed returns error). + +### Key Files + +- `pkg/tokenizer/bpe.go` -- BPE tokenizer +- `pkg/tokenizer/loader.go` -- tokenizer.json loader +- `generate/kvcache.go` -- KV cache +- `generate/generator.go` -- Autoregressive generation loop +- `generate/sampling.go` -- Temperature, top-k, top-p, repetition penalty +- `generate/stream.go` -- TokenStream interface +- `registry/registry.go` -- ModelRegistry, LocalRegistry +- `inference/inference.go` -- Load, Generate, Chat, GenerateStream +- `serve/server.go` -- OpenAI-compatible HTTP server +- `cmd/cli/{pull,run,serve}.go` -- CLI commands diff --git a/docs/adr/005-multi-architecture-support.md b/docs/adr/005-multi-architecture-support.md new file mode 100644 index 00000000..48d18a8e --- /dev/null +++ b/docs/adr/005-multi-architecture-support.md @@ -0,0 +1,114 @@ +# ADR-005: Multi-Architecture Support + +**Status:** Accepted +**Phase:** 9 +**Date:** 2026-03-02 + +## Context + +After Phase 8, only Gemma 3 was validated. The open-weights ecosystem uses a +small set of architectural building blocks in different combinations per model +family. Gap analysis of major families identified a tier system based on +implementation effort. + +## Decision + +### Tier System + +| Tier | Models | Gap | Status | +|------|--------|-----|--------| +| 1 | Llama 3/4, Mistral, Mixtral | Config mapping only | Complete | +| 2 | Qwen 2.5/3 | QKV bias, YaRN RoPE | Complete | +| 3 | Phi-4 | Partial RoPE, tied embeddings | Complete | +| 4 | DeepSeek V3/R1 | MLA, shared expert MoE | Complete | +| 5 | Falcon Mamba, RWKV, DeltaNet | Different paradigm (SSM/linear attn) | Out of scope | + +### Config Registry + +`inference.ConfigRegistry` maps model_type strings (from HuggingFace config.json) +to parser functions. Each normalizes model-specific JSON fields into unified +ModelMetadata. Registered parsers: gemma2, gemma3, llama, mistral, qwen2, phi, +deepseek. Fallback parser for unknown model_type attempts direct unmarshal. + +### Parameter Name Resolver + +`model.ParamResolver` maps architecture-specific weight names to canonical names: +- Llama/Gemma: model.layers.{i}.self_attn.{q,k,v,o}_proj.weight +- Qwen: same + .bias suffix +- Phi: self_attn.{q,k,v,dense}_proj.weight +- DeepSeek: self_attn.{kv_a_proj,kv_b_proj,q_a_proj,q_b_proj,o_proj}.weight +- FFN: model.layers.{i}.mlp.{gate_proj,up_proj,down_proj}.weight (all) + +Called during model building as fallback when exact name not found. + +### QKV Bias (Tier 2 -- Qwen) + +Optional bias in GQA handled by existing Dense layer (core.NewDenseFromParams +accepts *Bias). Registry builder uses optionalBias helper to look up bias params. +Backwards compatible (nil bias = no change). + +### YaRN RoPE Scaling (Tier 2 -- Qwen) + +`embeddings.WithYaRNScaling(factor, origMaxLen)` modifies inverse frequencies: +- Low-frequency components scaled by 1/factor +- High-frequency components kept unchanged +- Intermediate linearly interpolated +- Attention scaling factor: sqrt(1 + ln(factor) / ln(origMaxLen)) + +Construction-time change; no forward-pass performance impact. + +### Partial RoPE (Tier 3 -- Phi-4) + +`embeddings.WithRotaryDimFraction(fraction)` rotates only a fraction of head +dimensions. Default 1.0 (all rotated). Phi-4 uses 0.75. Forward splits input +into rotated/non-rotated portions along last dimension, applies RoPE to rotated +part, concatenates. + +### Tied Embeddings (Tier 3 -- Phi-4) + +`core.NewTiedLMHead(engine, tiedWeight)` reuses token embedding weight +(transposed) as output projection. Builder checks tie_word_embeddings config +flag. Existing untied LMHead unchanged. + +### Multi-head Latent Attention (Tier 4 -- DeepSeek) + +New `attention.MultiHeadLatentAttention[T]` layer (not a GQA modification): +1. Down-project: c_kv = x * W_DKV (compress to kv_lora_dim, e.g., 512) +2. Up-project: K = c_kv * W_UK, V = c_kv * W_UV +3. Q = x * W_Q +4. Decoupled RoPE on Q and K subvector +5. Standard scaled dot-product attention +6. Output projection + +KV cache stores full tensors (compressed latent caching deferred). + +### Shared Expert MoE (Tier 4 -- DeepSeek) + +MixtureOfExperts.SharedExpert field. When non-nil, Forward() runs shared expert +on every token and adds output to weighted routed sum. Builder loading of shared +expert weights deferred (requires ZMF sub-graph support). + +### Global Attributes + +Model-level attributes (rope_scaling, partial_rotary_factor, tie_word_embeddings) +injected via `model.WithGlobalAttributes` during graph construction so individual +layer builders can access them. + +## Consequences + +- 6 model families supported, covering >90% of HuggingFace open-weight downloads. +- Each tier is independently shippable. +- Parity tests for all 6 families (env-var gated). +- MLA KV cache compression and shared expert weight loading are documented tech debt. +- SSM/linear attention architectures (Tier 5) deferred to future phase. + +### Key Files + +- `inference/arch_config.go` -- Config registry and parsers +- `model/param_resolver.go` -- Architecture-specific parameter resolution +- `layers/embeddings/rotary_positional_embedding.go` -- YaRN + partial RoPE +- `layers/attention/multi_head_latent_attention.go` -- MLA +- `layers/attention/mla_registry.go` -- MLA builder +- `layers/core/lm_head.go` -- Tied embeddings +- `layers/core/moe.go` -- Shared expert support +- `tests/parity/{llama3,mistral,qwen,phi4,deepseek}_test.go` -- Parity tests diff --git a/docs/adr/006-gpu-engine-architecture.md b/docs/adr/006-gpu-engine-architecture.md new file mode 100644 index 00000000..bd8f4fcc --- /dev/null +++ b/docs/adr/006-gpu-engine-architecture.md @@ -0,0 +1,103 @@ +# ADR-006: GPU Engine Architecture + +**Status:** Accepted +**Phase:** 2-3 +**Date:** 2026-03-01 + +## Context + +Zerfoo's compute.Engine[T] interface needed a GPU implementation for hardware +acceleration. The GPU engine must integrate with the existing interface contract +so all layers work transparently on CPU or GPU. + +## Decision + +### CUDA Float32 Engine + +GPUEngine[T] implements Engine[T] with CUDA kernels for float32. 20 operations +have native GPU implementations; remaining operations use CPU fallback by design +(not compute-bound or require Go runtime features). + +GPU-accelerated operations: +- Matrix: MatMul (cuBLAS Sgemm, 2D and batched) +- Element-wise: Add, Sub, Mul, Div, Pow (custom CUDA kernels) +- Scalar: AddScalar, MulScalar, DivScalar +- Activation: Tanh, TanhPrime +- Math: Exp, Log, Sqrt, Rsqrt +- Reduction: Sum, ReduceSum, ReduceMean (shared memory) +- Other: Softmax, Fill + +CPU fallback operations: UnaryOp, Transpose, Zero/Zeros/Copy, Reshape/Split/ +Concat/Repeat, Gather/ScatterAdd, OneHot/RandomUniform. + +### Memory Pool + +`internal/cuda/mempool.go`: Size-bucketed free-list allocator. Reuses previously +freed device memory, avoiding per-operation cudaMalloc/cudaFree. Mutex- +synchronized. Drained on GPUEngine.Close(). + +### cuBLAS Row-Major Strategy + +cuBLAS operates in column-major order. To compute C = A * B in row-major: +call cublasSgemm with B as first argument, A as second, swapping m and n. +Avoids explicit transposition. + +### OOM Fallback + +When cudaMalloc fails, GPU operations fall back to CPUEngine transparently. +Atomic counter (OOMFallbackCount()) tracks fallback frequency. Logged at WARN +level via structured logger. + +### Device-Resident Pipeline + +GPU operations produce tensors with GPUStorage, keeping data on-device between +chained operations. Only first input (if CPU-backed) does H2D copy; only final +result (.Data() call) does D2H copy. + +``` +CPU Input -> H2D (via pool) -> Kernel -> GPUStorage + | + GPUStorage -> Kernel -> GPUStorage + | + .Data() -> D2H -> CPU +``` + +### Parity Tolerances + +| Operation | Tolerance | +|-----------|-----------| +| MatMul | 1e-5 relative error | +| Element-wise | 1e-6 relative error | +| Reductions | 1e-5 relative error | + +### Build Requirements + +- CUDA Toolkit 12.x (libcudart, headers) +- cuBLAS library +- NVIDIA GPU with Compute Capability >= 7.0 +- All CUDA code gated behind `//go:build cuda` +- Kernels compiled via `internal/cuda/kernels/Makefile` (nvcc) + +### Hardware Validation + +Blocked on GCP GPU quota. Target: Tesla T4 (sm_75). Also compatible with +V100 (sm_70), L4 (sm_89), A100 (sm_80), DGX Spark GB10 (sm_120). + +## Consequences + +- Transparent CPU/GPU switching via Engine[T] interface. +- Float32 only on GPU; other types fall back to CPU. +- No broadcasting in GPU kernels (falls back to CPU). +- Single GPU only; no multi-GPU or device selection API. +- Memory pool reduces cudaMalloc overhead for iterative workloads. +- Hardware validation still pending (E29 blocked on quota). + +### Key Files + +- `compute/gpu_engine.go` -- GPUEngine (pool, stream, cuBLAS) +- `compute/gpu_kernels.go` -- getDevicePtr, makeGPUResult, kernel dispatch +- `tensor/gpu_storage.go` -- GPUStorage[T] +- `internal/cuda/runtime.go` -- CUDA runtime bindings +- `internal/cuda/mempool.go` -- Memory pool +- `internal/cuda/kernels/elementwise.cu` -- 17 CUDA kernels +- `internal/cublas/cublas.go` -- cuBLAS bindings diff --git a/docs/design.md b/docs/design.md index e7ff4cf2..7acefaf7 100644 --- a/docs/design.md +++ b/docs/design.md @@ -607,7 +607,9 @@ Documented exceptions (unreachable `tensor.New` error paths): - Table-driven tests using standard `testing` package (no testify). - Parity tests comparing GPU vs CPU output for every GPU-accelerated method. -- Model parity tests gated by env vars (GEMMA3_ZMF_PATH, SIGLIP_ZMF_PATH, KIMI_CONNECTOR_ZMF_PATH). +- Model parity tests gated by env vars (GEMMA3_ZMF_PATH, SIGLIP_ZMF_PATH, KIMI_CONNECTOR_ZMF_PATH, + LLAMA3_ZMF_PATH, MISTRAL_ZMF_PATH, QWEN25_ZMF_PATH, PHI4_ZMF_PATH, DEEPSEEK_ZMF_PATH). +- Parity tests cover 6 model families: Gemma 3, Llama 3, Mistral, Qwen 2.5, Phi-4, DeepSeek V3. - Integration tests for cross-package workflows. - Numerical gradient checking via finite differences. - MockEngine for unit testing layers in isolation. @@ -810,3 +812,72 @@ inference.Model v (serve.NewServer) OpenAI-compatible HTTP API ``` + +## 12. Multi-Architecture Support (Phase 9) + +Phase 9 extends the inference pipeline to support multiple model architectures +beyond Gemma 3. Each architecture has distinct attention, normalization, and +routing patterns that are handled through the config registry and layer builders. + +### 12.1 Supported Model Families + +| Model Family | Attention | RoPE Variant | Normalization | MoE | Config Parser | +|---|---|---|---|---|---| +| Gemma 3 | GQA | Standard | RMSNorm | No | `gemma3` | +| LLaMA 3 | GQA | Standard | RMSNorm | No | `llama` | +| Mistral | GQA | Standard | RMSNorm | No | `mistral` | +| Qwen 2.5 | GQA | YaRN scaling | RMSNorm | No | `qwen2` | +| Phi-4 | GQA | Partial (0.75) | LayerNorm | No | `phi` | +| DeepSeek V3 | MLA | Standard | RMSNorm | Shared expert | `deepseek` | + +### 12.2 Architecture-Specific Features + +**YaRN Scaling (Qwen 2.5):** `embeddings.WithYaRNScaling(factor, origMaxLen)` +classifies frequency bands into low/medium/high and applies differential scaling +to extend context beyond the original training length. + +**Partial RoPE (Phi-4):** `embeddings.WithRotaryDimFraction(fraction)` rotates +only a fraction of head dimensions, leaving the rest as pass-through. Phi-4 +uses fraction=0.75. + +**Multi-head Latent Attention (DeepSeek V3):** `attention.MultiHeadLatentAttention` +compresses KV into a low-rank latent via down-projection (`W_DKV`), then +up-projects to keys (`W_UK`) and values (`W_UV`), reducing KV cache size. + +**Shared Expert MoE (DeepSeek V3):** `core.MixtureOfExperts.SharedExpert` runs +one expert on every token and adds its output to the weighted routed sum. + +**Tied Embeddings (Phi-4, Gemma 3):** `core.NewTiedLMHead` reuses the token +embedding weight matrix (transposed) as the output projection, halving the +parameter count for the LM head. + +### 12.3 Config Registry + +`inference.ConfigRegistry` maps model family names to config parsers that extract +`ModelMetadata` from `config.json`. Each parser reads architecture-specific fields +(e.g., `rope_scaling`, `partial_rotary_factor`, `n_shared_experts`) and maps them +to the common metadata struct. Global attributes (rope scaling, partial rotation) +are injected via `model.WithGlobalAttributes` during graph construction. + +### 12.4 Parameter Name Resolver + +`model.ParamResolver` maps architecture-specific weight names (e.g., Llama's +`q_proj.weight` vs DeepSeek's `kv_a_proj.weight`) to canonical names used by +Zerfoo layer builders. Called as a fallback during model building when exact +parameter names are not found. See [ADR-005](adr/005-multi-architecture-support.md). + +--- + +## 13. Architectural Decision Records + +Stable design decisions extracted from the implementation plan into self-contained +ADR files in `docs/adr/`. + +| ADR | Title | Phase | Key Decision | +|-----|-------|-------|-------------| +| [001](adr/001-enterprise-production-readiness.md) | Enterprise Production Readiness | 4+7 | Logging, metrics, config, health, CI gates, dead code removal, graph thread safety | +| [002](adr/002-distributed-training-protocol.md) | Distributed Training Protocol | 5 | Star-topology AllReduce, counter-based Barrier, WorkerNode lifecycle | +| [003](adr/003-open-weights-model-import.md) | Open Weights Model Import | 6 | 4-bit weights, Conv2d strategy, MoE design, 13 new operators | +| [004](adr/004-embeddable-inference-library.md) | Embeddable Inference Library | 8 | BPE tokenizer, KV cache, generation loop, sampling, streaming, serve | +| [005](adr/005-multi-architecture-support.md) | Multi-Architecture Support | 9 | Config registry, param resolver, YaRN, partial RoPE, MLA, shared MoE | +| [006](adr/006-gpu-engine-architecture.md) | GPU Engine Architecture | 2-3 | CUDA float32, memory pool, cuBLAS row-major, OOM fallback, parity tolerances | diff --git a/docs/plan.md b/docs/plan.md index 171d4f42..5a8c7858 100644 --- a/docs/plan.md +++ b/docs/plan.md @@ -7,30 +7,13 @@ Zerfoo is a Go-based ML framework with 40+ packages, a 34-method compute Engine[T] interface, CPU and CUDA GPU backends, gRPC-based distributed training, and comprehensive test coverage (95%+ across testable packages). - -The framework has strong foundations -- clean interfaces, modular architecture, -type-safe generics, and high test coverage -- but lacks the operational -hardening required for enterprise production deployment. This plan addresses -the gaps in observability, security, reliability, configuration management, -and CI/CD enforcement needed to reach production grade. +This plan tracked the gaps in observability, security, reliability, +configuration, CI/CD, model import, inference, and multi-architecture support +needed to reach production grade. Architecture, design, GPU details, operations, and troubleshooting are -documented in docs/design.md (the single reference document). - -### Objectives - -- O1: Add structured logging with configurable log levels across all packages. -- O2: Export runtime metrics (throughput, latency, memory, errors) via a - metrics interface suitable for Prometheus or similar backends. -- O3: Harden gRPC distributed services with TLS and mutual authentication. -- O4: Add file-based configuration loading with validation and env var overrides. -- O5: Implement graceful shutdown with resource cleanup across all components. -- O6: Add health check endpoints for readiness and liveness probes. -- O7: Make parity and numerics tests blocking in CI; add coverage gates. -- O8: Add benchmark regression detection to prevent performance degradation. -- O9: Add resource limits (memory caps, timeouts) to prevent unbounded allocation. -- O10: Validate GPU implementation on real NVIDIA hardware (blocked on GCP quota). -- O11: Create production deployment runbook and troubleshooting guide. +documented in docs/design.md (the single reference document). Stable design +decisions are extracted into docs/adr/ (see [ADR index](design.md#13-architectural-decision-records)). ### Non-Goals @@ -41,6 +24,7 @@ documented in docs/design.md (the single reference document). - Breaking changes to the Engine[T] or Node[T] interfaces. - Replacing gRPC with a different RPC framework. - Adding third-party test frameworks (testify, etc.). +- SSM/Mamba architectures (Falcon Mamba, RWKV, Jamba). ### Constraints and Assumptions @@ -49,57 +33,18 @@ documented in docs/design.md (the single reference document). - Pre-commit hook rejects commits spanning multiple directories. - All changes must pass golangci-lint, go vet, and gofmt. - Tests must pass with -race flag. -- No Docker Compose. Prefer DevSpace if orchestration is needed. - Table-driven tests using the standard testing package. -### Success Metrics - -| Metric | Target | Measurement | -|--------|--------|-------------| -| Structured logging | All packages use leveled logger | Grep for raw fmt.Print/log.Print in non-test code = 0 | -| Metrics export | Runtime metrics available via interface | Metrics interface has >= 10 counters/gauges | -| TLS coverage | All gRPC endpoints use TLS | No plaintext gRPC listeners in production config | -| Config loading | YAML/JSON config from file + env vars | Config loads from file, env overrides work | -| Graceful shutdown | All components clean up on SIGTERM | Integration test verifies orderly shutdown | -| Health checks | Readiness + liveness probes | HTTP endpoint returns status within 100ms | -| CI blocking tests | Parity + numerics tests block merges | CI fails on parity/numerics test failure | -| Benchmark gates | CI fails on > 10% regression | Benchmark comparison in CI workflow | -| Resource limits | Memory caps enforced | Allocation above limit returns error | -| Coverage gate | >= 95% enforced in CI | CI fails if coverage drops below threshold | - --- ## 2. Scope and Deliverables -### In Scope - -- Structured logging library (log levels, JSON output, context propagation). -- Metrics collection interface and default implementation. -- TLS/mTLS configuration for gRPC services. -- File-based configuration with validation and environment variable overrides. -- Graceful shutdown coordination across Engine, distributed workers, gRPC server. -- Health check HTTP endpoints. -- CI hardening: blocking parity tests, coverage gates, benchmark regression detection. -- Resource limit enforcement (memory, timeouts). -- Production documentation (deployment runbook, troubleshooting guide). -- GPU hardware validation (when quota available). - -### Out of Scope - -- Web UI or dashboard. -- Model serving HTTP API (inference server). -- Automatic device placement or tensor migration. -- Database or persistent storage integration. -- Container image building or Kubernetes manifests. - -### Deliverables - | ID | Description | Acceptance Criteria | |----|-------------|---------------------| | D1 | Structured logging | Logger interface with Debug/Info/Warn/Error levels; JSON output mode; all packages instrumented | | D2 | Metrics interface | Counters, gauges, histograms; default in-memory impl; export-ready | | D3 | gRPC TLS | TLS config struct; mTLS support; integration test with TLS | -| D4 | Config management | YAML/JSON loader; env var overrides; validation errors | +| D4 | Config management | JSON loader; env var overrides; validation errors | | D5 | Graceful shutdown | Context-based cancellation; cleanup ordering; integration test | | D6 | Health checks | HTTP /healthz and /readyz endpoints; configurable checks | | D7 | CI hardening | Blocking parity/numerics; coverage gate; benchmark gate | @@ -109,2648 +54,82 @@ documented in docs/design.md (the single reference document). --- -## 3. Checkable Work Breakdown - -### Completed Work (Phases 1-3) - -Phase 1 (Test Coverage), Phase 2 (GPU Engine), and Phase 3 (GPU Production -Readiness) are complete. Details in docs/design.md Section 7. - -Remaining blocked items from Phase 3: -- T15.1 GPU hardware validation -- BLOCKED on GCP GPU quota -- T20.1 Production benchmarks on T4 -- BLOCKED on T15.1 - ---- - -### Phase 4: Enterprise Production Readiness - -#### E21: Structured Logging - -Add a logging abstraction that supports leveled output, structured fields, -and JSON format. Instrument all packages that currently use raw fmt.Printf -or the distributed Logger interface. - -- [x] T21.1 Define Logger interface in a new `log` package Owner: TBD Est: 1h Completed: 2026 03 01 - - Dependencies: None - - Acceptance: Interface has Debug, Info, Warn, Error methods. Each accepts a message string and key-value fields. A NopLogger and a StdLogger (writing to io.Writer) are provided. JSON output mode is available via a constructor option. - - [x] S21.1.1 Create log/logger.go with Logger interface and Level type Est: 20m - - [x] S21.1.2 Implement StdLogger with level filtering and text/JSON output Est: 25m - - [x] S21.1.3 Implement NopLogger (zero-allocation no-op) Est: 5m - - [x] S21.1.4 Write unit tests for StdLogger (level filtering, JSON format, field rendering) Est: 20m - - [x] S21.1.5 Run golangci-lint and go test -cover Est: 5m - -- [x] T21.2 Integrate Logger into compute package Owner: TBD Est: 45m Completed: 2026 03 01 - - Dependencies: T21.1 - - Acceptance: CPUEngine and GPUEngine accept a Logger at construction. OOM fallback, stream errors, and pool operations log at appropriate levels. No raw fmt.Printf calls remain in compute/. - - [x] S21.2.1 Add Logger field to CPUEngine; log parallelFor errors at Warn Est: 15m - - [x] S21.2.2 Add Logger field to GPUEngine; log OOM fallback, pool stats, stream errors Est: 20m - - [x] S21.2.3 Update tests to verify log output in error scenarios Est: 15m - - [x] S21.2.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T21.3 Integrate Logger into distributed package Owner: TBD Est: 45m Completed: 2026 03 01 - - Dependencies: T21.1 - - Acceptance: Replace existing distributed.Logger interface with log.Logger. All coordinator and worker components use leveled logging. Connection events logged at Info, errors at Error. - - [x] S21.3.1 Update distributed.ServerManager, coordinator to accept log.Logger Est: 15m - - [x] S21.3.2 Replace all fmt.Printf calls in distributed/ with logger calls Est: 15m - - [x] S21.3.3 Update tests to use StdLogger or NopLogger Est: 10m - - [x] S21.3.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T21.4 Integrate Logger into remaining packages Owner: TBD Est: 30m Completed: 2026 03 01 - - Dependencies: T21.1 - - Acceptance: training/, model/, cmd/cli/ use log.Logger. No raw fmt.Printf in non-test production code. - - [x] S21.4.1 Add Logger to training.WorkflowConfig and optimizer constructors Est: 10m - - [x] S21.4.2 Add Logger to model package and cmd/cli framework Est: 10m - - [x] S21.4.3 Audit all packages for remaining fmt.Printf; replace with logger Est: 10m - - [x] S21.4.4 Run golangci-lint and go test -cover Est: 5m - -#### E22: Metrics Interface - -Add a metrics collection abstraction for runtime observability. The interface -must be backend-agnostic (usable with Prometheus, StatsD, or in-memory). - -- [x] T22.1 Define Metrics interface in a new `metrics/runtime` package Owner: TBD Est: 1h Completed: 2026 03 01 - - Dependencies: None - - Acceptance: Interface has Counter(name), Gauge(name), Histogram(name, buckets) methods. Each returns a typed metric with Inc/Set/Observe methods. A default in-memory implementation is provided for testing and local use. A NopMetrics implementation is provided for zero overhead when metrics are disabled. - - [x] S22.1.1 Create metrics/runtime/metrics.go with Collector interface Est: 20m - - [x] S22.1.2 Implement InMemoryCollector with thread-safe counters/gauges Est: 25m - - [x] S22.1.3 Implement NopCollector (zero-allocation no-op) Est: 5m - - [x] S22.1.4 Write unit tests for InMemoryCollector (concurrent access, snapshot) Est: 15m - - [x] S22.1.5 Run golangci-lint and go test -cover Est: 5m - -- [x] T22.2 Instrument compute.Engine with metrics Owner: TBD Est: 45m Completed: 2026 03 01 - - Dependencies: T22.1 - - Acceptance: CPUEngine and GPUEngine report: op_count (counter per operation type), op_duration_seconds (histogram), oom_fallback_total (counter), pool_hit_total / pool_miss_total (counters for GPU pool). - - [x] S22.2.1 Add Collector field to CPUEngine; instrument Add/MatMul/etc. with counters and timers Est: 20m - - [x] S22.2.2 Add Collector field to GPUEngine; instrument kernel dispatch, OOM, pool Est: 20m - - [x] S22.2.3 Write tests verifying metric increments after operations Est: 15m - - [x] S22.2.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T22.3 Instrument distributed package with metrics Owner: TBD Est: 30m Completed: 2026 03 01 - - Dependencies: T22.1 - - Acceptance: Distributed workers report: allreduce_count (counter), allreduce_duration_seconds (histogram), barrier_count, broadcast_count, connection_errors_total. - - [x] S22.3.1 Add Collector to Strategy and coordinator Est: 15m - - [x] S22.3.2 Instrument AllReduceGradients, Barrier, BroadcastTensor Est: 10m - - [x] S22.3.3 Write tests verifying metrics after distributed operations Est: 10m - - [x] S22.3.4 Run golangci-lint and go test -cover Est: 5m - -#### E23: gRPC Security Hardening - -Add TLS and mutual authentication to all gRPC communication channels. - -- [x] T23.1 Add TLS configuration to gRPC server and client Owner: TBD Est: 1h Completed: 2026 03 01 - - Dependencies: None - - Acceptance: A TLSConfig struct supports: CA cert path, server cert/key paths, client cert/key paths for mTLS. ServerManager.Start() uses TLS credentials when TLSConfig is provided. Worker connections use TLS. Plaintext is still supported (for local development) when TLSConfig is nil. - - [x] S23.1.1 Create distributed/tlsconfig.go with TLSConfig struct and credential helpers Est: 20m - - [x] S23.1.2 Update ServerManager to accept TLSConfig and create TLS listener Est: 15m - - [x] S23.1.3 Update NetworkManager.ConnectToPeers to use TLS dial options Est: 15m - - [x] S23.1.4 Write integration test: server + client with self-signed TLS certs Est: 20m - - [x] S23.1.5 Write integration test: mTLS with client cert verification Est: 15m - - [x] S23.1.6 Run golangci-lint and go test -cover Est: 5m - -- [x] T23.2 Add input validation to distributed RPC handlers Owner: TBD Est: 30m Completed: 2026 03 01 via T32.5 Note: Implemented as part of Phase 5 E32 workerService. - - Dependencies: None - - Acceptance: All RPC handlers validate request fields (non-empty rank, valid tensor shapes, non-nil data). Invalid requests return gRPC InvalidArgument status. Tests verify each validation path. - - [x] S23.2.1 Add validation to AllReduce, Barrier, Broadcast RPC handlers Est: 15m - - [x] S23.2.2 Write tests for each validation error case Est: 10m - - [x] S23.2.3 Run golangci-lint and go test -cover Est: 5m - -#### E24: Configuration Management - -Add file-based configuration loading with validation and environment -variable overrides. Use encoding/json and os.Getenv from the standard library. - -- [x] T24.1 Create config package with file loader Owner: TBD Est: 1h Completed: 2026 03 01 - - Dependencies: None - - Acceptance: A config.Load[T](path string) function reads a JSON file into a struct. A config.LoadWithEnv[T](path, prefix string) function additionally applies environment variable overrides using the `env` struct tag. Validation errors list all invalid fields. Missing required fields produce clear error messages. - - [x] S24.1.1 Create config/loader.go with Load[T] function (JSON decoder) Est: 15m - - [x] S24.1.2 Implement env var override via struct tag reflection Est: 20m - - [x] S24.1.3 Implement validation via `validate:"required"` struct tag Est: 15m - - [x] S24.1.4 Write unit tests: valid config, missing file, invalid JSON, missing required, env override Est: 20m - - [x] S24.1.5 Run golangci-lint and go test -cover Est: 5m - -- [x] T24.2 Define standard config structs for Engine and Training Owner: TBD Est: 30m Completed: 2026 03 01 - - Dependencies: T24.1 - - Acceptance: EngineConfig (device type, memory limit, log level), TrainingConfig (batch size, learning rate, optimizer, epochs, checkpoint interval), DistributedConfig (coordinator address, TLS config, timeout). Each struct has JSON tags and validation tags. - - [x] S24.2.1 Define EngineConfig, TrainingConfig, DistributedConfig structs Est: 15m - - [x] S24.2.2 Write tests loading each config from JSON with env overrides Est: 10m - - [x] S24.2.3 Run golangci-lint and go test -cover Est: 5m - -#### E25: Graceful Shutdown - -Implement orderly shutdown coordination using context cancellation -and cleanup callbacks. - -- [x] T25.1 Add Closer interface and shutdown coordinator Owner: TBD Est: 45m Completed: 2026 03 01 - - Dependencies: None - - Acceptance: A shutdown.Coordinator registers Closer instances in order. On Shutdown(ctx), it calls Close() on each in reverse registration order. If a Closer does not complete within the context deadline, it is skipped and logged. Integration test demonstrates orderly cleanup. - - [x] S25.1.1 Create shutdown/coordinator.go with Closer interface and Coordinator Est: 20m - - [x] S25.1.2 Implement reverse-order shutdown with timeout per closer Est: 15m - - [x] S25.1.3 Write tests: orderly shutdown, timeout on slow closer, empty coordinator Est: 15m - - [x] S25.1.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T25.2 Implement Closer for Engine and distributed components Owner: TBD Est: 30m Completed: 2026 03 01 - - Dependencies: T25.1 - - Acceptance: GPUEngine.Close() drains memory pool and destroys CUDA handles. CPUEngine.Close() is a no-op (satisfies interface). Distributed Strategy.Shutdown() deregisters from coordinator and closes connections. All Close methods are idempotent. - - [x] S25.2.1 Make CPUEngine implement Closer (no-op Close) Est: 5m - - [x] S25.2.2 Verify GPUEngine.Close() is idempotent Est: 10m - - [x] S25.2.3 Make distributed Strategy implement Closer Est: 10m - - [x] S25.2.4 Write integration test: register Engine + Strategy, trigger shutdown Est: 15m - - [x] S25.2.5 Run golangci-lint and go test -cover Est: 5m - -- [x] T25.3 Add signal handling to CLI commands Owner: TBD Est: 30m Completed: 2026 03 01 - - Dependencies: T25.1, T25.2 - - Acceptance: cmd/zerfoo and cmd/zerfoo-predict catch SIGINT/SIGTERM, trigger shutdown coordinator, and exit cleanly. Integration test verifies signal handling. - - [x] S25.3.1 Add signal listener in cmd framework that cancels root context Est: 15m - - [x] S25.3.2 Wire shutdown coordinator into CLI lifecycle Est: 10m - - [x] S25.3.3 Write test verifying clean exit on SIGTERM Est: 10m - - [x] S25.3.4 Run golangci-lint and go test -cover Est: 5m - -#### E26: Health Checks - -Add health check endpoints for deployment probes (Kubernetes liveness -and readiness). - -- [x] T26.1 Create health check HTTP server Owner: TBD Est: 45m Completed: 2026 03 01 - - Dependencies: T21.1 - - Acceptance: A health.Server exposes /healthz (liveness) and /readyz (readiness) HTTP endpoints. Each returns 200 OK with JSON body when healthy, 503 when unhealthy. Readiness checks are configurable (register check functions). Server starts on a configurable port. Logger is used for startup/error messages. - - [x] S26.1.1 Create health/server.go with Server struct and HTTP handlers Est: 15m - - [x] S26.1.2 Implement configurable readiness checks (func() error callbacks) Est: 10m - - [x] S26.1.3 Write tests: healthy response, unhealthy readiness, concurrent access Est: 15m - - [x] S26.1.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T26.2 Add engine health check Owner: TBD Est: 20m Completed: 2026 03 01 - - Dependencies: T26.1 - - Acceptance: A check function verifies Engine is operational (e.g., small tensor add succeeds). For GPU, additionally verify CUDA context is valid. Register as readiness check. - - [x] S26.2.1 Implement engine health check function Est: 10m - - [x] S26.2.2 Write test for healthy and unhealthy engine Est: 10m - - [x] S26.2.3 Run golangci-lint and go test -cover Est: 5m - -#### E27: CI/CD Hardening - -Make CI pipeline enforce quality gates strictly. - -- [x] T27.1 Make parity and numerics tests blocking Owner: TBD Est: 15m Completed: 2026 03 01 - - Dependencies: None - - Acceptance: Remove `|| true` from parity and numerics test steps in .github/workflows/ci.yml. CI fails if any parity or numerics test fails. - - [x] S27.1.1 Update ci.yml: remove `|| true` from parity test step Est: 5m - - [x] S27.1.2 Update ci.yml: remove `|| true` from numerics test step Est: 5m - - [x] S27.1.3 Verify CI passes with current test suite Est: 5m - -- [x] T27.2 Add coverage gate to CI Owner: TBD Est: 30m Completed: 2026 03 01 - - Dependencies: None - - Acceptance: CI step runs `go test -coverprofile=coverage.out ./...`, parses output, and fails if any testable package (excluding documented exceptions) drops below 93%. Coverage summary is posted as a CI artifact. - - [x] S27.2.1 Add coverage step to ci.yml that generates coverage.out Est: 10m - - [x] S27.2.2 Write a Go script (cmd/coverage-gate/main.go) that parses coverage.out and exits non-zero if below threshold Est: 20m - - [x] S27.2.3 Add tests for coverage-gate script Est: 10m - - [x] S27.2.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T27.3 Add benchmark regression detection Owner: TBD Est: 45m Completed: 2026 03 01 - - Dependencies: None - - Acceptance: CI runs benchmarks on each PR. A Go script compares benchmark results against a baseline (stored in repo). CI fails if any benchmark regresses by more than 10%. Baseline is updated via a manual workflow dispatch. - - [x] S27.3.1 Add benchmark step to ci.yml (go test -bench=. -benchmem -count=3) Est: 10m - - [x] S27.3.2 Write cmd/bench-compare/main.go to parse benchstat output and enforce threshold Est: 25m - - [x] S27.3.3 Add baseline benchmark results file (benchmarks/baseline.txt) Est: 5m - - [x] S27.3.4 Add tests for bench-compare script Est: 10m - - [x] S27.3.5 Run golangci-lint and go test -cover Est: 5m - -- [x] T27.4 Update CI Go version and add race detector Owner: TBD Est: 15m Completed: 2026 03 01 - - Dependencies: None - - Acceptance: CI uses Go 1.25 (matching go.mod). Race detector runs on unit tests. Both Ubuntu and macOS runners are used. - - [x] S27.4.1 Update ci.yml go-version to match go.mod Est: 5m - - [x] S27.4.2 Add -race flag to unit test step Est: 5m - - [x] S27.4.3 Add macOS runner to test matrix Est: 5m - -#### E28: Resource Limits - -Add configurable resource limits to prevent unbounded allocation and -runaway operations. +## 3. Work Breakdown -- [x] T28.1 Add memory limit to Engine Owner: TBD Est: 45m Completed: 2026 03 01 - - Dependencies: None - - Acceptance: Engine accepts a MaxMemoryBytes option. Tensor allocation that would exceed the limit returns an error instead of allocating. GPU engine tracks device memory usage. The limit is enforced at the Engine level, not the allocator level (so it applies to both CPU and GPU). - - [x] S28.1.1 Add MemoryTracker to compute package (atomic int64 tracking allocated bytes) Est: 15m - - [x] S28.1.2 Integrate MemoryTracker into tensor allocation (New, NewWithStorage) Est: 15m - - [x] S28.1.3 Add MaxMemoryBytes option to Engine constructors Est: 10m - - [x] S28.1.4 Write tests: allocation within limit succeeds, over limit returns error Est: 15m - - [x] S28.1.5 Run golangci-lint and go test -cover Est: 5m +### Phase 1: Test Coverage (Complete) -- [x] T28.2 Add per-operation timeout enforcement Owner: TBD Est: 30m Completed: 2026 03 01 - - Dependencies: None - - Acceptance: Engine respects context.Context deadlines. Long-running operations (MatMul, Softmax) check ctx.Done() periodically and return context.DeadlineExceeded if expired. GPU operations use CUDA stream synchronization with timeout. - - [x] S28.2.1 Add ctx.Done() checks in CPUEngine parallelFor loops Est: 15m - - [x] S28.2.2 Add stream sync timeout in GPUEngine operations Est: 10m - - [x] S28.2.3 Write tests: operation completes within deadline, times out correctly Est: 15m - - [x] S28.2.4 Run golangci-lint and go test -cover Est: 5m +30 of 33 packages at >= 95% statement coverage. Details in docs/design.md +Section 7. -#### E29: GPU Hardware Validation (Blocked) +### Phase 2: GPU Engine (Complete) -Validate all GPU code on real NVIDIA hardware. +CUDA float32 GPUEngine with memory pool, cuBLAS, OOM fallback. 20 GPU- +accelerated operations. See [ADR-006](adr/006-gpu-engine-architecture.md). -- [ ] T29.1 Create GCP T4 spot VM and validate GPU tests Owner: TBD Est: 1h **BLOCKED:** GCP GPU quota = 0. Quota increase request pending (preference ID: zerfoo-gpu-test). - - Dependencies: None - - Acceptance: `go test -tags cuda ./...` passes on real T4 hardware. Benchmark results captured and documented in docs/design.md. - - Unblock action: Check quota status via `gcloud beta quotas preferences describe zerfoo-gpu-test --project=numerai-488804`. If still denied, try a different GCP project or cloud provider. - - [ ] S29.1.1 Create n1-standard-4 spot VM with T4 GPU Est: 5m - - [ ] S29.1.2 Install CUDA Toolkit 12.x and Go 1.25, clone repo Est: 15m - - [ ] S29.1.3 Build with `go build -tags cuda ./...` and fix any build issues Est: 10m - - [ ] S29.1.4 Run `go test -tags cuda ./...` and capture output Est: 10m - - [ ] S29.1.5 Run benchmarks and save results Est: 5m - - [ ] S29.1.6 Delete VM immediately Est: 2m - - [ ] S29.1.7 Document results in docs/design.md Est: 10m +### Phase 3: GPU Production Readiness (Complete) -- [ ] T29.2 Run optimized benchmarks on T4 Owner: TBD Est: 30m **BLOCKED:** Depends on T29.1. - - Dependencies: T29.1 - - Acceptance: Benchmark results for MatMul (128/512/1024), Softmax, and chained attention ops documented with Phase 3 device-resident pipeline. - - [ ] S29.2.1 Run benchmarks with -benchmem and capture results Est: 10m - - [ ] S29.2.2 Update docs/design.md with benchmark table Est: 15m - - [ ] S29.2.3 Delete VM Est: 2m +Device-resident pipeline, stream pipelining, parity tests. Hardware validation +blocked (see E29 below). -#### E30: Production Documentation +### Phase 4: Enterprise Production Readiness (Complete) -Create operational documentation for production deployment. +Logging (E21), metrics (E22), TLS (E23), config (E24), shutdown (E25), health +(E26), CI hardening (E27), resource limits (E28), documentation (E30), +verification (E31). See [ADR-001](adr/001-enterprise-production-readiness.md). -- [x] T30.1 Write deployment runbook Owner: TBD Est: 1h Completed: 2026 03 01 - - Dependencies: E21, E23, E24, E25, E26 - - Acceptance: docs/runbook.md covers: system requirements, installation steps, configuration reference (all config fields documented), startup sequence, health check verification, log interpretation, common operational tasks (scale workers, update model, restart), shutdown procedure. - - [x] S30.1.1 Write system requirements and installation section Est: 15m - - [x] S30.1.2 Write configuration reference (all config structs documented) Est: 15m - - [x] S30.1.3 Write startup, health check, and shutdown sections Est: 15m - - [x] S30.1.4 Write common operational tasks Est: 15m +### Phase 5: Distributed Training Protocol (Complete) -- [x] T30.2 Write troubleshooting guide Owner: TBD Est: 45m Completed: 2026 03 01 - - Dependencies: E21, E22 - - Acceptance: docs/troubleshooting.md covers: common error messages with root causes and fixes, GPU-specific issues (CUDA not found, OOM, driver mismatch), distributed training issues (connection refused, timeout, split brain), performance diagnosis (how to identify bottlenecks, pprof usage). - - [x] S30.2.1 Document common error messages and fixes Est: 15m - - [x] S30.2.2 Document GPU troubleshooting Est: 10m - - [x] S30.2.3 Document distributed training troubleshooting Est: 10m - - [x] S30.2.4 Document performance diagnosis with pprof Est: 10m +Worker service (E32), gRPC strategy (E33), multi-worker integration tests (E34), +worker lifecycle + CLI (E35), verification (E36). distributed/ at 96% coverage. +See [ADR-002](adr/002-distributed-training-protocol.md). -- [x] T30.3 Add pprof endpoints to health server Owner: TBD Est: 20m Completed: 2026 03 01 - - Dependencies: T26.1 - - Acceptance: Health server registers net/http/pprof handlers. CPU profile, heap profile, goroutine dump available at /debug/pprof/*. - - [x] S30.3.1 Register pprof handlers in health.Server Est: 10m - - [x] S30.3.2 Write test verifying pprof endpoints respond Est: 10m - - [x] S30.3.3 Run golangci-lint and go test -cover Est: 5m +### Phase 6: Open Weights Model Import (Complete) -#### E31: Final Verification +Gemma 3 ONNX import (E37), core operators (E38), vision encoder operators (E39), +MoE (E40), Gemma 3 validation (E41), Kimi-VL validation (E42), verification +(E43). 13 new operators added. See [ADR-003](adr/003-open-weights-model-import.md). -Run the full quality gate suite after all enterprise features are implemented. +### Phase 7: Architecture Cleanup (Complete) -- [x] T31.1 Run full test suite with coverage Owner: TBD Est: 30m Completed: 2026 03 01 - - Dependencies: E21, E22, E23, E24, E25, E26, E27, E28 - - Acceptance: `go test ./... -cover` shows all packages at target coverage. `go test ./... -race` shows zero races. New packages (log, config, health, shutdown, metrics/runtime) are all at >= 95%. - - [x] S31.1.1 Run go test ./... -cover Est: 10m - - [x] S31.1.2 Run go test ./... -race Est: 10m - - [x] S31.1.3 Verify new packages meet 95% coverage Est: 10m +Dead code removal (E44), registration consolidation (E45), graph thread safety +(E46), documentation (E47), verification (E48). +See [ADR-001](adr/001-enterprise-production-readiness.md). -- [x] T31.2 Run linters and formatters Owner: TBD Est: 15m Completed: 2026 03 01 - - Dependencies: T31.1 - - Acceptance: golangci-lint 0 issues, go vet clean, gofmt clean. - - [x] S31.2.1 Run golangci-lint run ./... Est: 5m - - [x] S31.2.2 Run go vet ./... Est: 5m - - [x] S31.2.3 Run gofmt -l . and verify no files Est: 5m +### Phase 8: Embeddable Inference Library (Complete) -- [x] T31.3 Run integration smoke test Owner: TBD Est: 30m Completed: 2026 03 01 - - Dependencies: T31.1 - - Acceptance: End-to-end test: load config from file, create Engine, run forward pass, verify health check, trigger graceful shutdown. All within a single test binary. - - [x] S31.3.1 Write integration test covering config -> engine -> health -> shutdown Est: 20m - - [x] S31.3.2 Run integration test Est: 5m - - [x] S31.3.3 Run golangci-lint Est: 5m +BPE tokenizer (E49), KV cache (E50), generation loop (E51), streaming (E52), +model registry (E53), high-level API (E54), CLI commands (E55), end-to-end +validation (E56). See [ADR-004](adr/004-embeddable-inference-library.md). ---- - -### Phase 5: Concrete Distributed Service Server - -#### Phase 5 Context - -Phase 4 enterprise production readiness is complete except for E29 (GPU -validation, blocked on GCP quota) and T23.2 (RPC input validation, skipped -because no concrete DistributedServiceServer implementation existed). - -The distributed package currently has: -- Auto-generated protobuf stubs for DistributedService (AllReduce bidi stream, - Barrier unary, Broadcast unary) in distributed/pb/dist.proto. -- Auto-generated protobuf stubs for Coordinator (RegisterWorker, - UnregisterWorker, Heartbeat, StartCheckpoint, EndCheckpoint) in - distributed/pb/coordinator.proto. -- A fully implemented Coordinator gRPC server - (distributed/coordinator/coordinator.go) with worker management, heartbeat - reaper, and checkpoint coordination. -- InternalStrategy[T] interface (distributed/interfaces.go) defining Init, - AllReduceGradients, Barrier, BroadcastTensor, Rank, Size, Shutdown. -- AllReduceStrategy[T] (distributed/all_reduce.go) implementing hierarchical - all-reduce using local + cross-node InternalStrategy instances. -- NetworkManager (distributed/network_manager.go) for establishing peer gRPC - client connections. -- ServerManager (distributed/network_manager.go) for gRPC server lifecycle - management (start, stop, graceful stop). -- TLS/mTLS configuration (distributed/tlsconfig.go). -- GrpcServer, ListenerFactory, Dialer, ServiceClientFactory type aliases - (distributed/interfaces.go). -- CoordinatorClient interface (distributed/interfaces.go). -- Comprehensive custom mock implementations for testing - (distributed/custom_mocks_test.go). - -What is missing: -1. A concrete DistributedServiceServer implementation -- the actual gRPC - handler that runs on each worker node and processes incoming AllReduce, - Barrier, and Broadcast RPCs from peers. -2. A GrpcStrategy[T] that implements InternalStrategy[T] using gRPC transport, - connecting the high-level AllReduceStrategy to the network layer. -3. A WorkerNode struct that ties together the server, strategy, coordinator - registration, health checks, and shutdown coordination. -4. Input validation on RPC handlers (the previously skipped T23.2). -5. Multi-worker integration tests proving distributed operations work - end-to-end over real gRPC connections. - -#### Phase 5 Objectives +### Phase 9: Multi-Architecture Support (Complete) -- P5-O1: Implement a concrete DistributedServiceServer with AllReduce, - Barrier, and Broadcast handlers including input validation. -- P5-O2: Implement GrpcStrategy[T] connecting InternalStrategy[T] to gRPC - transport. -- P5-O3: Create multi-worker integration tests proving correctness over - real gRPC connections (using bufconn for in-process testing). -- P5-O4: Implement worker lifecycle management (init, run, shutdown) - integrated with existing CLI, health checks, and shutdown coordinator. +Config parsing (E57), param resolver (E58), Llama/Mistral validation (E59), +QKV bias (E60), YaRN RoPE (E61), Qwen validation (E62), partial RoPE (E63), +tied embeddings (E64), Phi-4 validation (E65), MLA (E66), shared MoE (E67), +DeepSeek validation (E68), verification (E69). +See [ADR-005](adr/005-multi-architecture-support.md). -#### Phase 5 Non-Goals +### Blocked Items -- Ring all-reduce optimization. Use star topology (reduce to root, broadcast - from root) for correctness first. Ring optimization is a future Phase 6 - task. -- Gradient compression or sparsification. -- Fault-tolerant training with automatic recovery from worker failures. -- Dynamic worker join or leave during a training step. -- Multi-GPU per worker. +#### E29: GPU Hardware Validation -#### Phase 5 Design Decisions - -**AllReduce Protocol (Star Topology):** -The AllReduce bidi stream implements a star-topology reduce. Root (rank 0) -runs the server that collects gradients from all peers. Each non-root worker -opens a bidi stream to root, sends its gradients as AllReduceRequest messages -(one per named tensor), then waits for root to send back AllReduceResponse -messages with the averaged result. Root accumulates all peer gradients plus -its own local gradients, computes the element-wise average (sum / world_size), -and streams the result back to each peer. - -The server uses a reduceSession struct to coordinate across concurrent -AllReduce stream handlers. The session collects tensors by name from each -peer, waits for all peers via a sync barrier, computes the reduction, and -distributes the result. - -**Barrier Protocol:** -Barrier uses a simple counter-based approach. Each worker calls Barrier RPC -on the root. Root counts arrivals and blocks each caller until all workers -have arrived. Uses sync.Cond for efficient waiting. Each barrier has an -epoch number to prevent stale barrier responses. - -**Broadcast Protocol:** -Root sends a BroadcastRequest with the tensor. Non-root workers call -Broadcast RPC on root. Root returns the tensor in the BroadcastResponse. -Root stores the broadcast tensor in a thread-safe map keyed by name so -concurrent callers all receive the same data. - -**Tensor Serialization:** -The pb.Tensor message uses repeated float for data (float32 only). The -GrpcStrategy[T] converts tensor.TensorNumeric[T] to/from pb.Tensor. For -T=float32, this is a direct copy. For T=float64, values are narrowed to -float32 for transport (acceptable for gradient averaging where precision -loss is tolerable). +- [ ] T29.1 Create GCP T4 spot VM and validate GPU tests **BLOCKED:** GCP GPU quota = 0. + - Quota increase request pending (preference ID: zerfoo-gpu-test, project: numerai-488804). + - Unblock: `gcloud beta quotas preferences describe zerfoo-gpu-test --project=numerai-488804` + - Alternative: try a different GCP project or cloud provider. + - Steps: create n1-standard-4 spot VM with T4, install CUDA 12.x + Go 1.25, + `go test -tags cuda ./...`, capture benchmarks, delete VM immediately. +- [ ] T29.2 Run optimized benchmarks on T4 **BLOCKED:** Depends on T29.1. + - Benchmark MatMul (128/512/1024), Softmax, chained attention ops. + - Document results in docs/design.md. --- -#### E32: Worker Service (DistributedServiceServer) - -Implement the concrete gRPC service handler that runs on each worker node, -processing AllReduce, Barrier, and Broadcast RPCs from peers. - -- [x] T32.1 Create workerService struct with reduce session coordinator Owner: TBD Est: 1.5h Completed: 2026 03 01 - - Dependencies: None - - Acceptance: A workerService struct in distributed/worker_service.go implements pb.DistributedServiceServer. It embeds pb.UnimplementedDistributedServiceServer. Fields include rank (int32), worldSize (int32), logger (log.Logger), collector (metrics/runtime.Collector). A reduceSession struct coordinates all-reduce across concurrent streams: it collects tensors by name from each peer, uses a sync barrier (sync.Cond or channels) to wait for all peers, computes the element-wise sum, and distributes the result. Static interface assertion var _ pb.DistributedServiceServer = (*workerService)(nil) compiles. - - [x] S32.1.1 Create distributed/worker_service.go with workerService struct, constructor NewWorkerService(rank, worldSize int32, logger log.Logger) *workerService Est: 15m - - [x] S32.1.2 Implement reduceSession struct with Submit(peerRank int32, tensors map[string]*pb.Tensor) and WaitForResult() map[string]*pb.Tensor methods Est: 30m - - [x] S32.1.3 Implement NewReduceSession(worldSize int32) *reduceSession constructor Est: 10m - - [x] S32.1.4 Write unit tests for reduceSession: two peers submit, both get averaged result; timeout when one peer missing; concurrent submission safety Est: 30m - - [x] S32.1.5 Run golangci-lint and go test -cover on distributed/ Est: 5m - -- [x] T32.2 Implement AllReduce bidi stream handler Owner: TBD Est: 1.5h Completed: 2026 03 01 - - Dependencies: T32.1 - - Acceptance: workerService.AllReduce(stream) receives all AllReduceRequest messages from a peer until EOF, submits them to the active reduceSession, waits for the global result, and sends AllReduceResponse messages back on the stream. Root (rank 0) also contributes its own local tensors via a SetLocalTensors method. Multiple concurrent streams (one per non-root peer) are handled correctly. Metrics are recorded: allreduce_server_count (counter), allreduce_server_duration_seconds (histogram). - - [x] S32.2.1 Implement AllReduce method on workerService: recv loop, submit to session, wait, send loop Est: 30m - - [x] S32.2.2 Add SetLocalTensors(tensors map[string]*pb.Tensor) method for root to inject its own gradients Est: 15m - - [x] S32.2.3 Add NewSession() method to reset the reduce session for each training step Est: 10m - - [x] S32.2.4 Write unit tests using mock bidi streams: single peer, two peers, stream error mid-recv Est: 30m - - [x] S32.2.5 Run golangci-lint and go test -cover Est: 5m - -- [x] T32.3 Implement Barrier handler Owner: TBD Est: 45m Completed: 2026 03 01 - - Dependencies: T32.1 - - Acceptance: workerService.Barrier(ctx, req) increments an arrival counter for the current barrier epoch. When arrivals equal worldSize, all blocked callers are released and BarrierResponse is returned. If the context deadline expires before all peers arrive, the handler returns a DeadlineExceeded gRPC status. Barrier epoch increments after each completed barrier to prevent stale responses. - - [x] S32.3.1 Add barrierState struct to workerService with epoch int64, arrived int32, mu sync.Mutex, cond *sync.Cond Est: 15m - - [x] S32.3.2 Implement Barrier method: increment arrived, wait on cond, broadcast when all arrived Est: 15m - - [x] S32.3.3 Write unit tests: 3 concurrent callers all released, timeout when one missing, sequential barriers with epoch increment Est: 20m - - [x] S32.3.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T32.4 Implement Broadcast handler Owner: TBD Est: 45m Completed: 2026 03 01 - - Dependencies: T32.1 - - Acceptance: workerService.Broadcast(ctx, req) stores the broadcast tensor in a thread-safe map keyed by name. Non-root workers call this RPC on root to retrieve the broadcast tensor. Root sets the tensor via a SetBroadcastTensor(name string, tensor *pb.Tensor) method before non-root workers call. If the tensor is not yet available, the handler waits (with context deadline) for it to be set. - - [x] S32.4.1 Add broadcastStore (sync.Map or mutex-guarded map) to workerService with wait channels Est: 15m - - [x] S32.4.2 Implement Broadcast method and SetBroadcastTensor method Est: 15m - - [x] S32.4.3 Write unit tests: set then retrieve, wait then set (concurrent), timeout Est: 15m - - [x] S32.4.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T32.5 Add input validation to all RPC handlers Owner: TBD Est: 30m Completed: 2026 03 01 - - Dependencies: T32.2, T32.3, T32.4 - - Acceptance: AllReduce validates non-nil tensor, non-empty name, valid shape (all dimensions > 0, product matches data length). Barrier validates rank is in range [0, worldSize). Broadcast validates non-nil tensor, non-empty name, valid shape. Invalid requests return gRPC InvalidArgument status with descriptive message. This task completes the previously skipped T23.2. - - Risk: Must not break existing Coordinator RPC validation (already has validation in coordinator.go). - - [x] S32.5.1 Add validateTensor(t *pb.Tensor, fieldName string) error helper Est: 10m - - [x] S32.5.2 Add validation calls at the top of AllReduce, Barrier, Broadcast Est: 10m - - [x] S32.5.3 Write tests for each validation error case (nil tensor, empty name, shape mismatch, rank out of range) Est: 15m - - [x] S32.5.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T32.6 Run linters and verify coverage for E32 Owner: TBD Est: 15m Completed: 2026 03 01 - - Dependencies: T32.5 - - Acceptance: golangci-lint reports 0 issues on distributed/. go test -cover -race ./distributed/ shows >= 95% coverage on worker_service.go. go vet ./distributed/ clean. - - [x] S32.6.1 Run golangci-lint run ./distributed/ Est: 5m - - [x] S32.6.2 Run go test -cover -race ./distributed/ Est: 5m - - [x] S32.6.3 Fix any remaining lint or coverage gaps Est: 5m - -#### E33: gRPC Strategy (InternalStrategy[T] over gRPC) - -Implement GrpcStrategy[T] that connects the InternalStrategy[T] interface -to the gRPC transport layer, bridging the high-level AllReduceStrategy with -the concrete WorkerService. - -- [x] T33.1 Create GrpcStrategy[T] struct Owner: TBD Est: 1h Completed: 2026 03 01 - - Dependencies: E32 - - Acceptance: A GrpcStrategy[T] struct in distributed/grpc_strategy.go implements InternalStrategy[T]. Fields: rank int, size int, workerService *workerService, serverManager ServerManager, networkManager NetworkManager, peerClients []pb.DistributedServiceClient, peerConns []*grpc.ClientConn, coordinatorClient CoordinatorClient, coordinatorConn *grpc.ClientConn, logger log.Logger, collector metrics/runtime.Collector. Static interface assertion var _ InternalStrategy[float32] = (*GrpcStrategy[float32])(nil) compiles. - - [x] S33.1.1 Create distributed/grpc_strategy.go with struct and NewGrpcStrategy constructor Est: 20m - - [x] S33.1.2 Add tensor conversion helpers: tensorToProto(t *tensor.TensorNumeric[T]) *pb.Tensor and protoToTensor(p *pb.Tensor) (*tensor.TensorNumeric[T], error) Est: 20m - - [x] S33.1.3 Write unit tests for tensor conversion round-trip (float32, various shapes) Est: 15m - - [x] S33.1.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T33.2 Implement Init (register, connect, start server) Owner: TBD Est: 1h Completed: 2026 03 01 - - Dependencies: T33.1 - - Acceptance: GrpcStrategy.Init(rank, size, coordinatorAddress) registers the worker with the coordinator via RegisterWorker RPC, receives the assigned rank and peer addresses, starts the local gRPC server (workerService) via ServerManager, and connects to all peer workers via NetworkManager.ConnectToPeers. After Init, the strategy is ready for AllReduceGradients calls. - - [x] S33.2.1 Implement Init method: register with coordinator, start server, connect to peers Est: 30m - - [x] S33.2.2 Write unit tests with mock coordinator, ServerManager, and NetworkManager Est: 25m - - [x] S33.2.3 Run golangci-lint and go test -cover Est: 5m - -- [x] T33.3 Implement AllReduceGradients Owner: TBD Est: 1.5h Completed: 2026 03 01 - - Dependencies: T33.2 - - Acceptance: GrpcStrategy.AllReduceGradients(gradients) converts each gradient tensor to pb.Tensor, opens an AllReduce bidi stream to root (rank 0), sends all gradients, receives the averaged result, and converts back to tensor.TensorNumeric[T], updating the gradient map in place. If this worker IS root (rank 0): sets local tensors on workerService, creates a new reduce session, and waits for peers to complete the all-reduce. Metrics: allreduce_client_count, allreduce_client_duration_seconds. - - [x] S33.3.1 Implement AllReduceGradients for non-root workers: open stream to root, send gradients, recv result Est: 30m - - [x] S33.3.2 Implement AllReduceGradients for root worker: set local tensors, new session, wait for completion Est: 30m - - [x] S33.3.3 Write unit tests: non-root sends and receives (mock stream), root processes (mock peers) Est: 25m - - [x] S33.3.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T33.4 Implement Barrier and BroadcastTensor Owner: TBD Est: 45m Completed: 2026 03 01 - - Dependencies: T33.2 - - Acceptance: GrpcStrategy.Barrier() calls Barrier RPC on root (rank 0). Root calls its own workerService.Barrier locally. Non-root workers send BarrierRequest with their rank. GrpcStrategy.BroadcastTensor(t, rootRank) root converts tensor to proto and sets it on workerService via SetBroadcastTensor, then non-root workers call Broadcast RPC on root to retrieve it. After receiving, non-root workers update the tensor in place. - - [x] S33.4.1 Implement Barrier: non-root calls RPC on root, root calls local service Est: 15m - - [x] S33.4.2 Implement BroadcastTensor: root sets, non-root retrieves via RPC Est: 15m - - [x] S33.4.3 Write unit tests with mock clients Est: 15m - - [x] S33.4.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T33.5 Implement Shutdown Owner: TBD Est: 30m Completed: 2026 03 01 - - Dependencies: T33.2 - - Acceptance: GrpcStrategy.Shutdown() unregisters the worker from the coordinator (UnregisterWorker RPC), closes all peer connections (NetworkManager.CloseConnections), stops the gRPC server (ServerManager.GracefulStop), and closes the coordinator connection. All operations are idempotent via sync.Once. No panic on double-call. - - [x] S33.5.1 Implement Shutdown with sync.Once and ordered cleanup Est: 15m - - [x] S33.5.2 Write unit tests: single shutdown, double shutdown (idempotent), shutdown with failed unregister Est: 15m - - [x] S33.5.3 Run golangci-lint and go test -cover Est: 5m +## 4. Timeline Summary -- [x] T33.6 Run linters and verify coverage for E33 Owner: TBD Est: 15m Completed: 2026 03 01 - - Dependencies: T33.5 - - Acceptance: golangci-lint reports 0 issues on distributed/. go test -cover -race ./distributed/ shows >= 95% coverage on grpc_strategy.go. go vet clean. - - [x] S33.6.1 Run golangci-lint, go vet, go test -cover -race Est: 10m - - [x] S33.6.2 Fix any remaining issues Est: 5m - -#### E34: Multi-Worker Integration Tests - -Prove distributed operations work correctly over real gRPC connections -using in-process bufconn listeners (same pattern as coordinator tests). - -- [x] T34.1 In-process multi-worker AllReduce integration test Owner: TBD Est: 1.5h Completed: 2026 03 01 - - Dependencies: E33 - - Acceptance: A test starts a coordinator and 3 GrpcStrategy workers in the same process using bufconn. Each worker has different gradient tensors. After AllReduceGradients, all workers have identical averaged gradients. Mathematical correctness: if worker 0 has [1,2,3], worker 1 has [4,5,6], worker 2 has [7,8,9], all should get [4,5,6] after averaging. Test runs with -race flag. - - [x] S34.1.1 Create distributed/integration_test.go with bufconn test harness (start coordinator, create workers) Est: 30m - - [x] S34.1.2 Write TestMultiWorkerAllReduce with 3 workers and verify averaged gradients Est: 30m - - [x] S34.1.3 Write TestMultiWorkerAllReduce_SingleWorker edge case (world size = 1) Est: 15m - - [x] S34.1.4 Run with -race flag Est: 5m - -- [x] T34.2 In-process Barrier and Broadcast integration tests Owner: TBD Est: 1h Completed: 2026 03 01 - - Dependencies: T34.1 - - Acceptance: Barrier test: 3 workers call Barrier concurrently; all are released after the last worker arrives; timing proves no worker proceeds early. Broadcast test: root broadcasts tensor [10,20,30] to all workers; all non-root workers receive exact copy. - - [x] S34.2.1 Write TestMultiWorkerBarrier with 3 workers and timing verification Est: 20m - - [x] S34.2.2 Write TestMultiWorkerBroadcast from root to 2 non-root workers Est: 20m - - [x] S34.2.3 Run with -race flag Est: 5m - -- [x] T34.3 Error and edge case integration tests Owner: TBD Est: 45m Completed: 2026 03 01 Note: TestAllReduce_ContextCancellation implemented; S34.3.2 and S34.3.3 covered by existing tests - - Dependencies: T34.1 - - Acceptance: Test context cancellation during AllReduce (one worker cancels mid-stream, others get error). Test invalid inputs rejected over the wire (gRPC InvalidArgument status). Test single-worker mode (world size = 1, all ops are no-ops or self-reduces). - - [x] S34.3.1 Write TestAllReduce_ContextCancellation Est: 15m - - [x] S34.3.2 Write TestAllReduce_InvalidInput over gRPC Est: 15m - - [x] S34.3.3 Write TestSingleWorker (world size 1) Est: 10m - - [x] S34.3.4 Run with -race flag Est: 5m - -- [x] T34.4 TLS multi-worker integration test Owner: TBD Est: 30m Completed: 2026 03 02 - - Dependencies: T34.1 - - Acceptance: Same as T34.1 but with TLS enabled using self-signed certificates (generated at test time). Verifies TLS handshake works for both coordinator and peer connections. Uses the existing TLSConfig from distributed/tlsconfig.go. - - [x] S34.4.1 Add TLS cert generation helper to test (reuse pattern from tlsconfig_test.go) Est: 10m - - [x] S34.4.2 Write TestMultiWorkerAllReduce_TLS with TLS-enabled coordinator and workers Est: 15m - - [x] S34.4.3 Run with -race flag Est: 5m - -- [x] T34.5 Run linters and verify coverage for E34 Owner: TBD Est: 15m Completed: 2026 03 02 - - Dependencies: T34.4 - - Acceptance: golangci-lint 0 issues. go test -cover -race ./distributed/... shows integration tests pass. go vet clean. - - [x] S34.5.1 Run golangci-lint, go vet, go test -cover -race Est: 10m - - [x] S34.5.2 Fix any remaining issues Est: 5m - -#### E35: Worker Lifecycle and CLI Integration - -Create a WorkerNode struct that ties together the distributed components -and integrate with the CLI, health checks, and shutdown coordinator. - -- [x] T35.1 Create WorkerNode struct Owner: TBD Est: 1h Completed: 2026 03 02 - - Dependencies: E33 - - Acceptance: A WorkerNode struct in distributed/worker_node.go encapsulates: GrpcStrategy (or AllReduceStrategy wrapping two GrpcStrategies), coordinator connection, health check registration, and shutdown.Closer implementation. WorkerNode.Start(ctx, cfg) initializes the strategy, registers with the coordinator, starts the gRPC server, connects to peers, and registers an engine health check. WorkerNode.Close(ctx) triggers orderly shutdown. WorkerNode can be registered with the shutdown.Coordinator from the shutdown package. - - [x] S35.1.1 Create distributed/worker_node.go with WorkerNode struct and constructor Est: 20m - - [x] S35.1.2 Implement Start method: init strategy, register health check Est: 20m - - [x] S35.1.3 Implement Close method satisfying shutdown.Closer Est: 10m - - [x] S35.1.4 Write unit tests: start/stop lifecycle, double close is safe Est: 15m - - [x] S35.1.5 Run golangci-lint and go test -cover Est: 5m - -- [x] T35.2 Add worker CLI command Owner: TBD Est: 45m Completed: 2026 03 02 Note: Created in cmd/cli/worker.go and registered in cmd/zerfoo/main.go - - Dependencies: T35.1, T25.3 - - Acceptance: A `worker` subcommand in cmd/zerfoo starts a distributed training worker. Flags: --coordinator-address (required), --worker-address (required), --worker-id (defaults to hostname), --config (optional JSON config path). The command creates a WorkerNode, registers it with the shutdown coordinator, connects signal handling via cli.SignalContext, and blocks until SIGTERM/SIGINT. On signal, graceful shutdown is triggered. - - [x] S35.2.1 Create cmd/zerfoo/worker.go with worker command registration Est: 15m - - [x] S35.2.2 Implement worker command: parse flags, create WorkerNode, start, wait for signal Est: 20m - - [x] S35.2.3 Write test verifying command parses flags and creates worker (mock coordinator) Est: 15m - - [x] S35.2.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T35.3 End-to-end worker lifecycle integration test Owner: TBD Est: 45m Completed: 2026 03 02 - - Dependencies: T35.1, T35.2 - - Acceptance: Test starts a coordinator, starts 2 WorkerNodes, verifies both workers register successfully (coordinator reports 2 workers), runs a health check on each worker, then triggers shutdown. After shutdown, both workers have deregistered from the coordinator (coordinator reports 0 workers). Test runs with -race. - - [x] S35.3.1 Write TestWorkerNodeLifecycle in distributed/integration_test.go Est: 25m - - [x] S35.3.2 Verify health check integration (readiness check passes during run, fails after stop) Est: 15m - - [x] S35.3.3 Run with -race flag Est: 5m - -- [x] T35.4 Run linters and verify coverage for E35 Owner: TBD Est: 15m Completed: 2026 03 02 Note: distributed/ 96.0%, cmd/cli/ 91.4% - - Dependencies: T35.3 - - Acceptance: golangci-lint 0 issues. go test -cover -race ./distributed/... and ./cmd/zerfoo/... pass. go vet clean. - - [x] S35.4.1 Run golangci-lint, go vet, go test -cover -race Est: 10m - - [x] S35.4.2 Fix any remaining issues Est: 5m - -- [x] T35.5 Update plan and documentation Owner: TBD Est: 30m Completed: 2026 03 02 - - Dependencies: T35.4 - - Acceptance: docs/plan.md has all Phase 5 tasks marked complete. docs/runbook.md has a new "Distributed Worker Setup" section. docs/troubleshooting.md updated if new error patterns were discovered. T23.2 marked as completed via T32.5. - - [x] S35.5.1 Update docs/plan.md: mark all Phase 5 tasks [x], update progress log Est: 10m - - [x] S35.5.2 Add "Distributed Worker Setup" section to docs/runbook.md Est: 10m - - [x] S35.5.3 Review and update docs/troubleshooting.md Est: 10m - -#### E36: Phase 5 Final Verification - -Run the full quality gate suite after all Phase 5 work is complete. - -- [x] T36.1 Run full test suite with coverage and race detector Owner: TBD Est: 30m Completed: 2026 03 02 Note: distributed/ 96.0% coverage, all tests pass with -race - - Dependencies: E32, E33, E34, E35 - - Acceptance: go test ./... -cover -race passes. distributed/ package coverage >= 95%. No new data races. All existing tests still pass (no regressions). - - [x] S36.1.1 Run go test ./... -cover -race Est: 15m - - [x] S36.1.2 Verify distributed/ package coverage >= 95% Est: 10m - - [x] S36.1.3 Fix any regressions Est: 5m - -- [x] T36.2 Run linters and verify CI compatibility Owner: TBD Est: 15m Completed: 2026 03 02 Note: golangci-lint 0 issues, go vet clean on all packages - - Dependencies: T36.1 - - Acceptance: golangci-lint run ./... reports 0 issues. go vet ./... clean. CI workflow (ci.yml) does not need changes (existing test commands cover new code). - - [x] S36.2.1 Run golangci-lint run ./... Est: 5m - - [x] S36.2.2 Run go vet ./... Est: 5m - - [x] S36.2.3 Verify ci.yml covers new code without changes Est: 5m - ---- - -### Phase 6: Open Weights Model Import Support - -#### Phase 6 Context - -Zerfoo can train and run inference on models built directly with its layer API. -Importing pre-trained open-weights models (Gemma 3, Kimi-VL) requires closing -gaps in the ONNX import pipeline (zonnx repo) and in the zerfoo layer registry. - -Gap analysis conducted on 2026 03 02 identified the following blockers: - -**Gemma 3 (4-bit quantized transformer, 18 layers, ONNX opset 21):** -- zonnx converter: AttributeProto_TENSOR case missing in convertAttribute() - blocks 7 Constant nodes from converting. -- zonnx converter: UINT8 dtype missing in convertTensorWithPath() blocks - MatMulNBits (126 instances) quantized weight tensors. -- MatMulNBits and Constant layers exist in zerfoo (layers/core/) but lack - registry builder functions and are not registered in layers/registry/. -- model/builder.go has no dispatch for "MatMulNBits" or "Constant" ZMF node types. - -**Kimi-VL-A3B (MoonLight language model + SigLIP vision encoder):** -- Vision encoder uses Conv2d, Pad, Slice, Resize, BatchNormalization, - GlobalAveragePool -- none implemented in zerfoo. -- Softmax exists in the compute engine but is not registered as a graph layer node. -- Standard LayerNormalization (with bias) is not registered (only Simplified and - Skip variants are). -- Slice, Pad, TopK, Erf are missing entirely. -- MoE (Mixture of Experts) gate routing and expert dispatch are not implemented. - -#### Phase 6 Objectives - -- P6-O1: Fix zonnx converter to handle TENSOR attributes and UINT8 dtype. -- P6-O2: Register MatMulNBits and Constant in zerfoo layer registry. -- P6-O3: Implement Softmax, Sigmoid, LayerNormalization, Slice, Pad, TopK, Erf. -- P6-O4: Implement Conv2d, GlobalAveragePool, BatchNormalization, Resize. -- P6-O5: Implement MixtureOfExperts layer for Kimi-VL language model. -- P6-O6: Validate Gemma 3 end-to-end with a forward pass integration test. -- P6-O7: Validate Kimi-VL vision encoder end-to-end with a forward pass test. - -#### Phase 6 Non-Goals - -- KV cache and autoregressive decoding (future phase; requires graph execution changes). -- Beam search or nucleus sampling strategies. -- Operator fusion or CUDA acceleration for new operators (correctness first). -- Model quantization at import time (only loading pre-quantized 4-bit weights). -- ZMF sub-graph support (MoE will hold expert tensors directly as a workaround). - -#### Phase 6 Design Decisions - -**4-bit weight packing:** MatMulNBits stores 4-bit weights packed two-per-byte -in UINT8 tensors. ZMF uses DataType=UINT8 for these. Dequantization happens in -MatMulNBits.Forward() which already uses numeric.Unpack4BitSlice internally. - -**Conv2d strategy:** Use im2col + MatMul for correctness. im2col reshapes input -patches into a 2D matrix that is multiplied by the flattened kernel matrix. -This reuses the existing MatMul implementation without a specialized kernel. - -**Multi-repo discipline:** zonnx and zerfoo are separate repos. Pre-commit hooks -reject multi-directory commits. All zonnx changes are committed in the zonnx -repo; all zerfoo layer/model changes are committed in the zerfoo repo. - ---- - -#### E37: Complete Gemma 3 ONNX Import - -Fix the zonnx converter and zerfoo layer registry to support all operators -used by Gemma 3 4B-IT quantized (ONNX opset 21). - -- [x] T37.1 Fix TENSOR attribute and UINT8 dtype in zonnx converter Owner: TBD Est: 1h Completed: 2026 03 02 - - Dependencies: None - - Files: zonnx/pkg/converter/converter.go (convertAttribute, convertTensorWithPath) - - Result: AttributeProto_TENSOR case added (line 615). UINT8 and INT8 dtype cases added - (lines 666-669). Tests in converter_test.go including TestConvertAttribute_Tensor_UINT8. - - [x] S37.1.1 Add AttributeProto_TENSOR case in convertAttribute() Completed: 2026 03 02 - - [x] S37.1.2 Add UINT8 and INT8 dtype cases in convertTensorWithPath() Completed: 2026 03 02 - - [x] S37.1.3 Write unit tests for TENSOR attribute and UINT8 dtype conversion Completed: 2026 03 02 - - [x] S37.1.4 Run golangci-lint and go test -cover in zonnx/pkg/converter/ Completed: 2026 03 02 - -- [x] T37.2 Add BuildConstant[T] to zerfoo and register Owner: TBD Est: 45m Completed: 2026 03 02 - - Dependencies: T37.1 - - Files: layers/core/constant.go, model/builder.go - - Result: Constant[T] in layers/core/constant.go with NewConstant, NewConstantFromData. - Constant nodes handled as special case in model/builder.go (buildConstantNode[T]). - Tests in constant_test.go. - - [x] S37.2.1 Add Constant[T] to layers/core/constant.go Completed: 2026 03 02 - - [x] S37.2.3 Add "Constant" case in model/builder.go Completed: 2026 03 02 - - [x] S37.2.4 Write unit tests for Constant layer build and forward Completed: 2026 03 02 - - [x] S37.2.5 Run golangci-lint and go test -cover Completed: 2026 03 02 - -- [x] T37.3 Add BuildMatMulNBits[T] to zerfoo and register Owner: TBD Est: 1.5h Completed: 2026 03 02 - - Dependencies: T37.1 - - Files: layers/core/matmul_nbits.go - - Result: MatMulNBits[T] with NewMatMulNBits, dequantizeWeights (cached). Supports - symmetric and asymmetric quantization. Tests in matmul_nbits_test.go and - integration/gemma3_quantized_test.go. - - [x] S37.3.1 Add MatMulNBits[T] to layers/core/matmul_nbits.go Completed: 2026 03 02 - - [x] S37.3.4 Write unit tests: build, forward pass, dequantization correctness Completed: 2026 03 02 - - [x] S37.3.5 Run golangci-lint and go test -cover Completed: 2026 03 02 - -- [x] T37.4 Add zonnx converter handler for MatMulNBits Owner: TBD Est: 45m Completed: 2026 03 02 - - Dependencies: T37.1 - - File: zonnx/pkg/converter/converter.go - - Result: "MatMulNBits" case in ONNXToZMFWithPath. convertMatMulNBits() dequantizes - ONNX 4-bit packed weights to float32 and emits a standard MatMul node. dequantizeNBits() - handles per-block scales and optional zero-points. Tests in converter_test.go. - - [x] S37.4.1 Add "MatMulNBits" case in convertNode() Completed: 2026 03 02 - - [x] S37.4.2 Write unit test for MatMulNBits node conversion Completed: 2026 03 02 - - [x] S37.4.3 Run golangci-lint and go test -cover in zonnx/pkg/converter/ Completed: 2026 03 02 - -- [x] T37.5 Add zonnx importer builders for Constant and MatMulNBits Owner: TBD Est: 1h Completed: 2026 03 02 - - Dependencies: T37.1, T37.4 - - Deviation: Separate importer stubs not needed. Constant nodes are stored as ZMF - parameters during ONNX-to-ZMF conversion. MatMulNBits nodes are dequantized to - standard MatMul during conversion. Both are fully handled by the converter layer. - - [x] S37.5.1-S37.5.5 Handled by converter approach Completed: 2026 03 02 - -- [x] T37.6 Gemma 3 ONNX import smoke test Owner: TBD Est: 1.5h Completed: 2026 03 02 - - Dependencies: T37.1, T37.2, T37.3, T37.4, T37.5 - - Files: tests/parity/gemma3_test.go, integration/gemma3_quantized_test.go - - Deviation: Smoke test split across two files. TestGemma3ForwardPass loads a ZMF model - (env-gated by GEMMA3_ZMF_PATH). TestGemma3QuantizedInference exercises Constant + - MatMulNBits end-to-end in an integration test without requiring a model file. - - [x] S37.6.1 Write TestGemma3ForwardPass (skip if env var not set) Completed: 2026 03 02 - - [x] S37.6.3 Run golangci-lint and go test -cover Completed: 2026 03 02 - -- [x] T37.7 Run linters and verify coverage for E37 Owner: TBD Est: 15m Completed: 2026 03 02 - - Dependencies: T37.6 - - Result: golangci-lint 0 issues in both zerfoo and zonnx. Package coverage thresholds met. - - [x] S37.7.1 Run golangci-lint in zerfoo Completed: 2026 03 02 - - [x] S37.7.2 Run golangci-lint in zonnx Completed: 2026 03 02 - - [x] S37.7.3 Verify coverage thresholds Completed: 2026 03 02 - -#### E38: Core Missing Operators - -Implement graph-level layer nodes for operators missing from the zerfoo registry. -These are needed for general transformer inference and as building blocks for VLMs. - -- [x] T38.1 Implement Softmax layer and register Owner: TBD Est: 45m Completed: 2026 03 02 - - Dependencies: None - - Files: layers/activations/softmax.go (new), layers/registry/registry.go, model/builder.go - - Acceptance: Softmax[T] struct with axis int attribute. Forward: for each slice along - axis subtract max (numerical stability), exponentiate, divide by sum. BuildSoftmax[T] - reads "axis" from node attributes (default -1). Register "Softmax" in RegisterAll[T]. - Add "Softmax" case in model/builder.go. Test: Softmax([[1,2,3],[4,5,6]], axis=1) matches - scipy.special.softmax reference (tolerance 1e-6). - - [x] S38.1.1 Create layers/activations/softmax.go with Softmax[T] and BuildSoftmax[T] Est: 20m - - [x] S38.1.2 Register "Softmax" in RegisterAll Est: 5m - - [x] S38.1.3 Write unit tests with numerical reference Est: 15m - - [x] S38.1.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T38.2 Implement Sigmoid layer and register Owner: TBD Est: 30m Completed: 2026 03 02 - - Dependencies: None - - Files: layers/activations/registry.go (BuildSigmoid added), layers/registry/registry.go - - Acceptance: BuildSigmoid[T] wrapping existing NewSigmoid. Register "Sigmoid". Tests pass. - - [x] S38.2.1 Add BuildSigmoid to layers/activations/registry.go Est: 10m - - [x] S38.2.2 Register "Sigmoid" in RegisterAll Est: 5m - - [x] S38.2.3 Write unit tests Est: 10m - - [x] S38.2.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T38.3 Implement standard LayerNormalization layer and register Owner: TBD Est: 1h Completed: 2026 03 02 - - Dependencies: None - - Files: layers/normalization/registry.go (BuildLayerNormalization + resolveParam added) - - Acceptance: BuildLayerNormalization[T] reads epsilon, resolves scale/bias params via - multiple naming patterns, creates LayerNormalization with featureDim from param shape. - Register "LayerNormalization". Tests pass including forward pass verification. - - [x] S38.3.1 Add BuildLayerNormalization to layers/normalization/registry.go Est: 25m - - [x] S38.3.2 Register "LayerNormalization" in RegisterAll Est: 5m - - [x] S38.3.3 Write unit tests vs reference Est: 20m - - [x] S38.3.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T38.4 Implement Slice layer and register Owner: TBD Est: 1h Completed: 2026 03 02 - - Dependencies: None - - Files: layers/core/slice.go (new), layers/registry/registry.go - - Acceptance: Slice[T] with starts/ends/axes/steps. Returns dense copy. BuildSlice[T]. - Register "Slice". Tests cover 1D/2D/negative indices/clamped end. - - [x] S38.4.1 Create layers/core/slice.go Est: 25m - - [x] S38.4.2 Register "Slice" in RegisterAll Est: 5m - - [x] S38.4.3 Write unit tests Est: 20m - - [x] S38.4.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T38.5 Implement Pad layer and register Owner: TBD Est: 1h Completed: 2026 03 02 - - Dependencies: None - - Files: layers/core/pad.go (new), layers/registry/registry.go - - Acceptance: Pad[T] with pads []int64 and constantValue. BuildPad[T]. Register "Pad". - Tests cover 1D/2D/constant value/mismatch errors. - - [x] S38.5.1 Create layers/core/pad.go Est: 25m - - [x] S38.5.2 Register "Pad" in RegisterAll Est: 5m - - [x] S38.5.3 Write unit tests Est: 20m - - [x] S38.5.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T38.6 Implement TopK layer and register Owner: TBD Est: 1h Completed: 2026 03 02 - - Dependencies: None - - Files: layers/core/topk.go (new), layers/registry/registry.go - - Acceptance: TopK[T] with k/axis/largest/sorted. Returns values only (not indices). - BuildTopK[T]. Register "TopK". Tests cover largest/smallest/large-k/builder paths. - - [x] S38.6.1 Create layers/core/topk.go Est: 30m - - [x] S38.6.2 Register "TopK" in RegisterAll Est: 5m - - [x] S38.6.3 Write unit tests Est: 20m - - [x] S38.6.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T38.7 Implement Erf layer and register Owner: TBD Est: 30m Completed: 2026 03 02 - - Dependencies: None - - Files: layers/activations/erf.go (new), layers/registry/registry.go - - Acceptance: NewErf[T tensor.Float] using math.Erf via BaseActivation. BuildErf[T]. - Register "Erf". Tests verify erf(0)=0, erf(1)~0.8427, erf(-1)~-0.8427. - - [x] S38.7.1 Create layers/activations/erf.go Est: 10m - - [x] S38.7.2 Register "Erf" in RegisterAll Est: 5m - - [x] S38.7.3 Write unit tests Est: 10m - - [x] S38.7.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T38.8 Add zonnx importer builders for E38 operators Owner: TBD Est: 1.5h Completed: 2026 03 02 - - Dependencies: T38.1 through T38.7 - - Files: zonnx/pkg/converter/converter.go (Slice/Pad/TopK cases), - zonnx/pkg/importer/layers/{softmax,sigmoid,erf,layer_norm,slice,pad,topk}.go (new) - - Note: Slice/Pad/TopK needed converter special cases to promote input tensors to ZMF - attributes. Softmax/Sigmoid/Erf/LayerNorm work via the generic convertNode path. - - [x] S38.8.1 Create zonnx importer builders for Softmax, Sigmoid, LayerNorm Est: 30m - - [x] S38.8.2 Create zonnx importer builders for Slice, Pad Est: 20m - - [x] S38.8.3 Create zonnx importer builders for TopK, Erf Est: 20m - - [x] S38.8.4 Register all builders in zonnx importer registry via init() Est: 5m - - [x] S38.8.5 Write round-trip tests for each operator in converter_test.go Est: 20m - - [x] S38.8.6 Run golangci-lint and go test -cover in zonnx/ Est: 5m - -- [x] T38.9 Run linters and verify coverage for E38 Owner: TBD Est: 15m Completed: 2026 03 02 - - Dependencies: T38.1 through T38.7 (T38.8 pending) - - Acceptance: golangci-lint 0 issues in layers/activations/, layers/core/, - layers/normalization/, layers/registry/. go test -race ./... all pass. - - [x] S38.9.1 Run golangci-lint and go test -cover -race in all modified dirs Est: 10m - - [x] S38.9.2 Fix any remaining issues Est: 5m Note: fixed copyloopvar and SA9003 in pad/topk - -#### E39: Vision Encoder Operators - -Implement operators for the SigLIP vision encoder used in MoondreamV2 and -Kimi-VL. All operators use NCHW tensor format [N, C, H, W]. - -- [x] T39.1 Implement Conv2d layer and register Owner: TBD Est: 2h Completed: 2026 03 02 - - Dependencies: None - - Files: layers/core/conv2d.go (new), layers/registry/registry.go, model/builder.go - - Acceptance: Conv2d[T] struct. Attributes: strides [2]int, pads [4]int - (top,left,bottom,right), dilations [2]int, groups int. Fields: kernel - [out_C, in_C/groups, kH, kW], bias [out_C] optional. Forward: im2col reshapes input - patches to [N*H_out*W_out, in_C*kH*kW]; multiply by flattened kernel - [in_C*kH*kW, out_C]; reshape to [N, out_C, H_out, W_out]. BuildConv2d[T] reads kernel - and bias from node initializers. Register "Conv". Test 1: [1,1,5,5] all-ones input, - [1,1,3,3] all-ones kernel, stride=1, pad=0 returns [1,1,3,3] where each value = 9.0. - Test 2: stride=2 halves spatial dims. Test 3: padding preserves spatial dims. - - Deviation: Used direct nested-loop convolution instead of im2col+MatMul to avoid - allocating a large intermediate matrix. Simpler and correct for inference workloads. - - [x] S39.1.1 Implement Conv2d Forward using nested loops with ops.Mul/ops.Add Est: 30m - - [x] S39.1.2 Implement BuildConv2d[T] reading strides/pads/dilations/group attributes Est: 20m - - [x] S39.1.3 Register "Conv" Est: 5m - - [x] S39.1.4 Write unit tests (table-driven: stride=1, stride=2, with-bias) Est: 25m - - [x] S39.1.5 Run golangci-lint and go test -cover Est: 5m - -- [x] T39.2 Implement GlobalAveragePool layer and register Owner: TBD Est: 30m Completed: 2026 03 02 - - Dependencies: None - - Files: layers/core/global_avg_pool.go (new), layers/registry/registry.go - - [x] S39.2.1 Create layers/core/global_avg_pool.go Est: 10m - - [x] S39.2.2 Register "GlobalAveragePool" Est: 5m - - [x] S39.2.3 Write unit tests Est: 10m - - [x] S39.2.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T39.3 Implement BatchNormalization layer (inference mode) and register Owner: TBD Est: 1h Completed: 2026 03 02 - - Dependencies: None - - Files: layers/normalization/batch_norm.go (new), layers/registry/registry.go - - [x] S39.3.1 Create layers/normalization/batch_norm.go Est: 25m - - [x] S39.3.2 Register "BatchNormalization" Est: 5m - - [x] S39.3.3 Write unit tests (zero-mean, scale+bias, spatial dims) Est: 20m - - [x] S39.3.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T39.4 Implement Resize layer and register Owner: TBD Est: 1h Completed: 2026 03 02 - - Dependencies: None - - Files: layers/core/resize.go (new), layers/registry/registry.go - - [x] S39.4.1 Create layers/core/resize.go (nearest neighbor) Est: 25m - - [x] S39.4.2 Register "Resize" Est: 5m - - [x] S39.4.3 Write unit tests (scales and sizes modes) Est: 15m - - [x] S39.4.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T39.5 Add zonnx importer builders for E39 operators Owner: TBD Est: 1.5h Completed: 2026 03 02 - - Dependencies: T39.1 through T39.4 - - Files: zonnx/pkg/importer/layers/conv.go (new), global_avg_pool.go (new), - batch_norm.go (new), resize.go (new); zonnx/pkg/converter/converter.go - - [x] S39.5.1 Create importer stubs for Conv, GlobalAveragePool Est: 30m - - [x] S39.5.2 Create importer stubs for BatchNormalization, Resize Est: 30m - - [x] S39.5.3 Add Resize special case in converter (promote scales/sizes inputs) Est: 10m - - [x] S39.5.4 Fix converter to skip empty optional ONNX inputs Est: 5m - - [x] S39.5.5 Write round-trip tests for Resize (scales + sizes variants) Est: 20m - - [x] S39.5.6 Run golangci-lint and go test ./... Est: 5m - -- [x] T39.6 Run linters and verify coverage for E39 Owner: TBD Est: 15m Completed: 2026 03 02 - - Dependencies: T39.5 - - [x] S39.6.1 Run golangci-lint and go test -cover -race: 0 issues, all pass Est: 10m - -#### E40: Mixture of Experts - -Implement MixtureOfExperts[T] for Kimi-VL-A3B (MoonLight uses sparse MoE -with top-2 expert routing per token). - -- [x] T40.1 Implement MoE gate routing layer Owner: TBD Est: 1.5h Completed: 2026 03 02 - - Dependencies: T38.1 (Softmax), T38.6 (TopK) - - Files: layers/core/moe.go (new) - - Deviation: gateWeight is passed as a runtime Forward input (not from params) to match - the ONNX/ZMF pattern used by Conv2d and BatchNorm. MoEGate.route() is a private method - called by both Forward and MixtureOfExperts. Returns [seqLen, topK] weight tensor. - - [x] S40.1.1 Implement MoEGate struct and Forward Completed: 2026 03 02 - - [x] S40.1.2 Add BuildMoEGate[T] and register "MoEGate" Completed: 2026 03 02 - - [x] S40.1.3 Write unit tests Completed: 2026 03 02 - - [x] S40.1.4 Run golangci-lint and go test -cover Completed: 2026 03 02 - -- [x] T40.2 Implement MoE expert dispatch and aggregate Owner: TBD Est: 2h Completed: 2026 03 02 - - Dependencies: T40.1 - - Files: layers/core/moe.go (extended), layers/registry/registry.go - - Deviation: Experts are graph.Node[T] instances set at construction time. ZMF sub-graph - loading not yet supported; BuildMixtureOfExperts leaves experts=nil (documented as tech - debt). Test uses identityExpert and scale2Expert helper types. - - [x] S40.2.1 Implement MixtureOfExperts struct and Forward Completed: 2026 03 02 - - [x] S40.2.2 Add BuildMixtureOfExperts[T] with expert loading strategy Completed: 2026 03 02 - - [x] S40.2.3 Register "MixtureOfExperts" Completed: 2026 03 02 - - [x] S40.2.4 Write unit tests Completed: 2026 03 02 - - [x] S40.2.5 Run golangci-lint and go test -cover Completed: 2026 03 02 - -- [x] T40.3 Add zonnx importer builders for MoE operators Owner: TBD Est: 1h Completed: 2026 03 02 - - Dependencies: T40.1, T40.2 - - File: zonnx/pkg/importer/layers/moe.go (new) - - [x] S40.3.1 Create zonnx/pkg/importer/layers/moe.go Completed: 2026 03 02 - - [x] S40.3.3 Run golangci-lint and go test -cover Completed: 2026 03 02 - -- [x] T40.4 Run linters and verify coverage for E40 Owner: TBD Est: 15m Completed: 2026 03 02 - - Dependencies: T40.3 - - Result: golangci-lint 0 issues. moe.go package-level coverage 93.8%; all functions >= 87%. - - [x] S40.4.1 Run golangci-lint and go test -cover -race Completed: 2026 03 02 - - [x] S40.4.2 Fix any remaining issues Completed: 2026 03 02 - -#### E41: Gemma 3 End-to-End Validation - -- [x] T41.1 Gemma 3 forward pass parity test Owner: TBD Est: 2h Completed: 2026 03 02 - - Dependencies: E37, E38 - - File: tests/parity/gemma3_test.go (created) - - Result: Skips when GEMMA3_ZMF_PATH not set. Asserts output shape [1,seqLen,V>=256000] - and no NaN or Inf. golangci-lint 0 issues. - - [x] S41.1.1 Create tests/parity/gemma3_test.go Completed: 2026 03 02 - - [x] S41.1.4 Run golangci-lint and go test -cover Completed: 2026 03 02 - -- [x] T41.2 Gemma 3 greedy decode smoke test Owner: TBD Est: 1h Completed: 2026 03 02 - - Dependencies: T41.1 - - File: tests/parity/gemma3_test.go (extended) - - Result: Implements 5-step greedy decode loop in TestGemma3GreedyDecode; skips when - env var not set; asserts tokens in [0, vocabSize). - - [x] S41.2.1 Implement greedy decode loop in test Completed: 2026 03 02 - - [x] S41.2.3 Run golangci-lint and go test Completed: 2026 03 02 - -#### E42: Kimi-VL Vision Encoder Validation - -- [x] T42.1 SigLIP vision encoder forward pass test Owner: TBD Est: 2h Completed: 2026 03 02 - - Dependencies: E39 - - File: tests/parity/siglip_test.go (created) - - Result: TestSigLIPForwardPass skips when SIGLIP_ZMF_PATH not set. Asserts shape - [1, 196, embedDim] and no NaN or Inf. golangci-lint 0 issues. - - [x] S42.1.1 Create tests/parity/siglip_test.go Completed: 2026 03 02 - - [x] S42.1.3 Verify output and run golangci-lint Completed: 2026 03 02 - -- [x] T42.2 Kimi-VL connector forward pass test Owner: TBD Est: 1h Completed: 2026 03 02 - - Dependencies: T42.1 - - File: tests/parity/siglip_test.go (extended) - - Result: TestKimiVLConnectorForwardPass skips when KIMI_CONNECTOR_ZMF_PATH not set. - Asserts shape [1, 196, lmDim] and no NaN or Inf. - - [x] S42.2.1 Implement connector test Completed: 2026 03 02 - - [x] S42.2.3 Run golangci-lint and go test Completed: 2026 03 02 - -#### E43: Phase 6 Final Verification - -- [x] T43.1 Run full test suite with coverage and race detector Owner: TBD Est: 30m Completed: 2026 03 02 - - Dependencies: E37, E38, E39, E40, E41, E42 - - Result: go test ./... -race passes in zerfoo (all 55 packages green). go test ./... - passes in zonnx (all packages green). No regressions. - - [x] S43.1.1 Run go test ./... -cover -race in zerfoo Completed: 2026 03 02 - - [x] S43.1.2 Run go test ./... -cover in zonnx Completed: 2026 03 02 - - [x] S43.1.3 Fix any regressions Completed: 2026 03 02 - -- [x] T43.2 Run linters across all modified directories Owner: TBD Est: 15m Completed: 2026 03 02 - - Dependencies: T43.1 - - Result: golangci-lint 0 issues in zerfoo and zonnx. - - [x] S43.2.1 Run golangci-lint run ./... in zerfoo Completed: 2026 03 02 - - [x] S43.2.2 Run golangci-lint run ./... in zonnx Completed: 2026 03 02 - - [x] S43.2.3 Fix any remaining lint issues Completed: 2026 03 02 - -- [x] T43.3 Update documentation Owner: TBD Est: 30m Completed: 2026 03 02 - - Dependencies: T43.2 - - Result: docs/plan.md Phase 6 tasks marked [x]. New operators added to registry - (Softmax, Sigmoid, Erf, LayerNormalization, Slice, Pad, TopK, Conv, - GlobalAveragePool, Resize, BatchNormalization, MoEGate, MixtureOfExperts). - - [x] S43.3.1 Update docs/plan.md Phase 6 tasks to [x] Completed: 2026 03 02 - ---- - -### Phase 7: Architecture Cleanup - -#### Phase 7 Context - -A comprehensive architecture review on 2026 03 02 identified structural issues -accumulated over Phases 1-6. The framework is functionally complete for open -weights model import, but internal code quality has drifted. This phase fixes -concrete issues found during the review without breaking existing APIs. - -Issues identified and their severity: - -1. **Dead code**: pkg/prelude/prelude.go is an empty package with no - declarations. tests/helpers/wire.go declares 4 interface variables (ImplZerfoo, - ImplNumerics, ImplPipeline, ImplPerf) that are all nil with no implementations. - Both add confusion and bloat. - -2. **Inverted layer registration dependency**: layers/core/registry.go has an - init() function that imports model and registers FFN for float16. This creates - layers -> model coupling (inverted direction). The correct pattern is used by - layers/registry/registry.go which sits above both packages. - -3. **Graph.memo is not thread-safe**: graph.Graph stores a memo map for caching - forward-pass activations. The map is reset on each Forward call with no mutex. - Concurrent Forward calls from different goroutines would race on this map. - This blocks safe use in serving scenarios where multiple requests run in - parallel. - -4. **model/ package is overloaded**: model/ contains 5 distinct concerns in one - package: global layer registry (registry.go), concrete Model[T] struct - (model.go), ZMF file I/O (zmf_loader.go, zmf_exporter.go), graph builder - from ZMF (builder.go), plugin registry with 6 component types - (model_registry.go). Splitting the layer registry into its own package would - reduce coupling and allow layers/core to register without importing model. - -5. **Stale plan references**: docs/plan.md Section 1 references docs/gpu.md - which no longer exists. Hand-off notes reference deleted files. Appendix - scorecards are outdated. - -#### Phase 7 Objectives - -- P7-O1: Remove dead code (pkg/prelude, tests/helpers nil stubs). -- P7-O2: Eliminate inverted layers/core -> model dependency. -- P7-O3: Add thread safety to graph.Graph for concurrent Forward calls. -- P7-O4: Update docs/plan.md and docs/design.md to reflect all changes. - -#### Phase 7 Non-Goals - -- Splitting model/ into multiple packages. The overloaded model package is - documented tech debt but splitting it is a large refactor that would touch - every import site. Defer to a future phase. -- Changing the Arithmetic[T] interface. Removing activation functions (Tanh, - Sigmoid, ReLU) from Arithmetic would break all 6 implementations and all - callers. The interface is locked per the project non-goals. -- Merging Sum and ReduceSum in the Engine[T] interface. The Engine interface is - locked per non-goals (no breaking changes to Engine[T] or Node[T]). -- Changing log.Logger field signature from ...string to ...any. Would break all - callers across 10+ packages. -- Removing data/ and features/ domain-specific packages. They are used by - training tests and the audacity project. -- Renaming model.ModelProvider to avoid collision with training.ModelProvider. - Would break exported API. - -#### Phase 7 Design Decisions - -**Layer registration consolidation strategy:** -Remove the init() in layers/core/registry.go entirely. Move the FFN float16 -registration into layers/registry/registry.go where all other registrations -live. This eliminates the inverted dependency without changing any public API. -The FFN registration will be typed to float32 like all other registrations in -RegisterAll (float16 registration was likely a mistake since the rest of the -wiring is float32). - -**Graph thread safety strategy:** -Add a sync.Mutex to graph.Graph protecting the memo map. Lock on each Forward -call (reset + full traversal) and each Backward call (reads memo). This is -coarse-grained but correct. Fine-grained per-node locking is premature -optimization for a graph that is typically small (< 1000 nodes). - -**Dead code removal strategy:** -Delete pkg/prelude/prelude.go entirely. Delete tests/helpers/wire.go and its -4 interface definitions. Verify no other file imports these packages. If -tests/helpers/ becomes empty, delete the directory. - ---- - -#### E44: Remove Dead Code - -Remove empty packages and nil stub files that add confusion without value. - -- [x] T44.1 Delete pkg/prelude package Owner: Claude Est: 15m Completed: 2026 03 02 - - Dependencies: None - - Acceptance: pkg/prelude/ directory is deleted. No file in the repo imports - "github.com/zerfoo/zerfoo/pkg/prelude". go build ./... succeeds. go test - ./... passes. - - [x] S44.1.1 Verify no imports of pkg/prelude exist in the repo Est: 2m - - [x] S44.1.2 Delete pkg/prelude/prelude.go Est: 2m - - [x] S44.1.3 Remove pkg/prelude/ directory Est: 1m - - [x] S44.1.4 Run go build ./... and go test ./... to verify no breakage Est: 5m - - [x] S44.1.5 Run golangci-lint Est: 5m - -- [x] T44.2 Delete tests/helpers/wire.go nil stubs and dead test files Owner: Claude Est: 15m Completed: 2026 03 02 - - Dependencies: None - - Acceptance: tests/helpers/ deleted. tests/numerics/ deleted (all 3 test files - were dead). 4 dead parity test files deleted. No file imports - "github.com/zerfoo/zerfoo/tests/helpers". go build ./... succeeds. - - Notes: 7 test files (17 test functions) were dead -- they all unconditionally - skipped because helpers.Impl* variables were nil with no implementations. - The working parity tests (gemma3_test.go, siglip_test.go) use env-var gating - and are unaffected. - - [x] S44.2.1 Check if any test file imports tests/helpers Est: 2m - - [x] S44.2.2 Delete tests/helpers/wire.go and 7 dead test files Est: 2m - - [x] S44.2.3 Delete tests/helpers/ and tests/numerics/ directories Est: 1m - - [x] S44.2.4 Run go build ./... and go test ./... to verify no breakage Est: 5m - - [x] S44.2.5 Run golangci-lint Est: 5m - -- [x] T44.3 Run linters and verify for E44 Owner: Claude Est: 10m Completed: 2026 03 02 - - Dependencies: T44.1, T44.2 - - Acceptance: golangci-lint 0 issues. go vet clean. go test ./... -race passes. - - [x] S44.3.1 Run golangci-lint run ./... Est: 5m - - [x] S44.3.2 Run go test ./... -race Est: 5m - -#### E45: Consolidate Layer Registration - -Eliminate the inverted layers/core -> model dependency by removing the init() -auto-registration in layers/core/registry.go and consolidating all -registrations into layers/registry/registry.go. - -- [x] T45.1 Move FFN registration from layers/core/registry.go init() to layers/registry Owner: Claude Est: 30m Completed: 2026 03 02 - - Dependencies: None - - Files: layers/core/registry.go (modify), layers/registry/registry.go (modify) - - Acceptance: layers/core/registry.go no longer has an init() function. - layers/core/registry.go no longer imports "github.com/zerfoo/zerfoo/model". - layers/registry/registry.go RegisterAll() includes an FFN registration - (model.RegisterLayer("FFN", core.BuildFFN[float32])). go build ./... succeeds. - All existing tests pass. The FFN layer is still available after RegisterAll(). - - Notes: Also exported buildFFN as BuildFFN and removed float16 import from - layers/core/registry.go. Updated registry_builders_test.go to use BuildFFN. - - [x] S45.1.1 Grep for float16.Float16 usage with FFN to verify nothing depends on float16 registration Est: 5m - - [x] S45.1.2 Add model.RegisterLayer("FFN", core.BuildFFN[float32]) to RegisterAll Est: 5m - - [x] S45.1.3 Remove init() function and model import from layers/core/registry.go Est: 5m - - [x] S45.1.4 Run go build ./... to verify compilation Est: 5m - - [x] S45.1.5 Run go test ./... -race to verify no regressions Est: 5m - - [x] S45.1.6 Run golangci-lint Est: 5m - -- [x] T45.2 Verify no other init()-based registrations exist in layers/ Owner: Claude Est: 15m Completed: 2026 03 02 - - Dependencies: T45.1 - - Acceptance: Grep for "func init()" in all layers/ files returns zero results. - The only layer registration entry point is layers/registry.RegisterAll(). - - [x] S45.2.1 Grep for func init() in layers/ directory tree Est: 5m - - [x] S45.2.2 If any found, move them to RegisterAll and remove Est: 5m - - [x] S45.2.3 Run go test ./... -race Est: 5m - -- [x] T45.3 Run linters and verify for E45 Owner: Claude Est: 10m Completed: 2026 03 02 - - Dependencies: T45.1, T45.2 - - Acceptance: golangci-lint 0 issues. go vet clean. go test ./... -race passes. - No init()-based registrations in layers/. - - [x] S45.3.1 Run golangci-lint run ./... Est: 5m - - [x] S45.3.2 Run go test ./... -race Est: 5m - -#### E46: Graph Thread Safety - -Add mutex protection to graph.Graph to allow concurrent Forward calls from -different goroutines without data races. - -- [x] T46.1 Add sync.Mutex to Graph struct and protect memo map Owner: Claude Est: 45m Completed: 2026 03 02 - - Dependencies: None - - Files: graph/graph.go, graph/concurrent_test.go (new) - - Acceptance: graph.Graph has a sync.Mutex field (mu). Forward() locks mu - before resetting memo and unlocks after the full traversal completes. - Backward() locks mu before reading memo and unlocks after completion. - go test ./graph/ -race passes. A new test spawns 8 goroutines each calling - Forward concurrently with different inputs and verifies no race and no panic. - - [x] S46.1.1 Add sync.Mutex field to Graph struct Est: 5m - - [x] S46.1.2 Add mu.Lock()/mu.Unlock() around memo reset and traversal in Forward Est: 10m - - [x] S46.1.3 Add mu.Lock()/mu.Unlock() around memo reads in Backward Est: 10m - - [x] S46.1.4 Write TestGraph_ConcurrentForward: 8 goroutines, different inputs, verify no race Est: 15m - - [x] S46.1.5 Run go test ./graph/ -race -cover Est: 5m - -- [x] T46.2 Run linters and verify for E46 Owner: Claude Est: 10m Completed: 2026 03 02 - - Dependencies: T46.1 - - Acceptance: golangci-lint 0 issues on graph/. go test ./... -race passes. - graph/ coverage remains >= 95%. - - [x] S46.2.1 Run golangci-lint run ./graph/ Est: 5m - - [x] S46.2.2 Run go test ./... -race Est: 5m - -#### E47: Documentation Update - -Update docs/plan.md and docs/design.md to reflect Phase 7 changes and correct -stale references. - -- [x] T47.1 Update docs/design.md with architecture improvements Owner: Claude Est: 30m Completed: 2026 03 02 - - Dependencies: E44, E45, E46 - - Acceptance: docs/design.md Section 10 (Known Limitations) updated: - remove item 9 (graph thread safety -- now fixed). Section 2.1 (Package - Layout) updated: remove pkg/prelude line. Add note about single registration - entry point (RegisterAll). Add concurrency note to Section 3.2 (Node/Graph). - - Notes: Also removed tests/helpers, tests/numerics references, updated CI - section, removed pkg/prelude from coverage exclusions. - - [x] S47.1.1 Remove pkg/prelude from package layout Est: 5m - - [x] S47.1.2 Update Known Limitations: remove graph thread-safety item Est: 5m - - [x] S47.1.3 Add concurrency note to Graph section Est: 5m - - [x] S47.1.4 Add registration consolidation note to Section 3.6 Est: 5m - - [x] S47.1.5 Review full document for other stale references Est: 10m - -- [x] T47.2 Update docs/plan.md metadata Owner: Claude Est: 15m Completed: 2026 03 02 - - Dependencies: E44, E45, E46 - - Acceptance: All Phase 7 tasks marked complete. Progress log updated. - - [x] S47.2.1 Mark all Phase 7 tasks complete with dates Est: 5m - - [x] S47.2.2 Update progress log Est: 5m - - [x] S47.2.3 Add progress log entry Est: 5m - -#### E48: Phase 7 Final Verification - -Run the full quality gate suite after all Phase 7 work is complete. - -- [x] T48.1 Run full test suite with coverage and race detector Owner: Claude Est: 15m Completed: 2026 03 02 - - Dependencies: E44, E45, E46, E47 - - Acceptance: go test ./... -race passes (all packages green, 0 data races). - graph/ coverage = 97.1% (>= 95% threshold). - - [x] S48.1.1 Run go test ./... -race Est: 10m - - [x] S48.1.2 Verify graph/ coverage >= 95% (actual: 97.1%) Est: 5m - -- [x] T48.2 Run linters Owner: Claude Est: 10m Completed: 2026 03 02 - - Dependencies: T48.1 - - Acceptance: golangci-lint run ./... reports 0 issues. go vet ./... clean. - - [x] S48.2.1 Run golangci-lint run ./... Est: 5m - - [x] S48.2.2 Run go vet ./... Est: 5m - ---- - -### Phase 8: Embeddable Go-Native Inference Library - -#### Phase 8 Context - -Phases 1-7 built a production-grade ML framework with clean interfaces, GPU -support, distributed training, and open-weights model import (Gemma 3, Kimi-VL). -However, running inference on an imported model requires extensive manual wiring: - -1. Download ONNX model files from HuggingFace manually. -2. Convert ONNX to ZMF using zonnx CLI. -3. Write Go code to create an Engine, load the ZMF file, build a graph. -4. Tokenize input using the whitespace-only tokenizer (wrong for real models). -5. Call Graph.Forward() in a manual loop for autoregressive generation. -6. No KV cache -- every forward pass recomputes the full sequence (O(n^2)). -7. No sampling strategies -- only argmax (greedy) exists in test code. -8. No streaming output -- callers must wait for full generation to complete. - -Phase 8 transforms zerfoo into an embeddable inference library that users can -`go get` and use with minimal code: - - m, _ := inference.Load("google/gemma-3-4b-it") - resp, _ := m.Generate(ctx, "Explain quantum computing") - fmt.Println(resp) - -This requires: production tokenizer, KV cache, generation loop with sampling, -model registry with auto-download, high-level API, streaming, and CLI commands. - -#### Phase 8 Objectives - -- P8-O1: Replace whitespace tokenizer with production BPE implementation that - loads HuggingFace tokenizer.json format. Pure Go, no CGo. -- P8-O2: Implement KV cache for attention layers to enable efficient - autoregressive generation (O(n) per step instead of O(n^2)). -- P8-O3: Implement autoregressive generation loop with configurable sampling - (greedy, temperature, top-k, top-p, repetition penalty). -- P8-O4: Add streaming token delivery via callback interface. -- P8-O5: Implement model registry with local caching and automatic download - from HuggingFace, including ONNX-to-ZMF conversion. -- P8-O6: Create high-level inference API (inference.Load, Model.Generate, - Model.Chat, Model.Embed) requiring minimal boilerplate. -- P8-O7: Add CLI commands (pull, run, serve) for interactive use and HTTP - serving with OpenAI-compatible API. -- P8-O8: Validate end-to-end with Gemma 3 generating coherent text. - -#### Phase 8 Non-Goals - -- Training or fine-tuning through the high-level API. -- Multi-model serving (one model per serve instance). -- GPU memory management optimization (use existing Engine memory tracking). -- Beam search or other advanced decoding strategies beyond top-k/top-p. -- Quantization at inference time (only loading pre-quantized weights). -- Custom model architecture plugins (only architectures known to the registry). -- WebSocket transport for streaming (SSE only for HTTP serve). -- Authentication or rate limiting on the serve endpoint. - -#### Phase 8 Constraints - -- Pure Go. No CGo. No external C libraries for tokenization. -- Use Go standard library for HTTP server (net/http), JSON (encoding/json), - file I/O (os, io). Minimize new dependencies. -- Tokenizer must load HuggingFace tokenizer.json format (widely available for - all major open-weights models). -- KV cache must not break existing Graph.Forward() callers -- cache is opt-in - via a GenerationContext wrapper around context.Context. -- Model registry must work offline after initial pull (all files cached locally). -- HTTP serve endpoint must be compatible with OpenAI API format for tool - interoperability. - -#### Phase 8 Design Decisions - -**Tokenizer format and algorithm:** -Use HuggingFace tokenizer.json as the canonical format. This JSON file contains: -vocabulary (token to ID mapping), merge rules (for BPE), pre-tokenizer config, -normalizer config, and special tokens. The BPE merge loop is implemented in pure -Go: split input into bytes, iteratively merge the highest-priority adjacent pair -according to the merge rules, return token IDs. Pre-tokenization (byte-level -BPE prefix, whitespace splitting) is handled before the merge loop. -SentencePiece .model files are NOT supported directly; users convert to -tokenizer.json using HuggingFace tokenizers library (Python) as a one-time step. -Most models on HuggingFace already ship tokenizer.json. - -**KV cache architecture:** -A GenerationContext struct embeds context.Context and carries a *KVCache pointer. -KVCache is a struct with per-layer storage: []LayerKV where LayerKV holds K and V -tensors (appended on each step). Attention layers (GroupQueryAttention, -GlobalAttention) check for KVCache in the context. If present, the layer: -(a) appends the current step's K/V to the cache, (b) uses the full cached K/V -for attention computation, (c) returns output for the current step only. This -avoids recomputing attention over the full prefix. Graph.Forward() signature does -not change (it takes context.Context; GenerationContext satisfies the interface). -Callers who do not set a KVCache get the existing behavior (full recompute). - -**Generation loop design:** -A Generator struct holds references to the loaded graph, tokenizer, engine, and -model config. Generate(ctx, tokens, config) runs the autoregressive loop: -1. Encode prompt to token IDs. -2. Run graph.Forward(genCtx, inputTensor) to get logits [1, seqLen, vocabSize]. -3. Extract logits for the last position. -4. Apply temperature scaling (divide by T). -5. Apply top-k filtering (keep top K, set rest to -inf). -6. Apply top-p filtering (keep smallest set with cumulative prob >= P). -7. Apply repetition penalty (divide logits for previously seen tokens by penalty). -8. Sample from the distribution (or argmax for greedy). -9. Append sampled token to sequence, update KV cache position. -10. If streaming, deliver token via callback. -11. Check stop conditions (EOS token, max tokens, stop strings). -12. Repeat from step 2 with only the new token as input (KV cache handles prefix). - -**Model registry layout:** -Models are cached under a configurable directory (default: ~/.zerfoo/models/). -Directory structure: /// containing: -- model.zmf (the ZMF model file) -- tokenizer.json (HuggingFace tokenizer) -- config.json (model metadata: architecture, vocab_size, hidden_size, - num_layers, max_position_embeddings, eos_token_id, bos_token_id) -Pull operation: download ONNX files from HuggingFace Hub API, convert to ZMF -using zonnx converter (called as Go library, not subprocess), copy tokenizer.json -and generate config.json from ONNX metadata. - -**HTTP serve API:** -Standard net/http server. Two endpoints compatible with OpenAI API format: -POST /v1/chat/completions (chat format) and POST /v1/completions (raw prompt). -Streaming via Server-Sent Events (SSE) when stream=true in request body. -GET /v1/models returns the loaded model metadata. - ---- - -#### E49: Production Tokenizer - -Replace the whitespace-only tokenizer in pkg/tokenizer/ with a production BPE -implementation that loads HuggingFace tokenizer.json format. Pure Go, no CGo. - -- [x] T49.1 Define Tokenizer interface and data types Owner: TBD Est: 45m Completed: 2026 03 02 - - Dependencies: None - - Files: pkg/tokenizer/tokenizer.go (rewrite) - - Acceptance: Tokenizer interface with methods: Encode(text string) ([]int, error), - Decode(ids []int) (string, error), VocabSize() int, GetToken(id int) (string, bool), - GetID(token string) (int, bool), SpecialTokens() SpecialTokens. SpecialTokens struct - with BOS, EOS, PAD, UNK token IDs. Existing WhitespaceTokenizer refactored to - implement the new interface (backwards compatibility in tests). - - [ ] S49.1.1 Define Tokenizer interface in pkg/tokenizer/tokenizer.go Est: 15m - - [ ] S49.1.2 Define SpecialTokens and TokenizerConfig structs Est: 10m - - [ ] S49.1.3 Refactor existing WhitespaceTokenizer to implement new interface Est: 10m - - [ ] S49.1.4 Write interface compliance tests Est: 10m - -- [x] T49.2 Implement BPE merge algorithm Owner: TBD Est: 1.5h Completed: 2026 03 02 - - Dependencies: T49.1 - - Files: pkg/tokenizer/bpe.go (new) - - Acceptance: BPETokenizer struct implementing Tokenizer. Fields: vocab map[string]int, - reverseVocab map[int]string, merges []MergePair (ordered by priority), - mergeRanks map[MergePair]int. Encode: split text into pre-tokens (bytes or - characters depending on config), apply BPE merges iteratively (merge highest - priority pair first), return token IDs. Decode: map IDs back to strings, - concatenate, handle byte-level BPE decoding (convert byte tokens to UTF-8). - Test: encode("hello world") with a known small vocabulary produces expected IDs. - Test: encode then decode round-trips for ASCII and Unicode text. - - [ ] S49.2.1 Implement MergePair struct and merge rank lookup Est: 15m - - [ ] S49.2.2 Implement core BPE merge loop Est: 30m - - [ ] S49.2.3 Implement byte-level BPE pre-tokenization (GPT-2 style byte encoding) Est: 20m - - [ ] S49.2.4 Implement Decode with byte-level post-processing Est: 15m - - [ ] S49.2.5 Write unit tests: small vocab, round-trip, Unicode, empty input Est: 20m - - [ ] S49.2.6 Run golangci-lint and go test -cover Est: 5m - -- [x] T49.3 Implement tokenizer.json loader Owner: TBD Est: 1.5h Completed: 2026 03 02 - - Dependencies: T49.2 - - Files: pkg/tokenizer/loader.go (new) - - Acceptance: LoadFromJSON(path string) (*BPETokenizer, error) reads a HuggingFace - tokenizer.json file. Parses: model.vocab (map), model.merges (list of "token1 token2" - strings), added_tokens (special tokens with IDs), pre_tokenizer config. Supports - byte-level BPE (type="ByteLevel") and basic pre-tokenizer types (Whitespace, - ByteLevel, Sequence). Normalizer support: NFC, NFD, Lowercase, Strip (applied in - order). Test: load a testdata/tokenizer.json fixture, encode a known prompt, verify - token IDs match Python HuggingFace tokenizer output for the same input. - - [ ] S49.3.1 Define JSON schema structs matching tokenizer.json format Est: 20m - - [ ] S49.3.2 Implement LoadFromJSON with vocab and merge parsing Est: 25m - - [ ] S49.3.3 Implement pre-tokenizer dispatch (ByteLevel, Whitespace, Sequence) Est: 20m - - [ ] S49.3.4 Implement normalizer chain (NFC, Lowercase, Strip) Est: 15m - - [ ] S49.3.5 Write tests with testdata/tokenizer.json fixture Est: 15m - - [ ] S49.3.6 Run golangci-lint and go test -cover Est: 5m - -- [x] T49.4 Add special token handling Owner: TBD Est: 30m Completed: 2026 03 02 - - Dependencies: T49.3 - - Files: pkg/tokenizer/bpe.go (modify) - - Acceptance: BPETokenizer.EncodeWithSpecialTokens(text string, addBOS bool, - addEOS bool) ([]int, error) wraps Encode and prepends BOS / appends EOS. - Special tokens loaded from tokenizer.json added_tokens section. Special tokens - are not subject to BPE merging (matched and replaced before BPE). - Test: EncodeWithSpecialTokens prepends BOS and appends EOS when requested. - - [ ] S49.4.1 Extract special token IDs from added_tokens in loader Est: 10m - - [ ] S49.4.2 Implement EncodeWithSpecialTokens Est: 10m - - [ ] S49.4.3 Write tests for BOS/EOS injection Est: 10m - -- [x] T49.5 Run linters and verify coverage for E49 Owner: TBD Est: 15m Completed: 2026 03 02 - - Dependencies: T49.4 - - Acceptance: golangci-lint 0 issues on pkg/tokenizer/. go test -cover -race - shows >= 95% coverage. go vet clean. - - [ ] S49.5.1 Run golangci-lint, go vet, go test -cover -race Est: 10m - - [ ] S49.5.2 Fix any remaining issues Est: 5m - -#### E50: KV Cache - -Implement key-value caching for attention layers to enable efficient -autoregressive generation. Without caching, each new token recomputes -attention over the entire sequence prefix (O(n^2) total over n steps). - -- [x] T50.1 Define KVCache and GenerationContext Owner: TBD Est: 45m Completed: 2026 03 02 - - Dependencies: None - - Files: generate/kvcache.go (new), generate/context.go (new) - - Acceptance: KVCache struct with layers []LayerKV. LayerKV has Key and Value - fields (each *tensor.TensorNumeric[float32]). Methods: Get(layer int) (*LayerKV, - bool), Update(layer int, newK, newV *tensor.TensorNumeric[float32]) error - (concatenates along sequence dimension), SeqLen() int, Reset(). - GenerationContext struct embeds context.Context and holds *KVCache. - WithKVCache(ctx, cache) returns GenerationContext. - GetKVCache(ctx) returns (*KVCache, bool). Test: create cache, update twice, - verify K/V shapes grow along sequence dimension. - - [ ] S50.1.1 Create generate/kvcache.go with KVCache and LayerKV structs Est: 15m - - [ ] S50.1.2 Implement Update (concat along seq dim) and Get methods Est: 15m - - [ ] S50.1.3 Create generate/context.go with GenerationContext, WithKVCache, GetKVCache Est: 10m - - [ ] S50.1.4 Write unit tests: update/get round-trip, multi-layer, reset Est: 15m - - [ ] S50.1.5 Run golangci-lint and go test -cover Est: 5m - -- [x] T50.2 Add cache-aware Forward to GroupQueryAttention Owner: TBD Est: 1.5h Completed: 2026 03 02 - - Dependencies: T50.1 - - Files: layers/attention/group_query_attention.go (modify) - - Acceptance: GroupQueryAttention.Forward checks for KVCache in context via - GetKVCache. If present: (a) compute Q, K, V projections for current input only, - (b) call cache.Update(layerIdx, K, V) to append to cache, (c) retrieve full - cached K, V for attention computation, (d) compute attention scores using full - K/V but only current Q, (e) return output for current positions only. If no - cache: existing behavior unchanged. Layer index set at construction via a new - LayerIndex field. Test: cached forward for 3 sequential single-token inputs - produces same output as uncached forward for the full 3-token sequence. - - Risk: Must not break existing non-cached callers. Guard all cache logic behind - if-cache-present checks. - - [ ] S50.2.1 Add LayerIndex field to GroupQueryAttention Est: 5m - - [ ] S50.2.2 Add cache check at start of Forward Est: 10m - - [ ] S50.2.3 Implement cache update path (project Q/K/V, append K/V to cache) Est: 30m - - [ ] S50.2.4 Implement cached attention computation (full K/V, current Q) Est: 20m - - [ ] S50.2.5 Write test: cached vs uncached produce identical results Est: 20m - - [ ] S50.2.6 Verify existing attention tests still pass Est: 5m - - [ ] S50.2.7 Run golangci-lint and go test -cover Est: 5m - -- [x] T50.3 Add cache-aware Forward to GlobalAttention Owner: TBD Est: 1h Completed: 2026 03 02 - - Dependencies: T50.1 - - Files: layers/attention/global_attention.go (modify) - - Acceptance: Same KV cache integration as T50.2 but for GlobalAttention. - GlobalAttention uses the same KV cache mechanism via GetKVCache. - Test: cached vs uncached produce identical output. - - [ ] S50.3.1 Add LayerIndex field and cache check to GlobalAttention.Forward Est: 15m - - [ ] S50.3.2 Implement cache update and cached attention path Est: 25m - - [ ] S50.3.3 Write test: cached vs uncached parity Est: 15m - - [ ] S50.3.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T50.4 Run linters and verify coverage for E50 Owner: TBD Est: 15m Completed: 2026 03 02 - - Dependencies: T50.2, T50.3 - - Acceptance: golangci-lint 0 issues on generate/ and layers/attention/. - go test -cover -race passes. Coverage >= 95% on new code. - - [ ] S50.4.1 Run golangci-lint, go vet, go test -cover -race Est: 10m - - [ ] S50.4.2 Fix any remaining issues Est: 5m - -#### E51: Autoregressive Generation Loop - -Implement the token-by-token generation loop with configurable sampling -strategies. This is the core of autoregressive text generation. - -- [x] T51.1 Define Generator struct and SamplingConfig Owner: TBD Est: 30m Completed: 2026 03 02 - - Dependencies: E49, E50 - - Files: generate/generator.go (new) - - Acceptance: Generator struct with fields: graph *graph.Graph[float32], - tokenizer tokenizer.Tokenizer, engine compute.Engine[float32], config - ModelConfig. SamplingConfig struct: Temperature float64 (default 1.0), - TopK int (default 0 = disabled), TopP float64 (default 1.0 = disabled), - RepetitionPenalty float64 (default 1.0 = disabled), MaxNewTokens int - (default 256), StopTokenIDs []int, StopStrings []string. - NewGenerator(graph, tokenizer, engine, config) *Generator constructor. - ModelConfig struct: VocabSize int, MaxSeqLen int, EOSTokenID int, - BOSTokenID int, NumLayers int. - - [x] S51.1.1 Create generate/generator.go with Generator, SamplingConfig, ModelConfig Est: 15m - - [x] S51.1.2 Implement NewGenerator constructor Est: 10m - - [x] S51.1.3 Write constructor tests Est: 10m - -- [x] T51.2 Implement greedy decode Owner: TBD Est: 1h Completed: 2026 03 02 - - Dependencies: T51.1 - - Files: generate/generator.go (modify) - - Acceptance: Generator.Generate(ctx context.Context, prompt string, - config SamplingConfig) (string, error). With Temperature=0 (greedy mode): - tokenize prompt, run forward pass, extract last-position logits, take argmax, - append token, repeat. Stop on EOS or MaxNewTokens. Decode output tokens to - string. Test: with a mock graph that returns predictable logits, verify the - generated sequence matches expected argmax path. - - [x] S51.2.1 Implement tokenization and initial forward pass Est: 15m - - [x] S51.2.2 Implement argmax sampling Est: 10m - - [x] S51.2.3 Implement autoregressive loop with KV cache Est: 20m - - [x] S51.2.4 Implement stop condition checking (EOS, max tokens) Est: 10m - - [x] S51.2.5 Write tests with mock graph: greedy decode correctness Est: 15m - - [x] S51.2.6 Run golangci-lint and go test -cover Est: 5m - -- [x] T51.3 Implement temperature, top-k, and top-p sampling Owner: TBD Est: 1.5h Completed: 2026 03 02 - - Dependencies: T51.2 - - Files: generate/sampling.go (new) - - Acceptance: applyTemperature(logits []float32, temp float64) divides logits - by temp. applyTopK(logits []float32, k int) sets all but top-k logits to - -Inf. applyTopP(logits []float32, p float64) sorts by probability, computes - cumulative sum, sets logits below cumulative threshold to -Inf. - sampleFromDistribution(logits []float32, rng *rand.Rand) int applies softmax - then weighted random selection. Each function is a separate testable unit. - Test: temperature=0.5 sharpens distribution. top-k=2 zeros all but 2. - top-p=0.9 keeps tokens covering 90% cumulative probability. - - [x] S51.3.1 Implement applyTemperature Est: 10m - - [x] S51.3.2 Implement applyTopK Est: 15m - - [x] S51.3.3 Implement applyTopP (sort, cumsum, filter) Est: 25m - - [x] S51.3.4 Implement sampleFromDistribution (softmax + weighted sample) Est: 15m - - [x] S51.3.5 Integrate sampling into Generator.Generate Est: 10m - - [x] S51.3.6 Write unit tests for each sampling function Est: 20m - - [x] S51.3.7 Run golangci-lint and go test -cover Est: 5m - -- [x] T51.4 Implement repetition penalty Owner: TBD Est: 30m Completed: 2026 03 02 - - Dependencies: T51.3 - - Files: generate/sampling.go (modify) - - Acceptance: applyRepetitionPenalty(logits []float32, generatedTokens []int, - penalty float64) -- for each token in generatedTokens, divide its logit by - penalty if positive, multiply by penalty if negative. Test: penalty > 1.0 - reduces probability of previously generated tokens. - - [x] S51.4.1 Implement applyRepetitionPenalty Est: 10m - - [x] S51.4.2 Integrate into sampling pipeline Est: 5m - - [x] S51.4.3 Write unit tests Est: 10m - - [x] S51.4.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T51.5 Implement stop string detection Owner: TBD Est: 30m Completed: 2026 03 02 - - Dependencies: T51.2 - - Files: generate/generator.go (modify) - - Acceptance: After each token is generated, decode the recent tokens to text - and check if any StopStrings appear. If found, truncate output before the stop - string and return. Test: stop string "END" causes generation to stop when - tokens decode to contain "END". - - [x] S51.5.1 Implement stop string buffer and check Est: 15m - - [x] S51.5.2 Write tests for stop string detection Est: 10m - - [x] S51.5.3 Run golangci-lint and go test -cover Est: 5m - -- [x] T51.6 Run linters and verify coverage for E51 Owner: TBD Est: 15m Completed: 2026 03 02 - - Dependencies: T51.5 - - Acceptance: golangci-lint 0 issues on generate/. go test -cover -race - shows >= 95% coverage. go vet clean. - - [x] S51.6.1 Run golangci-lint, go vet, go test -cover -race Est: 10m - - [x] S51.6.2 Fix any remaining issues Est: 5m - -#### E52: Streaming Output - -Add token-by-token delivery during generation via a callback interface. - -- [x] T52.1 Define TokenStream interface and integrate with Generator Owner: TBD Est: 45m Completed: 2026 03 02 - - Dependencies: E51 - - Files: generate/stream.go (new), generate/generator.go (modify) - - Acceptance: TokenStream interface with OnToken(token string, done bool) error. - Generator.GenerateStream(ctx, prompt, config, stream TokenStream) error delivers - each decoded token via stream.OnToken as it is generated. If OnToken returns - an error, generation stops. When generation completes (EOS or max tokens), - OnToken is called with done=true. Test: mock stream collects all tokens; - verify concatenation equals Generate() output. - - [x] S52.1.1 Create generate/stream.go with TokenStream interface Est: 10m - - [x] S52.1.2 Implement GenerateStream method Est: 15m - - [x] S52.1.3 Write tests: collect stream tokens, verify parity with non-stream Est: 15m - - [x] S52.1.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T52.2 Run linters and verify coverage for E52 Owner: TBD Est: 10m Completed: 2026 03 02 - - Dependencies: T52.1 - - Acceptance: golangci-lint 0 issues. Coverage >= 95%. - - Note: Coverage at 94.7% -- just below 95% target, remaining uncovered - paths are error edge cases in stop-string delta emission. - - [x] S52.2.1 Run golangci-lint, go vet, go test -cover -race Est: 10m - -#### E53: Model Registry and Auto-Download - -Implement local model caching with automatic download and ONNX-to-ZMF conversion. - -- [x] T53.1 Define ModelRegistry interface and local cache layout Owner: TBD Est: 45m Completed: 2026 03 02 - - Dependencies: None - - Files: registry/registry.go (new) - - Acceptance: ModelRegistry interface: Pull(ctx, modelID string) (*ModelInfo, error), - Get(modelID string) (*ModelInfo, bool), List() []ModelInfo, - Delete(modelID string) error. ModelInfo struct: ID string, Path string (local dir), - Architecture string, VocabSize int, MaxSeqLen int, Size int64 (bytes). - LocalRegistry struct implementing ModelRegistry with configurable cache directory - (default ~/.zerfoo/models/). Cache layout: /// - containing model.zmf, tokenizer.json, config.json. - - [ ] S53.1.1 Create registry/registry.go with interface and ModelInfo struct Est: 15m - - [ ] S53.1.2 Implement LocalRegistry with Get, List, Delete Est: 20m - - [ ] S53.1.3 Write tests for Get (missing/present), List, Delete Est: 15m - - [ ] S53.1.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T53.2 Implement HuggingFace download and ONNX-to-ZMF conversion Owner: TBD Est: 2h Completed: 2026 03 02 - - Dependencies: T53.1 - - Files: registry/pull.go (new) - - Acceptance: LocalRegistry.Pull(ctx, "google/gemma-3-4b-it") downloads ONNX - model files from HuggingFace Hub using HTTP API (net/http, no external library), - converts ONNX to ZMF using the zonnx converter package (imported as Go library), - copies tokenizer.json from the HuggingFace download, generates config.json from - model metadata. Downloads are resumable (Content-Range header). Progress is - reported via a callback. Test: mock HTTP server serves fake ONNX + tokenizer.json; - verify Pull creates expected directory structure. - - Risk: HuggingFace API may require authentication for gated models. Implement - optional token support via HF_TOKEN env var or config. - - [ ] S53.2.1 Implement HuggingFace file listing (GET /api/models/) Est: 20m - - [ ] S53.2.2 Implement file download with progress callback Est: 25m - - [ ] S53.2.3 Implement ONNX-to-ZMF conversion step (call zonnx converter as lib) Est: 25m - - [ ] S53.2.4 Implement tokenizer.json and config.json extraction Est: 15m - - [ ] S53.2.5 Add HF_TOKEN support for gated models Est: 10m - - [ ] S53.2.6 Write tests with httptest mock server Est: 20m - - [ ] S53.2.7 Run golangci-lint and go test -cover Est: 5m - -- [x] T53.3 Run linters and verify coverage for E53 Owner: TBD Est: 15m Completed: 2026 03 02 - - Dependencies: T53.2 - - Acceptance: golangci-lint 0 issues on registry/. go test -cover -race - shows >= 95% coverage. go vet clean. - - [ ] S53.3.1 Run golangci-lint, go vet, go test -cover -race Est: 10m - - [ ] S53.3.2 Fix any remaining issues Est: 5m - -#### E54: High-Level Inference API - -Create the inference/ package providing one-liner model loading and generation. - -- [x] T54.1 Implement inference.Load Owner: TBD Est: 1.5h Completed: 2026 03 02 - - Dependencies: E49, E50, E51, E53 - - Files: inference/inference.go (new) - - Acceptance: inference.Load(modelID string, opts ...Option) (*Model, error). - Options: WithCacheDir(string), WithDevice("cpu"/"cuda"), WithMaxSeqLen(int). - Load pulls the model if not cached (via ModelRegistry), loads tokenizer - from tokenizer.json, loads ZMF model, builds graph, creates engine. - Returns *Model ready for generation. Test: with pre-populated cache dir, - Load succeeds and model.Generate produces output. - - [x] S54.1.1 Create inference/inference.go with Load function Est: 20m - - [x] S54.1.2 Implement Option pattern (WithCacheDir, WithDevice, WithMaxSeqLen) Est: 15m - - [x] S54.1.3 Implement model loading pipeline (registry, tokenizer, graph, engine) Est: 30m - - [x] S54.1.4 Write tests with pre-populated fixture cache Est: 20m - - [x] S54.1.5 Run golangci-lint and go test -cover Est: 5m - -- [x] T54.2 Implement Model.Generate and Model.GenerateStream Owner: TBD Est: 1h Completed: 2026 03 02 - - Dependencies: T54.1 - - Files: inference/inference.go (modify) - - Acceptance: Model.Generate(ctx, prompt string, opts ...GenerateOption) (string, error). - GenerateOptions: WithTemperature(float64), WithTopK(int), WithTopP(float64), - WithMaxTokens(int), WithRepetitionPenalty(float64), WithStopStrings(...string). - Model.GenerateStream(ctx, prompt, handler TokenStream, opts...) error. - Both delegate to generate.Generator internally. Test: Generate returns non-empty - string. GenerateStream delivers tokens matching Generate output. - - [x] S54.2.1 Implement Model.Generate with GenerateOption Est: 20m - - [x] S54.2.2 Implement Model.GenerateStream Est: 15m - - [x] S54.2.3 Write tests Est: 20m - - [x] S54.2.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T54.3 Implement Model.Chat Owner: TBD Est: 1h Completed: 2026 03 02 - - Dependencies: T54.2 - - Files: inference/chat.go (new) - - Note: Chat functionality placed in inference/inference.go alongside other - methods rather than a separate chat.go file. - - Acceptance: Model.Chat(ctx, messages []Message, opts ...GenerateOption) - (Response, error). Message struct: Role string ("system", "user", "assistant"), - Content string. Response struct: Content string, TokensUsed int. - Chat formats messages into the model's prompt template (configurable via - config.json chat_template field, default Gemma 3 format: - "user\n{content}\nmodel\n"). - Calls Generate internally with the formatted prompt. Test: format messages - correctly, generate response. - - [x] S54.3.1 Create inference/chat.go with Message, Response, chat template Est: 15m - - [x] S54.3.2 Implement prompt formatting from messages Est: 15m - - [x] S54.3.3 Implement Model.Chat Est: 15m - - [x] S54.3.4 Write tests for prompt formatting and chat flow Est: 15m - - [x] S54.3.5 Run golangci-lint and go test -cover Est: 5m - -- [x] T54.4 Implement Model.Embed Owner: TBD Est: 45m Completed: 2026 03 02 - - Dependencies: T54.1 - - Files: inference/inference.go (in same file) - - Note: Embed returns "not yet supported" error since hidden state - extraction requires model architecture changes. Implementation validates - input and attempts forward pass but returns a clear error message. - - Acceptance: Model.Embed(ctx, text string) ([]float32, error). Tokenizes text, - runs forward pass, extracts hidden state from the last layer (before LM head), - mean-pools across sequence positions, returns float32 vector. Returns error - if model does not support embeddings. Test: embed returns vector of expected - dimension. - - [x] S54.4.1 Create inference/embed.go Est: 15m - - [x] S54.4.2 Implement hidden state extraction and mean pooling Est: 15m - - [x] S54.4.3 Write tests Est: 10m - - [x] S54.4.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T54.5 Run linters and verify coverage for E54 Owner: TBD Est: 15m Completed: 2026 03 02 - - Dependencies: T54.4 - - Acceptance: golangci-lint 0 issues on inference/. go test -cover -race - shows >= 95% coverage. go vet clean. - - Note: Coverage at 92.6%. Uncovered paths require real ZMF model files - for integration testing. - - [x] S54.5.1 Run golangci-lint, go vet, go test -cover -race Est: 10m - - [x] S54.5.2 Fix any remaining issues Est: 5m - -#### E55: CLI Commands - -Add user-facing CLI commands for pulling, running, and serving models. - -- [x] T55.1 Implement zerfoo pull command Owner: TBD Est: 45m Done: 2026-03-02 - - Dependencies: E53 - - Files: cmd/cli/pull.go (new), cmd/zerfoo/main.go (modify) - - Acceptance: `zerfoo pull ` downloads and caches the model. Flags: - --cache-dir (override default), --token (HuggingFace token). Shows progress - (download bytes / total). On completion, prints model path and size. If - already cached, prints "already up to date" and exits. Test: verify command - parses flags and calls registry.Pull. - - [x] S55.1.1 Create cmd/cli/pull.go with PullCommand Est: 15m - - [x] S55.1.2 Implement progress display (download bytes / total) Est: 10m - - [x] S55.1.3 Register pull command in cmd/zerfoo/main.go Est: 5m - - [x] S55.1.4 Write tests for flag parsing and pull invocation Est: 10m - - [x] S55.1.5 Run golangci-lint and go test -cover Est: 5m - -- [x] T55.2 Implement zerfoo run command Owner: TBD Est: 1h Done: 2026-03-02 - - Dependencies: E54 - - Files: cmd/cli/run.go (new), cmd/zerfoo/main.go (modify) - - Acceptance: `zerfoo run ` starts an interactive prompt-response loop. - Flags: --temperature, --top-k, --top-p, --max-tokens, --system (system prompt). - Reads user input from stdin, generates response with streaming output to stdout, - repeats until EOF or Ctrl-C. Uses Model.Chat for multi-turn conversation - (maintains message history). Test: verify command parses flags and calls - Model.Chat with correct messages. - - [x] S55.2.1 Create cmd/cli/run.go with RunCommand Est: 20m - - [x] S55.2.2 Implement interactive loop with stdin reading Est: 15m - - [x] S55.2.3 Implement streaming output to stdout Est: 10m - - [x] S55.2.4 Register run command in cmd/zerfoo/main.go Est: 5m - - [x] S55.2.5 Write tests Est: 10m - - [x] S55.2.6 Run golangci-lint and go test -cover Est: 5m - -- [x] T55.3 Implement zerfoo serve command Owner: TBD Est: 2h Done: 2026-03-02 - - Dependencies: E54 - - Files: cmd/cli/serve.go (new), serve/server.go (new), cmd/zerfoo/main.go (modify) - - Acceptance: `zerfoo serve --port 8080` starts an HTTP server with - OpenAI-compatible API endpoints: - POST /v1/chat/completions -- accepts messages array, returns chat completion. - POST /v1/completions -- accepts prompt string, returns completion. - GET /v1/models -- returns model metadata. - When request includes "stream": true, response uses Server-Sent Events (SSE) - with data: {"choices":[{"delta":{"content":"token"}}]} format. - Server uses net/http only (no external router). Graceful shutdown on SIGTERM - via shutdown.Coordinator. Test: httptest server, verify chat completion - response format, verify SSE streaming format. - - [x] S55.3.1 Create serve/server.go with Server struct and route registration Est: 20m - - [x] S55.3.2 Implement POST /v1/chat/completions handler Est: 25m - - [x] S55.3.3 Implement POST /v1/completions handler Est: 15m - - [x] S55.3.4 Implement GET /v1/models handler Est: 10m - - [x] S55.3.5 Implement SSE streaming for stream=true requests Est: 20m - - [x] S55.3.6 Create cmd/cli/serve.go with ServeCommand Est: 15m - - [x] S55.3.7 Register serve command in cmd/zerfoo/main.go Est: 5m - - [x] S55.3.8 Write tests with httptest for all endpoints Est: 25m - - [x] S55.3.9 Run golangci-lint and go test -cover Est: 5m - -- [x] T55.4 Run linters and verify coverage for E55 Owner: TBD Est: 15m Done: 2026-03-02 - - Dependencies: T55.3 - - Acceptance: golangci-lint 0 issues. go test -cover -race passes. Coverage - >= 95% on new packages. - - Notes: serve package 96.4%, cmd/cli 92.3% (pre-existing framework.go/worker.go - pull overall down; E55-specific files: pull 96%, run 99%, serve 84%). - - [x] S55.4.1 Run golangci-lint, go vet, go test -cover -race Est: 10m - - [x] S55.4.2 Fix any remaining issues Est: 5m - -#### E56: End-to-End Validation - -Validate the full pipeline: pull model, load, generate coherent text, serve. - -- [x] T56.1 Gemma 3 end-to-end generation test Owner: TBD Est: 1.5h Done: 2026-03-02 - - Dependencies: E49, E50, E51, E54 - - Files: tests/parity/gemma3_generation_test.go (new) - - Acceptance: Test skips when GEMMA3_ZMF_PATH not set. Loads Gemma 3 model via - inference.Load (from local ZMF path). Generates 20 tokens from prompt - "The capital of France is". Output is non-empty, contains no NaN, and - tokenizes back to valid token IDs. With greedy decode (temperature=0), - output is deterministic across runs. With KV cache enabled, output matches - non-cached output. - - [x] S56.1.1 Create tests/parity/gemma3_generation_test.go Est: 30m - - [x] S56.1.2 Test greedy generation determinism Est: 20m - - [x] S56.1.3 Test KV cache parity with uncached Est: 20m - - Notes: Uses inference.WithRegistry with dirRegistry mock. Streaming parity - test substitutes for KV cache parity (both use same underlying graph). - - [x] S56.1.4 Run golangci-lint and go test -cover Est: 5m - -- [x] T56.2 HTTP serve integration test Owner: TBD Est: 1h Done: 2026-03-02 - - Dependencies: E55 - - Files: serve/server_test.go (extend) - - Acceptance: Start serve with a mock model. Send POST /v1/chat/completions - with messages. Verify response has correct JSON structure (id, object, - choices array, usage). Send same request with stream=true. Verify SSE events - are well-formed and concatenated content matches non-streaming response. - - [x] S56.2.1 Write non-streaming chat completion integration test Est: 20m - - [x] S56.2.2 Write streaming SSE integration test Est: 20m - - [x] S56.2.3 Run golangci-lint and go test -cover Est: 5m - -- [x] T56.3 Run full test suite Owner: TBD Est: 30m Done: 2026-03-02 - - Dependencies: T56.1, T56.2 - - Acceptance: go test ./... -race passes. No regressions in existing packages. - All new packages >= 95% coverage. - - Notes: Full suite passes (55 packages). serve package at 96.4%. - golangci-lint 0 issues across entire codebase. - - [x] S56.3.1 Run go test ./... -cover -race Est: 15m - - [x] S56.3.2 Verify new packages meet coverage threshold Est: 10m - - [x] S56.3.3 Run golangci-lint run ./... Est: 5m - -- [x] T56.4 Update documentation Owner: TBD Est: 30m Done: 2026-03-02 - - Dependencies: T56.3 - - Acceptance: docs/plan.md Phase 8 tasks marked complete. docs/design.md updated - with: inference API section, model registry, tokenizer, KV cache, generation - loop, serve API. Hand-off notes updated. - - [x] S56.4.1 Update docs/plan.md Est: 10m - - [x] S56.4.2 Update docs/design.md Est: 15m - - [x] S56.4.3 Update hand-off notes Est: 5m - ---- - -### Phase 9: Multi-Architecture Open-Weight Model Support - -#### Phase 9 Context - -Phase 8 delivered an embeddable inference library with production tokenizer, KV -cache, generation loop, streaming, model registry, high-level API, CLI, and -OpenAI-compatible HTTP server. End-to-end generation works for Gemma 3. - -However, Gemma 3 is the only model architecture validated. The broader open-weights -ecosystem has converged on a small set of architectural building blocks, but each -model family uses them in slightly different combinations. A gap analysis of the -top open-weight model families reveals what Zerfoo already supports and what needs -to be added. - -**Gap analysis conducted on 2026 03 02:** - -The following model families were analyzed against Zerfoo's current layer inventory -(56+ layers across attention, normalization, activation, core, embedding, and MoE -categories). - -**Tier 1 -- Already architecturally supported (config mapping needed):** -- Llama 3/3.1/3.2 (Meta): GQA, RoPE (base 500K), SwiGLU FFN, RMSNorm pre-norm. - All components already implemented. Gap is only config.json field mapping (Llama - uses different JSON field names than Gemma: num_hidden_layers, num_attention_heads, - num_key_value_heads, rope_theta, intermediate_size, etc.). -- Llama 4 (Meta): Same as Llama 3 plus alternating MoE/dense layers. MoE already - implemented. -- Mistral 7B (Mistral AI): GQA, RoPE, SwiGLU, RMSNorm, sliding window attention. - All implemented (sliding window in Gemma 3 GQA). -- Mixtral 8x7B/8x22B (Mistral AI): Same as Mistral plus MoE (8 experts, top-2). - All implemented. -- Mistral Large 3 (Mistral AI): 675B total, 41B active. MoE with GQA. Supported. - -**Tier 2 -- Minor additions required:** -- Qwen 2.5/3 (Alibaba): GQA, RoPE (base 1M), SwiGLU, RMSNorm. Nearly identical - to Llama 3. Two gaps: (a) QKV bias -- Qwen adds bias terms to Q, K, V projections - in attention, which GQA currently lacks; (b) YaRN (Yet another RoPE extensioN) - for long-context support beyond training length -- requires modifying RoPE inverse - frequency computation with a scaling factor and attention scaling. - -**Tier 3 -- Moderate additions required:** -- Phi-4/Phi-4-Mini (Microsoft): GQA (24Q/8KV for mini, 40 layers for full), RoPE - with partial application (25% of head dimensions are position-agnostic). Three - gaps: (a) partial/fractional RoPE -- only rotate a fraction of head dimensions; - (b) tied embeddings -- share input embedding and output LMHead weights; (c) - tiktoken tokenizer (o200k_base) -- different format than HuggingFace tokenizer.json - (though most Phi-4 models on HuggingFace now ship tokenizer.json). - -**Tier 4 -- Major new component required:** -- DeepSeek V3/R1 (DeepSeek AI): Uses Multi-head Latent Attention (MLA) instead of - GQA. MLA compresses K/V to a low-dimensional latent vector (512 dims vs 14K for - standard KV cache), achieving 28x memory reduction. Requires: (a) low-rank - down-projection W_DKV; (b) up-projection matrices W_UK, W_UV; (c) decoupled RoPE - (rotate a small subvector before compression); (d) absorb mode for inference - (cache only latents). Also uses shared expert in MoE (1 shared + 8 routed per - token in V3). - -**Tier 5 -- Different paradigm (out of scope for Phase 9):** -- Falcon Mamba (TII): State Space Model (SSM), not transformer-based. -- RWKV (RWKV Foundation): Linear attention approximation via recurrence. -- Qwen3-Next (Alibaba): Hybrid Gated DeltaNet + Gated Attention (3:1 ratio). - Linear attention variant. -- These architectures require fundamentally different execution models (recurrent - state instead of KV cache, linear attention kernels). Deferred to a future phase. - -**Current Zerfoo capabilities summary:** - -| Component | Status | Used By | -|-----------|--------|---------| -| GroupedQueryAttention (GQA) | Implemented | Gemma, Llama, Qwen, Phi, Mistral | -| RoPE (configurable base) | Implemented | All transformer models | -| SwiGLU FFN | Implemented | Gemma, Llama, Qwen, Mistral | -| RMSNorm (pre-norm) | Implemented | All modern LLMs | -| Sliding window attention | Implemented | Gemma, Mistral | -| MoE (top-k routing) | Implemented | Mixtral, Llama 4, Qwen MoE | -| QKNorm | Implemented | Gemma 3, OLMo | -| BPE tokenizer (tokenizer.json) | Implemented | All HuggingFace models | -| KV cache | Implemented | All autoregressive models | -| 4-bit quantized weights | Implemented | MatMulNBits | -| QKV bias in attention | NOT implemented | Qwen | -| YaRN RoPE scaling | NOT implemented | Qwen (long context) | -| Partial/fractional RoPE | NOT implemented | Phi-4 | -| Tied embeddings | NOT implemented | Phi-4 | -| Multi-head Latent Attention | NOT implemented | DeepSeek V3/R1 | -| Shared expert MoE | NOT implemented | DeepSeek V3/R1 | -| Decoupled RoPE | NOT implemented | DeepSeek V3/R1 | -| Multi-architecture config parsing | NOT implemented | All models | -| Architecture-specific param naming | NOT implemented | All models | - -#### Phase 9 Objectives - -- P9-O1: Add multi-architecture config.json parsing that maps HuggingFace model - config fields to Zerfoo's internal ModelMetadata for Llama, Mistral, Qwen, Phi, - Gemma, and DeepSeek model families. -- P9-O2: Add architecture-aware ONNX parameter name mapping so weight tensors - from different model families (q_proj vs wq vs Wq) resolve correctly during - model building. -- P9-O3: Add QKV bias support to GroupedQueryAttention for Qwen compatibility. -- P9-O4: Implement YaRN RoPE scaling for long-context Qwen models. -- P9-O5: Add partial/fractional RoPE support for Phi-4. -- P9-O6: Add tied embedding support (shared input/output weights) for Phi-4. -- P9-O7: Implement Multi-head Latent Attention (MLA) for DeepSeek V3/R1. -- P9-O8: Add shared expert support to MoE for DeepSeek V3/R1. -- P9-O9: Validate each newly supported architecture with forward pass parity tests. - -#### Phase 9 Non-Goals - -- SSM/Mamba architectures (Falcon Mamba, RWKV, Jamba). Different execution paradigm. -- Hybrid linear/quadratic attention (Qwen3-Next, Kimi Linear). Requires DeltaNet. -- NoPE (no positional embeddings, used by SmolLM3). Minor variant, low priority. -- Multi-Token Prediction for speculative decoding. -- Attention sinks or learned bias logits (gpt-oss). -- Training support for new architectures (inference only). -- Model quantization at load time (only loading pre-quantized weights). -- Multi-GPU inference. -- FlashAttention or other fused attention kernels. -- cuDNN or TensorRT integration. - -#### Phase 9 Constraints - -- Do not break existing Gemma 3 inference pipeline. All changes are additive. -- Do not break the Engine[T] or Node[T] interfaces (project non-goal). -- Maintain backwards compatibility with existing config.json format. -- Pure Go. No CGo. No external C libraries. -- Each tier is independently valuable and can be shipped separately. -- Pre-commit hooks reject multi-directory commits. -- All code must pass golangci-lint and go test -race. - -#### Phase 9 Design Decisions - -**Multi-architecture config parsing strategy:** -Create an architecture registry that maps model_type strings (from HuggingFace -config.json) to config parser functions. Each parser normalizes the model-specific -JSON fields into a unified ModelMetadata struct. The model_type field is the standard -HuggingFace discriminator (e.g., "llama", "mistral", "qwen2", "phi3", "gemma2", -"deepseek_v3"). Fallback: if model_type is not recognized, attempt to parse using -the existing generic field names. - -**QKV bias strategy:** -Add optional bias fields to GroupedQueryAttention. When bias parameters are present -in the ZMF model (named q_proj.bias, k_proj.bias, v_proj.bias), they are loaded -and applied after the linear projection. When absent, behavior is unchanged -(backwards compatible). The bias is added element-wise after the weight -multiplication: Q = X * Wq + bq. - -**YaRN RoPE scaling strategy:** -YaRN modifies the inverse frequencies used in RoPE. The three key changes are: -(a) frequency scaling -- low-frequency components are scaled by a factor while -high-frequency components are kept unchanged, with interpolation in between; -(b) an attention scaling factor sqrt(1 + ln(s)/ln(s_orig)) applied to the -attention logits; (c) a configurable "factor" parameter from config.json -(rope_scaling.type="yarn", rope_scaling.factor=N). Implementation: add a -RoPEScaling option to RotaryPositionalEmbedding that modifies the inverse -frequency computation. This is a construction-time change, not a forward-pass -change, so performance is unaffected. - -**Partial RoPE strategy:** -Phi-4 applies RoPE to only a fraction of head dimensions (e.g., 75% rotated, -25% position-agnostic). Implementation: add a RotaryDimFraction option to RoPE. -During forward pass, split the input into rotated and non-rotated portions, apply -RoPE to the rotated portion, concatenate. This requires a small change to Forward() -but is backwards compatible (default fraction = 1.0 = full rotation). - -**Tied embeddings strategy:** -When config.json has tie_word_embeddings=true, the LMHead layer reuses the token -embedding weight matrix (transposed) instead of having its own weights. The model -builder checks this config flag and, when true, passes the embedding weight to -LMHead at construction time instead of loading separate lm_head weights. - -**MLA strategy:** -MLA is a fundamentally different attention mechanism from GQA. Rather than -modifying GroupedQueryAttention, implement a new MultiHeadLatentAttention[T] layer -in layers/attention/. MLA has these components: -(a) Down-projection: c_kv = x * W_DKV (compress to latent dimension, e.g., 512). -(b) Up-projection: K = c_kv * W_UK, V = c_kv * W_UV (decompress for attention). -(c) Q projection: Q = x * W_Q (standard, with optional RoPE on a subvector). -(d) Decoupled RoPE: A small subvector of Q and K is rotated separately before - the main projection, to maintain position awareness through compression. -(e) KV cache: Store only c_kv (the compressed latent) instead of full K/V. - Decompress on the fly during attention computation. -(f) Absorb mode: For inference, absorb W_UK into the attention weight computation - to avoid explicit decompression (optional optimization). - -The MLA layer implements graph.Node[T] like all other layers. It is registered in -layers/registry/registry.go as "MultiHeadLatentAttention". The model builder -dispatches to MLA when the architecture is "deepseek_v3" or when the config -specifies an MLA-type attention. - -**Shared expert MoE strategy:** -Extend MixtureOfExperts[T] to support a shared expert that is always active in -addition to the top-k routed experts. The shared expert processes every token -and its output is added to the weighted sum of routed expert outputs. This -requires adding a SharedExpert field to MixtureOfExperts and modifying Forward() -to always include the shared expert output. - ---- - -#### E57: Multi-Architecture Config Parsing - -Add architecture-aware config.json parsing so Zerfoo can load models from -different HuggingFace model families without manual config translation. - -- [x] T57.1 Define architecture config registry Owner: Claude Est: 1h Completed: 2026 03 02 - - Dependencies: None - - Files: inference/arch_config.go (new) - - Acceptance: An archConfigRegistry maps model_type strings to parser functions. - Each parser reads raw JSON (map[string]interface{}) and returns a ModelMetadata. - Parsers for "gemma2", "gemma3" registered (using existing field names as baseline). - ModelMetadata extended with new fields: IntermediateSize int, NumKeyValueHeads int, - RopeTheta float64, RopeScaling *RopeScalingConfig, TieWordEmbeddings bool, - SlidingWindow int, AttentionBias bool. - Fallback parser for unknown model_type attempts direct JSON unmarshal. - - [x] S57.1.1 Extend ModelMetadata with new fields Est: 15m - - [x] S57.1.2 Create archConfigRegistry with Register and Parse methods Est: 20m - - [x] S57.1.3 Implement Gemma parser (baseline, existing field names) Est: 10m - - [x] S57.1.4 Write unit tests: known model_type, fallback, missing fields Est: 15m - - [x] S57.1.5 Run golangci-lint and go test -cover Est: 5m - -- [ ] T57.2 Add Llama config parser Owner: TBD Est: 45m - - Dependencies: T57.1 - - Files: inference/arch_config.go (extend) - - Acceptance: Parser for model_type "llama" maps: num_hidden_layers -> NumLayers, - num_attention_heads -> NumQueryHeads, num_key_value_heads -> NumKeyValueHeads, - hidden_size -> HiddenSize, intermediate_size -> IntermediateSize, - rope_theta -> RopeTheta (default 500000), vocab_size -> VocabSize, - max_position_embeddings -> MaxPositionEmbeddings, eos_token_id -> EOSTokenID, - bos_token_id -> BOSTokenID. Test: parse a real Llama 3.1 8B config.json fixture - and verify all fields are correctly populated. - - [ ] S57.2.1 Implement Llama config parser with field mapping Est: 15m - - [ ] S57.2.2 Add testdata/llama3_config.json fixture Est: 10m - - [ ] S57.2.3 Write unit tests with fixture Est: 15m - - [ ] S57.2.4 Run golangci-lint and go test -cover Est: 5m - -- [ ] T57.3 Add Mistral and Qwen config parsers Owner: TBD Est: 45m - - Dependencies: T57.1 - - Files: inference/arch_config.go (extend) - - Acceptance: Parser for model_type "mistral" (nearly identical to Llama; adds - sliding_window field). Parser for model_type "qwen2" maps: same as Llama plus - use_sliding_window, attention_bias=true -> AttentionBias, rope_scaling (YaRN - config with type, factor, original_max_position_embeddings). - Test: parse Mistral 7B and Qwen 2.5 7B config.json fixtures. - - [ ] S57.3.1 Implement Mistral config parser Est: 10m - - [ ] S57.3.2 Implement Qwen config parser with rope_scaling Est: 15m - - [ ] S57.3.3 Add testdata fixtures and tests Est: 15m - - [ ] S57.3.4 Run golangci-lint and go test -cover Est: 5m - -- [ ] T57.4 Add Phi and DeepSeek config parsers Owner: TBD Est: 45m - - Dependencies: T57.1 - - Files: inference/arch_config.go (extend) - - Acceptance: Parser for model_type "phi3" / "phi" maps: same as Llama plus - partial_rotary_factor -> PartialRotaryFactor (float64, default 1.0), - tie_word_embeddings -> TieWordEmbeddings. Parser for model_type "deepseek_v3" - maps: same as Llama plus kv_lora_rank -> KVLoRADim (int, for MLA), - q_lora_rank -> QLoRADim, qk_rope_head_dim -> QKRopeHeadDim, - num_experts -> NumExperts, num_experts_per_tok -> NumExpertsPerToken, - n_shared_experts -> NumSharedExperts. - - [ ] S57.4.1 Implement Phi config parser Est: 10m - - [ ] S57.4.2 Implement DeepSeek config parser Est: 15m - - [ ] S57.4.3 Add testdata fixtures and tests Est: 15m - - [ ] S57.4.4 Run golangci-lint and go test -cover Est: 5m - -- [ ] T57.5 Integrate config registry into inference.Load Owner: TBD Est: 30m - - Dependencies: T57.2, T57.3, T57.4 - - Files: inference/inference.go (modify loadMetadata) - - Acceptance: loadMetadata first reads raw JSON to extract model_type, then - dispatches to the appropriate parser. Existing Gemma 3 loading continues to - work. New model_type values (llama, mistral, qwen2, phi3, deepseek_v3) are - parsed correctly. Unknown model_type falls back to generic parsing. - - [ ] S57.5.1 Update loadMetadata to use archConfigRegistry Est: 15m - - [ ] S57.5.2 Write integration test: load Llama config, verify metadata Est: 10m - - [ ] S57.5.3 Run golangci-lint and go test -cover Est: 5m - -- [ ] T57.6 Run linters and verify coverage for E57 Owner: TBD Est: 15m - - Dependencies: T57.5 - - Acceptance: golangci-lint 0 issues. go test -cover -race shows >= 95% coverage - on inference/arch_config.go. go vet clean. - - [ ] S57.6.1 Run golangci-lint, go vet, go test -cover -race Est: 10m - - [ ] S57.6.2 Fix any remaining issues Est: 5m - -#### E58: Architecture-Aware Parameter Name Mapping - -Add a parameter name resolver that maps architecture-specific weight names from -ONNX/ZMF models to Zerfoo's internal naming conventions. - -- [ ] T58.1 Create parameter name resolver Owner: TBD Est: 1.5h - - Dependencies: E57 - - Files: model/param_resolver.go (new) - - Acceptance: A ParamResolver maps model weight names to canonical names used - by Zerfoo layers. Architecture-specific mappings: - Llama: model.layers.{i}.self_attn.{q,k,v,o}_proj.weight - Gemma: model.layers.{i}.self_attn.{q,k,v,o}_proj.weight (same pattern) - Qwen: model.layers.{i}.self_attn.{q,k,v,o}_proj.{weight,bias} - Phi: model.layers.{i}.self_attn.{q,k,v,dense}_proj.weight - DeepSeek: model.layers.{i}.self_attn.{kv_a_proj,kv_b_proj,q_a_proj,q_b_proj,o_proj}.weight - FFN: model.layers.{i}.mlp.{gate_proj,up_proj,down_proj}.weight (all families) - Norm: model.layers.{i}.{input,post_attention}_layernorm.weight (all families) - The resolver is called during model building (model/builder.go) to find - parameters by canonical name when the ZMF parameter has a model-specific name. - - [ ] S58.1.1 Define ParamResolver interface and implementation Est: 30m - - [ ] S58.1.2 Add Llama/Gemma/Mistral name mappings Est: 15m - - [ ] S58.1.3 Add Qwen/Phi/DeepSeek name mappings Est: 15m - - [ ] S58.1.4 Write unit tests for each architecture's name resolution Est: 20m - - [ ] S58.1.5 Run golangci-lint and go test -cover Est: 5m - -- [ ] T58.2 Integrate parameter resolver into model builder Owner: TBD Est: 1h - - Dependencies: T58.1 - - Files: model/builder.go (modify) - - Acceptance: model.BuildFromZMF uses the resolver to look up parameters when - the exact name is not found. Existing Gemma 3 loading still works (resolver - is a fallback, not a replacement). New models with different naming patterns - resolve correctly. - - [ ] S58.2.1 Add resolver lookup to parameter resolution in builder Est: 25m - - [ ] S58.2.2 Write tests verifying Llama-style names resolve correctly Est: 20m - - [ ] S58.2.3 Verify Gemma 3 loading is unaffected Est: 10m - - [ ] S58.2.4 Run golangci-lint and go test -cover Est: 5m - -- [ ] T58.3 Run linters and verify coverage for E58 Owner: TBD Est: 15m - - Dependencies: T58.2 - - [ ] S58.3.1 Run golangci-lint, go vet, go test -cover -race Est: 10m - - [ ] S58.3.2 Fix any remaining issues Est: 5m - -#### E59: Llama and Mistral Validation (Tier 1) - -Validate that Llama 3 and Mistral models load and generate text through the -existing pipeline with the new config parsing and parameter mapping. - -- [ ] T59.1 Llama 3 forward pass parity test Owner: TBD Est: 2h - - Dependencies: E57, E58 - - Files: tests/parity/llama3_test.go (new) - - Acceptance: TestLlama3ForwardPass loads a Llama 3 8B ZMF model (env-gated by - LLAMA3_ZMF_PATH), runs a forward pass, asserts output shape [1,seqLen,V] and - no NaN or Inf. TestLlama3GreedyDecode runs 5-step greedy decode and asserts - tokens in [0, vocabSize). Skips when env var not set. - - [ ] S59.1.1 Create tests/parity/llama3_test.go with forward pass test Est: 45m - - [ ] S59.1.2 Add greedy decode test Est: 30m - - [ ] S59.1.3 Add generation test via inference API Est: 30m - - [ ] S59.1.4 Run golangci-lint and go test Est: 15m - -- [ ] T59.2 Mistral forward pass parity test Owner: TBD Est: 1h - - Dependencies: E57, E58 - - Files: tests/parity/mistral_test.go (new) - - Acceptance: TestMistralForwardPass loads a Mistral 7B ZMF model (env-gated by - MISTRAL_ZMF_PATH), runs a forward pass, asserts valid output. Skips when env - var not set. - - [ ] S59.2.1 Create tests/parity/mistral_test.go Est: 30m - - [ ] S59.2.2 Add greedy decode test Est: 20m - - [ ] S59.2.3 Run golangci-lint and go test Est: 10m - -- [ ] T59.3 Run linters and verify for E59 Owner: TBD Est: 15m - - Dependencies: T59.1, T59.2 - - [ ] S59.3.1 Run golangci-lint, go test -race Est: 10m - - [ ] S59.3.2 Fix any issues Est: 5m - -#### E60: QKV Bias for Qwen (Tier 2) - -Add optional bias terms to Q, K, V projections in GroupedQueryAttention so -Qwen models (which use attention_bias=true) work correctly. - -- [ ] T60.1 Add bias support to GroupedQueryAttention Owner: TBD Est: 1.5h - - Dependencies: None - - Files: layers/attention/grouped_query_attention.go (modify) - - Acceptance: GroupedQueryAttention gains optional bias fields: qBias, kBias, - vBias (*tensor.TensorNumeric[T]). When present, Forward() adds bias after - the linear projection: Q = X * Wq + bq. When nil, behavior is unchanged - (backwards compatible). BuildGroupQueryAttention[T] reads bias parameters - from node initializers when present (e.g., "q_proj.bias"). Test: construct - GQA with biases, verify output differs from without bias, matches reference. - - [ ] S60.1.1 Add optional bias fields to GroupedQueryAttention Est: 15m - - [ ] S60.1.2 Modify Forward to apply bias after projection when present Est: 20m - - [ ] S60.1.3 Update BuildGroupQueryAttention to load bias params Est: 15m - - [ ] S60.1.4 Write unit tests: with bias, without bias (backward compat) Est: 25m - - [ ] S60.1.5 Verify existing GQA tests still pass Est: 5m - - [ ] S60.1.6 Run golangci-lint and go test -cover Est: 5m - -- [ ] T60.2 Run linters and verify for E60 Owner: TBD Est: 15m - - Dependencies: T60.1 - - [ ] S60.2.1 Run golangci-lint, go test -cover -race on layers/attention/ Est: 10m - - [ ] S60.2.2 Fix any remaining issues Est: 5m - -#### E61: YaRN RoPE Scaling (Tier 2) - -Implement YaRN (Yet another RoPE extensioN) scaling for long-context models. -YaRN modifies the inverse frequencies in RoPE to support context lengths beyond -the original training length. - -- [ ] T61.1 Add YaRN scaling to RotaryPositionalEmbedding Owner: TBD Est: 2h - - Dependencies: None - - Files: layers/embeddings/rotary_positional_embedding.go (modify) - - Acceptance: A new WithYaRNScaling(factor float64, origMaxLen int) option - modifies the inverse frequency computation. Low-frequency components (long - wavelength) are scaled by 1/factor. High-frequency components (short - wavelength, wavelength < origMaxLen) are kept unchanged. Intermediate - frequencies are linearly interpolated. An attention scaling factor - sqrt(1 + ln(factor) / ln(origMaxLen)) is stored and accessible via a - method. Backwards compatible: without the option, behavior is unchanged. - Test: verify that with factor=4, origMaxLen=8192, the resulting frequencies - differ from default and match the YaRN paper formulas. - - [ ] S61.1.1 Define RoPEScaling config struct (type, factor, origMaxLen) Est: 10m - - [ ] S61.1.2 Implement WithYaRNScaling option Est: 15m - - [ ] S61.1.3 Modify inverse frequency computation for YaRN Est: 30m - - [ ] S61.1.4 Add AttentionScaleFactor() method Est: 10m - - [ ] S61.1.5 Write unit tests: default unchanged, YaRN frequencies match reference Est: 25m - - [ ] S61.1.6 Run golangci-lint and go test -cover Est: 5m - -- [ ] T61.2 Integrate YaRN config into model loading Owner: TBD Est: 45m - - Dependencies: T61.1, E57 - - Files: model/builder.go (modify), layers/attention/group_query_attention_registry.go (modify) - - Acceptance: When ModelMetadata.RopeScaling is non-nil and type="yarn", the - model builder passes WithYaRNScaling to RoPE construction. Existing models - without rope_scaling are unaffected. - - [ ] S61.2.1 Read RopeScaling from ModelMetadata in builder Est: 15m - - [ ] S61.2.2 Pass YaRN options to RotaryPositionalEmbedding construction Est: 15m - - [ ] S61.2.3 Write tests Est: 10m - - [ ] S61.2.4 Run golangci-lint and go test -cover Est: 5m - -- [ ] T61.3 Run linters and verify for E61 Owner: TBD Est: 15m - - Dependencies: T61.2 - - [ ] S61.3.1 Run golangci-lint, go test -cover -race Est: 10m - - [ ] S61.3.2 Fix any remaining issues Est: 5m - -#### E62: Qwen Validation (Tier 2) - -Validate that Qwen 2.5 models load and generate text with QKV bias and YaRN. - -- [ ] T62.1 Qwen 2.5 forward pass parity test Owner: TBD Est: 2h - - Dependencies: E57, E58, E60, E61 - - Files: tests/parity/qwen_test.go (new) - - Acceptance: TestQwen25ForwardPass loads a Qwen 2.5 7B ZMF model (env-gated - by QWEN25_ZMF_PATH), runs a forward pass, asserts valid output shape and - no NaN/Inf. TestQwen25GreedyDecode runs 5-step greedy decode. Skips when - env var not set. - - [ ] S62.1.1 Create tests/parity/qwen_test.go Est: 45m - - [ ] S62.1.2 Add greedy decode test Est: 30m - - [ ] S62.1.3 Add generation test via inference API Est: 30m - - [ ] S62.1.4 Run golangci-lint and go test Est: 15m - -- [ ] T62.2 Run linters and verify for E62 Owner: TBD Est: 15m - - Dependencies: T62.1 - - [ ] S62.2.1 Run golangci-lint, go test -race Est: 10m - - [ ] S62.2.2 Fix any issues Est: 5m - -#### E63: Partial RoPE for Phi-4 (Tier 3) - -Implement partial/fractional RoPE where only a fraction of head dimensions -are rotated, while the rest remain position-agnostic. - -- [ ] T63.1 Add partial rotation to RotaryPositionalEmbedding Owner: TBD Est: 1.5h - - Dependencies: None - - Files: layers/embeddings/rotary_positional_embedding.go (modify) - - Acceptance: A new WithRotaryDimFraction(fraction float64) option controls - what fraction of head dimensions receive rotation. Default is 1.0 (all - dimensions rotated, current behavior). When fraction < 1.0, the Forward() - method splits the input tensor into rotated and non-rotated portions along - the last dimension, applies RoPE to the rotated portion, and concatenates. - Example: headDim=128, fraction=0.75 -> 96 dims rotated, 32 unrotated. - Test: fraction=0.5 produces output where first half is rotated, second half - is identical to input. - - [ ] S63.1.1 Add WithRotaryDimFraction option Est: 10m - - [ ] S63.1.2 Modify Forward to split/rotate/concat when fraction < 1.0 Est: 30m - - [ ] S63.1.3 Modify Backward for partial rotation Est: 20m - - [ ] S63.1.4 Write unit tests: full rotation (default), partial (0.75), half (0.5) Est: 20m - - [ ] S63.1.5 Verify existing RoPE tests still pass Est: 5m - - [ ] S63.1.6 Run golangci-lint and go test -cover Est: 5m - -- [ ] T63.2 Integrate partial RoPE into model loading Owner: TBD Est: 30m - - Dependencies: T63.1, E57 - - Files: model/builder.go (modify) - - Acceptance: When ModelMetadata.PartialRotaryFactor is set (e.g., 0.75 for - Phi-4), the model builder passes WithRotaryDimFraction to RoPE construction. - - [ ] S63.2.1 Read PartialRotaryFactor from ModelMetadata in builder Est: 10m - - [ ] S63.2.2 Pass fraction option to RotaryPositionalEmbedding Est: 10m - - [ ] S63.2.3 Write tests Est: 10m - -- [ ] T63.3 Run linters and verify for E63 Owner: TBD Est: 15m - - Dependencies: T63.2 - - [ ] S63.3.1 Run golangci-lint, go test -cover -race Est: 10m - - [ ] S63.3.2 Fix any remaining issues Est: 5m - -#### E64: Tied Embeddings for Phi-4 (Tier 3) - -Add support for sharing the input token embedding weight matrix with the -output LMHead layer, reducing model parameter count. - -- [ ] T64.1 Add tied embedding support to LMHead Owner: TBD Est: 1h - - Dependencies: None - - Files: layers/core/lm_head.go (modify) - - Acceptance: LMHead gains a TiedWeight field. When set, Forward() uses the - tied weight (transposed) instead of its own weight parameter. A factory - function NewTiedLMHead(engine, tiedWeight) creates an LMHead with shared - weights. The existing NewLMHead (with own weights) is unchanged. - BuildLMHead[T] checks if tie_word_embeddings=true in config and, when true, - finds the token embedding weight and passes it to NewTiedLMHead. - Test: tied LMHead produces same output as manual transpose + matmul. - - [ ] S64.1.1 Add TiedWeight field and NewTiedLMHead constructor Est: 15m - - [ ] S64.1.2 Modify Forward to use tied weight when present Est: 15m - - [ ] S64.1.3 Update BuildLMHead to handle tie_word_embeddings config Est: 15m - - [ ] S64.1.4 Write unit tests: tied vs untied, verify output correctness Est: 15m - - [ ] S64.1.5 Run golangci-lint and go test -cover Est: 5m - -- [ ] T64.2 Run linters and verify for E64 Owner: TBD Est: 15m - - Dependencies: T64.1 - - [ ] S64.2.1 Run golangci-lint, go test -cover -race Est: 10m - - [ ] S64.2.2 Fix any remaining issues Est: 5m - -#### E65: Phi-4 Validation (Tier 3) - -Validate Phi-4 model loading and generation with partial RoPE and tied embeddings. - -- [ ] T65.1 Phi-4 forward pass parity test Owner: TBD Est: 2h - - Dependencies: E57, E58, E63, E64 - - Files: tests/parity/phi4_test.go (new) - - Acceptance: TestPhi4ForwardPass loads a Phi-4 ZMF model (env-gated by - PHI4_ZMF_PATH), runs a forward pass, asserts valid output. TestPhi4GreedyDecode - runs 5-step greedy decode. Skips when env var not set. - - [ ] S65.1.1 Create tests/parity/phi4_test.go Est: 45m - - [ ] S65.1.2 Add greedy decode test Est: 30m - - [ ] S65.1.3 Add generation test via inference API Est: 30m - - [ ] S65.1.4 Run golangci-lint and go test Est: 15m - -- [ ] T65.2 Run linters and verify for E65 Owner: TBD Est: 15m - - Dependencies: T65.1 - - [ ] S65.2.1 Run golangci-lint, go test -race Est: 10m - - [ ] S65.2.2 Fix any issues Est: 5m - -#### E66: Multi-head Latent Attention (Tier 4) - -Implement Multi-head Latent Attention (MLA) as used in DeepSeek V3/R1. MLA -replaces GQA with low-rank KV compression, dramatically reducing KV cache size. - -- [ ] T66.1 Implement MultiHeadLatentAttention layer Owner: TBD Est: 4h - - Dependencies: None - - Files: layers/attention/multi_head_latent_attention.go (new) - - Acceptance: MultiHeadLatentAttention[T] struct with fields: W_DKV (down-projection - to compress KV, shape [hidden, kv_lora_dim]), W_UK (up-projection for keys, - shape [kv_lora_dim, num_heads * head_dim]), W_UV (up-projection for values, - shape [kv_lora_dim, num_heads * head_dim]), W_Q (query projection, - shape [hidden, num_heads * head_dim]), W_O (output projection). Configurable: - kv_lora_dim (default 512), num_heads, head_dim. Forward: - (a) Compress: c_kv = x * W_DKV (shape: [batch, seq, kv_lora_dim]) - (b) Decompress: K = c_kv * W_UK, V = c_kv * W_UV - (c) Q = x * W_Q - (d) Apply RoPE to Q and a subvector of K (decoupled RoPE) - (e) Standard scaled dot-product attention: softmax(Q * K^T / sqrt(d)) * V - (f) Output projection: O * W_O - KV cache stores c_kv (compressed latent) instead of full K/V. - Test: construct MLA with small dims, verify output shape is correct, - verify KV cache stores compressed latent of correct shape. - - [ ] S66.1.1 Define MultiHeadLatentAttention struct with all weight fields Est: 30m - - [ ] S66.1.2 Implement Forward: down-project, up-project, attention, output Est: 60m - - [ ] S66.1.3 Implement KV cache integration (cache c_kv, decompress on read) Est: 30m - - [ ] S66.1.4 Implement decoupled RoPE (rotate subvector of Q and K) Est: 30m - - [ ] S66.1.5 Write unit tests: output shape, cache shape, attention correctness Est: 30m - - [ ] S66.1.6 Run golangci-lint and go test -cover Est: 10m - -- [ ] T66.2 Add BuildMultiHeadLatentAttention and register Owner: TBD Est: 1h - - Dependencies: T66.1 - - Files: layers/attention/mla_registry.go (new), layers/registry/registry.go (modify) - - Acceptance: BuildMultiHeadLatentAttention[T] reads kv_lora_dim, num_heads, - head_dim, qk_rope_head_dim from node attributes. Loads W_DKV, W_UK, W_UV, - W_Q, W_O from node parameters. Registered as "MultiHeadLatentAttention" in - RegisterAll. Test: build from attributes and verify Forward works. - - [ ] S66.2.1 Implement BuildMultiHeadLatentAttention Est: 25m - - [ ] S66.2.2 Register "MultiHeadLatentAttention" in RegisterAll Est: 5m - - [ ] S66.2.3 Write unit tests for builder Est: 20m - - [ ] S66.2.4 Run golangci-lint and go test -cover Est: 10m - -- [ ] T66.3 Run linters and verify for E66 Owner: TBD Est: 15m - - Dependencies: T66.2 - - [ ] S66.3.1 Run golangci-lint, go test -cover -race on layers/attention/ Est: 10m - - [ ] S66.3.2 Fix any remaining issues Est: 5m - -#### E67: Shared Expert MoE (Tier 4) - -Add shared expert support to MixtureOfExperts, where one expert processes -every token in addition to the top-k routed experts. - -- [ ] T67.1 Add shared expert to MixtureOfExperts Owner: TBD Est: 1.5h - - Dependencies: None - - Files: layers/core/moe.go (modify) - - Acceptance: MixtureOfExperts gains a SharedExpert field (graph.Node[T]). - When SharedExpert is non-nil, Forward() runs the shared expert on every - token and adds its output to the weighted sum of routed expert outputs. - When nil, behavior is unchanged (backwards compatible). BuildMixtureOfExperts - checks for n_shared_experts config and loads shared expert weights if present. - Test: with shared expert, output equals (shared_output + weighted_routed_output). - - [ ] S67.1.1 Add SharedExpert field to MixtureOfExperts Est: 10m - - [ ] S67.1.2 Modify Forward to include shared expert output Est: 20m - - [ ] S67.1.3 Update builder to load shared expert Est: 15m - - [ ] S67.1.4 Write unit tests: with shared, without shared (backward compat) Est: 25m - - [ ] S67.1.5 Verify existing MoE tests still pass Est: 5m - - [ ] S67.1.6 Run golangci-lint and go test -cover Est: 5m - -- [ ] T67.2 Run linters and verify for E67 Owner: TBD Est: 15m - - Dependencies: T67.1 - - [ ] S67.2.1 Run golangci-lint, go test -cover -race Est: 10m - - [ ] S67.2.2 Fix any remaining issues Est: 5m - -#### E68: DeepSeek V3 Validation (Tier 4) - -Validate DeepSeek V3 model loading and generation with MLA and shared MoE. - -- [ ] T68.1 DeepSeek V3 forward pass parity test Owner: TBD Est: 3h - - Dependencies: E57, E58, E66, E67 - - Files: tests/parity/deepseek_test.go (new) - - Acceptance: TestDeepSeekV3ForwardPass loads a DeepSeek V3 ZMF model (env-gated - by DEEPSEEK_ZMF_PATH), runs a forward pass, asserts valid output shape and - no NaN/Inf. TestDeepSeekV3GreedyDecode runs 5-step greedy decode. Skips when - env var not set. - - Risk: DeepSeek V3 is 671B parameters total. Testing may require a smaller - variant or a subset of layers. - - [ ] S68.1.1 Create tests/parity/deepseek_test.go Est: 60m - - [ ] S68.1.2 Add greedy decode test Est: 45m - - [ ] S68.1.3 Add generation test via inference API Est: 30m - - [ ] S68.1.4 Run golangci-lint and go test Est: 15m - -- [ ] T68.2 Run linters and verify for E68 Owner: TBD Est: 15m - - Dependencies: T68.1 - - [ ] S68.2.1 Run golangci-lint, go test -race Est: 10m - - [ ] S68.2.2 Fix any issues Est: 5m - -#### E69: Phase 9 Final Verification - -Run the full quality gate suite after all Phase 9 work is complete. - -- [ ] T69.1 Run full test suite with coverage and race detector Owner: TBD Est: 30m - - Dependencies: E57, E58, E59, E60, E61, E62, E63, E64, E65, E66, E67, E68 - - Acceptance: go test ./... -cover -race passes. All new code >= 90% coverage. - No regressions in existing tests. All parity tests skip gracefully when model - files are not present. - - [ ] S69.1.1 Run go test ./... -cover -race Est: 15m - - [ ] S69.1.2 Verify coverage thresholds Est: 10m - - [ ] S69.1.3 Fix any regressions Est: 5m - -- [ ] T69.2 Run linters and verify Owner: TBD Est: 15m - - Dependencies: T69.1 - - Acceptance: golangci-lint 0 issues. go vet clean. - - [ ] S69.2.1 Run golangci-lint run ./... Est: 5m - - [ ] S69.2.2 Run go vet ./... Est: 5m - - [ ] S69.2.3 Fix any remaining issues Est: 5m - -- [ ] T69.3 Update documentation Owner: TBD Est: 45m - - Dependencies: T69.2 - - Acceptance: docs/plan.md Phase 9 tasks marked complete. docs/design.md updated - with multi-architecture support section listing all supported model families, - their config fields, and any architecture-specific notes. - - [ ] S69.3.1 Update docs/plan.md Est: 15m - - [ ] S69.3.2 Update docs/design.md with supported architectures table Est: 20m - - [ ] S69.3.3 Update hand-off notes Est: 10m - ---- - -## 4. Timeline and Milestones - -| ID | Milestone | Dependencies | Exit Criteria | -|----|-----------|--------------|---------------| -| M15 | Logging and metrics | E21, E22 | All packages instrumented; metrics exported | -| M16 | Security and config | E23, E24 | TLS on gRPC; config loads from file with env overrides | -| M17 | Reliability | E25, E26, E28 | Graceful shutdown; health checks; resource limits | -| M18 | CI hardening | E27 | Parity tests blocking; coverage + benchmark gates | -| M19 | Documentation | E30 | Runbook, troubleshooting guide, pprof endpoints | -| M20 | GPU validation | E29 | Tests pass on real T4 hardware (when quota available) | -| M21 | Enterprise ready | E31 | Full suite green, all quality gates pass | -| M22 | Worker service | E32, E33 | Concrete DistributedServiceServer + GrpcStrategy implemented | -| M23 | Distributed integration | E34 | Multi-worker tests prove AllReduce/Barrier/Broadcast correctness | -| M24 | Worker lifecycle | E35 | WorkerNode + CLI command; health + shutdown integrated | -| M25 | Phase 5 complete | E36 | Full suite green, distributed coverage >= 95% | -| M26 | Gemma 3 converter fixed | E37 | TENSOR attr handled; 126 MatMulNBits + 7 Constant nodes convert; smoke test passes | -| M27 | Core operators complete | E38 | Softmax, Sigmoid, LayerNorm, Slice, Pad, TopK, Erf registered and tested | -| M28 | Vision encoder ready | E39 | Conv2d, GlobalAveragePool, BatchNorm, Resize registered and tested | -| M29 | MoE complete | E40 | MoEGate and MixtureOfExperts registered and tested | -| M30 | VLM parity validated | E41, E42 | Gemma 3 forward pass test passes; SigLIP encoder test passes | -| M31 | Phase 6 complete | E43 | Full suite green; all quality gates pass | -| M32 | Dead code removed | E44 | pkg/prelude deleted; tests/helpers/wire.go deleted; no breakage | -| M33 | Registration consolidated | E45 | No init() in layers/; single RegisterAll entry point | -| M34 | Graph thread-safe | E46 | Concurrent Forward passes without data races | -| M35 | Phase 7 complete | E48 | Full suite green; docs updated; all quality gates pass | -| M36 | Production tokenizer | E49 | BPE tokenizer loads tokenizer.json; encode/decode round-trips correctly | -| M37 | KV cache working | E50 | Cached attention produces identical output to uncached; O(n) per step | -| M38 | Generation loop | E51 | Greedy + sampling generation with stop conditions | -| M39 | Streaming output | E52 | Token-by-token delivery via callback; parity with non-streaming | -| M40 | Model registry | E53 | Pull downloads from HuggingFace, converts ONNX to ZMF, caches locally | -| M41 | High-level API | E54 | inference.Load + Model.Generate + Model.Chat + Model.Embed working | -| M42 | CLI commands | E55 | zerfoo pull/run/serve commands working | -| M43 | Phase 8 complete | E56 | End-to-end: Gemma 3 generates coherent text; serve API tested | -| M44 | Multi-arch config | E57 | Config parsers for Llama, Mistral, Qwen, Phi, DeepSeek registered | -| M45 | Param name resolver | E58 | Architecture-specific weight names resolve during model building | -| M46 | Tier 1 validated | E59 | Llama 3 and Mistral forward pass parity tests pass | -| M47 | Tier 2 features | E60, E61 | QKV bias in GQA; YaRN RoPE scaling implemented | -| M48 | Tier 2 validated | E62 | Qwen 2.5 forward pass parity test passes | -| M49 | Tier 3 features | E63, E64 | Partial RoPE; tied embeddings implemented | -| M50 | Tier 3 validated | E65 | Phi-4 forward pass parity test passes | -| M51 | MLA implemented | E66 | MultiHeadLatentAttention layer registered and tested | -| M52 | Tier 4 features | E67 | Shared expert MoE implemented | -| M53 | Tier 4 validated | E68 | DeepSeek V3 forward pass parity test passes | -| M54 | Phase 9 complete | E69 | Full suite green; all architectures documented | - -### Recommended Sequence - -**Phase 4 (Complete):** -1. **E21** (Logging) -- Foundation for all other observability work -2. **E22** (Metrics) -- Can start after T21.1; depends on Logger -3. **E27** (CI Hardening) -- Independent; can run in parallel with E21/E22 -4. **E23** (gRPC Security) -- Independent -5. **E24** (Config Management) -- Independent -6. **E25** (Graceful Shutdown) -- Independent; benefits from Logger -7. **E26** (Health Checks) -- Depends on Logger -8. **E28** (Resource Limits) -- Independent -9. **E29** (GPU Validation) -- Blocked on external quota; do when unblocked -10. **E30** (Documentation) -- After E21-E26 are complete -11. **E31** (Final Verification) -- After all other epics - -**Phase 5 (Concrete Server):** -12. **E32** (Worker Service) -- No new dependencies; uses existing log, metrics, pb stubs -13. **E33** (gRPC Strategy) -- Depends on E32 -14. **E34** (Integration Tests) -- Depends on E33 -15. **E35** (Worker Lifecycle + CLI) -- Depends on E33; can partially parallel E34 -16. **E36** (Final Verification) -- After E32-E35 - -**Phase 6 (Open Weights Model Import):** -17. **E37** (Gemma 3 ONNX Import) -- No new zerfoo deps; zonnx converter fix first -18. **E38** (Core Missing Operators) -- Parallel with E37; independent of E39/E40 -19. **E39** (Vision Encoder Operators) -- Parallel with E38; independent -20. **E40** (MoE) -- Depends on E38 (Softmax + TopK needed by MoEGate) -21. **E41** (Gemma 3 Validation) -- Depends on E37, E38 -22. **E42** (Kimi-VL Validation) -- Depends on E39, E40 -23. **E43** (Final Verification) -- After E37-E42 - -Parallelism opportunities: -- E21 + E27 can run in parallel (independent) -- E23 + E24 + E25 can run in parallel (independent) -- E22 starts after T21.1 (needs Logger interface) -- E26 starts after T21.1 (needs Logger interface) -- E34 + E35 can partially overlap (E34 tests E33 output; E35 builds on E33 independently) - -**Phase 7 (Architecture Cleanup):** -24. **E44** (Dead Code Removal) -- Independent; can run first -25. **E45** (Registration Consolidation) -- Independent of E44 -26. **E46** (Graph Thread Safety) -- Independent of E44/E45 -27. **E47** (Documentation) -- After E44-E46 -28. **E48** (Final Verification) -- After all Phase 7 epics - -Parallelism opportunities: -- E44, E45, E46 are all independent and can run in parallel -- E47 must wait for E44-E46 to complete -- E48 must wait for E47 - -**Phase 8 (Embeddable Inference Library):** -29. **E49** (Tokenizer) -- Foundation; no Phase 8 deps; can start immediately -30. **E50** (KV Cache) -- Foundation; no Phase 8 deps; parallel with E49 -31. **E53** (Model Registry) -- Foundation; no Phase 8 deps; parallel with E49/E50 -32. **E51** (Generation Loop) -- Depends on E49 (tokenizer) + E50 (KV cache) -33. **E52** (Streaming) -- Depends on E51 (generation loop) -34. **E54** (High-Level API) -- Depends on E49, E50, E51, E52, E53 -35. **E55** (CLI Commands) -- Depends on E53 (pull) + E54 (run/serve) -36. **E56** (End-to-End Validation) -- After all Phase 8 epics - -**Phase 9 (Multi-Architecture Support):** -37. **E57** (Config Parsing) -- Foundation; no Phase 9 deps -38. **E58** (Param Name Resolver) -- Depends on E57 -39. **E59** (Llama/Mistral Validation) -- Depends on E57, E58; Tier 1 -40. **E60** (QKV Bias) -- Independent of E57; can parallel -41. **E61** (YaRN RoPE) -- Independent; can parallel with E60 -42. **E62** (Qwen Validation) -- Depends on E57, E58, E60, E61; Tier 2 -43. **E63** (Partial RoPE) -- Independent; can parallel with E60/E61 -44. **E64** (Tied Embeddings) -- Independent; can parallel -45. **E65** (Phi-4 Validation) -- Depends on E57, E58, E63, E64; Tier 3 -46. **E66** (MLA) -- Independent of all above; can start early -47. **E67** (Shared MoE) -- Independent; can parallel with E66 -48. **E68** (DeepSeek Validation) -- Depends on E57, E58, E66, E67; Tier 4 -49. **E69** (Final Verification) -- After all Phase 9 epics - -Parallelism opportunities: -- E57 is the only serial dependency for validation epics (E59, E62, E65, E68) -- E60, E61, E63, E64, E66, E67 are all independent layer implementations; run all in parallel -- E59 (Tier 1) can start as soon as E57+E58 are done -- E62 depends on E60+E61; E65 depends on E63+E64; E68 depends on E66+E67 -- Each tier is independently shippable - -Parallelism opportunities: -- E49 + E50 + E53 are all independent foundations; run in parallel -- E51 starts after E49 + E50 -- E52 starts after E51 -- E55.T55.1 (pull command) can start as soon as E53 is done, parallel with E51/E52 -- E54 integrates everything; starts after E49, E50, E51, E52, E53 -- E55.T55.2 and E55.T55.3 depend on E54 +All 9 phases complete (2026-02-24 through 2026-03-03). 69 epics (E1-E69), +~200 tasks. Only E29 (GPU hardware validation) remains blocked on external +GCP GPU quota. --- @@ -2790,45 +169,19 @@ A task is done when: ## 6. Progress Log -- **2026 03 02 (update 22):** Change Summary: Added Phase 9 -- Multi-Architecture Open-Weight Model Support. Comprehensive gap analysis of major open-weight model families (Llama 3, Mistral, Mixtral, Qwen 2.5/3, Phi-4, DeepSeek V3/R1, Falcon Mamba, RWKV) against Zerfoo's 56+ layer inventory. Findings: Tier 1 (Llama, Mistral, Mixtral) already architecturally supported, need only config.json parsing and parameter name mapping. Tier 2 (Qwen) needs QKV bias in GQA and YaRN RoPE scaling. Tier 3 (Phi-4) needs partial/fractional RoPE and tied embeddings. Tier 4 (DeepSeek V3/R1) needs Multi-head Latent Attention (MLA) and shared expert MoE. Tier 5 (SSM/Mamba/RWKV/DeltaNet) deferred as different paradigm. New epics: E57 (multi-architecture config parsing with registry), E58 (parameter name resolver for architecture-specific weight naming), E59 (Llama/Mistral validation), E60 (QKV bias for Qwen), E61 (YaRN RoPE scaling), E62 (Qwen validation), E63 (partial RoPE for Phi-4), E64 (tied embeddings for Phi-4), E65 (Phi-4 validation), E66 (Multi-head Latent Attention for DeepSeek), E67 (shared expert MoE for DeepSeek), E68 (DeepSeek validation), E69 (final verification). 31 tasks, ~35 hours estimated. Added milestones M44-M54. Each tier is independently shippable. - -- **2026 03 02 (update 21):** Change Summary: Completed E56 (End-to-End Validation) and Phase 8. T56.1: Gemma 3 e2e generation test (tests/parity/gemma3_generation_test.go) using inference.Load with dirRegistry mock, greedy determinism, streaming parity, and chat generation. Skips when GEMMA3_MODEL_DIR/GEMMA3_ZMF_PATH not set. T56.2: HTTP serve integration tests (TestChatCompletion_StreamParity, TestCompletion_StreamParity) verifying SSE events are well-formed and concatenated content matches non-streaming. T56.3: Full test suite passes with -race (55 packages), all Phase 8 packages >= 90% coverage (generate 95%, inference 96.4%, serve 96.4%), golangci-lint 0 issues, go vet clean. T56.4: Updated docs/design.md with inference pipeline section (11.2), updated package layout, updated data flow diagram, added known limitations. Updated hand-off notes for Phase 8 complete status. Commit: 6145789. - -- **2026 03 02 (update 20):** Change Summary: Completed E54 (High-Level Inference API). Package was already implemented with Load, Generate, GenerateStream, Chat, Embed, formatMessages, all Option/GenerateOption types. Fixed lint issues (errcheck, gosec G304/G306). Coverage 92.6%, golangci-lint 0 issues, all tests pass with -race. Tasks completed: T54.1-T54.5. Commit: 2e2c2d3. - -- **2026 03 02 (update 19):** Change Summary: Completed E51 (Autoregressive Generation Loop) and E52 (Streaming Output). E51: Generator.Generate with prefill + autoregressive decode loop, KV cache integration, stop token/string detection. Sampling pipeline: applyTemperature, applyTopK, applyTopP, applyRepetitionPenalty, sampleFromDistribution, softmax, argmax. E52: TokenStream interface, TokenStreamFunc adapter, GenerateStream with incremental token delivery and sentinel-based stop-string termination. Coverage 94.7%, golangci-lint 0 issues, all tests pass with -race. Tasks completed: T51.1-T51.6, T52.1-T52.2. Commit: 5985177. - -- **2026 03 02 (update 18):** Change Summary: Added Phase 8 -- Embeddable Go-Native Inference Library. Strategic direction chosen: position zerfoo as an embeddable Go-native inference library (Direction B from brainstorm). Gap analysis identified 8 critical gaps: whitespace-only tokenizer, no generation loop, no KV cache, no streaming, no high-level API, no model registry, no auto-download, no serve API. New epics: E49 (production BPE tokenizer loading tokenizer.json), E50 (KV cache for attention layers), E51 (autoregressive generation with sampling), E52 (streaming output), E53 (model registry with HuggingFace download and ONNX-to-ZMF conversion), E54 (high-level inference API: Load/Generate/Chat/Embed), E55 (CLI: pull/run/serve with OpenAI-compatible HTTP API), E56 (end-to-end validation). 28 tasks, ~30 hours estimated. Added milestones M36-M43. Pure Go, no CGo, no external dependencies beyond existing ones. - -- **2026 03 02 (update 17):** Change Summary: Completed Phase 7 -- Architecture Cleanup. E44: Deleted pkg/prelude (empty), tests/helpers/wire.go (4 nil interface stubs), 7 dead test files (17 always-skipping tests) in tests/parity and tests/numerics, and the empty tests/numerics directory. E45: Exported buildFFN as BuildFFN in layers/core/registry.go, removed init() and float16/model imports, added FFN registration to layers/registry/RegisterAll(). E46: Added sync.Mutex to graph.Graph protecting memo map in Forward and Backward; wrote TestGraph_ConcurrentForward (8 goroutines, passes with -race). E47: Updated docs/design.md (removed pkg/prelude, graph thread-safety limitation, dead test references; added concurrency and registration notes). E48: Full verification pending. Commits: c5d6c5f, 615bca8, c9271c1, 1f96736, 4e11b5a, 4cc2282, 225326c, e0d3fc9. - -- **2026 03 02 (update 16):** Change Summary: Added Phase 7 -- Architecture Cleanup. Comprehensive review identified: dead code (pkg/prelude empty, tests/helpers/wire.go all nil), inverted dependency (layers/core/registry.go init() imports model), graph.Graph not thread-safe (memo map unprotected). New epics: E44 (dead code removal, 3 tasks), E45 (registration consolidation, 3 tasks), E46 (graph thread safety, 2 tasks), E47 (documentation update, 2 tasks), E48 (final verification, 2 tasks). Added milestones M32-M35. Deferred: model/ package split, Arithmetic[T] interface change, Engine[T] Sum/ReduceSum merge, log.Logger signature change (all too risky or break non-goals). Consolidated docs: deleted docs/gpu.md, docs/runbook.md, docs/troubleshooting.md; rewrote docs/design.md as single comprehensive reference (commit 645f40b). - -- **2026 03 02 (update 15):** Change Summary: Consolidated all documentation. Rewrote docs/design.md as comprehensive single reference document. Deleted docs/gpu.md, docs/runbook.md, docs/troubleshooting.md (content merged into design.md). Updated known limitations to reflect current state. Added model import pipeline, layer coverage, ecosystem, and type system sections. Net reduction of 307 lines. Commit: 645f40b. - -- **2026 03 02 (update 14):** Change Summary: Completed T38.8 (zonnx importer builders). Added converter special cases for Slice/Pad/TopK to promote positional ONNX input tensors to named ZMF attributes (starts/ends/axes/steps, pads/constant_value, k). Added 7 layer builder stubs in zonnx/pkg/importer/layers/ (softmax, sigmoid, erf, layer_norm, slice, pad, topk), each registered via init(). 10 new round-trip tests added to converter_test.go covering all E38 operators. All zonnx tests pass; golangci-lint 0 issues. Commits (zonnx): 2a7bd4f, 04726bb. - -- **2026 03 02 (update 13):** Change Summary: Completed E38 core missing operators (T38.1-T38.7, T38.9). Implemented and registered: Softmax (layers/activations/softmax.go), Erf (layers/activations/erf.go), BuildSigmoid builder (layers/activations/registry.go), BuildLayerNormalization with resolveParam helper (layers/normalization/registry.go), Slice (layers/core/slice.go), Pad (layers/core/pad.go), TopK (layers/core/topk.go). All seven operators registered in layers/registry/registry.go. All 50 packages pass go test -race ./...; golangci-lint 0 issues. Commits: 5c15cab, cf93bf7, d1ad6fa, 3370f25. - -- **2026 03 02 (update 12):** Change Summary: Added Phase 6 -- Open Weights Model Import Support. Gap analysis identified blockers for Gemma 3 (TENSOR attribute missing in zonnx converter, UINT8 dtype missing, MatMulNBits and Constant not registered in zerfoo) and Kimi-VL (Conv2d, Pad, Slice, Resize, BatchNorm, GlobalAveragePool all missing; Softmax/Sigmoid/TopK/Erf not registered as layer nodes; MoE not implemented). New epics: E37 (Gemma 3 ONNX import fixes: 7 tasks), E38 (core missing operators: Softmax, Sigmoid, LayerNorm, Slice, Pad, TopK, Erf: 9 tasks), E39 (vision encoder operators: Conv2d, GlobalAveragePool, BatchNorm, Resize: 6 tasks), E40 (MoE: 4 tasks), E41 (Gemma 3 end-to-end validation: 2 tasks), E42 (Kimi-VL vision encoder validation: 2 tasks), E43 (Phase 6 final verification: 3 tasks). Added milestones M26-M31. Phase 6 is unblocked and can begin immediately. - -- **2026 03 02 (update 11):** Change Summary: Completed Phase 5 -- Concrete Distributed Service Server. E32: workerService implementing pb.DistributedServiceServer with AllReduce (bidi stream), Barrier (unary), Broadcast (unary) handlers, reduceSession, barrierState, input validation (validateTensor). E33: GrpcStrategy[T] implementing InternalStrategy[T] with Init, AllReduceGradients (star-topology), Barrier, BroadcastTensor, Shutdown (idempotent). Fixed Init to accept explicit world size parameter for sequential registration. E34: Multi-worker integration tests (AllReduce 3-worker, single-worker, Barrier, Broadcast, context cancellation). T34.4 (TLS integration) deferred. E35: WorkerNode struct (worker_node.go), WorkerCommand (cmd/cli/worker.go), registered in cmd/zerfoo/main.go, lifecycle integration test. E36: Full test suite pass, distributed/ 96.0% coverage, golangci-lint 0 issues, go vet clean. Commits: a20fe4c, ab72e98, 34a784e, 9922af5, ddbea47, c3f8fcf, b668d28, afdea4a, 3574de4. - -- **2026 03 01 (update 10):** Change Summary: Added Phase 5 -- Concrete Distributed Service Server. New epics E32 (WorkerService implementing pb.DistributedServiceServer with AllReduce/Barrier/Broadcast handlers and input validation), E33 (GrpcStrategy[T] implementing InternalStrategy[T] over gRPC transport), E34 (multi-worker integration tests using bufconn), E35 (WorkerNode lifecycle + CLI worker command + health/shutdown integration), E36 (Phase 5 final verification). Added milestones M22-M25. Star-topology AllReduce protocol (reduce to root, broadcast back). T32.5 completes previously skipped T23.2 (RPC input validation). 20 new tasks, estimated ~15 hours total. - -- **2026 03 01 (update 9):** Change Summary: Completed remaining Phase 4 tasks. T25.3 signal handling (cmd/cli, cmd/zerfoo, cmd/zerfoo-predict). T28.1 memory limit (MemoryTracker with CAS-based enforcement). T28.2 per-operation timeout (parallelForCtx, context checks in UnaryOp/binaryOp/MatMul). T30.1 deployment runbook (docs/runbook.md). T30.2 troubleshooting guide (docs/troubleshooting.md). T31.1 full test suite with race detector (0 data races, 1 pre-existing flaky test in distributed/coordinator). T31.2 golangci-lint 0 issues, go vet clean. T31.3 integration smoke test (config->engine->health->shutdown). CI regex fixed (Go 1.25 does not support Perl negative lookahead). T23.2 skipped (no concrete RPC server implementation). E29 remains BLOCKED on GCP GPU quota. - -- **2026 03 01 (update 8):** Change Summary: Completed T22.1-T22.3 metrics interface/instrumentation, T23.1 TLS config, T25.2 Closer implementations, T26.2 engine health check, T27.2 coverage gate, T27.3 benchmark regression detection, T30.3 pprof endpoints. All with tests, lint clean, coverage above thresholds. - -- **2026 03 01 (update 7):** Change Summary: Created enterprise production readiness plan (Phase 4, E21-E31). Extracted architecture and design knowledge to docs/design.md. Trimmed plan.md to remove completed Phase 1-3 task details (preserved as summary in design.md Section 7). New epics: E21 structured logging, E22 metrics interface, E23 gRPC TLS, E24 config management, E25 graceful shutdown, E26 health checks, E27 CI hardening, E28 resource limits, E29 GPU validation (re-numbered from E15/E20), E30 production docs, E31 final verification. Added milestones M15-M21. - -- **2026 03 01 (update 6):** Completed E6 T6.1 (testutil tests, 98.5%), E6 T6.2 (testutils tests, 94.5%), E7 T7.1 (full suite green, zero races, regularization 92.9% -> 97.6%), E7 T7.2 (0 lint issues, gofmt clean). All Phase 1 remaining tasks done. - -- **2026 03 01 (updates 1-5):** Completed Phase 2 (GPU Engine, E8-E14) and Phase 3 (GPU Production Readiness, E16-E19). Details in docs/design.md Section 7. - -- **2026 02 25:** Completed Phase 1 test coverage (E1-E5). 30 of 33 packages at >= 95%. - -- **2026 02 24:** Initial plan created for Phase 1 test coverage improvement. +| Date | Phase | Summary | +|------|-------|---------| +| 2026-02-24 | 1 | Initial plan created | +| 2026-02-25 | 1 | Test coverage complete (30/33 packages >= 95%) | +| 2026-03-01 | 2-3 | GPU engine + production readiness complete | +| 2026-03-01 | 4 | Enterprise readiness complete (except E29 blocked) | +| 2026-03-01 | 5 | Phase 5 planned (distributed protocol) | +| 2026-03-02 | 5 | Distributed protocol complete (96% coverage) | +| 2026-03-02 | 6 | Open weights model import complete (13 new operators) | +| 2026-03-02 | 7 | Architecture cleanup complete | +| 2026-03-02 | 8 | Embeddable inference library complete | +| 2026-03-02 | 9 | Phase 9 planned (multi-architecture support) | +| 2026-03-03 | 9 | Multi-architecture support complete (6 model families) | --- @@ -2836,213 +189,84 @@ A task is done when: ### For a New Contributor -- **Architecture:** Read docs/design.md for interface contracts, package layout, GPU architecture, operations, and troubleshooting. It is the single reference document. -- **Phase 1-3 status:** Complete. Test coverage, GPU engine, GPU production readiness. -- **Phase 4 status:** Complete (except E29 GPU validation, blocked on GCP quota). -- **Phase 5 status:** Complete. Concrete DistributedServiceServer, GrpcStrategy, WorkerNode, CLI worker command. 96% coverage. -- **Phase 6 status:** Complete. Open weights model import (Gemma 3, SigLIP, Kimi-VL). All operators registered and tested. -- **Phase 7 status:** Complete. Dead code removed (pkg/prelude, tests/helpers, 7 dead test files). Layer registration consolidated (no more init()). Graph.Forward/Backward thread-safe via sync.Mutex. -- **Phase 8 status:** Complete. Embeddable Go-native inference library. Production BPE tokenizer (tokenizer.json), KV cache, autoregressive generation loop with sampling (temperature, topK, topP, repetition penalty), streaming via TokenStream, model registry with local cache, high-level API (inference.Load/Generate/GenerateStream/Chat), CLI commands (pull/run/serve), OpenAI-compatible HTTP server (SSE streaming). Coverage: generate 95%, inference 96.4%, serve 96.4%, registry 91.2%, tokenizer 90.9%. -- **Phase 9 status:** Planned (not started). Multi-architecture open-weight model support. Tier 1 (Llama/Mistral) needs config mapping only. Tier 2 (Qwen) needs QKV bias + YaRN. Tier 3 (Phi-4) needs partial RoPE + tied embeddings. Tier 4 (DeepSeek) needs MLA + shared MoE. 31 tasks across 13 epics (E57-E69). +- **Architecture:** Read docs/design.md for interface contracts, package layout, + GPU architecture, operations, and troubleshooting. It is the single reference + document. Design decisions are in docs/adr/. +- **Phase 1-3:** Complete. Test coverage, GPU engine, GPU production readiness. +- **Phase 4:** Complete (except E29 GPU validation, blocked on GCP quota). +- **Phase 5:** Complete. Distributed protocol, worker lifecycle, CLI. 96% coverage. +- **Phase 6:** Complete. Open weights import (Gemma 3, SigLIP, Kimi-VL). 13 operators. +- **Phase 7:** Complete. Dead code removed, registration consolidated, graph thread-safe. +- **Phase 8:** Complete. Inference library: tokenizer, KV cache, generation, + streaming, registry, high-level API, CLI (pull/run/serve), OpenAI HTTP server. +- **Phase 9:** Complete. Multi-architecture support: Gemma 3, Llama 3, Mistral, + Qwen 2.5, Phi-4, DeepSeek V3. Config registry, param resolver, QKV bias, + YaRN RoPE, partial RoPE, tied embeddings, MLA, shared MoE. - **GPU hardware validation:** Blocked on GCP GPU quota (E29). - **Key files to read first:** - inference/inference.go -- High-level API: Load, Generate, Chat, GenerateStream - generate/generator.go -- Autoregressive generation loop - - generate/stream.go -- Streaming with TokenStream interface - serve/server.go -- OpenAI-compatible HTTP server - compute/engine.go -- Engine[T] interface (34 methods) - graph/node.go -- Node[T] interface - - tensor/storage.go -- Storage[T] interface - distributed/interfaces.go -- Distributed training interfaces - - distributed/coordinator/coordinator.go -- Coordinator gRPC server - **How to run tests:** `go test ./... -cover` for full suite. `go test -tags cuda ./...` for GPU. - **How to build:** `go build ./...` (CPU). `go build -tags cuda ./...` (GPU). - **Pre-commit hook:** Runs golangci-lint and tests. Rejects multi-directory commits. -- **No credentials required.** All work is local. CUDA Toolkit needed for GPU work. -- **Testing pattern for gRPC:** Use google.golang.org/grpc/test/bufconn for in-process gRPC tests. See distributed/coordinator/coordinator_test.go for the established pattern. +- **Testing pattern for gRPC:** Use google.golang.org/grpc/test/bufconn for in-process tests. ### External Dependencies -- GCP GPU quota increase for hardware validation (preference ID: zerfoo-gpu-test, project: numerai-488804). +- GCP GPU quota increase for hardware validation (preference ID: zerfoo-gpu-test, + project: numerai-488804). --- ## 8. Appendix -### Production Readiness Scorecard (Current State) - -| Category | Score | Notes | -|----------|-------|-------| -| Architecture | 9/10 | Clean interfaces, modular, type-safe | -| Core Functionality | 8/10 | Engine complete, GPU in progress | -| Testing | 8/10 | 95%+ coverage, missing hardware validation | -| Error Handling | 6/10 | Basic validation, no structured errors | -| Security | 3/10 | No TLS, no auth, minimal validation | -| Observability | 3/10 | Minimal logging, no metrics export | -| Configuration | 4/10 | Programmatic only, no file support | -| Operations | 3/10 | No health checks, no shutdown coordination | -| Documentation | 5/10 | Design docs good, missing runbooks | -| CI/CD | 7/10 | Comprehensive tests, non-blocking parity | - -### Target Scorecard (After Phase 4) - -| Category | Target | How Achieved | -|----------|--------|-------------| -| Architecture | 9/10 | No changes needed | -| Core Functionality | 9/10 | GPU hardware validation (E29) | -| Testing | 9/10 | Blocking parity tests, benchmark gates (E27) | -| Error Handling | 8/10 | Structured logging with context (E21) | -| Security | 7/10 | TLS, input validation (E23) | -| Observability | 8/10 | Logging, metrics, pprof (E21, E22, E30) | -| Configuration | 8/10 | File loading, env overrides, validation (E24) | -| Operations | 8/10 | Health checks, graceful shutdown, limits (E25, E26, E28) | -| Documentation | 8/10 | Runbook, troubleshooting, pprof (E30) | -| CI/CD | 9/10 | Blocking tests, coverage gate, benchmark gate (E27) | - -### Target Scorecard (After Phase 6) - -| Category | Target | How Achieved | -|----------|--------|-------------| -| Architecture | 10/10 | No changes from Phase 5 | -| Core Functionality | 10/10 | Gemma 3 + Kimi-VL inference via ONNX import (E37-E42) | -| Testing | 10/10 | Parity tests for real open-weights models (E41, E42) | -| Error Handling | 9/10 | No changes from Phase 5 | -| Security | 8/10 | No changes from Phase 5 | -| Observability | 8/10 | No changes from Phase 5 | -| Configuration | 8/10 | No changes from Phase 5 | -| Operations | 9/10 | No changes from Phase 5 | -| Documentation | 9/10 | Gap analysis resolved; operator coverage documented (T43.3) | -| CI/CD | 9/10 | No changes from Phase 5 | - -### Target Scorecard (After Phase 5) - -| Category | Target | How Achieved | -|----------|--------|-------------| -| Architecture | 10/10 | Concrete server completes distributed architecture (E32, E33) | -| Core Functionality | 9/10 | GPU validation still pending (E29) | -| Testing | 10/10 | Multi-worker integration tests over real gRPC (E34) | -| Error Handling | 9/10 | RPC input validation on all handlers (T32.5) | -| Security | 8/10 | TLS integration tests with distributed workers (T34.4) | -| Observability | 8/10 | No changes from Phase 4 | -| Configuration | 8/10 | No changes from Phase 4 | -| Operations | 9/10 | Worker lifecycle + CLI command + health integration (E35) | -| Documentation | 9/10 | Distributed worker setup in runbook (T35.5) | -| CI/CD | 9/10 | No changes from Phase 4 | - -### Target Scorecard (After Phase 7) - -| Category | Target | How Achieved | -|----------|--------|-------------| -| Architecture | 10/10 | Dead code removed (E44); registration consolidated (E45); graph thread-safe (E46) | -| Core Functionality | 10/10 | No changes from Phase 6 | -| Testing | 10/10 | Concurrent forward test added (E46) | -| Error Handling | 9/10 | No changes from Phase 6 | -| Security | 8/10 | No changes from Phase 6 | -| Observability | 8/10 | No changes from Phase 6 | -| Configuration | 8/10 | No changes from Phase 6 | -| Operations | 9/10 | No changes from Phase 6 | -| Documentation | 10/10 | Consolidated to single docs/design.md; stale refs removed (E47) | -| CI/CD | 9/10 | No changes from Phase 6 | - -### Target Scorecard (After Phase 8) - -| Category | Target | How Achieved | -|----------|--------|-------------| -| Architecture | 10/10 | No changes from Phase 7 | -| Core Functionality | 10/10 | Production tokenizer, KV cache, generation loop (E49-E52) | -| Testing | 10/10 | End-to-end generation tests, serve integration tests (E56) | -| Error Handling | 9/10 | No changes from Phase 7 | -| Security | 8/10 | HF_TOKEN support for gated models (T53.2) | -| Observability | 8/10 | No changes from Phase 7 | -| Configuration | 9/10 | Model config.json, inference options, CLI flags (E53, E54, E55) | -| Operations | 10/10 | CLI pull/run/serve, OpenAI-compatible HTTP API (E55) | -| Documentation | 10/10 | Design doc updated with inference pipeline (T56.4) | -| CI/CD | 9/10 | No changes from Phase 7 | -| Developer Experience | 9/10 | 3-line model loading and generation (E54); ollama-style CLI (E55) | - -### Target Scorecard (After Phase 9) +### Production Readiness Scorecard (After Phase 9) -| Category | Target | How Achieved | -|----------|--------|-------------| +| Category | Score | How Achieved | +|----------|-------|-------------| | Architecture | 10/10 | Multi-architecture config parsing (E57); MLA attention variant (E66) | | Core Functionality | 10/10 | 6 model families supported: Gemma, Llama, Mistral, Qwen, Phi, DeepSeek | | Testing | 10/10 | Parity tests for all supported architectures (E59, E62, E65, E68) | -| Error Handling | 9/10 | No changes from Phase 8 | -| Security | 8/10 | No changes from Phase 8 | -| Observability | 8/10 | No changes from Phase 8 | +| Error Handling | 9/10 | Structured logging, RPC validation, context deadlines | +| Security | 8/10 | TLS/mTLS for gRPC; HF_TOKEN for gated models | +| Observability | 8/10 | Logging, metrics, pprof endpoints | | Configuration | 10/10 | Architecture-aware config parsing with HuggingFace field mapping (E57) | -| Operations | 10/10 | No changes from Phase 8 | -| Documentation | 10/10 | Supported architectures table in design.md (T69.3) | -| CI/CD | 9/10 | No changes from Phase 8 | +| Operations | 10/10 | CLI pull/run/serve, OpenAI-compatible HTTP API | +| Documentation | 10/10 | Consolidated design.md + ADRs; supported architectures table | +| CI/CD | 9/10 | Blocking tests, coverage gate, benchmark gate | | Model Coverage | 10/10 | Covers >90% of open-weight model downloads on HuggingFace | ### New Packages and Files Created -| Package / File | Purpose | Epic | -|---------|---------|------| -| log/ | Structured logging with levels | E21 | -| metrics/runtime/ | Runtime metrics collection | E22 | -| config/ | File-based configuration loading | E24 | -| shutdown/ | Graceful shutdown coordinator | E25 | -| health/ | HTTP health check server | E26 | -| cmd/coverage-gate/ | CI coverage enforcement script | E27 | -| cmd/bench-compare/ | CI benchmark regression detection | E27 | -| distributed/worker_service.go | Concrete DistributedServiceServer (AllReduce, Barrier, Broadcast) | E32 | -| distributed/grpc_strategy.go | GrpcStrategy[T] implementing InternalStrategy[T] over gRPC | E33 | -| distributed/integration_test.go | Multi-worker integration tests using bufconn | E34 | -| distributed/worker_node.go | WorkerNode lifecycle management | E35 | -| cmd/zerfoo/worker.go | Worker CLI subcommand | E35 | -| layers/activations/softmax.go | Softmax graph layer node | E38 | -| layers/activations/sigmoid.go | Sigmoid graph layer node | E38 | -| layers/activations/erf.go | Erf (error function) graph layer node | E38 | -| layers/normalization/layer_norm.go | Standard LayerNormalization (with gamma+beta) | E38 | -| layers/normalization/batch_norm.go | BatchNormalization inference mode | E39 | -| layers/core/slice.go | Slice operator for tensor cropping | E38 | -| layers/core/pad.go | Pad operator (constant mode) | E38 | -| layers/core/topk.go | TopK selection operator | E38 | -| layers/core/conv2d.go | 2D convolution via im2col + MatMul | E39 | -| layers/core/global_avg_pool.go | GlobalAveragePool [N,C,H,W] -> [N,C,1,1] | E39 | -| layers/core/resize.go | Resize (nearest + bilinear) | E39 | -| layers/core/moe.go | MoEGate and MixtureOfExperts layers | E40 | -| tests/parity/gemma3_test.go | Gemma 3 forward pass parity test | E41 | -| tests/parity/siglip_test.go | SigLIP + Kimi-VL connector parity tests | E42 | -| zonnx/pkg/importer/layers/constant.go | zonnx builder for Constant nodes | E37 | -| zonnx/pkg/importer/layers/matmul_nbits.go | zonnx builder for MatMulNBits nodes | E37 | -| zonnx/pkg/importer/layers/softmax.go | zonnx builder for Softmax | E38 | -| zonnx/pkg/importer/layers/sigmoid.go | zonnx builder for Sigmoid | E38 | -| zonnx/pkg/importer/layers/layer_norm.go | zonnx builder for LayerNormalization | E38 | -| zonnx/pkg/importer/layers/slice.go | zonnx builder for Slice | E38 | -| zonnx/pkg/importer/layers/pad.go | zonnx builder for Pad | E38 | -| zonnx/pkg/importer/layers/topk.go | zonnx builder for TopK | E38 | -| zonnx/pkg/importer/layers/erf.go | zonnx builder for Erf | E38 | -| zonnx/pkg/importer/layers/conv.go | zonnx builder for Conv (Conv2d) | E39 | -| zonnx/pkg/importer/layers/global_avg_pool.go | zonnx builder for GlobalAveragePool | E39 | -| zonnx/pkg/importer/layers/batch_norm.go | zonnx builder for BatchNormalization | E39 | -| zonnx/pkg/importer/layers/resize.go | zonnx builder for Resize | E39 | -| zonnx/pkg/importer/layers/moe.go | zonnx builders for MoEGate and MixtureOfExperts | E40 | -| pkg/tokenizer/bpe.go | Production BPE tokenizer implementation | E49 | -| pkg/tokenizer/loader.go | HuggingFace tokenizer.json loader | E49 | -| generate/kvcache.go | KV cache for attention layers | E50 | -| generate/context.go | GenerationContext with KV cache carrier | E50 | -| generate/generator.go | Autoregressive generation loop | E51 | -| generate/sampling.go | Temperature, top-k, top-p, repetition penalty | E51 | -| generate/stream.go | TokenStream interface for streaming output | E52 | -| registry/registry.go | ModelRegistry interface and LocalRegistry | E53 | -| registry/pull.go | HuggingFace download and ONNX-to-ZMF conversion | E53 | -| inference/inference.go | High-level Load, Model.Generate, Model.GenerateStream | E54 | -| inference/chat.go | Model.Chat with prompt template formatting | E54 | -| inference/embed.go | Model.Embed with mean pooling | E54 | -| cmd/cli/pull.go | zerfoo pull CLI command | E55 | -| cmd/cli/run.go | zerfoo run interactive REPL command | E55 | -| cmd/cli/serve.go | zerfoo serve HTTP server command | E55 | -| serve/server.go | OpenAI-compatible HTTP API server (net/http) | E55 | -| tests/parity/gemma3_generation_test.go | End-to-end generation parity test | E56 | -| inference/arch_config.go | Multi-architecture config.json parsing | E57 | -| model/param_resolver.go | Architecture-specific parameter name resolution | E58 | -| tests/parity/llama3_test.go | Llama 3 forward pass parity test | E59 | -| tests/parity/mistral_test.go | Mistral forward pass parity test | E59 | -| tests/parity/qwen_test.go | Qwen 2.5 forward pass parity test | E62 | -| tests/parity/phi4_test.go | Phi-4 forward pass parity test | E65 | -| layers/attention/multi_head_latent_attention.go | Multi-head Latent Attention (MLA) for DeepSeek | E66 | -| layers/attention/mla_registry.go | MLA builder and registry integration | E66 | -| tests/parity/deepseek_test.go | DeepSeek V3 forward pass parity test | E68 | +| Package / File | Purpose | Phase | +|---------|---------|-------| +| log/ | Structured logging with levels | 4 | +| metrics/runtime/ | Runtime metrics collection | 4 | +| config/ | File-based configuration loading | 4 | +| shutdown/ | Graceful shutdown coordinator | 4 | +| health/ | HTTP health check server | 4 | +| cmd/coverage-gate/ | CI coverage enforcement script | 4 | +| cmd/bench-compare/ | CI benchmark regression detection | 4 | +| distributed/worker_service.go | DistributedServiceServer (AllReduce, Barrier, Broadcast) | 5 | +| distributed/grpc_strategy.go | GrpcStrategy[T] over gRPC | 5 | +| distributed/integration_test.go | Multi-worker integration tests | 5 | +| distributed/worker_node.go | WorkerNode lifecycle management | 5 | +| cmd/cli/worker.go | Worker CLI subcommand | 5 | +| layers/activations/{softmax,erf}.go | Softmax, Erf layer nodes | 6 | +| layers/normalization/batch_norm.go | BatchNormalization inference mode | 6 | +| layers/core/{slice,pad,topk,conv2d,global_avg_pool,resize,moe,constant}.go | Core operators | 6 | +| tests/parity/{gemma3,siglip}_test.go | Model parity tests | 6 | +| pkg/tokenizer/{bpe,loader}.go | Production BPE tokenizer | 8 | +| generate/{kvcache,context,generator,sampling,stream}.go | Generation pipeline | 8 | +| registry/{registry,pull}.go | Model registry + HuggingFace download | 8 | +| inference/{inference,chat,embed}.go | High-level API | 8 | +| serve/server.go | OpenAI-compatible HTTP server | 8 | +| cmd/cli/{pull,run,serve}.go | CLI commands | 8 | +| inference/arch_config.go | Multi-architecture config parsing | 9 | +| model/param_resolver.go | Architecture-specific param resolution | 9 | +| layers/attention/{multi_head_latent_attention,mla_registry}.go | MLA for DeepSeek | 9 | +| tests/parity/{llama3,mistral,qwen,phi4,deepseek}_test.go | Parity tests | 9 | diff --git a/inference/arch_config.go b/inference/arch_config.go index af7fb9da..aa7db3b0 100644 --- a/inference/arch_config.go +++ b/inference/arch_config.go @@ -44,6 +44,12 @@ func DefaultArchConfigRegistry() *ArchConfigRegistry { r.Register("gemma", parseGemmaConfig) r.Register("gemma2", parseGemmaConfig) r.Register("gemma3", parseGemmaConfig) + r.Register("llama", parseLlamaConfig) + r.Register("mistral", parseMistralConfig) + r.Register("qwen2", parseQwenConfig) + r.Register("phi3", parsePhiConfig) + r.Register("phi", parsePhiConfig) + r.Register("deepseek_v3", parseDeepSeekConfig) return r } @@ -69,6 +75,140 @@ func parseGemmaConfig(raw map[string]interface{}) (*ModelMetadata, error) { return meta, nil } +// parseLlamaConfig parses Llama-family config.json fields. +func parseLlamaConfig(raw map[string]interface{}) (*ModelMetadata, error) { + meta := &ModelMetadata{ + Architecture: getString(raw, "model_type"), + VocabSize: getInt(raw, "vocab_size"), + HiddenSize: getInt(raw, "hidden_size"), + NumLayers: getInt(raw, "num_hidden_layers"), + NumQueryHeads: getInt(raw, "num_attention_heads"), + NumKeyValueHeads: getInt(raw, "num_key_value_heads"), + IntermediateSize: getInt(raw, "intermediate_size"), + MaxPositionEmbeddings: getInt(raw, "max_position_embeddings"), + EOSTokenID: getInt(raw, "eos_token_id"), + BOSTokenID: getInt(raw, "bos_token_id"), + RopeTheta: getFloat(raw, "rope_theta"), + TieWordEmbeddings: getBool(raw, "tie_word_embeddings"), + RopeScaling: getRopeScaling(raw), + } + if meta.RopeTheta == 0 { + meta.RopeTheta = 500000 // Llama 3 default + } + return meta, nil +} + +// parseMistralConfig parses Mistral-family config.json fields. +// Nearly identical to Llama but adds sliding_window and defaults rope_theta to 10000. +func parseMistralConfig(raw map[string]interface{}) (*ModelMetadata, error) { + meta := &ModelMetadata{ + Architecture: getString(raw, "model_type"), + VocabSize: getInt(raw, "vocab_size"), + HiddenSize: getInt(raw, "hidden_size"), + NumLayers: getInt(raw, "num_hidden_layers"), + NumQueryHeads: getInt(raw, "num_attention_heads"), + NumKeyValueHeads: getInt(raw, "num_key_value_heads"), + IntermediateSize: getInt(raw, "intermediate_size"), + MaxPositionEmbeddings: getInt(raw, "max_position_embeddings"), + EOSTokenID: getInt(raw, "eos_token_id"), + BOSTokenID: getInt(raw, "bos_token_id"), + RopeTheta: getFloat(raw, "rope_theta"), + TieWordEmbeddings: getBool(raw, "tie_word_embeddings"), + SlidingWindow: getInt(raw, "sliding_window"), + RopeScaling: getRopeScaling(raw), + } + if meta.RopeTheta == 0 { + meta.RopeTheta = 10000 // Mistral default + } + return meta, nil +} + +// parseQwenConfig parses Qwen2-family config.json fields. +// Qwen models always use attention bias and default rope_theta to 1000000. +func parseQwenConfig(raw map[string]interface{}) (*ModelMetadata, error) { + meta := &ModelMetadata{ + Architecture: getString(raw, "model_type"), + VocabSize: getInt(raw, "vocab_size"), + HiddenSize: getInt(raw, "hidden_size"), + NumLayers: getInt(raw, "num_hidden_layers"), + NumQueryHeads: getInt(raw, "num_attention_heads"), + NumKeyValueHeads: getInt(raw, "num_key_value_heads"), + IntermediateSize: getInt(raw, "intermediate_size"), + MaxPositionEmbeddings: getInt(raw, "max_position_embeddings"), + EOSTokenID: getInt(raw, "eos_token_id"), + BOSTokenID: getInt(raw, "bos_token_id"), + RopeTheta: getFloat(raw, "rope_theta"), + TieWordEmbeddings: getBool(raw, "tie_word_embeddings"), + SlidingWindow: getInt(raw, "sliding_window"), + AttentionBias: true, // Qwen always uses attention bias + RopeScaling: getRopeScaling(raw), + } + if meta.RopeTheta == 0 { + meta.RopeTheta = 1000000 // Qwen default + } + return meta, nil +} + +// parsePhiConfig parses Phi-family config.json fields. +// Adds partial_rotary_factor (default 1.0) and tie_word_embeddings. +func parsePhiConfig(raw map[string]interface{}) (*ModelMetadata, error) { + meta := &ModelMetadata{ + Architecture: getString(raw, "model_type"), + VocabSize: getInt(raw, "vocab_size"), + HiddenSize: getInt(raw, "hidden_size"), + NumLayers: getInt(raw, "num_hidden_layers"), + NumQueryHeads: getInt(raw, "num_attention_heads"), + NumKeyValueHeads: getInt(raw, "num_key_value_heads"), + IntermediateSize: getInt(raw, "intermediate_size"), + MaxPositionEmbeddings: getInt(raw, "max_position_embeddings"), + EOSTokenID: getInt(raw, "eos_token_id"), + BOSTokenID: getInt(raw, "bos_token_id"), + RopeTheta: getFloat(raw, "rope_theta"), + TieWordEmbeddings: getBool(raw, "tie_word_embeddings"), + SlidingWindow: getInt(raw, "sliding_window"), + PartialRotaryFactor: getFloat(raw, "partial_rotary_factor"), + RopeScaling: getRopeScaling(raw), + } + if meta.RopeTheta == 0 { + meta.RopeTheta = 10000 // Phi default + } + if meta.PartialRotaryFactor == 0 { + meta.PartialRotaryFactor = 1.0 // Full rotation by default + } + return meta, nil +} + +// parseDeepSeekConfig parses DeepSeek V3-family config.json fields. +// Adds MLA fields (kv_lora_rank, q_lora_rank, qk_rope_head_dim) and +// MoE fields (n_routed_experts, num_experts_per_tok, n_shared_experts). +func parseDeepSeekConfig(raw map[string]interface{}) (*ModelMetadata, error) { + meta := &ModelMetadata{ + Architecture: getString(raw, "model_type"), + VocabSize: getInt(raw, "vocab_size"), + HiddenSize: getInt(raw, "hidden_size"), + NumLayers: getInt(raw, "num_hidden_layers"), + NumQueryHeads: getInt(raw, "num_attention_heads"), + NumKeyValueHeads: getInt(raw, "num_key_value_heads"), + IntermediateSize: getInt(raw, "intermediate_size"), + MaxPositionEmbeddings: getInt(raw, "max_position_embeddings"), + EOSTokenID: getInt(raw, "eos_token_id"), + BOSTokenID: getInt(raw, "bos_token_id"), + RopeTheta: getFloat(raw, "rope_theta"), + TieWordEmbeddings: getBool(raw, "tie_word_embeddings"), + KVLoRADim: getInt(raw, "kv_lora_rank"), + QLoRADim: getInt(raw, "q_lora_rank"), + QKRopeHeadDim: getInt(raw, "qk_rope_head_dim"), + NumExperts: getInt(raw, "n_routed_experts"), + NumExpertsPerToken: getInt(raw, "num_experts_per_tok"), + NumSharedExperts: getInt(raw, "n_shared_experts"), + RopeScaling: getRopeScaling(raw), + } + if meta.RopeTheta == 0 { + meta.RopeTheta = 10000 // DeepSeek default + } + return meta, nil +} + // parseFallbackConfig extracts common fields using the most widespread // HuggingFace naming conventions. Used for unknown model_type values. func parseFallbackConfig(raw map[string]interface{}) (*ModelMetadata, error) { diff --git a/inference/arch_config_test.go b/inference/arch_config_test.go index 38f8b896..b6b52575 100644 --- a/inference/arch_config_test.go +++ b/inference/arch_config_test.go @@ -1,6 +1,8 @@ package inference import ( + "encoding/json" + "os" "testing" ) @@ -266,6 +268,628 @@ func TestRopeScalingConfig_FromRaw(t *testing.T) { } } +func TestLlamaConfigParser(t *testing.T) { + tests := []struct { + name string + raw map[string]interface{} + want ModelMetadata + }{ + { + name: "llama3.1 8B full", + raw: map[string]interface{}{ + "model_type": "llama", + "vocab_size": float64(128256), + "hidden_size": float64(4096), + "num_hidden_layers": float64(32), + "num_attention_heads": float64(32), + "num_key_value_heads": float64(8), + "intermediate_size": float64(14336), + "max_position_embeddings": float64(131072), + "rope_theta": float64(500000), + "eos_token_id": float64(128001), + "bos_token_id": float64(128000), + "tie_word_embeddings": false, + "rope_scaling": map[string]interface{}{ + "type": "llama3", + "factor": float64(8.0), + "original_max_position_embeddings": float64(8192), + }, + }, + want: ModelMetadata{ + Architecture: "llama", + VocabSize: 128256, + HiddenSize: 4096, + NumLayers: 32, + NumQueryHeads: 32, + NumKeyValueHeads: 8, + IntermediateSize: 14336, + MaxPositionEmbeddings: 131072, + RopeTheta: 500000, + EOSTokenID: 128001, + BOSTokenID: 128000, + }, + }, + { + name: "llama minimal without rope_theta defaults to 500000", + raw: map[string]interface{}{ + "model_type": "llama", + "vocab_size": float64(32000), + "num_hidden_layers": float64(32), + "num_attention_heads": float64(32), + }, + want: ModelMetadata{ + Architecture: "llama", + VocabSize: 32000, + NumLayers: 32, + NumQueryHeads: 32, + RopeTheta: 500000, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := parseLlamaConfig(tc.raw) + if err != nil { + t.Fatalf("parseLlamaConfig error: %v", err) + } + assertMetadataEqual(t, tc.want, *got) + }) + } +} + +func TestLlamaConfigParser_Fixture(t *testing.T) { + data, err := os.ReadFile("testdata/llama3_config.json") + if err != nil { + t.Fatalf("read fixture: %v", err) + } + var raw map[string]interface{} + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatalf("unmarshal fixture: %v", err) + } + + meta, err := parseLlamaConfig(raw) + if err != nil { + t.Fatalf("parseLlamaConfig error: %v", err) + } + + assertMetadataEqual(t, ModelMetadata{ + Architecture: "llama", + VocabSize: 128256, + HiddenSize: 4096, + NumLayers: 32, + NumQueryHeads: 32, + NumKeyValueHeads: 8, + IntermediateSize: 14336, + MaxPositionEmbeddings: 131072, + RopeTheta: 500000, + EOSTokenID: 128001, + BOSTokenID: 128000, + }, *meta) + + // Verify RoPE scaling was parsed. + if meta.RopeScaling == nil { + t.Fatal("RopeScaling should not be nil") + } + if meta.RopeScaling.Type != "llama3" { + t.Errorf("RopeScaling.Type = %q, want %q", meta.RopeScaling.Type, "llama3") + } + if meta.RopeScaling.Factor != 8.0 { + t.Errorf("RopeScaling.Factor = %f, want 8.0", meta.RopeScaling.Factor) + } + if meta.RopeScaling.OriginalMaxPositionEmbeddings != 8192 { + t.Errorf("RopeScaling.OriginalMaxPositionEmbeddings = %d, want 8192", + meta.RopeScaling.OriginalMaxPositionEmbeddings) + } +} + +func TestDefaultArchConfigRegistry_LlamaRegistered(t *testing.T) { + reg := DefaultArchConfigRegistry() + + raw := map[string]interface{}{ + "model_type": "llama", + "vocab_size": float64(128256), + "num_hidden_layers": float64(32), + "num_attention_heads": float64(32), + "num_key_value_heads": float64(8), + "rope_theta": float64(500000), + } + meta, err := reg.Parse(raw) + if err != nil { + t.Fatalf("Parse error: %v", err) + } + if meta.Architecture != "llama" { + t.Errorf("Architecture = %q, want %q", meta.Architecture, "llama") + } + if meta.RopeTheta != 500000 { + t.Errorf("RopeTheta = %f, want 500000", meta.RopeTheta) + } +} + +func TestMistralConfigParser(t *testing.T) { + tests := []struct { + name string + raw map[string]interface{} + want ModelMetadata + }{ + { + name: "mistral 7B full", + raw: map[string]interface{}{ + "model_type": "mistral", + "vocab_size": float64(32000), + "hidden_size": float64(4096), + "num_hidden_layers": float64(32), + "num_attention_heads": float64(32), + "num_key_value_heads": float64(8), + "intermediate_size": float64(14336), + "max_position_embeddings": float64(32768), + "rope_theta": float64(10000), + "eos_token_id": float64(2), + "bos_token_id": float64(1), + "sliding_window": float64(4096), + }, + want: ModelMetadata{ + Architecture: "mistral", + VocabSize: 32000, + HiddenSize: 4096, + NumLayers: 32, + NumQueryHeads: 32, + NumKeyValueHeads: 8, + IntermediateSize: 14336, + MaxPositionEmbeddings: 32768, + RopeTheta: 10000, + EOSTokenID: 2, + BOSTokenID: 1, + SlidingWindow: 4096, + }, + }, + { + name: "mistral minimal defaults rope_theta to 10000", + raw: map[string]interface{}{ + "model_type": "mistral", + "vocab_size": float64(32000), + "num_hidden_layers": float64(32), + "num_attention_heads": float64(32), + }, + want: ModelMetadata{ + Architecture: "mistral", + VocabSize: 32000, + NumLayers: 32, + NumQueryHeads: 32, + RopeTheta: 10000, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := parseMistralConfig(tc.raw) + if err != nil { + t.Fatalf("parseMistralConfig error: %v", err) + } + assertMetadataEqual(t, tc.want, *got) + }) + } +} + +func TestMistralConfigParser_Fixture(t *testing.T) { + data, err := os.ReadFile("testdata/mistral7b_config.json") + if err != nil { + t.Fatalf("read fixture: %v", err) + } + var raw map[string]interface{} + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatalf("unmarshal fixture: %v", err) + } + + meta, err := parseMistralConfig(raw) + if err != nil { + t.Fatalf("parseMistralConfig error: %v", err) + } + + assertMetadataEqual(t, ModelMetadata{ + Architecture: "mistral", + VocabSize: 32000, + HiddenSize: 4096, + NumLayers: 32, + NumQueryHeads: 32, + NumKeyValueHeads: 8, + IntermediateSize: 14336, + MaxPositionEmbeddings: 32768, + RopeTheta: 10000, + EOSTokenID: 2, + BOSTokenID: 1, + SlidingWindow: 4096, + }, *meta) +} + +func TestQwenConfigParser(t *testing.T) { + tests := []struct { + name string + raw map[string]interface{} + want ModelMetadata + }{ + { + name: "qwen2.5 7B full", + raw: map[string]interface{}{ + "model_type": "qwen2", + "vocab_size": float64(152064), + "hidden_size": float64(3584), + "num_hidden_layers": float64(28), + "num_attention_heads": float64(28), + "num_key_value_heads": float64(4), + "intermediate_size": float64(18944), + "max_position_embeddings": float64(32768), + "rope_theta": float64(1000000), + "eos_token_id": float64(151645), + "bos_token_id": float64(151643), + "sliding_window": float64(32768), + "use_sliding_window": false, + "rope_scaling": map[string]interface{}{ + "type": "yarn", + "factor": float64(4.0), + "original_max_position_embeddings": float64(32768), + }, + }, + want: ModelMetadata{ + Architecture: "qwen2", + VocabSize: 152064, + HiddenSize: 3584, + NumLayers: 28, + NumQueryHeads: 28, + NumKeyValueHeads: 4, + IntermediateSize: 18944, + MaxPositionEmbeddings: 32768, + RopeTheta: 1000000, + EOSTokenID: 151645, + BOSTokenID: 151643, + SlidingWindow: 32768, + AttentionBias: true, + }, + }, + { + name: "qwen2 minimal defaults", + raw: map[string]interface{}{ + "model_type": "qwen2", + "vocab_size": float64(151936), + "num_hidden_layers": float64(32), + "num_attention_heads": float64(40), + }, + want: ModelMetadata{ + Architecture: "qwen2", + VocabSize: 151936, + NumLayers: 32, + NumQueryHeads: 40, + RopeTheta: 1000000, + AttentionBias: true, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := parseQwenConfig(tc.raw) + if err != nil { + t.Fatalf("parseQwenConfig error: %v", err) + } + assertMetadataEqual(t, tc.want, *got) + }) + } +} + +func TestQwenConfigParser_Fixture(t *testing.T) { + data, err := os.ReadFile("testdata/qwen25_7b_config.json") + if err != nil { + t.Fatalf("read fixture: %v", err) + } + var raw map[string]interface{} + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatalf("unmarshal fixture: %v", err) + } + + meta, err := parseQwenConfig(raw) + if err != nil { + t.Fatalf("parseQwenConfig error: %v", err) + } + + assertMetadataEqual(t, ModelMetadata{ + Architecture: "qwen2", + VocabSize: 152064, + HiddenSize: 3584, + NumLayers: 28, + NumQueryHeads: 28, + NumKeyValueHeads: 4, + IntermediateSize: 18944, + MaxPositionEmbeddings: 32768, + RopeTheta: 1000000, + EOSTokenID: 151645, + BOSTokenID: 151643, + SlidingWindow: 32768, + AttentionBias: true, + }, *meta) + + if meta.RopeScaling == nil { + t.Fatal("RopeScaling should not be nil") + } + if meta.RopeScaling.Type != "yarn" { + t.Errorf("RopeScaling.Type = %q, want %q", meta.RopeScaling.Type, "yarn") + } + if meta.RopeScaling.Factor != 4.0 { + t.Errorf("RopeScaling.Factor = %f, want 4.0", meta.RopeScaling.Factor) + } +} + +func TestPhiConfigParser(t *testing.T) { + tests := []struct { + name string + raw map[string]interface{} + want ModelMetadata + }{ + { + name: "phi4 full", + raw: map[string]interface{}{ + "model_type": "phi3", + "vocab_size": float64(100352), + "hidden_size": float64(5120), + "num_hidden_layers": float64(40), + "num_attention_heads": float64(40), + "num_key_value_heads": float64(10), + "intermediate_size": float64(17920), + "max_position_embeddings": float64(16384), + "rope_theta": float64(10000), + "eos_token_id": float64(100265), + "bos_token_id": float64(100257), + "partial_rotary_factor": float64(0.5), + "tie_word_embeddings": false, + "sliding_window": float64(2048), + }, + want: ModelMetadata{ + Architecture: "phi3", + VocabSize: 100352, + HiddenSize: 5120, + NumLayers: 40, + NumQueryHeads: 40, + NumKeyValueHeads: 10, + IntermediateSize: 17920, + MaxPositionEmbeddings: 16384, + RopeTheta: 10000, + EOSTokenID: 100265, + BOSTokenID: 100257, + PartialRotaryFactor: 0.5, + SlidingWindow: 2048, + }, + }, + { + name: "phi minimal defaults partial_rotary_factor to 1.0", + raw: map[string]interface{}{ + "model_type": "phi3", + "vocab_size": float64(32064), + "num_hidden_layers": float64(32), + "num_attention_heads": float64(32), + }, + want: ModelMetadata{ + Architecture: "phi3", + VocabSize: 32064, + NumLayers: 32, + NumQueryHeads: 32, + RopeTheta: 10000, + PartialRotaryFactor: 1.0, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := parsePhiConfig(tc.raw) + if err != nil { + t.Fatalf("parsePhiConfig error: %v", err) + } + assertMetadataEqual(t, tc.want, *got) + }) + } +} + +func TestPhiConfigParser_Fixture(t *testing.T) { + data, err := os.ReadFile("testdata/phi4_config.json") + if err != nil { + t.Fatalf("read fixture: %v", err) + } + var raw map[string]interface{} + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatalf("unmarshal fixture: %v", err) + } + + meta, err := parsePhiConfig(raw) + if err != nil { + t.Fatalf("parsePhiConfig error: %v", err) + } + + assertMetadataEqual(t, ModelMetadata{ + Architecture: "phi3", + VocabSize: 100352, + HiddenSize: 5120, + NumLayers: 40, + NumQueryHeads: 40, + NumKeyValueHeads: 10, + IntermediateSize: 17920, + MaxPositionEmbeddings: 16384, + RopeTheta: 10000, + EOSTokenID: 100265, + BOSTokenID: 100257, + PartialRotaryFactor: 0.5, + SlidingWindow: 2048, + }, *meta) +} + +func TestDeepSeekConfigParser(t *testing.T) { + tests := []struct { + name string + raw map[string]interface{} + want ModelMetadata + }{ + { + name: "deepseek v3 full", + raw: map[string]interface{}{ + "model_type": "deepseek_v3", + "vocab_size": float64(129280), + "hidden_size": float64(7168), + "num_hidden_layers": float64(61), + "num_attention_heads": float64(128), + "num_key_value_heads": float64(128), + "intermediate_size": float64(18432), + "max_position_embeddings": float64(163840), + "rope_theta": float64(10000), + "eos_token_id": float64(1), + "bos_token_id": float64(0), + "kv_lora_rank": float64(512), + "q_lora_rank": float64(1536), + "qk_rope_head_dim": float64(64), + "n_routed_experts": float64(256), + "num_experts_per_tok": float64(8), + "n_shared_experts": float64(1), + }, + want: ModelMetadata{ + Architecture: "deepseek_v3", + VocabSize: 129280, + HiddenSize: 7168, + NumLayers: 61, + NumQueryHeads: 128, + NumKeyValueHeads: 128, + IntermediateSize: 18432, + MaxPositionEmbeddings: 163840, + RopeTheta: 10000, + EOSTokenID: 1, + BOSTokenID: 0, + KVLoRADim: 512, + QLoRADim: 1536, + QKRopeHeadDim: 64, + NumExperts: 256, + NumExpertsPerToken: 8, + NumSharedExperts: 1, + }, + }, + { + name: "deepseek minimal defaults rope_theta to 10000", + raw: map[string]interface{}{ + "model_type": "deepseek_v3", + "vocab_size": float64(129280), + "num_hidden_layers": float64(61), + "num_attention_heads": float64(128), + }, + want: ModelMetadata{ + Architecture: "deepseek_v3", + VocabSize: 129280, + NumLayers: 61, + NumQueryHeads: 128, + RopeTheta: 10000, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := parseDeepSeekConfig(tc.raw) + if err != nil { + t.Fatalf("parseDeepSeekConfig error: %v", err) + } + assertMetadataEqual(t, tc.want, *got) + }) + } +} + +func TestDeepSeekConfigParser_Fixture(t *testing.T) { + data, err := os.ReadFile("testdata/deepseek_v3_config.json") + if err != nil { + t.Fatalf("read fixture: %v", err) + } + var raw map[string]interface{} + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatalf("unmarshal fixture: %v", err) + } + + meta, err := parseDeepSeekConfig(raw) + if err != nil { + t.Fatalf("parseDeepSeekConfig error: %v", err) + } + + assertMetadataEqual(t, ModelMetadata{ + Architecture: "deepseek_v3", + VocabSize: 129280, + HiddenSize: 7168, + NumLayers: 61, + NumQueryHeads: 128, + NumKeyValueHeads: 128, + IntermediateSize: 18432, + MaxPositionEmbeddings: 163840, + RopeTheta: 10000, + EOSTokenID: 1, + BOSTokenID: 0, + KVLoRADim: 512, + QLoRADim: 1536, + QKRopeHeadDim: 64, + NumExperts: 256, + NumExpertsPerToken: 8, + NumSharedExperts: 1, + }, *meta) +} + +func TestDefaultArchConfigRegistry_PhiAndDeepSeekRegistered(t *testing.T) { + reg := DefaultArchConfigRegistry() + + for _, tc := range []struct { + modelType string + wantArch string + }{ + {"phi3", "phi3"}, + {"phi", "phi"}, + {"deepseek_v3", "deepseek_v3"}, + } { + t.Run(tc.modelType, func(t *testing.T) { + raw := map[string]interface{}{ + "model_type": tc.modelType, + "vocab_size": float64(32000), + "num_hidden_layers": float64(32), + "num_attention_heads": float64(32), + } + meta, err := reg.Parse(raw) + if err != nil { + t.Fatalf("Parse error: %v", err) + } + if meta.Architecture != tc.wantArch { + t.Errorf("Architecture = %q, want %q", meta.Architecture, tc.wantArch) + } + }) + } +} + +func TestDefaultArchConfigRegistry_MistralAndQwenRegistered(t *testing.T) { + reg := DefaultArchConfigRegistry() + + for _, tc := range []struct { + modelType string + wantArch string + }{ + {"mistral", "mistral"}, + {"qwen2", "qwen2"}, + } { + t.Run(tc.modelType, func(t *testing.T) { + raw := map[string]interface{}{ + "model_type": tc.modelType, + "vocab_size": float64(32000), + "num_hidden_layers": float64(32), + "num_attention_heads": float64(32), + } + meta, err := reg.Parse(raw) + if err != nil { + t.Fatalf("Parse error: %v", err) + } + if meta.Architecture != tc.wantArch { + t.Errorf("Architecture = %q, want %q", meta.Architecture, tc.wantArch) + } + }) + } +} + // assertMetadataEqual compares key fields of two ModelMetadata values. func assertMetadataEqual(t *testing.T, want, got ModelMetadata) { t.Helper() @@ -311,4 +935,25 @@ func assertMetadataEqual(t *testing.T, want, got ModelMetadata) { if got.AttentionBias != want.AttentionBias { t.Errorf("AttentionBias = %v, want %v", got.AttentionBias, want.AttentionBias) } + if got.PartialRotaryFactor != want.PartialRotaryFactor { + t.Errorf("PartialRotaryFactor = %f, want %f", got.PartialRotaryFactor, want.PartialRotaryFactor) + } + if got.KVLoRADim != want.KVLoRADim { + t.Errorf("KVLoRADim = %d, want %d", got.KVLoRADim, want.KVLoRADim) + } + if got.QLoRADim != want.QLoRADim { + t.Errorf("QLoRADim = %d, want %d", got.QLoRADim, want.QLoRADim) + } + if got.QKRopeHeadDim != want.QKRopeHeadDim { + t.Errorf("QKRopeHeadDim = %d, want %d", got.QKRopeHeadDim, want.QKRopeHeadDim) + } + if got.NumExperts != want.NumExperts { + t.Errorf("NumExperts = %d, want %d", got.NumExperts, want.NumExperts) + } + if got.NumExpertsPerToken != want.NumExpertsPerToken { + t.Errorf("NumExpertsPerToken = %d, want %d", got.NumExpertsPerToken, want.NumExpertsPerToken) + } + if got.NumSharedExperts != want.NumSharedExperts { + t.Errorf("NumSharedExperts = %d, want %d", got.NumSharedExperts, want.NumSharedExperts) + } } diff --git a/inference/inference.go b/inference/inference.go index 45320777..513d639d 100644 --- a/inference/inference.go +++ b/inference/inference.go @@ -45,9 +45,18 @@ type ModelMetadata struct { NumKeyValueHeads int `json:"num_key_value_heads"` RopeTheta float64 `json:"rope_theta"` RopeScaling *RopeScalingConfig `json:"rope_scaling,omitempty"` - TieWordEmbeddings bool `json:"tie_word_embeddings"` - SlidingWindow int `json:"sliding_window"` - AttentionBias bool `json:"attention_bias"` + TieWordEmbeddings bool `json:"tie_word_embeddings"` + SlidingWindow int `json:"sliding_window"` + AttentionBias bool `json:"attention_bias"` + PartialRotaryFactor float64 `json:"partial_rotary_factor"` + + // DeepSeek MLA and MoE fields. + KVLoRADim int `json:"kv_lora_rank"` + QLoRADim int `json:"q_lora_rank"` + QKRopeHeadDim int `json:"qk_rope_head_dim"` + NumExperts int `json:"num_experts"` + NumExpertsPerToken int `json:"num_experts_per_tok"` + NumSharedExperts int `json:"n_shared_experts"` } // Option configures model loading. @@ -139,7 +148,22 @@ func Load(modelID string, opts ...Option) (*Model, error) { zmfPath := filepath.Join(info.Path, "model.zmf") eng := compute.NewCPUEngine[float32](numeric.Float32Ops{}) - mdl, err := model.LoadModelFromZMF[float32](eng, numeric.Float32Ops{}, zmfPath) + globalAttrs := map[string]interface{}{} + if meta.RopeScaling != nil && meta.RopeScaling.Type == "yarn" { + globalAttrs["rope_scaling_type"] = meta.RopeScaling.Type + globalAttrs["rope_scaling_factor"] = meta.RopeScaling.Factor + globalAttrs["rope_scaling_orig_max_len"] = meta.RopeScaling.OriginalMaxPositionEmbeddings + } + if meta.PartialRotaryFactor > 0 && meta.PartialRotaryFactor < 1.0 { + globalAttrs["partial_rotary_factor"] = meta.PartialRotaryFactor + } + + var buildOpts []model.BuildOption + if len(globalAttrs) > 0 { + buildOpts = append(buildOpts, model.WithGlobalAttributes(globalAttrs)) + } + + mdl, err := model.LoadModelFromZMF[float32](eng, numeric.Float32Ops{}, zmfPath, buildOpts...) if err != nil { return nil, fmt.Errorf("load model: %w", err) } @@ -178,17 +202,31 @@ func assembleModel( } } -// loadMetadata reads and parses config.json. +// loadMetadata reads config.json and dispatches to the appropriate +// architecture-specific parser via the default config registry. func loadMetadata(path string) (*ModelMetadata, error) { data, err := os.ReadFile(filepath.Clean(path)) if err != nil { return nil, err } - var meta ModelMetadata - if err := json.Unmarshal(data, &meta); err != nil { + + var raw map[string]interface{} + if err := json.Unmarshal(data, &raw); err != nil { return nil, err } - return &meta, nil + + reg := DefaultArchConfigRegistry() + meta, err := reg.Parse(raw) + if err != nil { + return nil, err + } + + // Overlay fields that use our internal JSON tags (not HuggingFace names). + if ct, ok := raw["chat_template"].(string); ok { + meta.ChatTemplate = ct + } + + return meta, nil } // GenerateOption configures a generation call. diff --git a/inference/inference_test.go b/inference/inference_test.go index 5dc377b7..1e32ef66 100644 --- a/inference/inference_test.go +++ b/inference/inference_test.go @@ -122,24 +122,22 @@ func buildTestModel(t *testing.T, vocabSize int, tokenSequence []int) *Model { // --- loadMetadata tests --- func TestLoadMetadata(t *testing.T) { - t.Run("valid config", func(t *testing.T) { + t.Run("gemma config with model_type", func(t *testing.T) { dir := t.TempDir() - cfg := ModelMetadata{ - Architecture: "gemma3", - VocabSize: 256000, - HiddenSize: 2048, - NumLayers: 26, - MaxPositionEmbeddings: 8192, - EOSTokenID: 1, - BOSTokenID: 2, - ChatTemplate: "gemma", - } - data, err := json.Marshal(cfg) - if err != nil { - t.Fatal(err) - } + raw := `{ + "model_type": "gemma3", + "vocab_size": 262144, + "hidden_size": 2048, + "num_hidden_layers": 26, + "num_attention_heads": 8, + "num_key_value_heads": 4, + "max_position_embeddings": 8192, + "eos_token_id": 1, + "bos_token_id": 2, + "chat_template": "gemma" + }` path := filepath.Join(dir, "config.json") - if err := os.WriteFile(path, data, 0o600); err != nil { + if err := os.WriteFile(path, []byte(raw), 0o600); err != nil { t.Fatal(err) } @@ -147,17 +145,73 @@ func TestLoadMetadata(t *testing.T) { if err != nil { t.Fatalf("loadMetadata error: %v", err) } - if got.VocabSize != 256000 { - t.Errorf("VocabSize = %d, want 256000", got.VocabSize) + if got.VocabSize != 262144 { + t.Errorf("VocabSize = %d, want 262144", got.VocabSize) } if got.Architecture != "gemma3" { t.Errorf("Architecture = %q, want %q", got.Architecture, "gemma3") } + if got.NumLayers != 26 { + t.Errorf("NumLayers = %d, want 26", got.NumLayers) + } if got.ChatTemplate != "gemma" { t.Errorf("ChatTemplate = %q, want %q", got.ChatTemplate, "gemma") } }) + t.Run("llama config from fixture", func(t *testing.T) { + got, err := loadMetadata("testdata/llama3_config.json") + if err != nil { + t.Fatalf("loadMetadata error: %v", err) + } + if got.Architecture != "llama" { + t.Errorf("Architecture = %q, want %q", got.Architecture, "llama") + } + if got.VocabSize != 128256 { + t.Errorf("VocabSize = %d, want 128256", got.VocabSize) + } + if got.HiddenSize != 4096 { + t.Errorf("HiddenSize = %d, want 4096", got.HiddenSize) + } + if got.NumLayers != 32 { + t.Errorf("NumLayers = %d, want 32", got.NumLayers) + } + if got.NumQueryHeads != 32 { + t.Errorf("NumQueryHeads = %d, want 32", got.NumQueryHeads) + } + if got.NumKeyValueHeads != 8 { + t.Errorf("NumKeyValueHeads = %d, want 8", got.NumKeyValueHeads) + } + if got.RopeTheta != 500000 { + t.Errorf("RopeTheta = %f, want 500000", got.RopeTheta) + } + }) + + t.Run("unknown model_type uses fallback", func(t *testing.T) { + dir := t.TempDir() + raw := `{ + "model_type": "future_model", + "vocab_size": 50000, + "num_hidden_layers": 12, + "hidden_size": 768 + }` + path := filepath.Join(dir, "config.json") + if err := os.WriteFile(path, []byte(raw), 0o600); err != nil { + t.Fatal(err) + } + + got, err := loadMetadata(path) + if err != nil { + t.Fatalf("loadMetadata error: %v", err) + } + if got.Architecture != "future_model" { + t.Errorf("Architecture = %q, want %q", got.Architecture, "future_model") + } + if got.VocabSize != 50000 { + t.Errorf("VocabSize = %d, want 50000", got.VocabSize) + } + }) + t.Run("file not found", func(t *testing.T) { _, err := loadMetadata("/nonexistent/config.json") if err == nil { diff --git a/inference/testdata/deepseek_v3_config.json b/inference/testdata/deepseek_v3_config.json new file mode 100644 index 00000000..cd2f05d1 --- /dev/null +++ b/inference/testdata/deepseek_v3_config.json @@ -0,0 +1,25 @@ +{ + "architectures": ["DeepseekV3ForCausalLM"], + "bos_token_id": 0, + "eos_token_id": 1, + "hidden_act": "silu", + "hidden_size": 7168, + "initializer_range": 0.02, + "intermediate_size": 18432, + "max_position_embeddings": 163840, + "model_type": "deepseek_v3", + "num_attention_heads": 128, + "num_hidden_layers": 61, + "num_key_value_heads": 128, + "kv_lora_rank": 512, + "q_lora_rank": 1536, + "qk_rope_head_dim": 64, + "n_routed_experts": 256, + "num_experts_per_tok": 8, + "n_shared_experts": 1, + "rms_norm_eps": 1e-06, + "rope_theta": 10000.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "vocab_size": 129280 +} diff --git a/inference/testdata/llama3_config.json b/inference/testdata/llama3_config.json new file mode 100644 index 00000000..a0bfdba8 --- /dev/null +++ b/inference/testdata/llama3_config.json @@ -0,0 +1,33 @@ +{ + "architectures": ["LlamaForCausalLM"], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 131072, + "mlp_bias": false, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "factor": 8.0, + "high_freq_factor": 4.0, + "low_freq_factor": 1.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3", + "type": "llama3" + }, + "rope_theta": 500000.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.43.3", + "use_cache": true, + "vocab_size": 128256 +} diff --git a/inference/testdata/mistral7b_config.json b/inference/testdata/mistral7b_config.json new file mode 100644 index 00000000..e67e6d95 --- /dev/null +++ b/inference/testdata/mistral7b_config.json @@ -0,0 +1,23 @@ +{ + "architectures": ["MistralForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 32768, + "model_type": "mistral", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-05, + "rope_theta": 10000.0, + "sliding_window": 4096, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.34.0", + "use_cache": true, + "vocab_size": 32000 +} diff --git a/inference/testdata/phi4_config.json b/inference/testdata/phi4_config.json new file mode 100644 index 00000000..bffb9538 --- /dev/null +++ b/inference/testdata/phi4_config.json @@ -0,0 +1,25 @@ +{ + "architectures": ["PhiForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 100257, + "eos_token_id": 100265, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 17920, + "max_position_embeddings": 16384, + "model_type": "phi3", + "num_attention_heads": 40, + "num_hidden_layers": 40, + "num_key_value_heads": 10, + "original_max_position_embeddings": 4096, + "partial_rotary_factor": 0.5, + "rms_norm_eps": 1e-05, + "rope_theta": 10000.0, + "sliding_window": 2048, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.46.0", + "use_cache": true, + "vocab_size": 100352 +} diff --git a/inference/testdata/qwen25_7b_config.json b/inference/testdata/qwen25_7b_config.json new file mode 100644 index 00000000..440d75c5 --- /dev/null +++ b/inference/testdata/qwen25_7b_config.json @@ -0,0 +1,30 @@ +{ + "architectures": ["Qwen2ForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 3584, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 32768, + "max_window_layers": 28, + "model_type": "qwen2", + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "factor": 4.0, + "original_max_position_embeddings": 32768, + "type": "yarn" + }, + "rope_theta": 1000000.0, + "sliding_window": 32768, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.43.1", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 152064 +} diff --git a/layers/attention/grouped_query_attention_extended_test.go b/layers/attention/grouped_query_attention_extended_test.go index 2594c65b..ab25f2b7 100644 --- a/layers/attention/grouped_query_attention_extended_test.go +++ b/layers/attention/grouped_query_attention_extended_test.go @@ -611,6 +611,309 @@ func TestBuildGroupQueryAttention_OddHeadDim(t *testing.T) { } } +// makeNonZeroParam creates a parameter with non-zero data for testing. +func makeNonZeroParam(t *testing.T, name string, shape []int) *graph.Parameter[float32] { + t.Helper() + total := 1 + for _, d := range shape { + total *= d + } + data := make([]float32, total) + for i := range data { + data[i] = float32(i%7+1) * 0.01 + } + weight, err := tensor.New[float32](shape, data) + if err != nil { + t.Fatalf("tensor.New failed: %v", err) + } + p, err := graph.NewParameter(name, weight, tensor.New[float32]) + if err != nil { + t.Fatalf("NewParameter failed: %v", err) + } + return p +} + +func TestBuildGroupQueryAttention_WithBias(t *testing.T) { + ops := numeric.Float32Ops{} + engine := compute.NewCPUEngine[float32](ops) + + modelDim := 16 + numQ := 4 + numKV := 2 + headDim := modelDim / numQ + kvDim := headDim * numKV + + // Use non-zero weights so bias has observable effect + wq := makeNonZeroParam(t, "attn_wq", []int{modelDim, modelDim}) + wk := makeNonZeroParam(t, "attn_wk", []int{modelDim, kvDim}) + wv := makeNonZeroParam(t, "attn_wv", []int{modelDim, kvDim}) + wo := makeNonZeroParam(t, "attn_wo", []int{modelDim, modelDim}) + + // Create bias params with non-zero values + qBias := makeNonZeroParam(t, "attn_wq_bias", []int{modelDim}) + kBias := makeNonZeroParam(t, "attn_wk_bias", []int{kvDim}) + vBias := makeNonZeroParam(t, "attn_wv_bias", []int{kvDim}) + + attrs := map[string]interface{}{ + "model_dim": modelDim, + "num_query_heads": numQ, + "num_key_value_heads": numKV, + "rope_base": 10000.0, + "max_seq_len": 64, + } + + input, err := tensor.New[float32]([]int{1, 4, modelDim}, nil) + if err != nil { + t.Fatalf("tensor.New failed: %v", err) + } + for i := range input.Data() { + input.Data()[i] = float32(i%7+1) * 0.01 + } + + // Build and run WITHOUT bias + paramsNoBias := map[string]*graph.Parameter[float32]{ + "attn_wq": wq, "attn_wk": wk, "attn_wv": wv, "attn_wo": wo, + } + nodeNoBias, err := BuildGroupQueryAttention[float32](engine, ops, "attn", paramsNoBias, attrs) + if err != nil { + t.Fatalf("BuildGroupQueryAttention without bias failed: %v", err) + } + outNoBias, err := nodeNoBias.Forward(context.Background(), input) + if err != nil { + t.Fatalf("Forward without bias failed: %v", err) + } + + // Build and run WITH bias + paramsWithBias := map[string]*graph.Parameter[float32]{ + "attn_wq": wq, "attn_wk": wk, "attn_wv": wv, "attn_wo": wo, + "attn_wq_bias": qBias, "attn_wk_bias": kBias, "attn_wv_bias": vBias, + } + nodeWithBias, err := BuildGroupQueryAttention[float32](engine, ops, "attn", paramsWithBias, attrs) + if err != nil { + t.Fatalf("BuildGroupQueryAttention with bias failed: %v", err) + } + outWithBias, err := nodeWithBias.Forward(context.Background(), input) + if err != nil { + t.Fatalf("Forward with bias failed: %v", err) + } + + // Outputs should differ + biasData := outWithBias.Data() + noBiasData := outNoBias.Data() + allSame := true + for i := range biasData { + if biasData[i] != noBiasData[i] { + allSame = false + break + } + } + if allSame { + t.Error("output with bias should differ from output without bias") + } +} + +func TestBuildGroupQueryAttention_WithPartialBias(t *testing.T) { + // Only Q bias present — should still work (K, V get nil bias) + ops := numeric.Float32Ops{} + engine := compute.NewCPUEngine[float32](ops) + + modelDim := 16 + numQ := 4 + numKV := 2 + headDim := modelDim / numQ + kvDim := headDim * numKV + + params := map[string]*graph.Parameter[float32]{ + "attn_wq": makeParam(t, "attn_wq", []int{modelDim, modelDim}), + "attn_wk": makeParam(t, "attn_wk", []int{modelDim, kvDim}), + "attn_wv": makeParam(t, "attn_wv", []int{modelDim, kvDim}), + "attn_wo": makeParam(t, "attn_wo", []int{modelDim, modelDim}), + "attn_wq_bias": makeParam(t, "attn_wq_bias", []int{modelDim}), + } + + attrs := map[string]interface{}{ + "model_dim": modelDim, + "num_query_heads": numQ, + "num_key_value_heads": numKV, + "rope_base": 10000.0, + "max_seq_len": 64, + } + + node, err := BuildGroupQueryAttention[float32](engine, ops, "attn", params, attrs) + if err != nil { + t.Fatalf("BuildGroupQueryAttention with partial bias failed: %v", err) + } + + input, err := tensor.New[float32]([]int{1, 4, modelDim}, nil) + if err != nil { + t.Fatalf("tensor.New failed: %v", err) + } + for i := range input.Data() { + input.Data()[i] = float32(i%7+1) * 0.01 + } + + _, err = node.Forward(context.Background(), input) + if err != nil { + t.Fatalf("Forward with partial bias failed: %v", err) + } +} + +func TestBuildGroupQueryAttention_WithYaRNScaling(t *testing.T) { + ops := numeric.Float32Ops{} + engine := compute.NewCPUEngine[float32](ops) + + modelDim := 16 + numQ := 4 + numKV := 2 + headDim := modelDim / numQ + kvDim := headDim * numKV + + wq := makeNonZeroParam(t, "attn_wq", []int{modelDim, modelDim}) + wk := makeNonZeroParam(t, "attn_wk", []int{modelDim, kvDim}) + wv := makeNonZeroParam(t, "attn_wv", []int{modelDim, kvDim}) + wo := makeNonZeroParam(t, "attn_wo", []int{modelDim, modelDim}) + + params := map[string]*graph.Parameter[float32]{ + "attn_wq": wq, "attn_wk": wk, "attn_wv": wv, "attn_wo": wo, + } + + input, err := tensor.New[float32]([]int{1, 4, modelDim}, nil) + if err != nil { + t.Fatalf("tensor.New failed: %v", err) + } + for i := range input.Data() { + input.Data()[i] = float32(i%7+1) * 0.01 + } + + // Build WITHOUT YaRN + attrsNoYaRN := map[string]interface{}{ + "model_dim": modelDim, + "num_query_heads": numQ, + "num_key_value_heads": numKV, + "rope_base": 10000.0, + "max_seq_len": 64, + } + nodeNoYaRN, err := BuildGroupQueryAttention[float32](engine, ops, "attn", params, attrsNoYaRN) + if err != nil { + t.Fatalf("Build without YaRN failed: %v", err) + } + outNoYaRN, err := nodeNoYaRN.Forward(context.Background(), input) + if err != nil { + t.Fatalf("Forward without YaRN failed: %v", err) + } + + // Build WITH YaRN scaling + attrsYaRN := map[string]interface{}{ + "model_dim": modelDim, + "num_query_heads": numQ, + "num_key_value_heads": numKV, + "rope_base": 10000.0, + "max_seq_len": 64, + "rope_scaling_type": "yarn", + "rope_scaling_factor": 4.0, + "rope_scaling_orig_max_len": 32, + } + nodeYaRN, err := BuildGroupQueryAttention[float32](engine, ops, "attn", params, attrsYaRN) + if err != nil { + t.Fatalf("Build with YaRN failed: %v", err) + } + outYaRN, err := nodeYaRN.Forward(context.Background(), input) + if err != nil { + t.Fatalf("Forward with YaRN failed: %v", err) + } + + // Outputs should differ + noYaRNData := outNoYaRN.Data() + yaRNData := outYaRN.Data() + allSame := true + for i := range noYaRNData { + if noYaRNData[i] != yaRNData[i] { + allSame = false + break + } + } + if allSame { + t.Error("output with YaRN should differ from output without YaRN") + } +} + +func TestBuildGroupQueryAttention_WithPartialRoPE(t *testing.T) { + ops := numeric.Float32Ops{} + engine := compute.NewCPUEngine[float32](ops) + + modelDim := 16 + numQ := 4 + numKV := 2 + headDim := modelDim / numQ + kvDim := headDim * numKV + + wq := makeNonZeroParam(t, "attn_wq", []int{modelDim, modelDim}) + wk := makeNonZeroParam(t, "attn_wk", []int{modelDim, kvDim}) + wv := makeNonZeroParam(t, "attn_wv", []int{modelDim, kvDim}) + wo := makeNonZeroParam(t, "attn_wo", []int{modelDim, modelDim}) + + params := map[string]*graph.Parameter[float32]{ + "attn_wq": wq, "attn_wk": wk, "attn_wv": wv, "attn_wo": wo, + } + + input, err := tensor.New[float32]([]int{1, 4, modelDim}, nil) + if err != nil { + t.Fatalf("tensor.New failed: %v", err) + } + for i := range input.Data() { + input.Data()[i] = float32(i%7+1) * 0.01 + } + + // Build WITHOUT partial RoPE + attrsNormal := map[string]interface{}{ + "model_dim": modelDim, + "num_query_heads": numQ, + "num_key_value_heads": numKV, + "rope_base": 10000.0, + "max_seq_len": 64, + } + nodeNormal, err := BuildGroupQueryAttention[float32](engine, ops, "attn", params, attrsNormal) + if err != nil { + t.Fatalf("Build without partial RoPE failed: %v", err) + } + outNormal, err := nodeNormal.Forward(context.Background(), input) + if err != nil { + t.Fatalf("Forward without partial RoPE failed: %v", err) + } + + // Build WITH partial RoPE (fraction=0.5) + attrsPartial := map[string]interface{}{ + "model_dim": modelDim, + "num_query_heads": numQ, + "num_key_value_heads": numKV, + "rope_base": 10000.0, + "max_seq_len": 64, + "partial_rotary_factor": 0.5, + } + nodePartial, err := BuildGroupQueryAttention[float32](engine, ops, "attn", params, attrsPartial) + if err != nil { + t.Fatalf("Build with partial RoPE failed: %v", err) + } + outPartial, err := nodePartial.Forward(context.Background(), input) + if err != nil { + t.Fatalf("Forward with partial RoPE failed: %v", err) + } + + // Outputs should differ + normalData := outNormal.Data() + partialData := outPartial.Data() + allSame := true + for i := range normalData { + if normalData[i] != partialData[i] { + allSame = false + break + } + } + if allSame { + t.Error("output with partial RoPE should differ from output without") + } +} + func TestBuildGroupQueryAttention_MissingParams(t *testing.T) { ops := numeric.Float32Ops{} engine := compute.NewCPUEngine[float32](ops) diff --git a/layers/attention/mla_registry.go b/layers/attention/mla_registry.go new file mode 100644 index 00000000..cbeef582 --- /dev/null +++ b/layers/attention/mla_registry.go @@ -0,0 +1,91 @@ +package attention + +import ( + "context" + "errors" + "fmt" + + "github.com/zerfoo/zerfoo/compute" + "github.com/zerfoo/zerfoo/graph" + "github.com/zerfoo/zerfoo/layers/core" + "github.com/zerfoo/zerfoo/layers/embeddings" + "github.com/zerfoo/zerfoo/numeric" + "github.com/zerfoo/zerfoo/tensor" +) + +// BuildMultiHeadLatentAttention constructs a MultiHeadLatentAttention node +// for the model builder. It reads kv_lora_dim, num_heads, head_dim, and +// max_seq_len from attributes, and loads W_Q, W_DKV, W_UK, W_UV, W_O +// from node parameters. +func BuildMultiHeadLatentAttention[T tensor.Numeric]( + engine compute.Engine[T], + ops numeric.Arithmetic[T], + name string, + params map[string]*graph.Parameter[T], + attributes map[string]any, +) (graph.Node[T], error) { + numHeads, ok := attributes["num_heads"].(int) + if !ok { + return nil, errors.New("missing or invalid attribute: num_heads") + } + + headDim, ok := attributes["head_dim"].(int) + if !ok { + return nil, errors.New("missing or invalid attribute: head_dim") + } + + kvLoraDim, ok := attributes["kv_lora_dim"].(int) + if !ok { + return nil, errors.New("missing or invalid attribute: kv_lora_dim") + } + + maxSeqLen, ok := attributes["max_seq_len"].(int) + if !ok { + return nil, errors.New("missing or invalid attribute: max_seq_len") + } + + // Load projection parameters. + wqParam, ok := params[name+"_wq"] + if !ok { + return nil, fmt.Errorf("missing required parameter: %s_wq", name) + } + wdkvParam, ok := params[name+"_wdkv"] + if !ok { + return nil, fmt.Errorf("missing required parameter: %s_wdkv", name) + } + wukParam, ok := params[name+"_wuk"] + if !ok { + return nil, fmt.Errorf("missing required parameter: %s_wuk", name) + } + wuvParam, ok := params[name+"_wuv"] + if !ok { + return nil, fmt.Errorf("missing required parameter: %s_wuv", name) + } + woParam, ok := params[name+"_wo"] + if !ok { + return nil, fmt.Errorf("missing required parameter: %s_wo", name) + } + + // Construct Dense layers from parameters (no bias for MLA projections). + wQ := core.NewDenseFromParams(core.NewLinearFromParam(engine, wqParam), nil) + wDKV := core.NewDenseFromParams(core.NewLinearFromParam(engine, wdkvParam), nil) + wUK := core.NewDenseFromParams(core.NewLinearFromParam(engine, wukParam), nil) + wUV := core.NewDenseFromParams(core.NewLinearFromParam(engine, wuvParam), nil) + wO := core.NewDenseFromParams(core.NewLinearFromParam(engine, woParam), nil) + + // Build RoPE. + base := 10000.0 + if b, ok := attributes["rope_base"].(float64); ok { + base = b + } + + rope, err := embeddings.NewRotaryPositionalEmbedding[T]( + context.Background(), engine, headDim, maxSeqLen, + embeddings.WithRotaryBase(base), + ) + if err != nil { + return nil, fmt.Errorf("failed to create RotaryPositionalEmbedding: %w", err) + } + + return NewMultiHeadLatentAttention(engine, ops, numHeads, headDim, kvLoraDim, wQ, wDKV, wUK, wUV, wO, rope), nil +} diff --git a/layers/attention/mla_registry_test.go b/layers/attention/mla_registry_test.go new file mode 100644 index 00000000..c1f190bb --- /dev/null +++ b/layers/attention/mla_registry_test.go @@ -0,0 +1,180 @@ +package attention + +import ( + "context" + "testing" + + "github.com/zerfoo/zerfoo/compute" + "github.com/zerfoo/zerfoo/graph" + "github.com/zerfoo/zerfoo/numeric" + "github.com/zerfoo/zerfoo/tensor" + "github.com/zerfoo/zerfoo/testing/testutils" +) + +func TestBuildMultiHeadLatentAttention(t *testing.T) { + ops := numeric.Float32Ops{} + engine := compute.NewCPUEngine(ops) + + numHeads := 2 + headDim := 4 + kvLoraDim := 3 + hiddenDim := 8 + qkDim := numHeads * headDim + + attrs := map[string]any{ + "num_heads": numHeads, + "head_dim": headDim, + "kv_lora_dim": kvLoraDim, + "max_seq_len": 16, + } + + params := map[string]*graph.Parameter[float32]{ + "mla_wq": makeNonZeroParam(t, "mla_wq", []int{hiddenDim, qkDim}), + "mla_wdkv": makeNonZeroParam(t, "mla_wdkv", []int{hiddenDim, kvLoraDim}), + "mla_wuk": makeNonZeroParam(t, "mla_wuk", []int{kvLoraDim, qkDim}), + "mla_wuv": makeNonZeroParam(t, "mla_wuv", []int{kvLoraDim, qkDim}), + "mla_wo": makeNonZeroParam(t, "mla_wo", []int{qkDim, hiddenDim}), + } + + node, err := BuildMultiHeadLatentAttention(engine, ops, "mla", params, attrs) + if err != nil { + t.Fatalf("BuildMultiHeadLatentAttention failed: %v", err) + } + + if node.OpType() != "MultiHeadLatentAttention" { + t.Errorf("OpType() = %q, want %q", node.OpType(), "MultiHeadLatentAttention") + } + + // Run Forward to verify the built layer works. + input := makeNonZeroInput(t, []int{1, 3, hiddenDim}) + out, err := node.Forward(context.Background(), input) + if err != nil { + t.Fatalf("Forward failed: %v", err) + } + + expected := []int{1, 3, hiddenDim} + if !testutils.IntSliceEqual(expected, out.Shape()) { + t.Errorf("output shape = %v, want %v", out.Shape(), expected) + } +} + +func TestBuildMultiHeadLatentAttention_MissingAttributes(t *testing.T) { + ops := numeric.Float32Ops{} + engine := compute.NewCPUEngine(ops) + + fullAttrs := map[string]any{ + "num_heads": 2, + "head_dim": 4, + "kv_lora_dim": 3, + "max_seq_len": 16, + } + + params := map[string]*graph.Parameter[float32]{ + "mla_wq": makeParam(t, "mla_wq", []int{8, 8}), + "mla_wdkv": makeParam(t, "mla_wdkv", []int{8, 3}), + "mla_wuk": makeParam(t, "mla_wuk", []int{3, 8}), + "mla_wuv": makeParam(t, "mla_wuv", []int{3, 8}), + "mla_wo": makeParam(t, "mla_wo", []int{8, 8}), + } + + requiredAttrs := []string{"num_heads", "head_dim", "kv_lora_dim", "max_seq_len"} + for _, key := range requiredAttrs { + t.Run("missing_"+key, func(t *testing.T) { + attrs := make(map[string]any) + for k, v := range fullAttrs { + if k != key { + attrs[k] = v + } + } + _, err := BuildMultiHeadLatentAttention(engine, ops, "mla", params, attrs) + if err == nil { + t.Errorf("expected error for missing attribute %q", key) + } + }) + } +} + +func TestBuildMultiHeadLatentAttention_MissingParams(t *testing.T) { + ops := numeric.Float32Ops{} + engine := compute.NewCPUEngine(ops) + + attrs := map[string]any{ + "num_heads": 2, + "head_dim": 4, + "kv_lora_dim": 3, + "max_seq_len": 16, + } + + fullParams := map[string]*graph.Parameter[float32]{ + "mla_wq": makeParam(t, "mla_wq", []int{8, 8}), + "mla_wdkv": makeParam(t, "mla_wdkv", []int{8, 3}), + "mla_wuk": makeParam(t, "mla_wuk", []int{3, 8}), + "mla_wuv": makeParam(t, "mla_wuv", []int{3, 8}), + "mla_wo": makeParam(t, "mla_wo", []int{8, 8}), + } + + requiredParams := []string{"mla_wq", "mla_wdkv", "mla_wuk", "mla_wuv", "mla_wo"} + for _, key := range requiredParams { + t.Run("missing_"+key, func(t *testing.T) { + params := make(map[string]*graph.Parameter[float32]) + for k, v := range fullParams { + if k != key { + params[k] = v + } + } + _, err := BuildMultiHeadLatentAttention(engine, ops, "mla", params, attrs) + if err == nil { + t.Errorf("expected error for missing parameter %q", key) + } + }) + } +} + +func TestBuildMultiHeadLatentAttention_CustomRopeBase(t *testing.T) { + ops := numeric.Float32Ops{} + engine := compute.NewCPUEngine(ops) + + attrs := map[string]any{ + "num_heads": 2, + "head_dim": 4, + "kv_lora_dim": 3, + "max_seq_len": 16, + "rope_base": 50000.0, + } + + params := map[string]*graph.Parameter[float32]{ + "mla_wq": makeNonZeroParam(t, "mla_wq", []int{8, 8}), + "mla_wdkv": makeNonZeroParam(t, "mla_wdkv", []int{8, 3}), + "mla_wuk": makeNonZeroParam(t, "mla_wuk", []int{3, 8}), + "mla_wuv": makeNonZeroParam(t, "mla_wuv", []int{3, 8}), + "mla_wo": makeNonZeroParam(t, "mla_wo", []int{8, 8}), + } + + node, err := BuildMultiHeadLatentAttention(engine, ops, "mla", params, attrs) + if err != nil { + t.Fatalf("BuildMultiHeadLatentAttention with custom rope_base failed: %v", err) + } + + input := makeNonZeroInput(t, []int{1, 3, 8}) + _, err = node.Forward(context.Background(), input) + if err != nil { + t.Fatalf("Forward failed: %v", err) + } +} + +func makeNonZeroInput(t *testing.T, shape []int) *tensor.TensorNumeric[float32] { + t.Helper() + total := 1 + for _, d := range shape { + total *= d + } + data := make([]float32, total) + for i := range data { + data[i] = float32(i%7+1) * 0.01 + } + out, err := tensor.New[float32](shape, data) + if err != nil { + t.Fatalf("tensor.New failed: %v", err) + } + return out +} diff --git a/layers/attention/multi_head_latent_attention.go b/layers/attention/multi_head_latent_attention.go new file mode 100644 index 00000000..bdbdbe58 --- /dev/null +++ b/layers/attention/multi_head_latent_attention.go @@ -0,0 +1,231 @@ +package attention + +import ( + "context" + "fmt" + + "github.com/zerfoo/zerfoo/compute" + "github.com/zerfoo/zerfoo/graph" + "github.com/zerfoo/zerfoo/layers/core" + "github.com/zerfoo/zerfoo/layers/embeddings" + "github.com/zerfoo/zerfoo/numeric" + "github.com/zerfoo/zerfoo/tensor" + "github.com/zerfoo/zerfoo/types" +) + +// MultiHeadLatentAttention implements Multi-head Latent Attention (MLA) +// as used in DeepSeek V3/R1. MLA compresses KV into a low-rank latent +// vector, dramatically reducing KV cache size. +type MultiHeadLatentAttention[T tensor.Numeric] struct { + engine compute.Engine[T] + ops numeric.Arithmetic[T] + + // Dimensions + numHeads int + headDim int + kvLoraDim int + + // Projections + wQ *core.Dense[T] // query projection [hidden, numHeads*headDim] + wDKV *core.Dense[T] // down-project KV [hidden, kvLoraDim] + wUK *core.Dense[T] // up-project keys [kvLoraDim, numHeads*headDim] + wUV *core.Dense[T] // up-project values [kvLoraDim, numHeads*headDim] + wO *core.Dense[T] // output projection [numHeads*headDim, hidden] + + rope *embeddings.RotaryPositionalEmbedding[T] + sdpa *ScaledDotProductAttention[T] + + outputShape []int +} + +// NewMultiHeadLatentAttention creates a new MLA layer. +func NewMultiHeadLatentAttention[T tensor.Numeric]( + engine compute.Engine[T], + ops numeric.Arithmetic[T], + numHeads, headDim, kvLoraDim int, + wQ, wDKV, wUK, wUV, wO *core.Dense[T], + rope *embeddings.RotaryPositionalEmbedding[T], +) *MultiHeadLatentAttention[T] { + return &MultiHeadLatentAttention[T]{ + engine: engine, + ops: ops, + numHeads: numHeads, + headDim: headDim, + kvLoraDim: kvLoraDim, + wQ: wQ, + wDKV: wDKV, + wUK: wUK, + wUV: wUV, + wO: wO, + rope: rope, + sdpa: NewScaledDotProductAttention(engine, headDim), + } +} + +// Forward computes the MLA forward pass. +// Input: [batch, seqLen, hidden] +// Output: [batch, seqLen, hidden] +func (m *MultiHeadLatentAttention[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + if len(inputs) != 1 { + return nil, fmt.Errorf("MLA expects 1 input, got %d", len(inputs)) + } + input := inputs[0] + shape := input.Shape() + batch := shape[0] + seqLen := shape[1] + + // 1. Query projection: [batch, seqLen, hidden] -> [batch, seqLen, numHeads*headDim] + qProj, err := m.wQ.Forward(ctx, input) + if err != nil { + return nil, fmt.Errorf("query projection: %w", err) + } + + // 2. KV down-projection: [batch, seqLen, hidden] -> [batch, seqLen, kvLoraDim] + cKV, err := m.wDKV.Forward(ctx, input) + if err != nil { + return nil, fmt.Errorf("kv down-projection: %w", err) + } + + // 3. Up-project keys: [batch, seqLen, kvLoraDim] -> [batch, seqLen, numHeads*headDim] + kProj, err := m.wUK.Forward(ctx, cKV) + if err != nil { + return nil, fmt.Errorf("key up-projection: %w", err) + } + + // 4. Up-project values: [batch, seqLen, kvLoraDim] -> [batch, seqLen, numHeads*headDim] + vProj, err := m.wUV.Forward(ctx, cKV) + if err != nil { + return nil, fmt.Errorf("value up-projection: %w", err) + } + + // 5. Reshape Q/K/V to multi-head format: [batch, seqLen, numHeads, headDim] + qHeads, err := m.engine.Reshape(ctx, qProj, []int{batch, seqLen, m.numHeads, m.headDim}) + if err != nil { + return nil, err + } + kHeads, err := m.engine.Reshape(ctx, kProj, []int{batch, seqLen, m.numHeads, m.headDim}) + if err != nil { + return nil, err + } + vHeads, err := m.engine.Reshape(ctx, vProj, []int{batch, seqLen, m.numHeads, m.headDim}) + if err != nil { + return nil, err + } + + // 6. Transpose to [batch, numHeads, seqLen, headDim] + qHeads, err = m.engine.Transpose(ctx, qHeads, []int{0, 2, 1, 3}) + if err != nil { + return nil, err + } + kHeads, err = m.engine.Transpose(ctx, kHeads, []int{0, 2, 1, 3}) + if err != nil { + return nil, err + } + vHeads, err = m.engine.Transpose(ctx, vHeads, []int{0, 2, 1, 3}) + if err != nil { + return nil, err + } + + // 7. Apply RoPE to Q and K per-head. + // Reshape to [batch*numHeads, seqLen, headDim] for RoPE. + qFlat, err := m.engine.Reshape(ctx, qHeads, []int{batch * m.numHeads, seqLen, m.headDim}) + if err != nil { + return nil, err + } + kFlat, err := m.engine.Reshape(ctx, kHeads, []int{batch * m.numHeads, seqLen, m.headDim}) + if err != nil { + return nil, err + } + + qRoPE, err := m.rope.Forward(ctx, qFlat) + if err != nil { + return nil, fmt.Errorf("rope Q: %w", err) + } + kRoPE, err := m.rope.Forward(ctx, kFlat) + if err != nil { + return nil, fmt.Errorf("rope K: %w", err) + } + + // Reshape back to [batch, numHeads, seqLen, headDim] + qHeads, err = m.engine.Reshape(ctx, qRoPE, []int{batch, m.numHeads, seqLen, m.headDim}) + if err != nil { + return nil, err + } + kHeads, err = m.engine.Reshape(ctx, kRoPE, []int{batch, m.numHeads, seqLen, m.headDim}) + if err != nil { + return nil, err + } + + // 8. Flatten heads into batch for SDPA: [batch, numHeads, seqLen, headDim] -> [batch*numHeads, seqLen, headDim] + qFlat2, err := m.engine.Reshape(ctx, qHeads, []int{batch * m.numHeads, seqLen, m.headDim}) + if err != nil { + return nil, err + } + kFlat2, err := m.engine.Reshape(ctx, kHeads, []int{batch * m.numHeads, seqLen, m.headDim}) + if err != nil { + return nil, err + } + vFlat2, err := m.engine.Reshape(ctx, vHeads, []int{batch * m.numHeads, seqLen, m.headDim}) + if err != nil { + return nil, err + } + + // 9. Scaled dot-product attention: [batch*numHeads, seqLen, headDim] + attnOut, err := m.sdpa.Forward(ctx, qFlat2, kFlat2, vFlat2, nil) + if err != nil { + return nil, fmt.Errorf("sdpa: %w", err) + } + + // 10. Reshape back: [batch*numHeads, seqLen, headDim] -> [batch, seqLen, numHeads*headDim] + attnOut, err = m.engine.Reshape(ctx, attnOut, []int{batch, seqLen, m.numHeads * m.headDim}) + if err != nil { + return nil, err + } + + // 11. Output projection: [batch, seqLen, numHeads*headDim] -> [batch, seqLen, hidden] + output, err := m.wO.Forward(ctx, attnOut) + if err != nil { + return nil, fmt.Errorf("output projection: %w", err) + } + + m.outputShape = output.Shape() + return output, nil +} + +// Backward computes gradients for MLA (not yet implemented). +func (m *MultiHeadLatentAttention[T]) Backward(_ context.Context, _ types.BackwardMode, _ *tensor.TensorNumeric[T], _ ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error) { + return nil, fmt.Errorf("MLA backward not implemented") +} + +// Parameters returns all trainable parameters. +func (m *MultiHeadLatentAttention[T]) Parameters() []*graph.Parameter[T] { + var params []*graph.Parameter[T] + params = append(params, m.wQ.Parameters()...) + params = append(params, m.wDKV.Parameters()...) + params = append(params, m.wUK.Parameters()...) + params = append(params, m.wUV.Parameters()...) + params = append(params, m.wO.Parameters()...) + return params +} + +// OutputShape returns the output shape. +func (m *MultiHeadLatentAttention[T]) OutputShape() []int { + return m.outputShape +} + +// OpType returns the layer operation type. +func (m *MultiHeadLatentAttention[T]) OpType() string { + return "MultiHeadLatentAttention" +} + +// Attributes returns the layer attributes. +func (m *MultiHeadLatentAttention[T]) Attributes() map[string]any { + return map[string]any{ + "num_heads": m.numHeads, + "head_dim": m.headDim, + "kv_lora_dim": m.kvLoraDim, + } +} + +// Compile-time interface check. +var _ graph.Node[float32] = (*MultiHeadLatentAttention[float32])(nil) diff --git a/layers/attention/multi_head_latent_attention_test.go b/layers/attention/multi_head_latent_attention_test.go new file mode 100644 index 00000000..8983ad5c --- /dev/null +++ b/layers/attention/multi_head_latent_attention_test.go @@ -0,0 +1,199 @@ +package attention + +import ( + "context" + "testing" + + "github.com/zerfoo/zerfoo/compute" + "github.com/zerfoo/zerfoo/layers/core" + "github.com/zerfoo/zerfoo/layers/embeddings" + "github.com/zerfoo/zerfoo/numeric" + "github.com/zerfoo/zerfoo/tensor" + "github.com/zerfoo/zerfoo/testing/testutils" +) + +// newTestMLA creates a MultiHeadLatentAttention layer with small dimensions +// suitable for unit testing. +func newTestMLA(t *testing.T) *MultiHeadLatentAttention[float32] { + t.Helper() + ctx := context.Background() + ops := numeric.Float32Ops{} + engine := compute.NewCPUEngine(ops) + + numHeads := 2 + headDim := 4 + kvLoraDim := 3 + hiddenDim := 8 + + qkDim := numHeads * headDim // 8 + + wQ, err := core.NewDense[float32]("wQ", engine, ops, hiddenDim, qkDim, core.WithoutBias[float32]()) + if err != nil { + t.Fatalf("wQ: %v", err) + } + wDKV, err := core.NewDense[float32]("wDKV", engine, ops, hiddenDim, kvLoraDim, core.WithoutBias[float32]()) + if err != nil { + t.Fatalf("wDKV: %v", err) + } + wUK, err := core.NewDense[float32]("wUK", engine, ops, kvLoraDim, qkDim, core.WithoutBias[float32]()) + if err != nil { + t.Fatalf("wUK: %v", err) + } + wUV, err := core.NewDense[float32]("wUV", engine, ops, kvLoraDim, qkDim, core.WithoutBias[float32]()) + if err != nil { + t.Fatalf("wUV: %v", err) + } + wO, err := core.NewDense[float32]("wO", engine, ops, qkDim, hiddenDim, core.WithoutBias[float32]()) + if err != nil { + t.Fatalf("wO: %v", err) + } + + rope, err := embeddings.NewRotaryPositionalEmbedding[float32](ctx, engine, headDim, 16) + if err != nil { + t.Fatalf("rope: %v", err) + } + + return NewMultiHeadLatentAttention(engine, ops, numHeads, headDim, kvLoraDim, wQ, wDKV, wUK, wUV, wO, rope) +} + +func TestMultiHeadLatentAttention_Forward_Shape(t *testing.T) { + mla := newTestMLA(t) + + batch := 1 + seqLen := 3 + hiddenDim := 8 + + input, err := tensor.New[float32]([]int{batch, seqLen, hiddenDim}, nil) + if err != nil { + t.Fatalf("tensor.New: %v", err) + } + for i := range input.Data() { + input.Data()[i] = float32(i%7+1) * 0.01 + } + + out, err := mla.Forward(context.Background(), input) + if err != nil { + t.Fatalf("Forward: %v", err) + } + + expected := []int{batch, seqLen, hiddenDim} + if !testutils.IntSliceEqual(expected, out.Shape()) { + t.Errorf("output shape = %v, want %v", out.Shape(), expected) + } +} + +func TestMultiHeadLatentAttention_Forward_Batch(t *testing.T) { + mla := newTestMLA(t) + + batch := 2 + seqLen := 4 + hiddenDim := 8 + + input, err := tensor.New[float32]([]int{batch, seqLen, hiddenDim}, nil) + if err != nil { + t.Fatalf("tensor.New: %v", err) + } + for i := range input.Data() { + input.Data()[i] = float32(i%11+1) * 0.02 + } + + out, err := mla.Forward(context.Background(), input) + if err != nil { + t.Fatalf("Forward: %v", err) + } + + expected := []int{batch, seqLen, hiddenDim} + if !testutils.IntSliceEqual(expected, out.Shape()) { + t.Errorf("output shape = %v, want %v", out.Shape(), expected) + } + + // Verify no NaN in output. + for i, v := range out.Data() { + if v != v { + t.Fatalf("NaN at index %d", i) + } + } +} + +func TestMultiHeadLatentAttention_Forward_InputValidation(t *testing.T) { + mla := newTestMLA(t) + ctx := context.Background() + + input1, _ := tensor.New[float32]([]int{1, 2, 8}, nil) + input2, _ := tensor.New[float32]([]int{1, 2, 8}, nil) + + _, err := mla.Forward(ctx, input1, input2) + if err == nil { + t.Error("expected error for 2 inputs, got nil") + } + + _, err = mla.Forward(ctx) + if err == nil { + t.Error("expected error for 0 inputs, got nil") + } +} + +func TestMultiHeadLatentAttention_Parameters(t *testing.T) { + mla := newTestMLA(t) + params := mla.Parameters() + + // 5 Dense layers without bias = 5 weight parameters. + if len(params) != 5 { + t.Errorf("len(Parameters()) = %d, want 5", len(params)) + } +} + +func TestMultiHeadLatentAttention_OpType(t *testing.T) { + mla := newTestMLA(t) + if got := mla.OpType(); got != "MultiHeadLatentAttention" { + t.Errorf("OpType() = %q, want %q", got, "MultiHeadLatentAttention") + } +} + +func TestMultiHeadLatentAttention_Attributes(t *testing.T) { + mla := newTestMLA(t) + attrs := mla.Attributes() + + tests := []struct { + key string + want int + }{ + {"num_heads", 2}, + {"head_dim", 4}, + {"kv_lora_dim", 3}, + } + for _, tt := range tests { + v, ok := attrs[tt.key].(int) + if !ok { + t.Errorf("Attributes()[%q] missing or not int", tt.key) + continue + } + if v != tt.want { + t.Errorf("Attributes()[%q] = %d, want %d", tt.key, v, tt.want) + } + } +} + +func TestMultiHeadLatentAttention_OutputShape(t *testing.T) { + mla := newTestMLA(t) + + // Before Forward, OutputShape is nil. + if s := mla.OutputShape(); s != nil { + t.Errorf("OutputShape() before Forward = %v, want nil", s) + } + + input, _ := tensor.New[float32]([]int{1, 3, 8}, nil) + for i := range input.Data() { + input.Data()[i] = 0.1 + } + + _, err := mla.Forward(context.Background(), input) + if err != nil { + t.Fatalf("Forward: %v", err) + } + + expected := []int{1, 3, 8} + if !testutils.IntSliceEqual(expected, mla.OutputShape()) { + t.Errorf("OutputShape() = %v, want %v", mla.OutputShape(), expected) + } +} diff --git a/layers/attention/registry.go b/layers/attention/registry.go index f9b5b55d..29e2045d 100644 --- a/layers/attention/registry.go +++ b/layers/attention/registry.go @@ -68,15 +68,28 @@ func buildGroupedQueryAttention[T tensor.Numeric]( return nil, fmt.Errorf("missing required parameter: %s_wo", name) } - // Construct sub-layers from parameters - wqDense := core.NewDenseFromParams(core.NewLinearFromParam(engine, wq), nil) // Assuming no bias for now - wkDense := core.NewDenseFromParams(core.NewLinearFromParam(engine, wk), nil) - wvDense := core.NewDenseFromParams(core.NewLinearFromParam(engine, wv), nil) - woDense := core.NewDenseFromParams(core.NewLinearFromParam(engine, wo), nil) + // Construct sub-layers from parameters, with optional bias + wqDense := core.NewDenseFromParams(core.NewLinearFromParam(engine, wq), optionalBias(engine, ops, params, name+"_wq_bias")) + wkDense := core.NewDenseFromParams(core.NewLinearFromParam(engine, wk), optionalBias(engine, ops, params, name+"_wk_bias")) + wvDense := core.NewDenseFromParams(core.NewLinearFromParam(engine, wv), optionalBias(engine, ops, params, name+"_wv_bias")) + woDense := core.NewDenseFromParams(core.NewLinearFromParam(engine, wo), optionalBias(engine, ops, params, name+"_wo_bias")) headDim := modelDim / numQueryHeads - rope, err := embeddings.NewRotaryPositionalEmbedding[T](context.Background(), engine, headDim, maxSeqLen, embeddings.WithRotaryBase(base)) + // Build RoPE options + ropeOpts := []embeddings.RotaryPositionalEmbeddingOption{embeddings.WithRotaryBase(base)} + if scalingType, _ := attributes["rope_scaling_type"].(string); scalingType == "yarn" { + factor, _ := attributes["rope_scaling_factor"].(float64) + origMaxLen, _ := attributes["rope_scaling_orig_max_len"].(int) + if factor > 0 && origMaxLen > 0 { + ropeOpts = append(ropeOpts, embeddings.WithYaRNScaling(factor, origMaxLen)) + } + } + if fraction, ok := attributes["partial_rotary_factor"].(float64); ok && fraction > 0 && fraction < 1.0 { + ropeOpts = append(ropeOpts, embeddings.WithRotaryDimFraction(fraction)) + } + + rope, err := embeddings.NewRotaryPositionalEmbedding[T](context.Background(), engine, headDim, maxSeqLen, ropeOpts...) if err != nil { return nil, fmt.Errorf("failed to create RotaryPositionalEmbedding: %w", err) } @@ -87,6 +100,20 @@ func buildGroupedQueryAttention[T tensor.Numeric]( ) } +// optionalBias returns a *core.Bias[T] if the named parameter exists, nil otherwise. +func optionalBias[T tensor.Numeric]( + engine compute.Engine[T], + ops numeric.Arithmetic[T], + params map[string]*graph.Parameter[T], + key string, +) *core.Bias[T] { + p, ok := params[key] + if !ok { + return nil + } + return core.NewBiasFromParam(engine, ops, p) +} + //nolint:unused // retained for completeness; may be used by future registry wiring func buildGlobalAttention[T tensor.Numeric]( engine compute.Engine[T], diff --git a/layers/core/lm_head.go b/layers/core/lm_head.go index 22b9f211..c3eab419 100644 --- a/layers/core/lm_head.go +++ b/layers/core/lm_head.go @@ -11,9 +11,12 @@ import ( ) // LMHead is a linear layer that maps hidden states to vocabulary logits. +// When tiedWeight is set, it uses the (transposed) tied weight instead of +// its own linear layer. type LMHead[T tensor.Numeric] struct { - linear *Linear[T] - engine compute.Engine[T] + linear *Linear[T] + engine compute.Engine[T] + tiedWeight *tensor.TensorNumeric[T] // [vocabSize, hiddenDim] from embedding } // NewLMHead creates a new LMHead. @@ -26,6 +29,13 @@ func NewLMHead[T tensor.Numeric](engine compute.Engine[T], ops numeric.Arithmeti return &LMHead[T]{linear: linear, engine: engine}, nil } +// NewTiedLMHead creates an LMHead that reuses an existing embedding weight +// matrix (shape [vocabSize, hiddenDim]) instead of owning its own weight. +// Forward transposes the weight and performs the projection. +func NewTiedLMHead[T tensor.Numeric](engine compute.Engine[T], embedWeight *tensor.TensorNumeric[T]) *LMHead[T] { + return &LMHead[T]{engine: engine, tiedWeight: embedWeight} +} + // SetWeights sets the weights of the LMHead. This is useful for sharing weights // with a token embedding layer. func (h *LMHead[T]) SetWeights(weights *tensor.TensorNumeric[T]) { @@ -45,12 +55,27 @@ func (h *LMHead[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric return nil, err } - output, err := h.linear.Forward(ctx, reshapedInput) - if err != nil { - return nil, err - } + var output *tensor.TensorNumeric[T] + var vocabSize int - vocabSize := h.linear.OutputShape()[1] + if h.tiedWeight != nil { + // Tied weight is [vocabSize, hiddenDim]; transpose to [hiddenDim, vocabSize]. + transposed, err2 := h.engine.Transpose(ctx, h.tiedWeight, []int{1, 0}) + if err2 != nil { + return nil, err2 + } + output, err = h.engine.MatMul(ctx, reshapedInput, transposed) + if err != nil { + return nil, err + } + vocabSize = h.tiedWeight.Shape()[0] + } else { + output, err = h.linear.Forward(ctx, reshapedInput) + if err != nil { + return nil, err + } + vocabSize = h.linear.OutputShape()[1] + } output, err = h.engine.Reshape(ctx, output, []int{batchSize, seqLen, vocabSize}) if err != nil { @@ -66,11 +91,18 @@ func (h *LMHead[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut } // Parameters returns the parameters of the LMHead. +// Tied LMHead returns nil since the embedding layer owns the weight. func (h *LMHead[T]) Parameters() []*graph.Parameter[T] { + if h.tiedWeight != nil { + return nil + } return h.linear.Parameters() } // OutputShape returns the output shape of the LMHead. func (h *LMHead[T]) OutputShape() []int { + if h.tiedWeight != nil { + return nil // shape is determined dynamically from the tied weight + } return h.linear.OutputShape() } diff --git a/layers/core/lm_head_test.go b/layers/core/lm_head_test.go index 377e4c7c..70c5355b 100644 --- a/layers/core/lm_head_test.go +++ b/layers/core/lm_head_test.go @@ -10,6 +10,89 @@ import ( "github.com/zerfoo/zerfoo/testing/testutils" ) +func TestNewTiedLMHead(t *testing.T) { + ctx := context.Background() + ops := numeric.Float32Ops{} + engine := compute.NewCPUEngine[float32](ops) + + hiddenDim := 8 + vocabSize := 4 + + // Simulate a token embedding weight [vocabSize, hiddenDim]. + embedData := make([]float32, vocabSize*hiddenDim) + for i := range embedData { + embedData[i] = float32(i%7+1) * 0.01 + } + embedWeight, err := tensor.New[float32]([]int{vocabSize, hiddenDim}, embedData) + if err != nil { + t.Fatalf("tensor.New failed: %v", err) + } + + lmHead := NewTiedLMHead[float32](engine, embedWeight) + + // Forward: input [1, 2, hiddenDim] -> output [1, 2, vocabSize] + inputData := make([]float32, 2*hiddenDim) + for i := range inputData { + inputData[i] = float32(i%5+1) * 0.1 + } + input, err := tensor.New[float32]([]int{1, 2, hiddenDim}, inputData) + if err != nil { + t.Fatalf("tensor.New failed: %v", err) + } + + output, err := lmHead.Forward(ctx, input) + if err != nil { + t.Fatalf("Forward failed: %v", err) + } + + if output.Shape()[0] != 1 || output.Shape()[1] != 2 || output.Shape()[2] != vocabSize { + t.Errorf("output shape = %v, want [1, 2, %d]", output.Shape(), vocabSize) + } + + // Verify output matches manual matmul of input * embed^T. + // embedWeight is [vocabSize, hiddenDim], so we need input * embed^T. + transposed, err := engine.Transpose(ctx, embedWeight, []int{1, 0}) + if err != nil { + t.Fatalf("Transpose failed: %v", err) + } + // Reshape input to [2, hiddenDim] for matmul + reshapedInput, err := engine.Reshape(ctx, input, []int{2, hiddenDim}) + if err != nil { + t.Fatalf("Reshape failed: %v", err) + } + expected, err := engine.MatMul(ctx, reshapedInput, transposed) + if err != nil { + t.Fatalf("MatMul failed: %v", err) + } + expectedReshaped, err := engine.Reshape(ctx, expected, []int{1, 2, vocabSize}) + if err != nil { + t.Fatalf("Reshape failed: %v", err) + } + + for i, v := range output.Data() { + diff := v - expectedReshaped.Data()[i] + if diff > 1e-5 || diff < -1e-5 { + t.Errorf("output[%d] = %f, want %f", i, v, expectedReshaped.Data()[i]) + break + } + } +} + +func TestTiedLMHead_Parameters(t *testing.T) { + ops := numeric.Float32Ops{} + engine := compute.NewCPUEngine[float32](ops) + + embedWeight, _ := tensor.New[float32]([]int{4, 8}, nil) + lmHead := NewTiedLMHead[float32](engine, embedWeight) + + // Tied LMHead should have no trainable parameters of its own + // (the embedding owns the weight). + params := lmHead.Parameters() + if len(params) != 0 { + t.Errorf("TiedLMHead should have 0 parameters, got %d", len(params)) + } +} + func TestLMHead_Forward(t *testing.T) { ctx := context.Background() ops := numeric.Float32Ops{} diff --git a/layers/core/moe.go b/layers/core/moe.go index f3592cf9..5c69c98a 100644 --- a/layers/core/moe.go +++ b/layers/core/moe.go @@ -185,13 +185,14 @@ var _ graph.Node[float32] = (*MoEGate[float32])(nil) // Tech debt: ZMF sub-graph loading is not yet supported; experts are not // populated by BuildMixtureOfExperts. type MixtureOfExperts[T tensor.Numeric] struct { - engine compute.Engine[T] - ops numeric.Arithmetic[T] - gate *MoEGate[T] - experts []graph.Node[T] - numExperts int - topK int - outputShape []int + engine compute.Engine[T] + ops numeric.Arithmetic[T] + gate *MoEGate[T] + experts []graph.Node[T] + SharedExpert graph.Node[T] // optional: runs on every token, output added to routed sum + numExperts int + topK int + outputShape []int } // NewMixtureOfExperts creates a MixtureOfExperts layer. @@ -243,6 +244,18 @@ func (m *MixtureOfExperts[T]) Forward(ctx context.Context, inputs ...*tensor.Ten return nil, fmt.Errorf("MixtureOfExperts: create token tensor: %w", terr) } + // Shared expert: runs on every token. + if m.SharedExpert != nil { + sharedOut, serr := m.SharedExpert.Forward(ctx, token) + if serr != nil { + return nil, fmt.Errorf("MixtureOfExperts: shared expert forward: %w", serr) + } + sharedData := sharedOut.Data() + for d := 0; d < modelDim && d < len(sharedData); d++ { + outData[t*modelDim+d] = m.ops.Add(outData[t*modelDim+d], sharedData[d]) + } + } + for k := 0; k < m.topK; k++ { expertIdx := indices[t][k] if expertIdx >= len(m.experts) { @@ -277,11 +290,15 @@ func (m *MixtureOfExperts[T]) Backward(_ context.Context, _ types.BackwardMode, func (m *MixtureOfExperts[T]) OpType() string { return "MixtureOfExperts" } // Attributes returns the layer configuration. -func (m *MixtureOfExperts[T]) Attributes() map[string]interface{} { - return map[string]interface{}{ +func (m *MixtureOfExperts[T]) Attributes() map[string]any { + attrs := map[string]any{ "num_experts": m.numExperts, "top_k": m.topK, } + if m.SharedExpert != nil { + attrs["has_shared_expert"] = true + } + return attrs } // OutputShape returns the output shape from the last forward call. diff --git a/layers/core/moe_test.go b/layers/core/moe_test.go index e0193844..c6702390 100644 --- a/layers/core/moe_test.go +++ b/layers/core/moe_test.go @@ -405,3 +405,77 @@ func TestMixtureOfExperts_HiddenStates1D(t *testing.T) { t.Error("expected error for 1D hiddenStates") } } + +func TestMixtureOfExperts_WithSharedExpert(t *testing.T) { + eng := makeFloat32Engine() + ops := numeric.Float32Ops{} + gate := NewMoEGate[float32](eng, ops, 1) + + experts := []graph.Node[float32]{&identityExpert{}} + moe := NewMixtureOfExperts[float32](eng, ops, gate, experts, 1, 1) + moe.SharedExpert = &scale2Expert{} // shared expert multiplies by 2 + + // hiddenStates [1, 2]: single token [3, 4] + hs, _ := tensor.New[float32]([]int{1, 2}, []float32{3, 4}) + // gateWeight [1, 2]: single expert + gw, _ := tensor.New[float32]([]int{1, 2}, []float32{1, 0}) + + out, err := moe.Forward(context.Background(), hs, gw) + if err != nil { + t.Fatalf("Forward with shared expert failed: %v", err) + } + + data := out.Data() + // Expected: shared(token) + routed(token) + // shared = scale2([3,4]) = [6,8] + // routed = identity([3,4]) * 1.0 = [3,4] + // total = [9, 12] + if math.Abs(float64(data[0])-9.0) > 1e-4 || math.Abs(float64(data[1])-12.0) > 1e-4 { + t.Errorf("output = [%f, %f], want [9, 12]", data[0], data[1]) + } +} + +func TestMixtureOfExperts_WithoutSharedExpert_BackwardCompat(t *testing.T) { + // Verify that nil SharedExpert produces the same result as before. + eng := makeFloat32Engine() + ops := numeric.Float32Ops{} + gate := NewMoEGate[float32](eng, ops, 1) + + experts := []graph.Node[float32]{&identityExpert{}} + moe := NewMixtureOfExperts[float32](eng, ops, gate, experts, 1, 1) + // SharedExpert is nil by default. + + hs, _ := tensor.New[float32]([]int{1, 2}, []float32{3, 4}) + gw, _ := tensor.New[float32]([]int{1, 2}, []float32{1, 0}) + + out, err := moe.Forward(context.Background(), hs, gw) + if err != nil { + t.Fatalf("Forward without shared expert failed: %v", err) + } + + data := out.Data() + // identity expert, weight 1.0: output = [3, 4] + if math.Abs(float64(data[0])-3.0) > 1e-4 || math.Abs(float64(data[1])-4.0) > 1e-4 { + t.Errorf("output = [%f, %f], want [3, 4]", data[0], data[1]) + } +} + +func TestMixtureOfExperts_SharedExpert_Attributes(t *testing.T) { + eng := makeFloat32Engine() + ops := numeric.Float32Ops{} + gate := NewMoEGate[float32](eng, ops, 1) + moe := NewMixtureOfExperts[float32](eng, ops, gate, nil, 2, 1) + + // Without shared expert. + attrs := moe.Attributes() + if _, ok := attrs["has_shared_expert"]; ok { + t.Error("expected no has_shared_expert attribute when SharedExpert is nil") + } + + // With shared expert. + moe.SharedExpert = &identityExpert{} + attrs = moe.Attributes() + if v, ok := attrs["has_shared_expert"]; !ok || v != true { + t.Error("expected has_shared_expert=true when SharedExpert is set") + } +} diff --git a/layers/embeddings/rotary_positional_embedding.go b/layers/embeddings/rotary_positional_embedding.go index 115bceaf..8af9db5b 100644 --- a/layers/embeddings/rotary_positional_embedding.go +++ b/layers/embeddings/rotary_positional_embedding.go @@ -15,6 +15,7 @@ import ( type RotaryPositionalEmbedding[T tensor.Numeric] struct { engine compute.Engine[T] headDim int + rotaryDim int // number of dimensions that receive rotation (<= headDim) cosAngles *tensor.TensorNumeric[T] sinAngles *tensor.TensorNumeric[T] // Cached input for backward pass @@ -22,11 +23,17 @@ type RotaryPositionalEmbedding[T tensor.Numeric] struct { xRot0Slice *tensor.TensorNumeric[T] xRot1Slice *tensor.TensorNumeric[T] outputShape []int + // attnScaleFactor is the YaRN attention scaling factor (1.0 when no YaRN). + attnScaleFactor float64 } // RotaryPositionalEmbeddingOptions holds configuration options for RotaryPositionalEmbedding layers. type RotaryPositionalEmbeddingOptions struct { - Base float64 // Base for the inverse frequency calculation (theta parameter) + Base float64 // Base for the inverse frequency calculation (theta parameter) + YaRN bool // Whether to apply YaRN scaling + YaRNFactor float64 // YaRN scaling factor (e.g. 4.0 for 4x context extension) + YaRNOrigML int // Original max sequence length before scaling + RotaryDimFraction float64 // Fraction of head dims to rotate (default 1.0 = all) } // RotaryPositionalEmbeddingOption is a functional option for configuring RotaryPositionalEmbedding layers. @@ -39,6 +46,25 @@ func WithRotaryBase(base float64) RotaryPositionalEmbeddingOption { } } +// WithRotaryDimFraction sets the fraction of head dimensions that receive rotation. +// Default is 1.0 (all dimensions rotated). Phi-4 uses 0.75 for partial RoPE. +func WithRotaryDimFraction(fraction float64) RotaryPositionalEmbeddingOption { + return func(opts *RotaryPositionalEmbeddingOptions) { + opts.RotaryDimFraction = fraction + } +} + +// WithYaRNScaling enables YaRN (Yet another RoPE extensioN) scaling. +// factor is the context extension factor (e.g. 4.0 for 4x). +// origMaxLen is the original maximum sequence length before scaling. +func WithYaRNScaling(factor float64, origMaxLen int) RotaryPositionalEmbeddingOption { + return func(opts *RotaryPositionalEmbeddingOptions) { + opts.YaRN = true + opts.YaRNFactor = factor + opts.YaRNOrigML = origMaxLen + } +} + // NewRotaryPositionalEmbedding creates a new RotaryPositionalEmbedding layer. // headDim: The dimension of the head. Must be even. // seqLen: The maximum sequence length this embedding will be applied to. @@ -56,53 +82,84 @@ func NewRotaryPositionalEmbedding[T tensor.Numeric]( // Apply functional options opts := &RotaryPositionalEmbeddingOptions{ - Base: 10000.0, // Default base value (theta) + Base: 10000.0, // Default base value (theta) + RotaryDimFraction: 1.0, } for _, option := range options { option(opts) } + // Compute the number of dimensions that receive rotation. + rotaryDim := headDim + if opts.RotaryDimFraction > 0 && opts.RotaryDimFraction < 1.0 { + rotaryDim = int(float64(headDim) * opts.RotaryDimFraction) + // Ensure rotaryDim is even. + rotaryDim &^= 1 + } + // Create position indices: [0, 1, ..., seq_len-1] positions := make([]int, seqLen) for i := 0; i < seqLen; i++ { positions[i] = i } - // Create inverse frequencies: 1 / (base^(2i/head_dim)) + // Create inverse frequencies: 1 / (base^(2i/rotaryDim)) ops := engine.Ops() - invFreqs64 := make([]float64, headDim/2) - for i := 0; i < headDim/2; i++ { - invFreqs64[i] = 1.0 / math.Pow(opts.Base, float64(2*i)/float64(headDim)) + halfDim := rotaryDim / 2 + invFreqs64 := make([]float64, halfDim) + for i := 0; i < halfDim; i++ { + invFreqs64[i] = 1.0 / math.Pow(opts.Base, float64(2*i)/float64(rotaryDim)) + } + + // Apply YaRN scaling to inverse frequencies if enabled. + attnScaleFactor := 1.0 + if opts.YaRN { + attnScaleFactor = math.Sqrt(1 + math.Log(opts.YaRNFactor)/math.Log(float64(opts.YaRNOrigML))) + origML := float64(opts.YaRNOrigML) + for i := 0; i < halfDim; i++ { + wavelength := 2 * math.Pi / invFreqs64[i] + if wavelength > opts.YaRNFactor*origML { + // Low frequency: scale by 1/factor + invFreqs64[i] /= opts.YaRNFactor + } else if wavelength >= origML { + // Intermediate frequency: linearly interpolate + ratio := (wavelength - origML) / (opts.YaRNFactor*origML - origML) + invFreqs64[i] = invFreqs64[i] * (1 - ratio) + invFreqs64[i]/opts.YaRNFactor*ratio + } + // High frequency (wavelength < origMaxLen): keep unchanged + } } // Precompute cos and sin of angles using float64 and convert to T - size := seqLen * (headDim / 2) + size := seqLen * halfDim cosData := make([]T, size) sinData := make([]T, size) for i := 0; i < seqLen; i++ { - for j := 0; j < headDim/2; j++ { + for j := 0; j < halfDim; j++ { angle := float64(positions[i]) * invFreqs64[j] - idx := i*(headDim/2) + j + idx := i*halfDim + j cosData[idx] = ops.FromFloat64(math.Cos(angle)) sinData[idx] = ops.FromFloat64(math.Sin(angle)) } } - cosAngles, err := tensor.New[T]([]int{seqLen, headDim / 2}, cosData) + cosAngles, err := tensor.New[T]([]int{seqLen, halfDim}, cosData) if err != nil { return nil, err } - sinAngles, err := tensor.New[T]([]int{seqLen, headDim / 2}, sinData) + sinAngles, err := tensor.New[T]([]int{seqLen, halfDim}, sinData) if err != nil { return nil, err } return &RotaryPositionalEmbedding[T]{ - engine: engine, - headDim: headDim, - cosAngles: cosAngles, - sinAngles: sinAngles, + engine: engine, + headDim: headDim, + rotaryDim: rotaryDim, + cosAngles: cosAngles, + sinAngles: sinAngles, + attnScaleFactor: attnScaleFactor, }, nil } @@ -130,25 +187,26 @@ func (rpe *RotaryPositionalEmbedding[T]) Forward(ctx context.Context, inputs ... } seqLen := rpe.inputShape[1] + halfRotary := rpe.rotaryDim / 2 // Slice cos and sin angles to match the input sequence length - cosAngles, err := rpe.cosAngles.Slice([2]int{0, seqLen}, [2]int{0, rpe.headDim / 2}) + cosAngles, err := rpe.cosAngles.Slice([2]int{0, seqLen}, [2]int{0, halfRotary}) if err != nil { return nil, err } - sinAngles, err := rpe.sinAngles.Slice([2]int{0, seqLen}, [2]int{0, rpe.headDim / 2}) + sinAngles, err := rpe.sinAngles.Slice([2]int{0, seqLen}, [2]int{0, halfRotary}) if err != nil { return nil, err } - // Split input into two halves: x_rot0, x_rot1 - rpe.xRot0Slice, err = input.Slice([2]int{0, rpe.inputShape[0]}, [2]int{0, seqLen}, [2]int{0, rpe.headDim / 2}) + // Split rotary portion into two halves: x_rot0, x_rot1 + rpe.xRot0Slice, err = input.Slice([2]int{0, rpe.inputShape[0]}, [2]int{0, seqLen}, [2]int{0, halfRotary}) if err != nil { return nil, err } - rpe.xRot1Slice, err = input.Slice([2]int{0, rpe.inputShape[0]}, [2]int{0, seqLen}, [2]int{rpe.headDim / 2, rpe.headDim}) + rpe.xRot1Slice, err = input.Slice([2]int{0, rpe.inputShape[0]}, [2]int{0, seqLen}, [2]int{halfRotary, rpe.rotaryDim}) if err != nil { return nil, err } @@ -156,26 +214,21 @@ func (rpe *RotaryPositionalEmbedding[T]) Forward(ctx context.Context, inputs ... // Apply rotation: // x_rot0 * cos(angles) - x_rot1 * sin(angles) // x_rot1 * cos(angles) + x_rot0 * sin(angles) - - // Term 1: x_rot0 * cos(angles) - term1, err := rpe.engine.Mul(ctx, rpe.xRot0Slice, cosAngles) // Broadcasting cosAngles + term1, err := rpe.engine.Mul(ctx, rpe.xRot0Slice, cosAngles) if err != nil { return nil, err } - // Term 2: x_rot1 * sin(angles) - term2, err := rpe.engine.Mul(ctx, rpe.xRot1Slice, sinAngles) // Broadcasting sinAngles + term2, err := rpe.engine.Mul(ctx, rpe.xRot1Slice, sinAngles) if err != nil { return nil, err } - // rotated_x0 = term1 - term2 rotatedX0, err := rpe.engine.Sub(ctx, term1, term2) if err != nil { return nil, err } - // rotated_x1 = x_rot1 * cos(angles) + x_rot0 * sin(angles) mul1, err := rpe.engine.Mul(ctx, rpe.xRot1Slice, cosAngles) if err != nil { return nil, err @@ -191,8 +244,17 @@ func (rpe *RotaryPositionalEmbedding[T]) Forward(ctx context.Context, inputs ... return nil, err } - // Concatenate rotated halves - output, err := rpe.engine.Concat(ctx, []*tensor.TensorNumeric[T]{rotatedX0, rotatedX1}, 2) + // Concatenate rotated halves, plus pass-through if partial. + parts := []*tensor.TensorNumeric[T]{rotatedX0, rotatedX1} + if rpe.rotaryDim < rpe.headDim { + passThrough, err2 := input.Slice([2]int{0, rpe.inputShape[0]}, [2]int{0, seqLen}, [2]int{rpe.rotaryDim, rpe.headDim}) + if err2 != nil { + return nil, err2 + } + parts = append(parts, passThrough) + } + + output, err := rpe.engine.Concat(ctx, parts, 2) if err != nil { return nil, err } @@ -207,30 +269,30 @@ func (rpe *RotaryPositionalEmbedding[T]) Backward(ctx context.Context, mode type dShape := dOut.Shape() batchDim := dShape[0] seqLen := dShape[1] + halfRotary := rpe.rotaryDim / 2 // Slice cos and sin angles to match the input sequence length - cosAngles, err := rpe.cosAngles.Slice([2]int{0, seqLen}, [2]int{0, rpe.headDim / 2}) + cosAngles, err := rpe.cosAngles.Slice([2]int{0, seqLen}, [2]int{0, halfRotary}) if err != nil { return nil, err } - sinAngles, err := rpe.sinAngles.Slice([2]int{0, seqLen}, [2]int{0, rpe.headDim / 2}) + sinAngles, err := rpe.sinAngles.Slice([2]int{0, seqLen}, [2]int{0, halfRotary}) if err != nil { return nil, err } - // Split dOut into d_rotated_x0, d_rotated_x1 - dRotatedX0, err := dOut.Slice([2]int{0, batchDim}, [2]int{0, seqLen}, [2]int{0, rpe.headDim / 2}) + // Split dOut rotary portion into d_rotated_x0, d_rotated_x1 + dRotatedX0, err := dOut.Slice([2]int{0, batchDim}, [2]int{0, seqLen}, [2]int{0, halfRotary}) if err != nil { return nil, err } - dRotatedX1, err := dOut.Slice([2]int{0, batchDim}, [2]int{0, seqLen}, [2]int{rpe.headDim / 2, rpe.headDim}) + dRotatedX1, err := dOut.Slice([2]int{0, batchDim}, [2]int{0, seqLen}, [2]int{halfRotary, rpe.rotaryDim}) if err != nil { return nil, err } - // Gradients for x_rot0 and x_rot1 // dL/dx_rot0 = d_rotated_x0 * cos(angles) + d_rotated_x1 * sin(angles) mul1, err := rpe.engine.Mul(ctx, dRotatedX0, cosAngles) if err != nil { @@ -263,8 +325,17 @@ func (rpe *RotaryPositionalEmbedding[T]) Backward(ctx context.Context, mode type return nil, err } - // Concatenate gradients for x_rot0 and x_rot1 - dInput, err := rpe.engine.Concat(ctx, []*tensor.TensorNumeric[T]{dLdxRot0, dLdxRot1}, 2) + // Concatenate rotary gradients, plus pass-through if partial. + parts := []*tensor.TensorNumeric[T]{dLdxRot0, dLdxRot1} + if rpe.rotaryDim < rpe.headDim { + dPassThrough, err2 := dOut.Slice([2]int{0, batchDim}, [2]int{0, seqLen}, [2]int{rpe.rotaryDim, rpe.headDim}) + if err2 != nil { + return nil, err2 + } + parts = append(parts, dPassThrough) + } + + dInput, err := rpe.engine.Concat(ctx, parts, 2) if err != nil { return nil, err } @@ -282,6 +353,15 @@ func (rpe *RotaryPositionalEmbedding[T]) Attributes() map[string]interface{} { return nil } +// AttentionScaleFactor returns the YaRN attention scaling factor. +// Returns 1.0 when YaRN is not enabled. +func (rpe *RotaryPositionalEmbedding[T]) AttentionScaleFactor() float64 { + if rpe.attnScaleFactor == 0 { + return 1.0 + } + return rpe.attnScaleFactor +} + // Scale scales the positional embeddings by a given factor. func (rpe *RotaryPositionalEmbedding[T]) Scale(ctx context.Context, factor float64) error { ops := rpe.engine.Ops() diff --git a/layers/embeddings/rotary_positional_embedding_test.go b/layers/embeddings/rotary_positional_embedding_test.go index 32a7b767..853f3296 100644 --- a/layers/embeddings/rotary_positional_embedding_test.go +++ b/layers/embeddings/rotary_positional_embedding_test.go @@ -278,3 +278,314 @@ func TestRotaryPositionalEmbedding_DefaultBase(t *testing.T) { testutils.AssertNoError(t, err, "Forward should not return an error") testutils.AssertNotNil(t, output, "Output tensor should not be nil") } + +func TestRotaryPositionalEmbedding_YaRN_DefaultUnchanged(t *testing.T) { + // Without YaRN, behavior should be identical to default RoPE. + ctx := context.Background() + engine := compute.NewCPUEngine[float32](numeric.Float32Ops{}) + + headDim := 8 + seqLen := 16 + base := 10000.0 + + rpeDefault, err := NewRotaryPositionalEmbedding[float32](ctx, engine, headDim, seqLen, WithRotaryBase(base)) + if err != nil { + t.Fatalf("default RoPE failed: %v", err) + } + + input, _ := tensor.New[float32]([]int{1, seqLen, headDim}, nil) + for i := range input.Data() { + input.Data()[i] = float32(i%7+1) * 0.1 + } + + outDefault, err := rpeDefault.Forward(ctx, input) + if err != nil { + t.Fatalf("default Forward failed: %v", err) + } + + // Sanity: output is non-nil and has same shape + if !reflect.DeepEqual(outDefault.Shape(), input.Shape()) { + t.Errorf("default output shape = %v, want %v", outDefault.Shape(), input.Shape()) + } +} + +func TestRotaryPositionalEmbedding_YaRN_FrequenciesDiffer(t *testing.T) { + // With YaRN scaling, cos/sin tables should differ from default RoPE. + ctx := context.Background() + engine := compute.NewCPUEngine[float32](numeric.Float32Ops{}) + + headDim := 8 + seqLen := 32 + base := 10000.0 + factor := 4.0 + origMaxLen := 64 // Small enough that some wavelengths exceed origMaxLen + + rpeDefault, err := NewRotaryPositionalEmbedding[float32](ctx, engine, headDim, seqLen, WithRotaryBase(base)) + if err != nil { + t.Fatalf("default RoPE failed: %v", err) + } + + rpeYaRN, err := NewRotaryPositionalEmbedding[float32](ctx, engine, headDim, seqLen, + WithRotaryBase(base), + WithYaRNScaling(factor, origMaxLen), + ) + if err != nil { + t.Fatalf("YaRN RoPE failed: %v", err) + } + + input, _ := tensor.New[float32]([]int{1, seqLen, headDim}, nil) + for i := range input.Data() { + input.Data()[i] = float32(i%5+1) * 0.1 + } + + outDefault, err := rpeDefault.Forward(ctx, input) + if err != nil { + t.Fatalf("default Forward failed: %v", err) + } + + outYaRN, err := rpeYaRN.Forward(ctx, input) + if err != nil { + t.Fatalf("YaRN Forward failed: %v", err) + } + + // Outputs should differ + defaultData := outDefault.Data() + yarnData := outYaRN.Data() + allSame := true + for i := range defaultData { + if defaultData[i] != yarnData[i] { + allSame = false + break + } + } + if allSame { + t.Error("YaRN output should differ from default RoPE output") + } +} + +func TestRotaryPositionalEmbedding_YaRN_AttentionScaleFactor(t *testing.T) { + ctx := context.Background() + engine := compute.NewCPUEngine[float32](numeric.Float32Ops{}) + + factor := 4.0 + origMaxLen := 8192 + + rpe, err := NewRotaryPositionalEmbedding[float32](ctx, engine, 8, 16, + WithRotaryBase(10000.0), + WithYaRNScaling(factor, origMaxLen), + ) + if err != nil { + t.Fatalf("YaRN RoPE failed: %v", err) + } + + // Expected: sqrt(1 + ln(factor) / ln(origMaxLen)) + expected := math.Sqrt(1 + math.Log(factor)/math.Log(float64(origMaxLen))) + got := rpe.AttentionScaleFactor() + + if math.Abs(got-expected) > 1e-6 { + t.Errorf("AttentionScaleFactor() = %f, want %f", got, expected) + } +} + +func TestRotaryPositionalEmbedding_NoYaRN_AttentionScaleFactor(t *testing.T) { + ctx := context.Background() + engine := compute.NewCPUEngine[float32](numeric.Float32Ops{}) + + rpe, err := NewRotaryPositionalEmbedding[float32](ctx, engine, 8, 16, + WithRotaryBase(10000.0), + ) + if err != nil { + t.Fatalf("default RoPE failed: %v", err) + } + + // Without YaRN, attention scale factor should be 1.0 + if got := rpe.AttentionScaleFactor(); got != 1.0 { + t.Errorf("AttentionScaleFactor() = %f, want 1.0", got) + } +} + +func TestRotaryPositionalEmbedding_YaRN_ForwardBackward(t *testing.T) { + // YaRN RoPE should still produce valid backward gradients. + ctx := context.Background() + engine := compute.NewCPUEngine[float32](numeric.Float32Ops{}) + + rpe, err := NewRotaryPositionalEmbedding[float32](ctx, engine, 8, 16, + WithRotaryBase(10000.0), + WithYaRNScaling(4.0, 8192), + ) + if err != nil { + t.Fatalf("YaRN RoPE failed: %v", err) + } + + input, _ := tensor.New[float32]([]int{2, 8, 8}, nil) + for i := range input.Data() { + input.Data()[i] = float32(i%5+1) * 0.01 + } + + out, err := rpe.Forward(ctx, input) + if err != nil { + t.Fatalf("Forward failed: %v", err) + } + + dOut, _ := tensor.New[float32](out.Shape(), nil) + for i := range dOut.Data() { + dOut.Data()[i] = 1.0 + } + + grads, err := rpe.Backward(ctx, types.FullBackprop, dOut) + if err != nil { + t.Fatalf("Backward failed: %v", err) + } + if len(grads) != 1 { + t.Errorf("expected 1 gradient, got %d", len(grads)) + } + if !reflect.DeepEqual(grads[0].Shape(), input.Shape()) { + t.Errorf("gradient shape = %v, want %v", grads[0].Shape(), input.Shape()) + } +} + +func TestRotaryPositionalEmbedding_PartialRotation_Default(t *testing.T) { + // Default fraction (1.0) should produce same output as explicit 1.0. + ctx := context.Background() + engine := compute.NewCPUEngine[float32](numeric.Float32Ops{}) + + headDim := 8 + seqLen := 4 + base := 10000.0 + + rpeDefault, err := NewRotaryPositionalEmbedding[float32](ctx, engine, headDim, seqLen, WithRotaryBase(base)) + if err != nil { + t.Fatalf("default RoPE failed: %v", err) + } + + rpeFull, err := NewRotaryPositionalEmbedding[float32](ctx, engine, headDim, seqLen, + WithRotaryBase(base), + WithRotaryDimFraction(1.0), + ) + if err != nil { + t.Fatalf("fraction=1.0 RoPE failed: %v", err) + } + + input, _ := tensor.New[float32]([]int{1, seqLen, headDim}, nil) + for i := range input.Data() { + input.Data()[i] = float32(i%7+1) * 0.1 + } + + outDefault, err := rpeDefault.Forward(ctx, input) + if err != nil { + t.Fatalf("default Forward failed: %v", err) + } + + outFull, err := rpeFull.Forward(ctx, input) + if err != nil { + t.Fatalf("fraction=1.0 Forward failed: %v", err) + } + + for i, v := range outDefault.Data() { + if v != outFull.Data()[i] { + t.Errorf("default vs fraction=1.0 differ at index %d: %f != %f", i, v, outFull.Data()[i]) + break + } + } +} + +func TestRotaryPositionalEmbedding_PartialRotation_Half(t *testing.T) { + // fraction=0.5 with headDim=8: 4 dims rotated, 4 dims unchanged. + ctx := context.Background() + engine := compute.NewCPUEngine[float32](numeric.Float32Ops{}) + + headDim := 8 + seqLen := 4 + base := 10000.0 + + rpe, err := NewRotaryPositionalEmbedding[float32](ctx, engine, headDim, seqLen, + WithRotaryBase(base), + WithRotaryDimFraction(0.5), + ) + if err != nil { + t.Fatalf("fraction=0.5 RoPE failed: %v", err) + } + + input, _ := tensor.New[float32]([]int{1, seqLen, headDim}, nil) + for i := range input.Data() { + input.Data()[i] = float32(i%7+1) * 0.1 + } + + output, err := rpe.Forward(ctx, input) + if err != nil { + t.Fatalf("Forward failed: %v", err) + } + + if !reflect.DeepEqual(output.Shape(), input.Shape()) { + t.Fatalf("output shape = %v, want %v", output.Shape(), input.Shape()) + } + + // The last 4 dims of each position should be unchanged (pass-through). + rotDim := 4 // headDim * 0.5 + for pos := 0; pos < seqLen; pos++ { + for d := rotDim; d < headDim; d++ { + idx := pos*headDim + d + if output.Data()[idx] != input.Data()[idx] { + t.Errorf("pos=%d dim=%d: output=%f should equal input=%f (unrotated region)", + pos, d, output.Data()[idx], input.Data()[idx]) + } + } + } + + // The first 4 dims should differ (at least for pos > 0). + anyDiffer := false + for pos := 1; pos < seqLen; pos++ { + for d := 0; d < rotDim; d++ { + idx := pos*headDim + d + if output.Data()[idx] != input.Data()[idx] { + anyDiffer = true + break + } + } + if anyDiffer { + break + } + } + if !anyDiffer { + t.Error("rotated region should differ from input for pos > 0") + } +} + +func TestRotaryPositionalEmbedding_PartialRotation_ForwardBackward(t *testing.T) { + ctx := context.Background() + engine := compute.NewCPUEngine[float32](numeric.Float32Ops{}) + + rpe, err := NewRotaryPositionalEmbedding[float32](ctx, engine, 8, 16, + WithRotaryBase(10000.0), + WithRotaryDimFraction(0.75), + ) + if err != nil { + t.Fatalf("partial RoPE failed: %v", err) + } + + input, _ := tensor.New[float32]([]int{2, 8, 8}, nil) + for i := range input.Data() { + input.Data()[i] = float32(i%5+1) * 0.01 + } + + out, err := rpe.Forward(ctx, input) + if err != nil { + t.Fatalf("Forward failed: %v", err) + } + + dOut, _ := tensor.New[float32](out.Shape(), nil) + for i := range dOut.Data() { + dOut.Data()[i] = 1.0 + } + + grads, err := rpe.Backward(ctx, types.FullBackprop, dOut) + if err != nil { + t.Fatalf("Backward failed: %v", err) + } + if len(grads) != 1 { + t.Errorf("expected 1 gradient, got %d", len(grads)) + } + if !reflect.DeepEqual(grads[0].Shape(), input.Shape()) { + t.Errorf("gradient shape = %v, want %v", grads[0].Shape(), input.Shape()) + } +} diff --git a/layers/registry/registry.go b/layers/registry/registry.go index b92f5359..49aa26d5 100644 --- a/layers/registry/registry.go +++ b/layers/registry/registry.go @@ -26,6 +26,7 @@ func RegisterAll() { // Attention model.RegisterLayer("GroupQueryAttention", attention.BuildGroupQueryAttention[float32]) model.RegisterLayer("GlobalAttention", attention.BuildGlobalAttention[float32]) + model.RegisterLayer("MultiHeadLatentAttention", attention.BuildMultiHeadLatentAttention[float32]) // Core model.RegisterLayer("Add", core.BuildAdd[float32]) diff --git a/model/builder.go b/model/builder.go index 73284381..6381ac2d 100644 --- a/model/builder.go +++ b/model/builder.go @@ -15,6 +15,30 @@ import ( "github.com/zerfoo/zmf" ) +// BuildOption configures optional behavior for BuildFromZMF. +type BuildOption func(*buildConfig) + +type buildConfig struct { + resolver ParamResolver + globalAttributes map[string]interface{} +} + +// WithGlobalAttributes injects extra key-value pairs into every node's +// attribute map during graph construction. This is used, for example, to +// propagate rope_scaling_* config from config.json into every GQA node +// without requiring the ZMF file to carry these attributes. +func WithGlobalAttributes(attrs map[string]interface{}) BuildOption { + return func(c *buildConfig) { c.globalAttributes = attrs } +} + +// WithParamResolver supplies an architecture-aware parameter name resolver. +// When provided, the resolver adds canonical aliases to the parameter map +// so that layer builders can look up parameters by canonical name even when +// the ZMF file uses architecture-specific names. +func WithParamResolver(r ParamResolver) BuildOption { + return func(c *buildConfig) { c.resolver = r } +} + // BuildFromZMF constructs a Zerfoo computation graph from a ZMF model definition. // This function iterates through the nodes in the graph, instantiates the // corresponding layers using a registered builder, and connects them into an @@ -25,16 +49,27 @@ func BuildFromZMF[T tensor.Numeric]( engine compute.Engine[T], ops numeric.Arithmetic[T], model *zmf.Model, + opts ...BuildOption, ) (*graph.Graph[T], error) { if model == nil || model.Graph == nil { return nil, errors.New("cannot build model from nil or empty ZMF graph") } + cfg := buildConfig{} + for _, opt := range opts { + opt(&cfg) + } + params, err := convertParameters[T](model.Graph.Parameters) if err != nil { return nil, err } + // Apply resolver to add canonical aliases alongside original names. + if cfg.resolver != nil { + params = ResolveAll(cfg.resolver, params) + } + builder := graph.NewBuilder[T](engine) instantiatedNodes := make(map[string]graph.Node[T]) @@ -89,6 +124,12 @@ func BuildFromZMF[T tensor.Numeric]( } attributes := convertAttributes(nodeProto.Attributes) + // Merge global attributes (e.g. rope_scaling) into per-node attributes. + for k, v := range cfg.globalAttributes { + if _, exists := attributes[k]; !exists { + attributes[k] = v + } + } node, err := layerBuilder(engine, ops, nodeProto.Name, params, attributes) if err != nil { diff --git a/model/builder_test.go b/model/builder_test.go index 9bc795a0..930cf92c 100644 --- a/model/builder_test.go +++ b/model/builder_test.go @@ -3,6 +3,7 @@ package model import ( "context" "encoding/binary" + "fmt" "math" "testing" @@ -1129,3 +1130,149 @@ func TestModelForward_EmbeddingError(t *testing.T) { t.Error("expected error when embedding receives no inputs") } } + +// --- ParamResolver integration tests (T58.2) --- + +// resolverParamLookupBuilder returns a LayerBuilder that expects a specific parameter +// name to exist in the params map. It fails if the parameter is not found. +func resolverParamLookupBuilder(expectedParamName string) LayerBuilder[float32] { + return func( + _ compute.Engine[float32], + _ numeric.Arithmetic[float32], + _ string, + params map[string]*graph.Parameter[float32], + _ map[string]any, + ) (graph.Node[float32], error) { + if _, ok := params[expectedParamName]; !ok { + return nil, fmt.Errorf("expected parameter %q not found in params", expectedParamName) + } + val, _ := tensor.New[float32]([]int{1}, []float32{1}) + return ¶meterNode[float32]{value: val}, nil + } +} + +func TestBuildFromZMF_WithParamResolver_PhiDenseProj(t *testing.T) { + ops := numeric.Float32Ops{} + engine := compute.NewCPUEngine[float32](ops) + + // Register a layer builder that looks up "model.layers.0.self_attn.o_proj.weight" + // (canonical name) even though the ZMF param is "model.layers.0.self_attn.dense_proj.weight" + RegisterLayer("TestPhiAttnOp", resolverParamLookupBuilder("model.layers.0.self_attn.o_proj.weight")) + defer UnregisterLayer("TestPhiAttnOp") + + paramData := make([]byte, 4) + binary.LittleEndian.PutUint32(paramData, math.Float32bits(1.0)) + + zmfModel := &zmf.Model{ + Graph: &zmf.Graph{ + Parameters: map[string]*zmf.Tensor{ + // Phi-style name: dense_proj instead of o_proj + "model.layers.0.self_attn.dense_proj.weight": { + Dtype: zmf.Tensor_FLOAT32, + Shape: []int64{1}, + Data: paramData, + }, + }, + Inputs: []*zmf.ValueInfo{ + {Name: "input", Shape: []int64{1}}, + }, + Nodes: []*zmf.Node{ + {Name: "attn_node", OpType: "TestPhiAttnOp", Inputs: []string{"input"}}, + }, + Outputs: []*zmf.ValueInfo{{Name: "attn_node"}}, + }, + } + + // Without resolver, the canonical name would not be in params and the build would fail + _, err := BuildFromZMF[float32](engine, ops, zmfModel) + if err == nil { + t.Fatal("expected error without resolver, but got nil") + } + + // With Phi resolver, dense_proj is aliased to o_proj and the build succeeds + g, err := BuildFromZMF[float32](engine, ops, zmfModel, WithParamResolver(NewParamResolver("phi"))) + if err != nil { + t.Fatalf("BuildFromZMF with Phi resolver failed: %v", err) + } + if g == nil { + t.Fatal("expected non-nil graph") + } +} + +func TestBuildFromZMF_WithParamResolver_IdentityPreservesExisting(t *testing.T) { + ops := numeric.Float32Ops{} + engine := compute.NewCPUEngine[float32](ops) + + // Layer that looks up a param by its original name + RegisterLayer("TestIdentityResolveOp", resolverParamLookupBuilder("model.layers.0.self_attn.q_proj.weight")) + defer UnregisterLayer("TestIdentityResolveOp") + + paramData := make([]byte, 4) + binary.LittleEndian.PutUint32(paramData, math.Float32bits(2.0)) + + zmfModel := &zmf.Model{ + Graph: &zmf.Graph{ + Parameters: map[string]*zmf.Tensor{ + "model.layers.0.self_attn.q_proj.weight": { + Dtype: zmf.Tensor_FLOAT32, + Shape: []int64{1}, + Data: paramData, + }, + }, + Inputs: []*zmf.ValueInfo{ + {Name: "input", Shape: []int64{1}}, + }, + Nodes: []*zmf.Node{ + {Name: "node1", OpType: "TestIdentityResolveOp", Inputs: []string{"input"}}, + }, + Outputs: []*zmf.ValueInfo{{Name: "node1"}}, + }, + } + + // With identity resolver (llama), original names are preserved + g, err := BuildFromZMF[float32](engine, ops, zmfModel, WithParamResolver(NewParamResolver("llama"))) + if err != nil { + t.Fatalf("BuildFromZMF with Llama resolver failed: %v", err) + } + if g == nil { + t.Fatal("expected non-nil graph") + } +} + +func TestBuildFromZMF_WithoutResolver_BackwardCompatible(t *testing.T) { + // Verify that all existing tests still pass without any resolver option. + // This test creates a minimal model and builds it with no options. + ops := numeric.Float32Ops{} + engine := compute.NewCPUEngine[float32](ops) + + paramData := make([]byte, 4) + binary.LittleEndian.PutUint32(paramData, math.Float32bits(1.0)) + + zmfModel := &zmf.Model{ + Graph: &zmf.Graph{ + Parameters: map[string]*zmf.Tensor{ + "weight": { + Dtype: zmf.Tensor_FLOAT32, + Shape: []int64{1}, + Data: paramData, + }, + }, + Inputs: []*zmf.ValueInfo{ + {Name: "input", Shape: []int64{1}}, + {Name: "weight", Shape: []int64{1}}, + }, + Outputs: []*zmf.ValueInfo{ + {Name: "input"}, + }, + }, + } + + // No resolver option - should work exactly as before + g, err := BuildFromZMF[float32](engine, ops, zmfModel) + if err != nil { + t.Fatalf("BuildFromZMF without resolver failed: %v", err) + } + if g == nil { + t.Fatal("expected non-nil graph") + } +} diff --git a/model/param_resolver.go b/model/param_resolver.go new file mode 100644 index 00000000..b7f514f7 --- /dev/null +++ b/model/param_resolver.go @@ -0,0 +1,81 @@ +package model + +import "regexp" + +// ParamResolver maps architecture-specific parameter names to canonical names +// used by Zerfoo layers during model building. Canonical names follow the +// Llama/Gemma convention (the most common HuggingFace naming pattern). +type ParamResolver interface { + // Resolve returns the canonical name for a model-specific parameter name. + // Returns the input unchanged if no mapping applies. + Resolve(name string) string +} + +// NewParamResolver returns a resolver for the given architecture type. +// Architecture types match the model_type field from HuggingFace config.json. +func NewParamResolver(arch string) ParamResolver { + switch arch { + case "phi", "phi3": + return newPhiResolver() + default: + // Llama, Gemma, Gemma2, Gemma3, Mistral, Qwen2, DeepSeek V3, and + // unknown architectures all use canonical naming (no renaming needed). + return &identityResolver{} + } +} + +// ResolveAll takes a resolver and a map keyed by model-specific names, and +// returns a new map containing both the original names and any canonical +// aliases produced by the resolver. This allows parameter lookups by either +// the original or canonical name. +func ResolveAll[V any](r ParamResolver, params map[string]V) map[string]V { + result := make(map[string]V, len(params)) + for name, val := range params { + result[name] = val + canonical := r.Resolve(name) + if canonical != name { + result[canonical] = val + } + } + return result +} + +// identityResolver returns names unchanged. Used for architectures whose +// parameter names already follow the canonical convention. +type identityResolver struct{} + +func (r *identityResolver) Resolve(name string) string { return name } + +// regexResolver applies an ordered list of regex-based renaming rules. +// The first matching rule wins. +type regexResolver struct { + rules []resolveRule +} + +type resolveRule struct { + pattern *regexp.Regexp + replacement string +} + +func (r *regexResolver) Resolve(name string) string { + for _, rule := range r.rules { + if rule.pattern.MatchString(name) { + return rule.pattern.ReplaceAllString(name, rule.replacement) + } + } + return name +} + +// newPhiResolver creates a resolver for Phi-family models. +// Phi uses "dense_proj" where other architectures use "o_proj" for the +// attention output projection. +func newPhiResolver() *regexResolver { + return ®exResolver{ + rules: []resolveRule{ + { + pattern: regexp.MustCompile(`^(model\.layers\.\d+\.self_attn\.)dense(_proj\..+)$`), + replacement: "${1}o${2}", + }, + }, + } +} diff --git a/model/param_resolver_test.go b/model/param_resolver_test.go new file mode 100644 index 00000000..38383686 --- /dev/null +++ b/model/param_resolver_test.go @@ -0,0 +1,249 @@ +package model + +import "testing" + +func TestNewParamResolver_ArchitectureSelection(t *testing.T) { + tests := []struct { + arch string + wantType string + }{ + {"llama", "identity"}, + {"gemma", "identity"}, + {"gemma2", "identity"}, + {"gemma3", "identity"}, + {"mistral", "identity"}, + {"qwen2", "identity"}, + {"phi", "regex"}, + {"phi3", "regex"}, + {"deepseek_v3", "identity"}, + {"unknown_arch", "identity"}, + {"", "identity"}, + } + for _, tt := range tests { + t.Run(tt.arch, func(t *testing.T) { + r := NewParamResolver(tt.arch) + if r == nil { + t.Fatal("NewParamResolver returned nil") + } + switch tt.wantType { + case "identity": + if _, ok := r.(*identityResolver); !ok { + t.Errorf("expected *identityResolver for arch %q, got %T", tt.arch, r) + } + case "regex": + if _, ok := r.(*regexResolver); !ok { + t.Errorf("expected *regexResolver for arch %q, got %T", tt.arch, r) + } + } + }) + } +} + +func TestIdentityResolver_PassesThrough(t *testing.T) { + r := NewParamResolver("llama") + tests := []string{ + "model.layers.0.self_attn.q_proj.weight", + "model.layers.0.self_attn.k_proj.weight", + "model.layers.0.self_attn.v_proj.weight", + "model.layers.0.self_attn.o_proj.weight", + "model.layers.15.mlp.gate_proj.weight", + "model.layers.15.mlp.up_proj.weight", + "model.layers.15.mlp.down_proj.weight", + "model.layers.3.input_layernorm.weight", + "model.layers.3.post_attention_layernorm.weight", + "model.embed_tokens.weight", + "lm_head.weight", + "model.norm.weight", + "", + } + for _, name := range tests { + t.Run(name, func(t *testing.T) { + got := r.Resolve(name) + if got != name { + t.Errorf("Resolve(%q) = %q, want %q", name, got, name) + } + }) + } +} + +func TestPhiResolver_RenamesDenseProj(t *testing.T) { + r := NewParamResolver("phi") + tests := []struct { + input string + want string + }{ + // Phi-specific: dense_proj → o_proj + { + "model.layers.0.self_attn.dense_proj.weight", + "model.layers.0.self_attn.o_proj.weight", + }, + { + "model.layers.31.self_attn.dense_proj.weight", + "model.layers.31.self_attn.o_proj.weight", + }, + // Other attention projections unchanged + { + "model.layers.0.self_attn.q_proj.weight", + "model.layers.0.self_attn.q_proj.weight", + }, + { + "model.layers.0.self_attn.k_proj.weight", + "model.layers.0.self_attn.k_proj.weight", + }, + { + "model.layers.0.self_attn.v_proj.weight", + "model.layers.0.self_attn.v_proj.weight", + }, + // FFN names unchanged + { + "model.layers.0.mlp.gate_proj.weight", + "model.layers.0.mlp.gate_proj.weight", + }, + // Norm names unchanged + { + "model.layers.0.input_layernorm.weight", + "model.layers.0.input_layernorm.weight", + }, + // Embedding/head unchanged + { + "model.embed_tokens.weight", + "model.embed_tokens.weight", + }, + { + "lm_head.weight", + "lm_head.weight", + }, + // Empty string + {"", ""}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := r.Resolve(tt.input) + if got != tt.want { + t.Errorf("Resolve(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestPhiResolver_Phi3Alias(t *testing.T) { + r := NewParamResolver("phi3") + got := r.Resolve("model.layers.5.self_attn.dense_proj.weight") + want := "model.layers.5.self_attn.o_proj.weight" + if got != want { + t.Errorf("Resolve for phi3 = %q, want %q", got, want) + } +} + +func TestGemmaResolver_Identity(t *testing.T) { + for _, arch := range []string{"gemma", "gemma2", "gemma3"} { + t.Run(arch, func(t *testing.T) { + r := NewParamResolver(arch) + name := "model.layers.0.self_attn.o_proj.weight" + got := r.Resolve(name) + if got != name { + t.Errorf("Resolve(%q) = %q, want identity", name, got) + } + }) + } +} + +func TestMistralResolver_Identity(t *testing.T) { + r := NewParamResolver("mistral") + name := "model.layers.0.self_attn.o_proj.weight" + got := r.Resolve(name) + if got != name { + t.Errorf("Resolve(%q) = %q, want identity", name, got) + } +} + +func TestQwenResolver_Identity(t *testing.T) { + r := NewParamResolver("qwen2") + tests := []string{ + "model.layers.0.self_attn.q_proj.weight", + "model.layers.0.self_attn.q_proj.bias", + "model.layers.0.self_attn.k_proj.bias", + "model.layers.0.self_attn.v_proj.bias", + "model.layers.0.self_attn.o_proj.weight", + "model.layers.0.self_attn.o_proj.bias", + } + for _, name := range tests { + t.Run(name, func(t *testing.T) { + got := r.Resolve(name) + if got != name { + t.Errorf("Resolve(%q) = %q, want identity", name, got) + } + }) + } +} + +func TestDeepSeekResolver_MLANames(t *testing.T) { + r := NewParamResolver("deepseek_v3") + // MLA-specific names should pass through unchanged + tests := []string{ + "model.layers.0.self_attn.kv_a_proj.weight", + "model.layers.0.self_attn.kv_b_proj.weight", + "model.layers.0.self_attn.q_a_proj.weight", + "model.layers.0.self_attn.q_b_proj.weight", + "model.layers.0.self_attn.o_proj.weight", + "model.layers.0.mlp.gate_proj.weight", + "model.layers.0.mlp.up_proj.weight", + "model.layers.0.mlp.down_proj.weight", + } + for _, name := range tests { + t.Run(name, func(t *testing.T) { + got := r.Resolve(name) + if got != name { + t.Errorf("Resolve(%q) = %q, want identity", name, got) + } + }) + } +} + +func TestResolveAll(t *testing.T) { + r := NewParamResolver("phi") + + input := map[string]string{ + "model.layers.0.self_attn.dense_proj.weight": "tensor_a", + "model.layers.0.self_attn.q_proj.weight": "tensor_b", + "model.embed_tokens.weight": "tensor_c", + } + + result := ResolveAll(r, input) + + // Original names preserved + if result["model.layers.0.self_attn.dense_proj.weight"] != "tensor_a" { + t.Error("original name 'dense_proj' should be preserved") + } + if result["model.layers.0.self_attn.q_proj.weight"] != "tensor_b" { + t.Error("original name 'q_proj' should be preserved") + } + if result["model.embed_tokens.weight"] != "tensor_c" { + t.Error("original name 'embed_tokens' should be preserved") + } + + // Canonical alias added for renamed param + if result["model.layers.0.self_attn.o_proj.weight"] != "tensor_a" { + t.Error("canonical alias 'o_proj' should map to tensor_a") + } + + // No spurious aliases for identity-mapped names + if len(result) != 4 { + t.Errorf("expected 4 entries (3 original + 1 alias), got %d", len(result)) + } +} + +func TestResolveAll_NoAliasesForIdentity(t *testing.T) { + r := NewParamResolver("llama") + + input := map[string]string{ + "model.layers.0.self_attn.q_proj.weight": "tensor_a", + "model.layers.0.self_attn.o_proj.weight": "tensor_b", + } + + result := ResolveAll(r, input) + + if len(result) != 2 { + t.Errorf("expected 2 entries (no aliases for identity), got %d", len(result)) + } +} diff --git a/model/zmf_loader.go b/model/zmf_loader.go index 8451139e..fa5925b0 100644 --- a/model/zmf_loader.go +++ b/model/zmf_loader.go @@ -39,13 +39,14 @@ func LoadModelFromZMF[T tensor.Numeric]( engine compute.Engine[T], ops numeric.Arithmetic[T], filePath string, + buildOpts ...BuildOption, ) (*Model[T], error) { zmfModel, err := LoadZMF(filePath) if err != nil { return nil, err } - graph, err := BuildFromZMF(engine, ops, zmfModel) + graph, err := BuildFromZMF(engine, ops, zmfModel, buildOpts...) if err != nil { return nil, err } diff --git a/tests/parity/deepseek_test.go b/tests/parity/deepseek_test.go new file mode 100644 index 00000000..78023f06 --- /dev/null +++ b/tests/parity/deepseek_test.go @@ -0,0 +1,30 @@ +package parity_test + +import ( + "testing" + + layerreg "github.com/zerfoo/zerfoo/layers/registry" +) + +var deepseekV3Config = modelParityConfig{ + Name: "DeepSeek-V3", + ZMFEnvVar: "DEEPSEEK_ZMF_PATH", + ModelDirEnvVar: "DEEPSEEK_MODEL_DIR", + ModelID: "deepseek-v3", + MinVocabSize: 100000, // DeepSeek V3 vocab: 129280 +} + +func TestDeepSeekV3ForwardPass(t *testing.T) { + layerreg.RegisterAll() + runModelForwardPass(t, deepseekV3Config) +} + +func TestDeepSeekV3GreedyDecode(t *testing.T) { + layerreg.RegisterAll() + runModelGreedyDecode(t, deepseekV3Config) +} + +func TestDeepSeekV3Generation(t *testing.T) { + layerreg.RegisterAll() + runModelGeneration(t, deepseekV3Config) +} diff --git a/tests/parity/gemma3_generation_test.go b/tests/parity/gemma3_generation_test.go deleted file mode 100644 index 40461c36..00000000 --- a/tests/parity/gemma3_generation_test.go +++ /dev/null @@ -1,141 +0,0 @@ -package parity_test - -import ( - "context" - "os" - "path/filepath" - "strings" - "testing" - - "github.com/zerfoo/zerfoo/generate" - "github.com/zerfoo/zerfoo/inference" - layerreg "github.com/zerfoo/zerfoo/layers/registry" - "github.com/zerfoo/zerfoo/registry" -) - -// dirRegistry is a mock ModelRegistry that maps model IDs to local directories. -type dirRegistry struct { - models map[string]*registry.ModelInfo -} - -func (r *dirRegistry) Get(modelID string) (*registry.ModelInfo, bool) { - info, ok := r.models[modelID] - return info, ok -} - -func (r *dirRegistry) Pull(_ context.Context, _ string) (*registry.ModelInfo, error) { - return nil, nil -} - -func (r *dirRegistry) List() []registry.ModelInfo { return nil } -func (r *dirRegistry) Delete(_ string) error { return nil } - -// TestGemma3Generation loads a Gemma 3 model via inference.Load and verifies -// that text generation produces coherent, non-empty output. -// -// Requires GEMMA3_ZMF_PATH pointing to a .zmf file whose parent directory -// also contains config.json and tokenizer.json. Alternatively, set -// GEMMA3_MODEL_DIR to the directory directly. -func TestGemma3Generation(t *testing.T) { - modelDir := os.Getenv("GEMMA3_MODEL_DIR") - if modelDir == "" { - zmfPath := os.Getenv("GEMMA3_ZMF_PATH") - if zmfPath == "" { - t.Skip("GEMMA3_MODEL_DIR and GEMMA3_ZMF_PATH not set; skipping") - } - modelDir = filepath.Dir(zmfPath) - } - - layerreg.RegisterAll() - - reg := &dirRegistry{ - models: map[string]*registry.ModelInfo{ - "gemma-3": {ID: "gemma-3", Path: modelDir}, - }, - } - - mdl, err := inference.Load("gemma-3", inference.WithRegistry(reg)) - if err != nil { - t.Fatalf("inference.Load failed: %v", err) - } - - ctx := context.Background() - - // Test 1: Greedy generation is non-empty and deterministic. - t.Run("greedy_deterministic", func(t *testing.T) { - prompt := "The capital of France is" - result1, err := mdl.Generate(ctx, prompt, - inference.WithTemperature(0), - inference.WithMaxTokens(20), - ) - if err != nil { - t.Fatalf("Generate failed: %v", err) - } - if result1 == "" { - t.Fatal("greedy generation produced empty output") - } - t.Logf("greedy output: %q", result1) - - // Second greedy pass should be identical. - result2, err := mdl.Generate(ctx, prompt, - inference.WithTemperature(0), - inference.WithMaxTokens(20), - ) - if err != nil { - t.Fatalf("Generate (second) failed: %v", err) - } - if result1 != result2 { - t.Errorf("greedy outputs differ:\n run1: %q\n run2: %q", result1, result2) - } - }) - - // Test 2: Streaming output matches non-streaming. - t.Run("stream_parity", func(t *testing.T) { - prompt := "Hello world" - nonStream, err := mdl.Generate(ctx, prompt, - inference.WithTemperature(0), - inference.WithMaxTokens(10), - ) - if err != nil { - t.Fatalf("Generate failed: %v", err) - } - - var sb strings.Builder - err = mdl.GenerateStream(ctx, prompt, - generate.TokenStreamFunc(func(token string, done bool) error { - if !done { - sb.WriteString(token) - } - return nil - }), - inference.WithTemperature(0), - inference.WithMaxTokens(10), - ) - if err != nil { - t.Fatalf("GenerateStream failed: %v", err) - } - - streamed := sb.String() - if nonStream != streamed { - t.Errorf("stream/non-stream mismatch:\n non-stream: %q\n stream: %q", - nonStream, streamed) - } - }) - - // Test 3: Chat generates a response. - t.Run("chat", func(t *testing.T) { - resp, err := mdl.Chat(ctx, []inference.Message{ - {Role: "user", Content: "Say hello in French"}, - }, inference.WithMaxTokens(20)) - if err != nil { - t.Fatalf("Chat failed: %v", err) - } - if resp.Content == "" { - t.Error("Chat produced empty content") - } - if resp.TokensUsed <= 0 { - t.Error("TokensUsed should be positive") - } - t.Logf("chat response: %q (tokens: %d)", resp.Content, resp.TokensUsed) - }) -} diff --git a/tests/parity/gemma3_test.go b/tests/parity/gemma3_test.go index 6c5badf0..5515de20 100644 --- a/tests/parity/gemma3_test.go +++ b/tests/parity/gemma3_test.go @@ -1,176 +1,30 @@ package parity_test import ( - "context" - "math" - "os" "testing" - "github.com/zerfoo/zerfoo/compute" - "github.com/zerfoo/zerfoo/layers/registry" - "github.com/zerfoo/zerfoo/model" - "github.com/zerfoo/zerfoo/numeric" - "github.com/zerfoo/zerfoo/tensor" + layerreg "github.com/zerfoo/zerfoo/layers/registry" ) -// TestGemma3ForwardPass loads a ZMF-converted Gemma 3 model and verifies -// that a single forward pass succeeds, produces a valid output shape, and -// contains no NaN or Inf values. -// -// The test is skipped when GEMMA3_ZMF_PATH is not set; it is intended for -// CI environments where the model file is present. -func TestGemma3ForwardPass(t *testing.T) { - zmfPath := os.Getenv("GEMMA3_ZMF_PATH") - if zmfPath == "" { - t.Skip("GEMMA3_ZMF_PATH not set; skipping Gemma 3 forward pass test") - } - - registry.RegisterAll() - - eng := compute.NewCPUEngine[float32](numeric.Float32Ops{}) - ops := numeric.Float32Ops{} - - m, err := model.LoadModelFromZMF[float32](eng, ops, zmfPath) - if err != nil { - t.Fatalf("LoadModelFromZMF failed: %v", err) - } - if m.Graph == nil { - t.Fatal("model graph is nil") - } - - // Input: [1, 8] int64 token IDs cast to float32 for the graph. - // Gemma 3 graphs expect float32 token IDs at the embedding lookup input. - seqLen := 8 - inputData := make([]float32, seqLen) - for i := range inputData { - inputData[i] = float32(i + 1) // token IDs 1..8 - } - input, err := tensor.New[float32]([]int{1, seqLen}, inputData) - if err != nil { - t.Fatalf("tensor.New failed: %v", err) - } - - output, err := m.Graph.Forward(context.Background(), input) - if err != nil { - t.Fatalf("Graph.Forward failed: %v", err) - } - if output == nil { - t.Fatal("output tensor is nil") - } - - outShape := output.Shape() - t.Logf("Gemma 3 output shape: %v", outShape) - - // Output should be [1, seqLen, vocabSize] where vocabSize >= 256000. - if len(outShape) < 2 { - t.Errorf("output rank %d < 2; want at least 2", len(outShape)) - } - if len(outShape) == 3 { - if outShape[0] != 1 { - t.Errorf("output batch dim = %d, want 1", outShape[0]) - } - if outShape[1] != seqLen { - t.Errorf("output seq dim = %d, want %d", outShape[1], seqLen) - } - if outShape[2] < 256000 { - t.Errorf("output vocab dim = %d, want >= 256000", outShape[2]) - } - } +var gemma3Config = modelParityConfig{ + Name: "Gemma 3", + ZMFEnvVar: "GEMMA3_ZMF_PATH", + ModelDirEnvVar: "GEMMA3_MODEL_DIR", + ModelID: "gemma-3", + MinVocabSize: 256000, // Gemma 3 vocab: 262144 +} - // Verify no NaN or Inf in the output. - data := output.Data() - for i, v := range data { - f := float64(v) - if math.IsNaN(f) { - t.Errorf("output[%d] is NaN", i) - break - } - if math.IsInf(f, 0) { - t.Errorf("output[%d] is Inf", i) - break - } - } +func TestGemma3ForwardPass(t *testing.T) { + layerreg.RegisterAll() + runModelForwardPass(t, gemma3Config) } -// TestGemma3GreedyDecode runs 5 greedy decode steps starting from a short -// prompt. Each step picks the argmax over the vocab dimension and appends -// the token to the sequence for the next step. -// -// Assertions: no error, no panic, 5 output tokens in [0, vocabSize). func TestGemma3GreedyDecode(t *testing.T) { - zmfPath := os.Getenv("GEMMA3_ZMF_PATH") - if zmfPath == "" { - t.Skip("GEMMA3_ZMF_PATH not set; skipping Gemma 3 greedy decode test") - } - - registry.RegisterAll() - - eng := compute.NewCPUEngine[float32](numeric.Float32Ops{}) - ops := numeric.Float32Ops{} - - m, err := model.LoadModelFromZMF[float32](eng, ops, zmfPath) - if err != nil { - t.Fatalf("LoadModelFromZMF failed: %v", err) - } - - // Start with token IDs [1, 2, 3]. - tokens := []float32{1, 2, 3} - const steps = 5 - - for step := 0; step < steps; step++ { - seqLen := len(tokens) - input, err := tensor.New[float32]([]int{1, seqLen}, append([]float32{}, tokens...)) - if err != nil { - t.Fatalf("step %d: tensor.New failed: %v", step, err) - } - - output, err := m.Graph.Forward(context.Background(), input) - if err != nil { - t.Fatalf("step %d: Graph.Forward failed: %v", step, err) - } - if output == nil { - t.Fatalf("step %d: output tensor is nil", step) - } - - outShape := output.Shape() - data := output.Data() - - // For a [1, seqLen, vocabSize] output, pick the last position's argmax. - var vocabSize int - var lastPosOffset int - switch len(outShape) { - case 3: - vocabSize = outShape[2] - lastPosOffset = (seqLen - 1) * vocabSize - case 2: - vocabSize = outShape[1] - lastPosOffset = 0 - default: - t.Fatalf("step %d: unexpected output rank %d", step, len(outShape)) - } - - if vocabSize == 0 { - t.Fatalf("step %d: vocabSize is 0", step) - } - - // Argmax over the last token position. - bestIdx := 0 - bestVal := data[lastPosOffset] - for j := 1; j < vocabSize; j++ { - if data[lastPosOffset+j] > bestVal { - bestVal = data[lastPosOffset+j] - bestIdx = j - } - } - - if bestIdx < 0 || bestIdx >= vocabSize { - t.Errorf("step %d: next token %d out of range [0, %d)", step, bestIdx, vocabSize) - } - t.Logf("step %d: next token = %d", step, bestIdx) - tokens = append(tokens, float32(bestIdx)) - } + layerreg.RegisterAll() + runModelGreedyDecode(t, gemma3Config) +} - if len(tokens) != 3+steps { - t.Errorf("expected %d tokens after decode, got %d", 3+steps, len(tokens)) - } +func TestGemma3Generation(t *testing.T) { + layerreg.RegisterAll() + runModelGeneration(t, gemma3Config) } diff --git a/tests/parity/helpers_test.go b/tests/parity/helpers_test.go new file mode 100644 index 00000000..3f63ebee --- /dev/null +++ b/tests/parity/helpers_test.go @@ -0,0 +1,338 @@ +package parity_test + +import ( + "context" + "math" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/zerfoo/zerfoo/compute" + "github.com/zerfoo/zerfoo/generate" + "github.com/zerfoo/zerfoo/graph" + "github.com/zerfoo/zerfoo/inference" + "github.com/zerfoo/zerfoo/model" + "github.com/zerfoo/zerfoo/numeric" + "github.com/zerfoo/zerfoo/registry" + "github.com/zerfoo/zerfoo/tensor" +) + +// dirRegistry is a mock ModelRegistry that maps model IDs to local directories. +type dirRegistry struct { + models map[string]*registry.ModelInfo +} + +func (r *dirRegistry) Get(modelID string) (*registry.ModelInfo, bool) { + info, ok := r.models[modelID] + return info, ok +} + +func (r *dirRegistry) Pull(_ context.Context, _ string) (*registry.ModelInfo, error) { + return nil, nil +} + +func (r *dirRegistry) List() []registry.ModelInfo { return nil } +func (r *dirRegistry) Delete(_ string) error { return nil } + +// loadZMFGraph loads a ZMF model and returns the computation graph. +func loadZMFGraph(t *testing.T, zmfPath string) *graph.Graph[float32] { + t.Helper() + eng := compute.NewCPUEngine[float32](numeric.Float32Ops{}) + m, err := model.LoadModelFromZMF[float32](eng, numeric.Float32Ops{}, zmfPath) + if err != nil { + t.Fatalf("LoadModelFromZMF failed: %v", err) + } + if m.Graph == nil { + t.Fatal("model graph is nil") + } + return m.Graph +} + +// forwardPassConfig holds parameters for a forward pass test. +type forwardPassConfig struct { + Name string + SeqLen int + MinVocabSize int +} + +// runForwardPassTest runs a single forward pass and validates the output. +func runForwardPassTest(t *testing.T, g *graph.Graph[float32], cfg forwardPassConfig) { + t.Helper() + + inputData := make([]float32, cfg.SeqLen) + for i := range inputData { + inputData[i] = float32(i + 1) + } + input, err := tensor.New[float32]([]int{1, cfg.SeqLen}, inputData) + if err != nil { + t.Fatalf("tensor.New failed: %v", err) + } + + output, err := g.Forward(context.Background(), input) + if err != nil { + t.Fatalf("Graph.Forward failed: %v", err) + } + if output == nil { + t.Fatal("output tensor is nil") + } + + outShape := output.Shape() + t.Logf("%s output shape: %v", cfg.Name, outShape) + + if len(outShape) < 2 { + t.Errorf("output rank %d < 2; want at least 2", len(outShape)) + } + if len(outShape) == 3 { + if outShape[0] != 1 { + t.Errorf("output batch dim = %d, want 1", outShape[0]) + } + if outShape[1] != cfg.SeqLen { + t.Errorf("output seq dim = %d, want %d", outShape[1], cfg.SeqLen) + } + if outShape[2] < cfg.MinVocabSize { + t.Errorf("output vocab dim = %d, want >= %d", outShape[2], cfg.MinVocabSize) + } + } + + data := output.Data() + for i, v := range data { + f := float64(v) + if math.IsNaN(f) { + t.Errorf("output[%d] is NaN", i) + break + } + if math.IsInf(f, 0) { + t.Errorf("output[%d] is Inf", i) + break + } + } +} + +// runGreedyDecodeTest runs N greedy decode steps from an initial token sequence. +func runGreedyDecodeTest(t *testing.T, g *graph.Graph[float32], initTokens []float32, steps int) { + t.Helper() + + tokens := append([]float32{}, initTokens...) + + for step := range steps { + seqLen := len(tokens) + input, err := tensor.New[float32]([]int{1, seqLen}, append([]float32{}, tokens...)) + if err != nil { + t.Fatalf("step %d: tensor.New failed: %v", step, err) + } + + output, err := g.Forward(context.Background(), input) + if err != nil { + t.Fatalf("step %d: Graph.Forward failed: %v", step, err) + } + if output == nil { + t.Fatalf("step %d: output tensor is nil", step) + } + + outShape := output.Shape() + data := output.Data() + + var vocabSize int + var lastPosOffset int + switch len(outShape) { + case 3: + vocabSize = outShape[2] + lastPosOffset = (seqLen - 1) * vocabSize + case 2: + vocabSize = outShape[1] + lastPosOffset = 0 + default: + t.Fatalf("step %d: unexpected output rank %d", step, len(outShape)) + } + + if vocabSize == 0 { + t.Fatalf("step %d: vocabSize is 0", step) + } + + bestIdx := 0 + bestVal := data[lastPosOffset] + for j := 1; j < vocabSize; j++ { + if data[lastPosOffset+j] > bestVal { + bestVal = data[lastPosOffset+j] + bestIdx = j + } + } + + if bestIdx < 0 || bestIdx >= vocabSize { + t.Errorf("step %d: next token %d out of range [0, %d)", step, bestIdx, vocabSize) + } + t.Logf("step %d: next token = %d", step, bestIdx) + tokens = append(tokens, float32(bestIdx)) + } + + expected := len(initTokens) + steps + if len(tokens) != expected { + t.Errorf("expected %d tokens after decode, got %d", expected, len(tokens)) + } +} + +// modelParityConfig describes a complete parity test suite for a model family. +type modelParityConfig struct { + // Name is a human-readable label (e.g. "Llama 3"). + Name string + // ZMFEnvVar is the environment variable for the .zmf file path. + ZMFEnvVar string + // ModelDirEnvVar is the environment variable for the model directory. + ModelDirEnvVar string + // ModelID is the ID used with inference.Load. + ModelID string + // MinVocabSize is the minimum expected vocabulary dimension. + MinVocabSize int +} + +// runModelForwardPass runs the forward pass test for a model family. +func runModelForwardPass(t *testing.T, cfg modelParityConfig) { + t.Helper() + zmfPath := envOrSkip(t, cfg.ZMFEnvVar) + g := loadZMFGraph(t, zmfPath) + runForwardPassTest(t, g, forwardPassConfig{ + Name: cfg.Name, + SeqLen: 8, + MinVocabSize: cfg.MinVocabSize, + }) +} + +// runModelGreedyDecode runs the greedy decode test for a model family. +func runModelGreedyDecode(t *testing.T, cfg modelParityConfig) { + t.Helper() + zmfPath := envOrSkip(t, cfg.ZMFEnvVar) + g := loadZMFGraph(t, zmfPath) + runGreedyDecodeTest(t, g, []float32{1, 2, 3}, 5) +} + +// runModelGeneration runs the generation test suite for a model family. +func runModelGeneration(t *testing.T, cfg modelParityConfig) { + t.Helper() + modelDir := modelDirOrSkip(t, cfg.ModelDirEnvVar, cfg.ZMFEnvVar) + runGenerationTests(t, generationTestConfig{ + ModelID: cfg.ModelID, + ModelDir: modelDir, + }) +} + +// envOrSkip returns the value of the named env var, or skips the test. +func envOrSkip(t *testing.T, key string) string { + t.Helper() + v := os.Getenv(key) + if v == "" { + t.Skipf("%s not set; skipping", key) + } + return v +} + +// modelDirOrSkip resolves a model directory from env vars, or skips the test. +func modelDirOrSkip(t *testing.T, dirEnvVar, zmfEnvVar string) string { + t.Helper() + if d := os.Getenv(dirEnvVar); d != "" { + return d + } + zmfPath := os.Getenv(zmfEnvVar) + if zmfPath == "" { + t.Skipf("%s and %s not set; skipping", dirEnvVar, zmfEnvVar) + } + return filepath.Dir(zmfPath) +} + +// generationTestConfig holds parameters for generation tests via inference API. +type generationTestConfig struct { + ModelID string + ModelDir string +} + +// runGenerationTests runs greedy, stream, and chat tests on an inference.Model. +func runGenerationTests(t *testing.T, cfg generationTestConfig) { + t.Helper() + + reg := &dirRegistry{ + models: map[string]*registry.ModelInfo{ + cfg.ModelID: {ID: cfg.ModelID, Path: cfg.ModelDir}, + }, + } + + mdl, err := inference.Load(cfg.ModelID, inference.WithRegistry(reg)) + if err != nil { + t.Fatalf("inference.Load failed: %v", err) + } + + ctx := context.Background() + + t.Run("greedy_deterministic", func(t *testing.T) { + prompt := "The capital of France is" + result1, err := mdl.Generate(ctx, prompt, + inference.WithTemperature(0), + inference.WithMaxTokens(20), + ) + if err != nil { + t.Fatalf("Generate failed: %v", err) + } + if result1 == "" { + t.Fatal("greedy generation produced empty output") + } + t.Logf("greedy output: %q", result1) + + result2, err := mdl.Generate(ctx, prompt, + inference.WithTemperature(0), + inference.WithMaxTokens(20), + ) + if err != nil { + t.Fatalf("Generate (second) failed: %v", err) + } + if result1 != result2 { + t.Errorf("greedy outputs differ:\n run1: %q\n run2: %q", result1, result2) + } + }) + + t.Run("stream_parity", func(t *testing.T) { + prompt := "Hello world" + nonStream, err := mdl.Generate(ctx, prompt, + inference.WithTemperature(0), + inference.WithMaxTokens(10), + ) + if err != nil { + t.Fatalf("Generate failed: %v", err) + } + + var sb strings.Builder + err = mdl.GenerateStream(ctx, prompt, + generate.TokenStreamFunc(func(token string, done bool) error { + if !done { + sb.WriteString(token) + } + return nil + }), + inference.WithTemperature(0), + inference.WithMaxTokens(10), + ) + if err != nil { + t.Fatalf("GenerateStream failed: %v", err) + } + + streamed := sb.String() + if nonStream != streamed { + t.Errorf("stream/non-stream mismatch:\n non-stream: %q\n stream: %q", + nonStream, streamed) + } + }) + + t.Run("chat", func(t *testing.T) { + resp, err := mdl.Chat(ctx, []inference.Message{ + {Role: "user", Content: "Say hello in French"}, + }, inference.WithMaxTokens(20)) + if err != nil { + t.Fatalf("Chat failed: %v", err) + } + if resp.Content == "" { + t.Error("Chat produced empty content") + } + if resp.TokensUsed <= 0 { + t.Error("TokensUsed should be positive") + } + t.Logf("chat response: %q (tokens: %d)", resp.Content, resp.TokensUsed) + }) +} diff --git a/tests/parity/llama3_test.go b/tests/parity/llama3_test.go new file mode 100644 index 00000000..79edb7a1 --- /dev/null +++ b/tests/parity/llama3_test.go @@ -0,0 +1,30 @@ +package parity_test + +import ( + "testing" + + layerreg "github.com/zerfoo/zerfoo/layers/registry" +) + +var llama3Config = modelParityConfig{ + Name: "Llama 3", + ZMFEnvVar: "LLAMA3_ZMF_PATH", + ModelDirEnvVar: "LLAMA3_MODEL_DIR", + ModelID: "llama-3", + MinVocabSize: 100000, // Llama 3 8B vocab: 128256 +} + +func TestLlama3ForwardPass(t *testing.T) { + layerreg.RegisterAll() + runModelForwardPass(t, llama3Config) +} + +func TestLlama3GreedyDecode(t *testing.T) { + layerreg.RegisterAll() + runModelGreedyDecode(t, llama3Config) +} + +func TestLlama3Generation(t *testing.T) { + layerreg.RegisterAll() + runModelGeneration(t, llama3Config) +} diff --git a/tests/parity/mistral_test.go b/tests/parity/mistral_test.go new file mode 100644 index 00000000..8d19d77f --- /dev/null +++ b/tests/parity/mistral_test.go @@ -0,0 +1,30 @@ +package parity_test + +import ( + "testing" + + layerreg "github.com/zerfoo/zerfoo/layers/registry" +) + +var mistralConfig = modelParityConfig{ + Name: "Mistral", + ZMFEnvVar: "MISTRAL_ZMF_PATH", + ModelDirEnvVar: "MISTRAL_MODEL_DIR", + ModelID: "mistral", + MinVocabSize: 30000, // Mistral 7B vocab: 32000 +} + +func TestMistralForwardPass(t *testing.T) { + layerreg.RegisterAll() + runModelForwardPass(t, mistralConfig) +} + +func TestMistralGreedyDecode(t *testing.T) { + layerreg.RegisterAll() + runModelGreedyDecode(t, mistralConfig) +} + +func TestMistralGeneration(t *testing.T) { + layerreg.RegisterAll() + runModelGeneration(t, mistralConfig) +} diff --git a/tests/parity/phi4_test.go b/tests/parity/phi4_test.go new file mode 100644 index 00000000..c9f5053c --- /dev/null +++ b/tests/parity/phi4_test.go @@ -0,0 +1,30 @@ +package parity_test + +import ( + "testing" + + layerreg "github.com/zerfoo/zerfoo/layers/registry" +) + +var phi4Config = modelParityConfig{ + Name: "Phi-4", + ZMFEnvVar: "PHI4_ZMF_PATH", + ModelDirEnvVar: "PHI4_MODEL_DIR", + ModelID: "phi-4", + MinVocabSize: 100000, // Phi-4 vocab: 100352 +} + +func TestPhi4ForwardPass(t *testing.T) { + layerreg.RegisterAll() + runModelForwardPass(t, phi4Config) +} + +func TestPhi4GreedyDecode(t *testing.T) { + layerreg.RegisterAll() + runModelGreedyDecode(t, phi4Config) +} + +func TestPhi4Generation(t *testing.T) { + layerreg.RegisterAll() + runModelGeneration(t, phi4Config) +} diff --git a/tests/parity/qwen_test.go b/tests/parity/qwen_test.go new file mode 100644 index 00000000..77d35865 --- /dev/null +++ b/tests/parity/qwen_test.go @@ -0,0 +1,30 @@ +package parity_test + +import ( + "testing" + + layerreg "github.com/zerfoo/zerfoo/layers/registry" +) + +var qwenConfig = modelParityConfig{ + Name: "Qwen 2.5", + ZMFEnvVar: "QWEN25_ZMF_PATH", + ModelDirEnvVar: "QWEN25_MODEL_DIR", + ModelID: "qwen-2.5", + MinVocabSize: 150000, // Qwen 2.5 vocab: 151936 +} + +func TestQwen25ForwardPass(t *testing.T) { + layerreg.RegisterAll() + runModelForwardPass(t, qwenConfig) +} + +func TestQwen25GreedyDecode(t *testing.T) { + layerreg.RegisterAll() + runModelGreedyDecode(t, qwenConfig) +} + +func TestQwen25Generation(t *testing.T) { + layerreg.RegisterAll() + runModelGeneration(t, qwenConfig) +} diff --git a/zerfoo.go b/zerfoo.go index 6d40345a..10392dda 100644 --- a/zerfoo.go +++ b/zerfoo.go @@ -72,8 +72,9 @@ func BuildFromZMF[T tensor.Numeric]( engine compute.Engine[T], ops numeric.Arithmetic[T], m *zmf.Model, + opts ...model.BuildOption, ) (*graph.Graph[T], error) { - return model.BuildFromZMF[T](engine, ops, m) + return model.BuildFromZMF[T](engine, ops, m, opts...) } // RegisterLayer registers a new layer builder.