@@ -5,19 +5,20 @@ import CoreML
55import Foundation
66
77/// Responsible for transcribing audio chunk to text using the provided models and configurations.
8- final class TranscribeTask {
8+ @available ( macOS 13 , iOS 16 , watchOS 10 , visionOS 1 , * )
9+ open class TranscribeTask {
910 private var timings : TranscriptionTimings
1011 private let progress : Progress
1112 private let audioEncoder : any AudioEncoding
1213 private let featureExtractor : any FeatureExtracting
1314 private let segmentSeeker : any SegmentSeeking
1415 private let textDecoder : any TextDecoding
15- private let tokenizer : any WhisperTokenizer
1616 private let audioProcessor : any AudioProcessing
1717
18+ public private( set) var tokenizer : any WhisperTokenizer
1819 public var segmentDiscoveryCallback : SegmentDiscoveryCallback ?
1920
20- init (
21+ public init (
2122 currentTimings: TranscriptionTimings ,
2223 progress: Progress ? ,
2324 audioProcessor: ( any AudioProcessing ) ? = nil ,
@@ -37,7 +38,23 @@ final class TranscribeTask {
3738 self . tokenizer = tokenizer
3839 }
3940
40- func run(
41+ /// Hook for subclasses to launch work that can run alongside the main decoder pipeline.
42+ open func windowPreprocess(
43+ for paddedAudio: any AudioProcessorOutputType ,
44+ seek: Int ,
45+ segmentSize: Int
46+ ) async { }
47+
48+ /// Hook for subclasses to finalize side work and optionally replace the segments for the current window.
49+ open func windowPostProcess(
50+ seek: Int ,
51+ segmentSize: Int ,
52+ originalSegments: [ TranscriptionSegment ]
53+ ) async -> [ TranscriptionSegment ] {
54+ originalSegments
55+ }
56+
57+ public func run(
4158 audioArray: [ Float ] ,
4259 decodeOptions: DecodingOptions ? = nil ,
4360 callback: TranscriptionCallback = nil
@@ -61,7 +78,6 @@ final class TranscribeTask {
6178 // These accumulate across windows
6279 var allSegments : [ TranscriptionSegment ] = [ ]
6380 var allTokens : [ Int ] = [ ]
64- var transcription = " "
6581
6682 let startDecoderInit = CFAbsoluteTimeGetCurrent ( )
6783 var decoderInputs = try textDecoder. prepareDecoderInputs ( withPrompt: [ tokenizer. specialTokens. startOfTranscriptToken] )
@@ -107,6 +123,7 @@ final class TranscribeTask {
107123
108124 let windowSamples = featureExtractor. windowSamples ?? Constants . defaultWindowSamples
109125 while seek < seekClipEnd - windowPadding {
126+ let windowSeek = seek
110127 // calculate new encoder segment features
111128 let timeOffset = Float ( seek) / Float( WhisperKit . sampleRate)
112129 let segmentSize = min ( windowSamples, contentFrames - seek, seekClipEnd - seek)
@@ -119,6 +136,7 @@ final class TranscribeTask {
119136 guard let audioSamples = audioProcessor. padOrTrim ( fromArray: clipAudioSamples, startAt: 0 , toLength: windowSamples) else {
120137 throw WhisperError . transcriptionFailed ( " Audio samples are nil " )
121138 }
139+ await windowPreprocess ( for: audioSamples, seek: windowSeek, segmentSize: segmentSize)
122140 let processTime = Date ( ) . timeIntervalSince ( audioProcessingStart)
123141 timings. audioProcessing += processTime
124142 timings. totalAudioProcessingRuns += 1
@@ -222,24 +240,30 @@ final class TranscribeTask {
222240 seek = min ( seek, maxSeekOffset)
223241 }
224242
225- guard let currentSegments = currentSegments else {
243+ guard let currentSegments else {
226244 // No current segment found, skip to next window
227245 continue
228246 }
229247
248+ let processedSegments = await windowPostProcess (
249+ seek: windowSeek,
250+ segmentSize: segmentSize,
251+ originalSegments: currentSegments
252+ )
253+
230254 if options. verbose {
231- let lines = TranscriptionUtilities . formatSegments ( currentSegments )
255+ let lines = TranscriptionUtilities . formatSegments ( processedSegments )
232256 Logging . debug ( " Segments for window: " )
233257 for line in lines {
234258 Logging . debug ( line)
235259 }
236260 }
237261
238- segmentDiscoveryCallback ? ( currentSegments )
262+ segmentDiscoveryCallback ? ( processedSegments )
239263
240264 // add them to the `allSegments` list
241- allSegments. append ( contentsOf: currentSegments )
242- let allCurrentTokens = currentSegments . flatMap { $0. tokens }
265+ allSegments. append ( contentsOf: processedSegments )
266+ let allCurrentTokens = processedSegments . flatMap { $0. tokens }
243267 allTokens. append ( contentsOf: allCurrentTokens)
244268
245269 timings. decodingWindowing += Date ( ) . timeIntervalSince ( windowingStart)
@@ -364,8 +388,23 @@ final class TranscribeTask {
364388 timings. decodingLoop = CFAbsoluteTimeGetCurrent ( ) - startDecodeLoopTime
365389 timings. fullPipeline = CFAbsoluteTimeGetCurrent ( ) - timings. pipelineStart
366390
367- let wordTokens = allTokens. filter { $0 < tokenizer. specialTokens. specialTokenBegin }
368- transcription = tokenizer. decode ( tokens: wordTokens) . trimmingCharacters ( in: . whitespaces)
391+ let transcriptionResult = finalizeTranscriptionResult (
392+ tokens: allTokens,
393+ segments: allSegments,
394+ language: detectedLanguage,
395+ timings: timings
396+ )
397+ return transcriptionResult
398+ }
399+
400+ open func finalizeTranscriptionResult(
401+ tokens: [ Int ] ,
402+ segments allSegments: [ TranscriptionSegment ] ,
403+ language detectedLanguage: String ? ,
404+ timings: TranscriptionTimings
405+ ) -> TranscriptionResult {
406+ let wordTokens = tokens. filter { $0 < tokenizer. specialTokens. specialTokenBegin }
407+ let transcription = tokenizer. decode ( tokens: wordTokens) . trimmingCharacters ( in: . whitespaces)
369408 return TranscriptionResult (
370409 text: transcription,
371410 segments: allSegments,
0 commit comments