Skip to content

Commit 03f0bb4

Browse files
Add public callbacks to help expose internal state a little more (#240)
* SegmentDiscovery callback * ModelState callback * FractionCompleted callback * TranscriptionPhaseCallback callback * Updates for review * Formatting * Remove remaining callback from init --------- Co-authored-by: ZachNagengast <[email protected]>
1 parent dd2eb73 commit 03f0bb4

File tree

4 files changed

+113
-4
lines changed

4 files changed

+113
-4
lines changed

Sources/WhisperKit/Core/Models.swift

+41
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,47 @@ public struct TranscriptionProgress {
629629
}
630630
}
631631

632+
// Callbacks to receive state updates during transcription.
633+
634+
/// A callback that provides transcription segments as they are discovered.
635+
/// - Parameters:
636+
/// - segments: An array of `TranscriptionSegment` objects representing the transcribed segments
637+
public typealias SegmentDiscoveryCallback = (_ segments: [TranscriptionSegment]) -> Void
638+
639+
/// A callback that reports changes in the model's state.
640+
/// - Parameters:
641+
/// - oldState: The previous state of the model, if any
642+
/// - newState: The current state of the model
643+
public typealias ModelStateCallback = (_ oldState: ModelState?, _ newState: ModelState) -> Void
644+
645+
/// A callback that reports changes in the transcription process.
646+
/// - Parameter state: The current `TranscriptionState` of the transcription process
647+
public typealias TranscriptionStateCallback = (_ state: TranscriptionState) -> Void
648+
649+
/// Represents the different states of the transcription process.
650+
public enum TranscriptionState: CustomStringConvertible {
651+
/// The audio is being converted to the required format for transcription
652+
case convertingAudio
653+
654+
/// The audio is actively being transcribed to text
655+
case transcribing
656+
657+
/// The transcription process has completed
658+
case finished
659+
660+
/// A human-readable description of the transcription state
661+
public var description: String {
662+
switch self {
663+
case .convertingAudio:
664+
return "Converting Audio"
665+
case .transcribing:
666+
return "Transcribing"
667+
case .finished:
668+
return "Finished"
669+
}
670+
}
671+
}
672+
632673
/// Callback to receive progress updates during transcription.
633674
///
634675
/// - Parameters:

Sources/WhisperKit/Core/TranscribeTask.swift

+4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ final class TranscribeTask {
1515
private let textDecoder: any TextDecoding
1616
private let tokenizer: any WhisperTokenizer
1717

18+
public var segmentDiscoveryCallback: SegmentDiscoveryCallback?
19+
1820
init(
1921
currentTimings: TranscriptionTimings,
2022
progress: Progress?,
@@ -230,6 +232,8 @@ final class TranscribeTask {
230232
}
231233
}
232234

235+
segmentDiscoveryCallback?(currentSegments)
236+
233237
// add them to the `allSegments` list
234238
allSegments.append(contentsOf: currentSegments)
235239
let allCurrentTokens = currentSegments.flatMap { $0.tokens }

Sources/WhisperKit/Core/WhisperKit.swift

+31-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@ import Tokenizers
1313
open class WhisperKit {
1414
/// Models
1515
public private(set) var modelVariant: ModelVariant = .tiny
16-
public private(set) var modelState: ModelState = .unloaded
16+
public private(set) var modelState: ModelState = .unloaded {
17+
didSet {
18+
modelStateCallback?(oldValue, modelState)
19+
}
20+
}
21+
1722
public var modelCompute: ModelComputeOptions
1823
public var tokenizer: WhisperTokenizer?
1924

@@ -42,6 +47,11 @@ open class WhisperKit {
4247
public var tokenizerFolder: URL?
4348
public private(set) var useBackgroundDownloadSession: Bool
4449

50+
/// Callbacks
51+
public var segmentDiscoveryCallback: SegmentDiscoveryCallback?
52+
public var modelStateCallback: ModelStateCallback?
53+
public var transcriptionStateCallback: TranscriptionStateCallback?
54+
4555
public init(_ config: WhisperKitConfig = WhisperKitConfig()) async throws {
4656
modelCompute = config.computeOptions ?? ModelComputeOptions()
4757
audioProcessor = config.audioProcessor ?? AudioProcessor()
@@ -365,7 +375,7 @@ open class WhisperKit {
365375
} else {
366376
currentTimings.decoderLoadTime = CFAbsoluteTimeGetCurrent() - decoderLoadStart
367377
}
368-
378+
369379
Logging.debug("Loaded text decoder in \(String(format: "%.2f", currentTimings.decoderLoadTime))s")
370380
}
371381

@@ -378,13 +388,13 @@ open class WhisperKit {
378388
computeUnits: modelCompute.audioEncoderCompute,
379389
prewarmMode: prewarmMode
380390
)
381-
391+
382392
if prewarmMode {
383393
currentTimings.encoderSpecializationTime = CFAbsoluteTimeGetCurrent() - encoderLoadStart
384394
} else {
385395
currentTimings.encoderLoadTime = CFAbsoluteTimeGetCurrent() - encoderLoadStart
386396
}
387-
397+
388398
Logging.debug("Loaded audio encoder in \(String(format: "%.2f", currentTimings.encoderLoadTime))s")
389399
}
390400

@@ -549,6 +559,8 @@ open class WhisperKit {
549559
decodeOptions: DecodingOptions? = nil,
550560
callback: TranscriptionCallback = nil
551561
) async -> [Result<[TranscriptionResult], Swift.Error>] {
562+
transcriptionStateCallback?(.convertingAudio)
563+
552564
// Start timing the audio loading and conversion process
553565
let loadAudioStart = Date()
554566

@@ -561,6 +573,11 @@ open class WhisperKit {
561573
currentTimings.audioLoading = loadAndConvertTime
562574
Logging.debug("Total Audio Loading and Converting Time: \(loadAndConvertTime)")
563575

576+
transcriptionStateCallback?(.transcribing)
577+
defer {
578+
transcriptionStateCallback?(.finished)
579+
}
580+
564581
// Transcribe the loaded audio arrays
565582
let transcribeResults = await transcribeWithResults(
566583
audioArrays: audioArrays,
@@ -733,6 +750,8 @@ open class WhisperKit {
733750
decodeOptions: DecodingOptions? = nil,
734751
callback: TranscriptionCallback = nil
735752
) async throws -> [TranscriptionResult] {
753+
transcriptionStateCallback?(.convertingAudio)
754+
736755
// Process input audio file into audio samples
737756
let audioArray = try await withThrowingTaskGroup(of: [Float].self) { group -> [Float] in
738757
let convertAudioStart = Date()
@@ -746,6 +765,12 @@ open class WhisperKit {
746765
return try AudioProcessor.loadAudioAsFloatArray(fromPath: audioPath)
747766
}
748767

768+
transcriptionStateCallback?(.transcribing)
769+
defer {
770+
transcriptionStateCallback?(.finished)
771+
}
772+
773+
// Send converted samples to be transcribed
749774
let transcribeResults: [TranscriptionResult] = try await transcribe(
750775
audioArray: audioArray,
751776
decodeOptions: decodeOptions,
@@ -872,6 +897,8 @@ open class WhisperKit {
872897
tokenizer: tokenizer
873898
)
874899

900+
transcribeTask.segmentDiscoveryCallback = self.segmentDiscoveryCallback
901+
875902
let transcribeTaskResult = try await transcribeTask.run(
876903
audioArray: audioArray,
877904
decodeOptions: decodeOptions,

Tests/WhisperKitTests/UnitTests.swift

+37
Original file line numberDiff line numberDiff line change
@@ -1067,6 +1067,43 @@ final class UnitTests: XCTestCase {
10671067
XCTAssertEqual(result.segments.first?.text, " and so my fellow americans ask not what your country can do for you ask what you can do for your country.")
10681068
}
10691069

1070+
func testCallbacks() async throws {
1071+
let config = try WhisperKitConfig(
1072+
modelFolder: tinyModelPath(),
1073+
verbose: true,
1074+
logLevel: .debug,
1075+
load: false
1076+
)
1077+
let whisperKit = try await WhisperKit(config)
1078+
let modelStateExpectation = XCTestExpectation(description: "Model state callback expectation")
1079+
whisperKit.modelStateCallback = { (oldState: ModelState?, newState: ModelState) in
1080+
Logging.debug("Model state: \(newState)")
1081+
modelStateExpectation.fulfill()
1082+
}
1083+
1084+
let segmentDiscoveryExpectation = XCTestExpectation(description: "Segment discovery callback expectation")
1085+
whisperKit.segmentDiscoveryCallback = { (segments: [TranscriptionSegment]) in
1086+
Logging.debug("Segments discovered: \(segments)")
1087+
segmentDiscoveryExpectation.fulfill()
1088+
}
1089+
1090+
let transcriptionStateExpectation = XCTestExpectation(description: "Transcription state callback expectation")
1091+
whisperKit.transcriptionStateCallback = { (state: TranscriptionState) in
1092+
Logging.debug("Transcription state: \(state)")
1093+
transcriptionStateExpectation.fulfill()
1094+
}
1095+
1096+
// Run the full pipeline
1097+
try await whisperKit.loadModels()
1098+
let audioFilePath = try XCTUnwrap(
1099+
Bundle.current.path(forResource: "jfk", ofType: "wav"),
1100+
"Audio file not found"
1101+
)
1102+
let _ = try await whisperKit.transcribe(audioPath: audioFilePath)
1103+
1104+
await fulfillment(of: [modelStateExpectation, segmentDiscoveryExpectation, transcriptionStateExpectation], timeout: 1)
1105+
}
1106+
10701107
// MARK: - Utils Tests
10711108

10721109
func testFillIndexesWithValue() throws {

0 commit comments

Comments
 (0)