Skip to content

Commit 4ef384e

Browse files
Prepare for swift-transformers upgrade, add hooks for TranscribeTask overrides (#367)
Co-authored-by: chen <[email protected]>
1 parent 0926e52 commit 4ef384e

File tree

7 files changed

+118
-19
lines changed

7 files changed

+118
-19
lines changed

.github/workflows/pre-release-tests.yml renamed to .github/workflows/release-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Pre-Release Tests
1+
name: Release Tests
22

33
on:
44
push:

.github/workflows/unit-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ jobs:
7575
run: |
7676
echo "Simulators on runner:"
7777
xcrun simctl list
78-
if [[ "${{ matrix.run-config['name'] }}" != "macOS" ]]; then
78+
if [[ "${{ matrix.run-config['name'] }}" == "visionOS" ]]; then
7979
xcodebuild -downloadPlatform ${{ matrix.run-config['name'] }}
8080
fi
8181
echo "Runtimes for testing:"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
# WhisperKit
1313

14-
[![Tests](https://github.com/argmaxinc/whisperkit/actions/workflows/unit-tests.yml/badge.svg)](https://github.com/argmaxinc/whisperkit/actions/workflows/pre-release-tests.yml)
14+
[![Tests](https://github.com/argmaxinc/whisperkit/actions/workflows/release-tests.yml/badge.svg)](https://github.com/argmaxinc/whisperkit/actions/workflows/release-tests.yml)
1515
[![License](https://img.shields.io/github/license/argmaxinc/whisperkit?logo=github&logoColor=969da4&label=License&labelColor=353a41&color=32d058)](LICENSE.md)
1616
[![Supported Swift Version](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Fargmaxinc%2FWhisperKit%2Fbadge%3Ftype%3Dswift-versions&labelColor=353a41&color=32d058)](https://swiftpackageindex.com/argmaxinc/WhisperKit) [![Supported Platforms](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Fargmaxinc%2FWhisperKit%2Fbadge%3Ftype%3Dplatforms&labelColor=353a41&color=32d058)](https://swiftpackageindex.com/argmaxinc/WhisperKit)
1717
[![Discord](https://img.shields.io/discord/1171912382512115722?style=flat&logo=discord&logoColor=969da4&label=Discord&labelColor=353a41&color=32d058&link=https%3A%2F%2Fdiscord.gg%2FG5F5GZGecC)](https://discord.gg/G5F5GZGecC)

Sources/WhisperKit/Core/TextDecoder.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,8 +355,8 @@ public extension TextDecoding {
355355
{
356356
// Prefilling kv cache data requires non-nil task and language tokens, set defaults if not provided
357357
// Task tokens are remapped to 0->transcribe and 1->translate for the prefill lookup table
358-
let task = MLMultiArray.from([taskToken == tokenizer.specialTokens.transcribeToken ? 0 : 1])
359-
let lang = MLMultiArray.from([languageToken])
358+
let task = try MLMultiArray.from([taskToken == tokenizer.specialTokens.transcribeToken ? 0 : 1])
359+
let lang = try MLMultiArray.from([languageToken])
360360
guard let prefillOutput = try await self.prefillKVCache(withTask: task, andLanguage: lang) else {
361361
Logging.error("Unable to prefill cache")
362362
return prefilledDecoderInputs

Sources/WhisperKit/Core/TranscribeTask.swift

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,20 @@ import CoreML
55
import Foundation
66

77
/// Responsible for transcribing audio chunk to text using the provided models and configurations.
8-
final class TranscribeTask {
8+
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
9+
open class TranscribeTask {
910
private var timings: TranscriptionTimings
1011
private let progress: Progress
1112
private let audioEncoder: any AudioEncoding
1213
private let featureExtractor: any FeatureExtracting
1314
private let segmentSeeker: any SegmentSeeking
1415
private let textDecoder: any TextDecoding
15-
private let tokenizer: any WhisperTokenizer
1616
private let audioProcessor: any AudioProcessing
1717

18+
public private(set) var tokenizer: any WhisperTokenizer
1819
public var segmentDiscoveryCallback: SegmentDiscoveryCallback?
1920

20-
init(
21+
public init(
2122
currentTimings: TranscriptionTimings,
2223
progress: Progress?,
2324
audioProcessor: (any AudioProcessing)? = nil,
@@ -37,7 +38,23 @@ final class TranscribeTask {
3738
self.tokenizer = tokenizer
3839
}
3940

40-
func run(
41+
/// Hook for subclasses to launch work that can run alongside the main decoder pipeline.
42+
open func windowPreprocess(
43+
for paddedAudio: any AudioProcessorOutputType,
44+
seek: Int,
45+
segmentSize: Int
46+
) async {}
47+
48+
/// Hook for subclasses to finalize side work and optionally replace the segments for the current window.
49+
open func windowPostProcess(
50+
seek: Int,
51+
segmentSize: Int,
52+
originalSegments: [TranscriptionSegment]
53+
) async -> [TranscriptionSegment] {
54+
originalSegments
55+
}
56+
57+
public func run(
4158
audioArray: [Float],
4259
decodeOptions: DecodingOptions? = nil,
4360
callback: TranscriptionCallback = nil
@@ -61,7 +78,6 @@ final class TranscribeTask {
6178
// These accumulate across windows
6279
var allSegments: [TranscriptionSegment] = []
6380
var allTokens: [Int] = []
64-
var transcription = ""
6581

6682
let startDecoderInit = CFAbsoluteTimeGetCurrent()
6783
var decoderInputs = try textDecoder.prepareDecoderInputs(withPrompt: [tokenizer.specialTokens.startOfTranscriptToken])
@@ -107,6 +123,7 @@ final class TranscribeTask {
107123

108124
let windowSamples = featureExtractor.windowSamples ?? Constants.defaultWindowSamples
109125
while seek < seekClipEnd - windowPadding {
126+
let windowSeek = seek
110127
// calculate new encoder segment features
111128
let timeOffset = Float(seek) / Float(WhisperKit.sampleRate)
112129
let segmentSize = min(windowSamples, contentFrames - seek, seekClipEnd - seek)
@@ -119,6 +136,7 @@ final class TranscribeTask {
119136
guard let audioSamples = audioProcessor.padOrTrim(fromArray: clipAudioSamples, startAt: 0, toLength: windowSamples) else {
120137
throw WhisperError.transcriptionFailed("Audio samples are nil")
121138
}
139+
await windowPreprocess(for: audioSamples, seek: windowSeek, segmentSize: segmentSize)
122140
let processTime = Date().timeIntervalSince(audioProcessingStart)
123141
timings.audioProcessing += processTime
124142
timings.totalAudioProcessingRuns += 1
@@ -222,24 +240,30 @@ final class TranscribeTask {
222240
seek = min(seek, maxSeekOffset)
223241
}
224242

225-
guard let currentSegments = currentSegments else {
243+
guard let currentSegments else {
226244
// No current segment found, skip to next window
227245
continue
228246
}
229247

248+
let processedSegments = await windowPostProcess(
249+
seek: windowSeek,
250+
segmentSize: segmentSize,
251+
originalSegments: currentSegments
252+
)
253+
230254
if options.verbose {
231-
let lines = TranscriptionUtilities.formatSegments(currentSegments)
255+
let lines = TranscriptionUtilities.formatSegments(processedSegments)
232256
Logging.debug("Segments for window:")
233257
for line in lines {
234258
Logging.debug(line)
235259
}
236260
}
237261

238-
segmentDiscoveryCallback?(currentSegments)
262+
segmentDiscoveryCallback?(processedSegments)
239263

240264
// add them to the `allSegments` list
241-
allSegments.append(contentsOf: currentSegments)
242-
let allCurrentTokens = currentSegments.flatMap { $0.tokens }
265+
allSegments.append(contentsOf: processedSegments)
266+
let allCurrentTokens = processedSegments.flatMap { $0.tokens }
243267
allTokens.append(contentsOf: allCurrentTokens)
244268

245269
timings.decodingWindowing += Date().timeIntervalSince(windowingStart)
@@ -364,8 +388,23 @@ final class TranscribeTask {
364388
timings.decodingLoop = CFAbsoluteTimeGetCurrent() - startDecodeLoopTime
365389
timings.fullPipeline = CFAbsoluteTimeGetCurrent() - timings.pipelineStart
366390

367-
let wordTokens = allTokens.filter { $0 < tokenizer.specialTokens.specialTokenBegin }
368-
transcription = tokenizer.decode(tokens: wordTokens).trimmingCharacters(in: .whitespaces)
391+
let transcriptionResult = finalizeTranscriptionResult(
392+
tokens: allTokens,
393+
segments: allSegments,
394+
language: detectedLanguage,
395+
timings: timings
396+
)
397+
return transcriptionResult
398+
}
399+
400+
open func finalizeTranscriptionResult(
401+
tokens: [Int],
402+
segments allSegments: [TranscriptionSegment],
403+
language detectedLanguage: String?,
404+
timings: TranscriptionTimings
405+
) -> TranscriptionResult {
406+
let wordTokens = tokens.filter { $0 < tokenizer.specialTokens.specialTokenBegin }
407+
let transcription = tokenizer.decode(tokens: wordTokens).trimmingCharacters(in: .whitespaces)
369408
return TranscriptionResult(
370409
text: transcription,
371410
segments: allSegments,

Sources/WhisperKit/Core/WhisperKit.swift

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import AVFoundation
66
import CoreML
77
import Foundation
88
import Hub
9-
import TensorUtils
109
import Tokenizers
1110

1211
open class WhisperKit {
@@ -957,6 +956,29 @@ open class WhisperKit {
957956
return transcribeResults
958957
}
959958

959+
/// Setup the `TranscribeTask` used for decoding. Subclasses may override to provide custom behavior.
960+
open func setupTranscribeTask(
961+
currentTimings: TranscriptionTimings,
962+
progress: Progress,
963+
audioProcessor: any AudioProcessing,
964+
audioEncoder: any AudioEncoding,
965+
featureExtractor: any FeatureExtracting,
966+
segmentSeeker: any SegmentSeeking,
967+
textDecoder: any TextDecoding,
968+
tokenizer: any WhisperTokenizer
969+
) -> TranscribeTask {
970+
TranscribeTask(
971+
currentTimings: currentTimings,
972+
progress: progress,
973+
audioProcessor: audioProcessor,
974+
audioEncoder: audioEncoder,
975+
featureExtractor: featureExtractor,
976+
segmentSeeker: segmentSeeker,
977+
textDecoder: textDecoder,
978+
tokenizer: tokenizer
979+
)
980+
}
981+
960982
/// Runs the transcription task on a single audio sample array asynchronously with custom segment callback.
961983
/// - Returns: An array of `TranscriptionResult`.
962984
/// - Throws: An error if the transcription fails or if the tokenizer is unavailable.
@@ -983,7 +1005,7 @@ open class WhisperKit {
9831005
progress.totalUnitCount = max(1, progress.totalUnitCount)
9841006
progress.addChild(childProgress, withPendingUnitCount: 1)
9851007

986-
let transcribeTask = TranscribeTask(
1008+
let transcribeTask = setupTranscribeTask(
9871009
currentTimings: currentTimings,
9881010
progress: childProgress,
9891011
audioProcessor: audioProcessor,

Sources/WhisperKit/Utilities/Extensions+Internal.swift

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,24 @@
44
import AVFoundation
55
import CoreML
66

7+
extension MLMultiArray {
8+
/// All values will be stored in the last dimension of the MLMultiArray (default is dims=1)
9+
static func from(_ array: [Int], dims: Int = 1) throws -> MLMultiArray {
10+
var shape = Array(repeating: 1, count: dims)
11+
shape[shape.count - 1] = array.count
12+
/// Examples:
13+
/// dims=1 : [arr.count]
14+
/// dims=2 : [1, arr.count]
15+
///
16+
let output = try MLMultiArray(shape: shape as [NSNumber], dataType: .int32)
17+
let pointer = UnsafeMutablePointer<Int32>(OpaquePointer(output.dataPointer))
18+
for (i, item) in array.enumerated() {
19+
pointer[i] = Int32(item)
20+
}
21+
return output
22+
}
23+
}
24+
725
extension Array {
826
func batched(into size: Int) -> [[Element]] {
927
return stride(from: 0, to: count, by: size).map {
@@ -35,6 +53,26 @@ extension Array where Element: Hashable {
3553
}
3654
}
3755

56+
extension String {
57+
/// Reference: https://github.com/huggingface/swift-transformers/blob/94610577e4af9bbc267060af1e25e977604dd796/Sources/Tokenizers/Decoder.swift#L267-L275
58+
func trimmingFromEnd(character: Character = " ", upto: Int) -> String {
59+
var result = self
60+
var trimmed = 0
61+
while trimmed < upto && result.last == character {
62+
result.removeLast()
63+
trimmed += 1
64+
}
65+
return result
66+
}
67+
}
68+
69+
extension [String] {
70+
/// Reference: https://github.com/huggingface/swift-transformers/blob/94610577e4af9bbc267060af1e25e977604dd796/Sources/Hub/HubApi.swift#L983-L987
71+
func matching(glob: String) -> [String] {
72+
filter { fnmatch(glob, $0, 0) == 0 }
73+
}
74+
}
75+
3876
extension AVAudioPCMBuffer {
3977
/// Converts the buffer to a float array
4078
func asFloatArray() throws -> [Float] {

0 commit comments

Comments
 (0)