diff --git a/app/faster_whisper/core.py b/app/faster_whisper/core.py index dc1817d..771f4e7 100644 --- a/app/faster_whisper/core.py +++ b/app/faster_whisper/core.py @@ -7,7 +7,7 @@ import whisper from faster_whisper import WhisperModel -from .utils import ResultWriter, WriteTXT, WriteSRT, WriteVTT, WriteTSV, WriteJSON +from whisper.utils import ResultWriter, WriteTXT, WriteSRT, WriteVTT, WriteTSV, WriteJSON model_name = os.getenv("ASR_MODEL", "base") @@ -31,6 +31,11 @@ model_lock = Lock() +def to_whisper_word(word): + word_dict = word._asdict() + word_dict["confidence"] = word_dict.pop("probability") + return word_dict + def transcribe( audio, task: Union[str, None], @@ -54,7 +59,10 @@ def transcribe( text = "" segment_generator, info = model.transcribe(audio, beam_size=5, **options_dict) for segment in segment_generator: - segments.append(segment) + seg_dict = segment._asdict() + if "words" in seg_dict: + seg_dict["words"] = [to_whisper_word(word) for word in seg_dict["words"]] + segments.append(seg_dict) text = text + segment.text result = { "language": options_dict.get("language", info.language), @@ -84,15 +92,20 @@ def language_detection(audio): def write_result( result: dict, file: BinaryIO, output: Union[str, None] ): + options = { + 'max_line_width': 1000, + 'max_line_count': 10, + 'highlight_words': False + } if output == "srt": - WriteSRT(ResultWriter).write_result(result, file=file) + WriteSRT(ResultWriter).write_result(result, file=file, options=options) elif output == "vtt": - WriteVTT(ResultWriter).write_result(result, file=file) + WriteVTT(ResultWriter).write_result(result, file=file, options=options) elif output == "tsv": - WriteTSV(ResultWriter).write_result(result, file=file) + WriteTSV(ResultWriter).write_result(result, file=file, options=options) elif output == "json": - WriteJSON(ResultWriter).write_result(result, file=file) + WriteJSON(ResultWriter).write_result(result, file=file, options=options) elif output == "txt": - WriteTXT(ResultWriter).write_result(result, file=file) + WriteTXT(ResultWriter).write_result(result, file=file, options=options) else: return 'Please select an output method!' diff --git a/app/faster_whisper/utils.py b/app/faster_whisper/utils.py deleted file mode 100644 index 4a41acf..0000000 --- a/app/faster_whisper/utils.py +++ /dev/null @@ -1,86 +0,0 @@ -import json -import os -from typing import TextIO - -from faster_whisper.utils import format_timestamp - - -class ResultWriter: - extension: str - - def __init__(self, output_dir: str): - self.output_dir = output_dir - - def __call__(self, result: dict, audio_path: str): - audio_basename = os.path.basename(audio_path) - output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension) - - with open(output_path, "w", encoding="utf-8") as f: - self.write_result(result, file=f) - - def write_result(self, result: dict, file: TextIO): - raise NotImplementedError - - -class WriteTXT(ResultWriter): - extension: str = "txt" - - def write_result(self, result: dict, file: TextIO): - for segment in result["segments"]: - print(segment.text.strip(), file=file, flush=True) - - -class WriteVTT(ResultWriter): - extension: str = "vtt" - - def write_result(self, result: dict, file: TextIO): - print("WEBVTT\n", file=file) - for segment in result["segments"]: - print( - f"{format_timestamp(segment.start)} --> {format_timestamp(segment.end)}\n" - f"{segment.text.strip().replace('-->', '->')}\n", - file=file, - flush=True, - ) - - -class WriteSRT(ResultWriter): - extension: str = "srt" - - def write_result(self, result: dict, file: TextIO): - for i, segment in enumerate(result["segments"], start=1): - # write srt lines - print( - f"{i}\n" - f"{format_timestamp(segment.start, always_include_hours=True, decimal_marker=',')} --> " - f"{format_timestamp(segment.end, always_include_hours=True, decimal_marker=',')}\n" - f"{segment.text.strip().replace('-->', '->')}\n", - file=file, - flush=True, - ) - - -class WriteTSV(ResultWriter): - """ - Write a transcript to a file in TSV (tab-separated values) format containing lines like: - \t\t - - Using integer milliseconds as start and end times means there's no chance of interference from - an environment setting a language encoding that causes the decimal in a floating point number - to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++. - """ - extension: str = "tsv" - - def write_result(self, result: dict, file: TextIO): - print("start", "end", "text", sep="\t", file=file) - for segment in result["segments"]: - print(round(1000 * segment.start), file=file, end="\t") - print(round(1000 * segment.end), file=file, end="\t") - print(segment.text.strip().replace("\t", " "), file=file, flush=True) - - -class WriteJSON(ResultWriter): - extension: str = "json" - - def write_result(self, result: dict, file: TextIO): - json.dump(result, file)