Skip to content
This repository has been archived by the owner on Aug 13, 2024. It is now read-only.

Commit

Permalink
add entity resolution prototype
Browse files Browse the repository at this point in the history
* do a preliminary infilling pass to generate "preliminary arguments"
* search a entities store using the "preliminary arguments" to get a list of relevant entities
* pass list of relevant entities as json schema enum to "final infilling"
  • Loading branch information
jordan-wu-97 committed Dec 11, 2023
1 parent 7e3bc71 commit 6bf1150
Show file tree
Hide file tree
Showing 11 changed files with 348 additions and 74 deletions.
16 changes: 8 additions & 8 deletions examples/fast-api-server/dummy-data/employees.csv
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
id,name,email
1,Richard Hendricks,[email protected]
2,Erlich Bachman,[email protected]
3,Dinesh Chugtai,[email protected]
4,Bertram Gilfoyle,[email protected]
5,Jared Dunn,[email protected]
6,Monica Hall,[email protected]
7,Gavin Belson,[email protected]
id,name,email,role
1,Richard Hendricks,[email protected],Pied Piper CEO
2,Erlich Bachman,[email protected],
3,Dinesh Chugtai,[email protected],
4,Bertram Gilfoyle,[email protected],
5,Jared Dunn,[email protected],Pied Piper COO
6,Monica Hall,[email protected],
7,Gavin Belson,[email protected],Hooli CEO
Original file line number Diff line number Diff line change
@@ -1,40 +1,67 @@
from typing import Sequence

import anyio
import pandas as pd
from openassistants.contrib.python_callable import PythonCallableFunction
from openassistants.data_models.function_input import BaseJSONSchema
from openassistants.data_models.function_output import FunctionOutput, TextOutput
from openassistants.functions.base import FunctionExecutionDependency
from openassistants.functions.base import (
Entity,
EntityConfig,
FunctionExecutionDependency,
)
from openassistants.functions.utils import AsyncStreamVersion


async def find_email_by_name_callable(
async def _execute(
deps: FunctionExecutionDependency,
) -> AsyncStreamVersion[Sequence[FunctionOutput]]:
"""
user entities are:
name | email
richard | [email protected]
...
"""

name = deps.arguments["name"]

# load csv
df = pd.read_csv("dummy-data/employees.csv")
df = await anyio.to_thread.run_sync(pd.read_csv, "dummy-data/employees.csv")

# find email where csv.name == name
email = df[df["name"] == name]["email"].iloc[0]
filtered = df[df["name"] == name]["email"]

email = None if len(filtered) == 0 else filtered.iloc[0]

if email is None:
yield [
TextOutput(text=f"Could not find email for: {name}"),
]

else:
yield [TextOutput(text=f"Found Email For: {name} ({email})")]


async def _get_entity_configs() -> dict[str, EntityConfig]:
df = await anyio.to_thread.run_sync(pd.read_csv, "dummy-data/employees.csv")

records = df.to_dict("records")

yield [TextOutput(text=f"Found Email For: {name} ({email})")]
return {
"name": EntityConfig(
entities=[
Entity(
identity=row["name"],
description=row["role"] if isinstance(row["role"], str) else None,
)
for row in records
],
)
}


find_email_by_name_function = PythonCallableFunction(
id="find_email",
type="FindEmailFunction",
display_name="Find Email",
description="Find an email address",
sample_questions=["Find the email address for {employee}"],
sample_questions=[
"Find the email address for {employee}",
"What is {employee}'s email address?",
],
parameters=BaseJSONSchema(
json_schema={
"type": "object",
Expand All @@ -47,5 +74,6 @@ async def find_email_by_name_callable(
"required": ["name"],
}
),
callable=find_email_by_name_callable,
execute_callable=_execute,
get_entity_configs_callable=_get_entity_configs,
)
20 changes: 15 additions & 5 deletions packages/openassistants/openassistants/contrib/python_callable.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,32 @@
from typing import Callable, Sequence
from typing import Awaitable, Callable, Dict, Sequence

from openassistants.data_models.function_input import BaseJSONSchema
from openassistants.data_models.function_output import FunctionOutput
from openassistants.functions.base import BaseFunction, FunctionExecutionDependency
from openassistants.functions.base import (
BaseFunction,
EntityConfig,
FunctionExecutionDependency,
)
from openassistants.functions.utils import AsyncStreamVersion


class PythonCallableFunction(BaseFunction):
callable: Callable[
execute_callable: Callable[
[FunctionExecutionDependency], AsyncStreamVersion[Sequence[FunctionOutput]]
]

parameters: BaseJSONSchema

get_entity_configs_callable: Callable[[], Awaitable[dict[str, EntityConfig]]]

async def execute(
self, deps: FunctionExecutionDependency
) -> AsyncStreamVersion[Sequence[FunctionOutput]]:
async for version in self.callable(deps):
yield version
async for output in self.execute_callable(deps):
yield output

async def get_parameters_json_schema(self) -> dict:
return self.parameters.json_schema

async def get_entity_configs(self) -> Dict[str, EntityConfig]:
return await self.get_entity_configs_callable()
46 changes: 39 additions & 7 deletions packages/openassistants/openassistants/core/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from langchain.chat_models.base import BaseChatModel
from langchain.chat_models.openai import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.embeddings.base import Embeddings

from openassistants.data_models.chat_messages import (
OpasAssistantMessage,
Expand All @@ -11,34 +13,52 @@
OpasUserMessage,
)
from openassistants.data_models.function_input import FunctionCall, FunctionInputRequest
from openassistants.functions.base import BaseFunction, FunctionExecutionDependency
from openassistants.functions.base import (
BaseFunction,
Entity,
FunctionExecutionDependency,
)
from openassistants.functions.crud import FunctionCRUD, LocalCRUD
from openassistants.llm_function_calling.entity_resolution import resolve_entities
from openassistants.llm_function_calling.infilling import (
generate_argument_decisions,
generate_arguments,
)
from openassistants.llm_function_calling.selection import select_function
from openassistants.utils.async_utils import AsyncStreamVersion
from openassistants.utils.langchain_util import LangChainCachedEmbeddings


class Assistant:
function_identification: BaseChatModel
function_infilling: BaseChatModel
function_summarization: BaseChatModel
entity_embedding_model: Embeddings
function_libraries: List[FunctionCRUD]

_cached_all_functions: List[BaseFunction]

def __init__(
self,
libraries: List[str | FunctionCRUD],
function_identification: BaseChatModel = ChatOpenAI(model="gpt-3.5-turbo-16k"),
function_infilling: BaseChatModel = ChatOpenAI(model="gpt-3.5-turbo-16k"),
function_summarization: BaseChatModel = ChatOpenAI(model="gpt-3.5-turbo-16k"),
function_identification: Optional[BaseChatModel] = None,
function_infilling: Optional[BaseChatModel] = None,
function_summarization: Optional[BaseChatModel] = None,
entity_embedding_model: Optional[Embeddings] = None,
):
self.function_identification = function_identification
self.function_infilling = function_infilling
self.function_summarization = function_summarization
# instantiate dynamically vs as default args
self.function_identification = function_identification or ChatOpenAI(
model_name="gpt-3.5-turbo", temperature=0.0, max_tokens=128
)
self.function_infilling = function_infilling or ChatOpenAI(
model_name="gpt-3.5-turbo", temperature=0.0, max_tokens=128
)
self.function_summarization = function_summarization or ChatOpenAI(
model_name="gpt-3.5-turbo", temperature=0.0, max_tokens=128
)
self.entity_embedding_model = (
entity_embedding_model or LangChainCachedEmbeddings(OpenAIEmbeddings())
)
self.function_libraries = [
library if isinstance(library, FunctionCRUD) else LocalCRUD(library)
for library in libraries
Expand Down Expand Up @@ -83,6 +103,7 @@ async def do_infilling(
message: OpasUserMessage,
selected_function: BaseFunction,
args_json_schema: dict,
entities_info: Dict[str, List[Entity]],
) -> Tuple[bool, dict]:
# Perform infilling and generate argument decisions in parallel
arguments_future = asyncio.create_task(
Expand All @@ -91,6 +112,7 @@ async def do_infilling(
self.function_infilling,
message.content,
dependencies.get("chat_history"),
entities_info,
)
)
argument_decisions_future = asyncio.create_task(
Expand Down Expand Up @@ -183,13 +205,23 @@ async def handle_user_plaintext(
await selected_function.get_parameters_json_schema()
)

# perform entity resolution
entities_info = await resolve_entities(
selected_function,
self.function_infilling,
self.entity_embedding_model,
message.content,
dependencies.get("chat_history"),
)

# perform argument infilling
if len(selected_function_arg_json_schema["properties"]) > 0:
complete, arguments = await self.do_infilling(
dependencies,
message,
selected_function,
selected_function_arg_json_schema,
entities_info,
)
else:
complete, arguments = True, {}
Expand Down
17 changes: 14 additions & 3 deletions packages/openassistants/openassistants/functions/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import abc
import dataclasses
import textwrap
from typing import List, Optional, Sequence
from typing import Any, Dict, List, Optional, Sequence

from langchain.chat_models.base import BaseChatModel
from pydantic import BaseModel

from openassistants.data_models.chat_messages import OpasMessage
from openassistants.data_models.function_output import FunctionOutput
from openassistants.functions.utils import AsyncStreamVersion
from openassistants.utils.json_schema import PyRepr
from pydantic import BaseModel


@dataclasses.dataclass
Expand All @@ -19,6 +18,15 @@ class FunctionExecutionDependency:
summarization_chat_model: BaseChatModel


class Entity(BaseModel):
identity: str
description: Optional[str] = None


class EntityConfig(BaseModel):
entities: List[Entity]


class BaseFunction(BaseModel, abc.ABC):
id: str
type: str
Expand Down Expand Up @@ -60,3 +68,6 @@ def {self.id}({params_repr}) -> pd.DataFrame:

def get_function_name(self) -> str:
return f"{self.id}"

async def get_entity_configs(self) -> Dict[str, EntityConfig]:
return {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import asyncio
from typing import Any, Dict, List, Tuple

from langchain.chat_models.base import BaseChatModel
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores.usearch import USearch
from openassistants.data_models.chat_messages import OpasUserMessage
from openassistants.functions.base import BaseFunction, Entity, EntityConfig
from openassistants.llm_function_calling.infilling import generate_arguments


async def _vec_search(
documents: List[Document],
query: str,
embeddings: Embeddings,
) -> List[Document]:
search: USearch = await USearch.afrom_documents(
embedding=embeddings,
documents=documents,
)
results = await search.asimilarity_search(
query,
k=3,
)
return results


def entity_to_document(entity: Entity) -> Document:
doc = Document(metadata=entity.model_dump(), page_content=entity.identity)

if entity.description:
doc.page_content += f" ({entity.description})"

return doc


async def _get_entities(
entity_cfg: EntityConfig,
entity_key: str,
preliminary_arguments: Dict[str, Any],
embeddings: Embeddings,
) -> Tuple[str, List[Entity]]:
documents = [entity_to_document(entity) for entity in entity_cfg.entities]

query = str(preliminary_arguments[entity_key])

vec_result = await _vec_search(documents, query, embeddings)

return entity_key, [Entity(**r.metadata) for r in vec_result]


async def resolve_entities(
function: BaseFunction,
function_infilling_llm: BaseChatModel,
embeddings: Embeddings,
user_query: str,
chat_history: List[OpasUserMessage],
) -> Dict[str, List[Entity]]:
entity_configs = await function.get_entity_configs()

# skip if no entity configs
if len(entity_configs) == 0:
return {}

preliminary_arguments = await generate_arguments(
function,
function_infilling_llm,
user_query,
chat_history,
{},
)

results = await asyncio.gather(
*[
_get_entities(entity_cfg, param_name, preliminary_arguments, embeddings)
for param_name, entity_cfg in entity_configs.items()
]
)

return {key: entities for key, entities in results}
Loading

0 comments on commit 6bf1150

Please sign in to comment.