Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions hindsight-api-slim/hindsight_api/engine/retain/embedding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,26 @@
logger = logging.getLogger(__name__)


def _expected_dimension(embeddings_backend) -> int | None:
"""Return the backend's expected embedding dimension when available."""
try:
dimension = getattr(embeddings_backend, "dimension", None)
except Exception:
return None
return dimension if isinstance(dimension, int) and dimension > 0 else None


def _validate_embedding_vector(vector: list[float], *, index: int, expected_dimension: int | None) -> list[float]:
actual_dimension = len(vector)
if actual_dimension == 0:
if expected_dimension is None:
raise RuntimeError(f"embedding {index} has dimension 0; expected non-empty vector")
raise RuntimeError(f"embedding {index} has dimension 0; expected {expected_dimension}")
if expected_dimension is not None and actual_dimension != expected_dimension:
raise RuntimeError(f"embedding {index} has dimension {actual_dimension}; expected {expected_dimension}")
return vector


def generate_embedding(embeddings_backend, text: str) -> list[float]:
"""
Generate embedding for text using the provided embeddings backend.
Expand All @@ -21,7 +41,16 @@ def generate_embedding(embeddings_backend, text: str) -> list[float]:
"""
try:
embeddings = embeddings_backend.encode([text])
return embeddings[0]
if len(embeddings) != 1:
raise RuntimeError(
f"Embeddings backend returned {len(embeddings)} vectors for 1 input text; "
"expected exact 1:1 alignment"
)
return _validate_embedding_vector(
embeddings[0],
index=0,
expected_dimension=_expected_dimension(embeddings_backend),
)
except Exception as e:
raise Exception(f"Failed to generate embedding: {str(e)}")

Expand Down Expand Up @@ -59,4 +88,8 @@ async def generate_embeddings_batch(embeddings_backend, texts: list[str]) -> lis
"expected exact 1:1 alignment"
)

return embeddings
expected_dimension = _expected_dimension(embeddings_backend)
return [
_validate_embedding_vector(embedding, index=index, expected_dimension=expected_dimension)
for index, embedding in enumerate(embeddings)
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import uuid

import pytest

from hindsight_api.engine.consolidation import consolidator


class _ZeroLengthEmbeddings:
dimension = 384

def encode(self, texts):
assert texts == ["Consolidated observation text."]
return [[]]


class _FakeMemoryEngine:
embeddings = _ZeroLengthEmbeddings()


class _FailingConn:
async def fetchrow(self, *args, **kwargs):
raise AssertionError("zero-length embedding should be rejected before database insert")


@pytest.mark.asyncio
async def test_create_observation_rejects_zero_length_embedding_before_insert(monkeypatch):
source_id = uuid.uuid4()

async def fake_filter_live_source_memories(conn, bank_id, source_memory_ids):
return source_memory_ids

monkeypatch.setattr(consolidator, "_filter_live_source_memories", fake_filter_live_source_memories)

with pytest.raises(RuntimeError, match="embedding 0 has dimension 0; expected 384"):
await consolidator._create_observation_directly(
conn=_FailingConn(),
memory_engine=_FakeMemoryEngine(),
bank_id="test-bank",
source_memory_ids=[source_id],
observation_text="Consolidated observation text.",
)
16 changes: 16 additions & 0 deletions hindsight-api-slim/tests/test_retain_orchestrator_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,19 @@ def test_passes_through_aligned_embeddings(self):
result = asyncio.run(embedding_utils.generate_embeddings_batch(backend, ["a", "b"]))

assert result == [[0.1], [0.2]]

def test_raises_when_backend_returns_empty_embedding_vector(self):
backend = MagicMock()
backend.dimension = 3
backend.encode.return_value = [[0.1, 0.2, 0.3], []]

with pytest.raises(RuntimeError, match="embedding 1 has dimension 0; expected 3"):
asyncio.run(embedding_utils.generate_embeddings_batch(backend, ["a", "b"]))

def test_raises_when_backend_returns_wrong_embedding_dimension(self):
backend = MagicMock()
backend.dimension = 3
backend.encode.return_value = [[0.1, 0.2, 0.3], [0.4, 0.5]]

with pytest.raises(RuntimeError, match="embedding 1 has dimension 2; expected 3"):
asyncio.run(embedding_utils.generate_embeddings_batch(backend, ["a", "b"]))