diff --git a/whisper/decoding.py b/whisper/decoding.py index 49485d009..cc06081a6 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -18,18 +18,20 @@ @torch.no_grad() def detect_language( model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None -) -> Tuple[Tensor, List[dict]]: +) -> Tuple[Tensor, Union[Dict[str,float], List[Dict[str,float]]]]: """ - Detect the spoken language in the audio, and return them as list of strings, along with the ids + Detect the spoken language in the audio, and return them as a list of strings, along with the ids of the most probable language tokens and the probability distribution over all language tokens. This is performed outside the main decode loop in order to not interfere with kv-caching. Returns ------- language_tokens : Tensor, shape = (n_audio,) - ids of the most probable language tokens, which appears after the startoftranscript token. - language_probs : List[Dict[str, float]], length = n_audio - list of dictionaries containing the probability distribution over all languages. + ids of the most probable language tokens, which appear after the startoftranscript token. + language_probs : Union[Dict[str, float], List[Dict[str, float]]] + If the input contains a single audio sample, this will be a dictionary containing the + probability distribution over all languages for that sample. If the input contains multiple + audio samples, this will be a list of such dictionaries, one for each sample (length = n_audio). """ if tokenizer is None: tokenizer = get_tokenizer(