diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 0fa7910..e4ce89e 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -34,7 +34,7 @@ repos:
- id: check-types
name: check types
types_or: [python, jupyter]
- entry: uv run mypy .
+ entry: uv run mypy . --exclude templates
language: system
always_run: false
pass_filenames: false
diff --git a/src/create_ragbits_app/main.py b/src/create_ragbits_app/main.py
index 685c725..7c7614f 100644
--- a/src/create_ragbits_app/main.py
+++ b/src/create_ragbits_app/main.py
@@ -44,7 +44,7 @@ async def run() -> None:
# Let user select a template
from inquirer.shortcuts import list_input, text
- selected_template_str = list_input("Select a template to use", choices=template_choices)
+ selected_template_str = list_input("Select a template to use (more to come soon!)", choices=template_choices)
# Get the directory name from the selection
selected_idx = template_choices.index(selected_template_str)
diff --git a/src/create_ragbits_app/template_utils.py b/src/create_ragbits_app/template_utils.py
index e6e33ca..ea03687 100644
--- a/src/create_ragbits_app/template_utils.py
+++ b/src/create_ragbits_app/template_utils.py
@@ -19,12 +19,17 @@
import sys
import jinja2
+from rich.console import Console
+from rich.progress import Progress, SpinnerColumn, TextColumn
+from rich.tree import Tree
from create_ragbits_app.template_config_base import TemplateConfig
# Get templates directory
TEMPLATES_DIR = pathlib.Path(__file__).parent / "templates"
+console = Console()
+
def get_available_templates() -> list[dict]:
"""Get list of available templates from templates directory with their metadata."""
@@ -79,48 +84,68 @@ def create_project(template_name: str, project_path: str, context: dict) -> None
os.makedirs(project_path, exist_ok=True)
# Process all template files and directories
- for item in template_path.glob("**/*"):
- if item.name == "template_config.py":
- continue # Skip template config file
-
- # Get relative path from template root
- rel_path = str(item.relative_to(template_path))
-
- # Process path parts for Jinja templating (for directory names)
- path_parts = []
- for part in pathlib.Path(rel_path).parts:
- if "{{" in part and "}}" in part:
- # Render the directory name as a template
- name_template = jinja2.Template(part)
- rendered_part = name_template.render(**context)
- path_parts.append(rendered_part)
- else:
- path_parts.append(part)
-
- # Construct the target path with processed directory names
- target_rel_path = os.path.join(*path_parts) if path_parts else ""
- target_path = pathlib.Path(project_path) / target_rel_path
-
- if item.is_dir():
- os.makedirs(target_path, exist_ok=True)
- elif item.is_file():
- # Process as template if it's a .j2 file
- if item.suffix == ".j2":
- with open(item) as f:
- template_content = f.read()
-
- # Render template with context
- template = jinja2.Template(template_content)
- rendered_content = template.render(**context)
-
- # Save to target path without .j2 extension
- target_path = target_path.with_suffix("")
- with open(target_path, "w") as f:
- f.write(rendered_content)
+ with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), console=console) as progress:
+ progress.add_task("[cyan]Creating project structure...", total=None)
+
+ for item in template_path.glob("**/*"):
+ if item.name == "template_config.py":
+ continue # Skip template config file
+
+ # Get relative path from template root
+ rel_path = str(item.relative_to(template_path))
+
+ # Process path parts for Jinja templating (for directory names)
+ path_parts = []
+ for part in pathlib.Path(rel_path).parts:
+ if "{{" in part and "}}" in part:
+ # Render the directory name as a template
+ name_template = jinja2.Template(part)
+ rendered_part = name_template.render(**context)
+ path_parts.append(rendered_part)
+ else:
+ path_parts.append(part)
+
+ # Construct the target path with processed directory names
+ target_rel_path = os.path.join(*path_parts) if path_parts else ""
+ target_path = pathlib.Path(project_path) / target_rel_path
+
+ if item.is_dir():
+ os.makedirs(target_path, exist_ok=True)
+ elif item.is_file():
+ # Process as template if it's a .j2 file
+ if item.suffix == ".j2":
+ with open(item) as f:
+ template_content = f.read()
+
+ # Render template with context
+ template = jinja2.Template(template_content)
+ rendered_content = template.render(**context)
+
+ # Save to target path without .j2 extension
+ target_path = target_path.with_suffix("")
+ with open(target_path, "w") as f:
+ f.write(rendered_content)
+ else:
+ # Create parent directories if they don't exist
+ os.makedirs(target_path.parent, exist_ok=True)
+ # Simple file copy
+ shutil.copy2(item, target_path)
+
+ # Display project structure
+ console.print("\n[bold green]✓ Project created successfully![/bold green]")
+ console.print(f"[bold]Project location:[/bold] {project_path}\n")
+
+ # Create and display project tree
+ tree = Tree("[bold blue]Project Structure[/bold blue]")
+ project_root = pathlib.Path(project_path)
+
+ def build_tree(node: Tree, path: pathlib.Path) -> None:
+ for item in path.iterdir():
+ if item.is_dir():
+ branch = node.add(f"[bold cyan]{item.name}[/bold cyan]")
+ build_tree(branch, item)
else:
- # Create parent directories if they don't exist
- os.makedirs(target_path.parent, exist_ok=True)
- # Simple file copy
- shutil.copy2(item, target_path)
+ node.add(f"[green]{item.name}[/green]")
- print(f"Project created successfully at {project_path}")
+ build_tree(tree, project_root)
+ console.print(tree)
diff --git a/src/create_ragbits_app/templates/rag/docker/docker-compose.yml.j2 b/src/create_ragbits_app/templates/rag/docker/docker-compose.yml.j2
index a64295d..9fb16d4 100644
--- a/src/create_ragbits_app/templates/rag/docker/docker-compose.yml.j2
+++ b/src/create_ragbits_app/templates/rag/docker/docker-compose.yml.j2
@@ -1,11 +1,34 @@
version: '3.8'
+
services:
+ {% if vector_store == "Qdrant" %}
qdrant:
- image: qdrant/qdrant
+ image: qdrant/qdrant:latest
container_name: qdrant-db
ports:
- "6333:6333"
- "6334:6334"
volumes:
- - /opt/qdrant_storage_ragbits_{{pkg_name}}:/qdrant/storage:z
+ - qdrant_data:/qdrant/storage
restart: unless-stopped
+ {% elif vector_store == "Postgresql with pgvector" %}
+ postgres:
+ image: ankane/pgvector:latest
+ container_name: postgres-db
+ ports:
+ - "5432:5432"
+ environment:
+ POSTGRES_DB: postgres
+ POSTGRES_USER: postgres
+ POSTGRES_PASSWORD: postgres
+ volumes:
+ - postgres_data:/var/lib/postgresql/data
+ restart: unless-stopped
+ {% endif %}
+
+volumes:
+ {% if vector_store == "Qdrant" %}
+ qdrant_data:
+ {% elif vector_store == "Postgresql with pgvector" %}
+ postgres_data:
+ {% endif %}
diff --git a/src/create_ragbits_app/templates/rag/ingest.py.j2 b/src/create_ragbits_app/templates/rag/ingest.py.j2
deleted file mode 100644
index 72e8556..0000000
--- a/src/create_ragbits_app/templates/rag/ingest.py.j2
+++ /dev/null
@@ -1,23 +0,0 @@
-import asyncio
-
-from ragbits.document_search.documents.sources import WebSource
-
-from ragbits_rag.components import get_document_search
-
-
-async def main():
- document_search = get_document_search()
-
- # You can ingest a document using one of ragbits built-in sources, such as web, s3, gcs, azure or git.
- result = await document_search.ingest([
- WebSource(url="https://www.ecdc.europa.eu/sites/default/files/documents/AER-Yellow-Fever-2019.pdf")
- ])
-
- print(f"Number of successfully parsed documents: {len(result.successful)}")
-
- for doc in result.failed:
- print(f"Failed to parse document: {doc.document_uri}")
-
-
-if __name__ == '__main__':
- asyncio.run(main())
diff --git a/src/create_ragbits_app/templates/rag/pyproject.toml.j2 b/src/create_ragbits_app/templates/rag/pyproject.toml.j2
index 51081a5..c7c9596 100644
--- a/src/create_ragbits_app/templates/rag/pyproject.toml.j2
+++ b/src/create_ragbits_app/templates/rag/pyproject.toml.j2
@@ -5,9 +5,9 @@ description = "Repository generated with ragbits"
readme = "README.md"
requires-python = ">={{ python_version }}"
dependencies = [
- "ragbits[qdrant]=={{ ragbits_version }}",
- "pydantic-settings",
- "unstructured[pdf]>=0.17.2",
+ {% for dep in dependencies %}
+ "{{ dep }}",
+ {% endfor %}
]
[build-system]
diff --git a/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/components.py.j2 b/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/components.py.j2
index 0543389..321aa2b 100644
--- a/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/components.py.j2
+++ b/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/components.py.j2
@@ -1,14 +1,34 @@
import logging
+{% if vector_store == "Qdrant" %}
from qdrant_client import AsyncQdrantClient
-from ragbits.core.embeddings import LiteLLMEmbedder
-from ragbits.core.llms import LiteLLM
from ragbits.core.vector_stores.qdrant import QdrantVectorStore
+{% elif vector_store == "Postgresql with pgvector" %}
+from ragbits.core.vector_stores.pgvector import PgvectorVectorStore
+import asyncpg
+{% endif %}
+
+from ragbits.core.embeddings.dense import LiteLLMEmbedder
+{% if hybrid_search %}
+from ragbits.core.embeddings.sparse.fastembed import FastEmbedSparseEmbedder
+from ragbits.core.vector_stores.hybrid import HybridSearchVectorStore
+{% endif %}
+from ragbits.core.llms import LiteLLM
from ragbits.document_search import DocumentSearch
+{% if image_description %}
from ragbits.document_search.documents.element import ImageElement
from ragbits.document_search.ingestion.enrichers import ElementEnricherRouter, ImageElementEnricher
+{% endif %}
+from ragbits.document_search.ingestion.parsers import DocumentParserRouter
+from ragbits.document_search.documents.document import DocumentType
from ragbits.document_search.retrieval.rephrasers import LLMQueryRephraser
+{% if parser == "unstructured" %}
+from ragbits.document_search.ingestion.parsers.unstructured import UnstructuredDocumentParser
+{% elif parser == "docling" %}
+from ragbits.document_search.ingestion.parsers.docling import DoclingDocumentParser
+{% endif %}
+
from {{pkg_name}}.config import config
# disable logging from LiteLLM as ragbits already logs the necessary information
@@ -21,21 +41,62 @@ def get_llm():
def get_vector_store():
store_prefix = "{{project_name}}"
- qdrant_client = AsyncQdrantClient(config.qdrant_host)
- dense_embedder = LiteLLMEmbedder(model=config.embedding_model)
+ dense_embedder = LiteLLMEmbedder(model_name=config.embedding_model)
- return QdrantVectorStore(client=qdrant_client, embedder=dense_embedder, index_name=store_prefix + "-dense")
+ {% if vector_store == "Qdrant" %}
+ qdrant_client = AsyncQdrantClient(config.qdrant_host)
+ dense_store = QdrantVectorStore(client=qdrant_client, embedder=dense_embedder, index_name=store_prefix + "-dense")
+ {% if hybrid_search %}
+ sparse_embedder = FastEmbedSparseEmbedder(model_name="Qdrant/bm25")
+ sparse_store = QdrantVectorStore(client=qdrant_client, embedder=sparse_embedder, index_name=store_prefix + "-sparse")
+ return HybridSearchVectorStore(dense_store, sparse_store)
+ {% else %}
+ return dense_store
+ {% endif %}
+ {% elif vector_store == "Postgresql with pgvector" %}
+ pool = await asyncpg.create_pool(config.postgres_dsn)
+ dense_store = PgvectorVectorStore(pool=pool, embedder=dense_embedder, table_name=store_prefix + "-dense")
+ {% if hybrid_search %}
+ sparse_embedder = FastEmbedSparseEmbedder(model_name="Qdrant/bm25")
+ sparse_store = PgvectorVectorStore(pool=pool, embedder=sparse_embedder, table_name=store_prefix + "-sparse")
+ return HybridSearchVectorStore(dense_store, sparse_store)
+ {% else %}
+ return dense_store
+ {% endif %}
+ {% endif %}
def get_document_search():
llm = get_llm()
vector_store = get_vector_store()
+ # Configure document parsers
+ {% if parser == "unstructured" %}
+ parser_router = DocumentParserRouter({
+ DocumentType.PDF: UnstructuredDocumentParser(),
+ DocumentType.DOCX: UnstructuredDocumentParser(),
+ DocumentType.HTML: UnstructuredDocumentParser(),
+ DocumentType.TXT: UnstructuredDocumentParser(),
+ })
+ {% elif parser == "docling" %}
+ parser_router = DocumentParserRouter({
+ DocumentType.PDF: DoclingDocumentParser(),
+ DocumentType.DOCX: DoclingDocumentParser(),
+ DocumentType.HTML: DoclingDocumentParser(),
+ DocumentType.TXT: DoclingDocumentParser(),
+ })
+ {% endif %}
+
document_search = DocumentSearch(
vector_store=vector_store,
query_rephraser=LLMQueryRephraser(llm=llm), # Setup query rephrasing
+ parser_router=parser_router, # Add document parser configuration
+ {% if image_description %}
enricher_router=ElementEnricherRouter({
ImageElement: ImageElementEnricher(llm=llm) # Get image descriptions with multi-modal LLM
- })
+ }),
+ {% else %}
+ enricher_router=None,
+ {% endif %}
)
return document_search
diff --git a/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/config.py.j2 b/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/config.py.j2
index 0ca405e..2971626 100644
--- a/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/config.py.j2
+++ b/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/config.py.j2
@@ -4,12 +4,17 @@ from pydantic_settings import BaseSettings
class AppConfig(BaseSettings):
+ # Vector store configuration
+ {% if vector_store == "Qdrant" %}
qdrant_host: str = "http://localhost:6333"
-
- llm_model: str = "gpt-4o-mini"
- embedding_model: str = "text-embedding-3-small"
-
- openai_api_key: str = ""
+ {% elif vector_store == "Postgresql with pgvector" %}
+ postgres_dsn: str = "postgresql://postgres:postgres@localhost:5432/postgres"
+ {% endif %}
+
+ # LLM configuration
+ llm_model: str = "gpt-4.1-mini"
+ embedding_model: str = "text-embedding-3-large"
+ openai_api_key: str
class Config:
env_file = Path(__file__).parent.parent.parent / ".env"
diff --git a/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/ingest.py.j2 b/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/ingest.py.j2
index ac35fe2..9ffa99f 100644
--- a/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/ingest.py.j2
+++ b/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/ingest.py.j2
@@ -1,7 +1,5 @@
import asyncio
-from ragbits.document_search.documents.sources import WebSource
-
from {{pkg_name}}.components import get_document_search
@@ -9,9 +7,9 @@ async def main():
document_search = get_document_search()
# You can ingest a document using one of ragbits built-in sources, such as web, s3, gcs, azure or git.
- result = await document_search.ingest([
- WebSource(url="https://www.ecdc.europa.eu/sites/default/files/documents/AER-Yellow-Fever-2019.pdf")
- ])
+ result = await document_search.ingest(
+ "web://https://www.ecdc.europa.eu/sites/default/files/documents/AER-Yellow-Fever-2019.pdf"
+ )
print(f"Number of successfully parsed documents: {len(result.successful)}")
diff --git a/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/main.py.j2 b/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/main.py.j2
index a0dba21..dc28a05 100644
--- a/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/main.py.j2
+++ b/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/main.py.j2
@@ -1,43 +1,43 @@
"""
Main entry point for the {{ project_name }} application.
+To run this file, use the following command:
+
+```bash
+ragbits api run {{pkg_name}}.main:ChatApp
+```
"""
-import asyncio
+from collections.abc import AsyncGenerator
from pydantic import BaseModel
from ragbits.core.prompt import Prompt
-
+from ragbits.chat.interface import ChatInterface
+from ragbits.chat.interface.types import ChatResponse, Message
+from {{pkg_name}}.prompt_qa import QAPrompt, QAPromptInput
from {{pkg_name}}.components import get_llm, get_document_search
-class QAPromptInput(BaseModel):
- question: str
- contexts: list[str]
-
-class QAPrompt(Prompt[QAPromptInput, str]):
- system_prompt = """
- Your task is to answer user questions based on context.
- If the question is not related to the context, say that the question is not related.
-
-
- {% for context in contexts %}
- {{context}}
- {% endfor %}
-
- """
-
- user_prompt = "{{ question }}"
-
-async def reply(question: str):
- llm = get_llm()
- document_search = get_document_search()
+class ChatApp(ChatInterface):
+ """Chat interface for {{ project_name }} application."""
- context = await document_search.search(question)
+ def __init__(self) -> None:
+ self.llm = get_llm()
+ self.document_search = get_document_search()
- async for chunk in llm.generate_streaming(
- QAPrompt(QAPromptInput(question=question, contexts=[ctx.text_representation for ctx in context]))
- ):
- print(chunk, end="")
+ async def chat(
+ self,
+ message: str,
+ history: list[Message] | None = None,
+ context: dict | None = None,
+ ) -> AsyncGenerator[ChatResponse, None]:
+ # Search for relevant documents
+ search_results = await self.document_search.search(message)
+ # Create prompt with context
+ prompt = QAPrompt(QAPromptInput(
+ question=message,
+ contexts=[ctx.text_representation for ctx in search_results]
+ ))
-if __name__ == '__main__':
- asyncio.run(reply("In what countries in 2018 there was signs of yellow fever?"))
+ # Stream the response from the LLM
+ async for chunk in self.llm.generate_streaming(prompt):
+ yield self.create_text_response(chunk)
diff --git a/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/prompt_qa.py b/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/prompt_qa.py
new file mode 100644
index 0000000..baa1889
--- /dev/null
+++ b/src/create_ragbits_app/templates/rag/src/{{pkg_name}}/prompt_qa.py
@@ -0,0 +1,26 @@
+from pydantic import BaseModel
+from ragbits.core.prompt import Prompt
+
+
+class QAPromptInput(BaseModel):
+ """Input for question answering prompt."""
+
+ question: str
+ contexts: list[str]
+
+
+class QAPrompt(Prompt[QAPromptInput, str]):
+ """Prompt for question answering."""
+
+ system_prompt = """
+ Your task is to answer user questions based on context.
+ If the question is not related to the context, say that the question is not related.
+
+
+ {% for context in contexts %}
+ {{context}}
+ {% endfor %}
+
+ """
+
+ user_prompt = "{{ question }}"
diff --git a/src/create_ragbits_app/templates/rag/template_config.py b/src/create_ragbits_app/templates/rag/template_config.py
index 564d3dd..1d0dcd4 100644
--- a/src/create_ragbits_app/templates/rag/template_config.py
+++ b/src/create_ragbits_app/templates/rag/template_config.py
@@ -3,6 +3,9 @@
"""
from create_ragbits_app.template_config_base import (
+ ConfirmQuestion,
+ ListQuestion,
+ Question,
TemplateConfig,
)
@@ -13,7 +16,61 @@ class RagTemplateConfig(TemplateConfig):
name: str = "RAG (Retrieval Augmented Generation)"
description: str = "Basic RAG (Retrieval Augmented Generation) application"
- questions: list = []
+ questions: list[Question] = [
+ ListQuestion(
+ name="vector_store",
+ message="What Vector database you want to use?",
+ choices=[
+ "Qdrant",
+ "Postgresql with pgvector",
+ ],
+ ),
+ ListQuestion(
+ name="parser",
+ message="What parser you want to use parse documents?",
+ choices=[
+ "docling",
+ "unstructured",
+ ],
+ ),
+ ConfirmQuestion(
+ name="hybrid_search", message="Do you want to use hybrid search with sparse embeddings?", default=True
+ ),
+ ConfirmQuestion(
+ name="image_description", message="Do you want to describe images with multi-modal LLM?", default=True
+ ),
+ ]
+
+ def build_context(self, context: dict) -> dict: # noqa: PLR6301
+ """Build additional context based on the answers."""
+ vector_store = context.get("vector_store")
+ parser = context.get("parser")
+ hybrid_search = context.get("hybrid_search")
+
+ # Collect all ragbits extras
+ ragbits_extras = []
+
+ if vector_store == "Qdrant":
+ ragbits_extras.append("qdrant")
+ elif vector_store == "Postgresql with pgvector":
+ ragbits_extras.append("pgvector")
+
+ if parser == "docling":
+ ragbits_extras.append("docling")
+
+ if hybrid_search:
+ ragbits_extras.append("fastembed")
+
+ # Build dependencies list
+ dependencies = [
+ f"ragbits[{','.join(ragbits_extras)}]=={context.get('ragbits_version')}",
+ "pydantic-settings",
+ ]
+
+ if parser == "unstructured":
+ dependencies.append("unstructured[pdf]>=0.17.2")
+
+ return {"dependencies": dependencies}
# Create instance of the config to be imported