Skip to content

Commit

Permalink
Proposed model download code
Browse files Browse the repository at this point in the history
  • Loading branch information
greenw0lf committed Oct 4, 2024
1 parent 5caf5e2 commit 976b51f
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 33 deletions.
12 changes: 12 additions & 0 deletions base_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import subprocess
import json
from urllib.parse import urlparse
from typing import Tuple
from config import data_base_dir

Expand Down Expand Up @@ -73,3 +74,14 @@ def save_provenance(provenance: dict, asr_output_dir: str) -> bool:
return False

return True


def validate_http_uri(http_uri: str) -> bool:
o = urlparse(http_uri, allow_fragments=False)
if o.scheme != "http" and o.scheme != "https":
logger.error(f"Invalid protocol in {http_uri}")
return False
if o.path == "":
logger.error(f"No object_name specified in {http_uri}")
return False
return True
19 changes: 10 additions & 9 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,13 @@ def assert_tuple(param: str) -> str:


assert w_device in ["cuda", "cpu"], "Please use either cuda|cpu for W_DEVICE"
assert w_model in [
"tiny",
"base",
"small",
"medium",
"large",
"large-v2",
"large-v3",
], "Please use one of: tiny|base|small|medium|large|large-v2|large-v3 for W_MODEL"
if input_uri[0:5] != "s3://" and not validators.url(output_uri):
assert w_model in [
"tiny",
"base",
"small",
"medium",
"large",
"large-v2",
"large-v3",
], "Please use one of: tiny|base|small|medium|large|large-v2|large-v3 for W_MODEL"
72 changes: 50 additions & 22 deletions model_download.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,62 @@
import logging
import os
import tarfile
from urllib.parse import urlparse
import requests
from s3_util import S3Store, parse_s3_uri, validate_s3_uri
from base_util import get_asset_info, validate_http_uri
from config import model_base_dir, w_model, s3_endpoint_url


# makes sure the model is available locally, if not download it from S3, if that fails download from Huggingface
# FIXME should also check if the correct w_model type is available locally!
def check_model_availability() -> bool:
def extract_model(w_model: str, model_base_dir: str, destination: str, asset_id: str, extension: str) -> str:
logger = logging.getLogger(__name__)
if os.path.exists(model_base_dir + "/model.bin"):
logger.info("Model found locally")
return True
else:
logger.info("Model not found locally, attempting to download from S3")
if not validate_s3_uri(w_model):
logger.info("No S3 URI detected")
logger.info(f"Downloading version {w_model} from Huggingface instead")
return False
s3 = S3Store(s3_endpoint_url)
logger.info(f"Downloaded {w_model} into {model_base_dir}")
if not os.path.exists(destination): # Create dir for model to be extracted in
os.makedirs(destination)
logger.info(f"Extracting the model into {destination}")
tar_path = model_base_dir + "/" + asset_id + "." + extension
with tarfile.open(tar_path) as tar:
tar.extractall(path=destination)
# cleanup: delete the tar file
os.remove(tar_path)
return model_base_dir + "/" + asset_id


# makes sure the model is obtained from S3/HTTP/Huggingface, if w_model doesn't exist locally
def check_model_availability() -> str:
logger = logging.getLogger(__name__)

if validate_s3_uri(w_model):
logger.info(f"{w_model} is an S3 URI. Attempting to download")
bucket, object_name = parse_s3_uri(w_model)
asset_id, extension = get_asset_info(object_name)
destination = model_base_dir + "/" + asset_id
if os.path.exists(destination):
logger.info("Model already exists")
return model_base_dir + "/" + asset_id
s3 = S3Store(s3_endpoint_url)
success = s3.download_file(bucket, object_name, model_base_dir)
if not success:
logger.error(f"Could not download {w_model} into {model_base_dir}")
return False
logger.info(f"Downloaded {w_model} into {model_base_dir}")
logger.info("Extracting the model")
tar_path = model_base_dir + "/" + object_name
with tarfile.open(tar_path) as tar:
tar.extractall(path=model_base_dir)
# cleanup: delete the tar file
os.remove(tar_path)
return True
return ""
return extract_model(w_model, model_base_dir, destination, asset_id, extension)

elif validate_http_uri(w_model):
logger.info(f"{w_model} is an HTTP URI. Attempting to download")
asset_id, extension = get_asset_info(urlparse(w_model).path)
destination = model_base_dir + "/" + asset_id
if os.path.exists(destination):
logger.info("Model already exists")
return model_base_dir + "/" + asset_id
with open(model_base_dir + "/" + asset_id + "." + extension, "wb") as file:
response = requests.get(w_model)
if response.status_code >= 400:
logger.error(f"Could not download {w_model} into {model_base_dir}")
return ""
file.write(response.content)
file.close()
return extract_model(w_model, model_base_dir, destination, asset_id, extension)

# The faster-whisper API can auto-detect if the version exists locally. No need to add extra checks
logger.info(f"{w_model} is not an S3/HTTP URI. Using HuggingFace instead")
return w_model
2 changes: 1 addition & 1 deletion s3_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def validate_s3_uri(s3_uri: str) -> bool:
logger.error(f"Invalid protocol in {s3_uri}")
return False
if o.path == "":
logger.error(f"No object_name specified {s3_uri}")
logger.error(f"No object_name specified in {s3_uri}")
return False
return True

Expand Down
2 changes: 1 addition & 1 deletion whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def load_model(model_base_dir: str, model_type: str, device: str) -> WhisperMode
os.environ["HF_HOME"] = model_base_dir

# determine loading locally or have Whisper download from HuggingFace
model_location = model_base_dir if check_model_availability() else model_type
model_location = check_model_availability()
model = WhisperModel(
model_location, # either local path or e.g. large-v2 (means HuggingFace download)
device=device,
Expand Down

0 comments on commit 976b51f

Please sign in to comment.