Skip to content

Commit

Permalink
fix langchain calls (#340)
Browse files Browse the repository at this point in the history
* error handling

* prototype
  • Loading branch information
piotrm0 authored Aug 1, 2023
1 parent d2d61bb commit bac458e
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 13 deletions.
11 changes: 7 additions & 4 deletions trulens_eval/examples/frameworks/langchain/langchain_async.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
"outputs": [],
"source": [
"# Set up a language match feedback function.\n",
"\n",
"tru = Tru()\n",
"hugs = feedback.Huggingface()\n",
"f_lang_match = Feedback(hugs.language_match).on_input_output()\n",
Expand All @@ -80,7 +81,7 @@
"llm = ChatOpenAI(\n",
" temperature=0.0,\n",
" streaming=True, # important\n",
" callbacks=[callback]\n",
" # callbacks=[callback] # callback can be here or below in acall_with_record\n",
")\n",
"chain = LLMChain(llm=llm, prompt=prompt)\n",
"tc = tru.Chain(chain, feedbacks=[f_lang_match])\n",
Expand All @@ -91,6 +92,7 @@
"f_res_record = asyncio.create_task(\n",
" tc.acall_with_record(\n",
" inputs=dict(question=message),\n",
" callbacks=[callback]\n",
" )\n",
")\n",
"\n",
Expand All @@ -99,7 +101,7 @@
" print(token)\n",
"\n",
"# By now the acall_with_record results should be ready.\n",
"res, record = await f_res_record\n",
"res = await f_res_record\n",
"\n",
"res"
]
Expand Down Expand Up @@ -132,15 +134,16 @@
" prompt = PromptTemplate.from_template(\"Honestly answer this question: {question}.\")\n",
" llm = ChatOpenAI(\n",
" temperature=0.0,\n",
" streaming=True, # important\n",
" callbacks=[callback]\n",
" streaming=True# important\n",
" # callbacks=[callback]\n",
" )\n",
" chain = LLMChain(llm=llm, prompt=prompt)\n",
" tc = tru.Chain(chain, feedbacks=[f_lang_match])\n",
"\n",
" f_res_record = asyncio.create_task(\n",
" tc.acall_with_record(\n",
" inputs=dict(question=message),\n",
" callbacks=[callback]\n",
" )\n",
" )\n",
" \n",
Expand Down
4 changes: 1 addition & 3 deletions trulens_eval/trulens_eval/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
from abc import abstractmethod
import logging
from pprint import PrettyPrinter
from typing import (
Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple
)
from typing import Any, Dict, Iterable, Optional, Sequence, Set, Tuple

import pydantic
from pydantic import Field
Expand Down
17 changes: 16 additions & 1 deletion trulens_eval/trulens_eval/instruments.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ class TP but a more complete solution may be the instrumentation of
import os
from pprint import PrettyPrinter
import threading as th
import traceback
from typing import Callable, Dict, Iterable, Optional, Sequence, Set

from pydantic import BaseModel
Expand Down Expand Up @@ -389,7 +390,6 @@ def find_root_methods(f):

if record is None:
logger.debug(f"{query}: no record found, not recording.")

return await func(*args, **kwargs)

# Otherwise keep track of inputs and outputs (or exception).
Expand All @@ -413,6 +413,8 @@ def find_instrumented(f):
start_time = None
end_time = None

bindings = dict()

try:
# Using sig bind here so we can produce a list of key-value
# pairs even if positional arguments were provided.
Expand All @@ -428,6 +430,9 @@ def find_instrumented(f):
error = e
error_str = str(e)

logger.error(f"Error calling wrapped function {func.__name__}.")
logger.error(traceback.format_exc())

# Don't include self in the recorded arguments.
nonself = {
k: jsonify(v)
Expand Down Expand Up @@ -498,6 +503,8 @@ def find_instrumented(f):
start_time = None
end_time = None

bindings = dict()

try:
# Using sig bind here so we can produce a list of key-value
# pairs even if positional arguments were provided.
Expand All @@ -511,6 +518,9 @@ def find_instrumented(f):
error = e
error_str = str(e)

logger.error(f"Error calling wrapped function {func.__name__}.")
logger.error(traceback.format_exc())

# Don't include self in the recorded arguments.
nonself = {
k: jsonify(v)
Expand Down Expand Up @@ -549,6 +559,11 @@ def find_instrumented(f):
# chain.
setattr(w, Instrument.PATH, query)

# NOTE(piotrm): This is important; langchain checks signatures to adjust
# behaviour and we need to match. Without this, wrapper signatures will
# show up only as *args, **kwargs .
w.__signature__ = inspect.signature(func)

return w

def instrument_object(self, obj, query: Query, done: Set[int] = None):
Expand Down
15 changes: 10 additions & 5 deletions trulens_eval/trulens_eval/tru_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime
import logging
from pprint import PrettyPrinter
import traceback
from typing import Any, ClassVar, Dict, List, Sequence, Tuple, Union

# import nest_asyncio # NOTE(piotrm): disabling for now, need more investigation
Expand Down Expand Up @@ -51,7 +52,9 @@ class Default:
# Instrument only methods with these names and of these classes.
METHODS = {
"_call": lambda o: isinstance(o, langchain.chains.base.Chain),
"__call__": lambda o: isinstance(o, langchain.chains.base.Chain),
"_acall": lambda o: isinstance(o, langchain.chains.base.Chain),
"acall": lambda o: isinstance(o, langchain.chains.base.Chain),
"_get_relevant_documents":
lambda o: True, # VectorStoreRetriever, langchain >= 0.230
}
Expand Down Expand Up @@ -189,7 +192,7 @@ async def func_async(inputs, **kwargs):
return evl.run_until_complete(self._eval_async_root_method(func_async, inputs, **kwargs))
"""

# NOTE: Input signature compatible with langchain.chains.base.Chain._acall
# NOTE: Input signature compatible with langchain.chains.base.Chain.acall
async def acall_with_record(self, inputs: Union[Dict[str, Any], Any], **kwargs) -> Tuple[Any, Record]:
"""
Run the chain and also return a record metadata object.
Expand Down Expand Up @@ -217,15 +220,16 @@ async def acall_with_record(self, inputs: Union[Dict[str, Any], Any], **kwargs)
try:
start_time = datetime.now()
ret, cost = await Endpoint.atrack_all_costs_tally(
lambda: self.app._acall(inputs=inputs, **kwargs)
lambda: self.app.acall(inputs=inputs, **kwargs)
)
end_time = datetime.now()

except BaseException as e:
end_time = datetime.now()
error = e
logger.error(f"App raised an exception: {e}")

logger.error(traceback.format_exc())

assert len(record) > 0, "No information recorded in call."

ret_record_args = dict()
Expand All @@ -249,7 +253,7 @@ async def acall_with_record(self, inputs: Union[Dict[str, Any], Any], **kwargs)

return ret, ret_record

# NOTE: Input signature compatible with langchain.chains.base.Chain._call
# NOTE: Input signature compatible with langchain.chains.base.Chain.__call__
def call_with_record(self, inputs: Union[Dict[str, Any], Any], **kwargs) -> Tuple[Any, Record]:
"""
Run the chain and also return a record metadata object.
Expand Down Expand Up @@ -277,14 +281,15 @@ def call_with_record(self, inputs: Union[Dict[str, Any], Any], **kwargs) -> Tupl
try:
start_time = datetime.now()
ret, cost = Endpoint.track_all_costs_tally(
lambda: self.app._call(inputs=inputs, **kwargs)
lambda: self.app.__call__(inputs=inputs, **kwargs)
)
end_time = datetime.now()

except BaseException as e:
end_time = datetime.now()
error = e
logger.error(f"App raised an exception: {e}")
logger.error(traceback.format_exc())

assert len(record) > 0, "No information recorded in call."

Expand Down

0 comments on commit bac458e

Please sign in to comment.