Skip to content

Commit

Permalink
facilitate model download
Browse files Browse the repository at this point in the history
  • Loading branch information
mikepapadim committed Dec 31, 2023
1 parent e2f516d commit c0d1a33
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 4 deletions.
1 change: 1 addition & 0 deletions llamashepherd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@

from .main import main
from .config import options
from .model_config import urls

55 changes: 52 additions & 3 deletions llamashepherd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
import os
import subprocess
import argparse
import wget


from tabulate import tabulate
import config
import model_config

# Access the 'options' dictionary from the config module
llamas = config.options
Expand Down Expand Up @@ -125,10 +129,55 @@ def interactive_action(default_llama_shepherd_path):
print("Invalid input. Please enter 'y', 'n', or '0'.")


def initialize_action():
# Add initialization logic here
def initialize_action(default_llama_shepherd_path):
print("Initializing models...")

while True:
user_input = input(
"Do you want to download and config Tokenizer and/or TinyLLama models? (y/n, 0 to exit): ").lower()

if user_input == 'y':
# Ask whether to download Tokenizer
download_tokenizer = input(f"Do you want to download the Tokenizer model? (y/n): ").lower()
if download_tokenizer == 'y':
# Add logic to download and configure the Tokenizer model
download_and_configure_model("Tokenizer", model_config.urls["tokenizer"], default_llama_shepherd_path)

# Ask whether to download stories models
download_stories = input(f"Do you want to download the stories models? (y/n): ").lower()
if download_stories == 'y':
# Add logic to download and configure stories models
download_and_configure_model("Stories15M", model_config.urls["stories15M"], default_llama_shepherd_path)
download_and_configure_model("Stories42M", model_config.urls["stories42M"], default_llama_shepherd_path)
download_and_configure_model("Stories110M", model_config.urls["stories110M"], default_llama_shepherd_path)

break
elif user_input == 'n':
sys.exit()
else:
print("Invalid input. Please enter 'y', 'n', or '0'.")


def download_and_configure_model(model_name, model_url, destination_directory):
print(f"Downloading and configuring {model_name} model from: {model_url}")

# Ensure the models directory exists
models_directory = os.path.join(destination_directory, "models")
os.makedirs(models_directory, exist_ok=True)

# Specify the destination file path
destination_path = os.path.join(models_directory, f"{model_name}.bin")

try:
# Download the model using wget
wget.download(model_url, out=destination_path)
print(f"\n{model_name} model downloaded successfully to {destination_path}")

# Add logic to configure the model if needed

except Exception as e:
print(f"Error downloading {model_name} model: {e}")


def main():
home_directory = os.path.expanduser("~")
Expand Down Expand Up @@ -160,7 +209,7 @@ def main():
elif args.action == "interactive":
interactive_action(default_llama_shepherd_path)
elif args.action == "initialize":
initialize_action()
initialize_action(default_llama_shepherd_path)
elif args.action == "--help":
parser.print_help()

Expand Down
8 changes: 8 additions & 0 deletions llamashepherd/model_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# model_config.py

urls = {
"stories15M": "https://huggingface.co/karpathy/tinyllamas/blob/main/stories15M.bin",
"stories42M": "https://huggingface.co/karpathy/tinyllamas/blob/main/stories42M.bin",
"stories110M": "https://huggingface.co/karpathy/tinyllamas/blob/main/stories110M.bin",
"tokenizer": "https://github.com/karpathy/llama2.c/blob/master/tokenizer.bin"
}
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
packages=find_packages(),
entry_points={
'console_scripts': [
'llamashepherd = main:main',
'llamashepherd = llamashepherd.main:main',
],
},
)
Expand Down

0 comments on commit c0d1a33

Please sign in to comment.