Skip to content

Commit e3e21d4

Browse files
ZachNagengastkeleftherioua2they
authored
Example app VAD default + memory reduction (#217)
* Release memory when transcribing single files Co-authored-by: keleftheriou <[email protected]> * Add method to load from file into float array iteratively - Reduces peak memory by doing the array conversion while loading in chunks so the array copy size is lower - Previously copied the entire buffer which spiked the memory 2x * Fix leak * Use vad by default in examples * Fix vad thread issue * Fix unused warning * Revert change to early stop callback * Fix warnings - Optional cli commands are deprecated - @_disfavoredOverload required @available to prevent infinite loop * PR review - simplify early stop test logic Co-authored-by: Andrey Leonov <[email protected]> * Cleanup from review --------- Co-authored-by: keleftheriou <[email protected]> Co-authored-by: Andrey Leonov <[email protected]>
1 parent bfb1316 commit e3e21d4

File tree

9 files changed

+117
-55
lines changed

9 files changed

+117
-55
lines changed

Examples/WhisperAX/WhisperAX/Views/ContentView.swift

+8-3
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ struct ContentView: View {
5151
@AppStorage("silenceThreshold") private var silenceThreshold: Double = 0.3
5252
@AppStorage("useVAD") private var useVAD: Bool = true
5353
@AppStorage("tokenConfirmationsNeeded") private var tokenConfirmationsNeeded: Double = 2
54-
@AppStorage("chunkingStrategy") private var chunkingStrategy: ChunkingStrategy = .none
54+
@AppStorage("concurrentWorkerCount") private var concurrentWorkerCount: Int = 4
55+
@AppStorage("chunkingStrategy") private var chunkingStrategy: ChunkingStrategy = .vad
5556
@AppStorage("encoderComputeUnits") private var encoderComputeUnits: MLComputeUnits = .cpuAndNeuralEngine
5657
@AppStorage("decoderComputeUnits") private var decoderComputeUnits: MLComputeUnits = .cpuAndNeuralEngine
5758

@@ -1269,12 +1270,15 @@ struct ContentView: View {
12691270

12701271
func transcribeCurrentFile(path: String) async throws {
12711272
// Load and convert buffer in a limited scope
1273+
Logging.debug("Loading audio file: \(path)")
1274+
let loadingStart = Date()
12721275
let audioFileSamples = try await Task {
12731276
try autoreleasepool {
1274-
let audioFileBuffer = try AudioProcessor.loadAudio(fromPath: path)
1275-
return AudioProcessor.convertBufferToArray(buffer: audioFileBuffer)
1277+
return try AudioProcessor.loadAudioAsFloatArray(fromPath: path)
12761278
}
12771279
}.value
1280+
Logging.debug("Loaded audio file in \(Date().timeIntervalSince(loadingStart)) seconds")
1281+
12781282

12791283
let transcription = try await transcribeAudioSamples(audioFileSamples)
12801284

@@ -1316,6 +1320,7 @@ struct ContentView: View {
13161320
withoutTimestamps: !enableTimestamps,
13171321
wordTimestamps: true,
13181322
clipTimestamps: seekClip,
1323+
concurrentWorkerCount: concurrentWorkerCount,
13191324
chunkingStrategy: chunkingStrategy
13201325
)
13211326

Sources/WhisperKit/Core/Audio/AudioChunker.swift

-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ open class VADAudioChunker: AudioChunking {
8181
// Typically this will be the full audio file, unless seek points are explicitly provided
8282
var startIndex = seekClipStart
8383
while startIndex < seekClipEnd - windowPadding {
84-
let currentFrameLength = audioArray.count
8584
guard startIndex >= 0 && startIndex < audioArray.count else {
8685
throw WhisperError.audioProcessingFailed("startIndex is outside the buffer size")
8786
}

Sources/WhisperKit/Core/Audio/AudioProcessor.swift

+58-9
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,6 @@ public extension AudioProcessing {
9393
}
9494

9595
static func padOrTrimAudio(fromArray audioArray: [Float], startAt startIndex: Int = 0, toLength frameLength: Int = 480_000, saveSegment: Bool = false) -> MLMultiArray? {
96-
let currentFrameLength = audioArray.count
97-
9896
guard startIndex >= 0 && startIndex < audioArray.count else {
9997
Logging.error("startIndex is outside the buffer size")
10098
return nil
@@ -197,7 +195,15 @@ public class AudioProcessor: NSObject, AudioProcessing {
197195

198196
let audioFileURL = URL(fileURLWithPath: audioFilePath)
199197
let audioFile = try AVAudioFile(forReading: audioFileURL, commonFormat: .pcmFormatFloat32, interleaved: false)
198+
return try loadAudio(fromFile: audioFile, startTime: startTime, endTime: endTime, maxReadFrameSize: maxReadFrameSize)
199+
}
200200

201+
public static func loadAudio(
202+
fromFile audioFile: AVAudioFile,
203+
startTime: Double? = 0,
204+
endTime: Double? = nil,
205+
maxReadFrameSize: AVAudioFrameCount? = nil
206+
) throws -> AVAudioPCMBuffer {
201207
let sampleRate = audioFile.fileFormat.sampleRate
202208
let channelCount = audioFile.fileFormat.channelCount
203209
let frameLength = AVAudioFrameCount(audioFile.length)
@@ -243,13 +249,56 @@ public class AudioProcessor: NSObject, AudioProcessing {
243249
}
244250
}
245251

252+
public static func loadAudioAsFloatArray(
253+
fromPath audioFilePath: String,
254+
startTime: Double? = 0,
255+
endTime: Double? = nil
256+
) throws -> [Float] {
257+
guard FileManager.default.fileExists(atPath: audioFilePath) else {
258+
throw WhisperError.loadAudioFailed("Resource path does not exist \(audioFilePath)")
259+
}
260+
261+
let audioFileURL = URL(fileURLWithPath: audioFilePath)
262+
let audioFile = try AVAudioFile(forReading: audioFileURL, commonFormat: .pcmFormatFloat32, interleaved: false)
263+
let inputSampleRate = audioFile.fileFormat.sampleRate
264+
let inputFrameCount = AVAudioFrameCount(audioFile.length)
265+
let inputDuration = Double(inputFrameCount) / inputSampleRate
266+
267+
let start = startTime ?? 0
268+
let end = min(endTime ?? inputDuration, inputDuration)
269+
270+
// Load 10m of audio at a time to reduce peak memory while converting
271+
// Particularly impactful for large audio files
272+
let chunkDuration: Double = 60 * 10
273+
var currentTime = start
274+
var result: [Float] = []
275+
276+
while currentTime < end {
277+
let chunkEnd = min(currentTime + chunkDuration, end)
278+
279+
try autoreleasepool {
280+
let buffer = try loadAudio(
281+
fromFile: audioFile,
282+
startTime: currentTime,
283+
endTime: chunkEnd
284+
)
285+
286+
let floatArray = Self.convertBufferToArray(buffer: buffer)
287+
result.append(contentsOf: floatArray)
288+
}
289+
290+
currentTime = chunkEnd
291+
}
292+
293+
return result
294+
}
295+
246296
public static func loadAudio(at audioPaths: [String]) async -> [Result<[Float], Swift.Error>] {
247297
await withTaskGroup(of: [(index: Int, result: Result<[Float], Swift.Error>)].self) { taskGroup -> [Result<[Float], Swift.Error>] in
248298
for (index, audioPath) in audioPaths.enumerated() {
249299
taskGroup.addTask {
250300
do {
251-
let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath)
252-
let audio = AudioProcessor.convertBufferToArray(buffer: audioBuffer)
301+
let audio = try AudioProcessor.loadAudioAsFloatArray(fromPath: audioPath)
253302
return [(index: index, result: .success(audio))]
254303
} catch {
255304
return [(index: index, result: .failure(error))]
@@ -280,10 +329,10 @@ public class AudioProcessor: NSObject, AudioProcessing {
280329
frameCount: AVAudioFrameCount? = nil,
281330
maxReadFrameSize: AVAudioFrameCount = Constants.defaultAudioReadFrameSize
282331
) -> AVAudioPCMBuffer? {
283-
let inputFormat = audioFile.fileFormat
332+
let inputSampleRate = audioFile.fileFormat.sampleRate
284333
let inputStartFrame = audioFile.framePosition
285334
let inputFrameCount = frameCount ?? AVAudioFrameCount(audioFile.length)
286-
let inputDuration = Double(inputFrameCount) / inputFormat.sampleRate
335+
let inputDuration = Double(inputFrameCount) / inputSampleRate
287336
let endFramePosition = min(inputStartFrame + AVAudioFramePosition(inputFrameCount), audioFile.length + 1)
288337

289338
guard let outputFormat = AVAudioFormat(standardFormatWithSampleRate: sampleRate, channels: channelCount) else {
@@ -305,8 +354,8 @@ public class AudioProcessor: NSObject, AudioProcessing {
305354
let remainingFrames = AVAudioFrameCount(endFramePosition - audioFile.framePosition)
306355
let framesToRead = min(remainingFrames, maxReadFrameSize)
307356

308-
let currentPositionInSeconds = Double(audioFile.framePosition) / inputFormat.sampleRate
309-
let nextPositionInSeconds = (Double(audioFile.framePosition) + Double(framesToRead)) / inputFormat.sampleRate
357+
let currentPositionInSeconds = Double(audioFile.framePosition) / inputSampleRate
358+
let nextPositionInSeconds = (Double(audioFile.framePosition) + Double(framesToRead)) / inputSampleRate
310359
Logging.debug("Resampling \(String(format: "%.2f", currentPositionInSeconds))s - \(String(format: "%.2f", nextPositionInSeconds))s")
311360

312361
do {
@@ -644,7 +693,7 @@ public class AudioProcessor: NSObject, AudioProcessing {
644693
&propertySize,
645694
&name
646695
)
647-
if status == noErr, let deviceNameCF = name?.takeUnretainedValue() as String? {
696+
if status == noErr, let deviceNameCF = name?.takeRetainedValue() as String? {
648697
deviceName = deviceNameCF
649698
}
650699

Sources/WhisperKit/Core/TextDecoder.swift

+4-2
Original file line numberDiff line numberDiff line change
@@ -591,9 +591,11 @@ open class TextDecoder: TextDecoding, WhisperMLModel {
591591
var hasAlignment = false
592592
var isFirstTokenLogProbTooLow = false
593593
let windowUUID = UUID()
594-
DispatchQueue.global().async { [weak self] in
594+
Task { [weak self] in
595595
guard let self = self else { return }
596-
self.shouldEarlyStop[windowUUID] = false
596+
await MainActor.run {
597+
self.shouldEarlyStop[windowUUID] = false
598+
}
597599
}
598600
for tokenIndex in prefilledIndex..<loopCount {
599601
let loopStart = Date()

Sources/WhisperKit/Core/Utils.swift

+10-1
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,14 @@ public extension String {
205205
}
206206

207207
extension AVAudioPCMBuffer {
208+
/// Converts the buffer to a float array
209+
func asFloatArray() throws -> [Float] {
210+
guard let data = floatChannelData?.pointee else {
211+
throw WhisperError.audioProcessingFailed("Error converting audio, missing floatChannelData")
212+
}
213+
return Array(UnsafeBufferPointer(start: data, count: Int(frameLength)))
214+
}
215+
208216
/// Appends the contents of another buffer to the current buffer
209217
func appendContents(of buffer: AVAudioPCMBuffer) -> Bool {
210218
return appendContents(of: buffer, startingFrame: 0, frameCount: buffer.frameLength)
@@ -446,8 +454,9 @@ public func modelSupport(for deviceName: String, from config: ModelSupportConfig
446454
/// Deprecated
447455
@available(*, deprecated, message: "Subject to removal in a future version. Use modelSupport(for:from:) -> ModelSupport instead.")
448456
@_disfavoredOverload
457+
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
449458
public func modelSupport(for deviceName: String, from config: ModelSupportConfig? = nil) -> (default: String, disabled: [String]) {
450-
let modelSupport = modelSupport(for: deviceName, from: config)
459+
let modelSupport: ModelSupport = modelSupport(for: deviceName, from: config)
451460
return (modelSupport.default, modelSupport.disabled)
452461
}
453462

Sources/WhisperKit/Core/WhisperKit.swift

+26-23
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,8 @@ open class WhisperKit {
446446
open func detectLanguage(
447447
audioPath: String
448448
) async throws -> (language: String, langProbs: [String: Float]) {
449-
let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath)
449+
// Only need the first 30s for language detection
450+
let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath, endTime: 30.0)
450451
let audioArray = AudioProcessor.convertBufferToArray(buffer: audioBuffer)
451452
return try await detectLangauge(audioArray: audioArray)
452453
}
@@ -721,15 +722,17 @@ open class WhisperKit {
721722
callback: TranscriptionCallback = nil
722723
) async throws -> [TranscriptionResult] {
723724
// Process input audio file into audio samples
724-
let loadAudioStart = Date()
725-
let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath)
726-
let loadTime = Date().timeIntervalSince(loadAudioStart)
725+
let audioArray = try await withThrowingTaskGroup(of: [Float].self) { group -> [Float] in
726+
let convertAudioStart = Date()
727+
defer {
728+
let convertTime = Date().timeIntervalSince(convertAudioStart)
729+
currentTimings.audioLoading = convertTime
730+
Logging.debug("Audio loading and convert time: \(convertTime)")
731+
logCurrentMemoryUsage("Audio Loading and Convert")
732+
}
727733

728-
let convertAudioStart = Date()
729-
let audioArray = AudioProcessor.convertBufferToArray(buffer: audioBuffer)
730-
let convertTime = Date().timeIntervalSince(convertAudioStart)
731-
currentTimings.audioLoading = loadTime + convertTime
732-
Logging.debug("Audio loading time: \(loadTime), Audio convert time: \(convertTime)")
734+
return try AudioProcessor.loadAudioAsFloatArray(fromPath: audioPath)
735+
}
733736

734737
let transcribeResults: [TranscriptionResult] = try await transcribe(
735738
audioArray: audioArray,
@@ -837,23 +840,23 @@ open class WhisperKit {
837840
throw WhisperError.tokenizerUnavailable()
838841
}
839842

840-
let childProgress = Progress()
841-
progress.totalUnitCount += 1
842-
progress.addChild(childProgress, withPendingUnitCount: 1)
843-
844-
let transcribeTask = TranscribeTask(
845-
currentTimings: currentTimings,
846-
progress: childProgress,
847-
audioEncoder: audioEncoder,
848-
featureExtractor: featureExtractor,
849-
segmentSeeker: segmentSeeker,
850-
textDecoder: textDecoder,
851-
tokenizer: tokenizer
852-
)
853-
854843
do {
855844
try Task.checkCancellation()
856845

846+
let childProgress = Progress()
847+
progress.totalUnitCount += 1
848+
progress.addChild(childProgress, withPendingUnitCount: 1)
849+
850+
let transcribeTask = TranscribeTask(
851+
currentTimings: currentTimings,
852+
progress: childProgress,
853+
audioEncoder: audioEncoder,
854+
featureExtractor: featureExtractor,
855+
segmentSeeker: segmentSeeker,
856+
textDecoder: textDecoder,
857+
tokenizer: tokenizer
858+
)
859+
857860
let transcribeTaskResult = try await transcribeTask.run(
858861
audioArray: audioArray,
859862
decodeOptions: decodeOptions,

Sources/WhisperKitCLI/CLIArguments.swift

+4-4
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,9 @@ struct CLIArguments: ParsableArguments {
103103
@Flag(help: "Simulate streaming transcription using the input audio file")
104104
var streamSimulated: Bool = false
105105

106-
@Option(help: "Maximum concurrent inference, might be helpful when processing more than 1 audio file at the same time. 0 means unlimited")
107-
var concurrentWorkerCount: Int = 0
106+
@Option(help: "Maximum concurrent inference, might be helpful when processing more than 1 audio file at the same time. 0 means unlimited. Default: 4")
107+
var concurrentWorkerCount: Int = 4
108108

109-
@Option(help: "Chunking strategy for audio processing, `nil` means no chunking, `vad` means using voice activity detection")
110-
var chunkingStrategy: String? = nil
109+
@Option(help: "Chunking strategy for audio processing, `none` means no chunking, `vad` means using voice activity detection. Default: `vad`")
110+
var chunkingStrategy: String = "vad"
111111
}

Sources/WhisperKitCLI/TranscribeCLI.swift

+3-11
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,8 @@ struct TranscribeCLI: AsyncParsableCommand {
3838
cliArguments.audioPath = audioFiles.map { audioFolder + "/" + $0 }
3939
}
4040

41-
if let chunkingStrategyRaw = cliArguments.chunkingStrategy {
42-
if ChunkingStrategy(rawValue: chunkingStrategyRaw) == nil {
43-
throw ValidationError("Wrong chunking strategy \"\(chunkingStrategyRaw)\", valid strategies: \(ChunkingStrategy.allCases.map { $0.rawValue })")
44-
}
41+
if ChunkingStrategy(rawValue: cliArguments.chunkingStrategy) == nil {
42+
throw ValidationError("Wrong chunking strategy \"\(cliArguments.chunkingStrategy)\", valid strategies: \(ChunkingStrategy.allCases.map { $0.rawValue })")
4543
}
4644
}
4745

@@ -318,12 +316,6 @@ struct TranscribeCLI: AsyncParsableCommand {
318316
}
319317

320318
private func decodingOptions(task: DecodingTask) -> DecodingOptions {
321-
let chunkingStrategy: ChunkingStrategy? =
322-
if let chunkingStrategyRaw = cliArguments.chunkingStrategy {
323-
ChunkingStrategy(rawValue: chunkingStrategyRaw)
324-
} else {
325-
nil
326-
}
327319
return DecodingOptions(
328320
verbose: cliArguments.verbose,
329321
task: task,
@@ -344,7 +336,7 @@ struct TranscribeCLI: AsyncParsableCommand {
344336
firstTokenLogProbThreshold: cliArguments.firstTokenLogProbThreshold,
345337
noSpeechThreshold: cliArguments.noSpeechThreshold ?? 0.6,
346338
concurrentWorkerCount: cliArguments.concurrentWorkerCount,
347-
chunkingStrategy: chunkingStrategy
339+
chunkingStrategy: ChunkingStrategy(rawValue: cliArguments.chunkingStrategy)
348340
)
349341
}
350342

Tests/WhisperKitTests/UnitTests.swift

+4-1
Original file line numberDiff line numberDiff line change
@@ -548,9 +548,11 @@ final class UnitTests: XCTestCase {
548548
}
549549

550550
func testDecodingEarlyStopping() async throws {
551+
let earlyStopTokenCount = 10
551552
let options = DecodingOptions()
552553
let continuationCallback: TranscriptionCallback = { (progress: TranscriptionProgress) -> Bool? in
553-
false
554+
// Stop after only 10 tokens (full test audio contains 16)
555+
return progress.tokens.count <= earlyStopTokenCount
554556
}
555557

556558
let result = try await XCTUnwrapAsync(
@@ -576,6 +578,7 @@ final class UnitTests: XCTestCase {
576578
XCTAssertNotNil(resultWithWait)
577579
let tokenCountWithWait = resultWithWait.segments.flatMap { $0.tokens }.count
578580
let decodingTimePerTokenWithWait = resultWithWait.timings.decodingLoop / Double(tokenCountWithWait)
581+
Logging.debug("Decoding loop without wait: \(result.timings.decodingLoop), with wait: \(resultWithWait.timings.decodingLoop)")
579582

580583
// Assert that the decoding predictions per token are not slower with the waiting
581584
XCTAssertEqual(decodingTimePerTokenWithWait, decodingTimePerToken, accuracy: decodingTimePerToken, "Decoding predictions per token should not be significantly slower with waiting")

0 commit comments

Comments
 (0)