Skip to content

Commit 2734c04

Browse files
authored
Merge pull request #87 from juanmc2005/develop
Version 0.5
2 parents d4ff0ee + b75dc9f commit 2734c04

17 files changed

+451
-169
lines changed

README.md

+59-47
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
Stream audio
2323
</a>
2424
<span> | </span>
25-
<a href="#add-your-model">
26-
Add your model
25+
<a href="#custom-models">
26+
Custom models
2727
</a>
2828
<span> | </span>
2929
<a href="#tune-hyper-parameters">
@@ -34,6 +34,10 @@
3434
Build pipelines
3535
</a>
3636
<br/>
37+
<a href="#websockets">
38+
WebSockets
39+
</a>
40+
<span> | </span>
3741
<a href="#powered-by-research">
3842
Research
3943
</a>
@@ -72,10 +76,10 @@ conda install pysoundfile -c conda-forge
7276

7377
3) [Install PyTorch](https://pytorch.org/get-started/locally/#start-locally)
7478

75-
4) Install pyannote.audio 2.0 (currently no official release)
79+
4) Install pyannote.audio
7680

7781
```shell
78-
pip install git+https://github.com/pyannote/pyannote-audio.git@2.0.1#egg=pyannote-audio
82+
pip install pyannote.audio==2.0.1
7983
```
8084

8185
**Note:** starting from version 0.4, installing pyannote.audio is mandatory to run the default system or to use pyannote-based models. In any other case, this step can be ignored.
@@ -105,25 +109,26 @@ See `diart.stream -h` for more options.
105109

106110
### From python
107111

108-
Run a real-time speaker diarization pipeline over an audio stream with `RealTimeInference`:
112+
Use `RealTimeInference` to easily run a pipeline on an audio source and write the results to disk:
109113

110114
```python
111115
from diart.sources import MicrophoneAudioSource
112116
from diart.inference import RealTimeInference
113-
from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig
114-
115-
config = PipelineConfig() # Default parameters
116-
pipeline = OnlineSpeakerDiarization(config)
117-
audio_source = MicrophoneAudioSource(config.sample_rate)
118-
inference = RealTimeInference("/output/path", do_plot=True)
119-
inference(pipeline, audio_source)
117+
from diart.pipelines import OnlineSpeakerDiarization
118+
from diart.sinks import RTTMWriter
119+
120+
pipeline = OnlineSpeakerDiarization()
121+
mic = MicrophoneAudioSource(pipeline.config.sample_rate)
122+
inference = RealTimeInference(pipeline, mic, do_plot=True)
123+
inference.attach_observers(RTTMWriter("/output/file.rttm"))
124+
inference()
120125
```
121126

122-
For faster inference and evaluation on a dataset we recommend to use `Benchmark` instead (see our notes on [reproducibility](#reproducibility)).
127+
For inference and evaluation on a dataset we recommend to use `Benchmark` (see notes on [reproducibility](#reproducibility)).
123128

124-
## Add your model
129+
## Custom models
125130

126-
Third-party segmentation and embedding models can be integrated seamlessly by subclassing `SegmentationModel` and `EmbeddingModel`:
131+
Third-party models can be integrated seamlessly by subclassing `SegmentationModel` and `EmbeddingModel`:
127132

128133
```python
129134
import torch
@@ -148,8 +153,8 @@ class MyEmbeddingModel(EmbeddingModel):
148153
config = PipelineConfig(embedding=MyEmbeddingModel())
149154
pipeline = OnlineSpeakerDiarization(config)
150155
mic = MicrophoneAudioSource(config.sample_rate)
151-
inference = RealTimeInference("/out/dir")
152-
inference(pipeline, mic)
156+
inference = RealTimeInference(pipeline, mic)
157+
inference()
153158
```
154159

155160
## Tune hyper-parameters
@@ -159,31 +164,21 @@ Diart implements a hyper-parameter optimizer based on [optuna](https://optuna.re
159164
### From the command line
160165

161166
```shell
162-
diart.tune /wav/dir --reference /rttm/dir --output /out/dir
167+
diart.tune /wav/dir --reference /rttm/dir --output /output/dir
163168
```
164169

165170
See `diart.tune -h` for more options.
166171

167172
### From python
168173

169174
```python
170-
from diart.optim import Optimizer, TauActive, RhoUpdate, DeltaNew
171-
from diart.pipelines import PipelineConfig
172-
from diart.inference import Benchmark
175+
from diart.optim import Optimizer
173176

174-
# Benchmark runs and evaluates the pipeline on a dataset
175-
benchmark = Benchmark("/wav/dir", "/rttm/dir", "/out/dir/tmp", show_report=False)
176-
# Base configuration for the pipeline we're going to tune
177-
base_config = PipelineConfig()
178-
# Hyper-parameters to optimize
179-
hparams = [TauActive, RhoUpdate, DeltaNew]
180-
# Optimizer implements the optimization loop
181-
optimizer = Optimizer(benchmark, base_config, hparams, "/out/dir")
182-
# Run optimization
183-
optimizer.optimize(num_iter=100, show_progress=True)
177+
optimizer = Optimizer("/wav/dir", "/rttm/dir", "/output/dir")
178+
optimizer(num_iter=100)
184179
```
185180

186-
This will use `/out/dir/tmp` as a working directory and write results to an sqlite database in `/out/dir`.
181+
This will write results to an sqlite database in `/output/dir`.
187182

188183
### Distributed optimization
189184

@@ -195,26 +190,23 @@ mysql -u root -e "CREATE DATABASE IF NOT EXISTS example"
195190
optuna create-study --study-name "example" --storage "mysql://root@localhost/example"
196191
```
197192

198-
Then you can run multiple identical optimizers pointing to the database:
193+
You can now run multiple identical optimizers pointing to this database:
199194

200195
```shell
201-
diart.tune /wav/dir --reference /rttm/dir --output /out/dir --storage mysql://root@localhost/example
196+
diart.tune /wav/dir --reference /rttm/dir --storage mysql://root@localhost/example
202197
```
203198

204-
If you are using the python API, make sure that worker directories are different to avoid concurrency issues:
199+
or in python:
205200

206201
```python
207202
from diart.optim import Optimizer
208-
from diart.inference import Benchmark
209203
from optuna.samplers import TPESampler
210204
import optuna
211205

212-
ID = 0 # Worker identifier
213-
base_config, hparams = ...
214-
benchmark = Benchmark("/wav/dir", "/rttm/dir", f"/out/dir/worker-{ID}", show_report=False)
215-
study = optuna.load_study("example", "mysql://root@localhost/example", TPESampler())
216-
optimizer = Optimizer(benchmark, base_config, hparams, study)
217-
optimizer.optimize(num_iter=100, show_progress=True)
206+
db = "mysql://root@localhost/example"
207+
study = optuna.load_study("example", db, TPESampler())
208+
optimizer = Optimizer("/wav/dir", "/rttm/dir", study)
209+
optimizer(num_iter=100)
218210
```
219211

220212
## Build pipelines
@@ -256,6 +248,24 @@ torch.Size([4, 512])
256248
...
257249
```
258250

251+
## WebSockets
252+
253+
Diart is also compatible with the WebSocket protocol to serve pipelines on the web.
254+
255+
In the following example we build a minimal server that receives audio chunks and sends back predictions in RTTM format:
256+
257+
```python
258+
from diart.pipelines import OnlineSpeakerDiarization
259+
from diart.sources import WebSocketAudioSource
260+
from diart.inference import RealTimeInference
261+
262+
pipeline = OnlineSpeakerDiarization()
263+
source = WebSocketAudioSource(pipeline.config.sample_rate, "localhost", 7007)
264+
inference = RealTimeInference(pipeline, source, do_plot=True)
265+
inference.attach_hooks(lambda ann_wav: source.send(ann_wav[0].to_rttm()))
266+
inference()
267+
```
268+
259269
## Powered by research
260270

261271
Diart is the official implementation of the paper *[Overlap-aware low-latency online speaker diarization based on end-to-end local segmentation](/paper.pdf)* by [Juan Manuel Coria](https://juanmc2005.github.io/), [Hervé Bredin](https://herve.niderb.fr), [Sahar Ghannay](https://saharghannay.github.io/) and [Sophie Rosset](https://perso.limsi.fr/rosset/).
@@ -299,32 +309,34 @@ To obtain the best results, make sure to use the following hyper-parameters:
299309
| DIHARD II | 1s | 0.619 | 0.326 | 0.997 |
300310
| DIHARD II | 5s | 0.555 | 0.422 | 1.517 |
301311

302-
`diart.benchmark` and `diart.inference.Benchmark` can quickly run and evaluate the pipeline, and even measure its real-time latency. For instance, for a DIHARD III configuration:
312+
`diart.benchmark` and `diart.inference.Benchmark` can run, evaluate and measure the real-time latency of the pipeline. For instance, for a DIHARD III configuration:
303313

304314
```shell
305-
diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 --output /out/dir
315+
diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 --segmentation pyannote/segmentation@Interspeech2021
306316
```
307317

308318
or using the inference API:
309319

310320
```python
311321
from diart.inference import Benchmark
312322
from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig
323+
from diart.models import SegmentationModel
313324

314325
config = PipelineConfig(
326+
# Set the model used in the paper
327+
segmentation=SegmentationModel.from_pyannote("pyannote/segmentation@Interspeech2021"),
315328
step=0.5,
316329
latency=0.5,
317330
tau_active=0.555,
318331
rho_update=0.422,
319332
delta_new=1.517
320333
)
321334
pipeline = OnlineSpeakerDiarization(config)
322-
benchmark = Benchmark("/wav/dir", "/rttm/dir", "/out/dir")
323-
335+
benchmark = Benchmark("/wav/dir", "/rttm/dir")
324336
benchmark(pipeline)
325337
```
326338

327-
This runs a faster inference by pre-calculating model outputs in batches.
339+
This pre-calculates model outputs in batches, so it runs a lot faster.
328340
See `diart.benchmark -h` for more options.
329341

330342
For convenience and to facilitate future comparisons, we also provide the [expected outputs](/expected_outputs) of the paper implementation in RTTM format for every entry of Table 1 and Figure 5. This includes the VBx offline topline as well as our proposed online approach with latencies 500ms, 1s, 2s, 3s, 4s, and 5s.

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ einops>=0.3.0
77
tqdm>=4.64.0
88
pandas>=1.4.2
99
torchaudio>=0.10,<1.0
10-
pyannote.core>=4.4
10+
pyannote.core>=4.5
1111
pyannote.database>=4.1.1
1212
pyannote.metrics>=3.2
1313
optuna>=2.10
14+
websockets>=10.3

setup.cfg

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[metadata]
22
name=diart
3-
version=0.4.0
3+
version=0.5.0
44
author=Juan Manuel Coria
55
description=Speaker diarization in real time
66
long_description=file: README.md
@@ -29,10 +29,11 @@ install_requires=
2929
tqdm>=4.64.0
3030
pandas>=1.4.2
3131
torchaudio>=0.10,<1.0
32-
pyannote.core>=4.4
32+
pyannote.core>=4.5
3333
pyannote.database>=4.1.1
3434
pyannote.metrics>=3.2
3535
optuna>=2.10
36+
websockets>=10.3
3637

3738
[options.packages.find]
3839
where=src

src/diart/argdoc.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
SEGMENTATION = "Segmentation model name from pyannote"
2+
EMBEDDING = "Embedding model name from pyannote"
13
STEP = "Sliding window step (in seconds)"
24
LATENCY = "System latency (in seconds). STEP <= LATENCY <= CHUNK_DURATION"
35
TAU = "Probability threshold to consider a speaker as active. 0 <= TAU <= 1"

src/diart/benchmark.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
import argparse
2+
from pathlib import Path
23

34
import torch
45

56
import diart.argdoc as argdoc
67
from diart.inference import Benchmark
8+
from diart.models import SegmentationModel, EmbeddingModel
79
from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig
810

911

1012
def run():
1113
parser = argparse.ArgumentParser()
12-
parser.add_argument("root", type=str, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)")
13-
parser.add_argument("--reference", type=str,
14+
parser.add_argument("root", type=Path, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)")
15+
parser.add_argument("--segmentation", default="pyannote/segmentation", type=str,
16+
help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation")
17+
parser.add_argument("--embedding", default="pyannote/embedding", type=str,
18+
help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding")
19+
parser.add_argument("--reference", type=Path,
1420
help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files")
1521
parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
1622
parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5")
@@ -23,20 +29,25 @@ def run():
2329
parser.add_argument("--batch-size", default=32, type=int, help=f"{argdoc.BATCH_SIZE}. Defaults to 32")
2430
parser.add_argument("--cpu", dest="cpu", action="store_true",
2531
help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise")
26-
parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to `root`")
32+
parser.add_argument("--output", type=Path, help=f"{argdoc.OUTPUT}. Defaults to no writing")
2733
args = parser.parse_args()
2834
args.device = torch.device("cpu") if args.cpu else None
35+
args.segmentation = SegmentationModel.from_pyannote(args.segmentation)
36+
args.embedding = EmbeddingModel.from_pyannote(args.embedding)
2937

3038
benchmark = Benchmark(
3139
args.root,
3240
args.reference,
3341
args.output,
3442
show_progress=True,
3543
show_report=True,
36-
batch_size=args.batch_size
44+
batch_size=args.batch_size,
3745
)
3846

39-
benchmark(OnlineSpeakerDiarization(PipelineConfig.from_namespace(args), profile=True))
47+
pipeline = OnlineSpeakerDiarization(PipelineConfig.from_namespace(args), profile=True)
48+
report = benchmark(pipeline)
49+
if args.output is not None and report is not None:
50+
report.to_csv(args.output / "benchmark_report.csv")
4051

4152

4253
if __name__ == "__main__":

src/diart/blocks/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@
1313
OverlapAwareSpeakerEmbedding,
1414
)
1515
from .segmentation import SpeakerSegmentation
16-
from .utils import Binarize
16+
from .utils import Binarize, Resample, AdjustVolume

src/diart/blocks/segmentation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,4 @@ def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures:
3737
with torch.no_grad():
3838
wave = rearrange(self.formatter.cast(waveform), "batch sample channel -> batch channel sample")
3939
output = self.model(wave.to(self.device)).cpu()
40-
return self.formatter.restore_type(output)
40+
return self.formatter.restore_type(output)

0 commit comments

Comments
 (0)