Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions dotnet/src/Agents/AzureAI/AzureAIAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -282,19 +282,29 @@ public async IAsyncEnumerable<AgentResponseItem<StreamingChatMessageContent>> In
cancellationToken);

// Return the chunks to the caller.
int messageIndex = 0;
await foreach (var result in invokeResults.ConfigureAwait(false))
{
// Notify the thread of any messages that were assembled from the streaming response during this iteration.
await NotifyMessagesAsync().ConfigureAwait(false);

yield return new(result, azureAIAgentThread);
}

// Notify the thread of any new messages that were assembled from the streaming response.
foreach (var newMessage in newMessagesReceiver)
{
await this.NotifyThreadOfNewMessage(azureAIAgentThread, newMessage, cancellationToken).ConfigureAwait(false);
// Notify the thread of any remaining messages that were assembled from the streaming response after all iterations are complete.
await NotifyMessagesAsync().ConfigureAwait(false);

if (options?.OnIntermediateMessage is not null)
async Task NotifyMessagesAsync()
{
for (; messageIndex < newMessagesReceiver.Count; messageIndex++)
{
await options.OnIntermediateMessage(newMessage).ConfigureAwait(false);
ChatMessageContent newMessage = newMessagesReceiver[messageIndex];
await this.NotifyThreadOfNewMessage(azureAIAgentThread, newMessage, cancellationToken).ConfigureAwait(false);

if (options?.OnIntermediateMessage is not null)
{
await options.OnIntermediateMessage(newMessage).ConfigureAwait(false);
}
}
}
}
Expand Down
99 changes: 62 additions & 37 deletions dotnet/src/Agents/AzureAI/Internal/AgentThreadActions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ public static async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamin
// Evaluate status and process steps and messages, as encountered.
HashSet<string> processedStepIds = [];
Dictionary<string, FunctionResultContent[]> stepFunctionResults = [];
List<RunStep> stepsToProcess = [];
List<RunStep> messageCreationStepsToProcess = [];

FunctionCallsProcessor functionProcessor = new(logger);
// This matches current behavior. Will be configurable upon integrating with `FunctionChoice` (#6795/#5200)
Expand All @@ -401,7 +401,7 @@ public static async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamin
// Check for cancellation
cancellationToken.ThrowIfCancellationRequested();

stepsToProcess.Clear();
messageCreationStepsToProcess.Clear();

await foreach (StreamingUpdate update in asyncUpdates.ConfigureAwait(false))
{
Expand Down Expand Up @@ -440,9 +440,14 @@ public static async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamin
{
switch (stepUpdate.UpdateKind)
{
case StreamingUpdateReason.RunStepCompleted:
stepsToProcess.Add(stepUpdate.Value);
case StreamingUpdateReason.RunStepCompleted when stepUpdate.Value.StepDetails is RunStepToolCallDetails toolDetails:
ProcessToolCallStep(stepUpdate.Value, toolDetails, agent, messages, threadId, stepFunctionResults);
break;

case StreamingUpdateReason.RunStepCompleted when stepUpdate.Value.StepDetails is RunStepMessageCreationDetails:
messageCreationStepsToProcess.Add(stepUpdate.Value);
break;

default:
break;
}
Expand Down Expand Up @@ -510,55 +515,75 @@ await functionProcessor.InvokeFunctionCallsAsync(
}
}

if (stepsToProcess.Count > 0)
if (messageCreationStepsToProcess.Count > 0)
{
logger.LogAzureAIAgentProcessingRunMessages(nameof(InvokeAsync), run!.Id, threadId);

foreach (RunStep step in stepsToProcess)
foreach (RunStep step in messageCreationStepsToProcess)
{
if (step.StepDetails is RunStepMessageCreationDetails messageDetails)
{
PersistentThreadMessage? message =
await RetrieveMessageAsync(
client,
threadId,
messageDetails.MessageCreation.MessageId,
agent.PollingOptions.MessageSynchronizationDelay,
cancellationToken).ConfigureAwait(false);

if (message != null)
{
ChatMessageContent content = GenerateMessageContent(agent.GetName(), message, step, logger);
messages?.Add(content);
}
}
else if (step.StepDetails is RunStepToolCallDetails toolDetails)
{
foreach (RunStepToolCall toolCall in toolDetails.ToolCalls)
{
if (toolCall is RunStepFunctionToolCall functionCall)
{
messages?.Add(GenerateFunctionResultContent(agent.GetName(), stepFunctionResults[step.Id], step));
stepFunctionResults.Remove(step.Id);
break;
}

if (toolCall is RunStepCodeInterpreterToolCall codeCall)
{
messages?.Add(GenerateCodeInterpreterContent(agent.GetName(), codeCall.Input, step));
}
}
await ProcessMessageCreationStepAsync(step, messageDetails, agent, client, messages, threadId, logger, cancellationToken).ConfigureAwait(false);
}
}

logger.LogAzureAIAgentProcessedRunMessages(nameof(InvokeAsync), stepsToProcess.Count, run!.Id, threadId);
logger.LogAzureAIAgentProcessedRunMessages(nameof(InvokeAsync), messageCreationStepsToProcess.Count, run!.Id, threadId);
}
}
while (run?.Status != RunStatus.Completed);

logger.LogAzureAIAgentCompletedRun(nameof(InvokeAsync), run?.Id ?? "Failed", threadId);
}

private static async Task ProcessMessageCreationStepAsync(
RunStep step,
RunStepMessageCreationDetails messageDetails,
AzureAIAgent agent,
PersistentAgentsClient client,
IList<ChatMessageContent>? messages,
string threadId,
ILogger logger,
CancellationToken cancellationToken)
{
PersistentThreadMessage? message =
await RetrieveMessageAsync(
client,
threadId,
messageDetails.MessageCreation.MessageId,
agent.PollingOptions.MessageSynchronizationDelay,
cancellationToken).ConfigureAwait(false);

if (message != null)
{
ChatMessageContent content = GenerateMessageContent(agent.GetName(), message, step, logger);
messages?.Add(content);
}
}

private static void ProcessToolCallStep(
RunStep step,
RunStepToolCallDetails toolDetails,
AzureAIAgent agent,
IList<ChatMessageContent>? messages,
string threadId,
Dictionary<string, FunctionResultContent[]> stepFunctionResults)
{
foreach (RunStepToolCall toolCall in toolDetails.ToolCalls)
{
if (toolCall is RunStepFunctionToolCall functionCall)
{
messages?.Add(GenerateFunctionResultContent(agent.GetName(), stepFunctionResults[step.Id], step));
stepFunctionResults.Remove(step.Id);
break;
}

if (toolCall is RunStepCodeInterpreterToolCall codeCall)
{
messages?.Add(GenerateCodeInterpreterContent(agent.GetName(), codeCall.Input, step));
}
}
}

private static ChatMessageContent GenerateMessageContent(string? assistantName, PersistentThreadMessage message, RunStep? completedStep = null, ILogger? logger = null)
{
AuthorRole role = new(message.Role.ToString());
Expand Down
100 changes: 61 additions & 39 deletions dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ public static async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamin
// Evaluate status and process steps and messages, as encountered.
HashSet<string> processedStepIds = [];
Dictionary<string, FunctionResultContent[]> stepFunctionResults = [];
List<RunStep> stepsToProcess = [];
List<RunStep> messageCreationStepsToProcess = [];
ThreadRun? run = null;

FunctionCallsProcessor functionProcessor = new(logger);
Expand All @@ -389,7 +389,7 @@ public static async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamin
// Check for cancellation
cancellationToken.ThrowIfCancellationRequested();

stepsToProcess.Clear();
messageCreationStepsToProcess.Clear();

await foreach (StreamingUpdate update in asyncUpdates.ConfigureAwait(false))
{
Expand Down Expand Up @@ -436,7 +436,15 @@ public static async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamin
switch (stepUpdate.UpdateKind)
{
case StreamingUpdateReason.RunStepCompleted:
stepsToProcess.Add(stepUpdate.Value);
if (!string.IsNullOrEmpty(stepUpdate.Value.Details.CreatedMessageId))
{
messageCreationStepsToProcess.Add(stepUpdate.Value);
}
else
{
ProcessToolCallStep(stepUpdate.Value, agent, messages, threadId, stepFunctionResults);
}

break;
default:
break;
Expand Down Expand Up @@ -504,55 +512,69 @@ await functionProcessor.InvokeFunctionCallsAsync(
}
}

if (stepsToProcess.Count > 0)
if (messageCreationStepsToProcess.Count > 0)
{
logger.LogOpenAIAssistantProcessingRunMessages(nameof(InvokeAsync), run!.Id, threadId);

foreach (RunStep step in stepsToProcess)
foreach (RunStep step in messageCreationStepsToProcess)
{
if (!string.IsNullOrEmpty(step.Details.CreatedMessageId))
{
ThreadMessage? message =
await RetrieveMessageAsync(
client,
threadId,
step.Details.CreatedMessageId,
agent.PollingOptions.MessageSynchronizationDelay,
cancellationToken).ConfigureAwait(false);

if (message != null)
{
ChatMessageContent content = GenerateMessageContent(agent.GetName(), message, step);
messages?.Add(content);
}
}
else
{
foreach (RunStepToolCall toolCall in step.Details.ToolCalls)
{
if (toolCall.Kind == RunStepToolCallKind.Function)
{
messages?.Add(GenerateFunctionResultContent(agent.GetName(), stepFunctionResults[step.Id], step));
stepFunctionResults.Remove(step.Id);
break;
}

if (toolCall.Kind == RunStepToolCallKind.CodeInterpreter)
{
messages?.Add(GenerateCodeInterpreterContent(agent.GetName(), toolCall.CodeInterpreterInput, step));
}
}
}
await ProcessMessageCreationStepAsync(step, agent, client, messages, threadId, cancellationToken).ConfigureAwait(false);
}

logger.LogOpenAIAssistantProcessedRunMessages(nameof(InvokeAsync), stepsToProcess.Count, run!.Id, threadId);
logger.LogOpenAIAssistantProcessedRunMessages(nameof(InvokeAsync), messageCreationStepsToProcess.Count, run!.Id, threadId);
}
}
while (run?.Status != RunStatus.Completed);

logger.LogOpenAIAssistantCompletedRun(nameof(InvokeAsync), run?.Id ?? "Failed", threadId);
}

private static async Task ProcessMessageCreationStepAsync(
RunStep step,
OpenAIAssistantAgent agent,
AssistantClient client,
IList<ChatMessageContent>? messages,
string threadId,
CancellationToken cancellationToken)
{
ThreadMessage? message =
await RetrieveMessageAsync(
client,
threadId,
step.Details.CreatedMessageId,
agent.PollingOptions.MessageSynchronizationDelay,
cancellationToken).ConfigureAwait(false);

if (message != null)
{
ChatMessageContent content = GenerateMessageContent(agent.GetName(), message, step);
messages?.Add(content);
}
}

private static void ProcessToolCallStep(
RunStep step,
OpenAIAssistantAgent agent,
IList<ChatMessageContent>? messages,
string threadId,
Dictionary<string, FunctionResultContent[]> stepFunctionResults)
{
foreach (RunStepToolCall toolCall in step.Details.ToolCalls)
{
if (toolCall.Kind == RunStepToolCallKind.Function)
{
messages?.Add(GenerateFunctionResultContent(agent.GetName(), stepFunctionResults[step.Id], step));
stepFunctionResults.Remove(step.Id);
break;
}

if (toolCall.Kind == RunStepToolCallKind.CodeInterpreter)
{
messages?.Add(GenerateCodeInterpreterContent(agent.GetName(), toolCall.CodeInterpreterInput, step));
}
}
}

private static ChatMessageContent GenerateMessageContent(string? assistantName, ThreadMessage message, RunStep? completedStep = null, ILogger? logger = null)
{
AuthorRole role = new(message.Role.ToString());
Expand Down
22 changes: 16 additions & 6 deletions dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -289,19 +289,29 @@ IAsyncEnumerable<StreamingChatMessageContent> InternalInvokeStreamingAsync()
}

// Return the chunks to the caller.
int messageIndex = 0;
await foreach (var result in invokeResults.ConfigureAwait(false))
{
// Notify the thread of any messages that were assembled from the streaming response during this iteration.
await NotifyMessagesAsync().ConfigureAwait(false);

yield return new(result, openAIAssistantAgentThread);
}

// Notify the thread of any new messages that were assembled from the streaming response.
foreach (var newMessage in newMessagesReceiver)
{
await this.NotifyThreadOfNewMessage(openAIAssistantAgentThread, newMessage, cancellationToken).ConfigureAwait(false);
// Notify the thread of any remaining messages that were assembled from the streaming response after all iterations are complete.
await NotifyMessagesAsync().ConfigureAwait(false);

if (options?.OnIntermediateMessage is not null)
async Task NotifyMessagesAsync()
{
for (; messageIndex < newMessagesReceiver.Count; messageIndex++)
{
await options.OnIntermediateMessage(newMessage).ConfigureAwait(false);
ChatMessageContent newMessage = newMessagesReceiver[messageIndex];
await this.NotifyThreadOfNewMessage(openAIAssistantAgentThread, newMessage, cancellationToken).ConfigureAwait(false);

if (options?.OnIntermediateMessage is not null)
{
await options.OnIntermediateMessage(newMessage).ConfigureAwait(false);
}
}
}
}
Expand Down
22 changes: 16 additions & 6 deletions dotnet/src/Agents/OpenAI/OpenAIResponseAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ public override async IAsyncEnumerable<AgentResponseItem<StreamingChatMessageCon
// Invoke responses with the updated chat history.
ChatHistory chatHistory = [.. messages];
int messageCount = chatHistory.Count;
int messageIndex = chatHistory.Count;
var invokeResults = ResponseThreadActions.InvokeStreamingAsync(
this,
chatHistory,
Expand All @@ -89,17 +90,26 @@ public override async IAsyncEnumerable<AgentResponseItem<StreamingChatMessageCon
// Return streaming chat message content to the caller.
await foreach (var result in invokeResults.ConfigureAwait(false))
{
// Notify the thread of any messages that were assembled from the streaming response during this iteration.
await NotifyMessagesAsync().ConfigureAwait(false);

yield return new(result, agentThread);
}

// Notify the thread of new messages
for (int i = messageCount; i < chatHistory.Count; i++)
{
await this.NotifyThreadOfNewMessage(agentThread, chatHistory[i], cancellationToken).ConfigureAwait(false);
// Notify the thread of any remaining messages that were assembled from the streaming response after all iterations are complete.
await NotifyMessagesAsync().ConfigureAwait(false);

if (options?.OnIntermediateMessage is not null)
async Task NotifyMessagesAsync()
{
for (; messageIndex < chatHistory.Count; messageIndex++)
{
await options.OnIntermediateMessage(chatHistory[i]).ConfigureAwait(false);
ChatMessageContent newMessage = chatHistory[messageIndex];
await this.NotifyThreadOfNewMessage(agentThread, newMessage, cancellationToken).ConfigureAwait(false);

if (options?.OnIntermediateMessage is not null)
{
await options.OnIntermediateMessage(newMessage).ConfigureAwait(false);
}
}
}
}
Expand Down
Loading
Loading