diff --git a/Sources/WhisperKit/Core/Models.swift b/Sources/WhisperKit/Core/Models.swift index 5ca6995f..325e0622 100644 --- a/Sources/WhisperKit/Core/Models.swift +++ b/Sources/WhisperKit/Core/Models.swift @@ -629,6 +629,47 @@ public struct TranscriptionProgress { } } +// Callbacks to receive state updates during transcription. + +/// A callback that provides transcription segments as they are discovered. +/// - Parameters: +/// - segments: An array of `TranscriptionSegment` objects representing the transcribed segments +public typealias SegmentDiscoveryCallback = (_ segments: [TranscriptionSegment]) -> Void + +/// A callback that reports changes in the model's state. +/// - Parameters: +/// - oldState: The previous state of the model, if any +/// - newState: The current state of the model +public typealias ModelStateCallback = (_ oldState: ModelState?, _ newState: ModelState) -> Void + +/// A callback that reports changes in the transcription process. +/// - Parameter state: The current `TranscriptionState` of the transcription process +public typealias TranscriptionStateCallback = (_ state: TranscriptionState) -> Void + +/// Represents the different states of the transcription process. +public enum TranscriptionState: CustomStringConvertible { + /// The audio is being converted to the required format for transcription + case convertingAudio + + /// The audio is actively being transcribed to text + case transcribing + + /// The transcription process has completed + case finished + + /// A human-readable description of the transcription state + public var description: String { + switch self { + case .convertingAudio: + return "Converting Audio" + case .transcribing: + return "Transcribing" + case .finished: + return "Finished" + } + } +} + /// Callback to receive progress updates during transcription. /// /// - Parameters: diff --git a/Sources/WhisperKit/Core/TranscribeTask.swift b/Sources/WhisperKit/Core/TranscribeTask.swift index a6939e7a..0ec031a3 100644 --- a/Sources/WhisperKit/Core/TranscribeTask.swift +++ b/Sources/WhisperKit/Core/TranscribeTask.swift @@ -15,6 +15,8 @@ final class TranscribeTask { private let textDecoder: any TextDecoding private let tokenizer: any WhisperTokenizer + public var segmentDiscoveryCallback: SegmentDiscoveryCallback? + init( currentTimings: TranscriptionTimings, progress: Progress?, @@ -230,6 +232,8 @@ final class TranscribeTask { } } + segmentDiscoveryCallback?(currentSegments) + // add them to the `allSegments` list allSegments.append(contentsOf: currentSegments) let allCurrentTokens = currentSegments.flatMap { $0.tokens } diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index 88a665fc..5f617e6a 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -13,7 +13,12 @@ import Tokenizers open class WhisperKit { /// Models public private(set) var modelVariant: ModelVariant = .tiny - public private(set) var modelState: ModelState = .unloaded + public private(set) var modelState: ModelState = .unloaded { + didSet { + modelStateCallback?(oldValue, modelState) + } + } + public var modelCompute: ModelComputeOptions public var tokenizer: WhisperTokenizer? @@ -42,6 +47,11 @@ open class WhisperKit { public var tokenizerFolder: URL? public private(set) var useBackgroundDownloadSession: Bool + /// Callbacks + public var segmentDiscoveryCallback: SegmentDiscoveryCallback? + public var modelStateCallback: ModelStateCallback? + public var transcriptionStateCallback: TranscriptionStateCallback? + public init(_ config: WhisperKitConfig = WhisperKitConfig()) async throws { modelCompute = config.computeOptions ?? ModelComputeOptions() audioProcessor = config.audioProcessor ?? AudioProcessor() @@ -365,7 +375,7 @@ open class WhisperKit { } else { currentTimings.decoderLoadTime = CFAbsoluteTimeGetCurrent() - decoderLoadStart } - + Logging.debug("Loaded text decoder in \(String(format: "%.2f", currentTimings.decoderLoadTime))s") } @@ -378,13 +388,13 @@ open class WhisperKit { computeUnits: modelCompute.audioEncoderCompute, prewarmMode: prewarmMode ) - + if prewarmMode { currentTimings.encoderSpecializationTime = CFAbsoluteTimeGetCurrent() - encoderLoadStart } else { currentTimings.encoderLoadTime = CFAbsoluteTimeGetCurrent() - encoderLoadStart } - + Logging.debug("Loaded audio encoder in \(String(format: "%.2f", currentTimings.encoderLoadTime))s") } @@ -549,6 +559,8 @@ open class WhisperKit { decodeOptions: DecodingOptions? = nil, callback: TranscriptionCallback = nil ) async -> [Result<[TranscriptionResult], Swift.Error>] { + transcriptionStateCallback?(.convertingAudio) + // Start timing the audio loading and conversion process let loadAudioStart = Date() @@ -561,6 +573,11 @@ open class WhisperKit { currentTimings.audioLoading = loadAndConvertTime Logging.debug("Total Audio Loading and Converting Time: \(loadAndConvertTime)") + transcriptionStateCallback?(.transcribing) + defer { + transcriptionStateCallback?(.finished) + } + // Transcribe the loaded audio arrays let transcribeResults = await transcribeWithResults( audioArrays: audioArrays, @@ -733,6 +750,8 @@ open class WhisperKit { decodeOptions: DecodingOptions? = nil, callback: TranscriptionCallback = nil ) async throws -> [TranscriptionResult] { + transcriptionStateCallback?(.convertingAudio) + // Process input audio file into audio samples let audioArray = try await withThrowingTaskGroup(of: [Float].self) { group -> [Float] in let convertAudioStart = Date() @@ -746,6 +765,12 @@ open class WhisperKit { return try AudioProcessor.loadAudioAsFloatArray(fromPath: audioPath) } + transcriptionStateCallback?(.transcribing) + defer { + transcriptionStateCallback?(.finished) + } + + // Send converted samples to be transcribed let transcribeResults: [TranscriptionResult] = try await transcribe( audioArray: audioArray, decodeOptions: decodeOptions, @@ -872,6 +897,8 @@ open class WhisperKit { tokenizer: tokenizer ) + transcribeTask.segmentDiscoveryCallback = self.segmentDiscoveryCallback + let transcribeTaskResult = try await transcribeTask.run( audioArray: audioArray, decodeOptions: decodeOptions, diff --git a/Tests/WhisperKitTests/UnitTests.swift b/Tests/WhisperKitTests/UnitTests.swift index df258a46..0ec1b2d0 100644 --- a/Tests/WhisperKitTests/UnitTests.swift +++ b/Tests/WhisperKitTests/UnitTests.swift @@ -1067,6 +1067,43 @@ final class UnitTests: XCTestCase { XCTAssertEqual(result.segments.first?.text, " and so my fellow americans ask not what your country can do for you ask what you can do for your country.") } + func testCallbacks() async throws { + let config = try WhisperKitConfig( + modelFolder: tinyModelPath(), + verbose: true, + logLevel: .debug, + load: false + ) + let whisperKit = try await WhisperKit(config) + let modelStateExpectation = XCTestExpectation(description: "Model state callback expectation") + whisperKit.modelStateCallback = { (oldState: ModelState?, newState: ModelState) in + Logging.debug("Model state: \(newState)") + modelStateExpectation.fulfill() + } + + let segmentDiscoveryExpectation = XCTestExpectation(description: "Segment discovery callback expectation") + whisperKit.segmentDiscoveryCallback = { (segments: [TranscriptionSegment]) in + Logging.debug("Segments discovered: \(segments)") + segmentDiscoveryExpectation.fulfill() + } + + let transcriptionStateExpectation = XCTestExpectation(description: "Transcription state callback expectation") + whisperKit.transcriptionStateCallback = { (state: TranscriptionState) in + Logging.debug("Transcription state: \(state)") + transcriptionStateExpectation.fulfill() + } + + // Run the full pipeline + try await whisperKit.loadModels() + let audioFilePath = try XCTUnwrap( + Bundle.current.path(forResource: "jfk", ofType: "wav"), + "Audio file not found" + ) + let _ = try await whisperKit.transcribe(audioPath: audioFilePath) + + await fulfillment(of: [modelStateExpectation, segmentDiscoveryExpectation, transcriptionStateExpectation], timeout: 1) + } + // MARK: - Utils Tests func testFillIndexesWithValue() throws {