1
1
<br />
2
2
3
3
<p align =" center " >
4
- <img src =" /logo.png " title =" Logo " />
4
+ <img width = " 40% " src =" /logo.jpg " title =" Logo " />
5
5
</p >
6
6
7
7
<p align =" center " >
@@ -67,11 +67,10 @@ conda create -n diart python=3.8
67
67
conda activate diart
68
68
```
69
69
70
- 2 ) Install ` PortAudio ` and ` soundfile ` :
70
+ 2 ) Install audio libraries :
71
71
72
72
``` shell
73
- conda install portaudio
74
- conda install pysoundfile -c conda-forge
73
+ conda install portaudio pysoundfile ffmpeg -c conda-forge
75
74
```
76
75
77
76
3 ) Install diart:
@@ -101,6 +100,8 @@ diart.stream /path/to/audio.wav
101
100
A live conversation:
102
101
103
102
``` shell
103
+ # Use "microphone:ID" to select a non-default device
104
+ # See `python -m sounddevice` for available devices
104
105
diart.stream microphone
105
106
```
106
107
@@ -127,29 +128,49 @@ For inference and evaluation on a dataset we recommend to use `Benchmark` (see n
127
128
128
129
## Custom models
129
130
130
- Third-party models can be integrated seamlessly by subclassing ` SegmentationModel ` and ` EmbeddingModel ` :
131
+ Third-party models can be integrated seamlessly by subclassing ` SegmentationModel ` and ` EmbeddingModel ` (which are PyTorch ` Module ` subclasses) :
131
132
132
133
``` python
133
- import torch
134
- from typing import Optional
135
134
from diart import OnlineSpeakerDiarization, PipelineConfig
136
- from diart.models import EmbeddingModel
135
+ from diart.models import EmbeddingModel, SegmentationModel
137
136
from diart.sources import MicrophoneAudioSource
138
137
from diart.inference import RealTimeInference
139
138
139
+
140
+ def model_loader ():
141
+ return load_pretrained_model(" my_model.ckpt" )
142
+
143
+
144
+ class MySegmentationModel (SegmentationModel ):
145
+ def __init__ (self ):
146
+ super ().__init__ (model_loader)
147
+
148
+ @ property
149
+ def sample_rate (self ) -> int :
150
+ return 16000
151
+
152
+ @ property
153
+ def duration (self ) -> float :
154
+ return 2 # seconds
155
+
156
+ def forward (self , waveform ):
157
+ # self.model is created lazily
158
+ return self .model(waveform)
159
+
160
+
140
161
class MyEmbeddingModel (EmbeddingModel ):
141
162
def __init__ (self ):
142
- super ().__init__ ()
143
- self .my_pretrained_model = load(" my_model.ckpt" )
163
+ super ().__init__ (model_loader)
164
+
165
+ def forward (self , waveform , weights ):
166
+ # self.model is created lazily
167
+ return self .model(waveform, weights)
168
+
144
169
145
- def __call__ (
146
- self ,
147
- waveform : torch.Tensor,
148
- weights : Optional[torch.Tensor] = None
149
- ) -> torch.Tensor:
150
- return self .my_pretrained_model(waveform, weights)
151
-
152
- config = PipelineConfig(embedding = MyEmbeddingModel())
170
+ config = PipelineConfig(
171
+ segmentation = MySegmentationModel(),
172
+ embedding = MyEmbeddingModel()
173
+ )
153
174
pipeline = OnlineSpeakerDiarization(config)
154
175
mic = MicrophoneAudioSource(config.sample_rate)
155
176
inference = RealTimeInference(pipeline, mic)
@@ -225,7 +246,7 @@ from diart.blocks import SpeakerSegmentation, OverlapAwareSpeakerEmbedding
225
246
226
247
segmentation = SpeakerSegmentation.from_pyannote(" pyannote/segmentation" )
227
248
embedding = OverlapAwareSpeakerEmbedding.from_pyannote(" pyannote/embedding" )
228
- sample_rate = segmentation.model.get_sample_rate()
249
+ sample_rate = segmentation.model.sample_rate
229
250
mic = MicrophoneAudioSource(sample_rate)
230
251
231
252
stream = mic.stream.pipe(
@@ -252,7 +273,20 @@ torch.Size([1, 3, 512])
252
273
253
274
Diart is also compatible with the WebSocket protocol to serve pipelines on the web.
254
275
255
- In the following example we build a minimal server that receives audio chunks and sends back predictions in RTTM format:
276
+ ### From the command line
277
+
278
+ ``` commandline
279
+ diart.serve --host 0.0.0.0 --port 7007
280
+ diart.client microphone --host <server-address> --port 7007
281
+ ```
282
+
283
+ ** Note:** please make sure that the client uses the same ` step ` and ` sample_rate ` than the server with ` --step ` and ` -sr ` .
284
+
285
+ See ` -h ` for more options.
286
+
287
+ ### From python
288
+
289
+ For customized solutions, a server can also be created in python using the ` WebSocketAudioSource ` :
256
290
257
291
``` python
258
292
from diart import OnlineSpeakerDiarization
@@ -261,7 +295,7 @@ from diart.inference import RealTimeInference
261
295
262
296
pipeline = OnlineSpeakerDiarization()
263
297
source = WebSocketAudioSource(pipeline.config.sample_rate, " localhost" , 7007 )
264
- inference = RealTimeInference(pipeline, source, do_plot = True )
298
+ inference = RealTimeInference(pipeline, source)
265
299
inference.attach_hooks(lambda ann_wav : source.send(ann_wav[0 ].to_rttm()))
266
300
prediction = inference()
267
301
```
@@ -318,22 +352,29 @@ diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1
318
352
or using the inference API:
319
353
320
354
``` python
321
- from diart.inference import Benchmark
355
+ from diart.inference import Benchmark, Parallelize
322
356
from diart import OnlineSpeakerDiarization, PipelineConfig
323
357
from diart.models import SegmentationModel
324
358
359
+ benchmark = Benchmark(" /wav/dir" , " /rttm/dir" )
360
+
361
+ name = " pyannote/segmentation@Interspeech2021"
362
+ segmentation = SegmentationModel.from_pyannote(name)
325
363
config = PipelineConfig(
326
364
# Set the model used in the paper
327
- segmentation = SegmentationModel.from_pyannote( " pyannote/ segmentation@Interspeech2021 " ) ,
365
+ segmentation = segmentation,
328
366
step = 0.5 ,
329
367
latency = 0.5 ,
330
368
tau_active = 0.555 ,
331
369
rho_update = 0.422 ,
332
370
delta_new = 1.517
333
371
)
334
- pipeline = OnlineSpeakerDiarization(config)
335
- benchmark = Benchmark(" /wav/dir" , " /rttm/dir" )
336
- benchmark(pipeline)
372
+ benchmark(OnlineSpeakerDiarization, config)
373
+
374
+ # Run the same benchmark in parallel
375
+ p_benchmark = Parallelize(benchmark, num_workers = 4 )
376
+ if __name__ == " __main__" : # Needed for multiprocessing
377
+ p_benchmark(OnlineSpeakerDiarization, config)
337
378
```
338
379
339
380
This pre-calculates model outputs in batches, so it runs a lot faster.
0 commit comments