diff --git a/ingest.py b/ingest.py index 5819e0f8..6508a445 100644 --- a/ingest.py +++ b/ingest.py @@ -1,5 +1,7 @@ import logging import os +import sys +from utils import default_device_type from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed import click @@ -38,7 +40,7 @@ def load_single_document(file_path: str) -> Document: return loader.load()[0] except Exception as ex: file_log('%s loading error: \n%s' % (file_path, ex)) - return None + return None def load_document_batch(filepaths): logging.info("Loading document batch") @@ -93,7 +95,7 @@ def load_documents(source_dir: str) -> list[Document]: docs.extend(contents) except Exception as ex: file_log('Exception: %s' % (ex)) - + return docs @@ -109,11 +111,10 @@ def split_documents(documents: list[Document]) -> tuple[list[Document], list[Doc text_docs.append(doc) return text_docs, python_docs - @click.command() @click.option( "--device_type", - default="cuda" if torch.cuda.is_available() else "cpu", + default=default_device_type(), type=click.Choice( [ "cpu", @@ -137,7 +138,7 @@ def split_documents(documents: list[Document]) -> tuple[list[Document], list[Doc "mtia", ], ), - help="Device to run on. (Default is cuda)", + help=f"Device to run on. (Default is {default_device_type()})", ) def main(device_type): # Load documents and split in chunks @@ -171,7 +172,7 @@ def main(device_type): persist_directory=PERSIST_DIRECTORY, client_settings=CHROMA_SETTINGS, ) - + if __name__ == "__main__": diff --git a/run_localGPT.py b/run_localGPT.py index c9f0a2fa..aee6fdce 100644 --- a/run_localGPT.py +++ b/run_localGPT.py @@ -3,6 +3,7 @@ import click import torch import utils +from utils import default_device_type from langchain.chains import RetrievalQA from langchain.embeddings import HuggingFaceInstructEmbeddings from langchain.llms import HuggingFacePipeline @@ -161,11 +162,11 @@ def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"): return qa -# chose device typ to run on as well as to show source documents. +# chose device type to run on as well as to show source documents. @click.command() @click.option( "--device_type", - default="cuda" if torch.cuda.is_available() else "cpu", + default=default_device_type(), type=click.Choice( [ "cpu", @@ -189,7 +190,7 @@ def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"): "mtia", ], ), - help="Device to run on. (Default is cuda)", + help=f"Device to run on. (Default is {default_device_type()})", ) @click.option( "--show_sources", @@ -269,7 +270,7 @@ def main(device_type, show_sources, use_history, model_type, save_qa): print("\n> " + document.metadata["source"] + ":") print(document.page_content) print("----------------------------------SOURCE DOCUMENTS---------------------------") - + # Log the Q&A to CSV only if save_qa is True if save_qa: utils.log_to_csv(query, answer) diff --git a/utils.py b/utils.py index 0440d214..b5c780ff 100644 --- a/utils.py +++ b/utils.py @@ -1,5 +1,7 @@ import os import csv +import torch +import sys from datetime import datetime def log_to_csv(question, answer): @@ -22,4 +24,10 @@ def log_to_csv(question, answer): with open(log_path, mode='a', newline='', encoding='utf-8') as file: writer = csv.writer(file) timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - writer.writerow([timestamp, question, answer]) \ No newline at end of file + writer.writerow([timestamp, question, answer]) + +def has_mps(): + return sys.platform == "darwin" and torch.backends.mps.is_available() + +def default_device_type(): + return "cuda" if torch.cuda.is_available() else "mps" if has_mps() else "cpu"