diff --git a/README.md b/README.md index 6755eba..8e7d2a4 100644 --- a/README.md +++ b/README.md @@ -83,13 +83,29 @@ For remote server communication: ```swift // Create a streaming HTTP transport +// This uses Server-Sent Events (SSE) for real-time updates when streaming is enabled. +// It aligns with the "Streamable HTTP transport" from the 2025-03-26 MCP specification. let transport = HTTPClientTransport( endpoint: URL(string: "http://localhost:8080")!, - streaming: true // Enable Server-Sent Events for real-time updates + streaming: true ) try await client.connect(transport: transport) ``` +#### SSE Transport (Legacy) + +For direct SSE communication, particularly with servers expecting the older HTTP with SSE transport mechanism: + +```swift +// Create an SSE client transport +// This implements the "HTTP with SSE" transport, an earlier specification. +// For new implementations, prefer HTTPClientTransport with streaming: true. +let sseTransport = SSEClientTransport( + endpoint: URL(string: "http://localhost:8080/sse")! // Ensure endpoint is SSE-specific if needed +) +try await client.connect(transport: sseTransport) +``` + ### Tools Tools represent functions that can be called by the client: @@ -644,7 +660,8 @@ The Swift SDK provides multiple built-in transports: | Transport | Description | Platforms | Best for | |-----------|-------------|-----------|----------| | [`StdioTransport`](/Sources/MCP/Base/Transports/StdioTransport.swift) | Implements [stdio transport](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#stdio) using standard input/output streams | Apple platforms, Linux with glibc | Local subprocesses, CLI tools | -| [`HTTPClientTransport`](/Sources/MCP/Base/Transports/HTTPClientTransport.swift) | Implements [Streamable HTTP transport](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http) using Foundation's URL Loading System | All platforms with Foundation | Remote servers, web applications | +| [`HTTPClientTransport`](/Sources/MCP/Base/Transports/HTTPClientTransport.swift) | Implements [Streamable HTTP transport](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http) using Foundation's URL Loading System. Can use Server-Sent Events (SSE) when `streaming` is enabled. | All platforms with Foundation | Remote servers, web applications | +| [`SSEClientTransport`](/Sources/MCP/Base/Transports/SSEClientTransport.swift) | Implements the [HTTP with SSE transport](https://modelcontextprotocol.io/specification/2024-11-05/basic/transports#http-with-sse) (an earlier specified mechanism). The current [Streamable HTTP transport](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http) (implemented by `HTTPClientTransport` with `streaming: true`) is now preferred. | All platforms with Foundation | Legacy systems or direct SSE communication with servers designed for the older SSE-specific mechanism. | | [`NetworkTransport`](/Sources/MCP/Base/Transports/NetworkTransport.swift) | Custom transport using Apple's Network framework for TCP/UDP connections | Apple platforms only | Low-level networking, custom protocols | ### Custom Transport Implementation diff --git a/Sources/MCP/Base/Transports/SSEClientTransport.swift b/Sources/MCP/Base/Transports/SSEClientTransport.swift new file mode 100644 index 0000000..2715872 --- /dev/null +++ b/Sources/MCP/Base/Transports/SSEClientTransport.swift @@ -0,0 +1,423 @@ +import Foundation +import Logging + +#if !os(Linux) + import EventSource + + /// An implementation of the MCP HTTP with SSE transport protocol. + /// + /// This transport implements the [HTTP with SSE transport](https://modelcontextprotocol.io/specification/2024-11-05/basic/transports#http-with-sse) + /// specification from the Model Context Protocol. + /// + /// It supports: + /// - Sending JSON-RPC messages via HTTP POST requests + /// - Receiving responses via SSE events + /// - Automatic handling of endpoint discovery + /// + /// ## Example Usage + /// + /// ```swift + /// import MCP + /// + /// // Create an SSE transport with the server endpoint + /// let transport = SSETransport( + /// endpoint: URL(string: "http://localhost:8080")!, + /// token: "your-auth-token" // Optional + /// ) + /// + /// // Initialize the client with the transport + /// let client = Client(name: "MyApp", version: "1.0.0") + /// try await client.connect(transport: transport) + /// + /// // The transport will automatically handle SSE events + /// // and deliver them through the client's notification handlers + /// ``` + public actor SSEClientTransport: Transport { + /// The server endpoint URL to connect to + public let endpoint: URL + + /// Logger instance for transport-related events + public nonisolated let logger: Logger + + /// Whether the transport is currently connected + public private(set) var isConnected: Bool = false + + /// The URL to send messages to, provided by the server in the 'endpoint' event + private var messageURL: URL? + + /// Authentication token for requests (if required) + private let token: String? + + /// The URLSession for network requests + private let session: URLSession + + /// Task for SSE streaming connection + private var streamingTask: Task? + + /// Used for async/await in connect() + private var connectionContinuation: CheckedContinuation? + + /// Stream for receiving messages + private let messageStream: AsyncThrowingStream + private let messageContinuation: AsyncThrowingStream.Continuation + + /// Creates a new SSE transport with the specified endpoint + /// + /// - Parameters: + /// - endpoint: The server URL to connect to + /// - token: Optional authentication token + /// - configuration: URLSession configuration to use (default: .default) + /// - logger: Optional logger instance for transport events + public init( + endpoint: URL, + token: String? = nil, + configuration: URLSessionConfiguration = .default, + logger: Logger? = nil + ) { + self.endpoint = endpoint + self.token = token + self.session = URLSession(configuration: configuration) + + // Create message stream + var continuation: AsyncThrowingStream.Continuation! + self.messageStream = AsyncThrowingStream { continuation = $0 } + self.messageContinuation = continuation + + self.logger = + logger + ?? Logger( + label: "mcp.transport.sse", + factory: { _ in SwiftLogNoOpLogHandler() } + ) + } + + /// Establishes connection with the transport + /// + /// This creates an SSE connection to the server and waits for the 'endpoint' + /// event to receive the URL for sending messages. + public func connect() async throws { + guard !isConnected else { return } + + logger.info("Connecting to SSE endpoint: \(endpoint)") + + // Start listening for server events + streamingTask = Task { await listenForServerEvents() } + + // Wait for the endpoint URL to be received with a timeout + return try await withThrowingTaskGroup(of: Void.self) { group in + // Add the connection task + group.addTask { + try await self.waitForConnection() + } + + // Add the timeout task + group.addTask { + try await Task.sleep(for: .seconds(5)) // 5 second timeout + throw MCPError.internalError("Connection timeout waiting for endpoint URL") + } + + // Take the first result and cancel the other task + if let result = try await group.next() { + group.cancelAll() + return result + } + throw MCPError.internalError("Connection failed") + } + } + + /// Waits for the connection to be established + private func waitForConnection() async throws { + try await withCheckedThrowingContinuation { continuation in + self.connectionContinuation = continuation + } + } + + /// Disconnects from the transport + /// + /// This terminates the SSE connection and releases resources. + public func disconnect() async { + guard isConnected else { return } + + logger.info("Disconnecting from SSE endpoint") + + // Cancel the streaming task + streamingTask?.cancel() + streamingTask = nil + + // Clean up + isConnected = false + messageContinuation.finish() + + // If there's a pending connection continuation, fail it + if let continuation = connectionContinuation { + continuation.resume(throwing: MCPError.internalError("Connection closed")) + connectionContinuation = nil + } + + // Cancel any in-progress requests + session.invalidateAndCancel() + } + + /// Sends a JSON-RPC message to the server + /// + /// This sends data to the message endpoint provided by the server + /// during connection setup. + /// + /// - Parameter data: The JSON-RPC message to send + /// - Throws: MCPError if there's no message URL or if the request fails + public func send(_ data: Data) async throws { + guard isConnected else { + throw MCPError.internalError("Transport not connected") + } + + guard let messageURL = messageURL else { + throw MCPError.internalError("No message URL provided by server") + } + + logger.debug("Sending message", metadata: ["size": "\(data.count)"]) + + var request = URLRequest(url: messageURL) + request.httpMethod = "POST" + request.httpBody = data + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + + // Add authorization if token is provided + if let token = token { + request.setValue("Bearer \(token)", forHTTPHeaderField: "Authorization") + } + + let (_, response) = try await session.data(for: request) + + guard let httpResponse = response as? HTTPURLResponse else { + throw MCPError.internalError("Invalid HTTP response") + } + + guard (200..<300).contains(httpResponse.statusCode) else { + throw MCPError.internalError("HTTP error: \(httpResponse.statusCode)") + } + } + + /// Receives data in an async sequence + /// + /// This returns an AsyncThrowingStream that emits Data objects representing + /// each JSON-RPC message received from the server via SSE. + /// + /// - Returns: An AsyncThrowingStream of Data objects + public func receive() -> AsyncThrowingStream { + return messageStream + } + + // MARK: - Private Methods + + /// Main task that listens for server-sent events + private func listenForServerEvents() async { + let maxAttempts = 3 + var currentAttempt = 0 + var lastErrorEncountered: Swift.Error? + + while !Task.isCancelled && currentAttempt < maxAttempts { + currentAttempt += 1 + do { + logger.info( + "Attempting SSE connection (attempt \(currentAttempt)/\(maxAttempts)) to \(endpoint)" + ) + try await connectToSSEStream() + // If connectToSSEStream() returns without throwing, it means the stream of events finished. + // If connectionContinuation is still set at this point, it means we never got the 'endpoint' event. + if let continuation = self.connectionContinuation { + logger.error( + "SSE stream ended before 'endpoint' event was received during initial connection phase." + ) + let streamEndedError = MCPError.internalError( + "SSE stream ended before 'endpoint' event was received.") + continuation.resume(throwing: streamEndedError) + self.connectionContinuation = nil // Mark as handled + } + // If stream ended (either successfully resolving continuation or not), exit listenForServerEvents. + logger.debug( + "SSE stream processing completed or stream ended. Connection active: \(isConnected)" + ) + return + } catch { + // Check for cancellation immediately after an error. + if Task.isCancelled { + logger.info( + "SSE connection task cancelled after an error during attempt \(currentAttempt)." + ) + lastErrorEncountered = error // Store error that occurred before cancellation + break // Exit the retry loop; cancellation will be handled after the loop. + } + + lastErrorEncountered = error // Store the error from this attempt. + logger.warning( + "SSE connection attempt \(currentAttempt)/\(maxAttempts) failed: \(error.localizedDescription)" + ) + + if currentAttempt < maxAttempts && !Task.isCancelled { // If there are more attempts left + do { + let delay: TimeInterval + if currentAttempt == 1 { + delay = 0.5 + } // After 1st attempt fails + else { + delay = 1.0 + } // After 2nd attempt fails + + logger.info( + "Waiting \(delay) seconds before next SSE connection attempt (attempt \(currentAttempt + 1))." + ) + try await Task.sleep(for: .seconds(delay)) + } catch { // Catch cancellation of sleep + logger.info("SSE connection retry sleep was cancelled.") + // lastErrorEncountered is already set from the connection attempt. + // Task.isCancelled will be true, so the loop condition or post-loop check will handle it. + break // Exit the retry loop. + } + } + } + } // End of while loop + + // After the loop (due to Task.isCancelled or currentAttempt >= maxAttempts) + if let continuation = self.connectionContinuation { + // This continuation is still pending; means connection never established successfully. + if Task.isCancelled { + logger.info( + "SSE connection attempt was cancelled. Failing pending connection continuation." + ) + // Use lastErrorEncountered if cancellation happened after an error, otherwise a generic cancellation error. + let cancelError = + lastErrorEncountered + ?? MCPError.internalError("Connection attempt cancelled.") + continuation.resume(throwing: cancelError) + } else if currentAttempt >= maxAttempts { // This implies !Task.isCancelled + logger.error( + "All \(maxAttempts) SSE connection attempts failed. Failing pending connection continuation with last error: \(lastErrorEncountered?.localizedDescription ?? "N/A")" + ) + let finalError = + lastErrorEncountered + ?? MCPError.internalError( + "All SSE connection attempts failed after unknown error.") + continuation.resume(throwing: finalError) + } + self.connectionContinuation = nil // Ensure it's nilled after use. + } + logger.debug( + "listenForServerEvents task finished. Final connection state: \(isConnected). Message URL: \(String(describing: self.messageURL))" + ) + } + + /// Establishes the SSE stream connection + private func connectToSSEStream() async throws { + logger.debug("Starting SSE connection") + + var request = URLRequest(url: endpoint) + request.httpMethod = "GET" + request.setValue("text/event-stream", forHTTPHeaderField: "Accept") + request.setValue("no-cache", forHTTPHeaderField: "Cache-Control") + + // Add authorization if token is provided + if let token = token { + request.setValue("Bearer \(token)", forHTTPHeaderField: "Authorization") + } + + // On supported platforms, we use the EventSource implementation + let (byteStream, response) = try await session.bytes(for: request) + + guard let httpResponse = response as? HTTPURLResponse else { + throw MCPError.internalError("Invalid HTTP response") + } + + guard httpResponse.statusCode == 200 else { + throw MCPError.internalError("HTTP error: \(httpResponse.statusCode)") + } + + guard let contentType = httpResponse.value(forHTTPHeaderField: "Content-Type"), + contentType.contains("text/event-stream") + else { + throw MCPError.internalError("Invalid content type for SSE stream") + } + + logger.debug("SSE connection established") + + // Process the SSE stream + for try await event in byteStream.events { + // Check if task has been cancelled + if Task.isCancelled { break } + + processServerSentEvent(event) + } + } + + /// Processes a server-sent event + private func processServerSentEvent(_ event: SSE) { + // Process event based on type + switch event.event { + case "endpoint": + if !event.data.isEmpty { + processEndpointURL(event.data) + } else { + logger.error("Received empty endpoint data") + } + + case "message", nil: // Default event type is "message" per SSE spec + if !event.data.isEmpty, + let messageData = event.data.data(using: .utf8) + { + messageContinuation.yield(messageData) + } else { + logger.warning("Received empty message data") + } + + default: + logger.warning("Received unknown event type: \(event.event ?? "nil")") + } + } + + /// Processes an endpoint URL string received from the server + private func processEndpointURL(_ endpoint: String) { + logger.debug("Received endpoint path: \(endpoint)") + + // Construct the full URL for sending messages + if let url = constructMessageURL(from: endpoint) { + messageURL = url + logger.info("Message URL set to: \(url)") + + // Mark as connected + isConnected = true + + // Resume the connection continuation if it exists + if let continuation = connectionContinuation { + continuation.resume() + connectionContinuation = nil + } + } else { + logger.error("Failed to construct message URL from path: \(endpoint)") + + // Fail the connection if we have a continuation + if let continuation = connectionContinuation { + continuation.resume(throwing: MCPError.internalError("Invalid endpoint URL")) + connectionContinuation = nil + } + } + } + + /// Constructs a message URL from a path or absolute URL + private func constructMessageURL(from path: String) -> URL? { + // Handle absolute URLs + if path.starts(with: "http://") || path.starts(with: "https://") { + return URL(string: path) + } + + // Handle relative paths + guard var components = URLComponents(url: endpoint, resolvingAgainstBaseURL: true) + else { + return nil + } + + // For relative paths, preserve the scheme, host, and port + let pathToUse = path.starts(with: "/") ? path : "/\(path)" + components.path = pathToUse + return components.url + } + } +#endif diff --git a/Tests/MCPTests/HTTPClientTransportTests.swift b/Tests/MCPTests/HTTPClientTransportTests.swift index c2149bb..f858b8e 100644 --- a/Tests/MCPTests/HTTPClientTransportTests.swift +++ b/Tests/MCPTests/HTTPClientTransportTests.swift @@ -33,100 +33,8 @@ import Testing static var httpClientTransportSetup: Self { Self() } } - // MARK: - Mock Handler Registry Actor - - actor RequestHandlerStorage { - private var requestHandler: - (@Sendable (URLRequest) async throws -> (HTTPURLResponse, Data))? - - func setHandler( - _ handler: @Sendable @escaping (URLRequest) async throws -> (HTTPURLResponse, Data) - ) async { - requestHandler = handler - } - - func clearHandler() async { - requestHandler = nil - } - - func executeHandler(for request: URLRequest) async throws -> (HTTPURLResponse, Data) { - guard let handler = requestHandler else { - throw NSError( - domain: "MockURLProtocolError", code: 0, - userInfo: [ - NSLocalizedDescriptionKey: "No request handler set" - ]) - } - return try await handler(request) - } - } - - // MARK: - Helper Methods - - extension URLRequest { - fileprivate func readBody() -> Data? { - if let httpBodyData = self.httpBody { - return httpBodyData - } - - guard let bodyStream = self.httpBodyStream else { return nil } - bodyStream.open() - defer { bodyStream.close() } - - let bufferSize: Int = 4096 - let buffer = UnsafeMutablePointer.allocate(capacity: bufferSize) - defer { buffer.deallocate() } - - var data = Data() - while bodyStream.hasBytesAvailable { - let bytesRead = bodyStream.read(buffer, maxLength: bufferSize) - data.append(buffer, count: bytesRead) - } - return data - } - } - // MARK: - Mock URL Protocol - final class MockURLProtocol: URLProtocol, @unchecked Sendable { - static let requestHandlerStorage = RequestHandlerStorage() - - static func setHandler( - _ handler: @Sendable @escaping (URLRequest) async throws -> (HTTPURLResponse, Data) - ) async { - await requestHandlerStorage.setHandler { request in - try await handler(request) - } - } - - func executeHandler(for request: URLRequest) async throws -> (HTTPURLResponse, Data) { - return try await Self.requestHandlerStorage.executeHandler(for: request) - } - - override class func canInit(with request: URLRequest) -> Bool { - return true - } - - override class func canonicalRequest(for request: URLRequest) -> URLRequest { - return request - } - - override func startLoading() { - Task { - do { - let (response, data) = try await self.executeHandler(for: request) - client?.urlProtocol(self, didReceive: response, cacheStoragePolicy: .notAllowed) - client?.urlProtocol(self, didLoad: data) - client?.urlProtocolDidFinishLoading(self) - } catch { - client?.urlProtocol(self, didFailWithError: error) - } - } - } - - override func stopLoading() {} - } - // MARK: - @Suite("HTTP Client Transport Tests", .serialized) diff --git a/Tests/MCPTests/Helpers/MockURLProtocol.swift b/Tests/MCPTests/Helpers/MockURLProtocol.swift new file mode 100644 index 0000000..80b0bd6 --- /dev/null +++ b/Tests/MCPTests/Helpers/MockURLProtocol.swift @@ -0,0 +1,92 @@ +import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +final class MockURLProtocol: URLProtocol, @unchecked Sendable { + static let requestHandlerStorage = RequestHandlerStorage() + + static func setHandler( + _ handler: @Sendable @escaping (URLRequest) async throws -> (HTTPURLResponse, Data) + ) async { + await requestHandlerStorage.setHandler { request in + try await handler(request) + } + } + + func executeHandler(for request: URLRequest) async throws -> (HTTPURLResponse, Data) { + return try await Self.requestHandlerStorage.executeHandler(for: request) + } + + override class func canInit(with request: URLRequest) -> Bool { + return true + } + + override class func canonicalRequest(for request: URLRequest) -> URLRequest { + return request + } + + override func startLoading() { + Task { + do { + let (response, data) = try await self.executeHandler(for: request) + client?.urlProtocol(self, didReceive: response, cacheStoragePolicy: .notAllowed) + client?.urlProtocol(self, didLoad: data) + client?.urlProtocolDidFinishLoading(self) + } catch { + client?.urlProtocol(self, didFailWithError: error) + } + } + } + + override func stopLoading() {} +} + +actor RequestHandlerStorage { + private var requestHandler: (@Sendable (URLRequest) async throws -> (HTTPURLResponse, Data))? + + func setHandler( + _ handler: @Sendable @escaping (URLRequest) async throws -> (HTTPURLResponse, Data) + ) async { + requestHandler = handler + } + + func clearHandler() async { + requestHandler = nil + } + + func executeHandler(for request: URLRequest) async throws -> (HTTPURLResponse, Data) { + guard let handler = requestHandler else { + throw NSError( + domain: "MockURLProtocolError", code: 0, + userInfo: [ + NSLocalizedDescriptionKey: "No request handler set" + ]) + } + return try await handler(request) + } +} + +extension URLRequest { + func readBody() -> Data? { + if let httpBodyData = self.httpBody { + return httpBodyData + } + + guard let bodyStream = self.httpBodyStream else { return nil } + bodyStream.open() + defer { bodyStream.close() } + + let bufferSize: Int = 4096 + let buffer = UnsafeMutablePointer.allocate(capacity: bufferSize) + defer { buffer.deallocate() } + + var data = Data() + while bodyStream.hasBytesAvailable { + let bytesRead = bodyStream.read(buffer, maxLength: bufferSize) + data.append(buffer, count: bytesRead) + } + return data + } +} diff --git a/Tests/MCPTests/SSEClientTransportTests.swift b/Tests/MCPTests/SSEClientTransportTests.swift new file mode 100644 index 0000000..93eb946 --- /dev/null +++ b/Tests/MCPTests/SSEClientTransportTests.swift @@ -0,0 +1,522 @@ +#if swift(>=6.1) && !os(Linux) + import EventSource + @preconcurrency import Foundation + import Logging + import Testing + + @testable import MCP + + final class MockSSEURLProtocol: URLProtocol, @unchecked Sendable { + static let requestHandlerStorage = RequestHandlerStorage() + private var loadingTask: Task? + + static func setHandler( + _ handler: @Sendable @escaping (URLRequest) async throws -> (HTTPURLResponse, Data) + ) async { + await requestHandlerStorage.setHandler { request in + try await handler(request) + } + } + + func executeHandler(for request: URLRequest) async throws -> (HTTPURLResponse, Data) { + return try await Self.requestHandlerStorage.executeHandler(for: request) + } + + override class func canInit(with request: URLRequest) -> Bool { + return true + } + + override class func canonicalRequest(for request: URLRequest) -> URLRequest { + return request + } + + override func startLoading() { + loadingTask = Task { + do { + let (response, data) = try await self.executeHandler(for: request) + + // For SSE GET requests, we need to simulate a streaming response + if request.httpMethod == "GET" + && request.value(forHTTPHeaderField: "Accept")?.contains( + "text/event-stream") == true + { + // Send the response headers + client?.urlProtocol( + self, didReceive: response, cacheStoragePolicy: .notAllowed) + + // Simulate SSE event data coming in chunks + if !data.isEmpty { + // Break the data into lines to simulate events coming one by one + let dataString = String(data: data, encoding: .utf8) ?? "" + let lines = dataString.split( + separator: "\n", omittingEmptySubsequences: false) + + // Simulate delay between events + for line in lines { + let lineData = Data("\(line)\n".utf8) + self.client?.urlProtocol(self, didLoad: lineData) + try await Task.sleep(for: .milliseconds(10)) + } + } + + // Complete the loading + self.client?.urlProtocolDidFinishLoading(self) + } else { + // For regular requests, just deliver the data all at once + client?.urlProtocol( + self, didReceive: response, cacheStoragePolicy: .notAllowed) + client?.urlProtocol(self, didLoad: data) + client?.urlProtocolDidFinishLoading(self) + } + } catch { + client?.urlProtocol(self, didFailWithError: error) + } + } + } + + override func stopLoading() { + // Cancel any ongoing tasks + loadingTask?.cancel() + loadingTask = nil + } + } + + // MARK: - Test trait + + /// A test trait that automatically manages the mock URL protocol handler for SSE transport tests. + struct SSETransportTestSetupTrait: TestTrait, TestScoping { + func provideScope( + for test: Test, testCase: Test.Case?, + performing function: @Sendable () async throws -> Void + ) async throws { + // Clear handler before test + await MockSSEURLProtocol.requestHandlerStorage.clearHandler() + + do { + // Execute the test + try await function() + } catch { + // Ensure handler is cleared even if test fails + await MockSSEURLProtocol.requestHandlerStorage.clearHandler() + throw error + } + + // Clear handler after test + await MockSSEURLProtocol.requestHandlerStorage.clearHandler() + } + } + + extension Trait where Self == SSETransportTestSetupTrait { + static var sseTransportSetup: Self { Self() } + } + + // MARK: - Test State Management + + actor CapturedRequest { + private var value: URLRequest? + + func setValue(_ newValue: URLRequest?) { + value = newValue + } + + func getValue() -> URLRequest? { + return value + } + } + + actor CapturedRequests { + private var sseRequest: URLRequest? + private var postRequest: URLRequest? + + func setSSERequest(_ request: URLRequest?) { + sseRequest = request + } + + func setPostRequest(_ request: URLRequest?) { + postRequest = request + } + + func getValues() -> (URLRequest?, URLRequest?) { + return (sseRequest, postRequest) + } + } + + // MARK: - Tests + + @Suite("SSE Client Transport Tests", .serialized) + struct SSEClientTransportTests { + let testEndpoint = URL(string: "http://localhost:8080/sse")! + + @Test("Connect and receive endpoint event", .sseTransportSetup) + func testConnectAndReceiveEndpoint() async throws { + // Setup URLSession with mock protocol + let configuration = URLSessionConfiguration.ephemeral + configuration.protocolClasses = [MockSSEURLProtocol.self] + + // Create transport with mocked URLSession + let transport = SSEClientTransport( + endpoint: testEndpoint, + configuration: configuration + ) + + // Setup mock response for the SSE connection + let endpointEvent = """ + event: endpoint + data: /messages/123 + + """ + + await MockSSEURLProtocol.setHandler { request in + guard request.httpMethod == "GET" else { + throw MCPError.internalError("Unexpected request method") + } + + let response = HTTPURLResponse( + url: request.url!, + statusCode: 200, + httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "text/event-stream"] + )! + + return (response, Data(endpointEvent.utf8)) + } + + // Connect should receive the endpoint event and complete + try await transport.connect() + + // Transport should now be connected + #expect(await transport.isConnected) + + // Disconnect to clean up + await transport.disconnect() + } + + @Test("Send message after connection", .sseTransportSetup) + func testSendMessage() async throws { + // Setup URLSession with mock protocol + let configuration = URLSessionConfiguration.ephemeral + configuration.protocolClasses = [MockSSEURLProtocol.self] + + // Create transport with mocked URLSession + let transport = SSEClientTransport( + endpoint: testEndpoint, + configuration: configuration + ) + + // Configure different responses based on the request + let messageURL = URL(string: "http://localhost:8080/messages/123")! + let capturedRequest = CapturedRequest() + + await MockSSEURLProtocol.setHandler { request in + if request.httpMethod == "GET" { + // SSE connection request + let response = HTTPURLResponse( + url: testEndpoint, + statusCode: 200, + httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "text/event-stream"] + )! + + // Send the endpoint event + return ( + response, + Data( + """ + event: endpoint + data: /messages/123 + + """.utf8) + ) + } else if request.httpMethod == "POST" { + // Message request + await capturedRequest.setValue(request) + let response = HTTPURLResponse( + url: messageURL, + statusCode: 200, + httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"] + )! + return (response, Data()) + } else { + throw MCPError.internalError("Unexpected request method") + } + } + + // Connect first + try await transport.connect() + + // Now send a message + let messageData = #"{"jsonrpc":"2.0","method":"test","id":1}"#.data(using: .utf8)! + try await transport.send(messageData) + + // Verify the sent message + let capturedPostRequest = await capturedRequest.getValue() + #expect(capturedPostRequest != nil) + #expect(capturedPostRequest?.url?.path == "/messages/123") + #expect(capturedPostRequest?.httpMethod == "POST") + #expect(capturedPostRequest?.readBody() == messageData) + #expect( + capturedPostRequest?.value(forHTTPHeaderField: "Content-Type") == "application/json" + ) + + // Disconnect + await transport.disconnect() + } + + @Test("Receive message events", .sseTransportSetup) + func testReceiveMessageEvents() async throws { + // Setup URLSession with mock protocol + let configuration = URLSessionConfiguration.ephemeral + configuration.protocolClasses = [MockSSEURLProtocol.self] + + // Create transport with mocked URLSession + let transport = SSEClientTransport( + endpoint: testEndpoint, + configuration: configuration + ) + + // Configure mock response with both endpoint and message events + let eventStreamData = """ + event: endpoint + data: /messages/123 + + event: message + data: {"jsonrpc":"2.0","result":{"content":"Hello"},"id":1} + + event: message + data: {"jsonrpc":"2.0","result":{"content":"World"},"id":2} + + """ + + await MockSSEURLProtocol.setHandler { request in + guard request.httpMethod == "GET" else { + throw MCPError.internalError("Unexpected request method") + } + + let response = HTTPURLResponse( + url: request.url!, + statusCode: 200, + httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "text/event-stream"] + )! + + return (response, Data(eventStreamData.utf8)) + } + + // Start receiving before connecting + let receiverTask = Task { + var receivedMessages: [Data] = [] + let stream = await transport.receive() + var iterator = stream.makeAsyncIterator() + + // Try to get 2 messages + for _ in 0..<2 { + if let message = try await iterator.next() { + receivedMessages.append(message) + } + } + + return receivedMessages + } + + // Connect to trigger the event stream + try await transport.connect() + + // Wait for messages and check them + let receivedMessages = try await receiverTask.value + + #expect(receivedMessages.count == 2) + + // Check first message + let firstMessageString = String(data: receivedMessages[0], encoding: .utf8) + #expect(firstMessageString?.contains("Hello") == true) + + // Check second message + let secondMessageString = String(data: receivedMessages[1], encoding: .utf8) + #expect(secondMessageString?.contains("World") == true) + + // Disconnect + await transport.disconnect() + } + + @Test("Authentication token", .sseTransportSetup) + func testAuthenticationToken() async throws { + // Setup URLSession with mock protocol + let configuration = URLSessionConfiguration.ephemeral + configuration.protocolClasses = [MockSSEURLProtocol.self] + + let testToken = "test-auth-token" + + // Create transport with mocked URLSession and auth token + let transport = SSEClientTransport( + endpoint: testEndpoint, + token: testToken, + configuration: configuration + ) + + // Keep track of requests to verify auth headers + let capturedRequests = CapturedRequests() + + await MockSSEURLProtocol.setHandler { request in + if request.httpMethod == "GET" { + await capturedRequests.setSSERequest(request) + let response = HTTPURLResponse( + url: testEndpoint, + statusCode: 200, + httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "text/event-stream"] + )! + + return ( + response, + Data( + """ + event: endpoint + data: /messages/123 + + """.utf8) + ) + } else if request.httpMethod == "POST" { + await capturedRequests.setPostRequest(request) + let response = HTTPURLResponse( + url: URL(string: "http://localhost:8080/messages/123")!, + statusCode: 200, + httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"] + )! + return (response, Data()) + } else { + throw MCPError.internalError("Unexpected request method") + } + } + + // Connect and send a message + try await transport.connect() + try await transport.send(Data()) + + // Verify auth tokens in both requests + let (sseRequest, postRequest) = await capturedRequests.getValues() + #expect(sseRequest?.value(forHTTPHeaderField: "Authorization") == "Bearer \(testToken)") + #expect( + postRequest?.value(forHTTPHeaderField: "Authorization") == "Bearer \(testToken)") + + await transport.disconnect() + } + + @Test("HTTP error handling", .sseTransportSetup) + func testHTTPErrorHandling() async throws { + // Setup URLSession with mock protocol + let configuration = URLSessionConfiguration.ephemeral + configuration.protocolClasses = [MockSSEURLProtocol.self] + + // Create transport with mocked URLSession + let transport = SSEClientTransport( + endpoint: testEndpoint, + configuration: configuration + ) + + // Configure mock to return 404 for SSE connection + await MockSSEURLProtocol.setHandler { request in + let response = HTTPURLResponse( + url: request.url!, + statusCode: 404, + httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "text/plain"] + )! + + return (response, Data("Not Found".utf8)) + } + + // Connection should fail with HTTP error + do { + try await transport.connect() + Issue.record("Connection should have failed with HTTP error") + } catch let error as MCPError { + guard case .internalError(let message) = error else { + Issue.record("Expected MCPError.internalError, got \(error)") + throw error + } + #expect(message?.contains("HTTP error: 404") ?? false) + } + } + + @Test("URL construction from different endpoint formats", .sseTransportSetup) + func testURLConstruction() async throws { + // Setup URLSession with mock protocol + let configuration = URLSessionConfiguration.ephemeral + configuration.protocolClasses = [MockSSEURLProtocol.self] + + // Test different endpoint formats + let testCases = [ + // Relative path without leading slash + "messages/123", + // Relative path with leading slash + "/messages/456", + // Absolute URL + "https://api.example.com/messages/789", + ] + + for endpoint in testCases { + // Create new transport for each test case + let transport = SSEClientTransport( + endpoint: testEndpoint, + configuration: configuration + ) + + // Configure mock to return the current endpoint + let capturedRequest = CapturedRequest() + + await MockSSEURLProtocol.setHandler { request in + if request.httpMethod == "GET" { + let response = HTTPURLResponse( + url: testEndpoint, + statusCode: 200, + httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "text/event-stream"] + )! + + return ( + response, + Data( + """ + event: endpoint + data: \(endpoint) + + """.utf8) + ) + } else if request.httpMethod == "POST" { + await capturedRequest.setValue(request) + let response = HTTPURLResponse( + url: request.url!, + statusCode: 200, + httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"] + )! + return (response, Data()) + } else { + throw MCPError.internalError("Unexpected request method") + } + } + + // Connect and send a test message + try await transport.connect() + try await transport.send(Data()) + + // Verify the URL construction based on endpoint format + let capturedPostRequest = await capturedRequest.getValue() + if endpoint.starts(with: "http") { + // For absolute URLs + #expect(capturedPostRequest?.url?.absoluteString == endpoint) + } else { + // For relative paths + let expectedPath = endpoint.starts(with: "/") ? endpoint : "/\(endpoint)" + #expect(capturedPostRequest?.url?.path == expectedPath) + #expect(capturedPostRequest?.url?.host == testEndpoint.host) + #expect(capturedPostRequest?.url?.scheme == testEndpoint.scheme) + } + + await transport.disconnect() + } + } + } +#endif // swift(>=6.1) && !os(Linux)