diff --git a/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs index fd2466ea..93559b7d 100644 --- a/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs @@ -91,7 +91,7 @@ public override async Task SendMessageAsync( { Content = content, }; - StreamableHttpClientSessionTransport.CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders); + StreamableHttpClientSessionTransport.CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, sessionId: null, protocolVersion: null); var response = await _httpClient.SendAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false); if (!response.IsSuccessStatusCode) @@ -152,7 +152,7 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) { using var request = new HttpRequestMessage(HttpMethod.Get, _sseEndpoint); request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); - StreamableHttpClientSessionTransport.CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders); + StreamableHttpClientSessionTransport.CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, sessionId: null, protocolVersion: null); using var response = await _httpClient.SendAsync( request, diff --git a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs index d4d03480..14df5c35 100644 --- a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs @@ -30,6 +30,9 @@ internal sealed partial class StreamableHttpClientSessionTransport : TransportBa private string? _negotiatedProtocolVersion; private Task? _getReceiveTask; + private readonly SemaphoreSlim _disposeLock = new(1, 1); + private bool _disposed; + public StreamableHttpClientSessionTransport( string endpointName, SseClientTransportOptions transportOptions, @@ -138,12 +141,26 @@ internal async Task SendHttpRequestAsync(JsonRpcMessage mes public override async ValueTask DisposeAsync() { + using var _ = await _disposeLock.LockAsync().ConfigureAwait(false); + + if (_disposed) + { + return; + } + _disposed = true; + try { await _connectionCts.CancelAsync().ConfigureAwait(false); try { + // Send DELETE request to terminate the session. Only send if we have a session ID, per MCP spec. + if (!string.IsNullOrEmpty(SessionId)) + { + await SendDeleteRequest(); + } + if (_getReceiveTask != null) { await _getReceiveTask.ConfigureAwait(false); @@ -236,6 +253,22 @@ message is JsonRpcMessageWithId rpcResponseOrError && return null; } + private async Task SendDeleteRequest() + { + using var deleteRequest = new HttpRequestMessage(HttpMethod.Delete, _options.Endpoint); + CopyAdditionalHeaders(deleteRequest.Headers, _options.AdditionalHeaders, SessionId, _negotiatedProtocolVersion); + + try + { + // Do not validate we get a successful status code, because server support for the DELETE request is optional + (await _httpClient.SendAsync(deleteRequest, CancellationToken.None).ConfigureAwait(false)).Dispose(); + } + catch (Exception ex) + { + LogTransportShutdownFailed(Name, ex); + } + } + private void LogJsonException(JsonException ex, string data) { if (_logger.IsEnabled(LogLevel.Trace)) @@ -251,8 +284,8 @@ private void LogJsonException(JsonException ex, string data) internal static void CopyAdditionalHeaders( HttpRequestHeaders headers, IDictionary? additionalHeaders, - string? sessionId = null, - string? protocolVersion = null) + string? sessionId, + string? protocolVersion) { if (sessionId is not null) { diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs index c5711020..119659ae 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs @@ -14,8 +14,10 @@ namespace ModelContextProtocol.AspNetCore.Tests; public class StreamableHttpClientConformanceTests(ITestOutputHelper outputHelper) : KestrelInMemoryTest(outputHelper), IAsyncDisposable { private WebApplication? _app; + private readonly List _deleteRequestSessionIds = []; - private async Task StartAsync() + // Don't add the delete endpoint by default to ensure the client still works with basic sessionless servers. + private async Task StartAsync(bool enableDelete = false) { Builder.Services.Configure(options => { @@ -28,7 +30,7 @@ private async Task StartAsync() Services = _app.Services, }); - _app.MapPost("/mcp", (JsonRpcMessage message) => + _app.MapPost("/mcp", (JsonRpcMessage message, HttpContext context) => { if (message is not JsonRpcRequest request) { @@ -36,6 +38,12 @@ private async Task StartAsync() return Results.Accepted(); } + if (enableDelete) + { + // Add a session ID to the response to enable session tracking + context.Response.Headers.Append("mcp-session-id", "test-session-123"); + } + if (request.Method == "initialize") { return Results.Json(new JsonRpcResponse @@ -87,6 +95,15 @@ private async Task StartAsync() throw new Exception("Unexpected message!"); }); + if (enableDelete) + { + _app.MapDelete("/mcp", context => + { + _deleteRequestSessionIds.Add(context.Request.Headers["mcp-session-id"].ToString()); + return Task.CompletedTask; + }); + } + await _app.StartAsync(TestContext.Current.CancellationToken); } @@ -136,6 +153,27 @@ public async Task CanCallToolConcurrently() await Task.WhenAll(echoTasks); } + [Fact] + public async Task SendsDeleteRequestOnDispose() + { + await StartAsync(enableDelete: true); + + await using var transport = new SseClientTransport(new() + { + Endpoint = new("http://localhost/mcp"), + TransportMode = HttpTransportMode.StreamableHttp, + }, HttpClient, LoggerFactory); + + await using var client = await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + + // Dispose should trigger DELETE request + await client.DisposeAsync(); + + // Verify DELETE request was sent with correct session ID + var sessionId = Assert.Single(_deleteRequestSessionIds); + Assert.Equal("test-session-123", sessionId); + } + private static async Task CallEchoAndValidateAsync(McpClientTool echoTool) { var response = await echoTool.CallAsync(new Dictionary() { ["message"] = "Hello world!" }, cancellationToken: TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryConnection.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryConnection.cs index 0269ea7b..b7d2ce64 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryConnection.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryConnection.cs @@ -65,10 +65,21 @@ private class DuplexStream(IDuplexPipe duplexPipe, CancellationTokenSource conne public override bool CanWrite => true; public override bool CanSeek => false; - public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - => _readStream.ReadAsync(buffer, offset, count, cancellationToken); - public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) - => _readStream.ReadAsync(buffer, cancellationToken); + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + // Normally, Kestrel will trigger RequestAborted when the connectionClosedCts fires causing it to gracefully close + // the connection. However, there's currently a race condition that can cause this to get missed. This at least + // unblocks HttpConnection.SendAsync when it disposes the underlying connection stream while awaiting the _readAheadTask + // as would happen with a real socket. https://github.com/dotnet/aspnetcore/pull/62385 + using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, connectionClosedCts.Token); + return await _readStream.ReadAsync(buffer, offset, count, linkedTokenSource.Token); + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, connectionClosedCts.Token); + return await _readStream.ReadAsync(buffer, linkedTokenSource.Token); + } public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) => _writeStream.WriteAsync(buffer, offset, count, cancellationToken); @@ -81,7 +92,7 @@ public override Task FlushAsync(CancellationToken cancellationToken) protected override void Dispose(bool disposing) { // Signal to the server the the client has closed the connection, and dispose the client-half of the Pipes. - connectionClosedCts.Cancel(); + ThreadPool.UnsafeQueueUserWorkItem(static cts => ((CancellationTokenSource)cts!).Cancel(), connectionClosedCts); duplexPipe.Input.Complete(); duplexPipe.Output.Complete(); }