Skip to content

Commit 936ec3f

Browse files
Add repo option to regression test matrix (#293)
* Add repo and token option to regression test matrix * Add default Debug.xcconfig file * Update fastlane to run on repo from benchmark config * Formatting
1 parent a7e3858 commit 936ec3f

File tree

7 files changed

+89
-28
lines changed

7 files changed

+89
-28
lines changed

Examples/WhisperAX/Debug.xcconfig

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
// Run `make setup` to add your team here
2+
DEVELOPMENT_TEAM=

Examples/WhisperAX/WhisperAX.xcodeproj/xcshareddata/xcschemes/WhisperAX.xcscheme

+5
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@
7979
value = "$(MODEL_NAME)"
8080
isEnabled = "YES">
8181
</EnvironmentVariable>
82+
<EnvironmentVariable
83+
key = "MODEL_REPO"
84+
value = "$(MODEL_REPO)"
85+
isEnabled = "YES">
86+
</EnvironmentVariable>
8287
</EnvironmentVariables>
8388
</LaunchAction>
8489
<ProfileAction

Sources/WhisperKit/Core/Configurations.swift

+4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ open class WhisperKitConfig {
1212
public var downloadBase: URL?
1313
/// Repository for downloading models
1414
public var modelRepo: String?
15+
/// Token for downloading models from repo (if required)
16+
public var modelToken: String?
1517

1618
/// Folder to store models
1719
public var modelFolder: String?
@@ -47,6 +49,7 @@ open class WhisperKitConfig {
4749
public init(model: String? = nil,
4850
downloadBase: URL? = nil,
4951
modelRepo: String? = nil,
52+
modelToken: String? = nil,
5053
modelFolder: String? = nil,
5154
tokenizerFolder: URL? = nil,
5255
computeOptions: ModelComputeOptions? = nil,
@@ -67,6 +70,7 @@ open class WhisperKitConfig {
6770
self.model = model
6871
self.downloadBase = downloadBase
6972
self.modelRepo = modelRepo
73+
self.modelToken = modelToken
7074
self.modelFolder = modelFolder
7175
self.tokenizerFolder = tokenizerFolder
7276
self.computeOptions = computeOptions

Sources/WhisperKit/Core/WhisperKit.swift

+4-2
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,10 @@ open class WhisperKit {
6868
model: config.model,
6969
downloadBase: config.downloadBase,
7070
modelRepo: config.modelRepo,
71+
modelToken: config.modelToken,
7172
modelFolder: config.modelFolder,
7273
download: config.download
7374
)
74-
7575

7676
if let prewarm = config.prewarm, prewarm {
7777
Logging.info("Prewarming models...")
@@ -295,6 +295,7 @@ open class WhisperKit {
295295
model: String?,
296296
downloadBase: URL? = nil,
297297
modelRepo: String?,
298+
modelToken: String? = nil,
298299
modelFolder: String?,
299300
download: Bool
300301
) async throws {
@@ -312,7 +313,8 @@ open class WhisperKit {
312313
variant: modelVariant,
313314
downloadBase: downloadBase,
314315
useBackgroundSession: useBackgroundDownloadSession,
315-
from: repo
316+
from: repo,
317+
token: modelToken
316318
)
317319
} catch {
318320
// Handle errors related to model downloading

Tests/WhisperKitTests/RegressionTestUtils.swift

+6
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class TestInfo: JSONCodable {
5454
let datasetDir: String
5555
let datasetRepo: String
5656
let model: String
57+
let modelRepo: String
5758
let modelSizeMB: Double
5859
let date: String
5960
let timeElapsedInSeconds: TimeInterval
@@ -69,6 +70,7 @@ class TestInfo: JSONCodable {
6970
datasetDir: String,
7071
datasetRepo: String,
7172
model: String,
73+
modelRepo: String,
7274
modelSizeMB: Double,
7375
date: String,
7476
timeElapsedInSeconds: TimeInterval,
@@ -83,6 +85,7 @@ class TestInfo: JSONCodable {
8385
self.datasetDir = datasetDir
8486
self.datasetRepo = datasetRepo
8587
self.model = model
88+
self.modelRepo = modelRepo
8689
self.modelSizeMB = modelSizeMB
8790
self.date = date
8891
self.timeElapsedInSeconds = timeElapsedInSeconds
@@ -101,6 +104,7 @@ struct TestReport: JSONCodable {
101104
let osType: String
102105
let osVersion: String
103106
let modelsTested: [String]
107+
let modelReposTested: [String]
104108
let failureInfo: [String: String]
105109
let attachments: [String: String]
106110

@@ -109,13 +113,15 @@ struct TestReport: JSONCodable {
109113
osType: String,
110114
osVersion: String,
111115
modelsTested: [String],
116+
modelReposTested: [String],
112117
failureInfo: [String: String],
113118
attachments: [String: String]
114119
) {
115120
self.deviceModel = deviceModel
116121
self.osType = osType
117122
self.osVersion = osVersion
118123
self.modelsTested = modelsTested
124+
self.modelReposTested = modelReposTested
119125
self.failureInfo = failureInfo
120126
self.attachments = attachments
121127
}

Tests/WhisperKitTests/RegressionTests.swift

+56-22
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,22 @@ import WatchKit
1313
#endif
1414

1515
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
16-
final class RegressionTests: XCTestCase {
16+
class RegressionTests: XCTestCase {
1717
var audioFileURLs: [URL]?
1818
var remoteFileURLs: [URL]?
1919
var metadataURL: URL?
2020
var testWERURLs: [URL]?
2121
var modelsToTest: [String] = []
22+
var modelReposToTest: [String] = []
2223
var modelsTested: [String] = []
24+
var modelReposTested: [String] = []
2325
var optionsToTest: [DecodingOptions] = [DecodingOptions()]
2426

2527
struct TestConfig {
2628
let dataset: String
2729
let modelComputeOptions: ModelComputeOptions
2830
var model: String
31+
var modelRepo: String
2932
let decodingOptions: DecodingOptions
3033
}
3134

@@ -34,6 +37,7 @@ final class RegressionTests: XCTestCase {
3437
var datasets = ["librispeech-10mins", "earnings22-10mins"]
3538
let debugDataset = ["earnings22-10mins"]
3639
let debugModels = ["tiny"]
40+
let debugRepos = ["argmaxinc/whisperkit-coreml"]
3741

3842
var computeOptions: [ModelComputeOptions] = [
3943
ModelComputeOptions(audioEncoderCompute: .cpuAndNeuralEngine, textDecoderCompute: .cpuAndNeuralEngine),
@@ -71,16 +75,29 @@ final class RegressionTests: XCTestCase {
7175
Logging.debug("Max memory before warning: \(maxMemory)")
7276
}
7377

74-
func testEnvConfigurations(defaultModels: [String]? = nil) {
78+
class func getModelToken() -> String? {
79+
// Add token here or override
80+
return nil
81+
}
82+
83+
func testEnvConfigurations(defaultModels: [String]? = nil, defaultRepos: [String]? = nil) {
7584
if let modelSizeEnv = ProcessInfo.processInfo.environment["MODEL_NAME"], !modelSizeEnv.isEmpty {
7685
modelsToTest = [modelSizeEnv]
7786
Logging.debug("Model size: \(modelSizeEnv)")
87+
88+
if let repoEnv = ProcessInfo.processInfo.environment["MODEL_REPO"] {
89+
modelReposToTest = [repoEnv]
90+
Logging.debug("Using repo: \(repoEnv)")
91+
}
92+
7893
XCTAssertTrue(modelsToTest.count > 0, "Invalid model size: \(modelSizeEnv)")
94+
7995
if modelSizeEnv == "crash_test" {
8096
fatalError("Crash test triggered")
8197
}
8298
} else {
8399
modelsToTest = defaultModels ?? debugModels
100+
modelReposToTest = defaultRepos ?? debugRepos
84101
Logging.debug("Model size not set by env")
85102
}
86103
}
@@ -116,7 +133,7 @@ final class RegressionTests: XCTestCase {
116133

117134
// MARK: - Test Pipeline
118135

119-
private func runRegressionTests(with testMatrix: [TestConfig]) async throws {
136+
public func runRegressionTests(with testMatrix: [TestConfig]) async throws {
120137
var failureInfo: [String: String] = [:]
121138
var attachments: [String: String] = [:]
122139
let device = getCurrentDevice()
@@ -159,8 +176,7 @@ final class RegressionTests: XCTestCase {
159176

160177
// Create WhisperKit instance with checks for memory usage
161178
let whisperKit = try await createWithMemoryCheck(
162-
model: config.model,
163-
computeOptions: config.modelComputeOptions,
179+
testConfig: config,
164180
verbose: true,
165181
logLevel: .debug
166182
)
@@ -169,6 +185,8 @@ final class RegressionTests: XCTestCase {
169185
config.model = modelFile
170186
modelsTested.append(modelFile)
171187
modelsTested = Array(Set(modelsTested))
188+
modelReposTested.append(config.modelRepo)
189+
modelReposTested = Array(Set(modelReposTested))
172190
}
173191

174192
for audioFilePath in audioFilePaths {
@@ -295,6 +313,7 @@ final class RegressionTests: XCTestCase {
295313
datasetDir: config.dataset,
296314
datasetRepo: datasetRepo,
297315
model: config.model,
316+
modelRepo: config.modelRepo,
298317
modelSizeMB: modelSizeMB ?? -1,
299318
date: startTime.formatted(Date.ISO8601FormatStyle().dateSeparator(.dash)),
300319
timeElapsedInSeconds: Date().timeIntervalSince(startTime),
@@ -432,20 +451,23 @@ final class RegressionTests: XCTestCase {
432451
}
433452
}
434453

435-
private func getTestMatrix() -> [TestConfig] {
454+
public func getTestMatrix() -> [TestConfig] {
436455
var regressionTestConfigMatrix: [TestConfig] = []
437456
for dataset in datasets {
438457
for computeOption in computeOptions {
439458
for options in optionsToTest {
440-
for model in modelsToTest {
441-
regressionTestConfigMatrix.append(
442-
TestConfig(
443-
dataset: dataset,
444-
modelComputeOptions: computeOption,
445-
model: model,
446-
decodingOptions: options
459+
for repo in modelReposToTest {
460+
for model in modelsToTest {
461+
regressionTestConfigMatrix.append(
462+
TestConfig(
463+
dataset: dataset,
464+
modelComputeOptions: computeOption,
465+
model: model,
466+
modelRepo: repo,
467+
decodingOptions: options
468+
)
447469
)
448-
)
470+
}
449471
}
450472
}
451473
}
@@ -555,6 +577,7 @@ final class RegressionTests: XCTestCase {
555577
osType: osDetails.osType,
556578
osVersion: osDetails.osVersion,
557579
modelsTested: modelsTested,
580+
modelReposTested: modelReposTested,
558581
failureInfo: failureInfo,
559582
attachments: attachments
560583
)
@@ -610,17 +633,14 @@ final class RegressionTests: XCTestCase {
610633
return Double(modelSize / (1024 * 1024)) // Convert to MB
611634
}
612635

613-
func createWithMemoryCheck(
614-
model: String,
615-
computeOptions: ModelComputeOptions,
616-
verbose: Bool,
617-
logLevel: Logging.LogLevel
618-
) async throws -> WhisperKit {
636+
public func initWhisperKitTask(testConfig config: TestConfig, verbose: Bool, logLevel: Logging.LogLevel) -> Task<WhisperKit, Error> {
619637
// Create the initialization task
620638
let initializationTask = Task { () -> WhisperKit in
621639
let whisperKit = try await WhisperKit(WhisperKitConfig(
622-
model: model,
623-
computeOptions: computeOptions,
640+
model: config.model,
641+
modelRepo: config.modelRepo,
642+
modelToken: Self.getModelToken(),
643+
computeOptions: config.modelComputeOptions,
624644
verbose: verbose,
625645
logLevel: logLevel,
626646
prewarm: true,
@@ -629,6 +649,20 @@ final class RegressionTests: XCTestCase {
629649
try Task.checkCancellation()
630650
return whisperKit
631651
}
652+
return initializationTask
653+
}
654+
655+
func createWithMemoryCheck(
656+
testConfig: TestConfig,
657+
verbose: Bool,
658+
logLevel: Logging.LogLevel
659+
) async throws -> WhisperKit {
660+
// Create the initialization task
661+
let initializationTask = initWhisperKitTask(
662+
testConfig: testConfig,
663+
verbose: verbose,
664+
logLevel: logLevel
665+
)
632666

633667
// Start the memory monitoring task
634668
let monitorTask = Task {

fastlane/Fastfile

+12-4
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ BASE_BENCHMARK_PATH = "#{WORKING_DIR}/benchmark_data".freeze
2323
BASE_UPLOAD_PATH = "#{WORKING_DIR}/upload_folder".freeze
2424
XCRESULT_PATH = File.expand_path("#{BASE_BENCHMARK_PATH}/#{COMMIT_TIMESTAMP}_#{COMMIT_HASH}/")
2525
BENCHMARK_REPO = 'argmaxinc/whisperkit-evals-dataset'.freeze
26-
BENCHMARK_CONFIGS = {
26+
BENCHMARK_CONFIGS ||= {
2727
full: {
2828
test_identifier: 'WhisperAXTests/RegressionTests/testModelPerformance',
2929
name: 'full',
@@ -50,12 +50,14 @@ BENCHMARK_CONFIGS = {
5050
'openai_whisper-large-v3-v20240930_turbo',
5151
'openai_whisper-large-v3-v20240930_626MB',
5252
'openai_whisper-large-v3-v20240930_turbo_632MB'
53-
]
53+
],
54+
repo: 'argmaxinc/whisperkit-coreml'
5455
},
5556
debug: {
5657
test_identifier: 'WhisperAXTests/RegressionTests/testModelPerformanceWithDebugConfig',
5758
name: 'debug',
58-
models: ['tiny', 'crash_test', 'unknown_model', 'small.en']
59+
models: ['tiny', 'crash_test', 'unknown_model', 'small.en'],
60+
repo: 'argmaxinc/whisperkit-coreml'
5961
}
6062
}.freeze
6163

@@ -200,7 +202,9 @@ end
200202

201203
def run_benchmark(devices, config)
202204
summaries = []
203-
BENCHMARK_CONFIGS[config][:models].each do |model|
205+
config_data = BENCHMARK_CONFIGS[config]
206+
207+
config_data[:models].each do |model|
204208
begin
205209
# Sanitize device name for use in file path
206210
devices_to_test = devices.map { |device_info| device_info[:name] }.compact
@@ -228,8 +232,12 @@ def run_benchmark(devices, config)
228232
UI.message "Running in #{BENCHMARK_CONFIGS[config][:name]} mode"
229233

230234
UI.message "Running benchmark for model: #{model}"
235+
UI.message 'Using Hugging Face:'
236+
UI.message " • Repository: #{config_data[:repo]}"
237+
231238
xcargs = [
232239
"MODEL_NAME=#{model}",
240+
"MODEL_REPO=#{config_data[:repo]}",
233241
'-allowProvisioningUpdates',
234242
'-allowProvisioningDeviceRegistration'
235243
].join(' ')

0 commit comments

Comments
 (0)