diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py index 151b4d36..2c1bc2cd 100644 --- a/src/diart/blocks/diarization.py +++ b/src/diart/blocks/diarization.py @@ -87,7 +87,8 @@ def sample_rate(self) -> int: class SpeakerDiarization(base.Pipeline): - def __init__(self, config: SpeakerDiarizationConfig | None = None): + def __init__(self, config: SpeakerDiarizationConfig | None = None, return_embeddings: bool = False): + self.return_embeddings = return_embeddings self._config = SpeakerDiarizationConfig() if config is None else config msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]" @@ -225,6 +226,9 @@ def __call__( agg_prediction = shifted_agg_prediction outputs.append((agg_prediction, agg_waveform)) + if self.return_embeddings: + outputs[-1] = outputs[-1] + (self.clustering.centers,) # extend output with speakers' embeddings + # Make place for new chunks in buffer if required if len(self.chunk_buffer) == self.pred_aggregation.num_overlapping_windows: diff --git a/src/diart/sinks.py b/src/diart/sinks.py index 1ae3adf9..243eac9d 100644 --- a/src/diart/sinks.py +++ b/src/diart/sinks.py @@ -7,6 +7,7 @@ from pyannote.metrics.diarization import DiarizationErrorRate from rx.core import Observer from typing_extensions import Literal +import json class WindowClosedException(Exception): @@ -55,6 +56,46 @@ def on_error(self, error: Exception): def on_completed(self): self.patch() +###### + +class RedisWriter(Observer): + def __init__(self, uri: Text, redis_client, patch_collar: float = 0.05): + super().__init__() + self.uri = uri + self.redis_client = redis_client + self.conversation_id = uri # Assuming URI as a unique identifier for the conversation + self.patch_collar = patch_collar + + def on_next(self, value: Union[Tuple, Annotation]): + if isinstance(value, tuple): + prediction, _ = value[:2] + # Process each segment in the prediction + for segment, _, label in prediction.itertracks(yield_label=True): + # Update last centroids for each speaker + + # Write data to Redis queues + diarization_data = { + 'start': segment.start, + 'end': segment.end, + 'speaker_id': label, + } + if len(value)==3: + diarization_data['centroids'] = value[-1].tolist() + + self.redis_client.rpush(f'diarization_{self.conversation_id}', json.dumps(diarization_data)) + + else: + prediction = value + + def on_error(self, error: Exception): + # Handle error (optional) + pass + + def on_completed(self): + # Handle completion (optional) + pass + + ####### class PredictionAccumulator(Observer): def __init__(self, uri: Optional[Text] = None, patch_collar: float = 0.05):