Skip to content

feat: new rag template #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ repos:
- id: check-types
name: check types
types_or: [python, jupyter]
entry: uv run mypy .
entry: uv run mypy . --exclude templates
language: system
always_run: false
pass_filenames: false
2 changes: 1 addition & 1 deletion src/create_ragbits_app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def run() -> None:
# Let user select a template
from inquirer.shortcuts import list_input, text

selected_template_str = list_input("Select a template to use", choices=template_choices)
selected_template_str = list_input("Select a template to use (more to come soon!)", choices=template_choices)

# Get the directory name from the selection
selected_idx = template_choices.index(selected_template_str)
Expand Down
111 changes: 68 additions & 43 deletions src/create_ragbits_app/template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@
import sys

import jinja2
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.tree import Tree

from create_ragbits_app.template_config_base import TemplateConfig

# Get templates directory
TEMPLATES_DIR = pathlib.Path(__file__).parent / "templates"

console = Console()


def get_available_templates() -> list[dict]:
"""Get list of available templates from templates directory with their metadata."""
Expand Down Expand Up @@ -79,48 +84,68 @@ def create_project(template_name: str, project_path: str, context: dict) -> None
os.makedirs(project_path, exist_ok=True)

# Process all template files and directories
for item in template_path.glob("**/*"):
if item.name == "template_config.py":
continue # Skip template config file

# Get relative path from template root
rel_path = str(item.relative_to(template_path))

# Process path parts for Jinja templating (for directory names)
path_parts = []
for part in pathlib.Path(rel_path).parts:
if "{{" in part and "}}" in part:
# Render the directory name as a template
name_template = jinja2.Template(part)
rendered_part = name_template.render(**context)
path_parts.append(rendered_part)
else:
path_parts.append(part)

# Construct the target path with processed directory names
target_rel_path = os.path.join(*path_parts) if path_parts else ""
target_path = pathlib.Path(project_path) / target_rel_path

if item.is_dir():
os.makedirs(target_path, exist_ok=True)
elif item.is_file():
# Process as template if it's a .j2 file
if item.suffix == ".j2":
with open(item) as f:
template_content = f.read()

# Render template with context
template = jinja2.Template(template_content)
rendered_content = template.render(**context)

# Save to target path without .j2 extension
target_path = target_path.with_suffix("")
with open(target_path, "w") as f:
f.write(rendered_content)
with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), console=console) as progress:
progress.add_task("[cyan]Creating project structure...", total=None)

for item in template_path.glob("**/*"):
if item.name == "template_config.py":
continue # Skip template config file

# Get relative path from template root
rel_path = str(item.relative_to(template_path))

# Process path parts for Jinja templating (for directory names)
path_parts = []
for part in pathlib.Path(rel_path).parts:
if "{{" in part and "}}" in part:
# Render the directory name as a template
name_template = jinja2.Template(part)
rendered_part = name_template.render(**context)
path_parts.append(rendered_part)
else:
path_parts.append(part)

# Construct the target path with processed directory names
target_rel_path = os.path.join(*path_parts) if path_parts else ""
target_path = pathlib.Path(project_path) / target_rel_path

if item.is_dir():
os.makedirs(target_path, exist_ok=True)
elif item.is_file():
# Process as template if it's a .j2 file
if item.suffix == ".j2":
with open(item) as f:
template_content = f.read()

# Render template with context
template = jinja2.Template(template_content)
rendered_content = template.render(**context)

# Save to target path without .j2 extension
target_path = target_path.with_suffix("")
with open(target_path, "w") as f:
f.write(rendered_content)
else:
# Create parent directories if they don't exist
os.makedirs(target_path.parent, exist_ok=True)
# Simple file copy
shutil.copy2(item, target_path)

# Display project structure
console.print("\n[bold green]✓ Project created successfully![/bold green]")
console.print(f"[bold]Project location:[/bold] {project_path}\n")

# Create and display project tree
tree = Tree("[bold blue]Project Structure[/bold blue]")
project_root = pathlib.Path(project_path)

def build_tree(node: Tree, path: pathlib.Path) -> None:
for item in path.iterdir():
if item.is_dir():
branch = node.add(f"[bold cyan]{item.name}[/bold cyan]")
build_tree(branch, item)
else:
# Create parent directories if they don't exist
os.makedirs(target_path.parent, exist_ok=True)
# Simple file copy
shutil.copy2(item, target_path)
node.add(f"[green]{item.name}[/green]")

print(f"Project created successfully at {project_path}")
build_tree(tree, project_root)
console.print(tree)
27 changes: 25 additions & 2 deletions src/create_ragbits_app/templates/rag/docker/docker-compose.yml.j2
Original file line number Diff line number Diff line change
@@ -1,11 +1,34 @@
version: '3.8'

services:
{% if vector_store == "Qdrant" %}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: comment to all jinja files

there are many unnecessary whitespace chars in the generated files. It can be fixed with whitespace control using minus signs - in the block delimiters (for example {%- if %})

qdrant:
image: qdrant/qdrant
image: qdrant/qdrant:latest
container_name: qdrant-db
ports:
- "6333:6333"
- "6334:6334"
volumes:
- /opt/qdrant_storage_ragbits_{{pkg_name}}:/qdrant/storage:z
- qdrant_data:/qdrant/storage
restart: unless-stopped
{% elif vector_store == "Postgresql with pgvector" %}
postgres:
image: ankane/pgvector:latest
container_name: postgres-db
ports:
- "5432:5432"
environment:
POSTGRES_DB: postgres
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
volumes:
- postgres_data:/var/lib/postgresql/data
restart: unless-stopped
{% endif %}

volumes:
{% if vector_store == "Qdrant" %}
qdrant_data:
{% elif vector_store == "Postgresql with pgvector" %}
postgres_data:
{% endif %}
23 changes: 0 additions & 23 deletions src/create_ragbits_app/templates/rag/ingest.py.j2

This file was deleted.

6 changes: 3 additions & 3 deletions src/create_ragbits_app/templates/rag/pyproject.toml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ description = "Repository generated with ragbits"
readme = "README.md"
requires-python = ">={{ python_version }}"
dependencies = [
"ragbits[qdrant]=={{ ragbits_version }}",
"pydantic-settings",
"unstructured[pdf]>=0.17.2",
{% for dep in dependencies %}
"{{ dep }}",
{% endfor %}
]

[build-system]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,34 @@
import logging

{% if vector_store == "Qdrant" %}
from qdrant_client import AsyncQdrantClient
from ragbits.core.embeddings import LiteLLMEmbedder
from ragbits.core.llms import LiteLLM
from ragbits.core.vector_stores.qdrant import QdrantVectorStore
{% elif vector_store == "Postgresql with pgvector" %}
from ragbits.core.vector_stores.pgvector import PgvectorVectorStore
import asyncpg
{% endif %}

from ragbits.core.embeddings.dense import LiteLLMEmbedder
{% if hybrid_search %}
from ragbits.core.embeddings.sparse.fastembed import FastEmbedSparseEmbedder
from ragbits.core.vector_stores.hybrid import HybridSearchVectorStore
{% endif %}
from ragbits.core.llms import LiteLLM
from ragbits.document_search import DocumentSearch
{% if image_description %}
from ragbits.document_search.documents.element import ImageElement
from ragbits.document_search.ingestion.enrichers import ElementEnricherRouter, ImageElementEnricher
{% endif %}
from ragbits.document_search.ingestion.parsers import DocumentParserRouter
from ragbits.document_search.documents.document import DocumentType
from ragbits.document_search.retrieval.rephrasers import LLMQueryRephraser

{% if parser == "unstructured" %}
from ragbits.document_search.ingestion.parsers.unstructured import UnstructuredDocumentParser
{% elif parser == "docling" %}
from ragbits.document_search.ingestion.parsers.docling import DoclingDocumentParser
{% endif %}

from {{pkg_name}}.config import config

# disable logging from LiteLLM as ragbits already logs the necessary information
Expand All @@ -21,21 +41,62 @@ def get_llm():

def get_vector_store():
store_prefix = "{{project_name}}"
qdrant_client = AsyncQdrantClient(config.qdrant_host)
dense_embedder = LiteLLMEmbedder(model=config.embedding_model)
dense_embedder = LiteLLMEmbedder(model_name=config.embedding_model)

return QdrantVectorStore(client=qdrant_client, embedder=dense_embedder, index_name=store_prefix + "-dense")
{% if vector_store == "Qdrant" %}
qdrant_client = AsyncQdrantClient(config.qdrant_host)
dense_store = QdrantVectorStore(client=qdrant_client, embedder=dense_embedder, index_name=store_prefix + "-dense")
{% if hybrid_search %}
sparse_embedder = FastEmbedSparseEmbedder(model_name="Qdrant/bm25")
sparse_store = QdrantVectorStore(client=qdrant_client, embedder=sparse_embedder, index_name=store_prefix + "-sparse")
return HybridSearchVectorStore(dense_store, sparse_store)
{% else %}
return dense_store
{% endif %}
{% elif vector_store == "Postgresql with pgvector" %}
pool = await asyncpg.create_pool(config.postgres_dsn)
dense_store = PgvectorVectorStore(pool=pool, embedder=dense_embedder, table_name=store_prefix + "-dense")
{% if hybrid_search %}
sparse_embedder = FastEmbedSparseEmbedder(model_name="Qdrant/bm25")
sparse_store = PgvectorVectorStore(pool=pool, embedder=sparse_embedder, table_name=store_prefix + "-sparse")
return HybridSearchVectorStore(dense_store, sparse_store)
{% else %}
return dense_store
{% endif %}
{% endif %}

def get_document_search():
llm = get_llm()
vector_store = get_vector_store()

# Configure document parsers
{% if parser == "unstructured" %}
parser_router = DocumentParserRouter({
DocumentType.PDF: UnstructuredDocumentParser(),
DocumentType.DOCX: UnstructuredDocumentParser(),
DocumentType.HTML: UnstructuredDocumentParser(),
DocumentType.TXT: UnstructuredDocumentParser(),
})
{% elif parser == "docling" %}
parser_router = DocumentParserRouter({
DocumentType.PDF: DoclingDocumentParser(),
DocumentType.DOCX: DoclingDocumentParser(),
DocumentType.HTML: DoclingDocumentParser(),
DocumentType.TXT: DoclingDocumentParser(),
})
{% endif %}

document_search = DocumentSearch(
vector_store=vector_store,
query_rephraser=LLMQueryRephraser(llm=llm), # Setup query rephrasing
parser_router=parser_router, # Add document parser configuration
{% if image_description %}
enricher_router=ElementEnricherRouter({
ImageElement: ImageElementEnricher(llm=llm) # Get image descriptions with multi-modal LLM
})
}),
{% else %}
enricher_router=None,
{% endif %}
)

return document_search
15 changes: 10 additions & 5 deletions src/create_ragbits_app/templates/rag/src/{{pkg_name}}/config.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@ from pydantic_settings import BaseSettings


class AppConfig(BaseSettings):
# Vector store configuration
{% if vector_store == "Qdrant" %}
qdrant_host: str = "http://localhost:6333"

llm_model: str = "gpt-4o-mini"
embedding_model: str = "text-embedding-3-small"

openai_api_key: str = ""
{% elif vector_store == "Postgresql with pgvector" %}
postgres_dsn: str = "postgresql://postgres:postgres@localhost:5432/postgres"
{% endif %}

# LLM configuration
llm_model: str = "gpt-4.1-mini"
embedding_model: str = "text-embedding-3-large"
openai_api_key: str

class Config:
env_file = Path(__file__).parent.parent.parent / ".env"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import asyncio

from ragbits.document_search.documents.sources import WebSource

from {{pkg_name}}.components import get_document_search


async def main():
document_search = get_document_search()

# You can ingest a document using one of ragbits built-in sources, such as web, s3, gcs, azure or git.
result = await document_search.ingest([
WebSource(url="https://www.ecdc.europa.eu/sites/default/files/documents/AER-Yellow-Fever-2019.pdf")
])
result = await document_search.ingest(
"web://https://www.ecdc.europa.eu/sites/default/files/documents/AER-Yellow-Fever-2019.pdf"
)

print(f"Number of successfully parsed documents: {len(result.successful)}")

Expand Down
Loading