From bac458e9f1d3054bb98645fe15ac8c3be49c4e93 Mon Sep 17 00:00:00 2001 From: Piotr Mardziel Date: Tue, 1 Aug 2023 14:39:37 -0700 Subject: [PATCH] fix langchain calls (#340) * error handling * prototype --- .../frameworks/langchain/langchain_async.ipynb | 11 +++++++---- trulens_eval/trulens_eval/app.py | 4 +--- trulens_eval/trulens_eval/instruments.py | 17 ++++++++++++++++- trulens_eval/trulens_eval/tru_chain.py | 15 ++++++++++----- 4 files changed, 34 insertions(+), 13 deletions(-) diff --git a/trulens_eval/examples/frameworks/langchain/langchain_async.ipynb b/trulens_eval/examples/frameworks/langchain/langchain_async.ipynb index b50f2adaa..251715fea 100644 --- a/trulens_eval/examples/frameworks/langchain/langchain_async.ipynb +++ b/trulens_eval/examples/frameworks/langchain/langchain_async.ipynb @@ -68,6 +68,7 @@ "outputs": [], "source": [ "# Set up a language match feedback function.\n", + "\n", "tru = Tru()\n", "hugs = feedback.Huggingface()\n", "f_lang_match = Feedback(hugs.language_match).on_input_output()\n", @@ -80,7 +81,7 @@ "llm = ChatOpenAI(\n", " temperature=0.0,\n", " streaming=True, # important\n", - " callbacks=[callback]\n", + " # callbacks=[callback] # callback can be here or below in acall_with_record\n", ")\n", "chain = LLMChain(llm=llm, prompt=prompt)\n", "tc = tru.Chain(chain, feedbacks=[f_lang_match])\n", @@ -91,6 +92,7 @@ "f_res_record = asyncio.create_task(\n", " tc.acall_with_record(\n", " inputs=dict(question=message),\n", + " callbacks=[callback]\n", " )\n", ")\n", "\n", @@ -99,7 +101,7 @@ " print(token)\n", "\n", "# By now the acall_with_record results should be ready.\n", - "res, record = await f_res_record\n", + "res = await f_res_record\n", "\n", "res" ] @@ -132,8 +134,8 @@ " prompt = PromptTemplate.from_template(\"Honestly answer this question: {question}.\")\n", " llm = ChatOpenAI(\n", " temperature=0.0,\n", - " streaming=True, # important\n", - " callbacks=[callback]\n", + " streaming=True# important\n", + " # callbacks=[callback]\n", " )\n", " chain = LLMChain(llm=llm, prompt=prompt)\n", " tc = tru.Chain(chain, feedbacks=[f_lang_match])\n", @@ -141,6 +143,7 @@ " f_res_record = asyncio.create_task(\n", " tc.acall_with_record(\n", " inputs=dict(question=message),\n", + " callbacks=[callback]\n", " )\n", " )\n", " \n", diff --git a/trulens_eval/trulens_eval/app.py b/trulens_eval/trulens_eval/app.py index efc064a06..8cc21ac3c 100644 --- a/trulens_eval/trulens_eval/app.py +++ b/trulens_eval/trulens_eval/app.py @@ -6,9 +6,7 @@ from abc import abstractmethod import logging from pprint import PrettyPrinter -from typing import ( - Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple -) +from typing import Any, Dict, Iterable, Optional, Sequence, Set, Tuple import pydantic from pydantic import Field diff --git a/trulens_eval/trulens_eval/instruments.py b/trulens_eval/trulens_eval/instruments.py index b847aeddc..6edca1c11 100644 --- a/trulens_eval/trulens_eval/instruments.py +++ b/trulens_eval/trulens_eval/instruments.py @@ -258,6 +258,7 @@ class TP but a more complete solution may be the instrumentation of import os from pprint import PrettyPrinter import threading as th +import traceback from typing import Callable, Dict, Iterable, Optional, Sequence, Set from pydantic import BaseModel @@ -389,7 +390,6 @@ def find_root_methods(f): if record is None: logger.debug(f"{query}: no record found, not recording.") - return await func(*args, **kwargs) # Otherwise keep track of inputs and outputs (or exception). @@ -413,6 +413,8 @@ def find_instrumented(f): start_time = None end_time = None + bindings = dict() + try: # Using sig bind here so we can produce a list of key-value # pairs even if positional arguments were provided. @@ -428,6 +430,9 @@ def find_instrumented(f): error = e error_str = str(e) + logger.error(f"Error calling wrapped function {func.__name__}.") + logger.error(traceback.format_exc()) + # Don't include self in the recorded arguments. nonself = { k: jsonify(v) @@ -498,6 +503,8 @@ def find_instrumented(f): start_time = None end_time = None + bindings = dict() + try: # Using sig bind here so we can produce a list of key-value # pairs even if positional arguments were provided. @@ -511,6 +518,9 @@ def find_instrumented(f): error = e error_str = str(e) + logger.error(f"Error calling wrapped function {func.__name__}.") + logger.error(traceback.format_exc()) + # Don't include self in the recorded arguments. nonself = { k: jsonify(v) @@ -549,6 +559,11 @@ def find_instrumented(f): # chain. setattr(w, Instrument.PATH, query) + # NOTE(piotrm): This is important; langchain checks signatures to adjust + # behaviour and we need to match. Without this, wrapper signatures will + # show up only as *args, **kwargs . + w.__signature__ = inspect.signature(func) + return w def instrument_object(self, obj, query: Query, done: Set[int] = None): diff --git a/trulens_eval/trulens_eval/tru_chain.py b/trulens_eval/trulens_eval/tru_chain.py index d7b2879d3..e082e0196 100644 --- a/trulens_eval/trulens_eval/tru_chain.py +++ b/trulens_eval/trulens_eval/tru_chain.py @@ -5,6 +5,7 @@ from datetime import datetime import logging from pprint import PrettyPrinter +import traceback from typing import Any, ClassVar, Dict, List, Sequence, Tuple, Union # import nest_asyncio # NOTE(piotrm): disabling for now, need more investigation @@ -51,7 +52,9 @@ class Default: # Instrument only methods with these names and of these classes. METHODS = { "_call": lambda o: isinstance(o, langchain.chains.base.Chain), + "__call__": lambda o: isinstance(o, langchain.chains.base.Chain), "_acall": lambda o: isinstance(o, langchain.chains.base.Chain), + "acall": lambda o: isinstance(o, langchain.chains.base.Chain), "_get_relevant_documents": lambda o: True, # VectorStoreRetriever, langchain >= 0.230 } @@ -189,7 +192,7 @@ async def func_async(inputs, **kwargs): return evl.run_until_complete(self._eval_async_root_method(func_async, inputs, **kwargs)) """ - # NOTE: Input signature compatible with langchain.chains.base.Chain._acall + # NOTE: Input signature compatible with langchain.chains.base.Chain.acall async def acall_with_record(self, inputs: Union[Dict[str, Any], Any], **kwargs) -> Tuple[Any, Record]: """ Run the chain and also return a record metadata object. @@ -217,7 +220,7 @@ async def acall_with_record(self, inputs: Union[Dict[str, Any], Any], **kwargs) try: start_time = datetime.now() ret, cost = await Endpoint.atrack_all_costs_tally( - lambda: self.app._acall(inputs=inputs, **kwargs) + lambda: self.app.acall(inputs=inputs, **kwargs) ) end_time = datetime.now() @@ -225,7 +228,8 @@ async def acall_with_record(self, inputs: Union[Dict[str, Any], Any], **kwargs) end_time = datetime.now() error = e logger.error(f"App raised an exception: {e}") - + logger.error(traceback.format_exc()) + assert len(record) > 0, "No information recorded in call." ret_record_args = dict() @@ -249,7 +253,7 @@ async def acall_with_record(self, inputs: Union[Dict[str, Any], Any], **kwargs) return ret, ret_record - # NOTE: Input signature compatible with langchain.chains.base.Chain._call + # NOTE: Input signature compatible with langchain.chains.base.Chain.__call__ def call_with_record(self, inputs: Union[Dict[str, Any], Any], **kwargs) -> Tuple[Any, Record]: """ Run the chain and also return a record metadata object. @@ -277,7 +281,7 @@ def call_with_record(self, inputs: Union[Dict[str, Any], Any], **kwargs) -> Tupl try: start_time = datetime.now() ret, cost = Endpoint.track_all_costs_tally( - lambda: self.app._call(inputs=inputs, **kwargs) + lambda: self.app.__call__(inputs=inputs, **kwargs) ) end_time = datetime.now() @@ -285,6 +289,7 @@ def call_with_record(self, inputs: Union[Dict[str, Any], Any], **kwargs) -> Tupl end_time = datetime.now() error = e logger.error(f"App raised an exception: {e}") + logger.error(traceback.format_exc()) assert len(record) > 0, "No information recorded in call."