Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug/Store RAG Document Metadata Subdocs in Separate Table #223

Merged
merged 11 commits into from
Jan 15, 2025
58 changes: 55 additions & 3 deletions lambda/models/domain_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,19 @@

"""Domain objects for interacting with the model endpoints."""

import logging
import time
import uuid
from enum import Enum
from typing import Annotated, Any, Dict, List, Optional, Union
from typing import Annotated, Any, Dict, List, Optional, TypeAlias, Union

from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt, PositiveInt
from pydantic.functional_validators import AfterValidator, field_validator, model_validator
from typing_extensions import Self
from utilities.validators import validate_all_fields_defined, validate_any_fields_defined, validate_instance_type

logger = logging.getLogger(__name__)


class InferenceContainer(str, Enum):
"""Enum representing the interface container type."""
Expand Down Expand Up @@ -312,6 +315,22 @@ class IngestionType(Enum):
MANUAL = "manual"


RagDocumentDict: TypeAlias = Dict[str, Any]
estohlmann marked this conversation as resolved.
Show resolved Hide resolved


class RagSubDocument(BaseModel):
"""Rag Sub-Document Entity for storing in DynamoDB."""

document_id: str
subdocs: list[str] = Field(default_factory=lambda: [])
index: int = Field(exclude=True)
sk: Optional[str] = None

def __init__(self, **data: Any) -> None:
super().__init__(**data)
self.sk = f"subdoc#{self.document_id}#{self.index}"


class RagDocument(BaseModel):
"""Rag Document Entity for storing in DynamoDB."""

Expand All @@ -322,16 +341,49 @@ class RagDocument(BaseModel):
document_name: str
source: str
username: str
sub_docs: List[str] = Field(default_factory=lambda: [])
subdocs: List[str] = Field(default_factory=lambda: [], exclude=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO this feels kinda clunky having this on this object considering it isn't persisted or returned to the frontend. I'd just use the result of join_docs result directly instead of trying to store it on the RagDocument. That way you can stop worrying about that in the find* methods too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeping for now since this is a convenient way to pass along the subdocs as part of the document. It isn't queried with the API unless explicitly requested.

chunk_size: int
chunk_overlap: int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Especially after yesterday's conversation of having different strategies, I could see these fields being their own object so we can easily represent multiple strategies. Something like what we ended up with for features maybe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm looking into storing a map for the strategy. It may not make it into this release.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to store map

ingestion_type: IngestionType = Field(default_factory=lambda: IngestionType.MANUAL)
upload_date: int = Field(default_factory=lambda: int(time.time()))

chunks: Optional[int] = 0
model_config = ConfigDict(use_enum_values=True, validate_default=True)

def __init__(self, **data: Any) -> None:
super().__init__(**data)
self.pk = self.createPartitionKey(self.repository_id, self.collection_id)
self.chunks = len(self.subdocs)

@staticmethod
def createPartitionKey(repository_id: str, collection_id: str) -> str:
return f"{repository_id}#{collection_id}"

def chunk_doc(self, chunk_size: int = 1000) -> list[RagSubDocument]:
"""Chunk the document into smaller sub-documents."""
chunked_docs: list[RagSubDocument] = []
for i in range(0, len(self.subdocs), chunk_size):
subdocs = self.subdocs[i : i + chunk_size]
logging.info(f"Chunking document {self.document_id} into {subdocs} sub-documents {i}")
chunk = RagSubDocument(document_id=self.document_id, subdocs=subdocs, index=i)
chunked_docs.append(chunk)
return chunked_docs

@staticmethod
def join_docs(documents: List[RagDocumentDict]) -> List[RagDocumentDict]:
"""Join the multiple sub-documents into a single document."""
# Group documents by document_id
grouped_docs: dict[str, List[RagDocumentDict]] = {}
for doc in documents:
doc_id = doc.get("document_id", "")
if doc_id not in grouped_docs:
grouped_docs[doc_id] = []
grouped_docs[doc_id].append(doc)

# Join same document_id into single RagDocument
joined_docs: List[RagDocumentDict] = []
for docs in grouped_docs.values():
joined_doc = docs[0]
joined_doc["subdocs"] = [sub_doc for doc in docs for sub_doc in (doc.get("subdocs", []) or [])]
joined_docs.append(joined_doc)

return joined_docs
51 changes: 32 additions & 19 deletions lambda/repository/lambda_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import json
import logging
import os
from typing import Any, Dict, List
from typing import Any, cast, Dict, List

import boto3
import requests
Expand Down Expand Up @@ -49,7 +49,7 @@
)
lisa_api_endpoint = ""
registered_repositories: List[Dict[str, Any]] = []
doc_repo = RagDocumentRepository(os.environ["RAG_DOCUMENT_TABLE"])
doc_repo = RagDocumentRepository(os.environ["RAG_DOCUMENT_TABLE"], os.environ["RAG_SUB_DOCUMENT_TABLE"])


def _get_embeddings(model_name: str, id_token: str) -> LisaOpenAIEmbeddings:
Expand Down Expand Up @@ -295,34 +295,42 @@ def delete_document(event: dict, context: dict) -> Dict[str, Any]:
document_id = query_string_params.get("documentId")
document_name = query_string_params.get("documentName")

ensure_repository_access(event, find_repository_by_id(repository_id))

if not document_id and not document_name:
raise ValidationError("Either documentId or documentName must be specified")
if document_id and document_name:
raise ValidationError("Only one of documentId or documentName must be specified")

docs = []
docs: list[RagDocument.model_dump] = []
if document_id:
docs = [doc_repo.find_by_id(document_id)]
docs = [doc_repo.find_by_id(repository_id=repository_id, document_id=document_id, join_docs=True)]
elif document_name:
docs = doc_repo.find_by_name(repository_id, collection_id, document_name)
docs = doc_repo.find_by_name(
repository_id=repository_id, collection_id=collection_id, document_name=document_name, join_docs=True
)

if not docs:
raise ValueError(f"No documents found in repository collection {repository_id}:{collection_id}")

# Grab all sub document ids related to the parent document(s)
subdoc_ids = [sub_doc for doc in docs for sub_doc in doc.get("sub_docs", [])]

id_token = get_id_token(event)
embeddings = _get_embeddings(model_name=collection_id, id_token=id_token)
vs = get_vector_store_client(repository_id=repository_id, index=collection_id, embeddings=embeddings)

vs.delete(ids=subdoc_ids)
for doc in docs:
vs.delete(ids=doc.get("subdocs"))

for doc in docs:
doc_repo.delete_by_id(repository_id=repository_id, document_id=doc.get("document_id"))

doc_repo.batch_delete(docs)
doc_ids = {doc.get("document_id") for doc in docs}
subdoc_ids = []
for doc in docs:
subdoc_ids.extend(doc.get("subdocs"))

return {
"documentName": docs[0].get("document_name"),
"removedDocuments": len(docs),
"removedDocuments": len(doc_ids),
"removedDocumentChunks": len(subdoc_ids),
}

Expand Down Expand Up @@ -370,7 +378,7 @@ def ingest_documents(event: dict, context: dict) -> dict:

texts = [] # list of strings
metadatas = [] # list of dicts
all_ids = []
doc_entities = []
id_token = get_id_token(event)
embeddings = _get_embeddings(model_name=model_name, id_token=id_token)
vs = get_vector_store_client(repository_id, index=model_name, embeddings=embeddings)
Expand All @@ -391,15 +399,21 @@ def ingest_documents(event: dict, context: dict) -> dict:
collection_id=model_name,
document_name=document_name,
source=doc_source,
sub_docs=ids,
subdocs=ids,
username=username,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
ingestion_type=IngestionType.MANUAL,
)
doc_repo.save(doc_entity)
doc_entities.append(doc_entity)

all_ids.extend(ids)

return {"ids": all_ids, "count": len(all_ids)}
doc_ids = (doc.document_id for doc in doc_entities)
subdoc_ids = [sub_id for doc in doc_entities for sub_id in doc.subdocs]
return {
"documentIds": doc_ids,
"chunkCount": len(subdoc_ids),
}


@api_wrapper
Expand Down Expand Up @@ -445,7 +459,7 @@ def presigned_url(event: dict, context: dict) -> dict:


@api_wrapper
def list_docs(event: dict, context: dict) -> List[RagDocument]:
def list_docs(event: dict, context: dict) -> list[RagDocument.model_dump]:
"""List all documents for a given repository/collection.

Args:
Expand All @@ -468,5 +482,4 @@ def list_docs(event: dict, context: dict) -> List[RagDocument]:
query_string_params = event.get("queryStringParameters", {})
collection_id = query_string_params.get("collectionId")

docs: List[RagDocument] = doc_repo.list_all(repository_id, collection_id)
return docs
return cast(list[RagDocument.model_dump], doc_repo.list_all(repository_id, collection_id))
6 changes: 4 additions & 2 deletions lambda/repository/pipeline_ingest_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
session = boto3.Session()
ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config)

doc_repo = RagDocumentRepository(os.environ["RAG_DOCUMENT_TABLE"])
doc_repo = RagDocumentRepository(os.environ["RAG_DOCUMENT_TABLE"], os.environ["RAG_SUB_DOCUMENT_TABLE"])


def batch_texts(texts: List[str], metadatas: List[Dict], batch_size: int = 500) -> list[tuple[list[str], list[dict]]]:
Expand Down Expand Up @@ -158,7 +158,9 @@ def handle_pipeline_ingest_documents(event: Dict[str, Any], context: Any) -> Dic
collection_id=embedding_model,
document_name=key,
source=docs[0][0].metadata.get("source"),
sub_docs=all_ids,
subdocs=all_ids,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
username=username,
ingestion_type=IngestionType.AUTO,
)
Expand Down
Loading
Loading