Skip to content

Commit

Permalink
Bug/Store RAG Document Metadata Subdocs in Separate Table (#223)
Browse files Browse the repository at this point in the history
* Move subdocs to separate table
  • Loading branch information
bedanley authored Jan 15, 2025
1 parent 5dd8bce commit f3fc00e
Show file tree
Hide file tree
Showing 7 changed files with 273 additions and 109 deletions.
62 changes: 59 additions & 3 deletions lambda/models/domain_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,19 @@

"""Domain objects for interacting with the model endpoints."""

import logging
import time
import uuid
from enum import Enum
from typing import Annotated, Any, Dict, List, Optional, Union
from typing import Annotated, Any, Dict, Generator, List, Optional, TypeAlias, Union

from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt, PositiveInt
from pydantic.functional_validators import AfterValidator, field_validator, model_validator
from typing_extensions import Self
from utilities.validators import validate_all_fields_defined, validate_any_fields_defined, validate_instance_type

logger = logging.getLogger(__name__)


class InferenceContainer(str, Enum):
"""Enum representing the interface container type."""
Expand Down Expand Up @@ -312,6 +315,28 @@ class IngestionType(Enum):
MANUAL = "manual"


RagDocumentDict: TypeAlias = Dict[str, Any]


class ChunkStrategyType(Enum):
"""Enum for different types of chunking strategies."""

FIXED = "fixed"


class RagSubDocument(BaseModel):
"""Rag Sub-Document Entity for storing in DynamoDB."""

document_id: str
subdocs: list[str] = Field(default_factory=lambda: [])
index: int = Field(exclude=True)
sk: Optional[str] = None

def __init__(self, **data: Any) -> None:
super().__init__(**data)
self.sk = f"subdoc#{self.document_id}#{self.index}"


class RagDocument(BaseModel):
"""Rag Document Entity for storing in DynamoDB."""

Expand All @@ -322,16 +347,47 @@ class RagDocument(BaseModel):
document_name: str
source: str
username: str
sub_docs: List[str] = Field(default_factory=lambda: [])
subdocs: List[str] = Field(default_factory=lambda: [], exclude=True)
chunk_strategy: dict[str, str] = {}
ingestion_type: IngestionType = Field(default_factory=lambda: IngestionType.MANUAL)
upload_date: int = Field(default_factory=lambda: int(time.time()))

chunks: Optional[int] = 0
model_config = ConfigDict(use_enum_values=True, validate_default=True)

def __init__(self, **data: Any) -> None:
super().__init__(**data)
self.pk = self.createPartitionKey(self.repository_id, self.collection_id)
self.chunks = len(self.subdocs)

@staticmethod
def createPartitionKey(repository_id: str, collection_id: str) -> str:
return f"{repository_id}#{collection_id}"

def chunk_doc(self, chunk_size: int = 1000) -> Generator[RagSubDocument, None, None]:
"""Chunk the document into smaller sub-documents."""
total_subdocs = len(self.subdocs)
for start_index in range(0, total_subdocs, chunk_size):
end_index = min(start_index + chunk_size, total_subdocs)
yield RagSubDocument(
document_id=self.document_id, subdocs=self.subdocs[start_index:end_index], index=start_index
)

@staticmethod
def join_docs(documents: List[RagDocumentDict]) -> List[RagDocumentDict]:
"""Join the multiple sub-documents into a single document."""
# Group documents by document_id
grouped_docs: dict[str, List[RagDocumentDict]] = {}
for doc in documents:
doc_id = doc.get("document_id", "")
if doc_id not in grouped_docs:
grouped_docs[doc_id] = []
grouped_docs[doc_id].append(doc)

# Join same document_id into single RagDocument
joined_docs: List[RagDocumentDict] = []
for docs in grouped_docs.values():
joined_doc = docs[0]
joined_doc["subdocs"] = [sub_doc for doc in docs for sub_doc in (doc.get("subdocs", []) or [])]
joined_docs.append(joined_doc)

return joined_docs
62 changes: 41 additions & 21 deletions lambda/repository/lambda_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import requests
from botocore.config import Config
from lisapy.langchain import LisaOpenAIEmbeddings
from models.domain_objects import IngestionType, RagDocument
from models.domain_objects import ChunkStrategyType, IngestionType, RagDocument
from repository.rag_document_repo import RagDocumentRepository
from utilities.common_functions import api_wrapper, get_cert_path, get_groups, get_id_token, get_username, retry_config
from utilities.exceptions import HTTPException
Expand All @@ -49,7 +49,7 @@
)
lisa_api_endpoint = ""
registered_repositories: List[Dict[str, Any]] = []
doc_repo = RagDocumentRepository(os.environ["RAG_DOCUMENT_TABLE"])
doc_repo = RagDocumentRepository(os.environ["RAG_DOCUMENT_TABLE"], os.environ["RAG_SUB_DOCUMENT_TABLE"])


def _get_embeddings(model_name: str, id_token: str) -> LisaOpenAIEmbeddings:
Expand Down Expand Up @@ -287,34 +287,42 @@ def delete_document(event: dict, context: dict) -> Dict[str, Any]:
document_id = query_string_params.get("documentId")
document_name = query_string_params.get("documentName")

ensure_repository_access(event, find_repository_by_id(repository_id))

if not document_id and not document_name:
raise ValidationError("Either documentId or documentName must be specified")
if document_id and document_name:
raise ValidationError("Only one of documentId or documentName must be specified")

docs = []
docs: list[RagDocument.model_dump] = []
if document_id:
docs = [doc_repo.find_by_id(document_id)]
docs = [doc_repo.find_by_id(repository_id=repository_id, document_id=document_id, join_docs=True)]
elif document_name:
docs = doc_repo.find_by_name(repository_id, collection_id, document_name)
docs = doc_repo.find_by_name(
repository_id=repository_id, collection_id=collection_id, document_name=document_name, join_docs=True
)

if not docs:
raise ValueError(f"No documents found in repository collection {repository_id}:{collection_id}")

# Grab all sub document ids related to the parent document(s)
subdoc_ids = [sub_doc for doc in docs for sub_doc in doc.get("sub_docs", [])]

id_token = get_id_token(event)
embeddings = _get_embeddings(model_name=collection_id, id_token=id_token)
vs = get_vector_store_client(repository_id=repository_id, index=collection_id, embeddings=embeddings)

vs.delete(ids=subdoc_ids)
for doc in docs:
vs.delete(ids=doc.get("subdocs"))

for doc in docs:
doc_repo.delete_by_id(repository_id=repository_id, document_id=doc.get("document_id"))

doc_repo.batch_delete(docs)
doc_ids = {doc.get("document_id") for doc in docs}
subdoc_ids = []
for doc in docs:
subdoc_ids.extend(doc.get("subdocs"))

return {
"documentName": docs[0].get("document_name"),
"removedDocuments": len(docs),
"removedDocuments": len(doc_ids),
"removedDocumentChunks": len(subdoc_ids),
}

Expand Down Expand Up @@ -362,7 +370,7 @@ def ingest_documents(event: dict, context: dict) -> dict:

texts = [] # list of strings
metadatas = [] # list of dicts
all_ids = []
doc_entities = []
id_token = get_id_token(event)
embeddings = _get_embeddings(model_name=model_name, id_token=id_token)
vs = get_vector_store_client(repository_id, index=model_name, embeddings=embeddings)
Expand All @@ -383,15 +391,24 @@ def ingest_documents(event: dict, context: dict) -> dict:
collection_id=model_name,
document_name=document_name,
source=doc_source,
sub_docs=ids,
subdocs=ids,
username=username,
chunk_strategy={
"type": ChunkStrategyType.FIXED.value,
"size": str(chunk_size),
"overlap": str(chunk_overlap),
},
ingestion_type=IngestionType.MANUAL,
)
doc_repo.save(doc_entity)
doc_entities.append(doc_entity)

all_ids.extend(ids)

return {"ids": all_ids, "count": len(all_ids)}
doc_ids = (doc.document_id for doc in doc_entities)
subdoc_ids = [sub_id for doc in doc_entities for sub_id in doc.subdocs]
return {
"documentIds": doc_ids,
"chunkCount": len(subdoc_ids),
}


@api_wrapper
Expand Down Expand Up @@ -437,7 +454,7 @@ def presigned_url(event: dict, context: dict) -> dict:


@api_wrapper
def list_docs(event: dict, context: dict) -> List[RagDocument]:
def list_docs(event: dict, context: dict) -> dict[str, list[RagDocument.model_dump] | str | None]:
"""List all documents for a given repository/collection.
Args:
Expand All @@ -447,8 +464,8 @@ def list_docs(event: dict, context: dict) -> List[RagDocument]:
context (dict): The Lambda context object
Returns:
list[RagDocument]: A list of RagDocument objects representing all documents
in the specified collection
Tuple list[RagDocument], dict[lastEvaluatedKey]: A list of RagDocument objects representing all documents
in the specified collection and the last evaluated key for pagination
Raises:
KeyError: If collectionId is not provided in queryStringParameters
Expand All @@ -459,6 +476,9 @@ def list_docs(event: dict, context: dict) -> List[RagDocument]:

query_string_params = event.get("queryStringParameters", {})
collection_id = query_string_params.get("collectionId")
last_evaluated = query_string_params.get("lastEvaluated")

docs: List[RagDocument] = doc_repo.list_all(repository_id, collection_id)
return docs
docs, last_evaluated = doc_repo.list_all(
repository_id=repository_id, collection_id=collection_id, last_evaluated_key=last_evaluated
)
return {"documents": docs, "lastEvaluated": last_evaluated}
11 changes: 8 additions & 3 deletions lambda/repository/pipeline_ingest_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
from utilities.validation import validate_chunk_params, validate_model_name, validate_repository_type, ValidationError
from utilities.vector_store import get_vector_store_client

from .lambda_functions import _get_embeddings_pipeline, IngestionType, RagDocument
from .lambda_functions import _get_embeddings_pipeline, ChunkStrategyType, IngestionType, RagDocument

logger = logging.getLogger(__name__)
session = boto3.Session()
ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config)

doc_repo = RagDocumentRepository(os.environ["RAG_DOCUMENT_TABLE"])
doc_repo = RagDocumentRepository(os.environ["RAG_DOCUMENT_TABLE"], os.environ["RAG_SUB_DOCUMENT_TABLE"])


def batch_texts(texts: List[str], metadatas: List[Dict], batch_size: int = 500) -> list[tuple[list[str], list[dict]]]:
Expand Down Expand Up @@ -158,7 +158,12 @@ def handle_pipeline_ingest_documents(event: Dict[str, Any], context: Any) -> Dic
collection_id=embedding_model,
document_name=key,
source=docs[0][0].metadata.get("source"),
sub_docs=all_ids,
subdocs=all_ids,
chunk_strategy={
"type": ChunkStrategyType.FIXED.value,
"size": str(chunk_size),
"overlap": str(chunk_overlap),
},
username=username,
ingestion_type=IngestionType.AUTO,
)
Expand Down
Loading

0 comments on commit f3fc00e

Please sign in to comment.