Skip to content

Commit d191654

Browse files
ZachNagengasta2theyEduardoPach
authored
Allow protocol defined types for model inputs and outputs (#281)
* Freeze more enums * Audio input length from CoreML metadata Add arbitrary length audio * Backwards compatible generic model io * Support generic io for model inputs and outputs * Add speed factor to timing report * Use actor for early stop checks for better concurrency safety * Add io type protocol handling and tests * Formatting * Fix timestamp token filter logic and tests * Run unit tests on any branch in PR * Upload test failure results --------- Co-authored-by: Andrey Leonov <[email protected]> Co-authored-by: Eduardo Pacheco <[email protected]>
1 parent 3bc936a commit d191654

15 files changed

+746
-253
lines changed

.github/workflows/development-tests.yml

-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ name: Development Tests
22

33
on:
44
pull_request:
5-
branches: ["main"]
65
pull_request_review:
76
types: [submitted]
87
workflow_dispatch:

.github/workflows/unit-tests.yml

+11
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,19 @@ jobs:
7575
sleep 15
7676
xcrun simctl list devices
7777
- name: Build and Test - ${{ matrix.run-config['name'] }}
78+
id: test-step
7879
if: ${{ matrix.run-config['condition'] == true }}
80+
continue-on-error: true
7981
run: |
8082
set -o pipefail
8183
xcodebuild clean build-for-testing -scheme whisperkit-Package -destination '${{ matrix.run-config['clean-destination'] }}' | xcpretty
8284
xcodebuild test -only-testing WhisperKitTests/UnitTests -scheme whisperkit-Package -destination '${{ matrix.run-config['test-destination'] }}'
85+
86+
- name: Upload Test Results
87+
if: failure() && steps.test-step.outcome == 'failure'
88+
uses: actions/upload-artifact@v4
89+
with:
90+
name: test-results-${{ matrix.run-config['name'] }}
91+
path: |
92+
~/Library/Developer/Xcode/DerivedData/**/Logs/Test/*.xcresult
93+
retention-days: 5

Sources/WhisperKit/Core/Audio/AudioProcessor.swift

+6-3
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ public extension AudioProcessing {
9393
}
9494

9595
static func padOrTrimAudio(fromArray audioArray: [Float], startAt startIndex: Int = 0, toLength frameLength: Int = 480_000, saveSegment: Bool = false) -> MLMultiArray? {
96-
guard startIndex >= 0 && startIndex < audioArray.count else {
96+
guard startIndex >= 0, startIndex < audioArray.count else {
9797
Logging.error("startIndex is outside the buffer size")
9898
return nil
9999
}
@@ -178,7 +178,6 @@ public class AudioProcessor: NSObject, AudioProcessing {
178178
}
179179

180180
public var audioBufferCallback: (([Float]) -> Void)?
181-
public var maxBufferLength = WhisperKit.sampleRate * WhisperKit.chunkLength // 30 seconds of audio at 16,000 Hz
182181
public var minBufferLength = Int(Double(WhisperKit.sampleRate) * 0.1) // 0.1 second of audio at 16,000 Hz
183182

184183
// MARK: - Loading and conversion
@@ -229,7 +228,11 @@ public class AudioProcessor: NSObject, AudioProcessing {
229228
guard let buffer = AVAudioPCMBuffer(pcmFormat: audioFile.processingFormat, frameCapacity: frameCount) else {
230229
throw WhisperError.loadAudioFailed("Unable to create audio buffer")
231230
}
232-
try audioFile.read(into: buffer, frameCount: frameCount)
231+
do {
232+
try audioFile.read(into: buffer, frameCount: frameCount)
233+
} catch {
234+
throw WhisperError.loadAudioFailed("Failed to read audio file: \(error)")
235+
}
233236
outputBuffer = buffer
234237
} else {
235238
// Audio needs resampling to 16khz

Sources/WhisperKit/Core/Audio/VoiceActivityDetector.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ open class VoiceActivityDetector {
117117
}
118118
}
119119

120-
// MARK - Utility
120+
// MARK: - Utility
121121

122122
func voiceActivityClipTimestamps(in waveform: [Float]) -> [Float] {
123123
let nonSilentChunks = calculateActiveChunks(in: waveform)

Sources/WhisperKit/Core/AudioEncoder.swift

+15-3
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,22 @@
33

44
import CoreML
55

6+
public protocol AudioEncoderOutputType {}
7+
extension MLMultiArray: AudioEncoderOutputType {}
8+
69
/// AudioEncoding protocol defines the requirements for an audio encoding implementation.
10+
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
711
public protocol AudioEncoding {
812
/// The size of the embedding produced by the encoder.
913
var embedSize: Int? { get }
1014

1115
/// Encodes the given audio features asynchronously.
1216
/// - Parameter features: The audio features to be encoded.
13-
/// - Returns: An optional `MLMultiArray` containing the encoded features.
14-
func encodeFeatures(_ features: MLMultiArray) async throws -> MLMultiArray?
17+
/// - Returns: An optional tensor containing the encoded features.
18+
func encodeFeatures(_ features: any FeatureExtractorOutputType) async throws -> (any AudioEncoderOutputType)?
1519
}
1620

21+
/// Backwards-compatible AudioEncoder implementation
1722
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
1823
public class AudioEncoder: AudioEncoding, WhisperMLModel {
1924
public var model: MLModel?
@@ -36,8 +41,15 @@ public class AudioEncoder: AudioEncoding, WhisperMLModel {
3641

3742
public init() {}
3843

44+
public func encodeFeatures(_ features: any FeatureExtractorOutputType) async throws -> (any AudioEncoderOutputType)? {
45+
guard let features = features as? MLMultiArray else {
46+
throw WhisperError.audioProcessingFailed("AudioEncoder input must be MLMultiArray")
47+
}
48+
49+
return try await encodeFeatures(features)
50+
}
51+
3952
public func encodeFeatures(_ features: MLMultiArray) async throws -> MLMultiArray? {
40-
// Make sure features is shape MultiArray (Float32 1 × {80,128} × 3000)
4153
guard let model else {
4254
throw WhisperError.modelsUnavailable()
4355
}

Sources/WhisperKit/Core/FeatureExtractor.swift

+16-1
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,15 @@ import CoreGraphics
77
import CoreML
88
import Foundation
99

10+
public protocol FeatureExtractorOutputType {}
11+
extension MLMultiArray: FeatureExtractorOutputType {}
12+
1013
public protocol FeatureExtracting {
14+
associatedtype OutputType: FeatureExtractorOutputType
15+
1116
var melCount: Int? { get }
12-
func logMelSpectrogram(fromAudio inputAudio: MLMultiArray) async throws -> MLMultiArray?
17+
var windowSamples: Int? { get }
18+
func logMelSpectrogram(fromAudio inputAudio: MLMultiArray) async throws -> OutputType?
1319
}
1420

1521
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
@@ -26,6 +32,14 @@ open class FeatureExtractor: FeatureExtracting, WhisperMLModel {
2632
return shape[1]
2733
}
2834

35+
public var windowSamples: Int? {
36+
guard let inputDescription = model?.modelDescription.inputDescriptionsByName["audio"] else { return nil }
37+
guard inputDescription.type == .multiArray else { return nil }
38+
guard let shapeConstraint = inputDescription.multiArrayConstraint else { return nil }
39+
let shape = shapeConstraint.shape.map { $0.intValue }
40+
return shape[0] // The audio input is a 1D array
41+
}
42+
2943
public func logMelSpectrogram(fromAudio inputAudio: MLMultiArray) async throws -> MLMultiArray? {
3044
guard let model else {
3145
throw WhisperError.modelsUnavailable()
@@ -40,4 +54,5 @@ open class FeatureExtractor: FeatureExtracting, WhisperMLModel {
4054
let output = MelSpectrogramOutput(features: outputFeatures)
4155
return output.melspectrogramFeatures
4256
}
57+
4358
}

Sources/WhisperKit/Core/Models.swift

+11-1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ public extension WhisperMLModel {
4949

5050
// MARK: - Whisper Models
5151

52+
@frozen
5253
public enum ModelVariant: CustomStringConvertible, CaseIterable {
5354
case tiny
5455
case tinyEn
@@ -100,6 +101,7 @@ public enum ModelVariant: CustomStringConvertible, CaseIterable {
100101
}
101102
}
102103

104+
@frozen
103105
public enum ModelState: CustomStringConvertible {
104106
case unloading
105107
case unloaded
@@ -282,6 +284,7 @@ public struct AudioChunk {
282284

283285
// MARK: - Decoding
284286

287+
@frozen
285288
public enum DecodingTask: Codable, CustomStringConvertible, CaseIterable {
286289
case transcribe
287290
case translate
@@ -296,7 +299,7 @@ public enum DecodingTask: Codable, CustomStringConvertible, CaseIterable {
296299
}
297300
}
298301

299-
public struct DecodingInputs {
302+
open class DecodingInputs {
300303
public var initialPrompt: [Int]
301304
public var inputIds: MLMultiArray
302305
public var cacheLength: MLMultiArray
@@ -355,6 +358,7 @@ public struct DecodingCache {
355358
}
356359
}
357360

361+
@frozen
358362
public enum ChunkingStrategy: String, Codable, CaseIterable {
359363
case none
360364
case vad
@@ -444,6 +448,7 @@ public struct DecodingResult {
444448
}
445449
}
446450

451+
@frozen
447452
public enum WhisperError: Error, LocalizedError, Equatable {
448453
case tokenizerUnavailable(String = "Tokenizer is unavailable")
449454
case modelsUnavailable(String = "Models are unavailable")
@@ -575,6 +580,7 @@ public struct TranscriptionResult: Codable {
575580
Total Tokens: \(totalTokens)
576581
Tokens per Second: \(String(format: "%.2f", tokensPerSecond)) tok/s
577582
Real Time Factor: \(String(format: "%.3f", rtf))
583+
Speed Factor: \(String(format: "%.3f", 1.0 / rtf))
578584
Fallbacks: \(timings.totalDecodingFallbacks)
579585
""")
580586
}
@@ -647,6 +653,7 @@ public typealias ModelStateCallback = (_ oldState: ModelState?, _ newState: Mode
647653
public typealias TranscriptionStateCallback = (_ state: TranscriptionState) -> Void
648654

649655
/// Represents the different states of the transcription process.
656+
@frozen
650657
public enum TranscriptionState: CustomStringConvertible {
651658
/// The audio is being converted to the required format for transcription
652659
case convertingAudio
@@ -1372,6 +1379,7 @@ extension WhisperTokenizerWrapper {
13721379

13731380
// MARK: Constants
13741381

1382+
@frozen
13751383
public enum Constants {
13761384
enum Logging {
13771385
static let subsystem = "com.argmax.whisperkit"
@@ -1502,6 +1510,8 @@ public enum Constants {
15021510

15031511
public static let defaultAudioReadFrameSize: AVAudioFrameCount = 1_323_000 // 30s of audio at commonly found 44.1khz sample rate
15041512

1513+
public static let defaultWindowSamples: Int = 480_000 // 30s of audio at 16khz sample rate default for Whisper models
1514+
15051515
public static let fallbackModelSupportConfig: ModelSupportConfig = {
15061516
var config = ModelSupportConfig(
15071517
repoName: "whisperkit-coreml-fallback",

Sources/WhisperKit/Core/Text/LogitsFilter.swift

+2-1
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,9 @@ open class TimestampRulesFilter: LogitsFiltering {
7575

7676
public func filterLogits(_ logits: MLMultiArray, withTokens tokens: [Int]) -> MLMultiArray {
7777
guard let sampleBegin = sampleBegin(for: tokens),
78-
sampleBegin > tokens.count
78+
sampleBegin <= tokens.count
7979
else {
80+
// Early return if we are still prefilling the prompt
8081
return logits
8182
}
8283

0 commit comments

Comments
 (0)