diff --git a/chromadb/test/ef/test_default_ef.py b/chromadb/test/ef/test_default_ef.py index c471a3237f6..92eec4c0e72 100644 --- a/chromadb/test/ef/test_default_ef.py +++ b/chromadb/test/ef/test_default_ef.py @@ -67,7 +67,7 @@ def test_provider_repeating(providers: List[str]) -> None: def test_invalid_sha256() -> None: ef = ONNXMiniLM_L6_V2() - shutil.rmtree(ef.DOWNLOAD_PATH) # clean up any existing models + shutil.rmtree(ef.download_path) # clean up any existing models with pytest.raises(ValueError) as e: ef._MODEL_SHA256 = "invalid" ef(["test"]) @@ -76,15 +76,15 @@ def test_invalid_sha256() -> None: def test_partial_download() -> None: ef = ONNXMiniLM_L6_V2() - shutil.rmtree(ef.DOWNLOAD_PATH, ignore_errors=True) # clean up any existing models - os.makedirs(ef.DOWNLOAD_PATH, exist_ok=True) - path = os.path.join(ef.DOWNLOAD_PATH, ef.ARCHIVE_FILENAME) + shutil.rmtree(ef.download_path, ignore_errors=True) # clean up any existing models + os.makedirs(ef.download_path, exist_ok=True) + path = os.path.join(ef.download_path, ef.ARCHIVE_FILENAME) with open(path, "wb") as f: # create invalid file to simulate partial download f.write(b"invalid") ef._download_model_if_not_exists() # re-download model assert os.path.exists(path) assert _verify_sha256( - str(os.path.join(ef.DOWNLOAD_PATH, ef.ARCHIVE_FILENAME)), + str(os.path.join(ef.download_path, ef.ARCHIVE_FILENAME)), ef._MODEL_SHA256, ) assert len(ef(["test"])) == 1 diff --git a/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py b/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py index 4084121cd6e..fb6e162a1a0 100644 --- a/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py +++ b/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py @@ -36,7 +36,9 @@ def _verify_sha256(fname: str, expected_sha256: str) -> bool: # and verify the ONNX model. class ONNXMiniLM_L6_V2(EmbeddingFunction[Documents]): MODEL_NAME = "all-MiniLM-L6-v2" - DOWNLOAD_PATH = Path.home() / ".cache" / "chroma" / "onnx_models" / MODEL_NAME + _DEFAULT_DOWNLOAD_PATH = ( + Path.home() / ".cache" / "chroma" / "onnx_models" / MODEL_NAME + ) EXTRACTED_FOLDER_NAME = "onnx" ARCHIVE_FILENAME = "onnx.tar.gz" MODEL_DOWNLOAD_URL = ( @@ -44,13 +46,19 @@ class ONNXMiniLM_L6_V2(EmbeddingFunction[Documents]): ) _MODEL_SHA256 = "913d7300ceae3b2dbc2c50d1de4baacab4be7b9380491c27fab7418616a16ec3" - def __init__(self, preferred_providers: Optional[List[str]] = None) -> None: + def __init__( + self, + preferred_providers: Optional[List[str]] = None, + download_path: Optional[Path] = None, + ) -> None: """ Initialize the ONNXMiniLM_L6_V2 embedding function. Args: preferred_providers (List[str], optional): The preferred ONNX runtime providers. Defaults to None. + download_path (Path, optional): The path to download the model to. + Defaults to None. """ # convert the list to set for unique values if preferred_providers and not all( @@ -64,6 +72,7 @@ def __init__(self, preferred_providers: Optional[List[str]] = None) -> None: raise ValueError("Preferred providers must be unique") self._preferred_providers = preferred_providers + self.download_path = download_path or self._DEFAULT_DOWNLOAD_PATH try: # Equivalent to import onnxruntime @@ -205,7 +214,7 @@ def tokenizer(self) -> Any: """ tokenizer = self.Tokenizer.from_file( os.path.join( - self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "tokenizer.json" + self.download_path, self.EXTRACTED_FOLDER_NAME, "tokenizer.json" ) ) # max_seq_length = 256, for some reason sentence-transformers uses 256 even though the HF config has a max length of 128 @@ -249,7 +258,7 @@ def model(self) -> Any: self._preferred_providers.remove("CoreMLExecutionProvider") return self.ort.InferenceSession( - os.path.join(self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "model.onnx"), + os.path.join(self.download_path, self.EXTRACTED_FOLDER_NAME, "model.onnx"), # Since 1.9 onnyx runtime requires providers to be specified when there are multiple available providers=self._preferred_providers, sess_options=so, @@ -290,7 +299,7 @@ def _download_model_if_not_exists(self) -> None: "tokenizer.json", "vocab.txt", ] - extracted_folder = os.path.join(self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME) + extracted_folder = os.path.join(self.download_path, self.EXTRACTED_FOLDER_NAME) onnx_files_exist = True for f in onnx_files: if not os.path.exists(os.path.join(extracted_folder, f)): @@ -298,28 +307,28 @@ def _download_model_if_not_exists(self) -> None: break # Model is not downloaded yet if not onnx_files_exist: - os.makedirs(self.DOWNLOAD_PATH, exist_ok=True) + os.makedirs(self.download_path, exist_ok=True) if not os.path.exists( - os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME) + os.path.join(self.download_path, self.ARCHIVE_FILENAME) ) or not _verify_sha256( - os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME), + os.path.join(self.download_path, self.ARCHIVE_FILENAME), self._MODEL_SHA256, ): self._download( url=self.MODEL_DOWNLOAD_URL, - fname=os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME), + fname=os.path.join(self.download_path, self.ARCHIVE_FILENAME), ) with tarfile.open( - name=os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME), + name=os.path.join(self.download_path, self.ARCHIVE_FILENAME), mode="r:gz", ) as tar: if sys.version_info >= (3, 12): - tar.extractall(path=self.DOWNLOAD_PATH, filter="data") + tar.extractall(path=self.download_path, filter="data") else: # filter argument was added in Python 3.12 # https://docs.python.org/3/library/tarfile.html#tarfile.TarFile.extractall # In versions prior to 3.12, this provides the same behavior as filter="data" - tar.extractall(path=self.DOWNLOAD_PATH) + tar.extractall(path=self.download_path) @staticmethod def name() -> str: