diff --git a/.gitignore b/.gitignore index 826fa7e54..c928174cb 100644 --- a/.gitignore +++ b/.gitignore @@ -8,5 +8,10 @@ __pycache__ # Ignore everything in the generated directory /generated/* +# Ignore the locally persisted chromadb +.chroma_db/ +**/chroma-embeddings.parquet + # Don't ignore .gitkeep files in the generated directory !/generated/.gitkeep + diff --git a/constants.py b/constants.py index b7ccc11b4..9eaa2b8f8 100644 --- a/constants.py +++ b/constants.py @@ -1,4 +1,9 @@ +import os + + EXTENSION_TO_SKIP = [".png",".jpg",".jpeg",".gif",".bmp",".svg",".ico",".tif",".tiff"] DEFAULT_DIR = "generated" +ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) +DEFAULT_DIR_FULL_PATH = os.path.join(ROOT_DIR, DEFAULT_DIR) DEFAULT_MODEL = "gpt-3.5-turbo" # we recommend 'gpt-4' if you have it # gpt3.5 is going to be worse at generating code so we strongly recommend gpt4. i know most people dont have access, we are working on a hosted version DEFAULT_MAX_TOKENS = 2000 # i wonder how to tweak this properly. we dont want it to be max length as it encourages verbosity of code. but too short and code also truncates suddenly. \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/embeddings.py b/src/embeddings.py new file mode 100644 index 000000000..440774ac1 --- /dev/null +++ b/src/embeddings.py @@ -0,0 +1,71 @@ +import os + +from termcolor import colored +import chromadb + +from constants import ROOT_DIR +from src.traversal import traverse_dir + + +class Embeddings(): + + EXCLUDE_PATTERNS = ["*.png","*.jpg","*.jpeg","*.gif","*.bmp","*.svg","*.ico","*.tif","*.tiff"] + + def __init__(self, debug=False): + self.debug = debug + self.GENERATED_FILES_COLLECTION_NAME = "generated_files" + + DB_DIR = os.path.join(ROOT_DIR, ".chroma_db") + # Create the embedding database directory if it doesn't exist + if not os.path.exists(DB_DIR): + os.makedirs(DB_DIR) + + self.client = chromadb.Client( + chromadb.config.Settings( + chroma_db_impl="duckdb+parquet", + persist_directory=DB_DIR, + ) + ) + + self.ensure_generated_files_collection_exists() + + def ensure_generated_files_collection_exists(self): + # Create the generated files collection if it doesn't exist + self.generated_files_collection = self.client.get_or_create_collection(self.GENERATED_FILES_COLLECTION_NAME) + + + def persist_generated_file_contents(self, reset=False): + if reset: + self.generated_files_collection.reset() + self.ensure_generated_files_collection_exists() + + # Iterate over all files in the generated directory + file_paths_list = [] + file_contents_list = [] + metadatas_list = [] + for file_path in traverse_dir("generated", exclude_patterns=self.EXCLUDE_PATTERNS): + file_paths_list.append(file_path) + if self.debug: + print("embedding: " + colored(file_path, 'green')) + # Read the file + with open(file_path, "r") as file: + file_contents_list.append(file.read()) + # Get the filename + filename = os.path.basename(file_path) + # Get the extension + extension = filename.split(".")[-1] + metadatas_list.append({ + "filename": filename, + "extension": extension, + }) + + # Upsert the file into the database + self.generated_files_collection.upsert( + documents=file_contents_list, + metadatas=metadatas_list, + ids=file_paths_list, + ) + if self.debug: + print(colored("persisted embeddings for %s files." % len(file_paths_list), 'yellow')) + + \ No newline at end of file diff --git a/src/traversal.py b/src/traversal.py new file mode 100644 index 000000000..67dfd9e82 --- /dev/null +++ b/src/traversal.py @@ -0,0 +1,23 @@ +import os +import fnmatch + +def traverse_dir(root_dir, include_patterns=None, exclude_patterns=None): + """ + # Example usage: + root_dir = '/path/to/directory' + include_patterns = ['*.txt', '*.py'] # Include files matching these patterns + exclude_patterns = ['exclude_dir1/*', 'exclude_dir2/*'] # Exclude directories matching these patterns + + for file_path in traverse_dir(root_dir, include_patterns=include_patterns, exclude_patterns=exclude_patterns): + print(file_path) + """ + for root, dirs, files in os.walk(root_dir): + if exclude_patterns and any(fnmatch.fnmatch(root, pattern) for pattern in exclude_patterns): + continue + for file in files: + file_path = os.path.join(root, file) + if include_patterns and not any(fnmatch.fnmatch(file_path, pattern) for pattern in include_patterns): + continue + yield file_path + +