From c21d4c484d051123bb3d493a3898679d7df4a627 Mon Sep 17 00:00:00 2001 From: Roberto Rodriguez <9653181+Cyb3rWard0g@users.noreply.github.com> Date: Thu, 23 Jan 2025 01:32:27 -0500 Subject: [PATCH] added persistence for sentencetransformer models --- ...entencetransformers_all-MiniLM-L6-v2.ipynb | 193 ++++++++++++++---- src/floki/document/embedder/sentence.py | 27 ++- 2 files changed, 182 insertions(+), 38 deletions(-) diff --git a/cookbook/vectorstores/chroma_sentencetransformers_all-MiniLM-L6-v2.ipynb b/cookbook/vectorstores/chroma_sentencetransformers_all-MiniLM-L6-v2.ipynb index e31bc5d..c44e0c4 100644 --- a/cookbook/vectorstores/chroma_sentencetransformers_all-MiniLM-L6-v2.ipynb +++ b/cookbook/vectorstores/chroma_sentencetransformers_all-MiniLM-L6-v2.ipynb @@ -34,6 +34,24 @@ "!pip install floki-ai chromadb" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Enable Logging" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "\n", + "logging.basicConfig(level=logging.INFO)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -45,7 +63,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -54,7 +72,7 @@ "True" ] }, - "execution_count": 1, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -75,14 +93,26 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:datasets:PyTorch version 2.5.1 available.\n", + "INFO:floki.document.embedder.sentence:Loading SentenceTransformer model from local path: model\n", + "INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: model\n", + "INFO:floki.document.embedder.sentence:Model loaded successfully.\n" + ] + } + ], "source": [ "from floki.document.embedder import SentenceTransformerEmbedder\n", "\n", "embedding_function = SentenceTransformerEmbedder(\n", - " model=\"all-MiniLM-L6-v2\"\n", + " model=\"all-MiniLM-L6-v2\",\n", + " cache_dir=\"model\"\n", ")" ] }, @@ -97,9 +127,17 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:floki.storage.vectorstores.chroma:ChromaVectorStore initialized with collection: example_collection\n" + ] + } + ], "source": [ "from floki.storage import ChromaVectorStore\n", "\n", @@ -130,7 +168,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -190,9 +228,30 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:floki.document.embedder.sentence:Generating embeddings for 10 input(s).\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f2de0ae5fbe84c838b47b2cf393ca7ef", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Batches: 0%| | 0/1 [00:00]}" ] }, - "execution_count": 13, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -418,16 +541,16 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "['example_collection']" + "[Collection(name=example_collection)]" ] }, - "execution_count": 14, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -438,7 +561,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -448,7 +571,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -457,7 +580,7 @@ "[]" ] }, - "execution_count": 16, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } diff --git a/src/floki/document/embedder/sentence.py b/src/floki/document/embedder/sentence.py index 10b48b9..3a0fa06 100644 --- a/src/floki/document/embedder/sentence.py +++ b/src/floki/document/embedder/sentence.py @@ -2,6 +2,7 @@ from typing import List, Any, Optional, Union, Literal from pydantic import Field import logging +import os logger = logging.getLogger(__name__) @@ -15,6 +16,7 @@ class SentenceTransformerEmbedder(EmbedderBase): device: Literal["cpu", "cuda", "mps", "npu"] = Field(default="cpu", description="Device for computation.") normalize_embeddings: bool = Field(default=False, description="Whether to normalize embeddings.") multi_process: bool = Field(default=False, description="Whether to use multi-process encoding.") + cache_dir: Optional[str] = Field(default=None, description="Directory to cache or load the model.") client: Optional[Any] = Field(default=None, init=False, description="Loaded SentenceTransformer model.") @@ -32,9 +34,28 @@ def model_post_init(self, __context: Any) -> None: "Install it using `pip install sentence-transformers`." ) - logger.info(f"Loading SentenceTransformer model: {self.model}") - self.client: SentenceTransformer = SentenceTransformer(model_name_or_path=self.model, device=self.device) - logger.info("Model loaded successfully.") + # Determine whether to load from cache or download + model_path = self.cache_dir if self.cache_dir and os.path.exists(self.cache_dir) else self.model + + # Attempt to load the model + try: + if os.path.exists(model_path): + logger.info(f"Loading SentenceTransformer model from local path: {model_path}") + else: + logger.info(f"Downloading SentenceTransformer model: {self.model}") + if self.cache_dir: + logger.info(f"Model will be cached to: {self.cache_dir}") + + self.client: SentenceTransformer = SentenceTransformer(model_name_or_path=model_path, device=self.device) + logger.info("Model loaded successfully.") + except Exception as e: + logger.error(f"Failed to load SentenceTransformer model: {e}") + raise + + # Save to cache directory if downloaded + if model_path == self.model and self.cache_dir and not os.path.exists(self.cache_dir): + logger.info(f"Saving the downloaded model to: {self.cache_dir}") + self.client.save(self.cache_dir) def embed(self, input: Union[str, List[str]]) -> Union[List[float], List[List[float]]]: """