Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ async def get_api_key(credentials: HTTPAuthorizationCredentials = Depends(securi
return credentials.credentials


def get_and_format_embedding(request: EmbeddingCreateRequest) -> List[float]:
"""Get or generate embedding for the request"""
if request.embedding is not None and len(request.embedding) > 0:
return "[" + ",".join(map(str, request.embedding)) + "]"
elif (request.embedding is None or len(request.embedding) == 0) and request.content is not None:
embeddings = embedding_service.generate_embedding(request.content)
return "[" + ",".join(map(str, embeddings)) + "]"
else:
raise ValueError("Embedding must be provided or content must be non-empty")

@app.on_event("startup")
async def startup():
"""Connect to database on startup"""
Expand Down Expand Up @@ -338,8 +348,7 @@ async def create_embedding(
if not vector_store_result:
raise HTTPException(status_code=404, detail="Vector store not found")

# Convert embedding to vector string format
embedding_vector_str = "[" + ",".join(map(str, request.embedding)) + "]"
embedding_vector_str = get_and_format_embedding(request)

# Insert embedding using configurable field names
fields = settings.db_fields
Expand Down Expand Up @@ -434,7 +443,7 @@ async def create_embeddings_batch(
param_count = 1

for embedding_req in request.embeddings:
embedding_vector_str = "[" + ",".join(map(str, embedding_req.embedding)) + "]"
embedding_vector_str = get_and_format_embedding(embedding_req)
values_clauses.append(f"(gen_random_uuid(), ${param_count}, ${param_count + 1}, ${param_count + 2}::vector, ${param_count + 3}, NOW())")
params.extend([
vector_store_id,
Expand Down