Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/copilot_sdk_flow/agent_arch/aoai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def get_azure_openai_client(
azure_endpoint is not None or "AZURE_OPENAI_ENDPOINT" in os.environ
), "azure_endpoint is None, AZURE_OPENAI_ENDPOINT environment variable is required"

logging.info(f"Using Azure OpenAI API version: {api_version}")
# create an AzureOpenAI client using AAD or key based auth
if "AZURE_OPENAI_API_KEY" in os.environ:
logging.warning(
Expand All @@ -25,8 +26,7 @@ def get_azure_openai_client(
aoai_client = AzureOpenAI(
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
api_key=os.environ["AZURE_OPENAI_API_KEY"],
api_version=api_version
or os.getenv("AZURE_OPENAI_API_VERSION", "2024-02-15-preview"),
api_version=api_version or os.getenv("AZURE_OPENAI_API_VERSION"),
)
else:
logging.info("Using Azure AD authentification [recommended]")
Expand All @@ -36,8 +36,7 @@ def get_azure_openai_client(
)
aoai_client = AzureOpenAI(
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
api_version=api_version
or os.getenv("AZURE_OPENAI_API_VERSION", "2024-02-15-preview"),
api_version=api_version or os.getenv("AZURE_OPENAI_API_VERSION"),
azure_ad_token_provider=token_provider,
)
return aoai_client
21 changes: 19 additions & 2 deletions src/copilot_sdk_flow/agent_arch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@
from typing import Optional
from typing import Dict
from pydantic import BaseModel
from distutils.util import strtobool


class Configuration(BaseModel):
AZURE_OPENAI_ENDPOINT: str
AZURE_OPENAI_ASSISTANT_ID: str
ORCHESTRATOR_MAX_WAITING_TIME: int = 60
AZURE_OPENAI_API_KEY: Optional[str] = None
AZURE_OPENAI_API_VERSION: Optional[str] = "2024-02-15-preview"
AZURE_OPENAI_API_VERSION: Optional[str] = "2024-05-01-preview"
COMPLETION_INSERT_NOTIFICATIONS: Optional[bool] = False
MAX_COMPLETION_TOKENS: Optional[int] = 1024
MAX_PROMPT_TOKENS: Optional[int] = 2048

@classmethod
def from_env_and_context(cls, context: Dict[str, str]):
Expand Down Expand Up @@ -38,6 +42,19 @@ def from_env_and_context(cls, context: Dict[str, str]):
),
AZURE_OPENAI_API_KEY=os.getenv("AZURE_OPENAI_API_KEY"),
AZURE_OPENAI_API_VERSION=os.getenv(
"AZURE_OPENAI_API_VERSION", "2024-02-15-preview"
"AZURE_OPENAI_API_VERSION", "2024-05-01-preview"
),
COMPLETION_INSERT_NOTIFICATIONS=strtobool(
os.getenv("COMPLETION_INSERT_NOTIFICATIONS", "False")
),
MAX_COMPLETION_TOKENS=int(
context.get("MAX_COMPLETION_TOKENS")
or os.getenv("MAX_COMPLETION_TOKENS")
or "2000"
),
MAX_PROMPT_TOKENS=int(
context.get("MAX_PROMPT_TOKENS")
or os.getenv("MAX_PROMPT_TOKENS")
or "2000"
),
)
35 changes: 35 additions & 0 deletions src/copilot_sdk_flow/agent_arch/event_log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from opentelemetry.trace import get_tracer
from opentelemetry.sdk.trace.export import ConsoleSpanExporter


class EventLogger:
TIME_TO_FIRST_TOKEN = "time_to_first_token"
TIME_TO_FIRST_EXTENSION_CALL = "time_to_first_extension_call"
TIME_TO_RUN_LOOP = "time_to_start_run_loop"
TIME_TO_COMPLETE_RUN_LOOP = "time_to_complete_run_loop"

def __init__(self):
self.tracer = get_tracer(__name__)
self.spans = {}
self.completed_spans = {}

def start_span(self, name: str):
if name in self.spans:
return self.spans[name]
else:
span = self.tracer.start_span(name)
self.spans[name] = span
return span

def end_span(self, name: str):
if name in self.spans:
self.spans[name].end()
self.completed_spans[name] = self.spans[name]
del self.spans[name]

def report(self):
return {
name: span.to_json()
for name, span in self.completed_spans.items()
if span.is_recording()
}
Binary file modified src/copilot_sdk_flow/agent_arch/extensions/data/order_data.db
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "query_order_data",
"description": "Run a SQL query against table `order_data` and return the results in JSON format.\nOrder data is stored in a SQLite table with properties:\n\n# Number_of_Orders INTEGER \"the number of orders processed\"\n# Sum_of_Order_Value_USD REAL \"the total value of the orders processed in USD\"\n# Sum_of_Number_of_Items REAL \"the sum of items in the orders processed\"\n# Number_of_Orders_with_Discount INTEGER \"the number of orders that received a discount\"\n# Sum_of_Discount_Percentage REAL \"the sum of discount percentage -- useful to calculate average discounts given\"\n# Sum_of_Shipping_Cost_USD REAL \"the sum of shipping cost for the processed orders\"\n# Number_of_Orders_Returned INTEGER \"the number of orders returned by the customers\"\n# Number_of_Orders_Cancelled INTEGER \"the number or orders cancelled by the customers before they were sent out\"\n# Sum_of_Time_to_Fulfillment REAL \"the sum of time to fulfillment\"\n# Number_of_Orders_Repeat_Customers INTEGER \"number of orders that were placed by repeat customers\"\n# Year INTEGER\n# Month INTEGER\n# Day INTEGER\n# Date TIMESTAMP\n# Day_of_Week INTEGER in 0 based format, Monday is 0, Tuesday is 1, etc.\n# main_category TEXT\n# sub_category TEXT\n# product_type TEXT\n\nIn this table all numbers are already aggregated, so all queries will be some type of aggregation with group by.",
"description": "Run a SQL query against table `order_data` and return the results in JSON format.\nOrder data is stored in a SQLite table.\n",
"parameters": {
"type": "object",
"properties": {
Expand Down
2 changes: 1 addition & 1 deletion src/copilot_sdk_flow/agent_arch/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ def from_bytes(cls, content: bytes):

class StepNotification(BaseModel):
type: str
content: str
content: Any
96 changes: 64 additions & 32 deletions src/copilot_sdk_flow/agent_arch/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,24 @@
ExtensionReturnMessage,
StepNotification,
)
from agent_arch.event_log import EventLogger


class Orchestrator:
def __init__(self, config: Configuration, client, session, extensions):
def __init__(
self, config: Configuration, client, session, extensions, event_logger
):
self.client = client
self.config = config
self.session = session
self.extensions = extensions
self.event_logger = event_logger

# getting the Assistant API specific constructs
logging.info(
f"Retrieving assistant with id: {config.AZURE_OPENAI_ASSISTANT_ID}"
)
self.assistant = self.client.beta.assistants.retrieve(
self.assistant = trace(self.client.beta.assistants.retrieve)(
self.config.AZURE_OPENAI_ASSISTANT_ID
)
self.thread = self.session.thread
Expand All @@ -40,15 +44,23 @@ def __init__(self, config: Configuration, client, session, extensions):

@trace
def run_loop(self):
# purge previous messages before run started
self._check_messages(skip=True)

logging.info(f"Creating the run")
self.run = self.client.beta.threads.runs.create(
thread_id=self.thread.id, assistant_id=self.assistant.id
self.run = trace(self.client.beta.threads.runs.create)(
thread_id=self.thread.id,
assistant_id=self.assistant.id,
# max_completion_tokens=self.config.MAX_COMPLETION_TOKENS,
# max_prompt_tokens=self.config.MAX_PROMPT_TOKENS,
)
logging.info(f"Pre loop run status: {self.run.status}")

start_time = time.time()

# loop until max_waiting_time is reached
self.event_logger.end_span(EventLogger.TIME_TO_RUN_LOOP)
self.event_logger.start_span(EventLogger.TIME_TO_COMPLETE_RUN_LOOP)
while (time.time() - start_time) < self.config.ORCHESTRATOR_MAX_WAITING_TIME:
# checks the run regularly
self.run = self.client.beta.threads.runs.retrieve(
Expand All @@ -59,31 +71,17 @@ def run_loop(self):
)

# check if a step has been completed
run_steps = self.client.beta.threads.runs.steps.list(
thread_id=self.thread.id, run_id=self.run.id, after=self.last_step_id
)
for step in run_steps:
logging.info(
"The assistant has moved forward to step {}".format(step.id)
)
self.process_step(step)
self.last_step_id = step.id
# self._check_steps()

# check if there are messages
for message in self.client.beta.threads.messages.list(
thread_id=self.thread.id, order="asc", after=self.last_message_id
):
message = self.client.beta.threads.messages.retrieve(
thread_id=self.thread.id, message_id=message.id
)
self.process_message(message)
# self.session.send(message)
self.last_message_id = message.id
self._check_messages()

if self.run.status == "completed":
logging.info(f"Run completed.")
self.event_logger.end_span(EventLogger.TIME_TO_COMPLETE_RUN_LOOP)
return self.completed()
elif self.run.status == "requires_action":
self.event_logger.end_span(EventLogger.TIME_TO_FIRST_EXTENSION_CALL)
logging.info(f"Run requires action.")
self.requires_action()
elif self.run.status == "cancelled":
Expand All @@ -96,11 +94,33 @@ def run_loop(self):
)
elif self.run.status in ["in_progress", "queued"]:
time.sleep(0.25)
elif self.run.status == "incomplete":
raise ValueError(
f"Run incomplete: {self.run.status}, last_error: {self.run.last_error}"
)
else:
raise ValueError(f"Unknown run status: {self.run.status}")

# @trace
def _check_messages(self, skip=False):
# check if there are messages
for message in self.client.beta.threads.messages.list(
thread_id=self.thread.id, order="asc", after=self.last_message_id
):
if not skip:
message = trace(self.client.beta.threads.messages.retrieve)(
thread_id=self.thread.id, message_id=message.id
)
self._process_message(message)
# self.session.send(message)
self.last_message_id = message.id

@trace
def process_message(self, message):
def _process_message(self, message):
if message.content is None:
raise Exception("Message content is None")
if len(message.content) == 0:
raise Exception("Message content is empty []")
for entry in message.content:
if message.role == "user":
# this means a message we just added
Expand All @@ -117,22 +137,32 @@ def process_message(self, message):
else:
logging.critical("Unknown content type: {}".format(entry.type))

# @trace
def _check_steps(self):
"""Check if there are new steps to process"""
run_steps = self.client.beta.threads.runs.steps.list(
thread_id=self.thread.id, run_id=self.run.id, after=self.last_step_id
)
for step in run_steps:
if step.status == "completed":
logging.info(
"The assistant has moved forward to step {}".format(step.id)
)
self._process_completed_step(step)
self.last_step_id = step.id

@trace
def process_step(self, step):
def _process_completed_step(self, step):
"""Process a step from the run"""
if step.type == "tool_calls":
for tool_call in step.step_details.tool_calls:
if tool_call.type == "code":
if tool_call.type == "code_interpreter":
self.session.send(
StepNotification(
type=step.type, content=str(tool_call.model_dump())
)
StepNotification(type=tool_call.type, content=tool_call)
)
elif tool_call.type == "function":
self.session.send(
StepNotification(
type=step.type, content=str(tool_call.model_dump())
)
StepNotification(type=tool_call.type, content=tool_call)
)
else:
logging.error(f"Unsupported tool call type: {tool_call.type}")
Expand All @@ -142,6 +172,8 @@ def process_step(self, step):
@trace
def completed(self):
"""What to do when run.status == 'completed'"""
# self._check_steps()
self._check_messages()
self.session.close()

@trace
Expand Down Expand Up @@ -194,7 +226,7 @@ def requires_action(self):

if tool_call_outputs:
logging.info(f"Submitting tool outputs: {tool_call_outputs}")
_ = self.client.beta.threads.runs.submit_tool_outputs(
_ = trace(self.client.beta.threads.runs.submit_tool_outputs)(
thread_id=self.thread.id,
run_id=self.run.id,
tool_outputs=tool_call_outputs,
Expand Down
41 changes: 41 additions & 0 deletions src/copilot_sdk_flow/agent_arch/prompts/data_schema.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
### SQLite table with properties:
#
# Number_of_Orders INTEGER "the number of orders processed"
# Sum_of_Order_Value_USD REAL "the total value of the orders processed in USD"
# Sum_of_Number_of_Items REAL "the sum of items in the orders processed"
# Number_of_Orders_with_Discount INTEGER "the number of orders that received a discount"
# Sum_of_Discount_Percentage REAL "the sum of discount percentage -- useful to calculate average discounts given"
# Sum_of_Shipping_Cost_USD REAL "the sum of shipping cost for the processed orders"
# Number_of_Orders_Returned INTEGER "the number of orders returned by the customers"
# Number_of_Orders_Cancelled INTEGER "the number or orders cancelled by the customers before they were sent out"
# Sum_of_Time_to_Fulfillment REAL "the sum of time to fulfillment"
# Number_of_Orders_Repeat_Customers INTEGER "number of orders that were placed by repeat customers"
# Year INTEGER
# Month INTEGER
# Day INTEGER
# Date TIMESTAMP
# Day_of_Week INTEGER in 0 based format, Monday is 0, Tuesday is 1, etc.
# main_category TEXT
# sub_category TEXT
# product_type TEXT
# Region TEXT
#
In this table all numbers are already aggregated, so all queries will be some type of aggregation with group by.

in your reply only provide the query with no extra formatting
never use the AVG() function in SQL, always use SUM() / SUM() to get the average

Note that all categories, i.e. main_category, sub_category, product_type and Region all contain only UPPER CASE values.
So, whenever you are filtering or grouping by these values, make sure to provide the values in UPPER CASE.

When you query for a sub_category, make sure to always provide the main_category as well, for instance:
SELECT SUM(Number_of_Orders) FROM order_data WHERE main_category = "APPAREL" AND sub_category = "MEN'S CLOTHING" AND Month = 5 AND Year = 2024

When you query for the product_type, make sure to provide the main_category and sub_category as well, for instance:
SELECT SUM(Number_of_Orders) FROM order_data WHERE main_category = "TRAVEL" AND sub_category = "LUGGAGE & BAGS" AND product_type = "TRAVEL BACKPACKS" AND Month = 5 AND Year = 2024

To avoid issues with apostrophes, when referring to categories, always use double-quotes, for instance:
SELECT SUM(Number_of_Orders) FROM order_data WHERE main_category = "APPAREL" AND sub_category = "MEN'S CLOTHING" AND Month = 5 AND Year = 2024

Here are the valid values for the Region:
[{"Region":"NORTH AMERICA"},{"Region":"EUROPE"},{"Region":"ASIA-PACIFIC"},{"Region":"AFRICA"},{"Region":"MIDDLE EAST"},{"Region":"SOUTH AMERICA"}]
16 changes: 16 additions & 0 deletions src/copilot_sdk_flow/agent_arch/prompts/system_message.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
You are a helpful assistant that helps the user potentially with the help of some functions.

If you are using multiple tools to solve a user's task, make sure to communicate
information learned from one tool to the next tool.
First, make a plan of how you will use the tools to solve the user's task and communicated
that plan to the user with the first response. Then execute the plan making sure to communicate
the required information between tools since tools only see the information passed to them;
They do not have access to the chat history.

Only use a tool when it is necessary to solve the user's task.
Don't use a tool if you can answer the user's question directly.
Only use the tools provided in the tools list -- don't make up tools!!
If you are not getting the right information from a tool, make sure to ask the user for clarification.
Do not just return the wrong information. Do not make up information.

Anything that would benefit from a tabular presentation should be returned as markup table.
Loading