Skip to content

Commit 35c4dba

Browse files
authored
Merge pull request #139 from juanmc2005/develop
Version 0.7
2 parents e81917c + 8982d4f commit 35c4dba

24 files changed

+1090
-347
lines changed

README.md

+67-26
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
<br/>
22

33
<p align="center">
4-
<img src="/logo.png" title="Logo" />
4+
<img width="40%" src="/logo.jpg" title="Logo" />
55
</p>
66

77
<p align="center">
@@ -67,11 +67,10 @@ conda create -n diart python=3.8
6767
conda activate diart
6868
```
6969

70-
2) Install `PortAudio` and `soundfile`:
70+
2) Install audio libraries:
7171

7272
```shell
73-
conda install portaudio
74-
conda install pysoundfile -c conda-forge
73+
conda install portaudio pysoundfile ffmpeg -c conda-forge
7574
```
7675

7776
3) Install diart:
@@ -101,6 +100,8 @@ diart.stream /path/to/audio.wav
101100
A live conversation:
102101

103102
```shell
103+
# Use "microphone:ID" to select a non-default device
104+
# See `python -m sounddevice` for available devices
104105
diart.stream microphone
105106
```
106107

@@ -127,29 +128,49 @@ For inference and evaluation on a dataset we recommend to use `Benchmark` (see n
127128

128129
## Custom models
129130

130-
Third-party models can be integrated seamlessly by subclassing `SegmentationModel` and `EmbeddingModel`:
131+
Third-party models can be integrated seamlessly by subclassing `SegmentationModel` and `EmbeddingModel` (which are PyTorch `Module` subclasses):
131132

132133
```python
133-
import torch
134-
from typing import Optional
135134
from diart import OnlineSpeakerDiarization, PipelineConfig
136-
from diart.models import EmbeddingModel
135+
from diart.models import EmbeddingModel, SegmentationModel
137136
from diart.sources import MicrophoneAudioSource
138137
from diart.inference import RealTimeInference
139138

139+
140+
def model_loader():
141+
return load_pretrained_model("my_model.ckpt")
142+
143+
144+
class MySegmentationModel(SegmentationModel):
145+
def __init__(self):
146+
super().__init__(model_loader)
147+
148+
@property
149+
def sample_rate(self) -> int:
150+
return 16000
151+
152+
@property
153+
def duration(self) -> float:
154+
return 2 # seconds
155+
156+
def forward(self, waveform):
157+
# self.model is created lazily
158+
return self.model(waveform)
159+
160+
140161
class MyEmbeddingModel(EmbeddingModel):
141162
def __init__(self):
142-
super().__init__()
143-
self.my_pretrained_model = load("my_model.ckpt")
163+
super().__init__(model_loader)
164+
165+
def forward(self, waveform, weights):
166+
# self.model is created lazily
167+
return self.model(waveform, weights)
168+
144169

145-
def __call__(
146-
self,
147-
waveform: torch.Tensor,
148-
weights: Optional[torch.Tensor] = None
149-
) -> torch.Tensor:
150-
return self.my_pretrained_model(waveform, weights)
151-
152-
config = PipelineConfig(embedding=MyEmbeddingModel())
170+
config = PipelineConfig(
171+
segmentation=MySegmentationModel(),
172+
embedding=MyEmbeddingModel()
173+
)
153174
pipeline = OnlineSpeakerDiarization(config)
154175
mic = MicrophoneAudioSource(config.sample_rate)
155176
inference = RealTimeInference(pipeline, mic)
@@ -225,7 +246,7 @@ from diart.blocks import SpeakerSegmentation, OverlapAwareSpeakerEmbedding
225246

226247
segmentation = SpeakerSegmentation.from_pyannote("pyannote/segmentation")
227248
embedding = OverlapAwareSpeakerEmbedding.from_pyannote("pyannote/embedding")
228-
sample_rate = segmentation.model.get_sample_rate()
249+
sample_rate = segmentation.model.sample_rate
229250
mic = MicrophoneAudioSource(sample_rate)
230251

231252
stream = mic.stream.pipe(
@@ -252,7 +273,20 @@ torch.Size([1, 3, 512])
252273

253274
Diart is also compatible with the WebSocket protocol to serve pipelines on the web.
254275

255-
In the following example we build a minimal server that receives audio chunks and sends back predictions in RTTM format:
276+
### From the command line
277+
278+
```commandline
279+
diart.serve --host 0.0.0.0 --port 7007
280+
diart.client microphone --host <server-address> --port 7007
281+
```
282+
283+
**Note:** please make sure that the client uses the same `step` and `sample_rate` than the server with `--step` and `-sr`.
284+
285+
See `-h` for more options.
286+
287+
### From python
288+
289+
For customized solutions, a server can also be created in python using the `WebSocketAudioSource`:
256290

257291
```python
258292
from diart import OnlineSpeakerDiarization
@@ -261,7 +295,7 @@ from diart.inference import RealTimeInference
261295

262296
pipeline = OnlineSpeakerDiarization()
263297
source = WebSocketAudioSource(pipeline.config.sample_rate, "localhost", 7007)
264-
inference = RealTimeInference(pipeline, source, do_plot=True)
298+
inference = RealTimeInference(pipeline, source)
265299
inference.attach_hooks(lambda ann_wav: source.send(ann_wav[0].to_rttm()))
266300
prediction = inference()
267301
```
@@ -318,22 +352,29 @@ diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1
318352
or using the inference API:
319353

320354
```python
321-
from diart.inference import Benchmark
355+
from diart.inference import Benchmark, Parallelize
322356
from diart import OnlineSpeakerDiarization, PipelineConfig
323357
from diart.models import SegmentationModel
324358

359+
benchmark = Benchmark("/wav/dir", "/rttm/dir")
360+
361+
name = "pyannote/segmentation@Interspeech2021"
362+
segmentation = SegmentationModel.from_pyannote(name)
325363
config = PipelineConfig(
326364
# Set the model used in the paper
327-
segmentation=SegmentationModel.from_pyannote("pyannote/segmentation@Interspeech2021"),
365+
segmentation=segmentation,
328366
step=0.5,
329367
latency=0.5,
330368
tau_active=0.555,
331369
rho_update=0.422,
332370
delta_new=1.517
333371
)
334-
pipeline = OnlineSpeakerDiarization(config)
335-
benchmark = Benchmark("/wav/dir", "/rttm/dir")
336-
benchmark(pipeline)
372+
benchmark(OnlineSpeakerDiarization, config)
373+
374+
# Run the same benchmark in parallel
375+
p_benchmark = Parallelize(benchmark, num_workers=4)
376+
if __name__ == "__main__": # Needed for multiprocessing
377+
p_benchmark(OnlineSpeakerDiarization, config)
337378
```
338379

339380
This pre-calculates model outputs in batches, so it runs a lot faster.

demo.gif

5.48 MB
Loading

logo.jpg

45.9 KB
Loading

logo.png

-6.1 KB
Binary file not shown.

requirements.txt

+4-1
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@ einops>=0.3.0
77
tqdm>=4.64.0
88
pandas>=1.4.2
99
torch>=1.12.1
10+
torchvision>=0.14.0
1011
torchaudio>=0.12.1,<1.0
1112
pyannote.audio>=2.1.1
1213
pyannote.core>=4.5
1314
pyannote.database>=4.1.1
1415
pyannote.metrics>=3.2
1516
optuna>=2.10
16-
websockets>=10.3
17+
websocket-server>=0.6.4
18+
websocket-client>=0.58.0
19+
rich>=12.5.1

setup.cfg

+11-6
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
[metadata]
22
name=diart
3-
version=0.6.0
3+
version=0.7.0
44
author=Juan Manuel Coria
55
description=Speaker diarization in real time
66
long_description=file: README.md
7-
long_description_content_type = text/markdown
7+
long_description_content_type=text/markdown
88
keywords=speaker diarization, streaming, online, real time, rxpy
99
url=https://github.com/juanmc2005/StreamingSpeakerDiarization
1010
license=MIT
@@ -29,19 +29,24 @@ install_requires=
2929
tqdm>=4.64.0
3030
pandas>=1.4.2
3131
torch>=1.12.1
32+
torchvision>=0.14.0
3233
torchaudio>=0.12.1,<1.0
3334
pyannote.audio>=2.1.1
3435
pyannote.core>=4.5
3536
pyannote.database>=4.1.1
3637
pyannote.metrics>=3.2
3738
optuna>=2.10
38-
websockets>=10.3
39+
websocket-server>=0.6.4
40+
websocket-client>=0.58.0
41+
rich>=12.5.1
3942

4043
[options.packages.find]
4144
where=src
4245

4346
[options.entry_points]
4447
console_scripts=
45-
diart.stream=diart.stream:run
46-
diart.benchmark=diart.benchmark:run
47-
diart.tune=diart.tune:run
48+
diart.stream=diart.console.stream:run
49+
diart.benchmark=diart.console.benchmark:run
50+
diart.tune=diart.console.tune:run
51+
diart.serve=diart.console.serve:run
52+
diart.client=diart.console.client:run

src/diart/__init__.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,6 @@
1-
from .blocks import OnlineSpeakerDiarization, PipelineConfig
1+
from .blocks import (
2+
OnlineSpeakerDiarization,
3+
BasePipeline,
4+
PipelineConfig,
5+
BasePipelineConfig,
6+
)

src/diart/argdoc.py

+2
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,7 @@
1010
MAX_SPEAKERS = "Maximum number of speakers"
1111
CPU = "Force models to run on CPU"
1212
BATCH_SIZE = "For segmentation and embedding pre-calculation. If BATCH_SIZE < 2, run fully online and estimate real-time latency"
13+
NUM_WORKERS = "Number of parallel workers"
1314
OUTPUT = "Directory to store the system's output in RTTM format"
1415
HF_TOKEN = "Huggingface authentication token for hosted models ('true' | 'false' | <token>). If 'true', it will use the token from huggingface-cli login"
16+
SAMPLE_RATE = "Sample rate of the audio stream"

src/diart/blocks/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@
1313
OverlapAwareSpeakerEmbedding,
1414
)
1515
from .segmentation import SpeakerSegmentation
16-
from .diarization import OnlineSpeakerDiarization, PipelineConfig
16+
from .diarization import OnlineSpeakerDiarization, BasePipeline
17+
from .config import BasePipelineConfig, PipelineConfig
1718
from .utils import Binarize, Resample, AdjustVolume

0 commit comments

Comments
 (0)