diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 255f357..b7951fd 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -20,7 +20,6 @@ jobs: - { name: "macOS", condition: true, - clean-destination: "generic/platform=macOS", test-destination: "platform=macOS,arch=arm64", test-cases: "-only-testing WhisperKitTests/UnitTests -only-testing WhisperKitMLXTests/MLXUnitTests", mlx-disabled: "0", @@ -29,7 +28,6 @@ jobs: - { name: "iOS", condition: true, - clean-destination: "generic/platform=iOS", test-destination: "platform=iOS Simulator,OS=${{ inputs.ios-version }},name=iPhone 15", test-cases: "-only-testing WhisperKitTests/UnitTests", mlx-disabled: "1", @@ -38,7 +36,6 @@ jobs: - { name: "watchOS", condition: "${{ inputs.macos-runner == 'macos-14' }}", - clean-destination: "generic/platform=watchOS", test-destination: "platform=watchOS Simulator,OS=10.5,name=Apple Watch Ultra 2 (49mm)", test-cases: "-only-testing WhisperKitTests/UnitTests", mlx-disabled: "1", @@ -47,7 +44,6 @@ jobs: - { name: "visionOS", condition: "${{ inputs.macos-runner == 'macos-14' }}", - clean-destination: "generic/platform=visionOS", test-destination: "platform=visionOS Simulator,name=Apple Vision Pro", test-cases: "-only-testing WhisperKitTests/UnitTests", mlx-disabled: "1", @@ -65,7 +61,7 @@ jobs: id: model-cache uses: actions/cache@v4 with: - path: Sources/WhisperKitTestsUtils/Models + path: Models key: ${{ runner.os }}-models - name: Download Models if: steps.model-cache.outputs.cache-hit != 'true' @@ -96,5 +92,8 @@ jobs: if: ${{ matrix.run-config['condition'] == true }} run: | set -o pipefail - xcodebuild clean build-for-testing -scheme ${{ matrix.run-config['scheme'] }} -destination '${{ matrix.run-config['clean-destination'] }}' -skipPackagePluginValidation | xcpretty - xcodebuild test -only-testing WhisperKitTests/UnitTests -scheme ${{ matrix.run-config['scheme'] }} -destination '${{ matrix.run-config['test-destination'] }}' -skipPackagePluginValidation + xcodebuild clean build-for-testing test \ + ${{ matrix.run-config['test-cases'] }} \ + -scheme ${{ matrix.run-config['scheme'] }} \ + -destination '${{ matrix.run-config['test-destination'] }}' \ + -skipPackagePluginValidation diff --git a/.gitignore b/.gitignore index fd725dc..bb8893b 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ DerivedData/ **/*.xcscheme .netrc .env +/.vscode # Core ML Model Files Models diff --git a/.swiftpm/configuration/Package.resolved b/.swiftpm/configuration/Package.resolved new file mode 100644 index 0000000..bb2ef99 --- /dev/null +++ b/.swiftpm/configuration/Package.resolved @@ -0,0 +1,41 @@ +{ + "pins" : [ + { + "identity" : "mlx-swift", + "kind" : "remoteSourceControl", + "location" : "https://github.com/ml-explore/mlx-swift", + "state" : { + "revision" : "597aaa5f465b4b9a17c8646b751053f84e37925b", + "version" : "0.16.0" + } + }, + { + "identity" : "swift-argument-parser", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-argument-parser.git", + "state" : { + "revision" : "c8ed701b513cf5177118a175d85fbbbcd707ab41", + "version" : "1.3.0" + } + }, + { + "identity" : "swift-numerics", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-numerics", + "state" : { + "revision" : "0a5bc04095a675662cf24757cc0640aa2204253b", + "version" : "1.0.2" + } + }, + { + "identity" : "swift-transformers", + "kind" : "remoteSourceControl", + "location" : "https://github.com/huggingface/swift-transformers.git", + "state" : { + "revision" : "74b94211bdc741694ed7e700a1104c72e5ba68fe", + "version" : "0.1.7" + } + } + ], + "version" : 2 +} diff --git a/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved index 6d2ac24..7ae7da8 100644 --- a/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -1,13 +1,13 @@ { - "originHash" : "cd17206b47bb810af9459722192530e3838d8e6629a970988e32a432aaa05f6e", + "originHash" : "829222b514832cb61fe0002e0eebda98f23a75169c63f7d6ed7a320d57d5318f", "pins" : [ { "identity" : "mlx-swift", "kind" : "remoteSourceControl", "location" : "https://github.com/ml-explore/mlx-swift", "state" : { - "branch" : "main", - "revision" : "c11212bff42a1b88aea83811210d42a5f99440ad" + "revision" : "597aaa5f465b4b9a17c8646b751053f84e37925b", + "version" : "0.16.0" } }, { diff --git a/Makefile b/Makefile index 309b74a..a1314dd 100644 --- a/Makefile +++ b/Makefile @@ -60,6 +60,7 @@ setup-model-repo: git clone https://huggingface.co/$(MODEL_REPO) $(MODEL_REPO_DIR); \ fi + setup-mlx-model-repo: @echo "Setting up mlx repository..." @mkdir -p $(BASE_MODEL_DIR) @@ -109,21 +110,40 @@ download-mlx-model: setup-mlx-model-repo @echo "Downloading mlx model $(MODEL)..." @cd $(MLX_MODEL_REPO_DIR) && \ git lfs pull --include="openai_whisper-$(MODEL)/*" - @echo "MLX model $(MODEL) downloaded to $(MLX_MODEL_REPO_DIR)/openai_whisper-mlx-$(MODEL)" + @echo "MLX model $(MODEL) downloaded to $(MLX_MODEL_REPO_DIR)/openai_whisper-$(MODEL)" + build: @echo "Building WhisperKit..." - @swift build -v + @xcodebuild CLANG_ENABLE_CODE_COVERAGE=NO VALID_ARCHS=arm64 clean build \ + -configuration Release \ + -scheme whisperkit-Package \ + -destination generic/platform=macOS \ + -derivedDataPath .build/.xcodebuild/ \ + -clonedSourcePackagesDirPath .build/ \ + -skipPackagePluginValidation build-cli: @echo "Building WhisperKit CLI..." - @swift build -c release --product whisperkit-cli + @xcodebuild CLANG_ENABLE_CODE_COVERAGE=NO VALID_ARCHS=arm64 clean build \ + -configuration Release \ + -scheme whisperkit-cli \ + -destination generic/platform=macOS \ + -derivedDataPath .build/.xcodebuild/ \ + -clonedSourcePackagesDirPath .build/ \ + -skipPackagePluginValidation test: @echo "Running tests..." - @swift test -v + @xcodebuild clean build-for-testing test \ + -scheme whisperkit-Package \ + -only-testing WhisperKitMLXTests/MLXUnitTests \ + -only-testing WhisperKitTests/UnitTests \ + -destination 'platform=macOS,arch=arm64' \ + -skipPackagePluginValidation + clean-package-caches: @trash ~/Library/Caches/org.swift.swiftpm/repositories diff --git a/Package.resolved b/Package.resolved index d6aa442..bb2ef99 100644 --- a/Package.resolved +++ b/Package.resolved @@ -5,8 +5,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/ml-explore/mlx-swift", "state" : { - "branch" : "main", - "revision" : "3c802c808d281c191d5f26f37a4f93135d8ca119" + "revision" : "597aaa5f465b4b9a17c8646b751053f84e37925b", + "version" : "0.16.0" } }, { diff --git a/Package.swift b/Package.swift index 09b33f8..824bb0f 100644 --- a/Package.swift +++ b/Package.swift @@ -4,67 +4,26 @@ import PackageDescription import Foundation -// NOTE: `MLX` doesn't support `watchOS` yet, that's why we control the build using the `MLX_DISABLED` environment variable. -// To manualy build for `watchOS` use: -// `export MLX_DISABLED=1 && xcodebuild clean build-for-testing -scheme whisperkit -sdk watchos10.4 -destination 'platform=watchOS Simulator' -skipPackagePluginValidation` let package = Package( name: "whisperkit", platforms: [ .iOS(.v16), - .macOS("13.3") + .macOS("13.3"), + .watchOS(.v10) ], - products: products() + mlxProducts(), - dependencies: dependencies() + mlxDependencies(), - targets: targets() + mlxTargets() -) - -func products() -> [PackageDescription.Product] { - return [ + products: [ .library( name: "WhisperKit", targets: ["WhisperKit"] - ) - ] -} - -func mlxProducts() -> [PackageDescription.Product] { - let isMLXDisabled = ProcessInfo.processInfo.environment["MLX_DISABLED"] == "1" - if isMLXDisabled { - return [] - } else { - return [ - .library( - name: "WhisperKitMLX", - targets: ["WhisperKitMLX"] - ), - .executable( - name: "whisperkit-cli", - targets: ["WhisperKitCLI"] - ), - ] - } -} - -func dependencies() -> [PackageDescription.Package.Dependency] { - return [ + ), + ] + + cliProducts() + + mlxProducts(), + dependencies: [ .package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.7"), .package(url: "https://github.com/apple/swift-argument-parser.git", exact: "1.3.0"), - ] -} - -func mlxDependencies() -> [PackageDescription.Package.Dependency] { - let isMLXDisabled = ProcessInfo.processInfo.environment["MLX_DISABLED"] == "1" - if isMLXDisabled { - return [] - } else { - return [ - .package(url: "https://github.com/ml-explore/mlx-swift", branch: "main"), - ] - } -} - -func targets() -> [PackageDescription.Target] { - return [ + ] + mlxDependencies(), + targets: [ .target( name: "WhisperKit", dependencies: [ @@ -103,46 +62,89 @@ func targets() -> [PackageDescription.Target] { .product(name: "Transformers", package: "swift-transformers"), ] ) + ] + + cliTargets() + + mlxTargets() +) + +// MARK: - MLX Helper Functions + +// CLI +func cliProducts() -> [Product] { + guard !isMLXDisabled() else { return [] } + return [ + .executable( + name: "whisperkit-cli", + targets: ["WhisperKitCLI"] + ), + ] +} + +func cliTargets() -> [Target] { + guard !isMLXDisabled() else { return [] } + return [ + .executableTarget( + name: "WhisperKitCLI", + dependencies: [ + "WhisperKit", + "WhisperKitMLX", + .product(name: "ArgumentParser", package: "swift-argument-parser"), + ] + ), ] } -func mlxTargets() -> [PackageDescription.Target] { - let isMLXDisabled = ProcessInfo.processInfo.environment["MLX_DISABLED"] == "1" - if isMLXDisabled { - return [] - } else { - return [ - .executableTarget( - name: "WhisperKitCLI", - dependencies: [ - "WhisperKit", - "WhisperKitMLX", - .product(name: "ArgumentParser", package: "swift-argument-parser"), - ] - ), - .target( - name: "WhisperKitMLX", - dependencies: [ - "WhisperKit", - .product(name: "MLX", package: "mlx-swift"), - .product(name: "MLXFFT", package: "mlx-swift"), - .product(name: "MLXNN", package: "mlx-swift") - ], - path: "Sources/WhisperKit/MLX", - resources: [ - .copy("Resources/mel_filters_80.npy"), - .copy("Resources/mel_filters_128.npy") - ] - ), - .testTarget( - name: "WhisperKitMLXTests", - dependencies: [ - "WhisperKit", - "WhisperKitMLX", - "WhisperKitTestsUtils", - .product(name: "Transformers", package: "swift-transformers"), - ] - ) - ] - } +// MLX +func mlxProducts() -> [Product] { + guard !isMLXDisabled() else { return [] } + return [ + .library( + name: "WhisperKitMLX", + targets: ["WhisperKitMLX"] + ), + ] +} + +func mlxDependencies() -> [Package.Dependency] { + guard !isMLXDisabled() else { return [] } + return [ + .package(url: "https://github.com/ml-explore/mlx-swift", exact: "0.16.0"), + ] +} + +func mlxTargets() -> [Target] { + guard !isMLXDisabled() else { return [] } + return [ + .target( + name: "WhisperKitMLX", + dependencies: [ + "WhisperKit", + .product(name: "MLX", package: "mlx-swift"), + .product(name: "MLXFFT", package: "mlx-swift"), + .product(name: "MLXNN", package: "mlx-swift") + ], + path: "Sources/WhisperKit/MLX", + resources: [ + .copy("Resources/mel_filters_80.npy"), + .copy("Resources/mel_filters_128.npy") + ] + ), + .testTarget( + name: "WhisperKitMLXTests", + dependencies: [ + "WhisperKit", + "WhisperKitMLX", + "WhisperKitTestsUtils", + .product(name: "Transformers", package: "swift-transformers"), + ] + ) + ] +} + +// NOTE: `MLX` doesn't support `watchOS` yet, that's why we control the build using the `MLX_DISABLED` environment variable. +// To manualy build for `watchOS` use: +// `export MLX_DISABLED=1 && xcodebuild clean build-for-testing -scheme whisperkit -sdk watchos10.4 -destination 'platform=watchOS Simulator,OS=10.5,name=Apple Watch Ultra 2 (49mm)' -skipPackagePluginValidation` + +func isMLXDisabled() -> Bool { + ProcessInfo.processInfo.environment["MLX_DISABLED"] == "1" } diff --git a/README.md b/README.md index a4d5f5c..242087c 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,8 @@ Check out the demo app on [TestFlight](https://testflight.apple.com/join/LPVOyJZ - [Model Selection](#model-selection) - [Generating Models](#generating-models) - [Swift CLI](#swift-cli) + - [Backend Selection](#backend-selection) + - [Testing](#testing) - [Contributing \& Roadmap](#contributing--roadmap) - [License](#license) - [Citation](#citation) @@ -66,7 +68,7 @@ You can install `WhisperKit` command line app using [Homebrew](https://brew.sh) ```bash brew install whisperkit-cli -``` +``` ## Getting Started @@ -79,38 +81,51 @@ This example demonstrates how to transcribe a local audio file: ```swift import WhisperKit -// Initialize WhisperKit with default settings -Task { - let pipe = try? await WhisperKit() - let transcription = try? await pipe!.transcribe(audioPath: "path/to/your/audio.{wav,mp3,m4a,flac}")?.text - print(transcription) -} +// Initialize WhisperKit by passing the model name (WhisperKit will automatically download it): +let pipe = try await WhisperKit(model: "tiny") +// Transcribe the audio file +let transcription = try await pipe.transcribe(audioPath: "path/to/your/audio.{wav,mp3,m4a,flac}")?.text +// Print the transcription +print(transcription) ``` ### Model Selection -WhisperKit automatically downloads the recommended model for the device if not specified. You can also select a specific model by passing in the model name: +You have to specify the model by passing the model name: ```swift -let pipe = try? await WhisperKit(model: "large-v3") +let pipe = try await WhisperKit(model: "large-v3") ``` This method also supports glob search, so you can use wildcards to select a model: ```swift -let pipe = try? await WhisperKit(model: "distil*large-v3") +let pipe = try await WhisperKit(model: "distil*large-v3") ``` Note that the model search must return a single model from the source repo, otherwise an error will be thrown. For a list of available models, see our [HuggingFace repo](https://huggingface.co/argmaxinc/whisperkit-coreml). +For MLX models, see [here](https://huggingface.co/argmaxinc/whisperkit-mlx). + +If you want to get the recommended model for your device, you can use the following method: + +```swift +print(WhisperKit.recommendedModels()) +``` + +it should print the default and a list of disabled models, e.g.: + +```bash +(default: "openai_whisper-base", disabled: ["openai_whisper-large-v2_turbo", "openai_whisper-large-v2_turbo_955MB", "openai_whisper-large-v3_turbo", "openai_whisper-large-v3_turbo_954MB", "distil-whisper_distil-large-v3_turbo_600MB", "distil-whisper_distil-large-v3_turbo"]) +``` ### Generating Models WhisperKit also comes with the supporting repo [`whisperkittools`](https://github.com/argmaxinc/whisperkittools) which lets you create and deploy your own fine tuned versions of Whisper in CoreML format to HuggingFace. Once generated, they can be loaded by simply changing the repo name to the one used to upload the model: ```swift -let pipe = try? await WhisperKit(model: "large-v3", modelRepo: "username/your-model-repo") +let pipe = try await WhisperKit(model: "large-v3", modelRepo: "username/your-model-repo") ``` ### Swift CLI @@ -152,6 +167,53 @@ Which should print a transcription of the audio file. If you would like to strea swift run whisperkit-cli transcribe --model-path "Models/whisperkit-coreml/openai_whisper-large-v3" --stream ``` +### Backend Selection + +WhisperKit supports both CoreML and MLX backends. By default, it uses CoreML, but you can switch some or all pipeline components to MLX. +Available pipeline components are: +- `featureExtractor`, `FeatureExtractor` is used by default, use `MLXFeatureExtractor` to switch to MLX +- `audioEncoder`, `AudioEncoder` is used by default, use `MLXAudioEncoder` to switch to MLX +- `textDecoder`, `TextDecoder` is used by default, use `MLXTextDecoder` to switch to MLX + +Here is an example of how to switch the `featureExtractor` and `audioEncoder` to MLX and keep the `textDecoder` as CoreML: + +```swift +let pipe = try await WhisperKit( + model: "tiny", + mlxModel: "tiny", + featureExtractor: MLXFeatureExtractor(), + audioEncoder: MLXAudioEncoder() +) +``` + +**Note**: + +`swift run` and `swift test` commands won't work when the `mlx` backend is selected. +SwiftPM (command line) cannot build the Metal shaders so the ultimate build has to be done via Xcode. + +### Testing + +If you want to run the unit tests locally, first clone the repo: + +```bash +git clone https://github.com/argmaxinc/whisperkit.git +cd whisperkit +``` + +download the required models: + +```bash +make setup +make download-model MODEL=tiny +make download-mlx-model MODEL=tiny +``` + +and then run the tests: + +```bash +make test +``` + ## Contributing & Roadmap Our goal is to make WhisperKit better and better over time and we'd love your help! Just search the code for "TODO" for a variety of features that are yet to be built. Please refer to our [contribution guidelines](CONTRIBUTING.md) for submitting issues, pull requests, and coding standards, where we also have a public roadmap of features we are looking forward to building in the future. diff --git a/Sources/WhisperKit/Core/TextDecoder.swift b/Sources/WhisperKit/Core/TextDecoder.swift index f2bd27d..250f895 100644 --- a/Sources/WhisperKit/Core/TextDecoder.swift +++ b/Sources/WhisperKit/Core/TextDecoder.swift @@ -482,6 +482,8 @@ open class TextDecoder: TextDecoding, WhisperMLModel { return getModelInputDimention(model, named: "encoder_output_embeds", position: 1) } + public init() {} + /// Override default so we an unload the prefill data as well public func unloadModel() { model = nil diff --git a/Sources/WhisperKit/Core/Utils.swift b/Sources/WhisperKit/Core/Utils.swift index f4c7633..d8d759d 100644 --- a/Sources/WhisperKit/Core/Utils.swift +++ b/Sources/WhisperKit/Core/Utils.swift @@ -44,7 +44,7 @@ extension MLMultiArray { /// - index: The index of the element /// - strides: The precomputed strides of the multi-array, if not provided, it will be computed. It's a performance optimization to avoid recomputing the strides every time when accessing the multi-array with multiple indexes. @inline(__always) - func linearOffset(for index: [NSNumber], strides strideInts: [Int]? = nil) -> Int { + public func linearOffset(for index: [NSNumber], strides strideInts: [Int]? = nil) -> Int { var linearOffset = 0 let strideInts = strideInts ?? strides.map { $0.intValue } for (dimension, stride) in zip(index, strideInts) { diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index 0841c02..a9bb407 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -38,14 +38,18 @@ open class WhisperKit { /// Configuration public var modelFolder: URL? + public var mlxModelFolder: URL? public var tokenizerFolder: URL? public let useBackgroundDownloadSession: Bool public init( model: String? = nil, + mlxModel: String? = nil, downloadBase: URL? = nil, - modelRepo: String? = nil, + modelRepo: String = "argmaxinc/whisperkit-coreml", modelFolder: String? = nil, + mlxModelRepo: String = "argmaxinc/whisperkit-mlx", + mlxModelFolder: String? = nil, tokenizerFolder: URL? = nil, computeOptions: ModelComputeOptions? = nil, audioProcessor: (any AudioProcessing)? = nil, @@ -75,9 +79,12 @@ open class WhisperKit { try await setupModels( model: model, + mlxModel: mlxModel, downloadBase: downloadBase, modelRepo: modelRepo, modelFolder: modelFolder, + mlxModelRepo: mlxModelRepo, + mlxModelFolder: mlxModelFolder, download: download ) @@ -214,30 +221,56 @@ open class WhisperKit { /// Sets up the model folder either from a local path or by downloading from a repository. public func setupModels( model: String?, + mlxModel: String? = nil, downloadBase: URL? = nil, - modelRepo: String?, - modelFolder: String?, + modelRepo: String = "argmaxinc/whisperkit-coreml", + modelFolder: String? = nil, + mlxModelRepo: String = "argmaxinc/whisperkit-mlx", + mlxModelFolder: String? = nil, download: Bool ) async throws { - // Determine the model variant to use - let modelVariant = model ?? WhisperKit.recommendedModels().default + // If no model is provided, use the recommended model + var modelVariant = model + if model == nil, mlxModel == nil, mlxModelFolder == nil { + // Determine the model variant to use by default + modelVariant = WhisperKit.recommendedModels().default + } // If a local model folder is provided, use it; otherwise, download the model - if let folder = modelFolder { - self.modelFolder = URL(fileURLWithPath: folder) - } else if download { - let repo = modelRepo ?? "argmaxinc/whisperkit-coreml" + if let modelFolder { + self.modelFolder = URL(fileURLWithPath: modelFolder) + } else if download, let modelVariant { do { self.modelFolder = try await Self.download( variant: modelVariant, downloadBase: downloadBase, useBackgroundSession: useBackgroundDownloadSession, - from: repo + from: modelRepo + ) + } catch { + // Handle errors related to model downloading + throw WhisperError.modelsUnavailable(""" + CoreML Model not found. Please check the model or repo name and try again. + Error: \(error) + """) + } + } + + // Same for MLX + if let mlxModelFolder { + self.mlxModelFolder = URL(fileURLWithPath: mlxModelFolder) + } else if download, let mlxModel { + do { + self.mlxModelFolder = try await Self.download( + variant: mlxModel, + downloadBase: downloadBase, + useBackgroundSession: useBackgroundDownloadSession, + from: mlxModelRepo ) } catch { // Handle errors related to model downloading throw WhisperError.modelsUnavailable(""" - Model not found. Please check the model or repo name and try again. + MLX Model not found. Please check the model or repo name and try again. Error: \(error) """) } @@ -251,40 +284,37 @@ open class WhisperKit { public func loadModels( prewarmMode: Bool = false ) async throws { - modelState = prewarmMode ? .prewarming : .loading + assert(modelFolder != nil || mlxModelFolder != nil, "Please specify `modelFolder` or `mlxModelFolder`") + modelState = prewarmMode ? .prewarming : .loading let modelLoadStart = CFAbsoluteTimeGetCurrent() - guard let path = modelFolder else { - throw WhisperError.modelsUnavailable("Model folder is not set.") - } + Logging.debug("Loading models with prewarmMode: \(prewarmMode)") - Logging.debug("Loading models from \(path.path) with prewarmMode: \(prewarmMode)") - - if let featureExtractor = featureExtractor as? WhisperMLModel { - Logging.debug("Loading feature extractor") + if let path = modelFolder, let featureExtractor = featureExtractor as? WhisperMLModel { + Logging.debug("Loading feature extractor from \(path.path)") try await featureExtractor.loadModel( at: path.appending(path: "MelSpectrogram.mlmodelc"), computeUnits: modelCompute.melCompute, // hardcoded to use GPU prewarmMode: prewarmMode ) Logging.debug("Loaded feature extractor") - } else if let featureExtractor = featureExtractor as? WhisperMLXModel { - Logging.debug("Loading MLX feature extractor") + } else if let path = mlxModelFolder, let featureExtractor = featureExtractor as? WhisperMLXModel { + Logging.debug("Loading MLX feature extractor from \(path.path)") try await featureExtractor.loadModel(at: path, configPath: path) Logging.debug("Loaded MLX feature extractor") } - if let audioEncoder = audioEncoder as? WhisperMLModel { - Logging.debug("Loading audio encoder") + if let path = modelFolder, let audioEncoder = audioEncoder as? WhisperMLModel { + Logging.debug("Loading audio encoder from \(path.path)") try await audioEncoder.loadModel( at: path.appending(path: "AudioEncoder.mlmodelc"), computeUnits: modelCompute.audioEncoderCompute, prewarmMode: prewarmMode ) Logging.debug("Loaded audio encoder") - } else if let audioEncoder = audioEncoder as? WhisperMLXModel { - Logging.debug("Loading MLX audio encoder") + } else if let path = mlxModelFolder, let audioEncoder = audioEncoder as? WhisperMLXModel { + Logging.debug("Loading MLX audio encoder from \(path.path)") try await audioEncoder.loadModel( at: path.appending(path: "encoder.safetensors"), configPath: path.appending(path: "config.json") @@ -292,16 +322,16 @@ open class WhisperKit { Logging.debug("Loaded MLX audio encoder") } - if let textDecoder = textDecoder as? WhisperMLModel { - Logging.debug("Loading text decoder") + if let path = modelFolder, let textDecoder = textDecoder as? WhisperMLModel { + Logging.debug("Loading text decoder from \(path.path)") try await textDecoder.loadModel( at: path.appending(path: "TextDecoder.mlmodelc"), computeUnits: modelCompute.textDecoderCompute, prewarmMode: prewarmMode ) Logging.debug("Loaded text decoder") - } else if let textDecoder = textDecoder as? WhisperMLXModel { - Logging.debug("Loading MLX text decoder") + } else if let path = mlxModelFolder, let textDecoder = textDecoder as? WhisperMLXModel { + Logging.debug("Loading MLX text decoder from \(path.path)") try await textDecoder.loadModel( at: path.appending(path: "decoder.safetensors"), configPath: path.appending(path: "config.json") @@ -309,16 +339,18 @@ open class WhisperKit { Logging.debug("Loaded MLX text decoder") } - let decoderPrefillUrl = path.appending(path: "TextDecoderContextPrefill.mlmodelc") - if FileManager.default.fileExists(atPath: decoderPrefillUrl.path) { - Logging.debug("Loading text decoder prefill data") - textDecoder.prefillData = TextDecoderContextPrefill() - try await textDecoder.prefillData?.loadModel( - at: decoderPrefillUrl, - computeUnits: modelCompute.prefillCompute, - prewarmMode: prewarmMode - ) - Logging.debug("Loaded text decoder prefill data") + if let path = modelFolder { + let decoderPrefillUrl = path.appending(path: "TextDecoderContextPrefill.mlmodelc") + if FileManager.default.fileExists(atPath: decoderPrefillUrl.path) { + Logging.debug("Loading text decoder prefill data") + textDecoder.prefillData = TextDecoderContextPrefill() + try await textDecoder.prefillData?.loadModel( + at: decoderPrefillUrl, + computeUnits: modelCompute.prefillCompute, + prewarmMode: prewarmMode + ) + Logging.debug("Loaded text decoder prefill data") + } } if prewarmMode { diff --git a/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift b/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift index 29aecc9..935bbf0 100644 --- a/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift +++ b/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift @@ -39,6 +39,11 @@ open class MLXFeatureExtractor: FeatureExtracting { } } +extension MLXFeatureExtractor: WhisperMLXModel { + public func loadModel(at modelPath: URL, configPath: URL) async throws {} + public func unloadModel() {} +} + public extension MLXFeatureExtractor { /// Return the Hanning window. /// Taken from [numpy](https://numpy.org/doc/stable/reference/generated/numpy.hanning.html) implementation @@ -103,9 +108,6 @@ public extension MLXFeatureExtractor { nFFT: Int = 400, hopLength: Int = 160 ) -> MLXArray { - let device = MLX.Device.defaultDevice() - MLX.Device.setDefault(device: .cpu) - defer { MLX.Device.setDefault(device: device) } let window = hanning(nFFT) let freqs = stft(audio, window: window, nPerSeg: nFFT, nOverlap: hopLength) let magnitudes = freqs[..<(-1)].abs().square() diff --git a/Sources/WhisperKit/MLX/MLXTextDecoder.swift b/Sources/WhisperKit/MLX/MLXTextDecoder.swift index 2661c40..d98b9a1 100644 --- a/Sources/WhisperKit/MLX/MLXTextDecoder.swift +++ b/Sources/WhisperKit/MLX/MLXTextDecoder.swift @@ -13,7 +13,7 @@ public final class MLXTextDecoder: TextDecoding { public var isModelMultilingual: Bool = false public let supportsWordTimestamps: Bool = false public var logitsSize: Int? { - decoder?.nState + decoder?.nVocab } public var kvCacheEmbedDim: Int? { diff --git a/Sources/WhisperKit/MLX/MLXUtils.swift b/Sources/WhisperKit/MLX/MLXUtils.swift index 5059ac1..52a33db 100644 --- a/Sources/WhisperKit/MLX/MLXUtils.swift +++ b/Sources/WhisperKit/MLX/MLXUtils.swift @@ -35,6 +35,19 @@ extension MLXArray { } } +extension MLXArray { + var contiguousStrides: [Int] { + var contiguousStrides = [1] + var stride = 1 + for dimension in shape.dropFirst().reversed() { + stride = stride * dimension + contiguousStrides.append(stride) + } + contiguousStrides.reverse() + return contiguousStrides + } +} + extension MLXArray { func asMLMultiArray() throws -> MLMultiArray { let dataType = multiArrayDataType() @@ -45,11 +58,12 @@ extension MLXArray { let destination = UnsafeMutableRawBufferPointer(start: buffer, count: nbytes) ptr.copyBytes(to: destination) } + // `contiguousStrides` has to used, see the [discussion](https://github.com/ml-explore/mlx-swift/issues/117) return try MLMultiArray( dataPointer: buffer, shape: shape.map { NSNumber(value: $0) }, dataType: dataType, - strides: strides.map { NSNumber(value: $0) }, + strides: contiguousStrides.map { NSNumber(value: $0) }, deallocator: { $0.deallocate() } ) } diff --git a/Sources/WhisperKitCLI/CLIArguments.swift b/Sources/WhisperKitCLI/CLIArguments.swift index b76439b..1ca5e0e 100644 --- a/Sources/WhisperKitCLI/CLIArguments.swift +++ b/Sources/WhisperKitCLI/CLIArguments.swift @@ -3,6 +3,11 @@ import ArgumentParser +enum ModelType: String, Decodable, ExpressibleByArgument { + case coreML = "coreml" + case mlx = "mlx" +} + struct CLIArguments: ParsableArguments { @Option(help: "Paths to audio files") var audioPath = [String]() @@ -16,15 +21,33 @@ struct CLIArguments: ParsableArguments { @Option(help: "Model to download if no modelPath is provided") var model: String? + @Option(help: "Path of MLX model files") + var mlxModelPath: String? + + @Option(help: "MLX Model to download if no mlxModelPath is provided") + var mlxModel: String? + @Option(help: "Text to add in front of the model name to specify between different types of the same variant (values: \"openai\", \"distil\")") var modelPrefix: String = "openai" + @Option(help: "Text to add in front of the mlx model name to specify between different types of the same variant (values: \"openai\")") + var mlxModelPrefix: String = "openai" + @Option(help: "Path to save the downloaded model") var downloadModelPath: String? @Option(help: "Path to save the downloaded tokenizer files") var downloadTokenizerPath: String? + @Option(help: "Which feature extractor to use (supported: `coreml` and `mlx`)") + var featureExtractorType: ModelType = .coreML + + @Option(help: "Which audio encoder to use (supported: `coreml` and `mlx`)") + var audioEncoderType: ModelType = .coreML + + @Option(help: "Which text decoder to use (supported: `coreml` and `mlx`)") + var textDecoderType: ModelType = .coreML + @Option(help: "Compute units for audio encoder model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine,random}") var audioEncoderComputeUnits: ComputeUnits = .cpuAndNeuralEngine diff --git a/Sources/WhisperKitCLI/TranscribeCLI.swift b/Sources/WhisperKitCLI/TranscribeCLI.swift index 172e423..cae3c33 100644 --- a/Sources/WhisperKitCLI/TranscribeCLI.swift +++ b/Sources/WhisperKitCLI/TranscribeCLI.swift @@ -305,12 +305,58 @@ struct TranscribeCLI: AsyncParsableCommand { nil } + let mlxModelName: String? = + if let modelVariant = cliArguments.mlxModel { + cliArguments.mlxModelPrefix + "*" + modelVariant + } else { + nil + } + + var featureExtractorType = cliArguments.featureExtractorType + var audioEncoderType = cliArguments.featureExtractorType + var textDecoderType = cliArguments.featureExtractorType + + if modelName == nil, mlxModelName != nil { + // CoreML model not provided, default to MLX + featureExtractorType = .mlx + audioEncoderType = .mlx + textDecoderType = .mlx + } + + let featureExtractor: FeatureExtracting = + switch featureExtractorType { + case .coreML: + FeatureExtractor() + case .mlx: + MLXFeatureExtractor() + } + + let audioEncoder: AudioEncoding = + switch audioEncoderType { + case .coreML: + AudioEncoder() + case .mlx: + MLXAudioEncoder() + } + + let textDecoder: TextDecoding = + switch textDecoderType { + case .coreML: + TextDecoder() + case .mlx: + MLXTextDecoder() + } + return try await WhisperKit( model: modelName, + mlxModel: mlxModelName, downloadBase: downloadModelFolder, modelFolder: cliArguments.modelPath, tokenizerFolder: downloadTokenizerFolder, computeOptions: computeOptions, + featureExtractor: featureExtractor, + audioEncoder: audioEncoder, + textDecoder: textDecoder, verbose: cliArguments.verbose, logLevel: .debug, load: true, diff --git a/Sources/WhisperKitTestsUtils/TestUtils.swift b/Sources/WhisperKitTestsUtils/TestUtils.swift index 7bfae95..4271313 100644 --- a/Sources/WhisperKitTestsUtils/TestUtils.swift +++ b/Sources/WhisperKitTestsUtils/TestUtils.swift @@ -1,7 +1,7 @@ import CoreML import Combine import Foundation -@testable import WhisperKit +import WhisperKit import XCTest public enum TestError: Error { @@ -133,7 +133,8 @@ public extension MLMultiArray { @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) public extension XCTestCase { func transcribe( - modelPath: String, + modelPath: String? = nil, + mlxModelPath: String? = nil, options: DecodingOptions, callback: TranscriptionCallback = nil, audioFile: String = "jfk.wav", @@ -151,6 +152,7 @@ public extension XCTestCase { ) let whisperKit = try await WhisperKit( modelFolder: modelPath, + mlxModelFolder: mlxModelPath, computeOptions: computeOptions, featureExtractor: featureExtractor, audioEncoder: audioEncoder, @@ -170,7 +172,7 @@ public extension XCTestCase { func tinyModelPath() throws -> String { let modelDir = "whisperkit-coreml/openai_whisper-tiny" guard let modelPath = Bundle.module.urls(forResourcesWithExtension: "mlmodelc", subdirectory: modelDir)?.first?.deletingLastPathComponent().path else { - throw TestError.missingFile("Failed to load model, ensure \"Models/\(modelDir)\" exists via Makefile command: `make download-models`") + throw TestError.missingFile("Failed to load model, ensure \"Models/\(modelDir)\" exists via Makefile command: `make download-model MODEL=tiny`") } return modelPath } @@ -178,7 +180,7 @@ public extension XCTestCase { func tinyMLXModelPath() throws -> String { let modelDir = "whisperkit-mlx/openai_whisper-tiny" guard let modelPath = Bundle.module.urls(forResourcesWithExtension: "safetensors", subdirectory: modelDir)?.first?.deletingLastPathComponent().path else { - throw TestError.missingFile("Failed to load model, ensure \"Models/\(modelDir)\" exists via Makefile command: `make download-mlx-models`") + throw TestError.missingFile("Failed to load model, ensure \"Models/\(modelDir)\" exists via Makefile command: `make download-mlx-model MODEL=tiny`") } return modelPath } diff --git a/Tests/WhisperKitMLXTests/MLXUnitTests.swift b/Tests/WhisperKitMLXTests/MLXUnitTests.swift index 2368b3b..c51a03e 100644 --- a/Tests/WhisperKitMLXTests/MLXUnitTests.swift +++ b/Tests/WhisperKitMLXTests/MLXUnitTests.swift @@ -154,7 +154,7 @@ final class MLXUnitTests: XCTestCase { let result = try await XCTUnwrapAsync( try await transcribe( - modelPath: tinyModelPath, + mlxModelPath: tinyModelPath, options: options, audioFile: "es_test_clip.wav", featureExtractor: MLXFeatureExtractor(), @@ -173,7 +173,7 @@ final class MLXUnitTests: XCTestCase { let result = try await XCTUnwrapAsync( try await transcribe( - modelPath: tinyModelPath, + mlxModelPath: tinyModelPath, options: options, audioFile: "es_test_clip.wav", featureExtractor: MLXFeatureExtractor(), @@ -189,7 +189,7 @@ final class MLXUnitTests: XCTestCase { func testDetectSpanish() async throws { let targetLanguage = "es" let whisperKit = try await WhisperKit( - modelFolder: tinyModelPath, + mlxModelFolder: tinyModelPath, featureExtractor: MLXFeatureExtractor(), audioEncoder: MLXAudioEncoder(), textDecoder: MLXTextDecoder(), @@ -215,7 +215,7 @@ final class MLXUnitTests: XCTestCase { let result = try await XCTUnwrapAsync( try await transcribe( - modelPath: tinyModelPath, + mlxModelPath: tinyModelPath, options: options, audioFile: "ja_test_clip.wav", featureExtractor: MLXFeatureExtractor(), @@ -234,7 +234,7 @@ final class MLXUnitTests: XCTestCase { let result = try await XCTUnwrapAsync( try await transcribe( - modelPath: tinyModelPath, + mlxModelPath: tinyModelPath, options: options, audioFile: "ja_test_clip.wav", featureExtractor: MLXFeatureExtractor(), @@ -250,7 +250,7 @@ final class MLXUnitTests: XCTestCase { func testDetectJapanese() async throws { let targetLanguage = "ja" let whisperKit = try await WhisperKit( - modelFolder: tinyModelPath, + mlxModelFolder: tinyModelPath, featureExtractor: MLXFeatureExtractor(), audioEncoder: MLXAudioEncoder(), textDecoder: MLXTextDecoder(), @@ -283,7 +283,7 @@ final class MLXUnitTests: XCTestCase { for (i, option) in optionsPairs.enumerated() { let result = try await XCTUnwrapAsync( try await transcribe( - modelPath: tinyModelPath, + mlxModelPath: tinyModelPath, options: option.options, audioFile: "ja_test_clip.wav", featureExtractor: MLXFeatureExtractor(), @@ -312,19 +312,35 @@ final class MLXUnitTests: XCTestCase { // MARK: - Utils Tests + func testContiguousStrides() { + let count = 24 + let arr1 = MLXArray(0..