Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model support config fetching from model repo #216

Merged
merged 8 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ jobs:
run: make download-model MODEL=tiny
- name: Install and discover destinations
run: |
xcodebuild -downloadAllPlatforms
if [[ "${{ matrix.run-config['name'] }}" != "macOS" ]]; then
xcodebuild -downloadPlatform ${{ matrix.run-config['name'] }}
fi
echo "Destinations for testing:"
xcodebuild test-without-building -only-testing WhisperKitTests/UnitTests -scheme whisperkit-Package -showdestinations
- name: Boot Simulator and Wait
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/huggingface/swift-transformers.git",
"state" : {
"revision" : "74b94211bdc741694ed7e700a1104c72e5ba68fe",
"version" : "0.1.7"
"revision" : "fc6543263e4caed9bf6107466d625cfae9357f08",
"version" : "0.1.8"
}
}
],
Expand Down
25 changes: 13 additions & 12 deletions Examples/WhisperAX/WhisperAX/Views/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ struct ContentView: View {
MenuItem(name: "Stream", image: "waveform.badge.mic"),
]


private var isStreamMode: Bool {
self.selectedCategoryId == menu.first(where: { $0.name == "Stream" })?.id
}
Expand Down Expand Up @@ -202,7 +201,7 @@ struct ContentView: View {
.toolbar(content: {
ToolbarItem {
Button {
if (!enableEagerDecoding) {
if !enableEagerDecoding {
let fullTranscript = formatSegments(confirmedSegments + unconfirmedSegments, withTimestamps: enableTimestamps).joined(separator: "\n")
#if os(iOS)
UIPasteboard.general.string = fullTranscript
Expand Down Expand Up @@ -956,9 +955,7 @@ struct ContentView: View {

localModels = WhisperKit.formatModelFiles(localModels)
for model in localModels {
if !availableModels.contains(model),
!disabledModels.contains(model)
{
if !availableModels.contains(model) {
availableModels.append(model)
}
}
Expand All @@ -967,12 +964,17 @@ struct ContentView: View {
print("Previously selected model: \(selectedModel)")

Task {
let remoteModels = try await WhisperKit.fetchAvailableModels(from: repoName)
for model in remoteModels {
if !availableModels.contains(model),
!disabledModels.contains(model)
{
availableModels.append(model)
let remoteModelSupport = await WhisperKit.recommendedRemoteModels()
await MainActor.run {
for model in remoteModelSupport.supported {
if !availableModels.contains(model) {
availableModels.append(model)
}
}
for model in remoteModelSupport.disabled {
if !disabledModels.contains(model) {
disabledModels.append(model)
}
}
}
}
Expand Down Expand Up @@ -1644,7 +1646,6 @@ struct ContentView: View {
finalizeText()
}


let mergedResult = mergeTranscriptionResults(eagerResults, confirmedWords: confirmedWords)

return mergedResult
Expand Down
4 changes: 2 additions & 2 deletions Package.resolved
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/huggingface/swift-transformers.git",
"state" : {
"revision" : "74b94211bdc741694ed7e700a1104c72e5ba68fe",
"version" : "0.1.7"
"revision" : "fc6543263e4caed9bf6107466d625cfae9357f08",
"version" : "0.1.8"
}
}
],
Expand Down
2 changes: 1 addition & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ let package = Package(
),
],
dependencies: [
.package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.7"),
.package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.8"),
.package(url: "https://github.com/apple/swift-argument-parser.git", exact: "1.3.0"),
],
targets: [
Expand Down
4 changes: 2 additions & 2 deletions Sources/WhisperKit/Core/Audio/AudioChunker.swift
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ open class VADAudioChunker: AudioChunking {
// Typically this will be the full audio file, unless seek points are explicitly provided
var startIndex = seekClipStart
while startIndex < seekClipEnd - windowPadding {
let currentFrameLength = startIndex - seekClipStart
if startIndex >= currentFrameLength, startIndex < 0 {
let currentFrameLength = audioArray.count
guard startIndex >= 0 && startIndex < audioArray.count else {
throw WhisperError.audioProcessingFailed("startIndex is outside the buffer size")
}

Expand Down
2 changes: 1 addition & 1 deletion Sources/WhisperKit/Core/Audio/AudioProcessor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public extension AudioProcessing {
static func padOrTrimAudio(fromArray audioArray: [Float], startAt startIndex: Int = 0, toLength frameLength: Int = 480_000, saveSegment: Bool = false) -> MLMultiArray? {
let currentFrameLength = audioArray.count

if startIndex >= currentFrameLength, startIndex < 0 {
guard startIndex >= 0 && startIndex < audioArray.count else {
Logging.error("startIndex is outside the buffer size")
return nil
}
Expand Down
5 changes: 2 additions & 3 deletions Sources/WhisperKit/Core/Configurations.swift
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ open class WhisperKitConfig {
prewarm: Bool? = nil,
load: Bool? = nil,
download: Bool = true,
useBackgroundDownloadSession: Bool = false
) {
useBackgroundDownloadSession: Bool = false)
{
self.model = model
self.downloadBase = downloadBase
self.modelRepo = modelRepo
Expand All @@ -83,7 +83,6 @@ open class WhisperKitConfig {
}
}


/// Options for how to transcribe an audio file using WhisperKit.
///
/// - Parameters:
Expand Down
245 changes: 245 additions & 0 deletions Sources/WhisperKit/Core/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,114 @@ public struct ModelComputeOptions {
}
}

public struct ModelSupport: Codable, Equatable {
public let `default`: String
public let supported: [String]
/// Computed on init of ModelRepoConfig
public var disabled: [String] = []

private enum CodingKeys: String, CodingKey {
case `default`, supported
}
}

public struct DeviceSupport: Codable {
public let identifiers: [String]
public var models: ModelSupport
}

public struct ModelSupportConfig: Codable {
public let repoName: String
public let repoVersion: String
public var deviceSupports: [DeviceSupport]
/// Computed on init
public private(set) var knownModels: [String]
public private(set) var defaultSupport: DeviceSupport

enum CodingKeys: String, CodingKey {
case repoName = "name"
case repoVersion = "version"
case deviceSupports = "device_support"
}

public init(from decoder: Swift.Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
let repoName = try container.decode(String.self, forKey: .repoName)
let repoVersion = try container.decode(String.self, forKey: .repoVersion)
let deviceSupports = try container.decode([DeviceSupport].self, forKey: .deviceSupports)

self.init(repoName: repoName, repoVersion: repoVersion, deviceSupports: deviceSupports)
}

public init(repoName: String, repoVersion: String, deviceSupports: [DeviceSupport], includeFallback: Bool = true) {
self.repoName = repoName
self.repoVersion = repoVersion

if includeFallback {
self.deviceSupports = Self.mergeDeviceSupport(remote: deviceSupports, fallback: Constants.fallbackModelSupportConfig.deviceSupports)
self.knownModels = self.deviceSupports.flatMap { $0.models.supported }.orderedSet
} else {
self.deviceSupports = deviceSupports
self.knownModels = deviceSupports.flatMap { $0.models.supported }.orderedSet
}

// Add default device support with all models supported for unknown devices
self.defaultSupport = DeviceSupport(
identifiers: [],
models: ModelSupport(
default: "openai_whisper-base",
supported: self.knownModels
)
)

computeDisabledModels()
}

@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
public func modelSupport(for deviceIdentifier: String = WhisperKit.deviceName()) -> ModelSupport {
for support in deviceSupports {
if support.identifiers.contains(where: { deviceIdentifier.hasPrefix($0) }) {
return support.models
}
}

Logging.info("No device support found for \(deviceIdentifier), using default")
return defaultSupport.models
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add log that defaultSupport was used

}

private mutating func computeDisabledModels() {
for i in 0..<deviceSupports.count {
let disabledModels = Set(knownModels).subtracting(deviceSupports[i].models.supported)
self.deviceSupports[i].models.disabled = Array(disabledModels)
}
}

private static func mergeDeviceSupport(remote: [DeviceSupport], fallback: [DeviceSupport]) -> [DeviceSupport] {
var mergedSupports: [DeviceSupport] = []
let remoteIdentifiers = Set(remote.flatMap { $0.identifiers })

// Add remote device supports, merging with fallback if identifiers overlap
for remoteSupport in remote {
if let fallbackSupport = fallback.first(where: { $0.identifiers.contains(where: remoteSupport.identifiers.contains) }) {
let mergedModels = ModelSupport(
default: remoteSupport.models.default,
supported: (remoteSupport.models.supported + fallbackSupport.models.supported).orderedSet
)
mergedSupports.append(DeviceSupport(identifiers: remoteSupport.identifiers, models: mergedModels))
} else {
mergedSupports.append(remoteSupport)
}
}

// Add fallback device supports that don't overlap with remote
for fallbackSupport in fallback where !fallbackSupport.identifiers.contains(where: remoteIdentifiers.contains) {
mergedSupports.append(fallbackSupport)
}

return mergedSupports
}
}

// MARK: - Chunking

public struct AudioChunk {
Expand Down Expand Up @@ -1346,4 +1454,141 @@ public enum Constants {
public static let defaultLanguageCode: String = "en"

public static let defaultAudioReadFrameSize: AVAudioFrameCount = 1_323_000 // 30s of audio at commonly found 44.1khz sample rate

public static let fallbackModelSupportConfig: ModelSupportConfig = {
var config = ModelSupportConfig(
repoName: "whisperkit-coreml-fallback",
repoVersion: "0.2",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would we need to identify that it's a fallback config? maybe append -local to version here?

deviceSupports: [
DeviceSupport(
identifiers: ["iPhone11", "iPhone12", "Watch7", "Watch8"],
models: ModelSupport(
default: "openai_whisper-tiny",
supported: [
"openai_whisper-base",
"openai_whisper-base.en",
"openai_whisper-tiny",
"openai_whisper-tiny.en",
]
)
),
DeviceSupport(
identifiers: ["iPhone13", "iPad13,18", "iPad13,1"],
models: ModelSupport(
default: "openai_whisper-base",
supported: [
"openai_whisper-tiny",
"openai_whisper-tiny.en",
"openai_whisper-base",
"openai_whisper-base.en",
"openai_whisper-small",
"openai_whisper-small.en",
]
)
),
DeviceSupport(
identifiers: ["iPhone14", "iPhone15", "iPhone16", "iPhone17", "iPad14,1", "iPad14,2"],
models: ModelSupport(
default: "openai_whisper-base",
supported: [
"openai_whisper-tiny",
"openai_whisper-tiny.en",
"openai_whisper-base",
"openai_whisper-base.en",
"openai_whisper-small",
"openai_whisper-small.en",
"openai_whisper-large-v2_949MB",
"openai_whisper-large-v2_turbo_955MB",
"openai_whisper-large-v3_947MB",
"openai_whisper-large-v3_turbo_954MB",
"distil-whisper_distil-large-v3_594MB",
"distil-whisper_distil-large-v3_turbo_600MB",
"openai_whisper-large-v3-v20240930_626MB",
"openai_whisper-large-v3-v20240930_turbo_632MB",
]
)
),
DeviceSupport(
identifiers: [
"Mac13",
"iMac21",
"MacBookAir10,1",
"MacBookPro17",
"MacBookPro18",
"Macmini9",
"iPad13,16",
"iPad13,4",
"iPad13,8",
],
models: ModelSupport(
default: "openai_whisper-large-v3-v20240930",
supported: [
"openai_whisper-tiny",
"openai_whisper-tiny.en",
"openai_whisper-base",
"openai_whisper-base.en",
"openai_whisper-small",
"openai_whisper-small.en",
"openai_whisper-large-v2",
"openai_whisper-large-v2_949MB",
"openai_whisper-large-v3",
"openai_whisper-large-v3_947MB",
"distil-whisper_distil-large-v3",
"distil-whisper_distil-large-v3_594MB",
"openai_whisper-large-v3-v20240930",
"openai_whisper-large-v3-v20240930_626MB",
]
)
),
DeviceSupport(
identifiers: [
"Mac14",
"Mac15",
"Mac16",
"iPad14,3",
"iPad14,4",
"iPad14,5",
"iPad14,6",
"iPad14,8",
"iPad14,9",
"iPad14,10",
"iPad14,11",
"iPad16",
],
models: ModelSupport(
default: "openai_whisper-large-v3-v20240930",
supported: [
"openai_whisper-tiny",
"openai_whisper-tiny.en",
"openai_whisper-base",
"openai_whisper-base.en",
"openai_whisper-small",
"openai_whisper-small.en",
"openai_whisper-large-v2",
"openai_whisper-large-v2_949MB",
"openai_whisper-large-v2_turbo",
"openai_whisper-large-v2_turbo_955MB",
"openai_whisper-large-v3",
"openai_whisper-large-v3_947MB",
"openai_whisper-large-v3_turbo",
"openai_whisper-large-v3_turbo_954MB",
"distil-whisper_distil-large-v3",
"distil-whisper_distil-large-v3_594MB",
"distil-whisper_distil-large-v3_turbo",
"distil-whisper_distil-large-v3_turbo_600MB",
"openai_whisper-large-v3-v20240930",
"openai_whisper-large-v3-v20240930_turbo",
"openai_whisper-large-v3-v20240930_626MB",
"openai_whisper-large-v3-v20240930_turbo_632MB",
]
)
),
],
includeFallback: false
)

return config
}()

public static let knownModels: [String] = fallbackModelSupportConfig.deviceSupports.flatMap { $0.models.supported }.orderedSet
}
Loading