diff --git a/kotlin-sdk-client/api/kotlin-sdk-client.api b/kotlin-sdk-client/api/kotlin-sdk-client.api index a785a916..46a7b60d 100644 --- a/kotlin-sdk-client/api/kotlin-sdk-client.api +++ b/kotlin-sdk-client/api/kotlin-sdk-client.api @@ -8,9 +8,9 @@ public class io/modelcontextprotocol/kotlin/sdk/client/Client : io/modelcontextp protected fun assertNotificationCapability (Lio/modelcontextprotocol/kotlin/sdk/Method;)V public fun assertRequestHandlerCapability (Lio/modelcontextprotocol/kotlin/sdk/Method;)V public final fun callTool (Lio/modelcontextprotocol/kotlin/sdk/CallToolRequest;ZLio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public final fun callTool (Ljava/lang/String;Ljava/util/Map;ZLio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun callTool (Ljava/lang/String;Ljava/util/Map;Ljava/util/Map;ZLio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static synthetic fun callTool$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/CallToolRequest;ZLio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; - public static synthetic fun callTool$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Ljava/lang/String;Ljava/util/Map;ZLio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public static synthetic fun callTool$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Ljava/lang/String;Ljava/util/Map;Ljava/util/Map;ZLio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; public final fun complete (Lio/modelcontextprotocol/kotlin/sdk/CompleteRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static synthetic fun complete$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/CompleteRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; public fun connect (Lio/modelcontextprotocol/kotlin/sdk/shared/Transport;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt index 95a5bc5b..9929f92d 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt @@ -50,6 +50,8 @@ import kotlinx.atomicfu.update import kotlinx.collections.immutable.minus import kotlinx.collections.immutable.persistentMapOf import kotlinx.collections.immutable.toPersistentSet +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.json.JsonArray import kotlinx.serialization.json.JsonElement import kotlinx.serialization.json.JsonNull import kotlinx.serialization.json.JsonObject @@ -405,10 +407,14 @@ public open class Client(private val clientInfo: Implementation, options: Client ): EmptyRequestResult = request(request, options) /** - * Calls a tool on the server by name, passing the specified arguments. + * Calls a tool on the server by name, passing the specified arguments and metadata. * * @param name The name of the tool to call. * @param arguments A map of argument names to values for the tool. + * @param meta A map of metadata key-value pairs. Keys must follow MCP specification format. + * - Optional prefix: dot-separated labels followed by slash (e.g., "api.example.com/") + * - Name: alphanumeric start/end, may contain hyphens, underscores, dots, alphanumerics + * - Reserved prefixes starting with "mcp" or "modelcontextprotocol" are forbidden * @param compatibility Whether to use compatibility mode for older protocol versions. * @param options Optional request options. * @return The result of the tool call, or `null` if none. @@ -417,23 +423,19 @@ public open class Client(private val clientInfo: Implementation, options: Client public suspend fun callTool( name: String, arguments: Map, + meta: Map = emptyMap(), compatibility: Boolean = false, options: RequestOptions? = null, ): CallToolResultBase? { - val jsonArguments = arguments.mapValues { (_, value) -> - when (value) { - is String -> JsonPrimitive(value) - is Number -> JsonPrimitive(value) - is Boolean -> JsonPrimitive(value) - is JsonElement -> value - null -> JsonNull - else -> JsonPrimitive(value.toString()) - } - } + validateMetaKeys(meta.keys) + + val jsonArguments = convertToJsonMap(arguments) + val jsonMeta = convertToJsonMap(meta) val request = CallToolRequest( name = name, arguments = JsonObject(jsonArguments), + _meta = JsonObject(jsonMeta), ) return callTool(request, compatibility, options) } @@ -588,4 +590,137 @@ public open class Client(private val clientInfo: Implementation, options: Client val rootList = roots.value.values.toList() return ListRootsResult(rootList) } + + /** + * Validates meta keys according to MCP specification. + * + * Key format: [prefix/]name + * - Prefix (optional): dot-separated labels + slash + * - Reserved prefixes contain "modelcontextprotocol" or "mcp" as complete labels + * - Name: alphanumeric start/end, may contain hyphens, underscores, dots (empty allowed) + */ + private fun validateMetaKeys(keys: Set) { + for (key in keys) { + if (!isValidMetaKey(key)) { + throw Error("Invalid _meta key '$key'. Must follow format [prefix/]name with valid labels.") + } + } + } + + private fun isValidMetaKey(key: String): Boolean { + if (key.isEmpty()) return false + val parts = key.split('/', limit = 2) + return when (parts.size) { + 1 -> { + // No prefix, just validate name + isValidMetaName(parts[0]) + } + + 2 -> { + val (prefix, name) = parts + isValidMetaPrefix(prefix) && isValidMetaName(name) + } + + else -> false + } + } + + private fun isValidMetaPrefix(prefix: String): Boolean { + if (prefix.isEmpty()) return false + val labels = prefix.split('.') + + if (!labels.all { isValidLabel(it) }) { + return false + } + + return !labels.any { label -> + label.equals("modelcontextprotocol", ignoreCase = true) || + label.equals("mcp", ignoreCase = true) + } + } + + private fun isValidLabel(label: String): Boolean { + if (label.isEmpty()) return false + if (!label.first().isLetter() || !label.last().let { it.isLetter() || it.isDigit() }) { + return false + } + return label.all { it.isLetter() || it.isDigit() || it == '-' } + } + + private fun isValidMetaName(name: String): Boolean { + // Empty names are allowed per MCP specification + if (name.isEmpty()) return true + + if (!name.first().isLetterOrDigit() || !name.last().isLetterOrDigit()) { + return false + } + return name.all { it.isLetterOrDigit() || it in setOf('-', '_', '.') } + } + + private fun convertToJsonMap(map: Map): Map = map.mapValues { (key, value) -> + try { + convertToJsonElement(value) + } catch (e: Exception) { + logger.warn { "Failed to convert value for key '$key': ${e.message}. Using string representation." } + JsonPrimitive(value.toString()) + } + } + + @OptIn(ExperimentalUnsignedTypes::class, ExperimentalSerializationApi::class) + private fun convertToJsonElement(value: Any?): JsonElement = when (value) { + null -> JsonNull + + is Map<*, *> -> { + val jsonMap = value.entries.associate { (k, v) -> + k.toString() to convertToJsonElement(v) + } + JsonObject(jsonMap) + } + + is JsonElement -> value + + is String -> JsonPrimitive(value) + + is Number -> JsonPrimitive(value) + + is Boolean -> JsonPrimitive(value) + + is Char -> JsonPrimitive(value.toString()) + + is Enum<*> -> JsonPrimitive(value.name) + + is Collection<*> -> JsonArray(value.map { convertToJsonElement(it) }) + + is Array<*> -> JsonArray(value.map { convertToJsonElement(it) }) + + is IntArray -> JsonArray(value.map { JsonPrimitive(it) }) + + is LongArray -> JsonArray(value.map { JsonPrimitive(it) }) + + is FloatArray -> JsonArray(value.map { JsonPrimitive(it) }) + + is DoubleArray -> JsonArray(value.map { JsonPrimitive(it) }) + + is BooleanArray -> JsonArray(value.map { JsonPrimitive(it) }) + + is ShortArray -> JsonArray(value.map { JsonPrimitive(it) }) + + is ByteArray -> JsonArray(value.map { JsonPrimitive(it) }) + + is CharArray -> JsonArray(value.map { JsonPrimitive(it.toString()) }) + + // ExperimentalUnsignedTypes + is UIntArray -> JsonArray(value.map { JsonPrimitive(it) }) + + is ULongArray -> JsonArray(value.map { JsonPrimitive(it) }) + + is UShortArray -> JsonArray(value.map { JsonPrimitive(it) }) + + is UByteArray -> JsonArray(value.map { JsonPrimitive(it) }) + + else -> { + logger.debug { "Converting unknown type ${value::class.simpleName} to string: $value" } + JsonPrimitive(value.toString()) + } + } } diff --git a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientMetaParameterTest.kt b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientMetaParameterTest.kt new file mode 100644 index 00000000..2c37b3eb --- /dev/null +++ b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientMetaParameterTest.kt @@ -0,0 +1,333 @@ +package io.modelcontextprotocol.kotlin.sdk.client + +import io.modelcontextprotocol.kotlin.sdk.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.InitializeResult +import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.JSONRPCResponse +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.shared.Transport +import kotlinx.coroutines.test.runTest +import kotlinx.serialization.json.JsonObject +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertContains +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue + +/** + * Comprehensive test suite for MCP Client meta parameter functionality + * + * Tests cover: + * - Meta key validation according to MCP specification + * - JSON type conversion for various data types + * - Error handling for invalid meta keys + * - Integration with callTool method + */ +class ClientMetaParameterTest { + + private lateinit var client: Client + private lateinit var mockTransport: MockTransport + private val clientInfo = Implementation("test-client", "1.0.0") + + @BeforeTest + fun setup() = runTest { + mockTransport = MockTransport() + client = Client(clientInfo = clientInfo) + mockTransport.setupInitializationResponse() + client.connect(mockTransport) + } + + @Test + fun `should accept valid meta keys without throwing exception`() = runTest { + val validMeta = buildMap { + put("simple-key", "value1") + put("api.example.com/version", "1.0") + put("com.company.app/setting", "enabled") + put("retry_count", 3) + put("user.preference", true) + put("valid123", "alphanumeric") + put("multi.dot.name", "multiple-dots") + put("under_score", "underscore") + put("hyphen-dash", "hyphen") + put("org.apache.kafka/consumer-config", "complex-valid-prefix") + } + + val result = runCatching { + client.callTool("test-tool", mapOf("arg" to "value"), validMeta) + } + + assertTrue(result.isSuccess, "Valid meta keys should not cause exceptions") + } + + @Test + fun `should accept edge case valid prefixes and names`() = runTest { + val edgeCaseValidMeta = buildMap { + put("a/", "single-char-prefix-empty-name") // empty name is allowed + put("a1-b2/test", "alphanumeric-hyphen-prefix") + put("long.domain.name.here/config", "long-prefix") + put("x/a", "minimal-valid-key") + put("test123", "alphanumeric-name-only") + } + + val result = runCatching { + client.callTool("test-tool", emptyMap(), edgeCaseValidMeta) + } + + assertTrue(result.isSuccess, "Edge case valid meta keys should be accepted") + } + + @Test + fun `should reject mcp reserved prefix`() = runTest { + val invalidMeta = mapOf("mcp/internal" to "value") + + val exception = assertFailsWith { + client.callTool("test-tool", emptyMap(), invalidMeta) + } + + assertContains( + charSequence = exception.message ?: "", + other = "Invalid _meta key", + ) + } + + @Test + fun `should reject modelcontextprotocol reserved prefix`() = runTest { + val invalidMeta = mapOf("modelcontextprotocol/config" to "value") + + val exception = assertFailsWith { + client.callTool("test-tool", emptyMap(), invalidMeta) + } + + assertContains( + charSequence = exception.message ?: "", + other = "Invalid _meta key", + ) + } + + @Test + fun `should reject nested reserved prefixes`() = runTest { + val invalidKeys = listOf( + "api.mcp.io/setting", + "com.modelcontextprotocol.test/value", + "example.mcp/data", + "subdomain.mcp.com/config", + "app.modelcontextprotocol.dev/setting", + "test.mcp/value", + "service.modelcontextprotocol/data", + ) + + invalidKeys.forEach { key -> + val exception = assertFailsWith( + message = "Should reject nested reserved key: $key", + ) { + client.callTool("test-tool", emptyMap(), mapOf(key to "value")) + } + assertContains( + charSequence = exception.message ?: "", + other = "Invalid _meta key", + ) + } + } + + @Test + fun `should reject case-insensitive reserved prefixes`() = runTest { + val invalidKeys = listOf( + "MCP/internal", + "Mcp/config", + "mCp/setting", + "MODELCONTEXTPROTOCOL/data", + "ModelContextProtocol/value", + "modelContextProtocol/test", + ) + + invalidKeys.forEach { key -> + val exception = assertFailsWith( + message = "Should reject case-insensitive reserved key: $key", + ) { + client.callTool("test-tool", emptyMap(), mapOf(key to "value")) + } + assertContains( + charSequence = exception.message ?: "", + other = "Invalid _meta key", + ) + } + } + + @Test + fun `should reject invalid key formats`() = runTest { + val invalidKeys = listOf( + "", // empty key - not allowed at key level + "/invalid", // starts with slash + "-invalid", // starts with hyphen + ".invalid", // starts with dot + "in valid", // contains space + "api../test", // consecutive dots + "api./test", // label ends with dot + ) + + invalidKeys.forEach { key -> + val exception = assertFailsWith( + message = "Should reject invalid key format: '$key'", + ) { + client.callTool("test-tool", emptyMap(), mapOf(key to "value")) + } + assertContains( + charSequence = exception.message ?: "", + other = "Invalid _meta key", + ) + } + } + + @Test + fun `should convert various data types to JSON correctly`() = runTest { + val complexMeta = createComplexMetaData() + + val result = runCatching { + client.callTool( + "test-tool", + emptyMap(), + complexMeta, + ) + } + + assertTrue(result.isSuccess, "Complex data type conversion should not throw exceptions") + + mockTransport.lastJsonRpcRequest?.let { request -> + assertEquals("tools/call", request.method) + val params = request.params as JsonObject + assertTrue(params.containsKey("_meta"), "Request should contain _meta field") + } + } + + @Test + fun `should handle nested map structures correctly`() = runTest { + val nestedMeta = buildNestedConfiguration() + + val result = runCatching { + client.callTool("test-tool", emptyMap(), nestedMeta) + } + + assertTrue(result.isSuccess) + + mockTransport.lastJsonRpcRequest?.let { request -> + val params = request.params as JsonObject + val metaField = params["_meta"] as JsonObject + assertTrue(metaField.containsKey("config")) + } + } + + @Test + fun `should include empty meta object when meta parameter not provided`() = runTest { + client.callTool("test-tool", mapOf("arg" to "value")) + + mockTransport.lastJsonRpcRequest?.let { request -> + val params = request.params as JsonObject + val metaField = params["_meta"] as JsonObject + assertTrue(metaField.isEmpty(), "Meta field should be empty when not provided") + } + } + + private fun createComplexMetaData(): Map = buildMap { + put("string", "text") + put("number", 42) + put("boolean", true) + put("null_value", null) + put("list", listOf(1, 2, 3)) + put("map", mapOf("nested" to "value")) + put("enum", "STRING") + put("int_array", intArrayOf(1, 2, 3)) + } + + private fun buildNestedConfiguration(): Map = buildMap { + put( + "config", + buildMap { + put( + "database", + buildMap { + put("host", "localhost") + put("port", 5432) + }, + ) + put("features", listOf("feature1", "feature2")) + }, + ) + } +} + +class MockTransport : Transport { + private val _sentMessages = mutableListOf() + val sentMessages: List = _sentMessages + + private var onMessageBlock: (suspend (JSONRPCMessage) -> Unit)? = null + private var onCloseBlock: (() -> Unit)? = null + private var onErrorBlock: ((Throwable) -> Unit)? = null + + override suspend fun start() = Unit + + override suspend fun send(message: JSONRPCMessage) { + _sentMessages += message + + // Auto-respond to initialization and tool calls + when (message) { + is JSONRPCRequest -> { + when (message.method) { + "initialize" -> { + val initResponse = JSONRPCResponse( + id = message.id, + result = InitializeResult( + protocolVersion = "2024-11-05", + capabilities = ServerCapabilities( + tools = ServerCapabilities.Tools(listChanged = null), + ), + serverInfo = Implementation("mock-server", "1.0.0"), + ), + ) + onMessageBlock?.invoke(initResponse) + } + + "tools/call" -> { + val toolResponse = JSONRPCResponse( + id = message.id, + result = CallToolResult( + content = listOf(), + isError = false, + ), + ) + onMessageBlock?.invoke(toolResponse) + } + } + } + + else -> { + // Handle other message types if needed + } + } + } + + override suspend fun close() { + onCloseBlock?.invoke() + } + + override fun onMessage(block: suspend (JSONRPCMessage) -> Unit) { + onMessageBlock = block + } + + override fun onClose(block: () -> Unit) { + onCloseBlock = block + } + + override fun onError(block: (Throwable) -> Unit) { + onErrorBlock = block + } + + fun setupInitializationResponse() { + // This method helps set up the mock for proper initialization + } +} + +val MockTransport.lastJsonRpcRequest: JSONRPCRequest? + get() = sentMessages.lastOrNull() as? JSONRPCRequest