Skip to content

[CHORE] Add path support for onnx ef #4909

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions chromadb/test/ef/test_default_ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_provider_repeating(providers: List[str]) -> None:

def test_invalid_sha256() -> None:
ef = ONNXMiniLM_L6_V2()
shutil.rmtree(ef.DOWNLOAD_PATH) # clean up any existing models
shutil.rmtree(ef.download_path) # clean up any existing models
with pytest.raises(ValueError) as e:
ef._MODEL_SHA256 = "invalid"
ef(["test"])
Expand All @@ -76,15 +76,15 @@ def test_invalid_sha256() -> None:

def test_partial_download() -> None:
ef = ONNXMiniLM_L6_V2()
shutil.rmtree(ef.DOWNLOAD_PATH, ignore_errors=True) # clean up any existing models
os.makedirs(ef.DOWNLOAD_PATH, exist_ok=True)
path = os.path.join(ef.DOWNLOAD_PATH, ef.ARCHIVE_FILENAME)
shutil.rmtree(ef.download_path, ignore_errors=True) # clean up any existing models
os.makedirs(ef.download_path, exist_ok=True)
path = os.path.join(ef.download_path, ef.ARCHIVE_FILENAME)
with open(path, "wb") as f: # create invalid file to simulate partial download
f.write(b"invalid")
ef._download_model_if_not_exists() # re-download model
assert os.path.exists(path)
assert _verify_sha256(
str(os.path.join(ef.DOWNLOAD_PATH, ef.ARCHIVE_FILENAME)),
str(os.path.join(ef.download_path, ef.ARCHIVE_FILENAME)),
ef._MODEL_SHA256,
)
assert len(ef(["test"])) == 1
33 changes: 21 additions & 12 deletions chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,29 @@ def _verify_sha256(fname: str, expected_sha256: str) -> bool:
# and verify the ONNX model.
class ONNXMiniLM_L6_V2(EmbeddingFunction[Documents]):
MODEL_NAME = "all-MiniLM-L6-v2"
DOWNLOAD_PATH = Path.home() / ".cache" / "chroma" / "onnx_models" / MODEL_NAME
_DEFAULT_DOWNLOAD_PATH = (
Path.home() / ".cache" / "chroma" / "onnx_models" / MODEL_NAME
)
EXTRACTED_FOLDER_NAME = "onnx"
ARCHIVE_FILENAME = "onnx.tar.gz"
MODEL_DOWNLOAD_URL = (
"https://chroma-onnx-models.s3.amazonaws.com/all-MiniLM-L6-v2/onnx.tar.gz"
)
_MODEL_SHA256 = "913d7300ceae3b2dbc2c50d1de4baacab4be7b9380491c27fab7418616a16ec3"

def __init__(self, preferred_providers: Optional[List[str]] = None) -> None:
def __init__(
self,
preferred_providers: Optional[List[str]] = None,
download_path: Optional[Path] = None,
) -> None:
"""
Initialize the ONNXMiniLM_L6_V2 embedding function.

Args:
preferred_providers (List[str], optional): The preferred ONNX runtime providers.
Defaults to None.
download_path (Path, optional): The path to download the model to.
Defaults to None.
"""
# convert the list to set for unique values
if preferred_providers and not all(
Expand All @@ -64,6 +72,7 @@ def __init__(self, preferred_providers: Optional[List[str]] = None) -> None:
raise ValueError("Preferred providers must be unique")

self._preferred_providers = preferred_providers
self.download_path = download_path or self._DEFAULT_DOWNLOAD_PATH

try:
# Equivalent to import onnxruntime
Expand Down Expand Up @@ -205,7 +214,7 @@ def tokenizer(self) -> Any:
"""
tokenizer = self.Tokenizer.from_file(
os.path.join(
self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "tokenizer.json"
self.download_path, self.EXTRACTED_FOLDER_NAME, "tokenizer.json"
)
)
# max_seq_length = 256, for some reason sentence-transformers uses 256 even though the HF config has a max length of 128
Expand Down Expand Up @@ -249,7 +258,7 @@ def model(self) -> Any:
self._preferred_providers.remove("CoreMLExecutionProvider")

return self.ort.InferenceSession(
os.path.join(self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "model.onnx"),
os.path.join(self.download_path, self.EXTRACTED_FOLDER_NAME, "model.onnx"),
# Since 1.9 onnyx runtime requires providers to be specified when there are multiple available
providers=self._preferred_providers,
sess_options=so,
Expand Down Expand Up @@ -290,36 +299,36 @@ def _download_model_if_not_exists(self) -> None:
"tokenizer.json",
"vocab.txt",
]
extracted_folder = os.path.join(self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME)
extracted_folder = os.path.join(self.download_path, self.EXTRACTED_FOLDER_NAME)
onnx_files_exist = True
for f in onnx_files:
if not os.path.exists(os.path.join(extracted_folder, f)):
onnx_files_exist = False
break
# Model is not downloaded yet
if not onnx_files_exist:
os.makedirs(self.DOWNLOAD_PATH, exist_ok=True)
os.makedirs(self.download_path, exist_ok=True)
if not os.path.exists(
os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME)
os.path.join(self.download_path, self.ARCHIVE_FILENAME)
) or not _verify_sha256(
os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME),
os.path.join(self.download_path, self.ARCHIVE_FILENAME),
self._MODEL_SHA256,
):
self._download(
url=self.MODEL_DOWNLOAD_URL,
fname=os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME),
fname=os.path.join(self.download_path, self.ARCHIVE_FILENAME),
)
with tarfile.open(
name=os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME),
name=os.path.join(self.download_path, self.ARCHIVE_FILENAME),
mode="r:gz",
) as tar:
if sys.version_info >= (3, 12):
tar.extractall(path=self.DOWNLOAD_PATH, filter="data")
tar.extractall(path=self.download_path, filter="data")
else:
# filter argument was added in Python 3.12
# https://docs.python.org/3/library/tarfile.html#tarfile.TarFile.extractall
# In versions prior to 3.12, this provides the same behavior as filter="data"
tar.extractall(path=self.DOWNLOAD_PATH)
tar.extractall(path=self.download_path)

@staticmethod
def name() -> str:
Expand Down
Loading