diff --git a/Sources/AWSAppSyncApolloExtensions/Websocket/AppSyncWebSocketClient.swift b/Sources/AWSAppSyncApolloExtensions/Websocket/AppSyncWebSocketClient.swift index f727159..fc6aebb 100644 --- a/Sources/AWSAppSyncApolloExtensions/Websocket/AppSyncWebSocketClient.swift +++ b/Sources/AWSAppSyncApolloExtensions/Websocket/AppSyncWebSocketClient.swift @@ -79,36 +79,46 @@ public class AppSyncWebSocketClient: NSObject, ApolloWebSocket.WebSocketClient, } public func connect() { - AppSyncApolloLogger.debug("Calling Connect") - guard connection?.state != .running else { - AppSyncApolloLogger.debug("[AppSyncWebSocketClient] WebSocket is already in connecting state") - return - } - - subscribeToAppSyncResponse() - - Task { + taskQueue.async { [weak self] in + guard let self else { return } + AppSyncApolloLogger.debug("Calling Connect") + guard connection?.state != .running else { + AppSyncApolloLogger.debug("[AppSyncWebSocketClient] WebSocket is already in connecting state") + return + } + + subscribeToAppSyncResponse() + AppSyncApolloLogger.debug("[AppSyncWebSocketClient] Creating new connection and starting read") - self.connection = try await createWebSocketConnection() + self.connection = try await self.createWebSocketConnection() + // Perform reading from a WebSocket in a separate task recursively to avoid blocking the execution. - Task { await self.startReadMessage() } + Task { + await self.startReadMessage() + } + self.connection?.resume() } } public func disconnect(forceTimeout: TimeInterval?) { - AppSyncApolloLogger.debug("Calling Disconnect") - heartBeatMonitorCancellable?.cancel() - guard connection?.state == .running else { - AppSyncApolloLogger.debug("[AppSyncWebSocketClient] client should be in connected state to trigger disconnect") - return + taskQueue.async { [weak self] in + guard let self else { return } + AppSyncApolloLogger.debug("Calling Disconnect") + heartBeatMonitorCancellable?.cancel() + guard connection?.state == .running else { + AppSyncApolloLogger.debug("[AppSyncWebSocketClient] client should be in connected state to trigger disconnect") + return + } + + connection?.cancel(with: .goingAway, reason: nil) } - - connection?.cancel(with: .goingAway, reason: nil) } public func write(ping: Data, completion: (() -> Void)?) { - AppSyncApolloLogger.debug("Not called, not implemented.") + taskQueue.async { + AppSyncApolloLogger.debug("Not called, not implemented.") + } } public func write(string: String) { diff --git a/Tests/AWSAppSyncApolloExtensionsTests/Websocket/AppSyncWebSocketClientTests.swift b/Tests/AWSAppSyncApolloExtensionsTests/Websocket/AppSyncWebSocketClientTests.swift index 57d7790..75cc9a5 100644 --- a/Tests/AWSAppSyncApolloExtensionsTests/Websocket/AppSyncWebSocketClientTests.swift +++ b/Tests/AWSAppSyncApolloExtensionsTests/Websocket/AppSyncWebSocketClientTests.swift @@ -39,6 +39,31 @@ final class AppSyncWebSocketClientTests: XCTestCase { let webSocketClient = AppSyncWebSocketClient(endpointURL: endpoint, authorizer: MockAppSyncAuthorizer()) await verifyConnected(webSocketClient) } + + func testConnect_ConcurrentInvoke() async throws { + guard let endpoint = try localWebSocketServer?.start() else { + XCTFail("Local WebSocket server failed to start") + return + } + let webSocketClient = AppSyncWebSocketClient(endpointURL: endpoint, authorizer: MockAppSyncAuthorizer()) + let connectedExpectation = expectation(description: "WebSocket should received connected event only once") + connectedExpectation.expectedFulfillmentCount = 1 + let sink = webSocketClient.publisher.sink { event in + switch event { + case .connected: + connectedExpectation.fulfill() + default: + XCTFail("No other type of event should be received") + } + } + + for _ in 1...100 { + let task = Task { + webSocketClient.connect() + } + } + await fulfillment(of: [connectedExpectation], timeout: 5) + } func testDisconnect_didDisconnectFromRemote() async throws { var cancellables = Set() diff --git a/Tests/IntegrationTestApp/IntegrationTestAppTests/APIKeyTests.swift b/Tests/IntegrationTestApp/IntegrationTestAppTests/APIKeyTests.swift index 5449809..61756fb 100644 --- a/Tests/IntegrationTestApp/IntegrationTestAppTests/APIKeyTests.swift +++ b/Tests/IntegrationTestApp/IntegrationTestAppTests/APIKeyTests.swift @@ -138,6 +138,9 @@ final class APIKeyTests: IntegrationTestBase { } func testMaxSubscriptionReached() async throws { + let subscriptionLimit = 200 + let failedSubscriptionCount = 5 + let configuration = try AWSAppSyncConfiguration(with: .amplifyOutputs) let store = ApolloStore(cache: InMemoryNormalizedCache()) let authorizer = APIKeyAuthorizer(apiKey: configuration.apiKey ?? "") @@ -148,8 +151,11 @@ final class APIKeyTests: IntegrationTestBase { let websocket = AppSyncWebSocketClient(endpointURL: configuration.endpoint, authorizer: authorizer) let receivedConnection = expectation(description: "received connection") + receivedConnection.expectedFulfillmentCount = subscriptionLimit + let receivedMaxSubscriptionsReachedError = expectation(description: "received MaxSubscriptionsReachedError") - receivedConnection.expectedFulfillmentCount = 100 + receivedMaxSubscriptionsReachedError.expectedFulfillmentCount = failedSubscriptionCount + let sink = websocket.publisher.sink { event in if case .string(let message) = event { if message.contains("start_ack") { @@ -167,13 +173,18 @@ final class APIKeyTests: IntegrationTestBase { webSocketNetworkTransport: webSocketTransport ) let client = ApolloClient(networkTransport: splitTransport, store: store) - - for _ in 1...101 { - _ = client.subscribe(subscription: OnCreateSubscription()) { _ in - } + + try await Task.sleep(nanoseconds: 5 * 1_000_000_000) // 5 seconds + + var cancellables = [Cancellable]() + for _ in 1...subscriptionLimit + failedSubscriptionCount { + cancellables.append(client.subscribe(subscription: OnCreateSubscription()) { _ in }) } - + await fulfillment(of: [receivedConnection, receivedMaxSubscriptionsReachedError], timeout: 10) + + for cancellable in cancellables { + cancellable.cancel() + } } - } diff --git a/Tests/IntegrationTestApp/IntegrationTestAppTests/AuthTokenTests.swift b/Tests/IntegrationTestApp/IntegrationTestAppTests/AuthTokenTests.swift index 986b278..f65d78d 100644 --- a/Tests/IntegrationTestApp/IntegrationTestAppTests/AuthTokenTests.swift +++ b/Tests/IntegrationTestApp/IntegrationTestAppTests/AuthTokenTests.swift @@ -90,7 +90,7 @@ final class AuthTokenTests: IntegrationTestBase { let receivedDisconnectError = expectation(description: "received disconnect") receivedDisconnectError.assertForOverFulfill = false let sink = websocket.publisher.sink { event in - if case .error(let error) = event, error.localizedDescription.contains("Socket is not connected") { + if case .disconnected(_, _) = event { receivedDisconnectError.fulfill() } } @@ -99,8 +99,10 @@ final class AuthTokenTests: IntegrationTestBase { uploadingNetworkTransport: transport, webSocketNetworkTransport: webSocketTransport ) + let apolloCUPInvalidToken = ApolloClient(networkTransport: splitTransport, store: store) + try await Task.sleep(nanoseconds: 5 * 1_000_000_000) // 5 seconds await fulfillment(of: [receivedDisconnectError], timeout: 10) }