Skip to content

Commit

Permalink
Address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
greenw0lf committed Oct 18, 2024
1 parent 29ba4f6 commit 874eb52
Show file tree
Hide file tree
Showing 10 changed files with 42 additions and 41 deletions.
9 changes: 1 addition & 8 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
/data/*
!/data/README.md
!/data/input/
/data/input/*
!/data/input/whisper-test.mp3
!/data/output
/data/output/*
!/data/output/whisper-test
!/data/output/whisper-test/*
!/data/whisper-test
/model/*
__pycache__
.pytest_cache
Expand Down
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,11 @@ The pre-trained Whisper model version can be adjusted in the `.env` file by edit
We recommend version `large-v2` as it performs better than `large-v3` in our [benchmarks](https://opensource-spraakherkenning-nl.github.io/ASR_NL_results/NISV/bn_nl/res_labelled.html).

You can also specify an S3/HTTP URI if you want to load your own (custom) model (by modifying the `W_MODEL` parameter).

## Config

The parameters used to configure the application can be found under `.env` file. You will also need to create a `.env.override` file that contains secrets related to the S3 connection that should normally not be exposed in the `.env` file. The parameters that should be updated with valid values in the `.env.override` are:

- `S3_ENDPOINT_URL`
- `AWS_ACCESS_KEY_ID`
- `AWS_SECRET_ACCESS_KEY`
12 changes: 6 additions & 6 deletions asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def run(input_uri: str, output_uri: str, model=None) -> Optional[str]:
logger.error(
"The transcode failed to yield a valid file to continue with, quitting..."
)
remove_all_input_output(input_path, asset_id, output_path)
remove_all_input_output(output_path)
return transcode_output.error
else:
input_path = transcode_output.transcoded_file_path
Expand All @@ -73,7 +73,7 @@ def run(input_uri: str, output_uri: str, model=None) -> Optional[str]:
if isinstance(whisper_prov_or_error, dict):
prov_steps.append(whisper_prov_or_error)
else:
remove_all_input_output(input_path, asset_id, output_path)
remove_all_input_output(output_path)
return whisper_prov_or_error
else:
logger.info(f"Whisper transcript already present in {output_path}")
Expand All @@ -98,7 +98,7 @@ def run(input_uri: str, output_uri: str, model=None) -> Optional[str]:
prov_steps.append(daan_prov)
else:
logger.error("Could not generate DAAN transcript")
remove_all_input_output(input_path, asset_id, output_path)
remove_all_input_output(output_path)
return "DAAN Transcript failure: Could not generate DAAN transcript"
else:
logger.info(f"DAAN transcript already present in {output_path}")
Expand Down Expand Up @@ -138,20 +138,20 @@ def run(input_uri: str, output_uri: str, model=None) -> Optional[str]:
prov_success = save_provenance(final_prov, output_path)
if not prov_success:
logger.error("Could not save the provenance")
remove_all_input_output(input_path, asset_id, output_path)
remove_all_input_output(output_path)
return "Provenance failure: Could not save the provenance"

# 5. transfer output
if output_uri:
success = transfer_asr_output(output_path, output_uri)
if not success:
logger.error("Could not upload output to S3")
remove_all_input_output(input_path, asset_id, output_path)
remove_all_input_output(output_path)
return "Upload failure: Could not upload output to S3"
else:
logger.info("No output_uri specified, so all is done")

remove_all_input_output(input_path, asset_id, output_path)
remove_all_input_output(output_path)
return None


Expand Down
24 changes: 12 additions & 12 deletions base_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def get_asset_info(input_file: str) -> Tuple[str, str]:
return asset_id, extension


# i.e. {output_base_dir}/output/{input_filename_without_extension}
# i.e. {output_base_dir}/{input_filename_without_extension}
def asr_output_dir(input_path):
return os.path.join(data_base_dir, "output", get_asset_info(input_path)[0])
return os.path.join(data_base_dir, get_asset_info(input_path)[0])


def extension_to_mime_type(extension: str) -> str:
Expand Down Expand Up @@ -87,17 +87,17 @@ def validate_http_uri(http_uri: str) -> bool:
return True


def remove_all_input_output(input_path: str, asset_id: str, output_path: str) -> bool:
def remove_all_input_output(path: str) -> bool:
try:
if os.path.exists(input_path):
os.remove(input_path)
dirname, _ = os.path.split(input_path)
if os.path.exists(os.path.join(dirname, asset_id + ".mp3")):
os.remove(os.path.join(dirname, asset_id + ".mp3"))
if os.path.exists(output_path):
for file in os.listdir(output_path):
os.remove(file)
if os.path.exists(path):
for file in os.listdir(path):
os.remove(os.path.join(path, file))
logger.info(f"{file} has been removed successfully")
os.rmdir(path)
logger.info("All data has been deleted")
else:
logger.warning(f"{path} not found")
return False
return True

except OSError:
return False
File renamed without changes.
File renamed without changes.
File renamed without changes.
11 changes: 6 additions & 5 deletions download.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

logger = logging.getLogger(__name__)

input_file_dir = os.path.join(data_base_dir, "input/")


@dataclass
class DownloadResult:
Expand Down Expand Up @@ -41,8 +39,9 @@ def http_download(url: str) -> DownloadResult:
steps = [] # to report if input is already downloaded

fn = os.path.basename(urlparse(url).path)
input_file = os.path.join(input_file_dir, fn)
_, extension = get_asset_info(input_file)
asset_id, extension = get_asset_info(fn)
input_file_dir = os.path.join(data_base_dir, asset_id)
input_file = os.path.join(data_base_dir, asset_id, fn)
mime_type = extension_to_mime_type(extension)

# download if the file is not present (preventing unnecessary downloads)
Expand All @@ -58,7 +57,7 @@ def http_download(url: str) -> DownloadResult:
with open(input_file, "wb") as file:
response = requests.get(url)
if response.status_code >= 400:
logger.error(f"Could not download url: {response.status_code}")
logger.error(f"Could not download url. Response code: {response.status_code}")
download_time = (time.time() - start_time) * 1000
return DownloadResult(
input_file,
Expand Down Expand Up @@ -94,6 +93,8 @@ def s3_download(s3_uri: str) -> DownloadResult:

# parse S3 URI
bucket, object_name = parse_s3_uri(s3_uri)
asset_id, extension = get_asset_info(object_name)
input_file_dir = os.path.join(data_base_dir, asset_id)
logger.info(f"OBJECT NAME: {object_name}")
input_file = os.path.join(
input_file_dir,
Expand Down
2 changes: 1 addition & 1 deletion transcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def try_transcode(input_path, asset_id, extension) -> TranscodeOutput:
)

# check if the input file was already transcoded
transcoded_file_path = os.path.join(data_base_dir, "input", f"{asset_id}.mp3")
transcoded_file_path = os.path.join(data_base_dir, asset_id, f"{asset_id}.mp3")
if os.path.exists(transcoded_file_path):
logger.info("Transcoded file is already available, no new transcode needed")
end_time = (time.time() - start_time) * 1000
Expand Down
17 changes: 8 additions & 9 deletions whisper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# import ast
import json
import logging
import traceback
import os
import time
from typing import Optional

import faster_whisper
from config import (
Expand Down Expand Up @@ -51,7 +53,6 @@ def load_model(
def run_asr(input_path, output_dir, model=None) -> dict | str:
logger.info(f"Starting ASR on {input_path}")
start_time = time.time()
# I wanted to add more detailed errors, but function becomes too complex...
try:
if not model:
logger.info("Model not passed as param, need to obtain it first")
Expand Down Expand Up @@ -112,16 +113,14 @@ def run_asr(input_path, output_dir, model=None) -> dict | str:
"steps": [],
}

if write_whisper_json(transcript, output_dir):
return provenance
else:
return "Transcribe failure: Could not save the transcript into a JSON file"
error = write_whisper_json(transcript, output_dir)
return error if error else provenance
except Exception as e:
logger.exception(str(e))
return "Transcribe failure: Something went wrong during transcribing"
return traceback.format_exc()


def write_whisper_json(transcript: dict, output_dir: str) -> bool:
def write_whisper_json(transcript: dict, output_dir: str) -> Optional[str]:
logger.info("Writing whisper-transcript.json")
try:
if not os.path.exists(output_dir):
Expand All @@ -135,5 +134,5 @@ def write_whisper_json(transcript: dict, output_dir: str) -> bool:
json.dump(transcript, f, ensure_ascii=False, indent=4)
except EnvironmentError as e: # OSError or IOError...
logger.exception(os.strerror(e.errno))
return False
return True
return traceback.format_exc()
return None

0 comments on commit 874eb52

Please sign in to comment.