Skip to content

Commit 17b29b2

Browse files
authored
Merge pull request #66 from juanmc2005/develop
Version 0.4
2 parents 0ebf729 + b40f091 commit 17b29b2

25 files changed

+1809
-1079
lines changed

README.md

+175-38
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,49 @@
1212
<img alt="License" src="https://img.shields.io/github/license/juanmc2005/StreamingSpeakerDiarization?color=g">
1313
</p>
1414

15+
<div align="center">
16+
<h4>
17+
<a href="#installation">
18+
Installation
19+
</a>
20+
<span> | </span>
21+
<a href="#stream-audio">
22+
Stream audio
23+
</a>
24+
<span> | </span>
25+
<a href="#add-your-model">
26+
Add your model
27+
</a>
28+
<span> | </span>
29+
<a href="#tune-hyper-parameters">
30+
Tune hyper-parameters
31+
</a>
32+
<span> | </span>
33+
<a href="#build-pipelines">
34+
Build pipelines
35+
</a>
36+
<br/>
37+
<a href="#powered-by-research">
38+
Research
39+
</a>
40+
<span> | </span>
41+
<a href="#citation">
42+
Citation
43+
</a>
44+
<span> | </span>
45+
<a href="#reproducibility">
46+
Reproducibility
47+
</a>
48+
</h4>
49+
</div>
50+
1551
<br/>
1652

1753
<p align="center">
1854
<img width="100%" src="/demo.gif" title="Real-time diarization example" />
1955
</p>
2056

21-
## Install
57+
## Installation
2258

2359
1) Create environment:
2460

@@ -27,85 +63,186 @@ conda create -n diart python=3.8
2763
conda activate diart
2864
```
2965

30-
2) [Install PyTorch](https://pytorch.org/get-started/locally/#start-locally)
66+
2) Install `PortAudio` and `soundfile`:
67+
68+
```shell
69+
conda install portaudio
70+
conda install pysoundfile -c conda-forge
71+
```
72+
73+
3) [Install PyTorch](https://pytorch.org/get-started/locally/#start-locally)
74+
75+
4) Install pyannote.audio 2.0 (currently in development)
3176

32-
3) Install pyannote.audio 2.0 (currently in development)
3377
```shell
3478
pip install git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyannote-audio
3579
```
3680

37-
4) Install diart:
81+
**Note:** starting from version 0.4, installing pyannote.audio is mandatory to run the default system or to use pyannote-based models. In any other case, this step can be ignored.
82+
83+
5) Install diart:
3884
```shell
3985
pip install diart
4086
```
4187

42-
## Stream your own audio
88+
## Stream audio
4389

44-
### A recorded conversation
90+
### From the command line
91+
92+
A recorded conversation:
4593

4694
```shell
47-
python -m diart.stream /path/to/audio.wav
95+
diart.stream /path/to/audio.wav
4896
```
4997

50-
### From your microphone
98+
A live conversation:
5199

52100
```shell
53-
python -m diart.stream microphone
101+
diart.stream microphone
54102
```
55103

56-
See `python -m diart.stream -h` for more options.
104+
See `diart.stream -h` for more options.
57105

58-
## Inference API
106+
### From python
59107

60-
Run a customized real-time speaker diarization pipeline over an audio stream with `diart.inference.RealTimeInference`:
108+
Run a real-time speaker diarization pipeline over an audio stream with `RealTimeInference`:
61109

62110
```python
63111
from diart.sources import MicrophoneAudioSource
64112
from diart.inference import RealTimeInference
65113
from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig
66114

67-
pipeline = OnlineSpeakerDiarization(PipelineConfig())
68-
audio_source = MicrophoneAudioSource(pipeline.sample_rate)
115+
config = PipelineConfig() # Default parameters
116+
pipeline = OnlineSpeakerDiarization(config)
117+
audio_source = MicrophoneAudioSource(config.sample_rate)
69118
inference = RealTimeInference("/output/path", do_plot=True)
70-
71119
inference(pipeline, audio_source)
72120
```
73121

74-
For faster inference and evaluation on a dataset we recommend to use `Benchmark` (see our notes on [reproducibility](#reproducibility))
122+
For faster inference and evaluation on a dataset we recommend to use `Benchmark` instead (see our notes on [reproducibility](#reproducibility)).
123+
124+
## Add your model
125+
126+
Third-party segmentation and embedding models can be integrated seamlessly by subclassing `SegmentationModel` and `EmbeddingModel`:
127+
128+
```python
129+
import torch
130+
from typing import Optional
131+
from diart.models import EmbeddingModel
132+
from diart.pipelines import PipelineConfig, OnlineSpeakerDiarization
133+
from diart.sources import MicrophoneAudioSource
134+
from diart.inference import RealTimeInference
135+
136+
class MyEmbeddingModel(EmbeddingModel):
137+
def __init__(self):
138+
super().__init__()
139+
self.my_pretrained_model = load("my_model.ckpt")
140+
141+
def __call__(
142+
self,
143+
waveform: torch.Tensor,
144+
weights: Optional[torch.Tensor] = None
145+
) -> torch.Tensor:
146+
return self.my_pretrained_model(waveform, weights)
147+
148+
config = PipelineConfig(embedding=MyEmbeddingModel())
149+
pipeline = OnlineSpeakerDiarization(config)
150+
mic = MicrophoneAudioSource(config.sample_rate)
151+
inference = RealTimeInference("/out/dir")
152+
inference(pipeline, mic)
153+
```
154+
155+
## Tune hyper-parameters
156+
157+
Diart implements a hyper-parameter optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune any pipeline to any dataset.
158+
159+
### From the command line
160+
161+
```shell
162+
diart.tune /wav/dir --reference /rttm/dir --output /out/dir
163+
```
75164

76-
## Build your own pipeline
165+
See `diart.tune -h` for more options.
77166

78-
Diart also provides building blocks that can be combined to create your own pipeline.
167+
### From python
168+
169+
```python
170+
from diart.optim import Optimizer, TauActive, RhoUpdate, DeltaNew
171+
from diart.pipelines import PipelineConfig
172+
from diart.inference import Benchmark
173+
174+
# Benchmark runs and evaluates the pipeline on a dataset
175+
benchmark = Benchmark("/wav/dir", "/rttm/dir", "/out/dir/tmp", show_report=False)
176+
# Base configuration for the pipeline we're going to tune
177+
base_config = PipelineConfig()
178+
# Hyper-parameters to optimize
179+
hparams = [TauActive, RhoUpdate, DeltaNew]
180+
# Optimizer implements the optimization loop
181+
optimizer = Optimizer(benchmark, base_config, hparams, "/out/dir")
182+
# Run optimization
183+
optimizer.optimize(num_iter=100, show_progress=True)
184+
```
185+
186+
This will use `/out/dir/tmp` as a working directory and write results to an sqlite database in `/out/dir`.
187+
188+
### Distributed optimization
189+
190+
For bigger datasets, it is sometimes more convenient to run multiple optimization processes in parallel.
191+
To do this, create a study on a [recommended DBMS](https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/004_distributed.html#sphx-glr-tutorial-10-key-features-004-distributed-py) (e.g. MySQL or PostgreSQL) making sure that the study and database names match:
192+
193+
```shell
194+
mysql -u root -e "CREATE DATABASE IF NOT EXISTS example"
195+
optuna create-study --study-name "example" --storage "mysql://root@localhost/example"
196+
```
197+
198+
Then you can run multiple identical optimizers pointing to the database:
199+
200+
```shell
201+
diart.tune /wav/dir --reference /rttm/dir --output /out/dir --storage mysql://root@localhost/example
202+
```
203+
204+
If you are using the python API, make sure that worker directories are different to avoid concurrency issues:
205+
206+
```python
207+
from diart.optim import Optimizer
208+
from diart.inference import Benchmark
209+
from optuna.samplers import TPESampler
210+
import optuna
211+
212+
ID = 0 # Worker identifier
213+
base_config, hparams = ...
214+
benchmark = Benchmark("/wav/dir", "/rttm/dir", f"/out/dir/worker-{ID}", show_report=False)
215+
study = optuna.load_study("example", "mysql://root@localhost/example", TPESampler())
216+
optimizer = Optimizer(benchmark, base_config, hparams, study)
217+
optimizer.optimize(num_iter=100, show_progress=True)
218+
```
219+
220+
## Build pipelines
221+
222+
For a more advanced usage, diart also provides building blocks that can be combined to create your own pipeline.
79223
Streaming is powered by [RxPY](https://github.com/ReactiveX/RxPY), but the `blocks` module is completely independent and can be used separately.
80224

81225
### Example
82226

83227
Obtain overlap-aware speaker embeddings from a microphone stream:
84228

85229
```python
86-
import rx
87230
import rx.operators as ops
88231
import diart.operators as dops
89232
from diart.sources import MicrophoneAudioSource
90-
from diart.blocks import FramewiseModel, OverlapAwareSpeakerEmbedding
233+
from diart.blocks import SpeakerSegmentation, OverlapAwareSpeakerEmbedding
91234

92-
sample_rate = 16000
235+
segmentation = SpeakerSegmentation.from_pyannote("pyannote/segmentation")
236+
embedding = OverlapAwareSpeakerEmbedding.from_pyannote("pyannote/embedding")
237+
sample_rate = segmentation.model.get_sample_rate()
93238
mic = MicrophoneAudioSource(sample_rate)
94239

95-
# Initialize independent modules
96-
segmentation = FramewiseModel("pyannote/segmentation")
97-
embedding = OverlapAwareSpeakerEmbedding("pyannote/embedding")
98-
99-
# Reformat microphone stream. Defaults to 5s duration and 500ms shift
100-
regular_stream = mic.stream.pipe(dops.regularize_stream(sample_rate))
101-
# Branch the microphone stream to calculate segmentation
102-
segmentation_stream = regular_stream.pipe(ops.map(segmentation))
103-
# Join audio and segmentation stream to calculate speaker embeddings
104-
embedding_stream = rx.zip(
105-
regular_stream, segmentation_stream
106-
).pipe(ops.starmap(embedding))
107-
108-
embedding_stream.subscribe(on_next=lambda emb: print(emb.shape))
240+
stream = mic.stream.pipe(
241+
# Reformat stream to 5s duration and 500ms shift
242+
dops.regularize_audio_stream(sample_rate),
243+
ops.map(lambda wav: (wav, segmentation(wav))),
244+
ops.starmap(embedding)
245+
).subscribe(on_next=lambda emb: print(emb.shape))
109246

110247
mic.read()
111248
```
@@ -165,7 +302,7 @@ To obtain the best results, make sure to use the following hyper-parameters:
165302
`diart.benchmark` and `diart.inference.Benchmark` can quickly run and evaluate the pipeline, and even measure its real-time latency. For instance, for a DIHARD III configuration:
166303

167304
```shell
168-
python -m diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 --output /out/dir
305+
diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 --output /out/dir
169306
```
170307

171308
or using the inference API:
@@ -184,11 +321,11 @@ config = PipelineConfig(
184321
pipeline = OnlineSpeakerDiarization(config)
185322
benchmark = Benchmark("/wav/dir", "/rttm/dir", "/out/dir")
186323

187-
benchmark(pipeline, batch_size=32)
324+
benchmark(pipeline)
188325
```
189326

190327
This runs a faster inference by pre-calculating model outputs in batches.
191-
See `python -m diart.benchmark -h` for more options.
328+
See `diart.benchmark -h` for more options.
192329

193330
For convenience and to facilitate future comparisons, we also provide the [expected outputs](/expected_outputs) of the paper implementation in RTTM format for every entry of Table 1 and Figure 5. This includes the VBx offline topline as well as our proposed online approach with latencies 500ms, 1s, 2s, 3s, 4s, and 5s.
194331

requirements.txt

+5-1
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,8 @@ sounddevice>=0.4.2
66
einops>=0.3.0
77
tqdm>=4.64.0
88
pandas>=1.4.2
9-
git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyannote-audio
9+
torchaudio>=0.10,<1.0
10+
pyannote.core>=4.4
11+
pyannote.database>=4.1.1
12+
pyannote.metrics>=3.2
13+
optuna>=2.10

setup.cfg

+14-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[metadata]
22
name=diart
3-
version=0.3.0
3+
version=0.4.0
44
author=Juan Manuel Coria
55
description=Speaker diarization in real time
66
long_description=file: README.md
@@ -9,7 +9,7 @@ keywords=speaker diarization, streaming, online, real time, rxpy
99
url=https://github.com/juanmc2005/StreamingSpeakerDiarization
1010
license=MIT
1111
classifiers=
12-
Development Status :: 3 - Alpha
12+
Development Status :: 4 - Beta
1313
License :: OSI Approved :: MIT License
1414
Topic :: Multimedia :: Sound/Audio :: Analysis
1515
Topic :: Multimedia :: Sound/Audio :: Speech
@@ -19,7 +19,7 @@ classifiers=
1919
package_dir=
2020
=src
2121
packages=find:
22-
install_requires =
22+
install_requires=
2323
numpy>=1.20.2
2424
matplotlib>=3.3.3
2525
rx>=3.2.0
@@ -28,8 +28,17 @@ install_requires =
2828
einops>=0.3.0
2929
tqdm>=4.64.0
3030
pandas>=1.4.2
31-
pyannote-audio @ git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyannote-audio
32-
31+
torchaudio>=0.10,<1.0
32+
pyannote.core>=4.4
33+
pyannote.database>=4.1.1
34+
pyannote.metrics>=3.2
35+
optuna>=2.10
3336

3437
[options.packages.find]
3538
where=src
39+
40+
[options.entry_points]
41+
console_scripts=
42+
diart.stream=diart.stream:run
43+
diart.benchmark=diart.benchmark:run
44+
diart.tune=diart.tune:run

src/diart/argdoc.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@
66
GAMMA = "Parameter gamma for overlapped speech penalty"
77
BETA = "Parameter beta for overlapped speech penalty"
88
MAX_SPEAKERS = "Maximum number of speakers"
9-
GPU = "Run on GPU"
9+
CPU = "Force models to run on CPU"
10+
BATCH_SIZE = "For segmentation and embedding pre-calculation. If BATCH_SIZE < 2, run fully online and estimate real-time latency"
1011
OUTPUT = "Directory to store the system's output in RTTM format"

0 commit comments

Comments
 (0)