@@ -13,19 +13,22 @@ import WatchKit
13
13
#endif
14
14
15
15
@available ( macOS 13 , iOS 16 , watchOS 10 , visionOS 1 , * )
16
- final class RegressionTests : XCTestCase {
16
+ class RegressionTests : XCTestCase {
17
17
var audioFileURLs : [ URL ] ?
18
18
var remoteFileURLs : [ URL ] ?
19
19
var metadataURL : URL ?
20
20
var testWERURLs : [ URL ] ?
21
21
var modelsToTest : [ String ] = [ ]
22
+ var modelReposToTest : [ String ] = [ ]
22
23
var modelsTested : [ String ] = [ ]
24
+ var modelReposTested : [ String ] = [ ]
23
25
var optionsToTest : [ DecodingOptions ] = [ DecodingOptions ( ) ]
24
26
25
27
struct TestConfig {
26
28
let dataset : String
27
29
let modelComputeOptions : ModelComputeOptions
28
30
var model : String
31
+ var modelRepo : String
29
32
let decodingOptions : DecodingOptions
30
33
}
31
34
@@ -34,6 +37,7 @@ final class RegressionTests: XCTestCase {
34
37
var datasets = [ " librispeech-10mins " , " earnings22-10mins " ]
35
38
let debugDataset = [ " earnings22-10mins " ]
36
39
let debugModels = [ " tiny " ]
40
+ let debugRepos = [ " argmaxinc/whisperkit-coreml " ]
37
41
38
42
var computeOptions : [ ModelComputeOptions ] = [
39
43
ModelComputeOptions ( audioEncoderCompute: . cpuAndNeuralEngine, textDecoderCompute: . cpuAndNeuralEngine) ,
@@ -71,16 +75,29 @@ final class RegressionTests: XCTestCase {
71
75
Logging . debug ( " Max memory before warning: \( maxMemory) " )
72
76
}
73
77
74
- func testEnvConfigurations( defaultModels: [ String ] ? = nil ) {
78
+ class func getModelToken( ) -> String ? {
79
+ // Add token here or override
80
+ return nil
81
+ }
82
+
83
+ func testEnvConfigurations( defaultModels: [ String ] ? = nil , defaultRepos: [ String ] ? = nil ) {
75
84
if let modelSizeEnv = ProcessInfo . processInfo. environment [ " MODEL_NAME " ] , !modelSizeEnv. isEmpty {
76
85
modelsToTest = [ modelSizeEnv]
77
86
Logging . debug ( " Model size: \( modelSizeEnv) " )
87
+
88
+ if let repoEnv = ProcessInfo . processInfo. environment [ " MODEL_REPO " ] {
89
+ modelReposToTest = [ repoEnv]
90
+ Logging . debug ( " Using repo: \( repoEnv) " )
91
+ }
92
+
78
93
XCTAssertTrue ( modelsToTest. count > 0 , " Invalid model size: \( modelSizeEnv) " )
94
+
79
95
if modelSizeEnv == " crash_test " {
80
96
fatalError ( " Crash test triggered " )
81
97
}
82
98
} else {
83
99
modelsToTest = defaultModels ?? debugModels
100
+ modelReposToTest = defaultRepos ?? debugRepos
84
101
Logging . debug ( " Model size not set by env " )
85
102
}
86
103
}
@@ -116,7 +133,7 @@ final class RegressionTests: XCTestCase {
116
133
117
134
// MARK: - Test Pipeline
118
135
119
- private func runRegressionTests( with testMatrix: [ TestConfig ] ) async throws {
136
+ public func runRegressionTests( with testMatrix: [ TestConfig ] ) async throws {
120
137
var failureInfo : [ String : String ] = [ : ]
121
138
var attachments : [ String : String ] = [ : ]
122
139
let device = getCurrentDevice ( )
@@ -159,8 +176,7 @@ final class RegressionTests: XCTestCase {
159
176
160
177
// Create WhisperKit instance with checks for memory usage
161
178
let whisperKit = try await createWithMemoryCheck (
162
- model: config. model,
163
- computeOptions: config. modelComputeOptions,
179
+ testConfig: config,
164
180
verbose: true ,
165
181
logLevel: . debug
166
182
)
@@ -169,6 +185,8 @@ final class RegressionTests: XCTestCase {
169
185
config. model = modelFile
170
186
modelsTested. append ( modelFile)
171
187
modelsTested = Array ( Set ( modelsTested) )
188
+ modelReposTested. append ( config. modelRepo)
189
+ modelReposTested = Array ( Set ( modelReposTested) )
172
190
}
173
191
174
192
for audioFilePath in audioFilePaths {
@@ -295,6 +313,7 @@ final class RegressionTests: XCTestCase {
295
313
datasetDir: config. dataset,
296
314
datasetRepo: datasetRepo,
297
315
model: config. model,
316
+ modelRepo: config. modelRepo,
298
317
modelSizeMB: modelSizeMB ?? - 1 ,
299
318
date: startTime. formatted ( Date . ISO8601FormatStyle ( ) . dateSeparator ( . dash) ) ,
300
319
timeElapsedInSeconds: Date ( ) . timeIntervalSince ( startTime) ,
@@ -432,20 +451,23 @@ final class RegressionTests: XCTestCase {
432
451
}
433
452
}
434
453
435
- private func getTestMatrix( ) -> [ TestConfig ] {
454
+ public func getTestMatrix( ) -> [ TestConfig ] {
436
455
var regressionTestConfigMatrix : [ TestConfig ] = [ ]
437
456
for dataset in datasets {
438
457
for computeOption in computeOptions {
439
458
for options in optionsToTest {
440
- for model in modelsToTest {
441
- regressionTestConfigMatrix. append (
442
- TestConfig (
443
- dataset: dataset,
444
- modelComputeOptions: computeOption,
445
- model: model,
446
- decodingOptions: options
459
+ for repo in modelReposToTest {
460
+ for model in modelsToTest {
461
+ regressionTestConfigMatrix. append (
462
+ TestConfig (
463
+ dataset: dataset,
464
+ modelComputeOptions: computeOption,
465
+ model: model,
466
+ modelRepo: repo,
467
+ decodingOptions: options
468
+ )
447
469
)
448
- )
470
+ }
449
471
}
450
472
}
451
473
}
@@ -555,6 +577,7 @@ final class RegressionTests: XCTestCase {
555
577
osType: osDetails. osType,
556
578
osVersion: osDetails. osVersion,
557
579
modelsTested: modelsTested,
580
+ modelReposTested: modelReposTested,
558
581
failureInfo: failureInfo,
559
582
attachments: attachments
560
583
)
@@ -610,17 +633,14 @@ final class RegressionTests: XCTestCase {
610
633
return Double ( modelSize / ( 1024 * 1024 ) ) // Convert to MB
611
634
}
612
635
613
- func createWithMemoryCheck(
614
- model: String ,
615
- computeOptions: ModelComputeOptions ,
616
- verbose: Bool ,
617
- logLevel: Logging . LogLevel
618
- ) async throws -> WhisperKit {
636
+ public func initWhisperKitTask( testConfig config: TestConfig , verbose: Bool , logLevel: Logging . LogLevel ) -> Task < WhisperKit , Error > {
619
637
// Create the initialization task
620
638
let initializationTask = Task { ( ) -> WhisperKit in
621
639
let whisperKit = try await WhisperKit ( WhisperKitConfig (
622
- model: model,
623
- computeOptions: computeOptions,
640
+ model: config. model,
641
+ modelRepo: config. modelRepo,
642
+ modelToken: Self . getModelToken ( ) ,
643
+ computeOptions: config. modelComputeOptions,
624
644
verbose: verbose,
625
645
logLevel: logLevel,
626
646
prewarm: true ,
@@ -629,6 +649,20 @@ final class RegressionTests: XCTestCase {
629
649
try Task . checkCancellation ( )
630
650
return whisperKit
631
651
}
652
+ return initializationTask
653
+ }
654
+
655
+ func createWithMemoryCheck(
656
+ testConfig: TestConfig ,
657
+ verbose: Bool ,
658
+ logLevel: Logging . LogLevel
659
+ ) async throws -> WhisperKit {
660
+ // Create the initialization task
661
+ let initializationTask = initWhisperKitTask (
662
+ testConfig: testConfig,
663
+ verbose: verbose,
664
+ logLevel: logLevel
665
+ )
632
666
633
667
// Start the memory monitoring task
634
668
let monitorTask = Task {
0 commit comments