From 1d5ff272ed648ebf9ce0c3b4d8690cef87f7737a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Drago=C8=99?= Date: Tue, 19 Mar 2024 13:43:51 +0100 Subject: [PATCH 1/7] Add models, config, .flake8 --- .flake8 | 58 +++++++++++++++++++++++++++++++++++ .github/workflows/_deploy.yml | 2 +- config/config.yml | 34 ++++++++++++++++++++ models.py | 35 +++++++++++++++++++++ 4 files changed, 128 insertions(+), 1 deletion(-) create mode 100644 .flake8 create mode 100644 config/config.yml create mode 100644 models.py diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..9d542f7 --- /dev/null +++ b/.flake8 @@ -0,0 +1,58 @@ +# use .flake8 until we can move this config to pyproject.toml (not possible yet (27/02/2024) according to issue below) +# https://github.com/PyCQA/flake8/issues/234 + +[flake8] +select = + # B: bugbear warnings + B, + + # B950: bugbear max-linelength warning + # as suggested in the black docs + # https://github.com/psf/black/blob/d038a24ca200da9dacc1dcb05090c9e5b45b7869/docs/the_black_code_style/current_style.md#line-length + B950, + + # C: currently only C901, mccabe code complexity + C, + + # E: pycodestyle errors + E, + + # F: flake8 codes for pyflakes + F, + + # W: pycodestyle warnings + W, + +extend-ignore = + # E203: pycodestyle's "whitespace before ',', ';' or ':'" error + # ignored as suggested in the black docs + # https://github.com/psf/black/blob/d038a24ca200da9dacc1dcb05090c9e5b45b7869/docs/the_black_code_style/current_style.md#slices + E203, + + # E501: pycodestyle's "line too long (82 > 79) characters" error + # ignored in favor of B950 as suggested in the black docs + # https://github.com/psf/black/blob/d038a24ca200da9dacc1dcb05090c9e5b45b7869/docs/the_black_code_style/current_style.md#line-length + E501, + + # W503 line break before binary operator + W503, + +# set max-line-length to be black compatible, as suggested in the black docs +# https://github.com/psf/black/blob/d038a24ca200da9dacc1dcb05090c9e5b45b7869/docs/the_black_code_style/current_style.md#line-length +max-line-length = 88 + +# set max cyclomatic complexity for mccabe plugin +max-complexity = 10 + +# show total number of errors, set exit code to 1 if tot is not empty +count = True + +# show the source generating each error or warning +show-source = True + +# count errors and warnings +statistics = True + +exclude = + .venv + misc \ No newline at end of file diff --git a/.github/workflows/_deploy.yml b/.github/workflows/_deploy.yml index 5c6e6a9..9a45483 100644 --- a/.github/workflows/_deploy.yml +++ b/.github/workflows/_deploy.yml @@ -1,4 +1,4 @@ -name: Deploy dane-visual-feature-extraction-worker to ghcr +name: Deploy dane-whisper-asr-worker to ghcr on: workflow_call: diff --git a/config/config.yml b/config/config.yml new file mode 100644 index 0000000..b238189 --- /dev/null +++ b/config/config.yml @@ -0,0 +1,34 @@ +RABBITMQ: + HOST: dane-rabbitmq-api.default.svc.cluster.local + PORT: 5672 + EXCHANGE: DANE-exchange + RESPONSE_QUEUE: DANE-response-queue + USER: guest # change this for production mode + PASSWORD: guest # change this for production mode +ELASTICSEARCH: + HOST: + - elasticsearch + PORT: 9200 + USER: '' # change this for production mode + PASSWORD: '' # change this for production mode + SCHEME: http + INDEX: dane-index-k8s +FILE_SYSTEM: + BASE_MOUNT: data # data when running locally + INPUT_DIR: input-files + OUTPUT_DIR: output-files +INPUT: + TEST_INPUT_PATH: testsource__testcarrier/inputfile.wav + S3_ENDPOINT_URL: https://s3-host + MODEL: s3://bucket/model + DELETE_ON_COMPLETION: False +OUTPUT: + DELETE_ON_COMPLETION: True + TRANSFER_ON_COMPLETION: True + S3_ENDPOINT_URL: https://s3-host + S3_BUCKET: bucket-name # bucket reserved for 1 type of output + S3_FOLDER_IN_BUCKET: folder # folder within the bucket +WHISPER_ASR_SETTINGS: + WORD_TIMESTAMPS: True +DANE_DEPENDENCIES: + - input-generating-worker \ No newline at end of file diff --git a/models.py b/models.py new file mode 100644 index 0000000..f8f9aa2 --- /dev/null +++ b/models.py @@ -0,0 +1,35 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Optional, TypedDict +from dane.provenance import Provenance + + +# returned by callback() +class CallbackResponse(TypedDict): + state: int + message: str + + +# These are the types of output this worker (possibly) provides (depending on configuration) +class OutputType(Enum): + # name of output type, should just have a significant name, no other restrictions + # (as far as I understand) + TRANSCRIPT = "transcript" + PROVENANCE = "provenance" # produced by provenance.py + + +@dataclass +class WhisperASRInput: + state: int # HTTP status code + message: str # error/success message + source_id: str = "" # __ + input_file_path: str = "" # where the audio was downloaded from + provenance: Optional[Provenance] = None # mostly: how long did it take to download + + +@dataclass +class WhisperASROutput: + state: int # HTTP status code + message: str # error/success message + output_file_path: str = "" # where to store the text file + provenance: Optional[Provenance] = None # audio extraction provenance From f7e67bd944652ed2da05c356a5d38605d491d8b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Drago=C8=99?= Date: Tue, 19 Mar 2024 13:58:39 +0100 Subject: [PATCH 2/7] Add utils --- base_util.py | 171 ++++++++++++++++++++++++++++ io_util.py | 306 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 477 insertions(+) create mode 100644 base_util.py create mode 100644 io_util.py diff --git a/base_util.py b/base_util.py new file mode 100644 index 0000000..9456bc4 --- /dev/null +++ b/base_util.py @@ -0,0 +1,171 @@ +from typing import Any, List +from yacs.config import CfgNode +import os +from pathlib import Path +import logging + + +LOG_FORMAT = "%(asctime)s|%(levelname)s|%(process)d|%(module)s|%(funcName)s|%(lineno)d|%(message)s" +logger = logging.getLogger(__name__) + + +def validate_config(config: CfgNode, validate_file_paths: bool = True) -> bool: + """Check the configuration (supplied by config.yml) + Most of the config is related to DANE and do not need to be altered when + developing locally, except the last part (settings for this worker specifically). + consult https://github.com/beeldengeluid/dane-example-worker/wiki/Config + """ + try: + __validate_environment_variables() + except AssertionError as e: + print("Error malconfigured worker: env vars incomplete") + print(str(e)) + return False + + parent_dirs_to_check: List[str] = [] # parent dirs of file paths must exist + try: + # rabbitmq settings + assert config.RABBITMQ, "RABBITMQ" + assert check_setting(config.RABBITMQ.HOST, str), "RABBITMQ.HOST" + assert check_setting(config.RABBITMQ.PORT, int), "RABBITMQ.PORT" + assert check_setting(config.RABBITMQ.EXCHANGE, str), "RABBITMQ.EXCHANGE" + assert check_setting( + config.RABBITMQ.RESPONSE_QUEUE, str + ), "RABBITMQ.RESPONSE_QUEUE" + assert check_setting(config.RABBITMQ.USER, str), "RABBITMQ.USER" + assert check_setting(config.RABBITMQ.PASSWORD, str), "RABBITMQ.PASSWORD" + + # Elasticsearch settings + assert config.ELASTICSEARCH, "ELASTICSEARCH" + assert check_setting(config.ELASTICSEARCH.HOST, list), "ELASTICSEARCH.HOST" + assert ( + len(config.ELASTICSEARCH.HOST) == 1 + and type(config.ELASTICSEARCH.HOST[0]) is str + ), "Invalid ELASTICSEARCH.HOST" + + assert check_setting(config.ELASTICSEARCH.PORT, int), "ELASTICSEARCH.PORT" + assert check_setting(config.ELASTICSEARCH.USER, str, True), "ELASTICSEARCH.USER" + assert check_setting( + config.ELASTICSEARCH.PASSWORD, str, True + ), "ELASTICSEARCH.PASSWORD" + assert check_setting(config.ELASTICSEARCH.SCHEME, str), "ELASTICSEARCH.SCHEME" + assert check_setting(config.ELASTICSEARCH.INDEX, str), "ELASTICSEARCH.INDEX" + + # DANE python lib settings + assert config.PATHS, "PATHS" + assert check_setting(config.PATHS.TEMP_FOLDER, str), "PATHS.TEMP_FOLDER" + assert check_setting(config.PATHS.OUT_FOLDER, str), "PATHS.OUT_FOLDER" + + assert config.FILE_SYSTEM, "FILE_SYSTEM" + assert check_setting( + config.FILE_SYSTEM.BASE_MOUNT, str + ), "FILE_SYSTEM.BASE_MOUNT" + assert check_setting(config.FILE_SYSTEM.INPUT_DIR, str), "FILE_SYSTEM.INPUT_DIR" + assert check_setting( + config.FILE_SYSTEM.OUTPUT_DIR, str + ), "FILE_SYSTEM.OUTPUT_DIR" + + # settings for input & output handling + assert config.INPUT, "INPUT" + assert check_setting( + config.INPUT.S3_ENDPOINT_URL, str, True + ), "INPUT.S3_ENDPOINT_URL" + assert check_setting( + config.INPUT.MODEL_CHECKPOINT_S3_URI, str, True + ), "INPUT.MODEL_CHECKPOINT_S3_URI" + assert check_setting( + config.INPUT.MODEL_CONFIG_S3_URI, str, True + ), "INPUT.MODEL_CONFIG_S3_URI" + assert check_setting( + config.INPUT.DELETE_ON_COMPLETION, bool + ), "INPUT.DELETE_ON_COMPLETION" + + assert config.OUTPUT, "OUTPUT" + assert check_setting( + config.OUTPUT.DELETE_ON_COMPLETION, bool + ), "OUTPUT.DELETE_ON_COMPLETION" + assert check_setting( + config.OUTPUT.TRANSFER_ON_COMPLETION, bool + ), "OUTPUT.TRANSFER_ON_COMPLETION" + if config.OUTPUT.TRANSFER_ON_COMPLETION: + # required only in case output must be transferred + assert check_setting( + config.OUTPUT.S3_ENDPOINT_URL, str + ), "OUTPUT.S3_ENDPOINT_URL" + assert check_setting(config.OUTPUT.S3_BUCKET, str), "OUTPUT.S3_BUCKET" + assert check_setting( + config.OUTPUT.S3_FOLDER_IN_BUCKET, str + ), "OUTPUT.S3_FOLDER_IN_BUCKET" + + # settings for this worker specifically + assert check_setting( + config.WHISPER_ASR_SETTINGS.WORD_TIMESTAMPS, bool + ), "WHISPER_ASR_SETTINGS.WORD_TIMESTAMPS" + + assert __check_dane_dependencies(config.DANE_DEPENDENCIES), "DANE_DEPENDENCIES" + + # validate file paths (not while unit testing) + if validate_file_paths: + __validate_parent_dirs(parent_dirs_to_check) + __validate_dane_paths(config.PATHS.TEMP_FOLDER, config.PATHS.OUT_FOLDER) + + except AssertionError as e: + print(f"Configuration error: {str(e)}") + return False + + return True + + +def __validate_environment_variables() -> None: + # self.UNIT_TESTING = os.getenv('DW_ASR_UNIT_TESTING', False) + try: + assert True # TODO add secrets from the config.yml to the env + except AssertionError as e: + raise (e) + + +def __validate_dane_paths(dane_temp_folder: str, dane_out_folder: str) -> None: + i_dir = Path(dane_temp_folder) + o_dir = Path(dane_out_folder) + + try: + assert os.path.exists( + i_dir.parent.absolute() + ), f"{i_dir.parent.absolute()} does not exist" + assert os.path.exists( + o_dir.parent.absolute() + ), f"{o_dir.parent.absolute()} does not exist" + except AssertionError as e: + raise (e) + + +def check_setting(setting: Any, t: type, optional=False) -> bool: + return (type(setting) is t and optional is False) or ( + optional and (setting is None or type(setting) is t) + ) + + +def __check_dane_dependencies(deps: Any) -> bool: + """The idea is that you specify a bit more strictly that your worker can only + work on the OUTPUT of another worker. + If you want to define a dependency, you should populate the deps_allowed list + in this function with valid keys, that other workers use to identify themselves + within DANE: just use the queue_name + (see e.g. https://github.com/beeldengeluid/dane-video-segmentation-worker/blob/main/worker.py#L34-L35) + Then also make sure you define a valid dependency in your worker here: + https://github.com/beeldengeluid/dane-video-segmentation-worker/blob/main/worker.py#L36-L38 + (using another worker as an example) + """ + deps_to_check: list = deps if type(deps) is list else [] + deps_allowed: list = [] + return all(dep in deps_allowed for dep in deps_to_check) + + +def __validate_parent_dirs(paths: list) -> None: + try: + for p in paths: + assert os.path.exists( + Path(p).parent.absolute() + ), f"Parent dir of file does not exist: {p}" + except AssertionError as e: + raise (e) diff --git a/io_util.py b/io_util.py new file mode 100644 index 0000000..b8e872c --- /dev/null +++ b/io_util.py @@ -0,0 +1,306 @@ +import logging +import os +from pathlib import Path +import shutil +import tarfile +from time import time +from typing import Dict, List + +from dane import Document +from dane.config import cfg +from dane.s3_util import S3Store, parse_s3_uri, validate_s3_uri +from models import ( + OutputType, + Provenance, + WhisperASRInput, +) + + +logger = logging.getLogger(__name__) +INPUT_GENERATOR_TASK_KEY = "SOME_KEY" +OUTPUT_FILE_BASE_NAME = "out" +TAR_GZ_EXTENSION = ".tar.gz" +S3_OUTPUT_TYPES: List[OutputType] = [ + OutputType.TRANSCRIPT, + OutputType.PROVENANCE, +] # only upload this output to S3 + + +# make sure the necessary base dirs are there +def validate_data_dirs() -> bool: # TODO: perhaps add model dir + i_dir = Path(get_download_dir()) + o_dir = Path(get_base_output_dir()) + + if not os.path.exists(i_dir.parent.absolute()): + logger.info( + f"{i_dir.parent.absolute()} does not exist." + "Make sure BASE_MOUNT_DIR exists before retrying" + ) + return False + + # make sure the input and output dirs are there + try: + os.makedirs(i_dir, 0o755) + logger.info("created input dir: {}".format(i_dir)) + except FileExistsError as e: + logger.info(e) + + try: + os.makedirs(o_dir, 0o755) + logger.info("created output dir: {}".format(o_dir)) + except FileExistsError as e: + logger.info(e) + + return True + + +# for each OutputType a subdir is created inside the base output dir +def generate_output_dirs(source_id: str) -> Dict[str, str]: + base_output_dir = get_base_output_dir(source_id) + output_dirs = {} + for output_type in OutputType: + output_dir = os.path.join(base_output_dir, output_type.value) + if not os.path.isdir(output_dir): + os.makedirs(output_dir) + output_dirs[output_type.value] = output_dir + return output_dirs + + +# below this dir each processing module will put its output data in a subfolder +def get_base_output_dir(source_id: str = "") -> str: + path_elements = [cfg.FILE_SYSTEM.BASE_MOUNT, cfg.FILE_SYSTEM.OUTPUT_DIR] + if source_id: + path_elements.append(source_id) + return os.path.join(*path_elements) + + +# output file name of the final tar.gz that will be uploaded to S3 +def get_archive_file_path(source_id: str) -> str: + return os.path.join( + get_base_output_dir(source_id), + f"{OUTPUT_FILE_BASE_NAME}__{source_id}{TAR_GZ_EXTENSION}", + ) + + +def get_output_file_name(source_id: str, output_type: OutputType) -> str: + output_file_name = "" + match output_type: + case OutputType.PROVENANCE: + output_file_name = source_id + "_provenance.json" + case OutputType.TRANSCRIPT: + output_file_name = ( + source_id + ".json" + ) + return output_file_name + + +# output file name of the final .pt file that will be uploaded to S3 +# TODO decide whether to tar.gz this as well +def get_output_file_path(source_id: str, output_type: OutputType) -> str: + return os.path.join( + get_base_output_dir(source_id), + output_type.value, + get_output_file_name(source_id, output_type), + ) + + +# e.g. s3:///assets/ +def get_s3_base_uri(source_id: str) -> str: + uri = os.path.join(cfg.OUTPUT.S3_BUCKET, cfg.OUTPUT.S3_FOLDER_IN_BUCKET, source_id) + return f"s3://{uri}" + + +# e.g. s3:///assets//__.tar.gz +def get_s3_output_file_uri(source_id: str) -> str: + return f"{get_s3_base_uri(source_id)}/{get_archive_file_path(source_id)}" + + +# NOTE: only use for test run & unit test with input that points to tar file! +# e.g. ./data/input-files/__testob.tar.gz +def get_source_id_from_tar(input_path: str) -> str: + fn = os.path.basename(input_path) + tmp = fn.split("__") + source_id = tmp[1][: -len(TAR_GZ_EXTENSION)] + logger.info(f"Using source_id: {source_id}") + return source_id + + +# e.g. s3:///assets//__.tar.gz +def source_id_from_s3_uri(s3_uri: str) -> str: + fn = os.path.basename(s3_uri) + tmp = fn.split("__") + source_id = tmp[1][: -len(TAR_GZ_EXTENSION)] + return source_id + + +def delete_local_output(source_id: str) -> bool: + output_dir = get_base_output_dir(source_id) + logger.info(f"Deleting output folder: {output_dir}") + if output_dir == os.sep or output_dir == ".": + logger.warning(f"Rejected deletion of: {output_dir}") + return False + + if not _is_valid_output(output_dir): + logger.warning( + f"Tried to delete a dir that did not contain output: {output_dir}" + ) + return False + + try: + shutil.rmtree(output_dir) + logger.info(f"Cleaned up folder {output_dir}") + except Exception: + logger.exception(f"Failed to delete output dir {output_dir}") + return False + return True + + +# TODO implement some more validation +def _is_valid_output(output_dir: str) -> bool: + provenance_exists = os.path.exists( + os.path.join(output_dir, OutputType.PROVENANCE.value) + ) + audio_exists = os.path.exists(os.path.join(output_dir, OutputType.TRANSCRIPT.value)) + return provenance_exists and audio_exists + + +def _validate_transfer_config() -> bool: + if any( + [ + not x + for x in [ + cfg.OUTPUT.S3_ENDPOINT_URL, + cfg.OUTPUT.S3_BUCKET, + cfg.OUTPUT.S3_FOLDER_IN_BUCKET, + ] + ] + ): + logger.warning( + "TRANSFER_ON_COMPLETION configured without all the necessary S3 settings" + ) + return False + return True + + +# compresses all desired output dirs into a single tar and uploads it to S3 +def transfer_output(source_id: str) -> bool: + output_dir = get_base_output_dir(source_id) + logger.info(f"Transferring {output_dir} to S3 (asset={source_id})") + if not _validate_transfer_config(): + return False + + s3 = S3Store(cfg.OUTPUT.S3_ENDPOINT_URL) + file_list = [os.path.join(output_dir, ot.value) for ot in S3_OUTPUT_TYPES] + tar_file = get_archive_file_path(source_id) + + success = s3.transfer_to_s3( + cfg.OUTPUT.S3_BUCKET, + os.path.join( + cfg.OUTPUT.S3_FOLDER_IN_BUCKET, source_id + ), # assets/__ + file_list, # this list of subdirs will be compressed into the tar below + tar_file, # this file will be uploaded + ) + if not success: + logger.error(f"Failed to upload: {tar_file}") + return False + return True + + +def get_download_dir() -> str: + return os.path.join(cfg.FILE_SYSTEM.BASE_MOUNT, cfg.FILE_SYSTEM.INPUT_DIR) + + +def get_base_input_dir(source_id: str) -> str: + return os.path.join(get_download_dir(), source_id) + + +def delete_input_file(input_file: str, source_id: str, actually_delete: bool) -> bool: + logger.info(f"Verifying deletion of input file: {input_file}") + if actually_delete is False: + logger.info("Configured to leave the input alone, skipping deletion") + return True + + # first remove the input file + try: + os.remove(input_file) + logger.info(f"Deleted input tar file: {input_file}") + except OSError: + logger.exception("Could not delete input file") + return False + + # now remove the folders that were extracted from the input tar file + base_input_dir = get_base_input_dir(source_id) + try: + for root, dirs, files in os.walk(base_input_dir): + for d in dirs: + dir_path = os.path.join(root, d) + logger.info(f"Deleting {dir_path}") + shutil.rmtree(dir_path) + logger.info("Deleted extracted input dirs") + os.removedirs(base_input_dir) + logger.info(f"Finally deleted the base_input_dir: {base_input_dir}") + except OSError: + logger.exception("OSError while removing empty input file dirs") + + return True # return True even if empty dirs were not removed + + +def obtain_input_file(s3_uri: str) -> WhisperASRInput: + + if not validate_s3_uri(s3_uri): + return WhisperASRInput(500, f"Invalid S3 URI: {s3_uri}") + + source_id = source_id_from_s3_uri(s3_uri) + start_time = time() + output_folder = get_base_input_dir(source_id) + + # download the content into get_download_dir() + s3 = S3Store(cfg.OUTPUT.S3_ENDPOINT_URL) + bucket, object_name = parse_s3_uri(s3_uri) + logger.info(f"OBJECT NAME: {object_name}") + input_file_path = os.path.join( + get_download_dir(), + source_id, + os.path.basename(object_name), # i.e. __.tar.gz + ) + success = s3.download_file(bucket, object_name, output_folder) + if success: + # uncompress the .tar.gz + untar_input_file(input_file_path) + + provenance = Provenance( + activity_name="download", + activity_description="Download input data", + start_time_unix=start_time, + processing_time_ms=time() - start_time, + input_data={}, + output_data={"file_path": input_file_path}, + ) + return WhisperASRInput( + 200, + f"Downloaded input from: {s3_uri}", + source_id_from_s3_uri(s3_uri), # source_id + input_file_path, # locally downloaded .tar.gz + provenance, + ) + logger.error("Failed to download input data from S3") + return WhisperASRInput(500, f"Failed to download: {s3_uri}") + + +def fetch_input_s3_uri(handler, doc: Document) -> str: + logger.info("checking input") + possibles = handler.searchResult(doc._id, INPUT_GENERATOR_TASK_KEY) + logger.info(possibles) + if len(possibles) > 0 and "s3_location" in possibles[0].payload: + return possibles[0].payload.get("s3_location", "") + logger.error(f"No s3_location found in result for {INPUT_GENERATOR_TASK_KEY}") + return "" + + +# untars somefile.tar.gz into the same dir +def untar_input_file(tar_file_path: str): + logger.info(f"Uncompressing {tar_file_path}") + path = str(Path(tar_file_path).parent) + with tarfile.open(tar_file_path) as tar: + tar.extractall(path=path, filter="data") # type: ignore From bb0b9c62b46558748b95104fd2a4e3705b807824 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Drago=C8=99?= Date: Tue, 19 Mar 2024 14:58:10 +0100 Subject: [PATCH 3/7] Add worker.py + fix black pipeline fail --- io_util.py | 4 +- worker.py | 173 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 174 insertions(+), 3 deletions(-) create mode 100644 worker.py diff --git a/io_util.py b/io_util.py index b8e872c..e9d3d8b 100644 --- a/io_util.py +++ b/io_util.py @@ -88,9 +88,7 @@ def get_output_file_name(source_id: str, output_type: OutputType) -> str: case OutputType.PROVENANCE: output_file_name = source_id + "_provenance.json" case OutputType.TRANSCRIPT: - output_file_name = ( - source_id + ".json" - ) + output_file_name = source_id + ".json" return output_file_name diff --git a/worker.py b/worker.py new file mode 100644 index 0000000..02ad88d --- /dev/null +++ b/worker.py @@ -0,0 +1,173 @@ +import logging +import sys +import os +from base_util import validate_config +from dane import Document, Task, Result +from dane.base_classes import base_worker +from dane.config import cfg +from dane.provenance import Provenance +from models import CallbackResponse +from io_util import ( + fetch_input_s3_uri, + source_id_from_s3_uri, + get_s3_output_file_uri, +) +from pika.exceptions import ChannelClosedByBroker # type: ignore +import main_data_processor + + +logger = logging.getLogger() + + +class WhisperASRWorker(base_worker): + def __init__(self, config): + logger.info(config) + + if not validate_config(config, not self.UNIT_TESTING): + logger.error("Invalid config, quitting") + sys.exit() + + self.__queue_name = "QUEUE_NAME" + self.__binding_key = "#.BINDING_KEY" + self.__depends_on = ( + list(config.DANE_DEPENDENCIES) if "DANE_DEPENDENCIES" in config else [] + ) + + super().__init__( + self.__queue_name, + self.__binding_key, + config, + self.__depends_on, + auto_connect=not self.UNIT_TESTING, + no_api=self.UNIT_TESTING, + ) + + # NOTE: cannot be automatically filled, because no git client is present + if not self.generator: + logger.info("Generator was None, creating it now") + self.generator = { + "id": "dane-whisper-asr-worker", + "type": "dane-whisper-asr-worker", + "name": "ASR using Whisper", + "homepage": "https://github.com/beeldengeluid/dane-whisper-asr-worker", + } + + """----------------------------------INTERACTION WITH DANE SERVER ---------------------------------""" + + # DANE callback function, called whenever there is a job for this worker + def callback(self, task: Task, doc: Document) -> CallbackResponse: + logger.info("Receiving a task from the DANE server!") + logger.info(task) + logger.info(doc) + + # fetch s3 uri of input data: + s3_uri = fetch_input_s3_uri(self.handler, doc) + + # now run the main process! + processing_result, full_provenance_chain = main_data_processor.run(s3_uri) + + # if results are fine, save something to the DANE index + if processing_result.get("state", 500) == 200: + logger.info( + "applying IO on output went well, now finally saving to DANE index" + ) + self.save_to_dane_index( + doc, + task, + get_s3_output_file_uri(source_id_from_s3_uri(s3_uri)), + provenance=full_provenance_chain, + ) + return processing_result + + # TODO adapt + def save_to_dane_index( + self, + doc: Document, + task: Task, + s3_location: str, + provenance: Provenance, + ) -> None: + logger.info("saving results to DANE, task id={0}".format(task._id)) + # TODO figure out the multiple lines per transcript (refresh my memory) + r = Result( + self.generator, + payload={ + "doc_id": doc._id, + "task_id": task._id if task else None, + "doc_target_id": doc.target["id"], + "doc_target_url": doc.target["url"], + "s3_location": s3_location, + "provenance": provenance.to_json(), + }, + api=self.handler, + ) + r.save(task._id) + + +# Start the worker +# passing --run-test-file will run the whole process on the files in cfg.INPUT.TEST_FILES +if __name__ == "__main__": + from argparse import ArgumentParser + import json + from base_util import LOG_FORMAT + + # first read the CLI arguments + parser = ArgumentParser(description="dane-audio-extraction-worker") + parser.add_argument( + "--run-test-file", action="store", dest="run_test_file", default="n", nargs="?" + ) + parser.add_argument("--log", action="store", dest="loglevel", default="INFO") + args = parser.parse_args() + + # initialises the root logger + logging.basicConfig( + stream=sys.stdout, # configure a stream handler only for now (single handler) + format=LOG_FORMAT, + ) + + # setting the loglevel + log_level = args.loglevel.upper() + logger.setLevel(log_level) + logger.info(f"Logger initialized (log level: {log_level})") + logger.info(f"Got the following CMD line arguments: {args}") + + # see if the test file must be run + if args.run_test_file != "n": + logger.info("Running feature extraction with INPUT.TEST_INPUT_PATH ") + if cfg.INPUT.TEST_INPUT_PATH: + processing_result, full_provenance_chain = main_data_processor.run( + os.path.join( + cfg.FILE_SYSTEM.BASE_MOUNT, + cfg.FILE_SYSTEM.INPUT_DIR, + cfg.INPUT.TEST_INPUT_PATH, + ) + ) + logger.info("Results after applying desired I/O") + logger.info(processing_result) + logger.info("Full provenance chain") + logger.info( + json.dumps(full_provenance_chain.to_json(), indent=4, sort_keys=True) + if full_provenance_chain + else None + ) + else: + logger.error("Please configure an input file in INPUT.TEST_INPUT_FILE") + sys.exit() + else: + logger.info("Starting the worker") + # start the worker + w = AudioExtractionWorker(cfg) + try: + w.run() + except ChannelClosedByBroker: + """ + (406, 'PRECONDITION_FAILED - delivery acknowledgement on channel 1 timed out. + Timeout value used: 1800000 ms. + This timeout value can be configured, see consumers doc guide to learn more') + """ + logger.critical( + "Please increase the consumer_timeout in your RabbitMQ server" + ) + w.stop() + except (KeyboardInterrupt, SystemExit): + w.stop() From 3af0061cb3b0ddebdc30a01f172a3817a79936d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Drago=C8=99?= Date: Tue, 19 Mar 2024 15:17:21 +0100 Subject: [PATCH 4/7] Fix worker reference issue --- worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/worker.py b/worker.py index 02ad88d..bada107 100644 --- a/worker.py +++ b/worker.py @@ -112,7 +112,7 @@ def save_to_dane_index( from base_util import LOG_FORMAT # first read the CLI arguments - parser = ArgumentParser(description="dane-audio-extraction-worker") + parser = ArgumentParser(description="dane-whisper-asr-worker") parser.add_argument( "--run-test-file", action="store", dest="run_test_file", default="n", nargs="?" ) @@ -156,7 +156,7 @@ def save_to_dane_index( else: logger.info("Starting the worker") # start the worker - w = AudioExtractionWorker(cfg) + w = WhisperASRWorker(cfg) try: w.run() except ChannelClosedByBroker: From d687741a7805a45d652e2532ff160a10f92f5e9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Drago=C8=99?= Date: Tue, 19 Mar 2024 15:21:50 +0100 Subject: [PATCH 5/7] mypy pipeline fail fix --- main_data_processor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/main_data_processor.py b/main_data_processor.py index e69de29..f299191 100644 --- a/main_data_processor.py +++ b/main_data_processor.py @@ -0,0 +1,2 @@ +def run(s3_uri): + return s3_uri, False From 29f0e8211754eca560031e0070c67e2ac2fa6f1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Drago=C8=99?= Date: Tue, 26 Mar 2024 13:44:45 +0100 Subject: [PATCH 6/7] Change "data" to "/data" --- config/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/config.yml b/config/config.yml index b238189..7f38baf 100644 --- a/config/config.yml +++ b/config/config.yml @@ -14,7 +14,7 @@ ELASTICSEARCH: SCHEME: http INDEX: dane-index-k8s FILE_SYSTEM: - BASE_MOUNT: data # data when running locally + BASE_MOUNT: /data # data when running locally INPUT_DIR: input-files OUTPUT_DIR: output-files INPUT: From 337704a3d042b9ebb884fa70734c567809880115 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Drago=C8=99?= Date: Thu, 28 Mar 2024 10:53:22 +0100 Subject: [PATCH 7/7] Remove copy-pasted comments --- models.py | 4 +--- worker.py | 2 -- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/models.py b/models.py index f8f9aa2..0c3872c 100644 --- a/models.py +++ b/models.py @@ -12,8 +12,6 @@ class CallbackResponse(TypedDict): # These are the types of output this worker (possibly) provides (depending on configuration) class OutputType(Enum): - # name of output type, should just have a significant name, no other restrictions - # (as far as I understand) TRANSCRIPT = "transcript" PROVENANCE = "provenance" # produced by provenance.py @@ -23,7 +21,7 @@ class WhisperASRInput: state: int # HTTP status code message: str # error/success message source_id: str = "" # __ - input_file_path: str = "" # where the audio was downloaded from + input_file_path: str = "" # where the audio was downloaded to provenance: Optional[Provenance] = None # mostly: how long did it take to download diff --git a/worker.py b/worker.py index bada107..c52c3ab 100644 --- a/worker.py +++ b/worker.py @@ -79,7 +79,6 @@ def callback(self, task: Task, doc: Document) -> CallbackResponse: ) return processing_result - # TODO adapt def save_to_dane_index( self, doc: Document, @@ -88,7 +87,6 @@ def save_to_dane_index( provenance: Provenance, ) -> None: logger.info("saving results to DANE, task id={0}".format(task._id)) - # TODO figure out the multiple lines per transcript (refresh my memory) r = Result( self.generator, payload={