Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow protocol defined types for model inputs and outputs #281

Merged
merged 3 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/development-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ name: Development Tests

on:
pull_request:
branches: ["main"]
pull_request_review:
types: [submitted]
workflow_dispatch:
Expand Down
11 changes: 11 additions & 0 deletions .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,19 @@ jobs:
sleep 15
xcrun simctl list devices
- name: Build and Test - ${{ matrix.run-config['name'] }}
id: test-step
if: ${{ matrix.run-config['condition'] == true }}
continue-on-error: true
run: |
set -o pipefail
xcodebuild clean build-for-testing -scheme whisperkit-Package -destination '${{ matrix.run-config['clean-destination'] }}' | xcpretty
xcodebuild test -only-testing WhisperKitTests/UnitTests -scheme whisperkit-Package -destination '${{ matrix.run-config['test-destination'] }}'

- name: Upload Test Results
if: failure() && steps.test-step.outcome == 'failure'
uses: actions/upload-artifact@v4
with:
name: test-results-${{ matrix.run-config['name'] }}
path: |
~/Library/Developer/Xcode/DerivedData/**/Logs/Test/*.xcresult
retention-days: 5
9 changes: 6 additions & 3 deletions Sources/WhisperKit/Core/Audio/AudioProcessor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public extension AudioProcessing {
}

static func padOrTrimAudio(fromArray audioArray: [Float], startAt startIndex: Int = 0, toLength frameLength: Int = 480_000, saveSegment: Bool = false) -> MLMultiArray? {
guard startIndex >= 0 && startIndex < audioArray.count else {
guard startIndex >= 0, startIndex < audioArray.count else {
Logging.error("startIndex is outside the buffer size")
return nil
}
Expand Down Expand Up @@ -178,7 +178,6 @@ public class AudioProcessor: NSObject, AudioProcessing {
}

public var audioBufferCallback: (([Float]) -> Void)?
public var maxBufferLength = WhisperKit.sampleRate * WhisperKit.chunkLength // 30 seconds of audio at 16,000 Hz
public var minBufferLength = Int(Double(WhisperKit.sampleRate) * 0.1) // 0.1 second of audio at 16,000 Hz

// MARK: - Loading and conversion
Expand Down Expand Up @@ -229,7 +228,11 @@ public class AudioProcessor: NSObject, AudioProcessing {
guard let buffer = AVAudioPCMBuffer(pcmFormat: audioFile.processingFormat, frameCapacity: frameCount) else {
throw WhisperError.loadAudioFailed("Unable to create audio buffer")
}
try audioFile.read(into: buffer, frameCount: frameCount)
do {
try audioFile.read(into: buffer, frameCount: frameCount)
} catch {
throw WhisperError.loadAudioFailed("Failed to read audio file: \(error)")
}
outputBuffer = buffer
} else {
// Audio needs resampling to 16khz
Expand Down
2 changes: 1 addition & 1 deletion Sources/WhisperKit/Core/Audio/VoiceActivityDetector.swift
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ open class VoiceActivityDetector {
}
}

// MARK - Utility
// MARK: - Utility

func voiceActivityClipTimestamps(in waveform: [Float]) -> [Float] {
let nonSilentChunks = calculateActiveChunks(in: waveform)
Expand Down
18 changes: 15 additions & 3 deletions Sources/WhisperKit/Core/AudioEncoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,22 @@

import CoreML

public protocol AudioEncoderOutputType {}
extension MLMultiArray: AudioEncoderOutputType {}

/// AudioEncoding protocol defines the requirements for an audio encoding implementation.
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
public protocol AudioEncoding {
/// The size of the embedding produced by the encoder.
var embedSize: Int? { get }

/// Encodes the given audio features asynchronously.
/// - Parameter features: The audio features to be encoded.
/// - Returns: An optional `MLMultiArray` containing the encoded features.
func encodeFeatures(_ features: MLMultiArray) async throws -> MLMultiArray?
/// - Returns: An optional tensor containing the encoded features.
func encodeFeatures(_ features: any FeatureExtractorOutputType) async throws -> (any AudioEncoderOutputType)?
}

/// Backwards-compatible AudioEncoder implementation
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
public class AudioEncoder: AudioEncoding, WhisperMLModel {
public var model: MLModel?
Expand All @@ -36,8 +41,15 @@ public class AudioEncoder: AudioEncoding, WhisperMLModel {

public init() {}

public func encodeFeatures(_ features: any FeatureExtractorOutputType) async throws -> (any AudioEncoderOutputType)? {
guard let features = features as? MLMultiArray else {
throw WhisperError.audioProcessingFailed("AudioEncoder input must be MLMultiArray")
}

return try await encodeFeatures(features)
}

public func encodeFeatures(_ features: MLMultiArray) async throws -> MLMultiArray? {
// Make sure features is shape MultiArray (Float32 1 × {80,128} × 3000)
guard let model else {
throw WhisperError.modelsUnavailable()
}
Expand Down
17 changes: 16 additions & 1 deletion Sources/WhisperKit/Core/FeatureExtractor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@ import CoreGraphics
import CoreML
import Foundation

public protocol FeatureExtractorOutputType {}
extension MLMultiArray: FeatureExtractorOutputType {}

public protocol FeatureExtracting {
associatedtype OutputType: FeatureExtractorOutputType

var melCount: Int? { get }
func logMelSpectrogram(fromAudio inputAudio: MLMultiArray) async throws -> MLMultiArray?
var windowSamples: Int? { get }
func logMelSpectrogram(fromAudio inputAudio: MLMultiArray) async throws -> OutputType?
}

@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
Expand All @@ -26,6 +32,14 @@ open class FeatureExtractor: FeatureExtracting, WhisperMLModel {
return shape[1]
}

public var windowSamples: Int? {
guard let inputDescription = model?.modelDescription.inputDescriptionsByName["audio"] else { return nil }
guard inputDescription.type == .multiArray else { return nil }
guard let shapeConstraint = inputDescription.multiArrayConstraint else { return nil }
let shape = shapeConstraint.shape.map { $0.intValue }
return shape[0] // The audio input is a 1D array
}

public func logMelSpectrogram(fromAudio inputAudio: MLMultiArray) async throws -> MLMultiArray? {
guard let model else {
throw WhisperError.modelsUnavailable()
Expand All @@ -40,4 +54,5 @@ open class FeatureExtractor: FeatureExtracting, WhisperMLModel {
let output = MelSpectrogramOutput(features: outputFeatures)
return output.melspectrogramFeatures
}

}
12 changes: 11 additions & 1 deletion Sources/WhisperKit/Core/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public extension WhisperMLModel {

// MARK: - Whisper Models

@frozen
public enum ModelVariant: CustomStringConvertible, CaseIterable {
case tiny
case tinyEn
Expand Down Expand Up @@ -100,6 +101,7 @@ public enum ModelVariant: CustomStringConvertible, CaseIterable {
}
}

@frozen
public enum ModelState: CustomStringConvertible {
case unloading
case unloaded
Expand Down Expand Up @@ -282,6 +284,7 @@ public struct AudioChunk {

// MARK: - Decoding

@frozen
public enum DecodingTask: Codable, CustomStringConvertible, CaseIterable {
case transcribe
case translate
Expand All @@ -296,7 +299,7 @@ public enum DecodingTask: Codable, CustomStringConvertible, CaseIterable {
}
}

public struct DecodingInputs {
open class DecodingInputs {
public var initialPrompt: [Int]
public var inputIds: MLMultiArray
public var cacheLength: MLMultiArray
Expand Down Expand Up @@ -355,6 +358,7 @@ public struct DecodingCache {
}
}

@frozen
public enum ChunkingStrategy: String, Codable, CaseIterable {
case none
case vad
Expand Down Expand Up @@ -444,6 +448,7 @@ public struct DecodingResult {
}
}

@frozen
public enum WhisperError: Error, LocalizedError, Equatable {
case tokenizerUnavailable(String = "Tokenizer is unavailable")
case modelsUnavailable(String = "Models are unavailable")
Expand Down Expand Up @@ -575,6 +580,7 @@ public struct TranscriptionResult: Codable {
Total Tokens: \(totalTokens)
Tokens per Second: \(String(format: "%.2f", tokensPerSecond)) tok/s
Real Time Factor: \(String(format: "%.3f", rtf))
Speed Factor: \(String(format: "%.3f", 1.0 / rtf))
Fallbacks: \(timings.totalDecodingFallbacks)
""")
}
Expand Down Expand Up @@ -647,6 +653,7 @@ public typealias ModelStateCallback = (_ oldState: ModelState?, _ newState: Mode
public typealias TranscriptionStateCallback = (_ state: TranscriptionState) -> Void

/// Represents the different states of the transcription process.
@frozen
public enum TranscriptionState: CustomStringConvertible {
/// The audio is being converted to the required format for transcription
case convertingAudio
Expand Down Expand Up @@ -1372,6 +1379,7 @@ extension WhisperTokenizerWrapper {

// MARK: Constants

@frozen
public enum Constants {
enum Logging {
static let subsystem = "com.argmax.whisperkit"
Expand Down Expand Up @@ -1502,6 +1510,8 @@ public enum Constants {

public static let defaultAudioReadFrameSize: AVAudioFrameCount = 1_323_000 // 30s of audio at commonly found 44.1khz sample rate

public static let defaultWindowSamples: Int = 480_000 // 30s of audio at 16khz sample rate default for Whisper models

public static let fallbackModelSupportConfig: ModelSupportConfig = {
var config = ModelSupportConfig(
repoName: "whisperkit-coreml-fallback",
Expand Down
3 changes: 2 additions & 1 deletion Sources/WhisperKit/Core/Text/LogitsFilter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ open class TimestampRulesFilter: LogitsFiltering {

public func filterLogits(_ logits: MLMultiArray, withTokens tokens: [Int]) -> MLMultiArray {
guard let sampleBegin = sampleBegin(for: tokens),
sampleBegin > tokens.count
sampleBegin <= tokens.count
else {
// Early return if we are still prefilling the prompt
return logits
}

Expand Down
Loading
Loading