diff --git a/src/copilot_sdk_flow/agent_arch/aoai.py b/src/copilot_sdk_flow/agent_arch/aoai.py index a7e0ec6..b3e42a0 100644 --- a/src/copilot_sdk_flow/agent_arch/aoai.py +++ b/src/copilot_sdk_flow/agent_arch/aoai.py @@ -17,6 +17,7 @@ def get_azure_openai_client( azure_endpoint is not None or "AZURE_OPENAI_ENDPOINT" in os.environ ), "azure_endpoint is None, AZURE_OPENAI_ENDPOINT environment variable is required" + logging.info(f"Using Azure OpenAI API version: {api_version}") # create an AzureOpenAI client using AAD or key based auth if "AZURE_OPENAI_API_KEY" in os.environ: logging.warning( @@ -25,8 +26,7 @@ def get_azure_openai_client( aoai_client = AzureOpenAI( azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"], api_key=os.environ["AZURE_OPENAI_API_KEY"], - api_version=api_version - or os.getenv("AZURE_OPENAI_API_VERSION", "2024-02-15-preview"), + api_version=api_version or os.getenv("AZURE_OPENAI_API_VERSION"), ) else: logging.info("Using Azure AD authentification [recommended]") @@ -36,8 +36,7 @@ def get_azure_openai_client( ) aoai_client = AzureOpenAI( azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"], - api_version=api_version - or os.getenv("AZURE_OPENAI_API_VERSION", "2024-02-15-preview"), + api_version=api_version or os.getenv("AZURE_OPENAI_API_VERSION"), azure_ad_token_provider=token_provider, ) return aoai_client diff --git a/src/copilot_sdk_flow/agent_arch/config.py b/src/copilot_sdk_flow/agent_arch/config.py index adf50bb..f2c2dec 100644 --- a/src/copilot_sdk_flow/agent_arch/config.py +++ b/src/copilot_sdk_flow/agent_arch/config.py @@ -3,6 +3,7 @@ from typing import Optional from typing import Dict from pydantic import BaseModel +from distutils.util import strtobool class Configuration(BaseModel): @@ -10,7 +11,10 @@ class Configuration(BaseModel): AZURE_OPENAI_ASSISTANT_ID: str ORCHESTRATOR_MAX_WAITING_TIME: int = 60 AZURE_OPENAI_API_KEY: Optional[str] = None - AZURE_OPENAI_API_VERSION: Optional[str] = "2024-02-15-preview" + AZURE_OPENAI_API_VERSION: Optional[str] = "2024-05-01-preview" + COMPLETION_INSERT_NOTIFICATIONS: Optional[bool] = False + MAX_COMPLETION_TOKENS: Optional[int] = 1024 + MAX_PROMPT_TOKENS: Optional[int] = 2048 @classmethod def from_env_and_context(cls, context: Dict[str, str]): @@ -38,6 +42,19 @@ def from_env_and_context(cls, context: Dict[str, str]): ), AZURE_OPENAI_API_KEY=os.getenv("AZURE_OPENAI_API_KEY"), AZURE_OPENAI_API_VERSION=os.getenv( - "AZURE_OPENAI_API_VERSION", "2024-02-15-preview" + "AZURE_OPENAI_API_VERSION", "2024-05-01-preview" + ), + COMPLETION_INSERT_NOTIFICATIONS=strtobool( + os.getenv("COMPLETION_INSERT_NOTIFICATIONS", "False") + ), + MAX_COMPLETION_TOKENS=int( + context.get("MAX_COMPLETION_TOKENS") + or os.getenv("MAX_COMPLETION_TOKENS") + or "2000" + ), + MAX_PROMPT_TOKENS=int( + context.get("MAX_PROMPT_TOKENS") + or os.getenv("MAX_PROMPT_TOKENS") + or "2000" ), ) diff --git a/src/copilot_sdk_flow/agent_arch/event_log.py b/src/copilot_sdk_flow/agent_arch/event_log.py new file mode 100644 index 0000000..956d02e --- /dev/null +++ b/src/copilot_sdk_flow/agent_arch/event_log.py @@ -0,0 +1,35 @@ +from opentelemetry.trace import get_tracer +from opentelemetry.sdk.trace.export import ConsoleSpanExporter + + +class EventLogger: + TIME_TO_FIRST_TOKEN = "time_to_first_token" + TIME_TO_FIRST_EXTENSION_CALL = "time_to_first_extension_call" + TIME_TO_RUN_LOOP = "time_to_start_run_loop" + TIME_TO_COMPLETE_RUN_LOOP = "time_to_complete_run_loop" + + def __init__(self): + self.tracer = get_tracer(__name__) + self.spans = {} + self.completed_spans = {} + + def start_span(self, name: str): + if name in self.spans: + return self.spans[name] + else: + span = self.tracer.start_span(name) + self.spans[name] = span + return span + + def end_span(self, name: str): + if name in self.spans: + self.spans[name].end() + self.completed_spans[name] = self.spans[name] + del self.spans[name] + + def report(self): + return { + name: span.to_json() + for name, span in self.completed_spans.items() + if span.is_recording() + } diff --git a/src/copilot_sdk_flow/agent_arch/extensions/data/order_data.db b/src/copilot_sdk_flow/agent_arch/extensions/data/order_data.db index e2d9d04..a37ffa5 100644 Binary files a/src/copilot_sdk_flow/agent_arch/extensions/data/order_data.db and b/src/copilot_sdk_flow/agent_arch/extensions/data/order_data.db differ diff --git a/src/copilot_sdk_flow/agent_arch/extensions/query_order_data.json b/src/copilot_sdk_flow/agent_arch/extensions/query_order_data.json index 0270ec7..5a26a07 100644 --- a/src/copilot_sdk_flow/agent_arch/extensions/query_order_data.json +++ b/src/copilot_sdk_flow/agent_arch/extensions/query_order_data.json @@ -1,6 +1,6 @@ { "name": "query_order_data", - "description": "Run a SQL query against table `order_data` and return the results in JSON format.\nOrder data is stored in a SQLite table with properties:\n\n# Number_of_Orders INTEGER \"the number of orders processed\"\n# Sum_of_Order_Value_USD REAL \"the total value of the orders processed in USD\"\n# Sum_of_Number_of_Items REAL \"the sum of items in the orders processed\"\n# Number_of_Orders_with_Discount INTEGER \"the number of orders that received a discount\"\n# Sum_of_Discount_Percentage REAL \"the sum of discount percentage -- useful to calculate average discounts given\"\n# Sum_of_Shipping_Cost_USD REAL \"the sum of shipping cost for the processed orders\"\n# Number_of_Orders_Returned INTEGER \"the number of orders returned by the customers\"\n# Number_of_Orders_Cancelled INTEGER \"the number or orders cancelled by the customers before they were sent out\"\n# Sum_of_Time_to_Fulfillment REAL \"the sum of time to fulfillment\"\n# Number_of_Orders_Repeat_Customers INTEGER \"number of orders that were placed by repeat customers\"\n# Year INTEGER\n# Month INTEGER\n# Day INTEGER\n# Date TIMESTAMP\n# Day_of_Week INTEGER in 0 based format, Monday is 0, Tuesday is 1, etc.\n# main_category TEXT\n# sub_category TEXT\n# product_type TEXT\n\nIn this table all numbers are already aggregated, so all queries will be some type of aggregation with group by.", + "description": "Run a SQL query against table `order_data` and return the results in JSON format.\nOrder data is stored in a SQLite table.\n", "parameters": { "type": "object", "properties": { diff --git a/src/copilot_sdk_flow/agent_arch/messages.py b/src/copilot_sdk_flow/agent_arch/messages.py index f40243a..adba527 100644 --- a/src/copilot_sdk_flow/agent_arch/messages.py +++ b/src/copilot_sdk_flow/agent_arch/messages.py @@ -33,4 +33,4 @@ def from_bytes(cls, content: bytes): class StepNotification(BaseModel): type: str - content: str + content: Any diff --git a/src/copilot_sdk_flow/agent_arch/orchestrator.py b/src/copilot_sdk_flow/agent_arch/orchestrator.py index 987d04c..7b0be4d 100644 --- a/src/copilot_sdk_flow/agent_arch/orchestrator.py +++ b/src/copilot_sdk_flow/agent_arch/orchestrator.py @@ -14,20 +14,24 @@ ExtensionReturnMessage, StepNotification, ) +from agent_arch.event_log import EventLogger class Orchestrator: - def __init__(self, config: Configuration, client, session, extensions): + def __init__( + self, config: Configuration, client, session, extensions, event_logger + ): self.client = client self.config = config self.session = session self.extensions = extensions + self.event_logger = event_logger # getting the Assistant API specific constructs logging.info( f"Retrieving assistant with id: {config.AZURE_OPENAI_ASSISTANT_ID}" ) - self.assistant = self.client.beta.assistants.retrieve( + self.assistant = trace(self.client.beta.assistants.retrieve)( self.config.AZURE_OPENAI_ASSISTANT_ID ) self.thread = self.session.thread @@ -40,15 +44,23 @@ def __init__(self, config: Configuration, client, session, extensions): @trace def run_loop(self): + # purge previous messages before run started + self._check_messages(skip=True) + logging.info(f"Creating the run") - self.run = self.client.beta.threads.runs.create( - thread_id=self.thread.id, assistant_id=self.assistant.id + self.run = trace(self.client.beta.threads.runs.create)( + thread_id=self.thread.id, + assistant_id=self.assistant.id, + # max_completion_tokens=self.config.MAX_COMPLETION_TOKENS, + # max_prompt_tokens=self.config.MAX_PROMPT_TOKENS, ) logging.info(f"Pre loop run status: {self.run.status}") start_time = time.time() # loop until max_waiting_time is reached + self.event_logger.end_span(EventLogger.TIME_TO_RUN_LOOP) + self.event_logger.start_span(EventLogger.TIME_TO_COMPLETE_RUN_LOOP) while (time.time() - start_time) < self.config.ORCHESTRATOR_MAX_WAITING_TIME: # checks the run regularly self.run = self.client.beta.threads.runs.retrieve( @@ -59,31 +71,17 @@ def run_loop(self): ) # check if a step has been completed - run_steps = self.client.beta.threads.runs.steps.list( - thread_id=self.thread.id, run_id=self.run.id, after=self.last_step_id - ) - for step in run_steps: - logging.info( - "The assistant has moved forward to step {}".format(step.id) - ) - self.process_step(step) - self.last_step_id = step.id + # self._check_steps() # check if there are messages - for message in self.client.beta.threads.messages.list( - thread_id=self.thread.id, order="asc", after=self.last_message_id - ): - message = self.client.beta.threads.messages.retrieve( - thread_id=self.thread.id, message_id=message.id - ) - self.process_message(message) - # self.session.send(message) - self.last_message_id = message.id + self._check_messages() if self.run.status == "completed": logging.info(f"Run completed.") + self.event_logger.end_span(EventLogger.TIME_TO_COMPLETE_RUN_LOOP) return self.completed() elif self.run.status == "requires_action": + self.event_logger.end_span(EventLogger.TIME_TO_FIRST_EXTENSION_CALL) logging.info(f"Run requires action.") self.requires_action() elif self.run.status == "cancelled": @@ -96,11 +94,33 @@ def run_loop(self): ) elif self.run.status in ["in_progress", "queued"]: time.sleep(0.25) + elif self.run.status == "incomplete": + raise ValueError( + f"Run incomplete: {self.run.status}, last_error: {self.run.last_error}" + ) else: raise ValueError(f"Unknown run status: {self.run.status}") + # @trace + def _check_messages(self, skip=False): + # check if there are messages + for message in self.client.beta.threads.messages.list( + thread_id=self.thread.id, order="asc", after=self.last_message_id + ): + if not skip: + message = trace(self.client.beta.threads.messages.retrieve)( + thread_id=self.thread.id, message_id=message.id + ) + self._process_message(message) + # self.session.send(message) + self.last_message_id = message.id + @trace - def process_message(self, message): + def _process_message(self, message): + if message.content is None: + raise Exception("Message content is None") + if len(message.content) == 0: + raise Exception("Message content is empty []") for entry in message.content: if message.role == "user": # this means a message we just added @@ -117,22 +137,32 @@ def process_message(self, message): else: logging.critical("Unknown content type: {}".format(entry.type)) + # @trace + def _check_steps(self): + """Check if there are new steps to process""" + run_steps = self.client.beta.threads.runs.steps.list( + thread_id=self.thread.id, run_id=self.run.id, after=self.last_step_id + ) + for step in run_steps: + if step.status == "completed": + logging.info( + "The assistant has moved forward to step {}".format(step.id) + ) + self._process_completed_step(step) + self.last_step_id = step.id + @trace - def process_step(self, step): + def _process_completed_step(self, step): """Process a step from the run""" if step.type == "tool_calls": for tool_call in step.step_details.tool_calls: - if tool_call.type == "code": + if tool_call.type == "code_interpreter": self.session.send( - StepNotification( - type=step.type, content=str(tool_call.model_dump()) - ) + StepNotification(type=tool_call.type, content=tool_call) ) elif tool_call.type == "function": self.session.send( - StepNotification( - type=step.type, content=str(tool_call.model_dump()) - ) + StepNotification(type=tool_call.type, content=tool_call) ) else: logging.error(f"Unsupported tool call type: {tool_call.type}") @@ -142,6 +172,8 @@ def process_step(self, step): @trace def completed(self): """What to do when run.status == 'completed'""" + # self._check_steps() + self._check_messages() self.session.close() @trace @@ -194,7 +226,7 @@ def requires_action(self): if tool_call_outputs: logging.info(f"Submitting tool outputs: {tool_call_outputs}") - _ = self.client.beta.threads.runs.submit_tool_outputs( + _ = trace(self.client.beta.threads.runs.submit_tool_outputs)( thread_id=self.thread.id, run_id=self.run.id, tool_outputs=tool_call_outputs, diff --git a/src/copilot_sdk_flow/agent_arch/prompts/data_schema.jinja2 b/src/copilot_sdk_flow/agent_arch/prompts/data_schema.jinja2 new file mode 100644 index 0000000..5fde56d --- /dev/null +++ b/src/copilot_sdk_flow/agent_arch/prompts/data_schema.jinja2 @@ -0,0 +1,41 @@ +### SQLite table with properties: + # + # Number_of_Orders INTEGER "the number of orders processed" + # Sum_of_Order_Value_USD REAL "the total value of the orders processed in USD" + # Sum_of_Number_of_Items REAL "the sum of items in the orders processed" + # Number_of_Orders_with_Discount INTEGER "the number of orders that received a discount" + # Sum_of_Discount_Percentage REAL "the sum of discount percentage -- useful to calculate average discounts given" + # Sum_of_Shipping_Cost_USD REAL "the sum of shipping cost for the processed orders" + # Number_of_Orders_Returned INTEGER "the number of orders returned by the customers" + # Number_of_Orders_Cancelled INTEGER "the number or orders cancelled by the customers before they were sent out" + # Sum_of_Time_to_Fulfillment REAL "the sum of time to fulfillment" + # Number_of_Orders_Repeat_Customers INTEGER "number of orders that were placed by repeat customers" + # Year INTEGER + # Month INTEGER + # Day INTEGER + # Date TIMESTAMP + # Day_of_Week INTEGER in 0 based format, Monday is 0, Tuesday is 1, etc. + # main_category TEXT + # sub_category TEXT + # product_type TEXT + # Region TEXT + # +In this table all numbers are already aggregated, so all queries will be some type of aggregation with group by. + +in your reply only provide the query with no extra formatting +never use the AVG() function in SQL, always use SUM() / SUM() to get the average + +Note that all categories, i.e. main_category, sub_category, product_type and Region all contain only UPPER CASE values. +So, whenever you are filtering or grouping by these values, make sure to provide the values in UPPER CASE. + +When you query for a sub_category, make sure to always provide the main_category as well, for instance: +SELECT SUM(Number_of_Orders) FROM order_data WHERE main_category = "APPAREL" AND sub_category = "MEN'S CLOTHING" AND Month = 5 AND Year = 2024 + +When you query for the product_type, make sure to provide the main_category and sub_category as well, for instance: +SELECT SUM(Number_of_Orders) FROM order_data WHERE main_category = "TRAVEL" AND sub_category = "LUGGAGE & BAGS" AND product_type = "TRAVEL BACKPACKS" AND Month = 5 AND Year = 2024 + +To avoid issues with apostrophes, when referring to categories, always use double-quotes, for instance: +SELECT SUM(Number_of_Orders) FROM order_data WHERE main_category = "APPAREL" AND sub_category = "MEN'S CLOTHING" AND Month = 5 AND Year = 2024 + +Here are the valid values for the Region: +[{"Region":"NORTH AMERICA"},{"Region":"EUROPE"},{"Region":"ASIA-PACIFIC"},{"Region":"AFRICA"},{"Region":"MIDDLE EAST"},{"Region":"SOUTH AMERICA"}] diff --git a/src/copilot_sdk_flow/agent_arch/prompts/system_message.jinja2 b/src/copilot_sdk_flow/agent_arch/prompts/system_message.jinja2 new file mode 100644 index 0000000..6dd6615 --- /dev/null +++ b/src/copilot_sdk_flow/agent_arch/prompts/system_message.jinja2 @@ -0,0 +1,16 @@ +You are a helpful assistant that helps the user potentially with the help of some functions. + +If you are using multiple tools to solve a user's task, make sure to communicate +information learned from one tool to the next tool. +First, make a plan of how you will use the tools to solve the user's task and communicated +that plan to the user with the first response. Then execute the plan making sure to communicate +the required information between tools since tools only see the information passed to them; +They do not have access to the chat history. + +Only use a tool when it is necessary to solve the user's task. +Don't use a tool if you can answer the user's question directly. +Only use the tools provided in the tools list -- don't make up tools!! +If you are not getting the right information from a tool, make sure to ask the user for clarification. +Do not just return the wrong information. Do not make up information. + +Anything that would benefit from a tabular presentation should be returned as markup table. diff --git a/src/copilot_sdk_flow/agent_arch/sessions.py b/src/copilot_sdk_flow/agent_arch/sessions.py index eb486ae..f762a2e 100644 --- a/src/copilot_sdk_flow/agent_arch/sessions.py +++ b/src/copilot_sdk_flow/agent_arch/sessions.py @@ -14,21 +14,24 @@ TextResponse, ImageResponse, ) +from agent_arch.config import Configuration class Session: """Represents a session with the assistant.""" - def __init__(self, thread: Thread, client: AzureOpenAI): + def __init__(self, thread: Thread, client: AzureOpenAI, config: Configuration): """Initializes a new session with the assistant. Args: thread (Thread): The thread associated with the session. client (AzureOpenAI): The AzureOpenAI client. + config (Configuration): The configuration. """ self.id = thread.id self.thread = thread self.client = client + self.config = config self.output_queue = deque() self.open = True @@ -64,30 +67,43 @@ def send(self, message: Any): Args: message (Any): The message to send. """ - if isinstance(message, ExtensionCallMessage): + output_message = None # if nothing works, we do not output anything + + if ( + isinstance(message, ExtensionCallMessage) + and self.config.COMPLETION_INSERT_NOTIFICATIONS + ): if message.name == "query_order_data": output_message = f"_Calling extension `{message.name}` with SQL query:_\n```sql\n{message.args['sql_query']}\n```\n\n" else: output_message = f"_Calling extension `{message.name}`_\n\n" - elif isinstance(message, ExtensionReturnMessage): + elif ( + isinstance(message, ExtensionReturnMessage) + and self.config.COMPLETION_INSERT_NOTIFICATIONS + ): # output_message = f"_Extension `{message.name}` returned: `{message.content}`_\n\n" output_message = None - elif isinstance(message, StepNotification): + elif ( + isinstance(message, StepNotification) + and self.config.COMPLETION_INSERT_NOTIFICATIONS + ): + if message.type == "code_interpreter": + output_message = f"_Called extension `code_interpreter` with code:\n```python\n{message.content.code_interpreter.input}```_\n" + else: + output_message = None # output_message = f"_Agent moved forward with step: `{message.type}`: `{message.content}`_\n" - output_message = None elif isinstance(message, TextResponse): output_message = message.content elif isinstance(message, ImageResponse): output_message = "![image](" + message.content + ")\n\n" - else: - logging.critical(f"Unknown message type: {type(message)}") - output_message = f"`Unknown message type: {type(message)}`\n\n" + if output_message: logging.info( f"Queueing message type={message.__class__.__name__} len={len(output_message)}" ) self.output_queue.append(output_message) + @trace def close(self): """Closes the session.""" self.open = False @@ -96,20 +112,24 @@ def close(self): class SessionManager: """Manages assistant sessions.""" - def __init__(self, aoai_client: AzureOpenAI): + def __init__(self, aoai_client: AzureOpenAI, config: Configuration): """Initializes a new session manager. Args: aoai_client (AzureOpenAI): The AzureOpenAI client. """ self.aoai_client = aoai_client + self.config = config self.sessions = {} @trace def create_session(self) -> Session: """Creates a new session.""" - thread = self.aoai_client.beta.threads.create() - return Session(thread=thread, client=self.aoai_client) + thread = trace(self.aoai_client.beta.threads.create)() + self.sessions[thread.id] = Session( + thread=thread, client=self.aoai_client, config=self.config + ) + return self.sessions[thread.id] @trace def get_session(self, session_id: str) -> Union[Session, None]: @@ -118,14 +138,16 @@ def get_session(self, session_id: str) -> Union[Session, None]: return self.sessions[session_id] try: - thread = self.aoai_client.beta.threads.retrieve(session_id) + thread = trace(self.aoai_client.beta.threads.retrieve)(thread_id=session_id) except Exception as e: logging.critical( f"Error retrieving thread {session_id}: {traceback.format_exc()}" ) return None - self.sessions[session_id] = Session(thread=thread, client=self.aoai_client) + self.sessions[session_id] = Session( + thread=thread, client=self.aoai_client, config=self.config + ) return self.sessions[thread.id] diff --git a/src/copilot_sdk_flow/chat.py b/src/copilot_sdk_flow/chat.py index b80de55..826a3e1 100644 --- a/src/copilot_sdk_flow/chat.py +++ b/src/copilot_sdk_flow/chat.py @@ -15,6 +15,8 @@ from agent_arch.sessions import SessionManager from agent_arch.orchestrator import Orchestrator from agent_arch.extensions.manager import ExtensionsManager +from agent_arch.event_log import EventLogger +from agent_arch.messages import TextResponse @trace @@ -23,19 +25,28 @@ def chat_completion( stream: bool = False, context: dict[str, any] = {}, ): + event_logger = EventLogger() + event_logger.start_span(EventLogger.TIME_TO_FIRST_TOKEN) + event_logger.start_span(EventLogger.TIME_TO_FIRST_EXTENSION_CALL) + event_logger.start_span(EventLogger.TIME_TO_RUN_LOOP) + # a couple basic checks if not messages: - return {"error": "No messages provided."} + raise ValueError("No messages provided.") # loads the system config from the environment variables # with overrides from the context config = Configuration.from_env_and_context(context) # get the Azure OpenAI client - aoai_client = get_azure_openai_client(stream=False) # TODO: Assistants Streaming + aoai_client = get_azure_openai_client( + stream=False, + azure_endpoint=config.AZURE_OPENAI_ENDPOINT, + api_version=config.AZURE_OPENAI_API_VERSION, + ) # TODO: Assistants Streaming # the session manager is responsible for creating and storing sessions - session_manager = SessionManager(aoai_client) + session_manager = SessionManager(aoai_client, config) if "session_id" not in context: session = session_manager.create_session() @@ -53,12 +64,55 @@ def chat_completion( extensions.load() # the orchestrator is responsible for managing the assistant run - orchestrator = Orchestrator(config, aoai_client, session, extensions) - orchestrator.run_loop() + orchestrator = Orchestrator(config, aoai_client, session, extensions, event_logger) + try: + orchestrator.run_loop() + except Exception as e: + session.send(TextResponse(role="assistant", content=f"`Error: {e}`")) # for now we'll use this trick for outputs def output_queue_iterate(): while session.output_queue: yield session.output_queue.popleft() + event_logger.end_span(EventLogger.TIME_TO_FIRST_TOKEN) + + chat_completion_output = { + "context": context, + } + if context.get("return_spans", False): + chat_completion_output["context"]["spans"] = event_logger.report() + + if stream: + chat_completion_output["reply"] = output_queue_iterate() + else: + chat_completion_output["reply"] = "".join(list(output_queue_iterate())) + if chat_completion_output["reply"] == "": + chat_completion_output["reply"] = "No reply from the assistant." + + return chat_completion_output + + +if __name__ == "__main__": + from dotenv import load_dotenv + + load_dotenv() + + # from promptflow.tracing import start_trace + # start_trace() + + import logging + import json + + logging.basicConfig(level=logging.INFO) + # remove azure.core logging + logging.getLogger("azure.core").setLevel(logging.ERROR) + logging.getLogger("azure.identity").setLevel(logging.ERROR) + # logging.getLogger("httpx").setLevel(logging.ERROR) - return {"reply": output_queue_iterate(), "context": context} + # sample usage + messages = [ + # {"role": "user", "content": "plot avg monthly sales"}, + {"role": "user", "content": "avg sales in jan"}, + ] + result = chat_completion(messages, stream=False, context={"return_spans": True}) + print(json.dumps(result, indent=2)) diff --git a/src/copilot_sdk_flow/entry.py b/src/copilot_sdk_flow/entry.py index 23c91aa..f7ea243 100644 --- a/src/copilot_sdk_flow/entry.py +++ b/src/copilot_sdk_flow/entry.py @@ -38,17 +38,25 @@ def flow_entry_copilot_assistants( context = json.loads(context) if context else {} # refactor the whole chat_history thing - conversation = [ - { - "role": "user" if "inputs" in message else "assistant", - "content": ( - message["inputs"]["chat_input"] - if "inputs" in message - else message["outputs"]["chat_output"] - ), - } - for message in chat_history - ] + conversation = [] + for message in chat_history: + if "inputs" in message: + conversation.append( + { + "role": "user", + "content": message["inputs"]["chat_input"], + } + ) + elif "outputs" in message: + conversation.append( + { + "role": "assistant", + "content": message["outputs"]["reply"], + } + ) + else: + # ignore the ones not formatted as expected + pass # add the user input as last message in the conversation conversation.append({"role": "user", "content": chat_input}) diff --git a/src/copilot_sdk_flow/requirements.txt b/src/copilot_sdk_flow/requirements.txt index 26dc58c..0dca2ef 100644 --- a/src/copilot_sdk_flow/requirements.txt +++ b/src/copilot_sdk_flow/requirements.txt @@ -1,13 +1,13 @@ # those are the dependencies required only by chat.py # openai SDK -openai==1.13.3 +openai==1.30.1 # promptflow packages -promptflow[azure]==1.10.1 -promptflow-tracing==1.10.1 +promptflow[azure]==1.11.0 +promptflow-tracing==1.11.0 promptflow-tools==1.4.0 -promptflow-evals==0.2.0.dev0 +promptflow-evals==0.3.0 # azure dependencies (for authentication) azure-core==1.30.1 diff --git a/src/create_assistant.py b/src/create_assistant.py index 6222c33..a74d3de 100644 --- a/src/create_assistant.py +++ b/src/create_assistant.py @@ -70,6 +70,29 @@ def main(): azure_ad_token_provider=token_provider, ) + # read the various prompts + with open( + os.path.join( + os.path.dirname(__file__), + "copilot_sdk_flow", + "agent_arch", + "prompts", + "system_message.jinja2" + ) + ) as f: + system_message = f.read() + + with open( + os.path.join( + os.path.dirname(__file__), + "copilot_sdk_flow", + "agent_arch", + "prompts", + "data_schema.jinja2" + ) + ) as f: + data_schema = f.read() + with open( os.path.join( os.path.dirname(__file__), @@ -80,11 +103,12 @@ def main(): ) ) as f: custom_function_spec = json.load(f) + custom_function_spec["description"] += "\n" + data_schema logging.info(f"Creating assistant...") assistant = client.beta.assistants.create( name="Contoso Sales Assistant", - instructions="You are a helpful data analytics assistant helping user answer questions about the contoso sales data.", + instructions=system_message, model=os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT"), tools=[ {"type": "code_interpreter"}, diff --git a/src/data/ood.jsonl b/src/data/ood.jsonl new file mode 100644 index 0000000..01cc0df --- /dev/null +++ b/src/data/ood.jsonl @@ -0,0 +1,20 @@ +{ "chat_input": "What is the average lifespan of a domestic cat?" } +{ "chat_input": "Do cats have a dominant paw, similar to humans being right or left-handed?" } +{ "chat_input": "What is the origin of the domestic cat?" } +{ "chat_input": "Can cats see in complete darkness?" } +{ "chat_input": "How many breeds of cats are recognized by the International Cat Association (TICA)?" } +{ "chat_input": "What is the technical term for a group of cats?" } +{ "chat_input": "Do cats have a collarbone?" } +{ "chat_input": "How many whiskers does the average cat have?" } +{ "chat_input": "Can cats be trained to do tricks, similar to dogs?" } +{ "chat_input": "What is the purpose of a cat's whiskers?" } +{ "chat_input": "Are all cats lactose intolerant?" } +{ "chat_input": "What is the scientific name for the domestic cat?" } +{ "chat_input": "Do cats purr when they are happy and content?" } +{ "chat_input": "What is the name of the protein in a cat's saliva that can cause allergies in some people?" } +{ "chat_input": "Do cats have a strong sense of smell?" } +{ "chat_input": "How many hours a day do cats typically sleep?" } +{ "chat_input": "What is the average body temperature of a cat?" } +{ "chat_input": "Do cats groom themselves to regulate body temperature?" } +{ "chat_input": "Can cats be left-pawed, right-pawed, or ambidextrous?" } +{ "chat_input": "What is the term for a female cat that has not been spayed?" } diff --git a/src/evaluate.py b/src/evaluate.py index 7eebc5b..e3787d5 100644 --- a/src/evaluate.py +++ b/src/evaluate.py @@ -17,6 +17,7 @@ QAEvaluator, # ChatEvaluator, ) +from promptflow.evals.evaluators import ContentSafetyEvaluator from tabulate import tabulate # local imports @@ -56,6 +57,15 @@ def get_model_config(evaluation_endpoint, evaluation_model): return model_config +def get_project_scope(): + """Return refs to project using env vars.""" + return { + "subscription_id": os.getenv("AZURE_SUBSCRIPTION_ID"), + "resource_group_name": os.getenv("AZURE_RESOURCE_GROUP"), + "project_name": os.getenv("AZUREAI_PROJECT_NAME"), + } + + def run_evaluation( evaluation_name, evaluation_model_config, @@ -120,8 +130,14 @@ def run_evaluation( "context": "${target.context}", "ground_truth": "${data.ground_truth}", } - elif metric_name == "latency": - raise NotImplementedError("Latency metric is not implemented yet") + elif metric_name == "content_safety": + evaluators[metric_name] = ContentSafetyEvaluator( + project_scope=get_project_scope() + ) + evaluators_config[metric_name] = { + "question": "${data.chat_input}", + "answer": "${target.reply}", + } else: raise ValueError(f"Unknown metric: {metric_name}") @@ -135,6 +151,7 @@ def run_evaluation( evaluators=evaluators, evaluator_config=evaluators_config, data=evaluation_data_path, + output_path=output_path, ) tabular_result = pd.DataFrame(result.get("rows")) @@ -181,10 +198,16 @@ def main(): "similarity", "qa", "chat", - "latency", + "content_safety", ], required=True, ) + parser.add_argument( + "--output-data", + type=str, + required=False, + help="Path to output data file (metrics and tabular result)", + ) args = parser.parse_args() # set logging @@ -211,6 +234,7 @@ def main(): evaluation_model_config=eval_model_config, evaluation_data_path=args.evaluation_data_path, metrics=args.metrics, + output_path=args.output_data, ) print("-----Summarized Metrics-----") diff --git a/src/requirements.txt b/src/requirements.txt index bb32811..1ae2892 100644 --- a/src/requirements.txt +++ b/src/requirements.txt @@ -2,13 +2,13 @@ # including all scripts for provisioning and deploying # openai SDK -openai==1.13.3 +openai==1.30.1 # promptflow packages -promptflow[azure]==1.10.1 -promptflow-tracing==1.10.1 +promptflow[azure]==1.11.0 +promptflow-tracing==1.11.0 promptflow-tools==1.4.0 -promptflow-evals==0.2.0.dev0 +promptflow-evals==0.3.0 # azure dependencies azure-core==1.30.1 @@ -18,8 +18,11 @@ azure-mgmt-search==9.1.0 azure-mgmt-cognitiveservices==13.5.0 azure-ai-ml==1.16.0 azure-storage-file-share>=12.10.0 +opencensus-ext-azure==1.1.13 +azureml-core==1.56.0 +azureml-mlflow==1.56.0 # utilities omegaconf-argparse==1.0.1 omegaconf==2.3.0 -pydantic>=2.6 +pydantic>=2.6 \ No newline at end of file