Skip to content

Commit

Permalink
Refactor code and improve interactive interface
Browse files Browse the repository at this point in the history
The code has been largely refactored to improve function design, particularly focusing on user interaction. Changes include simplifying option selection, making choice display more manageable, adding a new function to retrieve language for a selected option, and reordering functions for better logic flow. Changes were also made to the argparse setup to update command line choices.
  • Loading branch information
mikepapadim committed Jan 4, 2024
1 parent 71b896f commit aa13097
Showing 1 changed file with 101 additions and 56 deletions.
157 changes: 101 additions & 56 deletions llamashepherd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def display_options(category=None, language=None):

category_rows = []
for option in options:
if language and option['name'].lower() != language.lower():
if language and option["name"].lower() != language.lower():
continue

# Check if "author" key exists in the option dictionary
Expand Down Expand Up @@ -55,104 +55,149 @@ def print_table(table):
print("\n" + "-" * 121 + "\n")


def choose_implemenation():
def clone_repository(url, destination):
try:
subprocess.run(["git", "clone", url, destination], check=True)
print(f"Repository cloned successfully to {destination}")
except subprocess.CalledProcessError as e:
print(f"Error cloning repository: {e}")


def list_action(language=None):
display_options(language)


def choose_implementation():
display_options()
while True:
try:
choice = int(input("Enter the number of your choice (0 to exit): "))
if 0 <= choice <= len(llamas):
return None if choice == 0 else list(llamas.keys())[choice - 1]
selected_category = (
None if choice == 0 else list(llamas.keys())[choice - 1]
)
return choice, selected_category
else:
print("Invalid choice. Please enter a number between 0 and", len(llamas))
print(
"Invalid choice. Please enter a number between 0 and", len(llamas)
)
except ValueError:
print("Invalid input. Please enter a number.")


def choose_option(category):
category_options = llamas[category]
print(f"\nChoose an option from the {category} category (0 to go back):")
for i, option in enumerate(category_options, start=1):
name = option.get('name', 'N/A')
author = option.get('author', 'N/A')
print(f"{i}. {name} by {author}")
# ...


def choose_option():
all_options = [option for options in llamas.values() for option in options]
total_options = len(all_options)

print("Choose an option (0 to go back):")

while True:
try:
choice = int(input("Enter the number of your choice (0 to go back): "))
if 0 <= choice <= len(category_options):
return None if choice == 0 else category_options[choice - 1]
choice = int(input(f"Enter the number of your choice (0 to go back): "))
if 0 <= choice <= total_options:
return None if choice == 0 else all_options[choice - 1]
else:
print(f"Invalid choice. Please enter a number between 0 and {len(category_options)}")
print(
f"Invalid choice. Please enter a number between 0 and {total_options}"
)
except ValueError:
print("Invalid input. Please enter a number.")


def clone_repository(url, destination):
try:
subprocess.run(["git", "clone", url, destination], check=True)
print(f"Repository cloned successfully to {destination}")
except subprocess.CalledProcessError as e:
print(f"Error cloning repository: {e}")


def list_action(language=None):
display_options(language)
def get_language_for_option(options_dict, selected_option):
for language, options_list in options_dict.items():
for option in options_list:
if option == selected_option:
return language
return (
None # Return None if the selected option is not found in any language category
)


def interactive_action(default_llama_shepherd_path):
display_options()
while True:
category = choose_implemenation()
if category is None:
break # Exit the program if the user chooses to exit

selected_option = choose_option(category)
if selected_option is None:
continue # Go back to category selection if the user chooses to go back
selected_option = choose_option()
selected_category = get_language_for_option(llamas, selected_option)

default_path = os.path.join(default_llama_shepherd_path, category, selected_option['name'])
destination = input(f"Enter the destination directory (default: {default_path}): ").strip() or default_path
default_path = os.path.join(
default_llama_shepherd_path, selected_category, selected_option["name"]
)
destination = (
input(
f"Enter the destination directory (default: {default_path}): "
).strip()
or default_path
)

# Create the destination directory if it doesn't exist
os.makedirs(destination, exist_ok=True)

clone_repository(selected_option['url'], destination)
clone_repository(selected_option["url"], destination)

# Ask whether to initialize or exit
while True:
user_input = input(
"Do you want to download and config Tokenizer and/or TinyLLama models? (y/n, 0 to exit): ").lower()
if user_input == 'y':
initialize_action()
"Do you want to download and config Tokenizer and/or TinyLLama models? (y/n, 0 to exit): "
).lower()
if user_input == "y":
initialize_action(default_llama_shepherd_path)
break
elif user_input == 'n':
elif user_input == "n":
sys.exit()
else:
print("Invalid input. Please enter 'y', 'n', or '0'.")
print("Invalid input. Please enter 'y', 'n', or '0.'")


def initialize_action(default_llama_shepherd_path):
print("Initializing models...")

while True:
user_input = input(
"Do you want to download and config Tokenizer and/or TinyLLama models? (y/n, 0 to exit): ").lower()
"Do you want to download and config Tokenizer and/or TinyLLama models? (y/n, 0 to exit): "
).lower()

if user_input == 'y':
if user_input == "y":
# Ask whether to download Tokenizer
download_tokenizer = input(f"Do you want to download the Tokenizer model? (y/n): ").lower()
if download_tokenizer == 'y':
download_tokenizer = input(
f"Do you want to download the Tokenizer model? (y/n): "
).lower()
if download_tokenizer == "y":
# Add logic to download and configure the Tokenizer model
download_and_configure_model("Tokenizer", model_config.urls["tokenizer"], default_llama_shepherd_path)
download_and_configure_model(
"Tokenizer",
model_config.urls["tokenizer"],
default_llama_shepherd_path,
)

# Ask whether to download stories models
download_stories = input(f"Do you want to download the stories models? (y/n): ").lower()
if download_stories == 'y':
download_stories = input(
f"Do you want to download the stories models? (y/n): "
).lower()
if download_stories == "y":
# Add logic to download and configure stories models
download_and_configure_model("Stories15M", model_config.urls["stories15M"], default_llama_shepherd_path)
download_and_configure_model("Stories42M", model_config.urls["stories42M"], default_llama_shepherd_path)
download_and_configure_model("Stories110M", model_config.urls["stories110M"], default_llama_shepherd_path)
download_and_configure_model(
"Stories15M",
model_config.urls["stories15M"],
default_llama_shepherd_path,
)
download_and_configure_model(
"Stories42M",
model_config.urls["stories42M"],
default_llama_shepherd_path,
)
download_and_configure_model(
"Stories110M",
model_config.urls["stories110M"],
default_llama_shepherd_path,
)

break
elif user_input == 'n':
elif user_input == "n":
sys.exit()
else:
print("Invalid input. Please enter 'y', 'n', or '0'.")
Expand Down Expand Up @@ -185,30 +230,30 @@ def main():

parser = argparse.ArgumentParser(
description="Llama Shepherd CLI: Manage your llama-related projects.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter # Show default values in the help menu
formatter_class=argparse.ArgumentDefaultsHelpFormatter, # Show default values in the help menu
)
parser.add_argument(
"action",
nargs="?",
default="--help", # Set default action to "--help"
choices=["list", "interactive", "initialize", "--help"],
help="Action to perform"
choices=["list", "install", "models", "--help"],
help="Action to perform",
)

parser.add_argument(
"language",
nargs="?", # Make language optional
default=None,
help="Specify the language for the 'list' action"
help="Specify the language for the 'list' action",
)

args = parser.parse_args()

if args.action == "list":
list_action(args.language) # Pass the language argument
elif args.action == "interactive":
elif args.action == "install":
interactive_action(default_llama_shepherd_path)
elif args.action == "initialize":
elif args.action == "models":
initialize_action(default_llama_shepherd_path)
elif args.action == "--help":
parser.print_help()
Expand Down

0 comments on commit aa13097

Please sign in to comment.