Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dotnet/src/Agents/A2A/A2AAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public override async IAsyncEnumerable<AgentResponseItem<ChatMessageContent>> In
messages,
thread,
() => new A2AAgentThread(this.Client),
requiresThreadRetrieval: false,
cancellationToken).ConfigureAwait(false);

// Invoke the agent.
Expand Down Expand Up @@ -78,6 +79,7 @@ public override async IAsyncEnumerable<AgentResponseItem<StreamingChatMessageCon
messages,
thread,
() => new A2AAgentThread(this.Client),
requiresThreadRetrieval: false,
cancellationToken).ConfigureAwait(false);

// Invoke the agent.
Expand Down
10 changes: 9 additions & 1 deletion dotnet/src/Agents/Abstractions/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -326,20 +326,22 @@ public abstract IAsyncEnumerable<AgentResponseItem<StreamingChatMessageContent>>

private ILogger? _logger;

#pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
Comment thread
westey-m marked this conversation as resolved.
/// <summary>
/// Ensures that the thread exists, is of the expected type, and is active, plus adds the provided message to the thread.
/// </summary>
/// <typeparam name="TThreadType">The expected type of the thead.</typeparam>
/// <param name="messages">The messages to add to the thread once it is setup.</param>
/// <param name="thread">The thread to create if it's null, validate it's type if not null, and start if it is not active.</param>
/// <param name="constructThread">A callback to use to construct the thread if it's null.</param>
/// <param name="requiresThreadRetrieval">true if the thread must implement <see cref="IAgentThreadMessageProvider"/> to allow message retrieval.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>An async task that completes once all update are complete.</returns>
/// <exception cref="KernelException"></exception>
protected virtual async Task<TThreadType> EnsureThreadExistsWithMessagesAsync<TThreadType>(
ICollection<ChatMessageContent> messages,
AgentThread? thread,
Func<TThreadType> constructThread,
bool requiresThreadRetrieval,
CancellationToken cancellationToken)
where TThreadType : AgentThread
{
Expand All @@ -348,6 +350,11 @@ protected virtual async Task<TThreadType> EnsureThreadExistsWithMessagesAsync<TT
thread = constructThread();
}

if (requiresThreadRetrieval && thread is not IAgentThreadMessageProvider)
{
throw new KernelException($"{this.GetType().Name} requires agent threads that implement {nameof(IAgentThreadMessageProvider)}.");
}

if (thread is not TThreadType concreteThreadType)
{
throw new KernelException($"{this.GetType().Name} currently only supports agent threads of type {typeof(TThreadType).Name}.");
Expand All @@ -367,6 +374,7 @@ protected virtual async Task<TThreadType> EnsureThreadExistsWithMessagesAsync<TT

return concreteThreadType;
}
#pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.

/// <summary>
/// Notfiy the given thread that a new message is available.
Expand Down
48 changes: 48 additions & 0 deletions dotnet/src/Agents/Abstractions/IAgentThreadMessageProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Threading;

namespace Microsoft.SemanticKernel.Agents;

/// <summary>
/// Interface for any Semantic Kernel agent thread that allow the messages
/// contained in it to be passed to an agent.
/// </summary>
/// <remarks>
/// <para>
/// <see cref="AgentThread"/> types that implement this interface can
/// be used with Agents that do not maintain a server-side chat history, e.g. ChatCompletionAgent.
/// These agents are typically implemented using simple LLMs and therefore
/// require the entire chat history to be provided to the LLM for each invocation.
/// </para>
/// <para>
/// This is in contrast to agents that maintain a server-side chat history, e.g. AzureAIAgentThread,
/// where the chat history is stored on the server and managed by the agent service.
/// </para>
/// <para>
/// The set of messages returned may be truncated or processed
/// by the <see cref="AgentThread"/> as needed before passed to the
/// agent to achieve a scalable and performant solution.
/// </para>
/// <para>
/// This interface can be used to implement custom agent threads, that store messages
/// in a database or 3rd party service, instead of in-memory like done by ChatHistoryAgentThread.
/// </para>
/// </remarks>
[Experimental("SKEXP0110")]
public interface IAgentThreadMessageProvider
{
/// <summary>
/// Asynchronously retrieves all messages to be used for the agent invocation.
/// </summary>
/// <remarks>
/// Messages are returned in ascending chronological order.
/// </remarks>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>The messages in the thread.</returns>
/// <exception cref="InvalidOperationException">The thread has been deleted.</exception>
IAsyncEnumerable<ChatMessageContent> GetMessagesAsync(CancellationToken cancellationToken = default);
}
2 changes: 2 additions & 0 deletions dotnet/src/Agents/AzureAI/AzureAIAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ public async IAsyncEnumerable<AgentResponseItem<ChatMessageContent>> InvokeAsync
messages,
thread,
() => new AzureAIAgentThread(this.Client),
requiresThreadRetrieval: false,
cancellationToken).ConfigureAwait(false);

Kernel kernel = this.GetKernel(options);
Expand Down Expand Up @@ -238,6 +239,7 @@ public async IAsyncEnumerable<AgentResponseItem<StreamingChatMessageContent>> In
messages,
thread,
() => new AzureAIAgentThread(this.Client),
requiresThreadRetrieval: false,
cancellationToken).ConfigureAwait(false);

Kernel kernel = this.GetKernel(options);
Expand Down
4 changes: 4 additions & 0 deletions dotnet/src/Agents/Bedrock/BedrockAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ public override async IAsyncEnumerable<AgentResponseItem<ChatMessageContent>> In
messages,
thread,
() => new BedrockAgentThread(this.RuntimeClient),
requiresThreadRetrieval: false,
cancellationToken).ConfigureAwait(false);

// Get the context contributions from the AIContextProviders.
Expand Down Expand Up @@ -200,6 +201,7 @@ public async IAsyncEnumerable<AgentResponseItem<ChatMessageContent>> InvokeAsync
[],
thread,
() => new BedrockAgentThread(this.RuntimeClient),
requiresThreadRetrieval: false,
cancellationToken).ConfigureAwait(false);

// Configure the agent request with the provided options
Expand Down Expand Up @@ -259,6 +261,7 @@ public override async IAsyncEnumerable<AgentResponseItem<StreamingChatMessageCon
messages,
thread,
() => new BedrockAgentThread(this.RuntimeClient),
requiresThreadRetrieval: false,
cancellationToken).ConfigureAwait(false);

// Get the context contributions from the AIContextProviders.
Expand Down Expand Up @@ -342,6 +345,7 @@ public async IAsyncEnumerable<AgentResponseItem<StreamingChatMessageContent>> In
[],
thread,
() => new BedrockAgentThread(this.RuntimeClient),
requiresThreadRetrieval: false,
cancellationToken).ConfigureAwait(false);

// Configure the agent request with the provided options
Expand Down
2 changes: 2 additions & 0 deletions dotnet/src/Agents/Copilot/CopilotStudioAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ public override async IAsyncEnumerable<AgentResponseItem<ChatMessageContent>> In
messages,
thread,
() => new CopilotStudioAgentThread(this.Client) { Logger = this.ActiveLoggerFactory.CreateLogger<CopilotStudioAgentThread>() },
requiresThreadRetrieval: false,
cancellationToken).ConfigureAwait(false);

// Invoke the agent
Expand Down Expand Up @@ -95,6 +96,7 @@ public override async IAsyncEnumerable<AgentResponseItem<StreamingChatMessageCon
messages,
thread,
() => new CopilotStudioAgentThread(this.Client) { Logger = this.ActiveLoggerFactory.CreateLogger<CopilotStudioAgentThread>() },
requiresThreadRetrieval: false,
cancellationToken).ConfigureAwait(false);

// Invoke the agent
Expand Down
36 changes: 23 additions & 13 deletions dotnet/src/Agents/Core/ChatCompletionAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,14 @@ public override async IAsyncEnumerable<AgentResponseItem<ChatMessageContent>> In
{
Verify.NotNull(messages);

ChatHistoryAgentThread chatHistoryAgentThread = await this.EnsureThreadExistsWithMessagesAsync(
// Ensure the thread exists, is updated with our new messages, and is retrievable.
AgentThread safeAgentThread = await this.EnsureThreadExistsWithMessagesAsync<AgentThread>(
messages,
thread,
() => new ChatHistoryAgentThread(),
requiresThreadRetrieval: true,
cancellationToken).ConfigureAwait(false);
var retrievableAgentThread = (IAgentThreadMessageProvider)safeAgentThread;

Kernel kernel = this.GetKernel(options);
#pragma warning disable SKEXP0110, SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
Expand All @@ -81,7 +84,7 @@ public override async IAsyncEnumerable<AgentResponseItem<ChatMessageContent>> In
}

// Get the context contributions from the AIContextProviders.
AIContext providersContext = await chatHistoryAgentThread.AIContextProviders.ModelInvokingAsync(messages, cancellationToken).ConfigureAwait(false);
AIContext providersContext = await safeAgentThread.AIContextProviders.ModelInvokingAsync(messages, cancellationToken).ConfigureAwait(false);

// Check for compatibility AIContextProviders and the UseImmutableKernel setting.
if (providersContext.AIFunctions is { Count: > 0 } && !this.UseImmutableKernel)
Expand All @@ -92,18 +95,20 @@ public override async IAsyncEnumerable<AgentResponseItem<ChatMessageContent>> In
kernel.Plugins.AddFromAIContext(providersContext, "Tools");
#pragma warning restore SKEXP0110, SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.

// Invoke Chat Completion with the updated chat history.
// Retrieve the chat history from the thread.
ChatHistory chatHistory = [];
await foreach (var existingMessage in chatHistoryAgentThread.GetMessagesAsync(cancellationToken).ConfigureAwait(false))
await foreach (var existingMessage in retrievableAgentThread.GetMessagesAsync(cancellationToken).ConfigureAwait(false))
{
chatHistory.Add(existingMessage);
}

// Invoke Chat Completion with the history that already contains our new messages.
var invokeResults = this.InternalInvokeAsync(
this.GetDisplayName(),
chatHistory,
async (m) =>
{
await this.NotifyThreadOfNewMessage(chatHistoryAgentThread, m, cancellationToken).ConfigureAwait(false);
await this.NotifyThreadOfNewMessage(safeAgentThread, m, cancellationToken).ConfigureAwait(false);
if (options?.OnIntermediateMessage is not null)
{
await options.OnIntermediateMessage(m).ConfigureAwait(false);
Expand Down Expand Up @@ -132,15 +137,15 @@ public override async IAsyncEnumerable<AgentResponseItem<ChatMessageContent>> In
// since the filter terminated the call, and therefore won't get executed.
if (!result.Items.Any(i => i is FunctionCallContent || i is FunctionResultContent))
{
await this.NotifyThreadOfNewMessage(chatHistoryAgentThread, result, cancellationToken).ConfigureAwait(false);
await this.NotifyThreadOfNewMessage(safeAgentThread, result, cancellationToken).ConfigureAwait(false);

if (options?.OnIntermediateMessage is not null)
{
await options.OnIntermediateMessage(result).ConfigureAwait(false);
}
}

yield return new(result, chatHistoryAgentThread);
yield return new(result, safeAgentThread);
}
}

Expand Down Expand Up @@ -173,11 +178,14 @@ public override async IAsyncEnumerable<AgentResponseItem<StreamingChatMessageCon
{
Verify.NotNull(messages);

ChatHistoryAgentThread chatHistoryAgentThread = await this.EnsureThreadExistsWithMessagesAsync(
// Ensure the thread exists, is updated with our new messages, and is retrievable.
AgentThread safeAgentThread = await this.EnsureThreadExistsWithMessagesAsync<AgentThread>(
messages,
thread,
() => new ChatHistoryAgentThread(),
requiresThreadRetrieval: true,
cancellationToken).ConfigureAwait(false);
var retrievableAgentThread = (IAgentThreadMessageProvider)safeAgentThread;

Kernel kernel = this.GetKernel(options);
#pragma warning disable SKEXP0110, SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
Expand All @@ -187,7 +195,7 @@ public override async IAsyncEnumerable<AgentResponseItem<StreamingChatMessageCon
}

// Get the context contributions from the AIContextProviders.
AIContext providersContext = await chatHistoryAgentThread.AIContextProviders.ModelInvokingAsync(messages, cancellationToken).ConfigureAwait(false);
AIContext providersContext = await safeAgentThread.AIContextProviders.ModelInvokingAsync(messages, cancellationToken).ConfigureAwait(false);

// Check for compatibility AIContextProviders and the UseImmutableKernel setting.
if (providersContext.AIFunctions is { Count: > 0 } && !this.UseImmutableKernel)
Expand All @@ -198,19 +206,21 @@ public override async IAsyncEnumerable<AgentResponseItem<StreamingChatMessageCon
kernel.Plugins.AddFromAIContext(providersContext, "Tools");
#pragma warning restore SKEXP0110, SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.

// Invoke Chat Completion with the updated chat history.
// Retrieve the chat history from the thread.
ChatHistory chatHistory = [];
await foreach (var existingMessage in chatHistoryAgentThread.GetMessagesAsync(cancellationToken).ConfigureAwait(false))
await foreach (var existingMessage in retrievableAgentThread.GetMessagesAsync(cancellationToken).ConfigureAwait(false))
{
chatHistory.Add(existingMessage);
}

// Invoke Chat Completion with the history that already contains our new messages.
string agentName = this.GetDisplayName();
var invokeResults = this.InternalInvokeStreamingAsync(
agentName,
chatHistory,
async (m) =>
{
await this.NotifyThreadOfNewMessage(chatHistoryAgentThread, m, cancellationToken).ConfigureAwait(false);
await this.NotifyThreadOfNewMessage(safeAgentThread, m, cancellationToken).ConfigureAwait(false);
if (options?.OnIntermediateMessage is not null)
{
await options.OnIntermediateMessage(m).ConfigureAwait(false);
Expand All @@ -223,7 +233,7 @@ public override async IAsyncEnumerable<AgentResponseItem<StreamingChatMessageCon

await foreach (var result in invokeResults.ConfigureAwait(false))
{
yield return new(result, chatHistoryAgentThread);
yield return new(result, safeAgentThread);
}
}

Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/Agents/Core/ChatHistoryAgentThread.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace Microsoft.SemanticKernel.Agents;
/// <summary>
/// Represents a conversation thread based on an instance of <see cref="ChatHistory"/> that is managed inside this class.
/// </summary>
public sealed class ChatHistoryAgentThread : AgentThread
public sealed class ChatHistoryAgentThread : AgentThread, IAgentThreadMessageProvider
{
private readonly ChatHistory _chatHistory = new();

Expand Down
24 changes: 15 additions & 9 deletions dotnet/src/Agents/OpenAI/Internal/ResponseThreadActions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ internal static async IAsyncEnumerable<ChatMessageContent> InvokeAsync(
if (!agent.StoreEnabled)
{
// Use the thread chat history
overrideHistory = [.. GetChatHistory(agentThread)];
overrideHistory = [.. await GetChatHistoryAsync(agentThread, cancellationToken).ConfigureAwait(false)];
}

var creationOptions = ResponseCreationOptionsFactory.CreateOptions(agent, agentThread, options);
Expand All @@ -61,7 +61,7 @@ internal static async IAsyncEnumerable<ChatMessageContent> InvokeAsync(
}

var message = response.ToChatMessageContent();
overrideHistory.Add(message);
history.Add(message);
yield return message;

// Reached maximum auto invocations
Expand Down Expand Up @@ -108,7 +108,7 @@ await functionProcessor.InvokeFunctionCallsAsync(
Role = AuthorRole.Tool,
Items = items,
};
overrideHistory.Add(functionResultMessage);
history.Add(functionResultMessage);
yield return functionResultMessage;
}
}
Expand All @@ -126,7 +126,7 @@ internal static async IAsyncEnumerable<StreamingChatMessageContent> InvokeStream
if (!agent.StoreEnabled)
{
// Use the thread chat history
overrideHistory = [.. GetChatHistory(agentThread)];
overrideHistory = [.. await GetChatHistoryAsync(agentThread, cancellationToken).ConfigureAwait(false)];
}

var inputItems = overrideHistory.Select(m => m.ToResponseItem()).ToList();
Expand Down Expand Up @@ -159,7 +159,7 @@ internal static async IAsyncEnumerable<StreamingChatMessageContent> InvokeStream
case StreamingResponseCompletedUpdate completedUpdate:
response = completedUpdate.Response;
message = completedUpdate.Response.ToChatMessageContent();
overrideHistory.Add(message);
history.Add(message);
break;

case StreamingResponseOutputItemAddedUpdate outputItemAddedUpdate:
Expand Down Expand Up @@ -287,16 +287,22 @@ await functionProcessor.InvokeFunctionCallsAsync(
InnerContent = functionCallUpdateContent,
Items = [functionCallUpdateContent],
};
overrideHistory.Add(functionResultMessage);
history.Add(functionResultMessage);
yield return streamingFunctionResultMessage;
}
}

private static ChatHistory GetChatHistory(AgentThread agentThread)
private static async Task<ChatHistory> GetChatHistoryAsync(AgentThread agentThread, CancellationToken cancellationToken)
{
if (agentThread is ChatHistoryAgentThread chatHistoryAgentThread)
if (agentThread is IAgentThreadMessageProvider agentThreadRetrievable)
{
return chatHistoryAgentThread.ChatHistory;
ChatHistory chatHistory = [];
await foreach (var message in agentThreadRetrievable.GetMessagesAsync(cancellationToken).ConfigureAwait(false))
{
chatHistory.Add(message);
}

return chatHistory;
}

throw new InvalidOperationException("The agent thread is not a ChatHistoryAgentThread.");
Expand Down
Loading
Loading