Skip to content

Commit bbf08cd

Browse files
authored
feat: new rag template (#2)
1 parent dacbb75 commit bbf08cd

13 files changed

+293
-121
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ repos:
3434
- id: check-types
3535
name: check types
3636
types_or: [python, jupyter]
37-
entry: uv run mypy .
37+
entry: uv run mypy . --exclude templates
3838
language: system
3939
always_run: false
4040
pass_filenames: false

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "create-ragbits-app"
3-
version = "0.0.4"
3+
version = "0.0.6"
44
description = "Set up a modern LLM app by running one command"
55
readme = "README.md"
66
requires-python = ">=3.11"

src/create_ragbits_app/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ async def run() -> None:
4444
# Let user select a template
4545
from inquirer.shortcuts import list_input, text
4646

47-
selected_template_str = list_input("Select a template to use", choices=template_choices)
47+
selected_template_str = list_input("Select a template to use (more to come soon!)", choices=template_choices)
4848

4949
# Get the directory name from the selection
5050
selected_idx = template_choices.index(selected_template_str)

src/create_ragbits_app/template_utils.py

Lines changed: 68 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,17 @@
1919
import sys
2020

2121
import jinja2
22+
from rich.console import Console
23+
from rich.progress import Progress, SpinnerColumn, TextColumn
24+
from rich.tree import Tree
2225

2326
from create_ragbits_app.template_config_base import TemplateConfig
2427

2528
# Get templates directory
2629
TEMPLATES_DIR = pathlib.Path(__file__).parent / "templates"
2730

31+
console = Console()
32+
2833

2934
def get_available_templates() -> list[dict]:
3035
"""Get list of available templates from templates directory with their metadata."""
@@ -79,48 +84,68 @@ def create_project(template_name: str, project_path: str, context: dict) -> None
7984
os.makedirs(project_path, exist_ok=True)
8085

8186
# Process all template files and directories
82-
for item in template_path.glob("**/*"):
83-
if item.name == "template_config.py":
84-
continue # Skip template config file
85-
86-
# Get relative path from template root
87-
rel_path = str(item.relative_to(template_path))
88-
89-
# Process path parts for Jinja templating (for directory names)
90-
path_parts = []
91-
for part in pathlib.Path(rel_path).parts:
92-
if "{{" in part and "}}" in part:
93-
# Render the directory name as a template
94-
name_template = jinja2.Template(part)
95-
rendered_part = name_template.render(**context)
96-
path_parts.append(rendered_part)
97-
else:
98-
path_parts.append(part)
99-
100-
# Construct the target path with processed directory names
101-
target_rel_path = os.path.join(*path_parts) if path_parts else ""
102-
target_path = pathlib.Path(project_path) / target_rel_path
103-
104-
if item.is_dir():
105-
os.makedirs(target_path, exist_ok=True)
106-
elif item.is_file():
107-
# Process as template if it's a .j2 file
108-
if item.suffix == ".j2":
109-
with open(item) as f:
110-
template_content = f.read()
111-
112-
# Render template with context
113-
template = jinja2.Template(template_content)
114-
rendered_content = template.render(**context)
115-
116-
# Save to target path without .j2 extension
117-
target_path = target_path.with_suffix("")
118-
with open(target_path, "w") as f:
119-
f.write(rendered_content)
87+
with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), console=console) as progress:
88+
progress.add_task("[cyan]Creating project structure...", total=None)
89+
90+
for item in template_path.glob("**/*"):
91+
if item.name == "template_config.py":
92+
continue # Skip template config file
93+
94+
# Get relative path from template root
95+
rel_path = str(item.relative_to(template_path))
96+
97+
# Process path parts for Jinja templating (for directory names)
98+
path_parts = []
99+
for part in pathlib.Path(rel_path).parts:
100+
if "{{" in part and "}}" in part:
101+
# Render the directory name as a template
102+
name_template = jinja2.Template(part)
103+
rendered_part = name_template.render(**context)
104+
path_parts.append(rendered_part)
105+
else:
106+
path_parts.append(part)
107+
108+
# Construct the target path with processed directory names
109+
target_rel_path = os.path.join(*path_parts) if path_parts else ""
110+
target_path = pathlib.Path(project_path) / target_rel_path
111+
112+
if item.is_dir():
113+
os.makedirs(target_path, exist_ok=True)
114+
elif item.is_file():
115+
# Process as template if it's a .j2 file
116+
if item.suffix == ".j2":
117+
with open(item) as f:
118+
template_content = f.read()
119+
120+
# Render template with context
121+
template = jinja2.Template(template_content)
122+
rendered_content = template.render(**context)
123+
124+
# Save to target path without .j2 extension
125+
target_path = target_path.with_suffix("")
126+
with open(target_path, "w") as f:
127+
f.write(rendered_content)
128+
else:
129+
# Create parent directories if they don't exist
130+
os.makedirs(target_path.parent, exist_ok=True)
131+
# Simple file copy
132+
shutil.copy2(item, target_path)
133+
134+
# Display project structure
135+
console.print("\n[bold green]✓ Project created successfully![/bold green]")
136+
console.print(f"[bold]Project location:[/bold] {project_path}\n")
137+
138+
# Create and display project tree
139+
tree = Tree("[bold blue]Project Structure[/bold blue]")
140+
project_root = pathlib.Path(project_path)
141+
142+
def build_tree(node: Tree, path: pathlib.Path) -> None:
143+
for item in path.iterdir():
144+
if item.is_dir():
145+
branch = node.add(f"[bold cyan]{item.name}[/bold cyan]")
146+
build_tree(branch, item)
120147
else:
121-
# Create parent directories if they don't exist
122-
os.makedirs(target_path.parent, exist_ok=True)
123-
# Simple file copy
124-
shutil.copy2(item, target_path)
148+
node.add(f"[green]{item.name}[/green]")
125149

126-
print(f"Project created successfully at {project_path}")
150+
build_tree(tree, project_root)
151+
console.print(tree)
Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,34 @@
11
version: '3.8'
2+
23
services:
4+
{%- if vector_store == "Qdrant" %}
35
qdrant:
4-
image: qdrant/qdrant
6+
image: qdrant/qdrant:latest
57
container_name: qdrant-db
68
ports:
79
- "6333:6333"
810
- "6334:6334"
911
volumes:
10-
- /opt/qdrant_storage_ragbits_{{pkg_name}}:/qdrant/storage:z
12+
- qdrant_data:/qdrant/storage
1113
restart: unless-stopped
14+
{%- elif vector_store == "Postgresql with pgvector" %}
15+
postgres:
16+
image: ankane/pgvector:latest
17+
container_name: postgres-db
18+
ports:
19+
- "5432:5432"
20+
environment:
21+
POSTGRES_DB: postgres
22+
POSTGRES_USER: postgres
23+
POSTGRES_PASSWORD: postgres
24+
volumes:
25+
- postgres_data:/var/lib/postgresql/data
26+
restart: unless-stopped
27+
{%- endif %}
28+
29+
volumes:
30+
{%- if vector_store == "Qdrant" %}
31+
qdrant_data:
32+
{%- elif vector_store == "Postgresql with pgvector" %}
33+
postgres_data:
34+
{%- endif %}

src/create_ragbits_app/templates/rag/ingest.py.j2

Lines changed: 0 additions & 23 deletions
This file was deleted.

src/create_ragbits_app/templates/rag/pyproject.toml.j2

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ description = "Repository generated with ragbits"
55
readme = "README.md"
66
requires-python = ">={{ python_version }}"
77
dependencies = [
8-
"ragbits[qdrant]=={{ ragbits_version }}",
9-
"pydantic-settings",
10-
"unstructured[pdf]>=0.17.2",
8+
{%- for dep in dependencies %}
9+
"{{ dep }}",
10+
{%- endfor %}
1111
]
1212

1313
[build-system]
Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,34 @@
11
import logging
22

3+
{%- if vector_store == "Qdrant" %}
34
from qdrant_client import AsyncQdrantClient
4-
from ragbits.core.embeddings import LiteLLMEmbedder
5-
from ragbits.core.llms import LiteLLM
65
from ragbits.core.vector_stores.qdrant import QdrantVectorStore
6+
{%- elif vector_store == "Postgresql with pgvector" %}
7+
from ragbits.core.vector_stores.pgvector import PgvectorVectorStore
8+
import asyncpg
9+
{%- endif %}
10+
11+
from ragbits.core.embeddings.dense import LiteLLMEmbedder
12+
{%- if hybrid_search %}
13+
from ragbits.core.embeddings.sparse.fastembed import FastEmbedSparseEmbedder
14+
from ragbits.core.vector_stores.hybrid import HybridSearchVectorStore
15+
{%- endif %}
16+
from ragbits.core.llms import LiteLLM
717
from ragbits.document_search import DocumentSearch
18+
{%- if image_description %}
819
from ragbits.document_search.documents.element import ImageElement
920
from ragbits.document_search.ingestion.enrichers import ElementEnricherRouter, ImageElementEnricher
21+
{%- endif %}
22+
from ragbits.document_search.ingestion.parsers import DocumentParserRouter
23+
from ragbits.document_search.documents.document import DocumentType
1024
from ragbits.document_search.retrieval.rephrasers import LLMQueryRephraser
1125

26+
{%- if parser == "unstructured" %}
27+
from ragbits.document_search.ingestion.parsers.unstructured import UnstructuredDocumentParser
28+
{%- elif parser == "docling" %}
29+
from ragbits.document_search.ingestion.parsers.docling import DoclingDocumentParser
30+
{%- endif %}
31+
1232
from {{pkg_name}}.config import config
1333

1434
# disable logging from LiteLLM as ragbits already logs the necessary information
@@ -21,21 +41,62 @@ def get_llm():
2141

2242
def get_vector_store():
2343
store_prefix = "{{project_name}}"
24-
qdrant_client = AsyncQdrantClient(config.qdrant_host)
25-
dense_embedder = LiteLLMEmbedder(model=config.embedding_model)
44+
dense_embedder = LiteLLMEmbedder(model_name=config.embedding_model)
2645

27-
return QdrantVectorStore(client=qdrant_client, embedder=dense_embedder, index_name=store_prefix + "-dense")
46+
{%- if vector_store == "Qdrant" %}
47+
qdrant_client = AsyncQdrantClient(config.qdrant_host)
48+
dense_store = QdrantVectorStore(client=qdrant_client, embedder=dense_embedder, index_name=store_prefix + "-dense")
49+
{%- if hybrid_search %}
50+
sparse_embedder = FastEmbedSparseEmbedder(model_name="Qdrant/bm25")
51+
sparse_store = QdrantVectorStore(client=qdrant_client, embedder=sparse_embedder, index_name=store_prefix + "-sparse")
52+
return HybridSearchVectorStore(dense_store, sparse_store)
53+
{%- else %}
54+
return dense_store
55+
{%- endif %}
56+
{% elif vector_store == "Postgresql with pgvector" %}
57+
pool = await asyncpg.create_pool(config.postgres_dsn)
58+
dense_store = PgvectorVectorStore(pool=pool, embedder=dense_embedder, table_name=store_prefix + "-dense")
59+
{%- if hybrid_search %}
60+
sparse_embedder = FastEmbedSparseEmbedder(model_name="Qdrant/bm25")
61+
sparse_store = PgvectorVectorStore(pool=pool, embedder=sparse_embedder, table_name=store_prefix + "-sparse")
62+
return HybridSearchVectorStore(dense_store, sparse_store)
63+
{%- else %}
64+
return dense_store
65+
{%- endif %}
66+
{%- endif %}
2867

2968
def get_document_search():
3069
llm = get_llm()
3170
vector_store = get_vector_store()
3271

72+
# Configure document parsers
73+
{%- if parser == "unstructured" %}
74+
parser_router = DocumentParserRouter({
75+
DocumentType.PDF: UnstructuredDocumentParser(),
76+
DocumentType.DOCX: UnstructuredDocumentParser(),
77+
DocumentType.HTML: UnstructuredDocumentParser(),
78+
DocumentType.TXT: UnstructuredDocumentParser(),
79+
})
80+
{%- elif parser == "docling" %}
81+
parser_router = DocumentParserRouter({
82+
DocumentType.PDF: DoclingDocumentParser(),
83+
DocumentType.DOCX: DoclingDocumentParser(),
84+
DocumentType.HTML: DoclingDocumentParser(),
85+
DocumentType.TXT: DoclingDocumentParser(),
86+
})
87+
{%- endif %}
88+
3389
document_search = DocumentSearch(
3490
vector_store=vector_store,
3591
query_rephraser=LLMQueryRephraser(llm=llm), # Setup query rephrasing
92+
parser_router=parser_router, # Add document parser configuration
93+
{%- if image_description %}
3694
enricher_router=ElementEnricherRouter({
3795
ImageElement: ImageElementEnricher(llm=llm) # Get image descriptions with multi-modal LLM
38-
})
96+
}),
97+
{%- else %}
98+
enricher_router=None,
99+
{%- endif %}
39100
)
40101

41102
return document_search

src/create_ragbits_app/templates/rag/src/{{pkg_name}}/config.py.j2

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,17 @@ from pydantic_settings import BaseSettings
44

55

66
class AppConfig(BaseSettings):
7+
# Vector store configuration
8+
{%- if vector_store == "Qdrant" %}
79
qdrant_host: str = "http://localhost:6333"
8-
9-
llm_model: str = "gpt-4o-mini"
10-
embedding_model: str = "text-embedding-3-small"
11-
12-
openai_api_key: str = ""
10+
{%- elif vector_store == "Postgresql with pgvector" %}
11+
postgres_dsn: str = "postgresql://postgres:postgres@localhost:5432/postgres"
12+
{%- endif %}
13+
14+
# LLM configuration
15+
llm_model: str = "gpt-4.1-mini"
16+
embedding_model: str = "text-embedding-3-large"
17+
openai_api_key: str
1318

1419
class Config:
1520
env_file = Path(__file__).parent.parent.parent / ".env"

src/create_ragbits_app/templates/rag/src/{{pkg_name}}/ingest.py.j2

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
import asyncio
22

3-
from ragbits.document_search.documents.sources import WebSource
4-
53
from {{pkg_name}}.components import get_document_search
64

75

86
async def main():
97
document_search = get_document_search()
108

119
# You can ingest a document using one of ragbits built-in sources, such as web, s3, gcs, azure or git.
12-
result = await document_search.ingest([
13-
WebSource(url="https://www.ecdc.europa.eu/sites/default/files/documents/AER-Yellow-Fever-2019.pdf")
14-
])
10+
result = await document_search.ingest(
11+
"web://https://arxiv.org/pdf/2310.06825"
12+
)
1513

1614
print(f"Number of successfully parsed documents: {len(result.successful)}")
1715

0 commit comments

Comments
 (0)