1+ using Microsoft . Extensions . Logging ;
2+ using Microsoft . Extensions . Logging . Abstractions ;
3+ using ModelContextProtocol . Protocol ;
4+ using System . Net ;
5+ using System . Threading . Channels ;
6+
7+ namespace ModelContextProtocol . Client ;
8+
9+ /// <summary>
10+ /// A transport that automatically detects whether to use Streamable HTTP or SSE transport
11+ /// by trying Streamable HTTP first and falling back to SSE if that fails.
12+ /// </summary>
13+ internal sealed partial class AutoDetectingClientSessionTransport : ITransport
14+ {
15+ private readonly SseClientTransportOptions _options ;
16+ private readonly HttpClient _httpClient ;
17+ private readonly ILoggerFactory ? _loggerFactory ;
18+ private readonly ILogger _logger ;
19+ private readonly string _name ;
20+ private readonly Channel < JsonRpcMessage > _messageChannel ;
21+
22+ public AutoDetectingClientSessionTransport ( SseClientTransportOptions transportOptions , HttpClient httpClient , ILoggerFactory ? loggerFactory , string endpointName )
23+ {
24+ Throw . IfNull ( transportOptions ) ;
25+ Throw . IfNull ( httpClient ) ;
26+
27+ _options = transportOptions ;
28+ _httpClient = httpClient ;
29+ _loggerFactory = loggerFactory ;
30+ _logger = ( ILogger ? ) loggerFactory ? . CreateLogger < AutoDetectingClientSessionTransport > ( ) ?? NullLogger . Instance ;
31+ _name = endpointName ;
32+
33+ // Same as TransportBase.cs.
34+ _messageChannel = Channel . CreateUnbounded < JsonRpcMessage > ( new UnboundedChannelOptions
35+ {
36+ SingleReader = true ,
37+ SingleWriter = false ,
38+ } ) ;
39+ }
40+
41+ /// <summary>
42+ /// Returns the active transport (either StreamableHttp or SSE)
43+ /// </summary>
44+ internal ITransport ? ActiveTransport { get ; private set ; }
45+
46+ public ChannelReader < JsonRpcMessage > MessageReader => _messageChannel . Reader ;
47+
48+ /// <inheritdoc/>
49+ public Task SendMessageAsync ( JsonRpcMessage message , CancellationToken cancellationToken = default )
50+ {
51+ if ( ActiveTransport is null )
52+ {
53+ return InitializeAsync ( message , cancellationToken ) ;
54+ }
55+
56+ return ActiveTransport . SendMessageAsync ( message , cancellationToken ) ;
57+ }
58+
59+ private async Task InitializeAsync ( JsonRpcMessage message , CancellationToken cancellationToken )
60+ {
61+ // Try StreamableHttp first
62+ var streamableHttpTransport = new StreamableHttpClientSessionTransport ( _name , _options , _httpClient , _messageChannel , _loggerFactory ) ;
63+
64+ try
65+ {
66+ LogAttemptingStreamableHttp ( _name ) ;
67+ using var response = await streamableHttpTransport . SendHttpRequestAsync ( message , cancellationToken ) . ConfigureAwait ( false ) ;
68+
69+ if ( response . IsSuccessStatusCode )
70+ {
71+ LogUsingStreamableHttp ( _name ) ;
72+ ActiveTransport = streamableHttpTransport ;
73+ }
74+ else
75+ {
76+ // If the status code is not success, fall back to SSE
77+ LogStreamableHttpFailed ( _name , response . StatusCode ) ;
78+
79+ await streamableHttpTransport . DisposeAsync ( ) . ConfigureAwait ( false ) ;
80+ await InitializeSseTransportAsync ( message , cancellationToken ) . ConfigureAwait ( false ) ;
81+ }
82+ }
83+ catch
84+ {
85+ // If nothing threw inside the try block, we've either set streamableHttpTransport as the
86+ // ActiveTransport, or else we will have disposed it in the !IsSuccessStatusCode else block.
87+ await streamableHttpTransport . DisposeAsync ( ) . ConfigureAwait ( false ) ;
88+ throw ;
89+ }
90+ }
91+
92+ private async Task InitializeSseTransportAsync ( JsonRpcMessage message , CancellationToken cancellationToken )
93+ {
94+ var sseTransport = new SseClientSessionTransport ( _name , _options , _httpClient , _messageChannel , _loggerFactory ) ;
95+
96+ try
97+ {
98+ LogAttemptingSSE ( _name ) ;
99+ await sseTransport . ConnectAsync ( cancellationToken ) . ConfigureAwait ( false ) ;
100+ await sseTransport . SendMessageAsync ( message , cancellationToken ) . ConfigureAwait ( false ) ;
101+
102+ LogUsingSSE ( _name ) ;
103+ ActiveTransport = sseTransport ;
104+ }
105+ catch
106+ {
107+ await sseTransport . DisposeAsync ( ) . ConfigureAwait ( false ) ;
108+ throw ;
109+ }
110+ }
111+
112+ public async ValueTask DisposeAsync ( )
113+ {
114+ try
115+ {
116+ if ( ActiveTransport is not null )
117+ {
118+ await ActiveTransport . DisposeAsync ( ) . ConfigureAwait ( false ) ;
119+ }
120+ }
121+ finally
122+ {
123+ // In the majority of cases, either the Streamable HTTP transport or SSE transport has completed the channel by now.
124+ // However, this may not be the case if HttpClient throws during the initial request due to misconfiguration.
125+ _messageChannel . Writer . TryComplete ( ) ;
126+ }
127+ }
128+
129+ [ LoggerMessage ( Level = LogLevel . Debug , Message = "{EndpointName} attempting to connect using Streamable HTTP transport." ) ]
130+ private partial void LogAttemptingStreamableHttp ( string endpointName ) ;
131+
132+ [ LoggerMessage ( Level = LogLevel . Information , Message = "{EndpointName} streamable HTTP transport failed with status code {StatusCode}, falling back to SSE transport." ) ]
133+ private partial void LogStreamableHttpFailed ( string endpointName , HttpStatusCode statusCode ) ;
134+
135+ [ LoggerMessage ( Level = LogLevel . Information , Message = "{EndpointName} using Streamable HTTP transport." ) ]
136+ private partial void LogUsingStreamableHttp ( string endpointName ) ;
137+
138+ [ LoggerMessage ( Level = LogLevel . Debug , Message = "{EndpointName} attempting to connect using SSE transport." ) ]
139+ private partial void LogAttemptingSSE ( string endpointName ) ;
140+
141+ [ LoggerMessage ( Level = LogLevel . Information , Message = "{EndpointName} using SSE transport." ) ]
142+ private partial void LogUsingSSE ( string endpointName ) ;
143+ }
0 commit comments