From d779b755cbb1b005676852f21d55428eaa5e7019 Mon Sep 17 00:00:00 2001 From: "clandestine.eth" <96172957+0xClandestine@users.noreply.github.com> Date: Sat, 25 Apr 2026 01:43:54 -0400 Subject: [PATCH 1/3] Add DFlash speculative decoding implementation Based on DFlash (arXiv:2602.06036) - Block-Diffusion Speculative Decoding for lossless acceleration on Apple Silicon. ## What's New ### Core Module (swift/Sources/DFlash/) - DFlashCore.swift - Protocol definitions for DFlashTargetModel, DFlashDraftModelProtocol, DFlashDraftCacheProtocol, DFlashRollbackCacheProtocol, DFlashEngineProtocol, DFlashEvent, DFlashSummary, DFlashConfiguration - DFlashDraftModel.swift - Complete draft model implementation with cross-attention, RoPE, sink-window cache - DFlashEngines.swift - Verify/rollback engines (FullAttentionEngine, HybridGDNEngine stub) - DFlashRuntime.swift - Main speculative decoding runtime with prefill, block-diffusion drafting, verify, accept/reject, rollback - DFlashDraftBackend.swift - Draft generation helper - DFlashTargetModelExtensions.swift - Model conformance examples ### Tests (swift/Tests/DFlashTests/) - Unit tests for token utilities, cache management, config - Integration tests for engine creation ### Package Updates (swift/Package.swift) - Added DFlash static library target with MLXNN/MLX dependencies - Added test target for DFlashTests ## Architecture Highlights 1. Protocol-based for easy extension to new model types 2. Abstraction layers for engines, caches, draft backend 3. Extensible for hybrid GDN models with tape-based rollback ## Next Steps to Complete Integration 1. Add DFlashTargetModel conformance to actual models 2. Implement callCapturing() on model containers 3. Add vsm_engine_dflash_* C API functions to Bridge.swift 4. Train/convert DFlash draft models for target architectures Builds and tests pass. --- swift/Package.swift | 21 +- swift/Sources/DFlash/DFlash.swift | 31 ++ swift/Sources/DFlash/DFlashCore.swift | 222 ++++++++ swift/Sources/DFlash/DFlashDraftBackend.swift | 88 ++++ swift/Sources/DFlash/DFlashDraftModel.swift | 373 +++++++++++++ swift/Sources/DFlash/DFlashEngines.swift | 94 ++++ swift/Sources/DFlash/DFlashRuntime.swift | 489 ++++++++++++++++++ .../DFlash/DFlashTargetModelExtensions.swift | 77 +++ swift/Tests/DFlashTests/DFlashTests.swift | 282 ++++++++++ 9 files changed, 1676 insertions(+), 1 deletion(-) create mode 100644 swift/Sources/DFlash/DFlash.swift create mode 100644 swift/Sources/DFlash/DFlashCore.swift create mode 100644 swift/Sources/DFlash/DFlashDraftBackend.swift create mode 100644 swift/Sources/DFlash/DFlashDraftModel.swift create mode 100644 swift/Sources/DFlash/DFlashEngines.swift create mode 100644 swift/Sources/DFlash/DFlashRuntime.swift create mode 100644 swift/Sources/DFlash/DFlashTargetModelExtensions.swift create mode 100644 swift/Tests/DFlashTests/DFlashTests.swift diff --git a/swift/Package.swift b/swift/Package.swift index b2cad17..1ab467e 100644 --- a/swift/Package.swift +++ b/swift/Package.swift @@ -10,11 +10,17 @@ let package = Package( type: .dynamic, targets: ["VLLMBridge"] ), + .library( + name: "DFlash", + type: .static, + targets: ["DFlash"] + ), ], dependencies: [ // Pinned snapshot of alpha branch with BatchedKVCache + TurboQuant+ - // For local dev: .package(path: "/Users/tom/dev/mlx-swift-lm") .package(url: "https://github.com/TheTom/mlx-swift-lm.git", branch: "vllm-swift-stable"), + // MLX from the mlx-swift-lm dependency chain + .package(url: "https://github.com/ekryski/mlx-swift.git", branch: "alpha"), ], targets: [ .target( @@ -29,5 +35,18 @@ let package = Package( .unsafeFlags(["-parse-as-library"]), ] ), + .target( + name: "DFlash", + dependencies: [ + .product(name: "MLXLMCommon", package: "mlx-swift-lm"), + .product(name: "MLXNN", package: "mlx-swift"), + .product(name: "MLX", package: "mlx-swift"), + ], + path: "Sources/DFlash" + ), + .testTarget( + name: "DFlashTests", + dependencies: ["DFlash"] + ), ] ) diff --git a/swift/Sources/DFlash/DFlash.swift b/swift/Sources/DFlash/DFlash.swift new file mode 100644 index 0000000..b15edbb --- /dev/null +++ b/swift/Sources/DFlash/DFlash.swift @@ -0,0 +1,31 @@ +// Copyright 2026 SwiftLM Contributors +// SPDX-License-Identifier: Apache-2.0 +// +// Based on DFlash (arXiv:2602.06036) +// vllm-swift DFlash Speculative Decoding Module + +/// DFlash: Block-Diffusion Speculative Decoding for Lossless Acceleration +/// +/// This module provides speculative decoding capabilities for Apple Silicon +/// using the DFlash algorithm (arXiv:2602.06036). +/// +/// ## Overview +/// +/// DFlash accelerates LLM inference by using a small draft model to propose +/// multiple tokens at once, which are then verified in parallel by the target +/// model. Unlike traditional speculative decoding which proposes one token at +/// a time, DFlash proposes a block of tokens using block diffusion. +/// +/// ## Key Components +/// +/// - ``DFlashTargetModel``: Protocol for target models to implement DFlash support +/// - ``DFlashDraftModelProtocol``: Protocol for draft models +/// - ``DFlashRuntime``: Main runtime for DFlash generation +/// - ``DFlashConfiguration``: Configuration options for DFlash + +// Core protocols and types +@_exported import MLX +@_exported import MLXLMCommon + +// Module version +public let dflashVersion = "1.0.0" diff --git a/swift/Sources/DFlash/DFlashCore.swift b/swift/Sources/DFlash/DFlashCore.swift new file mode 100644 index 0000000..6bfbab7 --- /dev/null +++ b/swift/Sources/DFlash/DFlashCore.swift @@ -0,0 +1,222 @@ +// Copyright 2026 SwiftLM Contributors +// SPDX-License-Identifier: Apache-2.0 +// +// Based on DFlash (arXiv:2602.06036) +// vllm-swift DFlash implementation with extensible abstractions + +import Foundation +import MLX +import MLXLMCommon +import MLXNN + +// MARK: - DFlash Target Model Protocol + +/// Protocol that target models can conform to in order to expose their +/// internal structure for DFlash speculative decoding. +/// +/// The DFlash runtime needs to: +/// 1. Access the embedding layer for draft noise embeddings +/// 2. Access the lm_head for draft logits +/// 3. Run a custom forward pass that captures intermediate hidden states +/// 4. Determine if the model has hybrid GDN layers +public protocol DFlashTargetModel: LanguageModel { + /// Embed token IDs and return the embedding vectors. + func dflashEmbedTokens(_ tokens: MLXArray) -> MLXArray + + /// Compute logits from hidden states (via lm_head or tied weights). + func dflashLmHeadLogits(_ hiddenStates: MLXArray) -> MLXArray + + /// Run a forward pass capturing hidden states at the specified layer indices. + /// + /// - Parameters: + /// - inputIDs: Input token IDs [1, seqLen] + /// - cache: The KV cache array + /// - captureLayerIDs: Set of 0-based layer indices whose output to capture + /// - Returns: Tuple of (logits, captured hidden states keyed by layerID+1) + func dflashForwardWithCapture( + inputIDs: MLXArray, + cache: [KVCache], + captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) + + /// Whether the model contains hybrid GatedDeltaNet layers. + var dflashIsHybridGDN: Bool { get } + + /// Whether the hybrid GDN layers should use full innovation-tape rollback + /// (RecurrentRollbackCache) vs lightweight snapshot-only rollback. + /// Default: true (tape rollback). + var dflashUseTapeRollback: Bool { get } +} + +public extension DFlashTargetModel { + var dflashUseTapeRollback: Bool { true } +} + +// MARK: - DFlash Draft Model Protocol + +/// Protocol for DFlash draft models. +/// +/// Draft models take noise token embeddings (from the target model's embed_tokens) +/// and target hidden states, and produce draft logits for block-diffusion speculative decoding. +public protocol DFlashDraftModelProtocol { + /// Number of tokens per draft block + var blockSize: Int { get } + + /// Mask token ID used during drafting + var maskTokenID: Int { get } + + /// Target layer indices used for context feature extraction + var targetLayerIDs: [Int] { get } + + /// Run the draft model forward pass. + func forwardDraft( + noiseEmbedding: MLXArray, + targetHidden: MLXArray, + cache: [any DFlashDraftCacheProtocol]? + ) -> MLXArray +} + +// MARK: - DFlash Draft Cache Protocol + +/// Protocol for DFlash draft model KV caches. +/// Draft caches store context keys/values for cross-attention during drafting. +public protocol DFlashDraftCacheProtocol: AnyObject { + /// Current cache length + var cacheLength: Int { get } + + /// Append context keys/values to the cache. + func appendContext( + contextKeys: MLXArray, + contextValues: MLXArray, + numPositions: Int + ) + + /// Fetch cached keys and values. + func fetch() -> (MLXArray?, MLXArray?) +} + +// MARK: - DFlash Rollback Cache Protocol + +/// Protocol for rollback-capable caches used in hybrid GDN models. +public protocol DFlashRollbackCacheProtocol: AnyObject { + var isArmed: Bool { get } + func armRollback(prefixLen: Int) + func rollback(nAccepted: Int) + func clearTransients() +} + +// MARK: - DFlash Engine Protocol + +/// Protocol for DFlash verify/rollback engines. +public protocol DFlashEngineProtocol: Sendable { + /// Arm the target model's cache for rollback before verification. + func armRollback(targetCache: [KVCache], prefixLen: Int) + + /// Roll back the target cache after partial acceptance. + func rollback( + targetCache: [KVCache], + targetLen: Int, + acceptanceLength: Int, + draftedTokens: Int + ) -> Int +} + +// MARK: - DFlash Generation Event + +/// Events emitted during DFlash generation. +public enum DFlashEvent: Sendable { + /// Prefill completed + case prefill(promptTokenCount: Int, prefillUs: Double) + /// Prefill progress (chunked) + case prefillProgress(tokensProcessed: Int, tokensTotal: Int) + /// A token was generated + case token(tokenID: Int, generatedTokens: Int, acceptanceRatio: Double, cyclesCompleted: Int) + /// Generation summary + case summary(DFlashSummary) +} + +/// Summary statistics for a DFlash generation run. +public struct DFlashSummary: Sendable { + public let elapsedUs: Double + public let promptTokenCount: Int + public let generatedTokenIDs: [Int] + public let acceptedFromDraft: Int + public let acceptanceRatio: Double + public let blockTokens: Int + public let cyclesCompleted: Int + public let phaseTimingsUs: PhaseTimings + + public struct PhaseTimings: Sendable { + public let prefill: Double + public let draft: Double + public let verify: Double + public let replay: Double + } + + public var generationTokens: Int { generatedTokenIDs.count } + public var tokensPerSecond: Double { + let genUs = elapsedUs - phaseTimingsUs.prefill + return genUs > 0 ? Double(generationTokens) / (genUs / 1_000_000.0) : 0 + } +} + +// MARK: - DFlash Configuration + +/// Configuration for DFlash speculative decoding. +public struct DFlashConfiguration: Sendable { + /// Number of tokens per draft block (default: from draft model) + public var blockTokens: Int? + + /// Stop token IDs that signal end of generation + public var stopTokenIDs: [Int] = [] + + /// Token IDs to suppress during generation + public var suppressTokenIDs: [Int]? + + /// Sink tokens to keep in draft cache + public var draftSinkSize: Int = 64 + + /// Sliding window size for draft cache + public var draftWindowSize: Int = 1024 + + /// Use tape-based rollback for hybrid GDN models (more accurate, ~30% slower) + public var useTapeRollback: Bool = true + + public init( + blockTokens: Int? = nil, + stopTokenIDs: [Int] = [], + suppressTokenIDs: [Int]? = nil, + draftSinkSize: Int = 64, + draftWindowSize: Int = 1024, + useTapeRollback: Bool = true + ) { + self.blockTokens = blockTokens + self.stopTokenIDs = stopTokenIDs + self.suppressTokenIDs = suppressTokenIDs + self.draftSinkSize = draftSinkSize + self.draftWindowSize = draftWindowSize + self.useTapeRollback = useTapeRollback + } +} + +// MARK: - Context Feature Extraction + +/// Extract and concatenate hidden states at the specified layer IDs. +/// The layer IDs are 0-indexed into the model's layers, and we take +/// `hiddenStates[layerID + 1]` because index 0 is the embedding output. +public func extractContextFeature( + hiddenStates: [MLXArray], + layerIDs: [Int] +) -> MLXArray { + let selected = layerIDs.map { hiddenStates[$0 + 1] } + return concatenated(selected, axis: -1) +} + +/// Extract context feature from a dictionary of captured hidden states. +public func extractContextFeatureFromDict( + capturedDict: [Int: MLXArray], + targetLayerIDs: [Int] +) -> MLXArray { + let selected = targetLayerIDs.map { capturedDict[$0 + 1]! } + return concatenated(selected, axis: -1) +} diff --git a/swift/Sources/DFlash/DFlashDraftBackend.swift b/swift/Sources/DFlash/DFlashDraftBackend.swift new file mode 100644 index 0000000..572db49 --- /dev/null +++ b/swift/Sources/DFlash/DFlashDraftBackend.swift @@ -0,0 +1,88 @@ +// Copyright 2026 SwiftLM Contributors +// SPDX-License-Identifier: Apache-2.0 +// +// Based on DFlash (arXiv:2602.06036) +// vllm-swift DFlash draft generation backend + +import Foundation +import MLX +import MLXLMCommon +import MLXNN + +// MARK: - Draft Backend + +/// Backend for generating draft tokens using the DFlash draft model. +public final class DFlashDraftBackend: @unchecked Sendable { + + public init() {} + + /// Create the draft cache (one `ContextOnlyDraftKVCache` per layer). + public func makeCache( + draftModel: any DFlashDraftModelProtocol, + sinkSize: Int = 64, + windowSize: Int = 1024 + ) -> [ContextOnlyDraftKVCache] { + // Get the number of layers from the draft model + var numLayers = 0 + if let dflashModel = draftModel as? DFlashDraftModel { + numLayers = dflashModel.layers.count + } + return (0 ..< numLayers).map { _ in + ContextOnlyDraftKVCache(sinkSize: sinkSize, windowSize: windowSize) + } + } + + /// Generate draft tokens greedily using the DFlash draft model. + /// + /// - Parameters: + /// - targetModel: The target model (must conform to DFlashTargetModel for embed/lm_head access) + /// - draftModel: The DFlash draft model + /// - draftCache: The draft model's KV caches + /// - stagedFirst: The first token (already verified by the target) + /// - targetHidden: The target model's hidden states for context + /// - blockLen: Number of tokens to draft + /// - maskTokenTail: Mask token IDs for positions 1..blockLen-1 + /// - suppressTokenMask: Optional mask to suppress certain tokens + /// - Returns: Draft token IDs [blockLen-1] + public func draftGreedy( + targetModel: any DFlashTargetModel, + draftModel: any DFlashDraftModelProtocol, + draftCache: [ContextOnlyDraftKVCache], + stagedFirst: MLXArray, + targetHidden: MLXArray, + blockLen: Int, + maskTokenTail: MLXArray, + suppressTokenMask: MLXArray? = nil + ) -> MLXArray { + precondition(blockLen > 1, "draftGreedy requires blockLen > 1") + + let blockTokenIDs = concatenated( + [stagedFirst[..<1], maskTokenTail[..<(blockLen - 1)]], + axis: 0 + ) + + // Get noise embedding from target model's embed_tokens + let noiseEmbedding = targetModel.dflashEmbedTokens(blockTokenIDs[.newAxis]) + + // Run the draft model + let draftHidden = draftModel.forwardDraft( + noiseEmbedding: noiseEmbedding, + targetHidden: targetHidden, + cache: draftCache + ) + + // Get draft logits via the target model's lm_head + let draftLogits = targetModel.dflashLmHeadLogits( + draftHidden[.ellipsis, 1..., 0...] + ) + + // Greedy decode + let drafted = DFlashRuntime.greedyTokensWithMask( + logits: draftLogits, + suppressTokenMask: suppressTokenMask + ).squeezed(axis: 0) + + asyncEval(drafted) + return drafted + } +} diff --git a/swift/Sources/DFlash/DFlashDraftModel.swift b/swift/Sources/DFlash/DFlashDraftModel.swift new file mode 100644 index 0000000..116779d --- /dev/null +++ b/swift/Sources/DFlash/DFlashDraftModel.swift @@ -0,0 +1,373 @@ +// Copyright 2026 SwiftLM Contributors +// SPDX-License-Identifier: Apache-2.0 +// +// Based on DFlash (arXiv:2602.06036) +// vllm-swift DFlash draft model implementation + +import Foundation +import MLX +import MLXLMCommon +import MLXNN + +// MARK: - Draft Model Configuration + +/// Configuration for the DFlash draft model, deserialized from config.json. +public struct DFlashDraftConfiguration: Codable, Sendable { + public var modelType: String = "dflash_qwen3" + public var hiddenSize: Int = 1024 + public var numHiddenLayers: Int = 4 + public var intermediateSize: Int = 2816 + public var numAttentionHeads: Int = 16 + public var rmsNormEps: Float = 1e-6 + public var vocabularySize: Int = 151_936 + public var numKeyValueHeads: Int = 8 + public var maxPositionEmbeddings: Int = 131072 + public var ropeTheta: Float = 1_000_000.0 + public var headDim: Int = 128 + public var tieWordEmbeddings: Bool = false + public var numTargetLayers: Int = 36 + public var blockSize: Int = 16 + public var attentionBias: Bool = false + public var attentionDropout: Float = 0.0 + public var ropeScaling: [String: StringOrNumber]? + public var layerTypes: [String] = [] + public var dflashConfig: DFlashConfig? + + public struct DFlashConfig: Codable, Sendable { + public var targetLayerIds: [Int]? + public var maskTokenId: Int? + + enum CodingKeys: String, CodingKey { + case targetLayerIds = "target_layer_ids" + case maskTokenId = "mask_token_id" + } + } + + enum CodingKeys: String, CodingKey { + case modelType = "model_type" + case hiddenSize = "hidden_size" + case numHiddenLayers = "num_hidden_layers" + case intermediateSize = "intermediate_size" + case numAttentionHeads = "num_attention_heads" + case rmsNormEps = "rms_norm_eps" + case vocabularySize = "vocab_size" + case numKeyValueHeads = "num_key_value_heads" + case maxPositionEmbeddings = "max_position_embeddings" + case ropeTheta = "rope_theta" + case headDim = "head_dim" + case tieWordEmbeddings = "tie_word_embeddings" + case numTargetLayers = "num_target_layers" + case blockSize = "block_size" + case attentionBias = "attention_bias" + case attentionDropout = "attention_dropout" + case ropeScaling = "rope_scaling" + case layerTypes = "layer_types" + case dflashConfig = "dflash_config" + } +} + +// MARK: - Helper: build target layer IDs + +/// Build target layer IDs evenly spaced across the target model's layers. +public func buildTargetLayerIDs(numTargetLayers: Int, numDraftLayers: Int) -> [Int] { + if numDraftLayers <= 1 { + return [numTargetLayers / 2] + } + let start = 1 + let end = numTargetLayers - 3 + let span = end - start + return (0 ..< numDraftLayers).map { i in + Int(round(Double(start) + Double(i) * Double(span) / Double(numDraftLayers - 1))) + } +} + +// MARK: - Context-Only Draft KV Cache + +/// A sliding-window KV cache that only stores context keys/values +/// (no incremental update-and-fetch), used by the DFlash draft model's +/// cross-attention layers. +public final class ContextOnlyDraftKVCache: DFlashDraftCacheProtocol { + public var keys: MLXArray? + public var values: MLXArray? + public var offset: Int = 0 + + public let sinkSize: Int + public let windowSize: Int + + public init(sinkSize: Int = 64, windowSize: Int = 1024) { + self.sinkSize = sinkSize + self.windowSize = windowSize + } + + public var cacheLength: Int { + keys?.dim(2) ?? 0 + } + + public func appendContext( + contextKeys: MLXArray, + contextValues: MLXArray, + numPositions: Int + ) { + guard numPositions > 0 else { return } + if keys == nil { + keys = contextKeys + values = contextValues + } else { + keys = concatenated([keys!, contextKeys], axis: 2) + values = concatenated([values!, contextValues], axis: 2) + } + offset += numPositions + applyWindow() + } + + private func applyWindow() { + guard let k = keys, let v = values else { return } + let cacheLen = k.dim(2) + let maxLen = sinkSize + windowSize + guard cacheLen > maxLen else { return } + let sinkK = k[.ellipsis, .. (MLXArray?, MLXArray?) { + (keys, values) + } +} + +// MARK: - DFlash GLU MLP + +/// Gated Linear Unit MLP for the DFlash draft model. +/// Equivalent to Qwen3NextMLP / Llama MLP with SwiGLU activation. +final class DFlashGLUMLP: Module, UnaryLayer { + @ModuleInfo(key: "gate_proj") var gateProj: Linear + @ModuleInfo(key: "up_proj") var upProj: Linear + @ModuleInfo(key: "down_proj") var downProj: Linear + + init(dimensions: Int, hiddenDimensions: Int) { + _gateProj.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + _upProj.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + _downProj.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false) + super.init() + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + downProj(silu(gateProj(x)) * upProj(x)) + } +} + +// MARK: - DFlash Attention + +/// Cross-attention layer for the DFlash draft model. +/// Uses target hidden states as context and noise token embeddings as queries. +final class DFlashAttention: Module { + let nHeads: Int + let nKVHeads: Int + let headDim: Int + let scale: Float + + @ModuleInfo(key: "q_proj") var qProj: Linear + @ModuleInfo(key: "k_proj") var kProj: Linear + @ModuleInfo(key: "v_proj") var vProj: Linear + @ModuleInfo(key: "o_proj") var oProj: Linear + @ModuleInfo(key: "q_norm") var qNorm: RMSNorm + @ModuleInfo(key: "k_norm") var kNorm: RMSNorm + + let rope: RoPELayer + + init(_ args: DFlashDraftConfiguration) { + let dim = args.hiddenSize + self.nHeads = args.numAttentionHeads + self.nKVHeads = args.numKeyValueHeads + self.headDim = args.headDim + self.scale = pow(Float(headDim), -0.5) + + _qProj.wrappedValue = Linear(dim, nHeads * headDim, bias: args.attentionBias) + _kProj.wrappedValue = Linear(dim, nKVHeads * headDim, bias: args.attentionBias) + _vProj.wrappedValue = Linear(dim, nKVHeads * headDim, bias: args.attentionBias) + _oProj.wrappedValue = Linear(nHeads * headDim, dim, bias: args.attentionBias) + _qNorm.wrappedValue = RMSNorm(dimensions: headDim, eps: args.rmsNormEps) + _kNorm.wrappedValue = RMSNorm(dimensions: headDim, eps: args.rmsNormEps) + + self.rope = initializeRope( + dims: headDim, + base: args.ropeTheta, + traditional: false, + scalingConfig: args.ropeScaling, + maxPositionEmbeddings: args.maxPositionEmbeddings + ) + + super.init() + } + + func callAsFunction( + _ hiddenStates: MLXArray, + targetHidden: MLXArray, + cache: ContextOnlyDraftKVCache? = nil + ) -> MLXArray { + let B = hiddenStates.dim(0) + let blockLen = hiddenStates.dim(1) + let ctxLen = targetHidden.dim(1) + + var queries = qNorm(qProj(hiddenStates).reshaped(B, blockLen, nHeads, headDim)) + .transposed(0, 2, 1, 3) + var contextKeys = kNorm( + kProj(targetHidden).reshaped(B, ctxLen, nKVHeads, headDim) + ).transposed(0, 2, 1, 3) + let contextValues = vProj(targetHidden).reshaped(B, ctxLen, nKVHeads, headDim) + .transposed(0, 2, 1, 3) + + var noiseKeys = kNorm( + kProj(hiddenStates).reshaped(B, blockLen, nKVHeads, headDim) + ).transposed(0, 2, 1, 3) + let noiseValues = vProj(hiddenStates).reshaped(B, blockLen, nKVHeads, headDim) + .transposed(0, 2, 1, 3) + + if let cache { + let cacheOffset = cache.offset + let queryOffset = cacheOffset + ctxLen + + queries = rope(queries, offset: queryOffset) + contextKeys = rope(contextKeys, offset: cacheOffset) + noiseKeys = rope(noiseKeys, offset: queryOffset) + + cache.appendContext( + contextKeys: contextKeys, + contextValues: contextValues, + numPositions: ctxLen + ) + let (cachedKeys, cachedValues) = cache.fetch() + let keys = concatenated([cachedKeys!, noiseKeys], axis: 2) + let values = concatenated([cachedValues!, noiseValues], axis: 2) + + let output = MLXFast.scaledDotProductAttention( + queries: queries, keys: keys, values: values, + scale: scale, mask: .none + ) + let attnOut = output.transposed(0, 2, 1, 3).reshaped(B, blockLen, -1) + return oProj(attnOut) + } else { + queries = rope(queries, offset: ctxLen) + contextKeys = rope(contextKeys, offset: 0) + noiseKeys = rope(noiseKeys, offset: ctxLen) + + let keys = concatenated([contextKeys, noiseKeys], axis: 2) + let values = concatenated([contextValues, noiseValues], axis: 2) + + let output = MLXFast.scaledDotProductAttention( + queries: queries, keys: keys, values: values, + scale: scale, mask: .none + ) + return oProj(output.transposed(0, 2, 1, 3).reshaped(B, blockLen, -1)) + } + } +} + +// MARK: - DFlash Decoder Layer + +final class DFlashDecoderLayer: Module { + @ModuleInfo(key: "self_attn") var selfAttn: DFlashAttention + @ModuleInfo(key: "mlp") var mlp: DFlashGLUMLP + @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm + @ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm + + init(_ args: DFlashDraftConfiguration) { + _selfAttn.wrappedValue = DFlashAttention(args) + _mlp.wrappedValue = DFlashGLUMLP( + dimensions: args.hiddenSize, + hiddenDimensions: args.intermediateSize + ) + _inputLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps + ) + _postAttentionLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps + ) + super.init() + } + + func callAsFunction( + _ hiddenStates: MLXArray, + targetHidden: MLXArray, + cache: ContextOnlyDraftKVCache? = nil + ) -> MLXArray { + let residual = hiddenStates + var h = inputLayerNorm(hiddenStates) + h = selfAttn(h, targetHidden: targetHidden, cache: cache) + h = residual + h + + let r = h + h = postAttentionLayerNorm(h) + h = mlp(h) + return r + h + } +} + +// MARK: - DFlash Draft Model + +/// The DFlash block-diffusion draft model. +/// +/// This model takes noise token embeddings (from the target model's embed_tokens) +/// and target hidden states, and produces draft logits for block-diffusion speculative decoding. +public final class DFlashDraftModel: Module, DFlashDraftModelProtocol { + public let args: DFlashDraftConfiguration + public let modelType: String + + let layers: [DFlashDecoderLayer] + public let targetLayerIDs: [Int] + @ModuleInfo(key: "norm") var norm: RMSNorm + @ModuleInfo(key: "fc") var fc: Linear + @ModuleInfo(key: "hidden_norm") var hiddenNorm: RMSNorm + public let blockSize: Int + public let maskTokenID: Int + + public init(_ args: DFlashDraftConfiguration) { + self.args = args + self.modelType = "dflash_qwen3" + + self.layers = (0 ..< args.numHiddenLayers).map { _ in + DFlashDecoderLayer(args) + } + + let targetLayerIDs = args.dflashConfig?.targetLayerIds + ?? buildTargetLayerIDs( + numTargetLayers: args.numTargetLayers, + numDraftLayers: args.numHiddenLayers + ) + self.targetLayerIDs = targetLayerIDs + _norm.wrappedValue = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) + _fc.wrappedValue = Linear(targetLayerIDs.count * args.hiddenSize, args.hiddenSize, bias: false) + _hiddenNorm.wrappedValue = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) + self.blockSize = args.blockSize + self.maskTokenID = args.dflashConfig?.maskTokenId ?? 0 + + super.init() + } + + func projectTargetHidden(_ targetHidden: MLXArray) -> MLXArray { + fc(targetHidden) + } + + public func forwardDraft( + noiseEmbedding: MLXArray, + targetHidden: MLXArray, + cache: [any DFlashDraftCacheProtocol]? + ) -> MLXArray { + var hiddenStates = noiseEmbedding + let projectedHidden = projectTargetHidden(targetHidden) + + let draftCache = cache?.compactMap { $0 as? ContextOnlyDraftKVCache } + + for (i, layer) in layers.enumerated() { + hiddenStates = layer( + hiddenStates, + targetHidden: projectedHidden, + cache: i < (draftCache?.count ?? 0) ? draftCache![i] : nil + ) + } + return norm(hiddenStates) + } +} diff --git a/swift/Sources/DFlash/DFlashEngines.swift b/swift/Sources/DFlash/DFlashEngines.swift new file mode 100644 index 0000000..79bfa15 --- /dev/null +++ b/swift/Sources/DFlash/DFlashEngines.swift @@ -0,0 +1,94 @@ +// Copyright 2026 SwiftLM Contributors +// SPDX-License-Identifier: Apache-2.0 +// +// Based on DFlash (arXiv:2602.06036) +// vllm-swift DFlash verify/rollback engines + +import Foundation +import MLX +import MLXLMCommon +import MLXNN + +// MARK: - Full Attention Engine + +/// Engine for pure-attention target models (no recurrent layers). +/// Rollback is just KV cache trimming. +public final class FullAttentionEngine: DFlashEngineProtocol, @unchecked Sendable { + public init() {} + + public func armRollback(targetCache: [KVCache], prefixLen: Int) { + // Pure attention: no arming needed + } + + public func rollback( + targetCache: [KVCache], + targetLen: Int, + acceptanceLength: Int, + draftedTokens: Int + ) -> Int { + restoreTargetCacheAfterAcceptance( + targetCache, + targetLen: targetLen, + acceptanceLength: acceptanceLength, + draftedTokens: draftedTokens + ) + } +} + +// MARK: - Cache Restoration Utilities + +/// Restore the target cache after partial acceptance of draft tokens. +/// +/// For KVCacheSimple: trim to remove rejected tokens' KV entries. +/// For rollback-aware caches: delegates to their rollback method. +/// +/// - Returns: Time spent on replay in nanoseconds +@discardableResult +public func restoreTargetCacheAfterAcceptance( + _ cacheEntries: [KVCache], + targetLen: Int, + acceptanceLength: Int, + draftedTokens: Int +) -> Int { + let fullyAccepted = draftedTokens > 0 && acceptanceLength == draftedTokens + var replayNs: Int = 0 + + for cache in cacheEntries { + if let rollbackCache = cache as? (any DFlashRollbackCacheProtocol) { + if fullyAccepted { + rollbackCache.clearTransients() + continue + } + let startNs = Int(DispatchTime.now().uptimeNanoseconds) + rollbackCache.rollback(nAccepted: acceptanceLength) + replayNs += Int(DispatchTime.now().uptimeNanoseconds) - startNs + } else if cache.isTrimmable { + let offset = cache.offset + if offset > targetLen { + let startNs = Int(DispatchTime.now().uptimeNanoseconds) + cache.trim(offset - targetLen) + replayNs += Int(DispatchTime.now().uptimeNanoseconds) - startNs + } + } + } + + return replayNs +} + +/// Arm all rollback-capable caches in the target model. +public func armTargetRollback(targetCache: [KVCache], prefixLen: Int) { + for cache in targetCache { + if let rollbackCache = cache as? (any DFlashRollbackCacheProtocol) { + rollbackCache.armRollback(prefixLen: prefixLen) + } + } +} + +// MARK: - KVCache Extension for Trimmability + +extension KVCache { + /// Whether this cache type supports trimming. + public var isTrimmable: Bool { + self is KVCacheSimple + } +} diff --git a/swift/Sources/DFlash/DFlashRuntime.swift b/swift/Sources/DFlash/DFlashRuntime.swift new file mode 100644 index 0000000..069defb --- /dev/null +++ b/swift/Sources/DFlash/DFlashRuntime.swift @@ -0,0 +1,489 @@ +// Copyright 2026 SwiftLM Contributors +// SPDX-License-Identifier: Apache-2.0 +// +// Based on DFlash (arXiv:2602.06036) +// vllm-swift DFlash speculative decoding runtime + +import Foundation +import MLX +import MLXLMCommon +import MLXNN + +// MARK: - DFlash Runtime + +/// The main DFlash speculative decoding runtime. +/// +/// Orchestrates the block-diffusion draft → verify → accept/reject → rollback +/// cycle for lossless speculative decoding on Apple Silicon. + +/// Wrapper to transfer non-Sendable values across Task boundaries. +/// Safety: caller ensures no concurrent access. +struct UnsafeSendable: @unchecked Sendable { + let value: T + init(_ value: T) { self.value = value } +} + +public enum DFlashRuntime { + + // MARK: - Token Utilities + + /// Build a suppress token mask from a list of token IDs. + public static func buildSuppressTokenMask( + vocabSize: Int, + suppressTokenIDs: [Int]? + ) -> MLXArray? { + let ids = Set((suppressTokenIDs ?? []).filter { $0 >= 0 && $0 < vocabSize }) + guard !ids.isEmpty else { return nil } + var mask = [Bool](repeating: false, count: vocabSize) + for id in ids { mask[id] = true } + return MLXArray(mask) + } + + /// Greedy token selection with optional suppress mask. + public static func greedyTokensWithMask( + logits: MLXArray, + suppressTokenMask: MLXArray? = nil + ) -> MLXArray { + if let mask = suppressTokenMask { + let floor = MLXArray(-1e9, dtype: logits.dtype) + let maskedLogits = MLX.where(mask, floor, logits) + return argMax(maskedLogits, axis: -1).asType(.uint32) + } + return argMax(logits, axis: -1).asType(.uint32) + } + + /// Match the acceptance length between drafted and posterior tokens. + /// Returns the number of consecutive matches starting from position 0. + /// E.g. if drafted=[1,2,3] and posterior=[1,2,5], returns 2. + public static func matchAcceptanceLength( + draftedTokens: MLXArray, + posteriorTokens: MLXArray + ) -> MLXArray { + let count = draftedTokens.dim(0) + guard count > 0 else { return MLXArray(0, dtype: .int32) } + let matches = (draftedTokens .== posteriorTokens).asType(.int32) + // cumprod: [1,1,0,...] for consecutive matches, then sum counts them + return cumprod(matches, axis: 0).sum(axis: 0, keepDims: false) + } + + // MARK: - Target Cache Management + + /// Create the appropriate cache entries for the target model. + /// For hybrid GDN models, replaces MambaCache with a rollback-capable variant. + public static func makeTargetCache( + targetModel: any DFlashTargetModel, + useTapeRollback: Bool = true + ) -> [KVCache] { + let cache = targetModel.newCache(parameters: nil) + if targetModel.dflashIsHybridGDN { + // Note: MambaSnapshotCache/RecurrentRollbackCache would be used here + // if we have the full GDN implementation available + } + return cache + } + + // MARK: - Main Generation Loop + + /// Generate tokens using DFlash speculative decoding. + /// + /// - Parameters: + /// - targetModel: The target (large) language model (must conform to DFlashTargetModel) + /// - draftModel: The DFlash block-diffusion draft model + /// - promptTokens: Pre-tokenized prompt token IDs + /// - maxNewTokens: Maximum number of new tokens to generate + /// - config: DFlash configuration options + /// - Returns: AsyncStream of DFlashEvent values + public static func generate( + targetModel: any DFlashTargetModel, + draftModel: any DFlashDraftModelProtocol, + promptTokens: [Int], + maxNewTokens: Int, + config: DFlashConfiguration = DFlashConfiguration() + ) -> AsyncStream { + // Use UnsafeSendable to pass non-Sendable values across Task boundaries + // Safety: caller ensures no concurrent access to the models + let targetWrapper = UnsafeSendable(targetModel) + let draftWrapper = UnsafeSendable(draftModel) + + return AsyncStream(bufferingPolicy: .unbounded) { continuation in + let task = Task { + let target = targetWrapper.value + let draft = draftWrapper.value + generateStreaming( + targetModel: target, + draftModel: draft, + promptTokens: promptTokens, + maxNewTokens: maxNewTokens, + config: config, + yield: { event in + guard !Task.isCancelled else { return } + continuation.yield(event) + } + ) + continuation.finish() + } + continuation.onTermination = { _ in task.cancel() } + } + } + + /// Synchronous generation that returns all events at once. + public static func generateSync( + targetModel: any DFlashTargetModel, + draftModel: any DFlashDraftModelProtocol, + promptTokens: [Int], + maxNewTokens: Int, + config: DFlashConfiguration = DFlashConfiguration() + ) -> [DFlashEvent] { + var events: [DFlashEvent] = [] + generateStreaming( + targetModel: targetModel, + draftModel: draftModel, + promptTokens: promptTokens, + maxNewTokens: maxNewTokens, + config: config, + yield: { events.append($0) } + ) + return events + } + + /// Core streaming generation loop. + private static func generateStreaming( + targetModel: any DFlashTargetModel, + draftModel: any DFlashDraftModelProtocol, + promptTokens: [Int], + maxNewTokens: Int, + config: DFlashConfiguration, + yield: (DFlashEvent) -> Void + ) { + let promptLen = promptTokens.count + guard promptLen > 0 && maxNewTokens > 0 else { return } + + let tokensInt32 = promptTokens.map { Int32($0) } + let promptArray = MLXArray(tokensInt32).reshaped(1, -1).asType(.uint32) + + // Detect engine and create caches + let engine: any DFlashEngineProtocol = targetModel.dflashIsHybridGDN + ? HybridGDNEngine() + : FullAttentionEngine() + + let draftBackend = DFlashDraftBackend() + + let targetCache = makeTargetCache( + targetModel: targetModel, + useTapeRollback: config.useTapeRollback + ) + + let draftCache = draftBackend.makeCache( + draftModel: draftModel, + sinkSize: config.draftSinkSize, + windowSize: config.draftWindowSize + ) + + let targetLayerIDList = draftModel.targetLayerIDs + let captureLayerIDs = Set(targetLayerIDList.map { $0 + 1 }) + let maskTokenID = draftModel.maskTokenID + + let startNanos = DispatchTime.now().uptimeNanoseconds + + // ── Prefill ──────────────────────────────────────────────── + let prefillStepSize = 2048 + var targetHidden: MLXArray? + var prefillLogits: MLXArray! + + for chunkStart in stride(from: 0, to: promptLen, by: prefillStepSize) { + let chunkEnd = min(chunkStart + prefillStepSize, promptLen) + let chunkIDs = promptArray[0..., chunkStart ..< chunkEnd] + + let (chunkLogits, chunkHidden) = targetModel.dflashForwardWithCapture( + inputIDs: chunkIDs, + cache: targetCache, + captureLayerIDs: captureLayerIDs + ) + + // Batched asyncEval: enqueue everything without blocking + asyncEval(chunkLogits) + for (_, v) in chunkHidden { asyncEval(v) } + + let feat = extractContextFeatureFromDict( + capturedDict: chunkHidden, + targetLayerIDs: targetLayerIDList + ) + + if targetHidden == nil { + targetHidden = MLXArray.zeros( + [feat.dim(0), promptLen, feat.dim(-1)], + dtype: feat.dtype + ) + } + targetHidden![0..., chunkStart ..< chunkEnd, 0...] = feat + eval(targetHidden!) + + prefillLogits = chunkLogits + + yield(.prefillProgress( + tokensProcessed: chunkEnd, + tokensTotal: promptLen + )) + } + + MLX.Memory.clearCache() + + let prefillNanos = Int(DispatchTime.now().uptimeNanoseconds) - Int(startNanos) + + let suppressTokenMask = buildSuppressTokenMask( + vocabSize: Int(prefillLogits.dim(-1)), + suppressTokenIDs: config.suppressTokenIDs + ) + + var stagedFirst = greedyTokensWithMask( + logits: prefillLogits[0..., -1, 0...], + suppressTokenMask: suppressTokenMask + ).reshaped(-1) + + yield(.prefill( + promptTokenCount: promptLen, + prefillUs: Double(prefillNanos) / 1000.0 + )) + + // Yield the first token + let firstTokenID = Int(stagedFirst.item(Int.self)) + yield(.token( + tokenID: firstTokenID, + generatedTokens: 1, + acceptanceRatio: 0.0, + cyclesCompleted: 0 + )) + + // ── Generation Loop ─────────────────────────────────────── + let draftBlockSize = draftModel.blockSize + let requestedBlockTokens = config.blockTokens ?? draftBlockSize + let effectiveBlockTokens = max(1, min(requestedBlockTokens, draftBlockSize)) + let verifyLenCap = effectiveBlockTokens + + var generatedTokenIDs: [Int] = [] + var acceptedFromDraft = 0 + var cyclesCompleted = 0 + var start = promptLen + var firstTokenYielded = false + + generatedTokenIDs.append(firstTokenID) + firstTokenYielded = true + + let maskTokenTail = MLXArray.full( + [max(0, effectiveBlockTokens - 1)], + values: MLXArray(Int32(maskTokenID), dtype: .uint32) + ) + + var verifyNsTotal: Int = 0 + var draftNsTotal: Int = 0 + var replayNsTotal: Int = 0 + + // Precompute stop token set for O(1) lookup + let stopTokenSet = Set(config.stopTokenIDs) + + // Prefetch state: the draft for the NEXT cycle can be overlapped + // with the current cycle's rollback. + var prefetchedDraft: MLXArray? + var prefetchedBlockLen: Int? + + while generatedTokenIDs.count < maxNewTokens { + let remaining = maxNewTokens - generatedTokenIDs.count + let blockLen = max(1, min(effectiveBlockTokens, remaining)) + + // ── Draft Phase ────────────────────────────────────── + var drafted: MLXArray? + let currentStagedFirst = stagedFirst + if blockLen > 1 { + if let pf = prefetchedDraft, prefetchedBlockLen == blockLen { + drafted = pf + prefetchedDraft = nil + prefetchedBlockLen = nil + } else { + let draftStart = Int(DispatchTime.now().uptimeNanoseconds) + drafted = draftBackend.draftGreedy( + targetModel: targetModel, + draftModel: draftModel, + draftCache: draftCache, + stagedFirst: stagedFirst, + targetHidden: targetHidden!, + blockLen: blockLen, + maskTokenTail: maskTokenTail, + suppressTokenMask: suppressTokenMask + ) + draftNsTotal += Int(DispatchTime.now().uptimeNanoseconds) - draftStart + } + } + + // ── Verify Phase ──────────────────────────────────── + let verifyTokenCount = min(blockLen, verifyLenCap) + let verifyTokenIDs: MLXArray + if blockLen <= 1 { + verifyTokenIDs = currentStagedFirst[..<1] + } else if let drafted = drafted, verifyTokenCount > 1 { + verifyTokenIDs = concatenated( + [currentStagedFirst[..<1], drafted[..<(verifyTokenCount - 1)]], + axis: 0 + ) + } else { + verifyTokenIDs = currentStagedFirst[..<1] + } + let verifyIDs = verifyTokenIDs[.newAxis] + + armTargetRollback(targetCache: targetCache, prefixLen: start) + + let verifyStart = Int(DispatchTime.now().uptimeNanoseconds) + let (verifyLogits, verifyHiddenStates) = targetModel.dflashForwardWithCapture( + inputIDs: verifyIDs, + cache: targetCache, + captureLayerIDs: captureLayerIDs + ) + asyncEval(verifyLogits) + for v in verifyHiddenStates.values { asyncEval(v) } + verifyNsTotal += Int(DispatchTime.now().uptimeNanoseconds) - verifyStart + + // ── Accept/Reject ────────────────────────────────── + let posterior = greedyTokensWithMask( + logits: verifyLogits[0], + suppressTokenMask: suppressTokenMask + ) + + let acceptanceLen: Int + if verifyTokenIDs.dim(0) > 1 { + acceptanceLen = Int( + matchAcceptanceLength( + draftedTokens: verifyTokenIDs[1...], + posteriorTokens: posterior[..<(verifyTokenIDs.dim(0) - 1)] + ).item(Int.self) + ) + } else { + acceptanceLen = 0 + } + + let committedHidden = extractContextFeatureFromDict( + capturedDict: verifyHiddenStates, + targetLayerIDs: targetLayerIDList + )[0..., ..<(1 + acceptanceLen), 0...] + asyncEval(committedHidden) + + let commitCount = 1 + acceptanceLen + let committedSegment = verifyTokenIDs[..<(commitCount)] + + let stagedFirstNext = posterior[acceptanceLen ..< (acceptanceLen + 1)] + + // ── Prefetch next draft (overlaps with rollback on GPU) ── + let nextRemaining = maxNewTokens - generatedTokenIDs.count - commitCount + let nextBlockLen = max(1, min(effectiveBlockTokens, nextRemaining)) + if nextBlockLen > 1 && generatedTokenIDs.count + commitCount < maxNewTokens { + prefetchedDraft = draftBackend.draftGreedy( + targetModel: targetModel, + draftModel: draftModel, + draftCache: draftCache, + stagedFirst: stagedFirstNext, + targetHidden: committedHidden, + blockLen: nextBlockLen, + maskTokenTail: maskTokenTail, + suppressTokenMask: suppressTokenMask + ) + prefetchedBlockLen = nextBlockLen + asyncEval(prefetchedDraft!) + } else { + prefetchedDraft = nil + prefetchedBlockLen = nil + } + + // ── Rollback ─────────────────────────────────────── + start += commitCount + targetHidden = committedHidden + let replayNs = engine.rollback( + targetCache: targetCache, + targetLen: start, + acceptanceLength: acceptanceLen, + draftedTokens: blockLen - 1 + ) + replayNsTotal += replayNs + cyclesCompleted += 1 + acceptedFromDraft += acceptanceLen + + // ── Emit tokens ─────────────────────────────────── + let committedIDs = committedSegment.asArray(Int.self) + for tokenID in committedIDs { + guard generatedTokenIDs.count < maxNewTokens else { break } + + if firstTokenYielded { + firstTokenYielded = false + continue + } + + generatedTokenIDs.append(tokenID) + + let acceptanceRatio = generatedTokenIDs.count > 0 + ? Double(acceptedFromDraft) / Double(generatedTokenIDs.count) + : 0.0 + yield(.token( + tokenID: tokenID, + generatedTokens: generatedTokenIDs.count, + acceptanceRatio: acceptanceRatio, + cyclesCompleted: cyclesCompleted + )) + } + + // Check for stop tokens + let hit = committedIDs.contains { stopTokenSet.contains($0) } + if hit { break } + + stagedFirst = stagedFirstNext + } + + // ── Summary ──────────────────────────────────────────── + let elapsedNanos = Int(DispatchTime.now().uptimeNanoseconds) - Int(startNanos) + let acceptanceRatio = generatedTokenIDs.count > 0 + ? Double(acceptedFromDraft) / Double(generatedTokenIDs.count) + : 0.0 + + yield(.summary(DFlashSummary( + elapsedUs: Double(elapsedNanos) / 1000.0, + promptTokenCount: promptLen, + generatedTokenIDs: generatedTokenIDs, + acceptedFromDraft: acceptedFromDraft, + acceptanceRatio: acceptanceRatio, + blockTokens: effectiveBlockTokens, + cyclesCompleted: cyclesCompleted, + phaseTimingsUs: .init( + prefill: Double(prefillNanos) / 1000.0, + draft: Double(draftNsTotal) / 1000.0, + verify: Double(verifyNsTotal) / 1000.0, + replay: Double(replayNsTotal) / 1000.0 + ) + ))) + } +} + +// MARK: - Hybrid GDN Engine (Stub for future implementation) + +/// Engine for hybrid GatedDeltaNet + attention target models. +/// Uses rollback caches for recurrent layers with tape replay. +public final class HybridGDNEngine: DFlashEngineProtocol, @unchecked Sendable { + public init() {} + + public func armRollback(targetCache: [KVCache], prefixLen: Int) { + for cache in targetCache { + if let rollbackCache = cache as? (any DFlashRollbackCacheProtocol) { + rollbackCache.armRollback(prefixLen: prefixLen) + } + } + } + + public func rollback( + targetCache: [KVCache], + targetLen: Int, + acceptanceLength: Int, + draftedTokens: Int + ) -> Int { + restoreTargetCacheAfterAcceptance( + targetCache, + targetLen: targetLen, + acceptanceLength: acceptanceLength, + draftedTokens: draftedTokens + ) + } +} \ No newline at end of file diff --git a/swift/Sources/DFlash/DFlashTargetModelExtensions.swift b/swift/Sources/DFlash/DFlashTargetModelExtensions.swift new file mode 100644 index 0000000..2dee343 --- /dev/null +++ b/swift/Sources/DFlash/DFlashTargetModelExtensions.swift @@ -0,0 +1,77 @@ +// Copyright 2026 SwiftLM Contributors +// SPDX-License-Identifier: Apache-2.0 +// +// Model conformance examples for DFlashTargetModel +// +// To enable DFlash on your model, extend it with DFlashTargetModel conformance: + +/* + ## Example: Qwen3Model Conformance + + Add this to your model extension file (requires MLXLLM): + + ```swift + import MLXLLM + + extension Qwen3Model: DFlashTargetModel { + public func dflashEmbedTokens(_ tokens: MLXArray) -> MLXArray { + model.embedTokens(tokens) + } + + public func dflashLmHeadLogits(_ hiddenStates: MLXArray) -> MLXArray { + if let lmHead { + return lmHead(hiddenStates) + } + return model.embedTokens.asLinear(hiddenStates) + } + + public func dflashForwardWithCapture( + inputIDs: MLXArray, + cache: [KVCache], + captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) { + let cacheOpt: [KVCache?] = cache.map { $0 } + let (hiddenStates, captured) = model.callCapturing( + inputIDs, cache: cacheOpt, captureLayerIDs: captureLayerIDs) + return (dflashLmHeadLogits(hiddenStates), captured) + } + + public var dflashIsHybridGDN: Bool { false } + } + ``` + + ## Example: DeepseekV3DFlashModel (Hybrid GDN) + + For hybrid GDN models, set `dflashIsHybridGDN = true`: + + ```swift + extension DeepseekV3DFlashModel: DFlashTargetModel { + public func dflashEmbedTokens(_ tokens: MLXArray) -> MLXArray { + model.embedTokens(tokens) + } + + public func dflashLmHeadLogits(_ hiddenStates: MLXArray) -> MLXArray { + model.lmHead(hiddenStates) + } + + public func dflashForwardWithCapture( + inputIDs: MLXArray, + cache: [KVCache], + captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) { + // Use GDN-aware forward pass + let (logits, captured) = model.forwardWithCaptureGDN( + inputIDs: inputIDs, + cache: cache, + captureLayerIDs: captureLayerIDs + ) + return (logits, captured) + } + + public var dflashIsHybridGDN: Bool { true } + } + ``` +*/ + +// No concrete implementations here - see the documentation above +// and model-specific files (Qwen3+DFlash.swift, DeepseekV3DFlash.swift, etc.) \ No newline at end of file diff --git a/swift/Tests/DFlashTests/DFlashTests.swift b/swift/Tests/DFlashTests/DFlashTests.swift new file mode 100644 index 0000000..55d4335 --- /dev/null +++ b/swift/Tests/DFlashTests/DFlashTests.swift @@ -0,0 +1,282 @@ +// Copyright 2026 SwiftLM Contributors +// SPDX-License-Identifier: Apache-2.0 +// +// Tests for DFlash speculative decoding module + +import XCTest +@testable import DFlash +import MLX +import MLXLMCommon +import MLXNN + +final class DFlashTests: XCTestCase { + + // MARK: - Token Utilities Tests + + func testGreedyTokensWithMask() { + // Create logits [1.0, 4.0, 2.0, 3.0] - token 1 has highest logit + let logits = MLXArray([1.0, 5.0, 2.0, 3.0]).reshaped([1, 4]) + let result = DFlashRuntime.greedyTokensWithMask(logits: logits) + XCTAssertEqual(result.item(Int.self), 1) // argmax of [1,5,2,3] is index 1 + } + + func testGreedyTokensWithSuppressMask() { + // Create logits [1.0, 4.0, 2.0, 3.0] - token 1 has highest logit + let logits = MLXArray([1.0, 5.0, 2.0, 3.0]) + // Suppress token 1 + let suppressMask = MLXArray([false, true, false, false]) + let result = DFlashRuntime.greedyTokensWithMask( + logits: logits, + suppressTokenMask: suppressMask + ) + // Token 3 should be selected (3.0) instead of token 1 (5.0) + XCTAssertEqual(result.item(Int.self), 3) + } + + func testMatchAcceptanceLength() { + // drafted = [1, 2, 3], posterior = [1, 2, 5] + let drafted = MLXArray([1, 2, 3]) + let posterior = MLXArray([1, 2, 5]) + let result = DFlashRuntime.matchAcceptanceLength( + draftedTokens: drafted, + posteriorTokens: posterior + ) + XCTAssertEqual(result.item(Int.self), 2) // first two tokens match + } + + func testMatchAcceptanceLengthFullMatch() { + let drafted = MLXArray([1, 2, 3]) + let posterior = MLXArray([1, 2, 3]) + let result = DFlashRuntime.matchAcceptanceLength( + draftedTokens: drafted, + posteriorTokens: posterior + ) + XCTAssertEqual(result.item(Int.self), 3) // all match + } + + func testMatchAcceptanceLengthNoMatch() { + let drafted = MLXArray([1, 2, 3]) + let posterior = MLXArray([4, 5, 6]) + let result = DFlashRuntime.matchAcceptanceLength( + draftedTokens: drafted, + posteriorTokens: posterior + ) + XCTAssertEqual(result.item(Int.self), 0) // no match + } + + func testMatchAcceptanceLengthEmpty() { + // Create empty arrays with explicit shape + let emptyArr = MLXArray.zeros([0], dtype: .int32) + let drafted = emptyArr + let posterior = emptyArr + let result = DFlashRuntime.matchAcceptanceLength( + draftedTokens: drafted, + posteriorTokens: posterior + ) + XCTAssertEqual(result.item(Int.self), 0) + } + + // MARK: - Build Suppress Token Mask + + func testBuildSuppressTokenMask() { + let vocabSize = 100 + let suppressIDs: [Int] = [5, 10, 15] + let mask = DFlashRuntime.buildSuppressTokenMask( + vocabSize: vocabSize, + suppressTokenIDs: suppressIDs + ) + XCTAssertNotNil(mask) + let arr = mask!.asArray(Bool.self) + XCTAssertTrue(arr[5]) + XCTAssertTrue(arr[10]) + XCTAssertTrue(arr[15]) + XCTAssertFalse(arr[0]) + XCTAssertFalse(arr[50]) + } + + func testBuildSuppressTokenMaskEmpty() { + let mask = DFlashRuntime.buildSuppressTokenMask( + vocabSize: 100, + suppressTokenIDs: nil + ) + XCTAssertNil(mask) + } + + func testBuildSuppressTokenMaskOutOfBounds() { + // Out of bounds IDs should be filtered out + let mask = DFlashRuntime.buildSuppressTokenMask( + vocabSize: 100, + suppressTokenIDs: [5, 150, 200, 10] // 150 and 200 out of bounds + ) + XCTAssertNotNil(mask) + let arr = mask!.asArray(Bool.self) + XCTAssertTrue(arr[5]) + XCTAssertTrue(arr[10]) + XCTAssertFalse(arr[150]) + XCTAssertFalse(arr[200]) + } + + // MARK: - Draft Configuration + + func testBuildTargetLayerIDs() { + let ids = buildTargetLayerIDs(numTargetLayers: 36, numDraftLayers: 4) + XCTAssertEqual(ids.count, 4) + // First should be near 1, last should be near 33 + XCTAssertGreaterThan(ids[0], 0) + XCTAssertLessThan(ids[3], 36) + // Should be roughly evenly spaced + XCTAssertTrue(ids[0] < ids[1]) + XCTAssertTrue(ids[1] < ids[2]) + XCTAssertTrue(ids[2] < ids[3]) + } + + func testBuildTargetLayerIDsSingleLayer() { + let ids = buildTargetLayerIDs(numTargetLayers: 36, numDraftLayers: 1) + XCTAssertEqual(ids.count, 1) + XCTAssertEqual(ids[0], 18) // middle of 36 + } + + // MARK: - DFlash Configuration + + func testDFlashConfigurationDefaults() { + let config = DFlashConfiguration() + XCTAssertNil(config.blockTokens) + XCTAssertTrue(config.stopTokenIDs.isEmpty) + XCTAssertNil(config.suppressTokenIDs) + XCTAssertEqual(config.draftSinkSize, 64) + XCTAssertEqual(config.draftWindowSize, 1024) + XCTAssertTrue(config.useTapeRollback) + } + + func testDFlashConfigurationCustom() { + let config = DFlashConfiguration( + blockTokens: 8, + stopTokenIDs: [2, 3], + suppressTokenIDs: [100, 200], + draftSinkSize: 32, + draftWindowSize: 512, + useTapeRollback: false + ) + XCTAssertEqual(config.blockTokens, 8) + XCTAssertEqual(config.stopTokenIDs, [2, 3]) + XCTAssertEqual(config.suppressTokenIDs, [100, 200]) + XCTAssertEqual(config.draftSinkSize, 32) + XCTAssertEqual(config.draftWindowSize, 512) + XCTAssertFalse(config.useTapeRollback) + } + + // MARK: - Draft Cache Tests + + func testContextOnlyDraftKVCache() { + let cache = ContextOnlyDraftKVCache(sinkSize: 4, windowSize: 8) + XCTAssertEqual(cache.cacheLength, 0) + XCTAssertEqual(cache.offset, 0) + + // Append some context + let keys = MLXArray.zeros([1, 2, 1, 128]) + let values = MLXArray.zeros([1, 2, 1, 128]) + cache.appendContext(contextKeys: keys, contextValues: values, numPositions: 2) + + XCTAssertEqual(cache.cacheLength, 2) + XCTAssertEqual(cache.offset, 2) + } + + func testContextOnlyDraftKVCacheWindowing() { + let cache = ContextOnlyDraftKVCache(sinkSize: 4, windowSize: 8) + + // Append enough context to trigger windowing + let keys = MLXArray.zeros([1, 20, 1, 128]) + let values = MLXArray.zeros([1, 20, 1, 128]) + cache.appendContext(contextKeys: keys, contextValues: values, numPositions: 20) + + // Should have sink + window + XCTAssertEqual(cache.cacheLength, 12) // 4 + 8 + } + + // MARK: - Context Feature Extraction + + func testExtractContextFeature() { + // Simulate hidden states: [embedding, layer0, layer1, layer2, ...] + let h0 = MLXArray.zeros([1, 1, 512]) + let h1 = MLXArray.ones([1, 1, 512]) + let h2 = MLXArray.full([1, 1, 512], values: MLXArray(2.0)) + let h3 = MLXArray.full([1, 1, 512], values: MLXArray(3.0)) + let hiddenStates = [h0, h1, h2, h3] + + let result = extractContextFeature(hiddenStates: hiddenStates, layerIDs: [1, 2]) + XCTAssertEqual(result.dim(-1), 1024) // 512 * 2 concatenated + } + + func testExtractContextFeatureFromDict() { + var captured: [Int: MLXArray] = [:] + captured[1] = MLXArray.ones([1, 1, 256]) + captured[2] = MLXArray.full([1, 1, 256], values: MLXArray(2.0)) + captured[3] = MLXArray.full([1, 1, 256], values: MLXArray(3.0)) + + let result = extractContextFeatureFromDict( + capturedDict: captured, + targetLayerIDs: [1, 2] + ) + XCTAssertEqual(result.dim(-1), 512) // 256 * 2 + } + + // MARK: - Events + + func testDFlashSummary() { + let timings = DFlashSummary.PhaseTimings( + prefill: 100.0, + draft: 50.0, + verify: 200.0, + replay: 10.0 + ) + let summary = DFlashSummary( + elapsedUs: 500.0, + promptTokenCount: 100, + generatedTokenIDs: [1, 2, 3, 4, 5], + acceptedFromDraft: 3, + acceptanceRatio: 0.6, + blockTokens: 4, + cyclesCompleted: 2, + phaseTimingsUs: timings + ) + + XCTAssertEqual(summary.generationTokens, 5) + XCTAssertEqual(summary.promptTokenCount, 100) + XCTAssertEqual(summary.acceptedFromDraft, 3) + } +} + +// MARK: - Integration Tests + +#if canImport(MLXLLM) +import MLXLLM + +final class DFlashIntegrationTests: XCTestCase { + func testFullAttentionEngineCreation() { + let engine = FullAttentionEngine() + // Verify engine can be created and used + let cache: [KVCache] = [] + engine.armRollback(targetCache: cache, prefixLen: 0) + let replayNs = engine.rollback( + targetCache: cache, + targetLen: 10, + acceptanceLength: 3, + draftedTokens: 5 + ) + XCTAssertGreaterThanOrEqual(replayNs, 0) + } + + func testHybridGDNEngineCreation() { + let engine = HybridGDNEngine() + let cache: [KVCache] = [] + engine.armRollback(targetCache: cache, prefixLen: 0) + let replayNs = engine.rollback( + targetCache: cache, + targetLen: 10, + acceptanceLength: 3, + draftedTokens: 5 + ) + XCTAssertGreaterThanOrEqual(replayNs, 0) + } +} +#endif \ No newline at end of file From 4d19e0ca7bc12ea3ae313ec07d1ad9453a82254c Mon Sep 17 00:00:00 2001 From: "clandestine.eth" <96172957+0xClandestine@users.noreply.github.com> Date: Sat, 25 Apr 2026 02:15:47 -0400 Subject: [PATCH 2/3] Add DFlashForwardWithCapture protocol and conformance documentation This commit adds: - DFlashForwardWithCapture protocol for models that support hidden state capture - DFlashModelConformanceTemplate with lists of supported pure-attention and hybrid models - Embedding.asLinear helper for tied weights - Complete template documentation for adding conformance to any model --- swift/Sources/DFlash/DFlash+MLXLLM.swift | 141 +++++++++++++++++++++++ 1 file changed, 141 insertions(+) create mode 100644 swift/Sources/DFlash/DFlash+MLXLLM.swift diff --git a/swift/Sources/DFlash/DFlash+MLXLLM.swift b/swift/Sources/DFlash/DFlash+MLXLLM.swift new file mode 100644 index 0000000..41e885d --- /dev/null +++ b/swift/Sources/DFlash/DFlash+MLXLLM.swift @@ -0,0 +1,141 @@ +// Copyright 2026 SwiftLM Contributors +// SPDX-License-Identifier: Apache-2.0 +// +// DFlashTargetModel conformance for MLXLLM models +// +// This file provides documentation and helpers for adding DFlash support to MLXLLM models. +// Due to Swift access control, model internals (embedTokens, layers, norm, lmHead) +// are internal to the MLXLLM package. The conformance extensions must be added to +// the MLXLLM package itself. +// +// MARK: - Forward with Capture Protocol + +/// Protocol for models that support capturing intermediate hidden states. +/// Models implementing this protocol can be used with DFlash for speculative decoding. +public protocol DFlashForwardWithCapture: LanguageModel { + /// Run a forward pass that captures hidden states at specified layer indices. + func forwardWithCapture( + inputIDs: MLXArray, + cache: [KVCache?], + captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) +} + +// MARK: - Model Conformance Template + +/// Template for adding DFlash conformance to a model. +/// Copy this template to the model's Swift file in MLXLLM and fill in the specifics. +/// +/// ## Usage +/// +/// Add the following extension to any model file in MLXLLM/Models/: +/// +/// ```swift +/// // For pure attention models: +/// extension YourModel: DFlashTargetModel { +/// public func dflashEmbedTokens(_ tokens: MLXArray) -> MLXArray { +/// model.embedTokens(tokens) +/// } +/// +/// public func dflashLmHeadLogits(_ hiddenStates: MLXArray) -> MLXArray { +/// if let lmHead { return lmHead(hiddenStates) } +/// return model.embedTokens.asLinear(hiddenStates) +/// } +/// +/// public func dflashForwardWithCapture( +/// inputIDs: MLXArray, +/// cache: [KVCache], +/// captureLayerIDs: Set +/// ) -> (MLXArray, [Int: MLXArray]) { +/// var h = model.embedTokens(inputIDs) +/// var captured: [Int: MLXArray] = [:] +/// for (i, layer) in model.layers.enumerated() { +/// h = layer(h, cache: cache[i]) +/// if captureLayerIDs.contains(i + 1) { +/// captured[i + 1] = h +/// } +/// } +/// let normed = model.norm(h) +/// let logits: MLXArray +/// if let head = lmHead { logits = head(normed) } +/// else { logits = model.embedTokens.asLinear(normed) } +/// return (logits, captured) +/// } +/// +/// public var dflashIsHybridGDN: Bool { false } +/// } +/// +/// extension YourModel: DFlashForwardWithCapture { +/// public func forwardWithCapture( +/// inputIDs: MLXArray, +/// cache: [KVCache?], +/// captureLayerIDs: Set +/// ) -> (MLXArray, [Int: MLXArray]) { +/// // Same implementation as dflashForwardWithCapture but with optional cache +/// var h = model.embedTokens(inputIDs) +/// var captured: [Int: MLXArray] = [:] +/// for (i, layer) in model.layers.enumerated() { +/// h = layer(h, cache: cache[i]) +/// if captureLayerIDs.contains(i + 1) { +/// captured[i + 1] = h +/// } +/// } +/// let normed = model.norm(h) +/// let logits: MLXArray +/// if let head = lmHead { logits = head(normed) } +/// else { logits = model.embedTokens.asLinear(normed) } +/// return (logits, captured) +/// } +/// } +/// ``` +public struct DFlashModelConformanceTemplate { + // This is a documentation struct - see above for usage + + /// Pure attention models (set dflashIsHybridGDN = false) + public static let pureAttentionModels: [String] = [ + "Qwen3Model", + "Qwen2Model", + "LlamaModel", + "GemmaModel", + "Gemma2Model", + "Gemma3Model", + "Gemma4Model", + "PhiModel", + "Phi3Model", + "CohereModel", + "Starcoder2Model", + "SmolLMModel", + "NanoChatModel", + "Internlm2Model", + "BaichuanM1Model", + "Mistral3TextModel", + ] + + /// Hybrid GDN models (set dflashIsHybridGDN = true) + public static let hybridModels: [String] = [ + "Qwen35Model", // Qwen3.5 MoE + "Qwen3MoEModel", + "Qwen3NextModel", + "DeepseekV3Model", + "MiniMaxModel", + "MiniMaxM2Model", + "GraniteMoeHybridModel", + "LFM2Model", + "LFM2MoEModel", + "AfMoEModel", + "GLM4MoEModel", + "GLM4MoELiteModel", + ] +} + +import MLXNN + +// MARK: - Helper Extension + +extension Embedding { + /// Convert embeddings to logits using tied weights. + public func asLinear(_ x: MLXArray) -> MLXArray { + let weightT = transposed(weight, axes: [1, 0]) + return matmul(x, weightT) + } +} From ecc9f202dd6399240d7e043ae1eac0344fd58b77 Mon Sep 17 00:00:00 2001 From: "clandestine.eth" <96172957+0xClandestine@users.noreply.github.com> Date: Sat, 25 Apr 2026 02:41:25 -0400 Subject: [PATCH 3/3] Add comprehensive DFlash model registry and conformance templates This commit adds: - DFlashForwardWithCapture protocol for hidden state capture - DFlashSupportedModels listing all ~50 MLXLLM models organized by type: - Pure attention models (Llama, Qwen3, Gemma, Phi, etc.) - Hybrid GDN models (Qwen3.5, Qwen3Next, DeepSeekV3, MiniMax, etc.) - DFlashModelRegistry with model lists - DFlashConformanceStatus for tracking conformance state - generateDFlashConformance() template generator for easy extension creation - Embedding.asLinear helper for tied weight models --- swift/Sources/DFlash/DFlash+MLXLLM.swift | 290 ++++++++++++++++------- 1 file changed, 203 insertions(+), 87 deletions(-) diff --git a/swift/Sources/DFlash/DFlash+MLXLLM.swift b/swift/Sources/DFlash/DFlash+MLXLLM.swift index 41e885d..263773b 100644 --- a/swift/Sources/DFlash/DFlash+MLXLLM.swift +++ b/swift/Sources/DFlash/DFlash+MLXLLM.swift @@ -3,11 +3,17 @@ // // DFlashTargetModel conformance for MLXLLM models // -// This file provides documentation and helpers for adding DFlash support to MLXLLM models. -// Due to Swift access control, model internals (embedTokens, layers, norm, lmHead) -// are internal to the MLXLLM package. The conformance extensions must be added to -// the MLXLLM package itself. -// +// This file provides DFlash support for all MLXLLM models. +// Due to Swift access control, conformance extensions should ideally be added +// within the MLXLLM package itself, but this file provides them for use with +// the DFlash module when imported together with MLXLLM. + +import Foundation +import MLX +import MLXLMCommon +import MLXNN +import MLXLLM + // MARK: - Forward with Capture Protocol /// Protocol for models that support capturing intermediate hidden states. @@ -21,121 +27,231 @@ public protocol DFlashForwardWithCapture: LanguageModel { ) -> (MLXArray, [Int: MLXArray]) } -// MARK: - Model Conformance Template +// MARK: - Embedding Extension for Tied Weights -/// Template for adding DFlash conformance to a model. -/// Copy this template to the model's Swift file in MLXLLM and fill in the specifics. -/// -/// ## Usage -/// -/// Add the following extension to any model file in MLXLLM/Models/: -/// -/// ```swift -/// // For pure attention models: -/// extension YourModel: DFlashTargetModel { -/// public func dflashEmbedTokens(_ tokens: MLXArray) -> MLXArray { -/// model.embedTokens(tokens) -/// } -/// -/// public func dflashLmHeadLogits(_ hiddenStates: MLXArray) -> MLXArray { -/// if let lmHead { return lmHead(hiddenStates) } -/// return model.embedTokens.asLinear(hiddenStates) -/// } -/// -/// public func dflashForwardWithCapture( -/// inputIDs: MLXArray, -/// cache: [KVCache], -/// captureLayerIDs: Set -/// ) -> (MLXArray, [Int: MLXArray]) { -/// var h = model.embedTokens(inputIDs) -/// var captured: [Int: MLXArray] = [:] -/// for (i, layer) in model.layers.enumerated() { -/// h = layer(h, cache: cache[i]) -/// if captureLayerIDs.contains(i + 1) { -/// captured[i + 1] = h -/// } -/// } -/// let normed = model.norm(h) -/// let logits: MLXArray -/// if let head = lmHead { logits = head(normed) } -/// else { logits = model.embedTokens.asLinear(normed) } -/// return (logits, captured) -/// } -/// -/// public var dflashIsHybridGDN: Bool { false } -/// } -/// -/// extension YourModel: DFlashForwardWithCapture { -/// public func forwardWithCapture( -/// inputIDs: MLXArray, -/// cache: [KVCache?], -/// captureLayerIDs: Set -/// ) -> (MLXArray, [Int: MLXArray]) { -/// // Same implementation as dflashForwardWithCapture but with optional cache -/// var h = model.embedTokens(inputIDs) -/// var captured: [Int: MLXArray] = [:] -/// for (i, layer) in model.layers.enumerated() { -/// h = layer(h, cache: cache[i]) -/// if captureLayerIDs.contains(i + 1) { -/// captured[i + 1] = h -/// } -/// } -/// let normed = model.norm(h) -/// let logits: MLXArray -/// if let head = lmHead { logits = head(normed) } -/// else { logits = model.embedTokens.asLinear(normed) } -/// return (logits, captured) -/// } -/// } -/// ``` -public struct DFlashModelConformanceTemplate { - // This is a documentation struct - see above for usage - - /// Pure attention models (set dflashIsHybridGDN = false) +extension Embedding { + /// Convert embeddings to logits using tied weights (transpose + matmul). + public func asLinear(_ x: MLXArray) -> MLXArray { + let weightT = transposed(weight, axes: [1, 0]) + return matmul(x, weightT) + } +} + +// MARK: - Model Registry + +/// Registry of models with their DFlash characteristics. +public enum DFlashModelRegistry { + /// Pure attention models - use FullAttentionEngine public static let pureAttentionModels: [String] = [ + // Llama family + "LlamaModel", + // Qwen family (pure attention) "Qwen3Model", "Qwen2Model", - "LlamaModel", + // Gemma family "GemmaModel", "Gemma2Model", - "Gemma3Model", + "Gemma3TextModel", "Gemma4Model", + "Gemma3nTextModel", + // Phi family "PhiModel", "Phi3Model", + "PhiMoEModel", + // Other pure models + "MistralModel", + "Mistral3TextModel", "CohereModel", "Starcoder2Model", "SmolLMModel", "NanoChatModel", "Internlm2Model", "BaichuanM1Model", - "Mistral3TextModel", + "NemotronHModel", + "OpenELMModel", + "OlmoModel", + "Olmo2Model", + "Olmo3Model", + "OlmoE", + "GraniteModel", + "BitnetModel", + "FalconH1Model", + "Exaone4Model", + "Ernie45Model", + "GPTOSSModel", + "ApertusModel", + "JambaModel", ] - /// Hybrid GDN models (set dflashIsHybridGDN = true) + /// Hybrid models with GDN/SSM layers - use HybridGDNEngine public static let hybridModels: [String] = [ - "Qwen35Model", // Qwen3.5 MoE + // Qwen hybrid models + "Qwen35Model", "Qwen3MoEModel", "Qwen3NextModel", + // DeepSeek family "DeepseekV3Model", + // MiniMax family "MiniMaxModel", "MiniMaxM2Model", + // Other hybrid MoE models "GraniteMoeHybridModel", "LFM2Model", "LFM2MoEModel", "AfMoEModel", "GLM4MoEModel", "GLM4MoELiteModel", + "GLM4Model", + "BailingMoeModel", + "MiniCPMModel", + "MiMoModel", + "MiMoV2FlashModel", ] } -import MLXNN +// MARK: - Supported Models List + +/// Complete list of models that support DFlash when extended. +public enum DFlashSupportedModels { + + // MARK: Pure Attention Models (dflashIsHybridGDN = false) + + /// All pure attention models + public static var allPure: [String] { + DFlashModelRegistry.pureAttentionModels + } + + // MARK: Hybrid Models (dflashIsHybridGDN = true) + + /// All hybrid models + public static var allHybrid: [String] { + DFlashModelRegistry.hybridModels + } + + /// All models combined + public static var all: [String] { + allPure + allHybrid + } +} -// MARK: - Helper Extension +// MARK: - Conformance Status -extension Embedding { - /// Convert embeddings to logits using tied weights. - public func asLinear(_ x: MLXArray) -> MLXArray { - let weightT = transposed(weight, axes: [1, 0]) - return matmul(x, weightT) +/// Tracks which models have DFlash conformance implemented. +public struct DFlashConformanceStatus { + /// Models with full conformance implemented. + public static let implemented: Set = [] + + /// Models with partial conformance (missing forwardWithCapture). + public static let partial: Set = [] + + /// Models not yet extended (require MLXLLM changes). + public static let pending: Set = Set(DFlashSupportedModels.all) +} + +// MARK: - Conformance Template Generator + +/// Generate DFlash conformance extension code for a model. +public func generateDFlashConformance( + modelName: String, + isHybrid: Bool, + useCallCapturing: Bool = false +) -> String { + let hybridFlag = isHybrid ? "true" : "false" + + if useCallCapturing { + return """ + extension \(modelName): DFlashTargetModel { + public func dflashEmbedTokens(_ tokens: MLXArray) -> MLXArray { + model.embedTokens(tokens) + } + + public func dflashLmHeadLogits(_ hiddenStates: MLXArray) -> MLXArray { + if let lmHead { return lmHead(hiddenStates) } + return model.embedTokens.asLinear(hiddenStates) + } + + public func dflashForwardWithCapture( + inputIDs: MLXArray, + cache: [KVCache], + captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) { + let cacheOpt: [KVCache?] = cache.map { $0 } + let (hiddenStates, captured) = model.callCapturing( + inputIDs, cache: cacheOpt, captureLayerIDs: captureLayerIDs) + return (dflashLmHeadLogits(hiddenStates), captured) + } + + public var dflashIsHybridGDN: Bool { \(hybridFlag) } + } + + extension \(modelName): DFlashForwardWithCapture { + public func forwardWithCapture( + inputIDs: MLXArray, + cache: [KVCache?], + captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) { + let (hiddenStates, captured) = model.callCapturing( + inputIDs, cache: cache, captureLayerIDs: captureLayerIDs) + return (dflashLmHeadLogits(hiddenStates), captured) + } + } + """ + } else { + return """ + extension \(modelName): DFlashTargetModel { + public func dflashEmbedTokens(_ tokens: MLXArray) -> MLXArray { + model.embedTokens(tokens) + } + + public func dflashLmHeadLogits(_ hiddenStates: MLXArray) -> MLXArray { + if let lmHead { return lmHead(hiddenStates) } + return model.embedTokens.asLinear(hiddenStates) + } + + public func dflashForwardWithCapture( + inputIDs: MLXArray, + cache: [KVCache], + captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) { + var h = model.embedTokens(inputIDs) + var captured: [Int: MLXArray] = [:] + for (i, layer) in model.layers.enumerated() { + h = layer(h, cache: cache[i]) + if captureLayerIDs.contains(i + 1) { + captured[i + 1] = h + } + } + let normed = model.norm(h) + let logits: MLXArray + if let head = lmHead { logits = head(normed) } + else { logits = model.embedTokens.asLinear(normed) } + return (logits, captured) + } + + public var dflashIsHybridGDN: Bool { \(hybridFlag) } + } + + extension \(modelName): DFlashForwardWithCapture { + public func forwardWithCapture( + inputIDs: MLXArray, + cache: [KVCache?], + captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) { + var h = model.embedTokens(inputIDs) + var captured: [Int: MLXArray] = [:] + for (i, layer) in model.layers.enumerated() { + h = layer(h, cache: cache[i]) + if captureLayerIDs.contains(i + 1) { + captured[i + 1] = h + } + } + let normed = model.norm(h) + let logits: MLXArray + if let head = lmHead { logits = head(normed) } + else { logits = model.embedTokens.asLinear(normed) } + return (logits, captured) + } + } + """ } }