Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve SegmentSeeker word alignment #305

Merged
merged 27 commits into from
Feb 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
85e847a
Reorder text normalization to remove whitespace at final step
ZachNagengast Feb 1, 2025
0cd78a2
Fix word alignment for long words or short token length segments
ZachNagengast Feb 20, 2025
1775c39
Support protocols from new swift-transformers version
ZachNagengast Feb 20, 2025
fe7c8e5
Merge branch 'main' into fix-word-timestamps-and-normalization
ZachNagengast Feb 20, 2025
2af850c
Backwards support for older swift-transformers
ZachNagengast Feb 20, 2025
9313eac
Formatting
ZachNagengast Feb 20, 2025
797d274
Revert default packages
ZachNagengast Feb 20, 2025
0423be9
Fix package.resolved
ZachNagengast Feb 20, 2025
2e905fe
Fix pacakge.resolved
ZachNagengast Feb 20, 2025
3ca28fc
Remove protocol conformance for tokenizer
ZachNagengast Feb 20, 2025
ffc2515
Add public attributes
a2they Feb 20, 2025
7c98d63
Add public initializers
ZachNagengast Feb 20, 2025
d725821
Add computed properties for duration in TranscriptionSegment and Word…
ZachNagengast Feb 20, 2025
d4c62e6
Lower token count for early stopping test
ZachNagengast Feb 20, 2025
64938b7
Attempt to fix early stopping test using detached tasks
ZachNagengast Feb 21, 2025
287d0ab
Use file with more tokens for early stopping test
ZachNagengast Feb 21, 2025
15d9282
Add timeout to early stopping callback test to prevent blocking
ZachNagengast Feb 21, 2025
1d79254
Revert watchos early stopping test check
ZachNagengast Feb 21, 2025
eb4ba45
Update more standard package versioning
ZachNagengast Feb 21, 2025
402a86e
Improve short duration word timestamps
ZachNagengast Feb 21, 2025
aed2237
PR review
ZachNagengast Feb 21, 2025
17f8f8c
Ammend early callback tests on some OS's
ZachNagengast Feb 21, 2025
95eed32
Remove captured var from test
ZachNagengast Feb 21, 2025
4b46ebb
Remove unecessary tasks from test
ZachNagengast Feb 22, 2025
9ce99a5
Revert test changes keeping os filter
ZachNagengast Feb 22, 2025
3d1a709
Formatting
ZachNagengast Feb 22, 2025
d6ff6ce
Cleanup
ZachNagengast Feb 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,6 @@ upload-benchmark-results:
@fastlane upload_results

clean-package-caches:
@trash ~/Library/Developer/Xcode/DerivedData/WhisperKit*
@trash ~/Library/Developer/Xcode/DerivedData/WhisperKit* || true
@swift package purge-cache
@swift package reset
4 changes: 2 additions & 2 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ let package = Package(
),
],
dependencies: [
.package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.8"),
.package(url: "https://github.com/apple/swift-argument-parser.git", exact: "1.3.0"),
.package(url: "https://github.com/huggingface/swift-transformers.git", .upToNextMinor(from: "0.1.8")),
.package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.3.0"),
],
targets: [
.target(
Expand Down
5 changes: 5 additions & 0 deletions Sources/WhisperKit/Core/Audio/AudioProcessor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ public typealias DeviceID = String
public struct AudioDevice: Identifiable, Hashable {
public let id: DeviceID
public let name: String

public init(id: DeviceID, name: String) {
self.id = id
self.name = name
}
}

public protocol AudioProcessing {
Expand Down
3 changes: 1 addition & 2 deletions Sources/WhisperKit/Core/FeatureExtractor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ open class FeatureExtractor: FeatureExtracting, WhisperMLModel {
guard inputDescription.type == .multiArray else { return nil }
guard let shapeConstraint = inputDescription.multiArrayConstraint else { return nil }
let shape = shapeConstraint.shape.map { $0.intValue }
return shape[0] // The audio input is a 1D array
return shape[0] // The audio input is a 1D array
}

public func logMelSpectrogram(fromAudio inputAudio: MLMultiArray) async throws -> MLMultiArray? {
Expand All @@ -54,5 +54,4 @@ open class FeatureExtractor: FeatureExtracting, WhisperMLModel {
let output = MelSpectrogramOutput(features: outputFeatures)
return output.melspectrogramFeatures
}

}
208 changes: 143 additions & 65 deletions Sources/WhisperKit/Core/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,26 @@ public struct ModelSupport: Codable, Equatable {
private enum CodingKeys: String, CodingKey {
case `default`, supported
}

public init(
default: String,
supported: [String],
disabled: [String] = []
) {
self.default = `default`
self.supported = supported
self.disabled = disabled
}
}

public struct DeviceSupport: Codable {
public let identifiers: [String]
public var models: ModelSupport

public init(identifiers: [String], models: ModelSupport) {
self.identifiers = identifiers
self.models = models
}
}

public struct ModelSupportConfig: Codable {
Expand Down Expand Up @@ -280,6 +295,11 @@ public struct ModelSupportConfig: Codable {
public struct AudioChunk {
public var seekOffsetIndex: Int
public var audioSamples: [Float]

public init(seekOffsetIndex: Int, audioSamples: [Float]) {
self.seekOffsetIndex = seekOffsetIndex
self.audioSamples = audioSamples
}
}

// MARK: - Decoding
Expand Down Expand Up @@ -351,7 +371,12 @@ public struct DecodingCache {
public var keyCache: MLMultiArray?
public var valueCache: MLMultiArray?
public var alignmentWeights: MLMultiArray?
public init(keyCache: MLMultiArray? = nil, valueCache: MLMultiArray? = nil, alignmentWeights: MLMultiArray? = nil) {

public init(
keyCache: MLMultiArray? = nil,
valueCache: MLMultiArray? = nil,
alignmentWeights: MLMultiArray? = nil
) {
self.keyCache = keyCache
self.valueCache = valueCache
self.alignmentWeights = alignmentWeights
Expand Down Expand Up @@ -432,7 +457,20 @@ public struct DecodingResult {
fallback: nil)
}

public init(language: String, languageProbs: [String: Float], tokens: [Int], tokenLogProbs: [[Int: Float]], text: String, avgLogProb: Float, noSpeechProb: Float, temperature: Float, compressionRatio: Float, cache: DecodingCache? = nil, timings: TranscriptionTimings? = nil, fallback: DecodingFallback? = nil) {
public init(
language: String,
languageProbs: [String: Float],
tokens: [Int],
tokenLogProbs: [[Int: Float]],
text: String,
avgLogProb: Float,
noSpeechProb: Float,
temperature: Float,
compressionRatio: Float,
cache: DecodingCache? = nil,
timings: TranscriptionTimings? = nil,
fallback: DecodingFallback? = nil
) {
self.language = language
self.languageProbs = languageProbs
self.tokens = tokens
Expand Down Expand Up @@ -510,6 +548,20 @@ public struct TranscriptionResult: Codable {
public var timings: TranscriptionTimings
public var seekTime: Float?

public init(
text: String,
segments: [TranscriptionSegment],
language: String,
timings: TranscriptionTimings,
seekTime: Float? = nil
) {
self.text = text
self.segments = segments
self.language = language
self.timings = timings
self.seekTime = seekTime
}

public func logSegments() {
for (i, segment) in segments.enumerated() {
let start = segment.start
Expand Down Expand Up @@ -593,18 +645,51 @@ public extension TranscriptionResult {
}

public struct TranscriptionSegment: Hashable, Codable {
public var id: Int = 0
public var seek: Int = 0
public var start: Float = 0.0
public var end: Float = 0.0
public var text: String = ""
public var tokens: [Int] = []
public var tokenLogProbs: [[Int: Float]] = [[:]]
public var temperature: Float = 1.0
public var avgLogprob: Float = 0.0
public var compressionRatio: Float = 1.0
public var noSpeechProb: Float = 0.0
public var words: [WordTiming]? = nil
public var id: Int
public var seek: Int
public var start: Float
public var end: Float
public var text: String
public var tokens: [Int]
public var tokenLogProbs: [[Int: Float]]
public var temperature: Float
public var avgLogprob: Float
public var compressionRatio: Float
public var noSpeechProb: Float
public var words: [WordTiming]?

/// Computed property for the duration of the segment
public var duration: Float {
return end - start
}

public init(
id: Int = 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we set defaults here, we should remove redundancy above

seek: Int = 0,
start: Float = 0.0,
end: Float = 0.0,
text: String = "",
tokens: [Int] = [],
tokenLogProbs: [[Int: Float]] = [[:]],
temperature: Float = 1.0,
avgLogprob: Float = 0.0,
compressionRatio: Float = 1.0,
noSpeechProb: Float = 0.0,
words: [WordTiming]? = nil
) {
self.id = id
self.seek = seek
self.start = start
self.end = end
self.text = text
self.tokens = tokens
self.tokenLogProbs = tokenLogProbs
self.temperature = temperature
self.avgLogprob = avgLogprob
self.compressionRatio = compressionRatio
self.noSpeechProb = noSpeechProb
self.words = words
}
}

public struct WordTiming: Hashable, Codable {
Expand All @@ -613,6 +698,19 @@ public struct WordTiming: Hashable, Codable {
public var start: Float
public var end: Float
public var probability: Float

/// Computed property for the duration of the word
public var duration: Float {
return end - start
}

public init(word: String, tokens: [Int], start: Float, end: Float, probability: Float) {
self.word = word
self.tokens = tokens
self.start = start
self.end = end
self.probability = probability
}
}

public struct TranscriptionProgress {
Expand Down Expand Up @@ -1198,17 +1296,40 @@ public struct SpecialTokens {
}
}

public protocol WhisperTokenizer: Tokenizer {
public protocol WhisperTokenizer {
/// swift-transformers pass through
func encode(text: String) -> [Int]
func decode(tokens: [Int]) -> String
func convertTokenToId(_ token: String) -> Int?
func convertIdToToken(_ id: Int) -> String?

/// WhisperKit specific
var specialTokens: SpecialTokens { get }
var allLanguageTokens: Set<Int> { get }

func splitToWordTokens(tokenIds: [Int]) -> (words: [String], wordTokens: [[Int]])
}

struct WhisperTokenizerWrapper: WhisperTokenizer {
open class WhisperTokenizerWrapper: WhisperTokenizer {
let tokenizer: any Tokenizer
let specialTokens: SpecialTokens
let allLanguageTokens: Set<Int>
public let specialTokens: SpecialTokens
public let allLanguageTokens: Set<Int>

public func encode(text: String) -> [Int] {
tokenizer.encode(text: text)
}

public func decode(tokens: [Int]) -> String {
tokenizer.decode(tokens: tokens)
}

public func convertTokenToId(_ token: String) -> Int? {
tokenizer.convertTokenToId(token)
}

public func convertIdToToken(_ id: Int) -> String? {
tokenizer.convertIdToToken(id)
}

init(tokenizer: any Tokenizer) {
let specialTokens = SpecialTokens(
Expand Down Expand Up @@ -1300,7 +1421,7 @@ struct WhisperTokenizerWrapper: WhisperTokenizer {
/// Decodes token ids into individual words and per-word subtokens
/// - Parameter tokenIds: Array of tokens to decode and then split
/// - Returns: Tuple containing and array of the split words and all tokens for each word
func splitToWordTokens(tokenIds: [Int]) -> (words: [String], wordTokens: [[Int]]) {
public func splitToWordTokens(tokenIds: [Int]) -> (words: [String], wordTokens: [[Int]]) {
let decodedWords = tokenizer.decode(tokens: tokenIds.filter { $0 < specialTokens.specialTokenBegin })

// Detect language of input text
Expand All @@ -1316,52 +1437,6 @@ struct WhisperTokenizerWrapper: WhisperTokenizer {
}
}

extension WhisperTokenizerWrapper: Tokenizer {
func tokenize(text: String) -> [String] {
tokenizer.tokenize(text: text)
}

func encode(text: String) -> [Int] {
tokenizer.encode(text: text)
}

func decode(tokens: [Int]) -> String {
tokenizer.decode(tokens: tokens)
}

func convertTokenToId(_ token: String) -> Int? {
tokenizer.convertTokenToId(token)
}

func convertIdToToken(_ id: Int) -> String? {
tokenizer.convertIdToToken(id)
}

var bosToken: String? {
tokenizer.bosToken
}

var bosTokenId: Int? {
tokenizer.bosTokenId
}

var eosToken: String? {
tokenizer.eosToken
}

var eosTokenId: Int? {
tokenizer.eosTokenId
}

var unknownToken: String? {
tokenizer.unknownToken
}

var unknownTokenId: Int? {
tokenizer.unknownTokenId
}
}

extension WhisperTokenizerWrapper {
/// Default values for each token, using base vocab
static var defaultWhitespaceToken: Int { 220 }
Expand Down Expand Up @@ -1512,6 +1587,9 @@ public enum Constants {

public static let defaultWindowSamples: Int = 480_000 // 30s of audio at 16khz sample rate default for Whisper models

public static let defaultPrependPunctuations: String = "\"'“¡¿([{-"
public static let defaultAppendPunctuations: String = "\"'.。,,!!??::”)]}、"

public static let fallbackModelSupportConfig: ModelSupportConfig = {
var config = ModelSupportConfig(
repoName: "whisperkit-coreml-fallback",
Expand Down
Loading