Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions MCP/cpp-sdk/src/server/mcp_server_implement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,13 +407,15 @@ void McpServerImplement::AddTool(const std::string& name, ToolFunc fn, AddToolOp
ServerTool tool(name, fn, params.title, params.description, params.inputSchema, params.outputSchema,
params.structuredOutput, params.annotations, params.icons);
toolManager_.AddTool(tool);
RefreshServerCapabilities();
}

void McpServerImplement::RemoveTool(const std::string& name)
{
CheckServerState();

toolManager_.RemoveTool(name);
RefreshServerCapabilities();
}

void McpServerImplement::AddPrompt(const std::string& name, RenderPromptFunc handler, AddPromptOptionalParams params)
Expand All @@ -436,13 +438,15 @@ void McpServerImplement::AddPrompt(const std::string& name, RenderPromptFunc han
}

promptManager_.AddPrompt(prompt, handler);
RefreshServerCapabilities();
}

void McpServerImplement::RemovePrompt(const std::string& name)
{
CheckServerState();

promptManager_.RemovePrompt(name);
RefreshServerCapabilities();
}

void McpServerImplement::AddResource(const std::string& uri, const std::string& name, ReadResourceFunc readFunc,
Expand Down Expand Up @@ -473,13 +477,15 @@ void McpServerImplement::AddResource(const std::string& uri, const std::string&
}

resourceManager_.AddResource(resource, readFunc);
RefreshServerCapabilities();
}

void McpServerImplement::RemoveResource(const std::string& uri)
{
CheckServerState();

resourceManager_.RemoveResource(uri);
RefreshServerCapabilities();
}

void McpServerImplement::AddResourceTemplate(const std::string& uriTemplate, const std::string& name,
Expand Down Expand Up @@ -507,13 +513,15 @@ void McpServerImplement::AddResourceTemplate(const std::string& uriTemplate, con
}

resourceManager_.AddResourceTemplate(resourceTemplate);
RefreshServerCapabilities();
}

void McpServerImplement::RemoveResourceTemplate(const std::string& uriTemplate)
{
CheckServerState();

resourceManager_.RemoveResourceTemplate(uriTemplate);
RefreshServerCapabilities();
}

bool McpServerImplement::ValidateStreamableHttpConfig(const StreamableHttpServerConfig& config)
Expand Down Expand Up @@ -574,6 +582,7 @@ bool McpServerImplement::InitializeServerManager()
[this](const RequestId& requestId, const Request& request, RequestContext& ctx) {
this->ReceiveIncomingMessages(requestId, request, ctx);
});
serverManager_->SetServerCapabilities(BuildServerCapabilities());
MCP_LOG(MCP_LOG_LEVEL_DEBUG, "ServerManager initialized successfully");
serverManager_->Start();
MCP_LOG(MCP_LOG_LEVEL_DEBUG, "ServerManager start successfully");
Expand All @@ -584,6 +593,39 @@ bool McpServerImplement::InitializeServerManager()
}
}

ServerCapabilities McpServerImplement::BuildServerCapabilities()
{
ServerCapabilities capabilities;

auto toolsResult = toolManager_.ListTools(std::nullopt);
if (!toolsResult.tools.empty()) {
capabilities.tools = ToolsCapabilities{};
}

auto promptsResult = promptManager_.ListPrompts();
if (!promptsResult.prompts.empty()) {
capabilities.prompts = PromptsCapabilities{};
}

auto resourcesResult = resourceManager_.ListResources(std::nullopt);
auto resourceTemplatesResult = resourceManager_.ListResourceTemplates();
if (!resourcesResult.resources.empty() || !resourceTemplatesResult.resourceTemplates.empty()) {
ResourcesCapabilities resourcesCaps{};
resourcesCaps.subscribe = true;
capabilities.resources = resourcesCaps;
}

return capabilities;
}

void McpServerImplement::RefreshServerCapabilities()
{
if (serverManager_ == nullptr) {
return;
}
serverManager_->SetServerCapabilities(BuildServerCapabilities());
}

static bool ValidateFilePathWithRealpath(const std::string& path, const char* what)
{
if (path.empty()) {
Expand Down
2 changes: 2 additions & 0 deletions MCP/cpp-sdk/src/server/mcp_server_implement.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class McpServerImplement : public McpServer {
bool InitializeServerManager();
bool ValidateTlsConfig(const TlsConfig& config);
void CheckServerState() const;
ServerCapabilities BuildServerCapabilities();
void RefreshServerCapabilities();

ServerConfig config_;
StreamableHttpServerConfig streamableConfig_;
Expand Down
20 changes: 20 additions & 0 deletions MCP/cpp-sdk/src/server/server_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ void ServerManager::StdioServerManagerStart()
{
std::shared_ptr<ServerTransport> stdioTransport_ = std::make_shared<StdioServerTransport>();
stdioSession_ = std::make_shared<ServerSession>(stdioTransport_, config_, GenerateSessionId());
stdioSession_->SetServerCapabilities(serverCapabilities_);
// Set incoming request callback if provided
if (requestCallback_) {
stdioSession_->SetIncomingRequestCallback(requestCallback_);
Expand Down Expand Up @@ -285,6 +286,24 @@ void ServerManager::SetIncomingRequestCallback(IncomingRequestCallback callback)
requestCallback_ = callback;
}

void ServerManager::SetServerCapabilities(const ServerCapabilities& capabilities)
{
serverCapabilities_ = capabilities;

if (stdioSession_) {
stdioSession_->SetServerCapabilities(serverCapabilities_);
}

for (auto& sessions : threadSessions_) {
for (auto& [sessionId, session] : sessions) {
(void)sessionId;
if (session) {
session->SetServerCapabilities(serverCapabilities_);
}
}
}
}

std::shared_ptr<ServerSession> ServerManager::NewSession(const std::string& sessionId)
{
// Create transport
Expand All @@ -301,6 +320,7 @@ std::shared_ptr<ServerSession> ServerManager::NewSession(const std::string& sess
MCP_LOG(MCP_LOG_LEVEL_ERROR, "Failed to create session layer.");
throw std::runtime_error("Failed to create session layer.");
}
session->SetServerCapabilities(serverCapabilities_);
// Set incoming request callback if provided
if (requestCallback_) {
session->SetIncomingRequestCallback(requestCallback_);
Expand Down
2 changes: 2 additions & 0 deletions MCP/cpp-sdk/src/server/server_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class ServerManager {
void Stop();
std::shared_ptr<ServerSession> GetSession(const std::string& sessionId);
void SetIncomingRequestCallback(IncomingRequestCallback callback);
void SetServerCapabilities(const ServerCapabilities& capabilities);

private:
void StdioServerManagerStart();
Expand Down Expand Up @@ -72,6 +73,7 @@ class ServerManager {
std::vector<std::shared_ptr<NotifyEventArg>> notifyArgs_;
std::shared_ptr<ServerSession> stdioSession_{nullptr};
std::shared_ptr<ServerTransport> stdioTransport_{nullptr};
ServerCapabilities serverCapabilities_;
};
} // namespace Mcp

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "streamable_http_server_transport.h"

#include <algorithm>
#include <cctype>
#include <nlohmann/json.hpp>
#include <regex>
#include <sstream>
Expand All @@ -26,6 +27,31 @@ constexpr const char* GET_STREAM_KEY = "_GET_stream";
// Session ID validation pattern (visible ASCII characters ranging from 0x21 to 0x7E)
static const std::regex SESSION_ID_PATTERN("^[\\x21-\\x7E]+$");

static std::optional<std::string> GetHeaderValueCaseInsensitive(
const Http::HttpRequest& request, const std::string& headerName)
{
auto exactIt = request.headers.find(headerName);
if (exactIt != request.headers.end()) {
return exactIt->second;
}

for (const auto& [key, value] : request.headers) {
if (key.size() != headerName.size()) {
continue;
}

bool same = std::equal(key.begin(), key.end(), headerName.begin(),
[](unsigned char left, unsigned char right) {
return std::tolower(left) == std::tolower(right);
});
if (same) {
return value;
}
}

return std::nullopt;
}

StreamableHttpServerTransport::StreamableHttpServerTransport(const std::string& mcpSessionId,
bool isJsonResponseEnabled)
: mcpSessionId_(mcpSessionId),
Expand All @@ -46,10 +72,9 @@ void StreamableHttpServerTransport::SetCallback(std::shared_ptr<TransportCallbac

std::string StreamableHttpServerTransport::GetSessionId(const HttpRequest& request) const
{
// Extract the session ID from request headers
auto it = request.headers.find(Http::MCP_SESSION_ID_HEADER);
if (it != request.headers.end()) {
return it->second;
auto sessionId = GetHeaderValueCaseInsensitive(request, Http::MCP_SESSION_ID_HEADER);
if (sessionId.has_value()) {
return sessionId.value();
}
return "";
}
Expand All @@ -62,9 +87,9 @@ bool StreamableHttpServerTransport::ValidateProtocolVersion(RequestContext& ctx,

// Get the protocol version from the request headers
std::string protocolVersion{};
auto versionIt = request.headers.find(Http::MCP_PROTOCOL_VERSION_HEADER);
if (versionIt != request.headers.end()) {
protocolVersion = versionIt->second;
auto version = GetHeaderValueCaseInsensitive(request, Http::MCP_PROTOCOL_VERSION_HEADER);
if (version.has_value()) {
protocolVersion = version.value();
} else {
// If no protocol version provided, assume default version
protocolVersion = DEFAULT_PROTOCOL_VERSION;
Expand Down Expand Up @@ -177,11 +202,7 @@ void StreamableHttpServerTransport::HandleRequest(const HttpRequest& request, Re

bool StreamableHttpServerTransport::ValidatePostRequestHeaders(RequestContext& ctx, const HttpRequest& request)
{
auto acceptIt = request.headers.find(Http::ACCEPT_HEADER);
std::string acceptHeader{};
if (acceptIt != request.headers.end()) {
acceptHeader = acceptIt->second;
}
std::string acceptHeader = GetHeaderValueCaseInsensitive(request, Http::ACCEPT_HEADER).value_or("");
bool hasJson = acceptHeader.find(Http::CONTENT_TYPE_JSON) != std::string::npos;
bool hasSse = acceptHeader.find(Http::CONTENT_TYPE_SSE) != std::string::npos;
if (!hasJson || !hasSse) {
Expand All @@ -195,9 +216,8 @@ bool StreamableHttpServerTransport::ValidatePostRequestHeaders(RequestContext& c
ctx.httpSendFunc(response, ctx);
return false;
}
auto contentTypeIt = request.headers.find(Http::CONTENT_TYPE_HEADER);
if (contentTypeIt == request.headers.end() ||
contentTypeIt->second.find(Http::CONTENT_TYPE_JSON) == std::string::npos) {
auto contentType = GetHeaderValueCaseInsensitive(request, Http::CONTENT_TYPE_HEADER);
if (!contentType.has_value() || contentType->find(Http::CONTENT_TYPE_JSON) == std::string::npos) {
HttpResponse response = CreateErrorResponse("Unsupported Media Type: Content-Type must be application/json",
Http::HTTP_STATUS_UNSUPPORTED_MEDIA_TYPE,
static_cast<int>(JsonRpcErrorCode::INVALID_REQUEST));
Expand Down Expand Up @@ -317,11 +337,7 @@ void StreamableHttpServerTransport::HandleGetRequest(RequestContext& ctx, const
}

// Check Accept header
auto acceptIt = request.headers.find("accept");
std::string acceptHeader{};
if (acceptIt != request.headers.end()) {
acceptHeader = acceptIt->second;
}
std::string acceptHeader = GetHeaderValueCaseInsensitive(request, Http::ACCEPT_HEADER).value_or("");

bool hasSse = acceptHeader.find(Http::CONTENT_TYPE_SSE) != std::string::npos;
if (!hasSse) {
Expand Down
72 changes: 72 additions & 0 deletions MCP/cpp-sdk/src/shared/jsonrpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,36 @@ struct adl_serializer<Mcp::InitializeResult> {
{
j["protocolVersion"] = r.protocolVersion;
j["capabilities"] = json::object();
if (r.capabilities.tools.has_value()) {
json toolsObj = json::object();
if (r.capabilities.tools->listChanged.has_value()) {
toolsObj["listChanged"] = r.capabilities.tools->listChanged.value();
}
j["capabilities"]["tools"] = std::move(toolsObj);
}
if (r.capabilities.prompts.has_value()) {
json promptsObj = json::object();
if (r.capabilities.prompts->listChanged.has_value()) {
promptsObj["listChanged"] = r.capabilities.prompts->listChanged.value();
}
j["capabilities"]["prompts"] = std::move(promptsObj);
}
if (r.capabilities.resources.has_value()) {
json resourcesObj = json::object();
if (r.capabilities.resources->subscribe.has_value()) {
resourcesObj["subscribe"] = r.capabilities.resources->subscribe.value();
}
if (r.capabilities.resources->listChanged.has_value()) {
resourcesObj["listChanged"] = r.capabilities.resources->listChanged.value();
}
j["capabilities"]["resources"] = std::move(resourcesObj);
}
if (r.capabilities.logging.has_value()) {
j["capabilities"]["logging"] = json::object();
}
if (r.capabilities.experimental.has_value()) {
j["capabilities"]["experimental"] = r.capabilities.experimental.value();
}
j["serverInfo"] = {
{"name", r.serverInfo.name},
{"version", r.serverInfo.version},
Expand All @@ -131,6 +161,48 @@ struct adl_serializer<Mcp::InitializeResult> {
r.protocolVersion = Mcp::DEFAULT_PROTOCOL_VERSION;
}
r.capabilities = Mcp::ServerCapabilities{};
if (j.contains("capabilities") && j.at("capabilities").is_object()) {
const auto& caps = j.at("capabilities");

if (caps.contains("tools") && caps.at("tools").is_object()) {
Mcp::ToolsCapabilities toolsCaps{};
const auto& toolsObj = caps.at("tools");
if (toolsObj.contains("listChanged") && toolsObj.at("listChanged").is_boolean()) {
toolsCaps.listChanged = toolsObj.at("listChanged").get<bool>();
}
r.capabilities.tools = toolsCaps;
}

if (caps.contains("prompts") && caps.at("prompts").is_object()) {
Mcp::PromptsCapabilities promptsCaps{};
const auto& promptsObj = caps.at("prompts");
if (promptsObj.contains("listChanged") && promptsObj.at("listChanged").is_boolean()) {
promptsCaps.listChanged = promptsObj.at("listChanged").get<bool>();
}
r.capabilities.prompts = promptsCaps;
}

if (caps.contains("resources") && caps.at("resources").is_object()) {
Mcp::ResourcesCapabilities resourcesCaps{};
const auto& resourcesObj = caps.at("resources");
if (resourcesObj.contains("subscribe") && resourcesObj.at("subscribe").is_boolean()) {
resourcesCaps.subscribe = resourcesObj.at("subscribe").get<bool>();
}
if (resourcesObj.contains("listChanged") && resourcesObj.at("listChanged").is_boolean()) {
resourcesCaps.listChanged = resourcesObj.at("listChanged").get<bool>();
}
r.capabilities.resources = resourcesCaps;
}

if (caps.contains("logging") && caps.at("logging").is_object()) {
r.capabilities.logging = Mcp::LoggingCapabilities{};
}

if (caps.contains("experimental") && caps.at("experimental").is_object()) {
r.capabilities.experimental =
caps.at("experimental").get<std::unordered_map<std::string, std::string>>();
}
}
if (j.contains("serverInfo")) {
const auto& si = j.at("serverInfo");
r.serverInfo.name = si.value("name", std::string{});
Expand Down