Skip to content

Expose dimensions in the Embedders #405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions examples/customize/embeddings/cohere_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@
api_key = None

embeder = CohereEmbeddings(
model="embed-english-v3.0",
model="embed-v4.0",
api_key=api_key,
)
res = embeder.embed_query("my question")
res = embeder.embed_query(
"my question",
# optionally, set output dimensions if it's supported by the model
dimensions=256,
input_type="search_query",
)
print("Embedding dimensions", len(res))
print(res[:10])
9 changes: 7 additions & 2 deletions examples/customize/embeddings/custom_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@ class CustomEmbeddings(Embedder):
def __init__(self, dimension: int = 10, **kwargs: Any):
self.dimension = dimension

def embed_query(self, input: str) -> list[float]:
return [random.random() for _ in range(self.dimension)]
def embed_query(
self, input: str, dimensions: int | None = None, **kwargs: Any
) -> list[float]:
v = [random.random() for _ in range(self.dimension)]
if dimensions:
return v[:dimensions]
return v


llm = CustomEmbeddings(dimensions=1024)
Expand Down
7 changes: 6 additions & 1 deletion examples/customize/embeddings/mistalai_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,10 @@
api_key = None

embeder = MistralAIEmbeddings(model="mistral-embed", api_key=api_key)
res = embeder.embed_query("my question")
res = embeder.embed_query(
"my question",
# optionally, set output dimensions
dimensions=256,
)
print("Embedding dimensions", len(res))
print(res[:10])
10 changes: 8 additions & 2 deletions examples/customize/embeddings/openai_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
# set api key here on in the OPENAI_API_KEY env var
api_key = None

embeder = OpenAIEmbeddings(model="text-embedding-ada-002", api_key=api_key)
res = embeder.embed_query("my question")
embeder = OpenAIEmbeddings(model="text-embedding-3-small", api_key=api_key)
res = embeder.embed_query(
"my question",
# optionally, set output dimensions
# dimensions=256,
)

print("Embedding dimensions", len(res))
print(res[:10])
6 changes: 5 additions & 1 deletion examples/customize/embeddings/vertexai_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,9 @@
from neo4j_graphrag.embeddings import VertexAIEmbeddings

embeder = VertexAIEmbeddings(model="text-embedding-005")
res = embeder.embed_query("my question")
res = embeder.embed_query(
"my question",
dimensions=256,
)
print("Embedding dimensions", len(res))
print(res[:10])
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from typing import Any
from random import random

from neo4j import GraphDatabase
Expand All @@ -12,16 +13,19 @@

INDEX_NAME = "embedding-name"
FULLTEXT_INDEX_NAME = "fulltext-index-name"
DIMENSION = 1536
EMBEDDING_DIMENSIONS = 1536

# Connect to Neo4j database
driver = GraphDatabase.driver(URI, auth=AUTH)


# Create Embedder object
class CustomEmbedder(Embedder):
def embed_query(self, text: str) -> list[float]:
return [random() for _ in range(DIMENSION)]
def embed_query(
self, text: str, dimensions: int | None = None, **kwargs: Any
) -> list[float]:
dimensions = dimensions or EMBEDDING_DIMENSIONS
return [random() for _ in range(dimensions)]


embedder = CustomEmbedder()
Expand All @@ -32,7 +36,7 @@ def embed_query(self, text: str) -> list[float]:
INDEX_NAME,
label="Document",
embedding_property="vectorProperty",
dimensions=DIMENSION,
dimensions=EMBEDDING_DIMENSIONS,
similarity_fn="euclidean",
)
create_fulltext_index(
Expand All @@ -43,7 +47,7 @@ def embed_query(self, text: str) -> list[float]:
retriever = HybridRetriever(driver, INDEX_NAME, FULLTEXT_INDEX_NAME, embedder)

# Upsert the query
vector = [random() for _ in range(DIMENSION)]
vector = [random() for _ in range(EMBEDDING_DIMENSIONS)]
insert_query = (
"MERGE (n:Document {id: $id})"
"WITH n "
Expand Down
6 changes: 5 additions & 1 deletion src/neo4j_graphrag/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any

from abc import ABC, abstractmethod

Expand All @@ -24,11 +25,14 @@ class Embedder(ABC):
"""

@abstractmethod
def embed_query(self, text: str) -> list[float]:
def embed_query(
self, text: str, dimensions: int | None = None, **kwargs: Any
) -> list[float]:
"""Embed query text.

Args:
text (str): Text to convert to vector embedding
dimensions (Optional[int]): The number of dimensions the resulting output embeddings should have. Only for models supporting it.

Returns:
list[float]: A vector embedding.
Expand Down
17 changes: 14 additions & 3 deletions src/neo4j_graphrag/embeddings/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,23 @@ def __init__(self, model: str = "", **kwargs: Any) -> None:
Please install it with `pip install "neo4j-graphrag[cohere]"`."""
)
self.model = model
self.client = cohere.Client(**kwargs)
self.client = cohere.ClientV2(**kwargs)

def embed_query(self, text: str, **kwargs: Any) -> list[float]:
def embed_query(
self, text: str, dimensions: int | None = None, **kwargs: Any
) -> list[float]:
"""
Generate embeddings for a given query using a Cohere text embedding model.

Args:
text (str): The text to generate an embedding for.
dimensions (Optional[int]): The number of dimensions the resulting output embeddings should have. Only for models supporting it.
**kwargs (Any): Additional keyword arguments to pass to the Cohere ClientV2.embed method.
"""
response = self.client.embed(
texts=[text],
model=self.model,
output_dimension=dimensions,
**kwargs,
)
return response.embeddings[0] # type: ignore
return response.embeddings.float[0] # type: ignore
9 changes: 6 additions & 3 deletions src/neo4j_graphrag/embeddings/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,19 @@ def __init__(self, model: str = "mistral-embed", **kwargs: Any) -> None:
self.model = model
self.mistral_client = Mistral(api_key=api_key, **kwargs)

def embed_query(self, text: str, **kwargs: Any) -> list[float]:
def embed_query(
self, text: str, dimensions: int | None = None, **kwargs: Any
) -> list[float]:
"""
Generate embeddings for a given query using a Mistral AI text embedding model.

Args:
text (str): The text to generate an embedding for.
**kwargs (Any): Additional keyword arguments to pass to the Mistral AI client.
dimensions (Optional[int]): The number of dimensions the resulting output embeddings should have. Only for models supporting it.
**kwargs (Any): Additional keyword arguments to pass to the embeddings.create method.
"""
embeddings_batch_response = self.mistral_client.embeddings.create(
model=self.model, inputs=[text], **kwargs
model=self.model, inputs=[text], output_dimension=dimensions, **kwargs
)
if embeddings_batch_response is None or not embeddings_batch_response.data:
raise EmbeddingsGenerationError("Failed to retrieve embeddings.")
Expand Down
2 changes: 1 addition & 1 deletion src/neo4j_graphrag/embeddings/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, model: str, **kwargs: Any) -> None:
self.model = model
self.client = ollama.Client(**kwargs)

def embed_query(self, text: str, **kwargs: Any) -> list[float]:
def embed_query(self, text: str, **kwargs: Any) -> list[float]: # type: ignore[override]
"""
Generate embeddings for a given query using an Ollama text embedding model.

Expand Down
13 changes: 11 additions & 2 deletions src/neo4j_graphrag/embeddings/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import abc
from typing import TYPE_CHECKING, Any

from openai import NotGiven

from neo4j_graphrag.embeddings.base import Embedder

if TYPE_CHECKING:
Expand Down Expand Up @@ -51,15 +53,22 @@ def _initialize_client(self, **kwargs: Any) -> Any:
"""
pass

def embed_query(self, text: str, **kwargs: Any) -> list[float]:
def embed_query(
self, text: str, dimensions: int | None = None, **kwargs: Any
) -> list[float]:
"""
Generate embeddings for a given query using an OpenAI text embedding model.

Args:
text (str): The text to generate an embedding for.
dimensions (Optional[int]): The number of dimensions the resulting output embeddings should have. Only for models supporting it.

**kwargs (Any): Additional arguments to pass to the OpenAI embedding generation function.
"""
response = self.client.embeddings.create(input=text, model=self.model, **kwargs)
d = dimensions or NotGiven()
response = self.client.embeddings.create(
input=text, model=self.model, dimensions=d, **kwargs
)
embedding: list[float] = response.data[0].embedding
return embedding

Expand Down
2 changes: 1 addition & 1 deletion src/neo4j_graphrag/embeddings/sentence_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
self.np = np
self.model = sentence_transformers.SentenceTransformer(model, *args, **kwargs)

def embed_query(self, text: str) -> Any:
def embed_query(self, text: str, **kwargs: Any) -> Any: # type: ignore[override]
result = self.model.encode([text])
if isinstance(result, self.torch.Tensor) or isinstance(result, self.np.ndarray):
return result.flatten().tolist()
Expand Down
21 changes: 12 additions & 9 deletions src/neo4j_graphrag/embeddings/vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,17 @@
# limitations under the License.
from __future__ import annotations

from typing import Any, TYPE_CHECKING
from typing import Any

from neo4j_graphrag.embeddings.base import Embedder


try:
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
except (ImportError, AttributeError):
TextEmbeddingModel = TextEmbeddingInput = None # type: ignore[misc, assignment]


if TYPE_CHECKING:
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel


class VertexAIEmbeddings(Embedder):
"""
Vertex AI embeddings class.
Expand All @@ -45,18 +42,24 @@ def __init__(self, model: str = "text-embedding-004") -> None:
)
self.model = TextEmbeddingModel.from_pretrained(model)

def embed_query(
self, text: str, task_type: str = "RETRIEVAL_QUERY", **kwargs: Any
def embed_query( # type: ignore[override]
self,
text: str,
task_type: str = "RETRIEVAL_QUERY",
dimensions: int | None = None,
**kwargs: Any,
) -> list[float]:
"""
Generate embeddings for a given query using a Vertex AI text embedding model.

Args:
text (str): The text to generate an embedding for.
dimensions (Optional[int]): The number of dimensions the resulting output embeddings should have. Only for models supporting it.
task_type (str): The type of the text embedding task. Defaults to "RETRIEVAL_QUERY". See https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#tasktype for a full list.
**kwargs (Any): Additional keyword arguments to pass to the Vertex AI client's get_embeddings method.
"""
# type annotation needed for mypy
inputs: list[str | TextEmbeddingInput] = [TextEmbeddingInput(text, task_type)]
embeddings = self.model.get_embeddings(inputs, **kwargs)
embeddings = self.model.get_embeddings(
inputs, output_dimensionality=dimensions, **kwargs
)
return embeddings[0].values
7 changes: 5 additions & 2 deletions tests/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,11 @@ def embedder() -> Embedder:


class RandomEmbedder(Embedder):
def embed_query(self, text: str) -> list[float]:
return [random.random() for _ in range(1536)]
def embed_query(
self, text: str, dimensions: int | None = None, **kwargs: Any
) -> list[float]:
d = dimensions or 1536
return [random.random() for _ in range(d)]


class BiologyEmbedder(Embedder):
Expand Down
Loading