diff --git a/Makefile b/Makefile index d575375..f9165a6 100644 --- a/Makefile +++ b/Makefile @@ -131,6 +131,6 @@ upload-benchmark-results: @fastlane upload_results clean-package-caches: - @trash ~/Library/Developer/Xcode/DerivedData/WhisperKit* + @trash ~/Library/Developer/Xcode/DerivedData/WhisperKit* || true @swift package purge-cache @swift package reset \ No newline at end of file diff --git a/Package.swift b/Package.swift index 8bbea16..315a27b 100644 --- a/Package.swift +++ b/Package.swift @@ -20,8 +20,8 @@ let package = Package( ), ], dependencies: [ - .package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.8"), - .package(url: "https://github.com/apple/swift-argument-parser.git", exact: "1.3.0"), + .package(url: "https://github.com/huggingface/swift-transformers.git", .upToNextMinor(from: "0.1.8")), + .package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.3.0"), ], targets: [ .target( diff --git a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift index 052ba20..bde93a2 100644 --- a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift @@ -16,6 +16,11 @@ public typealias DeviceID = String public struct AudioDevice: Identifiable, Hashable { public let id: DeviceID public let name: String + + public init(id: DeviceID, name: String) { + self.id = id + self.name = name + } } public protocol AudioProcessing { diff --git a/Sources/WhisperKit/Core/FeatureExtractor.swift b/Sources/WhisperKit/Core/FeatureExtractor.swift index 0fb0f68..3025e1f 100644 --- a/Sources/WhisperKit/Core/FeatureExtractor.swift +++ b/Sources/WhisperKit/Core/FeatureExtractor.swift @@ -37,7 +37,7 @@ open class FeatureExtractor: FeatureExtracting, WhisperMLModel { guard inputDescription.type == .multiArray else { return nil } guard let shapeConstraint = inputDescription.multiArrayConstraint else { return nil } let shape = shapeConstraint.shape.map { $0.intValue } - return shape[0] // The audio input is a 1D array + return shape[0] // The audio input is a 1D array } public func logMelSpectrogram(fromAudio inputAudio: MLMultiArray) async throws -> MLMultiArray? { @@ -54,5 +54,4 @@ open class FeatureExtractor: FeatureExtracting, WhisperMLModel { let output = MelSpectrogramOutput(features: outputFeatures) return output.melspectrogramFeatures } - } diff --git a/Sources/WhisperKit/Core/Models.swift b/Sources/WhisperKit/Core/Models.swift index bf7d7eb..9f2fef4 100644 --- a/Sources/WhisperKit/Core/Models.swift +++ b/Sources/WhisperKit/Core/Models.swift @@ -176,11 +176,26 @@ public struct ModelSupport: Codable, Equatable { private enum CodingKeys: String, CodingKey { case `default`, supported } + + public init( + default: String, + supported: [String], + disabled: [String] = [] + ) { + self.default = `default` + self.supported = supported + self.disabled = disabled + } } public struct DeviceSupport: Codable { public let identifiers: [String] public var models: ModelSupport + + public init(identifiers: [String], models: ModelSupport) { + self.identifiers = identifiers + self.models = models + } } public struct ModelSupportConfig: Codable { @@ -280,6 +295,11 @@ public struct ModelSupportConfig: Codable { public struct AudioChunk { public var seekOffsetIndex: Int public var audioSamples: [Float] + + public init(seekOffsetIndex: Int, audioSamples: [Float]) { + self.seekOffsetIndex = seekOffsetIndex + self.audioSamples = audioSamples + } } // MARK: - Decoding @@ -351,7 +371,12 @@ public struct DecodingCache { public var keyCache: MLMultiArray? public var valueCache: MLMultiArray? public var alignmentWeights: MLMultiArray? - public init(keyCache: MLMultiArray? = nil, valueCache: MLMultiArray? = nil, alignmentWeights: MLMultiArray? = nil) { + + public init( + keyCache: MLMultiArray? = nil, + valueCache: MLMultiArray? = nil, + alignmentWeights: MLMultiArray? = nil + ) { self.keyCache = keyCache self.valueCache = valueCache self.alignmentWeights = alignmentWeights @@ -432,7 +457,20 @@ public struct DecodingResult { fallback: nil) } - public init(language: String, languageProbs: [String: Float], tokens: [Int], tokenLogProbs: [[Int: Float]], text: String, avgLogProb: Float, noSpeechProb: Float, temperature: Float, compressionRatio: Float, cache: DecodingCache? = nil, timings: TranscriptionTimings? = nil, fallback: DecodingFallback? = nil) { + public init( + language: String, + languageProbs: [String: Float], + tokens: [Int], + tokenLogProbs: [[Int: Float]], + text: String, + avgLogProb: Float, + noSpeechProb: Float, + temperature: Float, + compressionRatio: Float, + cache: DecodingCache? = nil, + timings: TranscriptionTimings? = nil, + fallback: DecodingFallback? = nil + ) { self.language = language self.languageProbs = languageProbs self.tokens = tokens @@ -510,6 +548,20 @@ public struct TranscriptionResult: Codable { public var timings: TranscriptionTimings public var seekTime: Float? + public init( + text: String, + segments: [TranscriptionSegment], + language: String, + timings: TranscriptionTimings, + seekTime: Float? = nil + ) { + self.text = text + self.segments = segments + self.language = language + self.timings = timings + self.seekTime = seekTime + } + public func logSegments() { for (i, segment) in segments.enumerated() { let start = segment.start @@ -593,18 +645,51 @@ public extension TranscriptionResult { } public struct TranscriptionSegment: Hashable, Codable { - public var id: Int = 0 - public var seek: Int = 0 - public var start: Float = 0.0 - public var end: Float = 0.0 - public var text: String = "" - public var tokens: [Int] = [] - public var tokenLogProbs: [[Int: Float]] = [[:]] - public var temperature: Float = 1.0 - public var avgLogprob: Float = 0.0 - public var compressionRatio: Float = 1.0 - public var noSpeechProb: Float = 0.0 - public var words: [WordTiming]? = nil + public var id: Int + public var seek: Int + public var start: Float + public var end: Float + public var text: String + public var tokens: [Int] + public var tokenLogProbs: [[Int: Float]] + public var temperature: Float + public var avgLogprob: Float + public var compressionRatio: Float + public var noSpeechProb: Float + public var words: [WordTiming]? + + /// Computed property for the duration of the segment + public var duration: Float { + return end - start + } + + public init( + id: Int = 0, + seek: Int = 0, + start: Float = 0.0, + end: Float = 0.0, + text: String = "", + tokens: [Int] = [], + tokenLogProbs: [[Int: Float]] = [[:]], + temperature: Float = 1.0, + avgLogprob: Float = 0.0, + compressionRatio: Float = 1.0, + noSpeechProb: Float = 0.0, + words: [WordTiming]? = nil + ) { + self.id = id + self.seek = seek + self.start = start + self.end = end + self.text = text + self.tokens = tokens + self.tokenLogProbs = tokenLogProbs + self.temperature = temperature + self.avgLogprob = avgLogprob + self.compressionRatio = compressionRatio + self.noSpeechProb = noSpeechProb + self.words = words + } } public struct WordTiming: Hashable, Codable { @@ -613,6 +698,19 @@ public struct WordTiming: Hashable, Codable { public var start: Float public var end: Float public var probability: Float + + /// Computed property for the duration of the word + public var duration: Float { + return end - start + } + + public init(word: String, tokens: [Int], start: Float, end: Float, probability: Float) { + self.word = word + self.tokens = tokens + self.start = start + self.end = end + self.probability = probability + } } public struct TranscriptionProgress { @@ -1198,17 +1296,40 @@ public struct SpecialTokens { } } -public protocol WhisperTokenizer: Tokenizer { +public protocol WhisperTokenizer { + /// swift-transformers pass through + func encode(text: String) -> [Int] + func decode(tokens: [Int]) -> String + func convertTokenToId(_ token: String) -> Int? + func convertIdToToken(_ id: Int) -> String? + + /// WhisperKit specific var specialTokens: SpecialTokens { get } var allLanguageTokens: Set { get } func splitToWordTokens(tokenIds: [Int]) -> (words: [String], wordTokens: [[Int]]) } -struct WhisperTokenizerWrapper: WhisperTokenizer { +open class WhisperTokenizerWrapper: WhisperTokenizer { let tokenizer: any Tokenizer - let specialTokens: SpecialTokens - let allLanguageTokens: Set + public let specialTokens: SpecialTokens + public let allLanguageTokens: Set + + public func encode(text: String) -> [Int] { + tokenizer.encode(text: text) + } + + public func decode(tokens: [Int]) -> String { + tokenizer.decode(tokens: tokens) + } + + public func convertTokenToId(_ token: String) -> Int? { + tokenizer.convertTokenToId(token) + } + + public func convertIdToToken(_ id: Int) -> String? { + tokenizer.convertIdToToken(id) + } init(tokenizer: any Tokenizer) { let specialTokens = SpecialTokens( @@ -1300,7 +1421,7 @@ struct WhisperTokenizerWrapper: WhisperTokenizer { /// Decodes token ids into individual words and per-word subtokens /// - Parameter tokenIds: Array of tokens to decode and then split /// - Returns: Tuple containing and array of the split words and all tokens for each word - func splitToWordTokens(tokenIds: [Int]) -> (words: [String], wordTokens: [[Int]]) { + public func splitToWordTokens(tokenIds: [Int]) -> (words: [String], wordTokens: [[Int]]) { let decodedWords = tokenizer.decode(tokens: tokenIds.filter { $0 < specialTokens.specialTokenBegin }) // Detect language of input text @@ -1316,52 +1437,6 @@ struct WhisperTokenizerWrapper: WhisperTokenizer { } } -extension WhisperTokenizerWrapper: Tokenizer { - func tokenize(text: String) -> [String] { - tokenizer.tokenize(text: text) - } - - func encode(text: String) -> [Int] { - tokenizer.encode(text: text) - } - - func decode(tokens: [Int]) -> String { - tokenizer.decode(tokens: tokens) - } - - func convertTokenToId(_ token: String) -> Int? { - tokenizer.convertTokenToId(token) - } - - func convertIdToToken(_ id: Int) -> String? { - tokenizer.convertIdToToken(id) - } - - var bosToken: String? { - tokenizer.bosToken - } - - var bosTokenId: Int? { - tokenizer.bosTokenId - } - - var eosToken: String? { - tokenizer.eosToken - } - - var eosTokenId: Int? { - tokenizer.eosTokenId - } - - var unknownToken: String? { - tokenizer.unknownToken - } - - var unknownTokenId: Int? { - tokenizer.unknownTokenId - } -} - extension WhisperTokenizerWrapper { /// Default values for each token, using base vocab static var defaultWhitespaceToken: Int { 220 } @@ -1512,6 +1587,9 @@ public enum Constants { public static let defaultWindowSamples: Int = 480_000 // 30s of audio at 16khz sample rate default for Whisper models + public static let defaultPrependPunctuations: String = "\"'“¡¿([{-" + public static let defaultAppendPunctuations: String = "\"'.。,,!!??::”)]}、" + public static let fallbackModelSupportConfig: ModelSupportConfig = { var config = ModelSupportConfig( repoName: "whisperkit-coreml-fallback", diff --git a/Sources/WhisperKit/Core/Text/SegmentSeeker.swift b/Sources/WhisperKit/Core/Text/SegmentSeeker.swift index 33a45dc..a263c7d 100644 --- a/Sources/WhisperKit/Core/Text/SegmentSeeker.swift +++ b/Sources/WhisperKit/Core/Text/SegmentSeeker.swift @@ -280,7 +280,11 @@ open class SegmentSeeker: SegmentSeeking { return (textIndices.reversed(), timeIndices.reversed()) } - func mergePunctuations(alignment: [WordTiming], prepended: String, appended: String) -> [WordTiming] { + func mergePunctuations( + alignment: [WordTiming], + prepended: String = Constants.defaultPrependPunctuations, + appended: String = Constants.defaultAppendPunctuations + ) -> [WordTiming] { var prependedAlignment = [WordTiming]() var appendedAlignment = [WordTiming]() @@ -405,8 +409,8 @@ open class SegmentSeeker: SegmentSeeking { tokenizer: WhisperTokenizer, seek: Int, segmentSize: Int, - prependPunctuations: String, - appendPunctuations: String, + prependPunctuations: String = Constants.defaultPrependPunctuations, + appendPunctuations: String = Constants.defaultAppendPunctuations, lastSpeechTimestamp: Float, options: DecodingOptions, timings: TranscriptionTimings @@ -415,7 +419,6 @@ open class SegmentSeeker: SegmentSeeking { var wordTokenIds = [Int]() var filteredLogProbs = [Float]() var filteredIndices = [Int]() - var lastSpeechTimestamp = lastSpeechTimestamp // Iterate through each segment var indexOffset = 0 @@ -455,6 +458,7 @@ open class SegmentSeeker: SegmentSeeking { Logging.debug("Alignment weights shape: \(filteredAlignmentWeights.shape)") + // Find alignment between text tokens and time indices var alignment = try findAlignment( wordTokenIds: wordTokenIds, alignmentWeights: filteredAlignmentWeights, @@ -465,35 +469,73 @@ open class SegmentSeeker: SegmentSeeking { // TODO: This section is considered a "hack" in the source repo // Reference: https://github.com/openai/whisper/blob/ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab/whisper/timing.py#L305 - var wordDurations = alignment.map { $0.end - $0.start } + let wordDurations = calculateWordDurationConstraints(alignment: alignment) + alignment = truncateLongWordsAtSentenceBoundaries(alignment, maxDuration: wordDurations.max) + + // Process alignment for punctuations + if !alignment.isEmpty { + alignment = mergePunctuations(alignment: alignment, prepended: prependPunctuations, appended: appendPunctuations) + } + + // Update segments based on more accurate word timings + let updatedSegments = updateSegmentsWithWordTimings( + segments: segments, + mergedAlignment: alignment, + seek: seek, + lastSpeechTimestamp: lastSpeechTimestamp, + constrainedMedianDuration: wordDurations.median, + maxDuration: wordDurations.max, + tokenizer: tokenizer + ) + + return updatedSegments + } + + public func calculateWordDurationConstraints(alignment: [WordTiming]) -> (median: Float, max: Float) { + var wordDurations = alignment.map { $0.duration } wordDurations = wordDurations.filter { $0 > 0 } let medianDuration: Float = wordDurations.isEmpty ? 0.0 : wordDurations.sorted(by: <)[wordDurations.count / 2] let constrainedMedianDuration = min(0.7, medianDuration) let maxDuration = constrainedMedianDuration * 2 - // Truncate long words at sentence boundaries + return (constrainedMedianDuration, maxDuration) + } + + public func truncateLongWordsAtSentenceBoundaries(_ alignment: [WordTiming], maxDuration: Float) -> [WordTiming] { let sentenceEndMarks = [".", "。", "!", "!", "?", "?"] - if !wordDurations.isEmpty { - for i in 1.. maxDuration { - if sentenceEndMarks.contains(alignment[i].word) { - alignment[i].end = alignment[i].start + maxDuration - } else if i > 0, sentenceEndMarks.contains(alignment[i - 1].word) { - alignment[i].start = alignment[i].end - maxDuration + var truncatedAlignment = alignment + + if !truncatedAlignment.isEmpty { + for i in 1.. maxDuration { + if sentenceEndMarks.contains(truncatedAlignment[i].word) { + truncatedAlignment[i].end = truncatedAlignment[i].start + maxDuration + } else if i > 0, sentenceEndMarks.contains(truncatedAlignment[i - 1].word) { + truncatedAlignment[i].start = truncatedAlignment[i].end - maxDuration } } } } - // Process alignment for punctuations - let mergedAlignment = mergePunctuations(alignment: alignment, prepended: prependPunctuations, appended: appendPunctuations) + return truncatedAlignment + } - var wordIndex = 0 + public func updateSegmentsWithWordTimings( + segments: [TranscriptionSegment], + mergedAlignment: [WordTiming], + seek: Int, + lastSpeechTimestamp: Float, + constrainedMedianDuration: Float, + maxDuration: Float, + tokenizer: WhisperTokenizer + ) -> [TranscriptionSegment] { let timeOffset = Float(seek) / Float(WhisperKit.sampleRate) + var wordIndex = 0 + var lastSpeechTimestamp = lastSpeechTimestamp // Create a mutable copy, it will be updated below var updatedSegments = [TranscriptionSegment]() - for segment in segments { + for (segmentIndex, segment) in segments.enumerated() { var savedTokens = 0 let textTokens = segment.tokens.filter { $0 < tokenizer.specialTokens.specialTokenBegin } var wordsInSegment = [WordTiming]() @@ -507,10 +549,50 @@ open class SegmentSeeker: SegmentSeeking { continue } - let start = (timeOffset + timing.start).rounded(2) + // Retokenize the word if some tokens were filtered out + let word = timingTokens.count < timing.tokens.count ? + tokenizer.decode(tokens: timingTokens) : + timing.word + + var start = (timeOffset + timing.start).rounded(2) let end = (timeOffset + timing.end).rounded(2) + + // Handle short duration words by moving their start time back if there's space + // There is most commonly space when there is timestamp tokens or merged punctuations between words + if end - start < constrainedMedianDuration / 4 { + if wordsInSegment.count >= 1, + let previousTiming = wordsInSegment.last + { + let previousEnd = previousTiming.end + + // If there's space between this word and the previous word + if start > previousEnd { + // Move the word back, using either median duration or space available + // Eg: [[0.5 - 2.0], [3.0 - 3.0]] (two word timings with one second gap) + // -> [[0.5 - 2.0], [2.7 - 3.0]] (second word start moved back) + // however, if there is no space, it will not be adjusted + // Eg: [[0.5 - 2.0], [2.0 - 2.0]] (two word timings with no gap) + // -> [[0.5 - 2.0], [2.0 - 2.0]] (second word start not moved back) + let spaceAvailable = start - previousEnd + let desiredDuration = min(spaceAvailable, constrainedMedianDuration / 2) + start = (start - desiredDuration).rounded(2) + } + } else if wordsInSegment.isEmpty, + segmentIndex > 0, + updatedSegments.count > segmentIndex - 1, + start > updatedSegments[segmentIndex - 1].end + { + // First word of segment - check space from last segment + // Eg: [[0.5 - 1.5], [1.5 - 2.0]], [[3.0 - 3.0]] (two segments with one second gap) + // -> [[0.5 - 1.5], [1.5 - 2.0]], [[2.7 - 3.0]] (first word start of new segment moved back) + let spaceAvailable = start - updatedSegments[segmentIndex - 1].end + let desiredDuration = min(spaceAvailable, constrainedMedianDuration / 2) + start = (start - desiredDuration).rounded(2) + } + } + let probability = timing.probability.rounded(2) - let wordTiming = WordTiming(word: timing.word, + let wordTiming = WordTiming(word: word, tokens: timingTokens, start: start, end: end, @@ -526,32 +608,40 @@ open class SegmentSeeker: SegmentSeeking { // TODO: This section is considered a "hack" in the source repo // Reference: https://github.com/openai/whisper/blob/ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab/whisper/timing.py#L342 // Truncate long words at segment boundaries - if let firstWord = wordsInSegment.first, let lastWord = wordsInSegment.last { - // Logic for the first word - if firstWord.end - lastSpeechTimestamp > constrainedMedianDuration * 4 && - (firstWord.end - firstWord.start > maxDuration || - (wordsInSegment.count > 1 && wordsInSegment[1].end - firstWord.start > maxDuration * 2)) + if let firstWord = wordsInSegment.first { + // Ensure the first and second word after a pause is not longer than + // twice the median word duration. + let pauseLength = firstWord.end - lastSpeechTimestamp + let firstWordTooLong = firstWord.duration > maxDuration + let bothWordsTooLong = wordsInSegment.count > 1 && wordsInSegment[1].end - firstWord.start > maxDuration * 2 + if pauseLength > constrainedMedianDuration * 4 && + (firstWordTooLong || bothWordsTooLong) { - if wordsInSegment.count > 1 && wordsInSegment[1].end - wordsInSegment[1].start > maxDuration { + // First word or both words are too long + if wordsInSegment.count > 1 && wordsInSegment[1].duration > maxDuration { + // Second word is too long, find a good boundary to readjust the words let boundary = max(wordsInSegment[1].end / 2, wordsInSegment[1].end - maxDuration) wordsInSegment[0].end = boundary wordsInSegment[1].start = boundary } - wordsInSegment[0].start = max(lastSpeechTimestamp, firstWord.end - maxDuration) + // In either case, make sure the first word is not too long + wordsInSegment[0].start = max(lastSpeechTimestamp, wordsInSegment[0].end - maxDuration) } // Prefer segment-level start timestamp if the first word is too long. - if segment.start < firstWord.end && segment.start - 0.5 > firstWord.start { - wordsInSegment[0].start = max(0, min(firstWord.end - constrainedMedianDuration, segment.start)) + if segment.start < wordsInSegment[0].end && segment.start - 0.5 > wordsInSegment[0].start { + wordsInSegment[0].start = max(0, min(wordsInSegment[0].end - constrainedMedianDuration, segment.start)) } else { - updatedSegment.start = firstWord.start + updatedSegment.start = wordsInSegment[0].start } - // Prefer segment-level end timestamp if the last word is too long. - if updatedSegment.end > lastWord.start && segment.end + 0.5 < lastWord.end { - wordsInSegment[wordsInSegment.count - 1].end = max(lastWord.start + constrainedMedianDuration, segment.end) - } else { - updatedSegment.end = lastWord.end + if let lastWord = wordsInSegment.last { + // Prefer segment-level end timestamp if the last word is too long. + if updatedSegment.end > lastWord.start && segment.end + 0.5 < lastWord.end { + wordsInSegment[wordsInSegment.count - 1].end = max(lastWord.start + constrainedMedianDuration, segment.end) + } else { + updatedSegment.end = lastWord.end + } } lastSpeechTimestamp = updatedSegment.end diff --git a/Sources/WhisperKit/Core/Text/TokenSampler.swift b/Sources/WhisperKit/Core/Text/TokenSampler.swift index 2193161..3ece220 100644 --- a/Sources/WhisperKit/Core/Text/TokenSampler.swift +++ b/Sources/WhisperKit/Core/Text/TokenSampler.swift @@ -14,6 +14,16 @@ public struct SamplingResult { public var tokens: [Int] public var logProbs: [Float] public var completed: Bool + + public init( + tokens: [Int], + logProbs: [Float], + completed: Bool + ) { + self.tokens = tokens + self.logProbs = logProbs + self.completed = completed + } } @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) diff --git a/Sources/WhisperKit/Core/TextDecoder.swift b/Sources/WhisperKit/Core/TextDecoder.swift index f0e8219..f8f79b8 100644 --- a/Sources/WhisperKit/Core/TextDecoder.swift +++ b/Sources/WhisperKit/Core/TextDecoder.swift @@ -686,12 +686,23 @@ open class TextDecoder: TextDecoding, WhisperMLModel { let loopStart = Date() let isPrefill = tokenIndex < intialPromptIndex - 1 // Prefill stops at the last token of the initial prompt + let isLastPrefillToken = tokenIndex == intialPromptIndex - 1 let isFirstToken = tokenIndex == prefilledIndex // Check if current index is part of the initial prompt if tokenIndex < intialPromptIndex { - nextToken = currentTokens[tokenIndex] - Logging.debug("Forcing prompt tokenIndex: \(tokenIndex), token: \(nextToken), text: \(tokenizer.decode(tokens: [nextToken]))") + let isTimestampToken = currentTokens[tokenIndex] >= tokenizer.specialTokens.timeTokenBegin + let modelPredictedTimestamp = nextToken >= tokenizer.specialTokens.timeTokenBegin + + // Force the token unless it's the last prefill token and both are timestamps + if !(isLastPrefillToken && isTimestampToken && modelPredictedTimestamp) { + nextToken = currentTokens[tokenIndex] + Logging.debug("Forcing prompt tokenIndex: \(tokenIndex), token: \(nextToken), text: \(tokenizer.decode(tokens: [nextToken]))") + } else { + // Last prefill was a timestamp but the model predicted a timestamp + currentTokens[tokenIndex] = nextToken + Logging.debug("Skipping prompt tokenIndex: \(tokenIndex), token: \(nextToken), text: \(tokenizer.decode(tokens: [nextToken]))") + } } // Set the current token as model input diff --git a/Sources/WhisperKit/Core/TranscribeTask.swift b/Sources/WhisperKit/Core/TranscribeTask.swift index cc3565f..e149104 100644 --- a/Sources/WhisperKit/Core/TranscribeTask.swift +++ b/Sources/WhisperKit/Core/TranscribeTask.swift @@ -192,8 +192,8 @@ final class TranscribeTask { tokenizer: tokenizer, seek: previousSeek, segmentSize: segmentSize, - prependPunctuations: "\"'“¿([{-", - appendPunctuations: "\"'.。,,!!??::”)]}、", + prependPunctuations: Constants.defaultPrependPunctuations, + appendPunctuations: Constants.defaultAppendPunctuations, lastSpeechTimestamp: Float(Double(previousSeek) / Double(WhisperKit.sampleRate)), options: options, timings: timings diff --git a/Sources/WhisperKit/Core/Utils/Utils.swift b/Sources/WhisperKit/Core/Utils/Utils.swift index 241997d..7f4a6e1 100644 --- a/Sources/WhisperKit/Core/Utils/Utils.swift +++ b/Sources/WhisperKit/Core/Utils/Utils.swift @@ -56,7 +56,7 @@ extension Array where Element: Hashable { } extension MLMultiArray { - /// Calculate the linear offset by summing the products of each dimension’s index with the dimension’s stride. + /// Calculate the linear offset by summing the products of each dimension's index with the dimension's stride. /// More info [here](https://developer.apple.com/documentation/coreml/mlmultiarray/2879231-subscript) /// - Parameters: /// - index: The index of the element @@ -179,7 +179,7 @@ public extension MLTensor { } #endif -extension MLModel { +public extension MLModel { func asyncPrediction( from input: MLFeatureProvider, options: MLPredictionOptions @@ -251,11 +251,8 @@ public extension Float { public extension String { var normalized: String { - // Trim whitespace and newlines - let trimmedString = self.trimmingCharacters(in: .whitespacesAndNewlines) - // Convert to lowercase - let lowercaseString = trimmedString.lowercased() + let lowercaseString = self.lowercased() // Replace dashes with spaces let noDashesString = lowercaseString.replacingOccurrences(of: "-", with: " ") @@ -266,7 +263,10 @@ public extension String { // Replace multiple spaces with a single space let singleSpacedString = noPunctuationString.replacingOccurrences(of: " +", with: " ", options: .regularExpression) - return singleSpacedString + // Trim whitespace and newlines + let trimmedString = singleSpacedString.trimmingCharacters(in: .whitespacesAndNewlines) + + return trimmedString } func trimmingSpecialTokenCharacters() -> String { diff --git a/Tests/WhisperKitTests/UnitTests.swift b/Tests/WhisperKitTests/UnitTests.swift index 0fd780b..0dc58d6 100644 --- a/Tests/WhisperKitTests/UnitTests.swift +++ b/Tests/WhisperKitTests/UnitTests.swift @@ -1260,7 +1260,6 @@ final class UnitTests: XCTestCase { await fulfillment(of: [modelStateExpectation, segmentDiscoveryExpectation, transcriptionStateExpectation], timeout: 1) } - #if !os(watchOS) // FIXME: watchOS ignores the priority here for some reason func testCallbackWithEarlyStopping() async throws { let callbackTestTask = Task(priority: .userInitiated) { let computeOptions = ModelComputeOptions( @@ -1315,12 +1314,16 @@ final class UnitTests: XCTestCase { // Assert that more tokens are returned in the callback with waiting XCTAssertGreaterThanOrEqual(tokenCountWithWait, 30, "Tokens for callback with wait should contain the full audio file") - XCTAssertGreaterThan(tokenCountWithWait, tokenCountWithEarlyStop, "More tokens should be returned in the callback with waiting") + + #if os(watchOS) || os(iOS) // FIXME: Some OS ignore the priority here on github action runners for some reason + XCTAssertGreaterThanOrEqual(tokenCountWithWait, tokenCountWithEarlyStop, "More tokens should be returned in the callback with waiting (early stop: \(tokenCountWithEarlyStop), with wait: \(tokenCountWithWait))") + #else + XCTAssertGreaterThan(tokenCountWithWait, tokenCountWithEarlyStop, "More tokens should be returned in the callback with waiting (early stop: \(tokenCountWithEarlyStop), with wait: \(tokenCountWithWait))") + #endif } try await callbackTestTask.value } - #endif // MARK: - Utils Tests @@ -1354,6 +1357,32 @@ final class UnitTests: XCTestCase { XCTAssertEqual([1, 2, 3, 4].batched(into: 3), [[1, 2, 3], [4]]) } + func testStringNormalization() throws { + // Basic cases + XCTAssertEqual("Hello World!".normalized, "hello world") + XCTAssertEqual("hello world".normalized, "hello world") + XCTAssertEqual("HELLO WORLD".normalized, "hello world") + + // Punctuation + XCTAssertEqual("Hello, World!".normalized, "hello world") + XCTAssertEqual("Hello... World???".normalized, "hello world") + XCTAssertEqual("'Hello' \"World\"".normalized, "hello world") + + // Dashes and hyphens + XCTAssertEqual("hello-world".normalized, "hello world") + XCTAssertEqual("hello -- world".normalized, "hello world") + + // Whitespace handling + XCTAssertEqual(" hello world ".normalized, "hello world") + + // Mixed cases + XCTAssertEqual("Hello!!!---World???".normalized, "hello world") + + // Numbers and special characters + XCTAssertEqual("Hello: World".normalized, "hello world") + XCTAssertEqual("Hello% world 100".normalized, "hello world 100") + } + func testTrimmingSpecialTokenCharacters() { XCTAssertEqual("<|en|>".trimmingSpecialTokenCharacters(), "en") XCTAssertEqual("<|endoftext|>".trimmingSpecialTokenCharacters(), "endoftext") @@ -1898,7 +1927,7 @@ final class UnitTests: XCTestCase { WordTiming(word: "<|endoftext|>", tokens: [50257], start: 12, end: 13, probability: 1), ] - let mergedAlignmentTiming = SegmentSeeker().mergePunctuations(alignment: wordTimings, prepended: "\"'“¡¿([{-", appended: "\"'.。,,!!??::”)]}、") + let mergedAlignmentTiming = SegmentSeeker().mergePunctuations(alignment: wordTimings) let expectedWordTimings = [ WordTiming(word: "<|notimestamps|>", tokens: [50363], start: 0, end: 1, probability: 1), @@ -1942,7 +1971,7 @@ final class UnitTests: XCTestCase { WordTiming(word: "<|endoftext|>", tokens: [50257], start: 11, end: 12, probability: 1), ] - let mergedAlignmentTiming = SegmentSeeker().mergePunctuations(alignment: wordTimings, prepended: "\"'“¿([{-", appended: "\"'.。,,!!??::”)]}、") + let mergedAlignmentTiming = SegmentSeeker().mergePunctuations(alignment: wordTimings) let expectedWordTimings = [ WordTiming(word: "<|0.00|>", tokens: [50364], start: 0, end: 1, probability: 1), @@ -1969,6 +1998,30 @@ final class UnitTests: XCTestCase { } } + func testWordTimingComparison() throws { + let word1 = WordTiming(word: "Hello!", tokens: [1], start: 0, end: 1, probability: 1.0) + let word2 = WordTiming(word: "hello", tokens: [2], start: 0, end: 1, probability: 1.0) + let word3 = WordTiming(word: "World!", tokens: [3], start: 1, end: 2, probability: 1.0) + let word4 = WordTiming(word: "WORLD", tokens: [4], start: 1, end: 2, probability: 1.0) + + // Test common prefix finding + let sequence1 = [word1, word3] + let sequence2 = [word2, word4] + + let commonPrefix = findLongestCommonPrefix(sequence1, sequence2) + XCTAssertEqual(commonPrefix.count, 2) + XCTAssertEqual(commonPrefix[0].word, "hello") + XCTAssertEqual(commonPrefix[1].word, "WORLD") + + // Test different suffix finding + let sequence3 = [word1, word3] + let sequence4 = [word2, word4, WordTiming(word: "suffix", tokens: [5], start: 2, end: 3, probability: 1.0)] + + let differentSuffix = findLongestDifferentSuffix(sequence3, sequence4) + XCTAssertEqual(differentSuffix.count, 1) + XCTAssertEqual(differentSuffix[0].word, "suffix") + } + func testWordTimestampCorrectness() async throws { let options = DecodingOptions(wordTimestamps: true) @@ -2030,6 +2083,191 @@ final class UnitTests: XCTestCase { } } + func testLongWordDurations() async { + // Test case with words spanning very long durations + let wordTimings = [ + // Normal duration words + WordTiming(word: " The", tokens: [264], start: 0.5, end: 1.0, probability: 1), + WordTiming(word: " first", tokens: [4589], start: 1.0, end: 2.0, probability: 1), + // Long duration word + WordTiming(word: " segment", tokens: [234], start: 2.0, end: 3.0, probability: 1), + // Normal words + WordTiming(word: " with", tokens: [567], start: 3.0, end: 4.0, probability: 1), + WordTiming(word: " a", tokens: [257], start: 4.0, end: 5.0, probability: 1), + // Very long duration word + WordTiming(word: " long", tokens: [890], start: 5.0, end: 6.0, probability: 1), + // Normal duration ending + WordTiming(word: " ending", tokens: [123], start: 6.0, end: 35.0, probability: 1), + WordTiming(word: ".", tokens: [13], start: 35.0, end: 35.0, probability: 1), + ] + + let segmentSeeker = SegmentSeeker() + + // Create test segments + let segments = [ + TranscriptionSegment( + id: 0, + seek: 0, + start: 0.0, + end: 6.0, + text: "The first segment with a long", + tokens: [264, 4589, 234, 567, 257, 890], // Using tokens from wordTimings + tokenLogProbs: Array(repeating: [0: 0.0], count: 6) + ), + TranscriptionSegment( + id: 1, + seek: 0, + start: 6.5, // slight difference to test truncation + end: 30.0, // slight difference to test truncation + text: "ending.", + tokens: [123, 13], + tokenLogProbs: Array(repeating: [0: 0.0], count: 2) + ), + ] + + // Test duration constraints calculation + let (constrainedMedianDuration, maxDuration) = segmentSeeker.calculateWordDurationConstraints(alignment: wordTimings) + XCTAssertEqual(constrainedMedianDuration, 0.7, "Constrained median duration should be capped at 0.7") + XCTAssertEqual(maxDuration, 1.4, "Max duration should be double the constrained median duration") + + // Test truncation of long words + let truncatedAlignment = segmentSeeker.truncateLongWordsAtSentenceBoundaries(wordTimings, maxDuration: maxDuration) + let mergedAlignment = segmentSeeker.mergePunctuations(alignment: truncatedAlignment) + let updatedSegments = segmentSeeker.updateSegmentsWithWordTimings( + segments: segments, + mergedAlignment: mergedAlignment, + seek: 0, + lastSpeechTimestamp: 0, + constrainedMedianDuration: constrainedMedianDuration, + maxDuration: maxDuration, + tokenizer: try! await loadTokenizer(for: .tiny) + ) + + let updatedWords = updatedSegments.compactMap { $0.words }.flatMap { $0 } + + // Test that long segments are properly truncated + let longSegmentDuration = updatedWords.last!.duration + XCTAssertEqual( + longSegmentDuration, + maxDuration, + accuracy: 0.0001, + "Long segment duration (\(longSegmentDuration)s) should be truncated to maximum allowed duration (\(maxDuration)s)" + ) + + // Test that segments were properly updated + XCTAssertEqual(updatedSegments.count, segments.count, "Number of segments should remain the same") + XCTAssertLessThanOrEqual(updatedSegments.last!.end - updatedSegments.last!.start, 19.5, + "Segment duration should not exceed max duration") + + // Test the long word + let longWordIndex = updatedWords.firstIndex { $0.word == " ending." }! + let longWordDuration = updatedWords[longWordIndex].duration + XCTAssertEqual( + longWordDuration, + maxDuration, + accuracy: 0.0001, + "Very long pause duration (\(longWordDuration)s) should be truncated to maximum allowed duration (\(maxDuration)s)" + ) + + // Test that the long word is truncated after pause + XCTAssertEqual( + updatedWords[longWordIndex].start, + 33.6, + "Long word start time should remain unchanged" + ) + + // Test that the timing sequence remains monotonic + for i in 1..", tokens: [50363], start: 0, end: 0.5, probability: 1), + // Single token with 20s duration + WordTiming(word: " Hello", tokens: [314], start: 0.5, end: 20.5, probability: 1), + WordTiming(word: "<|endoftext|>", tokens: [50257], start: 20.5, end: 30, probability: 1), + ] + + let segmentSeeker = SegmentSeeker() + + // Create test segments + let segments = [ + TranscriptionSegment( + id: 0, + seek: 0, + start: 0.0, + end: 30.0, + text: "Hello", + tokens: [314], + tokenLogProbs: Array(repeating: [0: 0.0], count: 1) + ), + ] + + // Test duration constraints calculation + let (constrainedMedianDuration, maxDuration) = segmentSeeker.calculateWordDurationConstraints(alignment: wordTimings) + XCTAssertEqual(constrainedMedianDuration, 0.7, "Constrained median duration should be capped at 0.7") + XCTAssertEqual(maxDuration, 1.4, "Max duration should be double the constrained median duration") + + // Test truncation of long words + let truncatedAlignment = segmentSeeker.truncateLongWordsAtSentenceBoundaries(wordTimings, maxDuration: maxDuration) + let mergedAlignment = segmentSeeker.mergePunctuations(alignment: truncatedAlignment) + let updatedSegments = segmentSeeker.updateSegmentsWithWordTimings( + segments: segments, + mergedAlignment: mergedAlignment, + seek: 0, + lastSpeechTimestamp: 0, + constrainedMedianDuration: constrainedMedianDuration, + maxDuration: maxDuration, + tokenizer: try! await loadTokenizer(for: .tiny) + ) + + let updatedWords = updatedSegments.first!.words! + + // Test that long single tokens are properly truncated + let firstLongIndex = updatedWords.firstIndex { $0.word == " Hello" }! + let firstLongDuration = updatedWords[firstLongIndex].duration + XCTAssertLessThanOrEqual( + firstLongDuration, + maxDuration, + "Long single token duration (\(firstLongDuration)s) should be truncated" + ) + + // Verify timing sequence remains valid + var previousEnd: Float = 0 + for timing in updatedWords { + XCTAssertGreaterThanOrEqual( + timing.start, + previousEnd, + "Start time should not be earlier than previous end time" + ) + XCTAssertLessThanOrEqual( + timing.duration, + maxDuration, + "No word duration should significantly exceed maximum allowed duration" + ) + previousEnd = timing.end + } + } + // MARK: - Streaming Timestamp Tests func testStreamingTimestamps() async throws {