Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add VoiceActivityDetector base class #199

Merged
merged 2 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Sources/WhisperKit/Core/AudioChunker.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ public extension AudioChunking {
open class VADAudioChunker: AudioChunking {
/// prevent hallucinations at the end of the clip by stopping up to 1.0s early
private let windowPadding: Int
private let vad = EnergyVAD()
private let vad: VoiceActivityDetector

init(windowPadding: Int = 16000) {
init(windowPadding: Int = 16000, vad: VoiceActivityDetector = EnergyVAD()) {
self.windowPadding = windowPadding
self.vad = vad
}

private func splitOnMiddleOfLongestSilence(audioArray: [Float], startIndex: Int, endIndex: Int) -> Int {
Expand Down
58 changes: 58 additions & 0 deletions Sources/WhisperKit/Core/VAD/EnergyVAD.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// For licensing see accompanying LICENSE.md file.
// Copyright © 2024 Argmax, Inc. All rights reserved.

import Foundation

/// Voice activity detection based on energy threshold
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
final class EnergyVAD: VoiceActivityDetector {
var energyThreshold: Float

/// Initialize a new EnergyVAD instance
/// - Parameters:
/// - sampleRate: Audio sample rate
/// - frameLength: Frame length in seconds
/// - frameOverlap: frame overlap in seconds, this will include `frameOverlap` length audio into the `frameLength` and is helpful to catch audio that starts exactly at chunk boundaries
/// - energyThreshold: minimal energy threshold
convenience init(
sampleRate: Int = WhisperKit.sampleRate,
frameLength: Float = 0.1,
frameOverlap: Float = 0.0,
energyThreshold: Float = 0.02
) {
self.init(
sampleRate: sampleRate,
// Compute frame length and overlap in number of samples
frameLengthSamples: Int(frameLength * Float(sampleRate)),
frameOverlapSamples: Int(frameOverlap * Float(sampleRate)),
energyThreshold: energyThreshold
)
}

required init(
sampleRate: Int = 16000,
frameLengthSamples: Int,
frameOverlapSamples: Int = 0,
energyThreshold: Float = 0.02
) {
self.energyThreshold = energyThreshold
super.init(sampleRate: sampleRate, frameLengthSamples: frameLengthSamples, frameOverlapSamples: frameOverlapSamples)
}

override func voiceActivity(in waveform: [Float]) -> [Bool] {
let chunkRatio = Double(waveform.count) / Double(frameLengthSamples)

// Round up if uneven, the final chunk will not be a full `frameLengthSamples` long
let count = Int(chunkRatio.rounded(.up))

let chunkedVoiceActivity = AudioProcessor.calculateVoiceActivityInChunks(
of: waveform,
chunkCount: count,
frameLengthSamples: frameLengthSamples,
frameOverlapSamples: frameOverlapSamples,
energyThreshold: energyThreshold
)

return chunkedVoiceActivity
}
}
Original file line number Diff line number Diff line change
@@ -1,67 +1,47 @@
// For licensing see accompanying LICENSE.md file.
// Copyright © 2024 Argmax, Inc. All rights reserved.

import Accelerate
import Foundation

/// Voice activity detection based on energy threshold
/// A base class for Voice Activity Detection (VAD), used to identify and separate segments of audio that contain human speech from those that do not.
/// Subclasses must implement the `voiceActivity(in:)` method to provide specific voice activity detection functionality.
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
final class EnergyVAD {
class VoiceActivityDetector {
/// The sample rate of the audio signal, in samples per second.
var sampleRate: Int

/// The length of each frame in samples.
var frameLengthSamples: Int

// The number of samples overlapping between consecutive frames.
var frameOverlapSamples: Int
var energyThreshold: Float

/// Initialize a new EnergyVAD instance
/// Initializes a new `VoiceActivityDetector` instance with the specified parameters.
/// - Parameters:
/// - sampleRate: Audio sample rate
/// - frameLength: Frame length in seconds
/// - frameOverlap: frame overlap in seconds, this will include `frameOverlap` length audio into the `frameLength` and is helpful to catch audio that starts exactly at chunk boundaries
/// - energyThreshold: minimal energy threshold
convenience init(
sampleRate: Int = WhisperKit.sampleRate,
frameLength: Float = 0.1,
frameOverlap: Float = 0.0,
energyThreshold: Float = 0.02
) {
self.init(
sampleRate: sampleRate,
// Compute frame length and overlap in number of samples
frameLengthSamples: Int(frameLength * Float(sampleRate)),
frameOverlapSamples: Int(frameOverlap * Float(sampleRate)),
energyThreshold: energyThreshold
)
}

required init(
/// - sampleRate: The sample rate of the audio signal in samples per second. Defaults to 16000.
/// - frameLengthSamples: The length of each frame in samples.
/// - frameOverlapSamples: The number of samples overlapping between consecutive frames. Defaults to 0.
/// - Note: Subclasses should override the `voiceActivity(in:)` method to provide specific VAD functionality.
init(
sampleRate: Int = 16000,
frameLengthSamples: Int,
frameOverlapSamples: Int = 0,
energyThreshold: Float = 0.02
frameOverlapSamples: Int = 0
) {
self.sampleRate = sampleRate
self.frameLengthSamples = frameLengthSamples
self.frameOverlapSamples = frameOverlapSamples
self.energyThreshold = energyThreshold
}

/// Analyzes the provided audio waveform to determine which segments contain voice activity.
/// - Parameter waveform: An array of `Float` values representing the audio waveform.
/// - Returns: An array of `Bool` values where `true` indicates the presence of voice activity and `false` indicates silence.
func voiceActivity(in waveform: [Float]) -> [Bool] {
let chunkRatio = Double(waveform.count) / Double(frameLengthSamples)

// Round up if uneven, the final chunk will not be a full `frameLengthSamples` long
let count = Int(chunkRatio.rounded(.up))

let chunkedVoiceActivity = AudioProcessor.calculateVoiceActivityInChunks(
of: waveform,
chunkCount: count,
frameLengthSamples: frameLengthSamples,
frameOverlapSamples: frameOverlapSamples,
energyThreshold: energyThreshold
)

return chunkedVoiceActivity
fatalError("`voiceActivity` must be implemented by subclass")
}

/// Calculates and returns a list of active audio chunks, each represented by a start and end index.
/// - Parameter waveform: An array of `Float` values representing the audio waveform.
/// - Returns: An array of tuples where each tuple contains the start and end indices of an active audio chunk.
func calculateActiveChunks(in waveform: [Float]) -> [(startIndex: Int, endIndex: Int)] {
let vad: [Bool] = voiceActivity(in: waveform)
var result = [(startIndex: Int, endIndex: Int)]()
Expand Down Expand Up @@ -91,41 +71,9 @@ final class EnergyVAD {
return result
}

func voiceActivityClipTimestamps(in waveform: [Float]) -> [Float] {
let nonSilentChunks = calculateActiveChunks(in: waveform)
var clipTimestamps = [Float]()

for chunk in nonSilentChunks {
let startTimestamp = Float(chunk.startIndex) / Float(sampleRate)
let endTimestamp = Float(chunk.endIndex) / Float(sampleRate)

clipTimestamps.append(contentsOf: [startTimestamp, endTimestamp])
}

return clipTimestamps
}

func calculateNonSilentSeekClips(in waveform: [Float]) -> [(start: Int, end: Int)] {
let clipTimestamps = voiceActivityClipTimestamps(in: waveform)
let options = DecodingOptions(clipTimestamps: clipTimestamps)
let seekClips = prepareSeekClips(contentFrames: waveform.count, decodeOptions: options)
return seekClips
}

func calculateSeekTimestamps(in waveform: [Float]) -> [(startTime: Float, endTime: Float)] {
let nonSilentChunks = calculateActiveChunks(in: waveform)
var seekTimestamps = [(startTime: Float, endTime: Float)]()

for chunk in nonSilentChunks {
let startTimestamp = Float(chunk.startIndex) / Float(sampleRate)
let endTimestamp = Float(chunk.endIndex) / Float(sampleRate)

seekTimestamps.append(contentsOf: [(startTime: startTimestamp, endTime: endTimestamp)])
}

return seekTimestamps
}

/// Converts a voice activity index to the corresponding audio sample index.
/// - Parameter index: The voice activity index to convert.
/// - Returns: The corresponding audio sample index.
func voiceActivityIndexToAudioSampleIndex(_ index: Int) -> Int {
return index * frameLengthSamples
}
Expand All @@ -134,6 +82,9 @@ final class EnergyVAD {
return Float(voiceActivityIndexToAudioSampleIndex(index)) / Float(sampleRate)
}

/// Identifies the longest continuous period of silence within the provided voice activity detection results.
/// - Parameter vadResult: An array of `Bool` values representing voice activity detection results.
/// - Returns: A tuple containing the start and end indices of the longest silence period, or `nil` if no silence is found.
func findLongestSilence(in vadResult: [Bool]) -> (startIndex: Int, endIndex: Int)? {
var longestStartIndex: Int?
var longestEndIndex: Int?
Expand Down Expand Up @@ -165,4 +116,41 @@ final class EnergyVAD {
return nil
}
}

// MARK - Utility

func voiceActivityClipTimestamps(in waveform: [Float]) -> [Float] {
let nonSilentChunks = calculateActiveChunks(in: waveform)
var clipTimestamps = [Float]()

for chunk in nonSilentChunks {
let startTimestamp = Float(chunk.startIndex) / Float(sampleRate)
let endTimestamp = Float(chunk.endIndex) / Float(sampleRate)

clipTimestamps.append(contentsOf: [startTimestamp, endTimestamp])
}

return clipTimestamps
}

func calculateNonSilentSeekClips(in waveform: [Float]) -> [(start: Int, end: Int)] {
let clipTimestamps = voiceActivityClipTimestamps(in: waveform)
let options = DecodingOptions(clipTimestamps: clipTimestamps)
let seekClips = prepareSeekClips(contentFrames: waveform.count, decodeOptions: options)
return seekClips
}

func calculateSeekTimestamps(in waveform: [Float]) -> [(startTime: Float, endTime: Float)] {
let nonSilentChunks = calculateActiveChunks(in: waveform)
var seekTimestamps = [(startTime: Float, endTime: Float)]()

for chunk in nonSilentChunks {
let startTimestamp = Float(chunk.startIndex) / Float(sampleRate)
let endTimestamp = Float(chunk.endIndex) / Float(sampleRate)

seekTimestamps.append(contentsOf: [(startTime: startTimestamp, endTime: endTimestamp)])
}

return seekTimestamps
}
}