diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index c37a2e6dda1..240db25cb63 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -103,6 +103,7 @@ elseif(CMAKE_JS_VERSION) else() add_subdirectory(cli) add_subdirectory(bench) + add_subdirectory(mcp) add_subdirectory(server) add_subdirectory(quantize) add_subdirectory(vad-speech-segments) diff --git a/examples/mcp/CMakeLists.txt b/examples/mcp/CMakeLists.txt new file mode 100644 index 00000000000..92e28b51dfd --- /dev/null +++ b/examples/mcp/CMakeLists.txt @@ -0,0 +1,15 @@ +set(TARGET whisper-mcp-server) +add_executable(${TARGET} mcp-server.cpp stdio-transport.cpp mcp-handler.cpp) +target_include_directories(${TARGET} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) + +set(DEMO_TARGET mcp-demo) +add_executable(${DEMO_TARGET} mcp-demo.cpp stdio-client.cpp) + +include(DefaultTargetOptions) + +target_link_libraries(${TARGET} PRIVATE common json_cpp whisper ${CMAKE_THREAD_LIBS_INIT}) + +# mcp_client only needs json_cpp and threading not whisper. +target_link_libraries(${DEMO_TARGET} PRIVATE json_cpp ${CMAKE_THREAD_LIBS_INIT}) + +install(TARGETS ${TARGET} ${DEMO_TARGET} RUNTIME) diff --git a/examples/mcp/README.md b/examples/mcp/README.md new file mode 100644 index 00000000000..6441e4a61b2 --- /dev/null +++ b/examples/mcp/README.md @@ -0,0 +1,59 @@ +# whisper.cpp/examples/mcp +This directory contains an example of using the Model Context Protocol (MCP) with `whisper.cpp`. The transport +used in this example is the simple input/output (stdin/stdout) transport. When using the input/output transport, +the client is responsible for starting the server as a child process and the communication is done by reading +and writing data to stardard in/out. + +## Usage +The stdio client demo can be run using the following command: +``` +./build/bin/mcp-demo +``` +This will initalize the server using the [initialization] lifecycle phase. +Following that the client will send a request for the list of tools ([tools/list]) that the server supports. +It will then send a request to transcribe an audio file. + + +### Claude.ai Desktop integration +The Whisper.cpp MCP server can be integrated with the Claude.ai Desktop application. + +This requires adding a MCP server configuration to the Claude.ai Desktop: +```console +$ cat ~/Library/Application\ Support/Claude/claude_desktop_config.json +{ + "mcpServers": { + "whisper": { + "command": "/Users/danbev/work/ai/whisper.cpp/build/bin/whisper-mcp-server", + "args": [ + "--model", + "/Users/danbev/work/ai/whisper.cpp/models/ggml-base.en.bin" + ] + } + } +} +``` +Update the above paths to match your local system. And then restart the Claude.ai Desktop application. + +After that, clicking on "Connect apps" should show the following: + +![Claude.ai MCP integration screen](images/integration.png) + +And clicking on `[...]` should show the tools available: + +![Claude.ai MCP tools screen](images/tools.png) + +We should then be able to transribe an audio file by using a prompt like this: +```console +Can you transcribe the audio file at /Users/danbev/work/ai/whisper.cpp/samples/jfk.wav? +``` +And this will then prompt for accepting to run the transcription tool: + +![Claude.ai MCP accept screen](images/transcribe-accept.png) + +And this should result in a successful transcription: + +![Claude.ai MCP transcription result](images/transcribe-screenshot.png) + + +[initialization]: https://modelcontextprotocol.io/specification/2025-03-26/basic/lifecycle#initialization +[tools/list]: https://modelcontextprotocol.io/specification/2025-03-26/server/tools#listing-tools diff --git a/examples/mcp/images/integration.png b/examples/mcp/images/integration.png new file mode 100644 index 00000000000..d770a0ebba8 Binary files /dev/null and b/examples/mcp/images/integration.png differ diff --git a/examples/mcp/images/tools.png b/examples/mcp/images/tools.png new file mode 100644 index 00000000000..7fc0da8cd9f Binary files /dev/null and b/examples/mcp/images/tools.png differ diff --git a/examples/mcp/images/transcribe-accept.png b/examples/mcp/images/transcribe-accept.png new file mode 100644 index 00000000000..af8f0bf64e1 Binary files /dev/null and b/examples/mcp/images/transcribe-accept.png differ diff --git a/examples/mcp/images/transcribe-screenshot.png b/examples/mcp/images/transcribe-screenshot.png new file mode 100644 index 00000000000..6ac136bf36a Binary files /dev/null and b/examples/mcp/images/transcribe-screenshot.png differ diff --git a/examples/mcp/mcp-demo.cpp b/examples/mcp/mcp-demo.cpp new file mode 100644 index 00000000000..578ca8345fe --- /dev/null +++ b/examples/mcp/mcp-demo.cpp @@ -0,0 +1,84 @@ +#include "stdio-client.hpp" + +#include +#include + +void pretty_print_json(const json & j) { + std::cout << j.dump(2) << std::endl; +} + +int main(int argc, char ** argv) { + std::string server_command = "build/bin/whisper-mcp-server"; + + if (argc > 1) { + server_command = argv[1]; + } + + printf("Starting MCP Demo\n"); + printf("Server command: %s\n", server_command.c_str()); + + try { + mcp::StdioClient client; + + // Start the server + printf("Starting server...\n"); + if (!client.start_server(server_command)) { + fprintf(stderr, "Failed to start server\n"); + return 1; + } + + if (!client.wait_for_server_ready(2000)) { + fprintf(stderr, "Server failed to start within timeout\n"); + return 1; + } + + client.read_server_logs(); + + // Initialize + printf("Initializing...\n"); + json init_response = client.initialize("mcp-demo-client", "1.0.0"); + printf("Initialize response:\n"); + pretty_print_json(init_response); + + if (init_response.contains("error")) { + fprintf(stderr, "Initialization failed!\n"); + return 1; + } + + // Send initialized notification + printf("Sending initialized notification...\n"); + client.send_initialized(); + client.read_server_logs(); + + // List tools + printf("Listing tools...\n"); + json tools_response = client.list_tools(); + printf("Tools list response:\n"); + pretty_print_json(tools_response); + + // Call transcribe tool + printf("Calling transcribe tool...\n"); + json transcribe_args = { + {"file", "samples/jfk.wav"} + }; + + json transcribe_response = client.call_tool("transcribe", transcribe_args); + printf("Transcribe response:\n"); + pretty_print_json(transcribe_response); + + // Call model info tool + printf("Calling model info tool...\n"); + json model_info_response = client.call_tool("model_info", json::object()); + printf("Model info response:\n"); + pretty_print_json(model_info_response); + + // Final logs + printf("Final server logs:\n"); + client.read_server_logs(); + } catch (const std::exception & e) { + fprintf(stderr, "Exception: %s\n", e.what()); + return 1; + } + + return 0; +} diff --git a/examples/mcp/mcp-handler.cpp b/examples/mcp/mcp-handler.cpp new file mode 100644 index 00000000000..622e94b1663 --- /dev/null +++ b/examples/mcp/mcp-handler.cpp @@ -0,0 +1,331 @@ +#include "mcp-handler.hpp" +#include "common.h" +#include "common-whisper.h" +#include +#include + +namespace mcp { + +// JSON-RPC 2.0 error codes +enum class MCPError : int { + // Standard JSON-RPC errors + PARSE_ERROR = -32700, + INVALID_REQUEST = -32600, + METHOD_NOT_FOUND = -32601, + INVALID_PARAMS = -32602, + INTERNAL_ERROR = -32603, + + // MCP-specific errors + MODEL_NOT_LOADED = 1001, + AUDIO_FILE_ERROR = 1002, + TRANSCRIPTION_FAILED = 1003 +}; + +Handler::Handler(Transport * transport, + const struct mcp_params & mparams, + const struct whisper_params & wparams, + const std::string & model_path) + : transport(transport), ctx(nullptr), model_path(model_path), model_loaded(false) + ,mparams(mparams), wparams(wparams) { + if (!transport) { + throw std::invalid_argument("Transport cannot be null"); + } +} + +Handler::~Handler() { + if (ctx) { + whisper_free(ctx); + } +} + +bool Handler::handle_message(const json & request) { + // Validate JSON-RPC 2.0 format + if (!request.contains("jsonrpc") || request["jsonrpc"] != "2.0") { + fprintf(stderr, "Invalid JSON-RPC format\n"); + return false; + } + + // Extract request ID (can be null for notifications) + json id = nullptr; + if (request.contains("id")) { + id = request["id"]; + } + + // Extract method + std::string method = request.value("method", ""); + if (method.empty()) { + send_error(id, static_cast(MCPError::INVALID_REQUEST), "Invalid request: missing method"); + return true; + } + + fprintf(stderr, "Processing method: %s\n", method.c_str()); + + // TODO: add missing methods from specification + try { + if (method == "initialize") { + handle_initialize(id, request.value("params", json::object())); + } + else if (method == "tools/list") { + handle_list_tools(id); + } + else if (method == "tools/call") { + handle_tool_call(id, request.value("params", json::object())); + } + else if (method == "notifications/initialized") { + handle_notification_initialized(); + } + else { + send_error(id, static_cast(MCPError::METHOD_NOT_FOUND), "Method not found: " + method); + } + return true; + } catch (const std::exception & e) { + fprintf(stderr, "Exception in message handler: %s\n", e.what()); + send_error(id, static_cast(MCPError::INTERNAL_ERROR), "Internal error: " + std::string(e.what())); + return true; + } +} + +void Handler::handle_initialize(const json & id, const json & params) { + fprintf(stderr, "Initializing whisper server with model: %s\n", model_path.c_str()); + + if (!load_model()) { + send_error(id, static_cast(MCPError::INTERNAL_ERROR), "Failed to load whisper model"); + + return; + } + + json result = { + {"protocolVersion", "2024-11-05"}, + {"capabilities", { + {"tools", json::object()} + }}, + {"serverInfo", { + {"name", "whisper-mcp-server"}, + {"version", "1.0.0"} + }} + }; + + send_result(id, result); +} + +void Handler::handle_list_tools(const json & id) { + fprintf(stderr, "Listing available tools\n"); + + json result = { + {"tools", json::array({ + { + {"name", "transcribe"}, + {"description", "Transcribe audio file using whisper.cpp"}, + {"inputSchema", { + {"type", "object"}, + {"properties", { + {"file", { + {"type", "string"}, + {"description", "Path to audio file"} + }}, + {"language", { + {"type", "string"}, + {"description", "Language code (optional, auto-detect if not specified)"}, + {"default", "auto"} + }}, + {"translate", { + {"type", "boolean"}, + {"description", "Translate to English"}, + {"default", false} + }} + }}, + {"required", json::array({"file"})} + }} + }, + { + {"name", "model_info"}, + {"description", "Get information about loaded model"}, + {"inputSchema", { + {"type", "object"}, + {"properties", json::object()} + }} + } + })} + }; + + send_result(id, result); +} + +void Handler::handle_tool_call(const json & id, const json & params) { + if (!params.contains("name")) { + send_error(id, static_cast(MCPError::INVALID_PARAMS), "Missing required parameter: name"); + return; + } + + std::string tool_name = params["name"]; + json arguments = params.value("arguments", json::object()); + + if (tool_name == "transcribe") { + json result = create_transcribe_result(arguments); + send_result(id, result); + } + else if (tool_name == "model_info") { + json result = create_model_info_result(); + send_result(id, result); + } + else { + send_error(id, static_cast(MCPError::METHOD_NOT_FOUND), "Unknown tool: " + tool_name); + } +} + +void Handler::handle_notification_initialized() { + fprintf(stderr, "Client initialization completed\n"); +} + +void Handler::send_result(const json & id, const json & result) { + json response = { + {"jsonrpc", "2.0"}, + {"result", result} + }; + + if (!id.is_null()) { + response["id"] = id; + } + + transport->send_response(response); +} + +void Handler::send_error(const json & id, int code, const std::string & message) { + json response = { + {"jsonrpc", "2.0"}, + {"id", id}, + {"error", { + {"code", code}, + {"message", message} + }} + }; + + transport->send_response(response); +} + +bool Handler::load_model() { + if (model_loaded) { + return true; + } + + fprintf(stderr, "Loading whisper model from: %s\n", model_path.c_str()); + + whisper_context_params cparams = whisper_context_default_params(); + ctx = whisper_init_from_file_with_params(model_path.c_str(), cparams); + + if (!ctx) { + fprintf(stderr, "Failed to load model: %s\n", model_path.c_str()); + return false; + } + + model_loaded = true; + fprintf(stderr, "Model loaded successfully!\n"); + return true; +} + +std::string Handler::transcribe_file(const std::string & filepath, + const std::string & language, + bool translate) { + if (!model_loaded) { + throw std::runtime_error("Model not loaded"); + } + + whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + + if (language != "auto" && whisper_lang_id(language.c_str()) == -1) { + throw std::runtime_error("Unknown language: " + language); + } + + if (language != "auto") { + wparams.language = language.c_str(); + } else { + wparams.language = "auto"; + } + + wparams.translate = translate; + wparams.print_progress = false; + wparams.print_timestamps = false; + + std::vector pcmf32; + if (!load_audio_file(filepath, pcmf32)) { + throw std::runtime_error("Failed to load audio file: " + filepath); + } + + if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { + throw std::runtime_error("Whisper inference failed"); + } + + std::string result; + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(ctx, i); + result += text; + } + + return result; +} + +bool Handler::load_audio_file(const std::string & fname_inp, std::vector & pcmf32) { + fprintf(stderr, "Loading audio file: %s\n", fname_inp.c_str()); + std::vector> pcmf32s; + + if (!::read_audio_data(fname_inp, pcmf32, pcmf32s, wparams.diarize)) { + fprintf(stderr, "Failed to read audio file: %s\n", fname_inp.c_str()); + return false; + } + + fprintf(stderr, "Successfully loaded %s\n", fname_inp.c_str()); + return true; +} + +json Handler::create_transcribe_result(const json & arguments) { + try { + if (!arguments.contains("file")) { + throw std::runtime_error("Missing required parameter: file"); + } + + std::string file_path = arguments["file"]; + std::string language = arguments.value("language", "auto"); + bool translate = arguments.value("translate", false); + + std::string transcription = transcribe_file(file_path, language, translate); + + return json{ + {"content", json::array({ + { + {"type", "text"}, + {"text", transcription} + } + })} + }; + + } catch (const std::exception & e) { + throw std::runtime_error("Transcription failed: " + std::string(e.what())); + } +} + +json Handler::create_model_info_result() { + if (!model_loaded) { + throw std::runtime_error("No model loaded"); + } + + json model_info = { + {"model_path", model_path}, + {"model_loaded", model_loaded}, + {"vocab_size", whisper_n_vocab(ctx)}, + {"n_text_ctx", whisper_n_text_ctx(ctx)}, + {"n_audio_ctx", whisper_n_audio_ctx(ctx)}, + {"is_multilingual", whisper_is_multilingual(ctx)} + }; + + return json{ + {"content", json::array({ + { + {"type", "text"}, + {"text", "Model Information:\n" + model_info.dump(2)} + } + })} + }; +} + +} // namespace mcp diff --git a/examples/mcp/mcp-handler.hpp b/examples/mcp/mcp-handler.hpp new file mode 100644 index 00000000000..624fb39876b --- /dev/null +++ b/examples/mcp/mcp-handler.hpp @@ -0,0 +1,50 @@ +#pragma once + +#include "mcp-transport.hpp" +#include "mcp-params.hpp" + +#include "whisper.h" +#include +#include + +namespace mcp { + +class Handler { +public: + explicit Handler(mcp::Transport * transport, + const struct mcp_params & mparams, + const struct whisper_params & wparams, + const std::string & model_path); + ~Handler(); + + bool handle_message(const json & request); + +private: + // MCP protocol methods + void handle_initialize(const json & id, const json & params); + void handle_list_tools(const json & id); + void handle_tool_call(const json & id, const json & params); + void handle_notification_initialized(); + + // Response helpers + void send_result(const json & id, const json & result); + void send_error(const json & id, int code, const std::string & message); + + bool load_model(); + std::string transcribe_file(const std::string & filepath, + const std::string & language = "auto", + bool translate = false); + bool load_audio_file(const std::string & fname_inp, std::vector & pcmf32); + + json create_transcribe_result(const json & arguments); + json create_model_info_result(); + + bool model_loaded; + mcp::Transport * transport; + struct whisper_context * ctx; + std::string model_path; + struct mcp_params mparams; + struct whisper_params wparams; +}; + +} // namespace mcp diff --git a/examples/mcp/mcp-params.hpp b/examples/mcp/mcp-params.hpp new file mode 100644 index 00000000000..c55df68a96f --- /dev/null +++ b/examples/mcp/mcp-params.hpp @@ -0,0 +1,53 @@ +#pragma once + +#include +#include + +struct mcp_params { + bool ffmpeg_converter = false; +}; + +struct whisper_params { + int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + int32_t n_processors = 1; + int32_t offset_t_ms = 0; + int32_t offset_n = 0; + int32_t duration_ms = 0; + int32_t progress_step = 5; + int32_t max_context = -1; + int32_t max_len = 0; + int32_t best_of = 2; + int32_t beam_size = -1; + int32_t audio_ctx = 0; + + float word_thold = 0.01f; + float entropy_thold = 2.40f; + float logprob_thold = -1.00f; + float temperature = 0.00f; + float temperature_inc = 0.20f; + float no_speech_thold = 0.6f; + + bool debug_mode = false; + bool translate = false; + bool detect_language = false; + bool diarize = false; + bool tinydiarize = false; + bool split_on_word = false; + bool no_fallback = false; + bool print_special = false; + bool print_colors = false; + bool print_realtime = false; + bool print_progress = false; + bool no_timestamps = false; + bool use_gpu = true; + bool flash_attn = false; + bool suppress_nst = false; + bool no_context = false; + + std::string language = "en"; + std::string prompt = ""; + std::string model = "models/ggml-base.en.bin"; + std::string response_format = "json"; + std::string openvino_encode_device = "CPU"; + std::string dtw = ""; +}; diff --git a/examples/mcp/mcp-server.cpp b/examples/mcp/mcp-server.cpp new file mode 100644 index 00000000000..60d29a0322b --- /dev/null +++ b/examples/mcp/mcp-server.cpp @@ -0,0 +1,168 @@ +#include "stdio-transport.hpp" +#include "mcp-handler.hpp" + +#include "common.h" +#include "common-whisper.h" + +#include "whisper.h" +#include "json.hpp" + +#include +#include +#include +#include +#include + +using json = nlohmann::ordered_json; + +namespace { + +const std::string json_format = "json"; + +void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params, const mcp_params & mparams) { + fprintf(stderr, "\n"); + fprintf(stderr, "usage: %s [options] \n", argv[0]); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); + fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); + fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); + fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); + fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); + fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); + fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false"); + fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); + fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); + fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); + fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); + fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); + fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); + fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); + fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); + fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); + fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); + fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); + fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false"); + fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); + fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false"); + fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str()); + fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str()); + // mcp params + fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server\n", mparams.ffmpeg_converter ? "true" : "false"); + + fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false"); + fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold); + fprintf(stderr, " -nc, --no-context [%-7s] do not use previous audio context\n", params.no_context ? "true" : "false"); + fprintf(stderr, " -ng, --no-gpu [%-7s] do not use gpu\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, "\n"); +} + +bool whisper_params_parse(int argc, char ** argv, whisper_params & params, mcp_params & mparams) { + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + if (arg == "-h" || arg == "--help") { + whisper_print_usage(argc, argv, params, mparams); + exit(0); + } + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } + else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); } + else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); } + else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); } + else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } + else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } + else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } + else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); } + else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); } + else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } + else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } + else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } + else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } + else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; } + else if (arg == "-tr" || arg == "--translate") { params.translate = true; } + else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } + else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } + else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } + else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } + else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } + else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } + else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; } + else if ( arg == "--prompt") { params.prompt = argv[++i]; } + else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } + else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; } + else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } + else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } + else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(argv[++i]); } + else if (arg == "-nc" || arg == "--no-context") { params.no_context = true; } + + // mcp server params + else if ( arg == "--convert") { mparams.ffmpeg_converter = true; } + else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + whisper_print_usage(argc, argv, params, mparams); + exit(0); + } + GGML_UNUSED(mparams); + } + + return true; +} + +void check_ffmpeg_availibility() { + int result = system("ffmpeg -version"); + + if (result == 0) { + printf("ffmpeg is available.\n"); + } else { + printf("ffmpeg is not available.\n"); + exit(0); + } +} + +} // namespace + +int main(int argc, char ** argv) { + ggml_backend_load_all(); + + whisper_params wparams; + mcp_params mparams; + + if (whisper_params_parse(argc, argv, wparams, mparams) == false) { + whisper_print_usage(argc, argv, wparams, mparams); + return 1; + } + + if (wparams.language != "auto" && whisper_lang_id(wparams.language.c_str()) == -1) { + fprintf(stderr, "error: unknown language '%s'\n", wparams.language.c_str()); + whisper_print_usage(argc, argv, wparams, mparams); + exit(0); + } + + fprintf(stderr, "Whisper MCP Server starting...\n"); + + if (mparams.ffmpeg_converter) { + check_ffmpeg_availibility(); + } + + try { + mcp::StdioTransport transport; + mcp::Handler handler(&transport, mparams, wparams, wparams.model); + + fprintf(stderr, "MCP Server ready, listening on stdin...\n"); + transport.run(&handler); + + fprintf(stderr, "MCP Server shutting down\n"); + + } catch (const std::exception& e) { + fprintf(stderr, "Fatal error: %s\n", e.what()); + return 1; + } + + return 0; +} diff --git a/examples/mcp/mcp-transport.hpp b/examples/mcp/mcp-transport.hpp new file mode 100644 index 00000000000..f1b587414ff --- /dev/null +++ b/examples/mcp/mcp-transport.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include "json.hpp" + +using json = nlohmann::ordered_json; + +namespace mcp { + +class Transport { +public: + virtual ~Transport() = default; + virtual void send_response(const json & response) = 0; +}; + +} // namespace mcp diff --git a/examples/mcp/stdio-client.cpp b/examples/mcp/stdio-client.cpp new file mode 100644 index 00000000000..e690c0a875c --- /dev/null +++ b/examples/mcp/stdio-client.cpp @@ -0,0 +1,248 @@ +#include "stdio-client.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mcp { + +StdioClient::StdioClient() : server_pid(-1), server_stdin(nullptr), server_stdout(nullptr), + server_stderr(nullptr) ,request_id_counter(0) , server_running(false) { + stdin_pipe[0] = stdin_pipe[1] = -1; + stdout_pipe[0] = stdout_pipe[1] = -1; + stderr_pipe[0] = stderr_pipe[1] = -1; +} + +StdioClient::~StdioClient() { + cleanup(); +} + +void StdioClient::cleanup() { + if (server_stdin) { + fclose(server_stdin); + server_stdin = nullptr; + } + + if (server_stdout) { + fclose(server_stdout); + server_stdout = nullptr; + } + + if (server_stderr) { + fclose(server_stderr); + server_stderr = nullptr; + } + + if (server_running && server_pid > 0) { + kill(server_pid, SIGTERM); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + int status; + if (waitpid(server_pid, &status, WNOHANG) == 0) { + kill(server_pid, SIGKILL); + waitpid(server_pid, &status, 0); + } + server_running = false; + } +} + +bool StdioClient::start_server(const std::string & server_command, const std::vector & args) { + if (server_running) { + return false; // Already running + } + + // Create pipes + if (pipe(stdin_pipe) == -1 || pipe(stdout_pipe) == -1 || pipe(stderr_pipe) == -1) { + return false; + } + + server_pid = fork(); + if (server_pid == -1) { + return false; + } + + if (server_pid == 0) { + // Child process - become the server + dup2(stdin_pipe[0], STDIN_FILENO); + dup2(stdout_pipe[1], STDOUT_FILENO); + dup2(stderr_pipe[1], STDERR_FILENO); + + // Close all pipe ends + close(stdin_pipe[0]); close(stdin_pipe[1]); + close(stdout_pipe[0]); close(stdout_pipe[1]); + close(stderr_pipe[0]); close(stderr_pipe[1]); + + // Prepare arguments for execvp + std::vector argv; + argv.push_back(const_cast(server_command.c_str())); + + for (const auto& arg : args) { + argv.push_back(const_cast(arg.c_str())); + } + argv.push_back(nullptr); + + execvp(server_command.c_str(), argv.data()); + exit(1); // exec failed + } + + // Parent process - set up communication + close(stdin_pipe[0]); + close(stdout_pipe[1]); + close(stderr_pipe[1]); + + server_stdin = fdopen(stdin_pipe[1], "w"); + server_stdout = fdopen(stdout_pipe[0], "r"); + server_stderr = fdopen(stderr_pipe[0], "r"); + + if (!server_stdin || !server_stdout || !server_stderr) { + cleanup(); + return false; + } + + server_running = true; + return true; +} + +void StdioClient::stop_server() { + cleanup(); +} + +json StdioClient::send_request(const json & request) { + if (!server_running) { + throw std::runtime_error("Server is not running"); + } + + std::string request_str = request.dump() + "\n"; + + if (fputs(request_str.c_str(), server_stdin) == EOF) { + throw std::runtime_error("Failed to send request to server"); + } + fflush(server_stdin); + + // For notifications (no id), don't wait for response + if (!request.contains("id")) { + return json{}; + } + + // Read response + char buffer[4096]; + if (fgets(buffer, sizeof(buffer), server_stdout) == nullptr) { + throw std::runtime_error("Failed to read response from server"); + } + + std::string response_str(buffer); + if (!response_str.empty() && response_str.back() == '\n') { + response_str.pop_back(); + } + + return json::parse(response_str); +} + +void StdioClient::read_server_logs() { + int flags = fcntl(fileno(server_stderr), F_GETFL, 0); + fcntl(fileno(server_stderr), F_SETFL, flags | O_NONBLOCK); + + char buffer[1024]; + while (fgets(buffer, sizeof(buffer), server_stderr) != nullptr) { + std::cout << "[SERVER LOG] " << buffer; + } + + fcntl(fileno(server_stderr), F_SETFL, flags); +} + +json StdioClient::initialize(const std::string & client_name, const std::string & client_version) { + json request = { + {"jsonrpc", "2.0"}, + {"id", next_request_id()}, + {"method", "initialize"}, + {"params", { + {"protocolVersion", "2024-11-05"}, + {"capabilities", { + {"tools", json::object()} + }}, + {"clientInfo", { + {"name", client_name}, + {"version", client_version} + }} + }} + }; + + return send_request(request); +} + +void StdioClient::send_initialized() { + json notification = { + {"jsonrpc", "2.0"}, + {"method", "notifications/initialized"} + }; + + send_request(notification); +} + +json StdioClient::list_tools() { + json request = { + {"jsonrpc", "2.0"}, + {"id", next_request_id()}, + {"method", "tools/list"} + }; + + return send_request(request); +} + +json StdioClient::call_tool(const std::string & tool_name, const json & arguments) { + json request = { + {"jsonrpc", "2.0"}, + {"id", next_request_id()}, + {"method", "tools/call"}, + {"params", { + {"name", tool_name}, + {"arguments", arguments} + }} + }; + + return send_request(request); +} + +int StdioClient::next_request_id() { + return ++request_id_counter; +} + +bool StdioClient::wait_for_server_ready(int timeout_ms) { + auto start = std::chrono::steady_clock::now(); + + while (std::chrono::duration_cast( + std::chrono::steady_clock::now() - start).count() < timeout_ms) { + + if (server_running) { + // Give server a moment to fully start up + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + return true; + } + + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + + return false; +} + +std::string StdioClient::get_last_server_logs() { + std::stringstream logs; + + int flags = fcntl(fileno(server_stderr), F_GETFL, 0); + fcntl(fileno(server_stderr), F_SETFL, flags | O_NONBLOCK); + + char buffer[1024]; + while (fgets(buffer, sizeof(buffer), server_stderr) != nullptr) { + logs << buffer; + } + + fcntl(fileno(server_stderr), F_SETFL, flags); + return logs.str(); +} + +} // namespace mcp diff --git a/examples/mcp/stdio-client.hpp b/examples/mcp/stdio-client.hpp new file mode 100644 index 00000000000..c7bac9f22ef --- /dev/null +++ b/examples/mcp/stdio-client.hpp @@ -0,0 +1,52 @@ +#pragma once + +#include "json.hpp" +#include +#include + +using json = nlohmann::json; + +namespace mcp { + +class StdioClient { +private: + pid_t server_pid; + int stdin_pipe[2]; + int stdout_pipe[2]; + int stderr_pipe[2]; + FILE* server_stdin; + FILE* server_stdout; + FILE* server_stderr; + int request_id_counter; + bool server_running; + + void cleanup(); + +public: + StdioClient(); + ~StdioClient(); + + bool start_server(const std::string & server_command, const std::vector & args = {}); + void stop_server(); + bool is_server_running() const { + return server_running; + } + + // Core MCP communication + json send_request(const json& request); + void read_server_logs(); + + // MCP protocol methods + json initialize(const std::string& client_name = "mcp-test-client", + const std::string& client_version = "1.0.0"); + void send_initialized(); + json list_tools(); + json call_tool(const std::string& tool_name, const json& arguments); + + // Utilities + int next_request_id(); + bool wait_for_server_ready(int timeout_ms = 1000); + std::string get_last_server_logs(); +}; + +} // namespace mcp diff --git a/examples/mcp/stdio-transport.cpp b/examples/mcp/stdio-transport.cpp new file mode 100644 index 00000000000..9489cc97ed0 --- /dev/null +++ b/examples/mcp/stdio-transport.cpp @@ -0,0 +1,35 @@ +#include "stdio-transport.hpp" +#include "mcp-handler.hpp" + +#include +#include +#include + +namespace mcp { + +void StdioTransport::send_response(const json & response) { + std::cout << response.dump() << std::endl; + std::cout.flush(); +} + +void StdioTransport::run(Handler * handler) { + std::string line; + while (std::getline(std::cin, line)) { + if (line.empty()) { + continue; + } + + fprintf(stderr, "Received: %s\n", line.c_str()); + + try { + json request = json::parse(line); + handler->handle_message(request); + } catch (const json::parse_error & e) { + fprintf(stderr, "JSON parse error: %s\n", e.what()); + } catch (const std::exception & e) { + fprintf(stderr, "Error processing request: %s\n", e.what()); + } + } +} + +} // namespace mcp diff --git a/examples/mcp/stdio-transport.hpp b/examples/mcp/stdio-transport.hpp new file mode 100644 index 00000000000..ea3e090b033 --- /dev/null +++ b/examples/mcp/stdio-transport.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include "mcp-transport.hpp" + +namespace mcp { + +class Handler; + +class StdioTransport : public Transport { +public: + StdioTransport() = default; + ~StdioTransport() = default; + + void send_response(const json & response) override; + + void run(Handler * handler); +}; + +} // namespace mcp diff --git a/examples/mcp/test-server.sh b/examples/mcp/test-server.sh new file mode 100755 index 00000000000..05f3eefe33d --- /dev/null +++ b/examples/mcp/test-server.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +set -e + +## Test initialize request +echo '{"jsonrpc":"2.0","id":1,"method":"initialize","params":{}}' | \ + build/bin/whisper-mcp-server -m models/ggml-base.en.bin + +## Test tools list request +echo '{"jsonrpc":"2.0","id":2,"method":"tools/list","params":{}}' | \ + build/bin/whisper-mcp-server -m models/ggml-base.en.bin + +## Test transcribe +#echo '{"jsonrpc":"2.0","id":4,"method":"tools/call","params":{"name":"transcribe","arguments":{"file":"samples/jfk.wav","language":"en"}}}' | \ + #build/bin/whisper-mcp-server -m models/ggml-base.en.bin diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index efa1bbe3fc8..43e41d485ae 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -103,3 +103,11 @@ target_include_directories(${VAD_TEST} PRIVATE ../include ../ggml/include ../exa target_link_libraries(${VAD_TEST} PRIVATE common) add_test(NAME ${VAD_TEST} COMMAND ${VAD_TEST}) set_tests_properties(${VAD_TARGET} PROPERTIES LABELS "base;en") + +# MCP Handler test +set(MCP_TEST test-mcp) +add_executable(${MCP_TEST} ${MCP_TEST}.cpp ../examples/mcp/stdio-client.cpp) +target_include_directories(${MCP_TEST} PRIVATE ../include ../ggml/include ../examples ../examples/mcp) +target_link_libraries(${MCP_TEST} PRIVATE common) +add_test(NAME ${MCP_TEST} COMMAND ${MCP_TEST}) +set_tests_properties(${MCP_TEST} PROPERTIES LABELS "unit") diff --git a/tests/test-mcp.cpp b/tests/test-mcp.cpp new file mode 100644 index 00000000000..f97d3d60f5c --- /dev/null +++ b/tests/test-mcp.cpp @@ -0,0 +1,52 @@ +#include "stdio-client.hpp" + +#include "whisper.h" +#include "common-whisper.h" + +#include + +template +void assert_json_equals(const json & j, const std::string & key, const T & expected) { + assert(j.contains(key)); + assert(j.at(key) == expected); +} + +void assert_initialized(const json & response) { + assert_json_equals(response, "id", 1); + assert_json_equals(response, "jsonrpc", "2.0"); + + json result = response.at("result"); + + json cap = result.at("capabilities"); + assert(cap.at("tools").is_object()); + + assert_json_equals(result, "protocolVersion", "2024-11-05"); + + json server_info = result.at("serverInfo"); + assert_json_equals(server_info, "name", "whisper-mcp-server"); + assert_json_equals(server_info, "version", "1.0.0"); +} + +int main() { + std::string server_bin = "../../build/bin/whisper-mcp-server"; + std::vector args = { + "--model", "../../models/ggml-base.en.bin" + }; + mcp::StdioClient client; + + // Start server + assert(client.start_server(server_bin, args)); + assert(client.wait_for_server_ready(2000)); + assert(client.is_server_running()); + + + // Send initialize request + assert_initialized(client.initialize("mcp-test-client", "1.0.0")); + // Send initialized notification + client.send_initialized(); + + // Read logs for debugging + client.read_server_logs(); + + return 0; +}