Skip to content

Commit

Permalink
Add RAG document manager to LISA UI (#229)
Browse files Browse the repository at this point in the history
* Add document library page
* Remove collectionName from pipeline RAG
* Update Delete API to accept multiple IDs
* Add LocalStorage for all UI config pages
  • Loading branch information
bedanley authored Jan 29, 2025
1 parent 2bacb77 commit f21e1c2
Show file tree
Hide file tree
Showing 44 changed files with 7,851 additions and 11,421 deletions.
929 changes: 0 additions & 929 deletions ecs_model_deployer/package-lock.json

This file was deleted.

4 changes: 2 additions & 2 deletions ecs_model_deployer/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
"license": "Apache-2.0",
"dependencies": {
"@cdklabs/cdk-enterprise-iac": "^0.0.512",
"aws-cdk": "^2.153.0",
"aws-cdk-lib": "^2.153.0",
"aws-cdk-lib": "^2.176.0",
"aws-cdk": "^2.176.0",
"zod": "^3.23.8"
},
"devDependencies": {
Expand Down
3 changes: 2 additions & 1 deletion ecs_model_deployer/src/lib/ecsCluster.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
Cluster,
ContainerDefinition,
ContainerImage,
ContainerInsights,
Ec2Service,
Ec2ServiceProps,
Ec2TaskDefinition,
Expand Down Expand Up @@ -89,7 +90,7 @@ export class ECSCluster extends Construct {
const cluster = new Cluster(this, createCdkId([ecsConfig.identifier, 'Cl']), {
clusterName: createCdkId([config.deploymentName, ecsConfig.identifier], 32, 2),
vpc: vpc,
containerInsights: !config.region.includes('iso'),
containerInsightsV2: !config.region.includes('iso') ? ContainerInsights.ENABLED : ContainerInsights.DISABLED,
});

// Create auto scaling group
Expand Down
2 changes: 1 addition & 1 deletion lambda/models/domain_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ class RagDocument(BaseModel):
subdocs: List[str] = Field(default_factory=lambda: [], exclude=True)
chunk_strategy: dict[str, str] = {}
ingestion_type: IngestionType = Field(default_factory=lambda: IngestionType.MANUAL)
upload_date: int = Field(default_factory=lambda: int(time.time()))
upload_date: int = Field(default_factory=lambda: int(time.time() * 1000))
chunks: Optional[int] = 0
model_config = ConfigDict(use_enum_values=True, validate_default=True)

Expand Down
74 changes: 53 additions & 21 deletions lambda/repository/lambda_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,19 @@
import boto3
import requests
from botocore.config import Config
from langchain_core.vectorstores import VectorStore
from lisapy.langchain import LisaOpenAIEmbeddings
from models.domain_objects import ChunkStrategyType, IngestionType, RagDocument
from repository.rag_document_repo import RagDocumentRepository
from utilities.common_functions import api_wrapper, get_cert_path, get_groups, get_id_token, get_username, retry_config
from utilities.common_functions import (
api_wrapper,
get_cert_path,
get_groups,
get_id_token,
get_username,
is_admin,
retry_config,
)
from utilities.exceptions import HTTPException
from utilities.file_processing import process_record
from utilities.validation import validate_model_name, ValidationError
Expand Down Expand Up @@ -251,14 +260,23 @@ def similarity_search(event: dict, context: dict) -> Dict[str, Any]:


def ensure_repository_access(event: dict[str, Any], repository: dict[str, Any]) -> None:
"Ensures a user has access to the repository or else raises an HTTPException"
"""Ensures a user has access to the repository or else raises an HTTPException"""
user_groups = json.loads(event["requestContext"]["authorizer"]["groups"]) or []
if not user_has_group(user_groups, repository["allowedGroups"]):
raise HTTPException(status_code=403, message="User does not have permission to access this repository")


def _ensure_document_ownership(event: dict[str, Any], docs: list[dict[str, Any]]) -> None:
"""Verify ownership of documents"""
username = get_username(event)
admin = is_admin(event)
for doc in docs:
if not (admin or doc.get("username") == username):
raise ValueError(f"Document {doc.get('document_id')} is not owned by {username}")


@api_wrapper
def delete_document(event: dict, context: dict) -> Dict[str, Any]:
def delete_documents(event: dict, context: dict) -> Dict[str, Any]:
"""Purge all records related to the specified document from the RAG repository. If a documentId is supplied, a
single document will be removed. If a documentName is supplied, all documents with that name will be removed
Expand All @@ -267,7 +285,7 @@ def delete_document(event: dict, context: dict) -> Dict[str, Any]:
- pathParameters.repositoryId: The repository id of VectorStore
- queryStringParameters.collectionId: The collection identifier
- queryStringParameters.repositoryType: Type of repository of VectorStore
- queryStringParameters.documentId (optional): Name of document to purge
- queryStringParameters.documentIds (optional): Array of document IDs to purge
- queryStringParameters.documentName (optional): Name of document to purge
context (dict): The Lambda context object
Expand All @@ -281,22 +299,27 @@ def delete_document(event: dict, context: dict) -> Dict[str, Any]:
"""
path_params = event.get("pathParameters", {})
repository_id = path_params.get("repositoryId")

query_string_params = event.get("queryStringParameters", {})
collection_id = query_string_params.get("collectionId")
document_id = query_string_params.get("documentId")
query_string_params = event.get("queryStringParameters", {}) or {}
collection_id = query_string_params.get("collectionId", None)
body = json.loads(event.get("body", ""))
document_ids = body.get("documentIds", None)
document_name = query_string_params.get("documentName")

ensure_repository_access(event, find_repository_by_id(repository_id))
if not document_ids and not document_name:
raise ValidationError("No 'documentIds' or 'documentName' parameter supplied")
if document_ids and document_name:
raise ValidationError("Only one of documentIds or documentName must be specified")
if not collection_id and document_name:
raise ValidationError("A 'collectionId' must be included to delete a document by name")

if not document_id and not document_name:
raise ValidationError("Either documentId or documentName must be specified")
if document_id and document_name:
raise ValidationError("Only one of documentId or documentName must be specified")
ensure_repository_access(event, find_repository_by_id(repository_id))

docs: list[RagDocument.model_dump] = []
if document_id:
docs = [doc_repo.find_by_id(repository_id=repository_id, document_id=document_id, join_docs=True)]
if document_ids:
docs = [
doc_repo.find_by_id(repository_id=repository_id, document_id=doc_id, join_docs=True)
for doc_id in document_ids
]
elif document_name:
docs = doc_repo.find_by_name(
repository_id=repository_id, collection_id=collection_id, document_name=document_name, join_docs=True
Expand All @@ -305,16 +328,25 @@ def delete_document(event: dict, context: dict) -> Dict[str, Any]:
if not docs:
raise ValueError(f"No documents found in repository collection {repository_id}:{collection_id}")

id_token = get_id_token(event)
embeddings = _get_embeddings(model_name=collection_id, id_token=id_token)
vs = get_vector_store_client(repository_id=repository_id, index=collection_id, embeddings=embeddings)
_ensure_document_ownership(event, docs)

id_token = get_id_token(event)
vs_collection_map: dict[str, VectorStore] = {}
for doc in docs:
# Get vector store for document collection
collection_id = doc.get("collection_id")
vs = vs_collection_map.get(collection_id)
if not vs:
embeddings = _get_embeddings(model_name=collection_id, id_token=id_token)
vs = get_vector_store_client(repository_id=repository_id, index=collection_id, embeddings=embeddings)
vs_collection_map[collection_id] = vs
# Delete all document chunks from vector store collection
vs.delete(ids=doc.get("subdocs"))

for doc in docs:
doc_repo.delete_by_id(repository_id=repository_id, document_id=doc.get("document_id"))

# Collect all document parts for summary of deletion
doc_ids = {doc.get("document_id") for doc in docs}
subdoc_ids = []
for doc in docs:
Expand Down Expand Up @@ -471,11 +503,11 @@ def list_docs(event: dict, context: dict) -> dict[str, list[RagDocument.model_du
KeyError: If collectionId is not provided in queryStringParameters
"""

path_params = event.get("pathParameters", {})
path_params = event.get("pathParameters", {}) or {}
repository_id = path_params.get("repositoryId")

query_string_params = event.get("queryStringParameters", {})
collection_id = query_string_params.get("collectionId")
query_string_params = event.get("queryStringParameters", {}) or {}
collection_id = query_string_params.get("collectionId", None)
last_evaluated = query_string_params.get("lastEvaluated")

docs, last_evaluated = doc_repo.list_all(
Expand Down
33 changes: 25 additions & 8 deletions lambda/repository/rag_document_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def find_by_name(
def list_all(
self,
repository_id: str,
collection_id: str,
collection_id: Optional[str],
last_evaluated_key: Optional[dict] = None,
limit: int = 100,
join_docs: bool = False,
Expand All @@ -177,19 +177,36 @@ def list_all(
Args:
repository_id: Repository ID
collection_id: Collection ID
collection_id?: Collection ID
last_evaluated_key: last key for pagination
limit: maximum returned items
join_docs: whether to include subdoc ids with parent doc
Returns:
List of documents
"""
try:
pk = RagDocument.createPartitionKey(repository_id, collection_id)
query_params = {"KeyConditionExpression": Key("pk").eq(pk), "Limit": limit}
if last_evaluated_key:
query_params["ExclusiveStartKey"] = last_evaluated_key
response = self.doc_table.query(**query_params)
response = None
# Find all rag documents using repo id only
if not collection_id:
query_params = {
"IndexName": "repository_index",
"KeyConditionExpression": Key("repository_id").eq(repository_id),
"Limit": limit,
}
if last_evaluated_key:
query_params["ExclusiveStartKey"] = last_evaluated_key
response = self.doc_table.query(**query_params)
# Find all rag documents using repo id and collection
else:
pk = RagDocument.createPartitionKey(repository_id, collection_id)
query_params = {"KeyConditionExpression": Key("pk").eq(pk), "Limit": limit}
if last_evaluated_key:
query_params["ExclusiveStartKey"] = last_evaluated_key
response = self.doc_table.query(**query_params)

docs: list[RagDocumentDict] = response.get("Items", [])
next_key = response.get("LastEvaluatedKey", None)

if join_docs:
for doc in docs:
subdocs = RagDocumentRepository._get_subdoc_ids(self.find_subdocs_by_id(doc.get("document_id")))
Expand Down
5 changes: 2 additions & 3 deletions lambda/repository/state_machine/pipeline_ingest_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,12 @@ def handle_pipeline_ingest_documents(event: Dict[str, Any], context: Any) -> Dic
chunk_size = int(os.environ["CHUNK_SIZE"])
chunk_overlap = int(os.environ["CHUNK_OVERLAP"])
embedding_model = os.environ["EMBEDDING_MODEL"]
collection_name = os.environ["COLLECTION_NAME"]
repository_id = os.environ["REPOSITORY_ID"]
username = get_username(event)

# Initialize document processor and vectorstore
doc_processor = DocumentProcessor()
vectorstore = VectorStore(collection_name=collection_name, embedding_model=embedding_model)
vectorstore = VectorStore(collection_name=embedding_model, embedding_model=embedding_model)

# Download and process document
s3_client = boto3.client("s3")
Expand All @@ -69,7 +68,7 @@ def handle_pipeline_ingest_documents(event: Dict[str, Any], context: Any) -> Dic
# Store in DocTable
doc_entity = RagDocument(
repository_id=repository_id,
collection_id=collection_name,
collection_id=embedding_model,
document_name=key,
source=source,
subdocs=ids,
Expand Down
7 changes: 7 additions & 0 deletions lambda/utilities/common_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,13 @@ def get_username(event: dict) -> str:
return username


def is_admin(event: dict) -> bool:
"""Get admin status from event."""
admin_group = os.environ.get("ADMIN_GROUP", "")
groups = get_groups(event)
return admin_group in groups


def get_session_id(event: dict) -> str:
"""Get session_id from event."""
session_id: str = event.get("pathParameters", {}).get("sessionId")
Expand Down
3 changes: 2 additions & 1 deletion lib/api-base/ecsCluster.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
Cluster,
ContainerDefinition,
ContainerImage,
ContainerInsights,
Ec2Service,
Ec2ServiceProps,
Ec2TaskDefinition,
Expand Down Expand Up @@ -85,7 +86,7 @@ export class ECSCluster extends Construct {
const cluster = new Cluster(this, createCdkId(['Cl']), {
clusterName: createCdkId([config.deploymentName, ecsConfig.identifier], 32, 2),
vpc: vpc.vpc,
containerInsights: !config.region.includes('iso'),
containerInsightsV2: !config.region.includes('iso') ? ContainerInsights.ENABLED : ContainerInsights.DISABLED,
});

// Create auto-scaling group
Expand Down
Loading

0 comments on commit f21e1c2

Please sign in to comment.