diff --git a/Sources/WhisperKit/Core/TextDecoder.swift b/Sources/WhisperKit/Core/TextDecoder.swift index f0e8219..85fc33f 100644 --- a/Sources/WhisperKit/Core/TextDecoder.swift +++ b/Sources/WhisperKit/Core/TextDecoder.swift @@ -308,45 +308,49 @@ public extension TextDecoding { return kvCache } - static func updateKVCache(keyTensor: MLMultiArray, keySlice: MLMultiArray, - valueTensor: MLMultiArray, valueSlice: MLMultiArray, - insertAtIndex index: Int) - { - let tensorShape = keyTensor.shape.map { $0.intValue } - let sliceShape = keySlice.shape.map { $0.intValue } - let sliceStrides = keySlice.strides.map { $0.intValue } // same for val - let bytesPerSample = MemoryLayout<FloatType>.size - - keyTensor.withUnsafeMutableBytes { keyTensorPointer, keyTargetStrides in - keySlice.withUnsafeBytes { keySlicePointer in - valueTensor.withUnsafeMutableBytes { valueTensorPointer, valueTargetStrides in - valueSlice.withUnsafeBytes { valueSlicePointer in - // Assuming batch size is always 1 - DispatchQueue.concurrentPerform(iterations: tensorShape[1]) { j in - // Slice size is 3 for prefill and 1 for decode loops - for k in 0..<sliceShape[3] { - // Equivalent to: - // `tensor[0, j, 0, k + index] = slice[0, j, 0, k + index]` - let keyDestIndex = j * keyTargetStrides[1] + (index + k) * keyTargetStrides[3] - let keyDest = keyTensorPointer.baseAddress! + keyDestIndex * bytesPerSample - - let keySliceIndex = j * sliceStrides[1] + k * sliceStrides[3] - let keySlice = keySlicePointer.baseAddress! + keySliceIndex * bytesPerSample - memcpy(keyDest, keySlice, bytesPerSample) - - let valDestIndex = j * valueTargetStrides[1] + (index + k) * valueTargetStrides[3] - let valDest = valueTensorPointer.baseAddress! + valDestIndex * bytesPerSample - - let valSliceIndex = j * sliceStrides[1] + k * sliceStrides[3] - let valSlice = valueSlicePointer.baseAddress! + valSliceIndex * bytesPerSample - memcpy(valDest, valSlice, bytesPerSample) - } - } - } - } - } - } - } + static func updateKVCache(keyTensor: MLMultiArray, keySlice: MLMultiArray, + valueTensor: MLMultiArray, valueSlice: MLMultiArray, + insertAtIndex index: Int) + { + let tensorShape = keyTensor.shape.map { $0.intValue } + let sliceShape = keySlice.shape.map { $0.intValue } + + // Create flat arrays for safe concurrent access + var keyData = [FloatType](repeating: 0, count: keyTensor.count) + var valueData = [FloatType](repeating: 0, count: valueTensor.count) + + // Get current tensor data + memcpy(&keyData, keyTensor.dataPointer, keyTensor.count * MemoryLayout<FloatType>.size) + memcpy(&valueData, valueTensor.dataPointer, valueTensor.count * MemoryLayout<FloatType>.size) + + // Calculate dimensions for index mapping + let seqLength = tensorShape[3] + let hiddenDim = tensorShape[1] + + // Concurrent processing across hidden dimension + DispatchQueue.concurrentPerform(iterations: hiddenDim) { j in + for k in 0..<sliceShape[3] { + // Calculate linear indices + let targetSeqPos = index + k + guard targetSeqPos < seqLength else { continue } + + // Map 4D indices [0, j, 0, index+k] to linear index + let flatKeyIndex = j * seqLength + targetSeqPos + let flatSliceIndex = j * sliceShape[3] + k + + // Copy from slice to tensor + let sliceKeyPtr = keySlice.dataPointer.assumingMemoryBound(to: FloatType.self) + let sliceValuePtr = valueSlice.dataPointer.assumingMemoryBound(to: FloatType.self) + + keyData[flatKeyIndex] = sliceKeyPtr[flatSliceIndex] + valueData[flatKeyIndex] = sliceValuePtr[flatSliceIndex] + } + } + + // Copy data back to tensors + memcpy(keyTensor.dataPointer, &keyData, keyTensor.count * MemoryLayout<FloatType>.size) + memcpy(valueTensor.dataPointer, &valueData, valueTensor.count * MemoryLayout<FloatType>.size) + } static func updateAlignmentWeights( alignmentTensor: MLMultiArray,