diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0fa7910..e4ce89e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,7 +34,7 @@ repos: - id: check-types name: check types types_or: [python, jupyter] - entry: uv run mypy . + entry: uv run mypy . --exclude templates language: system always_run: false pass_filenames: false diff --git a/src/create_ragbits_app/main.py b/src/create_ragbits_app/main.py index 685c725..7c7614f 100644 --- a/src/create_ragbits_app/main.py +++ b/src/create_ragbits_app/main.py @@ -44,7 +44,7 @@ async def run() -> None: # Let user select a template from inquirer.shortcuts import list_input, text - selected_template_str = list_input("Select a template to use", choices=template_choices) + selected_template_str = list_input("Select a template to use (more to come soon!)", choices=template_choices) # Get the directory name from the selection selected_idx = template_choices.index(selected_template_str) diff --git a/src/create_ragbits_app/template_utils.py b/src/create_ragbits_app/template_utils.py index e6e33ca..ea03687 100644 --- a/src/create_ragbits_app/template_utils.py +++ b/src/create_ragbits_app/template_utils.py @@ -19,12 +19,17 @@ import sys import jinja2 +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn +from rich.tree import Tree from create_ragbits_app.template_config_base import TemplateConfig # Get templates directory TEMPLATES_DIR = pathlib.Path(__file__).parent / "templates" +console = Console() + def get_available_templates() -> list[dict]: """Get list of available templates from templates directory with their metadata.""" @@ -79,48 +84,68 @@ def create_project(template_name: str, project_path: str, context: dict) -> None os.makedirs(project_path, exist_ok=True) # Process all template files and directories - for item in template_path.glob("**/*"): - if item.name == "template_config.py": - continue # Skip template config file - - # Get relative path from template root - rel_path = str(item.relative_to(template_path)) - - # Process path parts for Jinja templating (for directory names) - path_parts = [] - for part in pathlib.Path(rel_path).parts: - if "{{" in part and "}}" in part: - # Render the directory name as a template - name_template = jinja2.Template(part) - rendered_part = name_template.render(**context) - path_parts.append(rendered_part) - else: - path_parts.append(part) - - # Construct the target path with processed directory names - target_rel_path = os.path.join(*path_parts) if path_parts else "" - target_path = pathlib.Path(project_path) / target_rel_path - - if item.is_dir(): - os.makedirs(target_path, exist_ok=True) - elif item.is_file(): - # Process as template if it's a .j2 file - if item.suffix == ".j2": - with open(item) as f: - template_content = f.read() - - # Render template with context - template = jinja2.Template(template_content) - rendered_content = template.render(**context) - - # Save to target path without .j2 extension - target_path = target_path.with_suffix("") - with open(target_path, "w") as f: - f.write(rendered_content) + with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), console=console) as progress: + progress.add_task("[cyan]Creating project structure...", total=None) + + for item in template_path.glob("**/*"): + if item.name == "template_config.py": + continue # Skip template config file + + # Get relative path from template root + rel_path = str(item.relative_to(template_path)) + + # Process path parts for Jinja templating (for directory names) + path_parts = [] + for part in pathlib.Path(rel_path).parts: + if "{{" in part and "}}" in part: + # Render the directory name as a template + name_template = jinja2.Template(part) + rendered_part = name_template.render(**context) + path_parts.append(rendered_part) + else: + path_parts.append(part) + + # Construct the target path with processed directory names + target_rel_path = os.path.join(*path_parts) if path_parts else "" + target_path = pathlib.Path(project_path) / target_rel_path + + if item.is_dir(): + os.makedirs(target_path, exist_ok=True) + elif item.is_file(): + # Process as template if it's a .j2 file + if item.suffix == ".j2": + with open(item) as f: + template_content = f.read() + + # Render template with context + template = jinja2.Template(template_content) + rendered_content = template.render(**context) + + # Save to target path without .j2 extension + target_path = target_path.with_suffix("") + with open(target_path, "w") as f: + f.write(rendered_content) + else: + # Create parent directories if they don't exist + os.makedirs(target_path.parent, exist_ok=True) + # Simple file copy + shutil.copy2(item, target_path) + + # Display project structure + console.print("\n[bold green]✓ Project created successfully![/bold green]") + console.print(f"[bold]Project location:[/bold] {project_path}\n") + + # Create and display project tree + tree = Tree("[bold blue]Project Structure[/bold blue]") + project_root = pathlib.Path(project_path) + + def build_tree(node: Tree, path: pathlib.Path) -> None: + for item in path.iterdir(): + if item.is_dir(): + branch = node.add(f"[bold cyan]{item.name}[/bold cyan]") + build_tree(branch, item) else: - # Create parent directories if they don't exist - os.makedirs(target_path.parent, exist_ok=True) - # Simple file copy - shutil.copy2(item, target_path) + node.add(f"[green]{item.name}[/green]") - print(f"Project created successfully at {project_path}") + build_tree(tree, project_root) + console.print(tree) diff --git a/src/create_ragbits_app/templates/rag/docker/docker-compose.yml.j2 b/src/create_ragbits_app/templates/rag/docker/docker-compose.yml.j2 index a64295d..9fb16d4 100644 --- a/src/create_ragbits_app/templates/rag/docker/docker-compose.yml.j2 +++ b/src/create_ragbits_app/templates/rag/docker/docker-compose.yml.j2 @@ -1,11 +1,34 @@ version: '3.8' + services: + {% if vector_store == "Qdrant" %} qdrant: - image: qdrant/qdrant + image: qdrant/qdrant:latest container_name: qdrant-db ports: - "6333:6333" - "6334:6334" volumes: - - /opt/qdrant_storage_ragbits_{{pkg_name}}:/qdrant/storage:z + - qdrant_data:/qdrant/storage restart: unless-stopped + {% elif vector_store == "Postgresql with pgvector" %} + postgres: + image: ankane/pgvector:latest + container_name: postgres-db + ports: + - "5432:5432" + environment: + POSTGRES_DB: postgres + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + volumes: + - postgres_data:/var/lib/postgresql/data + restart: unless-stopped + {% endif %} + +volumes: + {% if vector_store == "Qdrant" %} + qdrant_data: + {% elif vector_store == "Postgresql with pgvector" %} + postgres_data: + {% endif %} diff --git a/src/create_ragbits_app/templates/rag/ingest.py.j2 b/src/create_ragbits_app/templates/rag/ingest.py.j2 deleted file mode 100644 index 72e8556..0000000 --- a/src/create_ragbits_app/templates/rag/ingest.py.j2 +++ /dev/null @@ -1,23 +0,0 @@ -import asyncio - -from ragbits.document_search.documents.sources import WebSource - -from ragbits_rag.components import get_document_search - - -async def main(): - document_search = get_document_search() - - # You can ingest a document using one of ragbits built-in sources, such as web, s3, gcs, azure or git. - result = await document_search.ingest([ - WebSource(url="https://www.ecdc.europa.eu/sites/default/files/documents/AER-Yellow-Fever-2019.pdf") - ]) - - print(f"Number of successfully parsed documents: {len(result.successful)}") - - for doc in result.failed: - print(f"Failed to parse document: {doc.document_uri}") - - -if __name__ == '__main__': - asyncio.run(main()) diff --git a/src/create_ragbits_app/templates/rag/pyproject.toml.j2 b/src/create_ragbits_app/templates/rag/pyproject.toml.j2 index 51081a5..c7c9596 100644 --- a/src/create_ragbits_app/templates/rag/pyproject.toml.j2 +++ b/src/create_ragbits_app/templates/rag/pyproject.toml.j2 @@ -5,9 +5,9 @@ description = "Repository generated with ragbits" readme = "README.md" requires-python = ">={{ python_version }}" dependencies = [ - "ragbits[qdrant]=={{ ragbits_version }}", - "pydantic-settings", - "unstructured[pdf]>=0.17.2", + {% for dep in dependencies %} + "{{ dep }}", + {% endfor %} ] [build-system] diff --git a/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/components.py.j2 b/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/components.py.j2 index 0543389..321aa2b 100644 --- a/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/components.py.j2 +++ b/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/components.py.j2 @@ -1,14 +1,34 @@ import logging +{% if vector_store == "Qdrant" %} from qdrant_client import AsyncQdrantClient -from ragbits.core.embeddings import LiteLLMEmbedder -from ragbits.core.llms import LiteLLM from ragbits.core.vector_stores.qdrant import QdrantVectorStore +{% elif vector_store == "Postgresql with pgvector" %} +from ragbits.core.vector_stores.pgvector import PgvectorVectorStore +import asyncpg +{% endif %} + +from ragbits.core.embeddings.dense import LiteLLMEmbedder +{% if hybrid_search %} +from ragbits.core.embeddings.sparse.fastembed import FastEmbedSparseEmbedder +from ragbits.core.vector_stores.hybrid import HybridSearchVectorStore +{% endif %} +from ragbits.core.llms import LiteLLM from ragbits.document_search import DocumentSearch +{% if image_description %} from ragbits.document_search.documents.element import ImageElement from ragbits.document_search.ingestion.enrichers import ElementEnricherRouter, ImageElementEnricher +{% endif %} +from ragbits.document_search.ingestion.parsers import DocumentParserRouter +from ragbits.document_search.documents.document import DocumentType from ragbits.document_search.retrieval.rephrasers import LLMQueryRephraser +{% if parser == "unstructured" %} +from ragbits.document_search.ingestion.parsers.unstructured import UnstructuredDocumentParser +{% elif parser == "docling" %} +from ragbits.document_search.ingestion.parsers.docling import DoclingDocumentParser +{% endif %} + from {{pkg_name}}.config import config # disable logging from LiteLLM as ragbits already logs the necessary information @@ -21,21 +41,62 @@ def get_llm(): def get_vector_store(): store_prefix = "{{project_name}}" - qdrant_client = AsyncQdrantClient(config.qdrant_host) - dense_embedder = LiteLLMEmbedder(model=config.embedding_model) + dense_embedder = LiteLLMEmbedder(model_name=config.embedding_model) - return QdrantVectorStore(client=qdrant_client, embedder=dense_embedder, index_name=store_prefix + "-dense") + {% if vector_store == "Qdrant" %} + qdrant_client = AsyncQdrantClient(config.qdrant_host) + dense_store = QdrantVectorStore(client=qdrant_client, embedder=dense_embedder, index_name=store_prefix + "-dense") + {% if hybrid_search %} + sparse_embedder = FastEmbedSparseEmbedder(model_name="Qdrant/bm25") + sparse_store = QdrantVectorStore(client=qdrant_client, embedder=sparse_embedder, index_name=store_prefix + "-sparse") + return HybridSearchVectorStore(dense_store, sparse_store) + {% else %} + return dense_store + {% endif %} + {% elif vector_store == "Postgresql with pgvector" %} + pool = await asyncpg.create_pool(config.postgres_dsn) + dense_store = PgvectorVectorStore(pool=pool, embedder=dense_embedder, table_name=store_prefix + "-dense") + {% if hybrid_search %} + sparse_embedder = FastEmbedSparseEmbedder(model_name="Qdrant/bm25") + sparse_store = PgvectorVectorStore(pool=pool, embedder=sparse_embedder, table_name=store_prefix + "-sparse") + return HybridSearchVectorStore(dense_store, sparse_store) + {% else %} + return dense_store + {% endif %} + {% endif %} def get_document_search(): llm = get_llm() vector_store = get_vector_store() + # Configure document parsers + {% if parser == "unstructured" %} + parser_router = DocumentParserRouter({ + DocumentType.PDF: UnstructuredDocumentParser(), + DocumentType.DOCX: UnstructuredDocumentParser(), + DocumentType.HTML: UnstructuredDocumentParser(), + DocumentType.TXT: UnstructuredDocumentParser(), + }) + {% elif parser == "docling" %} + parser_router = DocumentParserRouter({ + DocumentType.PDF: DoclingDocumentParser(), + DocumentType.DOCX: DoclingDocumentParser(), + DocumentType.HTML: DoclingDocumentParser(), + DocumentType.TXT: DoclingDocumentParser(), + }) + {% endif %} + document_search = DocumentSearch( vector_store=vector_store, query_rephraser=LLMQueryRephraser(llm=llm), # Setup query rephrasing + parser_router=parser_router, # Add document parser configuration + {% if image_description %} enricher_router=ElementEnricherRouter({ ImageElement: ImageElementEnricher(llm=llm) # Get image descriptions with multi-modal LLM - }) + }), + {% else %} + enricher_router=None, + {% endif %} ) return document_search diff --git a/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/config.py.j2 b/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/config.py.j2 index 0ca405e..2971626 100644 --- a/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/config.py.j2 +++ b/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/config.py.j2 @@ -4,12 +4,17 @@ from pydantic_settings import BaseSettings class AppConfig(BaseSettings): + # Vector store configuration + {% if vector_store == "Qdrant" %} qdrant_host: str = "http://localhost:6333" - - llm_model: str = "gpt-4o-mini" - embedding_model: str = "text-embedding-3-small" - - openai_api_key: str = "" + {% elif vector_store == "Postgresql with pgvector" %} + postgres_dsn: str = "postgresql://postgres:postgres@localhost:5432/postgres" + {% endif %} + + # LLM configuration + llm_model: str = "gpt-4.1-mini" + embedding_model: str = "text-embedding-3-large" + openai_api_key: str class Config: env_file = Path(__file__).parent.parent.parent / ".env" diff --git a/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/ingest.py.j2 b/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/ingest.py.j2 index ac35fe2..9ffa99f 100644 --- a/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/ingest.py.j2 +++ b/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/ingest.py.j2 @@ -1,7 +1,5 @@ import asyncio -from ragbits.document_search.documents.sources import WebSource - from {{pkg_name}}.components import get_document_search @@ -9,9 +7,9 @@ async def main(): document_search = get_document_search() # You can ingest a document using one of ragbits built-in sources, such as web, s3, gcs, azure or git. - result = await document_search.ingest([ - WebSource(url="https://www.ecdc.europa.eu/sites/default/files/documents/AER-Yellow-Fever-2019.pdf") - ]) + result = await document_search.ingest( + "web://https://www.ecdc.europa.eu/sites/default/files/documents/AER-Yellow-Fever-2019.pdf" + ) print(f"Number of successfully parsed documents: {len(result.successful)}") diff --git a/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/main.py.j2 b/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/main.py.j2 index a0dba21..dc28a05 100644 --- a/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/main.py.j2 +++ b/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/main.py.j2 @@ -1,43 +1,43 @@ """ Main entry point for the {{ project_name }} application. +To run this file, use the following command: + +```bash +ragbits api run {{pkg_name}}.main:ChatApp +``` """ -import asyncio +from collections.abc import AsyncGenerator from pydantic import BaseModel from ragbits.core.prompt import Prompt - +from ragbits.chat.interface import ChatInterface +from ragbits.chat.interface.types import ChatResponse, Message +from {{pkg_name}}.prompt_qa import QAPrompt, QAPromptInput from {{pkg_name}}.components import get_llm, get_document_search -class QAPromptInput(BaseModel): - question: str - contexts: list[str] - -class QAPrompt(Prompt[QAPromptInput, str]): - system_prompt = """ - Your task is to answer user questions based on context. - If the question is not related to the context, say that the question is not related. - - - {% for context in contexts %} - {{context}} - {% endfor %} - - """ - - user_prompt = "{{ question }}" - -async def reply(question: str): - llm = get_llm() - document_search = get_document_search() +class ChatApp(ChatInterface): + """Chat interface for {{ project_name }} application.""" - context = await document_search.search(question) + def __init__(self) -> None: + self.llm = get_llm() + self.document_search = get_document_search() - async for chunk in llm.generate_streaming( - QAPrompt(QAPromptInput(question=question, contexts=[ctx.text_representation for ctx in context])) - ): - print(chunk, end="") + async def chat( + self, + message: str, + history: list[Message] | None = None, + context: dict | None = None, + ) -> AsyncGenerator[ChatResponse, None]: + # Search for relevant documents + search_results = await self.document_search.search(message) + # Create prompt with context + prompt = QAPrompt(QAPromptInput( + question=message, + contexts=[ctx.text_representation for ctx in search_results] + )) -if __name__ == '__main__': - asyncio.run(reply("In what countries in 2018 there was signs of yellow fever?")) + # Stream the response from the LLM + async for chunk in self.llm.generate_streaming(prompt): + yield self.create_text_response(chunk) diff --git a/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/prompt_qa.py b/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/prompt_qa.py new file mode 100644 index 0000000..baa1889 --- /dev/null +++ b/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/prompt_qa.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel +from ragbits.core.prompt import Prompt + + +class QAPromptInput(BaseModel): + """Input for question answering prompt.""" + + question: str + contexts: list[str] + + +class QAPrompt(Prompt[QAPromptInput, str]): + """Prompt for question answering.""" + + system_prompt = """ + Your task is to answer user questions based on context. + If the question is not related to the context, say that the question is not related. + + + {% for context in contexts %} + {{context}} + {% endfor %} + + """ + + user_prompt = "{{ question }}" diff --git a/src/create_ragbits_app/templates/rag/template_config.py b/src/create_ragbits_app/templates/rag/template_config.py index 564d3dd..1d0dcd4 100644 --- a/src/create_ragbits_app/templates/rag/template_config.py +++ b/src/create_ragbits_app/templates/rag/template_config.py @@ -3,6 +3,9 @@ """ from create_ragbits_app.template_config_base import ( + ConfirmQuestion, + ListQuestion, + Question, TemplateConfig, ) @@ -13,7 +16,61 @@ class RagTemplateConfig(TemplateConfig): name: str = "RAG (Retrieval Augmented Generation)" description: str = "Basic RAG (Retrieval Augmented Generation) application" - questions: list = [] + questions: list[Question] = [ + ListQuestion( + name="vector_store", + message="What Vector database you want to use?", + choices=[ + "Qdrant", + "Postgresql with pgvector", + ], + ), + ListQuestion( + name="parser", + message="What parser you want to use parse documents?", + choices=[ + "docling", + "unstructured", + ], + ), + ConfirmQuestion( + name="hybrid_search", message="Do you want to use hybrid search with sparse embeddings?", default=True + ), + ConfirmQuestion( + name="image_description", message="Do you want to describe images with multi-modal LLM?", default=True + ), + ] + + def build_context(self, context: dict) -> dict: # noqa: PLR6301 + """Build additional context based on the answers.""" + vector_store = context.get("vector_store") + parser = context.get("parser") + hybrid_search = context.get("hybrid_search") + + # Collect all ragbits extras + ragbits_extras = [] + + if vector_store == "Qdrant": + ragbits_extras.append("qdrant") + elif vector_store == "Postgresql with pgvector": + ragbits_extras.append("pgvector") + + if parser == "docling": + ragbits_extras.append("docling") + + if hybrid_search: + ragbits_extras.append("fastembed") + + # Build dependencies list + dependencies = [ + f"ragbits[{','.join(ragbits_extras)}]=={context.get('ragbits_version')}", + "pydantic-settings", + ] + + if parser == "unstructured": + dependencies.append("unstructured[pdf]>=0.17.2") + + return {"dependencies": dependencies} # Create instance of the config to be imported