This repository has been archived by the owner on Aug 13, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* do a preliminary infilling pass to generate "preliminary arguments" * search a entities store using the "preliminary arguments" to get a list of relevant entities * pass list of relevant entities as json schema enum to "final infilling"
- Loading branch information
1 parent
7e3bc71
commit 6bf1150
Showing
11 changed files
with
348 additions
and
74 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
id,name,email | ||
1,Richard Hendricks,[email protected] | ||
2,Erlich Bachman,[email protected] | ||
3,Dinesh Chugtai,[email protected] | ||
4,Bertram Gilfoyle,[email protected] | ||
5,Jared Dunn,[email protected] | ||
6,Monica Hall,[email protected] | ||
7,Gavin Belson,[email protected] | ||
id,name,email,role | ||
1,Richard Hendricks,[email protected],Pied Piper CEO | ||
2,Erlich Bachman,[email protected], | ||
3,Dinesh Chugtai,[email protected], | ||
4,Bertram Gilfoyle,[email protected], | ||
5,Jared Dunn,[email protected],Pied Piper COO | ||
6,Monica Hall,[email protected], | ||
7,Gavin Belson,[email protected],Hooli CEO |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,40 +1,67 @@ | ||
from typing import Sequence | ||
|
||
import anyio | ||
import pandas as pd | ||
from openassistants.contrib.python_callable import PythonCallableFunction | ||
from openassistants.data_models.function_input import BaseJSONSchema | ||
from openassistants.data_models.function_output import FunctionOutput, TextOutput | ||
from openassistants.functions.base import FunctionExecutionDependency | ||
from openassistants.functions.base import ( | ||
Entity, | ||
EntityConfig, | ||
FunctionExecutionDependency, | ||
) | ||
from openassistants.functions.utils import AsyncStreamVersion | ||
|
||
|
||
async def find_email_by_name_callable( | ||
async def _execute( | ||
deps: FunctionExecutionDependency, | ||
) -> AsyncStreamVersion[Sequence[FunctionOutput]]: | ||
""" | ||
user entities are: | ||
name | email | ||
richard | [email protected] | ||
... | ||
""" | ||
|
||
name = deps.arguments["name"] | ||
|
||
# load csv | ||
df = pd.read_csv("dummy-data/employees.csv") | ||
df = await anyio.to_thread.run_sync(pd.read_csv, "dummy-data/employees.csv") | ||
|
||
# find email where csv.name == name | ||
email = df[df["name"] == name]["email"].iloc[0] | ||
filtered = df[df["name"] == name]["email"] | ||
|
||
email = None if len(filtered) == 0 else filtered.iloc[0] | ||
|
||
if email is None: | ||
yield [ | ||
TextOutput(text=f"Could not find email for: {name}"), | ||
] | ||
|
||
else: | ||
yield [TextOutput(text=f"Found Email For: {name} ({email})")] | ||
|
||
|
||
async def _get_entity_configs() -> dict[str, EntityConfig]: | ||
df = await anyio.to_thread.run_sync(pd.read_csv, "dummy-data/employees.csv") | ||
|
||
records = df.to_dict("records") | ||
|
||
yield [TextOutput(text=f"Found Email For: {name} ({email})")] | ||
return { | ||
"name": EntityConfig( | ||
entities=[ | ||
Entity( | ||
identity=row["name"], | ||
description=row["role"] if isinstance(row["role"], str) else None, | ||
) | ||
for row in records | ||
], | ||
) | ||
} | ||
|
||
|
||
find_email_by_name_function = PythonCallableFunction( | ||
id="find_email", | ||
type="FindEmailFunction", | ||
display_name="Find Email", | ||
description="Find an email address", | ||
sample_questions=["Find the email address for {employee}"], | ||
sample_questions=[ | ||
"Find the email address for {employee}", | ||
"What is {employee}'s email address?", | ||
], | ||
parameters=BaseJSONSchema( | ||
json_schema={ | ||
"type": "object", | ||
|
@@ -47,5 +74,6 @@ async def find_email_by_name_callable( | |
"required": ["name"], | ||
} | ||
), | ||
callable=find_email_by_name_callable, | ||
execute_callable=_execute, | ||
get_entity_configs_callable=_get_entity_configs, | ||
) |
20 changes: 15 additions & 5 deletions
20
packages/openassistants/openassistants/contrib/python_callable.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,32 @@ | ||
from typing import Callable, Sequence | ||
from typing import Awaitable, Callable, Dict, Sequence | ||
|
||
from openassistants.data_models.function_input import BaseJSONSchema | ||
from openassistants.data_models.function_output import FunctionOutput | ||
from openassistants.functions.base import BaseFunction, FunctionExecutionDependency | ||
from openassistants.functions.base import ( | ||
BaseFunction, | ||
EntityConfig, | ||
FunctionExecutionDependency, | ||
) | ||
from openassistants.functions.utils import AsyncStreamVersion | ||
|
||
|
||
class PythonCallableFunction(BaseFunction): | ||
callable: Callable[ | ||
execute_callable: Callable[ | ||
[FunctionExecutionDependency], AsyncStreamVersion[Sequence[FunctionOutput]] | ||
] | ||
|
||
parameters: BaseJSONSchema | ||
|
||
get_entity_configs_callable: Callable[[], Awaitable[dict[str, EntityConfig]]] | ||
|
||
async def execute( | ||
self, deps: FunctionExecutionDependency | ||
) -> AsyncStreamVersion[Sequence[FunctionOutput]]: | ||
async for version in self.callable(deps): | ||
yield version | ||
async for output in self.execute_callable(deps): | ||
yield output | ||
|
||
async def get_parameters_json_schema(self) -> dict: | ||
return self.parameters.json_schema | ||
|
||
async def get_entity_configs(self) -> Dict[str, EntityConfig]: | ||
return await self.get_entity_configs_callable() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
81 changes: 81 additions & 0 deletions
81
packages/openassistants/openassistants/llm_function_calling/entity_resolution.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import asyncio | ||
from typing import Any, Dict, List, Tuple | ||
|
||
from langchain.chat_models.base import BaseChatModel | ||
from langchain.embeddings.base import Embeddings | ||
from langchain.schema import Document | ||
from langchain.vectorstores.usearch import USearch | ||
from openassistants.data_models.chat_messages import OpasUserMessage | ||
from openassistants.functions.base import BaseFunction, Entity, EntityConfig | ||
from openassistants.llm_function_calling.infilling import generate_arguments | ||
|
||
|
||
async def _vec_search( | ||
documents: List[Document], | ||
query: str, | ||
embeddings: Embeddings, | ||
) -> List[Document]: | ||
search: USearch = await USearch.afrom_documents( | ||
embedding=embeddings, | ||
documents=documents, | ||
) | ||
results = await search.asimilarity_search( | ||
query, | ||
k=3, | ||
) | ||
return results | ||
|
||
|
||
def entity_to_document(entity: Entity) -> Document: | ||
doc = Document(metadata=entity.model_dump(), page_content=entity.identity) | ||
|
||
if entity.description: | ||
doc.page_content += f" ({entity.description})" | ||
|
||
return doc | ||
|
||
|
||
async def _get_entities( | ||
entity_cfg: EntityConfig, | ||
entity_key: str, | ||
preliminary_arguments: Dict[str, Any], | ||
embeddings: Embeddings, | ||
) -> Tuple[str, List[Entity]]: | ||
documents = [entity_to_document(entity) for entity in entity_cfg.entities] | ||
|
||
query = str(preliminary_arguments[entity_key]) | ||
|
||
vec_result = await _vec_search(documents, query, embeddings) | ||
|
||
return entity_key, [Entity(**r.metadata) for r in vec_result] | ||
|
||
|
||
async def resolve_entities( | ||
function: BaseFunction, | ||
function_infilling_llm: BaseChatModel, | ||
embeddings: Embeddings, | ||
user_query: str, | ||
chat_history: List[OpasUserMessage], | ||
) -> Dict[str, List[Entity]]: | ||
entity_configs = await function.get_entity_configs() | ||
|
||
# skip if no entity configs | ||
if len(entity_configs) == 0: | ||
return {} | ||
|
||
preliminary_arguments = await generate_arguments( | ||
function, | ||
function_infilling_llm, | ||
user_query, | ||
chat_history, | ||
{}, | ||
) | ||
|
||
results = await asyncio.gather( | ||
*[ | ||
_get_entities(entity_cfg, param_name, preliminary_arguments, embeddings) | ||
for param_name, entity_cfg in entity_configs.items() | ||
] | ||
) | ||
|
||
return {key: entities for key, entities in results} |
Oops, something went wrong.