Skip to content

Commit 1090784

Browse files
authored
[feat add]FastEmbed embedding for local embeddings (#3552)
1 parent 394203d commit 1090784

File tree

5 files changed

+78
-0
lines changed

5 files changed

+78
-0
lines changed

mem0/embeddings/configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def validate_config(cls, v, values):
2424
"lmstudio",
2525
"langchain",
2626
"aws_bedrock",
27+
"fastembed",
2728
]:
2829
return v
2930
else:

mem0/embeddings/fastembed.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from typing import Optional, Literal
2+
3+
from mem0.embeddings.base import EmbeddingBase
4+
from mem0.configs.embeddings.base import BaseEmbedderConfig
5+
6+
try:
7+
from fastembed import TextEmbedding
8+
except ImportError:
9+
raise ImportError("FastEmbed is not installed. Please install it using `pip install fastembed`")
10+
11+
class FastEmbedEmbedding(EmbeddingBase):
12+
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
13+
super().__init__(config)
14+
15+
self.config.model = self.config.model or "thenlper/gte-large"
16+
self.dense_model = TextEmbedding(model_name = self.config.model)
17+
18+
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
19+
"""
20+
Convert the text to embeddings using FastEmbed running in the Onnx runtime
21+
Args:
22+
text (str): The text to embed.
23+
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
24+
Returns:
25+
list: The embedding vector.
26+
"""
27+
text = text.replace("\n", " ")
28+
embeddings = list(self.dense_model.embed(text))
29+
return embeddings[0]

mem0/utils/factory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ class EmbedderFactory:
145145
"lmstudio": "mem0.embeddings.lmstudio.LMStudioEmbedding",
146146
"langchain": "mem0.embeddings.langchain.LangchainEmbedding",
147147
"aws_bedrock": "mem0.embeddings.aws_bedrock.AWSBedrockEmbedding",
148+
"fastembed": "mem0.embeddings.fastembed.FastEmbedEmbedding",
148149
}
149150

150151
@classmethod

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ extras = [
7272
"sentence-transformers>=5.0.0",
7373
"elasticsearch>=8.0.0,<9.0.0",
7474
"opensearch-py>=2.0.0",
75+
"fastembed>=0.3.1",
7576
]
7677
test = [
7778
"pytest>=8.2.2",
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from unittest.mock import Mock, patch
2+
3+
import pytest
4+
import numpy as np
5+
from mem0.configs.embeddings.base import BaseEmbedderConfig
6+
7+
try:
8+
from mem0.embeddings.fastembed import FastEmbedEmbedding
9+
except ImportError:
10+
pytest.skip("fastembed not installed", allow_module_level=True)
11+
12+
13+
@pytest.fixture
14+
def mock_fastembed_client():
15+
with patch("mem0.embeddings.fastembed.TextEmbedding") as mock_fastembed:
16+
mock_client = Mock()
17+
mock_fastembed.return_value = mock_client
18+
yield mock_client
19+
20+
21+
def test_embed_with_jina_model(mock_fastembed_client):
22+
config = BaseEmbedderConfig(model="jinaai/jina-embeddings-v2-base-en", embedding_dims=768)
23+
embedder = FastEmbedEmbedding(config)
24+
25+
mock_embedding = np.array([0.1, 0.2, 0.3, 0.4, 0.5])
26+
mock_fastembed_client.embed.return_value = iter([mock_embedding])
27+
28+
text = "Sample text to embed."
29+
embedding = embedder.embed(text)
30+
31+
mock_fastembed_client.embed.assert_called_once_with(text)
32+
assert list(embedding) == [0.1, 0.2, 0.3, 0.4, 0.5]
33+
34+
35+
def test_embed_removes_newlines(mock_fastembed_client):
36+
config = BaseEmbedderConfig(model="jinaai/jina-embeddings-v2-base-en", embedding_dims=768)
37+
embedder = FastEmbedEmbedding(config)
38+
39+
mock_embedding = np.array([0.7, 0.8, 0.9])
40+
mock_fastembed_client.embed.return_value = iter([mock_embedding])
41+
42+
text_with_newlines = "Hello\nworld"
43+
embedding = embedder.embed(text_with_newlines)
44+
45+
mock_fastembed_client.embed.assert_called_once_with("Hello world")
46+
assert list(embedding) == [0.7, 0.8, 0.9]

0 commit comments

Comments
 (0)