Skip to content

Commit

Permalink
Merge pull request wise-agents#310 from kabir/custom_yaml_loader
Browse files Browse the repository at this point in the history
Custom yaml loader
  • Loading branch information
fjuma authored Sep 18, 2024
2 parents e7a0eef + 28c3e50 commit bb209e1
Show file tree
Hide file tree
Showing 14 changed files with 100 additions and 33 deletions.
16 changes: 2 additions & 14 deletions src/wiseagents/cli/wise_agent_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import threading
import traceback
from typing import List
from wiseagents.yaml import WiseAgentsLoader

import yaml

Expand Down Expand Up @@ -60,23 +61,10 @@ def main():
file_path = input(f'Enter the file path (ENTER for default {default_file_path} ): ')
if not file_path:
file_path = default_file_path
with open(file_path) as stream:
try:
for token in yaml.scan(stream):
if type(token) is yaml.TagToken and token.value[0] == "!":
package_name = ""
for part in token.value[1].split(".")[:-1]:
package_name += part + "."
package_name = package_name[:-1]
print(f'importing {package_name}')
importlib.import_module(package_name)

except yaml.YAMLError as exc:
traceback.print_exc()
with open(file_path) as stream:
try:

for agent in yaml.load_all(stream, Loader=yaml.FullLoader):
for agent in yaml.load_all(stream, Loader=WiseAgentsLoader):
agent : WiseAgent
print(f'Loaded agent: {agent.name}')
if agent.name == "PassThroughClientAgent1":
Expand Down
4 changes: 4 additions & 0 deletions src/wiseagents/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from wiseagents.graphdb import WiseAgentGraphDB
from wiseagents.llm import OpenaiAPIWiseAgentLLM, WiseAgentLLM
from wiseagents.vectordb import WiseAgentVectorDB
from wiseagents.yaml import WiseAgentsLoader
from wiseagents.wise_agent_messaging import WiseAgentMessage, WiseAgentMessageType, WiseAgentTransport, WiseAgentEvent


Expand All @@ -28,6 +29,8 @@ class WiseAgentCollaborationType(Enum):
class WiseAgentTool(yaml.YAMLObject):
''' A WiseAgentTool is an abstract class that represents a tool that can be used by an agent to perform a specific task.'''
yaml_tag = u'!wiseagents.WiseAgentTool'
yaml_loader = WiseAgentsLoader

def __init__(self, name: str, description: str, agent_tool: bool, parameters_json_schema: dict = {},
call_back : Optional[Callable[...,str]] = None):
''' Initialize the tool with the given name, description, agent tool, parameters json schema, and call back.
Expand Down Expand Up @@ -826,6 +829,7 @@ class WiseAgent(yaml.YAMLObject):
''' A WiseAgent is an abstract class that represents an agent that can send and receive messages to and from other agents.
'''
yaml_tag = u'!wiseagents.WiseAgent'
yaml_loader = WiseAgentsLoader

def __new__(cls, *args, **kwargs):
'''Create a new instance of the class, setting default values for the instance variables.'''
Expand Down
2 changes: 2 additions & 0 deletions src/wiseagents/graphdb/wise_agent_graph_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import yaml
from pydantic import BaseModel, Field
from wiseagents.yaml import WiseAgentsLoader


class Entity(BaseModel):
Expand Down Expand Up @@ -68,6 +69,7 @@ class WiseAgentGraphDB(yaml.YAMLObject):
"""Abstract class to define the interface for a WiseAgentGraphDB."""

yaml_tag = u'!WiseAgentGraphDB'
yaml_loader = WiseAgentsLoader

@abstractmethod
def get_schema(self) -> str:
Expand Down
4 changes: 3 additions & 1 deletion src/wiseagents/llm/wise_agent_LLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

import yaml
from openai.types.chat import ChatCompletionMessageParam, ChatCompletion, ChatCompletionToolParam
from wiseagents.yaml import WiseAgentsLoader


class WiseAgentLLM(yaml.YAMLObject):
"""Abstract class to define the interface for a WiseAgentLLM."""
yaml_tag = u'!WiseAgentLLM'
yaml_tag = u'!WiseAgentLLM'
yaml_loader = WiseAgentsLoader
def __init__(self, system_message, model_name):
'''Initialize the agent.
Expand Down
9 changes: 3 additions & 6 deletions src/wiseagents/transports/stomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class StompWiseAgentTransport(WiseAgentTransport):
yaml_tag = u'!wiseagents.transports.StompWiseAgentTransport'
request_conn : stomp.Connection = None
response_conn : stomp.Connection = None

def __init__(self, host: str, port: int, agent_name: str):
'''Initialize the transport.
Expand All @@ -71,14 +72,10 @@ def __repr__(self) -> str:

def __getstate__(self) -> object:
'''Return the state of the transport. Removing the instance variable chain to avoid it is serialized/deserialized by pyyaml.'''
state = self.__dict__.copy()
del state['_request_receiver']
del state['_response_receiver']
del state['_event_receiver']
del state['_error_receiver']
state = super().__getstate__()
del state['request_conn']
del state['response_conn']
return state
return state


def start(self):
Expand Down
4 changes: 4 additions & 0 deletions src/wiseagents/vectordb/wise_agent_vector_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

import yaml
from pydantic import BaseModel, Field
from wiseagents.yaml import WiseAgentsLoader

import wiseagents.yaml


class Document(BaseModel):
Expand All @@ -22,6 +25,7 @@ class Document(BaseModel):
class WiseAgentVectorDB(yaml.YAMLObject):
"""Abstract class to define the interface for a WiseAgentVectorDB."""
yaml_tag = u'!WiseAgentVectorDB'
yaml_loader = WiseAgentsLoader

@abstractmethod
def get_or_create_collection(self, collection_name: str):
Expand Down
15 changes: 13 additions & 2 deletions src/wiseagents/wise_agent_messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Callable, Optional

from yaml import YAMLObject

from wiseagents.yaml import WiseAgentsLoader

class WiseAgentMessageType(Enum):
ACK = auto()
Expand Down Expand Up @@ -94,7 +94,8 @@ def route_response_to(self) -> str:
return self._route_response_to

class WiseAgentTransport(YAMLObject):

yaml_loader = WiseAgentsLoader

''' A transport for sending messages between agents. '''
def set_call_backs(self, request_receiver: Optional[Callable[[], WiseAgentMessage]] = None,
event_receiver: Optional[Callable[[], WiseAgentEvent]] = None,
Expand All @@ -112,6 +113,16 @@ def set_call_backs(self, request_receiver: Optional[Callable[[], WiseAgentMessag
self._event_receiver = event_receiver
self._error_receiver = error_receiver
self._response_receiver = response_receiver

def __getstate__(self) -> object:
'''Return the state of the transport. Removing the instance variable chain to avoid it is serialized/deserialized by pyyaml.'''
state = self.__dict__.copy()
del state['_request_receiver']
del state['_response_receiver']
del state['_event_receiver']
del state['_error_receiver']
return state


@abstractmethod
def start(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# This is the __init__.py file for the wiseagents.yaml_env package
# This is the __init__.py file for the wiseagents.yaml package

# Import any modules or subpackages here
from .yaml_utils import setup_yaml_for_env_vars
from .wise_yaml_loader import WiseAgentsLoader


# Define any necessary initialization code here

# Optionally, you can define __all__ to specify the public interface of the package
# __all__ = ['module1', 'module2', 'subpackage']
__all__ = ['setup_yaml_for_env_vars']
__all__ = ['setup_yaml_for_env_vars', 'WiseAgentsLoader']
55 changes: 55 additions & 0 deletions src/wiseagents/yaml/wise_yaml_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import importlib
import yaml

from yaml.reader import Reader
from yaml.scanner import Scanner
from yaml.parser import Parser
from yaml.composer import Composer
from yaml.constructor import FullConstructor
from yaml.resolver import Resolver


class WiseAgentsLoader(Reader, Scanner, Parser, Composer, FullConstructor, Resolver):

def __init__(self, stream):
opened_file = False
try:
stream_copy = None
if isinstance(stream, str):
stream_copy = "" + stream
elif isinstance(stream, bytes):
stream_copy = b"" + stream
else:
opened_file = True
stream_copy = open(getattr(stream, 'name', "<file>"))

Reader.__init__(self, stream)
Scanner.__init__(self)
Parser.__init__(self)
Composer.__init__(self)
FullConstructor.__init__(self)
Resolver.__init__(self)

seen_classes = {}
seen_packages = {}

for token in yaml.scan(stream_copy):
if type(token) is yaml.TagToken and token.value[0] == "!":
if token.value in seen_classes.keys():
continue
seen_classes[token.value] = True
package_name = ""
for part in token.value[1].split(".")[:-1]:
package_name += part + "."
package_name = package_name[:-1]
if package_name in seen_packages.values():
continue
seen_packages[package_name] = True
importlib.import_module(package_name)

finally:
if opened_file:
stream_copy.close()

def construct_document(self, node):
return super().construct_document(node)
File renamed without changes.
5 changes: 3 additions & 2 deletions tests/wiseagents/test_yaml_deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import yaml

from wiseagents import WiseAgent, WiseAgentRegistry
from wiseagents.yaml import WiseAgentsLoader


@pytest.fixture(scope="session", autouse=True)
Expand All @@ -19,7 +20,7 @@ def test_using_deserialized_agent():
# Create a WiseAgent object
with open(pathlib.Path().resolve() / "tests/wiseagents/test.yaml") as stream:
try:
deserialized_agent = yaml.load(stream, Loader=yaml.Loader)
deserialized_agent = yaml.load(stream, Loader=WiseAgentsLoader)
except yaml.YAMLError as exc:
print(exc)
# Assert that the serialized agent can be deserialized back to a WiseAgent object
Expand Down Expand Up @@ -52,7 +53,7 @@ def test_using_multiple_deserialized_agents():
deserialized_agent = []
with open(pathlib.Path().resolve() / "tests/wiseagents/test-multiple.yaml") as stream:
try:
for agent in yaml.load_all(stream, Loader=yaml.Loader):
for agent in yaml.load_all(stream, Loader=WiseAgentsLoader):
deserialized_agent.append(agent)
except yaml.YAMLError as exc:
print(exc)
Expand Down
8 changes: 5 additions & 3 deletions tests/wiseagents/test_yaml_serializtion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from wiseagents.graphdb import Neo4jLangChainWiseAgentGraphDB
from wiseagents.llm import OpenaiAPIWiseAgentLLM
from wiseagents.vectordb import PGVectorLangChainWiseAgentVectorDB
from wiseagents.yaml import WiseAgentsLoader


@pytest.fixture(scope="session", autouse=True)
Expand All @@ -16,6 +17,7 @@ def run_after_all_tests():


class DummyTransport(WiseAgentTransport):
yaml_tag = "!tests.wiseagents.test_yaml_serializtion.DummyTransport"
def __init__(self):
pass

Expand Down Expand Up @@ -60,7 +62,7 @@ def test_serialize_wise_agent(monkeypatch):
logging.debug(serialized_agent)

# Assert that the serialized agent can be deserialized back to a WiseAgent object
deserialized_agent = yaml.load(serialized_agent, Loader=yaml.Loader)
deserialized_agent = yaml.load(serialized_agent, Loader=WiseAgentsLoader)
assert isinstance(deserialized_agent, WiseAgent)
assert deserialized_agent.name == agent.name
assert deserialized_agent.description == agent.description
Expand Down Expand Up @@ -110,7 +112,7 @@ def test_using_deserialized_agent(monkeypatch):
logging.debug(serialized_agent)

# Assert that the serialized agent can be deserialized back to a WiseAgent object
deserialized_agent = yaml.load(serialized_agent, Loader=yaml.Loader)
deserialized_agent = yaml.load(serialized_agent, Loader=WiseAgentsLoader)
assert isinstance(deserialized_agent, WiseAgent)
assert deserialized_agent.name == agent.name
assert deserialized_agent.description == agent.description
Expand All @@ -135,6 +137,6 @@ def test_serialize_assistant():
assistant = AssistantAgent(name="Assistant1", description="This is a test assistant", transport=DummyTransport(), destination_agent_name="")
serialized_assistant = yaml.dump(assistant)
logging.info(serialized_assistant)
deserialized_agent = yaml.load(serialized_assistant, Loader=yaml.Loader)
deserialized_agent = yaml.load(serialized_assistant, Loader=WiseAgentsLoader)
finally:
assistant.stop_agent()
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# This is the __init__.py file for the wiseagents.yaml_env package
# This is the __init__.py file for the wiseagents.yaml package

# Import any modules or subpackages here

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import yaml

# This appears to be unused but actually does something!
import wiseagents.yaml_env
import wiseagents.yaml

wiseagents.yaml_env.setup_yaml_for_env_vars()
wiseagents.yaml.setup_yaml_for_env_vars()


def test_yaml_env_var_not_set():
Expand Down

0 comments on commit bb209e1

Please sign in to comment.