Skip to content

Commit bfb1316

Browse files
Add model support config fetching from model repo (#216)
* Add model support config fetching from model repo * Fix audio start index error handling Co-authored-by: 1amageek <[email protected]> * Formatting * Fix CI + watchOS build - New github runner image does not include visionOS, so to prevent downloading for all platforms this will specify the platform from the test matrix * Fix typo * Use dispatch group for sync recommendedModels * Remove sync remote model fetching * Formatting and cleanup from review --------- Co-authored-by: 1amageek <[email protected]>
1 parent c2f1b57 commit bfb1316

File tree

14 files changed

+676
-128
lines changed

14 files changed

+676
-128
lines changed

.github/workflows/unit-tests.yml

+3-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ jobs:
6060
run: make download-model MODEL=tiny
6161
- name: Install and discover destinations
6262
run: |
63-
xcodebuild -downloadAllPlatforms
63+
if [[ "${{ matrix.run-config['name'] }}" != "macOS" ]]; then
64+
xcodebuild -downloadPlatform ${{ matrix.run-config['name'] }}
65+
fi
6466
echo "Destinations for testing:"
6567
xcodebuild test-without-building -only-testing WhisperKitTests/UnitTests -scheme whisperkit-Package -showdestinations
6668
- name: Boot Simulator and Wait

Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@
3333
"kind" : "remoteSourceControl",
3434
"location" : "https://github.com/huggingface/swift-transformers.git",
3535
"state" : {
36-
"revision" : "74b94211bdc741694ed7e700a1104c72e5ba68fe",
37-
"version" : "0.1.7"
36+
"revision" : "fc6543263e4caed9bf6107466d625cfae9357f08",
37+
"version" : "0.1.8"
3838
}
3939
}
4040
],

Examples/WhisperAX/WhisperAX/Views/ContentView.swift

+13-12
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ struct ContentView: View {
110110
MenuItem(name: "Stream", image: "waveform.badge.mic"),
111111
]
112112

113-
114113
private var isStreamMode: Bool {
115114
self.selectedCategoryId == menu.first(where: { $0.name == "Stream" })?.id
116115
}
@@ -202,7 +201,7 @@ struct ContentView: View {
202201
.toolbar(content: {
203202
ToolbarItem {
204203
Button {
205-
if (!enableEagerDecoding) {
204+
if !enableEagerDecoding {
206205
let fullTranscript = formatSegments(confirmedSegments + unconfirmedSegments, withTimestamps: enableTimestamps).joined(separator: "\n")
207206
#if os(iOS)
208207
UIPasteboard.general.string = fullTranscript
@@ -956,9 +955,7 @@ struct ContentView: View {
956955

957956
localModels = WhisperKit.formatModelFiles(localModels)
958957
for model in localModels {
959-
if !availableModels.contains(model),
960-
!disabledModels.contains(model)
961-
{
958+
if !availableModels.contains(model) {
962959
availableModels.append(model)
963960
}
964961
}
@@ -967,12 +964,17 @@ struct ContentView: View {
967964
print("Previously selected model: \(selectedModel)")
968965

969966
Task {
970-
let remoteModels = try await WhisperKit.fetchAvailableModels(from: repoName)
971-
for model in remoteModels {
972-
if !availableModels.contains(model),
973-
!disabledModels.contains(model)
974-
{
975-
availableModels.append(model)
967+
let remoteModelSupport = await WhisperKit.recommendedRemoteModels()
968+
await MainActor.run {
969+
for model in remoteModelSupport.supported {
970+
if !availableModels.contains(model) {
971+
availableModels.append(model)
972+
}
973+
}
974+
for model in remoteModelSupport.disabled {
975+
if !disabledModels.contains(model) {
976+
disabledModels.append(model)
977+
}
976978
}
977979
}
978980
}
@@ -1644,7 +1646,6 @@ struct ContentView: View {
16441646
finalizeText()
16451647
}
16461648

1647-
16481649
let mergedResult = mergeTranscriptionResults(eagerResults, confirmedWords: confirmedWords)
16491650

16501651
return mergedResult

Package.resolved

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
"kind" : "remoteSourceControl",
1515
"location" : "https://github.com/huggingface/swift-transformers.git",
1616
"state" : {
17-
"revision" : "74b94211bdc741694ed7e700a1104c72e5ba68fe",
18-
"version" : "0.1.7"
17+
"revision" : "fc6543263e4caed9bf6107466d625cfae9357f08",
18+
"version" : "0.1.8"
1919
}
2020
}
2121
],

Package.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ let package = Package(
2020
),
2121
],
2222
dependencies: [
23-
.package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.7"),
23+
.package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.8"),
2424
.package(url: "https://github.com/apple/swift-argument-parser.git", exact: "1.3.0"),
2525
],
2626
targets: [

Sources/WhisperKit/Core/Audio/AudioChunker.swift

+2-2
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ open class VADAudioChunker: AudioChunking {
8181
// Typically this will be the full audio file, unless seek points are explicitly provided
8282
var startIndex = seekClipStart
8383
while startIndex < seekClipEnd - windowPadding {
84-
let currentFrameLength = startIndex - seekClipStart
85-
if startIndex >= currentFrameLength, startIndex < 0 {
84+
let currentFrameLength = audioArray.count
85+
guard startIndex >= 0 && startIndex < audioArray.count else {
8686
throw WhisperError.audioProcessingFailed("startIndex is outside the buffer size")
8787
}
8888

Sources/WhisperKit/Core/Audio/AudioProcessor.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ public extension AudioProcessing {
9595
static func padOrTrimAudio(fromArray audioArray: [Float], startAt startIndex: Int = 0, toLength frameLength: Int = 480_000, saveSegment: Bool = false) -> MLMultiArray? {
9696
let currentFrameLength = audioArray.count
9797

98-
if startIndex >= currentFrameLength, startIndex < 0 {
98+
guard startIndex >= 0 && startIndex < audioArray.count else {
9999
Logging.error("startIndex is outside the buffer size")
100100
return nil
101101
}

Sources/WhisperKit/Core/Configurations.swift

+2-3
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ open class WhisperKitConfig {
6060
prewarm: Bool? = nil,
6161
load: Bool? = nil,
6262
download: Bool = true,
63-
useBackgroundDownloadSession: Bool = false
64-
) {
63+
useBackgroundDownloadSession: Bool = false)
64+
{
6565
self.model = model
6666
self.downloadBase = downloadBase
6767
self.modelRepo = modelRepo
@@ -83,7 +83,6 @@ open class WhisperKitConfig {
8383
}
8484
}
8585

86-
8786
/// Options for how to transcribe an audio file using WhisperKit.
8887
///
8988
/// - Parameters:

Sources/WhisperKit/Core/Models.swift

+245
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,114 @@ public struct ModelComputeOptions {
165165
}
166166
}
167167

168+
public struct ModelSupport: Codable, Equatable {
169+
public let `default`: String
170+
public let supported: [String]
171+
/// Computed on init of ModelRepoConfig
172+
public var disabled: [String] = []
173+
174+
private enum CodingKeys: String, CodingKey {
175+
case `default`, supported
176+
}
177+
}
178+
179+
public struct DeviceSupport: Codable {
180+
public let identifiers: [String]
181+
public var models: ModelSupport
182+
}
183+
184+
public struct ModelSupportConfig: Codable {
185+
public let repoName: String
186+
public let repoVersion: String
187+
public var deviceSupports: [DeviceSupport]
188+
/// Computed on init
189+
public private(set) var knownModels: [String]
190+
public private(set) var defaultSupport: DeviceSupport
191+
192+
enum CodingKeys: String, CodingKey {
193+
case repoName = "name"
194+
case repoVersion = "version"
195+
case deviceSupports = "device_support"
196+
}
197+
198+
public init(from decoder: Swift.Decoder) throws {
199+
let container = try decoder.container(keyedBy: CodingKeys.self)
200+
let repoName = try container.decode(String.self, forKey: .repoName)
201+
let repoVersion = try container.decode(String.self, forKey: .repoVersion)
202+
let deviceSupports = try container.decode([DeviceSupport].self, forKey: .deviceSupports)
203+
204+
self.init(repoName: repoName, repoVersion: repoVersion, deviceSupports: deviceSupports)
205+
}
206+
207+
public init(repoName: String, repoVersion: String, deviceSupports: [DeviceSupport], includeFallback: Bool = true) {
208+
self.repoName = repoName
209+
self.repoVersion = repoVersion
210+
211+
if includeFallback {
212+
self.deviceSupports = Self.mergeDeviceSupport(remote: deviceSupports, fallback: Constants.fallbackModelSupportConfig.deviceSupports)
213+
self.knownModels = self.deviceSupports.flatMap { $0.models.supported }.orderedSet
214+
} else {
215+
self.deviceSupports = deviceSupports
216+
self.knownModels = deviceSupports.flatMap { $0.models.supported }.orderedSet
217+
}
218+
219+
// Add default device support with all models supported for unknown devices
220+
self.defaultSupport = DeviceSupport(
221+
identifiers: [],
222+
models: ModelSupport(
223+
default: "openai_whisper-base",
224+
supported: self.knownModels
225+
)
226+
)
227+
228+
computeDisabledModels()
229+
}
230+
231+
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
232+
public func modelSupport(for deviceIdentifier: String = WhisperKit.deviceName()) -> ModelSupport {
233+
for support in deviceSupports {
234+
if support.identifiers.contains(where: { deviceIdentifier.hasPrefix($0) }) {
235+
return support.models
236+
}
237+
}
238+
239+
Logging.info("No device support found for \(deviceIdentifier), using default")
240+
return defaultSupport.models
241+
}
242+
243+
private mutating func computeDisabledModels() {
244+
for i in 0..<deviceSupports.count {
245+
let disabledModels = Set(knownModels).subtracting(deviceSupports[i].models.supported)
246+
self.deviceSupports[i].models.disabled = Array(disabledModels)
247+
}
248+
}
249+
250+
private static func mergeDeviceSupport(remote: [DeviceSupport], fallback: [DeviceSupport]) -> [DeviceSupport] {
251+
var mergedSupports: [DeviceSupport] = []
252+
let remoteIdentifiers = Set(remote.flatMap { $0.identifiers })
253+
254+
// Add remote device supports, merging with fallback if identifiers overlap
255+
for remoteSupport in remote {
256+
if let fallbackSupport = fallback.first(where: { $0.identifiers.contains(where: remoteSupport.identifiers.contains) }) {
257+
let mergedModels = ModelSupport(
258+
default: remoteSupport.models.default,
259+
supported: (remoteSupport.models.supported + fallbackSupport.models.supported).orderedSet
260+
)
261+
mergedSupports.append(DeviceSupport(identifiers: remoteSupport.identifiers, models: mergedModels))
262+
} else {
263+
mergedSupports.append(remoteSupport)
264+
}
265+
}
266+
267+
// Add fallback device supports that don't overlap with remote
268+
for fallbackSupport in fallback where !fallbackSupport.identifiers.contains(where: remoteIdentifiers.contains) {
269+
mergedSupports.append(fallbackSupport)
270+
}
271+
272+
return mergedSupports
273+
}
274+
}
275+
168276
// MARK: - Chunking
169277

170278
public struct AudioChunk {
@@ -1346,4 +1454,141 @@ public enum Constants {
13461454
public static let defaultLanguageCode: String = "en"
13471455

13481456
public static let defaultAudioReadFrameSize: AVAudioFrameCount = 1_323_000 // 30s of audio at commonly found 44.1khz sample rate
1457+
1458+
public static let fallbackModelSupportConfig: ModelSupportConfig = {
1459+
var config = ModelSupportConfig(
1460+
repoName: "whisperkit-coreml-fallback",
1461+
repoVersion: "0.2",
1462+
deviceSupports: [
1463+
DeviceSupport(
1464+
identifiers: ["iPhone11", "iPhone12", "Watch7", "Watch8"],
1465+
models: ModelSupport(
1466+
default: "openai_whisper-tiny",
1467+
supported: [
1468+
"openai_whisper-base",
1469+
"openai_whisper-base.en",
1470+
"openai_whisper-tiny",
1471+
"openai_whisper-tiny.en",
1472+
]
1473+
)
1474+
),
1475+
DeviceSupport(
1476+
identifiers: ["iPhone13", "iPad13,18", "iPad13,1"],
1477+
models: ModelSupport(
1478+
default: "openai_whisper-base",
1479+
supported: [
1480+
"openai_whisper-tiny",
1481+
"openai_whisper-tiny.en",
1482+
"openai_whisper-base",
1483+
"openai_whisper-base.en",
1484+
"openai_whisper-small",
1485+
"openai_whisper-small.en",
1486+
]
1487+
)
1488+
),
1489+
DeviceSupport(
1490+
identifiers: ["iPhone14", "iPhone15", "iPhone16", "iPhone17", "iPad14,1", "iPad14,2"],
1491+
models: ModelSupport(
1492+
default: "openai_whisper-base",
1493+
supported: [
1494+
"openai_whisper-tiny",
1495+
"openai_whisper-tiny.en",
1496+
"openai_whisper-base",
1497+
"openai_whisper-base.en",
1498+
"openai_whisper-small",
1499+
"openai_whisper-small.en",
1500+
"openai_whisper-large-v2_949MB",
1501+
"openai_whisper-large-v2_turbo_955MB",
1502+
"openai_whisper-large-v3_947MB",
1503+
"openai_whisper-large-v3_turbo_954MB",
1504+
"distil-whisper_distil-large-v3_594MB",
1505+
"distil-whisper_distil-large-v3_turbo_600MB",
1506+
"openai_whisper-large-v3-v20240930_626MB",
1507+
"openai_whisper-large-v3-v20240930_turbo_632MB",
1508+
]
1509+
)
1510+
),
1511+
DeviceSupport(
1512+
identifiers: [
1513+
"Mac13",
1514+
"iMac21",
1515+
"MacBookAir10,1",
1516+
"MacBookPro17",
1517+
"MacBookPro18",
1518+
"Macmini9",
1519+
"iPad13,16",
1520+
"iPad13,4",
1521+
"iPad13,8",
1522+
],
1523+
models: ModelSupport(
1524+
default: "openai_whisper-large-v3-v20240930",
1525+
supported: [
1526+
"openai_whisper-tiny",
1527+
"openai_whisper-tiny.en",
1528+
"openai_whisper-base",
1529+
"openai_whisper-base.en",
1530+
"openai_whisper-small",
1531+
"openai_whisper-small.en",
1532+
"openai_whisper-large-v2",
1533+
"openai_whisper-large-v2_949MB",
1534+
"openai_whisper-large-v3",
1535+
"openai_whisper-large-v3_947MB",
1536+
"distil-whisper_distil-large-v3",
1537+
"distil-whisper_distil-large-v3_594MB",
1538+
"openai_whisper-large-v3-v20240930",
1539+
"openai_whisper-large-v3-v20240930_626MB",
1540+
]
1541+
)
1542+
),
1543+
DeviceSupport(
1544+
identifiers: [
1545+
"Mac14",
1546+
"Mac15",
1547+
"Mac16",
1548+
"iPad14,3",
1549+
"iPad14,4",
1550+
"iPad14,5",
1551+
"iPad14,6",
1552+
"iPad14,8",
1553+
"iPad14,9",
1554+
"iPad14,10",
1555+
"iPad14,11",
1556+
"iPad16",
1557+
],
1558+
models: ModelSupport(
1559+
default: "openai_whisper-large-v3-v20240930",
1560+
supported: [
1561+
"openai_whisper-tiny",
1562+
"openai_whisper-tiny.en",
1563+
"openai_whisper-base",
1564+
"openai_whisper-base.en",
1565+
"openai_whisper-small",
1566+
"openai_whisper-small.en",
1567+
"openai_whisper-large-v2",
1568+
"openai_whisper-large-v2_949MB",
1569+
"openai_whisper-large-v2_turbo",
1570+
"openai_whisper-large-v2_turbo_955MB",
1571+
"openai_whisper-large-v3",
1572+
"openai_whisper-large-v3_947MB",
1573+
"openai_whisper-large-v3_turbo",
1574+
"openai_whisper-large-v3_turbo_954MB",
1575+
"distil-whisper_distil-large-v3",
1576+
"distil-whisper_distil-large-v3_594MB",
1577+
"distil-whisper_distil-large-v3_turbo",
1578+
"distil-whisper_distil-large-v3_turbo_600MB",
1579+
"openai_whisper-large-v3-v20240930",
1580+
"openai_whisper-large-v3-v20240930_turbo",
1581+
"openai_whisper-large-v3-v20240930_626MB",
1582+
"openai_whisper-large-v3-v20240930_turbo_632MB",
1583+
]
1584+
)
1585+
),
1586+
],
1587+
includeFallback: false
1588+
)
1589+
1590+
return config
1591+
}()
1592+
1593+
public static let knownModels: [String] = fallbackModelSupportConfig.deviceSupports.flatMap { $0.models.supported }.orderedSet
13491594
}

0 commit comments

Comments
 (0)