diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index 5f617e6..cdb17f5 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -152,14 +152,14 @@ open class WhisperKit { return modelSupport(for: deviceName) } - public static func recommendedRemoteModels(from repo: String = "argmaxinc/whisperkit-coreml") async -> ModelSupport { + public static func recommendedRemoteModels(from repo: String = "argmaxinc/whisperkit-coreml", downloadBase: URL? = nil) async -> ModelSupport { let deviceName = Self.deviceName() - let config = await Self.fetchModelSupportConfig(from: repo) + let config = await Self.fetchModelSupportConfig(from: repo, downloadBase: downloadBase) return modelSupport(for: deviceName, from: config) } - public static func fetchModelSupportConfig(from repo: String = "argmaxinc/whisperkit-coreml") async -> ModelSupportConfig { - let hubApi = HubApi() + public static func fetchModelSupportConfig(from repo: String = "argmaxinc/whisperkit-coreml", downloadBase: URL? = nil) async -> ModelSupportConfig { + let hubApi = HubApi(downloadBase: downloadBase) var modelSupportConfig = Constants.fallbackModelSupportConfig do { @@ -176,8 +176,8 @@ open class WhisperKit { return modelSupportConfig } - public static func fetchAvailableModels(from repo: String = "argmaxinc/whisperkit-coreml", matching: [String] = ["*"]) async throws -> [String] { - let modelSupportConfig = await fetchModelSupportConfig(from: repo) + public static func fetchAvailableModels(from repo: String = "argmaxinc/whisperkit-coreml", matching: [String] = ["*"], downloadBase: URL? = nil) async throws -> [String] { + let modelSupportConfig = await fetchModelSupportConfig(from: repo, downloadBase: downloadBase) let supportedModels = modelSupportConfig.modelSupport().supported var filteredSupportSet: Set = [] for glob in matching { @@ -290,10 +290,10 @@ open class WhisperKit { self.modelFolder = URL(fileURLWithPath: folder) } else if download { // Determine the model variant to use - let modelSupport = await WhisperKit.recommendedRemoteModels() + let repo = modelRepo ?? "argmaxinc/whisperkit-coreml" + let modelSupport = await WhisperKit.recommendedRemoteModels(from: repo, downloadBase: downloadBase) let modelVariant = model ?? modelSupport.default - let repo = modelRepo ?? "argmaxinc/whisperkit-coreml" do { self.modelFolder = try await Self.download( variant: modelVariant,