diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 83ec13e..ae3695d 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -60,7 +60,9 @@ jobs: run: make download-model MODEL=tiny - name: Install and discover destinations run: | - xcodebuild -downloadAllPlatforms + if [[ "${{ matrix.run-config['name'] }}" != "macOS" ]]; then + xcodebuild -downloadPlatform ${{ matrix.run-config['name'] }} + fi echo "Destinations for testing:" xcodebuild test-without-building -only-testing WhisperKitTests/UnitTests -scheme whisperkit-Package -showdestinations - name: Boot Simulator and Wait diff --git a/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved index 41e3727..8c640f8 100644 --- a/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -33,8 +33,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/huggingface/swift-transformers.git", "state" : { - "revision" : "74b94211bdc741694ed7e700a1104c72e5ba68fe", - "version" : "0.1.7" + "revision" : "fc6543263e4caed9bf6107466d625cfae9357f08", + "version" : "0.1.8" } } ], diff --git a/Examples/WhisperAX/WhisperAX/Views/ContentView.swift b/Examples/WhisperAX/WhisperAX/Views/ContentView.swift index 2a182fb..19320c5 100644 --- a/Examples/WhisperAX/WhisperAX/Views/ContentView.swift +++ b/Examples/WhisperAX/WhisperAX/Views/ContentView.swift @@ -110,7 +110,6 @@ struct ContentView: View { MenuItem(name: "Stream", image: "waveform.badge.mic"), ] - private var isStreamMode: Bool { self.selectedCategoryId == menu.first(where: { $0.name == "Stream" })?.id } @@ -202,7 +201,7 @@ struct ContentView: View { .toolbar(content: { ToolbarItem { Button { - if (!enableEagerDecoding) { + if !enableEagerDecoding { let fullTranscript = formatSegments(confirmedSegments + unconfirmedSegments, withTimestamps: enableTimestamps).joined(separator: "\n") #if os(iOS) UIPasteboard.general.string = fullTranscript @@ -956,9 +955,7 @@ struct ContentView: View { localModels = WhisperKit.formatModelFiles(localModels) for model in localModels { - if !availableModels.contains(model), - !disabledModels.contains(model) - { + if !availableModels.contains(model) { availableModels.append(model) } } @@ -967,12 +964,17 @@ struct ContentView: View { print("Previously selected model: \(selectedModel)") Task { - let remoteModels = try await WhisperKit.fetchAvailableModels(from: repoName) - for model in remoteModels { - if !availableModels.contains(model), - !disabledModels.contains(model) - { - availableModels.append(model) + let remoteModelSupport = await WhisperKit.recommendedRemoteModels() + await MainActor.run { + for model in remoteModelSupport.supported { + if !availableModels.contains(model) { + availableModels.append(model) + } + } + for model in remoteModelSupport.disabled { + if !disabledModels.contains(model) { + disabledModels.append(model) + } } } } @@ -1644,7 +1646,6 @@ struct ContentView: View { finalizeText() } - let mergedResult = mergeTranscriptionResults(eagerResults, confirmedWords: confirmedWords) return mergedResult diff --git a/Package.resolved b/Package.resolved index 6cccf25..527eff0 100644 --- a/Package.resolved +++ b/Package.resolved @@ -14,8 +14,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/huggingface/swift-transformers.git", "state" : { - "revision" : "74b94211bdc741694ed7e700a1104c72e5ba68fe", - "version" : "0.1.7" + "revision" : "fc6543263e4caed9bf6107466d625cfae9357f08", + "version" : "0.1.8" } } ], diff --git a/Package.swift b/Package.swift index f3f111e..8bbea16 100644 --- a/Package.swift +++ b/Package.swift @@ -20,7 +20,7 @@ let package = Package( ), ], dependencies: [ - .package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.7"), + .package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.8"), .package(url: "https://github.com/apple/swift-argument-parser.git", exact: "1.3.0"), ], targets: [ diff --git a/Sources/WhisperKit/Core/Audio/AudioChunker.swift b/Sources/WhisperKit/Core/Audio/AudioChunker.swift index 467bfd6..325d41a 100644 --- a/Sources/WhisperKit/Core/Audio/AudioChunker.swift +++ b/Sources/WhisperKit/Core/Audio/AudioChunker.swift @@ -81,8 +81,8 @@ open class VADAudioChunker: AudioChunking { // Typically this will be the full audio file, unless seek points are explicitly provided var startIndex = seekClipStart while startIndex < seekClipEnd - windowPadding { - let currentFrameLength = startIndex - seekClipStart - if startIndex >= currentFrameLength, startIndex < 0 { + let currentFrameLength = audioArray.count + guard startIndex >= 0 && startIndex < audioArray.count else { throw WhisperError.audioProcessingFailed("startIndex is outside the buffer size") } diff --git a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift index c3958cb..89edeab 100644 --- a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift @@ -95,7 +95,7 @@ public extension AudioProcessing { static func padOrTrimAudio(fromArray audioArray: [Float], startAt startIndex: Int = 0, toLength frameLength: Int = 480_000, saveSegment: Bool = false) -> MLMultiArray? { let currentFrameLength = audioArray.count - if startIndex >= currentFrameLength, startIndex < 0 { + guard startIndex >= 0 && startIndex < audioArray.count else { Logging.error("startIndex is outside the buffer size") return nil } diff --git a/Sources/WhisperKit/Core/Configurations.swift b/Sources/WhisperKit/Core/Configurations.swift index c7a38b3..7ff1a9e 100644 --- a/Sources/WhisperKit/Core/Configurations.swift +++ b/Sources/WhisperKit/Core/Configurations.swift @@ -60,8 +60,8 @@ open class WhisperKitConfig { prewarm: Bool? = nil, load: Bool? = nil, download: Bool = true, - useBackgroundDownloadSession: Bool = false - ) { + useBackgroundDownloadSession: Bool = false) + { self.model = model self.downloadBase = downloadBase self.modelRepo = modelRepo @@ -83,7 +83,6 @@ open class WhisperKitConfig { } } - /// Options for how to transcribe an audio file using WhisperKit. /// /// - Parameters: diff --git a/Sources/WhisperKit/Core/Models.swift b/Sources/WhisperKit/Core/Models.swift index 3e05132..0417adf 100644 --- a/Sources/WhisperKit/Core/Models.swift +++ b/Sources/WhisperKit/Core/Models.swift @@ -165,6 +165,114 @@ public struct ModelComputeOptions { } } +public struct ModelSupport: Codable, Equatable { + public let `default`: String + public let supported: [String] + /// Computed on init of ModelRepoConfig + public var disabled: [String] = [] + + private enum CodingKeys: String, CodingKey { + case `default`, supported + } +} + +public struct DeviceSupport: Codable { + public let identifiers: [String] + public var models: ModelSupport +} + +public struct ModelSupportConfig: Codable { + public let repoName: String + public let repoVersion: String + public var deviceSupports: [DeviceSupport] + /// Computed on init + public private(set) var knownModels: [String] + public private(set) var defaultSupport: DeviceSupport + + enum CodingKeys: String, CodingKey { + case repoName = "name" + case repoVersion = "version" + case deviceSupports = "device_support" + } + + public init(from decoder: Swift.Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let repoName = try container.decode(String.self, forKey: .repoName) + let repoVersion = try container.decode(String.self, forKey: .repoVersion) + let deviceSupports = try container.decode([DeviceSupport].self, forKey: .deviceSupports) + + self.init(repoName: repoName, repoVersion: repoVersion, deviceSupports: deviceSupports) + } + + public init(repoName: String, repoVersion: String, deviceSupports: [DeviceSupport], includeFallback: Bool = true) { + self.repoName = repoName + self.repoVersion = repoVersion + + if includeFallback { + self.deviceSupports = Self.mergeDeviceSupport(remote: deviceSupports, fallback: Constants.fallbackModelSupportConfig.deviceSupports) + self.knownModels = self.deviceSupports.flatMap { $0.models.supported }.orderedSet + } else { + self.deviceSupports = deviceSupports + self.knownModels = deviceSupports.flatMap { $0.models.supported }.orderedSet + } + + // Add default device support with all models supported for unknown devices + self.defaultSupport = DeviceSupport( + identifiers: [], + models: ModelSupport( + default: "openai_whisper-base", + supported: self.knownModels + ) + ) + + computeDisabledModels() + } + + @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) + public func modelSupport(for deviceIdentifier: String = WhisperKit.deviceName()) -> ModelSupport { + for support in deviceSupports { + if support.identifiers.contains(where: { deviceIdentifier.hasPrefix($0) }) { + return support.models + } + } + + Logging.info("No device support found for \(deviceIdentifier), using default") + return defaultSupport.models + } + + private mutating func computeDisabledModels() { + for i in 0.. [DeviceSupport] { + var mergedSupports: [DeviceSupport] = [] + let remoteIdentifiers = Set(remote.flatMap { $0.identifiers }) + + // Add remote device supports, merging with fallback if identifiers overlap + for remoteSupport in remote { + if let fallbackSupport = fallback.first(where: { $0.identifiers.contains(where: remoteSupport.identifiers.contains) }) { + let mergedModels = ModelSupport( + default: remoteSupport.models.default, + supported: (remoteSupport.models.supported + fallbackSupport.models.supported).orderedSet + ) + mergedSupports.append(DeviceSupport(identifiers: remoteSupport.identifiers, models: mergedModels)) + } else { + mergedSupports.append(remoteSupport) + } + } + + // Add fallback device supports that don't overlap with remote + for fallbackSupport in fallback where !fallbackSupport.identifiers.contains(where: remoteIdentifiers.contains) { + mergedSupports.append(fallbackSupport) + } + + return mergedSupports + } +} + // MARK: - Chunking public struct AudioChunk { @@ -1346,4 +1454,141 @@ public enum Constants { public static let defaultLanguageCode: String = "en" public static let defaultAudioReadFrameSize: AVAudioFrameCount = 1_323_000 // 30s of audio at commonly found 44.1khz sample rate + + public static let fallbackModelSupportConfig: ModelSupportConfig = { + var config = ModelSupportConfig( + repoName: "whisperkit-coreml-fallback", + repoVersion: "0.2", + deviceSupports: [ + DeviceSupport( + identifiers: ["iPhone11", "iPhone12", "Watch7", "Watch8"], + models: ModelSupport( + default: "openai_whisper-tiny", + supported: [ + "openai_whisper-base", + "openai_whisper-base.en", + "openai_whisper-tiny", + "openai_whisper-tiny.en", + ] + ) + ), + DeviceSupport( + identifiers: ["iPhone13", "iPad13,18", "iPad13,1"], + models: ModelSupport( + default: "openai_whisper-base", + supported: [ + "openai_whisper-tiny", + "openai_whisper-tiny.en", + "openai_whisper-base", + "openai_whisper-base.en", + "openai_whisper-small", + "openai_whisper-small.en", + ] + ) + ), + DeviceSupport( + identifiers: ["iPhone14", "iPhone15", "iPhone16", "iPhone17", "iPad14,1", "iPad14,2"], + models: ModelSupport( + default: "openai_whisper-base", + supported: [ + "openai_whisper-tiny", + "openai_whisper-tiny.en", + "openai_whisper-base", + "openai_whisper-base.en", + "openai_whisper-small", + "openai_whisper-small.en", + "openai_whisper-large-v2_949MB", + "openai_whisper-large-v2_turbo_955MB", + "openai_whisper-large-v3_947MB", + "openai_whisper-large-v3_turbo_954MB", + "distil-whisper_distil-large-v3_594MB", + "distil-whisper_distil-large-v3_turbo_600MB", + "openai_whisper-large-v3-v20240930_626MB", + "openai_whisper-large-v3-v20240930_turbo_632MB", + ] + ) + ), + DeviceSupport( + identifiers: [ + "Mac13", + "iMac21", + "MacBookAir10,1", + "MacBookPro17", + "MacBookPro18", + "Macmini9", + "iPad13,16", + "iPad13,4", + "iPad13,8", + ], + models: ModelSupport( + default: "openai_whisper-large-v3-v20240930", + supported: [ + "openai_whisper-tiny", + "openai_whisper-tiny.en", + "openai_whisper-base", + "openai_whisper-base.en", + "openai_whisper-small", + "openai_whisper-small.en", + "openai_whisper-large-v2", + "openai_whisper-large-v2_949MB", + "openai_whisper-large-v3", + "openai_whisper-large-v3_947MB", + "distil-whisper_distil-large-v3", + "distil-whisper_distil-large-v3_594MB", + "openai_whisper-large-v3-v20240930", + "openai_whisper-large-v3-v20240930_626MB", + ] + ) + ), + DeviceSupport( + identifiers: [ + "Mac14", + "Mac15", + "Mac16", + "iPad14,3", + "iPad14,4", + "iPad14,5", + "iPad14,6", + "iPad14,8", + "iPad14,9", + "iPad14,10", + "iPad14,11", + "iPad16", + ], + models: ModelSupport( + default: "openai_whisper-large-v3-v20240930", + supported: [ + "openai_whisper-tiny", + "openai_whisper-tiny.en", + "openai_whisper-base", + "openai_whisper-base.en", + "openai_whisper-small", + "openai_whisper-small.en", + "openai_whisper-large-v2", + "openai_whisper-large-v2_949MB", + "openai_whisper-large-v2_turbo", + "openai_whisper-large-v2_turbo_955MB", + "openai_whisper-large-v3", + "openai_whisper-large-v3_947MB", + "openai_whisper-large-v3_turbo", + "openai_whisper-large-v3_turbo_954MB", + "distil-whisper_distil-large-v3", + "distil-whisper_distil-large-v3_594MB", + "distil-whisper_distil-large-v3_turbo", + "distil-whisper_distil-large-v3_turbo_600MB", + "openai_whisper-large-v3-v20240930", + "openai_whisper-large-v3-v20240930_turbo", + "openai_whisper-large-v3-v20240930_626MB", + "openai_whisper-large-v3-v20240930_turbo_632MB", + ] + ) + ), + ], + includeFallback: false + ) + + return config + }() + + public static let knownModels: [String] = fallbackModelSupportConfig.deviceSupports.flatMap { $0.models.supported }.orderedSet } diff --git a/Sources/WhisperKit/Core/Utils.swift b/Sources/WhisperKit/Core/Utils.swift index 8713510..21ef7b1 100644 --- a/Sources/WhisperKit/Core/Utils.swift +++ b/Sources/WhisperKit/Core/Utils.swift @@ -12,6 +12,9 @@ import UIKit #elseif canImport(AppKit) import AppKit #endif +#if canImport(Darwin) +import Darwin +#endif // MARK: - Extensions @@ -37,6 +40,21 @@ public extension Array where Element == TranscriptionSegment { } } +extension Array where Element: Hashable { + /// Returns an array with duplicates removed, preserving the original order. + var orderedSet: [Element] { + var seen = Set() + return self.filter { element in + if seen.contains(element) { + return false + } else { + seen.insert(element) + return true + } + } + } +} + extension MLMultiArray { /// Calculate the linear offset by summing the products of each dimension’s index with the dimension’s stride. /// More info [here](https://developer.apple.com/documentation/coreml/mlmultiarray/2879231-subscript) @@ -123,24 +141,23 @@ public extension MLComputeUnits { } } -#if os(macOS) +#if os(macOS) || targetEnvironment(simulator) // From: https://stackoverflow.com/a/71726663 -public extension Process { - static func stringFromTerminal(command: String) -> String { - let task = Process() - let pipe = Pipe() - task.standardOutput = pipe - task.launchPath = "/bin/bash" - task.arguments = ["-c", "sysctl -n " + command] - task.launch() - return String(bytes: pipe.fileHandleForReading.availableData, encoding: .utf8) ?? "" - } - - static let processor = stringFromTerminal(command: "machdep.cpu.brand_string") - static let cores = stringFromTerminal(command: "machdep.cpu.core_count") - static let threads = stringFromTerminal(command: "machdep.cpu.thread_count") - static let vendor = stringFromTerminal(command: "machdep.cpu.vendor") - static let family = stringFromTerminal(command: "machdep.cpu.family") +public extension ProcessInfo { + static func stringFromSysctl(named name: String) -> String { + var size: size_t = 0 + sysctlbyname(name, nil, &size, nil, 0) + var machineModel = [CChar](repeating: 0, count: Int(size)) + sysctlbyname(name, &machineModel, &size, nil, 0) + return String(cString: machineModel) + } + + static let processor = stringFromSysctl(named: "machdep.cpu.brand_string") + static let cores = stringFromSysctl(named: "machdep.cpu.core_count") + static let threads = stringFromSysctl(named: "machdep.cpu.thread_count") + static let vendor = stringFromSysctl(named: "machdep.cpu.vendor") + static let family = stringFromSysctl(named: "machdep.cpu.family") + static let hwModel = stringFromSysctl(named: "hw.model") } #endif @@ -419,68 +436,19 @@ func detectVariant(logitsDim: Int, encoderDim: Int) -> ModelVariant { return modelVariant } -public func modelSupport(for deviceName: String) -> (default: String, disabled: [String]) { - switch deviceName { - case let model where model.hasPrefix("iPhone11"), // A12 - let model where model.hasPrefix("iPhone12"), // A13 - let model where model.hasPrefix("Watch7"): // Series 9 and Ultra 2 - return ("openai_whisper-base", ["openai_whisper-small", - "openai_whisper-small.en", - "openai_whisper-large-v2", - "openai_whisper-large-v2_949MB", - "openai_whisper-large-v2_turbo", - "openai_whisper-large-v2_turbo_955MB", - "openai_whisper-large-v3", - "openai_whisper-large-v3_947MB", - "openai_whisper-large-v3_turbo", - "openai_whisper-large-v3_turbo_954MB", - "distil-whisper_distil-large-v3", - "distil-whisper_distil-large-v3_594MB", - "distil-whisper_distil-large-v3_turbo_600MB", - "distil-whisper_distil-large-v3_turbo"]) - - case let model where model.hasPrefix("iPhone13"): // A14 - return ("openai_whisper-base", ["openai_whisper-large-v2", - "openai_whisper-large-v2_turbo", - "openai_whisper-large-v2_turbo_955MB", - "openai_whisper-large-v3", - "openai_whisper-large-v3_turbo", - "openai_whisper-large-v3_turbo_954MB", - "distil-whisper_distil-large-v3_turbo_600MB", - "distil-whisper_distil-large-v3_turbo"]) - - case let model where model.hasPrefix("iPhone14"), // A15 - let model where model.hasPrefix("iPhone15"), // A16 - let model where model.hasPrefix("iPhone16"): // A17 - return ("openai_whisper-base", ["openai_whisper-large-v2", - "openai_whisper-large-v2_turbo", - "openai_whisper-large-v3", - "openai_whisper-large-v3_turbo"]) - - // Fall through to macOS checks - default: - break - } - - #if os(macOS) - if deviceName.hasPrefix("arm64") { - if Process.processor.contains("Apple M1") { - // Disable turbo variants for M1 - return ("openai_whisper-base", ["openai_whisper-large-v2_turbo", - "openai_whisper-large-v2_turbo_955MB", - "openai_whisper-large-v3_turbo", - "openai_whisper-large-v3_turbo_954MB", - "distil-whisper_distil-large-v3_turbo_600MB", - "distil-whisper_distil-large-v3_turbo"]) - } else { - // Enable all variants for M2 or M3, none disabled - return ("openai_whisper-base", []) - } - } - #endif +@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) +public func modelSupport(for deviceName: String, from config: ModelSupportConfig? = nil) -> ModelSupport { + let config = config ?? Constants.fallbackModelSupportConfig + let modelSupport = config.modelSupport(for: deviceName) + return modelSupport +} - // Unhandled device, default to base variant - return ("openai_whisper-base", [""]) +/// Deprecated +@available(*, deprecated, message: "Subject to removal in a future version. Use modelSupport(for:from:) -> ModelSupport instead.") +@_disfavoredOverload +public func modelSupport(for deviceName: String, from config: ModelSupportConfig? = nil) -> (default: String, disabled: [String]) { + let modelSupport = modelSupport(for: deviceName, from: config) + return (modelSupport.default, modelSupport.disabled) } public func detectModelURL(inFolder path: URL, named modelName: String) -> URL { @@ -492,7 +460,7 @@ public func detectModelURL(inFolder path: URL, named modelName: String) -> URL { // Swap to mlpackage only if the following is true: we found the mlmodel within the mlpackage, and we did not find a .mlmodelc var modelURL = compiledUrl - if (packageModelExists && !compiledModelExists) { + if packageModelExists && !compiledModelExists { modelURL = packageUrl } diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index c1b66d5..db8d07e 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -113,22 +113,14 @@ open class WhisperKit { load: load, download: download, useBackgroundDownloadSession: useBackgroundDownloadSession - ) + ) try await self.init(config) } // MARK: - Model Loading - public static func recommendedModels() -> (default: String, disabled: [String]) { - let deviceName = Self.deviceName() - Logging.debug("Running on \(deviceName)") - - let defaultModel = modelSupport(for: deviceName).default - let disabledModels = modelSupport(for: deviceName).disabled - return (defaultModel, disabledModels) - } - public static func deviceName() -> String { + #if !os(macOS) && !targetEnvironment(simulator) var utsname = utsname() uname(&utsname) let deviceName = withUnsafePointer(to: &utsname.machine) { @@ -136,14 +128,52 @@ open class WhisperKit { String(cString: $0) } } + #else + let deviceName = ProcessInfo.hwModel + #endif return deviceName } - public static func fetchAvailableModels(from repo: String = "argmaxinc/whisperkit-coreml", matching: [String] = ["openai_*", "distil-whisper_*"]) async throws -> [String] { + public static func recommendedModels() -> ModelSupport { + let deviceName = Self.deviceName() + Logging.debug("Running on \(deviceName)") + return modelSupport(for: deviceName) + } + + public static func recommendedRemoteModels(from repo: String = "argmaxinc/whisperkit-coreml") async -> ModelSupport { + let deviceName = Self.deviceName() + let config = await Self.fetchModelSupportConfig(from: repo) + return modelSupport(for: deviceName, from: config) + } + + public static func fetchModelSupportConfig(from repo: String = "argmaxinc/whisperkit-coreml") async -> ModelSupportConfig { let hubApi = HubApi() - let modelFiles = try await hubApi.getFilenames(from: repo, matching: matching) + var modelSupportConfig = Constants.fallbackModelSupportConfig - return formatModelFiles(modelFiles) + do { + // Try to decode config.json into ModelSupportConfig + let configUrl = try await hubApi.snapshot(from: repo, matching: "config*") + let decoder = JSONDecoder() + let jsonData = try Data(contentsOf: configUrl.appendingPathComponent("config.json")) + modelSupportConfig = try decoder.decode(ModelSupportConfig.self, from: jsonData) + } catch { + // Allow this to fail gracefully as it uses fallback config by default + Logging.error(error) + } + + return modelSupportConfig + } + + public static func fetchAvailableModels(from repo: String = "argmaxinc/whisperkit-coreml", matching: [String] = ["*"]) async throws -> [String] { + let modelSupportConfig = await fetchModelSupportConfig(from: repo) + let supportedModels = modelSupportConfig.modelSupport().supported + var filteredSupportSet: Set = [] + for glob in matching { + filteredSupportSet = filteredSupportSet.union(supportedModels.matching(glob: glob)) + } + let filteredSupport = Array(filteredSupportSet) + + return formatModelFiles(filteredSupport) } public static func formatModelFiles(_ modelFiles: [String]) -> [String] { @@ -243,13 +273,14 @@ open class WhisperKit { modelFolder: String?, download: Bool ) async throws { - // Determine the model variant to use - let modelVariant = model ?? WhisperKit.recommendedModels().default - // If a local model folder is provided, use it; otherwise, download the model if let folder = modelFolder { self.modelFolder = URL(fileURLWithPath: folder) } else if download { + // Determine the model variant to use + let modelSupport = await WhisperKit.recommendedRemoteModels() + let modelVariant = model ?? modelSupport.default + let repo = modelRepo ?? "argmaxinc/whisperkit-coreml" do { self.modelFolder = try await Self.download( @@ -741,14 +772,14 @@ open class WhisperKit { let isChunkable = audioArray.count > WhisperKit.windowSamples switch (isChunkable, decodeOptions?.chunkingStrategy) { case (true, .vad): - // We have some audio that will require multiple windows and a strategy to chunk them - let vad = decodeOptions?.voiceActivityDetector ?? EnergyVAD() - let chunker = VADAudioChunker(vad: vad) - let audioChunks: [AudioChunk] = try await chunker.chunkAll( - audioArray: audioArray, - maxChunkLength: WhisperKit.windowSamples, - decodeOptions: decodeOptions - ) + // We have some audio that will require multiple windows and a strategy to chunk them + let vad = decodeOptions?.voiceActivityDetector ?? EnergyVAD() + let chunker = VADAudioChunker(vad: vad) + let audioChunks: [AudioChunk] = try await chunker.chunkAll( + audioArray: audioArray, + maxChunkLength: WhisperKit.windowSamples, + decodeOptions: decodeOptions + ) // Reset the seek times since we've already chunked the audio var chunkedOptions = decodeOptions diff --git a/Tests/WhisperKitTests/RegressionTests.swift b/Tests/WhisperKitTests/RegressionTests.swift index b2bd9f9..724eec0 100644 --- a/Tests/WhisperKitTests/RegressionTests.swift +++ b/Tests/WhisperKitTests/RegressionTests.swift @@ -149,7 +149,7 @@ final class RegressionTests: XCTestCase { let iso8601DateTimeString = ISO8601DateFormatter().string(from: Date()) #if os(macOS) && arch(arm64) - currentDevice = Process.processor + currentDevice = ProcessInfo.processor #endif do { diff --git a/Tests/WhisperKitTests/Resources/config.json b/Tests/WhisperKitTests/Resources/config.json new file mode 100644 index 0000000..d5b9f8e --- /dev/null +++ b/Tests/WhisperKitTests/Resources/config.json @@ -0,0 +1,136 @@ +{ + "name": "whisperkit-coreml", + "version": "0.2", + "device_support": [ + { + "identifiers": ["iPhone11", "iPhone12", "Watch7", "Watch8"], + "models": { + "default": "openai_whisper-tiny", + "supported": [ + "openai_whisper-tiny", + "openai_whisper-tiny.en", + "openai_whisper-base", + "openai_whisper-base.en" + ] + } + }, + { + "identifiers": ["iPhone13", "iPad13,18", "iPad13,1"], + "models": { + "default": "openai_whisper-base", + "supported": [ + "openai_whisper-tiny", + "openai_whisper-tiny.en", + "openai_whisper-base", + "openai_whisper-base.en", + "openai_whisper-small", + "openai_whisper-small.en" + ] + } + }, + { + "identifiers": [ + "iPhone14", + "iPhone15", + "iPhone16", + "iPhone17", + "iPad14,1", + "iPad14,2" + ], + "models": { + "default": "openai_whisper-base", + "supported": [ + "openai_whisper-tiny", + "openai_whisper-tiny.en", + "openai_whisper-base", + "openai_whisper-base.en", + "openai_whisper-small", + "openai_whisper-small.en", + "openai_whisper-large-v2_949MB", + "openai_whisper-large-v2_turbo_955MB", + "openai_whisper-large-v3_947MB", + "openai_whisper-large-v3_turbo_954MB", + "distil-whisper_distil-large-v3_594MB", + "distil-whisper_distil-large-v3_turbo_600MB", + "openai_whisper-large-v3-v20240930_626MB", + "openai_whisper-large-v3-v20240930_turbo_632MB" + ] + } + }, + { + "identifiers": [ + "Mac13", + "iMac21", + "MacBookAir10,1", + "MacBookPro17", + "MacBookPro18", + "Macmini9", + "iPad13,16", + "iPad13,4", + "iPad13,8" + ], + "models": { + "default": "openai_whisper-large-v3-v20240930", + "supported": [ + "openai_whisper-tiny", + "openai_whisper-tiny.en", + "openai_whisper-base", + "openai_whisper-base.en", + "openai_whisper-small", + "openai_whisper-small.en", + "openai_whisper-large-v2", + "openai_whisper-large-v2_949MB", + "openai_whisper-large-v3", + "openai_whisper-large-v3_947MB", + "distil-whisper_distil-large-v3", + "distil-whisper_distil-large-v3_594MB", + "openai_whisper-large-v3-v20240930", + "openai_whisper-large-v3-v20240930_626MB" + ] + } + }, + { + "identifiers": [ + "Mac14", + "Mac15", + "Mac16", + "iPad14,3", + "iPad14,4", + "iPad14,5", + "iPad14,6", + "iPad14,8", + "iPad14,9", + "iPad14,10", + "iPad14,11", + "iPad16" + ], + "models": { + "default": "openai_whisper-large-v3-v20240930", + "supported": [ + "openai_whisper-tiny", + "openai_whisper-tiny.en", + "openai_whisper-base", + "openai_whisper-base.en", + "openai_whisper-small", + "openai_whisper-small.en", + "openai_whisper-large-v2", + "openai_whisper-large-v2_949MB", + "openai_whisper-large-v2_turbo", + "openai_whisper-large-v2_turbo_955MB", + "openai_whisper-large-v3", + "openai_whisper-large-v3_947MB", + "openai_whisper-large-v3_turbo", + "openai_whisper-large-v3_turbo_954MB", + "distil-whisper_distil-large-v3", + "distil-whisper_distil-large-v3_594MB", + "distil-whisper_distil-large-v3_turbo", + "distil-whisper_distil-large-v3_turbo_600MB", + "openai_whisper-large-v3-v20240930", + "openai_whisper-large-v3-v20240930_turbo", + "openai_whisper-large-v3-v20240930_626MB", + "openai_whisper-large-v3-v20240930_turbo_632MB" + ] + } + } + ] +} diff --git a/Tests/WhisperKitTests/UnitTests.swift b/Tests/WhisperKitTests/UnitTests.swift index 709e4ec..5a0ca5c 100644 --- a/Tests/WhisperKitTests/UnitTests.swift +++ b/Tests/WhisperKitTests/UnitTests.swift @@ -32,6 +32,160 @@ final class UnitTests: XCTestCase { ) } + // MARK: - Config Tests + + func testModelSupportConfigFallback() { + let fallbackRepoConfig = Constants.fallbackModelSupportConfig + XCTAssertEqual(fallbackRepoConfig.repoName, "whisperkit-coreml-fallback") + XCTAssertEqual(fallbackRepoConfig.repoVersion, "0.2") + XCTAssertGreaterThanOrEqual(fallbackRepoConfig.deviceSupports.count, 5) + + // Test that all device supports have their disabled models set except devices that should support all known models + for deviceSupport in fallbackRepoConfig.deviceSupports where !Constants.knownModels.allSatisfy(deviceSupport.models.supported.contains) { + let modelSupport = deviceSupport.models.supported + let knownModels = Constants.knownModels + + // Ensure that the disabled models list is not empty + XCTAssertFalse(deviceSupport.models.disabled.isEmpty, + "Disabled models should be set for \(deviceSupport.identifiers), found missing model(s): \(modelSupport.filter { knownModels.contains($0) })") + } + + // Test that default device support has all known models as supported and none disabled + let defaultSupport = fallbackRepoConfig.defaultSupport + XCTAssertEqual(defaultSupport.identifiers, []) + XCTAssertEqual(defaultSupport.models.supported.sorted(), Constants.knownModels.sorted()) + } + + func testModelSupportConfigFromJson() throws { + let configFilePath = try XCTUnwrap( + Bundle.module.path(forResource: "config", ofType: "json"), + "Config file not found" + ) + + let jsonData = try Data(contentsOf: URL(fileURLWithPath: configFilePath)) + let decoder = JSONDecoder() + let loadedConfig = try decoder.decode(ModelSupportConfig.self, from: jsonData) + + // Compare loaded config with fallback config + XCTAssertEqual(loadedConfig.repoName, "whisperkit-coreml") + XCTAssertEqual(loadedConfig.repoVersion, Constants.fallbackModelSupportConfig.repoVersion) + XCTAssertEqual(loadedConfig.deviceSupports.count, Constants.fallbackModelSupportConfig.deviceSupports.count) + + // Compare device supports + for (loadedDeviceSupport, fallbackDeviceSupport) in zip(loadedConfig.deviceSupports, Constants.fallbackModelSupportConfig.deviceSupports) { + XCTAssertEqual(loadedDeviceSupport.identifiers, fallbackDeviceSupport.identifiers) + XCTAssertEqual(loadedDeviceSupport.models.default, fallbackDeviceSupport.models.default) + XCTAssertEqual(Set(loadedDeviceSupport.models.supported), Set(fallbackDeviceSupport.models.supported)) + XCTAssertEqual(Set(loadedDeviceSupport.models.disabled), Set(fallbackDeviceSupport.models.disabled)) + } + } + + func testModelSupportConfigCorrectness() throws { + let config = Constants.fallbackModelSupportConfig + + // Test if a model exists in config for one device but not others, it is disabled + let iPhone13Models = config.modelSupport(for: "iPhone13,1") + let iPhone14Models = config.modelSupport(for: "iPhone14,3") + + XCTAssertFalse(iPhone13Models.supported.contains("openai_whisper-large-v3_947MB")) + XCTAssertTrue(iPhone13Models.disabled.contains("openai_whisper-large-v3_947MB")) + XCTAssertTrue(iPhone14Models.supported.contains("openai_whisper-large-v3_947MB")) + + // Test when a device with the same prefix if matched to the appropriate support if different + let iPad14A15Model = config.modelSupport(for: "iPad14,1") + let iPad14M2Model = config.modelSupport(for: "iPad14,4") + + XCTAssertFalse(iPad14A15Model.supported.contains("openai_whisper-large-v3-v20240930_turbo")) + XCTAssertTrue(iPad14A15Model.disabled.contains("openai_whisper-large-v3-v20240930_turbo")) + XCTAssertTrue(iPad14M2Model.supported.contains("openai_whisper-large-v3-v20240930_turbo")) + + // Test if a model exists in a remote repo but not in the fallback config, it is disabled for all devices except default + let newModel = "some_new_model" + let newDevice = "some_new_device" + let newDeviceSupport = config.deviceSupports + [DeviceSupport( + identifiers: [newDevice], + models: ModelSupport( + default: "openai_whisper-base", + supported: [ + "some_new_model", + ] + ) + )] + + let newConfig = ModelSupportConfig( + repoName: config.repoName, + repoVersion: config.repoVersion, + deviceSupports: newDeviceSupport + ) + + XCTAssertEqual(Set(newConfig.knownModels), Set(newDeviceSupport.flatMap { $0.models.supported })) + for deviceSupport in newConfig.deviceSupports where !deviceSupport.identifiers.allSatisfy([newDevice].contains) { + XCTAssertFalse(deviceSupport.models.supported.contains(newModel)) + XCTAssertTrue(deviceSupport.models.disabled.contains(newModel)) + } + + // Test if a model does not exist in a remote repo but does in the fallback config, it is disabled + // This will not prevent use of the model if already downloaded, but will enable the remote config to disable specific models + let knownLocalModel = Constants.fallbackModelSupportConfig.modelSupport(for: "iPhone13,1").supported.first! + let remoteModel = "remote_model" + let remoteConfig = ModelSupportConfig( + repoName: "test", + repoVersion: "test", + deviceSupports: [DeviceSupport( + identifiers: ["test_device"], + models: ModelSupport( + default: remoteModel, + supported: [remoteModel] + ) + )] + ) + + // Helper method returns supported model + let modelSupport = remoteConfig.modelSupport(for: "test_device").supported + let disabledModels = remoteConfig.modelSupport(for: "test_device").disabled + XCTAssertTrue(modelSupport.contains(remoteModel)) + XCTAssertTrue(disabledModels.contains(knownLocalModel)) + // Direct access has it disabled + for deviceSupport in remoteConfig.deviceSupports where deviceSupport.identifiers.contains("test_device") { + XCTAssertTrue(deviceSupport.models.supported.contains(remoteModel)) + XCTAssertFalse(deviceSupport.models.disabled.contains(remoteModel)) + XCTAssertFalse(deviceSupport.models.supported.contains(knownLocalModel)) + XCTAssertTrue(deviceSupport.models.disabled.contains(knownLocalModel)) + } + } + + func testModelSupportConfigFetch() async throws { + // Make sure remote repo config loads successfully from HF + let modelRepoConfig = await WhisperKit.fetchModelSupportConfig() + + XCTAssertFalse(modelRepoConfig.deviceSupports.isEmpty, "Should have device supports") + XCTAssertFalse(modelRepoConfig.knownModels.isEmpty, "Should have known models") + + XCTAssertGreaterThanOrEqual(modelRepoConfig.deviceSupports.count, Constants.fallbackModelSupportConfig.deviceSupports.count, "Remote config should have at least as many devices as fallback") + + // Verify that known models in the remote config include all known models from fallback + let remoteKnownModels = Set(modelRepoConfig.knownModels) + let fallbackKnownModels = Set(Constants.fallbackModelSupportConfig.knownModels) + XCTAssertTrue(remoteKnownModels.isSuperset(of: fallbackKnownModels), "Remote known models should include all fallback known models") + + // Test an unknown device to ensure it falls back to default support + let unknownDeviceSupport = modelRepoConfig.modelSupport(for: "unknown_device") + XCTAssertEqual(unknownDeviceSupport.supported, modelRepoConfig.defaultSupport.models.supported, "Unknown device should use default support") + } + + func testRecommendedModels() async { + let asyncRemoteModels = await WhisperKit.recommendedRemoteModels() + let defaultModels = WhisperKit.recommendedModels() + + // Remote models should not be nil or empty + XCTAssertNotNil(asyncRemoteModels, "Remote models should not be nil") + XCTAssertFalse(asyncRemoteModels.default.isEmpty, "Remote model name should not be empty") + + // Default models should not be nil or empty + XCTAssertNotNil(defaultModels, "Default models should not be nil") + XCTAssertFalse(defaultModels.default.isEmpty, "Default model name should not be empty") + } + // MARK: - Audio Tests func testAudioFileLoading() throws { @@ -951,6 +1105,18 @@ final class UnitTests: XCTestCase { XCTAssertEqual("endoftext|>".trimmingSpecialTokenCharacters(), "endoftext") } + func testDeviceName() { + let deviceName = WhisperKit.deviceName() + XCTAssertFalse(deviceName.isEmpty, "Device name should not be empty") + XCTAssertTrue(deviceName.contains(","), "Device name should contain a comma, found \(deviceName)") + } + + func testOrderedSet() { + let testArray = ["model1", "model2", "model1", "model3", "model2"] + let uniqueArray = testArray.orderedSet + XCTAssertEqual(uniqueArray, ["model1", "model2", "model3"], "Ordered set should contain unique elements in order") + } + // MARK: - LogitsFilter Tests func testSuppressTokensFilter() throws {