From 2704ad11d30d963b13761eb873d9015a846cb6e9 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Thu, 20 Jun 2024 10:56:21 +0200 Subject: [PATCH 01/64] feat: question fallback support --- examples/recruiting/candidates_freeform.py | 42 ++++++++ examples/visualize_multicollections_code.py | 32 ++++++ src/dbally/audit/event_handlers/base.py | 9 ++ .../audit/event_handlers/cli_event_handler.py | 14 +++ src/dbally/audit/event_tracker.py | 12 +++ src/dbally/collection/collection.py | 102 +++++++++++------- 6 files changed, 173 insertions(+), 38 deletions(-) create mode 100644 examples/recruiting/candidates_freeform.py create mode 100644 examples/visualize_multicollections_code.py diff --git a/examples/recruiting/candidates_freeform.py b/examples/recruiting/candidates_freeform.py new file mode 100644 index 00000000..5aa67d9d --- /dev/null +++ b/examples/recruiting/candidates_freeform.py @@ -0,0 +1,42 @@ +# pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring +from typing import List + +from sqlalchemy import create_engine +from sqlalchemy.ext.automap import automap_base + +from dbally.views.freeform.text2sql import BaseText2SQLView, ColumnConfig, TableConfig + +engine = create_engine("sqlite:///examples/recruiting/data/candidates.db") + +_Base = automap_base() +_Base.prepare(autoload_with=engine) +_Candidate = _Base.classes.candidates + + +class CandidateFreeformView(BaseText2SQLView): + """ + A view for retrieving candidates from the database. + """ + + def get_tables(self) -> List[TableConfig]: + """ + Get the tables used by the view. + + Returns: + A list of tables. + """ + return [ + TableConfig( + name="candidates", + columns=[ + ColumnConfig("name", "TEXT"), + ColumnConfig("country", "TEXT"), + ColumnConfig("years_of_experience", "INTEGER"), + ColumnConfig("position", "TEXT"), + ColumnConfig("university", "TEXT"), + ColumnConfig("skills", "TEXT"), + ColumnConfig("tags", "TEXT"), + ColumnConfig("id", "INTEGER PRIMARY KEY"), + ], + ), + ] diff --git a/examples/visualize_multicollections_code.py b/examples/visualize_multicollections_code.py new file mode 100644 index 00000000..119d4235 --- /dev/null +++ b/examples/visualize_multicollections_code.py @@ -0,0 +1,32 @@ +# pylint: disable=missing-function-docstring +import asyncio + +from recruiting import candidate_view_with_similarity_store, candidates_freeform +from recruiting.candidate_view_with_similarity_store import CandidateView, country_similarity +from recruiting.candidates_freeform import CandidateFreeformView +from recruiting.cypher_text2sql_view import SampleText2SQLViewCyphers, create_freeform_memory_engine +from sqlalchemy import create_engine + +import dbally +from dbally.audit import CLIEventHandler +from dbally.gradio import create_gradio_interface +from dbally.llms.litellm import LiteLLM + +cm_engine = create_engine("postgresql+pg8000://postgres:ikar89pl@localhost:5432/codebase_community") + + +async def main(): + await country_similarity.update() + llm = LiteLLM(model_name="gpt-3.5-turbo") + collection1 = dbally.create_collection("candidates", llm, event_handlers=[CLIEventHandler()]) + collection2 = dbally.create_collection("freeform candidates", llm, event_handlers=[CLIEventHandler()]) + collection1.add(CandidateView, lambda: CandidateView(candidate_view_with_similarity_store.engine)) + collection1.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) + collection2.add(CandidateFreeformView, lambda: CandidateFreeformView(candidates_freeform.engine)) + collection1.set_fallback_collection(collection2) + gradio_interface = await create_gradio_interface(user_collection=collection1) + gradio_interface.launch() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/dbally/audit/event_handlers/base.py b/src/dbally/audit/event_handlers/base.py index 10fce0cf..95468e4a 100644 --- a/src/dbally/audit/event_handlers/base.py +++ b/src/dbally/audit/event_handlers/base.py @@ -61,3 +61,12 @@ async def request_end(self, output: RequestEnd, request_context: RequestCtx) -> output: The output of the request. request_context: Optional context passed from request_start method """ + + @abc.abstractmethod + async def log_message(self, message: str, log_level: str) -> None: + """ + Displays the response from the LLM. + + Args: + message: db-ally event to be logged with all the details. + """ diff --git a/src/dbally/audit/event_handlers/cli_event_handler.py b/src/dbally/audit/event_handlers/cli_event_handler.py index f738f90b..d6fe2c89 100644 --- a/src/dbally/audit/event_handlers/cli_event_handler.py +++ b/src/dbally/audit/event_handlers/cli_event_handler.py @@ -98,6 +98,20 @@ async def event_start(self, event: Union[LLMEvent, SimilarityEvent], request_con f"[cyan bold]FETCHER: [grey53]{event.fetcher}\n" ) + async def log_message(self, message: str, log_level="INFO") -> None: + """ + Displays message logged by user + + Args: + message: Message to be sent + log_level: Message log level. + """ + self._print_syntax("[grey53]\n=======================================") + self._print_syntax("[grey53]=======================================") + self._print_syntax(f"[green]{log_level}: {message}") + self._print_syntax("[grey53]=======================================") + self._print_syntax("[grey53]=======================================\n") + async def event_end( self, event: Union[None, LLMEvent, SimilarityEvent], request_context: None, event_context: None ) -> None: diff --git a/src/dbally/audit/event_tracker.py b/src/dbally/audit/event_tracker.py index c483a65e..2d3f9e99 100644 --- a/src/dbally/audit/event_tracker.py +++ b/src/dbally/audit/event_tracker.py @@ -92,3 +92,15 @@ async def track_event(self, event: Union[LLMEvent, SimilarityEvent]) -> AsyncIte await handler.event_end( span.data, event_context=contexts[handler], request_context=self._request_contexts[handler] ) + + async def log_message(self, message, log_level="INFO") -> None: + """ + Send message to the handler + + Args: + message: Message to be sent to + log_level: Log level. + """ + + for handler in self._handlers: + await handler.log_message(message, log_level) diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index c207d95b..345b3c3a 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -60,6 +60,7 @@ def __init__( self._nl_responder = nl_responder self._event_handlers = event_handlers self._llm = llm + self._fallback_collection: Optional[Collection] = None T = TypeVar("T", bound=BaseView) @@ -120,6 +121,19 @@ def add_event_handler(self, event_handler: EventHandler): """ self._event_handlers.append(event_handler) + def add_fallback_collection(self, fallback_collection: "Collection"): + """ + Add fallback collection which will be asked if the base collection does not succeed. + + Args: + fallback_collection: Collection to be asked in case of base collection failure. + + Returns: + The fallback collection to create chains call + """ + self._fallback_collection = fallback_collection + return fallback_collection + def get(self, name: str) -> BaseView: """ Returns an instance of the view with the given name @@ -192,49 +206,61 @@ async def ask( # select view views = self.list() - - if len(views) == 0: - raise ValueError("Empty collection") - if len(views) == 1: - selected_view = next(iter(views)) - else: - selected_view = await self._view_selector.select_view( - question=question, - views=views, + selected_view = None + + try: + if len(views) == 0: + raise ValueError("Empty collection") + if len(views) == 1: + selected_view = next(iter(views)) + else: + selected_view = await self._view_selector.select_view( + question=question, + views=views, + event_tracker=event_tracker, + llm_options=llm_options, + ) + + view = self.get(selected_view) + + start_time_view = time.monotonic() + view_result = await view.ask( + query=question, + llm=self._llm, event_tracker=event_tracker, + n_retries=self.n_retries, + dry_run=dry_run, llm_options=llm_options, ) - - view = self.get(selected_view) - - start_time_view = time.monotonic() - view_result = await view.ask( - query=question, - llm=self._llm, - event_tracker=event_tracker, - n_retries=self.n_retries, - dry_run=dry_run, - llm_options=llm_options, - ) - end_time_view = time.monotonic() - - textual_response = None - if not dry_run and return_natural_response: - textual_response = await self._nl_responder.generate_response( - result=view_result, - question=question, - event_tracker=event_tracker, - llm_options=llm_options, + end_time_view = time.monotonic() + + textual_response = None + if not dry_run and return_natural_response: + textual_response = await self._nl_responder.generate_response( + result=view_result, + question=question, + event_tracker=event_tracker, + llm_options=llm_options, + ) + + result = ExecutionResult( + results=view_result.results, + context=view_result.context, + execution_time=time.monotonic() - start_time, + execution_time_view=end_time_view - start_time_view, + view_name=selected_view, + textual_response=textual_response, ) + except Exception as e: + await event_tracker.log_message( + f"Exception occurred during {selected_view} processing. Executing view from fallback collection", + log_level="Warning", + ) + if self._fallback_collection: + result = await self._fallback_collection.ask(question, dry_run, return_natural_response, llm_options) - result = ExecutionResult( - results=view_result.results, - context=view_result.context, - execution_time=time.monotonic() - start_time, - execution_time_view=end_time_view - start_time_view, - view_name=selected_view, - textual_response=textual_response, - ) + else: + raise e await event_tracker.request_end(RequestEnd(result=result)) From 01491c20df72f724a86ae68d43cdfa8fa46e128e Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Thu, 20 Jun 2024 12:08:34 +0200 Subject: [PATCH 02/64] Add synthatic sugar --- examples/visualize_multicollections_code.py | 2 +- src/dbally/_main.py | 3 +++ src/dbally/audit/event_handlers/base.py | 4 ++- .../audit/event_handlers/cli_event_handler.py | 16 ++++++++++-- src/dbally/audit/event_tracker.py | 10 +++++++- src/dbally/collection/collection.py | 25 +++++++++++++++---- 6 files changed, 50 insertions(+), 10 deletions(-) diff --git a/examples/visualize_multicollections_code.py b/examples/visualize_multicollections_code.py index 119d4235..3709c585 100644 --- a/examples/visualize_multicollections_code.py +++ b/examples/visualize_multicollections_code.py @@ -23,7 +23,7 @@ async def main(): collection1.add(CandidateView, lambda: CandidateView(candidate_view_with_similarity_store.engine)) collection1.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) collection2.add(CandidateFreeformView, lambda: CandidateFreeformView(candidates_freeform.engine)) - collection1.set_fallback_collection(collection2) + collection1.add_fallback(collection2) gradio_interface = await create_gradio_interface(user_collection=collection1) gradio_interface.launch() diff --git a/src/dbally/_main.py b/src/dbally/_main.py index eeb2d836..ab26c4b7 100644 --- a/src/dbally/_main.py +++ b/src/dbally/_main.py @@ -14,6 +14,7 @@ def create_collection( event_handlers: Optional[List[EventHandler]] = None, view_selector: Optional[ViewSelector] = None, nl_responder: Optional[NLResponder] = None, + fallback_collection: Optional[Collection] = None, ) -> Collection: """ Create a new [Collection](collection.md) that is a container for registering views and the\ @@ -44,6 +45,7 @@ def create_collection( will be used. nl_responder: NL responder used by the collection to respond to natural language queries. If None,\ a new instance of [NLResponder][dbally.nl_responder.nl_responder.NLResponder] will be used. + fallback_collection: Collection to be asked in case of base collection failure. Returns: a new instance of db-ally Collection @@ -61,4 +63,5 @@ def create_collection( view_selector=view_selector, llm=llm, event_handlers=event_handlers, + fallback_collection=fallback_collection, ) diff --git a/src/dbally/audit/event_handlers/base.py b/src/dbally/audit/event_handlers/base.py index 95468e4a..067b45fb 100644 --- a/src/dbally/audit/event_handlers/base.py +++ b/src/dbally/audit/event_handlers/base.py @@ -2,6 +2,7 @@ from abc import ABC from typing import Generic, TypeVar, Union +from dbally.audit.event_tracker import LogLevel from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart, SimilarityEvent RequestCtx = TypeVar("RequestCtx") @@ -63,10 +64,11 @@ async def request_end(self, output: RequestEnd, request_context: RequestCtx) -> """ @abc.abstractmethod - async def log_message(self, message: str, log_level: str) -> None: + async def log_message(self, message: str, log_level: LogLevel = LogLevel.INFO) -> None: """ Displays the response from the LLM. Args: message: db-ally event to be logged with all the details. + log_level: log level/importance """ diff --git a/src/dbally/audit/event_handlers/cli_event_handler.py b/src/dbally/audit/event_handlers/cli_event_handler.py index d6fe2c89..498fa78f 100644 --- a/src/dbally/audit/event_handlers/cli_event_handler.py +++ b/src/dbally/audit/event_handlers/cli_event_handler.py @@ -3,6 +3,8 @@ from sys import stdout from typing import Optional, Union +from dbally.audit.event_tracker import LogLevel + try: from rich import print as pprint from rich.console import Console @@ -98,7 +100,7 @@ async def event_start(self, event: Union[LLMEvent, SimilarityEvent], request_con f"[cyan bold]FETCHER: [grey53]{event.fetcher}\n" ) - async def log_message(self, message: str, log_level="INFO") -> None: + async def log_message(self, message: str, log_level: LogLevel = LogLevel.INFO) -> None: """ Displays message logged by user @@ -106,9 +108,19 @@ async def log_message(self, message: str, log_level="INFO") -> None: message: Message to be sent log_level: Message log level. """ + colour = None + if log_level == LogLevel.INFO: + colour = "white" + elif log_level == LogLevel.WARNING: + colour = "orange" + elif log_level == LogLevel.ERROR: + colour = "red" + elif log_level == LogLevel.DEBUG: + colour = "blue" + self._print_syntax("[grey53]\n=======================================") self._print_syntax("[grey53]=======================================") - self._print_syntax(f"[green]{log_level}: {message}") + self._print_syntax(f"[{colour}]{log_level}: {message}") self._print_syntax("[grey53]=======================================") self._print_syntax("[grey53]=======================================\n") diff --git a/src/dbally/audit/event_tracker.py b/src/dbally/audit/event_tracker.py index 2d3f9e99..ecbbbf4f 100644 --- a/src/dbally/audit/event_tracker.py +++ b/src/dbally/audit/event_tracker.py @@ -1,4 +1,5 @@ from contextlib import asynccontextmanager +from enum import StrEnum from typing import AsyncIterator, Dict, List, Optional, Union from dbally.audit.event_handlers.base import EventHandler @@ -6,6 +7,13 @@ from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart, SimilarityEvent +class LogLevel(StrEnum): + INFO = "Info" + WARNING = "Warning" + ERROR = "Error" + DEBUG = "Debug" + + class EventTracker: """ Container for event handlers and is responsible for processing events.""" @@ -93,7 +101,7 @@ async def track_event(self, event: Union[LLMEvent, SimilarityEvent]) -> AsyncIte span.data, event_context=contexts[handler], request_context=self._request_contexts[handler] ) - async def log_message(self, message, log_level="INFO") -> None: + async def log_message(self, message: str, log_level: LogLevel = LogLevel.INFO) -> None: """ Send message to the handler diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 345b3c3a..69830c25 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -35,12 +35,13 @@ def __init__( event_handlers: List[EventHandler], nl_responder: NLResponder, n_retries: int = 3, + fallback_collection: Optional["Collection"] = None, ) -> None: """ Args: name: Name of the collection is available for [Event handlers](event_handlers/index.md) and is\ used to distinguish different db-ally runs. - view_selector: As you register more then one [View](views/index.md) within single collection,\ + view_selector: As you register more than one [View](views/index.md) within single collection,\ before generating the IQL query, a View that fits query the most is selected by the\ [ViewSelector](view_selection/index.md). llm: LLM used by the collection to generate views and respond to natural language queries. @@ -51,6 +52,8 @@ def __init__( n_retries: IQL generator may produce invalid IQL. If this is the case this argument specifies\ how many times db-ally will try to regenerate it. Previous try with the error message is\ appended to the chat history to guide next generations. + fallback_collection: collection to be asked when the ask function could not find answer in views registered + to this collection """ self.name = name self.n_retries = n_retries @@ -60,7 +63,7 @@ def __init__( self._nl_responder = nl_responder self._event_handlers = event_handlers self._llm = llm - self._fallback_collection: Optional[Collection] = None + self._fallback_collection: Optional[Collection] = fallback_collection T = TypeVar("T", bound=BaseView) @@ -73,7 +76,7 @@ def add(self, view: Type[T], builder: Optional[Callable[[], T]] = None, name: Op query execution. We expect Class instead of object, as otherwise Views must have been implemented\ stateless, which would be cumbersome. builder: Optional factory function that will be used to create the View instance. Use it when you\ - need to pass outcome of API call or database connection to the view and it can change over time. + need to pass outcome of API call or database connection to the view, and it can change over time. name: Custom name of the view (defaults to the name of the class). Raises: @@ -121,9 +124,9 @@ def add_event_handler(self, event_handler: EventHandler): """ self._event_handlers.append(event_handler) - def add_fallback_collection(self, fallback_collection: "Collection"): + def add_fallback(self, fallback_collection: "Collection"): """ - Add fallback collection which will be asked if the base collection does not succeed. + Add fallback collection which will be asked if the ask to base collection does not succeed. Args: fallback_collection: Collection to be asked in case of base collection failure. @@ -134,6 +137,18 @@ def add_fallback_collection(self, fallback_collection: "Collection"): self._fallback_collection = fallback_collection return fallback_collection + def __rshift__(self, fallback_collection: "Collection"): + """ + Add fallback collection which will be asked if the ask to base collection does not succeed. + + Args: + fallback_collection: Collection to be asked in case of base collection failure. + + Returns: + The fallback collection to create chains call + """ + return self.add_fallback(fallback_collection) + def get(self, name: str) -> BaseView: """ Returns an instance of the view with the given name From 594fa4f6fdd1a1b725f30ab126f826a5e2727d42 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Thu, 20 Jun 2024 14:06:25 +0200 Subject: [PATCH 03/64] polishing --- src/dbally/audit/event_handlers/base.py | 24 ++++++++++++++++++- .../audit/event_handlers/cli_event_handler.py | 15 +++++++----- src/dbally/audit/event_tracker.py | 10 +------- src/dbally/collection/collection.py | 18 +++++++------- src/dbally/collection/results.py | 4 ++-- src/dbally/gradio/gradio_interface.py | 11 ++++++--- 6 files changed, 53 insertions(+), 29 deletions(-) diff --git a/src/dbally/audit/event_handlers/base.py b/src/dbally/audit/event_handlers/base.py index 067b45fb..2ac143b3 100644 --- a/src/dbally/audit/event_handlers/base.py +++ b/src/dbally/audit/event_handlers/base.py @@ -2,13 +2,35 @@ from abc import ABC from typing import Generic, TypeVar, Union -from dbally.audit.event_tracker import LogLevel +from strenum import StrEnum + from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart, SimilarityEvent RequestCtx = TypeVar("RequestCtx") EventCtx = TypeVar("EventCtx") +class LogLevel(StrEnum): + """ + An enumeration representing different logging levels. + + This enumeration inherits from `StrEnum`, making each log level a string value. + The log levels indicate the severity or type of events that occur within an application, + and they are commonly used to filter and categorize log messages. + + Attributes: + INFO (str): Represents informational messages that highlight the progress of the application. + WARNING (str): Represents potentially harmful situations that require attention. + ERROR (str): Represents error events that might still allow the application to continue running. + DEBUG (str): Represents detailed debugging messages useful during development and troubleshooting. + """ + + INFO = "Info" + WARNING = "Warning" + ERROR = "Error" + DEBUG = "Debug" + + class EventHandler(Generic[RequestCtx, EventCtx], ABC): """ A base class that every custom handler should inherit from diff --git a/src/dbally/audit/event_handlers/cli_event_handler.py b/src/dbally/audit/event_handlers/cli_event_handler.py index 498fa78f..62ba7466 100644 --- a/src/dbally/audit/event_handlers/cli_event_handler.py +++ b/src/dbally/audit/event_handlers/cli_event_handler.py @@ -3,7 +3,9 @@ from sys import stdout from typing import Optional, Union +from dbally.audit.event_handlers.base import EventHandler from dbally.audit.event_tracker import LogLevel +from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart, SimilarityEvent try: from rich import print as pprint @@ -16,8 +18,6 @@ RICH_OUTPUT = False pprint = print # type: ignore -from dbally.audit.event_handlers.base import EventHandler -from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart, SimilarityEvent _RICH_FORMATING_KEYWORD_SET = {"green", "orange", "grey", "bold", "cyan"} _RICH_FORMATING_PATTERN = rf"\[.*({'|'.join(_RICH_FORMATING_KEYWORD_SET)}).*\]" @@ -152,8 +152,11 @@ async def request_end(self, output: RequestEnd, request_context: Optional[dict] output: The output of the request. request_context: Optional context passed from request_start method """ - self._print_syntax("[green bold]REQUEST OUTPUT:") - self._print_syntax(f"Number of rows: {len(output.result.results)}") + if output.result: + self._print_syntax("[green bold]REQUEST OUTPUT:") + self._print_syntax(f"Number of rows: {len(output.result.results)}") - if "sql" in output.result.context: - self._print_syntax(f"{output.result.context['sql']}", "psql") + if "sql" in output.result.context: + self._print_syntax(f"{output.result.context['sql']}", "psql") + else: + self._print_syntax("[red bold]No results found") diff --git a/src/dbally/audit/event_tracker.py b/src/dbally/audit/event_tracker.py index ecbbbf4f..89fbe552 100644 --- a/src/dbally/audit/event_tracker.py +++ b/src/dbally/audit/event_tracker.py @@ -1,19 +1,11 @@ from contextlib import asynccontextmanager -from enum import StrEnum from typing import AsyncIterator, Dict, List, Optional, Union -from dbally.audit.event_handlers.base import EventHandler +from dbally.audit.event_handlers.base import EventHandler, LogLevel from dbally.audit.event_span import EventSpan from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart, SimilarityEvent -class LogLevel(StrEnum): - INFO = "Info" - WARNING = "Warning" - ERROR = "Error" - DEBUG = "Debug" - - class EventTracker: """ Container for event handlers and is responsible for processing events.""" diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 69830c25..293948c5 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -5,8 +5,9 @@ from collections import defaultdict from typing import Callable, Dict, List, Optional, Type, TypeVar +import dbally from dbally.audit.event_handlers.base import EventHandler -from dbally.audit.event_tracker import EventTracker +from dbally.audit.event_tracker import EventTracker, LogLevel from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ExecutionResult from dbally.data_models.audit import RequestEnd, RequestStart @@ -124,7 +125,7 @@ def add_event_handler(self, event_handler: EventHandler): """ self._event_handlers.append(event_handler) - def add_fallback(self, fallback_collection: "Collection"): + def add_fallback(self, fallback_collection: "Collection") -> "Collection": """ Add fallback collection which will be asked if the ask to base collection does not succeed. @@ -185,7 +186,7 @@ async def ask( dry_run: bool = False, return_natural_response: bool = False, llm_options: Optional[LLMOptions] = None, - ) -> ExecutionResult: + ) -> Optional[ExecutionResult]: """ Ask question in a text form and retrieve the answer based on the available views. @@ -222,6 +223,7 @@ async def ask( # select view views = self.list() selected_view = None + result = None try: if len(views) == 0: @@ -266,16 +268,16 @@ async def ask( view_name=selected_view, textual_response=textual_response, ) - except Exception as e: + except dbally.DbAllyError: await event_tracker.log_message( - f"Exception occurred during {selected_view} processing. Executing view from fallback collection", - log_level="Warning", + f"Exception occurred during {selected_view} processing. Executing view from fallback" f"collection.", + log_level=LogLevel.INFO, ) if self._fallback_collection: result = await self._fallback_collection.ask(question, dry_run, return_natural_response, llm_options) - else: - raise e + await event_tracker.log_message(r"No results found", LogLevel.ERROR) + return None await event_tracker.request_end(RequestEnd(result=result)) diff --git a/src/dbally/collection/results.py b/src/dbally/collection/results.py index b33cf5e3..3a427505 100644 --- a/src/dbally/collection/results.py +++ b/src/dbally/collection/results.py @@ -39,6 +39,6 @@ class ExecutionResult: results: List[Dict[str, Any]] context: Dict[str, Any] execution_time: float - execution_time_view: float - view_name: str + execution_time_view: float = 0 + view_name: str = "" textual_response: Optional[str] = None diff --git a/src/dbally/gradio/gradio_interface.py b/src/dbally/gradio/gradio_interface.py index 30182b37..8902d315 100644 --- a/src/dbally/gradio/gradio_interface.py +++ b/src/dbally/gradio/gradio_interface.py @@ -114,9 +114,14 @@ async def _ui_ask_query( execution_result = await self.collection.ask( question=question_query, return_natural_response=natural_language_flag ) - generated_query = str(execution_result.context) - data = self._load_results_into_dataframe(execution_result.results) - textual_response = str(execution_result.textual_response) if natural_language_flag else textual_response + if execution_result: + generated_query = str(execution_result.context) + data = self._load_results_into_dataframe(execution_result.results) + textual_response = str(execution_result.textual_response) if natural_language_flag else textual_response + else: + generated_query = "No results generated" + data = pd.DataFrame() + except UnsupportedQueryError: generated_query = {"Query": "unsupported"} data = pd.DataFrame() From 8c43cf6bd65400989bea8cf3e53a5822edf44e16 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Thu, 20 Jun 2024 14:35:13 +0200 Subject: [PATCH 04/64] resolve cyclic import --- src/dbally/collection/collection.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 809b93a8..04de6e97 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -5,7 +5,7 @@ from collections import defaultdict from typing import Callable, Dict, List, Optional, Type, TypeVar -import dbally +from dbally import DbAllyError from dbally.audit.event_handlers.base import EventHandler, LogLevel from dbally.audit.event_tracker import EventTracker from dbally.audit.events import RequestEnd, RequestStart @@ -24,7 +24,7 @@ class Collection: Collection is a container for a set of views that can be used by db-ally to answer user questions. Tip: - It is recommended to create new collections using the [`dbally.create_colletion`][dbally.create_collection]\ + It is recommended to create new collections using the [`dbally.create_collection`][dbally.create_collection]\ function instead of instantiating this class directly. """ @@ -73,7 +73,7 @@ def add(self, view: Type[T], builder: Optional[Callable[[], T]] = None, name: Op Register new [View](views/index.md) that will be available to query via the collection. Args: - view: A class inherithing from BaseView. Object of this type will be initialized during\ + view: A class inheriting from BaseView. Object of this type will be initialized during\ query execution. We expect Class instead of object, as otherwise Views must have been implemented\ stateless, which would be cumbersome. builder: Optional factory function that will be used to create the View instance. Use it when you\ @@ -268,7 +268,7 @@ async def ask( view_name=selected_view, textual_response=textual_response, ) - except dbally.DbAllyError: + except DbAllyError: await event_tracker.log_message( f"Exception occurred during {selected_view} processing. Executing view from fallback" f"collection.", log_level=LogLevel.INFO, From 63b80c6e726c1db5e0c312b4ef149b68f7e2cd46 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Thu, 20 Jun 2024 14:36:55 +0200 Subject: [PATCH 05/64] fix build --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index 631f4d4c..b50724fd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,6 +39,7 @@ install_requires = tabulate>=0.9.0 click~=8.1.7 numpy>=1.24.0 + StrEnum>=0.4.15 [options.extras_require] litellm = From 839a187b0bdd112768967eda1fd6610d93437a89 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Thu, 20 Jun 2024 14:47:22 +0200 Subject: [PATCH 06/64] fixups --- ...ections_code.py => visualize_fallback_code.py} | 0 src/dbally/collection/collection.py | 15 +++++++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) rename examples/{visualize_multicollections_code.py => visualize_fallback_code.py} (100%) diff --git a/examples/visualize_multicollections_code.py b/examples/visualize_fallback_code.py similarity index 100% rename from examples/visualize_multicollections_code.py rename to examples/visualize_fallback_code.py diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 04de6e97..6120f96b 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -5,18 +5,22 @@ from collections import defaultdict from typing import Callable, Dict, List, Optional, Type, TypeVar -from dbally import DbAllyError +from dbally.assistants.base import FunctionCallingError from dbally.audit.event_handlers.base import EventHandler, LogLevel from dbally.audit.event_tracker import EventTracker from dbally.audit.events import RequestEnd, RequestStart from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ExecutionResult +from dbally.iql import IQLError +from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.nl_responder.nl_responder import NLResponder +from dbally.prompts import PromptTemplateError from dbally.similarity.index import AbstractSimilarityIndex from dbally.view_selection.base import ViewSelector from dbally.views.base import BaseView, IndexLocation +from dbally.views.freeform.text2sql import Text2SQLError class Collection: @@ -268,7 +272,14 @@ async def ask( view_name=selected_view, textual_response=textual_response, ) - except DbAllyError: + except ( + NoViewFoundError, + IQLError, + FunctionCallingError, + UnsupportedQueryError, + PromptTemplateError, + Text2SQLError, + ): await event_tracker.log_message( f"Exception occurred during {selected_view} processing. Executing view from fallback" f"collection.", log_level=LogLevel.INFO, From 0334b69d585c298338f83aa31114896c39681d4d Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Fri, 21 Jun 2024 09:48:09 +0200 Subject: [PATCH 07/64] fixups --- examples/visualize_fallback_code.py | 7 +--- setup.cfg | 1 - src/dbally/audit/event_handlers/base.py | 33 --------------- .../audit/event_handlers/cli_event_handler.py | 40 +++++++------------ src/dbally/audit/event_tracker.py | 14 +------ src/dbally/audit/events.py | 12 ++++++ src/dbally/collection/collection.py | 29 ++++++++------ 7 files changed, 46 insertions(+), 90 deletions(-) diff --git a/examples/visualize_fallback_code.py b/examples/visualize_fallback_code.py index 3709c585..6a5848e4 100644 --- a/examples/visualize_fallback_code.py +++ b/examples/visualize_fallback_code.py @@ -5,25 +5,22 @@ from recruiting.candidate_view_with_similarity_store import CandidateView, country_similarity from recruiting.candidates_freeform import CandidateFreeformView from recruiting.cypher_text2sql_view import SampleText2SQLViewCyphers, create_freeform_memory_engine -from sqlalchemy import create_engine import dbally from dbally.audit import CLIEventHandler from dbally.gradio import create_gradio_interface from dbally.llms.litellm import LiteLLM -cm_engine = create_engine("postgresql+pg8000://postgres:ikar89pl@localhost:5432/codebase_community") - async def main(): await country_similarity.update() llm = LiteLLM(model_name="gpt-3.5-turbo") collection1 = dbally.create_collection("candidates", llm, event_handlers=[CLIEventHandler()]) - collection2 = dbally.create_collection("freeform candidates", llm, event_handlers=[CLIEventHandler()]) + collection2 = dbally.create_collection("freeform candidates", llm, event_handlers=[]) collection1.add(CandidateView, lambda: CandidateView(candidate_view_with_similarity_store.engine)) collection1.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) collection2.add(CandidateFreeformView, lambda: CandidateFreeformView(candidates_freeform.engine)) - collection1.add_fallback(collection2) + collection1.set_fallback(collection2) gradio_interface = await create_gradio_interface(user_collection=collection1) gradio_interface.launch() diff --git a/setup.cfg b/setup.cfg index b50724fd..631f4d4c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,7 +39,6 @@ install_requires = tabulate>=0.9.0 click~=8.1.7 numpy>=1.24.0 - StrEnum>=0.4.15 [options.extras_require] litellm = diff --git a/src/dbally/audit/event_handlers/base.py b/src/dbally/audit/event_handlers/base.py index 25b69347..dc3ea7f8 100644 --- a/src/dbally/audit/event_handlers/base.py +++ b/src/dbally/audit/event_handlers/base.py @@ -2,35 +2,12 @@ from abc import ABC from typing import Generic, Optional, TypeVar -from strenum import StrEnum - from dbally.audit.events import Event, RequestEnd, RequestStart RequestCtx = TypeVar("RequestCtx") EventCtx = TypeVar("EventCtx") -class LogLevel(StrEnum): - """ - An enumeration representing different logging levels. - - This enumeration inherits from `StrEnum`, making each log level a string value. - The log levels indicate the severity or type of events that occur within an application, - and they are commonly used to filter and categorize log messages. - - Attributes: - INFO (str): Represents informational messages that highlight the progress of the application. - WARNING (str): Represents potentially harmful situations that require attention. - ERROR (str): Represents error events that might still allow the application to continue running. - DEBUG (str): Represents detailed debugging messages useful during development and troubleshooting. - """ - - INFO = "Info" - WARNING = "Warning" - ERROR = "Error" - DEBUG = "Debug" - - class EventHandler(Generic[RequestCtx, EventCtx], ABC): """ A base class that every custom handler should inherit from @@ -82,13 +59,3 @@ async def request_end(self, output: RequestEnd, request_context: RequestCtx) -> output: The output of the request. request_context: Optional context passed from request_start method """ - - @abc.abstractmethod - async def log_message(self, message: str, log_level: LogLevel = LogLevel.INFO) -> None: - """ - Displays the response from the LLM. - - Args: - message: db-ally event to be logged with all the details. - log_level: log level/importance - """ diff --git a/src/dbally/audit/event_handlers/cli_event_handler.py b/src/dbally/audit/event_handlers/cli_event_handler.py index 6cf91909..33b30a15 100644 --- a/src/dbally/audit/event_handlers/cli_event_handler.py +++ b/src/dbally/audit/event_handlers/cli_event_handler.py @@ -3,8 +3,8 @@ from sys import stdout from typing import Optional -from dbally.audit.event_handlers.base import EventHandler, LogLevel -from dbally.audit.events import Event, LLMEvent, RequestEnd, RequestStart, SimilarityEvent +from dbally.audit.event_handlers.base import EventHandler +from dbally.audit.events import Event, FallbackEvent, LLMEvent, RequestEnd, RequestStart, SimilarityEvent try: from rich import print as pprint @@ -98,30 +98,18 @@ async def event_start(self, event: Event, request_context: None) -> None: f"[cyan bold]STORE: [grey53]{event.store}\n" f"[cyan bold]FETCHER: [grey53]{event.fetcher}\n" ) - - async def log_message(self, message: str, log_level: LogLevel = LogLevel.INFO) -> None: - """ - Displays message logged by user - - Args: - message: Message to be sent - log_level: Message log level. - """ - colour = None - if log_level == LogLevel.INFO: - colour = "white" - elif log_level == LogLevel.WARNING: - colour = "orange" - elif log_level == LogLevel.ERROR: - colour = "red" - elif log_level == LogLevel.DEBUG: - colour = "blue" - - self._print_syntax("[grey53]\n=======================================") - self._print_syntax("[grey53]=======================================") - self._print_syntax(f"[{colour}]{log_level}: {message}") - self._print_syntax("[grey53]=======================================") - self._print_syntax("[grey53]=======================================\n") + elif isinstance(event, FallbackEvent): + self._print_syntax( + "[grey53]\n=======================================\n" + "[grey53]=======================================\n" + f"[orange bold]Fallback event starts \n" + f"[orange bold]Triggering collection: [grey53]{event.triggering_collection_name}\n" + f"[orange bold]Triggering view name: [grey53]{event.triggering_view_name}\n" + f"[orange bold]Fallback collection name: [grey53]{event.fallback_collection_name}\n" + f"[orange bold]Error description: [grey53]{event.error_description}\n" + "[grey53]=======================================\n" + "[grey53]=======================================\n" + ) async def event_end(self, event: Optional[Event], request_context: None, event_context: None) -> None: """ diff --git a/src/dbally/audit/event_tracker.py b/src/dbally/audit/event_tracker.py index 4b1da73c..34faf803 100644 --- a/src/dbally/audit/event_tracker.py +++ b/src/dbally/audit/event_tracker.py @@ -1,7 +1,7 @@ from contextlib import asynccontextmanager from typing import AsyncIterator, Dict, List, Optional -from dbally.audit.event_handlers.base import EventHandler, LogLevel +from dbally.audit.event_handlers.base import EventHandler from dbally.audit.events import Event, RequestEnd, RequestStart from dbally.audit.spans import EventSpan @@ -92,15 +92,3 @@ async def track_event(self, event: Event) -> AsyncIterator[EventSpan]: await handler.event_end( span.data, event_context=contexts[handler], request_context=self._request_contexts[handler] ) - - async def log_message(self, message: str, log_level: LogLevel = LogLevel.INFO) -> None: - """ - Send message to the handler - - Args: - message: Message to be sent to - log_level: Log level. - """ - - for handler in self._handlers: - await handler.log_message(message, log_level) diff --git a/src/dbally/audit/events.py b/src/dbally/audit/events.py index c02cd5cb..1cc74ca4 100644 --- a/src/dbally/audit/events.py +++ b/src/dbally/audit/events.py @@ -41,6 +41,18 @@ class SimilarityEvent(Event): output_value: Optional[str] = None +@dataclass +class FallbackEvent(Event): + """ + FallbackEvent is fired when a processed view/collection raise an exception. + """ + + triggering_collection_name: str + triggering_view_name: str + fallback_collection_name: str + error_description: str + + @dataclass class RequestStart: """ diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 6120f96b..0d767bd8 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -6,9 +6,9 @@ from typing import Callable, Dict, List, Optional, Type, TypeVar from dbally.assistants.base import FunctionCallingError -from dbally.audit.event_handlers.base import EventHandler, LogLevel +from dbally.audit.event_handlers.base import EventHandler from dbally.audit.event_tracker import EventTracker -from dbally.audit.events import RequestEnd, RequestStart +from dbally.audit.events import FallbackEvent, RequestEnd, RequestStart from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ExecutionResult from dbally.iql import IQLError @@ -129,9 +129,9 @@ def add_event_handler(self, event_handler: EventHandler): """ self._event_handlers.append(event_handler) - def add_fallback(self, fallback_collection: "Collection") -> "Collection": + def set_fallback(self, fallback_collection: "Collection") -> "Collection": """ - Add fallback collection which will be asked if the ask to base collection does not succeed. + Set fallback collection which will be asked if the ask to base collection does not succeed. Args: fallback_collection: Collection to be asked in case of base collection failure. @@ -152,7 +152,7 @@ def __rshift__(self, fallback_collection: "Collection"): Returns: The fallback collection to create chains call """ - return self.add_fallback(fallback_collection) + return self.set_fallback(fallback_collection) def get(self, name: str) -> BaseView: """ @@ -279,15 +279,20 @@ async def ask( UnsupportedQueryError, PromptTemplateError, Text2SQLError, - ): - await event_tracker.log_message( - f"Exception occurred during {selected_view} processing. Executing view from fallback" f"collection.", - log_level=LogLevel.INFO, - ) + ) as e: if self._fallback_collection: - result = await self._fallback_collection.ask(question, dry_run, return_natural_response, llm_options) + event = FallbackEvent( + triggering_collection_name=self.name, + triggering_view_name=selected_view, + error_description=repr(e), + fallback_collection_name=self._fallback_collection.name, + ) + async with event_tracker.track_event(event) as span: + result = await self._fallback_collection.ask( + question, dry_run, return_natural_response, llm_options + ) + span(event) else: - await event_tracker.log_message(r"No results found", LogLevel.ERROR) return None await event_tracker.request_end(RequestEnd(result=result)) From 9fdb5bd7440f82cab90feb02b144314a9c5278cb Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Fri, 21 Jun 2024 21:58:06 +0200 Subject: [PATCH 08/64] error handling decorator --- src/dbally/collection/collection.py | 141 +++++++++++++--------------- src/dbally/collection/decorators.py | 41 ++++++++ 2 files changed, 107 insertions(+), 75 deletions(-) create mode 100644 src/dbally/collection/decorators.py diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 0d767bd8..62499f4c 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -5,22 +5,19 @@ from collections import defaultdict from typing import Callable, Dict, List, Optional, Type, TypeVar -from dbally.assistants.base import FunctionCallingError from dbally.audit.event_handlers.base import EventHandler from dbally.audit.event_tracker import EventTracker -from dbally.audit.events import FallbackEvent, RequestEnd, RequestStart +from dbally.audit.events import RequestEnd, RequestStart +from dbally.collection.decorators import handle_exception from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ExecutionResult -from dbally.iql import IQLError from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.nl_responder.nl_responder import NLResponder -from dbally.prompts import PromptTemplateError from dbally.similarity.index import AbstractSimilarityIndex from dbally.view_selection.base import ViewSelector from dbally.views.base import BaseView, IndexLocation -from dbally.views.freeform.text2sql import Text2SQLError class Collection: @@ -184,6 +181,57 @@ def list(self) -> Dict[str, str]: name: (textwrap.dedent(view.__doc__).strip() if view.__doc__ else "") for name, view in self._views.items() } + @handle_exception((UnsupportedQueryError, NoViewFoundError)) + async def _select_view(self, question, event_tracker, llm_options): + views = self.list() + if len(views) == 0: + raise ValueError("Empty collection") + if len(views) == 1: + selected_view_name = next(iter(views)) + else: + selected_view_name = await self._view_selector.select_view( + question=question, + views=views, + event_tracker=event_tracker, + llm_options=llm_options, + ) + return selected_view_name + + @handle_exception((UnsupportedQueryError, NoViewFoundError)) + async def _ask_view( + self, selected_view_name, question, event_tracker, dry_run, llm_options, return_natural_response, start_time + ): + selected_view = self.get(selected_view_name) + start_time_view = time.monotonic() + view_result = await selected_view.ask( + query=question, + llm=self._llm, + event_tracker=event_tracker, + n_retries=self.n_retries, + dry_run=dry_run, + llm_options=llm_options, + ) + end_time_view = time.monotonic() + + textual_response = None + if not dry_run and return_natural_response: + textual_response = await self._nl_responder.generate_response( + result=view_result, + question=question, + event_tracker=event_tracker, + llm_options=llm_options, + ) + + result = ExecutionResult( + results=view_result.results, + context=view_result.context, + execution_time=time.monotonic() - start_time, + execution_time_view=end_time_view - start_time_view, + view_name=selected_view_name, + textual_response=textual_response, + ) + return result + async def ask( self, question: str, @@ -224,76 +272,19 @@ async def ask( await event_tracker.request_start(RequestStart(question=question, collection_name=self.name)) - # select view - views = self.list() - selected_view = None - result = None - - try: - if len(views) == 0: - raise ValueError("Empty collection") - if len(views) == 1: - selected_view = next(iter(views)) - else: - selected_view = await self._view_selector.select_view( - question=question, - views=views, - event_tracker=event_tracker, - llm_options=llm_options, - ) - - view = self.get(selected_view) - - start_time_view = time.monotonic() - view_result = await view.ask( - query=question, - llm=self._llm, - event_tracker=event_tracker, - n_retries=self.n_retries, - dry_run=dry_run, - llm_options=llm_options, - ) - end_time_view = time.monotonic() - - textual_response = None - if not dry_run and return_natural_response: - textual_response = await self._nl_responder.generate_response( - result=view_result, - question=question, - event_tracker=event_tracker, - llm_options=llm_options, - ) - - result = ExecutionResult( - results=view_result.results, - context=view_result.context, - execution_time=time.monotonic() - start_time, - execution_time_view=end_time_view - start_time_view, - view_name=selected_view, - textual_response=textual_response, - ) - except ( - NoViewFoundError, - IQLError, - FunctionCallingError, - UnsupportedQueryError, - PromptTemplateError, - Text2SQLError, - ) as e: - if self._fallback_collection: - event = FallbackEvent( - triggering_collection_name=self.name, - triggering_view_name=selected_view, - error_description=repr(e), - fallback_collection_name=self._fallback_collection.name, - ) - async with event_tracker.track_event(event) as span: - result = await self._fallback_collection.ask( - question, dry_run, return_natural_response, llm_options - ) - span(event) - else: - return None + selected_view_name = await self._select_view( + question=question, event_tracker=event_tracker, llm_options=llm_options + ) + + result = await self._ask_view( + selected_view_name=selected_view_name, + question=question, + event_tracker=event_tracker, + dry_run=dry_run, + llm_options=llm_options, + return_natural_response=return_natural_response, + start_time=start_time, + ) await event_tracker.request_end(RequestEnd(result=result)) diff --git a/src/dbally/collection/decorators.py b/src/dbally/collection/decorators.py new file mode 100644 index 00000000..78f2f402 --- /dev/null +++ b/src/dbally/collection/decorators.py @@ -0,0 +1,41 @@ +from dbally.audit.events import FallbackEvent + + +def handle_exception(handle_exception_list): + def handle_exception_inner(func): + async def wrapper(self, **kwargs): # pylint: disable=missing-return-doc + try: + result = await func(self, **kwargs) + except handle_exception_list as error: + question = kwargs.get("question") + dry_run = kwargs.get("dry_run") + return_natural_response = kwargs.get("return_natural_response") + llm_options = kwargs.get("llm_options") + selected_view_name = str(kwargs.get("selected_view_name")) + event_tracker = kwargs.get("event_tracker") + + if self._fallback_collection: + event = FallbackEvent( + triggering_collection_name=self.name, + triggering_view_name=selected_view_name, + fallback_collection_name=self._fallback_collection.name, + error_description=repr(error), + ) + + async with event_tracker.track_event(event) as span: + result = await self._fallback_collection.ask( + question=question, + dry_run=dry_run, + return_natural_response=return_natural_response, + llm_options=llm_options, + ) + span(event) + + else: + raise error + + return result + + return wrapper + + return handle_exception_inner From 5a6b0d231d6a7a5fe7f437e50baaeed619f29b59 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Sat, 22 Jun 2024 11:02:08 +0200 Subject: [PATCH 09/64] pylint fixups --- src/dbally/collection/collection.py | 65 +++++++++++++++++++++++++++-- src/dbally/collection/decorators.py | 46 ++++++++++++++++++-- 2 files changed, 105 insertions(+), 6 deletions(-) diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 62499f4c..82ad0ae3 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -182,7 +182,25 @@ def list(self) -> Dict[str, str]: } @handle_exception((UnsupportedQueryError, NoViewFoundError)) - async def _select_view(self, question, event_tracker, llm_options): + async def _select_view(self, question: str, event_tracker: EventTracker, llm_options: Optional[LLMOptions]) -> str: + """ + Selects a view from the collection based on the given question. + + This method retrieves the list of views from the collection and selects one based on the + specified `question`. If there is only one view, it selects that one. If there are multiple + views, it uses a view selector to determine the most appropriate view. + + Args: + question: The question to be asked. + event_tracker: An instance of the event tracker to record events. + llm_options: Options for the language model. + + Returns: + The name of the selected view. + + Raises: + ValueError: If the collection of views is empty. + """ views = self.list() if len(views) == 0: raise ValueError("Empty collection") @@ -199,8 +217,49 @@ async def _select_view(self, question, event_tracker, llm_options): @handle_exception((UnsupportedQueryError, NoViewFoundError)) async def _ask_view( - self, selected_view_name, question, event_tracker, dry_run, llm_options, return_natural_response, start_time - ): + self, + selected_view_name: str, + question: str, + event_tracker: EventTracker, + dry_run: bool, + llm_options: Optional[LLMOptions], + return_natural_response: bool, + start_time: float, + ) -> ExecutionResult: + """ + Executes a query on the selected view and processes the result. + + This method performs the query on the selected view and measures the execution time. It also + optionally generates a natural language response if `return_natural_response` is True and + `dry_run` is False. + + Args: + selected_view_name: The name of the selected view. + question: The query to be executed. + event_tracker: An instance of the event tracker to record events. + dry_run: Whether to perform a dry run without executing the actual query. + llm_options: Options for the language model. + return_natural_response: Whether to return a natural response. + start_time: The start time of the execution. + + Returns: + ExecutionResult: An object containing the results, context, execution time, view execution + time, view name, and optionally a textual response. + + Example: + result = await self._ask_view( + selected_view_name="example_view", + question="What is the capital of France?", + event_tracker=my_event_tracker, + dry_run=False, + llm_options={"option1": "value1"}, + return_natural_response=True, + start_time=time.monotonic() + ) + + Raises: + KeyError: If the specified view does not exist in the collection. + """ selected_view = self.get(selected_view_name) start_time_view = time.monotonic() view_result = await selected_view.ask( diff --git a/src/dbally/collection/decorators.py b/src/dbally/collection/decorators.py index 78f2f402..2ce5a876 100644 --- a/src/dbally/collection/decorators.py +++ b/src/dbally/collection/decorators.py @@ -1,9 +1,49 @@ +from typing import Callable + from dbally.audit.events import FallbackEvent -def handle_exception(handle_exception_list): - def handle_exception_inner(func): - async def wrapper(self, **kwargs): # pylint: disable=missing-return-doc +# pylint: disable=protected-access +def handle_exception(handle_exception_list) -> Callable: + """ + Decorator to handle specified exceptions during the execution of an asynchronous function. + + This decorator is designed to be used with class methods, and it handles exceptions specified + in `handle_exception_list`. If an exception occurs and a fallback collection is available, + it will attempt to perform the same operation using the fallback collection. + + Args: + handle_exception_list (tuple): A tuple of exception classes that should be handled + by this decorator. + + Returns: + function: A wrapper function that handles the specified exceptions and attempts a fallback + operation if applicable. + + Example: + @handle_exception((SomeException, AnotherException)) + async def some_method(self, **kwargs): + # method implementation + + The decorated method can expect the following keyword arguments in `kwargs`: + - question (str): The question to be asked. + - dry_run (bool): Whether to perform a dry run. + - return_natural_response (bool): Whether to return a natural response. + - llm_options (dict): Options for the language model. + - selected_view_name (str): The name of the selected view. + - event_tracker (EventTracker): An event tracker instance. + + If an exception is caught and a fallback collection is available, an event of type + `FallbackEvent` will be tracked, and the fallback collection's `ask` method will be called + with the same arguments. + + Raises: + Exception: If an exception in `handle_exception_list` occurs and no fallback collection is + available, the original exception is re-raised. + """ + + def handle_exception_inner(func: Callable) -> Callable: # pylint: disable=missing-return-doc + async def wrapper(self, **kwargs) -> Callable: # pylint: disable=missing-return-doc try: result = await func(self, **kwargs) except handle_exception_list as error: From 0b9669ec3d61bc07b5ababf6835eb2ec7a930771 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Sat, 22 Jun 2024 14:01:05 +0200 Subject: [PATCH 10/64] adjustments --- src/dbally/collection/decorators.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/dbally/collection/decorators.py b/src/dbally/collection/decorators.py index 2ce5a876..f4ff3bcc 100644 --- a/src/dbally/collection/decorators.py +++ b/src/dbally/collection/decorators.py @@ -30,7 +30,6 @@ async def some_method(self, **kwargs): - dry_run (bool): Whether to perform a dry run. - return_natural_response (bool): Whether to return a natural response. - llm_options (dict): Options for the language model. - - selected_view_name (str): The name of the selected view. - event_tracker (EventTracker): An event tracker instance. If an exception is caught and a fallback collection is available, an event of type @@ -61,6 +60,10 @@ async def wrapper(self, **kwargs) -> Callable: # pylint: disable=missing-return fallback_collection_name=self._fallback_collection.name, error_description=repr(error), ) + if not self.fallback_collection_chain: + self.fallback_collection_chain = [] + else: + self._fallback_collection.append(self._fallback_collection) async with event_tracker.track_event(event) as span: result = await self._fallback_collection.ask( From 5cb4117e340af2369a1404ad152951aae776c885 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Mon, 24 Jun 2024 09:52:53 +0200 Subject: [PATCH 11/64] Fallback monitor idea --- examples/visualize_fallback_code.py | 5 +- src/dbally/_main.py | 4 ++ src/dbally/collection/collection.py | 73 +++++++---------------- src/dbally/collection/decorators.py | 7 +-- src/dbally/collection/fallback_monitor.py | 21 +++++++ src/dbally/gradio/gradio_interface.py | 2 + 6 files changed, 57 insertions(+), 55 deletions(-) create mode 100644 src/dbally/collection/fallback_monitor.py diff --git a/examples/visualize_fallback_code.py b/examples/visualize_fallback_code.py index 6a5848e4..d4843c32 100644 --- a/examples/visualize_fallback_code.py +++ b/examples/visualize_fallback_code.py @@ -1,6 +1,7 @@ # pylint: disable=missing-function-docstring import asyncio +from dbally.collection.fallback_monitor import FallbackMonitor from recruiting import candidate_view_with_similarity_store, candidates_freeform from recruiting.candidate_view_with_similarity_store import CandidateView, country_similarity from recruiting.candidates_freeform import CandidateFreeformView @@ -15,7 +16,9 @@ async def main(): await country_similarity.update() llm = LiteLLM(model_name="gpt-3.5-turbo") - collection1 = dbally.create_collection("candidates", llm, event_handlers=[CLIEventHandler()]) + collection1 = dbally.create_collection( + "candidates", llm, event_handlers=[CLIEventHandler()], fallback_monitor=FallbackMonitor() + ) collection2 = dbally.create_collection("freeform candidates", llm, event_handlers=[]) collection1.add(CandidateView, lambda: CandidateView(candidate_view_with_similarity_store.engine)) collection1.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) diff --git a/src/dbally/_main.py b/src/dbally/_main.py index ab26c4b7..3d085653 100644 --- a/src/dbally/_main.py +++ b/src/dbally/_main.py @@ -2,6 +2,7 @@ from .audit.event_handlers.base import EventHandler from .collection import Collection +from .collection.fallback_monitor import FallbackMonitor from .llms import LLM from .nl_responder.nl_responder import NLResponder from .view_selection.base import ViewSelector @@ -15,6 +16,7 @@ def create_collection( view_selector: Optional[ViewSelector] = None, nl_responder: Optional[NLResponder] = None, fallback_collection: Optional[Collection] = None, + fallback_monitor: Optional[FallbackMonitor] = None, ) -> Collection: """ Create a new [Collection](collection.md) that is a container for registering views and the\ @@ -34,6 +36,7 @@ def create_collection( ``` Args: + fallback_monitor: name: Name of the collection is available for [Event handlers](event_handlers/index.md) and is\ used to distinguish different db-ally runs. llm: LLM used by the collection to generate responses for natural language queries. @@ -64,4 +67,5 @@ def create_collection( llm=llm, event_handlers=event_handlers, fallback_collection=fallback_collection, + fallback_monitor=fallback_monitor, ) diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 82ad0ae3..be386315 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -10,6 +10,7 @@ from dbally.audit.events import RequestEnd, RequestStart from dbally.collection.decorators import handle_exception from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError +from dbally.collection.fallback_monitor import FallbackMonitor from dbally.collection.results import ExecutionResult from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError from dbally.llms.base import LLM @@ -38,6 +39,7 @@ def __init__( nl_responder: NLResponder, n_retries: int = 3, fallback_collection: Optional["Collection"] = None, + fallback_monitor: Optional[FallbackMonitor] = None, ) -> None: """ Args: @@ -66,6 +68,7 @@ def __init__( self._event_handlers = event_handlers self._llm = llm self._fallback_collection: Optional[Collection] = fallback_collection + self._fallback_monitor = fallback_monitor T = TypeVar("T", bound=BaseView) @@ -182,43 +185,8 @@ def list(self) -> Dict[str, str]: } @handle_exception((UnsupportedQueryError, NoViewFoundError)) - async def _select_view(self, question: str, event_tracker: EventTracker, llm_options: Optional[LLMOptions]) -> str: - """ - Selects a view from the collection based on the given question. - - This method retrieves the list of views from the collection and selects one based on the - specified `question`. If there is only one view, it selects that one. If there are multiple - views, it uses a view selector to determine the most appropriate view. - - Args: - question: The question to be asked. - event_tracker: An instance of the event tracker to record events. - llm_options: Options for the language model. - - Returns: - The name of the selected view. - - Raises: - ValueError: If the collection of views is empty. - """ - views = self.list() - if len(views) == 0: - raise ValueError("Empty collection") - if len(views) == 1: - selected_view_name = next(iter(views)) - else: - selected_view_name = await self._view_selector.select_view( - question=question, - views=views, - event_tracker=event_tracker, - llm_options=llm_options, - ) - return selected_view_name - - @handle_exception((UnsupportedQueryError, NoViewFoundError)) - async def _ask_view( + async def _ask_question( self, - selected_view_name: str, question: str, event_tracker: EventTracker, dry_run: bool, @@ -227,20 +195,18 @@ async def _ask_view( start_time: float, ) -> ExecutionResult: """ - Executes a query on the selected view and processes the result. + Find matching view and executes a query on the view and processes the result. This method performs the query on the selected view and measures the execution time. It also optionally generates a natural language response if `return_natural_response` is True and `dry_run` is False. Args: - selected_view_name: The name of the selected view. question: The query to be executed. event_tracker: An instance of the event tracker to record events. dry_run: Whether to perform a dry run without executing the actual query. llm_options: Options for the language model. return_natural_response: Whether to return a natural response. - start_time: The start time of the execution. Returns: ExecutionResult: An object containing the results, context, execution time, view execution @@ -253,13 +219,25 @@ async def _ask_view( event_tracker=my_event_tracker, dry_run=False, llm_options={"option1": "value1"}, - return_natural_response=True, - start_time=time.monotonic() - ) + return_natural_response=True) Raises: KeyError: If the specified view does not exist in the collection. """ + + views = self.list() + if len(views) == 0: + raise ValueError("Empty collection") + if len(views) == 1: + selected_view_name = next(iter(views)) + else: + selected_view_name = await self._view_selector.select_view( + question=question, + views=views, + event_tracker=event_tracker, + llm_options=llm_options, + ) + selected_view = self.get(selected_view_name) start_time_view = time.monotonic() view_result = await selected_view.ask( @@ -325,24 +303,19 @@ async def ask( IQLError: if incorrect IQL was generated `n_retries` amount of times. ValueError: if incorrect IQL was generated `n_retries` amount of times. """ - start_time = time.monotonic() event_tracker = EventTracker.initialize_with_handlers(self._event_handlers) await event_tracker.request_start(RequestStart(question=question, collection_name=self.name)) + start_time = time.monotonic() - selected_view_name = await self._select_view( - question=question, event_tracker=event_tracker, llm_options=llm_options - ) - - result = await self._ask_view( - selected_view_name=selected_view_name, + result = await self._ask_question( question=question, + start_time=start_time, event_tracker=event_tracker, dry_run=dry_run, llm_options=llm_options, return_natural_response=return_natural_response, - start_time=start_time, ) await event_tracker.request_end(RequestEnd(result=result)) diff --git a/src/dbally/collection/decorators.py b/src/dbally/collection/decorators.py index f4ff3bcc..36c44207 100644 --- a/src/dbally/collection/decorators.py +++ b/src/dbally/collection/decorators.py @@ -52,6 +52,7 @@ async def wrapper(self, **kwargs) -> Callable: # pylint: disable=missing-return llm_options = kwargs.get("llm_options") selected_view_name = str(kwargs.get("selected_view_name")) event_tracker = kwargs.get("event_tracker") + start_time = kwargs.get("start_time") if self._fallback_collection: event = FallbackEvent( @@ -60,10 +61,8 @@ async def wrapper(self, **kwargs) -> Callable: # pylint: disable=missing-return fallback_collection_name=self._fallback_collection.name, error_description=repr(error), ) - if not self.fallback_collection_chain: - self.fallback_collection_chain = [] - else: - self._fallback_collection.append(self._fallback_collection) + if self._fallback_monitor: + self._fallback_monitor.add_fallback_event(question, start_time, event) async with event_tracker.track_event(event) as span: result = await self._fallback_collection.ask( diff --git a/src/dbally/collection/fallback_monitor.py b/src/dbally/collection/fallback_monitor.py new file mode 100644 index 00000000..15fdd467 --- /dev/null +++ b/src/dbally/collection/fallback_monitor.py @@ -0,0 +1,21 @@ +from typing import Tuple, List, Dict + +from dbally.audit.events import FallbackEvent + + +class FallbackMonitor: + + def __init__(self): + self.fallback_log: Dict[Tuple[str, int], List[FallbackEvent]] = {} + + def add_fallback_event(self, question: str, start_time: int, fallback_event: FallbackEvent): + if not self.fallback_log.get((question, start_time)): + self.fallback_log[(question, start_time)] = [fallback_event] + else: + self.fallback_log[(question, start_time)].append(fallback_event) + + def number_of_fallback_queries(self) -> int: + return len(self.fallback_log) + + def __str__(self): + return f"Called fallbacks: {self.fallback_log}" diff --git a/src/dbally/gradio/gradio_interface.py b/src/dbally/gradio/gradio_interface.py index 8902d315..327206be 100644 --- a/src/dbally/gradio/gradio_interface.py +++ b/src/dbally/gradio/gradio_interface.py @@ -136,6 +136,8 @@ async def _ui_ask_query( log_content = self.log.read() gradio_dataframe, empty_dataframe_warning = self._load_gradio_data(data, "Results", "No matching results found") + if self.collection._fallback_monitor and self.collection._fallback_monitor.number_of_fallback_queries() > 0: + print(self.collection._fallback_monitor) return ( gradio_dataframe, empty_dataframe_warning, From c864f18b472dfd45997609d2f6f2562fcf87380c Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Mon, 24 Jun 2024 10:29:45 +0200 Subject: [PATCH 12/64] remove fallback monitor --- examples/visualize_fallback_code.py | 5 +---- src/dbally/_main.py | 4 ---- src/dbally/collection/collection.py | 3 --- src/dbally/collection/decorators.py | 3 --- src/dbally/collection/fallback_monitor.py | 21 --------------------- src/dbally/gradio/gradio_interface.py | 2 -- 6 files changed, 1 insertion(+), 37 deletions(-) delete mode 100644 src/dbally/collection/fallback_monitor.py diff --git a/examples/visualize_fallback_code.py b/examples/visualize_fallback_code.py index d4843c32..6a5848e4 100644 --- a/examples/visualize_fallback_code.py +++ b/examples/visualize_fallback_code.py @@ -1,7 +1,6 @@ # pylint: disable=missing-function-docstring import asyncio -from dbally.collection.fallback_monitor import FallbackMonitor from recruiting import candidate_view_with_similarity_store, candidates_freeform from recruiting.candidate_view_with_similarity_store import CandidateView, country_similarity from recruiting.candidates_freeform import CandidateFreeformView @@ -16,9 +15,7 @@ async def main(): await country_similarity.update() llm = LiteLLM(model_name="gpt-3.5-turbo") - collection1 = dbally.create_collection( - "candidates", llm, event_handlers=[CLIEventHandler()], fallback_monitor=FallbackMonitor() - ) + collection1 = dbally.create_collection("candidates", llm, event_handlers=[CLIEventHandler()]) collection2 = dbally.create_collection("freeform candidates", llm, event_handlers=[]) collection1.add(CandidateView, lambda: CandidateView(candidate_view_with_similarity_store.engine)) collection1.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) diff --git a/src/dbally/_main.py b/src/dbally/_main.py index 3d085653..ab26c4b7 100644 --- a/src/dbally/_main.py +++ b/src/dbally/_main.py @@ -2,7 +2,6 @@ from .audit.event_handlers.base import EventHandler from .collection import Collection -from .collection.fallback_monitor import FallbackMonitor from .llms import LLM from .nl_responder.nl_responder import NLResponder from .view_selection.base import ViewSelector @@ -16,7 +15,6 @@ def create_collection( view_selector: Optional[ViewSelector] = None, nl_responder: Optional[NLResponder] = None, fallback_collection: Optional[Collection] = None, - fallback_monitor: Optional[FallbackMonitor] = None, ) -> Collection: """ Create a new [Collection](collection.md) that is a container for registering views and the\ @@ -36,7 +34,6 @@ def create_collection( ``` Args: - fallback_monitor: name: Name of the collection is available for [Event handlers](event_handlers/index.md) and is\ used to distinguish different db-ally runs. llm: LLM used by the collection to generate responses for natural language queries. @@ -67,5 +64,4 @@ def create_collection( llm=llm, event_handlers=event_handlers, fallback_collection=fallback_collection, - fallback_monitor=fallback_monitor, ) diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index be386315..03be5e07 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -10,7 +10,6 @@ from dbally.audit.events import RequestEnd, RequestStart from dbally.collection.decorators import handle_exception from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError -from dbally.collection.fallback_monitor import FallbackMonitor from dbally.collection.results import ExecutionResult from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError from dbally.llms.base import LLM @@ -39,7 +38,6 @@ def __init__( nl_responder: NLResponder, n_retries: int = 3, fallback_collection: Optional["Collection"] = None, - fallback_monitor: Optional[FallbackMonitor] = None, ) -> None: """ Args: @@ -68,7 +66,6 @@ def __init__( self._event_handlers = event_handlers self._llm = llm self._fallback_collection: Optional[Collection] = fallback_collection - self._fallback_monitor = fallback_monitor T = TypeVar("T", bound=BaseView) diff --git a/src/dbally/collection/decorators.py b/src/dbally/collection/decorators.py index 36c44207..913ae73a 100644 --- a/src/dbally/collection/decorators.py +++ b/src/dbally/collection/decorators.py @@ -52,7 +52,6 @@ async def wrapper(self, **kwargs) -> Callable: # pylint: disable=missing-return llm_options = kwargs.get("llm_options") selected_view_name = str(kwargs.get("selected_view_name")) event_tracker = kwargs.get("event_tracker") - start_time = kwargs.get("start_time") if self._fallback_collection: event = FallbackEvent( @@ -61,8 +60,6 @@ async def wrapper(self, **kwargs) -> Callable: # pylint: disable=missing-return fallback_collection_name=self._fallback_collection.name, error_description=repr(error), ) - if self._fallback_monitor: - self._fallback_monitor.add_fallback_event(question, start_time, event) async with event_tracker.track_event(event) as span: result = await self._fallback_collection.ask( diff --git a/src/dbally/collection/fallback_monitor.py b/src/dbally/collection/fallback_monitor.py deleted file mode 100644 index 15fdd467..00000000 --- a/src/dbally/collection/fallback_monitor.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Tuple, List, Dict - -from dbally.audit.events import FallbackEvent - - -class FallbackMonitor: - - def __init__(self): - self.fallback_log: Dict[Tuple[str, int], List[FallbackEvent]] = {} - - def add_fallback_event(self, question: str, start_time: int, fallback_event: FallbackEvent): - if not self.fallback_log.get((question, start_time)): - self.fallback_log[(question, start_time)] = [fallback_event] - else: - self.fallback_log[(question, start_time)].append(fallback_event) - - def number_of_fallback_queries(self) -> int: - return len(self.fallback_log) - - def __str__(self): - return f"Called fallbacks: {self.fallback_log}" diff --git a/src/dbally/gradio/gradio_interface.py b/src/dbally/gradio/gradio_interface.py index 327206be..8902d315 100644 --- a/src/dbally/gradio/gradio_interface.py +++ b/src/dbally/gradio/gradio_interface.py @@ -136,8 +136,6 @@ async def _ui_ask_query( log_content = self.log.read() gradio_dataframe, empty_dataframe_warning = self._load_gradio_data(data, "Results", "No matching results found") - if self.collection._fallback_monitor and self.collection._fallback_monitor.number_of_fallback_queries() > 0: - print(self.collection._fallback_monitor) return ( gradio_dataframe, empty_dataframe_warning, From e6f1dbd10848e4db022930d44d686795c87b8045 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Mon, 24 Jun 2024 14:45:43 +0200 Subject: [PATCH 13/64] decorator clean up --- src/dbally/collection/collection.py | 156 ++++++++++++++++------------ src/dbally/collection/decorators.py | 80 -------------- 2 files changed, 88 insertions(+), 148 deletions(-) delete mode 100644 src/dbally/collection/decorators.py diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 03be5e07..2cd3115b 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -7,8 +7,7 @@ from dbally.audit.event_handlers.base import EventHandler from dbally.audit.event_tracker import EventTracker -from dbally.audit.events import RequestEnd, RequestStart -from dbally.collection.decorators import handle_exception +from dbally.audit.events import RequestEnd, RequestStart, FallbackEvent from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ExecutionResult from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError @@ -181,46 +180,12 @@ def list(self) -> Dict[str, str]: name: (textwrap.dedent(view.__doc__).strip() if view.__doc__ else "") for name, view in self._views.items() } - @handle_exception((UnsupportedQueryError, NoViewFoundError)) - async def _ask_question( + async def _select_view( self, question: str, event_tracker: EventTracker, - dry_run: bool, llm_options: Optional[LLMOptions], - return_natural_response: bool, - start_time: float, - ) -> ExecutionResult: - """ - Find matching view and executes a query on the view and processes the result. - - This method performs the query on the selected view and measures the execution time. It also - optionally generates a natural language response if `return_natural_response` is True and - `dry_run` is False. - - Args: - question: The query to be executed. - event_tracker: An instance of the event tracker to record events. - dry_run: Whether to perform a dry run without executing the actual query. - llm_options: Options for the language model. - return_natural_response: Whether to return a natural response. - - Returns: - ExecutionResult: An object containing the results, context, execution time, view execution - time, view name, and optionally a textual response. - - Example: - result = await self._ask_view( - selected_view_name="example_view", - question="What is the capital of France?", - event_tracker=my_event_tracker, - dry_run=False, - llm_options={"option1": "value1"}, - return_natural_response=True) - - Raises: - KeyError: If the specified view does not exist in the collection. - """ + ) -> str: views = self.list() if len(views) == 0: @@ -234,9 +199,10 @@ async def _ask_question( event_tracker=event_tracker, llm_options=llm_options, ) + return selected_view_name + async def _ask_view(self, selected_view_name, question, event_tracker, llm_options, dry_run): selected_view = self.get(selected_view_name) - start_time_view = time.monotonic() view_result = await selected_view.ask( query=question, llm=self._llm, @@ -245,26 +211,55 @@ async def _ask_question( dry_run=dry_run, llm_options=llm_options, ) - end_time_view = time.monotonic() + return view_result - textual_response = None - if not dry_run and return_natural_response: - textual_response = await self._nl_responder.generate_response( - result=view_result, - question=question, - event_tracker=event_tracker, - llm_options=llm_options, + async def _generate_textual_response( + self, + view_result, + question, + event_tracker, + llm_options, + ): + textual_response = await self._nl_responder.generate_response( + result=view_result, + question=question, + event_tracker=event_tracker, + llm_options=llm_options, + ) + return textual_response + + async def _handle_fallback( + self, + question, + dry_run, + return_natural_response, + llm_options, + selected_view_name, + event_tracker, + caught_exception, + ): + + if self._fallback_collection: + + event = FallbackEvent( + triggering_collection_name=self.name, + triggering_view_name=selected_view_name, + fallback_collection_name=self._fallback_collection.name, + error_description=repr(caught_exception), ) - result = ExecutionResult( - results=view_result.results, - context=view_result.context, - execution_time=time.monotonic() - start_time, - execution_time_view=end_time_view - start_time_view, - view_name=selected_view_name, - textual_response=textual_response, - ) - return result + async with event_tracker.track_event(event) as span: + result = await self._fallback_collection.ask( + question=question, + dry_run=dry_run, + return_natural_response=return_natural_response, + llm_options=llm_options, + ) + span(event) + return result + + else: + raise caught_exception async def ask( self, @@ -300,22 +295,47 @@ async def ask( IQLError: if incorrect IQL was generated `n_retries` amount of times. ValueError: if incorrect IQL was generated `n_retries` amount of times. """ - + handle_exceptions = (NoViewFoundError, UnsupportedQueryError, IndexUpdateError) event_tracker = EventTracker.initialize_with_handlers(self._event_handlers) - await event_tracker.request_start(RequestStart(question=question, collection_name=self.name)) - start_time = time.monotonic() + selected_view_name = "" - result = await self._ask_question( - question=question, - start_time=start_time, - event_tracker=event_tracker, - dry_run=dry_run, - llm_options=llm_options, - return_natural_response=return_natural_response, - ) + try: + + await event_tracker.request_start(RequestStart(question=question, collection_name=self.name)) + + start_time = time.monotonic() + selected_view_name = await self._select_view(question, event_tracker, llm_options) + + start_time_view = time.monotonic() + view_result = await self._ask_view(selected_view_name, question, event_tracker, llm_options, dry_run) + end_time_view = time.monotonic() - await event_tracker.request_end(RequestEnd(result=result)) + natural_response = ( + self._generate_textual_response(view_result, question, event_tracker, llm_options) + if not dry_run and return_natural_response + else "" + ) + result = ExecutionResult( + results=view_result.results, + context=view_result.context, + execution_time=time.monotonic() - start_time, + execution_time_view=end_time_view - start_time_view, + view_name=selected_view_name, + textual_response=natural_response, + ) + + except handle_exceptions as caught_exception: + result = await self._handle_fallback( + question, + dry_run, + return_natural_response, + llm_options, + selected_view_name, + event_tracker, + caught_exception, + ) + await event_tracker.request_end(RequestEnd(result=result)) return result diff --git a/src/dbally/collection/decorators.py b/src/dbally/collection/decorators.py deleted file mode 100644 index 913ae73a..00000000 --- a/src/dbally/collection/decorators.py +++ /dev/null @@ -1,80 +0,0 @@ -from typing import Callable - -from dbally.audit.events import FallbackEvent - - -# pylint: disable=protected-access -def handle_exception(handle_exception_list) -> Callable: - """ - Decorator to handle specified exceptions during the execution of an asynchronous function. - - This decorator is designed to be used with class methods, and it handles exceptions specified - in `handle_exception_list`. If an exception occurs and a fallback collection is available, - it will attempt to perform the same operation using the fallback collection. - - Args: - handle_exception_list (tuple): A tuple of exception classes that should be handled - by this decorator. - - Returns: - function: A wrapper function that handles the specified exceptions and attempts a fallback - operation if applicable. - - Example: - @handle_exception((SomeException, AnotherException)) - async def some_method(self, **kwargs): - # method implementation - - The decorated method can expect the following keyword arguments in `kwargs`: - - question (str): The question to be asked. - - dry_run (bool): Whether to perform a dry run. - - return_natural_response (bool): Whether to return a natural response. - - llm_options (dict): Options for the language model. - - event_tracker (EventTracker): An event tracker instance. - - If an exception is caught and a fallback collection is available, an event of type - `FallbackEvent` will be tracked, and the fallback collection's `ask` method will be called - with the same arguments. - - Raises: - Exception: If an exception in `handle_exception_list` occurs and no fallback collection is - available, the original exception is re-raised. - """ - - def handle_exception_inner(func: Callable) -> Callable: # pylint: disable=missing-return-doc - async def wrapper(self, **kwargs) -> Callable: # pylint: disable=missing-return-doc - try: - result = await func(self, **kwargs) - except handle_exception_list as error: - question = kwargs.get("question") - dry_run = kwargs.get("dry_run") - return_natural_response = kwargs.get("return_natural_response") - llm_options = kwargs.get("llm_options") - selected_view_name = str(kwargs.get("selected_view_name")) - event_tracker = kwargs.get("event_tracker") - - if self._fallback_collection: - event = FallbackEvent( - triggering_collection_name=self.name, - triggering_view_name=selected_view_name, - fallback_collection_name=self._fallback_collection.name, - error_description=repr(error), - ) - - async with event_tracker.track_event(event) as span: - result = await self._fallback_collection.ask( - question=question, - dry_run=dry_run, - return_natural_response=return_natural_response, - llm_options=llm_options, - ) - span(event) - - else: - raise error - - return result - - return wrapper - - return handle_exception_inner From 97689a341874e2e6f851f3b6c61934bda327b3f9 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Mon, 24 Jun 2024 14:54:39 +0200 Subject: [PATCH 14/64] add docstrings --- src/dbally/collection/collection.py | 63 +++++++++++++++++++++++++---- 1 file changed, 56 insertions(+), 7 deletions(-) diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 2cd3115b..6fc63935 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -9,7 +9,7 @@ from dbally.audit.event_tracker import EventTracker from dbally.audit.events import RequestEnd, RequestStart, FallbackEvent from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError -from dbally.collection.results import ExecutionResult +from dbally.collection.results import ExecutionResult, ViewExecutionResult from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions @@ -186,6 +186,23 @@ async def _select_view( event_tracker: EventTracker, llm_options: Optional[LLMOptions], ) -> str: + """ + Select a view based on the provided question and options. + + If there is only one view available, it selects that view directly. Otherwise, it + uses the view selector to choose the most appropriate view. + + Args: + question: The question to be answered. + event_tracker: The event tracker for logging and tracking events. + llm_options: Options for the LLM client. + + Returns: + str: The name of the selected view. + + Raises: + ValueError: If the collection of views is empty. + """ views = self.list() if len(views) == 0: @@ -201,7 +218,27 @@ async def _select_view( ) return selected_view_name - async def _ask_view(self, selected_view_name, question, event_tracker, llm_options, dry_run): + async def _ask_view( + self, + selected_view_name: str, + question: str, + event_tracker: EventTracker, + llm_options: Optional[LLMOptions], + dry_run: bool, + ): + """ + Ask the selected view to provide an answer to the question. + + Args: + selected_view_name: The name of the selected view. + question: The question to be answered. + event_tracker: The event tracker for logging and tracking events. + llm_options: Options for the LLM client. + dry_run: If True, only generate the query without executing it. + + Returns: + Any: The result from the selected view. + """ selected_view = self.get(selected_view_name) view_result = await selected_view.ask( query=question, @@ -215,11 +252,23 @@ async def _ask_view(self, selected_view_name, question, event_tracker, llm_optio async def _generate_textual_response( self, - view_result, - question, - event_tracker, - llm_options, - ): + view_result: ViewExecutionResult, + question: str, + event_tracker: EventTracker, + llm_options: Optional[LLMOptions], + ) -> str: + """ + Generate a textual response from the view result. + + Args: + view_result: The result from the view. + question: The question to be answered. + event_tracker: The event tracker for logging and tracking events. + llm_options: Options for the LLM client. + + Returns: + The generated textual response. + """ textual_response = await self._nl_responder.generate_response( result=view_result, question=question, From ee50bc784ac85e51b6d1f5294ba3a3a82f837a86 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Mon, 24 Jun 2024 14:56:49 +0200 Subject: [PATCH 15/64] fix docstrings --- src/dbally/collection/collection.py | 33 +++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 6fc63935..7e314bc0 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -5,6 +5,7 @@ from collections import defaultdict from typing import Callable, Dict, List, Optional, Type, TypeVar +from dbally import DbAllyError from dbally.audit.event_handlers.base import EventHandler from dbally.audit.event_tracker import EventTracker from dbally.audit.events import RequestEnd, RequestStart, FallbackEvent @@ -279,14 +280,32 @@ async def _generate_textual_response( async def _handle_fallback( self, - question, - dry_run, - return_natural_response, - llm_options, - selected_view_name, - event_tracker, - caught_exception, + question: str, + dry_run: bool, + return_natural_response: bool, + llm_options: Optional[LLMOptions], + selected_view_name: str, + event_tracker: EventTracker, + caught_exception: DbAllyError, ): + """ + Handle fallback if the main query fails. + + Args: + question: The question to be answered. + dry_run: If True, only generate the query without executing it. + return_natural_response: If True, return the natural language response. + llm_options: Options for the LLM client. + selected_view_name: The name of the selected view. + event_tracker: The event tracker for logging and tracking events. + caught_exception: The exception that was caught. + + Returns: + Any: The result from the fallback collection. + + Raises: + Exception: If there is no fallback collection or if an error occurs in the fallback. + """ if self._fallback_collection: From eddd20338a657df3ccfc4ff9e0ddb6a261438c43 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Mon, 24 Jun 2024 18:17:48 +0200 Subject: [PATCH 16/64] isort fix --- src/dbally/collection/collection.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 7e314bc0..57fbf15f 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -8,7 +8,7 @@ from dbally import DbAllyError from dbally.audit.event_handlers.base import EventHandler from dbally.audit.event_tracker import EventTracker -from dbally.audit.events import RequestEnd, RequestStart, FallbackEvent +from dbally.audit.events import FallbackEvent, RequestEnd, RequestStart from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ExecutionResult, ViewExecutionResult from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError @@ -308,7 +308,6 @@ async def _handle_fallback( """ if self._fallback_collection: - event = FallbackEvent( triggering_collection_name=self.name, triggering_view_name=selected_view_name, @@ -369,7 +368,6 @@ async def ask( selected_view_name = "" try: - await event_tracker.request_start(RequestStart(question=question, collection_name=self.name)) start_time = time.monotonic() @@ -393,6 +391,28 @@ async def ask( textual_response=natural_response, ) + + + + + + + + + + + + + + + + + + + + + + except handle_exceptions as caught_exception: result = await self._handle_fallback( question, From d2d7acf755ca3ebe562ad44d3958db306225601e Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Mon, 24 Jun 2024 20:18:55 +0200 Subject: [PATCH 17/64] Feat: Implement global event handlers --- docs/how-to/visualize_views.md | 1 + examples/visualize_views_code.py | 5 +- src/dbally/__init__.py | 17 +++--- src/dbally/_main.py | 53 +++++++++++++++++-- .../audit/event_handlers/cli_event_handler.py | 3 +- src/dbally/collection/collection.py | 22 +++----- src/dbally/gradio/__init__.py | 6 ++- src/dbally/gradio/gradio_interface.py | 3 +- 8 files changed, 80 insertions(+), 30 deletions(-) diff --git a/docs/how-to/visualize_views.md b/docs/how-to/visualize_views.md index 6553e07b..45cd3066 100644 --- a/docs/how-to/visualize_views.md +++ b/docs/how-to/visualize_views.md @@ -19,6 +19,7 @@ Define collection with implemented views ```python llm = LiteLLM(model_name="gpt-3.5-turbo") await country_similarity.update() +dbally.add_event_handler(CLIEventHandler(gradio_buffer)) collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()]) collection.add(CandidateView, lambda: CandidateView(engine)) collection.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) diff --git a/examples/visualize_views_code.py b/examples/visualize_views_code.py index 31806fe8..42b7aec4 100644 --- a/examples/visualize_views_code.py +++ b/examples/visualize_views_code.py @@ -6,14 +6,15 @@ import dbally from dbally.audit import CLIEventHandler -from dbally.gradio import create_gradio_interface +from dbally.gradio import create_gradio_interface, gradio_buffer from dbally.llms.litellm import LiteLLM async def main(): await country_similarity.update() llm = LiteLLM(model_name="gpt-3.5-turbo") - collection = dbally.create_collection("candidates", llm, event_handlers=[CLIEventHandler()]) + dbally.add_event_handler(CLIEventHandler(gradio_buffer)) + collection = dbally.create_collection("candidates", llm) collection.add(CandidateView, lambda: CandidateView(engine)) collection.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) gradio_interface = await create_gradio_interface(user_collection=collection) diff --git a/src/dbally/__init__.py b/src/dbally/__init__.py index 15caf8cc..f9ee37f7 100644 --- a/src/dbally/__init__.py +++ b/src/dbally/__init__.py @@ -10,7 +10,7 @@ from dbally.views.structured import BaseStructuredView from .__version__ import __version__ -from ._main import create_collection +from ._main import add_event_handler, create_collection, event_handlers, set_event_handlers from ._types import NOT_GIVEN, NotGiven from .embeddings.exceptions import ( EmbeddingConnectionError, @@ -23,27 +23,30 @@ __all__ = [ "__version__", + "event_handlers", + "add_event_handler", "create_collection", + "set_event_handlers", "decorators", - "MethodsBaseView", - "SqlAlchemyBaseView", - "Collection", "BaseStructuredView", + "Collection", "DataFrameBaseView", - "ExecutionResult", "DbAllyError", + "ExecutionResult", "EmbeddingError", "EmbeddingConnectionError", "EmbeddingResponseError", "EmbeddingStatusError", + "IndexUpdateError", "LLMError", "LLMConnectionError", "LLMResponseError", "LLMStatusError", - "NoViewFoundError", - "IndexUpdateError", + "MethodsBaseView", "NotGiven", "NOT_GIVEN", + "NoViewFoundError", + "SqlAlchemyBaseView", ] # Update the __module__ attribute for exported symbols so that diff --git a/src/dbally/_main.py b/src/dbally/_main.py index eeb2d836..e2814daf 100644 --- a/src/dbally/_main.py +++ b/src/dbally/_main.py @@ -1,5 +1,6 @@ from typing import List, Optional +from .audit import CLIEventHandler from .audit.event_handlers.base import EventHandler from .collection import Collection from .llms import LLM @@ -7,11 +8,54 @@ from .view_selection.base import ViewSelector from .view_selection.llm_view_selector import LLMViewSelector +# Global list of event handlers initialized with a default CLIEventHandler. +event_handlers: List[EventHandler] = [CLIEventHandler()] + + +def set_event_handlers(event_handler_list: List[EventHandler]) -> None: + """ + Set the global list of event handlers. + + This function replaces the current list of event handlers with the provided list. + It ensures that each handler in the provided list is an instance of EventHandler. + If any handler is not an instance of EventHandler, it raises a ValueError. + + Args: + event_handler_list (List[EventHandler]): The list of event handlers to set. + + Raises: + ValueError: If any handler in the list is not an instance of EventHandler. + """ + for handler in event_handler_list: + if isinstance(type(handler), EventHandler): + raise ValueError(f"{handler} is not an instance of EventHandler") + global event_handlers # pylint: disable=global-statement + event_handlers = event_handler_list + + +def add_event_handler(event_handler: EventHandler) -> None: + """ + Add an event handler to the global list. + + This function appends the provided event handler to the global list of event handlers. + It ensures that the handler is an instance of EventHandler. + If the handler is not an instance of EventHandler, it raises a ValueError. + + Args: + event_handler (EventHandler): The event handler to add. + + Raises: + ValueError: If the handler is not an instance of EventHandler. + """ + if isinstance(type(event_handler), EventHandler): + raise ValueError(f"{event_handler} is not an instance of EventHandler") + event_handlers.append(event_handler) + def create_collection( name: str, llm: LLM, - event_handlers: Optional[List[EventHandler]] = None, + collection_event_handlers: Optional[List[EventHandler]] = None, view_selector: Optional[ViewSelector] = None, nl_responder: Optional[NLResponder] = None, ) -> Collection: @@ -36,7 +80,7 @@ def create_collection( name: Name of the collection is available for [Event handlers](event_handlers/index.md) and is\ used to distinguish different db-ally runs. llm: LLM used by the collection to generate responses for natural language queries. - event_handlers: Event handlers used by the collection during query executions. Can be used to\ + collection_event_handlers: Event handlers used by the collection during query executions. Can be used to\ log events as [CLIEventHandler](event_handlers/cli_handler.md) or to validate system performance as\ [LangSmithEventHandler](event_handlers/langsmith_handler.md). view_selector: View selector used by the collection to select the best view for the given query.\ @@ -53,12 +97,13 @@ def create_collection( """ view_selector = view_selector or LLMViewSelector(llm=llm) nl_responder = nl_responder or NLResponder(llm=llm) - event_handlers = event_handlers or [] + + collection_event_handlers = collection_event_handlers or event_handlers return Collection( name, nl_responder=nl_responder, view_selector=view_selector, llm=llm, - event_handlers=event_handlers, + collection_event_handlers=collection_event_handlers, ) diff --git a/src/dbally/audit/event_handlers/cli_event_handler.py b/src/dbally/audit/event_handlers/cli_event_handler.py index aa48e049..13e4d848 100644 --- a/src/dbally/audit/event_handlers/cli_event_handler.py +++ b/src/dbally/audit/event_handlers/cli_event_handler.py @@ -32,7 +32,8 @@ class CLIEventHandler(EventHandler): import dbally from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler - my_collection = dbally.create_collection("my_collection", llm, event_handlers=[CLIEventHandler()]) + dbally.set_event_handlers([CLIEventHandler()]) + my_collection = dbally.create_collection("my_collection", llm) ``` After using `CLIEventHandler`, during every `Collection.ask` execution you will see output similar to the one below: diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 5c059fbc..fba7d763 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -5,6 +5,7 @@ from collections import defaultdict from typing import Callable, Dict, List, Optional, Type, TypeVar +import dbally from dbally.audit.event_handlers.base import EventHandler from dbally.audit.event_tracker import EventTracker from dbally.audit.events import RequestEnd, RequestStart @@ -32,7 +33,7 @@ def __init__( name: str, view_selector: ViewSelector, llm: LLM, - event_handlers: List[EventHandler], + collection_event_handlers: List[EventHandler], nl_responder: NLResponder, n_retries: int = 3, ) -> None: @@ -40,11 +41,11 @@ def __init__( Args: name: Name of the collection is available for [Event handlers](event_handlers/index.md) and is\ used to distinguish different db-ally runs. - view_selector: As you register more then one [View](views/index.md) within single collection,\ + view_selector: As you register more than one [View](views/index.md) within single collection,\ before generating the IQL query, a View that fits query the most is selected by the\ [ViewSelector](view_selection/index.md). llm: LLM used by the collection to generate views and respond to natural language queries. - event_handlers: Event handlers used by the collection during query executions. Can be used\ + collection_event_handlers: Event handlers used by the collection during query executions. Can be used\ to log events as [CLIEventHandler](event_handlers/cli_handler.md) or to validate system performance\ as [LangSmithEventHandler](event_handlers/langsmith_handler.md). nl_responder: Object that translates RAW response from db-ally into natural language. @@ -58,7 +59,9 @@ def __init__( self._builders: Dict[str, Callable[[], BaseView]] = {} self._view_selector = view_selector self._nl_responder = nl_responder - self._event_handlers = event_handlers + if collection_event_handlers != dbally.event_handlers: + print("WARNING Default event handler has been overwritten") + self._event_handlers = collection_event_handlers self._llm = llm T = TypeVar("T", bound=BaseView) @@ -68,7 +71,7 @@ def add(self, view: Type[T], builder: Optional[Callable[[], T]] = None, name: Op Register new [View](views/index.md) that will be available to query via the collection. Args: - view: A class inherithing from BaseView. Object of this type will be initialized during\ + view: A class inheriting from BaseView. Object of this type will be initialized during\ query execution. We expect Class instead of object, as otherwise Views must have been implemented\ stateless, which would be cumbersome. builder: Optional factory function that will be used to create the View instance. Use it when you\ @@ -111,15 +114,6 @@ def build_dogs_df_view(): self._views[name] = view self._builders[name] = builder - def add_event_handler(self, event_handler: EventHandler): - """ - Adds an event handler to the list of event handlers. - - Args: - event_handler: The event handler to be added. - """ - self._event_handlers.append(event_handler) - def get(self, name: str) -> BaseView: """ Returns an instance of the view with the given name diff --git a/src/dbally/gradio/__init__.py b/src/dbally/gradio/__init__.py index 41d84c3c..ca878ca7 100644 --- a/src/dbally/gradio/__init__.py +++ b/src/dbally/gradio/__init__.py @@ -1,3 +1,7 @@ +from io import StringIO + from dbally.gradio.gradio_interface import create_gradio_interface -__all__ = ["create_gradio_interface"] +gradio_buffer = StringIO() + +__all__ = ["create_gradio_interface", "gradio_buffer"] diff --git a/src/dbally/gradio/gradio_interface.py b/src/dbally/gradio/gradio_interface.py index 30182b37..2609083a 100644 --- a/src/dbally/gradio/gradio_interface.py +++ b/src/dbally/gradio/gradio_interface.py @@ -5,6 +5,7 @@ import gradio import pandas as pd +import dbally from dbally import BaseStructuredView from dbally.audit import CLIEventHandler from dbally.collection import Collection @@ -177,7 +178,7 @@ async def create_interface(self, user_collection: Collection, preview_limit: int self.preview_limit = preview_limit self.collection = user_collection - self.collection.add_event_handler(CLIEventHandler(self.log)) + dbally.add_event_handler(CLIEventHandler(self.log)) data_preview_frame = pd.DataFrame() question_interactive = False From 434e37a69d00080c224f0d2954ec5aac61d49883 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Tue, 25 Jun 2024 10:49:47 +0200 Subject: [PATCH 18/64] enhancments --- docs/how-to/visualize_views.md | 4 +-- examples/visualize_views_code.py | 6 ++-- src/dbally/__init__.py | 3 +- src/dbally/_main.py | 19 ++++++++++++- .../event_handlers/buffer_event_handler.py | 28 +++++++++++++++++++ .../audit/event_handlers/cli_event_handler.py | 11 ++------ src/dbally/gradio/__init__.py | 6 +--- src/dbally/gradio/gradio_interface.py | 19 +++++++------ 8 files changed, 67 insertions(+), 29 deletions(-) create mode 100644 src/dbally/audit/event_handlers/buffer_event_handler.py diff --git a/docs/how-to/visualize_views.md b/docs/how-to/visualize_views.md index 45cd3066..90291598 100644 --- a/docs/how-to/visualize_views.md +++ b/docs/how-to/visualize_views.md @@ -19,8 +19,8 @@ Define collection with implemented views ```python llm = LiteLLM(model_name="gpt-3.5-turbo") await country_similarity.update() -dbally.add_event_handler(CLIEventHandler(gradio_buffer)) -collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()]) +dbally.add_event_handler(BufferEventHandler()) +collection = dbally.create_collection("recruitment", llm) collection.add(CandidateView, lambda: CandidateView(engine)) collection.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) ``` diff --git a/examples/visualize_views_code.py b/examples/visualize_views_code.py index 42b7aec4..17b2cad4 100644 --- a/examples/visualize_views_code.py +++ b/examples/visualize_views_code.py @@ -5,15 +5,15 @@ from recruiting.cypher_text2sql_view import SampleText2SQLViewCyphers, create_freeform_memory_engine import dbally -from dbally.audit import CLIEventHandler -from dbally.gradio import create_gradio_interface, gradio_buffer +from dbally.audit.event_handlers.buffer_event_handler import BufferEventHandler +from dbally.gradio import create_gradio_interface from dbally.llms.litellm import LiteLLM async def main(): await country_similarity.update() llm = LiteLLM(model_name="gpt-3.5-turbo") - dbally.add_event_handler(CLIEventHandler(gradio_buffer)) + dbally.add_event_handler(BufferEventHandler()) collection = dbally.create_collection("candidates", llm) collection.add(CandidateView, lambda: CandidateView(engine)) collection.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) diff --git a/src/dbally/__init__.py b/src/dbally/__init__.py index f9ee37f7..fa7177b7 100644 --- a/src/dbally/__init__.py +++ b/src/dbally/__init__.py @@ -10,7 +10,7 @@ from dbally.views.structured import BaseStructuredView from .__version__ import __version__ -from ._main import add_event_handler, create_collection, event_handlers, set_event_handlers +from ._main import add_event_handler, create_collection, event_handlers, find_event_handler, set_event_handlers from ._types import NOT_GIVEN, NotGiven from .embeddings.exceptions import ( EmbeddingConnectionError, @@ -26,6 +26,7 @@ "event_handlers", "add_event_handler", "create_collection", + "find_event_handler", "set_event_handlers", "decorators", "BaseStructuredView", diff --git a/src/dbally/_main.py b/src/dbally/_main.py index e2814daf..f8272297 100644 --- a/src/dbally/_main.py +++ b/src/dbally/_main.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Type from .audit import CLIEventHandler from .audit.event_handlers.base import EventHandler @@ -52,6 +52,23 @@ def add_event_handler(event_handler: EventHandler) -> None: event_handlers.append(event_handler) +def find_event_handler(object_type: Type): + """ + Finds an event handler of the specified type from a list of event handlers. + + Args: + object_type (Type[Any]): The type of the event handler to find. + + Returns: + Optional[Any]: The first event handler of the specified type if found, + otherwise None. + """ + for event_handler in event_handlers: + if type(event_handler) is object_type: # pylint disable=unidiomatic-typecheck + return event_handler + return None + + def create_collection( name: str, llm: LLM, diff --git a/src/dbally/audit/event_handlers/buffer_event_handler.py b/src/dbally/audit/event_handlers/buffer_event_handler.py new file mode 100644 index 00000000..a4c0fcb5 --- /dev/null +++ b/src/dbally/audit/event_handlers/buffer_event_handler.py @@ -0,0 +1,28 @@ +from io import StringIO + +from rich.console import Console + +from dbally.audit import CLIEventHandler + + +class BufferEventHandler(CLIEventHandler): + """ + This handler stores in buffer all interactions between LLM and user happening during `Collection.ask`\ + execution. + + ### Usage + + ```python + import dbally + from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler + + dbally.set_event_handlers([BufferEventHandler()]) + my_collection = dbally.create_collection("my_collection", llm) + ``` + """ + + def __init__(self) -> None: + super().__init__() + + self.buffer = StringIO() + self._console = Console(file=self.buffer, record=True) diff --git a/src/dbally/audit/event_handlers/cli_event_handler.py b/src/dbally/audit/event_handlers/cli_event_handler.py index 13e4d848..883465dd 100644 --- a/src/dbally/audit/event_handlers/cli_event_handler.py +++ b/src/dbally/audit/event_handlers/cli_event_handler.py @@ -1,6 +1,4 @@ import re -from io import StringIO -from sys import stdout from typing import Optional try: @@ -24,7 +22,7 @@ class CLIEventHandler(EventHandler): """ This handler displays all interactions between LLM and user happening during `Collection.ask`\ - execution inside the terminal or store them in the given buffer. + execution inside the terminal. ### Usage @@ -41,12 +39,9 @@ class CLIEventHandler(EventHandler): ![Example output from CLIEventHandler](../../assets/event_handler_example.png) """ - def __init__(self, buffer: Optional[StringIO] = None) -> None: + def __init__(self) -> None: super().__init__() - - self.buffer = buffer - out = self.buffer if buffer else stdout - self._console = Console(file=out, record=True) if RICH_OUTPUT else None + self._console = Console(record=True) if RICH_OUTPUT else None def _print_syntax(self, content: str, lexer: Optional[str] = None) -> None: if self._console: diff --git a/src/dbally/gradio/__init__.py b/src/dbally/gradio/__init__.py index ca878ca7..41d84c3c 100644 --- a/src/dbally/gradio/__init__.py +++ b/src/dbally/gradio/__init__.py @@ -1,7 +1,3 @@ -from io import StringIO - from dbally.gradio.gradio_interface import create_gradio_interface -gradio_buffer = StringIO() - -__all__ = ["create_gradio_interface", "gradio_buffer"] +__all__ = ["create_gradio_interface"] diff --git a/src/dbally/gradio/gradio_interface.py b/src/dbally/gradio/gradio_interface.py index 2609083a..b92bd8f2 100644 --- a/src/dbally/gradio/gradio_interface.py +++ b/src/dbally/gradio/gradio_interface.py @@ -1,5 +1,4 @@ import json -from io import StringIO from typing import Any, Dict, List, Tuple import gradio @@ -7,7 +6,7 @@ import dbally from dbally import BaseStructuredView -from dbally.audit import CLIEventHandler +from dbally.audit.event_handlers.buffer_event_handler import BufferEventHandler from dbally.collection import Collection from dbally.collection.exceptions import NoViewFoundError from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError @@ -42,12 +41,15 @@ def __init__(self): self.preview_limit = None self.selected_view_name = None self.collection = None - self.log = StringIO() - - def _load_gradio_data(self, preview_dataframe, label, empty_warning=None) -> Tuple[gradio.DataFrame, gradio.Label]: - if not empty_warning: - empty_warning = "Preview not available" + buffer_handler = dbally.find_event_handler(BufferEventHandler) + if not buffer_handler: + raise ValueError( + "Could not initialize gradio console. Missing buffer handler.\n" + "Add dbally.add_event_handler(BufferEventHandler()) to fix it" + ) + self.log = buffer_handler.buffer + def _load_gradio_data(self, preview_dataframe, label) -> Tuple[gradio.DataFrame, gradio.Label]: if preview_dataframe.empty: gradio_preview_dataframe = gradio.DataFrame(label=label, value=preview_dataframe, visible=False) empty_frame_label = gradio.Label(value=f"{label} not available", visible=True, show_label=False) @@ -131,7 +133,7 @@ async def _ui_ask_query( self.log.seek(0) log_content = self.log.read() - gradio_dataframe, empty_dataframe_warning = self._load_gradio_data(data, "Results", "No matching results found") + gradio_dataframe, empty_dataframe_warning = self._load_gradio_data(data, "Results") return ( gradio_dataframe, empty_dataframe_warning, @@ -178,7 +180,6 @@ async def create_interface(self, user_collection: Collection, preview_limit: int self.preview_limit = preview_limit self.collection = user_collection - dbally.add_event_handler(CLIEventHandler(self.log)) data_preview_frame = pd.DataFrame() question_interactive = False From c5b5f80f5dcf318980ef2a44c823cadd7dde95e0 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Tue, 25 Jun 2024 15:00:53 +0200 Subject: [PATCH 19/64] fix linters --- examples/freeform.py | 2 +- examples/recruiting.py | 2 +- src/dbally/_main.py | 13 ++++++------- src/dbally/gradio/gradio_interface.py | 2 +- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/examples/freeform.py b/examples/freeform.py index da3bea5e..3b3439c5 100644 --- a/examples/freeform.py +++ b/examples/freeform.py @@ -63,7 +63,7 @@ async def main(): connection.execute(sqlalchemy.text(table_config.ddl)) llm = LiteLLM() - collection = dbally.create_collection("text2sql", llm=llm, event_handlers=[CLIEventHandler()]) + collection = dbally.create_collection("text2sql", llm=llm, collection_event_handlers=[CLIEventHandler()]) collection.add(MyText2SqlView, lambda: MyText2SqlView(engine)) await collection.ask("What are the names of products bought by customers from London?") diff --git a/examples/recruiting.py b/examples/recruiting.py index a4813b41..0dc77f89 100644 --- a/examples/recruiting.py +++ b/examples/recruiting.py @@ -102,7 +102,7 @@ async def recruiting_example(db_description: str, benchmark: Benchmark = example recruitment_db = dbally.create_collection( "recruitment", llm=LiteLLM(), - event_handlers=[CLIEventHandler()], + collection_event_handlers=[CLIEventHandler()], ) recruitment_db.add(RecruitmentView, lambda: RecruitmentView(ENGINE)) diff --git a/src/dbally/_main.py b/src/dbally/_main.py index f8272297..1353e8f8 100644 --- a/src/dbally/_main.py +++ b/src/dbally/_main.py @@ -52,20 +52,19 @@ def add_event_handler(event_handler: EventHandler) -> None: event_handlers.append(event_handler) -def find_event_handler(object_type: Type): +def find_event_handler(object_type: Type) -> Optional[EventHandler]: """ Finds an event handler of the specified type from a list of event handlers. Args: - object_type (Type[Any]): The type of the event handler to find. + object_type: The type of the event handler to find. Returns: - Optional[Any]: The first event handler of the specified type if found, - otherwise None. + The first event handler of the specified type if found, otherwise None. """ - for event_handler in event_handlers: - if type(event_handler) is object_type: # pylint disable=unidiomatic-typecheck - return event_handler + for single_event_handler in event_handlers: + if type(single_event_handler) is object_type: # pylint: disable=unidiomatic-typecheck + return single_event_handler return None diff --git a/src/dbally/gradio/gradio_interface.py b/src/dbally/gradio/gradio_interface.py index b92bd8f2..65b17e6d 100644 --- a/src/dbally/gradio/gradio_interface.py +++ b/src/dbally/gradio/gradio_interface.py @@ -47,7 +47,7 @@ def __init__(self): "Could not initialize gradio console. Missing buffer handler.\n" "Add dbally.add_event_handler(BufferEventHandler()) to fix it" ) - self.log = buffer_handler.buffer + self.log: BufferEventHandler = buffer_handler.buffer # pylint: disable=no-member def _load_gradio_data(self, preview_dataframe, label) -> Tuple[gradio.DataFrame, gradio.Label]: if preview_dataframe.empty: From e5a808bdfa50d6ecd277c98e21dfc2953a1e5eb4 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Tue, 25 Jun 2024 18:12:06 +0200 Subject: [PATCH 20/64] global variables moved to module --- examples/freeform.py | 2 +- examples/recruiting.py | 2 +- examples/visualize_views_code.py | 3 +- src/dbally/__init__.py | 6 +-- src/dbally/_main.py | 72 +++------------------------ src/dbally/collection/collection.py | 10 ++-- src/dbally/global_handlers.py | 63 +++++++++++++++++++++++ src/dbally/gradio/gradio_interface.py | 4 +- 8 files changed, 82 insertions(+), 80 deletions(-) create mode 100644 src/dbally/global_handlers.py diff --git a/examples/freeform.py b/examples/freeform.py index 3b3439c5..da3bea5e 100644 --- a/examples/freeform.py +++ b/examples/freeform.py @@ -63,7 +63,7 @@ async def main(): connection.execute(sqlalchemy.text(table_config.ddl)) llm = LiteLLM() - collection = dbally.create_collection("text2sql", llm=llm, collection_event_handlers=[CLIEventHandler()]) + collection = dbally.create_collection("text2sql", llm=llm, event_handlers=[CLIEventHandler()]) collection.add(MyText2SqlView, lambda: MyText2SqlView(engine)) await collection.ask("What are the names of products bought by customers from London?") diff --git a/examples/recruiting.py b/examples/recruiting.py index 0dc77f89..a4813b41 100644 --- a/examples/recruiting.py +++ b/examples/recruiting.py @@ -102,7 +102,7 @@ async def recruiting_example(db_description: str, benchmark: Benchmark = example recruitment_db = dbally.create_collection( "recruitment", llm=LiteLLM(), - collection_event_handlers=[CLIEventHandler()], + event_handlers=[CLIEventHandler()], ) recruitment_db.add(RecruitmentView, lambda: RecruitmentView(ENGINE)) diff --git a/examples/visualize_views_code.py b/examples/visualize_views_code.py index 17b2cad4..ef697123 100644 --- a/examples/visualize_views_code.py +++ b/examples/visualize_views_code.py @@ -6,6 +6,7 @@ import dbally from dbally.audit.event_handlers.buffer_event_handler import BufferEventHandler +from dbally.global_handlers import add_event_handler from dbally.gradio import create_gradio_interface from dbally.llms.litellm import LiteLLM @@ -13,7 +14,7 @@ async def main(): await country_similarity.update() llm = LiteLLM(model_name="gpt-3.5-turbo") - dbally.add_event_handler(BufferEventHandler()) + add_event_handler(BufferEventHandler()) collection = dbally.create_collection("candidates", llm) collection.add(CandidateView, lambda: CandidateView(engine)) collection.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) diff --git a/src/dbally/__init__.py b/src/dbally/__init__.py index fa7177b7..f41e1866 100644 --- a/src/dbally/__init__.py +++ b/src/dbally/__init__.py @@ -10,7 +10,7 @@ from dbally.views.structured import BaseStructuredView from .__version__ import __version__ -from ._main import add_event_handler, create_collection, event_handlers, find_event_handler, set_event_handlers +from ._main import create_collection from ._types import NOT_GIVEN, NotGiven from .embeddings.exceptions import ( EmbeddingConnectionError, @@ -23,11 +23,7 @@ __all__ = [ "__version__", - "event_handlers", - "add_event_handler", "create_collection", - "find_event_handler", - "set_event_handlers", "decorators", "BaseStructuredView", "Collection", diff --git a/src/dbally/_main.py b/src/dbally/_main.py index 1353e8f8..8f55ae96 100644 --- a/src/dbally/_main.py +++ b/src/dbally/_main.py @@ -1,6 +1,7 @@ -from typing import List, Optional, Type +from typing import List, Optional + +from dbally.global_handlers import global_event_handlers_list -from .audit import CLIEventHandler from .audit.event_handlers.base import EventHandler from .collection import Collection from .llms import LLM @@ -8,70 +9,11 @@ from .view_selection.base import ViewSelector from .view_selection.llm_view_selector import LLMViewSelector -# Global list of event handlers initialized with a default CLIEventHandler. -event_handlers: List[EventHandler] = [CLIEventHandler()] - - -def set_event_handlers(event_handler_list: List[EventHandler]) -> None: - """ - Set the global list of event handlers. - - This function replaces the current list of event handlers with the provided list. - It ensures that each handler in the provided list is an instance of EventHandler. - If any handler is not an instance of EventHandler, it raises a ValueError. - - Args: - event_handler_list (List[EventHandler]): The list of event handlers to set. - - Raises: - ValueError: If any handler in the list is not an instance of EventHandler. - """ - for handler in event_handler_list: - if isinstance(type(handler), EventHandler): - raise ValueError(f"{handler} is not an instance of EventHandler") - global event_handlers # pylint: disable=global-statement - event_handlers = event_handler_list - - -def add_event_handler(event_handler: EventHandler) -> None: - """ - Add an event handler to the global list. - - This function appends the provided event handler to the global list of event handlers. - It ensures that the handler is an instance of EventHandler. - If the handler is not an instance of EventHandler, it raises a ValueError. - - Args: - event_handler (EventHandler): The event handler to add. - - Raises: - ValueError: If the handler is not an instance of EventHandler. - """ - if isinstance(type(event_handler), EventHandler): - raise ValueError(f"{event_handler} is not an instance of EventHandler") - event_handlers.append(event_handler) - - -def find_event_handler(object_type: Type) -> Optional[EventHandler]: - """ - Finds an event handler of the specified type from a list of event handlers. - - Args: - object_type: The type of the event handler to find. - - Returns: - The first event handler of the specified type if found, otherwise None. - """ - for single_event_handler in event_handlers: - if type(single_event_handler) is object_type: # pylint: disable=unidiomatic-typecheck - return single_event_handler - return None - def create_collection( name: str, llm: LLM, - collection_event_handlers: Optional[List[EventHandler]] = None, + event_handlers: Optional[List[EventHandler]] = None, view_selector: Optional[ViewSelector] = None, nl_responder: Optional[NLResponder] = None, ) -> Collection: @@ -96,7 +38,7 @@ def create_collection( name: Name of the collection is available for [Event handlers](event_handlers/index.md) and is\ used to distinguish different db-ally runs. llm: LLM used by the collection to generate responses for natural language queries. - collection_event_handlers: Event handlers used by the collection during query executions. Can be used to\ + event_handlers: Event handlers used by the collection during query executions. Can be used to\ log events as [CLIEventHandler](event_handlers/cli_handler.md) or to validate system performance as\ [LangSmithEventHandler](event_handlers/langsmith_handler.md). view_selector: View selector used by the collection to select the best view for the given query.\ @@ -114,12 +56,12 @@ def create_collection( view_selector = view_selector or LLMViewSelector(llm=llm) nl_responder = nl_responder or NLResponder(llm=llm) - collection_event_handlers = collection_event_handlers or event_handlers + event_handlers = event_handlers or global_event_handlers_list return Collection( name, nl_responder=nl_responder, view_selector=view_selector, llm=llm, - collection_event_handlers=collection_event_handlers, + event_handlers=event_handlers, ) diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index fba7d763..90549ab9 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -5,12 +5,12 @@ from collections import defaultdict from typing import Callable, Dict, List, Optional, Type, TypeVar -import dbally from dbally.audit.event_handlers.base import EventHandler from dbally.audit.event_tracker import EventTracker from dbally.audit.events import RequestEnd, RequestStart from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ExecutionResult +from dbally.global_handlers import global_event_handlers_list from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.nl_responder.nl_responder import NLResponder @@ -33,7 +33,7 @@ def __init__( name: str, view_selector: ViewSelector, llm: LLM, - collection_event_handlers: List[EventHandler], + event_handlers: List[EventHandler], nl_responder: NLResponder, n_retries: int = 3, ) -> None: @@ -45,7 +45,7 @@ def __init__( before generating the IQL query, a View that fits query the most is selected by the\ [ViewSelector](view_selection/index.md). llm: LLM used by the collection to generate views and respond to natural language queries. - collection_event_handlers: Event handlers used by the collection during query executions. Can be used\ + event_handlers: Event handlers used by the collection during query executions. Can be used\ to log events as [CLIEventHandler](event_handlers/cli_handler.md) or to validate system performance\ as [LangSmithEventHandler](event_handlers/langsmith_handler.md). nl_responder: Object that translates RAW response from db-ally into natural language. @@ -59,9 +59,9 @@ def __init__( self._builders: Dict[str, Callable[[], BaseView]] = {} self._view_selector = view_selector self._nl_responder = nl_responder - if collection_event_handlers != dbally.event_handlers: + if event_handlers != global_event_handlers_list: print("WARNING Default event handler has been overwritten") - self._event_handlers = collection_event_handlers + self._event_handlers = event_handlers self._llm = llm T = TypeVar("T", bound=BaseView) diff --git a/src/dbally/global_handlers.py b/src/dbally/global_handlers.py new file mode 100644 index 00000000..86346e30 --- /dev/null +++ b/src/dbally/global_handlers.py @@ -0,0 +1,63 @@ +from typing import List, Optional, Type + +from dbally.audit.event_handlers import EventHandler +from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler + +# Global list of event handlers initialized with a default CLIEventHandler. +global_event_handlers_list: List[EventHandler] = [CLIEventHandler()] + + +def set_event_handlers(event_handler_list: List[EventHandler]) -> None: + """ + Set the global list of event handlers. + + This function replaces the current list of event handlers with the provided list. + It ensures that each handler in the provided list is an instance of EventHandler. + If any handler is not an instance of EventHandler, it raises a ValueError. + + Args: + event_handler_list (List[EventHandler]): The list of event handlers to set. + + Raises: + ValueError: If any handler in the list is not an instance of EventHandler. + """ + for handler in event_handler_list: + if isinstance(type(handler), EventHandler): + raise ValueError(f"{handler} is not an instance of EventHandler") + global global_event_handlers_list # pylint: disable=global-statement + global_event_handlers_list = event_handler_list + + +def add_event_handler(event_handler: EventHandler) -> None: + """ + Add an event handler to the global list. + + This function appends the provided event handler to the global list of event handlers. + It ensures that the handler is an instance of EventHandler. + If the handler is not an instance of EventHandler, it raises a ValueError. + + Args: + event_handler (EventHandler): The event handler to add. + + Raises: + ValueError: If the handler is not an instance of EventHandler. + """ + if isinstance(type(event_handler), EventHandler): + raise ValueError(f"{event_handler} is not an instance of EventHandler") + global_event_handlers_list.append(event_handler) + + +def find_event_handler(object_type: Type) -> Optional[EventHandler]: + """ + Finds an event handler of the specified type from a list of event handlers. + + Args: + object_type: The type of the event handler to find. + + Returns: + The first event handler of the specified type if found, otherwise None. + """ + for single_event_handler in global_event_handlers_list: + if type(single_event_handler) is object_type: # pylint: disable=unidiomatic-typecheck + return single_event_handler + return None diff --git a/src/dbally/gradio/gradio_interface.py b/src/dbally/gradio/gradio_interface.py index 65b17e6d..8575117a 100644 --- a/src/dbally/gradio/gradio_interface.py +++ b/src/dbally/gradio/gradio_interface.py @@ -4,11 +4,11 @@ import gradio import pandas as pd -import dbally from dbally import BaseStructuredView from dbally.audit.event_handlers.buffer_event_handler import BufferEventHandler from dbally.collection import Collection from dbally.collection.exceptions import NoViewFoundError +from dbally.global_handlers import find_event_handler from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError from dbally.prompts import PromptTemplateError @@ -41,7 +41,7 @@ def __init__(self): self.preview_limit = None self.selected_view_name = None self.collection = None - buffer_handler = dbally.find_event_handler(BufferEventHandler) + buffer_handler = find_event_handler(BufferEventHandler) if not buffer_handler: raise ValueError( "Could not initialize gradio console. Missing buffer handler.\n" From e1d09c0f6cd51c62f35e291a060c187322aaece4 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Tue, 25 Jun 2024 18:17:05 +0200 Subject: [PATCH 21/64] event handlers --- src/dbally/audit/event_handlers/buffer_event_handler.py | 2 +- src/dbally/audit/event_handlers/cli_event_handler.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dbally/audit/event_handlers/buffer_event_handler.py b/src/dbally/audit/event_handlers/buffer_event_handler.py index a4c0fcb5..4ef2eef4 100644 --- a/src/dbally/audit/event_handlers/buffer_event_handler.py +++ b/src/dbally/audit/event_handlers/buffer_event_handler.py @@ -16,7 +16,7 @@ class BufferEventHandler(CLIEventHandler): import dbally from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler - dbally.set_event_handlers([BufferEventHandler()]) + dbally.global_handlers.set_event_handlers([BufferEventHandler()]) my_collection = dbally.create_collection("my_collection", llm) ``` """ diff --git a/src/dbally/audit/event_handlers/cli_event_handler.py b/src/dbally/audit/event_handlers/cli_event_handler.py index 883465dd..2bc5743f 100644 --- a/src/dbally/audit/event_handlers/cli_event_handler.py +++ b/src/dbally/audit/event_handlers/cli_event_handler.py @@ -30,7 +30,7 @@ class CLIEventHandler(EventHandler): import dbally from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler - dbally.set_event_handlers([CLIEventHandler()]) + dbally.global_handlers.set_event_handlers([CLIEventHandler()]) my_collection = dbally.create_collection("my_collection", llm) ``` From ae5d83817c38c90359cec6b04cf45ea8c5301634 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Fri, 28 Jun 2024 17:37:25 +0200 Subject: [PATCH 22/64] wrap into singleton --- examples/visualize_views_code.py | 3 +- src/dbally/__init__.py | 2 + src/dbally/_main.py | 5 +- .../audit/event_handlers/cli_event_handler.py | 3 +- src/dbally/audit/event_tracker.py | 5 + src/dbally/collection/collection.py | 5 +- src/dbally/global_handlers.py | 63 -------- src/dbally/gradio/gradio_interface.py | 17 +- src/dbally/index.py | 153 ++++++++++++++++++ tests/unit/test_index.py | 106 ++++++++++++ 10 files changed, 282 insertions(+), 80 deletions(-) delete mode 100644 src/dbally/global_handlers.py create mode 100644 src/dbally/index.py create mode 100644 tests/unit/test_index.py diff --git a/examples/visualize_views_code.py b/examples/visualize_views_code.py index ef697123..44650c24 100644 --- a/examples/visualize_views_code.py +++ b/examples/visualize_views_code.py @@ -6,7 +6,6 @@ import dbally from dbally.audit.event_handlers.buffer_event_handler import BufferEventHandler -from dbally.global_handlers import add_event_handler from dbally.gradio import create_gradio_interface from dbally.llms.litellm import LiteLLM @@ -14,7 +13,7 @@ async def main(): await country_similarity.update() llm = LiteLLM(model_name="gpt-3.5-turbo") - add_event_handler(BufferEventHandler()) + dbally.global_event_handlers = [BufferEventHandler()] collection = dbally.create_collection("candidates", llm) collection.add(CandidateView, lambda: CandidateView(engine)) collection.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) diff --git a/src/dbally/__init__.py b/src/dbally/__init__.py index f41e1866..c96166f4 100644 --- a/src/dbally/__init__.py +++ b/src/dbally/__init__.py @@ -19,12 +19,14 @@ EmbeddingStatusError, ) from .exceptions import DbAllyError +from .index import global_event_handlers from .llms.clients.exceptions import LLMConnectionError, LLMError, LLMResponseError, LLMStatusError __all__ = [ "__version__", "create_collection", "decorators", + "global_event_handlers", "BaseStructuredView", "Collection", "DataFrameBaseView", diff --git a/src/dbally/_main.py b/src/dbally/_main.py index 8f55ae96..bf58845e 100644 --- a/src/dbally/_main.py +++ b/src/dbally/_main.py @@ -1,9 +1,8 @@ from typing import List, Optional -from dbally.global_handlers import global_event_handlers_list - from .audit.event_handlers.base import EventHandler from .collection import Collection +from .index import global_event_handlers from .llms import LLM from .nl_responder.nl_responder import NLResponder from .view_selection.base import ViewSelector @@ -56,7 +55,7 @@ def create_collection( view_selector = view_selector or LLMViewSelector(llm=llm) nl_responder = nl_responder or NLResponder(llm=llm) - event_handlers = event_handlers or global_event_handlers_list + event_handlers = event_handlers or global_event_handlers return Collection( name, diff --git a/src/dbally/audit/event_handlers/cli_event_handler.py b/src/dbally/audit/event_handlers/cli_event_handler.py index 2bc5743f..2533b94a 100644 --- a/src/dbally/audit/event_handlers/cli_event_handler.py +++ b/src/dbally/audit/event_handlers/cli_event_handler.py @@ -29,8 +29,9 @@ class CLIEventHandler(EventHandler): ```python import dbally from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler + from dbally.audit.index import set_event_handlers - dbally.global_handlers.set_event_handlers([CLIEventHandler()]) + set_event_handlers([CLIEventHandler()]) my_collection = dbally.create_collection("my_collection", llm) ``` diff --git a/src/dbally/audit/event_tracker.py b/src/dbally/audit/event_tracker.py index 34faf803..79e7acf9 100644 --- a/src/dbally/audit/event_tracker.py +++ b/src/dbally/audit/event_tracker.py @@ -27,11 +27,16 @@ def initialize_with_handlers(cls, event_handlers: List[EventHandler]) -> "EventT Returns: The initialized event store. + + Raises: + ValueError: if invalid event handler object is passed as argument. """ instance = cls() for handler in event_handlers: + if not isinstance(handler, EventHandler): + raise ValueError(f"Could not register {handler}. Handler must be instance of EvenHandler type") instance.subscribe(handler) return instance diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 90549ab9..49c302d4 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -10,7 +10,7 @@ from dbally.audit.events import RequestEnd, RequestStart from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ExecutionResult -from dbally.global_handlers import global_event_handlers_list +from dbally.index import global_event_handlers from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.nl_responder.nl_responder import NLResponder @@ -59,7 +59,8 @@ def __init__( self._builders: Dict[str, Callable[[], BaseView]] = {} self._view_selector = view_selector self._nl_responder = nl_responder - if event_handlers != global_event_handlers_list: + if event_handlers != global_event_handlers: + # At this moment there are no event tracker initialize to record an event print("WARNING Default event handler has been overwritten") self._event_handlers = event_handlers self._llm = llm diff --git a/src/dbally/global_handlers.py b/src/dbally/global_handlers.py deleted file mode 100644 index 86346e30..00000000 --- a/src/dbally/global_handlers.py +++ /dev/null @@ -1,63 +0,0 @@ -from typing import List, Optional, Type - -from dbally.audit.event_handlers import EventHandler -from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler - -# Global list of event handlers initialized with a default CLIEventHandler. -global_event_handlers_list: List[EventHandler] = [CLIEventHandler()] - - -def set_event_handlers(event_handler_list: List[EventHandler]) -> None: - """ - Set the global list of event handlers. - - This function replaces the current list of event handlers with the provided list. - It ensures that each handler in the provided list is an instance of EventHandler. - If any handler is not an instance of EventHandler, it raises a ValueError. - - Args: - event_handler_list (List[EventHandler]): The list of event handlers to set. - - Raises: - ValueError: If any handler in the list is not an instance of EventHandler. - """ - for handler in event_handler_list: - if isinstance(type(handler), EventHandler): - raise ValueError(f"{handler} is not an instance of EventHandler") - global global_event_handlers_list # pylint: disable=global-statement - global_event_handlers_list = event_handler_list - - -def add_event_handler(event_handler: EventHandler) -> None: - """ - Add an event handler to the global list. - - This function appends the provided event handler to the global list of event handlers. - It ensures that the handler is an instance of EventHandler. - If the handler is not an instance of EventHandler, it raises a ValueError. - - Args: - event_handler (EventHandler): The event handler to add. - - Raises: - ValueError: If the handler is not an instance of EventHandler. - """ - if isinstance(type(event_handler), EventHandler): - raise ValueError(f"{event_handler} is not an instance of EventHandler") - global_event_handlers_list.append(event_handler) - - -def find_event_handler(object_type: Type) -> Optional[EventHandler]: - """ - Finds an event handler of the specified type from a list of event handlers. - - Args: - object_type: The type of the event handler to find. - - Returns: - The first event handler of the specified type if found, otherwise None. - """ - for single_event_handler in global_event_handlers_list: - if type(single_event_handler) is object_type: # pylint: disable=unidiomatic-typecheck - return single_event_handler - return None diff --git a/src/dbally/gradio/gradio_interface.py b/src/dbally/gradio/gradio_interface.py index 8575117a..e80e3a59 100644 --- a/src/dbally/gradio/gradio_interface.py +++ b/src/dbally/gradio/gradio_interface.py @@ -4,11 +4,10 @@ import gradio import pandas as pd -from dbally import BaseStructuredView +from dbally import BaseStructuredView, global_event_handlers from dbally.audit.event_handlers.buffer_event_handler import BufferEventHandler from dbally.collection import Collection from dbally.collection.exceptions import NoViewFoundError -from dbally.global_handlers import find_event_handler from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError from dbally.prompts import PromptTemplateError @@ -41,13 +40,13 @@ def __init__(self): self.preview_limit = None self.selected_view_name = None self.collection = None - buffer_handler = find_event_handler(BufferEventHandler) - if not buffer_handler: - raise ValueError( - "Could not initialize gradio console. Missing buffer handler.\n" - "Add dbally.add_event_handler(BufferEventHandler()) to fix it" - ) - self.log: BufferEventHandler = buffer_handler.buffer # pylint: disable=no-member + if buffer_event_handler := global_event_handlers.find_buffer(): + pass + else: + buffer_event_handler = BufferEventHandler() + global_event_handlers.append(buffer_event_handler) + + self.log: BufferEventHandler = buffer_event_handler.buffer # pylint: disable=no-member def _load_gradio_data(self, preview_dataframe, label) -> Tuple[gradio.DataFrame, gradio.Label]: if preview_dataframe.empty: diff --git a/src/dbally/index.py b/src/dbally/index.py new file mode 100644 index 00000000..a117acf8 --- /dev/null +++ b/src/dbally/index.py @@ -0,0 +1,153 @@ +from .audit import EventHandler +from .audit.event_handlers.buffer_event_handler import BufferEventHandler + + +def singleton(self): + """ + A decorator to make a class a singleton. + + Args: + self: The class to be made singleton. + + Returns: + function: A function that returns the single instance of the class. + """ + instances = {} + + def get_instance(*args, **kwargs): + """ + Returns the single instance of the decorated class, creating it if necessary. + + Args: + *args: Positional arguments to pass to the class constructor. + **kwargs: Keyword arguments to pass to the class constructor. + + Returns: + object: The single instance of the class. + """ + if self not in instances: + instances[self] = self(*args, **kwargs) + return instances[self] + + return get_instance + + +@singleton +class GlobalEventHandlerClass: + """ + A singleton class to manage a list of event handlers. + """ + + def __init__(self): + """ + A singleton class to manage a list of event handlers. + """ + self._list = [] + + def add_item(self, item: EventHandler): + """ + Adds an event handler to the list. + + Args: + item: The event handler to add. + + Raises: + ValueError: If the item is not an instance of EventHandler. + """ + if not isinstance(item, EventHandler): + raise ValueError(f"Handler {item} is not EventHandler type") + self._list.append(item) + + def remove_item(self, item): + """ + Removes an event handler from the list. + + Args: + item (EventHandler): The event handler to remove. + """ + self._list.remove(item) + + def get_list(self): + """ + Returns the list of event handlers. + + Returns: + list: The list of event handlers. + """ + return self._list + + def clear_list(self): + """ + Clears the list of event handlers. + """ + self._list.clear() + + def __getitem__(self, index): + """ + Gets the event handler at the specified index. + + Args: + index (int): The index of the event handler to get. + + Returns: + EventHandler: The event handler at the specified index. + """ + return self._list[index] + + def __setitem__(self, index, value): + """ + Sets the event handler at the specified index. + + Args: + index (int): The index at which to set the event handler. + value (EventHandler): The event handler to set. + + Raises: + ValueError: If the value is not an instance of EventHandler. + """ + if not isinstance(value, EventHandler): + raise ValueError(f"Handler {value} is not EventHandler type") + self._list[index] = value + + def __delitem__(self, index): + """ + Deletes the event handler at the specified index. + + Args: + index (int): The index of the event handler to delete. + """ + del self._list[index] + + def __len__(self): + """ + Returns the number of event handlers in the list. + + Returns: + int: The number of event handlers. + """ + return len(self._list) + + def append(self, value): + """ + Appends an event handler to the list. + + Args: + value (EventHandler): The event handler to append. + """ + self.add_item(value) + + def find_buffer(self): + """ + Finds and returns the buffer of a BufferEventHandler in the list. + + Returns: + Buffer: The buffer of a BufferEventHandler if found, None otherwise. + """ + for handler in self._list: + if type(handler) is BufferEventHandler: # pylint: disable=C0123 + return handler.buffer + return None + + +# Global instance of the event handler singleton +global_event_handlers = GlobalEventHandlerClass() diff --git a/tests/unit/test_index.py b/tests/unit/test_index.py new file mode 100644 index 00000000..a72b9367 --- /dev/null +++ b/tests/unit/test_index.py @@ -0,0 +1,106 @@ +import pytest + +from dbally.audit.event_handlers.buffer_event_handler import BufferEventHandler +from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler +from dbally.index import GlobalEventHandlerClass, global_event_handlers + + +def test_singleton(): + handler1 = GlobalEventHandlerClass() + handler2 = GlobalEventHandlerClass() + assert handler1 is handler2 + + +def test_add_item(): + global_event_handlers.clear_list() + handler = CLIEventHandler() + global_event_handlers.add_item(handler) + assert len(global_event_handlers) == 1 + assert global_event_handlers[0] is handler + + +def test_add_item_invalid_type(): + global_event_handlers.clear_list() + with pytest.raises(ValueError): + global_event_handlers.add_item("not an event handler") + + +def test_remove_item(): + global_event_handlers.clear_list() + handler = CLIEventHandler() + global_event_handlers.add_item(handler) + global_event_handlers.remove_item(handler) + assert len(global_event_handlers) == 0 + + +def test_get_list(): + global_event_handlers.clear_list() + handler = CLIEventHandler() + global_event_handlers.add_item(handler) + assert global_event_handlers.get_list() == [handler] + + +def test_clear_list(): + global_event_handlers.clear_list() + handler = CLIEventHandler() + global_event_handlers.add_item(handler) + global_event_handlers.clear_list() + assert len(global_event_handlers) == 0 + + +def test_getitem(): + global_event_handlers.clear_list() + handler = CLIEventHandler() + global_event_handlers.add_item(handler) + assert global_event_handlers[0] is handler + + +def test_setitem(): + global_event_handlers.clear_list() + handler1 = CLIEventHandler() + handler2 = CLIEventHandler() + global_event_handlers.add_item(handler1) + global_event_handlers[0] = handler2 + assert global_event_handlers[0] is handler2 + + +def test_setitem_invalid_type(): + global_event_handlers.clear_list() + handler = CLIEventHandler() + global_event_handlers.add_item(handler) + with pytest.raises(ValueError): + global_event_handlers[0] = "not an event handler" + + +def test_delitem(): + global_event_handlers.clear_list() + handler = CLIEventHandler() + global_event_handlers.add_item(handler) + del global_event_handlers[0] + assert len(global_event_handlers) == 0 + + +def test_len(): + global_event_handlers.clear_list() + assert len(global_event_handlers) == 0 + handler = CLIEventHandler() + global_event_handlers.add_item(handler) + assert len(global_event_handlers) == 1 + + +def test_append(): + global_event_handlers.clear_list() + handler = CLIEventHandler() + global_event_handlers.append(handler) + assert len(global_event_handlers) == 1 + assert global_event_handlers[0] is handler + + +def test_find_buffer(): + global_event_handlers.clear_list() + handler = BufferEventHandler() + global_event_handlers.add_item(handler) + non_buffer_handler = CLIEventHandler() + global_event_handlers.add_item(non_buffer_handler) + global_event_handlers.remove_item(handler) + assert global_event_handlers.find_buffer() is None From 933a696192462f085eebd1563a84f9d6b12fddf7 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Mon, 1 Jul 2024 10:38:47 +0200 Subject: [PATCH 23/64] collection ehnacmentS --- docs/how-to/visualize_views.md | 1 - docs/quickstart/quickstart2_code.py | 5 +- docs/quickstart/quickstart3_code.py | 4 +- docs/quickstart/quickstart_code.py | 5 +- examples/visualize_views_code.py | 2 - .../audit/event_handlers/cli_event_handler.py | 5 +- src/dbally/collection/collection.py | 18 ++++--- src/dbally/gradio/gradio_interface.py | 22 ++++++-- src/dbally/index.py | 50 +++++++++++++++++++ 9 files changed, 90 insertions(+), 22 deletions(-) diff --git a/docs/how-to/visualize_views.md b/docs/how-to/visualize_views.md index 90291598..c9b79791 100644 --- a/docs/how-to/visualize_views.md +++ b/docs/how-to/visualize_views.md @@ -19,7 +19,6 @@ Define collection with implemented views ```python llm = LiteLLM(model_name="gpt-3.5-turbo") await country_similarity.update() -dbally.add_event_handler(BufferEventHandler()) collection = dbally.create_collection("recruitment", llm) collection.add(CandidateView, lambda: CandidateView(engine)) collection.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) diff --git a/docs/quickstart/quickstart2_code.py b/docs/quickstart/quickstart2_code.py index 593e7b4a..cffddff9 100644 --- a/docs/quickstart/quickstart2_code.py +++ b/docs/quickstart/quickstart2_code.py @@ -9,7 +9,7 @@ from sqlalchemy.ext.automap import automap_base import dbally -from dbally import decorators, SqlAlchemyBaseView +from dbally import decorators, SqlAlchemyBaseView, global_event_handlers from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler from dbally.similarity import SimpleSqlAlchemyFetcher, FaissStore, SimilarityIndex from dbally.embeddings.litellm import LiteLLMEmbeddingClient @@ -77,10 +77,11 @@ def from_country(self, country: Annotated[str, country_similarity]) -> sqlalchem async def main(): + global_event_handlers.append(CLIEventHandler()) await country_similarity.update() llm = LiteLLM(model_name="gpt-3.5-turbo") - collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()]) + collection = dbally.create_collection("recruitment", llm) collection.add(CandidateView, lambda: CandidateView(engine)) result = await collection.ask("Find someone from the United States with more than 2 years of experience.") diff --git a/docs/quickstart/quickstart3_code.py b/docs/quickstart/quickstart3_code.py index f0c9270a..42f92305 100644 --- a/docs/quickstart/quickstart3_code.py +++ b/docs/quickstart/quickstart3_code.py @@ -9,7 +9,8 @@ from sqlalchemy.ext.automap import automap_base import pandas as pd -from dbally import decorators, SqlAlchemyBaseView, DataFrameBaseView, ExecutionResult +from dbally import decorators, SqlAlchemyBaseView, DataFrameBaseView, ExecutionResult, global_event_handlers +from dbally.audit import CLIEventHandler from dbally.similarity import SimpleSqlAlchemyFetcher, FaissStore, SimilarityIndex from dbally.embeddings.litellm import LiteLLMEmbeddingClient from dbally.llms.litellm import LiteLLM @@ -125,6 +126,7 @@ def display_results(result: ExecutionResult): async def main(): + global_event_handlers.append(CLIEventHandler()) await country_similarity.update() llm = LiteLLM(model_name="gpt-3.5-turbo") diff --git a/docs/quickstart/quickstart_code.py b/docs/quickstart/quickstart_code.py index 34ee9765..79d3f84a 100644 --- a/docs/quickstart/quickstart_code.py +++ b/docs/quickstart/quickstart_code.py @@ -6,7 +6,7 @@ from sqlalchemy import create_engine from sqlalchemy.ext.automap import automap_base -from dbally import decorators, SqlAlchemyBaseView +from dbally import decorators, SqlAlchemyBaseView, global_event_handlers from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler from dbally.llms.litellm import LiteLLM @@ -58,7 +58,8 @@ def from_country(self, country: str) -> sqlalchemy.ColumnElement: async def main(): llm = LiteLLM(model_name="gpt-3.5-turbo") - collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()]) + global_event_handlers.append(CLIEventHandler()) + collection = dbally.create_collection("recruitment", llm) collection.add(CandidateView, lambda: CandidateView(engine)) result = await collection.ask("Find me French candidates suitable for a senior data scientist position.") diff --git a/examples/visualize_views_code.py b/examples/visualize_views_code.py index 44650c24..edd49b43 100644 --- a/examples/visualize_views_code.py +++ b/examples/visualize_views_code.py @@ -5,7 +5,6 @@ from recruiting.cypher_text2sql_view import SampleText2SQLViewCyphers, create_freeform_memory_engine import dbally -from dbally.audit.event_handlers.buffer_event_handler import BufferEventHandler from dbally.gradio import create_gradio_interface from dbally.llms.litellm import LiteLLM @@ -13,7 +12,6 @@ async def main(): await country_similarity.update() llm = LiteLLM(model_name="gpt-3.5-turbo") - dbally.global_event_handlers = [BufferEventHandler()] collection = dbally.create_collection("candidates", llm) collection.add(CandidateView, lambda: CandidateView(engine)) collection.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) diff --git a/src/dbally/audit/event_handlers/cli_event_handler.py b/src/dbally/audit/event_handlers/cli_event_handler.py index 2533b94a..52565297 100644 --- a/src/dbally/audit/event_handlers/cli_event_handler.py +++ b/src/dbally/audit/event_handlers/cli_event_handler.py @@ -27,11 +27,10 @@ class CLIEventHandler(EventHandler): ### Usage ```python - import dbally from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler - from dbally.audit.index import set_event_handlers + from dbally.index import global_event_handlers - set_event_handlers([CLIEventHandler()]) + dbally.global_event_handlers.append(CLIEventHandler()) my_collection = dbally.create_collection("my_collection", llm) ``` diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 49c302d4..9cdcede5 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -3,14 +3,14 @@ import textwrap import time from collections import defaultdict -from typing import Callable, Dict, List, Optional, Type, TypeVar +from typing import Callable, Dict, List, Optional, Type, TypeVar, Union from dbally.audit.event_handlers.base import EventHandler from dbally.audit.event_tracker import EventTracker from dbally.audit.events import RequestEnd, RequestStart from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ExecutionResult -from dbally.index import global_event_handlers +from dbally.index import GlobalEventHandlerClass, global_event_handlers from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.nl_responder.nl_responder import NLResponder @@ -33,7 +33,7 @@ def __init__( name: str, view_selector: ViewSelector, llm: LLM, - event_handlers: List[EventHandler], + event_handlers: Union[List[EventHandler], GlobalEventHandlerClass], nl_responder: NLResponder, n_retries: int = 3, ) -> None: @@ -59,12 +59,16 @@ def __init__( self._builders: Dict[str, Callable[[], BaseView]] = {} self._view_selector = view_selector self._nl_responder = nl_responder - if event_handlers != global_event_handlers: - # At this moment there are no event tracker initialize to record an event - print("WARNING Default event handler has been overwritten") - self._event_handlers = event_handlers self._llm = llm + if not event_handlers: + event_handlers = global_event_handlers + elif event_handlers != global_event_handlers: + # At this moment, there is no event tracker initialized to record an event + print(f"WARNING: Default event handler has been overwritten for {self.name}.") + + self._event_handlers = event_handlers + T = TypeVar("T", bound=BaseView) def add(self, view: Type[T], builder: Optional[Callable[[], T]] = None, name: Optional[str] = None) -> None: diff --git a/src/dbally/gradio/gradio_interface.py b/src/dbally/gradio/gradio_interface.py index e80e3a59..82233567 100644 --- a/src/dbally/gradio/gradio_interface.py +++ b/src/dbally/gradio/gradio_interface.py @@ -40,15 +40,29 @@ def __init__(self): self.preview_limit = None self.selected_view_name = None self.collection = None - if buffer_event_handler := global_event_handlers.find_buffer(): - pass - else: + + buffer_event_handler = global_event_handlers.find_buffer() + if not buffer_event_handler: buffer_event_handler = BufferEventHandler() global_event_handlers.append(buffer_event_handler) - self.log: BufferEventHandler = buffer_event_handler.buffer # pylint: disable=no-member def _load_gradio_data(self, preview_dataframe, label) -> Tuple[gradio.DataFrame, gradio.Label]: + """ + Load data into Gradio components for preview. + + This function takes a DataFrame and a label, and returns a tuple containing a Gradio DataFrame + and a Gradio Label. The visibility of these components is determined by whether the input + DataFrame is empty. + + Args: + preview_dataframe: The DataFrame to be loaded into the Gradio DataFrame component. + label: The label to be associated with the Gradio components. + + Returns: + A tuple containing the Gradio DataFrame component with the provided data and label and A Gradio Label + indicating the availability of data. + """ if preview_dataframe.empty: gradio_preview_dataframe = gradio.DataFrame(label=label, value=preview_dataframe, visible=False) empty_frame_label = gradio.Label(value=f"{label} not available", visible=True, show_label=False) diff --git a/src/dbally/index.py b/src/dbally/index.py index a117acf8..0a9efb4f 100644 --- a/src/dbally/index.py +++ b/src/dbally/index.py @@ -127,6 +127,56 @@ def __len__(self): """ return len(self._list) + def __eq__(self, other): + """ + Determine if this instance is equal to another object. + + The comparison is based on the equality of the `_list` attribute. + If the other object is an instance of `GlobalEventHandlerClass`, + their `_list` attributes are compared. If the other object is a list, + it is compared directly to this instance's `_list` attribute. + + Args: + other (object): The object to compare with this instance. + + Returns: + bool: True if the objects are considered equal, False otherwise. + + Raises: + NotImplemented: If the `other` object is neither a `GlobalEventHandlerClass` + instance nor a list. + """ + if isinstance(other, type(self)): + return self._list == other._list + if isinstance(other, list): + return self._list == other + return NotImplemented + + def __ne__(self, other): + """ + Determine if this instance is not equal to another object. + + The comparison is based on the inequality of the `_list` attribute. + If the other object is an instance of `GlobalEventHandlerClass`, + their `_list` attributes are compared. If the other object is a list, + it is compared directly to this instance's `_list` attribute. + + Args: + other (object): The object to compare with this instance. + + Returns: + bool: True if the objects are considered not equal, False otherwise. + + Raises: + NotImplemented: If the `other` object is neither a `GlobalEventHandlerClass` + instance nor a list. + """ + if isinstance(other, type(self)): + return self._list != other._list + if isinstance(other, list): + return self._list != other + return NotImplemented + def append(self, value): """ Appends an event handler to the list. From 200eaf1c252f377df9db227d81de1a6dc218e3e7 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Mon, 1 Jul 2024 14:47:49 +0200 Subject: [PATCH 24/64] singleton remove --- docs/quickstart/quickstart_code.py | 1 + examples/visualize_views_code.py | 5 +- src/dbally/__init__.py | 3 +- src/dbally/_main.py | 5 +- .../audit/event_handlers/cli_event_handler.py | 3 +- src/dbally/audit/event_tracker.py | 1 + src/dbally/collection/collection.py | 10 +- src/dbally/gradio/gradio_interface.py | 31 ++- src/dbally/index.py | 203 ------------------ tests/unit/test_index.py | 106 --------- 10 files changed, 45 insertions(+), 323 deletions(-) delete mode 100644 src/dbally/index.py delete mode 100644 tests/unit/test_index.py diff --git a/docs/quickstart/quickstart_code.py b/docs/quickstart/quickstart_code.py index 79d3f84a..9de87e68 100644 --- a/docs/quickstart/quickstart_code.py +++ b/docs/quickstart/quickstart_code.py @@ -59,6 +59,7 @@ async def main(): llm = LiteLLM(model_name="gpt-3.5-turbo") global_event_handlers.append(CLIEventHandler()) + collection = dbally.create_collection("recruitment", llm) collection.add(CandidateView, lambda: CandidateView(engine)) diff --git a/examples/visualize_views_code.py b/examples/visualize_views_code.py index edd49b43..127c3403 100644 --- a/examples/visualize_views_code.py +++ b/examples/visualize_views_code.py @@ -1,10 +1,12 @@ # pylint: disable=missing-function-docstring import asyncio +import dbally +from dbally.audit import CLIEventHandler +from dbally.audit.event_handlers.buffer_event_handler import BufferEventHandler from recruiting.candidate_view_with_similarity_store import CandidateView, country_similarity, engine from recruiting.cypher_text2sql_view import SampleText2SQLViewCyphers, create_freeform_memory_engine -import dbally from dbally.gradio import create_gradio_interface from dbally.llms.litellm import LiteLLM @@ -12,6 +14,7 @@ async def main(): await country_similarity.update() llm = LiteLLM(model_name="gpt-3.5-turbo") + dbally.global_event_handlers = [CLIEventHandler(), BufferEventHandler()] collection = dbally.create_collection("candidates", llm) collection.add(CandidateView, lambda: CandidateView(engine)) collection.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) diff --git a/src/dbally/__init__.py b/src/dbally/__init__.py index c96166f4..03e2aa50 100644 --- a/src/dbally/__init__.py +++ b/src/dbally/__init__.py @@ -19,9 +19,10 @@ EmbeddingStatusError, ) from .exceptions import DbAllyError -from .index import global_event_handlers from .llms.clients.exceptions import LLMConnectionError, LLMError, LLMResponseError, LLMStatusError +global_event_handlers = [] + __all__ = [ "__version__", "create_collection", diff --git a/src/dbally/_main.py b/src/dbally/_main.py index bf58845e..1f010a11 100644 --- a/src/dbally/_main.py +++ b/src/dbally/_main.py @@ -2,7 +2,7 @@ from .audit.event_handlers.base import EventHandler from .collection import Collection -from .index import global_event_handlers +import dbally from .llms import LLM from .nl_responder.nl_responder import NLResponder from .view_selection.base import ViewSelector @@ -54,8 +54,7 @@ def create_collection( """ view_selector = view_selector or LLMViewSelector(llm=llm) nl_responder = nl_responder or NLResponder(llm=llm) - - event_handlers = event_handlers or global_event_handlers + event_handlers = event_handlers or dbally.global_event_handlers return Collection( name, diff --git a/src/dbally/audit/event_handlers/cli_event_handler.py b/src/dbally/audit/event_handlers/cli_event_handler.py index 52565297..f9d20631 100644 --- a/src/dbally/audit/event_handlers/cli_event_handler.py +++ b/src/dbally/audit/event_handlers/cli_event_handler.py @@ -28,7 +28,7 @@ class CLIEventHandler(EventHandler): ```python from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler - from dbally.index import global_event_handlers + from dbally.index import dbally dbally.global_event_handlers.append(CLIEventHandler()) my_collection = dbally.create_collection("my_collection", llm) @@ -61,6 +61,7 @@ async def request_start(self, user_request: RequestStart) -> None: Args: user_request: Object containing name of collection and asked query """ + print(f"buffer {self._console.file}") self._print_syntax(f"[orange3 bold]Request starts... \n[orange3 bold]MESSAGE: [grey53]{user_request.question}") self._print_syntax("[grey53]\n=======================================") self._print_syntax("[grey53]=======================================\n") diff --git a/src/dbally/audit/event_tracker.py b/src/dbally/audit/event_tracker.py index 79e7acf9..7c177b7a 100644 --- a/src/dbally/audit/event_tracker.py +++ b/src/dbally/audit/event_tracker.py @@ -35,6 +35,7 @@ def initialize_with_handlers(cls, event_handlers: List[EventHandler]) -> "EventT instance = cls() for handler in event_handlers: + if not isinstance(handler, EventHandler): raise ValueError(f"Could not register {handler}. Handler must be instance of EvenHandler type") instance.subscribe(handler) diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 9cdcede5..0321dc7a 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -10,7 +10,7 @@ from dbally.audit.events import RequestEnd, RequestStart from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ExecutionResult -from dbally.index import GlobalEventHandlerClass, global_event_handlers +import dbally from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.nl_responder.nl_responder import NLResponder @@ -33,7 +33,7 @@ def __init__( name: str, view_selector: ViewSelector, llm: LLM, - event_handlers: Union[List[EventHandler], GlobalEventHandlerClass], + event_handlers: List[EventHandler], nl_responder: NLResponder, n_retries: int = 3, ) -> None: @@ -62,8 +62,8 @@ def __init__( self._llm = llm if not event_handlers: - event_handlers = global_event_handlers - elif event_handlers != global_event_handlers: + event_handlers = dbally.global_event_handlers + elif event_handlers != dbally.global_event_handlers: # At this moment, there is no event tracker initialized to record an event print(f"WARNING: Default event handler has been overwritten for {self.name}.") @@ -236,7 +236,7 @@ async def ask( ) await event_tracker.request_end(RequestEnd(result=result)) - + # print(dbally.my_callback) return result def get_similarity_indexes(self) -> Dict[AbstractSimilarityIndex, List[IndexLocation]]: diff --git a/src/dbally/gradio/gradio_interface.py b/src/dbally/gradio/gradio_interface.py index 82233567..0c8c4c1e 100644 --- a/src/dbally/gradio/gradio_interface.py +++ b/src/dbally/gradio/gradio_interface.py @@ -4,7 +4,8 @@ import gradio import pandas as pd -from dbally import BaseStructuredView, global_event_handlers +import dbally +from dbally import BaseStructuredView from dbally.audit.event_handlers.buffer_event_handler import BufferEventHandler from dbally.collection import Collection from dbally.collection.exceptions import NoViewFoundError @@ -27,6 +28,23 @@ async def create_gradio_interface(user_collection: Collection, preview_limit: in return gradio_interface +def find_event_buffer(): + """ + Searches through global event handlers to find an instance of BufferEventHandler. + + This function iterates over the list of global event handlers stored in `dbally.global_event_handlers`. + It checks the type of each handler, and if it finds one that is an instance of `BufferEventHandler`, it + returns that handler. If no such handler is found, the function returns `None`. + + Returns: + BufferEventHandler or None: The first instance of `BufferEventHandler` found in the list, or `None` if no such handler is found. + """ + for handler in dbally.global_event_handlers: + if type(handler) is BufferEventHandler: + return handler + return None + + class GradioAdapter: """ A class to adapt and integrate data collection and query execution with Gradio interface components. @@ -41,11 +59,16 @@ def __init__(self): self.selected_view_name = None self.collection = None - buffer_event_handler = global_event_handlers.find_buffer() + buffer_event_handler = find_event_buffer() if not buffer_event_handler: + print("buffer_event_handler not found") buffer_event_handler = BufferEventHandler() - global_event_handlers.append(buffer_event_handler) + dbally.global_event_handlers.append(buffer_event_handler) + else: + print("buffer_event_handler found") + print(dbally.global_event_handlers) self.log: BufferEventHandler = buffer_event_handler.buffer # pylint: disable=no-member + print(f" init 1 {self.log}") def _load_gradio_data(self, preview_dataframe, label) -> Tuple[gradio.DataFrame, gradio.Label]: """ @@ -143,10 +166,12 @@ async def _ui_ask_query( generated_query = {"Query": "No view matched to query"} data = pd.DataFrame() finally: + print(f" ask log 1 {self.log}") self.log.seek(0) log_content = self.log.read() gradio_dataframe, empty_dataframe_warning = self._load_gradio_data(data, "Results") + print(f" ask log 2 {self.log}") return ( gradio_dataframe, empty_dataframe_warning, diff --git a/src/dbally/index.py b/src/dbally/index.py deleted file mode 100644 index 0a9efb4f..00000000 --- a/src/dbally/index.py +++ /dev/null @@ -1,203 +0,0 @@ -from .audit import EventHandler -from .audit.event_handlers.buffer_event_handler import BufferEventHandler - - -def singleton(self): - """ - A decorator to make a class a singleton. - - Args: - self: The class to be made singleton. - - Returns: - function: A function that returns the single instance of the class. - """ - instances = {} - - def get_instance(*args, **kwargs): - """ - Returns the single instance of the decorated class, creating it if necessary. - - Args: - *args: Positional arguments to pass to the class constructor. - **kwargs: Keyword arguments to pass to the class constructor. - - Returns: - object: The single instance of the class. - """ - if self not in instances: - instances[self] = self(*args, **kwargs) - return instances[self] - - return get_instance - - -@singleton -class GlobalEventHandlerClass: - """ - A singleton class to manage a list of event handlers. - """ - - def __init__(self): - """ - A singleton class to manage a list of event handlers. - """ - self._list = [] - - def add_item(self, item: EventHandler): - """ - Adds an event handler to the list. - - Args: - item: The event handler to add. - - Raises: - ValueError: If the item is not an instance of EventHandler. - """ - if not isinstance(item, EventHandler): - raise ValueError(f"Handler {item} is not EventHandler type") - self._list.append(item) - - def remove_item(self, item): - """ - Removes an event handler from the list. - - Args: - item (EventHandler): The event handler to remove. - """ - self._list.remove(item) - - def get_list(self): - """ - Returns the list of event handlers. - - Returns: - list: The list of event handlers. - """ - return self._list - - def clear_list(self): - """ - Clears the list of event handlers. - """ - self._list.clear() - - def __getitem__(self, index): - """ - Gets the event handler at the specified index. - - Args: - index (int): The index of the event handler to get. - - Returns: - EventHandler: The event handler at the specified index. - """ - return self._list[index] - - def __setitem__(self, index, value): - """ - Sets the event handler at the specified index. - - Args: - index (int): The index at which to set the event handler. - value (EventHandler): The event handler to set. - - Raises: - ValueError: If the value is not an instance of EventHandler. - """ - if not isinstance(value, EventHandler): - raise ValueError(f"Handler {value} is not EventHandler type") - self._list[index] = value - - def __delitem__(self, index): - """ - Deletes the event handler at the specified index. - - Args: - index (int): The index of the event handler to delete. - """ - del self._list[index] - - def __len__(self): - """ - Returns the number of event handlers in the list. - - Returns: - int: The number of event handlers. - """ - return len(self._list) - - def __eq__(self, other): - """ - Determine if this instance is equal to another object. - - The comparison is based on the equality of the `_list` attribute. - If the other object is an instance of `GlobalEventHandlerClass`, - their `_list` attributes are compared. If the other object is a list, - it is compared directly to this instance's `_list` attribute. - - Args: - other (object): The object to compare with this instance. - - Returns: - bool: True if the objects are considered equal, False otherwise. - - Raises: - NotImplemented: If the `other` object is neither a `GlobalEventHandlerClass` - instance nor a list. - """ - if isinstance(other, type(self)): - return self._list == other._list - if isinstance(other, list): - return self._list == other - return NotImplemented - - def __ne__(self, other): - """ - Determine if this instance is not equal to another object. - - The comparison is based on the inequality of the `_list` attribute. - If the other object is an instance of `GlobalEventHandlerClass`, - their `_list` attributes are compared. If the other object is a list, - it is compared directly to this instance's `_list` attribute. - - Args: - other (object): The object to compare with this instance. - - Returns: - bool: True if the objects are considered not equal, False otherwise. - - Raises: - NotImplemented: If the `other` object is neither a `GlobalEventHandlerClass` - instance nor a list. - """ - if isinstance(other, type(self)): - return self._list != other._list - if isinstance(other, list): - return self._list != other - return NotImplemented - - def append(self, value): - """ - Appends an event handler to the list. - - Args: - value (EventHandler): The event handler to append. - """ - self.add_item(value) - - def find_buffer(self): - """ - Finds and returns the buffer of a BufferEventHandler in the list. - - Returns: - Buffer: The buffer of a BufferEventHandler if found, None otherwise. - """ - for handler in self._list: - if type(handler) is BufferEventHandler: # pylint: disable=C0123 - return handler.buffer - return None - - -# Global instance of the event handler singleton -global_event_handlers = GlobalEventHandlerClass() diff --git a/tests/unit/test_index.py b/tests/unit/test_index.py deleted file mode 100644 index a72b9367..00000000 --- a/tests/unit/test_index.py +++ /dev/null @@ -1,106 +0,0 @@ -import pytest - -from dbally.audit.event_handlers.buffer_event_handler import BufferEventHandler -from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler -from dbally.index import GlobalEventHandlerClass, global_event_handlers - - -def test_singleton(): - handler1 = GlobalEventHandlerClass() - handler2 = GlobalEventHandlerClass() - assert handler1 is handler2 - - -def test_add_item(): - global_event_handlers.clear_list() - handler = CLIEventHandler() - global_event_handlers.add_item(handler) - assert len(global_event_handlers) == 1 - assert global_event_handlers[0] is handler - - -def test_add_item_invalid_type(): - global_event_handlers.clear_list() - with pytest.raises(ValueError): - global_event_handlers.add_item("not an event handler") - - -def test_remove_item(): - global_event_handlers.clear_list() - handler = CLIEventHandler() - global_event_handlers.add_item(handler) - global_event_handlers.remove_item(handler) - assert len(global_event_handlers) == 0 - - -def test_get_list(): - global_event_handlers.clear_list() - handler = CLIEventHandler() - global_event_handlers.add_item(handler) - assert global_event_handlers.get_list() == [handler] - - -def test_clear_list(): - global_event_handlers.clear_list() - handler = CLIEventHandler() - global_event_handlers.add_item(handler) - global_event_handlers.clear_list() - assert len(global_event_handlers) == 0 - - -def test_getitem(): - global_event_handlers.clear_list() - handler = CLIEventHandler() - global_event_handlers.add_item(handler) - assert global_event_handlers[0] is handler - - -def test_setitem(): - global_event_handlers.clear_list() - handler1 = CLIEventHandler() - handler2 = CLIEventHandler() - global_event_handlers.add_item(handler1) - global_event_handlers[0] = handler2 - assert global_event_handlers[0] is handler2 - - -def test_setitem_invalid_type(): - global_event_handlers.clear_list() - handler = CLIEventHandler() - global_event_handlers.add_item(handler) - with pytest.raises(ValueError): - global_event_handlers[0] = "not an event handler" - - -def test_delitem(): - global_event_handlers.clear_list() - handler = CLIEventHandler() - global_event_handlers.add_item(handler) - del global_event_handlers[0] - assert len(global_event_handlers) == 0 - - -def test_len(): - global_event_handlers.clear_list() - assert len(global_event_handlers) == 0 - handler = CLIEventHandler() - global_event_handlers.add_item(handler) - assert len(global_event_handlers) == 1 - - -def test_append(): - global_event_handlers.clear_list() - handler = CLIEventHandler() - global_event_handlers.append(handler) - assert len(global_event_handlers) == 1 - assert global_event_handlers[0] is handler - - -def test_find_buffer(): - global_event_handlers.clear_list() - handler = BufferEventHandler() - global_event_handlers.add_item(handler) - non_buffer_handler = CLIEventHandler() - global_event_handlers.add_item(non_buffer_handler) - global_event_handlers.remove_item(handler) - assert global_event_handlers.find_buffer() is None From aa3ea916184733781552c1312085ff7e70a9d7d6 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Mon, 1 Jul 2024 14:57:34 +0200 Subject: [PATCH 25/64] fixups --- docs/quickstart/index.md | 5 ++++- docs/quickstart/quickstart2_code.py | 4 ++-- docs/quickstart/quickstart3_code.py | 6 +++--- docs/quickstart/quickstart_code.py | 6 +++--- examples/visualize_views_code.py | 6 +++--- src/dbally/_main.py | 3 ++- .../audit/event_handlers/cli_event_handler.py | 1 - src/dbally/audit/event_tracker.py | 1 - src/dbally/collection/collection.py | 5 ++--- src/dbally/gradio/gradio_interface.py | 14 +++++--------- 10 files changed, 24 insertions(+), 27 deletions(-) diff --git a/docs/quickstart/index.md b/docs/quickstart/index.md index 9754fe08..764f351b 100644 --- a/docs/quickstart/index.md +++ b/docs/quickstart/index.md @@ -114,12 +114,15 @@ Replace `...` with your OpenAI API key. Alternatively, you can set the `OPENAI_A ## Collection Definition Next, create a db-ally collection. A [collection](../concepts/collections.md) is an object where you register views and execute queries. It also requires an AI model to use for generating [IQL queries](../concepts/iql.md) (in this case, the GPT model defined above). +The collection could have its own event handlers which override the globally defined handlers. ```python import dbally +from dbally.audit import CLIEventHandler + async def main(): - collection = dbally.create_collection("recruitment", llm) + collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler]) collection.add(CandidateView, lambda: CandidateView(engine)) ``` diff --git a/docs/quickstart/quickstart2_code.py b/docs/quickstart/quickstart2_code.py index cffddff9..550b7f5c 100644 --- a/docs/quickstart/quickstart2_code.py +++ b/docs/quickstart/quickstart2_code.py @@ -9,7 +9,7 @@ from sqlalchemy.ext.automap import automap_base import dbally -from dbally import decorators, SqlAlchemyBaseView, global_event_handlers +from dbally import decorators, SqlAlchemyBaseView from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler from dbally.similarity import SimpleSqlAlchemyFetcher, FaissStore, SimilarityIndex from dbally.embeddings.litellm import LiteLLMEmbeddingClient @@ -77,7 +77,7 @@ def from_country(self, country: Annotated[str, country_similarity]) -> sqlalchem async def main(): - global_event_handlers.append(CLIEventHandler()) + dbally.global_event_handlers.append(CLIEventHandler()) await country_similarity.update() llm = LiteLLM(model_name="gpt-3.5-turbo") diff --git a/docs/quickstart/quickstart3_code.py b/docs/quickstart/quickstart3_code.py index 42f92305..8366777f 100644 --- a/docs/quickstart/quickstart3_code.py +++ b/docs/quickstart/quickstart3_code.py @@ -1,5 +1,4 @@ # pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring -import dbally import os import asyncio from typing_extensions import Annotated @@ -9,7 +8,8 @@ from sqlalchemy.ext.automap import automap_base import pandas as pd -from dbally import decorators, SqlAlchemyBaseView, DataFrameBaseView, ExecutionResult, global_event_handlers +import dbally +from dbally import decorators, SqlAlchemyBaseView, DataFrameBaseView, ExecutionResult from dbally.audit import CLIEventHandler from dbally.similarity import SimpleSqlAlchemyFetcher, FaissStore, SimilarityIndex from dbally.embeddings.litellm import LiteLLMEmbeddingClient @@ -126,7 +126,7 @@ def display_results(result: ExecutionResult): async def main(): - global_event_handlers.append(CLIEventHandler()) + dbally.global_event_handlers.append(CLIEventHandler()) await country_similarity.update() llm = LiteLLM(model_name="gpt-3.5-turbo") diff --git a/docs/quickstart/quickstart_code.py b/docs/quickstart/quickstart_code.py index 9de87e68..7f0fc612 100644 --- a/docs/quickstart/quickstart_code.py +++ b/docs/quickstart/quickstart_code.py @@ -6,7 +6,8 @@ from sqlalchemy import create_engine from sqlalchemy.ext.automap import automap_base -from dbally import decorators, SqlAlchemyBaseView, global_event_handlers +import dbally +from dbally import decorators, SqlAlchemyBaseView from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler from dbally.llms.litellm import LiteLLM @@ -57,8 +58,7 @@ def from_country(self, country: str) -> sqlalchemy.ColumnElement: async def main(): llm = LiteLLM(model_name="gpt-3.5-turbo") - - global_event_handlers.append(CLIEventHandler()) + dbally.global_event_handlers.append(CLIEventHandler()) collection = dbally.create_collection("recruitment", llm) collection.add(CandidateView, lambda: CandidateView(engine)) diff --git a/examples/visualize_views_code.py b/examples/visualize_views_code.py index 127c3403..c1a6e462 100644 --- a/examples/visualize_views_code.py +++ b/examples/visualize_views_code.py @@ -1,12 +1,12 @@ # pylint: disable=missing-function-docstring import asyncio -import dbally -from dbally.audit import CLIEventHandler -from dbally.audit.event_handlers.buffer_event_handler import BufferEventHandler from recruiting.candidate_view_with_similarity_store import CandidateView, country_similarity, engine from recruiting.cypher_text2sql_view import SampleText2SQLViewCyphers, create_freeform_memory_engine +import dbally +from dbally.audit import CLIEventHandler +from dbally.audit.event_handlers.buffer_event_handler import BufferEventHandler from dbally.gradio import create_gradio_interface from dbally.llms.litellm import LiteLLM diff --git a/src/dbally/_main.py b/src/dbally/_main.py index 1f010a11..8d161856 100644 --- a/src/dbally/_main.py +++ b/src/dbally/_main.py @@ -1,8 +1,9 @@ from typing import List, Optional +import dbally + from .audit.event_handlers.base import EventHandler from .collection import Collection -import dbally from .llms import LLM from .nl_responder.nl_responder import NLResponder from .view_selection.base import ViewSelector diff --git a/src/dbally/audit/event_handlers/cli_event_handler.py b/src/dbally/audit/event_handlers/cli_event_handler.py index f9d20631..450386d5 100644 --- a/src/dbally/audit/event_handlers/cli_event_handler.py +++ b/src/dbally/audit/event_handlers/cli_event_handler.py @@ -61,7 +61,6 @@ async def request_start(self, user_request: RequestStart) -> None: Args: user_request: Object containing name of collection and asked query """ - print(f"buffer {self._console.file}") self._print_syntax(f"[orange3 bold]Request starts... \n[orange3 bold]MESSAGE: [grey53]{user_request.question}") self._print_syntax("[grey53]\n=======================================") self._print_syntax("[grey53]=======================================\n") diff --git a/src/dbally/audit/event_tracker.py b/src/dbally/audit/event_tracker.py index 7c177b7a..79e7acf9 100644 --- a/src/dbally/audit/event_tracker.py +++ b/src/dbally/audit/event_tracker.py @@ -35,7 +35,6 @@ def initialize_with_handlers(cls, event_handlers: List[EventHandler]) -> "EventT instance = cls() for handler in event_handlers: - if not isinstance(handler, EventHandler): raise ValueError(f"Could not register {handler}. Handler must be instance of EvenHandler type") instance.subscribe(handler) diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 0321dc7a..22dfa50a 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -3,14 +3,14 @@ import textwrap import time from collections import defaultdict -from typing import Callable, Dict, List, Optional, Type, TypeVar, Union +from typing import Callable, Dict, List, Optional, Type, TypeVar +import dbally from dbally.audit.event_handlers.base import EventHandler from dbally.audit.event_tracker import EventTracker from dbally.audit.events import RequestEnd, RequestStart from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ExecutionResult -import dbally from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.nl_responder.nl_responder import NLResponder @@ -236,7 +236,6 @@ async def ask( ) await event_tracker.request_end(RequestEnd(result=result)) - # print(dbally.my_callback) return result def get_similarity_indexes(self) -> Dict[AbstractSimilarityIndex, List[IndexLocation]]: diff --git a/src/dbally/gradio/gradio_interface.py b/src/dbally/gradio/gradio_interface.py index 0c8c4c1e..5fa95395 100644 --- a/src/dbally/gradio/gradio_interface.py +++ b/src/dbally/gradio/gradio_interface.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple import gradio import pandas as pd @@ -28,7 +28,7 @@ async def create_gradio_interface(user_collection: Collection, preview_limit: in return gradio_interface -def find_event_buffer(): +def find_event_buffer() -> Optional[BufferEventHandler]: """ Searches through global event handlers to find an instance of BufferEventHandler. @@ -37,10 +37,10 @@ def find_event_buffer(): returns that handler. If no such handler is found, the function returns `None`. Returns: - BufferEventHandler or None: The first instance of `BufferEventHandler` found in the list, or `None` if no such handler is found. + The first instance of `BufferEventHandler` found in the list, or `None` if no such handler is found. """ for handler in dbally.global_event_handlers: - if type(handler) is BufferEventHandler: + if type(handler) is BufferEventHandler: # pylint: disable=C0123 return handler return None @@ -61,14 +61,10 @@ def __init__(self): buffer_event_handler = find_event_buffer() if not buffer_event_handler: - print("buffer_event_handler not found") buffer_event_handler = BufferEventHandler() dbally.global_event_handlers.append(buffer_event_handler) - else: - print("buffer_event_handler found") - print(dbally.global_event_handlers) + self.log: BufferEventHandler = buffer_event_handler.buffer # pylint: disable=no-member - print(f" init 1 {self.log}") def _load_gradio_data(self, preview_dataframe, label) -> Tuple[gradio.DataFrame, gradio.Label]: """ From 0f7a20a41b40b39e5187c749d973031571f1bb2c Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Mon, 1 Jul 2024 16:26:29 +0200 Subject: [PATCH 26/64] fixups --- src/dbally/__init__.py | 4 +++- src/dbally/gradio/gradio_interface.py | 2 -- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dbally/__init__.py b/src/dbally/__init__.py index 03e2aa50..30087f57 100644 --- a/src/dbally/__init__.py +++ b/src/dbally/__init__.py @@ -1,5 +1,7 @@ """ dbally """ +from typing import List, Callable + from dbally.collection.collection import Collection from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ExecutionResult @@ -21,7 +23,7 @@ from .exceptions import DbAllyError from .llms.clients.exceptions import LLMConnectionError, LLMError, LLMResponseError, LLMStatusError -global_event_handlers = [] +global_event_handlers: List[Callable] = [] __all__ = [ "__version__", diff --git a/src/dbally/gradio/gradio_interface.py b/src/dbally/gradio/gradio_interface.py index 5fa95395..17f63bdb 100644 --- a/src/dbally/gradio/gradio_interface.py +++ b/src/dbally/gradio/gradio_interface.py @@ -162,12 +162,10 @@ async def _ui_ask_query( generated_query = {"Query": "No view matched to query"} data = pd.DataFrame() finally: - print(f" ask log 1 {self.log}") self.log.seek(0) log_content = self.log.read() gradio_dataframe, empty_dataframe_warning = self._load_gradio_data(data, "Results") - print(f" ask log 2 {self.log}") return ( gradio_dataframe, empty_dataframe_warning, From f654d84161df387a3400875e34269cf7a087c083 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Mon, 1 Jul 2024 16:29:22 +0200 Subject: [PATCH 27/64] fixups --- src/dbally/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dbally/__init__.py b/src/dbally/__init__.py index 30087f57..380c00b5 100644 --- a/src/dbally/__init__.py +++ b/src/dbally/__init__.py @@ -1,6 +1,6 @@ """ dbally """ -from typing import List, Callable +from typing import Callable, List from dbally.collection.collection import Collection from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError From 72224253ff0c33b3b6e929e0397cfc2466c01315 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Tue, 2 Jul 2024 10:04:13 +0200 Subject: [PATCH 28/64] fixups --- docs/quickstart/quickstart2_code.py | 2 +- docs/quickstart/quickstart3_code.py | 2 +- docs/quickstart/quickstart_code.py | 2 +- examples/recruiting.py | 2 +- examples/visualize_views_code.py | 2 +- src/dbally/__init__.py | 4 ++-- src/dbally/_main.py | 8 ++++---- src/dbally/audit/event_handlers/cli_event_handler.py | 2 +- src/dbally/collection/collection.py | 11 ++++++----- src/dbally/gradio/gradio_interface.py | 8 ++++---- tests/unit/test_collection.py | 8 ++++---- 11 files changed, 26 insertions(+), 25 deletions(-) diff --git a/docs/quickstart/quickstart2_code.py b/docs/quickstart/quickstart2_code.py index 550b7f5c..2384e865 100644 --- a/docs/quickstart/quickstart2_code.py +++ b/docs/quickstart/quickstart2_code.py @@ -77,7 +77,7 @@ def from_country(self, country: Annotated[str, country_similarity]) -> sqlalchem async def main(): - dbally.global_event_handlers.append(CLIEventHandler()) + dbally.event_handlers.append(CLIEventHandler()) await country_similarity.update() llm = LiteLLM(model_name="gpt-3.5-turbo") diff --git a/docs/quickstart/quickstart3_code.py b/docs/quickstart/quickstart3_code.py index 8366777f..ef83a1bd 100644 --- a/docs/quickstart/quickstart3_code.py +++ b/docs/quickstart/quickstart3_code.py @@ -126,7 +126,7 @@ def display_results(result: ExecutionResult): async def main(): - dbally.global_event_handlers.append(CLIEventHandler()) + dbally.event_handlers.append(CLIEventHandler()) await country_similarity.update() llm = LiteLLM(model_name="gpt-3.5-turbo") diff --git a/docs/quickstart/quickstart_code.py b/docs/quickstart/quickstart_code.py index 7f0fc612..a565350d 100644 --- a/docs/quickstart/quickstart_code.py +++ b/docs/quickstart/quickstart_code.py @@ -58,7 +58,7 @@ def from_country(self, country: str) -> sqlalchemy.ColumnElement: async def main(): llm = LiteLLM(model_name="gpt-3.5-turbo") - dbally.global_event_handlers.append(CLIEventHandler()) + dbally.event_handlers.append(CLIEventHandler()) collection = dbally.create_collection("recruitment", llm) collection.add(CandidateView, lambda: CandidateView(engine)) diff --git a/examples/recruiting.py b/examples/recruiting.py index a4813b41..3f4c14f0 100644 --- a/examples/recruiting.py +++ b/examples/recruiting.py @@ -102,7 +102,7 @@ async def recruiting_example(db_description: str, benchmark: Benchmark = example recruitment_db = dbally.create_collection( "recruitment", llm=LiteLLM(), - event_handlers=[CLIEventHandler()], + override_event_handlers=[CLIEventHandler()], ) recruitment_db.add(RecruitmentView, lambda: RecruitmentView(ENGINE)) diff --git a/examples/visualize_views_code.py b/examples/visualize_views_code.py index c1a6e462..504f2ddc 100644 --- a/examples/visualize_views_code.py +++ b/examples/visualize_views_code.py @@ -14,7 +14,7 @@ async def main(): await country_similarity.update() llm = LiteLLM(model_name="gpt-3.5-turbo") - dbally.global_event_handlers = [CLIEventHandler(), BufferEventHandler()] + dbally.event_handlers = [CLIEventHandler(), BufferEventHandler()] collection = dbally.create_collection("candidates", llm) collection.add(CandidateView, lambda: CandidateView(engine)) collection.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) diff --git a/src/dbally/__init__.py b/src/dbally/__init__.py index 380c00b5..fc1a94fd 100644 --- a/src/dbally/__init__.py +++ b/src/dbally/__init__.py @@ -23,13 +23,13 @@ from .exceptions import DbAllyError from .llms.clients.exceptions import LLMConnectionError, LLMError, LLMResponseError, LLMStatusError -global_event_handlers: List[Callable] = [] +event_handlers: List[Callable] = [] __all__ = [ "__version__", "create_collection", "decorators", - "global_event_handlers", + "event_handlers", "BaseStructuredView", "Collection", "DataFrameBaseView", diff --git a/src/dbally/_main.py b/src/dbally/_main.py index 8d161856..9f0cd501 100644 --- a/src/dbally/_main.py +++ b/src/dbally/_main.py @@ -13,7 +13,7 @@ def create_collection( name: str, llm: LLM, - event_handlers: Optional[List[EventHandler]] = None, + override_event_handlers: Optional[List[EventHandler]] = None, view_selector: Optional[ViewSelector] = None, nl_responder: Optional[NLResponder] = None, ) -> Collection: @@ -38,7 +38,7 @@ def create_collection( name: Name of the collection is available for [Event handlers](event_handlers/index.md) and is\ used to distinguish different db-ally runs. llm: LLM used by the collection to generate responses for natural language queries. - event_handlers: Event handlers used by the collection during query executions. Can be used to\ + override_event_handlers: Event handlers used by the collection during query executions. Can be used to\ log events as [CLIEventHandler](event_handlers/cli_handler.md) or to validate system performance as\ [LangSmithEventHandler](event_handlers/langsmith_handler.md). view_selector: View selector used by the collection to select the best view for the given query.\ @@ -55,12 +55,12 @@ def create_collection( """ view_selector = view_selector or LLMViewSelector(llm=llm) nl_responder = nl_responder or NLResponder(llm=llm) - event_handlers = event_handlers or dbally.global_event_handlers + event_handlers = override_event_handlers or dbally.event_handlers return Collection( name, nl_responder=nl_responder, view_selector=view_selector, llm=llm, - event_handlers=event_handlers, + override_event_handlers=event_handlers, ) diff --git a/src/dbally/audit/event_handlers/cli_event_handler.py b/src/dbally/audit/event_handlers/cli_event_handler.py index 450386d5..accda82f 100644 --- a/src/dbally/audit/event_handlers/cli_event_handler.py +++ b/src/dbally/audit/event_handlers/cli_event_handler.py @@ -30,7 +30,7 @@ class CLIEventHandler(EventHandler): from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler from dbally.index import dbally - dbally.global_event_handlers.append(CLIEventHandler()) + dbally.event_handlers.append(CLIEventHandler()) my_collection = dbally.create_collection("my_collection", llm) ``` diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 22dfa50a..345ae49f 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -33,7 +33,7 @@ def __init__( name: str, view_selector: ViewSelector, llm: LLM, - event_handlers: List[EventHandler], + override_event_handlers: List[EventHandler], nl_responder: NLResponder, n_retries: int = 3, ) -> None: @@ -45,7 +45,7 @@ def __init__( before generating the IQL query, a View that fits query the most is selected by the\ [ViewSelector](view_selection/index.md). llm: LLM used by the collection to generate views and respond to natural language queries. - event_handlers: Event handlers used by the collection during query executions. Can be used\ + override_event_handlers: Event handlers used by the collection during query executions. Can be used\ to log events as [CLIEventHandler](event_handlers/cli_handler.md) or to validate system performance\ as [LangSmithEventHandler](event_handlers/langsmith_handler.md). nl_responder: Object that translates RAW response from db-ally into natural language. @@ -60,10 +60,11 @@ def __init__( self._view_selector = view_selector self._nl_responder = nl_responder self._llm = llm + event_handlers = override_event_handlers - if not event_handlers: - event_handlers = dbally.global_event_handlers - elif event_handlers != dbally.global_event_handlers: + if not override_event_handlers: + event_handlers = dbally.event_handlers + elif override_event_handlers != dbally.event_handlers: # At this moment, there is no event tracker initialized to record an event print(f"WARNING: Default event handler has been overwritten for {self.name}.") diff --git a/src/dbally/gradio/gradio_interface.py b/src/dbally/gradio/gradio_interface.py index 17f63bdb..44d4039e 100644 --- a/src/dbally/gradio/gradio_interface.py +++ b/src/dbally/gradio/gradio_interface.py @@ -32,15 +32,15 @@ def find_event_buffer() -> Optional[BufferEventHandler]: """ Searches through global event handlers to find an instance of BufferEventHandler. - This function iterates over the list of global event handlers stored in `dbally.global_event_handlers`. + This function iterates over the list of global event handlers stored in `dbally.event_handlers`. It checks the type of each handler, and if it finds one that is an instance of `BufferEventHandler`, it returns that handler. If no such handler is found, the function returns `None`. Returns: The first instance of `BufferEventHandler` found in the list, or `None` if no such handler is found. """ - for handler in dbally.global_event_handlers: - if type(handler) is BufferEventHandler: # pylint: disable=C0123 + for handler in dbally.event_handlers: + if isinstance(handler, BufferEventHandler): return handler return None @@ -62,7 +62,7 @@ def __init__(self): buffer_event_handler = find_event_buffer() if not buffer_event_handler: buffer_event_handler = BufferEventHandler() - dbally.global_event_handlers.append(buffer_event_handler) + dbally.event_handlers.append(buffer_event_handler) self.log: BufferEventHandler = buffer_event_handler.buffer # pylint: disable=no-member diff --git a/tests/unit/test_collection.py b/tests/unit/test_collection.py index 3e5bddb5..04e13477 100644 --- a/tests/unit/test_collection.py +++ b/tests/unit/test_collection.py @@ -270,7 +270,7 @@ class ViewWithMockGenerator(MockViewBase): def get_iql_generator(self, *_, **__): return iql_generator - collection = Collection("foo", view_selector=Mock(), llm=MockLLM(), nl_responder=Mock(), event_handlers=[]) + collection = Collection("foo", view_selector=Mock(), llm=MockLLM(), nl_responder=Mock(), override_event_handlers=[]) collection.add(ViewWithMockGenerator) return collection @@ -320,7 +320,7 @@ async def test_ask_view_selection_single_view() -> None: view_selector=MockViewSelector(""), llm=MockLLM(), nl_responder=AsyncMock(), - event_handlers=[], + override_event_handlers=[], ) collection.add(MockViewWithResults) @@ -339,7 +339,7 @@ async def test_ask_view_selection_multiple_views() -> None: view_selector=MockViewSelector("MockViewWithResults"), llm=MockLLM(), nl_responder=AsyncMock(), - event_handlers=[], + override_event_handlers=[], ) collection.add(MockView1) collection.add(MockViewWithResults) @@ -360,7 +360,7 @@ async def test_ask_view_selection_no_views() -> None: view_selector=MockViewSelector(""), llm=MockLLM(), nl_responder=AsyncMock(), - event_handlers=[], + override_event_handlers=[], ) with pytest.raises(ValueError): From 2c5bbd0eb9ebe46e5188f51a646d52da8da780a7 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Tue, 2 Jul 2024 10:12:27 +0200 Subject: [PATCH 29/64] fixups --- examples/freeform.py | 2 +- src/dbally/audit/event_handlers/cli_event_handler.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/freeform.py b/examples/freeform.py index da3bea5e..bed615f9 100644 --- a/examples/freeform.py +++ b/examples/freeform.py @@ -63,7 +63,7 @@ async def main(): connection.execute(sqlalchemy.text(table_config.ddl)) llm = LiteLLM() - collection = dbally.create_collection("text2sql", llm=llm, event_handlers=[CLIEventHandler()]) + collection = dbally.create_collection("text2sql", llm=llm, override_event_handlers=[CLIEventHandler()]) collection.add(MyText2SqlView, lambda: MyText2SqlView(engine)) await collection.ask("What are the names of products bought by customers from London?") diff --git a/src/dbally/audit/event_handlers/cli_event_handler.py b/src/dbally/audit/event_handlers/cli_event_handler.py index accda82f..d6ae68d8 100644 --- a/src/dbally/audit/event_handlers/cli_event_handler.py +++ b/src/dbally/audit/event_handlers/cli_event_handler.py @@ -27,8 +27,8 @@ class CLIEventHandler(EventHandler): ### Usage ```python + import dbally from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler - from dbally.index import dbally dbally.event_handlers.append(CLIEventHandler()) my_collection = dbally.create_collection("my_collection", llm) From 70141cc1a29e2aa0eaf35f672899fe4fa1edd8fc Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Tue, 2 Jul 2024 10:30:32 +0200 Subject: [PATCH 30/64] fixup --- src/dbally/collection/collection.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 345ae49f..7aac2bff 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -1,3 +1,5 @@ +import warnings + import asyncio import inspect import textwrap @@ -66,7 +68,7 @@ def __init__( event_handlers = dbally.event_handlers elif override_event_handlers != dbally.event_handlers: # At this moment, there is no event tracker initialized to record an event - print(f"WARNING: Default event handler has been overwritten for {self.name}.") + warnings.warn("Default event handler has been overwritten for {self.name}.") self._event_handlers = event_handlers From 272a5236dad7f840b1449e7532cc5a9427c7206f Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Tue, 2 Jul 2024 10:37:29 +0200 Subject: [PATCH 31/64] fixup --- src/dbally/collection/collection.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 7aac2bff..f3f9abbb 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -1,9 +1,8 @@ -import warnings - import asyncio import inspect import textwrap import time +import warnings from collections import defaultdict from typing import Callable, Dict, List, Optional, Type, TypeVar From 068c345b06e4a5080d375fe627f69050d6f560ba Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Tue, 2 Jul 2024 13:44:31 +0200 Subject: [PATCH 32/64] fixups --- examples/freeform.py | 2 +- examples/recruiting.py | 2 +- src/dbally/_main.py | 11 ++++++----- src/dbally/collection/collection.py | 9 ++++----- tests/unit/test_collection.py | 8 ++++---- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/examples/freeform.py b/examples/freeform.py index bed615f9..da3bea5e 100644 --- a/examples/freeform.py +++ b/examples/freeform.py @@ -63,7 +63,7 @@ async def main(): connection.execute(sqlalchemy.text(table_config.ddl)) llm = LiteLLM() - collection = dbally.create_collection("text2sql", llm=llm, override_event_handlers=[CLIEventHandler()]) + collection = dbally.create_collection("text2sql", llm=llm, event_handlers=[CLIEventHandler()]) collection.add(MyText2SqlView, lambda: MyText2SqlView(engine)) await collection.ask("What are the names of products bought by customers from London?") diff --git a/examples/recruiting.py b/examples/recruiting.py index 3f4c14f0..a4813b41 100644 --- a/examples/recruiting.py +++ b/examples/recruiting.py @@ -102,7 +102,7 @@ async def recruiting_example(db_description: str, benchmark: Benchmark = example recruitment_db = dbally.create_collection( "recruitment", llm=LiteLLM(), - override_event_handlers=[CLIEventHandler()], + event_handlers=[CLIEventHandler()], ) recruitment_db.add(RecruitmentView, lambda: RecruitmentView(ENGINE)) diff --git a/src/dbally/_main.py b/src/dbally/_main.py index 9f0cd501..9167fa0e 100644 --- a/src/dbally/_main.py +++ b/src/dbally/_main.py @@ -13,7 +13,7 @@ def create_collection( name: str, llm: LLM, - override_event_handlers: Optional[List[EventHandler]] = None, + event_handlers: Optional[List[EventHandler]] = None, view_selector: Optional[ViewSelector] = None, nl_responder: Optional[NLResponder] = None, ) -> Collection: @@ -38,9 +38,10 @@ def create_collection( name: Name of the collection is available for [Event handlers](event_handlers/index.md) and is\ used to distinguish different db-ally runs. llm: LLM used by the collection to generate responses for natural language queries. - override_event_handlers: Event handlers used by the collection during query executions. Can be used to\ + event_handlers: Event handlers used by the collection during query executions. Can be used to\ log events as [CLIEventHandler](event_handlers/cli_handler.md) or to validate system performance as\ - [LangSmithEventHandler](event_handlers/langsmith_handler.md). + [LangSmithEventHandler](event_handlers/langsmith_handler.md). If provided, this parameter overrides the + global dbally.event_handlers view_selector: View selector used by the collection to select the best view for the given query.\ If None, a new instance of [LLMViewSelector][dbally.view_selection.llm_view_selector.LLMViewSelector]\ will be used. @@ -55,12 +56,12 @@ def create_collection( """ view_selector = view_selector or LLMViewSelector(llm=llm) nl_responder = nl_responder or NLResponder(llm=llm) - event_handlers = override_event_handlers or dbally.event_handlers + event_handlers = event_handlers or dbally.event_handlers return Collection( name, nl_responder=nl_responder, view_selector=view_selector, llm=llm, - override_event_handlers=event_handlers, + event_handlers=event_handlers, ) diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index f3f9abbb..c7ed86c5 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -34,7 +34,7 @@ def __init__( name: str, view_selector: ViewSelector, llm: LLM, - override_event_handlers: List[EventHandler], + event_handlers: List[EventHandler], nl_responder: NLResponder, n_retries: int = 3, ) -> None: @@ -46,7 +46,7 @@ def __init__( before generating the IQL query, a View that fits query the most is selected by the\ [ViewSelector](view_selection/index.md). llm: LLM used by the collection to generate views and respond to natural language queries. - override_event_handlers: Event handlers used by the collection during query executions. Can be used\ + event_handlers: Event handlers used by the collection during query executions. Can be used\ to log events as [CLIEventHandler](event_handlers/cli_handler.md) or to validate system performance\ as [LangSmithEventHandler](event_handlers/langsmith_handler.md). nl_responder: Object that translates RAW response from db-ally into natural language. @@ -61,11 +61,10 @@ def __init__( self._view_selector = view_selector self._nl_responder = nl_responder self._llm = llm - event_handlers = override_event_handlers - if not override_event_handlers: + if not event_handlers: event_handlers = dbally.event_handlers - elif override_event_handlers != dbally.event_handlers: + elif event_handlers != dbally.event_handlers: # At this moment, there is no event tracker initialized to record an event warnings.warn("Default event handler has been overwritten for {self.name}.") diff --git a/tests/unit/test_collection.py b/tests/unit/test_collection.py index 04e13477..3e5bddb5 100644 --- a/tests/unit/test_collection.py +++ b/tests/unit/test_collection.py @@ -270,7 +270,7 @@ class ViewWithMockGenerator(MockViewBase): def get_iql_generator(self, *_, **__): return iql_generator - collection = Collection("foo", view_selector=Mock(), llm=MockLLM(), nl_responder=Mock(), override_event_handlers=[]) + collection = Collection("foo", view_selector=Mock(), llm=MockLLM(), nl_responder=Mock(), event_handlers=[]) collection.add(ViewWithMockGenerator) return collection @@ -320,7 +320,7 @@ async def test_ask_view_selection_single_view() -> None: view_selector=MockViewSelector(""), llm=MockLLM(), nl_responder=AsyncMock(), - override_event_handlers=[], + event_handlers=[], ) collection.add(MockViewWithResults) @@ -339,7 +339,7 @@ async def test_ask_view_selection_multiple_views() -> None: view_selector=MockViewSelector("MockViewWithResults"), llm=MockLLM(), nl_responder=AsyncMock(), - override_event_handlers=[], + event_handlers=[], ) collection.add(MockView1) collection.add(MockViewWithResults) @@ -360,7 +360,7 @@ async def test_ask_view_selection_no_views() -> None: view_selector=MockViewSelector(""), llm=MockLLM(), nl_responder=AsyncMock(), - override_event_handlers=[], + event_handlers=[], ) with pytest.raises(ValueError): From d8466ba11dc9f4be58f893fbf39f156cabd6ed4b Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Tue, 2 Jul 2024 13:49:18 +0200 Subject: [PATCH 33/64] global event handlers --- src/dbally/audit/event_handlers/buffer_event_handler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/dbally/audit/event_handlers/buffer_event_handler.py b/src/dbally/audit/event_handlers/buffer_event_handler.py index 4ef2eef4..571db8d8 100644 --- a/src/dbally/audit/event_handlers/buffer_event_handler.py +++ b/src/dbally/audit/event_handlers/buffer_event_handler.py @@ -14,9 +14,8 @@ class BufferEventHandler(CLIEventHandler): ```python import dbally - from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler - dbally.global_handlers.set_event_handlers([BufferEventHandler()]) + dbally.global_handlers=[BufferEventHandler()] my_collection = dbally.create_collection("my_collection", llm) ``` """ From 18b1637f933ab36bf3f0daff469e68d022d774cd Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Tue, 2 Jul 2024 13:59:30 +0200 Subject: [PATCH 34/64] cirucalr --- src/dbally/__init__.py | 5 +---- src/dbally/_main.py | 6 ++++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/dbally/__init__.py b/src/dbally/__init__.py index fc1a94fd..6eab49bc 100644 --- a/src/dbally/__init__.py +++ b/src/dbally/__init__.py @@ -1,6 +1,5 @@ """ dbally """ -from typing import Callable, List from dbally.collection.collection import Collection from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError @@ -12,7 +11,7 @@ from dbally.views.structured import BaseStructuredView from .__version__ import __version__ -from ._main import create_collection +from ._main import create_collection, event_handlers from ._types import NOT_GIVEN, NotGiven from .embeddings.exceptions import ( EmbeddingConnectionError, @@ -23,8 +22,6 @@ from .exceptions import DbAllyError from .llms.clients.exceptions import LLMConnectionError, LLMError, LLMResponseError, LLMStatusError -event_handlers: List[Callable] = [] - __all__ = [ "__version__", "create_collection", diff --git a/src/dbally/_main.py b/src/dbally/_main.py index 9167fa0e..c3e6a255 100644 --- a/src/dbally/_main.py +++ b/src/dbally/_main.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Callable, List, Optional import dbally @@ -9,11 +9,13 @@ from .view_selection.base import ViewSelector from .view_selection.llm_view_selector import LLMViewSelector +event_handlers: List[Callable] = [] + def create_collection( name: str, llm: LLM, - event_handlers: Optional[List[EventHandler]] = None, + event_handlers: Optional[List[EventHandler]] = None, # pylint: disable=redefined-outer-name view_selector: Optional[ViewSelector] = None, nl_responder: Optional[NLResponder] = None, ) -> Collection: From f30c965ba4e63513a5d0c1f7a6e25cf656e83422 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Tue, 2 Jul 2024 17:32:50 +0200 Subject: [PATCH 35/64] pylint check --- src/dbally/__init__.py | 5 ++++- src/dbally/_main.py | 8 +++----- src/dbally/audit/event_handlers/cli_event_handler.py | 3 +++ .../audit/event_handlers/langsmith_event_handler.py | 1 + src/dbally/view_selection/random_view_selector.py | 4 ++-- 5 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/dbally/__init__.py b/src/dbally/__init__.py index 6eab49bc..fc1a94fd 100644 --- a/src/dbally/__init__.py +++ b/src/dbally/__init__.py @@ -1,5 +1,6 @@ """ dbally """ +from typing import Callable, List from dbally.collection.collection import Collection from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError @@ -11,7 +12,7 @@ from dbally.views.structured import BaseStructuredView from .__version__ import __version__ -from ._main import create_collection, event_handlers +from ._main import create_collection from ._types import NOT_GIVEN, NotGiven from .embeddings.exceptions import ( EmbeddingConnectionError, @@ -22,6 +23,8 @@ from .exceptions import DbAllyError from .llms.clients.exceptions import LLMConnectionError, LLMError, LLMResponseError, LLMStatusError +event_handlers: List[Callable] = [] + __all__ = [ "__version__", "create_collection", diff --git a/src/dbally/_main.py b/src/dbally/_main.py index c3e6a255..0e084816 100644 --- a/src/dbally/_main.py +++ b/src/dbally/_main.py @@ -1,6 +1,6 @@ -from typing import Callable, List, Optional +from typing import List, Optional -import dbally +from dbally import event_handlers as global_event_handlers from .audit.event_handlers.base import EventHandler from .collection import Collection @@ -9,8 +9,6 @@ from .view_selection.base import ViewSelector from .view_selection.llm_view_selector import LLMViewSelector -event_handlers: List[Callable] = [] - def create_collection( name: str, @@ -58,7 +56,7 @@ def create_collection( """ view_selector = view_selector or LLMViewSelector(llm=llm) nl_responder = nl_responder or NLResponder(llm=llm) - event_handlers = event_handlers or dbally.event_handlers + event_handlers = event_handlers or global_event_handlers return Collection( name, diff --git a/src/dbally/audit/event_handlers/cli_event_handler.py b/src/dbally/audit/event_handlers/cli_event_handler.py index d6ae68d8..6137cafa 100644 --- a/src/dbally/audit/event_handlers/cli_event_handler.py +++ b/src/dbally/audit/event_handlers/cli_event_handler.py @@ -65,6 +65,7 @@ async def request_start(self, user_request: RequestStart) -> None: self._print_syntax("[grey53]\n=======================================") self._print_syntax("[grey53]=======================================\n") + # pylint: disable=unused-argument async def event_start(self, event: Event, request_context: None) -> None: """ Displays information that event has started, then all messages inside the prompt @@ -94,6 +95,7 @@ async def event_start(self, event: Event, request_context: None) -> None: f"[cyan bold]FETCHER: [grey53]{event.fetcher}\n" ) + # pylint: disable=unused-argument async def event_end(self, event: Optional[Event], request_context: None, event_context: None) -> None: """ Displays the response from the LLM. @@ -112,6 +114,7 @@ async def event_end(self, event: Optional[Event], request_context: None, event_c self._print_syntax("[grey53]\n=======================================") self._print_syntax("[grey53]=======================================\n") + # pylint: disable=unused-argument async def request_end(self, output: RequestEnd, request_context: Optional[dict] = None) -> None: """ Displays the output of the request, namely the `results` and the `context` diff --git a/src/dbally/audit/event_handlers/langsmith_event_handler.py b/src/dbally/audit/event_handlers/langsmith_event_handler.py index c0b619c2..89394f8d 100644 --- a/src/dbally/audit/event_handlers/langsmith_event_handler.py +++ b/src/dbally/audit/event_handlers/langsmith_event_handler.py @@ -79,6 +79,7 @@ async def event_start(self, event: Event, request_context: RunTree) -> RunTree: raise ValueError("Unsupported event") + # pylint: disable=unused-argument async def event_end(self, event: Optional[Event], request_context: RunTree, event_context: RunTree) -> None: """ Log the end of the event. diff --git a/src/dbally/view_selection/random_view_selector.py b/src/dbally/view_selection/random_view_selector.py index 61dce39d..cd85e978 100644 --- a/src/dbally/view_selection/random_view_selector.py +++ b/src/dbally/view_selection/random_view_selector.py @@ -15,8 +15,8 @@ async def select_view( self, question: str, views: Dict[str, str], - event_tracker: EventTracker, - llm_options: Optional[LLMOptions] = None, + event_tracker: EventTracker, # pylint: disable=unused-argument + llm_options: Optional[LLMOptions] = None, # pylint: disable=unused-argument ) -> str: """ Dummy implementation returning random view. From d1aa01299cd69f2576de3509fcc14fc6763aff22 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Tue, 2 Jul 2024 18:16:41 +0200 Subject: [PATCH 36/64] fixups --- docs/quickstart/quickstart2_code.py | 2 +- docs/quickstart/quickstart3_code.py | 2 +- docs/quickstart/quickstart_code.py | 2 +- examples/visualize_views_code.py | 2 +- src/dbally/__init__.py | 4 ++-- src/dbally/_main.py | 4 ++-- src/dbally/collection/collection.py | 4 ++-- src/dbally/gradio/gradio_interface.py | 4 ++-- 8 files changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/quickstart/quickstart2_code.py b/docs/quickstart/quickstart2_code.py index 2384e865..22c9b5f6 100644 --- a/docs/quickstart/quickstart2_code.py +++ b/docs/quickstart/quickstart2_code.py @@ -77,7 +77,7 @@ def from_country(self, country: Annotated[str, country_similarity]) -> sqlalchem async def main(): - dbally.event_handlers.append(CLIEventHandler()) + dbally.event_handlers_list.append(CLIEventHandler()) await country_similarity.update() llm = LiteLLM(model_name="gpt-3.5-turbo") diff --git a/docs/quickstart/quickstart3_code.py b/docs/quickstart/quickstart3_code.py index ef83a1bd..62bd4b88 100644 --- a/docs/quickstart/quickstart3_code.py +++ b/docs/quickstart/quickstart3_code.py @@ -126,7 +126,7 @@ def display_results(result: ExecutionResult): async def main(): - dbally.event_handlers.append(CLIEventHandler()) + dbally.event_handlers_list.append(CLIEventHandler()) await country_similarity.update() llm = LiteLLM(model_name="gpt-3.5-turbo") diff --git a/docs/quickstart/quickstart_code.py b/docs/quickstart/quickstart_code.py index a565350d..1bfef4f4 100644 --- a/docs/quickstart/quickstart_code.py +++ b/docs/quickstart/quickstart_code.py @@ -58,7 +58,7 @@ def from_country(self, country: str) -> sqlalchemy.ColumnElement: async def main(): llm = LiteLLM(model_name="gpt-3.5-turbo") - dbally.event_handlers.append(CLIEventHandler()) + dbally.event_handlers_list.append(CLIEventHandler()) collection = dbally.create_collection("recruitment", llm) collection.add(CandidateView, lambda: CandidateView(engine)) diff --git a/examples/visualize_views_code.py b/examples/visualize_views_code.py index 504f2ddc..a1bdc68d 100644 --- a/examples/visualize_views_code.py +++ b/examples/visualize_views_code.py @@ -14,7 +14,7 @@ async def main(): await country_similarity.update() llm = LiteLLM(model_name="gpt-3.5-turbo") - dbally.event_handlers = [CLIEventHandler(), BufferEventHandler()] + dbally.event_handlers_list = [CLIEventHandler(), BufferEventHandler()] collection = dbally.create_collection("candidates", llm) collection.add(CandidateView, lambda: CandidateView(engine)) collection.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) diff --git a/src/dbally/__init__.py b/src/dbally/__init__.py index fc1a94fd..c5277f0d 100644 --- a/src/dbally/__init__.py +++ b/src/dbally/__init__.py @@ -23,13 +23,13 @@ from .exceptions import DbAllyError from .llms.clients.exceptions import LLMConnectionError, LLMError, LLMResponseError, LLMStatusError -event_handlers: List[Callable] = [] +event_handlers_list: List[Callable] = [] __all__ = [ "__version__", "create_collection", "decorators", - "event_handlers", + "event_handlers_list", "BaseStructuredView", "Collection", "DataFrameBaseView", diff --git a/src/dbally/_main.py b/src/dbally/_main.py index 0e084816..1e54c523 100644 --- a/src/dbally/_main.py +++ b/src/dbally/_main.py @@ -1,6 +1,6 @@ from typing import List, Optional -from dbally import event_handlers as global_event_handlers +import dbally from .audit.event_handlers.base import EventHandler from .collection import Collection @@ -56,7 +56,7 @@ def create_collection( """ view_selector = view_selector or LLMViewSelector(llm=llm) nl_responder = nl_responder or NLResponder(llm=llm) - event_handlers = event_handlers or global_event_handlers + event_handlers = event_handlers or dbally.event_handlers_list return Collection( name, diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index c7ed86c5..a19e1ada 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -63,8 +63,8 @@ def __init__( self._llm = llm if not event_handlers: - event_handlers = dbally.event_handlers - elif event_handlers != dbally.event_handlers: + event_handlers = dbally.event_handlers_list + elif event_handlers != dbally.event_handlers_list: # At this moment, there is no event tracker initialized to record an event warnings.warn("Default event handler has been overwritten for {self.name}.") diff --git a/src/dbally/gradio/gradio_interface.py b/src/dbally/gradio/gradio_interface.py index 44d4039e..4cae4925 100644 --- a/src/dbally/gradio/gradio_interface.py +++ b/src/dbally/gradio/gradio_interface.py @@ -39,7 +39,7 @@ def find_event_buffer() -> Optional[BufferEventHandler]: Returns: The first instance of `BufferEventHandler` found in the list, or `None` if no such handler is found. """ - for handler in dbally.event_handlers: + for handler in dbally.event_handlers_list: if isinstance(handler, BufferEventHandler): return handler return None @@ -62,7 +62,7 @@ def __init__(self): buffer_event_handler = find_event_buffer() if not buffer_event_handler: buffer_event_handler = BufferEventHandler() - dbally.event_handlers.append(buffer_event_handler) + dbally.event_handlers_list.append(buffer_event_handler) self.log: BufferEventHandler = buffer_event_handler.buffer # pylint: disable=no-member From 4552d852389e7660aae4632a066ec477783858ef Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Tue, 2 Jul 2024 18:23:33 +0200 Subject: [PATCH 37/64] comment fixups --- src/dbally/_main.py | 4 ++-- src/dbally/audit/event_handlers/buffer_event_handler.py | 2 +- src/dbally/audit/event_handlers/cli_event_handler.py | 2 +- src/dbally/view_selection/random_view_selector.py | 5 +++-- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/dbally/_main.py b/src/dbally/_main.py index 1e54c523..b22854b1 100644 --- a/src/dbally/_main.py +++ b/src/dbally/_main.py @@ -13,7 +13,7 @@ def create_collection( name: str, llm: LLM, - event_handlers: Optional[List[EventHandler]] = None, # pylint: disable=redefined-outer-name + event_handlers: Optional[List[EventHandler]] = None, view_selector: Optional[ViewSelector] = None, nl_responder: Optional[NLResponder] = None, ) -> Collection: @@ -41,7 +41,7 @@ def create_collection( event_handlers: Event handlers used by the collection during query executions. Can be used to\ log events as [CLIEventHandler](event_handlers/cli_handler.md) or to validate system performance as\ [LangSmithEventHandler](event_handlers/langsmith_handler.md). If provided, this parameter overrides the - global dbally.event_handlers + global dbally.event_handlers_list view_selector: View selector used by the collection to select the best view for the given query.\ If None, a new instance of [LLMViewSelector][dbally.view_selection.llm_view_selector.LLMViewSelector]\ will be used. diff --git a/src/dbally/audit/event_handlers/buffer_event_handler.py b/src/dbally/audit/event_handlers/buffer_event_handler.py index 571db8d8..b0ef18e2 100644 --- a/src/dbally/audit/event_handlers/buffer_event_handler.py +++ b/src/dbally/audit/event_handlers/buffer_event_handler.py @@ -15,7 +15,7 @@ class BufferEventHandler(CLIEventHandler): ```python import dbally - dbally.global_handlers=[BufferEventHandler()] + dbally.event_handlers_list=[BufferEventHandler()] my_collection = dbally.create_collection("my_collection", llm) ``` """ diff --git a/src/dbally/audit/event_handlers/cli_event_handler.py b/src/dbally/audit/event_handlers/cli_event_handler.py index 6137cafa..af105e60 100644 --- a/src/dbally/audit/event_handlers/cli_event_handler.py +++ b/src/dbally/audit/event_handlers/cli_event_handler.py @@ -30,7 +30,7 @@ class CLIEventHandler(EventHandler): import dbally from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler - dbally.event_handlers.append(CLIEventHandler()) + dbally.event_handlers_list.append(CLIEventHandler()) my_collection = dbally.create_collection("my_collection", llm) ``` diff --git a/src/dbally/view_selection/random_view_selector.py b/src/dbally/view_selection/random_view_selector.py index cd85e978..0252b9fd 100644 --- a/src/dbally/view_selection/random_view_selector.py +++ b/src/dbally/view_selection/random_view_selector.py @@ -11,12 +11,13 @@ class RandomViewSelector(ViewSelector): Mock View Selector selecting a random view. """ + # pylint: disable=unused-argument async def select_view( self, question: str, views: Dict[str, str], - event_tracker: EventTracker, # pylint: disable=unused-argument - llm_options: Optional[LLMOptions] = None, # pylint: disable=unused-argument + event_tracker: EventTracker, + llm_options: Optional[LLMOptions] = None, ) -> str: """ Dummy implementation returning random view. From db3fc6f90e2e00d06784780dc121456b93bfb3af Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Tue, 2 Jul 2024 19:31:20 +0200 Subject: [PATCH 38/64] Revert "fixups" This reverts commit d1aa01299cd69f2576de3509fcc14fc6763aff22. --- docs/quickstart/quickstart2_code.py | 2 +- docs/quickstart/quickstart3_code.py | 2 +- docs/quickstart/quickstart_code.py | 2 +- examples/visualize_views_code.py | 2 +- src/dbally/__init__.py | 4 ++-- src/dbally/_main.py | 4 ++-- src/dbally/collection/collection.py | 4 ++-- src/dbally/gradio/gradio_interface.py | 4 ++-- 8 files changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/quickstart/quickstart2_code.py b/docs/quickstart/quickstart2_code.py index 22c9b5f6..2384e865 100644 --- a/docs/quickstart/quickstart2_code.py +++ b/docs/quickstart/quickstart2_code.py @@ -77,7 +77,7 @@ def from_country(self, country: Annotated[str, country_similarity]) -> sqlalchem async def main(): - dbally.event_handlers_list.append(CLIEventHandler()) + dbally.event_handlers.append(CLIEventHandler()) await country_similarity.update() llm = LiteLLM(model_name="gpt-3.5-turbo") diff --git a/docs/quickstart/quickstart3_code.py b/docs/quickstart/quickstart3_code.py index 62bd4b88..ef83a1bd 100644 --- a/docs/quickstart/quickstart3_code.py +++ b/docs/quickstart/quickstart3_code.py @@ -126,7 +126,7 @@ def display_results(result: ExecutionResult): async def main(): - dbally.event_handlers_list.append(CLIEventHandler()) + dbally.event_handlers.append(CLIEventHandler()) await country_similarity.update() llm = LiteLLM(model_name="gpt-3.5-turbo") diff --git a/docs/quickstart/quickstart_code.py b/docs/quickstart/quickstart_code.py index 1bfef4f4..a565350d 100644 --- a/docs/quickstart/quickstart_code.py +++ b/docs/quickstart/quickstart_code.py @@ -58,7 +58,7 @@ def from_country(self, country: str) -> sqlalchemy.ColumnElement: async def main(): llm = LiteLLM(model_name="gpt-3.5-turbo") - dbally.event_handlers_list.append(CLIEventHandler()) + dbally.event_handlers.append(CLIEventHandler()) collection = dbally.create_collection("recruitment", llm) collection.add(CandidateView, lambda: CandidateView(engine)) diff --git a/examples/visualize_views_code.py b/examples/visualize_views_code.py index a1bdc68d..504f2ddc 100644 --- a/examples/visualize_views_code.py +++ b/examples/visualize_views_code.py @@ -14,7 +14,7 @@ async def main(): await country_similarity.update() llm = LiteLLM(model_name="gpt-3.5-turbo") - dbally.event_handlers_list = [CLIEventHandler(), BufferEventHandler()] + dbally.event_handlers = [CLIEventHandler(), BufferEventHandler()] collection = dbally.create_collection("candidates", llm) collection.add(CandidateView, lambda: CandidateView(engine)) collection.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) diff --git a/src/dbally/__init__.py b/src/dbally/__init__.py index c5277f0d..fc1a94fd 100644 --- a/src/dbally/__init__.py +++ b/src/dbally/__init__.py @@ -23,13 +23,13 @@ from .exceptions import DbAllyError from .llms.clients.exceptions import LLMConnectionError, LLMError, LLMResponseError, LLMStatusError -event_handlers_list: List[Callable] = [] +event_handlers: List[Callable] = [] __all__ = [ "__version__", "create_collection", "decorators", - "event_handlers_list", + "event_handlers", "BaseStructuredView", "Collection", "DataFrameBaseView", diff --git a/src/dbally/_main.py b/src/dbally/_main.py index b22854b1..b953f7fb 100644 --- a/src/dbally/_main.py +++ b/src/dbally/_main.py @@ -1,6 +1,6 @@ from typing import List, Optional -import dbally +from dbally import event_handlers as global_event_handlers from .audit.event_handlers.base import EventHandler from .collection import Collection @@ -56,7 +56,7 @@ def create_collection( """ view_selector = view_selector or LLMViewSelector(llm=llm) nl_responder = nl_responder or NLResponder(llm=llm) - event_handlers = event_handlers or dbally.event_handlers_list + event_handlers = event_handlers or global_event_handlers return Collection( name, diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index a19e1ada..c7ed86c5 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -63,8 +63,8 @@ def __init__( self._llm = llm if not event_handlers: - event_handlers = dbally.event_handlers_list - elif event_handlers != dbally.event_handlers_list: + event_handlers = dbally.event_handlers + elif event_handlers != dbally.event_handlers: # At this moment, there is no event tracker initialized to record an event warnings.warn("Default event handler has been overwritten for {self.name}.") diff --git a/src/dbally/gradio/gradio_interface.py b/src/dbally/gradio/gradio_interface.py index 4cae4925..44d4039e 100644 --- a/src/dbally/gradio/gradio_interface.py +++ b/src/dbally/gradio/gradio_interface.py @@ -39,7 +39,7 @@ def find_event_buffer() -> Optional[BufferEventHandler]: Returns: The first instance of `BufferEventHandler` found in the list, or `None` if no such handler is found. """ - for handler in dbally.event_handlers_list: + for handler in dbally.event_handlers: if isinstance(handler, BufferEventHandler): return handler return None @@ -62,7 +62,7 @@ def __init__(self): buffer_event_handler = find_event_buffer() if not buffer_event_handler: buffer_event_handler = BufferEventHandler() - dbally.event_handlers_list.append(buffer_event_handler) + dbally.event_handlers.append(buffer_event_handler) self.log: BufferEventHandler = buffer_event_handler.buffer # pylint: disable=no-member From 94ab6c70ad7a413d2e13cbe1dc9d061dd214a02b Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Tue, 2 Jul 2024 20:19:14 +0200 Subject: [PATCH 39/64] move create collections to collections --- src/dbally/__init__.py | 3 +- src/dbally/_main.py | 67 ----------------------------- src/dbally/collection/__init__.py | 3 +- src/dbally/collection/collection.py | 58 +++++++++++++++++++++++++ tests/unit/test_collection.py | 3 +- 5 files changed, 62 insertions(+), 72 deletions(-) delete mode 100644 src/dbally/_main.py diff --git a/src/dbally/__init__.py b/src/dbally/__init__.py index fc1a94fd..754e897c 100644 --- a/src/dbally/__init__.py +++ b/src/dbally/__init__.py @@ -2,7 +2,7 @@ from typing import Callable, List -from dbally.collection.collection import Collection +from dbally.collection.collection import Collection, create_collection from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ExecutionResult from dbally.views import decorators @@ -12,7 +12,6 @@ from dbally.views.structured import BaseStructuredView from .__version__ import __version__ -from ._main import create_collection from ._types import NOT_GIVEN, NotGiven from .embeddings.exceptions import ( EmbeddingConnectionError, diff --git a/src/dbally/_main.py b/src/dbally/_main.py deleted file mode 100644 index b953f7fb..00000000 --- a/src/dbally/_main.py +++ /dev/null @@ -1,67 +0,0 @@ -from typing import List, Optional - -from dbally import event_handlers as global_event_handlers - -from .audit.event_handlers.base import EventHandler -from .collection import Collection -from .llms import LLM -from .nl_responder.nl_responder import NLResponder -from .view_selection.base import ViewSelector -from .view_selection.llm_view_selector import LLMViewSelector - - -def create_collection( - name: str, - llm: LLM, - event_handlers: Optional[List[EventHandler]] = None, - view_selector: Optional[ViewSelector] = None, - nl_responder: Optional[NLResponder] = None, -) -> Collection: - """ - Create a new [Collection](collection.md) that is a container for registering views and the\ - main entrypoint to db-ally features. - - Unlike instantiating a [Collection][dbally.Collection] directly, this function\ - provides a set of default values for various dependencies like LLM client, view selector,\ - IQL generator, and NL responder. - - ##Example - - ```python - from dbally import create_collection - from dbally.llms.litellm import LiteLLM - - collection = create_collection("my_collection", llm=LiteLLM()) - ``` - - Args: - name: Name of the collection is available for [Event handlers](event_handlers/index.md) and is\ - used to distinguish different db-ally runs. - llm: LLM used by the collection to generate responses for natural language queries. - event_handlers: Event handlers used by the collection during query executions. Can be used to\ - log events as [CLIEventHandler](event_handlers/cli_handler.md) or to validate system performance as\ - [LangSmithEventHandler](event_handlers/langsmith_handler.md). If provided, this parameter overrides the - global dbally.event_handlers_list - view_selector: View selector used by the collection to select the best view for the given query.\ - If None, a new instance of [LLMViewSelector][dbally.view_selection.llm_view_selector.LLMViewSelector]\ - will be used. - nl_responder: NL responder used by the collection to respond to natural language queries. If None,\ - a new instance of [NLResponder][dbally.nl_responder.nl_responder.NLResponder] will be used. - - Returns: - a new instance of db-ally Collection - - Raises: - ValueError: if default LLM client is not configured - """ - view_selector = view_selector or LLMViewSelector(llm=llm) - nl_responder = nl_responder or NLResponder(llm=llm) - event_handlers = event_handlers or global_event_handlers - - return Collection( - name, - nl_responder=nl_responder, - view_selector=view_selector, - llm=llm, - event_handlers=event_handlers, - ) diff --git a/src/dbally/collection/__init__.py b/src/dbally/collection/__init__.py index 66eea8fe..db041159 100644 --- a/src/dbally/collection/__init__.py +++ b/src/dbally/collection/__init__.py @@ -1,8 +1,9 @@ -from dbally.collection.collection import Collection +from dbally.collection.collection import Collection, create_collection from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ExecutionResult, ViewExecutionResult __all__ = [ + "create_collection", "Collection", "ExecutionResult", "ViewExecutionResult", diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index c7ed86c5..9ab3b6f0 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -16,6 +16,7 @@ from dbally.llms.clients.base import LLMOptions from dbally.nl_responder.nl_responder import NLResponder from dbally.similarity.index import AbstractSimilarityIndex +from dbally.view_selection import LLMViewSelector from dbally.view_selection.base import ViewSelector from dbally.views.base import BaseView, IndexLocation @@ -275,3 +276,60 @@ async def update_similarity_indexes(self) -> None: if failed_indexes: failed_locations = [loc for index in failed_indexes for loc in indexes[index]] raise IndexUpdateError(failed_indexes, failed_locations) + + +def create_collection( + name: str, + llm: LLM, + event_handlers: Optional[List[EventHandler]] = None, + view_selector: Optional[ViewSelector] = None, + nl_responder: Optional[NLResponder] = None, +) -> Collection: + """ + Create a new [Collection](collection.md) that is a container for registering views and the\ + main entrypoint to db-ally features. + + Unlike instantiating a [Collection][dbally.Collection] directly, this function\ + provides a set of default values for various dependencies like LLM client, view selector,\ + IQL generator, and NL responder. + + ##Example + + ```python + from dbally import create_collection + from dbally.llms.litellm import LiteLLM + + collection = create_collection("my_collection", llm=LiteLLM()) + ``` + + Args: + name: Name of the collection is available for [Event handlers](event_handlers/index.md) and is\ + used to distinguish different db-ally runs. + llm: LLM used by the collection to generate responses for natural language queries. + event_handlers: Event handlers used by the collection during query executions. Can be used to\ + log events as [CLIEventHandler](event_handlers/cli_handler.md) or to validate system performance as\ + [LangSmithEventHandler](event_handlers/langsmith_handler.md). If provided, this parameter overrides the + global dbally.event_handlers_list + view_selector: View selector used by the collection to select the best view for the given query.\ + If None, a new instance of [LLMViewSelector][dbally.view_selection.llm_view_selector.LLMViewSelector]\ + will be used. + nl_responder: NL responder used by the collection to respond to natural language queries. If None,\ + a new instance of [NLResponder][dbally.nl_responder.nl_responder.NLResponder] will be used. + + Returns: + a new instance of db-ally Collection + + Raises: + ValueError: if default LLM client is not configured + """ + view_selector = view_selector or LLMViewSelector(llm=llm) + nl_responder = nl_responder or NLResponder(llm=llm) + event_handlers = event_handlers or dbally.event_handlers + + return Collection( + name, + nl_responder=nl_responder, + view_selector=view_selector, + llm=llm, + event_handlers=event_handlers, + ) diff --git a/tests/unit/test_collection.py b/tests/unit/test_collection.py index 3e5bddb5..e12ecec6 100644 --- a/tests/unit/test_collection.py +++ b/tests/unit/test_collection.py @@ -6,8 +6,7 @@ import pytest from typing_extensions import Annotated -from dbally._main import create_collection -from dbally.collection import Collection +from dbally.collection import Collection, create_collection from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ViewExecutionResult from dbally.iql._exceptions import IQLError From 219c52d2179d1c317d868a219399bcfdd8788ef6 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Wed, 3 Jul 2024 09:22:31 +0200 Subject: [PATCH 40/64] pre commit --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 73ebe267..f3a439d4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -58,7 +58,7 @@ repos: # Enforces a coding standard, looks for code smells, and can make suggestions about how the code could be refactored. - repo: https://github.com/pycqa/pylint - rev: v3.0.1 + rev: v3.1.0 hooks: - id: pylint exclude: (/test_|tests/|docs/) From 765f163eff8c2d6b3be49da527daaad8c23ce047 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Tue, 2 Jul 2024 09:17:32 +0200 Subject: [PATCH 41/64] adjustments --- .../candidate_view_with_similarity_store.py | 3 ++- examples/visualize_fallback_code.py | 8 +++----- src/dbally/collection/collection.py | 14 ++++++-------- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/examples/recruiting/candidate_view_with_similarity_store.py b/examples/recruiting/candidate_view_with_similarity_store.py index f50c4545..4fb394a9 100644 --- a/examples/recruiting/candidate_view_with_similarity_store.py +++ b/examples/recruiting/candidate_view_with_similarity_store.py @@ -5,9 +5,10 @@ from sqlalchemy.ext.automap import automap_base from typing_extensions import Annotated -from dbally import SqlAlchemyBaseView, decorators from dbally.embeddings.litellm import LiteLLMEmbeddingClient from dbally.similarity import FaissStore, SimilarityIndex, SimpleSqlAlchemyFetcher +from dbally.views import decorators +from dbally.views.sqlalchemy_base import SqlAlchemyBaseView engine = create_engine("sqlite:///examples/recruiting/data/candidates.db") diff --git a/examples/visualize_fallback_code.py b/examples/visualize_fallback_code.py index 6a5848e4..af34933a 100644 --- a/examples/visualize_fallback_code.py +++ b/examples/visualize_fallback_code.py @@ -2,21 +2,19 @@ import asyncio from recruiting import candidate_view_with_similarity_store, candidates_freeform -from recruiting.candidate_view_with_similarity_store import CandidateView, country_similarity +from recruiting.candidate_view_with_similarity_store import CandidateView from recruiting.candidates_freeform import CandidateFreeformView from recruiting.cypher_text2sql_view import SampleText2SQLViewCyphers, create_freeform_memory_engine import dbally -from dbally.audit import CLIEventHandler from dbally.gradio import create_gradio_interface from dbally.llms.litellm import LiteLLM async def main(): - await country_similarity.update() llm = LiteLLM(model_name="gpt-3.5-turbo") - collection1 = dbally.create_collection("candidates", llm, event_handlers=[CLIEventHandler()]) - collection2 = dbally.create_collection("freeform candidates", llm, event_handlers=[]) + collection1 = dbally.create_collection("candidates", llm) + collection2 = dbally.create_collection("freeform candidates", llm) collection1.add(CandidateView, lambda: CandidateView(candidate_view_with_similarity_store.engine)) collection1.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) collection2.add(CandidateFreeformView, lambda: CandidateFreeformView(candidates_freeform.engine)) diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 5ef0a860..38cc1b9a 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -6,7 +6,6 @@ from typing import Callable, Dict, List, Optional, Type, TypeVar import dbally -from dbally import DbAllyError from dbally.audit.event_handlers.base import EventHandler from dbally.audit.event_tracker import EventTracker from dbally.audit.events import FallbackEvent, RequestEnd, RequestStart @@ -20,6 +19,8 @@ from dbally.view_selection.base import ViewSelector from dbally.views.base import BaseView, IndexLocation +HANDLED_EXCEPTION_TYPES = (NoViewFoundError, UnsupportedQueryError, IndexUpdateError) + class Collection: """ @@ -285,7 +286,7 @@ async def _handle_fallback( llm_options: Optional[LLMOptions], selected_view_name: str, event_tracker: EventTracker, - caught_exception: DbAllyError, + caught_exception: HANDLED_EXCEPTION_TYPES, ): """ Handle fallback if the main query fails. @@ -324,8 +325,7 @@ async def _handle_fallback( span(event) return result - else: - raise caught_exception + raise caught_exception async def ask( self, @@ -363,7 +363,6 @@ async def ask( IQLError: if incorrect IQL was generated `n_retries` amount of times. ValueError: if incorrect IQL was generated `n_retries` amount of times. """ - handle_exceptions = (NoViewFoundError, UnsupportedQueryError, IndexUpdateError) if not event_tracker: event_tracker = EventTracker.initialize_with_handlers(self._event_handlers) @@ -395,7 +394,7 @@ async def ask( textual_response=natural_response, ) - except handle_exceptions as caught_exception: + except HANDLED_EXCEPTION_TYPES as caught_exception: result = await self._handle_fallback( question=question, dry_run=dry_run, @@ -405,8 +404,7 @@ async def ask( event_tracker=event_tracker, caught_exception=caught_exception, ) - finally: - await event_tracker.request_end(RequestEnd(result=result)) + await event_tracker.request_end(RequestEnd(result=result)) return result From 1f0ecbb63e29656b4f10a18232b892c0bd92dff1 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Thu, 4 Jul 2024 12:53:59 +0200 Subject: [PATCH 42/64] review fixups --- src/dbally/__init__.py | 2 +- src/dbally/collection/collection.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/dbally/__init__.py b/src/dbally/__init__.py index 754e897c..f6c082ac 100644 --- a/src/dbally/__init__.py +++ b/src/dbally/__init__.py @@ -22,7 +22,7 @@ from .exceptions import DbAllyError from .llms.clients.exceptions import LLMConnectionError, LLMError, LLMResponseError, LLMStatusError -event_handlers: List[Callable] = [] +event_handlers: List = [] __all__ = [ "__version__", diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 9ab3b6f0..d01e0806 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -35,8 +35,8 @@ def __init__( name: str, view_selector: ViewSelector, llm: LLM, - event_handlers: List[EventHandler], nl_responder: NLResponder, + event_handlers: Optional[List[EventHandler]] = None, n_retries: int = 3, ) -> None: """ @@ -47,10 +47,10 @@ def __init__( before generating the IQL query, a View that fits query the most is selected by the\ [ViewSelector](view_selection/index.md). llm: LLM used by the collection to generate views and respond to natural language queries. + nl_responder: Object that translates RAW response from db-ally into natural language. event_handlers: Event handlers used by the collection during query executions. Can be used\ to log events as [CLIEventHandler](event_handlers/cli_handler.md) or to validate system performance\ as [LangSmithEventHandler](event_handlers/langsmith_handler.md). - nl_responder: Object that translates RAW response from db-ally into natural language. n_retries: IQL generator may produce invalid IQL. If this is the case this argument specifies\ how many times db-ally will try to regenerate it. Previous try with the error message is\ appended to the chat history to guide next generations. @@ -324,7 +324,6 @@ def create_collection( """ view_selector = view_selector or LLMViewSelector(llm=llm) nl_responder = nl_responder or NLResponder(llm=llm) - event_handlers = event_handlers or dbally.event_handlers return Collection( name, From efdca498520ffc6fb5cde4792304fe3bc284472a Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Thu, 4 Jul 2024 12:56:57 +0200 Subject: [PATCH 43/64] event handler type --- src/dbally/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/dbally/__init__.py b/src/dbally/__init__.py index f6c082ac..d59b3ea7 100644 --- a/src/dbally/__init__.py +++ b/src/dbally/__init__.py @@ -1,6 +1,6 @@ """ dbally """ -from typing import Callable, List +from typing import List from dbally.collection.collection import Collection, create_collection from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError @@ -13,6 +13,7 @@ from .__version__ import __version__ from ._types import NOT_GIVEN, NotGiven +from .audit import EventHandler from .embeddings.exceptions import ( EmbeddingConnectionError, EmbeddingError, @@ -22,7 +23,7 @@ from .exceptions import DbAllyError from .llms.clients.exceptions import LLMConnectionError, LLMError, LLMResponseError, LLMStatusError -event_handlers: List = [] +event_handlers: List[EventHandler] = [] __all__ = [ "__version__", From 2d089179f888732620579e2516ac3ab91e075add Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Thu, 4 Jul 2024 12:59:57 +0200 Subject: [PATCH 44/64] Remove cyclic with EventHandlers typing --- src/dbally/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/dbally/__init__.py b/src/dbally/__init__.py index d59b3ea7..6ae32ec4 100644 --- a/src/dbally/__init__.py +++ b/src/dbally/__init__.py @@ -13,7 +13,6 @@ from .__version__ import __version__ from ._types import NOT_GIVEN, NotGiven -from .audit import EventHandler from .embeddings.exceptions import ( EmbeddingConnectionError, EmbeddingError, @@ -23,7 +22,7 @@ from .exceptions import DbAllyError from .llms.clients.exceptions import LLMConnectionError, LLMError, LLMResponseError, LLMStatusError -event_handlers: List[EventHandler] = [] +event_handlers: List = [] __all__ = [ "__version__", From e3b5fd7d6091ee94c4995d86f61de951bdbf172b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Hordy=C5=84ski?= <26008518+mhordynski@users.noreply.github.com> Date: Tue, 2 Jul 2024 15:31:34 +0200 Subject: [PATCH 45/64] chore: doggify project (#67) --- README.md | 2 +- docs/assets/guide_dog_lg.png | Bin 0 -> 15525 bytes docs/assets/guide_dog_sm.png | Bin 0 -> 2533 bytes docs/stylesheets/extra.css | 10 ++++++++++ mkdocs.yml | 2 ++ 5 files changed, 13 insertions(+), 1 deletion(-) create mode 100644 docs/assets/guide_dog_lg.png create mode 100644 docs/assets/guide_dog_sm.png diff --git a/README.md b/README.md index 9088bd15..0bbc2843 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -#

db-ally

+#

🦮 db-ally

Efficient, consistent and secure library for querying structured data with natural language diff --git a/docs/assets/guide_dog_lg.png b/docs/assets/guide_dog_lg.png new file mode 100644 index 0000000000000000000000000000000000000000..dee16c227311f84319c95ebc81fcc5e740eb5486 GIT binary patch literal 15525 zcmb`uc{r5cA2)vPnKAaPWsO1BvdfZfM3e~GsSH{uvQ(BVGnTScD3vWsN`#`U6=rNv zDxxA;#xCpF2eUkPpWpY--#^cFJ#)D(_w7FOKCkm$U*~-vNw%jfxH&{Q007)4j+>nZ z00Cbj02>`)ZKY=UNrgs=b}KtlomT(4vT38@vQU-g)=Ny zJ@4yrSd4Xw`xiej*%6#lG`i&P(PXpp{jXqER!!)oXaB_hzxfo<);grFiC_CxtvVUG zMf?D?C%TCYql>4qHchfumOp^nD~Ec%x>SF#%2y2de1mpR``YIPq|1=Y0#T#iDqO06 zg?S1p?;o~gXzZJK8$0yzMZtGo?9W1C_Zbb-7hlrcUu;~&8B#cixlN@W2l~IbYEqoS z8b>dE&e1qO2ukm=IX}r+4J!N$O5WrhIStEqfF#f-F>NR$$zmc6*=}GHFeE(!{)OdMv%#UUQLZn zg1DM*RAi&F5`&u_Gm4r%*vEKq_2WeRS(0Hpv&K8Bke!f~vU4M=kFyUoL;dP)g-9m) zAZ%&Gi!O`H@{G-~oc#SeV3gLl*^i=Vv=g)NmL408tO-#-ddV?k!r z#zeFzVY|7W{Z;4XV8#W(wH1{QK%a5tF!0A7(`I~9dtq=QU+)T=_D2?D`pwOQy68k# zrZ&T$=&H-&FnGdatls0+$M#nQAK=FF!^}iUZ35r9AQa8Wu_GDpbz%nJz0q=)*@{j* z7y^lTZm7l?LI=Mn(>KZ-TqD-3eh~&WWdq;oZ<~bNl zOU3`zpVzH>ePy>>(zEnI!^h^3(=)d`ElpbE^|G)c8W*nqG}2MocDbx-R<8Oj zOQo$2rCTN!YXfw{pRVa&nclza_Up{q6Eo))BbJZkcW*uQG(U0aRaN(3#M-8Z;bhUa ze3`#}w_}CM()~`)qXyz@SNNHHug>X0_(cPQp$+EN3D(kGVtA$K>Mhevwc^p{%ZpEt z(u%ZzL&oNYn(LQ0QuBv?f0`iHJJS0<%Lq8KkUeL+4=`9Fd3;A0;bHOC^tjIakqNdN zfqnOfSCdZi+`i= z9QA!7Q3kiK5X0czeM?p(xPm;?Eym!5QwJv(^ApgR(c>$`%3?ldDSD767 z7^@EmzQ_O65y?|}k_K_bV^$I>#a1%h7&@pO(|hp1bZW-CGC~y7K#)0&=15bd_kUgP z5I^t{WLD!t4IaHSX<3v+ZO8o^YSAY>x$E;wSK}?`zx6!39&n7->|9~w5Kj%rqj=h^ zgcdU|@D9{sC2lj^yeUzF*2mUtDXndG`{HJaxUhiby%#+ z@*kB5dnsqesv5Pw=Ouo2Ms@FM;%&&TH=*8H3v9{*4X#d^Q(D)8^VRBCQWF9gOK*vp z>%?2)gxOx^nv8oaNSPCj@LSX5W>Q zT>40Ti4T17t~D@uK-5K)c_4q!&2$?Q!QO=mBs#Qb4Amq^A> z6z!lRy0(vQ4r8a?>?jV4f4F$NcCZkEhWKFX%#H%-8!Lz_BK zrxrXytTtT<=+hOwARru*X()j049ej-OAFsME)b7ICst4|`KNC0r5NrNp{0i?F!b(m zB!6h4A9g@`IkV4({(0II!kp+I-)dXl#A7K2ZhS^awRIu(r)}w|IDFC(LA*{77uV4t zzCH0hGLoQY2G7*tak?^7f6HTTk$BVO@yCc%3DQ4}X;!mI;GXNQ*z`W7)+NX?ru7Jm z@v&E#2U9rQGFxN20V0KvOI=*8;^Ja_($fOR?r{iQ7qWeRVTd7>rykP{xab#H@HYNW zsPZdn>%~XWuWljYAJj4XGxm$|1sV6KMxu`|JtsEMw9Q}dJi6fnH14SZBtAs!*865n zrYHXb$6Fwwm!?zXlR?)<9bm}b&qJ3X)6%BJ43^o06i+OzfJa~Lwl>-Yt z)xMnB(EuN=8MX?Zw;8Fisa0G`sEaloq+bgDZDX}ZX!(McUR6RdTHHE8?$wDC^9VK{ z8xXE^&Ggs`UgKI^9p>be=V%+=jw67M5t^X0w41Ro6{xP(QYB{=NZW;tcY4Z`P2&07W>q> zU);QT5240ewM>DL${csv_*L4v$oba5Ya3n^kyfj#vOBio@+$wkpjZJ{5e!AzJM&omo7b5UgP9jZzD+qwcqkSPvCG;L~l~ z8kK^HqKwM@IV*N_U5FOAI+~s(_Cd(IQtr&ip1ZA5hMO1J8Rj~i5K=i__NmS zj!#hY?{wa^o}5zy_ug#?9N0Z5N>EN$Nt5nmxcc7;%!BIdh`8-v`6;-O=CM3x+o{=u z9DOVn>7|=uC$2o?L~WtOz-lioo$gO}kM@g-ZE)jH*qwg$GeQ6T;{BI%WHZ%gcB^sG z(iwC!{E%up@ob59_0PHU4>=Q$aAVuLnC;&i+u7dc)Y0CSUd=4ur`BUZdgbBOjW?Td zq+}N4V}my$?b}7Ufg79RXG9-U|6<-%o4d`W7+vY(lY5*zujJF~qQ|*wV%_(tz+@jOS;pTGSy_`>7riwzUb$|wB6Kw2RlsSHQemFE9w$Bb+`n}D+}U?D z;|VUA&R|zhT_v@VGeLs+(B<6Amzcjs_oa8gh$~{6CE?90C3HpgZx0T4CvMXnzw8vN zC7^C@(;qTNU81kw^WsY3c%&1xW)yJbebIZ#_gve1wJb?m#@`K@hv=G%s)W0XzUIPX z?i*AOy*0*f#?}z!pc!W{<+&yMmI%Rpm+P4Im2#tJgZdcV4C(ABMhebQ0li|azqD$s z51oG#-jZ;ThnW?bhfznr%r;3|xwkW>iJfwPE~j* z4CW*$=lPwWN8(kI75${wu7=wTx51Aee;MumTZ9E8llxtZFVAP7J94!9CE>v+vp=QL zDiPz3a}5=DKkP1-+x@6)X`Al1SVf%4f61&DCU~6=V19kHNt+-#?;K|1J6$Q(%dEhE z)sx~}ee-l}Fq-Brbn?Tc-F3~CZR~U5Rqrbvj<=Iu;%jNLQ=6l6vPtT42B*>KBU=LKB*4x4NmVJ+44kC zh@4u`8Rl*qe&qQ%0p(7KLo3rivvt358_95bVsFzP1Xo@7n_q|Dj$&?FV{0A0ure03 z9a}STnLrqwT*snTUOY}6Rd0OG8N7Jtb%F#b@>ESmxyWkQ0Y4^ml9|NEL>30_vNnsI zypE{vpuG=R{dtPROb*<+r7+&i3UsWgffO&llWB%wMf>)J2N{qq39f4b(`Z2^TlRsx zEeS&mM;)^#I;0s@J)09C95Ju1D@*&+% z70RC7a_?^DStu0o2{JAAST%0V#-wONy*CQ1d-ol9xK*YOQELvL;yU27T!sUM(w)2$ zJ>fwkic8(4kED!m(=53ia&vfceVk5&jYLfrg!m&KYD1Qv4nxP1J{@O!&quv{7bqe2 zT|x&U$o6rsQ`0bU0$UeM78@o%oyhge9vEjd_+|L&eM;??8`9x(L{cct$NUiv8SJ{) zQua=c;}(OOkgSyHd2U`6vvc4hi93If)}yOTRUgv})^|-XUR$V$RJGo7gc%osTPM!|Vcm2!b|9XBZ z!g!P}aDC)ll*`(mD|;N(z*PZ^8c$lfhQiavLYG&1T+6#pa|pU(1el{&_<)cvx~5GS z_o=wgS@u~q6Dog~G_P1w-;Y?uhDr%4C*iBB__cuRy_>~9PzKW9P??b@w5$l>{A^XC zY3uv2L9u=@!@E60hE-ivzDMu$U-~6_Kg}c0knVI%fWwv+KH_$x9*LxEie|8%F-LB)$TVZ zV>_PIIZ1qH{3@0C;a#0LX)2(ON@deejk}SBg43_3BuC*1qb>h&-gruF_ze0zSJX!ZLHxylQIcQtAGrG zmYJ(@F%tGDI}vQGWZjk$izJbiWP8ACgJPd0rX&Ue)UU>4B6_M0>0gZ6-a*E6a zcqnE#_q5PKzDexB_fN!7^=qnPOS>f=E3Artk6s`a=gZ>Ja^WKjVIG6RppKn4n>Eqg zMdXxhroZL$)qSKcGmf|WPiXKzPF0Ay$;Q*nw$`gAJZjsfot-TheK+@4s^94#ZB7;0gd1T@U``jb;%$siGM>zLukEaQC&Q8TT8Z!n^&^cH zNFxZVrm>5V6UqmeZ>W2T!`?gbBr5ClzRdkYJMgCY@nNR>{>=pC#c`pn0oARl$?kh_ z>blTT+t{=qTx3l+x|d|Y;_w9-mVgie<*U90aV=N-fq`*pfH&%QobqnaYn}8)*Zf9X zL&hEV75c9>`F-UCs+il+54nRp+ZRezZD5g~q75pP>2*)Ep|fb<2%E@d#V9G9;KG<0 za$xCw`~`56I-49hR^Qcdyk(u)_NvM`F&xB@`1Myg!RGx8=125eXgZHcfrj{)SkeQ$ zk^1dqG%4Kw@0-#Fb_H%0Z37;%mh4o)kEbWDyv>fR$wg^_SDx&_i&odKH5>UUwZ4x~ zakAOd@aEzeYga>~cVa|EJwY+R$9hg{uN& zc^+`oLp1&*M`{;Fc<^u*|LO&W18uJVmsS*uM1 z971wfrVaZoOt$V`!@r7rp3c1$Z8amjGrLXqzrks2$>>*&DlezR2}CoiZ?0rm927p$ zEJZXJeeVByc`EF@C_w@)4s(w7ENXHjeZg6I9 z@((l8!{0({PHl{)EKBvc6Zx1#=3L|zrj+KLh%n82Un-xcOgcJ*opZe&`RCxBn_HK;j$7hwn1zk(iPFSQ8YWXY@;XK5BI+(P;xNz2}duH+#o5y}tR1iZp8y zUD270KeXieZj`<~xSU76r}T7Y>Q+ank3hjW*M@!+wHxVjR^ilG}cC zaoX6zimhW0oPo|A%wYxhx{#f8Ghdk5K87H+gO+n6v1e3;9?j&f3U-{k?btr;KUzw` zxKcKT> z1it|^SrH#a3A5P}JYN}Ntg_jytEvL+C>Cbiyg2K4(Z?pOCBU9GELR{~U+439*YJyD z3@r2)H;IpH)&z*>iR!tt(BF0YKjHQgVQQ}yymCpIY;nv?a z@MSJd5{xYU)7eh>&*PKuIg1aBXXMC~2HoY!i~iupNzPNJRWsnK9Zu+#GT}ZW`p|dW zleeeRf9_p$JHIrS$)=#+`k>gklB9FFCQZd?521&%IVMID0kpjF05|&O>Cm#<{SHeW z_2-jBA2V)?QhwSTR334nez%U|aI0U-jx~2JEcw*#^F&o4&`rAg95T?-DF^)Yg?*lo zYCuqj;3?a_Bu@%#-fqp9&Oh@gu#lhfrKij;WvhKHcYPBTo9fOGKnDE4Gxo{cH+4msgGKv;!sdtm&73hF!3y zG2g!@uI)VQxDXVfWO_x!%TF-u$ves8jk;Tc7CB&JJhG0R?GO)6U9c0cBZ4Y#5-1S# z``T9+#oZ-Snw2uy^Jku)osKyx zw#tqDVBQR_J3>Qp1;G*p_fN)6;@@il5xA!?G&|4dlBgRS?ow`#rZE@LTQr_!_4M62 zua&v-&6QRz?@#*7Yz`=phurs6GF8Pt`XS_L6Q}sPp+)He=PVHyX#FuRx z9~$Q&5D4oQDJ;NC>&1isv#3{ZIMRo!Y`EQ@Afx^;SD)N+jNfTUXI#V6MbtltNr=!) zLfpSmqjBRrvZ+Pgn-G3zOR8S6xM;NYxj+f*N^cy5* zXZ~)hB%9zW_i%I)bH-jSb1~8>ux`737;x23+{SW|jOFq~XJ=NZruj9BG|1$OM-{@XJ5bZ2$6U89Zyl{*CRDmUT!0( zhm#QYgaqbpkqMln*xdqa;S`rWdk&@PLX|U$@t#!(vK}H#HdHJ zMzg=~_u+Zd(k_%1QW$(?b@**ta7h9~D>C_VFA$R^4{z(yHc*m4RBPoZ!)2Xk{?a9h zTn8l1P2P2tr}$`EUm*Fr5t>~6L>pK-Ll6I;T)L1bqE=^8r^Q!f@e8{XhoNa}Vlc6( zaf{(dY=TDd(%e|0yBks+bnSRtE~}H2z4@3Xc6xt-$bkCVsC%eEVD=#U=S)KIf1!(b z{7>^*k)8^EB&xqMK7*(hjYS zNV1108y-?b1H`>Ci(l=Mxk$ll2Tkt{KJWzJI(-ESaR2C983$-s12C!}eT5n5#|1P+ z^#~WVkIhkwW!o;_x#}ZL@Mx-HWLRsR^5*L3?xMh9Bt~{N!I(!AsW1|v_|es+*4nr# z3f>*UTa?ltD2H#zdTT;rkqK8=umn~CduHXeiupq0=&mp5)!F0-mc%;J{J-}5jqcc; zC@Gz^@j=yE>z5sSR?Tkv5w$SjjVS$sWOETUJ2rKWK-zAtxH)EhHhVULxYY*;=3=0# zt$0ipe^jRO#@uhkdwss(Mebh>s8~z@AVK65ggV6S1=I@Bm{)c!&H6`nX4*OMC-C^^ zA~&2*4sn6lXlYAEyg9Dx;Wc~r;wSb!Xe{K7sQq9~vLee1QXyGbr&a(-;5!VIMP-G* zeAbzYev=u_2qHIC{sQr5bkpM?gJWA;x&7(XA`Q<$QAfWV3EiI z#eCMZO07f)TZ?pKDZ0MUA*4*ZL}%c44{i4Wu5AbROwO1u{F3Vbts=ZP{GD=oxBY)r zZ~!e}dr|`Xtrv3&oYyaC3}&X{_06g&8MV~!AaRh*CpOyaKlpuZoWVKrxfQU^cC*2n zU{91p!BASkuo|QR+Yy2D8iZVzLYJ2_Sey|#vPR+yre-2B)&g$9gP`QAfG5CaK%fNi zBafB$jaEsV(&rG;?T5RrtHn^*&7DuV9$SM66)noSRkS1pnaULg&y4mdAG;i^o`71> zfA%Gz*xW=PzDRz%zfx){`Z3KHL~E|LSeh7rTb19gbpCm@e)-u;?yaZ;#oUBmcyp%` zm_=CoF->0sqA}rMS;|Blw1tr7Ohg*IrK;D_L}SJ7Uw75)7DV_##Z#9 znz518eee?h(+y1TFRTHQ14Iawz_#>anwVAlD|bSpANJ1koqtSw#~(IV*<6#Tf0EJm z{W(&Axgr^G)|lZ9^D4_6f(jC6MVME_+5fc%YK~ypw%~b@NhI@SQl%%2MW2rWr_H7KGCd_?_tz^W#P9qPqw!T{97 z$7T-UMG5T2AKaDCy?-c+cjxUeKOVSPf9FsC$hG^=tO>p7e?9?sS-0##i8c=qz+Cg? z-DUpXBg`|#O1;_7c&oQ#Jb-$QS{e1nibZ> zvL5ckiGqi;xcUv|^3T7)u&X}lZQSg>H~VLGsIzIY``$u&M6^SI#j|S2b`StK&C2d( z5^hMyk(?y3q+U!dv%LqLFI|}ExEej*((P`{jpZFrH8Bwuo?K}Z825?!YOGINZK;xR z+Y5(apM{|@$GX#d4%=qbndg?PK(xXD8!mte^R>NyIcq2cJCvzT>Ufq4ccUdob`~p8 zwdaBDW`NJ9eBQ*@+E6YU>kKKu-5>yva1L?#jD+2@ZG)(;4W(pIB9^BzM*BxI-IdFO;sv! z(W-|-A8Or)e@N3Yj7)0pzTJab56U6*76V&79mnOfoNC@JMo)qZGB1w7$ z??Z%dZ2x((v{iT5ql%uzn+^2;3zyyuwB8_iUh+WQ&`sAqdk=RP-4lz%wRv2JHL%`f zx^U$)5_p4#1Qff6==C!xN+JQsG|qDczx%o$kv7|zF!f^16{riR0k(ZviE`k(Um*K9S#E1^vZye zd>$!YE2e3zG*K(Dj`>{t6J%U1Z@!g)-h%G#@aW4#7pF?Fj|DV9m`IKig98e4C*v|{ zHofeH+-Z=pD8ch@fD0c}ukVG)`Y~b9QVbS26IXgMFprdDW+6-V?^%5cQU3}J{%Das zS>CYw>_+7e@P&4&Rhz~a!qOJ*Yl{eF^Z_l-|dx)TwxTgw$*h zOoy|W?Aj2E-0u7>&Y0*QA@2^aK9RZzPgBuR?@ZYnjRoj`XuKfEBpO>ba}!L|ZIq6%!P?JrLupJx;p zy(O(Vir&k+H4AG%iBx$I?lsnHl@wTGlRtM%XhX+1h#vpi3QVK;Khu@@$3?(V%=EM( z=IF^daessY!=UCCq(~EI^+VGKBsc9v6n6nzGyi@E=q`IA9OV{bgngy^ZS;_7y}C8e zcbJT_wQZ!72pTjpYTG2m7`$+Y3c{2lU_LNZWCG_NLA->GK-!8Pi8YF5wITln)U(oE z#5V<8kyy5m3oRhifIu!X@!Ohssmd9+vK)Qyextew*|g`>R+*IY<(;|;%FiX8&?RxU zx+!k_&#IfGj`qmg@qJLf$W2);0vZhx-9;zeY4D3;H?0)+b@kU77Z_nljJAri%lm#> zqDkvJ3EB%C+(HPxX7a~Z?M%xxL5#nz?Z$X_>oaEiQcuaOX%R$lMfVzZ-swtpU>=Nk@Aiv%dbDt-HRmR49<`$~ zU)`gz?B@3wbF8#@qFx?iD4c*+;SR25_j#7Q^1~ZsA5QL^x%V@H{CHVq(b*d2j>-oZ zuef)Uei(tHm99Mh8Z17re!p(RI&RL#TK*xl;}>*5kECX zp3A?0g8Pozz_}S!zXm*2xs@&pzybH~V1ei{QOTj*jyC(1l+;xnRL>FmpVvbj;I5L|c$2bj>*5A)*pabu;&G)eYx6)Kkyeuub~d<01D#x|XL z;PFz>Gob_Pn}t`+6*h!t)U z9@*)((y>H_`5>NT4i64v2e`xX6G%BBAI#LRl7?@HflM|5xi&v&*F#aWVf=q94C^)M_3C31-@s|>M0eQ=m>I?X}!>eRikmP8Nl91zR z&&q(FKO(z58NsuKcsTMxz5Ip4lb~lQw{_F)m@B#2?UJwJa$9WR5NZ>bD;PAu+A*8% zYC{bxYn{3~3I1Kk(MAncsZiNLZ?^6BnX9i((9|`ApTZeZg{|4CNuKrWga(usPT1p* zsBcKHcn$^fOA$~{8!uY`ne~JslcI5+O3ySj=4s`Ail9GueG7t{HFwzq*X^;V#0VkW z0lNlX?L<92#POcmgdBrGYN_v5+(0Oj8++PQAHxNA4vw8=neCCoyyC{5AEd(m<(4iR zs|79!U`Ey(uhO6G;71c9Ww?Not(B{zYJqS4V7xhw;nDYk`mg2=hFSqs0`-yEpL$Qf;BH?mIZJcih&bFxm$s zf8xVFXnK=jjU<#JxUmnU?n_L!nj+I|9Maqf%hLqIh@Z5>05ymM3y* zPtAHxU74>(gB!i!l@2$M+DN}Lohc9_FJt5+Btj7J zEbY-_I`EwMty3Gtk|Iz=K8@$C5FgQ1unEXH?xLyCwL(}VQvU@eRNzu$y#+r@|weqGfeZ0MQKs+-~fzz?0alqq22Sbx(NGk0c^xw=E5*CO<7MYVMQf+bgh)_wH`EW zDpgG(M@4-%YnN+$JT>{XSQcQPaBAFt$e;2!t0%%RL`ZJi>zJhWObyLU(sz%8JOrMg zc*DbqQxYY+D39{VJ8z>3Qbx%t!#6Y#J^Cb_tmBs3p*JFM>g>1QhLmMYawa=QBqSMX zzmw?|pC$jys4xRYjfGRCiloG><4WxM?lK+=o-GSWlp%eq@<+X0^RElJQjcFNKsvj> z=7++Tr8RCxx1RufaiaI9vFx0Hd_5AIpQG&Tu-e1O19 z5aC`(A-@s$D2@TUuik1wfH9G4b-;w3`jNcEmNJO@hymLp2UYJ^0-(aej|hxwzL%Gc zH~j+3w#>B~GqT zS?%8y8zFImV9SC5fr7SX?zj>JFR0&Z+ZKc^L1?}C|3tdHX(JJouyY7tke>8oMIBH& z?fQOL17HEQdjoQYJ0o9`*t~jUAHcO)wRyvPZeF0vp)jrv93VG0`0z692Jr&`T>TwT zi(&yioLDwe{2K^FMhpj{01DeS3~X6S?g0QdiBsyOh1mqc3lYEvec&})%2hNz|0=w-LQv8 zRbg9kiY_RO3YSoe%l9G%QKDZI*!C(AS3M}DR2}ap`dO{K8|BS5JsEjt zK1YB9o|+g)fjGz`Mos=1?b>1lNjeXMvU?>6uqCeHZOlga8GScUgVd1;5=#&T#@a}J zLgy*Qp6<=ELYtci0$_t&)V0+JYqhQq0ayI8N-hRJ-`QUJWF2`EF^e8B;ERHAnq(OZ zYx{W({P*(W=J^f`wHv+2QS{AcU^HNE}mjD=uk3&4_OwS(@cs=aB^qv!& zNj3f32&I4c3J+Z2(_x9$tdB!WAgkF4NyL@-M%E^yVOC79JbxgW#L8}fF1ZEbb2qt% z2l{V+BPdG&=@gcuSI8Lf=@s?Q%{3jdBS`IYz|4;gC1LEj`j7>)P|AGq>c^;Md=rvc z8OqALR{j&&!9!-Mx}dOkuupHD#HwOrSs)iS0z0;sKhedG{mmHSd=MuI1Gtje_9Z2n z+t#mWw;o#B90?r_$Lm*huYH;6f(`|SOzU@F`=+{_po$7@7^Y>y?5 zPjLy;5#1kSLqA$q#mI%mu90){Lc*wMEN7VLnmg#B`uuTpi-p{cPE zPHb)!YJmSZ0N%0fg5M>#MtKPfO*;W>1bg^ce?5nf#;ztOqB1pT9TE7&6>y=Yfj!`_ zwBresVZ)|)(EAaU-Dkm$r+M%rFeH{n7+y5Kj{(1p+9IHms|~nZIM}d;1U2qJUP3u> zM|DUTJt;^4z?LV?g_qEJA$xQ`3YWgwHT)0m7gk;r*`S@a$(VZAX%B)D;~_y74n*ML zRy+do5enzP%pKoHO2e3d!Iv4eXTBl=;iGwAmW!9km}E*Qg5sA=JEJ@{1yEDg!<`@OQ&tq=zYXK;W3zWf}~i33_i>?BDdGZ$tXOUxAcH#JZS;B=in@^CWsGXM)?15Xy!PS7}lS-k^ z8A8+2oQQT65LpknJodwY?hcL=^9J_DVj~dKdMtpA)N^Nn@x}`3Ebg(9BGUvAxd?zF zuTHI`df?kzFN7m(jd#jCqzcNw2Z(^WSX##E@(?)ON68_YFq3O}pR>5wcnMEcNUg=^ zAz==>M)%m}gNi!~9K0QXaurfw3O@kn-DZpTkogC1aHA0to0J942M)4A7|<%tj13N} z1JLZm>l!^Y5<(S7?DFpXQ6=V4Td|R=ajEC{3ED)~qHr(=RihA}V5};c6h`pMt|Wd@ z6G2JvD7T?2jK;1}9U8z;HV>v&-4HZ1)s~;?m%VNusq%a7BNG`mNJg)*3a)32Xz=r{~$ZwML!%rfex!4X!f&N!< zFiGr^uv#*m`8>16FH2T6rkfv}u|}gF@Da*h$23Ltyd*{1iTTx(3~b#Anr!?;zy>1% zqXK^OXa$ApneY-ej6Nz6;R6q?Lm+XZhB>Z|8FH0!_&1kdt~39t!oeR!igF z_j0`-6}5uHAE8Jc#jue=JXM);ufyaZ>xMmqhmwAcyz6_vMJgrhro>H#G-Kc z8F3F)oH!1qjTt&Jy!l9v{vWg9Xz}erXVbe7I(;CWG6H-N#F~VIYja9SgMEVh$ohq31kM zA}S1SKtVlwT;wT|C1;Re+Z3#3vtYpw$YI6Phv)i#QVCovH~~WjXO#yMpa|dz8!1kW zr{Lb7|2=OduwgxK5sK?Mk04>w{e?nMmf#BVW= zAP0c~*f16V+sh0h_|8U%G(PZTULJu31YdYMY*W6#w+*G$1n_V5yr6?Q z$Yg`va(X1`6wr>DKVzI7Y694R8cqP0h2X9Kw)S&i4nKis$1?tXp>r)aT?K)KxhNL? zpDEYyC@}EFYy^Z z(#~u!oP7`Lj;O4J4|k-M!3eM^ zSfUm@gX`#Sod^dhdQ6qM7bXFg@<5UwMb?$huV!A*uuF$kS|9}^a8>!|e*ZTGjt4-L zjS5&Kc=3YopDvgZZ)5=q1Bck=K4`BzU>(oCULFhl*hv!*q&wxLXblb_-doYS)K10B zUV8`(BzOs$@XR8$48u#JjsyS{51mq6efpRnz=CTyZQ4Uf>qf`V0ni z&{%-&F>827((?1ou~(=9IY6jj?%F^Hz!d~;XS9JAkoGG%@NhizBv{f;=t=;MM` z!5`&A#QFGhv=8+xkS-s(4F}aAF!^Qtlrl-78ek@0&TqSld9TGQUP{SKL9T805>d zdoWoMX}pyiHbUIwCH(D#fzms*cPoTQZN_J{DyLMLZ(h2=UMa|U`4noP*$I#6*Sy;|e2$T#-Vc?~@B!k89!7?DbH4NoKnvg(@*#lgvwmF$AFoHUH#lg` z&r9|oTteT})~+uFagaWK*)sdnt`v;!B$pB2)@r@k11EhyWDh!X!kH*7w;xVv=FC{& zkkm-+u2QxON_dxS27Y{nX#=VBj`n>a<~LN;?{RUG@O4N_)>890iP3CGw~;uPs`n9P z)xEyT%C7s3Rzu9+O7U39mzFzEF}~atv5h=HHaH|P*d|lC(M>o0_$+O6k{By{f>e8F zE!T~D>a2rr z%iih$;MjXy-30(xJAqAaxlTNjN`pWkH}0Yy0JRxeuCNjpFJxp*%>w|A1*m>#SP#9= z0+aDQCGU0b4dH%q5npnxK%jt2(dEX<`e_tA?EmLu)rJ2P=MBM62quhQ=>x9+F8?n+ aftco+M=U=j%i+Wez=>n0%w8YG#r_`w^dcz$ literal 0 HcmV?d00001 diff --git a/docs/assets/guide_dog_sm.png b/docs/assets/guide_dog_sm.png new file mode 100644 index 0000000000000000000000000000000000000000..85f91ee53228202908c70665841c7720820724bb GIT binary patch literal 2533 zcmVCfm%U|o+1*?W1SkaLsud8DRth4fRzgfgiXzZIY)i{nm?;J7U~zP+ z4NmJ+(YChLj+J(>NI*Ib)>4gPXhUKX0}^6@1UC2G?ESr*bAC_X5U?3%JN9`H~T}FRy_%0y8%p?7s$8HHLPT7^*c>bnb z=zSsMo5=#h9Jqv#mN!u*_rx;2ABQreEpP&qd}IBi$VNbf6etY$SvpVnZ-CDMmjX>X z;)Vpt9Ak@P!ok_z_{j6Wm>2EX6;iL8n1WA!vDz9t8VaT4(}x$u7{jNNwEaZ#3v+03 z35q=R=pPXrsk$UW1t=_EFZ?c<_ZuAq_Q`IZn;#}nph^jf*b0&pIzms@dgt1+w|-Sw zx|xJC#-~F3>&MgTWTQy($>D5nx}Q+Vqpf#(PjVRtg+dJDI*W+ZRFWdvhzme>V%Atj z_^*UNOUHBRP#DtlkOv$S0TC$aNNbh+Q_qhrx-CxLB7PXbA!-CbT@e2iF06_Ub^$x7f> zU`H6J&;93FLkCFGqFKhoHbqt~Y0vgR41`e7p6f~4tuFFbeN8gApsW7gBV{&tYEd+s zOTmAS;?0;3m#+OVq5T$Mbs)c^bcPA>rt~Tp3LzwgadOIt5-t$&;Z9#n`+gq~4|zn@MtyIxxJ z(JhyZt?#?KD{J==!w6}4U*SrvH9qLRW`n$_vN8M$I|ffquDa#>H*6oh;qBu)wl1IE z{qntkaLNN)xQsKlG(K|80`0nMaq`S<7hUtf#%+)PX64-UpV~2|%0CLs!#o^ID3PSZ#|YeK=={=Qc)d1Y{U{A&2U@cn_esCVFy?+y+BCPoUGE^g#+nn{f9mISdX$$hOuloeW5(NHO$z?4YeyS`S^u*$HQ72 zZZH1A>zbcX)^3}WxS%qQw4k?V>6_R){Og=oK2|-mw!Q?i=xivc#OOlc!jP1=R%&yf zYc=ai~^Tbd^V z6Lh4!LFPA#_2u`%N?Oz&`AYu&KVDI88X*!})YX_xH8$D9!LeRF_v%tETzwt;w{G)9 z*B)^yc5Z^AY?%(p0Z43_5X%K4qs&eWRq|GS3*x#Xe|H9DktyX3G2w=EHd(D%HU8J6Pq{s@XiDBWql@sKt`4juMsxpm>7~ueO6ry z#_wEw>yJPC(*3_gPydlnsWxNE>%x25-2UMYuw=>8YL6(+PM*ARC@s!b#WAZ})hyRw z-vmSbGb|oH$PahrdLphwSeQ4+=4T`euq}rWwdkB3KWTrfi?UtICJvsQJ0_%@=*Seq zt9J(}mcKpxUtedY{##gq5FwN%lp>Unz^cW8oukvtG%5y!Y+fUzD8RNljO5Z*U@pBl zt2X{6tF+wR?Bd}5rKFr-uy2~~C1WKS9Nf?5k^0=rpi?ms2t=8fD7FJhSI&wHUi7Pn z1_T%r6&y%QIZuowg@`Qyl8j*q4s^txlMc}Lp-}=SxOy4Pj1Iwsl7(F~i0PMz0h~k} z3=$n7M3dvQP8rp3|FkI;6`Yhss9XwKTuMnpa?|CCC50ZR?R&;FeR5EQ5qi|ec__t-TUli%uNn6HhKmW z0fiWCb(oGS78oNFHzWl94|CJ~VRC$7v2XYJADJB=DU=n7k}}vz0AZqJrAchY7(Vv$ zwoBLS-g@ODN!BjYYJ=H{5s$rkhDc%p*cKgcHuM>Rg=7LhC#))v%=Yhm(Et7Io7l7c z^2yy>uWIhywt|`Qg|8@-{|+1iLXt#A2Bz;(lK;u@Y0rl5eeJ#jSWjUQk|K=mdfQCi z=-RFjQ=ko;3S7%^iDMuz563|7&^{KOzMbB|{k^s{HoBajZoc~8j_f`ADm#497D1-B zcP0-`XpAkbti9_d3}3f}WMYxNFne_26!$QI4M2;4!3p;z*ml98;G zdDIoLIZot2LH9KIc7B|l{~2sHT1qgQADv=YcV2%sS$3h^{ zVNxLDxJh24Nh2^|rqoyRKwv#%rR}j7#UQ%r!b_LO<)0^uwi<9QK)uJm%s)5jB*ST1m+sji4$Pz8{~R%<#F*vA7lW6UdG+4mB`>yXwb>f5%+{pgcHjWA25F3k$14WWaU0ez16N=sWCS>69?hvU!3iiLm;;eyt Date: Tue, 2 Jul 2024 13:33:40 +0000 Subject: [PATCH 46/64] refactor(prompts): prompt templates (#66) --- benchmark/dbally_benchmark/e2e_benchmark.py | 10 +- benchmark/dbally_benchmark/iql_benchmark.py | 14 +- .../text2sql/prompt_template.py | 2 +- docs/about/roadmap.md | 3 +- docs/how-to/llms/custom.md | 25 +- docs/how-to/views/few-shots.md | 97 ++++++++ docs/reference/collection.md | 2 - docs/reference/index.md | 1 - docs/reference/iql/iql_generator.md | 4 - docs/reference/nl_responder.md | 4 - docs/reference/prompt.md | 7 + .../view_selection/llm_view_selector.md | 2 - examples/recruiting.py | 39 ++- examples/recruiting/views.py | 2 +- mkdocs.yml | 2 + src/dbally/assistants/openai.py | 2 +- src/dbally/audit/events.py | 2 +- src/dbally/gradio/gradio_interface.py | 4 +- src/dbally/iql/_query.py | 14 +- src/dbally/iql_generator/iql_generator.py | 100 ++++---- .../iql_generator/iql_prompt_template.py | 69 ------ src/dbally/iql_generator/prompt.py | 87 +++++++ src/dbally/llms/base.py | 42 +--- src/dbally/llms/clients/base.py | 10 +- src/dbally/llms/clients/litellm.py | 16 +- src/dbally/llms/litellm.py | 17 +- src/dbally/nl_responder/nl_responder.py | 97 +++----- .../nl_responder_prompt_template.py | 47 ---- src/dbally/nl_responder/prompts.py | 111 +++++++++ .../query_explainer_prompt_template.py | 48 ---- src/dbally/prompt/__init__.py | 3 + src/dbally/{prompts => prompt}/elements.py | 2 +- src/dbally/prompt/template.py | 234 ++++++++++++++++++ src/dbally/prompts/__init__.py | 4 - src/dbally/prompts/common_validation_utils.py | 51 ---- src/dbally/prompts/formatters.py | 119 --------- src/dbally/prompts/prompt_template.py | 83 ------- .../view_selection/llm_view_selector.py | 44 +--- src/dbally/view_selection/prompt.py | 52 ++++ .../view_selector_prompt_template.py | 51 ---- src/dbally/views/base.py | 2 +- src/dbally/views/freeform/text2sql/prompt.py | 61 +++++ src/dbally/views/freeform/text2sql/view.py | 55 ++-- src/dbally/views/structured.py | 41 +-- src/dbally_codegen/autodiscovery.py | 93 +++++-- tests/integration/test_llm_options.py | 12 +- tests/unit/mocks.py | 9 +- tests/unit/test_assistants_adapters.py | 2 +- tests/unit/test_collection.py | 46 +--- tests/unit/test_fewshot.py | 27 +- tests/unit/test_iql_format.py | 147 ++++++----- tests/unit/test_iql_generator.py | 112 +++++---- tests/unit/test_prompt_builder.py | 149 +++++------ tests/unit/test_view_selector.py | 2 +- 54 files changed, 1200 insertions(+), 1081 deletions(-) create mode 100644 docs/how-to/views/few-shots.md create mode 100644 docs/reference/prompt.md delete mode 100644 src/dbally/iql_generator/iql_prompt_template.py create mode 100644 src/dbally/iql_generator/prompt.py delete mode 100644 src/dbally/nl_responder/nl_responder_prompt_template.py create mode 100644 src/dbally/nl_responder/prompts.py delete mode 100644 src/dbally/nl_responder/query_explainer_prompt_template.py create mode 100644 src/dbally/prompt/__init__.py rename src/dbally/{prompts => prompt}/elements.py (97%) create mode 100644 src/dbally/prompt/template.py delete mode 100644 src/dbally/prompts/__init__.py delete mode 100644 src/dbally/prompts/common_validation_utils.py delete mode 100644 src/dbally/prompts/formatters.py delete mode 100644 src/dbally/prompts/prompt_template.py create mode 100644 src/dbally/view_selection/prompt.py delete mode 100644 src/dbally/view_selection/view_selector_prompt_template.py create mode 100644 src/dbally/views/freeform/text2sql/prompt.py diff --git a/benchmark/dbally_benchmark/e2e_benchmark.py b/benchmark/dbally_benchmark/e2e_benchmark.py index 9ba0871c..aa686727 100644 --- a/benchmark/dbally_benchmark/e2e_benchmark.py +++ b/benchmark/dbally_benchmark/e2e_benchmark.py @@ -23,9 +23,9 @@ import dbally from dbally.collection import Collection from dbally.collection.exceptions import NoViewFoundError -from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError, default_iql_template +from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, UnsupportedQueryError from dbally.llms.litellm import LiteLLM -from dbally.view_selection.view_selector_prompt_template import default_view_selector_template +from dbally.view_selection.prompt import VIEW_SELECTION_TEMPLATE async def _run_dbally_for_single_example(example: BIRDExample, collection: Collection) -> Text2SQLResult: @@ -126,9 +126,9 @@ async def evaluate(cfg: DictConfig) -> Any: logger.info(f"db-ally predictions saved under directory: {output_dir}") if run: - run["config/iql_prompt_template"] = stringify_unsupported(default_iql_template.chat) - run["config/view_selection_prompt_template"] = stringify_unsupported(default_view_selector_template.chat) - run["config/iql_prompt_template"] = stringify_unsupported(default_iql_template) + run["config/iql_prompt_template"] = stringify_unsupported(IQL_GENERATION_TEMPLATE.chat) + run["config/view_selection_prompt_template"] = stringify_unsupported(VIEW_SELECTION_TEMPLATE.chat) + run["config/iql_prompt_template"] = stringify_unsupported(IQL_GENERATION_TEMPLATE) run[f"evaluation/{metrics_file_name}"].upload((output_dir / metrics_file_name).as_posix()) run[f"evaluation/{results_file_name}"].upload((output_dir / results_file_name).as_posix()) run["evaluation/metrics"] = stringify_unsupported(metrics) diff --git a/benchmark/dbally_benchmark/iql_benchmark.py b/benchmark/dbally_benchmark/iql_benchmark.py index 7bb2ae28..2557b2c2 100644 --- a/benchmark/dbally_benchmark/iql_benchmark.py +++ b/benchmark/dbally_benchmark/iql_benchmark.py @@ -21,9 +21,8 @@ from dbally.audit.event_tracker import EventTracker from dbally.iql_generator.iql_generator import IQLGenerator -from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError, default_iql_template +from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, UnsupportedQueryError from dbally.llms.litellm import LiteLLM -from dbally.prompts.formatters import IQLInputFormatter from dbally.views.structured import BaseStructuredView @@ -32,14 +31,17 @@ async def _run_iql_for_single_example( ) -> IQLResult: filter_list = view.list_filters() event_tracker = EventTracker() - input_formatter = IQLInputFormatter(question=example.question, filters=filter_list) try: - iql_filters, _ = await iql_generator.generate_iql(input_formatter=input_formatter, event_tracker=event_tracker) + iql_filters = await iql_generator.generate_iql( + question=example.question, + filters=filter_list, + event_tracker=event_tracker, + ) except UnsupportedQueryError: return IQLResult(question=example.question, iql_filters="UNSUPPORTED_QUERY", exception_raised=True) - return IQLResult(question=example.question, iql_filters=iql_filters, exception_raised=False) + return IQLResult(question=example.question, iql_filters=str(iql_filters), exception_raised=False) async def run_iql_for_dataset( @@ -139,7 +141,7 @@ async def evaluate(cfg: DictConfig) -> Any: logger.info(f"IQL predictions saved under directory: {output_dir}") if run: - run["config/iql_prompt_template"] = stringify_unsupported(default_iql_template.chat) + run["config/iql_prompt_template"] = stringify_unsupported(IQL_GENERATION_TEMPLATE.chat) run[f"evaluation/{metrics_file_name}"].upload((output_dir / metrics_file_name).as_posix()) run[f"evaluation/{results_file_name}"].upload((output_dir / results_file_name).as_posix()) run["evaluation/metrics"] = stringify_unsupported(metrics) diff --git a/benchmark/dbally_benchmark/text2sql/prompt_template.py b/benchmark/dbally_benchmark/text2sql/prompt_template.py index abee9659..60349f38 100644 --- a/benchmark/dbally_benchmark/text2sql/prompt_template.py +++ b/benchmark/dbally_benchmark/text2sql/prompt_template.py @@ -1,4 +1,4 @@ -from dbally.prompts import PromptTemplate +from dbally.prompt import PromptTemplate TEXT2SQL_PROMPT_TEMPLATE = PromptTemplate( ( diff --git a/docs/about/roadmap.md b/docs/about/roadmap.md index f6449c88..288aa359 100644 --- a/docs/about/roadmap.md +++ b/docs/about/roadmap.md @@ -10,14 +10,13 @@ Below you can find a list of planned features and integrations. ## Planned Features - [ ] **Support analytical queries**: support for exposing operations beyond filtering. -- [ ] **Few-shot prompting configuration**: allow users to configure the few-shot prompting in View definition to +- [x] **Few-shot prompting configuration**: allow users to configure the few-shot prompting in View definition to improve IQL generation accuracy. - [ ] **Request contextualization**: allow to provide extra context for db-ally runs, such as user asking the question. - [X] **OpenAI Assistants API adapter**: allow to embed db-ally into OpenAI's Assistants API to easily extend the capabilities of the assistant. - [ ] **Langchain adapter**: allow to embed db-ally into Langchain applications. - ## Integrations Being agnostic to the underlying technology is one of the main goals of db-ally. diff --git a/docs/how-to/llms/custom.md b/docs/how-to/llms/custom.md index c262351d..7e249847 100644 --- a/docs/how-to/llms/custom.md +++ b/docs/how-to/llms/custom.md @@ -44,42 +44,29 @@ class MyLLMClient(LLMClient[LiteLLMOptions]): async def call( self, - prompt: ChatFormat, - response_format: Optional[Dict[str, str]], + conversation: ChatFormat, options: LiteLLMOptions, event: LLMEvent, + json_mode: bool = False, ) -> str: # Your LLM API call ``` -The [`call`](../../reference/llms/index.md#dbally.llms.clients.base.LLMClient.call) method is an abstract method that must be implemented in your subclass. This method should call the LLM inference API and return the response. +The [`call`](../../reference/llms/index.md#dbally.llms.clients.base.LLMClient.call) method is an abstract method that must be implemented in your subclass. This method should call the LLM inference API and return the response in string format. ### Step 3: Use tokenizer to count tokens -The [`count_tokens`](../../reference/llms/index.md#dbally.llms.base.LLM.count_tokens) method is used to count the number of tokens in the messages. You can override this method in your custom class to use the tokenizer and count tokens specifically for your model. +The [`count_tokens`](../../reference/llms/index.md#dbally.llms.base.LLM.count_tokens) method is used to count the number of tokens in the prompt. You can override this method in your custom class to use the tokenizer and count tokens specifically for your model. ```python class MyLLM(LLM[LiteLLMOptions]): - def count_tokens(self, messages: ChatFormat, fmt: Dict[str, str]) -> int: - # Count tokens in the messages in a custom way + def count_tokens(self, prompt: PromptTemplate) -> int: + # Count tokens in the prompt in a custom way ``` !!!warning Incorrect token counting can cause problems in the [`NLResponder`](../../reference/nl_responder.md#dbally.nl_responder.nl_responder.NLResponder) and force the use of an explanation prompt template that is more generic and does not include specific rows from the IQL response. -### Step 4: Define custom prompt formatting - -The [`format_prompt`](../../reference/llms/index.md#dbally.llms.base.LLM.format_prompt) method is used to apply formatting to the prompt template. You can override this method in your custom class to change how the formatting is performed. - -```python -class MyLLM(LLM[LiteLLMOptions]): - - def format_prompt(self, template: PromptTemplate, fmt: Dict[str, str]) -> ChatFormat: - # Apply custom formatting to the prompt template -``` -!!!note - In general, implementation of this method is not required unless the LLM API does not support [OpenAI conversation formatting](https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages){:target="_blank"}. If your model API expects a different format, override this method to avoid issues with inference call. - ## Customising LLM Options [`LLMOptions`](../../reference/llms/index.md#dbally.llms.clients.base.LLMOptions) is a class that defines the options your LLM will use. To create a custom options, you need to create a subclass of [`LLMOptions`](../../reference/llms/index.md#dbally.llms.clients.base.LLMOptions) and define the required properties that will be passed to the [`LLMClient`](../../reference/llms/index.md#dbally.llms.clients.base.LLMClient). diff --git a/docs/how-to/views/few-shots.md b/docs/how-to/views/few-shots.md new file mode 100644 index 00000000..806ab171 --- /dev/null +++ b/docs/how-to/views/few-shots.md @@ -0,0 +1,97 @@ +# How-To: Define few shots + +There are many ways to improve the accuracy of IQL generation - one of them is to use few-shot prompting. db-ally allows you to inject few-shot examples for any type of defined view, both structured and freeform. + +Few shots are defined in the [`list_few_shots`](../../reference/views/index.md#dbally.views.base.BaseView.list_few_shots) method, each few shot example should be an instance of [`FewShotExample`](../../reference/prompt.md#dbally.prompt.elements.FewShotExample) class that defines example question and expected LLM answer. + +## Structured views + +For structured views, both questions and answers for [`FewShotExample`](../../reference/prompt.md#dbally.prompt.elements.FewShotExample) can be defined as a strings, whereas in case of answers Python expressions are also allowed (please see lambda function in example below). + +```python +from dbally.prompt.elements import FewShotExample +from dbally.views.sqlalchemy_base import SqlAlchemyBaseView + +class RecruitmentView(SqlAlchemyBaseView): + """ + A view for retrieving candidates from the database. + """ + + def list_few_shots(self) -> List[FewShotExample]: + return [ + FewShotExample( + "Which candidates studied at University of Toronto?", + 'studied_at("University of Toronto")', + ), + FewShotExample( + "Do we have any soon available perfect fits for senior data scientist positions?", + lambda: ( + self.is_available_within_months(1) + and self.data_scientist_position() + and self.has_seniority("senior") + ), + ), + ... + ] +``` + +## Freeform views + +Currently freeform views accept SQL query syntax as a raw string. The larger variety of passing parameters is considered to be implemented in further db-ally releases. + +```python +from dbally.prompt.elements import FewShotExample +from dbally.views.freeform.text2sql import BaseText2SQLView + +class RecruitmentView(BaseText2SQLView): + """ + A view for retrieving candidates from the database. + """ + + def list_few_shots(self) -> List[FewShotExample]: + return [ + FewShotExample( + "Which candidates studied at University of Toronto?", + 'SELECT name FROM candidates WHERE university = "University of Toronto"', + ), + FewShotExample( + "Which clients are from NY?", + 'SELECT name FROM clients WHERE city = "NY"', + ), + ... + ] +``` + +## Prompt format + +By default each few shot is injected subsequent to a system prompt message. The format is as follows: + +```python +[ + { + "role" "user", + "content": "Question", + }, + { + "role": "assistant", + "content": "Answer", + } +] +``` + +If you use `examples` formatting tag in content field of the system or user message, all examples are going to be injected inside the message without additional conversation. + +The example of prompt utilizing `examples` tag: + +```python +[ + { + "role" "system", + "content": "Here are example resonses:\n {examples}", + }, +] +``` + +!!!info + There is no best way to inject a few shot example. Different models can behave diffrently based on few shots formatting of choice. + Generally, first appoach should yield the best results in most cases. Therefore, adding example tags in your custom prompts is not recommended. diff --git a/docs/reference/collection.md b/docs/reference/collection.md index cb9b4b97..c7b7269a 100644 --- a/docs/reference/collection.md +++ b/docs/reference/collection.md @@ -3,8 +3,6 @@ !!! tip To understand the general idea better, visit the [Collection concept page](../concepts/collections.md). -::: dbally.create_collection - ::: dbally.collection.Collection ::: dbally.collection.results.ExecutionResult diff --git a/docs/reference/index.md b/docs/reference/index.md index 0deb591a..fa1abc4f 100644 --- a/docs/reference/index.md +++ b/docs/reference/index.md @@ -1,4 +1,3 @@ # dbally - ::: dbally.create_collection diff --git a/docs/reference/iql/iql_generator.md b/docs/reference/iql/iql_generator.md index 15edcb56..b91a0b0c 100644 --- a/docs/reference/iql/iql_generator.md +++ b/docs/reference/iql/iql_generator.md @@ -1,7 +1,3 @@ # IQLGenerator ::: dbally.iql_generator.iql_generator.IQLGenerator - -::: dbally.iql_generator.iql_prompt_template.IQLPromptTemplate - -::: dbally.iql_generator.iql_prompt_template.default_iql_template diff --git a/docs/reference/nl_responder.md b/docs/reference/nl_responder.md index fb80741c..531243de 100644 --- a/docs/reference/nl_responder.md +++ b/docs/reference/nl_responder.md @@ -26,7 +26,3 @@ Otherwise, a response is generated using a `nl_responder_prompt_template`. To understand general idea better, visit the [NL Responder concept page](../concepts/nl_responder.md). ::: dbally.nl_responder.nl_responder.NLResponder - -::: dbally.nl_responder.query_explainer_prompt_template - -::: dbally.nl_responder.nl_responder_prompt_template.default_nl_responder_template diff --git a/docs/reference/prompt.md b/docs/reference/prompt.md new file mode 100644 index 00000000..42ab8901 --- /dev/null +++ b/docs/reference/prompt.md @@ -0,0 +1,7 @@ +# Prompt + +::: dbally.prompt.template.PromptTemplate + +::: dbally.prompt.template.PromptFormat + +::: dbally.prompt.elements.FewShotExample diff --git a/docs/reference/view_selection/llm_view_selector.md b/docs/reference/view_selection/llm_view_selector.md index 774aa4b9..a177a8bd 100644 --- a/docs/reference/view_selection/llm_view_selector.md +++ b/docs/reference/view_selection/llm_view_selector.md @@ -1,5 +1,3 @@ # LLMViewSelector ::: dbally.view_selection.LLMViewSelector - -::: dbally.view_selection.view_selector_prompt_template.default_view_selector_template diff --git a/examples/recruiting.py b/examples/recruiting.py index a4813b41..ea16a934 100644 --- a/examples/recruiting.py +++ b/examples/recruiting.py @@ -9,9 +9,37 @@ from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler from dbally.audit.event_tracker import EventTracker from dbally.llms.litellm import LiteLLM -from dbally.prompts import PromptTemplate +from dbally.prompt import PromptTemplate +from dbally.prompt.elements import FewShotExample +from dbally.prompt.template import PromptFormat -TEXT2SQL_PROMPT_TEMPLATE = PromptTemplate( + +class Text2SQLPromptFormat(PromptFormat): + """ + Formats provided parameters to a form acceptable by SQL prompt. + """ + + def __init__( + self, + *, + question: str, + schema: str, + examples: List[FewShotExample] = None, + ) -> None: + """ + Constructs a new Text2SQLInputFormat instance. + + Args: + question: Question to be asked. + schema: SQL schema description. + examples: List of examples to be injected into the conversation. + """ + super().__init__(examples) + self.question = question + self.schema = schema + + +TEXT2SQL_PROMPT_TEMPLATE = PromptTemplate[Text2SQLPromptFormat]( ( { "role": "system", @@ -112,9 +140,10 @@ async def recruiting_example(db_description: str, benchmark: Benchmark = example for question in benchmark.questions: await recruitment_db.ask(question.dbally_question, return_natural_response=True) gpt_question = question.gpt_question if question.gpt_question else question.dbally_question - gpt_response = await llm.generate_text( - TEXT2SQL_PROMPT_TEMPLATE, {"schema": db_description, "question": gpt_question}, event_tracker=event_tracker - ) + + prompt_format = Text2SQLPromptFormat(question=gpt_question, schema=db_description) + formatted_prompt = TEXT2SQL_PROMPT_TEMPLATE.format_prompt(prompt_format) + gpt_response = await llm.generate_text(formatted_prompt, event_tracker=event_tracker) print(f"GPT response: {gpt_response}") diff --git a/examples/recruiting/views.py b/examples/recruiting/views.py index 63a6c821..773d3f62 100644 --- a/examples/recruiting/views.py +++ b/examples/recruiting/views.py @@ -7,7 +7,7 @@ from sqlalchemy import and_, select from dbally import SqlAlchemyBaseView, decorators -from dbally.prompts.elements import FewShotExample +from dbally.prompt.elements import FewShotExample from .db import Candidate diff --git a/mkdocs.yml b/mkdocs.yml index 852eac20..826ffe15 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -21,6 +21,7 @@ nav: - how-to/views/text-to-sql.md - how-to/views/pandas.md - how-to/views/custom.md + - how-to/views/few-shots.md - Using LLMs: - how-to/llms/litellm.md - how-to/llms/custom.md @@ -59,6 +60,7 @@ nav: - LLMs: - reference/llms/index.md - reference/llms/litellm.md + - reference/prompt.md - Similarity: - reference/similarity/index.md - Store: diff --git a/src/dbally/assistants/openai.py b/src/dbally/assistants/openai.py index 8560cc95..4ec239df 100644 --- a/src/dbally/assistants/openai.py +++ b/src/dbally/assistants/openai.py @@ -6,7 +6,7 @@ from dbally.assistants.base import AssistantAdapter, FunctionCallingError, FunctionCallState from dbally.collection import Collection -from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError +from dbally.iql_generator.prompt import UnsupportedQueryError _DBALLY_INFO = "Dbally has access to the following database views: " diff --git a/src/dbally/audit/events.py b/src/dbally/audit/events.py index 1cc74ca4..3bb23e17 100644 --- a/src/dbally/audit/events.py +++ b/src/dbally/audit/events.py @@ -3,7 +3,7 @@ from typing import Optional, Union from dbally.collection.results import ExecutionResult -from dbally.prompts import ChatFormat +from dbally.prompt.template import ChatFormat @dataclass diff --git a/src/dbally/gradio/gradio_interface.py b/src/dbally/gradio/gradio_interface.py index 5eb77943..5f023659 100644 --- a/src/dbally/gradio/gradio_interface.py +++ b/src/dbally/gradio/gradio_interface.py @@ -9,8 +9,8 @@ from dbally.audit.event_handlers.buffer_event_handler import BufferEventHandler from dbally.collection import Collection from dbally.collection.exceptions import NoViewFoundError -from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError -from dbally.prompts import PromptTemplateError +from dbally.iql_generator.prompt import UnsupportedQueryError +from dbally.prompt.template import PromptTemplateError async def create_gradio_interface(user_collection: Collection, preview_limit: int = 10) -> gradio.Interface: diff --git a/src/dbally/iql/_query.py b/src/dbally/iql/_query.py index 7ad86490..c2131a57 100644 --- a/src/dbally/iql/_query.py +++ b/src/dbally/iql/_query.py @@ -15,12 +15,19 @@ class IQLQuery: root: syntax.Node - def __init__(self, root: syntax.Node): + def __init__(self, root: syntax.Node, source: str) -> None: self.root = root + self._source = source + + def __str__(self) -> str: + return self._source @classmethod async def parse( - cls, source: str, allowed_functions: List["ExposedFunction"], event_tracker: Optional[EventTracker] = None + cls, + source: str, + allowed_functions: List["ExposedFunction"], + event_tracker: Optional[EventTracker] = None, ) -> "IQLQuery": """ Parse IQL string to IQLQuery object. @@ -32,4 +39,5 @@ async def parse( Returns: IQLQuery object """ - return cls(await IQLProcessor(source, allowed_functions, event_tracker=event_tracker).process()) + root = await IQLProcessor(source, allowed_functions, event_tracker=event_tracker).process() + return cls(root=root, source=source) diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index cea13957..7eeb9154 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -1,10 +1,16 @@ -from typing import List, Optional, Tuple, TypeVar +from typing import List, Optional from dbally.audit.event_tracker import EventTracker -from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate, default_iql_template # noqa +from dbally.iql import IQLError, IQLQuery +from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions -from dbally.prompts.formatters import IQLInputFormatter +from dbally.prompt.elements import FewShotExample +from dbally.prompt.template import PromptTemplate +from dbally.views.exposed_functions import ExposedFunction + +ERROR_MESSAGE = "Unfortunately, generated IQL is not valid. Please try again, \ + generation of correct IQL is very important. Below you have errors generated by the system:\n{error}" class IQLGenerator: @@ -18,67 +24,61 @@ class IQLGenerator: It uses LLM to generate text-based responses, passing in the prompt template, formatted filters, and user question. """ - _ERROR_MSG_PREFIX = "Unfortunately, generated IQL is not valid. Please try again, \ - generation of correct IQL is very important. Below you have errors generated by the system: \n" - - TException = TypeVar("TException", bound=Exception) - - def __init__(self, llm: LLM) -> None: + def __init__(self, llm: LLM, prompt_template: Optional[PromptTemplate[IQLGenerationPromptFormat]] = None) -> None: """ + Constructs a new IQLGenerator instance. + Args: llm: LLM used to generate IQL """ self._llm = llm + self._prompt_template = prompt_template or IQL_GENERATION_TEMPLATE async def generate_iql( self, - input_formatter: IQLInputFormatter, + question: str, + filters: List[ExposedFunction], event_tracker: EventTracker, - conversation: Optional[IQLPromptTemplate] = None, + examples: Optional[List[FewShotExample]] = None, llm_options: Optional[LLMOptions] = None, - ) -> Tuple[str, IQLPromptTemplate]: + n_retries: int = 3, + ) -> IQLQuery: """ - Uses LLM to generate IQL in text form + Generates IQL in text form using LLM. Args: - input_formatter: formatter used to prepare prompt arguments dictionary - event_tracker: event store used to audit the generation process - conversation: conversation to be continued - llm_options: options to use for the LLM client + question: User question. + filters: List of filters exposed by the view. + event_tracker: Event store used to audit the generation process. + examples: List of examples to be injected into the conversation. + llm_options: Options to use for the LLM client. + n_retries: Number of retries to regenerate IQL in case of errors. Returns: - IQL - iql generated based on the user question + Generated IQL query. """ - - conversation, fmt = input_formatter(conversation or default_iql_template) - - llm_response = await self._llm.generate_text( - template=conversation, - fmt=fmt, - event_tracker=event_tracker, - options=llm_options, + prompt_format = IQLGenerationPromptFormat( + question=question, + filters=filters, + examples=examples, ) - - iql_filters = conversation.llm_response_parser(llm_response) - - conversation = conversation.add_assistant_message(content=llm_response) - - return iql_filters, conversation - - def add_error_msg(self, conversation: IQLPromptTemplate, errors: List[TException]) -> IQLPromptTemplate: - """ - Appends to the conversation error messages returned due to the invalid IQL generated by the LLM. - - Args: - conversation (IQLPromptTemplate): conversation containing current IQL generation trace - errors (List[Exception]): errors to be appended - - Returns: - IQLPromptTemplate: Conversation extended with errors - """ - - msg = self._ERROR_MSG_PREFIX - for error in errors: - msg += str(error) + "\n" - - return conversation.add_user_message(content=msg) + formatted_prompt = self._prompt_template.format_prompt(prompt_format) + + for _ in range(n_retries + 1): + try: + response = await self._llm.generate_text( + prompt=formatted_prompt, + event_tracker=event_tracker, + options=llm_options, + ) + # TODO: Move response parsing to llm generate_text method + iql = formatted_prompt.response_parser(response) + # TODO: Move IQL query parsing to prompt response parser + return await IQLQuery.parse( + source=iql, + allowed_functions=filters, + event_tracker=event_tracker, + ) + except IQLError as exc: + formatted_prompt = formatted_prompt.add_assistant_message(response) + formatted_prompt = formatted_prompt.add_user_message(ERROR_MESSAGE.format(error=exc)) diff --git a/src/dbally/iql_generator/iql_prompt_template.py b/src/dbally/iql_generator/iql_prompt_template.py deleted file mode 100644 index 2da8abd2..00000000 --- a/src/dbally/iql_generator/iql_prompt_template.py +++ /dev/null @@ -1,69 +0,0 @@ -from typing import Callable, Dict, Optional - -from dbally.exceptions import DbAllyError -from dbally.prompts import ChatFormat, PromptTemplate, check_prompt_variables - - -class UnsupportedQueryError(DbAllyError): - """ - Error raised when IQL generator is unable to construct a query - with given filters. - """ - - -class IQLPromptTemplate(PromptTemplate): - """ - Class for prompt templates meant for the IQL - """ - - def __init__( - self, - chat: ChatFormat, - response_format: Optional[Dict[str, str]] = None, - llm_response_parser: Callable = lambda x: x, - ): - super().__init__(chat, response_format, llm_response_parser) - self.chat = check_prompt_variables(chat, {"filters", "question"}) - - -def _validate_iql_response(llm_response: str) -> str: - """ - Validates LLM response to IQL - - Args: - llm_response: LLM response - - Returns: - A string containing IQL for filters. - - Raises: - UnsuppotedQueryError: When IQL generator is unable to construct a query - with given filters. - """ - - if "unsupported query" in llm_response.lower(): - raise UnsupportedQueryError - return llm_response - - -default_iql_template = IQLPromptTemplate( - chat=( - { - "role": "system", - "content": "You have access to API that lets you query a database:\n" - "\n{filters}\n" - "Please suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" - "Remember! Don't give any comments, just the function calls.\n" - "The output will look like this:\n" - 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n' - "DO NOT INCLUDE arguments names in your response. Only the values.\n" - "You MUST use only these methods:\n" - "\n{filters}\n" - "It is VERY IMPORTANT not to use methods other than those listed above." - """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ - "This is CRUCIAL, otherwise the system will crash. ", - }, - {"role": "user", "content": "{question}"}, - ), - llm_response_parser=_validate_iql_response, -) diff --git a/src/dbally/iql_generator/prompt.py b/src/dbally/iql_generator/prompt.py new file mode 100644 index 00000000..44bb2cd4 --- /dev/null +++ b/src/dbally/iql_generator/prompt.py @@ -0,0 +1,87 @@ +# pylint: disable=C0301 + +from typing import List + +from dbally.exceptions import DbAllyError +from dbally.prompt.elements import FewShotExample +from dbally.prompt.template import PromptFormat, PromptTemplate +from dbally.views.exposed_functions import ExposedFunction + + +class UnsupportedQueryError(DbAllyError): + """ + Error raised when IQL generator is unable to construct a query + with given filters. + """ + + +def _validate_iql_response(llm_response: str) -> str: + """ + Validates LLM response to IQL + + Args: + llm_response: LLM response + + Returns: + A string containing IQL for filters. + + Raises: + UnsuppotedQueryError: When IQL generator is unable to construct a query + with given filters. + """ + if "unsupported query" in llm_response.lower(): + raise UnsupportedQueryError + return llm_response + + +class IQLGenerationPromptFormat(PromptFormat): + """ + IQL prompt format, providing a question and filters to be used in the conversation. + """ + + def __init__( + self, + *, + question: str, + filters: List[ExposedFunction], + examples: List[FewShotExample] = None, + ) -> None: + """ + Constructs a new IQLGenerationPromptFormat instance. + + Args: + question: Question to be asked. + filters: List of filters exposed by the view. + examples: List of examples to be injected into the conversation. + """ + super().__init__(examples) + self.question = question + self.filters = "\n".join([str(filter) for filter in filters]) + + +IQL_GENERATION_TEMPLATE = PromptTemplate[IQLGenerationPromptFormat]( + [ + { + "role": "system", + "content": ( + "You have access to API that lets you query a database:\n" + "\n{filters}\n" + "Please suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" + "Remember! Don't give any comments, just the function calls.\n" + "The output will look like this:\n" + 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n' + "DO NOT INCLUDE arguments names in your response. Only the values.\n" + "You MUST use only these methods:\n" + "\n{filters}\n" + "It is VERY IMPORTANT not to use methods other than those listed above." + """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ + "This is CRUCIAL, otherwise the system will crash. " + ), + }, + { + "role": "user", + "content": "{question}", + }, + ], + response_parser=_validate_iql_response, +) diff --git a/src/dbally/llms/base.py b/src/dbally/llms/base.py index 067fbe56..7e2381e1 100644 --- a/src/dbally/llms/base.py +++ b/src/dbally/llms/base.py @@ -1,12 +1,11 @@ from abc import ABC, abstractmethod from functools import cached_property -from typing import Dict, Generic, Optional, Type +from typing import Generic, Optional, Type from dbally.audit.event_tracker import EventTracker from dbally.audit.events import LLMEvent from dbally.llms.clients.base import LLMClient, LLMClientOptions, LLMOptions -from dbally.prompts.common_validation_utils import ChatFormat -from dbally.prompts.prompt_template import PromptTemplate +from dbally.prompt.template import PromptTemplate class LLM(Generic[LLMClientOptions], ABC): @@ -41,36 +40,21 @@ def client(self) -> LLMClient: Client for the LLM. """ - def format_prompt(self, template: PromptTemplate, fmt: Dict[str, str]) -> ChatFormat: + def count_tokens(self, prompt: PromptTemplate) -> int: """ - Applies formatting to the prompt template. + Counts tokens in the prompt. Args: - template: Prompt template in system/user/assistant openAI format. - fmt: Dictionary with formatting. + prompt: Formatted prompt template with conversation and response parsing configuration. Returns: - Prompt in the format of the client. + Number of tokens in the prompt. """ - return [{"role": message["role"], "content": message["content"].format(**fmt)} for message in template.chat] - - def count_tokens(self, messages: ChatFormat, fmt: Dict[str, str]) -> int: - """ - Counts tokens in the messages. - - Args: - messages: Messages to count tokens for. - fmt: Arguments to be used with prompt. - - Returns: - Number of tokens in the messages. - """ - return sum(len(message["content"].format(**fmt)) for message in messages) + return sum(len(message["content"]) for message in prompt.chat) async def generate_text( self, - template: PromptTemplate, - fmt: Dict[str, str], + prompt: PromptTemplate, *, event_tracker: Optional[EventTracker] = None, options: Optional[LLMOptions] = None, @@ -79,8 +63,7 @@ async def generate_text( Prepares and sends a prompt to the LLM and returns the response. Args: - template: Prompt template in system/user/assistant openAI format. - fmt: Dictionary with formatting. + prompt: Formatted prompt template with conversation and response parsing configuration. event_tracker: Event store used to audit the generation process. options: Options to use for the LLM client. @@ -88,16 +71,15 @@ async def generate_text( Text response from LLM. """ options = (self.default_options | options) if options else self.default_options - prompt = self.format_prompt(template, fmt) - event = LLMEvent(prompt=prompt, type=type(template).__name__) + event = LLMEvent(prompt=prompt.chat, type=type(prompt).__name__) event_tracker = event_tracker or EventTracker() async with event_tracker.track_event(event) as span: event.response = await self.client.call( - prompt=prompt, - response_format=template.response_format, + conversation=prompt.chat, options=options, event=event, + json_mode=prompt.json_mode, ) span(event) diff --git a/src/dbally/llms/clients/base.py b/src/dbally/llms/clients/base.py index 5de63ce7..0293390f 100644 --- a/src/dbally/llms/clients/base.py +++ b/src/dbally/llms/clients/base.py @@ -3,7 +3,7 @@ from typing import Any, ClassVar, Dict, Generic, Optional, TypeVar from dbally.audit.events import LLMEvent -from dbally.prompts import ChatFormat +from dbally.prompt.template import ChatFormat from ..._types import NotGiven @@ -67,19 +67,19 @@ def __init__(self, model_name: str) -> None: @abstractmethod async def call( self, - prompt: ChatFormat, - response_format: Optional[Dict[str, str]], + conversation: ChatFormat, options: LLMClientOptions, event: LLMEvent, + json_mode: bool = False, ) -> str: """ Calls LLM inference API. Args: - prompt: Prompt passed to the LLM. - response_format: Optional argument used in the OpenAI API - used to force a json output + conversation: List of dicts with "role" and "content" keys, representing the chat history so far. options: Additional settings used by LLM. event: LLMEvent instance which fields should be filled during the method execution. + json_mode: Force the response to be in JSON format. Returns: Response string from LLM. diff --git a/src/dbally/llms/clients/litellm.py b/src/dbally/llms/clients/litellm.py index b15ad362..1e23df91 100644 --- a/src/dbally/llms/clients/litellm.py +++ b/src/dbally/llms/clients/litellm.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, List, Optional, Union +from typing import List, Optional, Union try: import litellm @@ -12,7 +12,7 @@ from dbally.audit.events import LLMEvent from dbally.llms.clients.base import LLMClient, LLMOptions from dbally.llms.clients.exceptions import LLMConnectionError, LLMResponseError, LLMStatusError -from dbally.prompts import ChatFormat +from dbally.prompt.template import ChatFormat from ..._types import NOT_GIVEN, NotGiven @@ -72,19 +72,19 @@ def __init__( async def call( self, - prompt: ChatFormat, - response_format: Optional[Dict[str, str]], + conversation: ChatFormat, options: LiteLLMOptions, event: LLMEvent, + json_mode: bool = False, ) -> str: """ Calls the appropriate LLM endpoint with the given prompt and options. Args: - prompt: Prompt as an OpenAI client style list. - response_format: Optional argument used in the OpenAI API - used to force the json output + conversation: List of dicts with "role" and "content" keys, representing the chat history so far. options: Additional settings used by the LLM. event: Container with the prompt, LLM response and call metrics. + json_mode: Force the response to be in JSON format. Returns: Response string from LLM. @@ -94,9 +94,11 @@ async def call( LLMStatusError: If the LLM API returns an error status code. LLMResponseError: If the LLM API response is invalid. """ + response_format = {"type": "json_object"} if json_mode else None + try: response = await litellm.acompletion( - messages=prompt, + messages=conversation, model=self.model_name, base_url=self.base_url, api_key=self.api_key, diff --git a/src/dbally/llms/litellm.py b/src/dbally/llms/litellm.py index c5699a1e..077474e9 100644 --- a/src/dbally/llms/litellm.py +++ b/src/dbally/llms/litellm.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import Dict, Optional +from typing import Optional try: import litellm @@ -10,7 +10,7 @@ from dbally.llms.base import LLM from dbally.llms.clients.litellm import LiteLLMClient, LiteLLMOptions -from dbally.prompts import ChatFormat +from dbally.prompt.template import PromptTemplate class LiteLLM(LLM[LiteLLMOptions]): @@ -65,17 +65,14 @@ def client(self) -> LiteLLMClient: api_version=self.api_version, ) - def count_tokens(self, messages: ChatFormat, fmt: Dict[str, str]) -> int: + def count_tokens(self, prompt: PromptTemplate) -> int: """ - Counts tokens in the messages using a specified model. + Counts tokens in the prompt. Args: - messages: Messages to count tokens for. - fmt: Arguments to be used with prompt. + prompt: Formatted prompt template with conversation and response parsing configuration. Returns: - Number of tokens in the messages. + Number of tokens in the prompt. """ - return sum( - litellm.token_counter(model=self.model_name, text=message["content"].format(**fmt)) for message in messages - ) + return sum(litellm.token_counter(model=self.model_name, text=message["content"]) for message in prompt.chat) diff --git a/src/dbally/nl_responder/nl_responder.py b/src/dbally/nl_responder/nl_responder.py index 8bcafb11..7a8f98e4 100644 --- a/src/dbally/nl_responder/nl_responder.py +++ b/src/dbally/nl_responder/nl_responder.py @@ -1,48 +1,44 @@ -import copy -from typing import Dict, List, Optional - -import pandas as pd +from typing import Optional from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions -from dbally.nl_responder.nl_responder_prompt_template import NLResponderPromptTemplate, default_nl_responder_template -from dbally.nl_responder.query_explainer_prompt_template import ( - QueryExplainerPromptTemplate, - default_query_explainer_template, +from dbally.nl_responder.prompts import ( + NL_RESPONSE_TEMPLATE, + QUERY_EXPLANATION_TEMPLATE, + NLResponsePromptFormat, + QueryExplanationPromptFormat, ) +from dbally.prompt.template import PromptTemplate class NLResponder: - """Class used to generate natural language response from the database output.""" - - # Keys used to extract the query from the context (ordered by priority) - QUERY_KEYS = ["iql", "sql", "query"] + """ + Class used to generate natural language response from the database output. + """ def __init__( self, llm: LLM, - query_explainer_prompt_template: Optional[QueryExplainerPromptTemplate] = None, - nl_responder_prompt_template: Optional[NLResponderPromptTemplate] = None, + prompt_template: Optional[PromptTemplate[NLResponsePromptFormat]] = None, + explainer_prompt_template: Optional[PromptTemplate[QueryExplanationPromptFormat]] = None, max_tokens_count: int = 4096, ) -> None: """ + Constructs a new NLResponder instance. + Args: - llm: LLM used to generate natural language response - query_explainer_prompt_template: template for the prompt used to generate the iql explanation - if not set defaults to `default_query_explainer_template` - nl_responder_prompt_template: template for the prompt used to generate the NL response - if not set defaults to `nl_responder_prompt_template` - max_tokens_count: maximum number of tokens that can be used in the prompt + llm: LLM used to generate natural language response. + prompt_template: Template for the prompt used to generate the NL response + if not set defaults to `NL_RESPONSE_TEMPLATE`. + explainer_prompt_template: Template for the prompt used to generate the iql explanation + if not set defaults to `QUERY_EXPLANATION_TEMPLATE`. + max_tokens_count: Maximum number of tokens that can be used in the prompt. """ self._llm = llm - self._nl_responder_prompt_template = nl_responder_prompt_template or copy.deepcopy( - default_nl_responder_template - ) - self._query_explainer_prompt_template = query_explainer_prompt_template or copy.deepcopy( - default_query_explainer_template - ) + self._prompt_template = prompt_template or NL_RESPONSE_TEMPLATE + self._explainer_prompt_template = explainer_prompt_template or QUERY_EXPLANATION_TEMPLATE self._max_tokens_count = max_tokens_count async def generate_response( @@ -56,53 +52,38 @@ async def generate_response( Uses LLM to generate a response in natural language form. Args: - result: object representing the result of the query execution - question: user question - event_tracker: event store used to audit the generation process - llm_options: options to use for the LLM client. + result: Object representing the result of the query execution. + question: User question. + event_tracker: Event store used to audit the generation process. + llm_options: Options to use for the LLM client. Returns: Natural language response to the user question. """ - rows = _promptify_rows(result.results) - - tokens_count = self._llm.count_tokens( - messages=self._nl_responder_prompt_template.chat, - fmt={"rows": rows, "question": question}, + prompt_format = NLResponsePromptFormat( + question=question, + results=result.results, ) + formatted_prompt = self._prompt_template.format_prompt(prompt_format) + tokens_count = self._llm.count_tokens(formatted_prompt) if tokens_count > self._max_tokens_count: - context = result.context - query = next((context.get(key) for key in self.QUERY_KEYS if context.get(key)), question) + prompt_format = QueryExplanationPromptFormat( + question=question, + context=result.context, + results=result.results, + ) + formatted_prompt = self._explainer_prompt_template.format_prompt(prompt_format) llm_response = await self._llm.generate_text( - template=self._query_explainer_prompt_template, - fmt={"question": question, "query": query, "number_of_results": len(result.results)}, + prompt=formatted_prompt, event_tracker=event_tracker, options=llm_options, ) - return llm_response llm_response = await self._llm.generate_text( - template=self._nl_responder_prompt_template, - fmt={"rows": _promptify_rows(result.results), "question": question}, + prompt=formatted_prompt, event_tracker=event_tracker, options=llm_options, ) return llm_response - - -def _promptify_rows(rows: List[Dict]) -> str: - """ - Formats rows into a markdown table. - - Args: - rows: list of rows to be formatted - - Returns: - str: formatted rows - """ - - df = pd.DataFrame.from_records(rows) - - return df.to_markdown(index=False, headers="keys", tablefmt="psql") diff --git a/src/dbally/nl_responder/nl_responder_prompt_template.py b/src/dbally/nl_responder/nl_responder_prompt_template.py deleted file mode 100644 index 9e6e687e..00000000 --- a/src/dbally/nl_responder/nl_responder_prompt_template.py +++ /dev/null @@ -1,47 +0,0 @@ -from typing import Callable, Dict, Optional - -from dbally.prompts import ChatFormat, PromptTemplate, check_prompt_variables - - -class NLResponderPromptTemplate(PromptTemplate): - """ - Class for prompt templates meant for the natural response. - """ - - def __init__( - self, - chat: ChatFormat, - response_format: Optional[Dict[str, str]] = None, - llm_response_parser: Callable = lambda x: x, - ) -> None: - """ - Initializes NLResponderPromptTemplate class. - - Args: - chat: chat format - response_format: response format - llm_response_parser: function to parse llm response - """ - - super().__init__(chat, response_format, llm_response_parser) - self.chat = check_prompt_variables(chat, {"rows", "question"}) - - -default_nl_responder_template = NLResponderPromptTemplate( - chat=( - { - "role": "system", - "content": "You are a helpful assistant that helps answer the user's questions " - "based on the table provided. You MUST use the table to answer the question. " - "You are very intelligent and obedient.\n" - "The table ALWAYS contains full answer to a question.\n" - "Answer the question in a way that is easy to understand and informative.\n" - "DON'T MENTION using a table in your answer.", - }, - { - "role": "user", - "content": "The table below represents the answer to a question: {question}.\n" - "{rows}\nAnswer the question: {question}.", - }, - ) -) diff --git a/src/dbally/nl_responder/prompts.py b/src/dbally/nl_responder/prompts.py new file mode 100644 index 00000000..f99a8a6c --- /dev/null +++ b/src/dbally/nl_responder/prompts.py @@ -0,0 +1,111 @@ +from typing import Any, Dict, List + +import pandas as pd + +from dbally.prompt.elements import FewShotExample +from dbally.prompt.template import PromptFormat, PromptTemplate + + +class NLResponsePromptFormat(PromptFormat): + """ + Formats provided parameters to a form acceptable by default NL response prompt. + """ + + def __init__( + self, + *, + question: str, + results: List[Dict[str, Any]], + examples: List[FewShotExample] = None, + ) -> None: + """ + Constructs a new IQLGenerationPromptFormat instance. + + Args: + question: Question to be asked. + filters: List of filters exposed by the view. + examples: List of examples to be injected into the conversation. + """ + super().__init__(examples) + self.question = question + self.results = pd.DataFrame.from_records(results).to_markdown(index=False, headers="keys", tablefmt="psql") + + +class QueryExplanationPromptFormat(PromptFormat): + """ + Formats provided parameters to a form acceptable by default query explanation prompt. + """ + + def __init__( + self, + *, + question: str, + context: Dict[str, Any], + results: List[Dict[str, Any]], + examples: List[FewShotExample] = None, + ) -> None: + """ + Constructs a new QueryExplanationPromptFormat instance. + + Args: + question: Question to be asked. + context: Context of the query. + results: List of results returned by the query. + examples: List of examples to be injected into the conversation. + """ + super().__init__(examples) + self.question = question + self.query = next((context.get(key) for key in ("iql", "sql", "query") if context.get(key)), question) + self.number_of_results = len(results) + + +NL_RESPONSE_TEMPLATE = PromptTemplate[NLResponsePromptFormat]( + [ + { + "role": "system", + "content": ( + "You are a helpful assistant that helps answer the user's questions " + "based on the table provided. You MUST use the table to answer the question. " + "You are very intelligent and obedient.\n" + "The table ALWAYS contains full answer to a question.\n" + "Answer the question in a way that is easy to understand and informative.\n" + "DON'T MENTION using a table in your answer." + ), + }, + { + "role": "user", + "content": ( + "The table below represents the answer to a question: {question}.\n" + "{results}\n" + "Answer the question: {question}." + ), + }, + ], +) + +QUERY_EXPLANATION_TEMPLATE = PromptTemplate[QueryExplanationPromptFormat]( + [ + { + "role": "system", + "content": ( + "You are a helpful assistant that helps describe a table generated by a query " + "that answers users' question. " + "You are very intelligent and obedient.\n" + "Your task is to provide natural language description of the table used by the logical query " + "to the database.\n" + "Describe the table in a way that is short and informative.\n" + "Make your answer as short as possible, start it by infroming the user that the underlying " + "data is too long to print and then describe the table based on the question and the query.\n" + "DON'T MENTION using a query in your answer." + ), + }, + { + "role": "user", + "content": ( + "The query below represents the answer to a question: {question}.\n" + "Describe the table generated using this query: {query}.\n" + "Number of results to this query: {number_of_results}." + ), + }, + ], +) diff --git a/src/dbally/nl_responder/query_explainer_prompt_template.py b/src/dbally/nl_responder/query_explainer_prompt_template.py deleted file mode 100644 index 00a3e6a6..00000000 --- a/src/dbally/nl_responder/query_explainer_prompt_template.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import Callable, Dict, Optional - -from dbally.prompts import ChatFormat, PromptTemplate, check_prompt_variables - - -class QueryExplainerPromptTemplate(PromptTemplate): - """ - Class for prompt templates meant to generate explanations for queries - (when the data cannot be shown due to token limit). - - Args: - chat: chat format - response_format: response format - llm_response_parser: function to parse llm response - """ - - def __init__( - self, - chat: ChatFormat, - response_format: Optional[Dict[str, str]] = None, - llm_response_parser: Callable = lambda x: x, - ) -> None: - super().__init__(chat, response_format, llm_response_parser) - self.chat = check_prompt_variables(chat, {"question", "query", "number_of_results"}) - - -default_query_explainer_template = QueryExplainerPromptTemplate( - chat=( - { - "role": "system", - "content": "You are a helpful assistant that helps describe a table generated by a query " - "that answers users' question. " - "You are very intelligent and obedient.\n" - "Your task is to provide natural language description of the table used by the logical query " - "to the database.\n" - "Describe the table in a way that is short and informative.\n" - "Make your answer as short as possible, start it by infroming the user that the underlying " - "data is too long to print and then describe the table based on the question and the query.\n" - "DON'T MENTION using a query in your answer.\n", - }, - { - "role": "user", - "content": "The query below represents the answer to a question: {question}.\n" - "Describe the table generated using this query: {query}.\n" - "Number of results to this query: {number_of_results}.\n", - }, - ) -) diff --git a/src/dbally/prompt/__init__.py b/src/dbally/prompt/__init__.py new file mode 100644 index 00000000..61495d33 --- /dev/null +++ b/src/dbally/prompt/__init__.py @@ -0,0 +1,3 @@ +from .template import ChatFormat, PromptTemplate, PromptTemplateError + +__all__ = ["PromptTemplate", "PromptTemplateError", "ChatFormat"] diff --git a/src/dbally/prompts/elements.py b/src/dbally/prompt/elements.py similarity index 97% rename from src/dbally/prompts/elements.py rename to src/dbally/prompt/elements.py index 2937d7c1..37375508 100644 --- a/src/dbally/prompts/elements.py +++ b/src/dbally/prompt/elements.py @@ -58,4 +58,4 @@ def _parse_lambda(self, expr: Callable) -> str: return parsed_expr def __str__(self) -> str: - return self.answer + return f"{self.question} -> {self.answer}" diff --git a/src/dbally/prompt/template.py b/src/dbally/prompt/template.py new file mode 100644 index 00000000..124a3e1c --- /dev/null +++ b/src/dbally/prompt/template.py @@ -0,0 +1,234 @@ +import copy +import re +from typing import Callable, Dict, Generic, List, TypeVar + +from typing_extensions import Self + +from dbally.exceptions import DbAllyError +from dbally.prompt.elements import FewShotExample + +ChatFormat = List[Dict[str, str]] + + +class PromptTemplateError(DbAllyError): + """ + Error raised on incorrect PromptTemplate construction. + """ + + +def _check_chat_order(chat: ChatFormat) -> ChatFormat: + """ + Pydantic validator. Checks if the chat template is constructed correctly (system, user, assistant alternating). + + Args: + chat: Chat template + + Raises: + PromptTemplateError: if chat template is not constructed correctly. + + Returns: + Chat template + """ + if len(chat) == 0: + raise PromptTemplateError("Template should not be empty") + + expected_order = ["user", "assistant"] + for i, message in enumerate(chat): + role = message["role"] + if role == "system": + if i != 0: + raise PromptTemplateError("Only first message should come from system") + continue + index = i % len(expected_order) + if role != expected_order[index - 1]: + raise PromptTemplateError( + "Template format is not correct. It should be system, and then user/assistant alternating." + ) + + if expected_order[index] not in ["user", "assistant"]: + raise PromptTemplateError("Template needs to end on either user or assistant turn") + return chat + + +class PromptFormat: + """ + Generic format for prompts allowing to inject few shot examples into the conversation. + """ + + def __init__(self, examples: List[FewShotExample] = None) -> None: + """ + Constructs a new PromptFormat instance. + + Args: + examples: List of examples to be injected into the conversation. + """ + self.examples = examples or [] + + +PromptFormatT = TypeVar("PromptFormatT", bound=PromptFormat) + + +class PromptTemplate(Generic[PromptFormatT]): + """ + Class for prompt templates. + """ + + def __init__( + self, + chat: ChatFormat, + *, + json_mode: bool = False, + response_parser: Callable = lambda x: x, + ) -> None: + """ + Constructs a new PromptTemplate instance. + + Args: + chat: Chat-formatted conversation template. + json_mode: Whether to enforce JSON response from LLM. + response_parser: Function parsing the LLM response into the desired format. + """ + self.chat: ChatFormat = _check_chat_order(chat) + self.json_mode = json_mode + self.response_parser = response_parser + + def __eq__(self, other: "PromptTemplate") -> bool: + return isinstance(other, PromptTemplate) and self.chat == other.chat + + def _has_variable(self, variable: str) -> bool: + """ + Validates a given chat to make sure it contains variables required. + + Args: + variable: Variable to check. + + Returns: + True if the variable is present in the chat. + """ + for message in self.chat: + if re.match(rf"{{{variable}}}", message["content"]): + return True + return False + + def format_prompt(self, prompt_format: PromptFormatT) -> Self: + """ + Applies formatting to the prompt template chat contents. + + Args: + prompt_format: Format to be applied to the prompt. + + Returns: + PromptTemplate with formatted chat contents. + """ + formatted_prompt = copy.deepcopy(self) + formatting = dict(prompt_format.__dict__) + + if self._has_variable("examples"): + formatting["examples"] = "\n".join(prompt_format.examples) + else: + formatted_prompt = formatted_prompt.clear_few_shot_messages() + for example in prompt_format.examples: + formatted_prompt = formatted_prompt.add_few_shot_message(example) + + formatted_prompt.chat = [ + { + "role": message.get("role"), + "content": message.get("content").format(**formatting), + "is_example": message.get("is_example", False), + } + for message in formatted_prompt.chat + ] + return formatted_prompt + + def set_system_message(self, content: str) -> Self: + """ + Sets a system message to the template prompt. + + Args: + content: Message to be added. + + Returns: + PromptTemplate with appended system message. + """ + return self.__class__( + chat=[{"role": "system", "content": content}, *self.chat], + json_mode=self.json_mode, + response_parser=self.response_parser, + ) + + def add_user_message(self, content: str) -> Self: + """ + Add a user message to the template prompt. + + Args: + content: Message to be added. + + Returns: + PromptTemplate with appended user message. + """ + return self.__class__( + chat=[*self.chat, {"role": "user", "content": content}], + json_mode=self.json_mode, + response_parser=self.response_parser, + ) + + def add_assistant_message(self, content: str) -> Self: + """ + Add an assistant message to the template prompt. + + Args: + content: Message to be added. + + Returns: + PromptTemplate with appended assistant message. + """ + return self.__class__( + chat=[*self.chat, {"role": "assistant", "content": content}], + json_mode=self.json_mode, + response_parser=self.response_parser, + ) + + def add_few_shot_message(self, example: FewShotExample) -> Self: + """ + Add a few-shot message to the template prompt. + + Args: + example: Few-shot example to be added. + + Returns: + PromptTemplate with appended few-shot message. + + Raises: + PromptTemplateError: if the template is empty. + """ + if len(self.chat) == 0: + raise PromptTemplateError("Cannot add few-shot messages to an empty template.") + + few_shot = [ + {"role": "user", "content": example.question, "is_example": True}, + {"role": "assistant", "content": example.answer, "is_example": True}, + ] + few_shot_index = max( + (i for i, entry in enumerate(self.chat) if entry.get("is_example") or entry.get("role") == "system"), + default=0, + ) + chat = self.chat[: few_shot_index + 1] + few_shot + self.chat[few_shot_index + 1 :] + + return self.__class__( + chat=chat, + json_mode=self.json_mode, + response_parser=self.response_parser, + ) + + def clear_few_shot_messages(self) -> Self: + """ + Removes all few-shot messages from the template prompt. + + Returns: + PromptTemplate with few-shot messages removed. + """ + return self.__class__( + chat=[message for message in self.chat if not message.get("is_example")], + json_mode=self.json_mode, + response_parser=self.response_parser, + ) diff --git a/src/dbally/prompts/__init__.py b/src/dbally/prompts/__init__.py deleted file mode 100644 index 38e20cc7..00000000 --- a/src/dbally/prompts/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .common_validation_utils import ChatFormat, PromptTemplateError, check_prompt_variables -from .prompt_template import PromptTemplate - -__all__ = ["PromptTemplate", "PromptTemplateError", "check_prompt_variables", "ChatFormat"] diff --git a/src/dbally/prompts/common_validation_utils.py b/src/dbally/prompts/common_validation_utils.py deleted file mode 100644 index f4660810..00000000 --- a/src/dbally/prompts/common_validation_utils.py +++ /dev/null @@ -1,51 +0,0 @@ -import re -from typing import Dict, List, Set - -from dbally.exceptions import DbAllyError - -ChatFormat = List[Dict[str, str]] - - -class PromptTemplateError(DbAllyError): - """Error raised on incorrect PromptTemplate construction""" - - -def _extract_variables(text: str) -> List[str]: - """ - Given a text string, extract all variables that can be filled using .format - - Args: - text: string to process - - Returns: - list of variables extracted from text - """ - pattern = r"\{([^}]+)\}" - return re.findall(pattern, text) - - -def check_prompt_variables(chat: ChatFormat, variables_to_check: Set[str]) -> ChatFormat: - """ - Function validates a given chat to make sure it contains variables required. - - Args: - chat: chat to validate - variables_to_check: set of variables to assert - - Raises: - PromptTemplateError: If required variables are missing - - Returns: - Chat, if it's valid. - """ - variables = [] - for message in chat: - content = message["content"] - variables.extend(_extract_variables(content)) - if not set(variables_to_check).issubset(variables): - raise PromptTemplateError( - "Cannot build a prompt template from the provided chat, " - "because it lacks necessary string variables. " - "You need to format the following variables: {variables_to_check}" - ) - return chat diff --git a/src/dbally/prompts/formatters.py b/src/dbally/prompts/formatters.py deleted file mode 100644 index c2cce950..00000000 --- a/src/dbally/prompts/formatters.py +++ /dev/null @@ -1,119 +0,0 @@ -import copy -from abc import ABCMeta, abstractmethod -from typing import Dict, List, Tuple - -from dbally.prompts.elements import FewShotExample -from dbally.prompts.prompt_template import PromptTemplate -from dbally.views.exposed_functions import ExposedFunction - - -def _promptify_filters( - filters: List[ExposedFunction], -) -> str: - """ - Formats filters for prompt - - Args: - filters: list of filters exposed by the view - - Returns: - filters formatted for prompt - """ - filters_for_prompt = "\n".join([str(filter) for filter in filters]) - return filters_for_prompt - - -class InputFormatter(metaclass=ABCMeta): - """ - Formats provided parameters to a form acceptable by IQL prompt - """ - - @abstractmethod - def __call__(self, conversation_template: PromptTemplate) -> Tuple[PromptTemplate, Dict[str, str]]: - """ - Runs the input formatting for provided prompt template. - - Args: - conversation_template: a prompt template to use. - - Returns: - A tuple with template and a dictionary with formatted inputs. - """ - - -class IQLInputFormatter(InputFormatter): - """ - Formats provided parameters to a form acceptable by default IQL prompt - """ - - def __init__(self, filters: List[ExposedFunction], question: str) -> None: - self.filters = filters - self.question = question - - def __call__(self, conversation_template: PromptTemplate) -> Tuple[PromptTemplate, Dict[str, str]]: - """ - Runs the input formatting for provided prompt template. - - Args: - conversation_template: a prompt template to use. - - Returns: - A tuple with template and a dictionary with formatted filters and a question. - """ - return conversation_template, { - "filters": _promptify_filters(self.filters), - "question": self.question, - } - - -class IQLFewShotInputFormatter(InputFormatter): - """ - Formats provided parameters to a form acceptable by default IQL prompt. - Calling it will inject `examples` before last message in a conversation. - """ - - def __init__( - self, - filters: List[ExposedFunction], - examples: List[FewShotExample], - question: str, - ) -> None: - self.filters = filters - self.question = question - self.examples = examples - - def __call__(self, conversation_template: PromptTemplate) -> Tuple[PromptTemplate, Dict[str, str]]: - """ - Performs a deep copy of provided template and injects examples into chat history. - Also prepares filters and question to be included within the prompt. - - Args: - conversation_template: a prompt template to use to inject few-shot examples. - - Returns: - A tuple with deeply-copied and enriched with examples template - and a dictionary with formatted filters and a question. - """ - - template_copy = copy.deepcopy(conversation_template) - sys_msg = template_copy.chat[0] - existing_msgs = [msg for msg in template_copy.chat[1:] if "is_example" not in msg] - chat_examples = [ - msg - for example in self.examples - for msg in [ - {"role": "user", "content": example.question, "is_example": True}, - {"role": "assistant", "content": example.answer, "is_example": True}, - ] - ] - - template_copy.chat = ( - sys_msg, - *chat_examples, - *existing_msgs, - ) - - return template_copy, { - "filters": _promptify_filters(self.filters), - "question": self.question, - } diff --git a/src/dbally/prompts/prompt_template.py b/src/dbally/prompts/prompt_template.py deleted file mode 100644 index 8e2746fe..00000000 --- a/src/dbally/prompts/prompt_template.py +++ /dev/null @@ -1,83 +0,0 @@ -from typing import Callable, Dict, Optional - -from typing_extensions import Self - -from .common_validation_utils import ChatFormat, PromptTemplateError - - -def _check_chat_order(chat: ChatFormat) -> ChatFormat: - """ - Pydantic validator. Checks if the chat template is constructed correctly (system, user, assistant alternating). - - Args: - chat: Chat template - - Raises: - PromptTemplateError: if chat template is not constructed correctly. - - Returns: - Chat template - """ - expected_order = ["user", "assistant"] - for i, message in enumerate(chat): - role = message["role"] - if role == "system": - if i != 0: - raise PromptTemplateError("Only first message should come from system") - continue - index = i % len(expected_order) - if role != expected_order[index - 1]: - raise PromptTemplateError( - "Template format is not correct. It should be system, and then user/assistant alternating." - ) - - if expected_order[index] not in ["user", "assistant"]: - raise PromptTemplateError("Template needs to end on either user or assistant turn") - return chat - - -class PromptTemplate: - """ - Class for prompt templates - - Attributes: - response_format: Optional argument for OpenAI Turbo models - may be used to force json output - llm_response_parser: Function parsing the LLM response into IQL - """ - - def __init__( - self, - chat: ChatFormat, - response_format: Optional[Dict[str, str]] = None, - llm_response_parser: Callable = lambda x: x, - ): - self.chat: ChatFormat = _check_chat_order(chat) - self.response_format = response_format - self.llm_response_parser = llm_response_parser - - def __eq__(self, __value: object) -> bool: - return isinstance(__value, PromptTemplate) and self.chat == __value.chat - - def add_user_message(self, content: str) -> Self: - """ - Add a user message to the template prompt. - - Args: - content: Message to be added - - Returns: - PromptTemplate with appended user message - """ - return self.__class__((*self.chat, {"role": "user", "content": content})) - - def add_assistant_message(self, content: str) -> Self: - """ - Add an assistant message to the template prompt. - - Args: - content: Message to be added - - Returns: - PromptTemplate with appended assistant message - """ - return self.__class__((*self.chat, {"role": "assistant", "content": content})) diff --git a/src/dbally/view_selection/llm_view_selector.py b/src/dbally/view_selection/llm_view_selector.py index 2d501922..b4069bb1 100644 --- a/src/dbally/view_selection/llm_view_selector.py +++ b/src/dbally/view_selection/llm_view_selector.py @@ -1,12 +1,11 @@ -import copy -from typing import Callable, Dict, Optional +from typing import Dict, Optional from dbally.audit.event_tracker import EventTracker -from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions +from dbally.prompt.template import PromptTemplate from dbally.view_selection.base import ViewSelector -from dbally.view_selection.view_selector_prompt_template import default_view_selector_template +from dbally.view_selection.prompt import VIEW_SELECTION_TEMPLATE, ViewSelectionPromptFormat class LLMViewSelector(ViewSelector): @@ -20,22 +19,16 @@ class LLMViewSelector(ViewSelector): ultimately returning the name of the most suitable view. """ - def __init__( - self, - llm: LLM, - prompt_template: Optional[IQLPromptTemplate] = None, - promptify_views: Optional[Callable[[Dict[str, str]], str]] = None, - ) -> None: + def __init__(self, llm: LLM, prompt_template: Optional[PromptTemplate[ViewSelectionPromptFormat]] = None) -> None: """ + Constructs a new LLMViewSelector instance. + Args: llm: LLM used to generate IQL prompt_template: template for the prompt used for the view selection - promptify_views: Function formatting filters for prompt. By default names and descriptions of\ - all views are concatenated """ self._llm = llm - self._prompt_template = prompt_template or copy.deepcopy(default_view_selector_template) - self._promptify_views = promptify_views or _promptify_views + self._prompt_template = prompt_template or VIEW_SELECTION_TEMPLATE async def select_view( self, @@ -56,28 +49,13 @@ async def select_view( Returns: The most relevant view name. """ - - views_for_prompt = self._promptify_views(views) + prompt_format = ViewSelectionPromptFormat(question=question, views=views) + formatted_prompt = self._prompt_template.format_prompt(prompt_format) llm_response = await self._llm.generate_text( - template=self._prompt_template, - fmt={"views": views_for_prompt, "question": question}, + prompt=formatted_prompt, event_tracker=event_tracker, options=llm_options, ) - selected_view = self._prompt_template.llm_response_parser(llm_response) + selected_view = self._prompt_template.response_parser(llm_response) return selected_view - - -def _promptify_views(views: Dict[str, str]) -> str: - """ - Formats views for prompt - - Args: - views: dictionary of available view names with corresponding descriptions. - - Returns: - views_for_prompt: views formatted for prompt - """ - - return "\n".join([f"{name}: {description}" for name, description in views.items()]) diff --git a/src/dbally/view_selection/prompt.py b/src/dbally/view_selection/prompt.py new file mode 100644 index 00000000..cdbedf5a --- /dev/null +++ b/src/dbally/view_selection/prompt.py @@ -0,0 +1,52 @@ +from typing import Dict, List + +from dbally.prompt.elements import FewShotExample +from dbally.prompt.template import PromptFormat, PromptTemplate + + +class ViewSelectionPromptFormat(PromptFormat): + """ + Formats provided parameters to a form acceptable by default IQL prompt. + """ + + def __init__( + self, + *, + question: str, + views: Dict[str, str], + examples: List[FewShotExample] = None, + ) -> None: + """ + Constructs a new ViewSelectionPromptFormat instance. + + Args: + question: Question to be asked. + views: Dictionary of available view names with corresponding descriptions. + examples: List of examples to be injected into the conversation. + """ + super().__init__(examples) + self.question = question + self.views = "\n".join([f"{name}: {description}" for name, description in views.items()]) + + +VIEW_SELECTION_TEMPLATE = PromptTemplate[ViewSelectionPromptFormat]( + [ + { + "role": "system", + "content": ( + "You are a very smart database programmer. " + "You have access to API that lets you query a database:\n" + "First you need to select a class to query, based on its description and the user question. " + "You have the following classes to choose from:\n" + "{views}\n" + "Return only the selected view name. Don't give any comments.\n" + "You can only use the classes that were listed. " + "If none of the classes listed can be used to answer the user question, say `NoViewFoundError`" + ), + }, + { + "role": "user", + "content": "{question}", + }, + ], +) diff --git a/src/dbally/view_selection/view_selector_prompt_template.py b/src/dbally/view_selection/view_selector_prompt_template.py deleted file mode 100644 index 60440c84..00000000 --- a/src/dbally/view_selection/view_selector_prompt_template.py +++ /dev/null @@ -1,51 +0,0 @@ -import json -from typing import Callable, Dict, Optional - -from dbally.prompts import ChatFormat, PromptTemplate, check_prompt_variables - - -class ViewSelectorPromptTemplate(PromptTemplate): - """ - Class for prompt templates meant for the ViewSelector - """ - - def __init__( - self, - chat: ChatFormat, - response_format: Optional[Dict[str, str]] = None, - llm_response_parser: Callable = lambda x: x, - ): - super().__init__(chat, response_format, llm_response_parser) - self.chat = check_prompt_variables(chat, {"views"}) - - -def _convert_llm_json_response_to_selected_view(llm_response_json: str) -> str: - """ - Converts LLM json response to IQL - - Args: - llm_response_json: LLM response in JSON format - - Returns: - A string containing selected view - """ - llm_response_dict = json.loads(llm_response_json) - return llm_response_dict.get("view") - - -default_view_selector_template = ViewSelectorPromptTemplate( - chat=( - { - "role": "system", - "content": "You are a very smart database programmer. " - "You have access to API that lets you query a database:\n" - "First you need to select a class to query, based on its description and the user question. " - "You have the following classes to choose from:\n" - "{views}\n" - "Return only the selected view name. Don't give any comments.\n" - "You can only use the classes that were listed. " - "If none of the classes listed can be used to answer the user question, say `NoViewFoundError`", - }, - {"role": "user", "content": "{question}"}, - ), -) diff --git a/src/dbally/views/base.py b/src/dbally/views/base.py index a3278281..d5103884 100644 --- a/src/dbally/views/base.py +++ b/src/dbally/views/base.py @@ -5,7 +5,7 @@ from dbally.collection.results import ViewExecutionResult from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions -from dbally.prompts.elements import FewShotExample +from dbally.prompt.elements import FewShotExample from dbally.similarity import AbstractSimilarityIndex IndexLocation = Tuple[str, str, str] diff --git a/src/dbally/views/freeform/text2sql/prompt.py b/src/dbally/views/freeform/text2sql/prompt.py new file mode 100644 index 00000000..5f9a547d --- /dev/null +++ b/src/dbally/views/freeform/text2sql/prompt.py @@ -0,0 +1,61 @@ +# pylint: disable=C0301 + +from typing import List + +from dbally.prompt.elements import FewShotExample +from dbally.prompt.template import PromptFormat, PromptTemplate +from dbally.views.freeform.text2sql.config import TableConfig + + +class SQLGenerationPromptFormat(PromptFormat): + """ + Formats provided parameters to a form acceptable by default SQL prompt. + """ + + def __init__( + self, + *, + question: str, + dialect: str, + tables: List[TableConfig], + examples: List[FewShotExample] = None, + ) -> None: + """ + Constructs a new SQLGenerationPromptFormat instance. + + Args: + question: Question to be asked. + context: Context of the query. + examples: List of examples to be injected into the conversation. + """ + super().__init__(examples) + self.question = question + self.dialect = dialect + self.tables = "\n".join(table.ddl for table in tables) + + +SQL_GENERATION_TEMPLATE = PromptTemplate[SQLGenerationPromptFormat]( + [ + { + "role": "system", + "content": ( + "You are a very smart database programmer. " + "You have access to the following {dialect} tables:\n" + "{tables}\n" + "Create SQL query to answer user question. Response with JSON containing following keys:\n\n" + "- sql: SQL query to answer the question, with parameter :placeholders for user input.\n" + "- parameters: a list of parameters to be used in the query, represented by maps with the following keys:\n" + " - name: the name of the parameter\n" + " - value: the value of the parameter\n" + " - table: the table the parameter is used with (if any)\n" + " - column: the column the parameter is compared to (if any)\n\n" + "Respond ONLY with the raw JSON response. Don't include any additional text or characters." + ), + }, + { + "role": "user", + "content": "{question}", + }, + ], + json_mode=True, +) diff --git a/src/dbally/views/freeform/text2sql/view.py b/src/dbally/views/freeform/text2sql/view.py index 6891785e..7f24f00e 100644 --- a/src/dbally/views/freeform/text2sql/view.py +++ b/src/dbally/views/freeform/text2sql/view.py @@ -10,32 +10,12 @@ from dbally.collection.results import ViewExecutionResult from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions -from dbally.prompts import PromptTemplate +from dbally.prompt.template import PromptTemplate from dbally.similarity import AbstractSimilarityIndex, SimpleSqlAlchemyFetcher from dbally.views.base import BaseView, IndexLocation from dbally.views.freeform.text2sql.config import TableConfig from dbally.views.freeform.text2sql.exceptions import Text2SQLError - -text2sql_prompt = PromptTemplate( - chat=( - { - "role": "system", - "content": "You are a very smart database programmer. " - "You have access to the following {dialect} tables:\n" - "{tables}\n" - "Create SQL query to answer user question. Response with JSON containing following keys:\n\n" - "- sql: SQL query to answer the question, with parameter :placeholders for user input.\n" - "- parameters: a list of parameters to be used in the query, represented by maps with the following keys:\n" - " - name: the name of the parameter\n" - " - value: the value of the parameter\n" - " - table: the table the parameter is used with (if any)\n" - " - column: the column the parameter is compared to (if any)\n\n" - "Respond ONLY with the raw JSON response. Don't include any additional text or characters.", - }, - {"role": "user", "content": "{question}"}, - ), - response_format={"type": "json_object"}, -) +from dbally.views.freeform.text2sql.prompt import SQL_GENERATION_TEMPLATE, SQLGenerationPromptFormat @dataclass @@ -142,17 +122,26 @@ async def ask( Raises: Text2SQLError: If the text2sql query generation fails after n_retries. """ - conversation = text2sql_prompt sql, rows = None, None exceptions = [] - for _ in range(n_retries): + tables = self.get_tables() + examples = self.list_few_shots() + + prompt_format = SQLGenerationPromptFormat( + question=query, + dialect=self._engine.dialect.name, + tables=tables, + examples=examples, + ) + formatted_prompt = SQL_GENERATION_TEMPLATE.format_prompt(prompt_format) + + for _ in range(n_retries + 1): # We want to catch all exceptions to retry the process. # pylint: disable=broad-except try: - sql, parameters, conversation = await self._generate_sql( - query=query, - conversation=conversation, + sql, parameters, formatted_prompt = await self._generate_sql( + conversation=formatted_prompt, llm=llm, event_tracker=event_tracker, llm_options=llm_options, @@ -164,7 +153,7 @@ async def ask( rows = await self._execute_sql(sql, parameters, event_tracker=event_tracker) break except Exception as e: - conversation = conversation.add_user_message(f"Response is invalid! Error: {e}") + formatted_prompt = formatted_prompt.add_user_message(f"Response is invalid! Error: {e}") exceptions.append(e) continue @@ -182,15 +171,13 @@ async def ask( async def _generate_sql( self, - query: str, conversation: PromptTemplate, llm: LLM, event_tracker: EventTracker, llm_options: Optional[LLMOptions] = None, ) -> Tuple[str, List[SQLParameterOption], PromptTemplate]: response = await llm.generate_text( - template=conversation, - fmt={"tables": self._get_tables_context(), "dialect": self._engine.dialect.name, "question": query}, + prompt=conversation, event_tracker=event_tracker, options=llm_options, ) @@ -221,12 +208,6 @@ async def _execute_sql( with self._engine.connect() as conn: return conn.execute(text(sql), param_values).fetchall() - def _get_tables_context(self) -> str: - context = "" - for table in self._table_index.values(): - context += f"{table.ddl}\n" - return context - def _create_default_fetcher(self, table: str, column: str) -> SimpleSqlAlchemyFetcher: return SimpleSqlAlchemyFetcher( sqlalchemy_engine=self._engine, diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index 8b95ecaa..b5863075 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -4,11 +4,10 @@ from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult -from dbally.iql import IQLError, IQLQuery +from dbally.iql import IQLQuery from dbally.iql_generator.iql_generator import IQLGenerator from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions -from dbally.prompts.formatters import IQLFewShotInputFormatter, IQLInputFormatter from dbally.views.exposed_functions import ExposedFunction from ..similarity import AbstractSimilarityIndex @@ -26,10 +25,10 @@ def get_iql_generator(self, llm: LLM) -> IQLGenerator: Returns the IQL generator for the view. Args: - llm: LLM used to generate the IQL queries + llm: LLM used to generate the IQL queries. Returns: - IQLGenerator: IQL generator for the view + IQL generator for the view. """ return IQLGenerator(llm=llm) @@ -57,46 +56,30 @@ async def ask( Returns: The result of the query. """ + iql_generator = self.get_iql_generator(llm) filters = self.list_filters() examples = self.list_few_shots() - iql_generator = self.get_iql_generator(llm) - input_formatter = ( - IQLFewShotInputFormatter(question=query, filters=filters, examples=examples) - if examples - else IQLInputFormatter(question=query, filters=filters) - ) - - iql_filters, conversation = await iql_generator.generate_iql( - input_formatter=input_formatter, + iql = await iql_generator.generate_iql( + question=query, + filters=filters, + examples=examples, event_tracker=event_tracker, llm_options=llm_options, + n_retries=n_retries, ) - - for _ in range(n_retries): - try: - filters = await IQLQuery.parse(iql_filters, filters, event_tracker=event_tracker) - await self.apply_filters(filters) - break - except (IQLError, ValueError) as e: - conversation = iql_generator.add_error_msg(conversation, [e]) - iql_filters, conversation = await iql_generator.generate_iql( - input_formatter=input_formatter, - event_tracker=event_tracker, - conversation=conversation, - llm_options=llm_options, - ) - continue + await self.apply_filters(iql) result = self.execute(dry_run=dry_run) - result.context["iql"] = iql_filters + result.context["iql"] = f"{iql}" return result @abc.abstractmethod def list_filters(self) -> List[ExposedFunction]: """ + Lists all available filters for the View. Returns: Filters defined inside the View. diff --git a/src/dbally_codegen/autodiscovery.py b/src/dbally_codegen/autodiscovery.py index 1e20c542..c842a07f 100644 --- a/src/dbally_codegen/autodiscovery.py +++ b/src/dbally_codegen/autodiscovery.py @@ -6,12 +6,59 @@ from typing_extensions import Self from dbally.llms.base import LLM -from dbally.prompts import PromptTemplate +from dbally.prompt.template import PromptFormat, PromptTemplate from dbally.similarity.index import SimilarityIndex -from dbally.views.freeform.text2sql import ColumnConfig, TableConfig +from dbally.views.freeform.text2sql.config import ColumnConfig, TableConfig -DISCOVERY_TEMPLATE = PromptTemplate( - chat=( + +class DiscoveryPromptFormat(PromptFormat): + """ + Formats provided parameters to a form acceptable by default discovery prompt. + """ + + def __init__( + self, + *, + dialect: str, + table_ddl: str, + samples: List[Dict[str, Any]], + ) -> None: + """ + Constructs a new DiscoveryPromptFormat instance. + + Args: + dialect: The SQL dialect of the database. + table_ddl: The DDL of the table. + samples: The example rows from the table. + """ + super().__init__() + self.dialect = dialect + self.table_ddl = table_ddl + self.samples = samples + + +class SimilarityPromptFormat(PromptFormat): + """ + Formats provided parameters to a form acceptable by default similarity prompt. + """ + + def __init__(self, *, table_summary: str, column_name: str, samples: List[Any]) -> None: + """ + Constructs a new SimilarityPromptFormat instance. + + Args: + table_summary: The summary of the table. + column_name: The name of the column. + samples: The example values from the column. + """ + super().__init__() + self.table_summary = table_summary + self.column_name = column_name + self.samples = samples + + +DISCOVERY_TEMPLATE = PromptTemplate[DiscoveryPromptFormat]( + [ { "role": "system", "content": ( @@ -24,11 +71,11 @@ "role": "user", "content": "DDL:\n {table_ddl}\n" "EXAMPLE ROWS:\n {samples}", }, - ), + ], ) -SIMILARITY_TEMPLATE = PromptTemplate( - chat=( +SIMILARITY_TEMPLATE = PromptTemplate[SimilarityPromptFormat]( + [ { "role": "system", "content": ( @@ -43,7 +90,7 @@ "role": "user", "content": "TABLE SUMMARY: {table_summary}\n" "COLUMN NAME: {column_name}\n" "EXAMPLE VALUES: {samples}", }, - ) + ], ) @@ -108,14 +155,15 @@ async def extract_description(self, table: Table, connection: Connection) -> str """ ddl = self._generate_ddl(table) samples = self._fetch_samples(connection, table) - return await self.llm.generate_text( - template=DISCOVERY_TEMPLATE, - fmt={ - "dialect": self.engine.dialect.name, - "table_ddl": ddl, - "samples": samples, - }, + + prompt_format = DiscoveryPromptFormat( + dialect=self.engine.dialect.name, + table_ddl=ddl, + samples=samples, ) + formatted_prompt = DISCOVERY_TEMPLATE.format_prompt(prompt_format) + + return await self.llm.generate_text(formatted_prompt) def _fetch_samples(self, connection: Connection, table: Table) -> List[Dict[str, Any]]: rows = connection.execute(table.select().limit(self.samples_count)).fetchall() @@ -218,14 +266,15 @@ async def select_index( table=table, column=column, ) - use_index = await self.llm.generate_text( - template=SIMILARITY_TEMPLATE, - fmt={ - "table_summary": description, - "column_name": column.name, - "samples": samples, - }, + + prompt_format = SimilarityPromptFormat( + table_summary=description, + column_name=column.name, + samples=samples, ) + formatted_prompt = SIMILARITY_TEMPLATE.format_prompt(prompt_format) + + use_index = await self.llm.generate_text(formatted_prompt) return self.index_builder(connection.engine, table, column) if use_index.upper() == "TRUE" else None def _fetch_samples(self, connection: Connection, table: Table, column: Column) -> List[Any]: diff --git a/tests/integration/test_llm_options.py b/tests/integration/test_llm_options.py index e8c53435..fb8cfba4 100644 --- a/tests/integration/test_llm_options.py +++ b/tests/integration/test_llm_options.py @@ -35,20 +35,20 @@ async def test_llm_options_propagation(): llm.client.call.assert_has_calls( [ call( - prompt=ANY, - response_format=ANY, + conversation=ANY, + json_mode=ANY, event=ANY, options=expected_options, ), call( - prompt=ANY, - response_format=ANY, + conversation=ANY, + json_mode=ANY, event=ANY, options=expected_options, ), call( - prompt=ANY, - response_format=ANY, + conversation=ANY, + json_mode=ANY, event=ANY, options=expected_options, ), diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 9858e45f..75cc914b 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -6,12 +6,11 @@ from dataclasses import dataclass from functools import cached_property -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union from dbally import NOT_GIVEN, NotGiven from dbally.iql import IQLQuery from dbally.iql_generator.iql_generator import IQLGenerator -from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate, default_iql_template from dbally.llms.base import LLM from dbally.llms.clients.base import LLMClient, LLMOptions from dbally.similarity.index import AbstractSimilarityIndex @@ -35,12 +34,12 @@ def execute(self, dry_run=False) -> ViewExecutionResult: class MockIQLGenerator(IQLGenerator): - def __init__(self, iql: str) -> None: + def __init__(self, iql: IQLQuery) -> None: self.iql = iql super().__init__(llm=MockLLM()) - async def generate_iql(self, *_, **__) -> Tuple[str, IQLPromptTemplate]: - return self.iql, default_iql_template + async def generate_iql(self, *_, **__) -> IQLQuery: + return self.iql class MockViewSelector(ViewSelector): diff --git a/tests/unit/test_assistants_adapters.py b/tests/unit/test_assistants_adapters.py index 72a55e06..9c203bd6 100644 --- a/tests/unit/test_assistants_adapters.py +++ b/tests/unit/test_assistants_adapters.py @@ -8,7 +8,7 @@ from dbally.assistants.base import FunctionCallingError, FunctionCallState from dbally.assistants.openai import _DBALLY_INFO, _DBALLY_INSTRUCTION, OpenAIAdapter, OpenAIDballyResponse -from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError +from dbally.iql_generator.prompt import UnsupportedQueryError MOCK_VIEWS = {"view1": "description1", "view2": "description2"} F_ID = "f_id" diff --git a/tests/unit/test_collection.py b/tests/unit/test_collection.py index e12ecec6..cd6934dd 100644 --- a/tests/unit/test_collection.py +++ b/tests/unit/test_collection.py @@ -1,7 +1,7 @@ # pylint: disable=missing-docstring, missing-return-doc, missing-param-doc, disallowed-name, missing-return-type-doc from typing import List, Tuple, Type -from unittest.mock import AsyncMock, Mock, call, patch +from unittest.mock import AsyncMock, Mock import pytest from typing_extensions import Annotated @@ -9,9 +9,9 @@ from dbally.collection import Collection, create_collection from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ViewExecutionResult -from dbally.iql._exceptions import IQLError +from dbally.iql import IQLQuery +from dbally.iql.syntax import FunctionCall from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping -from dbally.views.structured import BaseStructuredView from tests.unit.mocks import MockIQLGenerator, MockLLM, MockSimilarityIndex, MockViewBase, MockViewSelector @@ -58,8 +58,8 @@ def execute(self, dry_run=False) -> ViewExecutionResult: def list_filters(self) -> List[ExposedFunction]: return [ExposedFunction("test_filter", "", [])] - def get_iql_generator(self, *_, **__): - return MockIQLGenerator("test_filter()") + def get_iql_generator(self, *_, **__) -> MockIQLGenerator: + return MockIQLGenerator(IQLQuery(FunctionCall("test_filter", []), "test_filter()")) @pytest.fixture(name="similarity_classes") @@ -274,42 +274,6 @@ def get_iql_generator(self, *_, **__): return collection -async def test_ask_feedback_loop(collection_feedback: Collection) -> None: - """ - Tests that the ask_feedback_loop method works correctly - """ - - mock_node = Mock(col_offset=0, end_col_offset=-1) - errors = [ - IQLError("err1", mock_node, "src1"), - IQLError("err2", mock_node, "src2"), - ValueError("err3"), - ValueError("err4"), - ] - with patch("dbally.iql._query.IQLQuery.parse") as mock_iql_query: - mock_iql_query.side_effect = errors - view = collection_feedback.get("ViewWithMockGenerator") - assert isinstance(view, BaseStructuredView) - iql_generator = view.get_iql_generator(llm=MockLLM()) - - await collection_feedback.ask("Mock question") - - iql_gen_error: Mock = iql_generator.add_error_msg # type: ignore - - iql_gen_error.assert_has_calls( - [call("iql1_c", [errors[0]]), call("iql2_c", [errors[1]]), call("iql3_c", [errors[2]])] - ) - assert iql_gen_error.call_count == 3 - - iql_gen_gen_iql: Mock = iql_generator.generate_iql # type: ignore - - for i, c in enumerate(iql_gen_gen_iql.call_args_list): - if i > 0: - assert c[1]["conversation"] == f"err{i}" - - assert iql_gen_gen_iql.call_count == 4 - - async def test_ask_view_selection_single_view() -> None: """ Tests that the ask method select view correctly when there is only one view diff --git a/tests/unit/test_fewshot.py b/tests/unit/test_fewshot.py index e2f4cf8d..2b8ba8b3 100644 --- a/tests/unit/test_fewshot.py +++ b/tests/unit/test_fewshot.py @@ -2,20 +2,20 @@ import pytest -from dbally.prompts.elements import FewShotExample +from dbally.prompt.elements import FewShotExample class TestExamples: - def studied_at(self, _: str): + def studied_at(self, _: str) -> bool: return False - def is_available_within_months(self, _: int): + def is_available_within_months(self, _: int) -> bool: return False - def data_scientist_position(self): + def data_scientist_position(self) -> bool: return False - def has_seniority(self, _: str): + def has_seniority(self, _: str) -> bool: return False def __call__(self) -> List[Tuple[str, Callable]]: # pylint: disable=W0602, C0116, W9011 @@ -57,16 +57,17 @@ def __call__(self) -> List[Tuple[str, Callable]]: # pylint: disable=W0602, C011 ] -def test_fewshot_string(): - result = FewShotExample("question", "answer") - assert result.answer == "answer" - assert str(result) == "answer" - - @pytest.mark.parametrize( "repr_lambda", TestExamples()(), ) -def test_fewshot_lambda(repr_lambda: Tuple[str, Callable]): +def test_fewshot_lambda(repr_lambda: Tuple[str, Callable]) -> None: result = FewShotExample("question", repr_lambda[1]) - assert str(result) == repr_lambda[0] + assert result.answer == repr_lambda[0] + assert str(result) == f"question -> {repr_lambda[0]}" + + +def test_fewshot_string() -> None: + result = FewShotExample("question", "answer") + assert result.answer == "answer" + assert str(result) == "question -> answer" diff --git a/tests/unit/test_iql_format.py b/tests/unit/test_iql_format.py index c2fb4274..8f583c4c 100644 --- a/tests/unit/test_iql_format.py +++ b/tests/unit/test_iql_format.py @@ -1,68 +1,89 @@ -from typing import List - -import pytest - -from dbally.iql_generator.iql_prompt_template import default_iql_template -from dbally.prompts.elements import FewShotExample -from dbally.prompts.formatters import IQLFewShotInputFormatter, IQLInputFormatter - - -async def test_iql_input_format_default() -> None: - input_fmt = IQLInputFormatter([], "") - - conversation, format = input_fmt(default_iql_template) - - assert len(conversation.chat) == len(default_iql_template.chat) - assert "filters" in format - assert "question" in format - - -async def test_iql_input_format_few_shot_default() -> None: - input_fmt = IQLFewShotInputFormatter([], [], "") - - conversation, format = input_fmt(default_iql_template) - - assert len(conversation.chat) == len(default_iql_template.chat) - assert "filters" in format - assert "question" in format - - -@pytest.mark.parametrize( - "examples", - [ - [], - [FewShotExample("q1", "a1")], - ], -) -async def test_iql_input_format_few_shot_examples_injected(examples: List[FewShotExample]) -> None: +from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat +from dbally.prompt.elements import FewShotExample + + +async def test_iql_prompt_format_default() -> None: + prompt_format = IQLGenerationPromptFormat( + question="", + filters=[], + examples=[], + ) + formatted_prompt = IQL_GENERATION_TEMPLATE.format_prompt(prompt_format) + + assert formatted_prompt.chat == [ + { + "role": "system", + "content": "You have access to API that lets you query a database:\n" + "\n\n" + "Please suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" + "Remember! Don't give any comments, just the function calls.\n" + "The output will look like this:\n" + 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n' + "DO NOT INCLUDE arguments names in your response. Only the values.\n" + "You MUST use only these methods:\n" + "\n\n" + "It is VERY IMPORTANT not to use methods other than those listed above." + """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ + "This is CRUCIAL, otherwise the system will crash. ", + "is_example": False, + }, + {"role": "user", "content": "", "is_example": False}, + ] + + +async def test_iql_prompt_format_few_shots_injected() -> None: examples = [FewShotExample("q1", "a1")] - input_fmt = IQLFewShotInputFormatter([], examples, "") - - conversation, format = input_fmt(default_iql_template) - - assert len(conversation.chat) == len(default_iql_template.chat) + (len(examples) * 2) - assert "filters" in format - assert "question" in format + prompt_format = IQLGenerationPromptFormat( + question="", + filters=[], + examples=examples, + ) + formatted_prompt = IQL_GENERATION_TEMPLATE.format_prompt(prompt_format) + + assert formatted_prompt.chat == [ + { + "role": "system", + "content": "You have access to API that lets you query a database:\n" + "\n\n" + "Please suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" + "Remember! Don't give any comments, just the function calls.\n" + "The output will look like this:\n" + 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n' + "DO NOT INCLUDE arguments names in your response. Only the values.\n" + "You MUST use only these methods:\n" + "\n\n" + "It is VERY IMPORTANT not to use methods other than those listed above." + """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ + "This is CRUCIAL, otherwise the system will crash. ", + "is_example": False, + }, + {"role": "user", "content": examples[0].question, "is_example": True}, + {"role": "assistant", "content": examples[0].answer, "is_example": True}, + {"role": "user", "content": "", "is_example": False}, + ] async def test_iql_input_format_few_shot_examples_repeat_no_example_duplicates() -> None: examples = [FewShotExample("q1", "a1")] - input_fmt = IQLFewShotInputFormatter([], examples, "q") - - conversation, _ = input_fmt(default_iql_template) - - assert len(conversation.chat) == len(default_iql_template.chat) + (len(examples) * 2) - assert conversation.chat[1]["role"] == "user" - assert conversation.chat[1]["content"] == examples[0].question - assert conversation.chat[2]["role"] == "assistant" - assert conversation.chat[2]["content"] == examples[0].answer - - conversation = conversation.add_assistant_message("response") - - conversation2, _ = input_fmt(conversation) - - assert len(conversation2.chat) == len(conversation.chat) - assert conversation2.chat[1]["role"] == "user" - assert conversation2.chat[1]["content"] == examples[0].question - assert conversation2.chat[2]["role"] == "assistant" - assert conversation2.chat[2]["content"] == examples[0].answer + prompt_format = IQLGenerationPromptFormat( + question="", + filters=[], + examples=examples, + ) + formatted_prompt = IQL_GENERATION_TEMPLATE.format_prompt(prompt_format) + + assert len(formatted_prompt.chat) == len(IQL_GENERATION_TEMPLATE.chat) + (len(examples) * 2) + assert formatted_prompt.chat[1]["role"] == "user" + assert formatted_prompt.chat[1]["content"] == examples[0].question + assert formatted_prompt.chat[2]["role"] == "assistant" + assert formatted_prompt.chat[2]["content"] == examples[0].answer + + formatted_prompt = formatted_prompt.add_assistant_message("response") + + formatted_prompt2 = formatted_prompt.format_prompt(prompt_format) + + assert len(formatted_prompt2.chat) == len(formatted_prompt.chat) + assert formatted_prompt2.chat[1]["role"] == "user" + assert formatted_prompt2.chat[1]["content"] == examples[0].question + assert formatted_prompt2.chat[2]["role"] == "assistant" + assert formatted_prompt2.chat[2]["content"] == examples[0].answer diff --git a/tests/unit/test_iql_generator.py b/tests/unit/test_iql_generator.py index 8c8df9e7..ce3f593d 100644 --- a/tests/unit/test_iql_generator.py +++ b/tests/unit/test_iql_generator.py @@ -1,17 +1,15 @@ # mypy: disable-error-code="empty-body" -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, Mock, patch import pytest import sqlalchemy from dbally import decorators from dbally.audit.event_tracker import EventTracker -from dbally.iql import IQLQuery +from dbally.iql import IQLError, IQLQuery from dbally.iql_generator.iql_generator import IQLGenerator -from dbally.iql_generator.iql_prompt_template import default_iql_template -from dbally.prompts.elements import FewShotExample -from dbally.prompts.formatters import IQLFewShotInputFormatter, IQLInputFormatter +from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat from dbally.views.methods_base import MethodsBaseView from tests.unit.mocks import MockLLM @@ -43,7 +41,7 @@ def view() -> MockView: @pytest.fixture def llm() -> MockLLM: llm = MockLLM() - llm.client.call = AsyncMock(return_value="LLM IQL mock answer") + llm.generate_text = AsyncMock(return_value="filter_by_id(1)") return llm @@ -52,58 +50,64 @@ def event_tracker() -> EventTracker: return EventTracker() -@pytest.mark.asyncio -async def test_iql_generation(llm: MockLLM, event_tracker: EventTracker, view: MockView) -> None: - iql_generator = IQLGenerator(llm) - - filters = {str(_filter) for _filter in view.list_filters()} - assert filters == {"filter_by_id(idx: int)", "filter_by_name(city: str)"} - - input_formatter = IQLInputFormatter(question="Mock_question", filters=view.list_filters()) - - response = await iql_generator.generate_iql(input_formatter, event_tracker, default_iql_template) - - template_after_response = default_iql_template.add_assistant_message(content="LLM IQL mock answer") - assert response == ("LLM IQL mock answer", template_after_response) - - template_after_response = template_after_response.add_user_message(content="Mock_error") - response2 = await iql_generator.generate_iql(input_formatter, event_tracker, template_after_response) - template_after_2nd_response = template_after_response.add_assistant_message(content="LLM IQL mock answer") - assert response2 == ("LLM IQL mock answer", template_after_2nd_response) +@pytest.fixture +def iql_generator(llm: MockLLM) -> IQLGenerator: + return IQLGenerator(llm) @pytest.mark.asyncio -async def test_iql_few_shot_generation(llm: MockLLM, event_tracker: EventTracker, view: MockView) -> None: - iql_generator = IQLGenerator(llm) - - filters = {str(_filter) for _filter in view.list_filters()} - assert filters == {"filter_by_id(idx: int)", "filter_by_name(city: str)"} - - input_formatter = IQLFewShotInputFormatter( +async def test_iql_generation(iql_generator: IQLGenerator, event_tracker: EventTracker, view: MockView) -> None: + filters = view.list_filters() + prompt_format = IQLGenerationPromptFormat( question="Mock_question", - filters=view.list_filters(), - examples=[FewShotExample("question", "filter_by_id(0)")], + filters=filters, ) - - response = await iql_generator.generate_iql(input_formatter, event_tracker, default_iql_template) - - expected_conversation, _ = input_formatter(default_iql_template) - template_after_response = expected_conversation.add_assistant_message(content="LLM IQL mock answer") - assert response == ("LLM IQL mock answer", template_after_response) - - template_after_response = template_after_response.add_user_message(content="Mock_error") - response2 = await iql_generator.generate_iql(input_formatter, event_tracker, template_after_response) - template_after_2nd_response = template_after_response.add_assistant_message(content="LLM IQL mock answer") - assert response2 == ("LLM IQL mock answer", template_after_2nd_response) + formatted_prompt = IQL_GENERATION_TEMPLATE.format_prompt(prompt_format) + + with patch("dbally.iql.IQLQuery.parse", AsyncMock(return_value="filter_by_id(1)")) as mock_parse: + iql = await iql_generator.generate_iql( + question="Mock_question", + filters=filters, + event_tracker=event_tracker, + ) + assert iql == "filter_by_id(1)" + iql_generator._llm.generate_text.assert_called_once_with( + prompt=formatted_prompt, + event_tracker=event_tracker, + options=None, + ) + mock_parse.assert_called_once_with( + source="filter_by_id(1)", + allowed_functions=filters, + event_tracker=event_tracker, + ) -def test_add_error_msg(llm: MockLLM) -> None: - iql_generator = IQLGenerator(llm) - errors = [ValueError("Mock_error")] - - conversation = default_iql_template.add_assistant_message(content="Assistant") - - conversation_with_error = iql_generator.add_error_msg(conversation, errors) - - error_msg = iql_generator._ERROR_MSG_PREFIX + "Mock_error\n" - assert conversation_with_error == conversation.add_user_message(content=error_msg) +@pytest.mark.asyncio +async def test_iql_generation_error_handling( + iql_generator: IQLGenerator, + event_tracker: EventTracker, + view: MockView, +) -> None: + filters = view.list_filters() + + mock_node = Mock(col_offset=0, end_col_offset=-1) + errors = [ + IQLError("err1", mock_node, "src1"), + IQLError("err2", mock_node, "src2"), + IQLError("err3", mock_node, "src3"), + IQLError("err4", mock_node, "src4"), + ] + + with patch("dbally.iql.IQLQuery.parse", AsyncMock(return_value="filter_by_id(1)")) as mock_parse: + mock_parse.side_effect = errors + iql = await iql_generator.generate_iql( + question="Mock_question", + filters=filters, + event_tracker=event_tracker, + ) + + assert iql is None + assert iql_generator._llm.generate_text.call_count == 4 + for i, arg in enumerate(iql_generator._llm.generate_text.call_args_list[1:], start=1): + assert f"err{i}" in arg[1]["prompt"].chat[-1]["content"] diff --git a/tests/unit/test_prompt_builder.py b/tests/unit/test_prompt_builder.py index f8a886fe..00fa7fd5 100644 --- a/tests/unit/test_prompt_builder.py +++ b/tests/unit/test_prompt_builder.py @@ -1,116 +1,99 @@ +from typing import List + import pytest -from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate -from dbally.prompts import ChatFormat, PromptTemplate, PromptTemplateError -from tests.unit.mocks import MockLLM +from dbally.prompt.elements import FewShotExample +from dbally.prompt.template import ChatFormat, PromptFormat, PromptTemplate, PromptTemplateError + + +class QuestionPromptFormat(PromptFormat): + """ + Generic format for prompts allowing to inject few shot examples into the conversation. + """ + + def __init__(self, question: str, examples: List[FewShotExample] = None) -> None: + """ + Constructs a new PromptFormat instance. + + Args: + question: Question to be asked. + examples: List of examples to be injected into the conversation. + """ + super().__init__(examples) + self.question = question @pytest.fixture() -def simple_template(): - simple_template = PromptTemplate( - chat=( +def template() -> PromptTemplate[QuestionPromptFormat]: + return PromptTemplate[QuestionPromptFormat]( + [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "{question}"}, - ) + ] ) - return simple_template -@pytest.fixture() -def llm(): - return MockLLM() +def test_prompt_template_formatting(template: PromptTemplate[QuestionPromptFormat]) -> None: + prompt_format = QuestionPromptFormat(question="Example user question?") + formatted_prompt = template.format_prompt(prompt_format) + assert formatted_prompt.chat == [ + {"content": "You are a helpful assistant.", "role": "system", "is_example": False}, + {"content": "Example user question?", "role": "user", "is_example": False}, + ] -def test_default_llm_format_prompt(llm, simple_template): - prompt = llm.format_prompt( - template=simple_template, - fmt={"question": "Example user question?"}, - ) - assert prompt == [ - {"content": "You are a helpful assistant.", "role": "system"}, - {"content": "Example user question?", "role": "user"}, +def test_missing_prompt_template_formatting(template: PromptTemplate[QuestionPromptFormat]) -> None: + prompt_format = PromptFormat() + with pytest.raises(KeyError): + template.format_prompt(prompt_format) + + +def test_add_few_shots(template: PromptTemplate[QuestionPromptFormat]) -> None: + examples = [ + FewShotExample( + question="What is the capital of France?", + answer_expr="Paris", + ), + FewShotExample( + question="What is the capital of Germany?", + answer_expr="Berlin", + ), ] + for example in examples: + template = template.add_few_shot_message(example) -def test_missing_format_dict(llm, simple_template): - with pytest.raises(KeyError): - _ = llm.format_prompt(simple_template, fmt={}) + assert template.chat == [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?", "is_example": True}, + {"role": "assistant", "content": "Paris", "is_example": True}, + {"role": "user", "content": "What is the capital of Germany?", "is_example": True}, + {"role": "assistant", "content": "Berlin", "is_example": True}, + {"role": "user", "content": "{question}"}, + ] @pytest.mark.parametrize( "invalid_chat", [ - ( + [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "{question}"}, {"role": "user", "content": "{question}"}, - ), - ( + ], + [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "assistant", "content": "{question}"}, {"role": "assistant", "content": "{question}"}, - ), - ( + ], + [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "{question}"}, {"role": "assistant", "content": "{question}"}, {"role": "system", "content": "{question}"}, - ), + ], ], ) -def test_chat_order_validation(invalid_chat): +def test_chat_order_validation(invalid_chat: ChatFormat) -> None: with pytest.raises(PromptTemplateError): - _ = PromptTemplate(chat=invalid_chat) - - -def test_dynamic_few_shot(llm, simple_template): - assert ( - len( - llm.format_prompt( - simple_template.add_assistant_message("assistant message").add_user_message("user message"), - fmt={"question": "user question"}, - ) - ) - == 4 - ) - - -@pytest.mark.parametrize( - "invalid_chat", - [ - ( - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "{question}"}, - ), - ( - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello"}, - ), - ( - {"role": "system", "content": "You are a helpful assistant. {filters}}"}, - {"role": "user", "content": "Hello"}, - ), - ], - ids=["Missing filters", "Missing filters, question", "Missing question"], -) -def test_bad_iql_prompt_template(invalid_chat: ChatFormat): - with pytest.raises(PromptTemplateError): - _ = IQLPromptTemplate(invalid_chat) - - -@pytest.mark.parametrize( - "chat", - [ - ( - {"role": "system", "content": "You are a helpful assistant.{filters}"}, - {"role": "user", "content": "{question}"}, - ), - ( - {"role": "system", "content": "{filters}{filters}{filters}}}"}, - {"role": "user", "content": "{question}"}, - ), - ], - ids=["Good template", "Good template with repeating variables"], -) -def test_good_iql_prompt_template(chat: ChatFormat): - _ = IQLPromptTemplate(chat) + PromptTemplate[QuestionPromptFormat](invalid_chat) diff --git a/tests/unit/test_view_selector.py b/tests/unit/test_view_selector.py index 2d3b1d9c..8de038e2 100644 --- a/tests/unit/test_view_selector.py +++ b/tests/unit/test_view_selector.py @@ -31,7 +31,7 @@ def views() -> Dict[str, str]: @pytest.mark.asyncio -async def test_view_selection(llm: LLM, views: Dict[str, str]): +async def test_view_selection(llm: LLM, views: Dict[str, str]) -> None: view_selector = LLMViewSelector(llm) view = await view_selector.select_view("Mock question?", views, event_tracker=EventTracker()) assert view == "MockView1" From ef025ac200f1c2cefeaef497987935938274a226 Mon Sep 17 00:00:00 2001 From: akotyla <79326805+akotyla@users.noreply.github.com> Date: Wed, 3 Jul 2024 16:21:37 +0200 Subject: [PATCH 47/64] feat(llms): add support for HuggingFace models loaded locally (#61) --- benchmark/dbally_benchmark/e2e_benchmark.py | 12 ++- benchmark/dbally_benchmark/iql_benchmark.py | 11 +-- .../dbally_benchmark/text2sql_benchmark.py | 7 +- docs/how-to/llms/local.md | 66 +++++++++++++ docs/reference/llms/local.md | 7 ++ mkdocs.yml | 2 + setup.cfg | 5 +- src/dbally/llms/clients/local.py | 95 +++++++++++++++++++ src/dbally/llms/local.py | 60 ++++++++++++ 9 files changed, 252 insertions(+), 13 deletions(-) create mode 100644 docs/how-to/llms/local.md create mode 100644 docs/reference/llms/local.md create mode 100644 src/dbally/llms/clients/local.py create mode 100644 src/dbally/llms/local.py diff --git a/benchmark/dbally_benchmark/e2e_benchmark.py b/benchmark/dbally_benchmark/e2e_benchmark.py index aa686727..f122a1ea 100644 --- a/benchmark/dbally_benchmark/e2e_benchmark.py +++ b/benchmark/dbally_benchmark/e2e_benchmark.py @@ -25,6 +25,7 @@ from dbally.collection.exceptions import NoViewFoundError from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, UnsupportedQueryError from dbally.llms.litellm import LiteLLM +from dbally.llms.local import LocalLLM from dbally.view_selection.prompt import VIEW_SELECTION_TEMPLATE @@ -82,10 +83,13 @@ async def evaluate(cfg: DictConfig) -> Any: engine = create_engine(benchmark_cfg.pg_connection_string + f"/{cfg.db_name}") - llm = LiteLLM( - model_name="gpt-4", - api_key=benchmark_cfg.openai_api_key, - ) + if cfg.model_name.startswith("local/"): + llm = LocalLLM(api_key=benchmark_cfg.hf_api_key, model_name=cfg.model_name.split("/", 1)[1]) + else: + llm = LiteLLM( + model_name=cfg.model_name, + api_key=benchmark_cfg.openai_api_key, + ) db = dbally.create_collection(cfg.db_name, llm) diff --git a/benchmark/dbally_benchmark/iql_benchmark.py b/benchmark/dbally_benchmark/iql_benchmark.py index 2557b2c2..d5b6b2ef 100644 --- a/benchmark/dbally_benchmark/iql_benchmark.py +++ b/benchmark/dbally_benchmark/iql_benchmark.py @@ -23,6 +23,7 @@ from dbally.iql_generator.iql_generator import IQLGenerator from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, UnsupportedQueryError from dbally.llms.litellm import LiteLLM +from dbally.llms.local import LocalLLM from dbally.views.structured import BaseStructuredView @@ -96,13 +97,11 @@ async def evaluate(cfg: DictConfig) -> Any: engine = create_engine(benchmark_cfg.pg_connection_string + f"/{cfg.db_name}") view = VIEW_REGISTRY[ViewName(view_name)](engine) - if "gpt" in cfg.model_name: - llm = LiteLLM( - model_name=cfg.model_name, - api_key=benchmark_cfg.openai_api_key, - ) + if cfg.model_name.startswith("local/"): + llm = LocalLLM(model_name=cfg.model_name.split("/", 1)[1], api_key=benchmark_cfg.hf_api_key) else: - raise ValueError("Only OpenAI's GPT models are supported for now.") + llm = LiteLLM(api_key=benchmark_cfg.openai_api_key, model_name=cfg.model_name) + iql_generator = IQLGenerator(llm=llm) run = None diff --git a/benchmark/dbally_benchmark/text2sql_benchmark.py b/benchmark/dbally_benchmark/text2sql_benchmark.py index ede53f88..5e4c5860 100644 --- a/benchmark/dbally_benchmark/text2sql_benchmark.py +++ b/benchmark/dbally_benchmark/text2sql_benchmark.py @@ -22,6 +22,7 @@ from dbally.audit.event_tracker import EventTracker from dbally.llms.litellm import LiteLLM +from dbally.llms.local import LocalLLM def _load_db_schema(db_name: str, encoding: Optional[str] = None) -> str: @@ -84,10 +85,12 @@ async def evaluate(cfg: DictConfig) -> Any: engine = create_engine(benchmark_cfg.pg_connection_string + f"/{cfg.db_name}") - if "gpt" in cfg.model_name: + if cfg.model_name.startswith("local/"): + llm = LocalLLM(model_name=cfg.model_name.split("/", 1)[1], api_key=benchmark_cfg.hf_api_key) + else: llm = LiteLLM( - model_name=cfg.model_name, api_key=benchmark_cfg.openai_api_key, + model_name=cfg.model_name, ) run = None diff --git a/docs/how-to/llms/local.md b/docs/how-to/llms/local.md new file mode 100644 index 00000000..fec271e7 --- /dev/null +++ b/docs/how-to/llms/local.md @@ -0,0 +1,66 @@ +# How-To: Use Local LLMs + +db-ally includes a ready-to-use implementation for local LLMs called [`LocalLLM`](../../reference/llms/local.md#dbally.llms.local.LocalLLM), which leverages the Hugging Face Transformers library to provide access to various LLMs available on Hugging Face. + +## Basic Usage + +Install the required dependencies for using local LLMs. + +```bash +pip install dbally[local] +``` + +Integrate db-ally with your Local LLM + +First, set up your environment to use a Hugging Face model. + +```python + +import os +from dbally.llms.localllm import LocalLLM + +os.environ["HUGGINGFACE_API_KEY"] = "your-api-key" + +llm = LocalLLM(model_name="meta-llama/Meta-Llama-3-8B-Instruct") +``` + +Use LLM in your collection + +```python + +my_collection = dbally.create_collection("my_collection", llm) +response = await my_collection.ask("Which LLM should I use?") +``` + +## Advanced Usage + +For advanced users, you can customize your LLM using [`LocalLLMOptions`](../../reference/llms/local.md#dbally.llms.clients.local.LocalLLMOptions). Here is a list of available parameters: + +- `repetition_penalty`: *float or null (optional)* - Penalizes repeated tokens to avoid repetitions. +- `do_sample`: *bool or null (optional)* - Enables sampling instead of greedy decoding. +- `best_of`: *int or null (optional)* - Generates multiple sequences and returns the one with the highest score. +- `max_new_tokens`: *int (optional)* - The maximum number of new tokens to generate. +- `top_k`: *int or null (optional)* - Limits the next token choices to the top-k probability tokens. +- `top_p`: *float or null (optional)* - Limits the next token choices to tokens within the top-p probability mass. +- `seed`: *int or null (optional)* - Sets the seed for random number generation to ensure reproducibility. +- `stop_sequences`: *list of strings or null (optional)* - Specifies sequences where the generation should stop. +- `temperature`: *float or null (optional)* - Adjusts the randomness of token selection. + +```python +import dbally +from dbally.llms.clients.localllm import LocalLLMOptions + +llm = LocalLLM("meta-llama/Meta-Llama-3-8B-Instruct", default_options=LocalLLMOptions(temperature=0.7)) +my_collection = dbally.create_collection("my_collection", llm) +``` + +You can also override any default parameter on the ask [`ask`](../../reference/collection.md#dbally.Collection.ask) call. + +```python +response = await my_collection.ask( + question="Which LLM should I use?", + llm_options=LocalLLMOptions( + temperature=0.65, + ), +) +``` \ No newline at end of file diff --git a/docs/reference/llms/local.md b/docs/reference/llms/local.md new file mode 100644 index 00000000..cfb40c39 --- /dev/null +++ b/docs/reference/llms/local.md @@ -0,0 +1,7 @@ +# Local + +::: dbally.llms.local.LocalLLM + +::: dbally.llms.clients.local.LocalLLMClient + +::: dbally.llms.clients.local.LocalLLMOptions diff --git a/mkdocs.yml b/mkdocs.yml index 826ffe15..59cdef42 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -24,6 +24,7 @@ nav: - how-to/views/few-shots.md - Using LLMs: - how-to/llms/litellm.md + - how-to/llms/local.md - how-to/llms/custom.md - Using similarity indexes: - how-to/use_custom_similarity_fetcher.md @@ -60,6 +61,7 @@ nav: - LLMs: - reference/llms/index.md - reference/llms/litellm.md + - reference/llms/local.md - reference/prompt.md - Similarity: - reference/similarity/index.md diff --git a/setup.cfg b/setup.cfg index 631f4d4c..7a162c5b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -69,7 +69,10 @@ elasticsearch = gradio = gradio~=4.31.5 gradio_client~=0.16.4 - +local = + accelerate~=0.31.0 + torch~=2.2.1 + transformers~=4.41.2 [options.packages.find] where = src diff --git a/src/dbally/llms/clients/local.py b/src/dbally/llms/clients/local.py new file mode 100644 index 00000000..d77be3f3 --- /dev/null +++ b/src/dbally/llms/clients/local.py @@ -0,0 +1,95 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from dbally.audit.events import LLMEvent +from dbally.llms.clients.base import LLMClient, LLMOptions +from dbally.prompt.template import ChatFormat + +from ..._types import NOT_GIVEN, NotGiven + + +@dataclass +class LocalLLMOptions(LLMOptions): + """ + Dataclass that represents all available LLM call options for the local LLM client. + Each of them is described in the [HuggingFace documentation] + (https://huggingface.co/docs/huggingface_hub/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation). # pylint: disable=line-too-long + """ + + repetition_penalty: Union[Optional[float], NotGiven] = NOT_GIVEN + do_sample: Union[Optional[bool], NotGiven] = NOT_GIVEN + best_of: Union[Optional[int], NotGiven] = NOT_GIVEN + max_new_tokens: Union[Optional[int], NotGiven] = NOT_GIVEN + top_k: Union[Optional[int], NotGiven] = NOT_GIVEN + top_p: Union[Optional[float], NotGiven] = NOT_GIVEN + seed: Union[Optional[int], NotGiven] = NOT_GIVEN + stop_sequences: Union[Optional[List[str]], NotGiven] = NOT_GIVEN + temperature: Union[Optional[float], NotGiven] = NOT_GIVEN + + +class LocalLLMClient(LLMClient[LocalLLMOptions]): + """ + Client for the local LLM that supports Hugging Face models. + """ + + _options_cls = LocalLLMOptions + + def __init__( + self, + model_name: str, + *, + hf_api_key: Optional[str] = None, + ) -> None: + """ + Constructs a new local LLMClient instance. + + Args: + model_name: Name of the model to use. + hf_api_key: The Hugging Face API key for authentication. + """ + + super().__init__(model_name) + + self.model = AutoModelForCausalLM.from_pretrained( + model_name, device_map="auto", torch_dtype=torch.bfloat16, token=hf_api_key + ) + self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_api_key) + + async def call( + self, + conversation: ChatFormat, + options: LocalLLMOptions, + event: LLMEvent, + json_mode: bool = False, + ) -> str: + """ + Makes a call to the local LLM with the provided prompt and options. + + Args: + conversation: List of dicts with "role" and "content" keys, representing the chat history so far. + options: Additional settings used by the LLM. + event: Container with the prompt, LLM response, and call metrics. + json_mode: Force the response to be in JSON format. + + Returns: + Response string from LLM. + """ + + input_ids = self.tokenizer.apply_chat_template( + conversation, add_generation_prompt=True, return_tensors="pt" + ).to(self.model.device) + + outputs = self.model.generate( + input_ids, + eos_token_id=self.tokenizer.eos_token_id, + **options.dict(), + ) + response = outputs[0][input_ids.shape[-1] :] + event.completion_tokens = len(outputs[0][input_ids.shape[-1] :]) + event.prompt_tokens = len(outputs[0][: input_ids.shape[-1]]) + event.total_tokens = input_ids.shape[-1] + decoded_response = self.tokenizer.decode(response, skip_special_tokens=True) + return decoded_response diff --git a/src/dbally/llms/local.py b/src/dbally/llms/local.py new file mode 100644 index 00000000..198513b3 --- /dev/null +++ b/src/dbally/llms/local.py @@ -0,0 +1,60 @@ +from functools import cached_property +from typing import Optional + +from transformers import AutoTokenizer + +from dbally.llms.base import LLM +from dbally.llms.clients.local import LocalLLMClient, LocalLLMOptions +from dbally.prompt.template import PromptTemplate + + +class LocalLLM(LLM[LocalLLMOptions]): + """ + Class for interaction with any LLM available in HuggingFace. + """ + + _options_cls = LocalLLMOptions + + def __init__( + self, + model_name: str, + default_options: Optional[LocalLLMOptions] = None, + *, + api_key: Optional[str] = None, + ) -> None: + """ + Constructs a new local LLM instance. + + Args: + model_name: Name of the model to use. This should be a model from the CausalLM class. + default_options: Default options for the LLM. + api_key: The API key for Hugging Face authentication. + """ + + super().__init__(model_name, default_options) + self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_key) + self.api_key = api_key + + @cached_property + def client(self) -> LocalLLMClient: + """ + Client for the LLM. + + Returns: + The client used to interact with the LLM. + """ + return LocalLLMClient(model_name=self.model_name, hf_api_key=self.api_key) + + def count_tokens(self, prompt: PromptTemplate) -> int: + """ + Counts tokens in the messages. + + Args: + prompt: Messages to count tokens for. + + Returns: + Number of tokens in the messages. + """ + + input_ids = self.tokenizer.apply_chat_template(prompt.chat) + return len(input_ids) From 2c3dccba07c6758402dfeeb63beef2caaa7db61d Mon Sep 17 00:00:00 2001 From: Bartosz Mikulski Date: Thu, 4 Jul 2024 07:48:18 +0200 Subject: [PATCH 48/64] fix(nl-responder): prevent halucination when no data is returned (#68) --- src/dbally/nl_responder/prompts.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/dbally/nl_responder/prompts.py b/src/dbally/nl_responder/prompts.py index f99a8a6c..17f63898 100644 --- a/src/dbally/nl_responder/prompts.py +++ b/src/dbally/nl_responder/prompts.py @@ -19,16 +19,24 @@ def __init__( examples: List[FewShotExample] = None, ) -> None: """ - Constructs a new IQLGenerationPromptFormat instance. + Constructs a new NLResponsePromptFormat instance. Args: question: Question to be asked. - filters: List of filters exposed by the view. + results: List of records, where dictonary keys store column names. examples: List of examples to be injected into the conversation. """ super().__init__(examples) self.question = question - self.results = pd.DataFrame.from_records(results).to_markdown(index=False, headers="keys", tablefmt="psql") + + if results: + self.results = pd.DataFrame.from_records(results).to_markdown(index=False, headers="keys", tablefmt="psql") + else: + self.results = ( + "The query returned 0 rows. The table has no data. Don't halucinate responses. " + "Make sure to inform the user, that there are no data points, that satisfy their question." + "Be brief in your response and remember to answer the question correctly." + ) class QueryExplanationPromptFormat(PromptFormat): From f0c0cbabd93360f3cf36f6ac6652bd3a2674f0fd Mon Sep 17 00:00:00 2001 From: semantic-release Date: Thu, 4 Jul 2024 06:07:51 +0000 Subject: [PATCH 49/64] 0.4.0 Automatically generated by python-semantic-release --- src/dbally/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dbally/__version__.py b/src/dbally/__version__.py index 9dc38a9c..2c7e3d99 100644 --- a/src/dbally/__version__.py +++ b/src/dbally/__version__.py @@ -1,3 +1,3 @@ """Version information.""" -__version__ = "0.3.1" +__version__ = "0.4.0" From 5c0db8f43483469104ffa90fe5e5c9e42943c119 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Hordy=C5=84ski?= Date: Thu, 4 Jul 2024 08:21:15 +0200 Subject: [PATCH 50/64] chore: changelog update after v0.4.0 --- CHANGELOG.md | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ebfeff36..74a24143 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,34 @@ # Changelog +# v0.4.0 (2024-07-04) + +## Feature + +* Added support for local HuggingFace models (#61) ([`953d8a1`](https://github.com/deepsense-ai/db-ally/commit/953d8a1f3c39c624dcc3927e9dfb4df08121df35)) + +* Few-shot examples can be now injected into Structured / Freeform view generation prompts (#42) ([`d482638`](https://github.com/deepsense-ai/db-ally/commit/d4826385e95505c077a1c710feeba68ddcaef20c)) + +## Documentation + +* Added docs explaining how to use AzureOpenAI (#55) ([`d890fec`](https://github.com/deepsense-ai/db-ally/commit/d890fecad38ed11d90a85e6472e64c81c607cf91)) + +## Fix + +* Fixed a bug with natural language responder hallucination when no data is returned (#68) ([`e3fec18`](https://github.com/deepsense-ai/db-ally/commit/e3fec186cca0cace7db4b6e92da5b047a27dfa80)) + +## Chore + +* Project was doggified 🦮 (#67) ([`a4fd411`](https://github.com/deepsense-ai/db-ally/commit/a4fd4115bc7884f5043a6839cfefdd36c97e94ab)) + +* `enhancment` label was replaced by `feature` ([`cd5bf7b`](https://github.com/deepsense-ai/db-ally/commit/cd5bf7b76b97e8d9e46ff872859ccd0ffdef859e)) + +## Refactor + +* Refactor of prompt templates (#66) ([`6510bd8`](https://github.com/deepsense-ai/db-ally/commit/6510bd83923c83c69f082b63c722065fd0e7a3cd)) + +* Refactor audit module (#58) ([`9fd817f`](https://github.com/deepsense-ai/db-ally/commit/9fd817f3955e4e0c61da1cf9be44e9b6ac426c15)) + + ## v0.3.1 (2024-06-17) ### Documentation From 32691372c63e4a202144073201649b92291b6b9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Hordy=C5=84ski?= Date: Thu, 4 Jul 2024 08:23:46 +0200 Subject: [PATCH 51/64] chore: changelog heading fix --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 74a24143..f875efdb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Changelog -# v0.4.0 (2024-07-04) +## v0.4.0 (2024-07-04) ## Feature From 986757a7601dd8f9b630907f48232760acb81aaa Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Thu, 4 Jul 2024 11:05:47 +0200 Subject: [PATCH 52/64] import adjustments --- src/dbally/collection/collection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 10f87aa0..dc60d419 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -12,7 +12,7 @@ from dbally.audit.events import FallbackEvent, RequestEnd, RequestStart from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ExecutionResult, ViewExecutionResult -from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError +from dbally.iql_generator.prompt import UnsupportedQueryError from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.nl_responder.nl_responder import NLResponder From f432b618cf65e32aace5000bf5047c385b8cb362 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Thu, 4 Jul 2024 16:20:45 +0200 Subject: [PATCH 53/64] chained fallback --- docs/concepts/collections.md | 25 +++++++++++++++++++++++++ examples/visualize_fallback_code.py | 23 ++++++++++++++++------- 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/docs/concepts/collections.md b/docs/concepts/collections.md index ef5186ed..bdca28d4 100644 --- a/docs/concepts/collections.md +++ b/docs/concepts/collections.md @@ -25,6 +25,31 @@ my_collection.ask("Find me Italian recipes for soups") In this scenario, the LLM first determines the most suitable view to address the query, and then that view is used to pull the relevant data. +Sometimes the selected view may not be able to answer the question and will raise an error. In such situations, the fallback collections can be used. +```python + llm = LiteLLM(model_name="gpt-3.5-turbo") + user_collection = dbally.create_collection("candidates", llm) + user_collection.add(CandidateView, lambda: CandidateView(candidate_view_with_similarity_store.engine)) + user_collection.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) + user_collection.add(CandidateView, lambda: (candidate_view_with_similarity_store.engine)) + + fallback_collection = dbally.create_collection("freeform candidates", llm) + fallback_collection.add(CandidateFreeformView, lambda: CandidateFreeformView(candidates_freeform.engine)) + user_collection.set_fallback(fallback_collection) +``` +The fallback collection process the same question with declared set of views. The fallback collection could be chained. + +```python + second_fallback_collection = dbally.create_collection("recruitment", llm) + second_fallback_collection.add(RecruitmentView, lambda: RecruitmentView(recruiting_engine)) + + fallback_collection.set_fallback(second_fallback_collection) + +``` + + + + !!! info The result of a query is an [`ExecutionResult`][dbally.collection.results.ExecutionResult] object, which contains the data fetched by the view. It contains a `results` attribute that holds the actual data, structured as a list of dictionaries. The exact structure of these dictionaries depends on the view that was used to fetch the data, which can be obtained by looking at the `view_name` attribute of the `ExecutionResult` object. diff --git a/examples/visualize_fallback_code.py b/examples/visualize_fallback_code.py index af34933a..18d387fc 100644 --- a/examples/visualize_fallback_code.py +++ b/examples/visualize_fallback_code.py @@ -5,6 +5,8 @@ from recruiting.candidate_view_with_similarity_store import CandidateView from recruiting.candidates_freeform import CandidateFreeformView from recruiting.cypher_text2sql_view import SampleText2SQLViewCyphers, create_freeform_memory_engine +from recruiting.db import ENGINE as recruiting_engine +from recruiting.views import RecruitmentView import dbally from dbally.gradio import create_gradio_interface @@ -13,13 +15,20 @@ async def main(): llm = LiteLLM(model_name="gpt-3.5-turbo") - collection1 = dbally.create_collection("candidates", llm) - collection2 = dbally.create_collection("freeform candidates", llm) - collection1.add(CandidateView, lambda: CandidateView(candidate_view_with_similarity_store.engine)) - collection1.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) - collection2.add(CandidateFreeformView, lambda: CandidateFreeformView(candidates_freeform.engine)) - collection1.set_fallback(collection2) - gradio_interface = await create_gradio_interface(user_collection=collection1) + user_collection = dbally.create_collection("candidates", llm) + user_collection.add(CandidateView, lambda: CandidateView(candidate_view_with_similarity_store.engine)) + user_collection.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) + user_collection.add(CandidateView, lambda: (candidate_view_with_similarity_store.engine)) + + fallback_collection = dbally.create_collection("freeform candidates", llm) + fallback_collection.add(CandidateFreeformView, lambda: CandidateFreeformView(candidates_freeform.engine)) + + second_fallback_collection = dbally.create_collection("recruitment", llm) + second_fallback_collection.add(RecruitmentView, lambda: RecruitmentView(recruiting_engine)) + + user_collection.set_fallback(fallback_collection).set_fallback(second_fallback_collection) + + gradio_interface = await create_gradio_interface(user_collection=user_collection) gradio_interface.launch() From 05704840d831203ea4df8c99c899fa81a0512802 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Fri, 5 Jul 2024 08:50:07 +0200 Subject: [PATCH 54/64] collections --- examples/visualize_fallback_code.py | 1 - src/dbally/collection/collection.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/visualize_fallback_code.py b/examples/visualize_fallback_code.py index 18d387fc..37211e69 100644 --- a/examples/visualize_fallback_code.py +++ b/examples/visualize_fallback_code.py @@ -18,7 +18,6 @@ async def main(): user_collection = dbally.create_collection("candidates", llm) user_collection.add(CandidateView, lambda: CandidateView(candidate_view_with_similarity_store.engine)) user_collection.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) - user_collection.add(CandidateView, lambda: (candidate_view_with_similarity_store.engine)) fallback_collection = dbally.create_collection("freeform candidates", llm) fallback_collection.add(CandidateFreeformView, lambda: CandidateFreeformView(candidates_freeform.engine)) diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 272ddb82..1eccbf7f 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -374,7 +374,7 @@ async def ask( view_result = await self._ask_view(selected_view_name, question, event_tracker, llm_options, dry_run) end_time_view = time.monotonic() - natural_response = ( + natural_response = await ( self._generate_textual_response(view_result, question, event_tracker, llm_options) if not dry_run and return_natural_response else "" From d36c93948b1c165ecfc53b16275cda5dfa3ecb41 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Fri, 5 Jul 2024 13:34:03 +0200 Subject: [PATCH 55/64] override global events --- examples/recruiting/views.py | 2 +- src/dbally/__init__.py | 2 +- .../audit/event_handlers/cli_event_handler.py | 5 ++-- src/dbally/audit/events.py | 1 + src/dbally/collection/collection.py | 28 +++++++++++++------ 5 files changed, 25 insertions(+), 13 deletions(-) diff --git a/examples/recruiting/views.py b/examples/recruiting/views.py index 773d3f62..9765ba51 100644 --- a/examples/recruiting/views.py +++ b/examples/recruiting/views.py @@ -75,7 +75,7 @@ def is_available_within_months( # pylint: disable=W0602, C0116, W9011 end = start + relativedelta(months=months) return Candidate.available_from.between(start, end) - def list_few_shots(self) -> List[FewShotExample]: # pylint: disable=W9011 + def list_few_shots(self) -> List[FewShotExample]: # pylint: disable=W9011, C0116 return [ FewShotExample( "Which candidates studied at University of Toronto?", diff --git a/src/dbally/__init__.py b/src/dbally/__init__.py index f17d840e..7aca9e48 100644 --- a/src/dbally/__init__.py +++ b/src/dbally/__init__.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, List -from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError +from dbally.collection.exceptions import NoViewFoundError from dbally.collection.results import ExecutionResult from dbally.views import decorators from dbally.views.methods_base import MethodsBaseView diff --git a/src/dbally/audit/event_handlers/cli_event_handler.py b/src/dbally/audit/event_handlers/cli_event_handler.py index 50c2e989..57e30545 100644 --- a/src/dbally/audit/event_handlers/cli_event_handler.py +++ b/src/dbally/audit/event_handlers/cli_event_handler.py @@ -13,7 +13,7 @@ pprint = print # type: ignore from dbally.audit.event_handlers.base import EventHandler -from dbally.audit.events import Event, LLMEvent, RequestEnd, RequestStart, SimilarityEvent, FallbackEvent +from dbally.audit.events import Event, FallbackEvent, LLMEvent, RequestEnd, RequestStart, SimilarityEvent _RICH_FORMATING_KEYWORD_SET = {"green", "orange", "grey", "bold", "cyan"} _RICH_FORMATING_PATTERN = rf"\[.*({'|'.join(_RICH_FORMATING_KEYWORD_SET)}).*\]" @@ -96,12 +96,13 @@ async def event_start(self, event: Event, request_context: None) -> None: ) elif isinstance(event, FallbackEvent): self._print_syntax( - "[grey53]\n=======================================\n" + f"[grey53]\n=======================================\n" "[grey53]=======================================\n" f"[orange bold]Fallback event starts \n" f"[orange bold]Triggering collection: [grey53]{event.triggering_collection_name}\n" f"[orange bold]Triggering view name: [grey53]{event.triggering_view_name}\n" f"[orange bold]Fallback collection name: [grey53]{event.fallback_collection_name}\n" + f"[orange bold]Override event handlers: [grey53]{event.override_global_event}\n" f"[orange bold]Error description: [grey53]{event.error_description}\n" "[grey53]=======================================\n" "[grey53]=======================================\n" diff --git a/src/dbally/audit/events.py b/src/dbally/audit/events.py index 3bb23e17..cb9a01e2 100644 --- a/src/dbally/audit/events.py +++ b/src/dbally/audit/events.py @@ -51,6 +51,7 @@ class FallbackEvent(Event): triggering_view_name: str fallback_collection_name: str error_description: str + override_global_event: False @dataclass diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 1eccbf7f..83338646 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -67,9 +67,9 @@ def __init__( self._builders: Dict[str, Callable[[], BaseView]] = {} self._view_selector = view_selector self._nl_responder = nl_responder - self._event_handlers = event_handlers or dbally.event_handlers self._llm = llm self._fallback_collection: Optional[Collection] = fallback_collection + self._event_handlers = event_handlers or dbally.event_handlers T = TypeVar("T", bound=BaseView) @@ -132,6 +132,7 @@ def set_fallback(self, fallback_collection: "Collection") -> "Collection": The fallback collection to create chains call """ self._fallback_collection = fallback_collection + return fallback_collection def __rshift__(self, fallback_collection: "Collection"): @@ -303,21 +304,26 @@ async def _handle_fallback( """ if self._fallback_collection: - event = FallbackEvent( + override_global_event = ( + self._fallback_collection._event_handlers != dbally.event_handlers # pylint: disable=W0212 + ) + + fallback_event = FallbackEvent( triggering_collection_name=self.name, triggering_view_name=selected_view_name, fallback_collection_name=self._fallback_collection.name, error_description=repr(caught_exception), + override_global_event=override_global_event, ) - async with event_tracker.track_event(event) as span: + async with event_tracker.track_event(fallback_event) as span: result = await self._fallback_collection.ask( question=question, dry_run=dry_run, return_natural_response=return_natural_response, llm_options=llm_options, ) - span(event) + span(fallback_event) return result raise caught_exception @@ -360,13 +366,15 @@ async def ask( """ if not event_tracker: + is_fallback_call = False event_tracker = EventTracker.initialize_with_handlers(self._event_handlers) + await event_tracker.request_start(RequestStart(question=question, collection_name=self.name)) + else: + is_fallback_call = True selected_view_name = "" try: - await event_tracker.request_start(RequestStart(question=question, collection_name=self.name)) - start_time = time.monotonic() selected_view_name = await self._select_view(question, event_tracker, llm_options) @@ -374,8 +382,8 @@ async def ask( view_result = await self._ask_view(selected_view_name, question, event_tracker, llm_options, dry_run) end_time_view = time.monotonic() - natural_response = await ( - self._generate_textual_response(view_result, question, event_tracker, llm_options) + natural_response = ( + await self._generate_textual_response(view_result, question, event_tracker, llm_options) if not dry_run and return_natural_response else "" ) @@ -399,7 +407,9 @@ async def ask( event_tracker=event_tracker, caught_exception=caught_exception, ) - await event_tracker.request_end(RequestEnd(result=result)) + + if not is_fallback_call: + await event_tracker.request_end(RequestEnd(result=result)) return result From bbe933b29ffbb411449da5d1489e72366cd7eddc Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Fri, 5 Jul 2024 13:48:08 +0200 Subject: [PATCH 56/64] MR merge alignment --- src/dbally/collection/collection.py | 56 ----------------------------- src/dbally/collection/results.py | 4 +-- 2 files changed, 2 insertions(+), 58 deletions(-) diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 83338646..51f146d2 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -449,59 +449,3 @@ async def update_similarity_indexes(self) -> None: if failed_indexes: failed_locations = [loc for index in failed_indexes for loc in indexes[index]] raise IndexUpdateError(failed_indexes, failed_locations) - - -def create_collection( - name: str, - llm: LLM, - event_handlers: Optional[List[EventHandler]] = None, - view_selector: Optional[ViewSelector] = None, - nl_responder: Optional[NLResponder] = None, -) -> Collection: - """ - Create a new [Collection](collection.md) that is a container for registering views and the\ - main entrypoint to db-ally features. - - Unlike instantiating a [Collection][dbally.Collection] directly, this function\ - provides a set of default values for various dependencies like LLM client, view selector,\ - IQL generator, and NL responder. - - ##Example - - ```python - from dbally import create_collection - from dbally.llms.litellm import LiteLLM - - collection = create_collection("my_collection", llm=LiteLLM()) - ``` - - Args: - name: Name of the collection is available for [Event handlers](event_handlers/index.md) and is\ - used to distinguish different db-ally runs. - llm: LLM used by the collection to generate responses for natural language queries. - event_handlers: Event handlers used by the collection during query executions. Can be used to\ - log events as [CLIEventHandler](event_handlers/cli_handler.md) or to validate system performance as\ - [LangSmithEventHandler](event_handlers/langsmith_handler.md). If provided, this parameter overrides the - global dbally.event_handlers_list - view_selector: View selector used by the collection to select the best view for the given query.\ - If None, a new instance of [LLMViewSelector][dbally.view_selection.llm_view_selector.LLMViewSelector]\ - will be used. - nl_responder: NL responder used by the collection to respond to natural language queries. If None,\ - a new instance of [NLResponder][dbally.nl_responder.nl_responder.NLResponder] will be used. - - Returns: - a new instance of db-ally Collection - - Raises: - ValueError: if default LLM client is not configured - """ - view_selector = view_selector or LLMViewSelector(llm=llm) - nl_responder = nl_responder or NLResponder(llm=llm) - - return Collection( - name, - nl_responder=nl_responder, - view_selector=view_selector, - llm=llm, - event_handlers=event_handlers, - ) diff --git a/src/dbally/collection/results.py b/src/dbally/collection/results.py index 3a427505..b33cf5e3 100644 --- a/src/dbally/collection/results.py +++ b/src/dbally/collection/results.py @@ -39,6 +39,6 @@ class ExecutionResult: results: List[Dict[str, Any]] context: Dict[str, Any] execution_time: float - execution_time_view: float = 0 - view_name: str = "" + execution_time_view: float + view_name: str textual_response: Optional[str] = None From feed5a8e04908464b256b863dea4d9a91d36e7cc Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Mon, 8 Jul 2024 10:00:01 +0200 Subject: [PATCH 57/64] collection fixups --- src/dbally/collection/__init__.py | 3 +-- src/dbally/collection/collection.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/dbally/collection/__init__.py b/src/dbally/collection/__init__.py index db041159..66eea8fe 100644 --- a/src/dbally/collection/__init__.py +++ b/src/dbally/collection/__init__.py @@ -1,9 +1,8 @@ -from dbally.collection.collection import Collection, create_collection +from dbally.collection.collection import Collection from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ExecutionResult, ViewExecutionResult __all__ = [ - "create_collection", "Collection", "ExecutionResult", "ViewExecutionResult", diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 51f146d2..642701da 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -16,7 +16,6 @@ from dbally.llms.clients.base import LLMOptions from dbally.nl_responder.nl_responder import NLResponder from dbally.similarity.index import AbstractSimilarityIndex -from dbally.view_selection import LLMViewSelector from dbally.view_selection.base import ViewSelector from dbally.views.base import BaseView, IndexLocation From 6465f1ed2edf7d51ad4f92ae8786d57cc09d1318 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Mon, 8 Jul 2024 14:06:50 +0200 Subject: [PATCH 58/64] collection polishing --- src/dbally/collection/collection.py | 73 ++++++++++++++------------- src/dbally/gradio/gradio_interface.py | 10 ++-- 2 files changed, 40 insertions(+), 43 deletions(-) diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 642701da..020ec02c 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -282,7 +282,7 @@ async def _handle_fallback( selected_view_name: str, event_tracker: EventTracker, caught_exception: HANDLED_EXCEPTION_TYPES, - ): + ) -> ExecutionResult: """ Handle fallback if the main query fails. @@ -296,36 +296,31 @@ async def _handle_fallback( caught_exception: The exception that was caught. Returns: - Any: The result from the fallback collection. + The result from the fallback collection. - Raises: - Exception: If there is no fallback collection or if an error occurs in the fallback. """ - if self._fallback_collection: - override_global_event = ( - self._fallback_collection._event_handlers != dbally.event_handlers # pylint: disable=W0212 - ) - - fallback_event = FallbackEvent( - triggering_collection_name=self.name, - triggering_view_name=selected_view_name, - fallback_collection_name=self._fallback_collection.name, - error_description=repr(caught_exception), - override_global_event=override_global_event, - ) + override_global_event = ( + self._fallback_collection._event_handlers != self._event_handlers # pylint: disable=W0212 + ) - async with event_tracker.track_event(fallback_event) as span: - result = await self._fallback_collection.ask( - question=question, - dry_run=dry_run, - return_natural_response=return_natural_response, - llm_options=llm_options, - ) - span(fallback_event) - return result + fallback_event = FallbackEvent( + triggering_collection_name=self.name, + triggering_view_name=selected_view_name, + fallback_collection_name=self._fallback_collection.name, + error_description=repr(caught_exception), + override_global_event=override_global_event, + ) - raise caught_exception + async with event_tracker.track_event(fallback_event) as span: + result = await self._fallback_collection.ask( + question=question, + dry_run=dry_run, + return_natural_response=return_natural_response, + llm_options=llm_options, + ) + span(fallback_event) + return result async def ask( self, @@ -334,7 +329,7 @@ async def ask( return_natural_response: bool = False, llm_options: Optional[LLMOptions] = None, event_tracker: Optional[EventTracker] = None, - ) -> Optional[ExecutionResult]: + ) -> ExecutionResult: """ Ask question in a text form and retrieve the answer based on the available views. @@ -362,6 +357,9 @@ async def ask( ValueError: if collection is empty IQLError: if incorrect IQL was generated `n_retries` amount of times. ValueError: if incorrect IQL was generated `n_retries` amount of times. + NoViewFoundError: if question does not match to any registered view, + UnsupportedQueryError: if the question could not be answered + IndexUpdateError: if index update failed """ if not event_tracker: @@ -397,15 +395,18 @@ async def ask( ) except HANDLED_EXCEPTION_TYPES as caught_exception: - result = await self._handle_fallback( - question=question, - dry_run=dry_run, - return_natural_response=return_natural_response, - llm_options=llm_options, - selected_view_name=selected_view_name, - event_tracker=event_tracker, - caught_exception=caught_exception, - ) + if self._fallback_collection: + result = await self._handle_fallback( + question=question, + dry_run=dry_run, + return_natural_response=return_natural_response, + llm_options=llm_options, + selected_view_name=selected_view_name, + event_tracker=event_tracker, + caught_exception=caught_exception, + ) + else: + raise caught_exception if not is_fallback_call: await event_tracker.request_end(RequestEnd(result=result)) diff --git a/src/dbally/gradio/gradio_interface.py b/src/dbally/gradio/gradio_interface.py index 5f023659..4a8de2b4 100644 --- a/src/dbally/gradio/gradio_interface.py +++ b/src/dbally/gradio/gradio_interface.py @@ -149,13 +149,9 @@ async def _ui_ask_query( execution_result = await self.collection.ask( question=question_query, return_natural_response=natural_language_flag ) - if execution_result: - generated_query = str(execution_result.context) - data = self._load_results_into_dataframe(execution_result.results) - textual_response = str(execution_result.textual_response) if natural_language_flag else textual_response - else: - generated_query = "No results generated" - data = pd.DataFrame() + generated_query = str(execution_result.context) + data = self._load_results_into_dataframe(execution_result.results) + textual_response = str(execution_result.textual_response) if natural_language_flag else textual_response except UnsupportedQueryError: generated_query = {"Query": "unsupported"} From 4c81e39ce9288416e581784b79b965a0fca6d048 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Mon, 8 Jul 2024 14:30:04 +0200 Subject: [PATCH 59/64] moving events to print --- docs/concepts/collections.md | 2 +- src/dbally/audit/event_handlers/cli_event_handler.py | 1 - src/dbally/audit/events.py | 1 - 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/concepts/collections.md b/docs/concepts/collections.md index bdca28d4..f64035bf 100644 --- a/docs/concepts/collections.md +++ b/docs/concepts/collections.md @@ -25,7 +25,7 @@ my_collection.ask("Find me Italian recipes for soups") In this scenario, the LLM first determines the most suitable view to address the query, and then that view is used to pull the relevant data. -Sometimes the selected view may not be able to answer the question and will raise an error. In such situations, the fallback collections can be used. +Sometimes, the selected view may not be able to answer the question and will raise an error. In such situations, the fallback collections can be used. This will cause a next view selection, but from the fallback collection. ```python llm = LiteLLM(model_name="gpt-3.5-turbo") user_collection = dbally.create_collection("candidates", llm) diff --git a/src/dbally/audit/event_handlers/cli_event_handler.py b/src/dbally/audit/event_handlers/cli_event_handler.py index 57e30545..c375968d 100644 --- a/src/dbally/audit/event_handlers/cli_event_handler.py +++ b/src/dbally/audit/event_handlers/cli_event_handler.py @@ -102,7 +102,6 @@ async def event_start(self, event: Event, request_context: None) -> None: f"[orange bold]Triggering collection: [grey53]{event.triggering_collection_name}\n" f"[orange bold]Triggering view name: [grey53]{event.triggering_view_name}\n" f"[orange bold]Fallback collection name: [grey53]{event.fallback_collection_name}\n" - f"[orange bold]Override event handlers: [grey53]{event.override_global_event}\n" f"[orange bold]Error description: [grey53]{event.error_description}\n" "[grey53]=======================================\n" "[grey53]=======================================\n" diff --git a/src/dbally/audit/events.py b/src/dbally/audit/events.py index cb9a01e2..3bb23e17 100644 --- a/src/dbally/audit/events.py +++ b/src/dbally/audit/events.py @@ -51,7 +51,6 @@ class FallbackEvent(Event): triggering_view_name: str fallback_collection_name: str error_description: str - override_global_event: False @dataclass From a088ca37116e31bd5e1b0b3f0721f4f9ee8b4735 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Mon, 8 Jul 2024 14:33:50 +0200 Subject: [PATCH 60/64] display improvementS --- .../audit/event_handlers/cli_event_handler.py | 2 +- src/dbally/collection/collection.py | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/dbally/audit/event_handlers/cli_event_handler.py b/src/dbally/audit/event_handlers/cli_event_handler.py index c375968d..5c97a016 100644 --- a/src/dbally/audit/event_handlers/cli_event_handler.py +++ b/src/dbally/audit/event_handlers/cli_event_handler.py @@ -101,8 +101,8 @@ async def event_start(self, event: Event, request_context: None) -> None: f"[orange bold]Fallback event starts \n" f"[orange bold]Triggering collection: [grey53]{event.triggering_collection_name}\n" f"[orange bold]Triggering view name: [grey53]{event.triggering_view_name}\n" - f"[orange bold]Fallback collection name: [grey53]{event.fallback_collection_name}\n" f"[orange bold]Error description: [grey53]{event.error_description}\n" + f"[orange bold]Fallback collection name: [grey53]{event.fallback_collection_name}\n" "[grey53]=======================================\n" "[grey53]=======================================\n" ) diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 020ec02c..2917e453 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -1,5 +1,6 @@ import asyncio import inspect +import logging import textwrap import time from collections import defaultdict @@ -130,8 +131,13 @@ def set_fallback(self, fallback_collection: "Collection") -> "Collection": Returns: The fallback collection to create chains call """ - self._fallback_collection = fallback_collection + if fallback_collection._event_handlers != self._event_handlers: # pylint: disable=W0212 + logging.warning( + "Global event handlers are override by fallback. New event handlers are: %s", + fallback_collection._event_handlers, # pylint: disable=W0212 + ) + self._fallback_collection = fallback_collection return fallback_collection def __rshift__(self, fallback_collection: "Collection"): @@ -300,16 +306,11 @@ async def _handle_fallback( """ - override_global_event = ( - self._fallback_collection._event_handlers != self._event_handlers # pylint: disable=W0212 - ) - fallback_event = FallbackEvent( triggering_collection_name=self.name, triggering_view_name=selected_view_name, fallback_collection_name=self._fallback_collection.name, error_description=repr(caught_exception), - override_global_event=override_global_event, ) async with event_tracker.track_event(fallback_event) as span: From 2fc55fae77f3b2370bdf5b36ece631f1741aae50 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Tue, 9 Jul 2024 21:32:16 +0200 Subject: [PATCH 61/64] documentation update --- docs/concepts/collections.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/concepts/collections.md b/docs/concepts/collections.md index f64035bf..2959de57 100644 --- a/docs/concepts/collections.md +++ b/docs/concepts/collections.md @@ -25,7 +25,9 @@ my_collection.ask("Find me Italian recipes for soups") In this scenario, the LLM first determines the most suitable view to address the query, and then that view is used to pull the relevant data. -Sometimes, the selected view may not be able to answer the question and will raise an error. In such situations, the fallback collections can be used. This will cause a next view selection, but from the fallback collection. +Sometimes, the selected view does not match question (LLM select wrong view) and will raise an error. In such situations, the fallback collections can be used. +This will cause a next view selection, but from the fallback collection. + ```python llm = LiteLLM(model_name="gpt-3.5-turbo") user_collection = dbally.create_collection("candidates", llm) From bc018d3bbfcf69f3d46f6230b612d00e2cdbde9e Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Mon, 15 Jul 2024 13:07:25 +0200 Subject: [PATCH 62/64] collection enhancment --- src/dbally/collection/collection.py | 1 + tests/unit/test_fallback_collection.py | 0 2 files changed, 1 insertion(+) create mode 100644 tests/unit/test_fallback_collection.py diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 2917e453..887fbac5 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -319,6 +319,7 @@ async def _handle_fallback( dry_run=dry_run, return_natural_response=return_natural_response, llm_options=llm_options, + event_tracker=event_tracker, ) span(fallback_event) return result diff --git a/tests/unit/test_fallback_collection.py b/tests/unit/test_fallback_collection.py new file mode 100644 index 00000000..e69de29b From 67b51f3bf3d64557e45c79fa3be89a490567432d Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Wed, 17 Jul 2024 11:57:46 +0200 Subject: [PATCH 63/64] tests --- examples/visualize_fallback_code.py | 5 +- .../event_handlers/otel_event_handler.py | 7 +- src/dbally/collection/collection.py | 36 +++- tests/unit/test_fallback_collection.py | 172 ++++++++++++++++++ 4 files changed, 210 insertions(+), 10 deletions(-) diff --git a/examples/visualize_fallback_code.py b/examples/visualize_fallback_code.py index 37211e69..f70f1eed 100644 --- a/examples/visualize_fallback_code.py +++ b/examples/visualize_fallback_code.py @@ -9,6 +9,7 @@ from recruiting.views import RecruitmentView import dbally +from dbally.audit import CLIEventHandler, OtelEventHandler from dbally.gradio import create_gradio_interface from dbally.llms.litellm import LiteLLM @@ -19,10 +20,10 @@ async def main(): user_collection.add(CandidateView, lambda: CandidateView(candidate_view_with_similarity_store.engine)) user_collection.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) - fallback_collection = dbally.create_collection("freeform candidates", llm) + fallback_collection = dbally.create_collection("freeform candidates", llm, event_handlers=[OtelEventHandler()]) fallback_collection.add(CandidateFreeformView, lambda: CandidateFreeformView(candidates_freeform.engine)) - second_fallback_collection = dbally.create_collection("recruitment", llm) + second_fallback_collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()]) second_fallback_collection.add(RecruitmentView, lambda: RecruitmentView(recruiting_engine)) user_collection.set_fallback(fallback_collection).set_fallback(second_fallback_collection) diff --git a/src/dbally/audit/event_handlers/otel_event_handler.py b/src/dbally/audit/event_handlers/otel_event_handler.py index 00a106a2..91a709f5 100644 --- a/src/dbally/audit/event_handlers/otel_event_handler.py +++ b/src/dbally/audit/event_handlers/otel_event_handler.py @@ -7,7 +7,7 @@ from opentelemetry.util.types import AttributeValue from dbally.audit.event_handlers.base import EventHandler -from dbally.audit.events import Event, LLMEvent, RequestEnd, RequestStart, SimilarityEvent +from dbally.audit.events import Event, FallbackEvent, LLMEvent, RequestEnd, RequestStart, SimilarityEvent TRACER_NAME = "db-ally.events" FORBIDDEN_CONTEXT_KEYS = {"filter_mask"} @@ -172,8 +172,11 @@ async def event_start(self, event: Event, request_context: SpanHandler) -> SpanH .set("db-ally.similarity.fetcher", event.fetcher) .set_input("db-ally.similarity.input", event.input_value) ) + if isinstance(event, FallbackEvent): + with self._new_child_span(request_context, "fallback") as span: + return self._handle_span(span).set("db-ally.error_description", event.error_description) - raise ValueError(f"Unsuported event: {type(event)}") + raise ValueError(f"Unsupported event: {type(event)}") async def event_end(self, event: Optional[Event], request_context: SpanHandler, event_context: SpanHandler) -> None: """ diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 887fbac5..2bd5d88e 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -131,13 +131,13 @@ def set_fallback(self, fallback_collection: "Collection") -> "Collection": Returns: The fallback collection to create chains call """ + self._fallback_collection = fallback_collection if fallback_collection._event_handlers != self._event_handlers: # pylint: disable=W0212 logging.warning( - "Global event handlers are override by fallback. New event handlers are: %s", + "Global event handlers are modified by fallback. New event handlers are: %s", fallback_collection._event_handlers, # pylint: disable=W0212 ) - self._fallback_collection = fallback_collection return fallback_collection def __rshift__(self, fallback_collection: "Collection"): @@ -279,6 +279,22 @@ async def _generate_textual_response( ) return textual_response + def get_all_event_handlers(self) -> List[EventHandler]: + """ + Retrieves all event handlers, including those from a fallback collection if available. + + This method returns a list of event handlers. If there is no fallback collection, + it simply returns the event handlers stored in the current object. If a fallback + collection is available, it combines the event handlers from both the current object + and the fallback collection, ensuring no duplicates. + + Returns: + A list of event handlers. + """ + if not self._fallback_collection: + return self._event_handlers + return list(set(self._event_handlers).union(self._fallback_collection.get_all_event_handlers())) + async def _handle_fallback( self, question: str, @@ -363,10 +379,10 @@ async def ask( UnsupportedQueryError: if the question could not be answered IndexUpdateError: if index update failed """ - if not event_tracker: is_fallback_call = False - event_tracker = EventTracker.initialize_with_handlers(self._event_handlers) + event_handlers = self.get_all_event_handlers() + event_tracker = EventTracker.initialize_with_handlers(event_handlers) await event_tracker.request_start(RequestStart(question=question, collection_name=self.name)) else: is_fallback_call = True @@ -375,10 +391,18 @@ async def ask( try: start_time = time.monotonic() - selected_view_name = await self._select_view(question, event_tracker, llm_options) + selected_view_name = await self._select_view( + question=question, event_tracker=event_tracker, llm_options=llm_options + ) start_time_view = time.monotonic() - view_result = await self._ask_view(selected_view_name, question, event_tracker, llm_options, dry_run) + view_result = await self._ask_view( + selected_view_name=selected_view_name, + question=question, + event_tracker=event_tracker, + llm_options=llm_options, + dry_run=dry_run, + ) end_time_view = time.monotonic() natural_response = ( diff --git a/tests/unit/test_fallback_collection.py b/tests/unit/test_fallback_collection.py index e69de29b..137581b6 100644 --- a/tests/unit/test_fallback_collection.py +++ b/tests/unit/test_fallback_collection.py @@ -0,0 +1,172 @@ +from typing import List, Optional +from unittest.mock import AsyncMock, Mock + +import pytest +from sqlalchemy import create_engine + +import dbally +from dbally.audit import CLIEventHandler, EventTracker, OtelEventHandler +from dbally.audit.event_handlers.buffer_event_handler import BufferEventHandler +from dbally.collection import Collection, ViewExecutionResult +from dbally.iql_generator.prompt import UnsupportedQueryError +from dbally.llms import LLM +from dbally.llms.clients import LLMOptions +from dbally.views.freeform.text2sql import BaseText2SQLView, ColumnConfig, TableConfig +from tests.unit.mocks import MockIQLGenerator, MockLLM, MockViewBase, MockViewSelector + +engine = create_engine("sqlite://", echo=True) + + +class MyText2SqlView(BaseText2SQLView): + """ + A Text2SQL view for the example. + """ + + def get_tables(self) -> List[TableConfig]: + return [ + TableConfig( + name="mock_table", + columns=[ + ColumnConfig("mock_field1", "SERIAL PRIMARY KEY"), + ColumnConfig("mock_field2", "VARCHAR(255)"), + ], + ), + ] + + async def ask( + self, + query: str, + llm: LLM, + event_tracker: EventTracker, + n_retries: int = 3, + dry_run: bool = False, + llm_options: Optional[LLMOptions] = None, + ) -> ViewExecutionResult: + return ViewExecutionResult( + results=[{"mock_result": "fallback_result"}], context={"mock_context": "fallback_context"} + ) + + +class MockView1(MockViewBase): + """ + Mock view 1 + """ + + def execute(self, dry_run=False) -> ViewExecutionResult: + return ViewExecutionResult(results=[{"foo": "bar"}], context={"baz": "qux"}) + + def get_iql_generator(self, *_, **__) -> MockIQLGenerator: + raise UnsupportedQueryError + + +class MockView2(MockViewBase): + """ + Mock view 2 + """ + + +@pytest.fixture(name="base_collection") +def mock_base_collection() -> Collection: + """ + Returns a collection with two mock views + """ + collection = dbally.create_collection( + "foo", + llm=MockLLM(), + view_selector=MockViewSelector("MockView1"), + nl_responder=AsyncMock(), + ) + collection.add(MockView1) + collection.add(MockView2) + return collection + + +@pytest.fixture(name="fallback_collection") +def mock_fallback_collection() -> Collection: + """ + Returns a collection with two mock views + """ + collection = dbally.create_collection( + "fallback_foo", + llm=MockLLM(), + view_selector=MockViewSelector("MyText2SqlView"), + nl_responder=AsyncMock(), + ) + collection.add(MyText2SqlView, lambda: MyText2SqlView(engine)) + return collection + + +async def test_no_fallback_collection(base_collection: Collection, fallback_collection: Collection): + with pytest.raises(UnsupportedQueryError) as exc_info: + result = await base_collection.ask("Mock fallback question") + print(result) + print(exc_info) + + +async def test_fallback_collection(base_collection: Collection, fallback_collection: Collection): + base_collection.set_fallback(fallback_collection) + result = await base_collection.ask("Mock fallback question") + assert result.results == [{"mock_result": "fallback_result"}] + assert result.context == {"mock_context": "fallback_context"} + + +def test_get_all_event_handlers_no_fallback(): + handler1 = CLIEventHandler() + handler2 = BufferEventHandler() + + collection = Collection( + name="test_collection", + llm=MockLLM(), + nl_responder=AsyncMock(), + view_selector=Mock(), + event_handlers=[handler1, handler2], + ) + + result = collection.get_all_event_handlers() + + assert result == [handler1, handler2] + + +def test_get_all_event_handlers_with_fallback(): + handler1 = CLIEventHandler() + handler2 = BufferEventHandler() + handler3 = OtelEventHandler() + + fallback_collection = Collection( + name="fallback_collection", view_selector=Mock(), llm=Mock(), nl_responder=Mock(), event_handlers=[handler3] + ) + + collection = Collection( + name="test_collection", + view_selector=Mock(), + llm=MockLLM(), + nl_responder=AsyncMock(), + event_handlers=[handler1, handler2], + fallback_collection=fallback_collection, + ) + + result = collection.get_all_event_handlers() + + assert set(result) == {handler1, handler2, handler3} + + +def test_get_all_event_handlers_with_duplicates(): + handler1 = CLIEventHandler() + handler2 = BufferEventHandler() + + fallback_collection = Collection( + name="fallback_collection", view_selector=Mock(), llm=Mock(), nl_responder=Mock(), event_handlers=[handler2] + ) + + collection = Collection( + name="test_collection", + view_selector=Mock(), + llm=Mock(), + nl_responder=Mock(), + event_handlers=[handler1, handler2], + fallback_collection=fallback_collection, + ) + + result = collection.get_all_event_handlers() + + assert set(result) == {handler1, handler2} From bdd3c28c8c007fd58949f603cdcfeff82e71a0bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Hordy=C5=84ski?= <26008518+mhordynski@users.noreply.github.com> Date: Thu, 18 Jul 2024 10:09:31 +0200 Subject: [PATCH 64/64] review: fallback collections (#75) --- src/dbally/collection/collection.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 2bd5d88e..542f78e4 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -134,8 +134,8 @@ def set_fallback(self, fallback_collection: "Collection") -> "Collection": self._fallback_collection = fallback_collection if fallback_collection._event_handlers != self._event_handlers: # pylint: disable=W0212 logging.warning( - "Global event handlers are modified by fallback. New event handlers are: %s", - fallback_collection._event_handlers, # pylint: disable=W0212 + "Event handlers of the fallback collection are different from the base collection. " + "Continuity of the audit trail is not guaranteed.", ) return fallback_collection @@ -303,7 +303,7 @@ async def _handle_fallback( llm_options: Optional[LLMOptions], selected_view_name: str, event_tracker: EventTracker, - caught_exception: HANDLED_EXCEPTION_TYPES, + caught_exception: Exception, ) -> ExecutionResult: """ Handle fallback if the main query fails. @@ -321,6 +321,8 @@ async def _handle_fallback( The result from the fallback collection. """ + if not self._fallback_collection: + raise caught_exception fallback_event = FallbackEvent( triggering_collection_name=self.name,