diff --git a/MCP/cpp-sdk/src/server/mcp_server_implement.cpp b/MCP/cpp-sdk/src/server/mcp_server_implement.cpp index d32c82c..e72d1fd 100644 --- a/MCP/cpp-sdk/src/server/mcp_server_implement.cpp +++ b/MCP/cpp-sdk/src/server/mcp_server_implement.cpp @@ -407,6 +407,7 @@ void McpServerImplement::AddTool(const std::string& name, ToolFunc fn, AddToolOp ServerTool tool(name, fn, params.title, params.description, params.inputSchema, params.outputSchema, params.structuredOutput, params.annotations, params.icons); toolManager_.AddTool(tool); + RefreshServerCapabilities(); } void McpServerImplement::RemoveTool(const std::string& name) @@ -414,6 +415,7 @@ void McpServerImplement::RemoveTool(const std::string& name) CheckServerState(); toolManager_.RemoveTool(name); + RefreshServerCapabilities(); } void McpServerImplement::AddPrompt(const std::string& name, RenderPromptFunc handler, AddPromptOptionalParams params) @@ -436,6 +438,7 @@ void McpServerImplement::AddPrompt(const std::string& name, RenderPromptFunc han } promptManager_.AddPrompt(prompt, handler); + RefreshServerCapabilities(); } void McpServerImplement::RemovePrompt(const std::string& name) @@ -443,6 +446,7 @@ void McpServerImplement::RemovePrompt(const std::string& name) CheckServerState(); promptManager_.RemovePrompt(name); + RefreshServerCapabilities(); } void McpServerImplement::AddResource(const std::string& uri, const std::string& name, ReadResourceFunc readFunc, @@ -473,6 +477,7 @@ void McpServerImplement::AddResource(const std::string& uri, const std::string& } resourceManager_.AddResource(resource, readFunc); + RefreshServerCapabilities(); } void McpServerImplement::RemoveResource(const std::string& uri) @@ -480,6 +485,7 @@ void McpServerImplement::RemoveResource(const std::string& uri) CheckServerState(); resourceManager_.RemoveResource(uri); + RefreshServerCapabilities(); } void McpServerImplement::AddResourceTemplate(const std::string& uriTemplate, const std::string& name, @@ -507,6 +513,7 @@ void McpServerImplement::AddResourceTemplate(const std::string& uriTemplate, con } resourceManager_.AddResourceTemplate(resourceTemplate); + RefreshServerCapabilities(); } void McpServerImplement::RemoveResourceTemplate(const std::string& uriTemplate) @@ -514,6 +521,7 @@ void McpServerImplement::RemoveResourceTemplate(const std::string& uriTemplate) CheckServerState(); resourceManager_.RemoveResourceTemplate(uriTemplate); + RefreshServerCapabilities(); } bool McpServerImplement::ValidateStreamableHttpConfig(const StreamableHttpServerConfig& config) @@ -574,6 +582,7 @@ bool McpServerImplement::InitializeServerManager() [this](const RequestId& requestId, const Request& request, RequestContext& ctx) { this->ReceiveIncomingMessages(requestId, request, ctx); }); + serverManager_->SetServerCapabilities(BuildServerCapabilities()); MCP_LOG(MCP_LOG_LEVEL_DEBUG, "ServerManager initialized successfully"); serverManager_->Start(); MCP_LOG(MCP_LOG_LEVEL_DEBUG, "ServerManager start successfully"); @@ -584,6 +593,39 @@ bool McpServerImplement::InitializeServerManager() } } +ServerCapabilities McpServerImplement::BuildServerCapabilities() +{ + ServerCapabilities capabilities; + + auto toolsResult = toolManager_.ListTools(std::nullopt); + if (!toolsResult.tools.empty()) { + capabilities.tools = ToolsCapabilities{}; + } + + auto promptsResult = promptManager_.ListPrompts(); + if (!promptsResult.prompts.empty()) { + capabilities.prompts = PromptsCapabilities{}; + } + + auto resourcesResult = resourceManager_.ListResources(std::nullopt); + auto resourceTemplatesResult = resourceManager_.ListResourceTemplates(); + if (!resourcesResult.resources.empty() || !resourceTemplatesResult.resourceTemplates.empty()) { + ResourcesCapabilities resourcesCaps{}; + resourcesCaps.subscribe = true; + capabilities.resources = resourcesCaps; + } + + return capabilities; +} + +void McpServerImplement::RefreshServerCapabilities() +{ + if (serverManager_ == nullptr) { + return; + } + serverManager_->SetServerCapabilities(BuildServerCapabilities()); +} + static bool ValidateFilePathWithRealpath(const std::string& path, const char* what) { if (path.empty()) { diff --git a/MCP/cpp-sdk/src/server/mcp_server_implement.h b/MCP/cpp-sdk/src/server/mcp_server_implement.h index f94155b..abcc593 100644 --- a/MCP/cpp-sdk/src/server/mcp_server_implement.h +++ b/MCP/cpp-sdk/src/server/mcp_server_implement.h @@ -82,6 +82,8 @@ class McpServerImplement : public McpServer { bool InitializeServerManager(); bool ValidateTlsConfig(const TlsConfig& config); void CheckServerState() const; + ServerCapabilities BuildServerCapabilities(); + void RefreshServerCapabilities(); ServerConfig config_; StreamableHttpServerConfig streamableConfig_; diff --git a/MCP/cpp-sdk/src/server/server_manager.cpp b/MCP/cpp-sdk/src/server/server_manager.cpp index 7b13435..a075b9b 100644 --- a/MCP/cpp-sdk/src/server/server_manager.cpp +++ b/MCP/cpp-sdk/src/server/server_manager.cpp @@ -144,6 +144,7 @@ void ServerManager::StdioServerManagerStart() { std::shared_ptr stdioTransport_ = std::make_shared(); stdioSession_ = std::make_shared(stdioTransport_, config_, GenerateSessionId()); + stdioSession_->SetServerCapabilities(serverCapabilities_); // Set incoming request callback if provided if (requestCallback_) { stdioSession_->SetIncomingRequestCallback(requestCallback_); @@ -285,6 +286,24 @@ void ServerManager::SetIncomingRequestCallback(IncomingRequestCallback callback) requestCallback_ = callback; } +void ServerManager::SetServerCapabilities(const ServerCapabilities& capabilities) +{ + serverCapabilities_ = capabilities; + + if (stdioSession_) { + stdioSession_->SetServerCapabilities(serverCapabilities_); + } + + for (auto& sessions : threadSessions_) { + for (auto& [sessionId, session] : sessions) { + (void)sessionId; + if (session) { + session->SetServerCapabilities(serverCapabilities_); + } + } + } +} + std::shared_ptr ServerManager::NewSession(const std::string& sessionId) { // Create transport @@ -301,6 +320,7 @@ std::shared_ptr ServerManager::NewSession(const std::string& sess MCP_LOG(MCP_LOG_LEVEL_ERROR, "Failed to create session layer."); throw std::runtime_error("Failed to create session layer."); } + session->SetServerCapabilities(serverCapabilities_); // Set incoming request callback if provided if (requestCallback_) { session->SetIncomingRequestCallback(requestCallback_); diff --git a/MCP/cpp-sdk/src/server/server_manager.h b/MCP/cpp-sdk/src/server/server_manager.h index 0cf6667..6db3c82 100644 --- a/MCP/cpp-sdk/src/server/server_manager.h +++ b/MCP/cpp-sdk/src/server/server_manager.h @@ -41,6 +41,7 @@ class ServerManager { void Stop(); std::shared_ptr GetSession(const std::string& sessionId); void SetIncomingRequestCallback(IncomingRequestCallback callback); + void SetServerCapabilities(const ServerCapabilities& capabilities); private: void StdioServerManagerStart(); @@ -72,6 +73,7 @@ class ServerManager { std::vector> notifyArgs_; std::shared_ptr stdioSession_{nullptr}; std::shared_ptr stdioTransport_{nullptr}; + ServerCapabilities serverCapabilities_; }; } // namespace Mcp diff --git a/MCP/cpp-sdk/src/server/transport/streamable_http_server_transport.cpp b/MCP/cpp-sdk/src/server/transport/streamable_http_server_transport.cpp index c5a0657..d5d36af 100644 --- a/MCP/cpp-sdk/src/server/transport/streamable_http_server_transport.cpp +++ b/MCP/cpp-sdk/src/server/transport/streamable_http_server_transport.cpp @@ -5,6 +5,7 @@ #include "streamable_http_server_transport.h" #include +#include #include #include #include @@ -26,6 +27,31 @@ constexpr const char* GET_STREAM_KEY = "_GET_stream"; // Session ID validation pattern (visible ASCII characters ranging from 0x21 to 0x7E) static const std::regex SESSION_ID_PATTERN("^[\\x21-\\x7E]+$"); +static std::optional GetHeaderValueCaseInsensitive( + const Http::HttpRequest& request, const std::string& headerName) +{ + auto exactIt = request.headers.find(headerName); + if (exactIt != request.headers.end()) { + return exactIt->second; + } + + for (const auto& [key, value] : request.headers) { + if (key.size() != headerName.size()) { + continue; + } + + bool same = std::equal(key.begin(), key.end(), headerName.begin(), + [](unsigned char left, unsigned char right) { + return std::tolower(left) == std::tolower(right); + }); + if (same) { + return value; + } + } + + return std::nullopt; +} + StreamableHttpServerTransport::StreamableHttpServerTransport(const std::string& mcpSessionId, bool isJsonResponseEnabled) : mcpSessionId_(mcpSessionId), @@ -46,10 +72,9 @@ void StreamableHttpServerTransport::SetCallback(std::shared_ptrsecond; + auto sessionId = GetHeaderValueCaseInsensitive(request, Http::MCP_SESSION_ID_HEADER); + if (sessionId.has_value()) { + return sessionId.value(); } return ""; } @@ -62,9 +87,9 @@ bool StreamableHttpServerTransport::ValidateProtocolVersion(RequestContext& ctx, // Get the protocol version from the request headers std::string protocolVersion{}; - auto versionIt = request.headers.find(Http::MCP_PROTOCOL_VERSION_HEADER); - if (versionIt != request.headers.end()) { - protocolVersion = versionIt->second; + auto version = GetHeaderValueCaseInsensitive(request, Http::MCP_PROTOCOL_VERSION_HEADER); + if (version.has_value()) { + protocolVersion = version.value(); } else { // If no protocol version provided, assume default version protocolVersion = DEFAULT_PROTOCOL_VERSION; @@ -177,11 +202,7 @@ void StreamableHttpServerTransport::HandleRequest(const HttpRequest& request, Re bool StreamableHttpServerTransport::ValidatePostRequestHeaders(RequestContext& ctx, const HttpRequest& request) { - auto acceptIt = request.headers.find(Http::ACCEPT_HEADER); - std::string acceptHeader{}; - if (acceptIt != request.headers.end()) { - acceptHeader = acceptIt->second; - } + std::string acceptHeader = GetHeaderValueCaseInsensitive(request, Http::ACCEPT_HEADER).value_or(""); bool hasJson = acceptHeader.find(Http::CONTENT_TYPE_JSON) != std::string::npos; bool hasSse = acceptHeader.find(Http::CONTENT_TYPE_SSE) != std::string::npos; if (!hasJson || !hasSse) { @@ -195,9 +216,8 @@ bool StreamableHttpServerTransport::ValidatePostRequestHeaders(RequestContext& c ctx.httpSendFunc(response, ctx); return false; } - auto contentTypeIt = request.headers.find(Http::CONTENT_TYPE_HEADER); - if (contentTypeIt == request.headers.end() || - contentTypeIt->second.find(Http::CONTENT_TYPE_JSON) == std::string::npos) { + auto contentType = GetHeaderValueCaseInsensitive(request, Http::CONTENT_TYPE_HEADER); + if (!contentType.has_value() || contentType->find(Http::CONTENT_TYPE_JSON) == std::string::npos) { HttpResponse response = CreateErrorResponse("Unsupported Media Type: Content-Type must be application/json", Http::HTTP_STATUS_UNSUPPORTED_MEDIA_TYPE, static_cast(JsonRpcErrorCode::INVALID_REQUEST)); @@ -317,11 +337,7 @@ void StreamableHttpServerTransport::HandleGetRequest(RequestContext& ctx, const } // Check Accept header - auto acceptIt = request.headers.find("accept"); - std::string acceptHeader{}; - if (acceptIt != request.headers.end()) { - acceptHeader = acceptIt->second; - } + std::string acceptHeader = GetHeaderValueCaseInsensitive(request, Http::ACCEPT_HEADER).value_or(""); bool hasSse = acceptHeader.find(Http::CONTENT_TYPE_SSE) != std::string::npos; if (!hasSse) { diff --git a/MCP/cpp-sdk/src/shared/jsonrpc.cpp b/MCP/cpp-sdk/src/shared/jsonrpc.cpp index 410d43f..78d44eb 100644 --- a/MCP/cpp-sdk/src/shared/jsonrpc.cpp +++ b/MCP/cpp-sdk/src/shared/jsonrpc.cpp @@ -111,6 +111,36 @@ struct adl_serializer { { j["protocolVersion"] = r.protocolVersion; j["capabilities"] = json::object(); + if (r.capabilities.tools.has_value()) { + json toolsObj = json::object(); + if (r.capabilities.tools->listChanged.has_value()) { + toolsObj["listChanged"] = r.capabilities.tools->listChanged.value(); + } + j["capabilities"]["tools"] = std::move(toolsObj); + } + if (r.capabilities.prompts.has_value()) { + json promptsObj = json::object(); + if (r.capabilities.prompts->listChanged.has_value()) { + promptsObj["listChanged"] = r.capabilities.prompts->listChanged.value(); + } + j["capabilities"]["prompts"] = std::move(promptsObj); + } + if (r.capabilities.resources.has_value()) { + json resourcesObj = json::object(); + if (r.capabilities.resources->subscribe.has_value()) { + resourcesObj["subscribe"] = r.capabilities.resources->subscribe.value(); + } + if (r.capabilities.resources->listChanged.has_value()) { + resourcesObj["listChanged"] = r.capabilities.resources->listChanged.value(); + } + j["capabilities"]["resources"] = std::move(resourcesObj); + } + if (r.capabilities.logging.has_value()) { + j["capabilities"]["logging"] = json::object(); + } + if (r.capabilities.experimental.has_value()) { + j["capabilities"]["experimental"] = r.capabilities.experimental.value(); + } j["serverInfo"] = { {"name", r.serverInfo.name}, {"version", r.serverInfo.version}, @@ -131,6 +161,48 @@ struct adl_serializer { r.protocolVersion = Mcp::DEFAULT_PROTOCOL_VERSION; } r.capabilities = Mcp::ServerCapabilities{}; + if (j.contains("capabilities") && j.at("capabilities").is_object()) { + const auto& caps = j.at("capabilities"); + + if (caps.contains("tools") && caps.at("tools").is_object()) { + Mcp::ToolsCapabilities toolsCaps{}; + const auto& toolsObj = caps.at("tools"); + if (toolsObj.contains("listChanged") && toolsObj.at("listChanged").is_boolean()) { + toolsCaps.listChanged = toolsObj.at("listChanged").get(); + } + r.capabilities.tools = toolsCaps; + } + + if (caps.contains("prompts") && caps.at("prompts").is_object()) { + Mcp::PromptsCapabilities promptsCaps{}; + const auto& promptsObj = caps.at("prompts"); + if (promptsObj.contains("listChanged") && promptsObj.at("listChanged").is_boolean()) { + promptsCaps.listChanged = promptsObj.at("listChanged").get(); + } + r.capabilities.prompts = promptsCaps; + } + + if (caps.contains("resources") && caps.at("resources").is_object()) { + Mcp::ResourcesCapabilities resourcesCaps{}; + const auto& resourcesObj = caps.at("resources"); + if (resourcesObj.contains("subscribe") && resourcesObj.at("subscribe").is_boolean()) { + resourcesCaps.subscribe = resourcesObj.at("subscribe").get(); + } + if (resourcesObj.contains("listChanged") && resourcesObj.at("listChanged").is_boolean()) { + resourcesCaps.listChanged = resourcesObj.at("listChanged").get(); + } + r.capabilities.resources = resourcesCaps; + } + + if (caps.contains("logging") && caps.at("logging").is_object()) { + r.capabilities.logging = Mcp::LoggingCapabilities{}; + } + + if (caps.contains("experimental") && caps.at("experimental").is_object()) { + r.capabilities.experimental = + caps.at("experimental").get>(); + } + } if (j.contains("serverInfo")) { const auto& si = j.at("serverInfo"); r.serverInfo.name = si.value("name", std::string{});