From c0d1a33b8469f0045fae24ce6f7df3e9bea798ff Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Sun, 31 Dec 2023 12:37:41 +0200 Subject: [PATCH] facilitate model download --- llamashepherd/__init__.py | 1 + llamashepherd/main.py | 55 +++++++++++++++++++++++++++++++++-- llamashepherd/model_config.py | 8 +++++ setup.py | 2 +- 4 files changed, 62 insertions(+), 4 deletions(-) diff --git a/llamashepherd/__init__.py b/llamashepherd/__init__.py index 5b87028..bacf502 100644 --- a/llamashepherd/__init__.py +++ b/llamashepherd/__init__.py @@ -2,4 +2,5 @@ from .main import main from .config import options +from .model_config import urls diff --git a/llamashepherd/main.py b/llamashepherd/main.py index 730ad1c..c2b10f8 100644 --- a/llamashepherd/main.py +++ b/llamashepherd/main.py @@ -4,8 +4,12 @@ import os import subprocess import argparse +import wget + + from tabulate import tabulate import config +import model_config # Access the 'options' dictionary from the config module llamas = config.options @@ -125,10 +129,55 @@ def interactive_action(default_llama_shepherd_path): print("Invalid input. Please enter 'y', 'n', or '0'.") -def initialize_action(): - # Add initialization logic here +def initialize_action(default_llama_shepherd_path): print("Initializing models...") + while True: + user_input = input( + "Do you want to download and config Tokenizer and/or TinyLLama models? (y/n, 0 to exit): ").lower() + + if user_input == 'y': + # Ask whether to download Tokenizer + download_tokenizer = input(f"Do you want to download the Tokenizer model? (y/n): ").lower() + if download_tokenizer == 'y': + # Add logic to download and configure the Tokenizer model + download_and_configure_model("Tokenizer", model_config.urls["tokenizer"], default_llama_shepherd_path) + + # Ask whether to download stories models + download_stories = input(f"Do you want to download the stories models? (y/n): ").lower() + if download_stories == 'y': + # Add logic to download and configure stories models + download_and_configure_model("Stories15M", model_config.urls["stories15M"], default_llama_shepherd_path) + download_and_configure_model("Stories42M", model_config.urls["stories42M"], default_llama_shepherd_path) + download_and_configure_model("Stories110M", model_config.urls["stories110M"], default_llama_shepherd_path) + + break + elif user_input == 'n': + sys.exit() + else: + print("Invalid input. Please enter 'y', 'n', or '0'.") + + +def download_and_configure_model(model_name, model_url, destination_directory): + print(f"Downloading and configuring {model_name} model from: {model_url}") + + # Ensure the models directory exists + models_directory = os.path.join(destination_directory, "models") + os.makedirs(models_directory, exist_ok=True) + + # Specify the destination file path + destination_path = os.path.join(models_directory, f"{model_name}.bin") + + try: + # Download the model using wget + wget.download(model_url, out=destination_path) + print(f"\n{model_name} model downloaded successfully to {destination_path}") + + # Add logic to configure the model if needed + + except Exception as e: + print(f"Error downloading {model_name} model: {e}") + def main(): home_directory = os.path.expanduser("~") @@ -160,7 +209,7 @@ def main(): elif args.action == "interactive": interactive_action(default_llama_shepherd_path) elif args.action == "initialize": - initialize_action() + initialize_action(default_llama_shepherd_path) elif args.action == "--help": parser.print_help() diff --git a/llamashepherd/model_config.py b/llamashepherd/model_config.py index e69de29..470fecb 100644 --- a/llamashepherd/model_config.py +++ b/llamashepherd/model_config.py @@ -0,0 +1,8 @@ +# model_config.py + +urls = { + "stories15M": "https://huggingface.co/karpathy/tinyllamas/blob/main/stories15M.bin", + "stories42M": "https://huggingface.co/karpathy/tinyllamas/blob/main/stories42M.bin", + "stories110M": "https://huggingface.co/karpathy/tinyllamas/blob/main/stories110M.bin", + "tokenizer": "https://github.com/karpathy/llama2.c/blob/master/tokenizer.bin" +} \ No newline at end of file diff --git a/setup.py b/setup.py index f994d77..a73a57f 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ packages=find_packages(), entry_points={ 'console_scripts': [ - 'llamashepherd = main:main', + 'llamashepherd = llamashepherd.main:main', ], }, )