diff --git a/Sources/WhisperKit/Core/TextDecoder.swift b/Sources/WhisperKit/Core/TextDecoder.swift
index f0e8219..85fc33f 100644
--- a/Sources/WhisperKit/Core/TextDecoder.swift
+++ b/Sources/WhisperKit/Core/TextDecoder.swift
@@ -308,45 +308,49 @@ public extension TextDecoding {
         return kvCache
     }
 
-    static func updateKVCache(keyTensor: MLMultiArray, keySlice: MLMultiArray,
-                              valueTensor: MLMultiArray, valueSlice: MLMultiArray,
-                              insertAtIndex index: Int)
-    {
-        let tensorShape = keyTensor.shape.map { $0.intValue }
-        let sliceShape = keySlice.shape.map { $0.intValue }
-        let sliceStrides = keySlice.strides.map { $0.intValue } // same for val
-        let bytesPerSample = MemoryLayout<FloatType>.size
-
-        keyTensor.withUnsafeMutableBytes { keyTensorPointer, keyTargetStrides in
-            keySlice.withUnsafeBytes { keySlicePointer in
-                valueTensor.withUnsafeMutableBytes { valueTensorPointer, valueTargetStrides in
-                    valueSlice.withUnsafeBytes { valueSlicePointer in
-                        // Assuming batch size is always 1
-                        DispatchQueue.concurrentPerform(iterations: tensorShape[1]) { j in
-                            // Slice size is 3 for prefill and 1 for decode loops
-                            for k in 0..<sliceShape[3] {
-                                // Equivalent to:
-                                // `tensor[0, j, 0, k + index] = slice[0, j, 0, k + index]`
-                                let keyDestIndex = j * keyTargetStrides[1] + (index + k) * keyTargetStrides[3]
-                                let keyDest = keyTensorPointer.baseAddress! + keyDestIndex * bytesPerSample
-
-                                let keySliceIndex = j * sliceStrides[1] + k * sliceStrides[3]
-                                let keySlice = keySlicePointer.baseAddress! + keySliceIndex * bytesPerSample
-                                memcpy(keyDest, keySlice, bytesPerSample)
-
-                                let valDestIndex = j * valueTargetStrides[1] + (index + k) * valueTargetStrides[3]
-                                let valDest = valueTensorPointer.baseAddress! + valDestIndex * bytesPerSample
-
-                                let valSliceIndex = j * sliceStrides[1] + k * sliceStrides[3]
-                                let valSlice = valueSlicePointer.baseAddress! + valSliceIndex * bytesPerSample
-                                memcpy(valDest, valSlice, bytesPerSample)
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    }
+  static func updateKVCache(keyTensor: MLMultiArray, keySlice: MLMultiArray,
+                           valueTensor: MLMultiArray, valueSlice: MLMultiArray,
+                           insertAtIndex index: Int)
+  {
+      let tensorShape = keyTensor.shape.map { $0.intValue }
+      let sliceShape = keySlice.shape.map { $0.intValue }
+
+      // Create flat arrays for safe concurrent access
+      var keyData = [FloatType](repeating: 0, count: keyTensor.count)
+      var valueData = [FloatType](repeating: 0, count: valueTensor.count)
+
+      // Get current tensor data
+      memcpy(&keyData, keyTensor.dataPointer, keyTensor.count * MemoryLayout<FloatType>.size)
+      memcpy(&valueData, valueTensor.dataPointer, valueTensor.count * MemoryLayout<FloatType>.size)
+
+      // Calculate dimensions for index mapping
+      let seqLength = tensorShape[3]
+      let hiddenDim = tensorShape[1]
+
+      // Concurrent processing across hidden dimension
+      DispatchQueue.concurrentPerform(iterations: hiddenDim) { j in
+          for k in 0..<sliceShape[3] {
+              // Calculate linear indices
+              let targetSeqPos = index + k
+              guard targetSeqPos < seqLength else { continue }
+
+              // Map 4D indices [0, j, 0, index+k] to linear index
+              let flatKeyIndex = j * seqLength + targetSeqPos
+              let flatSliceIndex = j * sliceShape[3] + k
+
+              // Copy from slice to tensor
+              let sliceKeyPtr = keySlice.dataPointer.assumingMemoryBound(to: FloatType.self)
+              let sliceValuePtr = valueSlice.dataPointer.assumingMemoryBound(to: FloatType.self)
+
+              keyData[flatKeyIndex] = sliceKeyPtr[flatSliceIndex]
+              valueData[flatKeyIndex] = sliceValuePtr[flatSliceIndex]
+          }
+      }
+
+      // Copy data back to tensors
+      memcpy(keyTensor.dataPointer, &keyData, keyTensor.count * MemoryLayout<FloatType>.size)
+      memcpy(valueTensor.dataPointer, &valueData, valueTensor.count * MemoryLayout<FloatType>.size)
+  }
 
     static func updateAlignmentWeights(
         alignmentTensor: MLMultiArray,