From 4041684c27d5532f245eb0ce699413cfedc8bfca Mon Sep 17 00:00:00 2001 From: Piotr Mardziel Date: Wed, 18 Oct 2023 14:32:46 -0700 Subject: [PATCH] threading robustness and feedback retrieval (#480) * split off work on threading issues * work on dummy example * prototyping various thread robustness solutions * work * working on threading and feedback results * more ignores * nevermind that last gitignore addition * remove unneeded * added feedback result retrieval into langchain quickstart * don't use submit inside feedback functions * what the last thing said --- .vscode/settings.json | 3 +- .../examples/experimental/dummy_example.ipynb | 164 ++++++++++++++-- .../end2end_apps/custom_app/custom_app.py | 14 ++ .../trubot/trubot_populate_db.ipynb | 9 +- .../quickstart/langchain_quickstart.ipynb | 45 ++++- trulens_eval/trulens_eval/app.py | 86 ++++++--- .../trulens_eval/database/sqlalchemy_db.py | 26 +++ .../trulens_eval/feedback/feedback.py | 46 +++-- .../feedback/provider/endpoint/base.py | 106 ++++++++--- .../trulens_eval/feedback/provider/hugs.py | 61 ++++-- trulens_eval/trulens_eval/schema.py | 23 ++- trulens_eval/trulens_eval/tru.py | 113 ++++++----- trulens_eval/trulens_eval/utils/langchain.py | 6 +- trulens_eval/trulens_eval/utils/llama.py | 13 +- trulens_eval/trulens_eval/utils/python.py | 35 +++- trulens_eval/trulens_eval/utils/threading.py | 179 ++++++++++-------- 16 files changed, 676 insertions(+), 253 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 0aab10655..e1d685831 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,3 +1,4 @@ { - "python.formatting.provider": "yapf" + "python.formatting.provider": "yapf", + "python.analysis.typeCheckingMode": "basic" } diff --git a/trulens_eval/examples/experimental/dummy_example.ipynb b/trulens_eval/examples/experimental/dummy_example.ipynb index 13e4a5dc4..c3d83fac6 100644 --- a/trulens_eval/examples/experimental/dummy_example.ipynb +++ b/trulens_eval/examples/experimental/dummy_example.ipynb @@ -4,13 +4,20 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Dummy Example\n", + "# Dummy Provider Example and High Volume Robustness Testing\n", "\n", - "This notebook shows the use of the dummy feedback function provider which\n", - "behaves like the huggingface provider except it does not actually perform any\n", - "network calls and just produces constant results. It can be used to prototype\n", - "feedback function wiring for your apps before invoking potentially slow (to\n", - "run/to load) feedback functions." + "This notebook has two purposes: \n", + "\n", + "- Demostrate the dummy feedback function provider which behaves like the\n", + " huggingface provider except it does not actually perform any network calls and\n", + " just produces constant results. It can be used to prototype feedback function\n", + " wiring for your apps before invoking potentially slow (to run/to load)\n", + " feedback functions.\n", + "\n", + "- Test out high-volume record and feedback computation. To this end, we use the\n", + " custom app which is dummy in a sense that it produces useless answers without\n", + " making any API calls but otherwise behaves similarly to real apps, and the\n", + " dummy feedback function provider." ] }, { @@ -34,12 +41,28 @@ "metadata": {}, "outputs": [], "source": [ + "from concurrent.futures import as_completed\n", + "from time import sleep\n", + "\n", "from examples.expositional.end2end_apps.custom_app.custom_app import CustomApp\n", + "from tqdm.auto import tqdm\n", "\n", "from trulens_eval import Feedback\n", "from trulens_eval import Tru\n", "from trulens_eval.feedback.provider.hugs import Dummy\n", + "from trulens_eval.schema import FeedbackMode\n", "from trulens_eval.tru_custom_app import TruCustomApp\n", + "from trulens_eval.utils.threading import TP\n", + "\n", + "tp = TP()\n", + "\n", + "d = Dummy(\n", + " loading_prob=0.1,\n", + " freeze_prob=0.01,\n", + " error_prob=0.01,\n", + " overloaded_prob=0.1,\n", + " rpm=6000\n", + ")\n", "\n", "tru = Tru()\n", "\n", @@ -57,24 +80,135 @@ "metadata": {}, "outputs": [], "source": [ + "f_dummy1 = Feedback(\n", + " d.language_match\n", + ").on_input_output()\n", + "\n", + "f_dummy2 = Feedback(\n", + " d.positive_sentiment\n", + ").on_output()\n", + "\n", + "f_dummy3 = Feedback(\n", + " d.positive_sentiment\n", + ").on_input()\n", + "\n", + "\n", "# Create custom app:\n", "ca = CustomApp()\n", "\n", "# Create trulens wrapper:\n", "ta = TruCustomApp(\n", " ca,\n", - " app_id=\"customapp\"\n", - ")\n", + " app_id=\"customapp\",\n", + " main_method = ca.respond_to_query,\n", + " feedbacks=[f_dummy1, f_dummy2, f_dummy3],\n", + " feedback_mode=FeedbackMode.WITH_APP_THREAD\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with ta as recorder:\n", + " res = ca.respond_to_query(f\"hello there\")\n", + "\n", + "rec = recorder.get()\n", + "print(rec.feedback_results)\n", + "for res in as_completed(rec.feedback_results):\n", + " print(res.result())\n", + " \n", + "print(rec.feedback_results)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Sequential app invocation.\n", + "\n", + "if True:\n", + " for i in tqdm(range(100), desc=\"invoking app\"):\n", + " with ta as recorder:\n", + " res = ca.respond_to_query(f\"hello {i}\")\n", + "\n", + " rec = recorder.get()\n", + " assert rec is not None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Parallel feedback evaluation.\n", + "\n", + "futures = []\n", + "num_tests = 1000\n", + "good = 0\n", + "bad = 0\n", + "\n", + "def test_feedback(msg):\n", + " return msg, d.positive_sentiment(msg)\n", + "\n", + "for i in tqdm(range(num_tests), desc=\"starting feedback task\"):\n", + " futures.append(tp.submit(test_feedback, msg=f\"good\"))\n", + "\n", + "prog = tqdm(as_completed(futures), total=num_tests)\n", + "\n", + "for f in prog:\n", + " try:\n", + " res = f.result()\n", + " good += 1\n", + "\n", + " assert res[0] == \"good\"\n", + "\n", + " prog.set_description_str(f\"{good} / {bad}\")\n", + " except Exception as e:\n", + " bad += 1\n", + " prog.set_description_str(f\"{good} / {bad}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ feedback result positive_sentiment DONE)\n", + "✅ feedback result positive_sentiment DONE)\n", + "✅ feedback result positive_sentiment DONE)\n", + "✅ feedback result language_match DONE)\n" + ] + } + ], + "source": [ + "# Parallel app invocation.\n", + "\n", + "def run_query(q):\n", + "\n", + " with ta as recorder:\n", + " res = ca.respond_to_query(q)\n", "\n", - "# Must be set after above constructor as the instrumentation changes what\n", - "# CustomApp.respond_to_query points to.\n", - "ta.main_method = CustomApp.respond_to_query\n", - "ta.main_async_method = CustomApp.arespond_to_query\n", + " rec = recorder.get()\n", + " assert rec is not None\n", "\n", - "from trulens_eval.appui import AppUI\n", + " return f\"run_query {q} result\"\n", "\n", - "aui = AppUI(app=ta)\n", - "aui.d" + "for i in tqdm(range(100), desc=\"starting app task\"):\n", + " print(\n", + " tp.completed_tasks, \n", + " end=\"\\r\"\n", + " )\n", + " tp.submit(run_query, q=f\"hello {i}\")" ] }, { diff --git a/trulens_eval/examples/expositional/end2end_apps/custom_app/custom_app.py b/trulens_eval/examples/expositional/end2end_apps/custom_app/custom_app.py index 1b8fcd59e..030db74c5 100644 --- a/trulens_eval/examples/expositional/end2end_apps/custom_app/custom_app.py +++ b/trulens_eval/examples/expositional/end2end_apps/custom_app/custom_app.py @@ -1,5 +1,6 @@ import asyncio from asyncio import sleep +from concurrent.futures import wait from examples.expositional.end2end_apps.custom_app.custom_llm import CustomLLM from examples.expositional.end2end_apps.custom_app.custom_memory import \ @@ -8,6 +9,7 @@ CustomRetriever from trulens_eval.tru_custom_app import instrument +from trulens_eval.utils.threading import ThreadPoolExecutor instrument.method(CustomRetriever, "retrieve_chunks") instrument.method(CustomMemory, "remember") @@ -42,6 +44,18 @@ def retrieve_chunks(self, data): @instrument def respond_to_query(self, input): chunks = self.retrieve_chunks(input) + + # Creates a few threads to process chunks in parallel to test apps that + # make use of threads. + ex = ThreadPoolExecutor(max_workers=max(1, len(chunks))) + + futures = list( + ex.submit(lambda chunk: chunk + " processed", chunk=chunk) for chunk in chunks + ) + + wait(futures) + chunks = list(future.result() for future in futures) + self.memory.remember(input) answer = self.llm.generate(",".join(chunks)) diff --git a/trulens_eval/examples/expositional/end2end_apps/trubot/trubot_populate_db.ipynb b/trulens_eval/examples/expositional/end2end_apps/trubot/trubot_populate_db.ipynb index 46ed66bb8..0a8224402 100644 --- a/trulens_eval/examples/expositional/end2end_apps/trubot/trubot_populate_db.ipynb +++ b/trulens_eval/examples/expositional/end2end_apps/trubot/trubot_populate_db.ipynb @@ -106,7 +106,7 @@ "\n", "def test_bot(selector, question):\n", " print(selector, question)\n", - " app = get_or_make_app(cid=question + str(selector), selector=selector)#, feedback_mode=FeedbackMode.WITH_APP)\n", + " app = get_or_make_app(cid=question + str(selector), selector=selector, feedback_mode=FeedbackMode.DEFERRED)\n", " answer = get_answer(app=app, question=question)\n", " return answer\n", "\n", @@ -114,9 +114,7 @@ "\n", "for s in selectors:\n", " for m in messages:\n", - " # results.append(TP().promise(test_bot, selector=s, question=m))\n", - " # TP().finish()\n", - " test_bot(selector=s, question=m)\n" + " results.append(TP().submit(test_bot, selector=s, question=m))\n" ] }, { @@ -125,8 +123,7 @@ "metadata": {}, "outputs": [], "source": [ - "thread = Tru().start_evaluator(restart=True)\n", - "# TP().finish()" + "thread = Tru().start_evaluator(restart=True)" ] }, { diff --git a/trulens_eval/examples/quickstart/langchain_quickstart.ipynb b/trulens_eval/examples/quickstart/langchain_quickstart.ipynb index 0594327b8..a31782f52 100644 --- a/trulens_eval/examples/quickstart/langchain_quickstart.ipynb +++ b/trulens_eval/examples/quickstart/langchain_quickstart.ipynb @@ -60,6 +60,7 @@ "\n", "# Imports main tools:\n", "from trulens_eval import TruChain, Feedback, Huggingface, Tru\n", + "from trulens_eval.schema import FeedbackResult\n", "tru = Tru()\n", "\n", "# Imports from langchain to build app. You may need to install langchain first\n", @@ -184,6 +185,48 @@ "display(llm_response)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Retrieve records and feedback" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The record of the ap invocation can be retrieved from the `recording`:\n", + "\n", + "rec = recording.get() # use .get if only one record\n", + "# recs = recording.records # use .records if multiple\n", + "\n", + "display(rec)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The results of the feedback functions can be rertireved from the record. These\n", + "# are `Future` instances (see `concurrent.futures`). You can use `as_completed`\n", + "# to wait until they have finished evaluating.\n", + "\n", + "from concurrent.futures import as_completed\n", + "\n", + "for feedback_future in as_completed(rec.feedback_results):\n", + " feedback, feedback_result = feedback_future.result()\n", + " \n", + " feedback: Feedback\n", + " feedbac_result: FeedbackResult\n", + "\n", + " display(feedback.name, feedback_result.result)\n" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -285,7 +328,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.8.16" }, "vscode": { "interpreter": { diff --git a/trulens_eval/trulens_eval/app.py b/trulens_eval/trulens_eval/app.py index a0768220e..d69e95334 100644 --- a/trulens_eval/trulens_eval/app.py +++ b/trulens_eval/trulens_eval/app.py @@ -4,6 +4,8 @@ from abc import ABC from abc import abstractmethod +from concurrent import futures +from concurrent.futures import as_completed import contextvars import inspect from inspect import BoundArguments @@ -709,9 +711,16 @@ def _on_new_record(self, func) -> Iterable[RecordingContext]: # WithInstrumentCallbacks requirement def _on_add_record( - self, ctx: RecordingContext, func: Callable, sig: Signature, - bindings: BoundArguments, ret: Any, error: Any, perf: Perf, cost: Cost - ): + self, + ctx: RecordingContext, + func: Callable, + sig: Signature, + bindings: BoundArguments, + ret: Any, + error: Any, + perf: Perf, + cost: Cost + ) -> Record: """ Called by instrumented methods if they use _new_record to construct a record call list. @@ -735,25 +744,27 @@ def build_record(calls, record_metadata): meta=jsonify(record_metadata) ) + # tp = TP() + # Finishing record needs to be done in a thread lock, done there: record = ctx.finish_record(build_record) if error is not None: - if self.feedback_mode == FeedbackMode.WITH_APP: - self._handle_error(record=record, error=error) - - elif self.feedback_mode in [FeedbackMode.DEFERRED, - FeedbackMode.WITH_APP_THREAD]: - TP().runlater(self._handle_error, record=record, error=error) - + # May block on DB. + self._handle_error(record=record, error=error) raise error - if self.feedback_mode == FeedbackMode.WITH_APP: - self._handle_record(record=record) + # Will block on DB, but not on feedback evaluation, depending on + # FeedbackMode: + record.feedback_results = self._handle_record(record=record) - elif self.feedback_mode in [FeedbackMode.DEFERRED, - FeedbackMode.WITH_APP_THREAD]: - TP().runlater(self._handle_record, record=record) + if record.feedback_results is None: + return record + + # If in blocking mode ("WITH_APP"), wait for feedbacks to finished + # evaluating before returning the record. + if self.feedback_mode in [FeedbackMode.WITH_APP]: + futures.wait(record.feedback_results) return record @@ -835,11 +846,13 @@ def with_(self, func, *args, **kwargs) -> Any: res, _ = self.with_record(func, *args, **kwargs) return res - def with_record(self, - func, - *args, - record_metadata: JSON = None, - **kwargs) -> Tuple[Any, Record]: + def with_record( + self, + func, + *args, + record_metadata: JSON = None, + **kwargs + ) -> Tuple[Any, Record]: """ Call the given `func` with the given `*args` and `**kwargs`, producing its results as well as a record of the execution. @@ -897,18 +910,30 @@ def _with_dep_message(self, method, is_async=False, with_record=False): """ ) - def _handle_record(self, record: Record): + def _add_future_feedback(self, future: 'Future[Feedback, FeedbackResult]'): + _, res = future.result() + self.tru.add_feedback(res) + + def _handle_record( + self, + record: Record + ) -> Optional[List['Future[Tuple[Feedback, FeedbackResult]]']]: """ - Write out record-related info to database if set. + Write out record-related info to database if set and schedule feedback + functions to be evaluated. """ if self.tru is None or self.feedback_mode is None: - return - + return None + + self.tru: Tru + self.db: DB + + # Need to add record to db before evaluating feedback functions. record_id = self.tru.add_record(record=record) if len(self.feedbacks) == 0: - return + return [] # Add empty (to run) feedback to db. if self.feedback_mode == FeedbackMode.DEFERRED: @@ -921,15 +946,18 @@ def _handle_record(self, record: Record): ) ) + return None + elif self.feedback_mode in [FeedbackMode.WITH_APP, FeedbackMode.WITH_APP_THREAD]: - results = self.tru.run_feedback_functions( - record=record, feedback_functions=self.feedbacks, app=self + return self.tru._submit_feedback_functions( + record=record, + feedback_functions=self.feedbacks, + app=self, + on_done=self._add_future_feedback ) - for result in results: - self.tru.add_feedback(result) def _handle_error(self, record: Record, error: Exception): if self.db is None: diff --git a/trulens_eval/trulens_eval/database/sqlalchemy_db.py b/trulens_eval/trulens_eval/database/sqlalchemy_db.py index 7adac50d3..6f806184f 100644 --- a/trulens_eval/trulens_eval/database/sqlalchemy_db.py +++ b/trulens_eval/trulens_eval/database/sqlalchemy_db.py @@ -36,6 +36,7 @@ from trulens_eval.schema import FeedbackResultStatus from trulens_eval.schema import RecordID from trulens_eval.utils.serial import JSON +from trulens_eval.utils.text import UNICODE_CHECK, UNICODE_CLOCK, UNICODE_HOURGLASS, UNICODE_STOP logger = logging.getLogger(__name__) @@ -145,6 +146,9 @@ def insert_record(self, record: schema.Record) -> schema.RecordID: session.merge(_rec) # update existing else: session.merge(_rec) # add new record # .add was not thread safe + + print(f"{UNICODE_CHECK} added record {_rec.record_id}") + return _rec.record_id def get_app(self, app_id: str) -> Optional[JSON]: @@ -171,6 +175,8 @@ def insert_app(self, app: schema.AppDefinition) -> schema.AppID: ) session.merge(_app) # .add was not thread safe + print(f"{UNICODE_CHECK} added app {_app.app_id}") + return _app.app_id def insert_feedback_definition( @@ -189,6 +195,8 @@ def insert_feedback_definition( ) session.merge(_fb_def) # .add was not thread safe + print(f"{UNICODE_CHECK} added feedback definition {_fb_def.feedback_definition_id}") + return _fb_def.feedback_definition_id def get_feedback_defs( @@ -223,6 +231,24 @@ def insert_feedback( session.merge( _feedback_result ) # insert new result # .add was not thread safe + + status = FeedbackResultStatus(_feedback_result.status) + + if status == FeedbackResultStatus.DONE: + icon = UNICODE_CHECK + elif status == FeedbackResultStatus.RUNNING: + icon = UNICODE_HOURGLASS + elif status == FeedbackResultStatus.NONE: + icon = UNICODE_CLOCK + elif status == FeedbackResultStatus.FAILED: + icon = UNICODE_STOP + else: + icon = "???" + + print( + f"{icon} feedback result {_feedback_result.name} {status.name} {_feedback_result.feedback_result_id}" + ) + return _feedback_result.feedback_result_id def get_feedback( diff --git a/trulens_eval/trulens_eval/feedback/feedback.py b/trulens_eval/trulens_eval/feedback/feedback.py index 88640c132..50ec2860c 100644 --- a/trulens_eval/trulens_eval/feedback/feedback.py +++ b/trulens_eval/trulens_eval/feedback/feedback.py @@ -6,7 +6,8 @@ import logging import pprint import traceback -from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +import warnings import numpy as np import pydantic @@ -213,7 +214,7 @@ def _default_selectors(self): self.selectors = selectors @staticmethod - def evaluate_deferred(tru: 'Tru') -> int: + def evaluate_deferred(tru: 'Tru') -> List['Future[Tuple[Feedback, FeedbackResult]]']: """ Evaluates feedback functions that were specified to be deferred. Returns an integer indicating how many evaluates were run. @@ -228,7 +229,7 @@ def prepare_feedback(row): app_json = row.app_json feedback = Feedback(**row.feedback_json) - feedback.run_and_log( + return feedback, feedback.run_and_log( record=record, app=app_json, tru=tru, @@ -237,7 +238,9 @@ def prepare_feedback(row): feedbacks = db.get_feedback() - started_count = 0 + tp = TP() + + futures: List['Future[Tuple[Feedback, FeedbackResult]]'] = [] for i, row in feedbacks.iterrows(): feedback_ident = f"{row.fname} for app {row.app_json['app_id']}, record {row.record_id}" @@ -248,8 +251,7 @@ def prepare_feedback(row): f"{UNICODE_YIELD} Feedback task starting: {feedback_ident}" ) - TP().runlater(prepare_feedback, row) - started_count += 1 + futures.append(tp.submit(prepare_feedback, row)) elif row.status in [FeedbackResultStatus.RUNNING]: now = datetime.now().timestamp() @@ -258,8 +260,7 @@ def prepare_feedback(row): f"{UNICODE_YIELD} Feedback task last made progress over 30 seconds ago. " f"Retrying: {feedback_ident}" ) - TP().runlater(prepare_feedback, row) - started_count += 1 + futures.append(tp.submit(prepare_feedback, row)) else: print( @@ -274,8 +275,7 @@ def prepare_feedback(row): f"{UNICODE_YIELD} Feedback task last made progress over 5 minutes ago. " f"Retrying: {feedback_ident}" ) - TP().runlater(prepare_feedback, row) - started_count += 1 + futures.append(tp.submit(prepare_feedback, row)) else: print( @@ -286,7 +286,7 @@ def prepare_feedback(row): elif row.status == FeedbackResultStatus.DONE: pass - return started_count + return futures def __call__(self, *args, **kwargs) -> Any: assert self.imp is not None, "Feedback definition needs an implementation to call." @@ -411,6 +411,19 @@ def on(self, *args, **kwargs): name=self.supplied_name ) + def run_robust( + self, + app: Union[AppDefinition, JSON], + record: Record, + timeout: float = 30, + retries: int = 3 + ): + """ + Same as `run` but will try multiple times upon non-user errors. + """ + + + def run( self, app: Union[AppDefinition, JSON], record: Record ) -> FeedbackResult: @@ -461,10 +474,9 @@ def run( ) cost += part_cost except Exception as e: - print( + raise RuntimeError( f"Evaluation of {self.name} failed on inputs: \n{pp.pformat(ins)[0:128]}\n{e}." ) - continue if isinstance(result_and_meta, Tuple): # If output is a tuple of two, we assume it is the float/multifloat and the metadata. @@ -506,8 +518,10 @@ def run( feedback_calls.append(feedback_call) if len(result_vals) == 0: - logger.warning( - f"Feedback function {self.supplied_name if self.supplied_name is not None else self.name} with aggregation {self.agg} had no inputs." + warnings.warn( + f"Feedback function {self.supplied_name if self.supplied_name is not None else self.name} with aggregation {self.agg} had no inputs.", + UserWarning, + stacklevel=1 ) result = np.nan @@ -559,7 +573,7 @@ def run_and_log( tru: 'Tru', app: Union[AppDefinition, JSON] = None, feedback_result_id: Optional[FeedbackResultID] = None - ) -> FeedbackResult: + ) -> Optional[FeedbackResult]: record_id = record.record_id app_id = record.app_id diff --git a/trulens_eval/trulens_eval/feedback/provider/endpoint/base.py b/trulens_eval/trulens_eval/feedback/provider/endpoint/base.py index 7a3ee6e44..f0faf9e45 100644 --- a/trulens_eval/trulens_eval/feedback/provider/endpoint/base.py +++ b/trulens_eval/trulens_eval/feedback/provider/endpoint/base.py @@ -8,15 +8,17 @@ from types import AsyncGeneratorType from types import ModuleType from typing import ( - Any, Awaitable, Dict, Optional, Sequence, Tuple, Type, TypeVar + Any, Awaitable, Callable, Dict, Optional, Sequence, Tuple, Type, TypeVar ) +import warnings import pydantic import requests from trulens_eval.keys import ApiKeyError from trulens_eval.schema import Cost -from trulens_eval.utils.python import get_first_local_in_call_stack +from trulens_eval.utils.threading import DEFAULT_NETWORK_TIMEOUT +from trulens_eval.utils.python import get_first_local_in_call_stack, locals_except from trulens_eval.utils.python import SingletonPerName from trulens_eval.utils.python import Thunk from trulens_eval.utils.serial import JSON @@ -144,7 +146,7 @@ def pace_me(self): return def post( - self, url: str, payload: JSON, timeout: Optional[int] = None + self, url: str, payload: JSON, timeout: float = DEFAULT_NETWORK_TIMEOUT ) -> Any: self.pace_me() ret = requests.post( @@ -826,27 +828,64 @@ class DummyEndpoint(Endpoint): Endpoint for testing purposes. Should not make any network calls. """ - # Pretend the model we are querying is loading as is in huggingface. - is_loading: bool = True + # Simulated result parameters below. - def __new__(cls, *args, **kwargs): - return super(Endpoint, cls).__new__(cls, name="dummyendpoint") + # How often to produce the "model loading" response. + loading_prob: float + # How much time to indicate as needed to load the model in the above response. + loading_time: Callable[[], float] = \ + pydantic.Field(exclude=True, default_factory=lambda: lambda: random.uniform(0.73, 3.7)) - def __init__(self, name: str = "dummyendpoint", **kwargs): + # How often to produce an error response. + error_prob: float + + # How often to produce freeze instead of producing a response. + freeze_prob: float + + # How often to produce the overloaded message. + overloaded_prob: float + + def __new__( + cls, + *args, + **kwargs + ): + + return super(Endpoint, cls).__new__( + cls, + name="dummyendpoint" + ) + + def __init__( + self, + name: str = "dummyendpoint", + error_prob: float = 1/100, + freeze_prob: float = 1/100, + overloaded_prob: float = 1/100, + loading_prob: float = 1/100, + rpm: float = DEFAULT_RPM * 10, + **kwargs + ): if hasattr(self, "callback_class"): # Already created with SingletonPerName mechanism return + assert error_prob + freeze_prob + overloaded_prob + loading_prob <= 1.0 + assert rpm > 0 + kwargs['name'] = name kwargs['callback_class'] = EndpointCallback - kwargs['rpm'] = DEFAULT_RPM * 10 - super().__init__(**kwargs) + super().__init__(**kwargs, **locals_except("self", "name", "kwargs", "__class__")) + + print(f"Using DummyEndpoint with {locals_except('self', 'name', 'kwargs', '__class__')}") + # TODO: make a robust version of POST or use tenacity def post( - self, url: str, payload: JSON, timeout: Optional[int] = None + self, url: str, payload: JSON, timeout: float = DEFAULT_NETWORK_TIMEOUT ) -> Any: - # classification results only, like from huggingface + # Classification results only, like from huggingface. Simulates + # overloaded, model loading, frozen, error. self.pace_me() @@ -857,16 +896,38 @@ def post( ) """ - if self.is_loading: - # "model loading message" - j = dict(estimated_time=1.2345) - self.is_loading = False - elif random.randint(a=0, b=50) == 0: - # randomly overloaded + r = random.random() + j: Optional[JSON] = None + + if r < self.freeze_prob: + # Simulated freeze outcome. + + while True: + sleep(timeout) + raise TimeoutError() + + r -= self.freeze_prob + + if r < self.error_prob: + # Simulated error outcome. + + raise RuntimeError("Simulated error happened.") + r -= self.error_prob + + if r < self.loading_prob: + # Simulated loading model outcome. + + j = dict(estimated_time=self.loading_time()) + r -= self.loading_prob + + if r < self.overloaded_prob: + # Simulated overloaded outcome. + j = dict(error="overloaded") + r -= self.overloaded_prob - else: - # otherwise a constant success + if j is None: + # Otherwise a simulated success outcome with some constant results. j = [ [ @@ -889,15 +950,14 @@ def post( # how long to wait: if "estimated_time" in j: wait_time = j['estimated_time'] - logger.error(f"Waiting for {j} ({wait_time}) second(s).") + warnings.warn(f"Waiting for {j} ({wait_time}) second(s).", ResourceWarning, stacklevel=2) sleep(wait_time + 2) return self.post(url, payload) if isinstance(j, Dict) and "error" in j: error = j['error'] - logger.error(f"API error: {j}.") if error == "overloaded": - logger.error("Waiting for overloaded API before trying again.") + warnings.warn("Waiting for overloaded API before trying again.", ResourceWarning, stacklevel=2) sleep(10) return self.post(url, payload) else: diff --git a/trulens_eval/trulens_eval/feedback/provider/hugs.py b/trulens_eval/trulens_eval/feedback/provider/hugs.py index 67f2aad49..bb40ca4f2 100644 --- a/trulens_eval/trulens_eval/feedback/provider/hugs.py +++ b/trulens_eval/trulens_eval/feedback/provider/hugs.py @@ -1,6 +1,8 @@ +from concurrent.futures import Future +from concurrent.futures import wait import logging from multiprocessing.pool import AsyncResult -from typing import Dict +from typing import Dict, Optional, Tuple import numpy as np @@ -8,7 +10,8 @@ from trulens_eval.feedback.provider.endpoint import HuggingfaceEndpoint from trulens_eval.feedback.provider.endpoint.base import DummyEndpoint from trulens_eval.feedback.provider.endpoint.base import Endpoint -from trulens_eval.utils.threading import TP +from trulens_eval.utils.python import locals_except +from trulens_eval.utils.threading import TP, ThreadPoolExecutor logger = logging.getLogger(__name__) @@ -67,7 +70,7 @@ class Huggingface(Provider): endpoint: Endpoint - def __init__(self, name: str = None, endpoint=None, **kwargs): + def __init__(self, name: Optional[str] = None, endpoint=None, **kwargs): # NOTE(piotrm): pydantic adds endpoint to the signature of this # constructor if we don't include it explicitly, even though we set it # down below. Adding it as None here as a temporary hack. @@ -105,7 +108,7 @@ def __init__(self, name: str = None, endpoint=None, **kwargs): # TODEP @_tci - def language_match(self, text1: str, text2: str) -> float: + def language_match(self, text1: str, text2: str) -> Tuple[float, Dict]: """ Uses Huggingface's papluca/xlm-roberta-base-language-detection model. A function that uses language detection on `text1` and `text2` and @@ -141,23 +144,26 @@ def get_scores(text): ) return {r['label']: r['score'] for r in hf_response} - max_length = 500 - scores1: AsyncResult[Dict] = TP().promise( - get_scores, text=text1[:max_length] - ) - scores2: AsyncResult[Dict] = TP().promise( - get_scores, text=text2[:max_length] - ) + with ThreadPoolExecutor(max_workers=2) as tpool: + max_length = 500 + f_scores1: Future[Dict] = tpool.submit( + get_scores, text=text1[:max_length] + ) + f_scores2: Future[Dict] = tpool.submit( + get_scores, text=text2[:max_length] + ) - scores1: Dict = scores1.get() - scores2: Dict = scores2.get() + wait([f_scores1, f_scores2]) + + scores1: Dict = f_scores1.result() + scores2: Dict = f_scores2.result() langs = list(scores1.keys()) prob1 = np.array([scores1[k] for k in langs]) prob2 = np.array([scores2[k] for k in langs]) diff = prob1 - prob2 - l1 = 1.0 - (np.linalg.norm(diff, ord=1)) / 2.0 + l1: float = float(1.0 - (np.linalg.norm(diff, ord=1)) / 2.0) return l1, dict(text1_scores=scores1, text2_scores=scores2) @@ -197,7 +203,9 @@ def positive_sentiment(self, text: str) -> float: for label in hf_response: if label['label'] == 'LABEL_2': - return label['score'] + return float(label['score']) + + raise RuntimeError("LABEL_2 not found in huggingface api response.") # TODEP @_tci @@ -238,6 +246,8 @@ def not_toxic(self, text: str) -> float: for label in hf_response: if label['label'] == 'toxic': return label['score'] + + raise RuntimeError("LABEL_2 not found in huggingface api response.") # TODEP @_tci @@ -262,6 +272,8 @@ def _summarized_groundedness(self, premise: str, hypothesis: str) -> float: for label in hf_response: if label['label'] == 'entailment': return label['score'] + + raise RuntimeError("LABEL_2 not found in huggingface api response.") # TODEP @_tci @@ -377,7 +389,6 @@ def pii_detection_with_cot_reasons(self, text: str): if not isinstance(hf_response, list): raise ValueError("Unexpected response from Huggingface API: response should be a list or a dictionary") - # Iterate through the entities and extract "word" and "score" for "NAME" entities for i, entity in enumerate(hf_response): reasons[f"{entity.get('entity_group')} detected: {entity['word']}"] = f"Score: {entity['score']}" @@ -401,8 +412,20 @@ def pii_detection_with_cot_reasons(self, text: str): class Dummy(Huggingface): - def __init__(self, name: str = None, **kwargs): + def __init__( + self, + name: Optional[str] = None, + error_prob: float = 1/100, + loading_prob: float = 1/100, + freeze_prob: float = 1/100, + overloaded_prob: float = 1/100, + rpm: float = 600, + **kwargs + ): kwargs['name'] = name or "dummyhugs" - kwargs['endpoint'] = DummyEndpoint(name="dummyendhugspoint") + kwargs['endpoint'] = DummyEndpoint( + name="dummyendhugspoint", + **locals_except("self", "name", "kwargs") + ) - super().__init__(**kwargs) + super().__init__(**kwargs) \ No newline at end of file diff --git a/trulens_eval/trulens_eval/schema.py b/trulens_eval/trulens_eval/schema.py index 92b229dda..6adb5f7f1 100644 --- a/trulens_eval/trulens_eval/schema.py +++ b/trulens_eval/trulens_eval/schema.py @@ -27,7 +27,7 @@ from pathlib import Path from pprint import PrettyPrinter from typing import ( - Any, Callable, ClassVar, Dict, Optional, Sequence, TypeVar, Union + Any, Callable, ClassVar, Dict, List, Mapping, Optional, Sequence, Type, TypeVar, Union ) import dill @@ -163,8 +163,8 @@ class Record(SerialModel): record_id: RecordID app_id: AppID - cost: Optional[Cost] = None # pydantic.Field(default_factory=Cost) - perf: Optional[Perf] = None # pydantic.Field(default_factory=Perf) + cost: Optional[Cost] = None + perf: Optional[Perf] = None ts: datetime = pydantic.Field(default_factory=lambda: datetime.now()) @@ -180,7 +180,13 @@ class Record(SerialModel): # via `layout_calls_as_app`. calls: Sequence[RecordAppCall] = [] + # Feedback results only filled for records that were just produced. Will not + # be filled in when read from database. Also, will not fill in when using + # `FeedbackMode.DEFERRED`. + feedback_results: Optional[List['Future[FeedbackResult]']] = pydantic.Field(exclude=True) + def __init__(self, record_id: Optional[RecordID] = None, **kwargs): + # Fixed record_id for obj_id_of_id below. super().__init__(record_id="temporary", **kwargs) if record_id is None: @@ -204,16 +210,17 @@ def layout_calls_as_app(self) -> JSON: would be in the AppDefinitionstructure. """ - # TODO: problem: collissions ret = Bunch(**self.dict()) for call in self.calls: - frame_info = call.top( - ) # info about the method call is at the top of the stack + # Info about the method call is at the top of the stack + frame_info = call.top() + + # Adds another attribute to path, from method name: path = frame_info.path._append( GetItemOrAttribute(item_or_attribute=frame_info.method.name) - ) # adds another attribute to path, from method name - # TODO: append if already there + ) + ret = path.set_or_append(obj=ret, val=call) return ret diff --git a/trulens_eval/trulens_eval/tru.py b/trulens_eval/trulens_eval/tru.py index 149ba9e1f..4acef4852 100644 --- a/trulens_eval/trulens_eval/tru.py +++ b/trulens_eval/trulens_eval/tru.py @@ -1,3 +1,4 @@ +from concurrent.futures import as_completed, wait import logging from multiprocessing import Process import os @@ -7,12 +8,13 @@ import threading from threading import Thread from time import sleep -from typing import Iterable, List, Optional, Sequence, Union +from typing import Callable, Iterable, List, Optional, Sequence, Tuple, Union import warnings import pkg_resources from trulens_eval.database.sqlalchemy_db import SqlAlchemyDB +from trulens_eval.db import DB from trulens_eval.db import JSON from trulens_eval.feedback import Feedback from trulens_eval.schema import AppDefinition @@ -114,9 +116,8 @@ def __init__( if database_file: warnings.warn( - DeprecationWarning( - "`database_file` is deprecated, use `database_url` instead as in `database_url='sqlite:///filename'." - ) + "`database_file` is deprecated, use `database_url` instead as in `database_url='sqlite:///filename'.", + DeprecationWarning, stacklevel=2 ) if database_url is None: @@ -178,33 +179,19 @@ def add_record(self, record: Optional[Record] = None, **kwargs): update_record = add_record - def run_feedback_functions( + def _submit_feedback_functions( self, record: Record, feedback_functions: Sequence[Feedback], app: Optional[AppDefinition] = None, - ) -> Sequence[JSON]: - """ - Run a collection of feedback functions and report their result. - - Parameters: - - record (Record): The record on which to evaluate the feedback - functions. - - app (App, optional): The app that produced the given record. - If not provided, it is looked up from the given database `db`. - - feedback_functions (Sequence[Feedback]): A collection of feedback - functions to evaluate. - - Returns nothing. - """ - + on_done: Optional[Callable[['Future[Tuple[Feedback,FeedbackResult]]'], None]] = None + ) -> List['Future[Tuple[Feedback,FeedbackResult]]']: app_id = record.app_id + self.db: DB + if app is None: - app = self.db.get_app(app_id=app_id) + app = AppDefinition.parse_obj(self.db.get_app(app_id=app_id)) if app is None: raise RuntimeError( "App {app_id} not present in db. " @@ -220,16 +207,53 @@ def run_feedback_functions( ) self.add_app(app=app) - evals = [] + futures = [] - for func in feedback_functions: - evals.append( - TP().promise(lambda f: f.run(app=app, record=record), func) - ) + tp: TP = TP() - evals = map(lambda p: p.get(), evals) + for ffunc in feedback_functions: + fut: 'Future[Tuple[Feedback,FeedbackResult]]' = \ + tp.submit(lambda f: (f, f.run(app=app, record=record)), ffunc) + + if on_done is not None: + fut.add_done_callback(on_done) + + futures.append(fut) - return list(evals) + return futures + + def run_feedback_functions( + self, + record: Record, + feedback_functions: Sequence[Feedback], + app: Optional[AppDefinition] = None, + ) -> Iterable[FeedbackResult]: + """ + Run a collection of feedback functions and report their result. + + Parameters: + + record (Record): The record on which to evaluate the feedback + functions. + + app (App, optional): The app that produced the given record. + If not provided, it is looked up from the given database `db`. + + feedback_functions (Sequence[Feedback]): A collection of feedback + functions to evaluate. + + Yields `FeedbackResult`, one for each element of `feedback_functions` + potentially in random order. + """ + + for res in as_completed( + self._submit_feedback_functions( + record=record, + feedback_functions=feedback_functions, + app=app + ) + ): + yield res.result() def add_app(self, app: AppDefinition) -> None: """ @@ -307,9 +331,11 @@ def get_leaderboard(self, app_ids: List[str]): return leaderboard - def start_evaluator(self, - restart=False, - fork=False) -> Union[Process, Thread]: + def start_evaluator( + self, + restart=False, + fork=False + ) -> Union[Process, Thread]: """ Start a deferred feedback function evaluation thread. """ @@ -324,24 +350,18 @@ def start_evaluator(self, "Evaluator is already running in this process." ) - from trulens_eval.feedback import Feedback - if not fork: self.evaluator_stop = threading.Event() def runloop(): while fork or not self.evaluator_stop.is_set(): - #print( - # "Looking for things to do. Stop me with `tru.stop_evaluator()`.", - # end='' - #) - started_count = Feedback.evaluate_deferred(tru=self) + futures = Feedback.evaluate_deferred(tru=self) - if started_count > 0: + if len(futures) > 0: print( - f"{UNICODE_YIELD}{UNICODE_YIELD}{UNICODE_YIELD} Started {started_count} deferred feedback functions." + f"{UNICODE_YIELD}{UNICODE_YIELD}{UNICODE_YIELD} Started {len(futures)} deferred feedback functions." ) - TP().finish() + wait(futures) print( f"{UNICODE_CHECK}{UNICODE_CHECK}{UNICODE_CHECK} Finished evaluating deferred feedback functions." ) @@ -431,6 +451,11 @@ def stop_dashboard(self, force: bool = False) -> None: Tru.dashboard_proc = None def run_dashboard_in_jupyter(self): + """ + Experimental approach to attempt to display the dashboard inside a + jupyter notebook. Relies on the `streamlit_jupyter` package. + """ + # EXPERIMENTAL # TODO: check for jupyter logger.warning( diff --git a/trulens_eval/trulens_eval/utils/langchain.py b/trulens_eval/trulens_eval/utils/langchain.py index a08602714..31563f035 100644 --- a/trulens_eval/trulens_eval/utils/langchain.py +++ b/trulens_eval/trulens_eval/utils/langchain.py @@ -8,6 +8,8 @@ from typing import List, Type +from concurrent.futures import wait + from trulens_eval import app from trulens_eval.feedback import Feedback from trulens_eval.utils.containers import first @@ -128,7 +130,9 @@ def _get_relevant_documents(self, query: str, *, ) for doc in docs ) - results = list((doc, promise.result()) for (doc, promise) in futures) + wait([future for (_, future) in futures]) + + results = list((doc, future.result()) for (doc, future) in futures) filtered = map(first, filter(second, results)) # Return only the filtered ones. diff --git a/trulens_eval/trulens_eval/utils/llama.py b/trulens_eval/trulens_eval/utils/llama.py index e12d7f206..6726e80dd 100644 --- a/trulens_eval/trulens_eval/utils/llama.py +++ b/trulens_eval/trulens_eval/utils/llama.py @@ -17,7 +17,7 @@ from trulens_eval.utils.imports import REQUIREMENT_LLAMA from trulens_eval.utils.pyschema import Class from trulens_eval.utils.serial import JSON -from trulens_eval.utils.threading import TP +from trulens_eval.utils.threading import ThreadPoolExecutor with OptionalImports(message=REQUIREMENT_LLAMA): from llama_index.indices.query.schema import QueryBundle @@ -142,10 +142,12 @@ def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: # Get relevant docs using super class: nodes = super()._retrieve(query_bundle) + ex = ThreadPoolExecutor(max_workers=max(1, len(nodes))) + # Evaluate the filter on each, in parallel. - promises = ( + futures = ( ( - node, TP().promise( + node, ex.submit( lambda query, node: self.feedback( query.query_str, node.node.get_text() ) > self.threshold, @@ -154,7 +156,10 @@ def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: ) ) for node in nodes ) - results = ((node, promise.get()) for (node, promise) in promises) + + wait([future for (_, future) in futures]) + + results = ((node, future.result()) for (node, future) in futures) filtered = map(first, filter(second, results)) # Return only the filtered ones. diff --git a/trulens_eval/trulens_eval/utils/python.py b/trulens_eval/trulens_eval/utils/python.py index be72eee80..22962b5a0 100644 --- a/trulens_eval/trulens_eval/utils/python.py +++ b/trulens_eval/trulens_eval/utils/python.py @@ -9,7 +9,7 @@ import logging from pprint import PrettyPrinter from queue import Queue -from typing import Any, Callable, Dict, Hashable, Optional, Sequence, TypeVar +from typing import Any, Callable, Dict, Generic, Hashable, Iterator, Optional, Sequence, Type, TypeVar, Union logger = logging.getLogger(__name__) pp = PrettyPrinter() @@ -17,6 +17,28 @@ T = TypeVar("T") Thunk = Callable[[], T] + +# Function utilities. + +def code_line(func) -> Optional[str]: + """ + Get a string representation of the location of the given function `func`. + """ + if hasattr(func, "__code__"): + code = func.__code__ + return f"{code.co_filename}:{code.co_firstlineno}" + else: + return None + +def locals_except(*exceptions): + """ + Get caller's locals except for the named exceptions. + """ + + locs = caller_frame(offset=1).f_locals # 1 to skip this call + + return {k: v for k, v in locs.items() if k not in exceptions} + # Python call stack utilities # Attribute name for storing a callstack in asyncio tasks. @@ -139,6 +161,8 @@ def _future_target_wrapper(stack, func, *args, **kwargs): the stack and need to do this to the frames prior to thread starts. """ + # TODO: See if threading.stack_size([size]) can be used instead. + # Keep this for looking up via get_first_local_in_call_stack . pre_start_stack = stack @@ -251,8 +275,9 @@ class above. # Class utilities +T = TypeVar("T") -class SingletonPerName(): +class SingletonPerName(Generic[T]): """ Class for creating singleton instances except there being one instance max, there is one max per different `name` argument. If `name` is never given, @@ -262,7 +287,7 @@ class SingletonPerName(): # Hold singleton instances here. instances: Dict[Hashable, 'SingletonPerName'] = dict() - def __new__(cls, *args, name: str = None, **kwargs): + def __new__(cls: Type[SingletonPerName[T]], *args, name: Optional[str] = None, **kwargs) -> SingletonPerName[T]: """ Create the singleton instance if it doesn't already exist and return it. """ @@ -275,4 +300,6 @@ def __new__(cls, *args, name: str = None, **kwargs): ) SingletonPerName.instances[k] = super().__new__(cls) - return SingletonPerName.instances[k] + obj: cls = SingletonPerName.instances[k] + + return obj diff --git a/trulens_eval/trulens_eval/utils/threading.py b/trulens_eval/trulens_eval/utils/threading.py index 6299dd91f..e2e3891fb 100644 --- a/trulens_eval/trulens_eval/utils/threading.py +++ b/trulens_eval/trulens_eval/utils/threading.py @@ -2,26 +2,31 @@ Multi-threading utilities. """ + +from concurrent.futures import Future from concurrent.futures import ThreadPoolExecutor as fThreadPoolExecutor +from concurrent.futures import TimeoutError from inspect import stack import logging -from multiprocessing.pool import AsyncResult -from multiprocessing.pool import ThreadPool -from queue import Queue -from time import sleep -from typing import Callable, List, Optional, TypeVar +import threading -import pandas as pd +from typing import Callable, TypeVar -from trulens_eval.utils.python import _future_target_wrapper +from trulens_eval.utils.python import _future_target_wrapper, code_line from trulens_eval.utils.python import SingletonPerName logger = logging.getLogger(__name__) T = TypeVar("T") +DEFAULT_NETWORK_TIMEOUT: float = 10.0 # seconds + class ThreadPoolExecutor(fThreadPoolExecutor): + """ + A ThreadPoolExecutor that keeps track of the stack prior to each thread's + invocation. + """ def submit(self, fn, /, *args, **kwargs): present_stack = stack() @@ -30,91 +35,101 @@ def submit(self, fn, /, *args, **kwargs): ) -class TP(SingletonPerName): # "thread processing" +class TP(SingletonPerName['TP']): # "thread processing" # Store here stacks of calls to various thread starting methods so that we # can retrieve the trace of calls that caused a thread to start. - # pre_run_stacks = dict() + MAX_THREADS = 128 + + # How long to wait for any task before restarting it. + DEBUG_TIMEOUT = 600.0 # 5 minutes def __init__(self): if hasattr(self, "thread_pool"): # Already initialized as per SingletonPerName mechanism. return - # TODO(piotrm): if more tasks than `processes` get added, future ones - # will block and earlier ones may never start executing. - self.thread_pool = ThreadPool(processes=64) - self.running = 0 - self.promises = Queue(maxsize=64) - - def runrepeatedly(self, func: Callable, rpm: float = 6, *args, **kwargs): - - def runner(): - while True: - func(*args, **kwargs) - sleep(60 / rpm) - - self.runlater(runner) - - def _thread_starter(self, func, args, kwargs): - present_stack = stack() - - prom = self.thread_pool.apply_async( - _future_target_wrapper, - args=(present_stack, func) + args, - kwds=kwargs + # Run tasks started with this class using this pool. + self.thread_pool = fThreadPoolExecutor( + max_workers=TP.MAX_THREADS, + thread_name_prefix="TP.submit" ) - return prom - - def finish_if_full(self): - if self.promises.full(): - print("Task queue full. Finishing existing tasks.") - self.finish() - - def runlater(self, func: Callable, *args, **kwargs) -> None: - self.finish_if_full() - - prom = self._thread_starter(func, args, kwargs) - - self.promises.put(prom) - - def promise(self, func: Callable[..., T], *args, **kwargs) -> AsyncResult: - self.finish_if_full() - - prom = self._thread_starter(func, args, kwargs) - - self.promises.put(prom) - - return prom - - def finish(self, timeout: Optional[float] = None) -> int: - logger.debug(f"Finishing {self.promises.qsize()} task(s).") - - timeouts = [] - - while not self.promises.empty(): - prom = self.promises.get() - try: - prom.get(timeout=timeout) - except TimeoutError: - timeouts.append(prom) - - for prom in timeouts: - self.promises.put(prom) - - if len(timeouts) == 0: - logger.debug("Done.") - else: - logger.debug("Some tasks timed out.") - - return len(timeouts) - - def _status(self) -> List[str]: - rows = [] - - for p in self.thread_pool._pool: - rows.append([p.is_alive(), str(p)]) + # Keep a seperate pool for threads whose function is only to wait for + # the tasks executed in the above pool. Keeping this seperate to prevent + # the deadlock whereas the wait thread waits for a tasks which will + # never be run because the thread pool is filled with wait threads. + self.thread_pool_debug_tasks = ThreadPoolExecutor( + max_workers=TP.MAX_THREADS, + thread_name_prefix="TP.submit with debug timeout" + ) - return pd.DataFrame(rows, columns=["alive", "thread"]) + self.completed_tasks = 0 + self.timedout_tasks = 0 + self.failed_tasks = 0 + + def _run_with_timeout( + self, + func: Callable[..., T], + *args, + timeout: float = DEBUG_TIMEOUT, + **kwargs + ) -> T: + + fut: 'Future[T]' = self.thread_pool.submit(func, *args, **kwargs) + + try: + res: T = fut.result(timeout=timeout) + return res + + except TimeoutError as e: + logger.error( + f"Run of {func.__name__} in {threading.current_thread()} timed out after {TP.DEBUG_TIMEOUT} second(s).\n" + f"{code_line(func)}" + ) + + raise e + + except Exception as e: + logger.warning( + f"Run of {func.__name__} in {threading.current_thread()} failed with: {e}" + ) + raise e + + def submit( + self, + func: Callable[..., T], + *args, + timeout: float = DEBUG_TIMEOUT, + **kwargs + ) -> 'Future[T]': + + # TODO(piotrm): need deadlock fixes here. If submit or _submit was called + # earlier in the stack, do not use a threadpool to evaluate this task + # and instead create a new thread for it. This prevents tasks in a + # threadpool adding sub-tasks in the same threadpool which can lead to + # deadlocks. Alternatively just raise an exception in those cases. + + return self._submit(func, *args, timeout = timeout, **kwargs) + + def _submit( + self, + func: Callable[..., T], + *args, + timeout: float = DEBUG_TIMEOUT, + **kwargs + ) -> 'Future[T]': + + # Submit a concurrent tasks to run `func` with the given `args` and + # `kwargs` but stop with error if it ever takes too long. This is only + # meant for debugging purposes as we expect all concurrent tasks to have + # their own retry/timeout capabilities. + + return self.thread_pool_debug_tasks.submit( + self._run_with_timeout, + func, + *args, + timeout=timeout, + **kwargs + )