Skip to content

Commit 0af7146

Browse files
authored
Add ability to prevent config.json being written to ~/Documents/huggingface/... (#262)
* `fetchAvailableModels` and `fetchAvailableModels` take `downloadBase: URL?` param Which controls where HubAPI leaves the config.json that it downloads. * `recommendedRemoteModels` also gets passed the downloadBase (and repo param)
1 parent e8eebbe commit 0af7146

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

Sources/WhisperKit/Core/WhisperKit.swift

+8-8
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,14 @@ open class WhisperKit {
152152
return modelSupport(for: deviceName)
153153
}
154154

155-
public static func recommendedRemoteModels(from repo: String = "argmaxinc/whisperkit-coreml") async -> ModelSupport {
155+
public static func recommendedRemoteModels(from repo: String = "argmaxinc/whisperkit-coreml", downloadBase: URL? = nil) async -> ModelSupport {
156156
let deviceName = Self.deviceName()
157-
let config = await Self.fetchModelSupportConfig(from: repo)
157+
let config = await Self.fetchModelSupportConfig(from: repo, downloadBase: downloadBase)
158158
return modelSupport(for: deviceName, from: config)
159159
}
160160

161-
public static func fetchModelSupportConfig(from repo: String = "argmaxinc/whisperkit-coreml") async -> ModelSupportConfig {
162-
let hubApi = HubApi()
161+
public static func fetchModelSupportConfig(from repo: String = "argmaxinc/whisperkit-coreml", downloadBase: URL? = nil) async -> ModelSupportConfig {
162+
let hubApi = HubApi(downloadBase: downloadBase)
163163
var modelSupportConfig = Constants.fallbackModelSupportConfig
164164

165165
do {
@@ -176,8 +176,8 @@ open class WhisperKit {
176176
return modelSupportConfig
177177
}
178178

179-
public static func fetchAvailableModels(from repo: String = "argmaxinc/whisperkit-coreml", matching: [String] = ["*"]) async throws -> [String] {
180-
let modelSupportConfig = await fetchModelSupportConfig(from: repo)
179+
public static func fetchAvailableModels(from repo: String = "argmaxinc/whisperkit-coreml", matching: [String] = ["*"], downloadBase: URL? = nil) async throws -> [String] {
180+
let modelSupportConfig = await fetchModelSupportConfig(from: repo, downloadBase: downloadBase)
181181
let supportedModels = modelSupportConfig.modelSupport().supported
182182
var filteredSupportSet: Set<String> = []
183183
for glob in matching {
@@ -290,10 +290,10 @@ open class WhisperKit {
290290
self.modelFolder = URL(fileURLWithPath: folder)
291291
} else if download {
292292
// Determine the model variant to use
293-
let modelSupport = await WhisperKit.recommendedRemoteModels()
293+
let repo = modelRepo ?? "argmaxinc/whisperkit-coreml"
294+
let modelSupport = await WhisperKit.recommendedRemoteModels(from: repo, downloadBase: downloadBase)
294295
let modelVariant = model ?? modelSupport.default
295296

296-
let repo = modelRepo ?? "argmaxinc/whisperkit-coreml"
297297
do {
298298
self.modelFolder = try await Self.download(
299299
variant: modelVariant,

0 commit comments

Comments
 (0)