diff --git a/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/SimpleAgentTest.kt b/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/SimpleAgentTest.kt index 5bc61d2..a0b7837 100644 --- a/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/SimpleAgentTest.kt +++ b/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/SimpleAgentTest.kt @@ -14,6 +14,7 @@ import kotlinx.coroutines.* import kotlinx.coroutines.flow.* import kotlinx.serialization.json.JsonElement import kotlin.test.* +import kotlin.time.Duration import kotlin.time.Duration.Companion.milliseconds abstract class SimpleAgentTest(protocolDriver: ProtocolDriver) : ProtocolDriver by protocolDriver { @@ -653,6 +654,979 @@ abstract class SimpleAgentTest(protocolDriver: ProtocolDriver) : ProtocolDriver assertTrue(notification is SessionUpdate.AvailableCommandsUpdate) } + @Test + fun `load session should process updates that arrive before load response`() = testWithProtocols { clientProtocol, agentProtocol -> + val sessionId = SessionId("loaded-session-id") + val replayText = "replay-before-load-response" + val notificationDeferred = CompletableDeferred() + + val client = Client(protocol = clientProtocol) + val agent = Agent(protocol = agentProtocol, agentSupport = object : AgentSupport { + override suspend fun initialize(clientInfo: ClientInfo): AgentInfo { + return AgentInfo(clientInfo.protocolVersion) + } + + override suspend fun createSession(sessionParameters: SessionCreationParameters): AgentSession { + TODO("Not yet implemented") + } + + override suspend fun loadSession( + sessionId: SessionId, + sessionParameters: SessionCreationParameters, + ): AgentSession { + AcpMethod.ClientMethods.SessionUpdate( + agentProtocol, + SessionNotification( + sessionId = sessionId, + update = SessionUpdate.AgentMessageChunk(ContentBlock.Text(replayText)) + ) + ) + + return object : AgentSession { + override val sessionId: SessionId = sessionId + + override suspend fun prompt( + content: List, + _meta: JsonElement?, + ): Flow = emptyFlow() + } + } + }) + + client.initialize(ClientInfo(protocolVersion = LATEST_PROTOCOL_VERSION)) + + val loadedSession = withTimeout(2000.milliseconds) { + client.loadSession(sessionId, SessionCreationParameters("/test/path", emptyList())) { _, _ -> + object : ClientSessionOperations { + override suspend fun requestPermissions( + toolCall: SessionUpdate.ToolCallUpdate, + permissions: List, + _meta: JsonElement?, + ): RequestPermissionResponse { + return RequestPermissionResponse(RequestPermissionOutcome.Cancelled) + } + + override suspend fun notify( + notification: SessionUpdate, + _meta: JsonElement?, + ) { + notificationDeferred.complete(notification) + } + } + } + } + + assertEquals(sessionId, loadedSession.sessionId) + + val notification = withTimeout(2000.milliseconds) { notificationDeferred.await() } + assertTrue(notification is SessionUpdate.AgentMessageChunk) + val text = ((notification as SessionUpdate.AgentMessageChunk).content as ContentBlock.Text).text + assertEquals(replayText, text) + } + + @Test + fun `unknown session replay should not leak into later load after init cleanup`() = testWithProtocols { clientProtocol, agentProtocol -> + val createdSessionId = SessionId("created-session-id") + val unknownSessionId = SessionId("unknown-session-id") + val staleReplay = "stale-replay-from-unknown-session" + val freshReplay = "fresh-replay-from-load-session" + val received = mutableListOf() + + val client = Client(protocol = clientProtocol) + Agent(protocol = agentProtocol, agentSupport = object : AgentSupport { + override suspend fun initialize(clientInfo: ClientInfo): AgentInfo { + return AgentInfo(clientInfo.protocolVersion) + } + + override suspend fun createSession(sessionParameters: SessionCreationParameters): AgentSession { + AcpMethod.ClientMethods.SessionUpdate( + agentProtocol, + SessionNotification( + sessionId = unknownSessionId, + update = SessionUpdate.AgentMessageChunk(ContentBlock.Text(staleReplay)) + ) + ) + + return object : AgentSession { + override val sessionId: SessionId = createdSessionId + + override suspend fun prompt( + content: List, + _meta: JsonElement?, + ): Flow = emptyFlow() + } + } + + override suspend fun loadSession( + sessionId: SessionId, + sessionParameters: SessionCreationParameters, + ): AgentSession { + AcpMethod.ClientMethods.SessionUpdate( + agentProtocol, + SessionNotification( + sessionId = sessionId, + update = SessionUpdate.AgentMessageChunk(ContentBlock.Text(freshReplay)) + ) + ) + + return object : AgentSession { + override val sessionId: SessionId = sessionId + + override suspend fun prompt( + content: List, + _meta: JsonElement?, + ): Flow = emptyFlow() + } + } + }) + + client.initialize(ClientInfo(protocolVersion = LATEST_PROTOCOL_VERSION)) + + withTimeout(2000.milliseconds) { + client.newSession(SessionCreationParameters("/test/path", emptyList())) { _, _ -> + object : ClientSessionOperations { + override suspend fun requestPermissions( + toolCall: SessionUpdate.ToolCallUpdate, + permissions: List, + _meta: JsonElement?, + ): RequestPermissionResponse { + return RequestPermissionResponse(RequestPermissionOutcome.Cancelled) + } + + override suspend fun notify( + notification: SessionUpdate, + _meta: JsonElement?, + ) { + } + } + } + } + + val loadedSession = withTimeout(2000.milliseconds) { + client.loadSession(unknownSessionId, SessionCreationParameters("/test/path", emptyList())) { _, _ -> + object : ClientSessionOperations { + override suspend fun requestPermissions( + toolCall: SessionUpdate.ToolCallUpdate, + permissions: List, + _meta: JsonElement?, + ): RequestPermissionResponse { + return RequestPermissionResponse(RequestPermissionOutcome.Cancelled) + } + + override suspend fun notify( + notification: SessionUpdate, + _meta: JsonElement?, + ) { + val text = (notification as? SessionUpdate.AgentMessageChunk)?.content as? ContentBlock.Text ?: return + received.add(text.text) + } + } + } + } + + assertEquals(unknownSessionId, loadedSession.sessionId) + assertContentEquals(listOf(freshReplay), received) + } + + @Test + fun `unknown session request during init should fail fast with not found`() = testWithProtocols { clientProtocol, agentProtocol -> + val unknownSessionId = SessionId("unknown-request-session-id") + val requestFailure = CompletableDeferred() + + val client = Client(protocol = clientProtocol) + Agent(protocol = agentProtocol, agentSupport = object : AgentSupport { + override suspend fun initialize(clientInfo: ClientInfo): AgentInfo { + return AgentInfo(clientInfo.protocolVersion) + } + + override suspend fun createSession(sessionParameters: SessionCreationParameters): AgentSession { + this@testWithProtocols.launch { + try { + AcpMethod.ClientMethods.FsReadTextFile( + agentProtocol, + ReadTextFileRequest(unknownSessionId, "/test/path", null, null, null) + ) + requestFailure.complete(AssertionError("Unknown session request should fail")) + } catch (t: Throwable) { + requestFailure.complete(t) + } + } + + delay(200.milliseconds) + return object : AgentSession { + override val sessionId: SessionId = SessionId("known-session-id") + + override suspend fun prompt( + content: List, + _meta: JsonElement?, + ): Flow = emptyFlow() + } + } + + override suspend fun loadSession( + sessionId: SessionId, + sessionParameters: SessionCreationParameters, + ): AgentSession { + TODO("Not yet implemented") + } + }) + + client.initialize(ClientInfo(protocolVersion = LATEST_PROTOCOL_VERSION)) + + withTimeout(2000.milliseconds) { + client.newSession(SessionCreationParameters("/test/path", emptyList())) { _, _ -> + object : ClientSessionOperations { + override suspend fun requestPermissions( + toolCall: SessionUpdate.ToolCallUpdate, + permissions: List, + _meta: JsonElement?, + ): RequestPermissionResponse { + return RequestPermissionResponse(RequestPermissionOutcome.Cancelled) + } + + override suspend fun notify( + notification: SessionUpdate, + _meta: JsonElement?, + ) { + } + } + } + } + + val failure = withTimeout(2000.milliseconds) { requestFailure.await() } + val message = failure.message ?: failure.toString() + assertTrue( + message.contains("Session $unknownSessionId not found"), + "Unexpected failure message: $message" + ) + } + + @Test + fun `unknown session replay should survive while another session still initializing`() = testWithProtocols { clientProtocol, agentProtocol -> + val newSessionId = SessionId("parallel-new-session-id") + val delayedLoadSessionId = SessionId("parallel-load-session-id") + val replayText = "replay-during-parallel-init" + val loadStarted = CompletableDeferred() + val allowLoadToComplete = CompletableDeferred() + val replayReceived = CompletableDeferred() + + val client = Client(protocol = clientProtocol) + Agent(protocol = agentProtocol, agentSupport = object : AgentSupport { + override suspend fun initialize(clientInfo: ClientInfo): AgentInfo { + return AgentInfo(clientInfo.protocolVersion) + } + + override suspend fun createSession(sessionParameters: SessionCreationParameters): AgentSession { + loadStarted.await() + AcpMethod.ClientMethods.SessionUpdate( + agentProtocol, + SessionNotification( + sessionId = delayedLoadSessionId, + update = SessionUpdate.AgentMessageChunk(ContentBlock.Text(replayText)) + ) + ) + return object : AgentSession { + override val sessionId: SessionId = newSessionId + + override suspend fun prompt( + content: List, + _meta: JsonElement?, + ): Flow = emptyFlow() + } + } + + override suspend fun loadSession( + sessionId: SessionId, + sessionParameters: SessionCreationParameters, + ): AgentSession { + loadStarted.complete(Unit) + allowLoadToComplete.await() + return object : AgentSession { + override val sessionId: SessionId = sessionId + + override suspend fun prompt( + content: List, + _meta: JsonElement?, + ): Flow = emptyFlow() + } + } + }) + + client.initialize(ClientInfo(protocolVersion = LATEST_PROTOCOL_VERSION)) + + val loadedSessionDeferred = async { + client.loadSession(delayedLoadSessionId, SessionCreationParameters("/test/path", emptyList())) { _, _ -> + object : ClientSessionOperations { + override suspend fun requestPermissions( + toolCall: SessionUpdate.ToolCallUpdate, + permissions: List, + _meta: JsonElement?, + ): RequestPermissionResponse { + return RequestPermissionResponse(RequestPermissionOutcome.Cancelled) + } + + override suspend fun notify( + notification: SessionUpdate, + _meta: JsonElement?, + ) { + val text = (notification as? SessionUpdate.AgentMessageChunk)?.content as? ContentBlock.Text ?: return + if (!replayReceived.isCompleted) { + replayReceived.complete(text.text) + } + } + } + } + } + + withTimeout(2000.milliseconds) { loadStarted.await() } + + val createdSession = withTimeout(2000.milliseconds) { + client.newSession(SessionCreationParameters("/test/path", emptyList())) { _, _ -> + object : ClientSessionOperations { + override suspend fun requestPermissions( + toolCall: SessionUpdate.ToolCallUpdate, + permissions: List, + _meta: JsonElement?, + ): RequestPermissionResponse { + return RequestPermissionResponse(RequestPermissionOutcome.Cancelled) + } + + override suspend fun notify( + notification: SessionUpdate, + _meta: JsonElement?, + ) { + } + } + } + } + assertEquals(newSessionId, createdSession.sessionId) + + allowLoadToComplete.complete(Unit) + val loadedSession = withTimeout(2000.milliseconds) { loadedSessionDeferred.await() } + + assertEquals(delayedLoadSessionId, loadedSession.sessionId) + assertEquals(replayText, withTimeout(2000.milliseconds) { replayReceived.await() }) + } + + @OptIn(UnstableApi::class) + @Test + fun `fork session should process updates that arrive before fork response`() = testWithProtocols { clientProtocol, agentProtocol -> + val sourceSessionId = SessionId("source-session-id") + val forkedSessionId = SessionId("forked-session-id") + val replayText = "replay-before-fork-response" + val notificationDeferred = CompletableDeferred() + + val client = Client(protocol = clientProtocol) + Agent(protocol = agentProtocol, agentSupport = object : AgentSupport { + override suspend fun initialize(clientInfo: ClientInfo): AgentInfo { + return AgentInfo( + clientInfo.protocolVersion, + capabilities = AgentCapabilities( + sessionCapabilities = SessionCapabilities(fork = SessionForkCapabilities()) + ) + ) + } + + override suspend fun createSession(sessionParameters: SessionCreationParameters): AgentSession { + TODO("Not yet implemented") + } + + override suspend fun loadSession( + sessionId: SessionId, + sessionParameters: SessionCreationParameters, + ): AgentSession { + TODO("Not yet implemented") + } + + override suspend fun forkSession( + sessionId: SessionId, + sessionParameters: SessionCreationParameters, + ): AgentSession { + assertEquals(sourceSessionId, sessionId) + AcpMethod.ClientMethods.SessionUpdate( + agentProtocol, + SessionNotification( + sessionId = forkedSessionId, + update = SessionUpdate.AgentMessageChunk(ContentBlock.Text(replayText)) + ) + ) + return object : AgentSession { + override val sessionId: SessionId = forkedSessionId + + override suspend fun prompt( + content: List, + _meta: JsonElement?, + ): Flow = emptyFlow() + } + } + }) + + client.initialize(ClientInfo(protocolVersion = LATEST_PROTOCOL_VERSION)) + + val forkedSession = withTimeout(2000.milliseconds) { + client.forkSession(sourceSessionId, SessionCreationParameters("/test/path", emptyList())) { _, _ -> + object : ClientSessionOperations { + override suspend fun requestPermissions( + toolCall: SessionUpdate.ToolCallUpdate, + permissions: List, + _meta: JsonElement?, + ): RequestPermissionResponse { + return RequestPermissionResponse(RequestPermissionOutcome.Cancelled) + } + + override suspend fun notify( + notification: SessionUpdate, + _meta: JsonElement?, + ) { + notificationDeferred.complete(notification) + } + } + } + } + + assertEquals(forkedSessionId, forkedSession.sessionId) + val notification = withTimeout(2000.milliseconds) { notificationDeferred.await() } + assertTrue(notification is SessionUpdate.AgentMessageChunk) + val text = ((notification as SessionUpdate.AgentMessageChunk).content as ContentBlock.Text).text + assertEquals(replayText, text) + } + + @OptIn(UnstableApi::class) + @Test + fun `resume session should process updates that arrive before resume response`() = testWithProtocols { clientProtocol, agentProtocol -> + val resumedSessionId = SessionId("resumed-session-id") + val replayText = "replay-before-resume-response" + val notificationDeferred = CompletableDeferred() + + val client = Client(protocol = clientProtocol) + Agent(protocol = agentProtocol, agentSupport = object : AgentSupport { + override suspend fun initialize(clientInfo: ClientInfo): AgentInfo { + return AgentInfo( + clientInfo.protocolVersion, + capabilities = AgentCapabilities( + sessionCapabilities = SessionCapabilities(resume = SessionResumeCapabilities()) + ) + ) + } + + override suspend fun createSession(sessionParameters: SessionCreationParameters): AgentSession { + TODO("Not yet implemented") + } + + override suspend fun loadSession( + sessionId: SessionId, + sessionParameters: SessionCreationParameters, + ): AgentSession { + TODO("Not yet implemented") + } + + override suspend fun resumeSession( + sessionId: SessionId, + sessionParameters: SessionCreationParameters, + ): AgentSession { + assertEquals(resumedSessionId, sessionId) + AcpMethod.ClientMethods.SessionUpdate( + agentProtocol, + SessionNotification( + sessionId = sessionId, + update = SessionUpdate.AgentMessageChunk(ContentBlock.Text(replayText)) + ) + ) + return object : AgentSession { + override val sessionId: SessionId = sessionId + + override suspend fun prompt( + content: List, + _meta: JsonElement?, + ): Flow = emptyFlow() + } + } + }) + + client.initialize(ClientInfo(protocolVersion = LATEST_PROTOCOL_VERSION)) + + val resumedSession = withTimeout(2000.milliseconds) { + client.resumeSession(resumedSessionId, SessionCreationParameters("/test/path", emptyList())) { _, _ -> + object : ClientSessionOperations { + override suspend fun requestPermissions( + toolCall: SessionUpdate.ToolCallUpdate, + permissions: List, + _meta: JsonElement?, + ): RequestPermissionResponse { + return RequestPermissionResponse(RequestPermissionOutcome.Cancelled) + } + + override suspend fun notify( + notification: SessionUpdate, + _meta: JsonElement?, + ) { + notificationDeferred.complete(notification) + } + } + } + } + + assertEquals(resumedSessionId, resumedSession.sessionId) + val notification = withTimeout(2000.milliseconds) { notificationDeferred.await() } + assertTrue(notification is SessionUpdate.AgentMessageChunk) + val text = ((notification as SessionUpdate.AgentMessageChunk).content as ContentBlock.Text).text + assertEquals(replayText, text) + } + + @Test + fun `failed session creation should cleanup holder and allow retry`() = testWithProtocols { clientProtocol, agentProtocol -> + val sessionId = SessionId("retry-session-id") + var loadAttempt = 0 + val firstNotification = CompletableDeferred() + val secondNotification = CompletableDeferred() + + val client = Client(protocol = clientProtocol) + Agent(protocol = agentProtocol, agentSupport = object : AgentSupport { + override suspend fun initialize(clientInfo: ClientInfo): AgentInfo { + return AgentInfo(clientInfo.protocolVersion) + } + + override suspend fun createSession(sessionParameters: SessionCreationParameters): AgentSession { + TODO("Not yet implemented") + } + + override suspend fun loadSession( + sessionId: SessionId, + sessionParameters: SessionCreationParameters, + ): AgentSession { + loadAttempt += 1 + AcpMethod.ClientMethods.SessionUpdate( + agentProtocol, + SessionNotification( + sessionId = sessionId, + update = SessionUpdate.AgentMessageChunk(ContentBlock.Text("replay-attempt-$loadAttempt")) + ) + ) + return object : AgentSession { + override val sessionId: SessionId = sessionId + + override suspend fun prompt( + content: List, + _meta: JsonElement?, + ): Flow = emptyFlow() + } + } + }) + + client.initialize(ClientInfo(protocolVersion = LATEST_PROTOCOL_VERSION)) + + val firstFailure = runCatching { + withTimeout(2000.milliseconds) { + client.loadSession(sessionId, SessionCreationParameters("/test/path", emptyList())) { _, _ -> + throw IllegalStateException("operations factory failed") + } + } + }.exceptionOrNull() + + assertNotNull(firstFailure, "First attempt should fail") + assertTrue( + (firstFailure.message ?: "").contains("operations factory failed"), + "Unexpected first failure: $firstFailure" + ) + + val loadedSession = withTimeout(2000.milliseconds) { + client.loadSession(sessionId, SessionCreationParameters("/test/path", emptyList())) { _, _ -> + object : ClientSessionOperations { + override suspend fun requestPermissions( + toolCall: SessionUpdate.ToolCallUpdate, + permissions: List, + _meta: JsonElement?, + ): RequestPermissionResponse { + return RequestPermissionResponse(RequestPermissionOutcome.Cancelled) + } + + override suspend fun notify( + notification: SessionUpdate, + _meta: JsonElement?, + ) { + val text = (notification as? SessionUpdate.AgentMessageChunk)?.content as? ContentBlock.Text ?: return + if (!firstNotification.isCompleted) { + firstNotification.complete(text.text) + } else if (!secondNotification.isCompleted) { + secondNotification.complete(text.text) + } + } + } + } + } + + assertEquals(sessionId, loadedSession.sessionId) + assertEquals("replay-attempt-2", withTimeout(2000.milliseconds) { firstNotification.await() }) + assertNull(withTimeoutOrNull(300) { secondNotification.await() }, "Retry should not receive stale replay from failed attempt") + } + + @Test + fun `new session should cleanup holder when createClientOperations fails`() = testWithProtocols { clientProtocol, agentProtocol -> + val sessionId = SessionId("new-session-retry-id") + var createAttempt = 0 + val received = mutableListOf() + + val client = Client(protocol = clientProtocol) + Agent(protocol = agentProtocol, agentSupport = object : AgentSupport { + override suspend fun initialize(clientInfo: ClientInfo): AgentInfo { + return AgentInfo(clientInfo.protocolVersion) + } + + override suspend fun createSession(sessionParameters: SessionCreationParameters): AgentSession { + createAttempt += 1 + AcpMethod.ClientMethods.SessionUpdate( + agentProtocol, + SessionNotification( + sessionId = sessionId, + update = SessionUpdate.AgentMessageChunk(ContentBlock.Text("replay-attempt-$createAttempt")) + ) + ) + return object : AgentSession { + override val sessionId: SessionId = sessionId + + override suspend fun prompt( + content: List, + _meta: JsonElement?, + ): Flow = emptyFlow() + } + } + + override suspend fun loadSession( + sessionId: SessionId, + sessionParameters: SessionCreationParameters, + ): AgentSession { + TODO("Not yet implemented") + } + }) + + client.initialize(ClientInfo(protocolVersion = LATEST_PROTOCOL_VERSION)) + + val firstFailure = runCatching { + withTimeout(2000.milliseconds) { + client.newSession(SessionCreationParameters("/test/path", emptyList())) { _, _ -> + throw IllegalStateException("operations factory failed") + } + } + }.exceptionOrNull() + + assertNotNull(firstFailure, "First attempt should fail") + assertTrue( + (firstFailure.message ?: "").contains("operations factory failed"), + "Unexpected first failure: $firstFailure" + ) + + val getSessionFailure = runCatching { + client.getSession(sessionId) + }.exceptionOrNull() + assertNotNull(getSessionFailure, "Session holder should be removed after factory failure") + assertTrue( + (getSessionFailure.message ?: "").contains("Session $sessionId not found"), + "Unexpected getSession failure: $getSessionFailure" + ) + + val createdSession = withTimeout(2000.milliseconds) { + client.newSession(SessionCreationParameters("/test/path", emptyList())) { _, _ -> + object : ClientSessionOperations { + override suspend fun requestPermissions( + toolCall: SessionUpdate.ToolCallUpdate, + permissions: List, + _meta: JsonElement?, + ): RequestPermissionResponse { + return RequestPermissionResponse(RequestPermissionOutcome.Cancelled) + } + + override suspend fun notify( + notification: SessionUpdate, + _meta: JsonElement?, + ) { + val text = (notification as? SessionUpdate.AgentMessageChunk)?.content as? ContentBlock.Text ?: return + received.add(text.text) + } + } + } + } + + assertEquals(sessionId, createdSession.sessionId) + assertContentEquals(listOf("replay-attempt-2"), received) + } + + @Test + fun `new session should process updates that arrive before new response`() = testWithProtocols { clientProtocol, agentProtocol -> + val sessionId = SessionId("new-session-id") + val replayTexts = listOf("replay-before-new-response-1", "replay-before-new-response-2") + val notificationsDeferred = CompletableDeferred>() + val received = mutableListOf() + + val client = Client(protocol = clientProtocol) + val agent = Agent(protocol = agentProtocol, agentSupport = object : AgentSupport { + override suspend fun initialize(clientInfo: ClientInfo): AgentInfo { + return AgentInfo(clientInfo.protocolVersion) + } + + override suspend fun createSession(sessionParameters: SessionCreationParameters): AgentSession { + for (text in replayTexts) { + AcpMethod.ClientMethods.SessionUpdate( + agentProtocol, + SessionNotification( + sessionId = sessionId, + update = SessionUpdate.AgentMessageChunk(ContentBlock.Text(text)) + ) + ) + } + + return object : AgentSession { + override val sessionId: SessionId = sessionId + + override suspend fun prompt( + content: List, + _meta: JsonElement?, + ): Flow = emptyFlow() + } + } + + override suspend fun loadSession( + sessionId: SessionId, + sessionParameters: SessionCreationParameters, + ): AgentSession { + TODO("Not yet implemented") + } + }) + + client.initialize(ClientInfo(protocolVersion = LATEST_PROTOCOL_VERSION)) + + val createdSession = withTimeout(2000.milliseconds) { + client.newSession(SessionCreationParameters("/test/path", emptyList())) { _, _ -> + object : ClientSessionOperations { + override suspend fun requestPermissions( + toolCall: SessionUpdate.ToolCallUpdate, + permissions: List, + _meta: JsonElement?, + ): RequestPermissionResponse { + return RequestPermissionResponse(RequestPermissionOutcome.Cancelled) + } + + override suspend fun notify( + notification: SessionUpdate, + _meta: JsonElement?, + ) { + if (notification !is SessionUpdate.AgentMessageChunk) return + val text = (notification.content as ContentBlock.Text).text + received.add(text) + if (received.size == replayTexts.size) { + notificationsDeferred.complete(received.toList()) + } + } + } + } + } + + assertEquals(sessionId, createdSession.sessionId) + assertContentEquals(replayTexts, withTimeout(2000.milliseconds) { notificationsDeferred.await() }) + } + + @Test + fun `create session should not lose events client slower agent (new session)`() = `create session should not lose events`(agentEmitDelay = 100.milliseconds, clientNotifyDelay = 200.milliseconds, isLoad = false) + @Test + fun `create session should not lose events agent slower client (new session)`() = `create session should not lose events`(agentEmitDelay = 200.milliseconds, clientNotifyDelay = 100.milliseconds, isLoad = false) + @Test + fun `create session should not lose events client slower agent (load session)`() = `create session should not lose events`(agentEmitDelay = 100.milliseconds, clientNotifyDelay = 200.milliseconds, isLoad = true) + @Test + fun `create session should not lose events agent slower client (load session)`() = `create session should not lose events`(agentEmitDelay = 200.milliseconds, clientNotifyDelay = 100.milliseconds, isLoad = true) + @Test + fun `load session should handle huge replay updates when client notify is slow`() = testWithProtocols { clientProtocol, agentProtocol -> + val sessionId = SessionId("load-session-capacity-stress-id") + val updatesCount = 1025 + val updates = (1..updatesCount).map { index -> + "msg-$index" + } + val receivedMessages = mutableListOf() + val allReceived = CompletableDeferred>() + + val client = Client(protocol = clientProtocol) + val agent = Agent(protocol = agentProtocol, agentSupport = object : AgentSupport { + override suspend fun initialize(clientInfo: ClientInfo): AgentInfo { + return AgentInfo(clientInfo.protocolVersion) + } + + override suspend fun createSession(sessionParameters: SessionCreationParameters): AgentSession { + TODO("Not yet implemented") + } + + override suspend fun loadSession( + sessionId: SessionId, + sessionParameters: SessionCreationParameters, + ): AgentSession { + for (message in updates) { + AcpMethod.ClientMethods.SessionUpdate( + agentProtocol, + SessionNotification( + sessionId, + SessionUpdate.AgentMessageChunk(ContentBlock.Text(message)) + ) + ) + } + + return object : AgentSession { + override val sessionId: SessionId = sessionId + + override suspend fun prompt( + content: List, + _meta: JsonElement?, + ): Flow = emptyFlow() + } + } + }) + + client.initialize(ClientInfo(protocolVersion = LATEST_PROTOCOL_VERSION)) + + val loadedSession = withTimeout(30000.milliseconds) { + client.loadSession(sessionId, SessionCreationParameters("/test/path", emptyList())) { _, _ -> + object : ClientSessionOperations { + override suspend fun requestPermissions( + toolCall: SessionUpdate.ToolCallUpdate, + permissions: List, + _meta: JsonElement?, + ): RequestPermissionResponse { + return RequestPermissionResponse(RequestPermissionOutcome.Cancelled) + } + + override suspend fun notify( + notification: SessionUpdate, + _meta: JsonElement?, + ) { + delay(2.milliseconds) + val text = (notification as? SessionUpdate.AgentMessageChunk)?.content as? ContentBlock.Text ?: return + receivedMessages.add(text.text) + if (receivedMessages.size == updatesCount && !allReceived.isCompleted) { + allReceived.complete(receivedMessages.toList()) + } + } + } + } + } + + assertEquals(sessionId, loadedSession.sessionId) + assertContentEquals(updates, withTimeout(30000.milliseconds) { allReceived.await() }) + } + + private fun `create session should not lose events`(agentEmitDelay: Duration, clientNotifyDelay: Duration, isLoad: Boolean = false) = testWithProtocols { clientProtocol, agentProtocol -> + val sessionId = SessionId("slow-notify-load-session-id") + val replayUpdates = (1..10).map { index -> + if (index % 2 == 0) { + SessionUpdate.AgentMessageChunk(ContentBlock.Text("agent-$index")) + } else { + SessionUpdate.UserMessageChunk(ContentBlock.Text("user-$index")) + } + } + val postInitializeText = "post-initialize-agent" + val expectedMessages = replayUpdates.mapNotNull { update -> + when (update) { + is SessionUpdate.AgentMessageChunk -> "agent:${(update.content as ContentBlock.Text).text}" + is SessionUpdate.UserMessageChunk -> "user:${(update.content as ContentBlock.Text).text}" + else -> null + } + } + + val receivedMessages = mutableListOf() + + val client = Client(protocol = clientProtocol) + val agent = Agent(protocol = agentProtocol, agentSupport = object : AgentSupport { + override suspend fun initialize(clientInfo: ClientInfo): AgentInfo { + return AgentInfo(clientInfo.protocolVersion) + } + + override suspend fun createSession(sessionParameters: SessionCreationParameters): AgentSession { + return newSession(sessionId) + } + + override suspend fun loadSession( + sessionId: SessionId, + sessionParameters: SessionCreationParameters, + ): AgentSession { + return newSession(sessionId) + } + + private suspend fun newSession(sessionId: SessionId): AgentSession { + for (update in replayUpdates) { + AcpMethod.ClientMethods.SessionUpdate( + agentProtocol, + SessionNotification(sessionId, update) + ) + delay(agentEmitDelay) + } + + return object : AgentSession { + override val sessionId: SessionId = sessionId + + override suspend fun postInitialize() { + delay(agentEmitDelay) + currentCoroutineContext().client.notify( + SessionUpdate.AgentMessageChunk(ContentBlock.Text(postInitializeText)) + ) + } + + override suspend fun prompt( + content: List, + _meta: JsonElement?, + ): Flow = emptyFlow() + } + } + }) + + client.initialize(ClientInfo(protocolVersion = LATEST_PROTOCOL_VERSION)) + + val postInitializeDeferred = CompletableDeferred() + withTimeout(8000.milliseconds) { + fun createOperations( + clientNotifyDelay: Duration, + postInitializeText: String, + postInitializeDeferred: CompletableDeferred, + receivedMessages: MutableList, + ): ClientSessionOperations = object : ClientSessionOperations { + override suspend fun requestPermissions( + toolCall: SessionUpdate.ToolCallUpdate, + permissions: List, + _meta: JsonElement?, + ): RequestPermissionResponse { + return RequestPermissionResponse(RequestPermissionOutcome.Cancelled) + } + + override suspend fun notify( + notification: SessionUpdate, + _meta: JsonElement?, + ) { + delay(clientNotifyDelay) + val message = when (notification) { + is SessionUpdate.AgentMessageChunk -> "agent:${(notification.content as ContentBlock.Text).text}" + is SessionUpdate.UserMessageChunk -> "user:${(notification.content as ContentBlock.Text).text}" + else -> return + } + if (message.contains(postInitializeText)) { + postInitializeDeferred.complete(message) + } else { + receivedMessages.add(message) + } + } + } + + if (isLoad) { + client.loadSession(sessionId, SessionCreationParameters("/test/path", emptyList())) { _, _ -> + createOperations(clientNotifyDelay, postInitializeText, postInitializeDeferred, receivedMessages) + } + } else { + client.newSession(SessionCreationParameters("/test/path", emptyList())) { _, _ -> + createOperations(clientNotifyDelay, postInitializeText, postInitializeDeferred, receivedMessages) + } + } + } + + withTimeout(5000.milliseconds) { + postInitializeDeferred.await() + } + assertContentEquals(expectedMessages, receivedMessages) + } + + @OptIn(UnstableApi::class) @Test fun `list sessions returns paginated results`() = testWithProtocols { clientProtocol, agentProtocol -> @@ -1026,4 +2000,4 @@ abstract class SimpleAgentTest(protocolDriver: ProtocolDriver) : ProtocolDriver assertEquals("help", command.name) } -} \ No newline at end of file +} diff --git a/acp/src/commonMain/kotlin/com/agentclientprotocol/client/Client.kt b/acp/src/commonMain/kotlin/com/agentclientprotocol/client/Client.kt index 39cbbb3..b56f6c9 100644 --- a/acp/src/commonMain/kotlin/com/agentclientprotocol/client/Client.kt +++ b/acp/src/commonMain/kotlin/com/agentclientprotocol/client/Client.kt @@ -13,13 +13,13 @@ import com.agentclientprotocol.util.PaginatedResponseToFlowAdapter import io.github.oshai.kotlinlogging.KotlinLogging import kotlinx.atomicfu.atomic import kotlinx.atomicfu.update +import kotlinx.collections.immutable.PersistentMap import kotlinx.collections.immutable.persistentMapOf import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.Deferred import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.MutableStateFlow -import kotlinx.coroutines.flow.first -import kotlinx.coroutines.flow.update import kotlinx.serialization.json.JsonElement private val logger = KotlinLogging.logger {} @@ -40,10 +40,86 @@ public typealias ClientInstance = Client public class Client( public val protocol: Protocol ) { - private val _sessions = atomic(persistentMapOf>()) + private class ClientSessionHolder { + private val sessionDeferred: CompletableDeferred = CompletableDeferred() + // Don't make the channel limited, because it leads to a deadlock also: + // when client side makes loadSession/newSession and an agent sends updates more than channel.capacity + // the message with call response suspends because protocol thread is suspended in handleNotification + // if to address it we have to somehow reorder events, that's not obvious on the protocol level, so we pay with memory right now to handle it + private val notifications = Channel>(capacity = Channel.UNLIMITED) + + val session: Deferred get() = sessionDeferred + + suspend fun drainEventsAndCompleteSession(session: ClientSessionImpl) { + @OptIn(ExperimentalCoroutinesApi::class) + notifications.close() + for ((notification, meta) in notifications) { + session.executeWithSession { + session.handleNotification(notification, meta) + } + } + + sessionDeferred.complete(session) + } + + fun completeExceptionally(cause: Throwable) { + notifications.close(cause) + sessionDeferred.completeExceptionally(cause) + } + + suspend fun handleOrQueue(notification: SessionUpdate, _meta: JsonElement?) { + val sendResult = notifications.trySend(Pair(notification, _meta)) + + // means that `close` was called in drain + if (!sendResult.isSuccess) { + // probably it will suspend for the period of loop with `handleNotification` above + val session = this@ClientSessionHolder.session.await() + session.executeWithSession { + session.handleNotification(notification, _meta) + } + } + } + } + + private data class SessionsStorage(val initializingSessionsCount: Int = 0, val sessions: PersistentMap = persistentMapOf()) + + private val _sessions = atomic(SessionsStorage()) + + /** + * Creates a new entry only if there are some currently initializing sessions. Otherwise, throws in the case of missing session. + */ + private fun getOrCreateSessionHolder(sessionId: SessionId): ClientSessionHolder { + val sessionsStorage = _sessions.value + val holder = sessionsStorage.sessions[sessionId] + if (holder != null) return holder + var clientSessionHolder: ClientSessionHolder? = null + _sessions.update { currentStorage -> + if (currentStorage.initializingSessionsCount > 0) { + val existingHolder = currentStorage.sessions[sessionId] + if (existingHolder != null) { + clientSessionHolder = existingHolder + currentStorage + } else { + val newHolder = ClientSessionHolder() + clientSessionHolder = newHolder + currentStorage.copy(sessions = currentStorage.sessions.put(sessionId, newHolder)) + } + } else { + clientSessionHolder = null + currentStorage + } + } + return clientSessionHolder ?: acpFail("Session $sessionId not found") + } + + private fun removeSessionHolder(sessionId: SessionId) { + _sessions.update { currentMap -> + currentMap.copy(sessions = currentMap.sessions.remove(sessionId)) + } + } + private val _clientInfo = CompletableDeferred() private val _agentInfo = CompletableDeferred() - private val _currentlyInitializingSessionsCount = MutableStateFlow(0) init { // Set up request handlers for incoming agent requests @@ -118,10 +194,8 @@ public class Client( } protocol.setNotificationHandler(AcpMethod.ClientMethods.SessionUpdate) { params: SessionNotification -> - val session = getSessionOrThrow(params.sessionId) - session.executeWithSession { - session.handleNotification(params.update, params._meta) - } + val sessionHolder = getOrCreateSessionHolder(params.sessionId) + sessionHolder.handleOrQueue(params.update, params._meta) } } @@ -293,49 +367,73 @@ public class Client( } } + /** + * After ClientSessionImpl is created the delayed notifications are drained and pushed into session.notify() + */ private suspend fun createSession(sessionId: SessionId, sessionParameters: SessionCreationParameters, sessionResponse: AcpCreatedSessionResponse, factory: ClientOperationsFactory): ClientSession { - val sessionDeferred = CompletableDeferred() + // doesn't throw if executing under `withInitializingSession` because creates a new entry + val sessionHolder = getOrCreateSessionHolder(sessionId) return runCatching { - _sessions.update { it.put(sessionId, sessionDeferred) } - val operations = factory.createClientOperations(sessionId, sessionResponse) - val session = ClientSessionImpl(this, sessionId, sessionParameters, operations, sessionResponse, protocol) - sessionDeferred.complete(session) + sessionHolder.drainEventsAndCompleteSession(session) session - }.getOrElse { - sessionDeferred.completeExceptionally(IllegalStateException("Failed to create session $sessionId", it)) - _sessions.update { it.remove(sessionId) } - throw it + }.getOrElse { throwable -> + // throw IllegalStateException to pass it as INTERNAL_ERROR to the other side (see in Protocol) + sessionHolder.completeExceptionally(IllegalStateException("Failed to create session $sessionId", throwable)) + // cleanup of this sessionId entry will be done in finally of withInitializingSession + throw throwable } } public fun getSession(sessionId: SessionId): ClientSession { - val completableDeferred = _sessions.value[sessionId] ?: error("Session $sessionId not found") - if (!completableDeferred.isCompleted) error("Session $sessionId not initialized yet") + val sessionHolder = _sessions.value.sessions[sessionId] ?: error("Session $sessionId not found") + if (!sessionHolder.session.isCompleted) error("Session $sessionId not initialized yet") @OptIn(ExperimentalCoroutinesApi::class) - return completableDeferred.getCompleted() + return sessionHolder.session.getCompleted() } private suspend fun getSessionOrThrow(sessionId: SessionId): ClientSessionImpl { - _sessions.value[sessionId]?.let { - return it.await() - } - // try to wait for all pending sessions to initialize - _currentlyInitializingSessionsCount.first { it == 0 } - // try to get the session again - _sessions.value[sessionId]?.let { - return it.await() - } - acpFail("Session $sessionId not found") + return getOrCreateSessionHolder(sessionId).session.await() } private suspend fun withInitializingSession(block: suspend () -> T): T { - _currentlyInitializingSessionsCount.update { it + 1 } + _sessions.update { it.copy(initializingSessionsCount = it.initializingSessionsCount + 1) } try { return block() } finally { - _currentlyInitializingSessionsCount.update { it - 1 } + var hangingSessions: Map? = null + _sessions.update { currentStorage -> + hangingSessions = null + if (currentStorage.initializingSessionsCount == 0) { + logger.error { "Assertion failed: initializingSessionsCount should be positive, got ${currentStorage.initializingSessionsCount}" } + return@update currentStorage + } + val newCount = currentStorage.initializingSessionsCount - 1 + return@update if (newCount == 0) { + // this means that currently no sessions can be in initializing state during to ongoing load/new/fork/resume calls + // so if on exit from these methods we observe any entries with not completed or failed state we assume that someone sent us events with non-existent session ids + // and we have to remove them and report errors + hangingSessions = currentStorage.sessions.filterValues { + @OptIn(ExperimentalCoroutinesApi::class) + !it.session.isCompleted || it.session.getCompletionExceptionOrNull() != null + } + var aliveSessions: PersistentMap = currentStorage.sessions + for ((id, _) in hangingSessions) { + aliveSessions = aliveSessions.remove(id) + } + currentStorage.copy(initializingSessionsCount = newCount, sessions = aliveSessions) + } else { + currentStorage.copy(initializingSessionsCount = newCount) + } + } + if (hangingSessions != null) { + for ((id, holder) in hangingSessions) { + logger.trace { "Removing hanging session $id" } + // report it as non existent session + holder.completeExceptionally(AcpExpectedError("Session $id not found")) + } + } } } } diff --git a/build.gradle.kts b/build.gradle.kts index 954c70a..76c9d89 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -7,7 +7,7 @@ plugins { private val buildNumber: String? = System.getenv("GITHUB_RUN_NUMBER") private val isReleasePublication = System.getenv("RELEASE_PUBLICATION")?.toBoolean() ?: false -private val baseVersion = "0.16.1" +private val baseVersion = "0.16.3" allprojects { group = "com.agentclientprotocol"