From 2bf82980e3d37e02f472d6f8937c0ea4c4f8efcb Mon Sep 17 00:00:00 2001
From: ZachNagengast <znagengast@gmail.com>
Date: Sun, 4 Feb 2024 23:31:05 -0800
Subject: [PATCH 1/3] Fixes and cleanup from early feedback

---
 .github/workflows/unit-tests.yml              |  1 +
 .../WhisperAX/Views/ContentView.swift         | 11 +++-
 Makefile                                      |  2 +-
 README.md                                     | 45 ++++++++++---
 Sources/WhisperKit/Core/AudioProcessor.swift  | 11 ++--
 Sources/WhisperKit/Core/TextDecoder.swift     |  8 ++-
 Sources/WhisperKit/Core/Utils.swift           | 63 +++++++++++++++++--
 Sources/WhisperKit/Core/WhisperKit.swift      |  2 +-
 Sources/WhisperKitCLI/transcribe.swift        | 17 +++--
 9 files changed, 126 insertions(+), 34 deletions(-)

diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml
index fe92120..672579c 100644
--- a/.github/workflows/unit-tests.yml
+++ b/.github/workflows/unit-tests.yml
@@ -39,5 +39,6 @@ jobs:
       run: xcodebuild build-for-testing -scheme whisperkit-Package -destination 'platform=macOS'
     - name: Run tests
       run: |
+        set -o pipefail
         xcodebuild test-without-building -only-testing WhisperKitTests/UnitTests -scheme whisperkit-Package -showdestinations
         xcodebuild test-without-building -only-testing WhisperKitTests/UnitTests -scheme whisperkit-Package -destination "platform=macOS,arch=arm64" | xcpretty
\ No newline at end of file
diff --git a/Examples/WhisperAX/WhisperAX/Views/ContentView.swift b/Examples/WhisperAX/WhisperAX/Views/ContentView.swift
index 32e5575..afbfe81 100644
--- a/Examples/WhisperAX/WhisperAX/Views/ContentView.swift
+++ b/Examples/WhisperAX/WhisperAX/Views/ContentView.swift
@@ -157,11 +157,11 @@ struct ContentView: View {
 #if os(macOS)
             selectedCategoryId = menu.first(where: { $0.name == selectedTab })?.id
 #endif
-
             fetchModels()
         }
     }
 
+
     // MARK: - Transcription
 
     var transcriptionView: some View {
@@ -169,7 +169,7 @@ struct ContentView: View {
             ScrollView(.horizontal) {
                 HStack(spacing: 1) {
                     let startIndex = max(bufferEnergy.count - 300, 0)
-                    ForEach(Array(bufferEnergy.enumerated())[startIndex...], id: \.offset) { index, energy in
+                    ForEach(Array(bufferEnergy.enumerated())[startIndex...], id: \.element) { index, energy in
                         ZStack {
                             RoundedRectangle(cornerRadius: 2)
                                 .frame(width: 2, height: CGFloat(energy) * 24)
@@ -660,7 +660,12 @@ struct ContentView: View {
         }
 
         localModels = WhisperKit.formatModelFiles(localModels)
-        availableModels = localModels
+        for model in localModels {
+            if !availableModels.contains(model),
+               !disabledModels.contains(model){
+                availableModels.append(model)
+            }
+        }
 
         print("Found locally: \(localModels)")
         print("Previously selected model: \(selectedModel)")
diff --git a/Makefile b/Makefile
index a1a8b56..2ad3451 100644
--- a/Makefile
+++ b/Makefile
@@ -49,7 +49,7 @@ setup-model-repo:
 		cd $(MODEL_REPO_DIR) && git fetch --all && git reset --hard origin/main && git clean -fdx; \
 	else \
 		echo "Repository not found, initializing..."; \
-		GIT_LFS_SKIP_SMUDGE=1 git clone https://hf.co/$(MODEL_REPO) $(MODEL_REPO_DIR); \
+		GIT_LFS_SKIP_SMUDGE=1 git clone  https://huggingface.co/$(MODEL_REPO) $(MODEL_REPO_DIR); \
 	fi
 
 # Download all models
diff --git a/README.md b/README.md
index 4a57f41..41c405e 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,19 @@
+
+<div align="center">
+  
+<a href="https://github.com/argmaxinc/WhisperKit#gh-light-mode-only">
+  <img src="https://github.com/argmaxinc/WhisperKit/assets/1981179/6ac3360b-2f5c-4392-a71a-05c5dda71093" alt="WhisperKit" width="20%" />
+</a>
+
 # WhisperKit
 
-WhisperKit is a Swift package that integrates OpenAI's popular [Whisper](https://github.com/openai/whisper) speech recognition model with Apple's CoreML framework for efficient, local inference on Apple devices. 
+[![Unit Tests](https://github.com/argmaxinc/whisperkit/actions/workflows/unit-tests.yml/badge.svg)](https://github.com/argmaxinc/whisperkit/actions/workflows/unit-tests.yml)
+[![Supported Swift Version](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Fargmaxinc%2FWhisperKit%2Fbadge%3Ftype%3Dswift-versions)](https://swiftpackageindex.com/argmaxinc/WhisperKit) [![Supported Platforms](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Fargmaxinc%2FWhisperKit%2Fbadge%3Ftype%3Dplatforms)](https://swiftpackageindex.com/argmaxinc/WhisperKit)
+[![License](https://img.shields.io/github/license/argmaxinc/whisperkit?color=green)](LICENSE.md)
+
+</div>
+
+WhisperKit is a Swift package that integrates OpenAI's popular [Whisper](https://github.com/openai/whisper) speech recognition model with Apple's CoreML framework for efficient, local inference on Apple devices.
 
 Check out the demo app on [TestFlight](https://testflight.apple.com/join/LPVOyJZW).
 
@@ -21,13 +34,16 @@ Check out the demo app on [TestFlight](https://testflight.apple.com/join/LPVOyJZ
 - [Citation](#citation)
 
 ## Installation
+
 WhisperKit can be integrated into your Swift project using the Swift Package Manager.
 
 ### Prerequisites
+
 - macOS 14.0 or later.
 - Xcode 15.0 or later.
 
 ### Steps
+
 1. Open your Swift project in Xcode.
 2. Navigate to `File` > `Add Package Dependencies...`.
 3. Enter the package repository URL: `https://github.com/argmaxinc/whisperkit`.
@@ -35,9 +51,11 @@ WhisperKit can be integrated into your Swift project using the Swift Package Man
 5. Click `Finish` to add WhisperKit to your project.
 
 ## Getting Started
+
 To get started with WhisperKit, you need to initialize it in your project.
 
 ### Quick Example
+
 This example demonstrates how to transcribe a local audio file:
 
 ```swift
@@ -52,7 +70,9 @@ Task {
 ```
 
 ### Model Selection
+
 WhisperKit automatically downloads the recommended model for the device if not specified. You can also select a specific model by passing in the model name:
+
 ```swift
 let pipe = try? await WhisperKit(model: "large-v3")
 ```
@@ -76,18 +96,25 @@ git clone https://github.com/argmaxinc/whisperkit.git
 cd whisperkit
 ```
 
-Then, setup the environment and download the models.
+Then, setup the environment and download your desired model.
+
+```bash
+make setup
+make download-model MODEL=large-v3
+```
 
 **Note**:
-1. this will download all available models to your local folder, if you only want to download a specific model, see our [HuggingFace repo](https://huggingface.co/argmaxinc/whisperkit-coreml))
-2. before running `download-models`, make sure [git-lfs](https://git-lfs.com) is installed
+
+1. This will download only the model specified by `MODEL` (see what's available in our [HuggingFace repo](https://huggingface.co/argmaxinc/whisperkit-coreml), where we use the prefix `openai_whisper-{MODEL}`)
+2. Before running `download-model`, make sure [git-lfs](https://git-lfs.com) is installed
+
+If you would like download all available models to your local folder, use this command instead:
 
 ```bash
-make setup
 make download-models
 ```
 
-You can then run the CLI with:
+You can then run them via the CLI with:
 
 ```bash
 swift run transcribe --model-path "Models/whisperkit-coreml/openai_whisper-large-v3" --audio-path "path/to/your/audio.{wav,mp3,m4a,flac}" 
@@ -95,19 +122,21 @@ swift run transcribe --model-path "Models/whisperkit-coreml/openai_whisper-large
 
 Which should print a transcription of the audio file.
 
-
 ## Contributing & Roadmap
+
 Our goal is to make WhisperKit better and better over time and we'd love your help! Just search the code for "TODO" for a variety of features that are yet to be built. Please refer to our [contribution guidelines](CONTRIBUTING.md) for submitting issues, pull requests, and coding standards, where we also have a public roadmap of features we are looking forward to building in the future.
 
 ## License
+
 WhisperKit is released under the MIT License. See [LICENSE.md](LICENSE.md) for more details.
 
 ## Citation
+
 If you use WhisperKit for something cool or just find it useful, please drop us a note at [info@takeargmax.com](mailto:info@takeargmax.com)!
 
 If you use WhisperKit for academic work, here is the BibTeX:
 
-```
+```bibtex
 @misc{whisperkit-argmax,
    title = {WhisperKit},
    author = {Argmax, Inc.},
diff --git a/Sources/WhisperKit/Core/AudioProcessor.swift b/Sources/WhisperKit/Core/AudioProcessor.swift
index 48dc402..12ad63d 100644
--- a/Sources/WhisperKit/Core/AudioProcessor.swift
+++ b/Sources/WhisperKit/Core/AudioProcessor.swift
@@ -40,7 +40,7 @@ public protocol AudioProcessing {
     var relativeEnergyWindow: Int { get set }
 
     /// Starts recording audio from the specified input device, resetting the previous state
-    func startRecordingLive(from inputDevice: AVCaptureDevice?, callback: (([Float]) -> Void)?) throws
+    func startRecordingLive(callback: (([Float]) -> Void)?) throws
 
     /// Pause recording
     func pauseRecording()
@@ -53,7 +53,7 @@ public protocol AudioProcessing {
 public extension AudioProcessing {
     // Use default recording device
     func startRecordingLive(callback: (([Float]) -> Void)?) throws {
-        try startRecordingLive(from: nil, callback: callback)
+        try startRecordingLive(callback: callback)
     }
 
     static func padOrTrimAudio(fromArray audioArray: [Float], startAt startIndex: Int = 0, toLength frameLength: Int = 480_000, saveSegment: Bool = false) -> MLMultiArray? {
@@ -382,14 +382,11 @@ public extension AudioProcessor {
         }
     }
 
-    func startRecordingLive(from inputDevice: AVCaptureDevice? = nil, callback: (([Float]) -> Void)? = nil) throws {
+    func startRecordingLive(callback: (([Float]) -> Void)? = nil) throws {
         audioSamples = []
         audioEnergy = []
 
-        if inputDevice != nil {
-            // TODO: implement selecting input device
-            Logging.debug("Input device selection not yet supported")
-        }
+        // TODO: implement selecting input device
 
         audioEngine = try setupEngine()
 
diff --git a/Sources/WhisperKit/Core/TextDecoder.swift b/Sources/WhisperKit/Core/TextDecoder.swift
index a7b12b7..289e15f 100644
--- a/Sources/WhisperKit/Core/TextDecoder.swift
+++ b/Sources/WhisperKit/Core/TextDecoder.swift
@@ -45,7 +45,7 @@ public protocol TextDecoding {
 
 @available(macOS 14, iOS 17, tvOS 14, watchOS 10, *)
 public extension TextDecoding {
-    func prepareDecoderInputs(withPrompt initialPrompt: [Int]) -> DecodingInputs {
+    func prepareDecoderInputs(withPrompt initialPrompt: [Int]) -> DecodingInputs? {
         let tokenShape = [NSNumber(value: 1), NSNumber(value: initialPrompt.count)]
 
         // Initialize MLMultiArray for tokens
@@ -59,11 +59,13 @@ public extension TextDecoding {
         }
 
         guard let kvCacheEmbedDim = self.kvCacheEmbedDim else {
-            fatalError("Unable to determine kvCacheEmbedDim")
+            Logging.error("Unable to determine kvCacheEmbedDim")
+            return nil
         }
 
         guard let kvCacheMaxSequenceLength = self.kvCacheMaxSequenceLength else {
-            fatalError("Unable to determine kvCacheMaxSequenceLength")
+            Logging.error("Unable to determine kvCacheMaxSequenceLength")
+            return nil
         }
 
         // Initialize each MLMultiArray
diff --git a/Sources/WhisperKit/Core/Utils.swift b/Sources/WhisperKit/Core/Utils.swift
index 8dac7f4..fbccae8 100644
--- a/Sources/WhisperKit/Core/Utils.swift
+++ b/Sources/WhisperKit/Core/Utils.swift
@@ -153,14 +153,65 @@ public func modelSupport(for deviceName: String) -> (default: String, disabled:
          let model where model.hasPrefix("iPhone16"): // A17
         return ("base", ["large-v3_turbo", "large-v3", "large-v2_turbo", "large-v2"])
 
-    // TODO: Disable turbo variants for M1
-    case let model where model.hasPrefix("arm64"): // Mac
-        return ("base", [""])
-
-    // Catch-all for unhandled models or macs
+    // Fall through to macOS checks
     default:
-        return ("base", [""])
+        break
+    }
+
+#if os(macOS)
+    if deviceName.hasPrefix("arm64") {
+        if Process.processor.contains("Apple M1") {
+            // Disable turbo variants for M1
+            return ("base", ["large-v3_turbo", "large-v3_turbo_1049MB", "large-v3_turbo_1307MB", "large-v2_turbo", "large-v2_turbo_1116MB", "large-v2_turbo_1430MB"])
+        } else {
+            // Enable all variants for M2 or M3, none disabled
+            return ("base", [])
+        }
+    }
+#endif
+    
+    // Unhandled device to base variant
+    return ("base", [""])
+}
+
+#if os(macOS)
+// From: https://stackoverflow.com/a/71726663
+extension Process {
+    static func stringFromTerminal(command: String) -> String {
+        let task = Process()
+        let pipe = Pipe()
+        task.standardOutput = pipe
+        task.launchPath = "/bin/bash"
+        task.arguments = ["-c", "sysctl -n " + command]
+        task.launch()
+        return String(bytes: pipe.fileHandleForReading.availableData, encoding: .utf8) ?? ""
     }
+    static let processor = stringFromTerminal(command: "machdep.cpu.brand_string")
+    static let cores = stringFromTerminal(command: "machdep.cpu.core_count")
+    static let threads = stringFromTerminal(command: "machdep.cpu.thread_count")
+    static let vendor = stringFromTerminal(command: "machdep.cpu.vendor")
+    static let family = stringFromTerminal(command: "machdep.cpu.family")
+}
+#endif
+
+public func resolveAbsolutePath(_ inputPath: String) -> String {
+    let fileManager = FileManager.default
+
+    // Expanding tilde if present
+    let pathWithTildeExpanded = NSString(string: inputPath).expandingTildeInPath
+
+    // If the path is already absolute, return it
+    if pathWithTildeExpanded.hasPrefix("/") {
+        return pathWithTildeExpanded
+    }
+
+    // Resolving relative path based on the current working directory
+    if let cwd = fileManager.currentDirectoryPath as String? {
+        let resolvedPath = URL(fileURLWithPath: cwd).appendingPathComponent(pathWithTildeExpanded).path
+        return resolvedPath
+    }
+
+    return inputPath
 }
 
 func loadTokenizer(for pretrained: ModelVariant) async throws -> Tokenizer {
diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift
index 3982315..62e2edb 100644
--- a/Sources/WhisperKit/Core/WhisperKit.swift
+++ b/Sources/WhisperKit/Core/WhisperKit.swift
@@ -141,7 +141,7 @@ public class WhisperKit {
             return (modelInfo + additionalInfo).trimmingFromEnd(character: "/", upto: 1)
         }
 
-        // Custom sorting order
+        // Sorting order based on enum
         let sizeOrder = ModelVariant.allCases.map { $0.description }
 
         let sortedModels = availableModels.sorted { firstModel, secondModel in
diff --git a/Sources/WhisperKitCLI/transcribe.swift b/Sources/WhisperKitCLI/transcribe.swift
index e2922c8..3ad862c 100644
--- a/Sources/WhisperKitCLI/transcribe.swift
+++ b/Sources/WhisperKitCLI/transcribe.swift
@@ -11,10 +11,10 @@ import WhisperKit
 @main
 struct WhisperKitCLI: AsyncParsableCommand {
     @Option(help: "Path to audio file")
-    var audioPath: String = "./Tests/WhisperKitTests/Resources/jfk.wav"
+    var audioPath: String = "Tests/WhisperKitTests/Resources/jfk.wav"
 
     @Option(help: "Path of model files")
-    var modelPath: String = "./Models/whisperkit-coreml/openai_whisper-tiny"
+    var modelPath: String = "Models/whisperkit-coreml/openai_whisper-tiny"
 
     @Option(help: "Compute units for audio encoder model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine,random}")
     var audioEncoderComputeUnits: ComputeUnits = .cpuAndNeuralEngine
@@ -71,10 +71,17 @@ struct WhisperKitCLI: AsyncParsableCommand {
     var reportPath: String = "."
 
     func transcribe(audioPath: String, modelPath: String) async throws {
-        guard FileManager.default.fileExists(atPath: modelPath) else {
-            fatalError("Resource path does not exist \(modelPath)")
+        let resolvedModelPath = resolveAbsolutePath(modelPath)
+        guard FileManager.default.fileExists(atPath: resolvedModelPath) else {
+            fatalError("Model path does not exist \(resolvedModelPath)")
         }
 
+        let resolvedAudioPath = resolveAbsolutePath(audioPath)
+        guard FileManager.default.fileExists(atPath: resolvedAudioPath) else {
+            fatalError("Resource path does not exist \(resolvedAudioPath)")
+        }
+
+
         let computeOptions = ModelComputeOptions(
             audioEncoderCompute: audioEncoderComputeUnits.asMLComputeUnits,
             textDecoderCompute: textDecoderComputeUnits.asMLComputeUnits
@@ -104,7 +111,7 @@ struct WhisperKitCLI: AsyncParsableCommand {
             noSpeechThreshold: noSpeechThreshold
         )
 
-        let transcribeResult = try await whisperKit.transcribe(audioPath: audioPath, decodeOptions: options)
+        let transcribeResult = try await whisperKit.transcribe(audioPath: resolvedAudioPath, decodeOptions: options)
 
         let transcription = transcribeResult?.text ?? "Transcription failed"
 

From 90a9ebbcdd009ad138245bf8b299329e68c9d322 Mon Sep 17 00:00:00 2001
From: ZachNagengast <znagengast@gmail.com>
Date: Sun, 4 Feb 2024 23:34:36 -0800
Subject: [PATCH 2/3] Formatting

---
 Makefile | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/Makefile b/Makefile
index 2ad3451..e329842 100644
--- a/Makefile
+++ b/Makefile
@@ -49,7 +49,7 @@ setup-model-repo:
 		cd $(MODEL_REPO_DIR) && git fetch --all && git reset --hard origin/main && git clean -fdx; \
 	else \
 		echo "Repository not found, initializing..."; \
-		GIT_LFS_SKIP_SMUDGE=1 git clone  https://huggingface.co/$(MODEL_REPO) $(MODEL_REPO_DIR); \
+		GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/$(MODEL_REPO) $(MODEL_REPO_DIR); \
 	fi
 
 # Download all models

From 7844745514c2b4d1835870a933aa83e1cf706f94 Mon Sep 17 00:00:00 2001
From: ZachNagengast <znagengast@gmail.com>
Date: Sun, 4 Feb 2024 23:49:36 -0800
Subject: [PATCH 3/3] Update tests

---
 .github/workflows/unit-tests.yml            |  2 +-
 Tests/WhisperKitTests/FunctionalTests.swift |  6 ++++++
 Tests/WhisperKitTests/UnitTests.swift       | 15 +++++++--------
 3 files changed, 14 insertions(+), 9 deletions(-)

diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml
index 672579c..f53d934 100644
--- a/.github/workflows/unit-tests.yml
+++ b/.github/workflows/unit-tests.yml
@@ -36,7 +36,7 @@ jobs:
         path: Models
         key: ${{ runner.os }}-models
     - name: Build
-      run: xcodebuild build-for-testing -scheme whisperkit-Package -destination 'platform=macOS'
+      run: xcodebuild clean build-for-testing -scheme whisperkit-Package -destination 'platform=macOS'
     - name: Run tests
       run: |
         set -o pipefail
diff --git a/Tests/WhisperKitTests/FunctionalTests.swift b/Tests/WhisperKitTests/FunctionalTests.swift
index 9d16b05..a1d49fd 100644
--- a/Tests/WhisperKitTests/FunctionalTests.swift
+++ b/Tests/WhisperKitTests/FunctionalTests.swift
@@ -7,6 +7,12 @@ import XCTest
 
 @available(macOS 14, iOS 17, *)
 final class FunctionalTests: XCTestCase {
+    func testInitLarge() async {
+        let modelPath = largev3ModelPath()
+        let whisperKit = try? await WhisperKit(modelFolder: modelPath, logLevel: .error)
+        XCTAssertNotNil(whisperKit)
+    }
+
     func testOutputAll() async throws {
         let modelPaths = allModelPaths()
 
diff --git a/Tests/WhisperKitTests/UnitTests.swift b/Tests/WhisperKitTests/UnitTests.swift
index ccaaf9f..3691cc5 100644
--- a/Tests/WhisperKitTests/UnitTests.swift
+++ b/Tests/WhisperKitTests/UnitTests.swift
@@ -15,7 +15,7 @@ final class UnitTests: XCTestCase {
         XCTAssertNotNil(whisperKit)
     }
 
-    // MARK: - Model Loading Tests
+    // MARK: - Model Loading Test
 
     func testInitTiny() async {
         let modelPath = tinyModelPath()
@@ -23,12 +23,6 @@ final class UnitTests: XCTestCase {
         XCTAssertNotNil(whisperKit)
     }
 
-    func testInitLarge() async {
-        let modelPath = largev3ModelPath()
-        let whisperKit = try? await WhisperKit(modelFolder: modelPath, logLevel: .error)
-        XCTAssertNotNil(whisperKit)
-    }
-
     // MARK: - Audio Tests
 
     func testAudioFileLoading() {
@@ -161,7 +155,12 @@ final class UnitTests: XCTestCase {
         let decoderInputs = textDecoder.prepareDecoderInputs(withPrompt: [textDecoder.tokenizer!.startOfTranscriptToken])
         let expectedShape: Int = 1
 
-        let decoderOutput = try! await textDecoder.decodeText(from: encoderInput, using: decoderInputs, sampler: tokenSampler, options: decodingOptions)
+        guard let inputs = decoderInputs else {
+            XCTFail("Failed to prepare decoder inputs")
+            return
+        }
+
+        let decoderOutput = try! await textDecoder.decodeText(from: encoderInput, using: inputs, sampler: tokenSampler, options: decodingOptions)
         XCTAssertNotNil(decoderOutput, "Failed to decode text")
         XCTAssertEqual(decoderOutput.count, expectedShape, "Decoder output shape is not as expected")
     }