diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index ae2b61d..0143b65 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -10,8 +10,17 @@ env: CARGO_TERM_COLOR: always jobs: + check_format: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Check formatting + run: cargo fmt --all -- --check + clippy_check: runs-on: ubuntu-latest + needs: check_format + steps: - uses: actions/checkout@v4 - name: Run all features diff --git a/CHANGELOG.md b/CHANGELOG.md index 7530432..5f1d306 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,10 +5,17 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/). +## 0.2.8 + +### Fixed +* Fixed JSON-RPC 2.0 protocol violation: server no longer sends a response to client notifications (§4 — notifications must never be replied to) +* Fixed `notifications/cancelled`: request cancellation now actually fires for both stdio and Streamable HTTP transports +* Fixed Streamable HTTP transport silently dropping notifications without processing them + ## 0.2.7 ### Added * JSON-RPC Batch Support for client and server -## Fixed -* Fixed broken Streamable HTTP server implementation \ No newline at end of file +### Fixed +* Fixed broken Streamable HTTP server implementation diff --git a/Cargo.toml b/Cargo.toml index e43db5f..ab7f67e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ exclude = [ ] [workspace.package] -version = "0.2.7" +version = "0.2.8" license = "MIT" edition = "2024" rust-version = "1.90.0" @@ -20,7 +20,7 @@ repository = "https://github.com/RomanEmreis/neva" documentation = "https://docs.rs/neva" [workspace.dependencies] -neva_macros = { path = "neva_macros", version = "0.2.7" } +neva_macros = { path = "neva_macros", version = "0.2.8" } [workspace.lints.rust] unsafe_code = "forbid" diff --git a/README.md b/README.md index 12558c4..d25ed36 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,8 @@ Blazingly fast and easily configurable [Model Context Protocol (MCP)](https://mo With simple configuration and ergonomic APIs, it provides everything you need to quickly build MCP clients and servers, fully aligned with the latest MCP specification. -[![latest](https://img.shields.io/badge/latest-0.2.7-d8eb34)](https://crates.io/crates/neva) -[![latest](https://img.shields.io/badge/rustc-1.90+-964B00)](https://crates.io/crates/neva) +[![latest](https://img.shields.io/badge/latest-0.2.8-d8eb34)](https://crates.io/crates/neva) +[![latest](https://img.shields.io/badge/rustc-1.90+-964B00)](https://releases.rs/docs/1.90.0/) [![License: MIT](https://img.shields.io/badge/License-MIT-624bd1.svg)](https://github.com/RomanEmreis/neva/blob/main/LICENSE) [![CI](https://github.com/RomanEmreis/neva/actions/workflows/rust.yml/badge.svg)](https://github.com/RomanEmreis/neva/actions/workflows/rust.yml) [![Release](https://github.com/RomanEmreis/neva/actions/workflows/release.yml/badge.svg)](https://github.com/RomanEmreis/neva/actions/workflows/release.yml) @@ -28,7 +28,7 @@ fully aligned with the latest MCP specification. #### Dependencies ```toml [dependencies] -neva = { version = "0.2.7", features = ["client-full"] } +neva = { version = "0.2.8", features = ["client-full"] } tokio = { version = "1", features = ["full"] } ``` @@ -66,7 +66,7 @@ async fn main() -> Result<(), Error> { #### Dependencies ```toml [dependencies] -neva = { version = "0.2.7", features = ["server-full"] } +neva = { version = "0.2.8", features = ["server-full"] } tokio = { version = "1", features = ["full"] } ``` #### Code diff --git a/examples/client/src/main.rs b/examples/client/src/main.rs index 06b6e63..bef764f 100644 --- a/examples/client/src/main.rs +++ b/examples/client/src/main.rs @@ -4,8 +4,8 @@ //! cargo run -p example-client //! ``` -use std::time::Duration; use neva::prelude::*; +use std::time::Duration; use tracing_subscriber::prelude::*; #[allow(dead_code)] @@ -21,28 +21,28 @@ async fn main() -> Result<(), Error> { tracing_subscriber::registry() .with(tracing_subscriber::fmt::layer()) .init(); - - let mut client = Client::new() - .with_options(|opt| opt - .with_stdio("npx", ["-y", "@modelcontextprotocol/server-everything"]) + + let mut client = Client::new().with_options(|opt| { + opt.with_stdio("npx", ["-y", "@modelcontextprotocol/server-everything"]) .with_roots(|roots| roots.with_list_changed()) .with_timeout(Duration::from_secs(5)) - .with_mcp_version("2025-11-25")); - + .with_mcp_version("2025-11-25") + }); + client.connect().await?; - + // Ping command tracing::info!("--- PING ---"); let resp = client.ping().await?; tracing::info!("{:?}", resp); - + // List tools tracing::info!("--- LIST TOOLS ---"); let tools = client.list_tools(None).await?; for tool in tools.tools.iter() { tracing::info!("- {}", tool.name); } - + // Call a tool tracing::info!("--- CALL TOOL ---"); let args = ("message", "Hello MCP!"); @@ -54,11 +54,9 @@ async fn main() -> Result<(), Error> { let tool = tools.get("get-structured-content").unwrap(); let args = ("location", "New York"); let result = client.call_tool(&tool.name, args).await?; - let weather: Weather = tool - .validate(&result) - .and_then(|res| res.as_json())?; + let weather: Weather = tool.validate(&result).and_then(|res| res.as_json())?; tracing::info!("{:?}", weather); - + // List resources tracing::info!("--- LIST RESOURCES ---"); let resources = client.list_resources(None).await?; @@ -73,7 +71,7 @@ async fn main() -> Result<(), Error> { for res in resources.resources { tracing::info!("- {}: {:?}", res.name, res.uri); } - + // List templates tracing::info!("--- LIST RESOURCE TEMPLATES ---"); let templates = client.list_resource_templates(None).await?; @@ -83,27 +81,26 @@ async fn main() -> Result<(), Error> { // Read resource tracing::info!("--- READ RESOURCE ---"); - let resource = client.read_resource("demo://resource/static/document/architecture.md").await?; + let resource = client + .read_resource("demo://resource/static/document/architecture.md") + .await?; tracing::info!("{:?}", resource.contents); - + // List prompts tracing::info!("--- LIST PROMPTS ---"); let prompts = client.list_prompts(None).await?; for prompt in prompts.prompts { tracing::info!("- {}, {:?}", prompt.name, prompt.args); } - + // Get prompt tracing::info!("--- GET PROMPT ---"); - let args = [ - ("city", "New York"), - ("state", "NY") - ]; + let args = [("city", "New York"), ("state", "NY")]; let prompt = client.get_prompt("args-prompt", args).await?; tracing::info!("{:?}: {:?}", prompt.descr, prompt.messages); - + // This can be uncommented to check the log notifications from MCP server //tokio::time::sleep(Duration::from_secs(60)).await; - + client.disconnect().await } diff --git a/examples/http/src/main.rs b/examples/http/src/main.rs index 3515fca..87f55c1 100644 --- a/examples/http/src/main.rs +++ b/examples/http/src/main.rs @@ -1,12 +1,12 @@ //! Run with: -//! +//! //! ```no_rust //! npx @modelcontextprotocol/inspector -//! +//! //! cargo run -p example-http //! ``` use neva::prelude::*; -use tracing_subscriber::{filter, reload, prelude::*}; +use tracing_subscriber::{filter, prelude::*, reload}; #[tool] async fn remote_tool(name: String, mut ctx: Context) { @@ -22,13 +22,12 @@ async fn main() { .with(filter) .with(notification::fmt::layer()) .init(); - + App::new() - .with_options(|opt| opt - .with_http(|http| http - .bind("127.0.0.1:3000") - .with_endpoint("/mcp")) - .with_logging(handle)) + .with_options(|opt| { + opt.with_http(|http| http.bind("127.0.0.1:3000").with_endpoint("/mcp")) + .with_logging(handle) + }) .run() .await; } diff --git a/examples/large_resources_server/src/main.rs b/examples/large_resources_server/src/main.rs index 8bcb543..f6f75e6 100644 --- a/examples/large_resources_server/src/main.rs +++ b/examples/large_resources_server/src/main.rs @@ -1,7 +1,7 @@ //! Run with: //! //! ```no_rust -//! npx @modelcontextprotocol/inspector +//! npx @modelcontextprotocol/inspector //! //! cargo run -p large_resources_server //! ``` @@ -12,16 +12,14 @@ use neva::prelude::*; async fn resource_meta(name: String) -> ResourceContents { let uri: Uri = format!("file://{name}").into(); let res = get_res_info(uri.clone(), name.clone()); - - ResourceContents::new(uri) - .with_title(name) - .with_json(res) + + ResourceContents::new(uri).with_title(name).with_json(res) } #[resource(uri = "file://{name}")] async fn resource_data(uri: Uri, name: String) -> ResourceContents { // get resource from somewhere - + ResourceContents::new(uri.clone()) .with_title(name.clone()) .with_blob("large file") diff --git a/examples/logging/src/main.rs b/examples/logging/src/main.rs index 8e7dc7e..155c2b9 100644 --- a/examples/logging/src/main.rs +++ b/examples/logging/src/main.rs @@ -5,7 +5,7 @@ //! ``` use neva::prelude::*; -use tracing_subscriber::{filter, reload, prelude::*}; +use tracing_subscriber::{filter, prelude::*, reload}; #[tool] async fn trace_tool() { @@ -18,19 +18,19 @@ async fn trace_tool() { async fn main() { // Configure logging filter let (filter, handle) = reload::Layer::new(filter::LevelFilter::DEBUG); - + // Configure logging tracing_subscriber::registry() - .with(filter) // Specify the default logging level - .with(tracing_subscriber::fmt::layer() - .event_format(notification::NotificationFormatter)) // Specify the MCP notification formatter + .with(filter) // Specify the default logging level + .with(tracing_subscriber::fmt::layer().event_format(notification::NotificationFormatter)) // Specify the MCP notification formatter .init(); - + App::new() - .with_options(|opt| opt - .with_stdio() - .with_mcp_version("2024-11-05") - .with_logging(handle)) + .with_options(|opt| { + opt.with_stdio() + .with_mcp_version("2024-11-05") + .with_logging(handle) + }) .run() .await; } diff --git a/examples/middlewares/src/main.rs b/examples/middlewares/src/main.rs index c6b95e2..73b5732 100644 --- a/examples/middlewares/src/main.rs +++ b/examples/middlewares/src/main.rs @@ -5,7 +5,7 @@ //! ``` use neva::prelude::*; -use tracing_subscriber::{prelude::*, filter, reload}; +use tracing_subscriber::{filter, prelude::*, reload}; #[tool(middleware = [specific_middleware])] async fn greeter(name: String) -> String { @@ -19,20 +19,17 @@ async fn hello_world() -> &'static str { #[resource(uri = "res://{name}")] async fn resource(name: String) -> ResourceContents { - ResourceContents::new(name) - .with_text("Hello, world!") + ResourceContents::new(name).with_text("Hello, world!") } #[prompt(middleware = [specific_middleware])] async fn prompt(topic: String) -> PromptMessage { - PromptMessage::user() - .with(format!("Sample prompt of {topic}")) + PromptMessage::user().with(format!("Sample prompt of {topic}")) } #[prompt] async fn another_prompt(topic: String) -> PromptMessage { - PromptMessage::user() - .with(format!("Another sample prompt of {topic}")) + PromptMessage::user().with(format!("Another sample prompt of {topic}")) } #[handler(command = "ping", middleware = [specific_middleware])] @@ -43,9 +40,9 @@ async fn ping_handler() { async fn logging_middleware(ctx: MwContext, next: Next) -> Response { let id = ctx.id(); tracing::info!("Request start: {id:?}"); - + let resp = next(ctx).await; - + tracing::info!("Request end: {id:?}"); resp } @@ -55,7 +52,7 @@ async fn global_tool_middleware(ctx: MwContext, next: Next) -> Response { next(ctx).await } -// Wraps all requests for the "greeter" tool, "prompt" prompt and ping handler +// Wraps all requests for the "greeter" tool, "prompt" prompt and ping handler async fn specific_middleware(ctx: MwContext, next: Next) -> Response { tracing::info!("Hello from specific middleware"); next(ctx).await @@ -65,16 +62,13 @@ async fn specific_middleware(ctx: MwContext, next: Next) -> Response { async fn main() { let (filter, handle) = reload::Layer::new(filter::LevelFilter::DEBUG); tracing_subscriber::registry() - .with(filter) - .with(tracing_subscriber::fmt::layer() - .event_format(notification::NotificationFormatter)) + .with(filter) + .with(tracing_subscriber::fmt::layer().event_format(notification::NotificationFormatter)) .init(); - + App::new() - .with_options(|opt| opt - .with_stdio() - .with_logging(handle)) - .wrap(logging_middleware) // Wraps all requests that pass through the server + .with_options(|opt| opt.with_stdio().with_logging(handle)) + .wrap(logging_middleware) // Wraps all requests that pass through the server .wrap_tools(global_tool_middleware) // Wraps all tools/call requests that pass through the server .run() .await; diff --git a/examples/pagination/src/main.rs b/examples/pagination/src/main.rs index 4021c0a..dfa5443 100644 --- a/examples/pagination/src/main.rs +++ b/examples/pagination/src/main.rs @@ -2,12 +2,12 @@ //! //! ```no_rust //! npx @modelcontextprotocol/inspector -//! +//! //! cargo run -p example-pagination //! ``` -use std::sync::Arc; use neva::prelude::*; +use std::sync::Arc; #[tool] async fn validate_resource(ctx: Context, uri: Uri) -> Result { @@ -21,15 +21,21 @@ async fn get_resource(name: String, repo: Dc) -> (String, S } #[resources] -async fn list_resources(params: ListResourcesRequestParams, repo: Dc) -> ListResourcesResult { +async fn list_resources( + params: ListResourcesRequestParams, + repo: Dc, +) -> ListResourcesResult { repo.get_resources(params.cursor).await } #[completion] -async fn filter_resources(params: CompleteRequestParams, repo: Dc) -> Completion { +async fn filter_resources( + params: CompleteRequestParams, + repo: Dc, +) -> Completion { let resources = &repo.resources; let filter = params.arg.value; - + let mut matched = Vec::new(); let mut total = 0; @@ -49,10 +55,10 @@ async fn filter_resources(params: CompleteRequestParams, repo: Dc (String, String) { ( format!("res://{name}"), - format!("Some details about resource: {name}") + format!("Some details about resource: {name}"), ) } async fn get_resources(&self, cursor: Option) -> ListResourcesResult { self.resources.paginate(cursor, 10).into() } -} \ No newline at end of file +} diff --git a/examples/progress/src/main.rs b/examples/progress/src/main.rs index caceec8..36f09cb 100644 --- a/examples/progress/src/main.rs +++ b/examples/progress/src/main.rs @@ -1,8 +1,8 @@ //! Run with: //! //! ```no_rust -//! npx @modelcontextprotocol/inspector -//! +//! npx @modelcontextprotocol/inspector +//! //! cargo run -p example-progress //! ``` @@ -12,21 +12,21 @@ use tracing_subscriber::prelude::*; #[tool] async fn long_running_task(token: Meta, command: String) { tracing::info!("Starting {command}"); - + let mut progress = 0; // Simulating a long-running task loop { if progress == 100 { break; } - + tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; progress += 5; - + tracing::info!( - target: "progress", - token = %token, - value = progress, + target: "progress", + token = %token, + value = progress, total = 100 ); } @@ -42,9 +42,7 @@ async fn main() { .init(); App::new() - .with_options(|opt| opt - .with_tasks(|tasks| tasks.with_all()) - .with_default_http()) + .with_options(|opt| opt.with_tasks(|tasks| tasks.with_all()).with_default_http()) .run() .await; } diff --git a/examples/protected-server/src/main.rs b/examples/protected-server/src/main.rs index f75b74f..a6d199a 100644 --- a/examples/protected-server/src/main.rs +++ b/examples/protected-server/src/main.rs @@ -6,7 +6,7 @@ //! JWT_SECRET=a-string-secret-at-least-256-bits-long cargo run -p protected-server //! ``` use neva::prelude::*; -use tracing_subscriber::{filter, reload, prelude::*}; +use tracing_subscriber::{filter, prelude::*, reload}; /// A tool that allowed to everyone #[tool] @@ -36,8 +36,7 @@ async fn restricted_resource(uri: Uri, name: String) -> (String, String) { #[tokio::main] async fn main() { - let secret = std::env::var("JWT_SECRET") - .expect("JWT_SECRET must be set"); + let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set"); let (filter, handle) = reload::Layer::new(filter::LevelFilter::DEBUG); tracing_subscriber::registry() @@ -46,14 +45,17 @@ async fn main() { .init(); App::new() - .with_options(|opt| opt - .with_http(|http| http - .with_auth(|auth| auth - .validate_exp(false) - .with_aud(["some aud"]) - .with_iss(["some issuer"]) - .set_decoding_key(secret.as_bytes()))) - .with_logging(handle)) + .with_options(|opt| { + opt.with_http(|http| { + http.with_auth(|auth| { + auth.validate_exp(false) + .with_aud(["some aud"]) + .with_iss(["some issuer"]) + .set_decoding_key(secret.as_bytes()) + }) + }) + .with_logging(handle) + }) .run() .await; } diff --git a/examples/server/src/main.rs b/examples/server/src/main.rs index 027fffd..86f88d6 100644 --- a/examples/server/src/main.rs +++ b/examples/server/src/main.rs @@ -6,9 +6,9 @@ use neva::prelude::*; -mod tools; -mod resources; mod prompts; +mod resources; +mod tools; #[handler(command = "ping")] async fn ping_handler() { @@ -18,13 +18,13 @@ async fn ping_handler() { #[tokio::main] async fn main() { App::new() - .with_options(|opt| opt - .with_stdio() - .with_mcp_version("2025-06-18") - .with_name("Sample MCP Server") - .with_version("0.1.0.0") - .with_tools(|tools| tools - .with_list_changed())) + .with_options(|opt| { + opt.with_stdio() + .with_mcp_version("2025-06-18") + .with_name("Sample MCP Server") + .with_version("0.1.0.0") + .with_tools(|tools| tools.with_list_changed()) + }) .run() .await; } diff --git a/examples/server/src/prompts.rs b/examples/server/src/prompts.rs index e57f873..eabbd76 100644 --- a/examples/server/src/prompts.rs +++ b/examples/server/src/prompts.rs @@ -14,11 +14,10 @@ use neva::prelude::*; ]"# )] async fn hello_world_code(lang: String) -> PromptMessage { - PromptMessage::user() - .with(format!("Write a hello-world function on {lang}")) + PromptMessage::user().with(format!("Write a hello-world function on {lang}")) } #[prompt(descr = "A prompt that return error")] async fn prompt_err() -> Result { Err(Error::from(ErrorCode::InvalidRequest)) -} \ No newline at end of file +} diff --git a/examples/server/src/resources.rs b/examples/server/src/resources.rs index dee28f9..1fbe292 100644 --- a/examples/server/src/resources.rs +++ b/examples/server/src/resources.rs @@ -27,10 +27,11 @@ async fn list_resources(_params: ListResourcesRequestParams) -> impl Into TextResourceContents { TextResourceContents::new( format!("res://{name}"), - format!("Some details about resource: {name}")) + format!("Some details about resource: {name}"), + ) } #[resource(uri = "res://err/{uri}")] async fn err_resource(_uri: Uri) -> Result { Err(Error::from(ErrorCode::ResourceNotFound)) -} \ No newline at end of file +} diff --git a/examples/server/src/tools.rs b/examples/server/src/tools.rs index 3b220ed..1acfa6e 100644 --- a/examples/server/src/tools.rs +++ b/examples/server/src/tools.rs @@ -54,7 +54,9 @@ async fn say_hello_to(name: String) -> String { }"# )] async fn say_json(arg: Json) -> Json { - let result = Results { message: format!("{}, {}!", arg.say, arg.name) }; + let result = Results { + message: format!("{}, {}!", arg.say, arg.name), + }; result.into() } @@ -70,7 +72,9 @@ async fn say_json(arg: Json) -> Json { }"# )] async fn say_out_json(say: String, name: String) -> Json { - let result = Results { message: format!("{say}, {name}!") }; + let result = Results { + message: format!("{say}, {name}!"), + }; result.into() } @@ -91,7 +95,8 @@ async fn tool_error() -> Result { #[tool(descr = "Resource metadata")] async fn read_resource(ctx: Context, res: Uri) -> Result { let result = ctx.resource(res).await?; - let resource = result.contents + let resource = result + .contents .into_iter() .next() .expect("No resource contents"); diff --git a/examples/subscription/src/main.rs b/examples/subscription/src/main.rs index 9ed3b00..77c57a9 100644 --- a/examples/subscription/src/main.rs +++ b/examples/subscription/src/main.rs @@ -12,23 +12,23 @@ async fn main() -> Result<(), Error> { tracing_subscriber::registry() .with(notification::fmt::layer()) .init(); - - let mut client = Client::new() - .with_options(|opt| opt - .with_stdio("cargo", ["run", "-p", "example-updates"]) - .with_mcp_version("2024-11-05")); - + + let mut client = Client::new().with_options(|opt| { + opt.with_stdio("cargo", ["run", "-p", "example-updates"]) + .with_mcp_version("2024-11-05") + }); + client.connect().await?; client.on_resources_changed(|_| async move { tracing::info!("Resources has been updated"); }); - + client.on_resource_changed(|n| async move { let params = n.params::().unwrap(); - tracing::info!("Resource: {} has been updated", params.uri); + tracing::info!("Resource: {} has been updated", params.uri); }); - + let uri = "res://test_999"; let params = ("uri", uri); let _ = client.call_tool("add_resource", params).await?; diff --git a/examples/updates/src/main.rs b/examples/updates/src/main.rs index 41fd567..bf00ae7 100644 --- a/examples/updates/src/main.rs +++ b/examples/updates/src/main.rs @@ -31,13 +31,11 @@ async fn get_resource(uri: Uri) -> ResourceContents { #[tokio::main] async fn main() { - let mut app = App::new() - .with_options(|opt| opt - .with_stdio() - .with_resources(|res| res - .with_subscribe() - .with_list_changed()) - .with_mcp_version("2024-11-05")); + let mut app = App::new().with_options(|opt| { + opt.with_stdio() + .with_resources(|res| res.with_subscribe().with_list_changed()) + .with_mcp_version("2024-11-05") + }); for i in 0..10 { app.add_resource(format!("res://test_{i}"), format!("test_{i}")); diff --git a/neva/src/app.rs b/neva/src/app.rs index 6c2b167..3bacf9c 100644 --- a/neva/src/app.rs +++ b/neva/src/app.rs @@ -1,58 +1,55 @@ -//! Represents an MCP application +//! Represents an MCP application -use tokio_util::sync::CancellationToken; -use self::{context::{Context, ServerRuntime}, options::{McpOptions, RuntimeMcpOptions}}; -use crate::error::{Error, ErrorCode}; -use crate::transport::{Receiver, Sender, Transport}; -use crate::shared; -use crate::middleware::{MwContext, Next, make_fn::make_mw}; +use self::{ + context::{Context, ServerRuntime}, + options::{McpOptions, RuntimeMcpOptions}, +}; use crate::app::handler::{ - FromHandlerParams, - GenericHandler, - ListResourcesHandler, - CompletionHandler, - HandlerParams, - RequestFunc, - RequestHandler + CompletionHandler, FromHandlerParams, GenericHandler, HandlerParams, ListResourcesHandler, + RequestFunc, RequestHandler, }; +use crate::error::{Error, ErrorCode}; +use crate::middleware::{MwContext, Next, make_fn::make_mw}; +use crate::shared; +use crate::transport::{Receiver, Sender, Transport}; use crate::types::{ - InitializeResult, InitializeRequestParams, IntoResponse, Response, Request, RequestId, Message, - MessageEnvelope, MessageBatch, - CompleteResult, ListToolsRequestParams, CallToolRequestParams, ListToolsResult, CallToolResponse, Tool, ToolHandler, - ListResourceTemplatesRequestParams, ListResourceTemplatesResult, ResourceTemplate, - ListResourcesRequestParams, ListResourcesResult, ReadResourceRequestParams, ReadResourceResult, - SubscribeRequestParams, UnsubscribeRequestParams, Resource, resource::template::ResourceFunc, - ListPromptsRequestParams, ListPromptsResult, GetPromptRequestParams, GetPromptResult, PromptHandler, Prompt, - notification::{Notification, CancelledNotificationParams}, - cursor::Pagination, Uri + CallToolRequestParams, CallToolResponse, CompleteResult, GetPromptRequestParams, + GetPromptResult, InitializeRequestParams, InitializeResult, IntoResponse, + ListPromptsRequestParams, ListPromptsResult, ListResourceTemplatesRequestParams, + ListResourceTemplatesResult, ListResourcesRequestParams, ListResourcesResult, + ListToolsRequestParams, ListToolsResult, Message, MessageBatch, MessageEnvelope, Prompt, + PromptHandler, ReadResourceRequestParams, ReadResourceResult, Request, Resource, + ResourceTemplate, Response, SubscribeRequestParams, Tool, ToolHandler, + UnsubscribeRequestParams, Uri, + cursor::Pagination, + notification::{CancelledNotificationParams, Notification}, + resource::template::ResourceFunc, }; +use tokio_util::sync::CancellationToken; #[cfg(feature = "tasks")] use crate::types::{ - ListTasksRequestParams, ListTasksResult, CancelTaskRequestParams, - GetTaskRequestParams, GetTaskPayloadRequestParams, Task, TaskPayload, + CancelTaskRequestParams, GetTaskPayloadRequestParams, GetTaskRequestParams, + ListTasksRequestParams, ListTasksResult, Task, TaskPayload, }; #[cfg(feature = "tasks")] use context::ToolOrTaskResponse; use std::{ - fmt::{Debug, Formatter}, collections::HashMap, + fmt::{Debug, Formatter}, sync::Arc, }; -#[cfg(feature = "tracing")] -use { - crate::types::notification::SetLevelRequestParams, - tracing::Instrument -}; #[cfg(feature = "di")] use volga_di::{Container, ContainerBuilder}; +#[cfg(feature = "tracing")] +use {crate::types::notification::SetLevelRequestParams, tracing::Instrument}; -pub mod options; +mod collection; pub mod context; pub(crate) mod handler; -mod collection; +pub mod options; const DEFAULT_PAGE_SIZE: usize = 10; @@ -63,11 +60,11 @@ type RequestHandlers = HashMap>; pub struct App { /// MCP server options pub(super) options: McpOptions, - + /// DI container #[cfg(feature = "di")] pub(super) container: ContainerBuilder, - + /// MCP server request handlers handlers: RequestHandlers, } @@ -82,7 +79,7 @@ impl Debug for App { impl App { /// Initializes a new MCP app pub fn new() -> Self { - let mut app = Self { + let mut app = Self { options: McpOptions::default(), handlers: HashMap::new(), #[cfg(feature = "di")] @@ -90,22 +87,31 @@ impl App { }; app.map_handler(crate::commands::INIT, Self::init); - app.map_handler(crate::types::completion::commands::COMPLETE, Self::completion); - + app.map_handler( + crate::types::completion::commands::COMPLETE, + Self::completion, + ); + app.map_handler(crate::types::tool::commands::LIST, Self::tools); app.map_handler(crate::types::tool::commands::CALL, Self::tool); - + app.map_handler(crate::types::resource::commands::LIST, Self::resources); - app.map_handler(crate::types::resource::commands::TEMPLATES_LIST, Self::resource_templates); + app.map_handler( + crate::types::resource::commands::TEMPLATES_LIST, + Self::resource_templates, + ); app.map_handler(crate::types::resource::commands::READ, Self::resource); - app.map_handler(crate::types::resource::commands::SUBSCRIBE, Self::resource_subscribe); - app.map_handler(crate::types::resource::commands::UNSUBSCRIBE, Self::resource_unsubscribe); - + app.map_handler( + crate::types::resource::commands::SUBSCRIBE, + Self::resource_subscribe, + ); + app.map_handler( + crate::types::resource::commands::UNSUBSCRIBE, + Self::resource_unsubscribe, + ); + app.map_handler(crate::types::prompt::commands::LIST, Self::prompts); app.map_handler(crate::types::prompt::commands::GET, Self::prompt); - - app.map_handler(crate::types::notification::commands::INITIALIZED, Self::notifications_init); - app.map_handler(crate::types::notification::commands::CANCELLED, Self::notifications_cancel); #[cfg(feature = "tasks")] { @@ -114,12 +120,15 @@ impl App { app.map_handler(crate::types::task::commands::CANCEL, Self::cancel_task); app.map_handler(crate::types::task::commands::RESULT, Self::task_result); } - + app.map_handler(crate::commands::PING, Self::ping); #[cfg(feature = "tracing")] - app.map_handler(crate::types::notification::commands::SET_LOG_LEVEL, Self::set_log_level); - + app.map_handler( + crate::types::notification::commands::SET_LOG_LEVEL, + Self::set_log_level, + ); + app } @@ -146,7 +155,9 @@ impl App { /// ``` pub fn run_blocking(self) { if tokio::runtime::Handle::try_current().is_ok() { - panic!("`App::run_blocking()` cannot be called inside an existing Tokio runtime. Use `run().await` instead."); + panic!( + "`App::run_blocking()` cannot be called inside an existing Tokio runtime. Use `run().await` instead." + ); } let runtime = match tokio::runtime::Builder::new_multi_thread() @@ -163,45 +174,45 @@ impl App { } }; - runtime.block_on(async { - self.run().await - }); + runtime.block_on(async { self.run().await }); } - + /// Run the MCP server - /// + /// /// # Example /// ```no_run /// use neva::App; - /// + /// /// # #[tokio::main] /// # async fn main() { /// let mut app = App::new(); - /// + /// /// // configure tools, resources, prompts - /// + /// /// app.run().await; /// # } /// ``` pub async fn run(mut self) { #[cfg(feature = "macros")] self.register_methods(); - + #[cfg(feature = "tracing")] - self.options.add_middleware(make_mw(Self::tracing_middleware)); - self.options.add_middleware(make_mw(Self::message_middleware)); - + self.options + .add_middleware(make_mw(Self::tracing_middleware)); + self.options + .add_middleware(make_mw(Self::message_middleware)); + let mut transport = self.options.transport(); let cancellation_token = transport.start(); self.wait_for_shutdown_signal(cancellation_token.clone()); - + let (sender, mut receiver) = transport.split(); let runtime = ServerRuntime::new( - sender, - self.options, + sender, + self.options, self.handlers, #[cfg(feature = "di")] - self.container.build() + self.container.build(), ); loop { tokio::select! { @@ -227,11 +238,11 @@ impl App { } } } - + /// Configure MCP server options pub fn with_options(mut self, config: F) -> Self - where - F: FnOnce(McpOptions) -> McpOptions + where + F: FnOnce(McpOptions) -> McpOptions, { self.options = config(self.options); self @@ -242,12 +253,12 @@ impl App { /// # Example /// ```no_run /// use neva::App; - /// + /// /// # #[tokio::main] /// # async fn main() { /// let mut app = App::new(); /// - /// app.map_handler("ping", || async { + /// app.map_handler("ping", || async { /// "pong" /// }); /// @@ -255,7 +266,7 @@ impl App { /// # } /// ``` pub fn map_handler(&mut self, name: impl Into, handler: F) -> &mut Self - where + where F: GenericHandler, R: IntoResponse + Send + 'static, Args: FromHandlerParams + Send + Sync + 'static, @@ -276,7 +287,7 @@ impl App { /// # async fn main() { /// let mut app = App::new(); /// - /// app.map_tool("hello", |name: String| async move { + /// app.map_tool("hello", |name: String| async move { /// format!("Hello, {name}") /// }); /// @@ -291,9 +302,13 @@ impl App { { self.options.add_tool(Tool::new(name, handler)) } - + /// Adds a known resource - pub fn add_resource, S: Into>(&mut self, uri: U, name: S) -> &mut Resource { + pub fn add_resource, S: Into>( + &mut self, + uri: U, + name: S, + ) -> &mut Resource { let resource = Resource::new(uri, name); self.options.add_resource(resource) } @@ -316,10 +331,10 @@ impl App { /// # } /// ``` pub fn map_resource( - &mut self, - uri: impl Into, - name: impl Into, - handler: F + &mut self, + uri: impl Into, + name: impl Into, + handler: F, ) -> &mut ResourceTemplate where F: GenericHandler, @@ -329,7 +344,7 @@ impl App { { let handler = ResourceFunc::new(handler); let template = ResourceTemplate::new(uri, name); - + self.options.add_resource_template(template, handler) } @@ -384,7 +399,7 @@ impl App { where F: ListResourcesHandler + Clone + Send + Sync + 'static, Args: FromHandlerParams + Send + Sync + 'static, - R: Into + R: Into, { let handler = move |params, args| { let handler = handler.clone(); @@ -415,7 +430,7 @@ impl App { where F: CompletionHandler + Clone + Send + Sync + 'static, Args: FromHandlerParams + Send + Sync + 'static, - R: Into + R: Into, { let handler = move |params, args| { let handler = handler.clone(); @@ -427,8 +442,8 @@ impl App { /// Connection initialization handler async fn init( - options: RuntimeMcpOptions, - _params: InitializeRequestParams + options: RuntimeMcpOptions, + _params: InitializeRequestParams, ) -> Result { Ok(InitializeResult::new(&options)) } @@ -438,13 +453,11 @@ impl App { // return default as its non-optional capability so far CompleteResult::default() } - + /// Tools request handler - async fn tools( - options: RuntimeMcpOptions, - params: ListToolsRequestParams - ) -> ListToolsResult { - options.list_tools() + async fn tools(options: RuntimeMcpOptions, params: ListToolsRequestParams) -> ListToolsResult { + options + .list_tools() .await .paginate(params.cursor, DEFAULT_PAGE_SIZE) .into() @@ -453,9 +466,10 @@ impl App { /// Resources request handler async fn resources( options: RuntimeMcpOptions, - params: ListResourcesRequestParams + params: ListResourcesRequestParams, ) -> ListResourcesResult { - options.list_resources() + options + .list_resources() .await .paginate(params.cursor, DEFAULT_PAGE_SIZE) .into() @@ -463,26 +477,28 @@ impl App { /// Resource templates request handler async fn resource_templates( - options: RuntimeMcpOptions, - params: ListResourceTemplatesRequestParams + options: RuntimeMcpOptions, + params: ListResourceTemplatesRequestParams, ) -> ListResourceTemplatesResult { - options.list_resource_templates() + options + .list_resource_templates() .await .paginate(params.cursor, DEFAULT_PAGE_SIZE) .into() } - + /// Prompts request handler async fn prompts( - options: RuntimeMcpOptions, - params: ListPromptsRequestParams + options: RuntimeMcpOptions, + params: ListPromptsRequestParams, ) -> ListPromptsResult { - options.list_prompts() + options + .list_prompts() .await .paginate(params.cursor, DEFAULT_PAGE_SIZE) .into() } - + /// A tool call request handler #[cfg(not(feature = "tasks"))] async fn tool(ctx: Context, params: CallToolRequestParams) -> Result { @@ -491,60 +507,53 @@ impl App { /// A tool call request handler #[cfg(feature = "tasks")] - async fn tool(ctx: Context, params: CallToolRequestParams) -> Result { + async fn tool( + ctx: Context, + params: CallToolRequestParams, + ) -> Result { ctx.call_tool_with_task(params).await } /// A read resource request handler - async fn resource(ctx: Context, params: ReadResourceRequestParams) -> Result { + async fn resource( + ctx: Context, + params: ReadResourceRequestParams, + ) -> Result { ctx.read_resource(params).await } - + /// A get prompt request handler - async fn prompt(ctx: Context, params: GetPromptRequestParams) -> Result { + async fn prompt( + ctx: Context, + params: GetPromptRequestParams, + ) -> Result { ctx.get_prompt(params).await } /// Ping request handler async fn ping() {} - - /// A notification initialization request handler - async fn notifications_init() {} - - /// A notification cancel request handler - async fn notifications_cancel( - options: RuntimeMcpOptions, - params: CancelledNotificationParams - ) { - options.cancel_request(¶ms.request_id); - } - + /// A subscription to a resource change request handler - async fn resource_subscribe( - mut ctx: Context, - params: SubscribeRequestParams - ) { + async fn resource_subscribe(mut ctx: Context, params: SubscribeRequestParams) { ctx.subscribe_to_resource(params.uri); } /// An unsubscription to from resource change request handler - async fn resource_unsubscribe( - mut ctx: Context, - params: UnsubscribeRequestParams - ) { + async fn resource_unsubscribe(mut ctx: Context, params: UnsubscribeRequestParams) { ctx.unsubscribe_from_resource(¶ms.uri); } - + /// Tasks request handler #[cfg(feature = "tasks")] async fn tasks( options: RuntimeMcpOptions, - params: ListTasksRequestParams + params: ListTasksRequestParams, ) -> Result { - if !options.is_tasks_list_supported() { + if !options.is_tasks_list_supported() { return Err(Error::new( - ErrorCode::InvalidRequest, - "Server does not support support tasks/list requests.")); + ErrorCode::InvalidRequest, + "Server does not support support tasks/list requests.", + )); } Ok(options .list_tasks() @@ -556,23 +565,21 @@ impl App { #[cfg(feature = "tasks")] async fn cancel_task( options: RuntimeMcpOptions, - params: CancelTaskRequestParams + params: CancelTaskRequestParams, ) -> Result { if options.is_tasks_cancellation_supported() { options.cancel_task(¶ms.id) } else { Err(Error::new( ErrorCode::InvalidRequest, - "Server does not support support tasks/cancel requests.")) + "Server does not support support tasks/cancel requests.", + )) } } /// A task status retrieval request handler #[cfg(feature = "tasks")] - async fn task( - options: RuntimeMcpOptions, - params: GetTaskRequestParams - ) -> Result { + async fn task(options: RuntimeMcpOptions, params: GetTaskRequestParams) -> Result { options.get_task_status(¶ms.id) } @@ -580,32 +587,32 @@ impl App { #[cfg(feature = "tasks")] async fn task_result( options: RuntimeMcpOptions, - params: GetTaskPayloadRequestParams + params: GetTaskPayloadRequestParams, ) -> Result { options.get_task_result(¶ms.id).await } - + /// Sets the logging level #[cfg(feature = "tracing")] async fn set_log_level( options: RuntimeMcpOptions, - params: SetLevelRequestParams + params: SetLevelRequestParams, ) -> Result<(), Error> { let current_level = options.log_level(); tracing::debug!( - logger = "neva", - "Logging level has been changed from {:?} to {:?}", current_level, params.level + logger = "neva", + "Logging level has been changed from {:?} to {:?}", + current_level, + params.level ); - + options.set_log_level(params.level) } - + #[cfg(feature = "tracing")] async fn tracing_middleware(ctx: MwContext, next: Next) -> Response { let span = create_tracing_span(ctx.session_id().cloned()); - next(ctx) - .instrument(span) - .await + next(ctx).instrument(span).await } #[inline] @@ -614,8 +621,8 @@ impl App { } async fn execute_batch(batch: MessageBatch, runtime: ServerRuntime) { - use futures_util::future::join_all; use crate::transport::TransportProtoSender; + use futures_util::future::join_all; // Capture the incoming batch's correlation and HTTP-context fields. // `id` + `session_id` are needed so the response batch can be routed @@ -681,7 +688,7 @@ impl App { .await; } MessageEnvelope::Notification(notification) => { - Self::handle_notification(notification).await; + Self::handle_notification(notification, runtime.clone()).await; } MessageEnvelope::Response(mut resp) => { // Apply the batch's session context so that @@ -748,7 +755,11 @@ impl App { Err(_err) => { // Unreachable in practice: envelopes are non-empty above. #[cfg(feature = "tracing")] - tracing::error!(logger = "neva", "Failed to construct batch response: {:?}", _err); + tracing::error!( + logger = "neva", + "Failed to construct batch response: {:?}", + _err + ); return; } }; @@ -765,60 +776,69 @@ impl App { } async fn message_middleware(ctx: MwContext, _: Next) -> Response { - let MwContext { - msg, + let MwContext { + msg, runtime, #[cfg(feature = "di")] - scope + scope, } = ctx; let id = msg.id(); let mut sender = runtime.sender(); - - let resp = Self::handle_message( - msg, + + if let Some(resp) = Self::handle_message( + msg, runtime, #[cfg(feature = "di")] - scope - ).await; - - if let Err(_err) = sender.send(resp.into()).await { + scope, + ) + .await + && let Err(_err) = sender.send(resp.into()).await + { #[cfg(feature = "tracing")] tracing::error!( - logger = "neva", - error = format!("Error sending response: {:?}", _err)); + logger = "neva", + error = format!("Error sending response: {:?}", _err) + ); } - + Response::empty(id) } - + #[inline] async fn handle_message( - msg: Message, + msg: Message, runtime: ServerRuntime, - #[cfg(feature = "di")] - scope: Container - ) -> Response { + #[cfg(feature = "di")] scope: Container, + ) -> Option { match msg { - Message::Request(req) => Self::handle_request( - req, - runtime, - #[cfg(feature = "di")] - scope - ).await, - Message::Response(resp) => Self::handle_response(resp, runtime).await, - Message::Notification(notification) => Self::handle_notification(notification).await, + Message::Request(req) => Some( + Self::handle_request( + req, + runtime, + #[cfg(feature = "di")] + scope, + ) + .await, + ), + Message::Response(resp) => Some(Self::handle_response(resp, runtime).await), + Message::Notification(notification) => { + // JSON-RPC 2.0 §4: notifications must never receive a response. + Self::handle_notification(notification, runtime).await; + None + } Message::Batch(_) => { // Batches are dispatched via execute_batch before reaching handle_message - unreachable!("Message::Batch should be intercepted in App::run before handle_message") + unreachable!( + "Message::Batch should be intercepted in App::run before handle_message" + ) } } } - + async fn handle_request( - req: Request, + req: Request, runtime: ServerRuntime, - #[cfg(feature = "di")] - scope: Container + #[cfg(feature = "di")] scope: Container, ) -> Response { #[cfg(feature = "http-server")] let mut req = req; @@ -828,19 +848,17 @@ impl App { #[cfg(not(feature = "http-server"))] let context = runtime.context(session_id); - + #[cfg(feature = "http-server")] let context = { let headers = std::mem::take(&mut req.headers); - let claims = req.claims - .take() - .map(|c| *c); + let claims = req.claims.take().map(|c| *c); runtime.context(session_id, headers, claims) }; - + #[cfg(feature = "di")] let context = context.with_scope(scope); - + let options = runtime.options(); let handlers = runtime.request_handlers(); let token = options.track_request(&full_id); @@ -849,18 +867,18 @@ impl App { tracing::trace!(logger = "neva", "Received: {:?}", req); let resp = if let Some(handler) = handlers.get(&req.method) { tokio::select! { - resp = handler.call(HandlerParams::Request(context, req)) => { - options.complete_request(&full_id); - resp - } - _ = token.cancelled() => { - #[cfg(feature = "tracing")] - tracing::debug!( - logger = "neva", - "The request with ID: {} has been cancelled", full_id); - Err(Error::from(ErrorCode::RequestCancelled)) - } + resp = handler.call(HandlerParams::Request(context, req)) => { + options.complete_request(&full_id); + resp + } + _ = token.cancelled() => { + #[cfg(feature = "tracing")] + tracing::debug!( + logger = "neva", + "The request with ID: {} has been cancelled", full_id); + Err(Error::from(ErrorCode::RequestCancelled)) } + } } else { Err(Error::from(ErrorCode::MethodNotFound)) }; @@ -871,16 +889,12 @@ impl App { } resp } - + async fn handle_response(resp: Response, runtime: ServerRuntime) -> Response { let resp_id = resp.id().clone(); - let session_id = resp - .session_id() - .cloned(); - - runtime - .pending_requests() - .complete(resp); + let session_id = resp.session_id().cloned(); + + runtime.pending_requests().complete(resp); let mut resp = Response::empty(resp_id); if let Some(session_id) = session_id { @@ -888,14 +902,24 @@ impl App { } resp } - + #[inline] - async fn handle_notification(notification: Notification) -> Response { - if let crate::types::notification::commands::MESSAGE = notification.method.as_str() { - #[cfg(feature = "tracing")] - notification.write(); + async fn handle_notification(notification: Notification, runtime: ServerRuntime) { + match notification.method.as_str() { + crate::types::notification::commands::CANCELLED => { + if let Some(params) = notification.params + && let Ok(params) = + serde_json::from_value::(params) + { + runtime.options().cancel_request(¶ms.request_id); + } + } + crate::types::notification::commands::MESSAGE => { + #[cfg(feature = "tracing")] + notification.write(); + } + _ => {} } - Response::empty(RequestId::default()) } #[inline] @@ -925,7 +949,8 @@ mod tests { let batch = MessageBatch::new(vec![ MessageEnvelope::Notification(Notification::new("notifications/foo", None)), MessageEnvelope::Notification(Notification::new("notifications/bar", None)), - ]).expect("non-empty batch must be constructable"); + ]) + .expect("non-empty batch must be constructable"); // Replicate the filter logic from execute_batch: // Request → Some(response slot), Notification/Response → None @@ -953,7 +978,8 @@ mod tests { let batch = MessageBatch::new(vec![ MessageEnvelope::Request(req1), MessageEnvelope::Request(req2), - ]).expect("non-empty batch must be constructable"); + ]) + .expect("non-empty batch must be constructable"); // Replicate the filter: only Request envelopes produce response slots let response_slots: Vec = batch @@ -964,6 +990,10 @@ mod tests { }) .collect(); - assert_eq!(response_slots.len(), 2, "two requests must produce two response slots"); + assert_eq!( + response_slots.len(), + 2, + "two requests must produce two response slots" + ); } -} \ No newline at end of file +} diff --git a/neva/src/app/collection.rs b/neva/src/app/collection.rs index ce176c3..3dabf5e 100644 --- a/neva/src/app/collection.rs +++ b/neva/src/app/collection.rs @@ -1,15 +1,15 @@ //! Represents a generic-collection implementation that can be mutated during runtime +use crate::error::{Error, ErrorCode}; use std::collections::HashMap; use tokio::sync::RwLock; -use crate::error::{Error, ErrorCode}; /// Generic collection with 2 states: /// - [`Collection::Init`] - initialization state can be mutated without blocking /// - [`Collection::Runtime`] - runtime state, the collection can be read by multiple readers and will blocked by only one writer pub(crate) enum Collection { Init(HashMap), - Runtime(RwLock>) + Runtime(RwLock>), } impl Collection { @@ -21,7 +21,7 @@ impl Collection { /// Turns the [`Collection`] into [`Collection::Runtime`] state #[inline] pub(crate) fn into_runtime(self) -> Self { - if let Self::Init(map) = self { + if let Self::Init(map) = self { Self::Runtime(RwLock::new(map)) } else { self @@ -33,29 +33,23 @@ impl Collection { pub(crate) async fn get(&self, key: &str) -> Option { match self { Self::Init(map) => map.get(key).cloned(), - Self::Runtime(lock) => { - lock.read() - .await - .get(key) - .cloned() - } + Self::Runtime(lock) => lock.read().await.get(key).cloned(), } } /// Inserts a key-value pair into this [`Collection`] when it in [`Collection::Runtime`] state. - /// + /// /// For the [`Collection::Init`] state - use the `as_mut().insert()` method. #[inline] pub(crate) async fn insert(&self, key: String, value: T) -> Result<(), Error> { match self { - Self::Init(_) => return Err(Error::new( - ErrorCode::InternalError, - "Attempt to insert a value during runtime when collection is in the init state")), - Self::Runtime(lock) => { - lock.write() - .await - .insert(key, value) + Self::Init(_) => { + return Err(Error::new( + ErrorCode::InternalError, + "Attempt to insert a value during runtime when collection is in the init state", + )); } + Self::Runtime(lock) => lock.write().await.insert(key, value), }; Ok(()) } @@ -66,14 +60,13 @@ impl Collection { #[inline] pub(crate) async fn remove(&self, key: &str) -> Result, Error> { let value = match self { - Self::Init(_) => return Err(Error::new( - ErrorCode::InternalError, - "Attempt to remove a value during runtime when collection is in the init state")), - Self::Runtime(lock) => { - lock.write() - .await - .remove(key) + Self::Init(_) => { + return Err(Error::new( + ErrorCode::InternalError, + "Attempt to remove a value during runtime when collection is in the init state", + )); } + Self::Runtime(lock) => lock.write().await.remove(key), }; Ok(value) } @@ -82,14 +75,8 @@ impl Collection { #[inline] pub(crate) async fn values(&self) -> Vec { match self { - Self::Init(map) => map - .values() - .cloned() - .collect(), - Self::Runtime(lock) => lock.read().await - .values() - .cloned() - .collect() + Self::Init(map) => map.values().cloned().collect(), + Self::Runtime(lock) => lock.read().await.values().cloned().collect(), } } } @@ -114,4 +101,4 @@ impl AsRef> for Collection { unreachable!() } } -} \ No newline at end of file +} diff --git a/neva/src/app/context.rs b/neva/src/app/context.rs index 81263b4..3d98089 100644 --- a/neva/src/app/context.rs +++ b/neva/src/app/context.rs @@ -1,53 +1,53 @@ -//! Server runtime context utilities +//! Server runtime context utilities -use tokio::time::timeout; +use super::{ + handler::RequestHandler, + options::{McpOptions, RuntimeMcpOptions}, +}; use crate::error::{Error, ErrorCode}; use crate::transport::Sender; -use super::{options::{McpOptions, RuntimeMcpOptions}, handler::RequestHandler}; use crate::{ - shared::{IntoArgs, RequestQueue}, - middleware::{MwContext, Next}, - transport::TransportProtoSender, + middleware::{MwContext, Next}, + shared::{IntoArgs, RequestQueue}, + transport::TransportProtoSender, types::{ - Tool, CallToolRequestParams, CallToolResponse, - ToolUse, ToolResult, - Resource, ReadResourceRequestParams, ReadResourceResult, - Prompt, GetPromptRequestParams, GetPromptResult, - RequestId, Request, Response, Uri, - Message, + CallToolRequestParams, CallToolResponse, GetPromptRequestParams, GetPromptResult, Message, + Prompt, ReadResourceRequestParams, ReadResourceResult, Request, RequestId, Resource, + Response, Tool, ToolResult, ToolUse, Uri, + elicitation::{ElicitRequestParams, ElicitResult, ElicitationCompleteParams}, notification::Notification, - root::{ListRootsRequestParams, ListRootsResult}, resource::SubscribeRequestParams, + root::{ListRootsRequestParams, ListRootsResult}, sampling::{CreateMessageRequestParams, CreateMessageResult}, - elicitation::{ElicitRequestParams, ElicitResult, ElicitationCompleteParams} - } + }, }; use std::{ - fmt::{Debug, Formatter}, collections::HashMap, + fmt::{Debug, Formatter}, + sync::Arc, time::Duration, - sync::Arc }; +use tokio::time::timeout; -#[cfg(feature = "http-server")] -use { - crate::transport::http::server::{validate_roles, validate_permissions}, - crate::auth::DefaultClaims, - volga::headers::HeaderMap -}; -#[cfg(feature = "di")] -use volga_di::Container; -#[cfg(feature = "tasks")] -use serde::de::DeserializeOwned; #[cfg(feature = "tasks")] use crate::{ shared::Either, types::{ - Task, TaskPayload, CreateTaskResult, tool::TaskSupport, - ListTasksRequestParams,ListTasksResult, Cursor, - CancelTaskRequestParams, GetTaskPayloadRequestParams, GetTaskRequestParams, + CancelTaskRequestParams, CreateTaskResult, Cursor, GetTaskPayloadRequestParams, + GetTaskRequestParams, ListTasksRequestParams, ListTasksResult, Task, TaskPayload, + tool::TaskSupport, }, }; +#[cfg(feature = "tasks")] +use serde::de::DeserializeOwned; +#[cfg(feature = "di")] +use volga_di::Container; +#[cfg(feature = "http-server")] +use { + crate::auth::DefaultClaims, + crate::transport::http::server::{validate_permissions, validate_roles}, + volga::headers::HeaderMap, +}; #[cfg(feature = "tasks")] pub(crate) type ToolOrTaskResponse = Either; @@ -59,19 +59,19 @@ type RequestHandlers = HashMap>; pub(crate) struct ServerRuntime { /// Represents MCP server options options: RuntimeMcpOptions, - + /// Represents registered request handlers handlers: Arc, - + /// Represents a queue of pending requests pending: RequestQueue, - + /// Represents a sender that depends on selected transport protocol sender: TransportProtoSender, - + /// Global middlewares entrypoint mw_start: Option, - + /// Represents a DI container #[cfg(feature = "di")] pub(crate) container: Container, @@ -82,24 +82,24 @@ pub(crate) struct ServerRuntime { pub struct Context { /// Represents current session id pub session_id: Option, - + /// Represents HTTP headers of the current request #[cfg(feature = "http-server")] pub headers: HeaderMap, - + /// Represents JWT claims of the current request #[cfg(feature = "http-server")] pub(crate) claims: Option, - + /// Represents MCP server options pub(crate) options: RuntimeMcpOptions, - + /// Represents a queue of pending requests pending: RequestQueue, - + /// Represents a sender that depends on selected transport protocol sender: TransportProtoSender, - + /// Represents a timeout for the current request timeout: Duration, @@ -121,11 +121,10 @@ impl Debug for Context { impl ServerRuntime { /// Creates a new server runtime pub(crate) fn new( - sender: TransportProtoSender, + sender: TransportProtoSender, mut options: McpOptions, handlers: RequestHandlers, - #[cfg(feature = "di")] - container: Container + #[cfg(feature = "di")] container: Container, ) -> Self { let middlewares = options.middlewares.take(); Self { @@ -138,14 +137,14 @@ impl ServerRuntime { container, } } - + /// Provides a [`RuntimeMcpOptions`] - pub(crate) fn options(&self) -> RuntimeMcpOptions { + pub(crate) fn options(&self) -> RuntimeMcpOptions { self.options.clone() } /// Provides the current connections sender - pub(crate) fn sender(&self) -> TransportProtoSender { + pub(crate) fn sender(&self) -> TransportProtoSender { self.sender.clone() } @@ -157,12 +156,12 @@ impl ServerRuntime { self.sender = sender; self } - + /// Provides a hash map of registered request handlers - pub(crate) fn request_handlers(&self) -> Arc { + pub(crate) fn request_handlers(&self) -> Arc { self.handlers.clone() } - + /// Creates a new MCP request [`Context`] #[cfg(not(feature = "http-server"))] pub(crate) fn context(&self, session_id: Option) -> Context { @@ -180,10 +179,10 @@ impl ServerRuntime { /// Creates a new MCP request [`Context`] #[cfg(feature = "http-server")] pub(crate) fn context( - &self, - session_id: Option, - headers: HeaderMap, - claims: Option + &self, + session_id: Option, + headers: HeaderMap, + claims: Option, ) -> Context { Context { session_id, @@ -197,12 +196,12 @@ impl ServerRuntime { scope: None, } } - + /// Provides a "queue" of pending requests pub(crate) fn pending_requests(&self) -> &RequestQueue { &self.pending } - + /// Starts the middleware pipeline #[inline] pub(crate) async fn execute(self, msg: Message) { @@ -217,7 +216,7 @@ impl Context { pub async fn tools(&self) -> Vec { self.options.tools.values().await } - + /// Finds a tool by `name` pub async fn find_tool(&self, name: &str) -> Option { self.options.tools.get(name).await @@ -226,20 +225,18 @@ impl Context { /// Returns a list of tools by name. /// If some tools requested in `names` are missing, they won't be in the result list. pub async fn find_tools(&self, names: impl IntoIterator) -> Vec { - futures_util::future::join_all( - names.into_iter() - .map(|name| self.options.tools.get(name))) + futures_util::future::join_all(names.into_iter().map(|name| self.options.tools.get(name))) .await .into_iter() .flatten() .collect() } - - /// Initiates a tool call once a [`ToolUse`] request received from assistant + + /// Initiates a tool call once a [`ToolUse`] request received from assistant /// withing a sampling window. /// /// For multiple [`ToolUse`] requests, use the [`Context::use_tools`] method. - /// + /// /// # Example /// ```no_run /// # #[cfg(feature = "server-macros")] { @@ -254,26 +251,24 @@ impl Context { /// /// # Ok(()) /// } - /// + /// /// #[tool] /// async fn get_weather(city: String) -> String { /// // ... - /// + /// /// format!("Sunny in {city}") /// } /// # } /// ``` pub async fn use_tool(&self, tool: ToolUse) -> ToolResult { let id = tool.id.clone(); - let res = self.clone() - .call_tool(tool.into()) - .await; + let res = self.clone().call_tool(tool.into()).await; match res { Ok(res) => ToolResult::new(id, res), - Err(err) => ToolResult::error(id, err) + Err(err) => ToolResult::error(id, err), } } - + /// Initiates a parallel tool calls for multiple [`ToolUse`] requests. /// /// For a single [`ToolUse`] use the [`Context::use_tool`] method. @@ -297,14 +292,12 @@ impl Context { /// # } /// ``` pub async fn use_tools(&self, tools: I) -> Vec - where - I : IntoIterator + where + I: IntoIterator, { - futures_util::future::join_all( - tools.into_iter().map(|t| self.use_tool(t))) - .await + futures_util::future::join_all(tools.into_iter().map(|t| self.use_tool(t))).await } - + /// Gets the prompt by name /// /// # Example @@ -328,11 +321,7 @@ impl Context { /// } /// # } /// ``` - pub async fn prompt( - &self, - name: N, - args: Args - ) -> Result + pub async fn prompt(&self, name: N, args: Args) -> Result where N: Into, Args: IntoArgs, @@ -340,13 +329,11 @@ impl Context { let params = GetPromptRequestParams { name: name.into(), args: args.into_args(), - meta: None + meta: None, }; - self.clone() - .get_prompt(params) - .await + self.clone().get_prompt(params).await } - + /// Reads a resource contents /// /// # Example @@ -362,94 +349,81 @@ impl Context { /// /// # Ok(()) /// } - /// + /// /// #[resource(uri = "file://{name}")] /// async fn get_doc(name: String) -> TextResourceContents { /// // read the doc - /// - /// # TextResourceContents::new("", "") + /// + /// # TextResourceContents::new("", "") /// } /// # } /// ``` pub async fn resource(&self, uri: impl Into) -> Result { let uri = uri.into(); let params = ReadResourceRequestParams::from(uri); - self.clone() - .read_resource(params) - .await + self.clone().read_resource(params).await } - + /// Adds a new resource and notifies clients pub async fn add_resource(&mut self, res: impl Into) -> Result<(), Error> { let res: Resource = res.into(); - self.options - .resources - .insert(res.name.clone(), res) - .await?; + self.options.resources.insert(res.name.clone(), res).await?; if self.options.is_resource_list_changed_supported() { - self.send_notification( - crate::types::resource::commands::LIST_CHANGED, - None - ).await - } else { + self.send_notification(crate::types::resource::commands::LIST_CHANGED, None) + .await + } else { Ok(()) } } /// Removes a resource and notifies clients - pub async fn remove_resource(&mut self, uri: impl Into) -> Result, Error> { - let removed = self.options - .resources - .remove(&uri.into()) - .await?; + pub async fn remove_resource( + &mut self, + uri: impl Into, + ) -> Result, Error> { + let removed = self.options.resources.remove(&uri.into()).await?; if removed.is_some() && self.options.is_resource_list_changed_supported() { - self.send_notification( - crate::types::resource::commands::LIST_CHANGED, - None - ).await?; + self.send_notification(crate::types::resource::commands::LIST_CHANGED, None) + .await?; } - + Ok(removed) } - + /// Sends a [`Notification`] that the resource with the `uri` has been updated pub async fn resource_updated(&mut self, uri: impl Into) -> Result<(), Error> { - if !self.options.is_resource_subscription_supported() { + if !self.options.is_resource_subscription_supported() { return Err(Error::new( - ErrorCode::MethodNotFound, - "Server does not support sending resource/updated notifications")) + ErrorCode::MethodNotFound, + "Server does not support sending resource/updated notifications", + )); } - + let uri = uri.into(); if self.is_subscribed(&uri) { let params = serde_json::to_value(SubscribeRequestParams::from(uri)).ok(); - self.send_notification(crate::types::resource::commands::UPDATED, params).await - } else { + self.send_notification(crate::types::resource::commands::UPDATED, params) + .await + } else { Ok(()) } } /// Adds a subscription to the resource with the [`Uri`] pub fn subscribe_to_resource(&mut self, uri: impl Into) { - self.options - .resource_subscriptions - .insert(uri.into()); + self.options.resource_subscriptions.insert(uri.into()); } - + /// Removes a subscription to the resource with the [`Uri`] pub fn unsubscribe_from_resource(&mut self, uri: &Uri) { - self.options - .resource_subscriptions - .remove(uri); + self.options.resource_subscriptions.remove(uri); } - + /// Returns `true` if there is a subscription to changes of the resource with the [`Uri`] pub fn is_subscribed(&self, uri: &Uri) -> bool { - self.options - .resource_subscriptions - .contains(uri) + self.options.resource_subscriptions.contains(uri) } /// Adds a new prompt and notifies clients @@ -460,27 +434,23 @@ impl Context { .await?; if self.options.is_prompts_list_changed_supported() { - self.send_notification( - crate::types::prompt::commands::LIST_CHANGED, - None - ).await + self.send_notification(crate::types::prompt::commands::LIST_CHANGED, None) + .await } else { Ok(()) } } /// Removes a prompt and notifies clients - pub async fn remove_prompt(&mut self, name: impl Into) -> Result, Error> { - let removed = self.options - .prompts - .remove(&name.into()) - .await?; + pub async fn remove_prompt( + &mut self, + name: impl Into, + ) -> Result, Error> { + let removed = self.options.prompts.remove(&name.into()).await?; if removed.is_some() && self.options.is_prompts_list_changed_supported() { - self.send_notification( - crate::types::prompt::commands::LIST_CHANGED, - None - ).await?; + self.send_notification(crate::types::prompt::commands::LIST_CHANGED, None) + .await?; } Ok(removed) @@ -488,16 +458,11 @@ impl Context { /// Adds a new prompt and notifies clients pub async fn add_tool(&mut self, tool: Tool) -> Result<(), Error> { - self.options - .tools - .insert(tool.name.clone(), tool) - .await?; + self.options.tools.insert(tool.name.clone(), tool).await?; if self.options.is_tools_list_changed_supported() { - self.send_notification( - crate::types::tool::commands::LIST_CHANGED, - None - ).await + self.send_notification(crate::types::tool::commands::LIST_CHANGED, None) + .await } else { Ok(()) } @@ -505,47 +470,45 @@ impl Context { /// Removes a tool and notifies clients pub async fn remove_tool(&mut self, name: impl Into) -> Result, Error> { - let removed = self.options - .tools - .remove(&name.into()) - .await?; + let removed = self.options.tools.remove(&name.into()).await?; if removed.is_some() && self.options.is_tools_list_changed_supported() { - self.send_notification( - crate::types::tool::commands::LIST_CHANGED, - None - ).await?; + self.send_notification(crate::types::tool::commands::LIST_CHANGED, None) + .await?; } Ok(removed) } - + #[inline] - pub(crate) async fn read_resource(self, params: ReadResourceRequestParams) -> Result { + pub(crate) async fn read_resource( + self, + params: ReadResourceRequestParams, + ) -> Result { let opt = self.options.clone(); match opt.read_resource(¶ms.uri) { Some((handler, args)) => { #[cfg(feature = "http-server")] { - let template = opt.resources_templates - .get(&handler.template) - .await; + let template = opt.resources_templates.get(&handler.template).await; self.validate_claims( template.as_ref().and_then(|t| t.roles.as_deref()), - template.as_ref().and_then(|t| t.permissions.as_deref())) + template.as_ref().and_then(|t| t.permissions.as_deref()), + ) }?; - handler.call(params - .with_args(args) - .with_context(self) - .into() - ).await - }, + handler + .call(params.with_args(args).with_context(self).into()) + .await + } _ => Err(Error::from(ErrorCode::ResourceNotFound)), } } #[inline] - pub(crate) async fn get_prompt(self, params: GetPromptRequestParams) -> Result { + pub(crate) async fn get_prompt( + self, + params: GetPromptRequestParams, + ) -> Result { match self.options.get_prompt(¶ms.name).await { None => Err(Error::new(ErrorCode::InvalidParams, "Prompt not found")), Some(prompt) => { @@ -557,7 +520,10 @@ impl Context { } #[inline] - pub(crate) async fn call_tool(self, params: CallToolRequestParams) -> Result { + pub(crate) async fn call_tool( + self, + params: CallToolRequestParams, + ) -> Result { match self.options.get_tool(¶ms.name).await { None => Err(Error::new(ErrorCode::InvalidParams, "Tool not found")), Some(tool) => { @@ -570,20 +536,23 @@ impl Context { #[inline] #[cfg(feature = "tasks")] - pub(crate) async fn call_tool_with_task(self, params: CallToolRequestParams) -> Result { + pub(crate) async fn call_tool_with_task( + self, + params: CallToolRequestParams, + ) -> Result { match self.options.get_tool(¶ms.name).await { None => Err(Error::new(ErrorCode::InvalidParams, "Tool not found")), Some(tool) => { #[cfg(feature = "http-server")] self.validate_claims(tool.roles.as_deref(), tool.permissions.as_deref())?; - + let task_support = tool.task_support(); if let Some(task_meta) = params.task { self.ensure_tool_augmentation_support(task_support)?; let task = Task::from(task_meta); let handle = self.options.track_task(task.clone()); - + let opt = self.options.clone(); let task_id = task.id.clone(); tokio::spawn(async move { @@ -611,7 +580,8 @@ impl Context { } else if task_support.is_some_and(|ts| ts == TaskSupport::Required) { Err(Error::new( ErrorCode::MethodNotFound, - "Tool required task augmented call")) + "Tool required task augmented call", + )) } else { tool.call(params.with_context(self).into()) .await @@ -622,7 +592,7 @@ impl Context { } /// Requests a list of available roots from a client - /// + /// /// # Example /// ```no_run /// # #[cfg(feature = "server-macros")] { @@ -643,22 +613,21 @@ impl Context { let req = Request::new( Some(RequestId::Uuid(uuid::Uuid::new_v4())), method, - Some(ListRootsRequestParams::default())); - - self.send_request(req) - .await? - .into_result() + Some(ListRootsRequestParams::default()), + ); + + self.send_request(req).await?.into_result() } - + /// Sends the sampling request to the client /// /// # Example /// ```no_run /// # #[cfg(feature = "server-macros")] { /// use neva::{ - /// Context, - /// error::Error, - /// types::sampling::CreateMessageRequestParams, + /// Context, + /// error::Error, + /// types::sampling::CreateMessageRequestParams, /// tool /// }; /// @@ -667,23 +636,25 @@ impl Context { /// let params = CreateMessageRequestParams::new() /// .with_message(format!("Write a short poem about {topic}")) /// .with_sys_prompt("You are a talented poet who writes concise, evocative verses."); - /// + /// /// let result = ctx.sample(params).await?; /// Ok(format!("{:?}", result.content)) /// } /// # } /// ``` #[cfg(not(feature = "tasks"))] - pub async fn sample(&mut self, params: CreateMessageRequestParams) -> Result { + pub async fn sample( + &mut self, + params: CreateMessageRequestParams, + ) -> Result { let method = crate::types::sampling::commands::CREATE; let req = Request::new( Some(RequestId::Uuid(uuid::Uuid::new_v4())), method, - Some(params)); + Some(params), + ); - self.send_request(req) - .await? - .into_result() + self.send_request(req).await?.into_result() } /// Sends the sampling request to the client @@ -692,9 +663,9 @@ impl Context { /// ```no_run /// # #[cfg(feature = "server-macros")] { /// use neva::{ - /// Context, - /// error::Error, - /// types::sampling::CreateMessageRequestParams, + /// Context, + /// error::Error, + /// types::sampling::CreateMessageRequestParams, /// tool /// }; /// @@ -703,22 +674,27 @@ impl Context { /// let params = CreateMessageRequestParams::new() /// .with_message(format!("Write a short poem about {topic}")) /// .with_sys_prompt("You are a talented poet who writes concise, evocative verses."); - /// + /// /// let result = ctx.sample(params).await?; /// Ok(format!("{:?}", result.content)) /// } /// # } /// ``` #[cfg(feature = "tasks")] - pub async fn sample(&mut self, params: CreateMessageRequestParams) -> Result { + pub async fn sample( + &mut self, + params: CreateMessageRequestParams, + ) -> Result { let method = crate::types::sampling::commands::CREATE; let is_task_aug = params.task.is_some(); let req = Request::new( - Some(RequestId::Uuid(uuid::Uuid::new_v4())), - method, - Some(params)); + Some(RequestId::Uuid(uuid::Uuid::new_v4())), + method, + Some(params), + ); - self.send_maybe_task_augmented_request(req, is_task_aug).await + self.send_maybe_task_augmented_request(req, is_task_aug) + .await } /// Sends the elicitation request to the client @@ -727,9 +703,9 @@ impl Context { /// ```no_run /// # #[cfg(feature = "serve-macros")] { /// use neva::{ - /// Context, - /// error::Error, - /// types::elicitation::ElicitRequestParams, + /// Context, + /// error::Error, + /// types::elicitation::ElicitRequestParams, /// tool /// }; /// @@ -748,11 +724,10 @@ impl Context { let req = Request::new( Some(RequestId::Uuid(uuid::Uuid::new_v4())), method, - Some(params)); + Some(params), + ); - self.send_request(req) - .await? - .into_result() + self.send_request(req).await?.into_result() } /// Sends the elicitation request to the client @@ -761,9 +736,9 @@ impl Context { /// ```no_run /// # #[cfg(feature = "serve-macros")] { /// use neva::{ - /// Context, - /// error::Error, - /// types::elicitation::ElicitRequestParams, + /// Context, + /// error::Error, + /// types::elicitation::ElicitRequestParams, /// tool /// }; /// @@ -779,18 +754,20 @@ impl Context { #[cfg(feature = "tasks")] pub async fn elicit(&mut self, params: ElicitRequestParams) -> Result { let related_task = params.related_task(); - + if let Some(related_task) = related_task { let task_id = related_task.id; - let mut id = task_id.as_str().parse::() + let mut id = task_id + .as_str() + .parse::() .expect("Invalid task id"); - - if let Some(session_id) = self.session_id { + + if let Some(session_id) = self.session_id { id = id.concat(session_id.into()); - } - + } + let receiver = self.pending.push(&id); - + self.options.tasks.set_result(&task_id, params); self.options.tasks.require_input(&task_id); @@ -798,17 +775,20 @@ impl Context { Ok(Ok(resp)) => resp, Ok(Err(_)) => { self.options.tasks.fail(&task_id); - return Err(Error::new(ErrorCode::InternalError, "Response channel closed")) - }, + return Err(Error::new( + ErrorCode::InternalError, + "Response channel closed", + )); + } Err(_) => { _ = self.pending.pop(&id); self.options.tasks.fail(&task_id); return Err(Error::new(ErrorCode::Timeout, "Request timed out")); } }; - + self.options.tasks.reset(&task_id); - + return resp.into_result(); } @@ -817,17 +797,17 @@ impl Context { let req = Request::new( Some(RequestId::Uuid(uuid::Uuid::new_v4())), method, - Some(params)); + Some(params), + ); - self.send_maybe_task_augmented_request(req, is_task_aug).await + self.send_maybe_task_augmented_request(req, is_task_aug) + .await } - + /// Notifies the client that the elicitation with the `id` has been completed pub async fn complete_elicitation(&mut self, id: impl Into) -> Result<(), Error> { let params = serde_json::to_value(ElicitationCompleteParams::new(id)).ok(); - self.send_notification( - crate::types::elicitation::commands::COMPLETE, - params) + self.send_notification(crate::types::elicitation::commands::COMPLETE, params) .await } @@ -836,9 +816,7 @@ impl Context { pub async fn task_changed(&mut self, id: &str) -> Result<(), Error> { let task = self.options.tasks.get_status(id)?; let params = serde_json::to_value(task).ok(); - self.send_notification( - crate::types::task::commands::STATUS, - params) + self.send_notification(crate::types::task::commands::STATUS, params) .await } @@ -849,9 +827,9 @@ impl Context { self.scope = Some(scope); self } - - /// Resolves a service and returns a cloned instance. - /// `T` must implement `Clone` otherwise + + /// Resolves a service and returns a cloned instance. + /// `T` must implement `Clone` otherwise /// use resolve_shared method that returns a shared pointer. #[inline] #[cfg(feature = "di")] @@ -873,11 +851,15 @@ impl Context { .resolve_shared::() .map_err(Into::into) } - + #[inline] #[cfg(feature = "http-server")] - fn validate_claims(&self, roles: Option<&[String]>, permissions: Option<&[String]>) -> Result<(), Error> { - let claims = self.claims.as_ref(); + fn validate_claims( + &self, + roles: Option<&[String]>, + permissions: Option<&[String]>, + ) -> Result<(), Error> { + let claims = self.claims.as_ref(); validate_roles(claims, roles)?; validate_permissions(claims, permissions)?; Ok(()) @@ -885,24 +867,27 @@ impl Context { #[inline] #[cfg(feature = "tasks")] - fn ensure_tool_augmentation_support(&self, task_support: Option) -> Result<(), Error> { + fn ensure_tool_augmentation_support( + &self, + task_support: Option, + ) -> Result<(), Error> { if !self.options.is_task_augmented_tool_call_supported() { - return Err( - Error::new( - ErrorCode::MethodNotFound, - "Server does not support task augmented tool calls")); + return Err(Error::new( + ErrorCode::MethodNotFound, + "Server does not support task augmented tool calls", + )); } let Some(task_support) = task_support else { - return Err( - Error::new( - ErrorCode::MethodNotFound, - "Tool does not support task augmented calls")); + return Err(Error::new( + ErrorCode::MethodNotFound, + "Tool does not support task augmented calls", + )); }; if task_support == TaskSupport::Forbidden { - return Err( - Error::new( - ErrorCode::MethodNotFound, - "Tool forbid task augmented calls")); + return Err(Error::new( + ErrorCode::MethodNotFound, + "Tool forbid task augmented calls", + )); } Ok(()) } @@ -912,21 +897,17 @@ impl Context { async fn send_maybe_task_augmented_request( &mut self, req: Request, - is_task_aug: bool + is_task_aug: bool, ) -> Result { if is_task_aug { - let result = self.send_request(req) - .await? - .into_result()?; + let result = self.send_request(req).await?.into_result()?; crate::shared::wait_to_completion(self, result).await } else { - self.send_request(req) - .await? - .into_result() + self.send_request(req).await?.into_result() } } - + /// Sends a [`Request`] to a client #[inline] async fn send_request(&mut self, mut req: Request) -> Result { @@ -940,7 +921,10 @@ impl Context { match timeout(self.timeout, receiver).await { Ok(Ok(resp)) => Ok(resp), - Ok(Err(_)) => Err(Error::new(ErrorCode::InternalError, "Response channel closed")), + Ok(Err(_)) => Err(Error::new( + ErrorCode::InternalError, + "Response channel closed", + )), Err(_) => { _ = self.pending.pop(&id); Err(Error::new(ErrorCode::Timeout, "Request timed out")) @@ -951,9 +935,9 @@ impl Context { /// Sends a notification to a client #[inline] async fn send_notification( - &mut self, - method: &str, - params: Option + &mut self, + method: &str, + params: Option, ) -> Result<(), Error> { let mut notification = Notification::new(method, params); if let Some(session_id) = self.session_id { @@ -967,19 +951,18 @@ impl Context { impl crate::shared::TaskApi for Context { /// Retrieve task result from the client. If the task is not completed yet, waits until it completes or cancels. async fn get_task_result(&mut self, id: impl Into) -> Result - where - T: DeserializeOwned + where + T: DeserializeOwned, { let params = GetTaskPayloadRequestParams { id: id.into() }; let method = crate::types::task::commands::RESULT; let req = Request::new( Some(RequestId::Uuid(uuid::Uuid::new_v4())), method, - Some(params)); + Some(params), + ); - self.send_request(req) - .await? - .into_result() + self.send_request(req).await?.into_result() } /// Retrieve task status from the client @@ -989,19 +972,19 @@ impl crate::shared::TaskApi for Context { let req = Request::new( Some(RequestId::Uuid(uuid::Uuid::new_v4())), method, - Some(params)); - - self.send_request(req) - .await? - .into_result() + Some(params), + ); + + self.send_request(req).await?.into_result() } - + /// Cancels a task that is currently running on the client - async fn cancel_task(&mut self, id: impl Into) -> Result { + async fn cancel_task(&mut self, id: impl Into) -> Result { if !self.options.is_tasks_cancellation_supported() { return Err(Error::new( - ErrorCode::InvalidRequest, - "Server does not support cancelling tasks.")); + ErrorCode::InvalidRequest, + "Server does not support cancelling tasks.", + )); } let params = CancelTaskRequestParams { id: id.into() }; @@ -1009,20 +992,19 @@ impl crate::shared::TaskApi for Context { let req = Request::new( Some(RequestId::Uuid(uuid::Uuid::new_v4())), method, - Some(params)); - - self.send_request(req) - .await? - .into_result() + Some(params), + ); + + self.send_request(req).await?.into_result() } /// Retrieves a list of tasks from the client async fn list_tasks(&mut self, cursor: Option) -> Result { - if !self.options.is_tasks_list_supported() { return Err(Error::new( - ErrorCode::InvalidRequest, - "Server does not support retrieving a task list.")); + ErrorCode::InvalidRequest, + "Server does not support retrieving a task list.", + )); } let params = ListTasksRequestParams { cursor }; @@ -1030,16 +1012,15 @@ impl crate::shared::TaskApi for Context { let req = Request::new( Some(RequestId::Uuid(uuid::Uuid::new_v4())), method, - Some(params)); - - self.send_request(req) - .await? - .into_result() + Some(params), + ); + + self.send_request(req).await?.into_result() } async fn handle_input(&mut self, _id: &str, _params: TaskPayload) -> Result<(), Error> { - // Reserved, there are no cases so far, for the server + // Reserved, there are no cases so far, for the server // to handle input requests from client. Ok(()) } -} \ No newline at end of file +} diff --git a/neva/src/app/handler.rs b/neva/src/app/handler.rs index 26ec185..58088f0 100644 --- a/neva/src/app/handler.rs +++ b/neva/src/app/handler.rs @@ -1,34 +1,25 @@ -//! Handler utilities for resources, tools and prompts +//! Handler utilities for resources, tools and prompts -use std::future::Future; -use std::sync::Arc; -use futures_util::future::BoxFuture; -use crate::error::{Error, ErrorCode}; -use crate::app::options::RuntimeMcpOptions; use crate::Context; +use crate::app::options::RuntimeMcpOptions; +use crate::error::{Error, ErrorCode}; use crate::types::{ - ListResourcesRequestParams, - CompleteRequestParams, - CallToolRequestParams, - ReadResourceRequestParams, - GetPromptRequestParams, - IntoResponse, Response, - Request, RequestId + CallToolRequestParams, CompleteRequestParams, GetPromptRequestParams, IntoResponse, + ListResourcesRequestParams, ReadResourceRequestParams, Request, RequestId, Response, }; +use futures_util::future::BoxFuture; +use std::future::Future; +use std::sync::Arc; /// Represents a specific registered handler -pub(crate) type RequestHandler = Arc< - dyn Handler - + Send - + Sync ->; +pub(crate) type RequestHandler = Arc + Send + Sync>; #[derive(Debug)] pub enum HandlerParams { Request(Context, Request), Tool(CallToolRequestParams), Resource(ReadResourceRequestParams), - Prompt(GetPromptRequestParams) + Prompt(GetPromptRequestParams), } impl From for HandlerParams { @@ -63,17 +54,17 @@ pub trait FromHandlerParams: Sized { } /// Represents a generic handler -pub trait GenericHandler: Clone + Send + Sync + 'static { +pub trait GenericHandler: Clone + Send + Sync + 'static { /// Output type type Output; /// Output future type Future: Future + Send; - + fn call(&self, args: Args) -> Self::Future; } /// Represents a generic handler for list resources -pub trait ListResourcesHandler: Clone + Send + Sync + 'static { +pub trait ListResourcesHandler: Clone + Send + Sync + 'static { /// Output type type Output; /// Output future @@ -84,7 +75,7 @@ pub trait ListResourcesHandler: Clone + Send + Sync + 'static { } /// Represents a generic completion handler. -pub trait CompletionHandler: Clone + Send + Sync + 'static { +pub trait CompletionHandler: Clone + Send + Sync + 'static { /// Output type type Output; /// Output future @@ -95,23 +86,26 @@ pub trait CompletionHandler: Clone + Send + Sync + 'static { } pub(crate) struct RequestFunc -where +where F: GenericHandler, R: IntoResponse, Args: FromHandlerParams, { func: F, - _marker: std::marker::PhantomData, + _marker: std::marker::PhantomData, } impl RequestFunc where F: GenericHandler, R: IntoResponse, - Args: FromHandlerParams + Args: FromHandlerParams, { pub(crate) fn new(func: F) -> Arc { - let func = Self { func, _marker: std::marker::PhantomData }; + let func = Self { + func, + _marker: std::marker::PhantomData, + }; Arc::new(func) } } @@ -120,17 +114,14 @@ impl Handler for RequestFunc where F: GenericHandler, R: IntoResponse, - Args: FromHandlerParams + Send + Sync + Args: FromHandlerParams + Send + Sync, { #[inline] fn call(&self, params: HandlerParams) -> BoxFuture<'_, Result> { Box::pin(async move { let id = RequestId::from_params(¶ms)?; let args = Args::from_params(¶ms)?; - Ok(self.func - .call(args) - .await - .into_response(id)) + Ok(self.func.call(args).await.into_response(id)) }) } } @@ -153,7 +144,10 @@ impl FromHandlerParams for Context { fn from_params(params: &HandlerParams) -> Result { match params { HandlerParams::Request(context, _) => Ok(context.clone()), - _ => Err(Error::new(ErrorCode::InternalError, "invalid handler parameters")) + _ => Err(Error::new( + ErrorCode::InternalError, + "invalid handler parameters", + )), } } } @@ -163,7 +157,10 @@ impl FromHandlerParams for RuntimeMcpOptions { fn from_params(params: &HandlerParams) -> Result { match params { HandlerParams::Request(ctx, _) => Ok(ctx.options.clone()), - _ => Err(Error::new(ErrorCode::InternalError, "invalid handler parameters")) + _ => Err(Error::new( + ErrorCode::InternalError, + "invalid handler parameters", + )), } } } @@ -173,7 +170,10 @@ impl FromHandlerParams for Request { fn from_params(params: &HandlerParams) -> Result { match params { HandlerParams::Request(_, req) => Ok(req.clone()), - _ => Err(Error::new(ErrorCode::InternalError, "invalid handler parameters")) + _ => Err(Error::new( + ErrorCode::InternalError, + "invalid handler parameters", + )), } } } @@ -251,6 +251,4 @@ impl_generic_handler! { T1 T2 T3 T4 } impl_generic_handler! { T1 T2 T3 T4 T5 } #[cfg(test)] -mod tests { - -} \ No newline at end of file +mod tests {} diff --git a/neva/src/app/options.rs b/neva/src/app/options.rs index c7cda6e..8fb35a1 100644 --- a/neva/src/app/options.rs +++ b/neva/src/app/options.rs @@ -1,44 +1,32 @@ -//! MCP server options +//! MCP server options +use crate::app::{collection::Collection, handler::RequestHandler}; +#[cfg(feature = "http-server")] +use crate::transport::HttpServer; +use crate::transport::{StdIoServer, TransportProto}; use dashmap::{DashMap, DashSet}; -use std::{sync::Arc, time::Duration}; use std::fmt::{Debug, Formatter}; +use std::{sync::Arc, time::Duration}; use tokio_util::sync::CancellationToken; -use crate::transport::{StdIoServer, TransportProto}; -#[cfg(feature = "http-server")] -use crate::transport::HttpServer; -use crate::app::{handler::RequestHandler, collection::Collection}; use crate::middleware::{Middleware, Middlewares}; use crate::PROTOCOL_VERSIONS; use crate::types::{ - RequestId, - Implementation, - Tool, - Resource, Uri, ReadResourceResult, ResourceTemplate, + Implementation, Prompt, PromptsCapability, ReadResourceResult, RequestId, Resource, + ResourceTemplate, ResourcesCapability, Tool, ToolsCapability, Uri, resource::{Route, route::ResourceHandler}, - Prompt, - ResourcesCapability, ToolsCapability, PromptsCapability }; +#[cfg(feature = "tasks")] +use crate::shared::{TaskHandle, TaskTracker}; #[cfg(feature = "tracing")] use crate::types::notification::LoggingLevel; #[cfg(feature = "tasks")] -use crate::shared::{TaskTracker, TaskHandle}; -#[cfg(feature = "tasks")] -use crate::types::{ - ServerTasksCapability, - TaskPayload, - Task, -}; +use crate::types::{ServerTasksCapability, Task, TaskPayload}; #[cfg(feature = "tracing")] -use tracing_subscriber::{ - filter::LevelFilter, - reload::Handle, - Registry -}; +use tracing_subscriber::{Registry, filter::LevelFilter, reload::Handle}; #[cfg(any(feature = "tracing", feature = "tasks"))] use crate::error::Error; @@ -52,7 +40,7 @@ pub type RuntimeMcpOptions = Arc; pub struct McpOptions { /// Information of current server's implementation pub(crate) implementation: Implementation, - + /// Timeout for the requests from server to a client pub(crate) request_timeout: Duration, @@ -70,7 +58,7 @@ pub struct McpOptions { /// Holds current subscriptions to resource changes pub(super) resource_subscriptions: DashSet, - + /// An ordered list of middlewares pub(super) middlewares: Option, @@ -86,26 +74,26 @@ pub struct McpOptions { /// Server tasks capability options #[cfg(feature = "tasks")] tasks_capability: Option, - + /// The last logging level set by the client #[cfg(feature = "tracing")] log_level: Option>, /// An MCP version that server supports protocol_ver: Option<&'static str>, - + /// Current transport protocol that this server uses proto: Option, /// A resource template routing data structure resource_routes: Route, - + /// Currently running requests requests: DashMap, /// Currently running tasks #[cfg(feature = "tasks")] - pub(super) tasks: TaskTracker + pub(super) tasks: TaskTracker, } impl Debug for McpOptions { @@ -119,13 +107,13 @@ impl Debug for McpOptions { .field("resources_capability", &self.resources_capability) .field("prompts_capability", &self.prompts_capability) .field("protocol_ver", &self.protocol_ver); - + #[cfg(feature = "tasks")] dbg.field("tasks_capability", &self.tasks_capability); - + #[cfg(feature = "tracing")] dbg.field("log_level", &self.log_level); - + dbg.finish() } } @@ -172,16 +160,18 @@ impl McpOptions { self.proto = Some(TransportProto::HttpServer(Box::new(http))); self } - + /// Sets Streamable HTTP as a transport protocol #[cfg(feature = "http-server")] pub fn with_http HttpServer>(mut self, config: F) -> Self { - self.proto = Some(TransportProto::HttpServer(Box::new(config(HttpServer::default())))); + self.proto = Some(TransportProto::HttpServer(Box::new(config( + HttpServer::default(), + )))); self } /// Sets Streamable HTTP as a transport protocol with default configuration - /// + /// /// Default: /// * __IP__: 127.0.0.1 /// * __PORT__: 3000 @@ -190,7 +180,7 @@ impl McpOptions { pub fn with_default_http(self) -> Self { self.with_http(|http| http) } - + /// Specifies MCP server name pub fn with_name(mut self, name: &str) -> Self { self.implementation.name = name.into(); @@ -204,7 +194,7 @@ impl McpOptions { } /// Specifies Model Context Protocol version - /// + /// /// Default: last available protocol version pub fn with_mcp_version(mut self, ver: &'static str) -> Self { self.protocol_ver = Some(ver); @@ -212,9 +202,9 @@ impl McpOptions { } /// Configures tools capability - pub fn with_tools(mut self, config: F) -> Self - where - F: FnOnce(ToolsCapability) -> ToolsCapability + pub fn with_tools(mut self, config: F) -> Self + where + F: FnOnce(ToolsCapability) -> ToolsCapability, { self.tools_capability = Some(config(Default::default())); self @@ -223,7 +213,7 @@ impl McpOptions { /// Configures resources capability pub fn with_resources(mut self, config: F) -> Self where - F: FnOnce(ResourcesCapability) -> ResourcesCapability + F: FnOnce(ResourcesCapability) -> ResourcesCapability, { self.resources_capability = Some(config(Default::default())); self @@ -232,7 +222,7 @@ impl McpOptions { /// Configures prompts capability pub fn with_prompts(mut self, config: F) -> Self where - F: FnOnce(PromptsCapability) -> PromptsCapability + F: FnOnce(PromptsCapability) -> PromptsCapability, { self.prompts_capability = Some(config(Default::default())); self @@ -242,7 +232,7 @@ impl McpOptions { #[cfg(feature = "tasks")] pub fn with_tasks(mut self, config: F) -> Self where - F: FnOnce(ServerTasksCapability) -> ServerTasksCapability + F: FnOnce(ServerTasksCapability) -> ServerTasksCapability, { self.tasks_capability = Some(config(Default::default())); self @@ -255,14 +245,14 @@ impl McpOptions { self.request_timeout = timeout; self } - + /// Configures [`LogLevelHandle`] that allow to change the [`LoggingLevel`] in runtime #[cfg(feature = "tracing")] pub fn with_logging(mut self, log_handle: Handle) -> Self { self.log_level = Some(log_handle); self } - + /// Sets the [`LoggingLevel`] #[cfg(feature = "tracing")] pub fn set_log_level(&self, level: LoggingLevel) -> Result<(), Error> { @@ -277,21 +267,19 @@ impl McpOptions { /// Returns current log level #[cfg(feature = "tracing")] pub(crate) fn log_level(&self) -> Option { - match &self.log_level { + match &self.log_level { None => None, - Some(handle) => handle - .clone_current() - .map(|x| x.into()), + Some(handle) => handle.clone_current().map(|x| x.into()), } } - + /// Tracks the request with `req_id` and returns the [`CancellationToken`] for this request pub(crate) fn track_request(&self, req_id: &RequestId) -> CancellationToken { let token = CancellationToken::new(); self.requests.insert(req_id.clone(), token.clone()); token } - + /// Cancels the request with `req_id` if it is present pub(crate) fn cancel_request(&self, req_id: &RequestId) { if let Some((_, token)) = self.requests.remove(req_id) { @@ -301,8 +289,7 @@ impl McpOptions { /// Completes the request with `req_id` if it is present pub(crate) fn complete_request(&self, req_id: &RequestId) { - self.requests - .remove(req_id); + self.requests.remove(req_id); } /// Returns a list of currently running tasks @@ -310,7 +297,7 @@ impl McpOptions { pub(crate) fn list_tasks(&self) -> Vec { self.tasks.tasks() } - + /// Tacks the task and returns the [`CancellationToken`] for this task #[cfg(feature = "tasks")] pub(crate) fn track_task(&self, task: Task) -> TaskHandle { @@ -323,7 +310,7 @@ impl McpOptions { self.tasks.cancel(task_id) } - /// Retrieves the task status + /// Retrieves the task status #[cfg(feature = "tasks")] pub(crate) fn get_task_status(&self, task_id: &str) -> Result { self.tasks.get_status(task_id) @@ -334,23 +321,18 @@ impl McpOptions { pub(crate) async fn get_task_result(&self, task_id: &str) -> Result { self.tasks.get_result(task_id).await } - + /// Adds a tool pub(crate) fn add_tool(&mut self, tool: Tool) -> &mut Tool { - self.tools_capability - .get_or_insert_default(); + self.tools_capability.get_or_insert_default(); - self.tools - .as_mut() - .entry(tool.name.clone()) - .or_insert(tool) + self.tools.as_mut().entry(tool.name.clone()).or_insert(tool) } /// Adds a resource pub(crate) fn add_resource(&mut self, resource: Resource) -> &mut Resource { - self.resources_capability - .get_or_insert_default(); - + self.resources_capability.get_or_insert_default(); + self.resources .as_mut() .entry(resource.uri.to_string()) @@ -359,16 +341,16 @@ impl McpOptions { /// Adds a resource template pub(crate) fn add_resource_template( - &mut self, - template: ResourceTemplate, - handler: RequestHandler + &mut self, + template: ResourceTemplate, + handler: RequestHandler, ) -> &mut ResourceTemplate { - self.resources_capability - .get_or_insert_default(); - + self.resources_capability.get_or_insert_default(); + let name = template.name.clone(); - - self.resource_routes.insert(&template.uri_template, name.clone(), handler); + + self.resource_routes + .insert(&template.uri_template, name.clone(), handler); self.resources_templates .as_mut() .entry(name) @@ -377,15 +359,14 @@ impl McpOptions { /// Adds a prompt pub(crate) fn add_prompt(&mut self, prompt: Prompt) -> &mut Prompt { - self.prompts_capability - .get_or_insert_default(); - + self.prompts_capability.get_or_insert_default(); + self.prompts .as_mut() .entry(prompt.name.clone()) .or_insert(prompt) } - + /// Registers a middleware #[inline] pub(crate) fn add_middleware(&mut self, middleware: Middleware) { @@ -393,28 +374,28 @@ impl McpOptions { .get_or_insert_with(Middlewares::new) .add(middleware); } - + /// Returns a Model Context Protocol version that this server supports #[inline] pub(crate) fn protocol_ver(&self) -> &'static str { - match self.protocol_ver { + match self.protocol_ver { Some(ver) => ver, - None => PROTOCOL_VERSIONS.last().unwrap() + None => PROTOCOL_VERSIONS.last().unwrap(), } } - + /// Returns current transport protocol pub(crate) fn transport(&mut self) -> TransportProto { let transport = self.proto.take(); transport.unwrap_or_default() } - + /// Returns a tool by its name #[inline] pub(crate) async fn get_tool(&self, name: &str) -> Option { self.tools.get(name).await } - + /// Returns a list of available tools #[inline] pub(crate) async fn list_tools(&self) -> Vec { @@ -473,7 +454,7 @@ impl McpOptions { } /// Returns [`ServerTasksCapability`] if configured. - /// + /// /// Otherwise, returns `None`. #[cfg(feature = "tasks")] pub(crate) fn tasks_capability(&self) -> Option { @@ -540,7 +521,7 @@ impl McpOptions { .and_then(|req| req.tools.as_ref()) .is_some_and(|tools| tools.call.is_some()) } - + /// Turns [`McpOptions`] into [`RuntimeMcpOptions`] pub(crate) fn into_runtime(mut self) -> RuntimeMcpOptions { self.tools = self.tools.into_runtime(); @@ -553,17 +534,19 @@ impl McpOptions { #[cfg(test)] mod tests { - use crate::error::{Error, ErrorCode}; + use super::*; use crate::SDK_NAME; - use crate::types::resource::template::ResourceFunc; + use crate::error::{Error, ErrorCode}; use crate::types::resource::Uri; - use crate::types::{GetPromptRequestParams, PromptMessage, ReadResourceRequestParams, ResourceContents, Role}; - use super::*; - + use crate::types::resource::template::ResourceFunc; + use crate::types::{ + GetPromptRequestParams, PromptMessage, ReadResourceRequestParams, ResourceContents, Role, + }; + #[test] fn it_creates_default_options() { let options = McpOptions::default(); - + assert_eq!(options.implementation.name, SDK_NAME); assert_eq!(options.implementation.version, env!("CARGO_PKG_VERSION")); assert_eq!(options.tools.as_ref().len(), 0); @@ -576,44 +559,41 @@ mod tests { #[test] fn it_takes_none_transport_by_default() { let mut options = McpOptions::default(); - + let transport = options.transport(); - + assert!(matches!(transport, TransportProto::None)); } - + #[test] fn it_sets_and_takes_stdio_transport() { - let mut options = McpOptions::default() - .with_stdio(); - + let mut options = McpOptions::default().with_stdio(); + let transport = options.transport(); assert!(matches!(transport, TransportProto::StdIoServer(_))); } - + #[test] fn it_sets_server_name() { - let options = McpOptions::default() - .with_name("name"); - + let options = McpOptions::default().with_name("name"); + assert_eq!(options.implementation.name, "name"); } #[test] fn it_sets_server_version() { - let options = McpOptions::default() - .with_version("1"); + let options = McpOptions::default().with_version("1"); assert_eq!(options.implementation.version, "1"); } - + #[tokio::test] async fn it_adds_and_gets_tool() { let mut options = McpOptions::default(); - + options.add_tool(Tool::new("tool", || async { "test" })); - + let tool = options.get_tool("tool").await.unwrap(); assert_eq!(tool.name, "tool"); } @@ -647,17 +627,18 @@ mod tests { .with_mime("text/plain") .with_text("some text") }; - + options.add_resource_template( ResourceTemplate::new("res://res", "test"), - ResourceFunc::new(handler)); + ResourceFunc::new(handler), + ); let req = ReadResourceRequestParams { uri: "res://res".into(), meta: None, - args: None + args: None, }; - + let res = options.read_resource(&req.uri).unwrap(); let res = res.0.call(req.into()).await.unwrap(); assert_eq!(res.contents.len(), 1); @@ -673,12 +654,13 @@ mod tests { options.add_resource_template( ResourceTemplate::new("res://res", "test"), - ResourceFunc::new(handler)); + ResourceFunc::new(handler), + ); let req = ReadResourceRequestParams { uri: "res://res".into(), meta: None, - args: None + args: None, }; let res = options.read_resource(&req.uri).unwrap(); @@ -698,7 +680,8 @@ mod tests { options.add_resource_template( ResourceTemplate::new("res://res", "test"), - ResourceFunc::new(handler)); + ResourceFunc::new(handler), + ); let resources = options.list_resource_templates().await; assert_eq!(resources.len(), 1); @@ -708,9 +691,7 @@ mod tests { async fn it_adds_and_gets_prompt() { let mut options = McpOptions::default(); - options.add_prompt(Prompt::new("test", || async { - [("test", Role::User)] - })); + options.add_prompt(Prompt::new("test", || async { [("test", Role::User)] })); let prompt = options.get_prompt("test").await.unwrap(); assert_eq!(prompt.name, "test"); @@ -732,7 +713,7 @@ mod tests { async fn it_adds_and_gets_prompt_with_error() { let mut options = McpOptions::default(); - options.add_prompt(Prompt::new("test", || async { + options.add_prompt(Prompt::new("test", || async { Err::(Error::from(ErrorCode::InternalError)) })); @@ -754,21 +735,18 @@ mod tests { async fn it_returns_prompts() { let mut options = McpOptions::default(); - options.add_prompt(Prompt::new("test", || async { - [("test", Role::User)] - })); + options.add_prompt(Prompt::new("test", || async { [("test", Role::User)] })); let prompts = options.list_prompts().await; assert_eq!(prompts.len(), 1); } - + #[test] fn it_returns_some_tool_capabilities_if_configured() { - let options = McpOptions::default() - .with_tools(|tools| tools.with_list_changed()); - + let options = McpOptions::default().with_tools(|tools| tools.with_list_changed()); + let tools_capability = options.tools_capability().unwrap(); - + assert!(tools_capability.list_changed); } @@ -776,7 +754,7 @@ mod tests { fn it_returns_some_tool_capabilities_if_there_are_tools() { let mut options = McpOptions::default(); options.add_tool(Tool::new("tool", || async { "test" })); - + let tools_capability = options.tools_capability().unwrap(); assert!(!tools_capability.list_changed); @@ -791,8 +769,7 @@ mod tests { #[test] fn it_returns_some_resource_capabilities_if_configured() { - let options = McpOptions::default() - .with_resources(|res| res.with_list_changed()); + let options = McpOptions::default().with_resources(|res| res.with_list_changed()); let resources_capability = options.resources_capability().unwrap(); @@ -816,10 +793,11 @@ mod tests { let handler = |_: Uri| async move { Err::(Error::from(ErrorCode::ResourceNotFound)) }; - + options.add_resource_template( - ResourceTemplate::new("res://test", "test"), - ResourceFunc::new(handler)); + ResourceTemplate::new("res://test", "test"), + ResourceFunc::new(handler), + ); let resources_capability = options.resources_capability().unwrap(); @@ -835,8 +813,7 @@ mod tests { #[test] fn it_returns_some_prompts_capability_if_configured() { - let options = McpOptions::default() - .with_prompts(|prompts| prompts.with_list_changed()); + let options = McpOptions::default().with_prompts(|prompts| prompts.with_list_changed()); let prompts_capability = options.prompts_capability().unwrap(); @@ -861,4 +838,4 @@ mod tests { assert!(options.prompts_capability().is_none()); } -} \ No newline at end of file +} diff --git a/neva/src/client.rs b/neva/src/client.rs index 6cb5190..3cdc012 100644 --- a/neva/src/client.rs +++ b/neva/src/client.rs @@ -1,65 +1,62 @@ -//! Utilities for the MCP client +//! Utilities for the MCP client -use std::{future::Future, sync::Arc}; -use std::fmt::{Debug, Formatter}; -use options::McpOptions; -use serde::Serialize; -use tokio_util::sync::CancellationToken; -use handler::RequestHandler; use crate::error::{Error, ErrorCode}; use crate::shared; use crate::transport::Transport; use crate::types::{ - ListToolsRequestParams, ListToolsResult, CallToolRequestParams, CallToolResponse, - ListResourcesRequestParams, ListResourcesResult, ReadResourceRequestParams, ReadResourceResult, - ListResourceTemplatesRequestParams, ListResourceTemplatesResult, Uri, - ListPromptsRequestParams, ListPromptsResult, GetPromptRequestParams, GetPromptResult, - ServerCapabilities, ClientCapabilities, Implementation, InitializeRequestParams, InitializeResult, - Request, RequestId, Response, RequestParamsMeta, MessageEnvelope, + CallToolRequestParams, CallToolResponse, ClientCapabilities, GetPromptRequestParams, + GetPromptResult, Implementation, InitializeRequestParams, InitializeResult, + ListPromptsRequestParams, ListPromptsResult, ListResourceTemplatesRequestParams, + ListResourceTemplatesResult, ListResourcesRequestParams, ListResourcesResult, + ListToolsRequestParams, ListToolsResult, MessageEnvelope, ReadResourceRequestParams, + ReadResourceResult, Request, RequestId, RequestParamsMeta, Response, Root, ServerCapabilities, + Uri, cursor::Cursor, + elicitation::{ElicitRequestParams, ElicitResult, ElicitationHandler}, notification::Notification, resource::{SubscribeRequestParams, UnsubscribeRequestParams}, sampling::{CreateMessageRequestParams, CreateMessageResult, SamplingHandler}, - elicitation::{ElicitRequestParams, ElicitResult, ElicitationHandler}, - Root }; +use handler::RequestHandler; +use options::McpOptions; +use serde::Serialize; +use std::fmt::{Debug, Formatter}; +use std::{future::Future, sync::Arc}; +use tokio_util::sync::CancellationToken; #[cfg(feature = "tasks")] use serde::de::DeserializeOwned; #[cfg(feature = "tasks")] use crate::types::{ - Task, TaskPayload, ListTasksRequestParams, ListTasksResult, - GetTaskPayloadRequestParams, - CancelTaskRequestParams, - GetTaskRequestParams, - TaskMetadata, + CancelTaskRequestParams, GetTaskPayloadRequestParams, GetTaskRequestParams, + ListTasksRequestParams, ListTasksResult, Task, TaskMetadata, TaskPayload, }; +pub mod batch; mod handler; mod notification_handler; pub mod options; -pub mod batch; pub mod subscribe; pub use batch::BatchBuilder; -/// Represents an MCP client app +/// Represents an MCP client app pub struct Client { /// MCP client options. options: McpOptions, /// Capabilities supported by the connected server. server_capabilities: Option, - + /// Implementation information of the connected server. server_info: Option, - + /// A [`CancellationToken`] that cancels transport background processes. cancellation_token: Option, - + /// Request handler - handler: Option + handler: Option, } impl Debug for Client { @@ -88,21 +85,21 @@ impl Client { server_capabilities: None, server_info: None, cancellation_token: None, - handler: None + handler: None, } } /// Configure MCP client options pub fn with_options(mut self, config: F) -> Self where - F: FnOnce(McpOptions) -> McpOptions + F: FnOnce(McpOptions) -> McpOptions, { self.options = config(self.options); self } - + /// Adds a new Root - /// + /// /// # Example /// ```no_run /// use neva::client::Client; @@ -139,7 +136,7 @@ impl Client { /// # } /// ``` pub fn add_roots(&mut self, roots: I) -> &mut Self - where + where T: Into, I: IntoIterator, { @@ -147,7 +144,7 @@ impl Client { self.publish_roots_changed(); self } - + /// Sends the "notifications/roots/list_changed" notification to the server pub fn publish_roots_changed(&mut self) { if let Some(handler) = self.handler.as_mut() { @@ -179,38 +176,38 @@ impl Client { self.options.add_elicitation_handler(handler); self } - + /// Connects the MCP client to the MCP server - /// + /// /// # Example /// ```no_run /// use neva::client::Client; /// use neva::error::Error; - /// + /// /// #[tokio::main] /// async fn main() -> Result<(), Error> { /// let mut client = Client::new(); - /// + /// /// client.connect().await?; - /// + /// /// // call tools, read resources, etc. - /// + /// /// client.disconnect().await /// } /// ``` pub async fn connect(&mut self) -> Result<(), Error> { #[cfg(feature = "macros")] self.register_methods(); - + let mut transport = self.options.transport(); let token = transport.start(); - + #[cfg(feature = "tracing")] self.register_tracing_notification_handlers(); - + self.cancellation_token = Some(token); self.handler = Some(RequestHandler::new(transport, &self.options)); - + self.wait_for_shutdown_signal(); self.init().await } @@ -234,14 +231,15 @@ impl Client { /// } /// ``` pub async fn disconnect(mut self) -> Result<(), Error> { - self.send_notification(crate::types::notification::commands::CANCELLED, None).await?; + self.send_notification(crate::types::notification::commands::CANCELLED, None) + .await?; if let Some(token) = self.cancellation_token { token.cancel(); } tokio::time::sleep(std::time::Duration::from_millis(100)).await; Ok(()) } - + /// Sends `initialize` request to an MCP server pub async fn init(&mut self) -> Result<(), Error> { let params = InitializeRequestParams { @@ -254,14 +252,15 @@ impl Client { #[cfg(feature = "tasks")] tasks: self.options.tasks_capability(), experimental: None, - }) + }), }; let req = Request::new( - Some(RequestId::Uuid(uuid::Uuid::new_v4())), - crate::commands::INIT, - Some(params)); - + Some(RequestId::Uuid(uuid::Uuid::new_v4())), + crate::commands::INIT, + Some(params), + ); + let resp = self.send_request(req).await?; let init_result = resp.into_result::()?; @@ -271,7 +270,10 @@ impl Client { self.cancel_transport(); return Err(Error::new( ErrorCode::InvalidRequest, - format!("Unsupported server protocol version: {}", init_result.protocol_ver), + format!( + "Unsupported server protocol version: {}", + init_result.protocol_ver + ), )); } if server_ver != self.options.protocol_ver() { @@ -285,29 +287,30 @@ impl Client { ), )); } - + self.server_capabilities = Some(init_result.capabilities); self.server_info = Some(init_result.server_info); - self.send_notification(crate::types::notification::commands::INITIALIZED, None).await + self.send_notification(crate::types::notification::commands::INITIALIZED, None) + .await } - + /// Sends a ping to the MCP server pub async fn ping(&mut self) -> Result { self.command::<()>(crate::commands::PING, None).await } - + /// Sends a command to the MCP server - /// + /// /// # Example /// ```no_run /// use neva::prelude::*; - /// + /// /// #[derive(serde::Serialize)] /// struct MyCommandParams { /// param: String, /// } - /// + /// /// #[tokio::main] /// async fn main() -> Result<(), Error> { /// let mut client = Client::new(); @@ -322,15 +325,15 @@ impl Client { /// ``` #[inline] pub async fn command( - &mut self, - command: impl Into, - params: Option + &mut self, + command: impl Into, + params: Option, ) -> Result { let id = self.generate_id()?; let request = Request::new(Some(id), command, params); self.send_request(request).await } - + /// Requests a list of tools that MCP server provides /// /// # Example @@ -352,14 +355,14 @@ impl Client { /// /// client.disconnect().await /// } - /// ``` + /// ``` pub async fn list_tools(&mut self, cursor: Option) -> Result { let params = ListToolsRequestParams { cursor }; self.command(crate::types::tool::commands::LIST, Some(params)) .await? .into_result() } - + /// Requests a list of resources that MCP server provides /// /// # Example @@ -381,8 +384,11 @@ impl Client { /// /// client.disconnect().await /// } - /// ``` - pub async fn list_resources(&mut self, cursor: Option) -> Result { + /// ``` + pub async fn list_resources( + &mut self, + cursor: Option, + ) -> Result { let params = ListResourcesRequestParams { cursor }; self.command(crate::types::resource::commands::LIST, Some(params)) .await? @@ -410,12 +416,18 @@ impl Client { /// /// client.disconnect().await /// } - /// ``` - pub async fn list_resource_templates(&mut self, cursor: Option) -> Result { + /// ``` + pub async fn list_resource_templates( + &mut self, + cursor: Option, + ) -> Result { let params = ListResourceTemplatesRequestParams { cursor }; - self.command(crate::types::resource::commands::TEMPLATES_LIST, Some(params)) - .await? - .into_result() + self.command( + crate::types::resource::commands::TEMPLATES_LIST, + Some(params), + ) + .await? + .into_result() } /// Requests a list of prompts that MCP server provides @@ -439,8 +451,11 @@ impl Client { /// /// client.disconnect().await /// } - /// ``` - pub async fn list_prompts(&mut self, cursor: Option) -> Result { + /// ``` + pub async fn list_prompts( + &mut self, + cursor: Option, + ) -> Result { let params = ListPromptsRequestParams { cursor }; self.command(crate::types::prompt::commands::LIST, Some(params)) .await? @@ -460,7 +475,7 @@ impl Client { /// /// client.connect().await?; /// - /// let args = [("message", "Hello MCP!")]; // or let args = ("message", "Hello MCP!"); + /// let args = [("message", "Hello MCP!")]; // or let args = ("message", "Hello MCP!"); /// let result = client.call_tool("echo", args).await?; /// // Do something with the result /// @@ -471,7 +486,7 @@ impl Client { /// # Structured output /// ```no_run /// use neva::prelude::*; - /// + /// /// #[json_schema(de)] /// struct Weather { /// conditions: String, @@ -486,14 +501,14 @@ impl Client { /// client.connect().await?; /// /// let tools = client.list_tools(None).await?; - /// + /// /// // Get the tool by name /// let tool: &Tool = tools.get("weather-forecast") /// .expect("Weather forecast tool not found"); - /// + /// /// let args = ("location", "London"); /// let result = client.call_tool("weather-forecast", args).await?; - /// + /// /// // Validate the output structure and deserialize the result /// let weather: Weather = tool /// .validate(&result) @@ -505,32 +520,30 @@ impl Client { /// } /// ``` pub async fn call_tool( - &mut self, - name: N, - args: Args + &mut self, + name: N, + args: Args, ) -> Result where N: Into, - Args: shared::IntoArgs + Args: shared::IntoArgs, { let params = CallToolRequestParams { name: name.into(), meta: None, args: args.into_args(), #[cfg(feature = "tasks")] - task: None + task: None, }; - - self.call_tool_raw(params) - .await? - .into_result() + + self.call_tool_raw(params).await?.into_result() } /// Calls a task-augmented tool that MCP server supports /// /// # Panics /// If the server does not support task-augmented tool calls - /// + /// /// # Example /// ```no_run /// use neva::client::Client; @@ -542,7 +555,7 @@ impl Client { /// /// client.connect().await?; /// - /// let args = [("message", "Hello MCP!")]; // or let args = ("message", "Hello MCP!"); + /// let args = [("message", "Hello MCP!")]; // or let args = ("message", "Hello MCP!"); /// let result = client.call_tool_as_task("echo", args, None).await?; /// // Do something with the result /// @@ -591,43 +604,42 @@ impl Client { &mut self, name: N, args: Args, - ttl: Option + ttl: Option, ) -> Result where N: Into, - Args: shared::IntoArgs + Args: shared::IntoArgs, { assert!( - self.is_server_support_call_tool_with_tasks(), - "Server does not support call tool with tasks."); - + self.is_server_support_call_tool_with_tasks(), + "Server does not support call tool with tasks." + ); + let params = CallToolRequestParams { name: name.into(), meta: None, args: args.into_args(), - task: Some(TaskMetadata { ttl }) + task: Some(TaskMetadata { ttl }), }; - let result = self - .call_tool_raw(params) - .await? - .into_result()?; + let result = self.call_tool_raw(params).await?.into_result()?; shared::wait_to_completion(self, result).await } - + /// Calls a tool #[inline] pub async fn call_tool_raw( - &mut self, - params: CallToolRequestParams + &mut self, + params: CallToolRequestParams, ) -> Result { let id = self.generate_id()?; - + let request = Request::new( Some(id.clone()), crate::types::tool::commands::CALL, - Some(params.with_meta(RequestParamsMeta::new(&id)))); + Some(params.with_meta(RequestParamsMeta::new(&id))), + ); self.send_request(request).await } @@ -650,8 +662,11 @@ impl Client { /// /// client.disconnect().await /// } - /// ``` - pub async fn read_resource(&mut self, uri: impl Into) -> Result { + /// ``` + pub async fn read_resource( + &mut self, + uri: impl Into, + ) -> Result { let id = self.generate_id()?; let request = Request::new( Some(id.clone()), @@ -660,13 +675,11 @@ impl Client { uri: uri.into(), meta: Some(RequestParamsMeta::new(&id)), #[cfg(feature = "server")] - args: None - }) + args: None, + }), ); - self.send_request(request) - .await? - .into_result() + self.send_request(request).await?.into_result() } /// Gets a prompt that MCP server provides @@ -691,15 +704,15 @@ impl Client { /// /// client.disconnect().await /// } - /// ``` + /// ``` pub async fn get_prompt( - &mut self, + &mut self, name: N, - args: Args + args: Args, ) -> Result where N: Into, - Args: shared::IntoArgs + Args: shared::IntoArgs, { let id = self.generate_id()?; let request = Request::new( @@ -708,15 +721,13 @@ impl Client { Some(GetPromptRequestParams { name: name.into(), meta: Some(RequestParamsMeta::new(&id)), - args: args.into_args() - }) + args: args.into_args(), + }), ); - self.send_request(request) - .await? - .into_result() + self.send_request(request).await?.into_result() } - + /// Subscribes to a resource on the server to receive notifications when it changes. pub async fn subscribe_to_resource(&mut self, uri: impl Into) -> Result<(), Error> { if !self.is_resource_subscription_supported() { @@ -725,12 +736,12 @@ impl Client { "Server does not support resource subscriptions", )); } - + let params = SubscribeRequestParams::from(uri); let resp = self .command(crate::types::resource::commands::SUBSCRIBE, Some(params)) .await?; - + match resp { Response::Ok(_) => Ok(()), Response::Err(err) => Err(err.error.into()), @@ -750,31 +761,31 @@ impl Client { let resp = self .command(crate::types::resource::commands::UNSUBSCRIBE, Some(params)) .await?; - + match resp { Response::Ok(_) => Ok(()), Response::Err(err) => Err(err.error.into()), } } - + /// Maps the `handler` to a specific `event` pub fn subscribe(&mut self, event: E, handler: F) where E: Into, F: Fn(Notification) -> R + Clone + Send + Sync + 'static, - R: Future + Send + R: Future + Send, { self.options .notification_handler .get_or_insert_default() .subscribe(event, handler); } - + /// Unsubscribe a handler from the `event` pub fn unsubscribe(&mut self, event: impl AsRef) { if let Some(notification_handler) = &self.options.notification_handler { notification_handler.unsubscribe(event); - } + } } /// Returns whether the server is configured to send the "notifications/resources/updated" @@ -816,18 +827,14 @@ impl Client { /// Returns whether the client has elicitation capabilities #[inline] fn is_elicitation_supported(&self) -> bool { - self.options.elicitation_capability - .as_ref() - .is_some() + self.options.elicitation_capability.as_ref().is_some() } /// Returns whether the client has task augmentation capabilities #[inline] #[cfg(feature = "tasks")] fn is_client_supports_tasks(&self) -> bool { - self.options.tasks_capability - .as_ref() - .is_some() + self.options.tasks_capability.as_ref().is_some() } /// Returns whether the server has task augmentation capabilities @@ -843,7 +850,8 @@ impl Client { #[inline] #[cfg(feature = "tasks")] fn is_client_support_cancelling_tasks(&self) -> bool { - self.options.tasks_capability + self.options + .tasks_capability .as_ref() .is_some_and(|c| c.cancel.is_some()) } @@ -872,7 +880,8 @@ impl Client { #[inline] #[cfg(feature = "tasks")] fn is_client_support_task_list(&self) -> bool { - self.options.tasks_capability + self.options + .tasks_capability .as_ref() .is_some_and(|c| c.list.is_some()) } @@ -892,7 +901,8 @@ impl Client { /// Sends a request to the MCP server #[inline] async fn send_request(&mut self, req: Request) -> Result { - self.handler.as_mut() + self.handler + .as_mut() .ok_or_else(|| Error::new(ErrorCode::InternalError, "Connection closed"))? .send_request(req) .await @@ -946,7 +956,8 @@ impl Client { ) -> Result, Error> { use futures_util::future::join_all; - let handler = self.handler + let handler = self + .handler .as_mut() .ok_or_else(|| Error::new(ErrorCode::InternalError, "Connection closed"))?; @@ -959,7 +970,10 @@ impl Client { async move { match tokio::time::timeout(request_timeout, rx).await { Ok(Ok(resp)) => Ok(resp), - Ok(Err(_)) => Err(Error::new(ErrorCode::InternalError, "Response channel closed")), + Ok(Err(_)) => Err(Error::new( + ErrorCode::InternalError, + "Response channel closed", + )), Err(_) => { let _ = pending.pop(&id); Err(Error::new(ErrorCode::Timeout, "Batch request timed out")) @@ -980,22 +994,24 @@ impl Client { #[inline] #[cfg(feature = "tasks")] async fn send_response(&mut self, req: Response) -> Result<(), Error> { - self.handler.as_mut() + self.handler + .as_mut() .ok_or_else(|| Error::new(ErrorCode::InternalError, "Connection closed"))? .send_response(req) .await; Ok(()) } - + /// Sends a notification to the MCP server #[inline] async fn send_notification( &mut self, method: &str, - params: Option + params: Option, ) -> Result<(), Error> { let notification = Notification::new(method, params); - self.handler.as_mut() + self.handler + .as_mut() .ok_or_else(|| Error::new(ErrorCode::InternalError, "Connection closed"))? .send_notification(notification) .await @@ -1004,12 +1020,12 @@ impl Client { #[cfg(feature = "tracing")] fn register_tracing_notification_handlers(&mut self) { use crate::types::notification::commands::*; - + self.subscribe(MESSAGE, Self::default_notification_handler); self.subscribe(STDERR, Self::default_notification_handler); self.subscribe(PROGRESS, Self::default_notification_handler); } - + #[cfg(feature = "tracing")] async fn default_notification_handler(notification: Notification) { notification.write(); @@ -1018,11 +1034,12 @@ impl Client { /// Generates a new [`RequestId`] #[inline] fn generate_id(&self) -> Result { - self.handler.as_ref() + self.handler + .as_ref() .ok_or_else(|| Error::new(ErrorCode::InternalError, "Connection closed")) .map(|h| h.next_id()) } - + /// Cancels the transport and clears connection state without sending a /// notification. Used when initialization fails after the transport has /// already been started (e.g. protocol version mismatch in `init()`). @@ -1040,7 +1057,7 @@ impl Client { shared::wait_for_shutdown_signal(token); }; } - + #[inline(always)] #[cfg(feature = "tasks")] fn ensure_tasks_supported(&self) { @@ -1060,8 +1077,8 @@ impl Client { impl shared::TaskApi for Client { /// Retrieves task result. If the task is not completed yet, waits until it completes or cancels. async fn get_task_result(&mut self, id: impl Into) -> Result - where - T: DeserializeOwned + where + T: DeserializeOwned, { let params = GetTaskPayloadRequestParams { id: id.into() }; self.command(crate::types::task::commands::RESULT, Some(params)) @@ -1069,29 +1086,29 @@ impl shared::TaskApi for Client { .into_result() } - /// Retrieve task status + /// Retrieve task status async fn get_task(&mut self, id: impl Into) -> Result { let params = GetTaskRequestParams { id: id.into() }; self.command(crate::types::task::commands::GET, Some(params)) .await? .into_result() } - + /// Cancels a task that is currently running - /// + /// /// # Panics /// If the client or server does not support cancelling tasks async fn cancel_task(&mut self, id: impl Into) -> Result { assert!( - self.is_client_support_cancelling_tasks(), + self.is_client_support_cancelling_tasks(), "Client does not support cancelling tasks. You may configure it with `Client::with_options(|opt| opt.with_tasks(...))` method." ); - + assert!( - self.is_server_support_cancelling_tasks(), + self.is_server_support_cancelling_tasks(), "Server does not support cancelling tasks." ); - + let params = CancelTaskRequestParams { id: id.into() }; self.command(crate::types::task::commands::CANCEL, Some(params)) .await? @@ -1099,20 +1116,20 @@ impl shared::TaskApi for Client { } /// Retrieves a list of tasks - /// + /// /// # Panics /// If the client or server does not support retrieving a task list async fn list_tasks(&mut self, cursor: Option) -> Result { assert!( - self.is_client_support_task_list(), + self.is_client_support_task_list(), "Client does not support retrieving a task list. You may configure it with `Client::with_options(|opt| opt.with_tasks(...))` method." ); - + assert!( - self.is_server_support_task_list(), + self.is_server_support_task_list(), "Server does not support retrieving a task list." ); - + let params = ListTasksRequestParams { cursor }; self.command(crate::types::task::commands::LIST, Some(params)) .await? @@ -1124,13 +1141,10 @@ impl shared::TaskApi for Client { if let Some(handler) = &self.options.elicitation_handler { use crate::types::IntoResponse; - let result = handler(params) - .await - .with_related_task(id); - - let id = id.parse::() - .expect("Invalid Request Id"); - + let result = handler(params).await.with_related_task(id); + + let id = id.parse::().expect("Invalid Request Id"); + self.send_response(result.into_response(id)).await?; } Ok(()) @@ -1152,11 +1166,8 @@ where }) } -type Handler = Arc< - dyn Fn(P) -> std::pin::Pin + Send>> - + Send - + Sync ->; +type Handler = + Arc std::pin::Pin + Send>> + Send + Sync>; #[cfg(test)] mod tests { @@ -1166,6 +1177,9 @@ mod tests { async fn call_batch_requires_connected_client() { let mut client = Client::new(); let result = client.call_batch(vec![]).await; - assert!(result.is_err(), "disconnected client should return an error"); + assert!( + result.is_err(), + "disconnected client should return an error" + ); } } diff --git a/neva/src/client/batch.rs b/neva/src/client/batch.rs index f273cfb..db2753d 100644 --- a/neva/src/client/batch.rs +++ b/neva/src/client/batch.rs @@ -6,8 +6,7 @@ use crate::{ shared::IntoArgs, types::{ MessageEnvelope, Request, RequestId, RequestParamsMeta, Response, - notification::Notification, - resource::Uri, + notification::Notification, resource::Uri, }, }; @@ -68,7 +67,10 @@ impl<'a> BatchBuilder<'a> { /// Enqueues a `tools/list` request (first page; cursor is always `None`). pub fn list_tools(mut self) -> Self { use crate::types::{ListToolsRequestParams, tool::commands}; - self.push_request(commands::LIST, Some(ListToolsRequestParams { cursor: None })); + self.push_request( + commands::LIST, + Some(ListToolsRequestParams { cursor: None }), + ); self } @@ -91,7 +93,10 @@ impl<'a> BatchBuilder<'a> { /// Enqueues a `resources/list` request (first page; cursor is always `None`). pub fn list_resources(mut self) -> Self { use crate::types::{ListResourcesRequestParams, resource::commands}; - self.push_request(commands::LIST, Some(ListResourcesRequestParams { cursor: None })); + self.push_request( + commands::LIST, + Some(ListResourcesRequestParams { cursor: None }), + ); self } @@ -113,14 +118,20 @@ impl<'a> BatchBuilder<'a> { /// Enqueues a `resources/templates/list` request. pub fn list_resource_templates(mut self) -> Self { use crate::types::{ListResourceTemplatesRequestParams, resource::commands}; - self.push_request(commands::TEMPLATES_LIST, Some(ListResourceTemplatesRequestParams { cursor: None })); + self.push_request( + commands::TEMPLATES_LIST, + Some(ListResourceTemplatesRequestParams { cursor: None }), + ); self } /// Enqueues a `prompts/list` request (first page; cursor is always `None`). pub fn list_prompts(mut self) -> Self { use crate::types::{ListPromptsRequestParams, prompt::commands}; - self.push_request(commands::LIST, Some(ListPromptsRequestParams { cursor: None })); + self.push_request( + commands::LIST, + Some(ListPromptsRequestParams { cursor: None }), + ); self } @@ -171,7 +182,9 @@ impl<'a> BatchBuilder<'a> { /// or a local sequential fallback if the client is not connected. /// (A disconnected client will fail on [`send`](Self::send) anyway.) fn next_id(&self) -> RequestId { - self.client.generate_id().unwrap_or(RequestId::Number(self.items.len() as i64)) + self.client + .generate_id() + .unwrap_or(RequestId::Number(self.items.len() as i64)) } /// Enqueues a generic request with the given method and params. diff --git a/neva/src/client/handler.rs b/neva/src/client/handler.rs index 28e9953..4726693 100644 --- a/neva/src/client/handler.rs +++ b/neva/src/client/handler.rs @@ -1,37 +1,33 @@ -//! Request handling utilities +//! Request handling utilities -use std::sync::Arc; -use tokio::{sync::RwLock, time::timeout}; -use std::{time::Duration, sync::atomic::{AtomicI64, Ordering}}; use crate::client::notification_handler::NotificationsHandler; use crate::{ client::options::McpOptions, error::{Error, ErrorCode}, shared::RequestQueue, transport::{ - Receiver, Sender, - Transport, TransportProto, TransportProtoReceiver, TransportProtoSender + Receiver, Sender, Transport, TransportProto, TransportProtoReceiver, TransportProtoSender, }, types::{ - IntoResponse, Response, Message, MessageBatch, MessageEnvelope, - RequestId, Request, - notification::Notification, - Root, root::ListRootsResult, + IntoResponse, Message, MessageBatch, MessageEnvelope, Request, RequestId, Response, Root, + elicitation::ElicitationHandler, notification::Notification, root::ListRootsResult, sampling::SamplingHandler, - elicitation::ElicitationHandler - } + }, }; +use std::sync::Arc; +use std::{ + sync::atomic::{AtomicI64, Ordering}, + time::Duration, +}; +use tokio::{sync::RwLock, time::timeout}; #[cfg(feature = "tasks")] use crate::{ shared::TaskTracker, types::{ - Task, Pagination, CreateTaskResult, - CreateMessageRequestParams, - ElicitRequestParams, - ListTasksRequestParams, ListTasksResult, - CancelTaskRequestParams, - GetTaskPayloadRequestParams, GetTaskRequestParams + CancelTaskRequestParams, CreateMessageRequestParams, CreateTaskResult, ElicitRequestParams, + GetTaskPayloadRequestParams, GetTaskRequestParams, ListTasksRequestParams, ListTasksResult, + Pagination, Task, }, }; @@ -41,7 +37,7 @@ const DEFAULT_PAGE_SIZE: usize = 10; struct Roots { /// Cached list of [`Root`] inner: Arc>>, - + /// Notifier for Roots cache updates sender: Option>>, } @@ -49,7 +45,7 @@ struct Roots { pub(super) struct RequestHandler { /// Request counter counter: AtomicI64, - + /// Request timeout timeout: Duration, @@ -58,7 +54,7 @@ pub(super) struct RequestHandler { /// Current transport sender handle sender: TransportProtoSender, - + /// Cached list of [`Root`] roots: Roots, @@ -73,20 +69,23 @@ pub(super) struct RequestHandler { /// Task tracker for client sampling tasks. #[cfg(feature = "tasks")] - tasks: Arc + tasks: Arc, } impl Roots { fn new(options: &McpOptions, notifications_sender: &TransportProtoSender) -> Self { let mut roots = Self { inner: Arc::new(RwLock::new(options.roots())), - sender: None + sender: None, }; - if options.roots_capability().is_some_and(|roots| roots.list_changed) { + if options + .roots_capability() + .is_some_and(|roots| roots.list_changed) + { let (tx, mut rx) = tokio::sync::mpsc::channel::>(1); - roots.sender = Some(tx); - + roots.sender = Some(tx); + let roots = roots.inner.clone(); let mut sender = notifications_sender.clone(); tokio::spawn(async move { @@ -94,9 +93,8 @@ impl Roots { let mut current_roots = roots.write().await; *current_roots = new_roots; - let changed = Notification::new( - crate::types::root::commands::LIST_CHANGED, - None); + let changed = + Notification::new(crate::types::root::commands::LIST_CHANGED, None); if let Err(_err) = sender.send(changed.into()).await { #[cfg(feature = "tracing")] tracing::error!("Error sending notification: {:?}", _err); @@ -104,10 +102,10 @@ impl Roots { } }); } - + roots } - + fn update(&mut self, roots: Vec) { match self.sender.as_mut() { None => (), @@ -115,16 +113,16 @@ impl Roots { _ = sender .try_send(roots) .map_err(|err| Error::new(ErrorCode::InternalError, err)) - }, + } } - } + } } impl RequestHandler { /// Creates a new [`RequestHandler`] pub(super) fn new(transport: TransportProto, options: &McpOptions) -> Self { let (tx, rx) = transport.split(); - + let handler = Self { roots: Roots::new(options, &tx), counter: AtomicI64::new(1), @@ -135,9 +133,9 @@ impl RequestHandler { elicitation_handler: options.elicitation_handler.clone(), notification_handler: options.notification_handler.clone(), #[cfg(feature = "tasks")] - tasks: Arc::new(TaskTracker::new()) + tasks: Arc::new(TaskTracker::new()), }; - + handler.start(rx) } @@ -169,7 +167,10 @@ impl RequestHandler { match timeout(self.timeout, receiver).await { Ok(Ok(resp)) => Ok(resp), - Ok(Err(_)) => Err(Error::new(ErrorCode::InternalError, "Response channel closed")), + Ok(Err(_)) => Err(Error::new( + ErrorCode::InternalError, + "Response channel closed", + )), Err(_) => { _ = self.pending.pop(&id); Err(Error::new(ErrorCode::Timeout, "Request timed out")) @@ -223,13 +224,16 @@ impl RequestHandler { pub(super) async fn send_response(&mut self, resp: Response) { send_response_impl(&mut self.sender, resp).await; } - + /// Sends a notification to MCP server #[inline] - pub(super) async fn send_notification(&mut self, notification: Notification) -> Result<(), Error> { + pub(super) async fn send_notification( + &mut self, + notification: Notification, + ) -> Result<(), Error> { self.sender.send(notification.into()).await } - + /// Updates [`Root`] cache pub(super) fn notify_roots_changed(&mut self, roots: Vec) { self.roots.update(roots); @@ -246,7 +250,7 @@ impl RequestHandler { #[cfg(feature = "tasks")] let tasks = self.tasks.clone(); - + tokio::task::spawn(async move { while let Ok(msg) = rx.recv().await { match msg { @@ -259,12 +263,13 @@ impl RequestHandler { &elicitation_handler, #[cfg(feature = "tasks")] &tasks, - ).await; + ) + .await; send_response_impl(&mut sender, resp).await; - }, + } Message::Notification(notification) => { dispatch_notification(notification, ¬ification_handler).await; - }, + } Message::Batch(batch) => { // JSON-RPC 2.0 §6 allows either peer to send a batch // containing any mix of Requests, Notifications, and @@ -298,19 +303,21 @@ impl RequestHandler { &elicitation_handler, #[cfg(feature = "tasks")] &tasks, - ).await; + ) + .await; responses.push(MessageEnvelope::Response(resp)); - }, + } MessageEnvelope::Notification(notification) => { - dispatch_notification(notification, ¬ification_handler).await; - }, + dispatch_notification(notification, ¬ification_handler) + .await; + } } } // MessageBatch::new returns Err for an empty vec (all // items were notifications), in which case no reply is // sent — correct per JSON-RPC 2.0 §6. if let Ok(batch) = MessageBatch::new(responses) - && let Err(_err) = sender.send(Message::Batch(batch)).await + && let Err(_err) = sender.send(Message::Batch(batch)).await { #[cfg(feature = "tracing")] tracing::error!("Error sending batch response: {_err:?}"); @@ -341,23 +348,28 @@ async fn dispatch_request( roots: &Arc>>, sampling_handler: &Option, elicitation_handler: &Option, - #[cfg(feature = "tasks")] - tasks: &Arc, + #[cfg(feature = "tasks")] tasks: &Arc, ) -> Response { let req_id = req.id(); match req.method.as_str() { - crate::types::sampling::commands::CREATE => handle_sampling( - req, - sampling_handler, - #[cfg(feature = "tasks")] - tasks, - ).await, - crate::types::elicitation::commands::CREATE => handle_elicitation( - req, - elicitation_handler, - #[cfg(feature = "tasks")] - tasks, - ).await, + crate::types::sampling::commands::CREATE => { + handle_sampling( + req, + sampling_handler, + #[cfg(feature = "tasks")] + tasks, + ) + .await + } + crate::types::elicitation::commands::CREATE => { + handle_elicitation( + req, + elicitation_handler, + #[cfg(feature = "tasks")] + tasks, + ) + .await + } crate::types::root::commands::LIST => handle_roots(req, roots).await, #[cfg(feature = "tasks")] crate::types::task::commands::RESULT => get_task_result(req, tasks).await, @@ -410,22 +422,24 @@ async fn handle_sampling(req: Request, handler: &Option) -> Res result.into_response(id) } else { Response::error( - id, + id, Error::new( - ErrorCode::MethodNotFound, - "Client does not support sampling requests")) + ErrorCode::MethodNotFound, + "Client does not support sampling requests", + ), + ) } } #[inline] #[cfg(feature = "tasks")] async fn handle_sampling( - req: Request, + req: Request, handler: &Option, - tasks: &Arc + tasks: &Arc, ) -> Response { let id = req.id(); - if let Some(handler) = &handler { + if let Some(handler) = &handler { let Some(params) = req.params else { return Response::error(id, Error::from(ErrorCode::InvalidParams)); }; @@ -455,8 +469,12 @@ async fn handle_sampling( } } else { Response::error( - id, - Error::new(ErrorCode::MethodNotFound, "Client does not support sampling requests")) + id, + Error::new( + ErrorCode::MethodNotFound, + "Client does not support sampling requests", + ), + ) } } @@ -464,7 +482,7 @@ async fn handle_sampling( #[cfg(not(feature = "tasks"))] async fn handle_elicitation(req: Request, handler: &Option) -> Response { let id = req.id(); - if let Some(handler) = &handler { + if let Some(handler) = &handler { let Some(params) = req.params else { return Response::error(id, Error::from(ErrorCode::InvalidParams)); }; @@ -476,27 +494,32 @@ async fn handle_elicitation(req: Request, handler: &Option) } else { Response::error( id, - Error::new(ErrorCode::MethodNotFound, "Client does not support elicitation requests")) + Error::new( + ErrorCode::MethodNotFound, + "Client does not support elicitation requests", + ), + ) } } #[inline] #[cfg(feature = "tasks")] async fn handle_elicitation( - req: Request, + req: Request, handler: &Option, - tasks: &Arc + tasks: &Arc, ) -> Response { let id = req.id(); - if let Some(handler) = &handler { + if let Some(handler) = &handler { let Some(params) = req.params else { return Response::error(id, Error::from(ErrorCode::InvalidParams)); }; let Ok(params) = serde_json::from_value(params) else { return Response::error(id, Error::from(ErrorCode::ParseError)); }; - if let ElicitRequestParams::Url(url_params) = ¶ms && - let Some(task_meta) = &url_params.task { + if let ElicitRequestParams::Url(url_params) = ¶ms + && let Some(task_meta) = &url_params.task + { let task = Task::from(*task_meta); let handle = tasks.track(task.clone()); @@ -520,36 +543,33 @@ async fn handle_elicitation( } else { Response::error( id, - Error::new(ErrorCode::MethodNotFound, "Client does not support elicitation requests")) + Error::new( + ErrorCode::MethodNotFound, + "Client does not support elicitation requests", + ), + ) } } - #[inline] #[cfg(feature = "tasks")] -fn handle_list_tasks( - req: Request, - tasks: &Arc -) -> Response { +fn handle_list_tasks(req: Request, tasks: &Arc) -> Response { let id = req.id(); let Some(params) = req.params else { return Response::error(id, Error::from(ErrorCode::InvalidParams)); }; let params: Option = serde_json::from_value(params).ok(); - ListTasksResult::from(tasks - .tasks() - .paginate( - params.and_then(|p| p.cursor), - DEFAULT_PAGE_SIZE)) - .into_response(id) + ListTasksResult::from( + tasks + .tasks() + .paginate(params.and_then(|p| p.cursor), DEFAULT_PAGE_SIZE), + ) + .into_response(id) } #[inline] #[cfg(feature = "tasks")] -fn cancel_task( - req: Request, - tasks: &Arc -) -> Response { +fn cancel_task(req: Request, tasks: &Arc) -> Response { let id = req.id(); let Some(params) = req.params else { return Response::error(id, Error::from(ErrorCode::InvalidParams)); @@ -559,18 +579,13 @@ fn cancel_task( }; match tasks.cancel(¶ms.id) { Ok(task) => task.into_response(id), - Err(err) => Response::error( - id, - Error::new(ErrorCode::InvalidParams, err.to_string())) + Err(err) => Response::error(id, Error::new(ErrorCode::InvalidParams, err.to_string())), } } #[inline] #[cfg(feature = "tasks")] -fn get_task( - req: Request, - tasks: &Arc -) -> Response { +fn get_task(req: Request, tasks: &Arc) -> Response { let id = req.id(); let Some(params) = req.params else { return Response::error(id, Error::from(ErrorCode::InvalidParams)); @@ -580,18 +595,13 @@ fn get_task( }; match tasks.get_status(¶ms.id) { Ok(task) => task.into_response(id), - Err(err) => Response::error( - id, - Error::new(ErrorCode::InvalidParams, err.to_string())) + Err(err) => Response::error(id, Error::new(ErrorCode::InvalidParams, err.to_string())), } } #[inline] #[cfg(feature = "tasks")] -async fn get_task_result( - req: Request, - tasks: &Arc -) -> Response { +async fn get_task_result(req: Request, tasks: &Arc) -> Response { let id = req.id(); let Some(params) = req.params else { return Response::error(id, Error::from(ErrorCode::InvalidParams)); @@ -601,9 +611,7 @@ async fn get_task_result( }; match tasks.get_result(¶ms.id).await { Ok(task) => task.into_response(id), - Err(err) => Response::error( - id, - Error::new(ErrorCode::InvalidParams, err.to_string())) + Err(err) => Response::error(id, Error::new(ErrorCode::InvalidParams, err.to_string())), } } @@ -619,7 +627,9 @@ async fn get_task_result( fn validate_batch_ids(items: &[MessageEnvelope]) -> Result<(), Error> { let mut seen = std::collections::HashSet::new(); for envelope in items { - if let MessageEnvelope::Request(req) = envelope && !seen.insert(req.id()) { + if let MessageEnvelope::Request(req) = envelope + && !seen.insert(req.id()) + { return Err(Error::new( ErrorCode::InvalidRequest, "batch contains duplicate request IDs", @@ -635,9 +645,9 @@ mod tests { #[tokio::test] async fn batch_responses_are_distributed_individually() { - use tokio::time::{timeout, Duration}; use crate::types::MessageBatch; use serde_json::json; + use tokio::time::{Duration, timeout}; let queue = RequestQueue::default(); @@ -656,7 +666,8 @@ mod tests { MessageEnvelope::Response(resp1), MessageEnvelope::Request(dummy_req), MessageEnvelope::Response(resp2), - ]).expect("batch must not be empty"); + ]) + .expect("batch must not be empty"); // Simulate the batch receive arm for envelope in batch { @@ -677,9 +688,13 @@ mod tests { #[test] fn validate_batch_ids_rejects_duplicate_request_ids() { - let req = |id: i64| MessageEnvelope::Request( - Request::new(Some(RequestId::Number(id)), "ping", None::<()>) - ); + let req = |id: i64| { + MessageEnvelope::Request(Request::new( + Some(RequestId::Number(id)), + "ping", + None::<()>, + )) + }; // Unique IDs — should pass assert!(validate_batch_ids(&[req(1), req(2), req(3)]).is_ok()); @@ -691,12 +706,11 @@ mod tests { #[test] fn validate_batch_ids_ignores_notifications() { - let notif = MessageEnvelope::Notification( - crate::types::notification::Notification::new("foo", None) - ); - let req = MessageEnvelope::Request( - Request::new(Some(RequestId::Number(1)), "ping", None::<()>) - ); + let notif = MessageEnvelope::Notification(crate::types::notification::Notification::new( + "foo", None, + )); + let req = + MessageEnvelope::Request(Request::new(Some(RequestId::Number(1)), "ping", None::<()>)); // Two notifications with no ID fields — should not trigger duplicate check assert!(validate_batch_ids(&[notif.clone(), req, notif]).is_ok()); } @@ -710,13 +724,12 @@ mod tests { // Simulate what send_batch does for a [Notification, Request, Notification] batch let notification_1 = MessageEnvelope::Notification( - crate::types::notification::Notification::new("foo", None) - ); - let request = MessageEnvelope::Request( - Request::new(Some(req_id.clone()), "ping", None::<()>) + crate::types::notification::Notification::new("foo", None), ); + let request = + MessageEnvelope::Request(Request::new(Some(req_id.clone()), "ping", None::<()>)); let notification_2 = MessageEnvelope::Notification( - crate::types::notification::Notification::new("bar", None) + crate::types::notification::Notification::new("bar", None), ); let items = vec![notification_1, request, notification_2]; @@ -729,7 +742,11 @@ mod tests { } } - assert_eq!(receivers.len(), 1, "exactly one receiver for the one Request"); + assert_eq!( + receivers.len(), + 1, + "exactly one receiver for the one Request" + ); assert_eq!(receivers[0].0, req_id, "receiver ID matches request ID"); } -} \ No newline at end of file +} diff --git a/neva/src/client/notification_handler.rs b/neva/src/client/notification_handler.rs index 2b1ea9b..43ea2f0 100644 --- a/neva/src/client/notification_handler.rs +++ b/neva/src/client/notification_handler.rs @@ -1,29 +1,24 @@ -//! Utilities for handling notifications from server +//! Utilities for handling notifications from server +use crate::types::notification::Notification; use std::collections::HashMap; use std::future::Future; use std::pin::Pin; use std::sync::Arc; use tokio::sync::RwLock; -use crate::types::notification::Notification; /// Represents a notification handler function -pub(crate) type NotificationsHandlerFunc = Arc< - dyn Fn(Notification) -> Pin< - Box + Send + 'static>, - > - + Send - + Sync ->; +pub(crate) type NotificationsHandlerFunc = + Arc Pin + Send + 'static>> + Send + Sync>; /// Represents a notification handler #[derive(Default)] pub(super) struct NotificationsHandler { - handlers: RwLock> + handlers: RwLock>, } impl NotificationsHandler { - /// Subscribes to the `event` with the `handler` + /// Subscribes to the `event` with the `handler` pub(super) fn subscribe(&self, event: E, handler: F) where E: Into, @@ -32,24 +27,22 @@ impl NotificationsHandler { { let handler: NotificationsHandlerFunc = Arc::new(move |params| { let handler = handler.clone(); - Box::pin(async move { handler(params).await; }) + Box::pin(async move { + handler(params).await; + }) }); tokio::task::block_in_place(|| { - self.handlers - .blocking_write() - .insert(event.into(), handler); + self.handlers.blocking_write().insert(event.into(), handler); }); } - + /// Unsubscribes from the `event` pub(super) fn unsubscribe(&self, event: impl AsRef) { tokio::task::block_in_place(|| { - self.handlers - .blocking_write() - .remove(event.as_ref()); + self.handlers.blocking_write().remove(event.as_ref()); }); } - + /// Calls an appropriate notifications handler pub(super) async fn notify(&self, notification: Notification) { let guard = self.handlers.read().await; @@ -58,4 +51,4 @@ impl NotificationsHandler { handler(notification).await; } } -} \ No newline at end of file +} diff --git a/neva/src/client/options.rs b/neva/src/client/options.rs index 6af33c1..66763d2 100644 --- a/neva/src/client/options.rs +++ b/neva/src/client/options.rs @@ -1,22 +1,17 @@ -//! MCP client options +//! MCP client options -use std::collections::HashMap; -use std::fmt::{Debug, Formatter}; -use std::sync::Arc; -use std::time::Duration; use crate::PROTOCOL_VERSIONS; -use crate::transport::{StdIoClient, stdio::options::StdIoOptions, TransportProto}; use crate::client::notification_handler::NotificationsHandler; -use crate::types::sampling::SamplingHandler; +use crate::transport::{StdIoClient, TransportProto, stdio::options::StdIoOptions}; use crate::types::elicitation::ElicitationHandler; +use crate::types::sampling::SamplingHandler; use crate::types::{ - Root, - Implementation, - Uri, - RootsCapability, - SamplingCapability, - ElicitationCapability, + ElicitationCapability, Implementation, Root, RootsCapability, SamplingCapability, Uri, }; +use std::collections::HashMap; +use std::fmt::{Debug, Formatter}; +use std::sync::Arc; +use std::time::Duration; #[cfg(feature = "tasks")] use crate::types::ClientTasksCapability; @@ -30,13 +25,13 @@ const DEFAULT_REQUEST_TIMEOUT: u64 = 10; // 10 seconds pub struct McpOptions { /// Information of current client's implementation pub(crate) implementation: Implementation, - + /// Request timeout pub(super) timeout: Duration, - + /// Roots capability options pub(super) roots_capability: Option, - + /// Sampling capability options pub(super) sampling_capability: Option, @@ -52,16 +47,16 @@ pub struct McpOptions { /// Represents a handler function that runs when received a "elicitation/create" request pub(super) elicitation_handler: Option, - + /// Represents a hash map of notification handlers pub(super) notification_handler: Option>, - + /// An MCP version that a client supports protocol_ver: Option<&'static str>, /// Current transport protocol that the server uses proto: Option, - + /// Represents a list of roots that the client supports roots: HashMap, } @@ -80,7 +75,7 @@ impl Debug for McpOptions { #[cfg(feature = "tasks")] dbg.field("tasks_capability", &self.tasks_capability); - + dbg.finish() } } @@ -110,9 +105,11 @@ impl McpOptions { /// Sets stdio as a transport protocol pub fn with_stdio(mut self, command: &'static str, args: T) -> Self where - T: IntoIterator + T: IntoIterator, { - self.proto = Some(TransportProto::StdioClient(StdIoClient::new(StdIoOptions::new(command, args)))); + self.proto = Some(TransportProto::StdioClient(StdIoClient::new( + StdIoOptions::new(command, args), + ))); self } @@ -153,11 +150,11 @@ impl McpOptions { self.protocol_ver = Some(ver); self } - + /// Configures Roots capability pub fn with_roots(mut self, config: T) -> Self - where - T: FnOnce(RootsCapability) -> RootsCapability + where + T: FnOnce(RootsCapability) -> RootsCapability, { self.roots_capability = Some(config(Default::default())); self @@ -166,7 +163,7 @@ impl McpOptions { /// Configures Sampling capability pub fn with_sampling(mut self, config: T) -> Self where - T: FnOnce(SamplingCapability) -> SamplingCapability + T: FnOnce(SamplingCapability) -> SamplingCapability, { self.sampling_capability = Some(config(Default::default())); self @@ -175,7 +172,7 @@ impl McpOptions { /// Configures Elicitation capability pub fn with_elicitation(mut self, config: T) -> Self where - T: FnOnce(ElicitationCapability) -> ElicitationCapability + T: FnOnce(ElicitationCapability) -> ElicitationCapability, { self.elicitation_capability = Some(config(Default::default())); self @@ -185,7 +182,7 @@ impl McpOptions { #[cfg(feature = "tasks")] pub fn with_tasks(mut self, config: T) -> Self where - T: FnOnce(ClientTasksCapability) -> ClientTasksCapability + T: FnOnce(ClientTasksCapability) -> ClientTasksCapability, { self.tasks_capability = Some(config(Default::default())); self @@ -204,7 +201,7 @@ impl McpOptions { pub(crate) fn protocol_ver(&self) -> &'static str { match self.protocol_ver { Some(ver) => ver, - None => PROTOCOL_VERSIONS.last().unwrap() + None => PROTOCOL_VERSIONS.last().unwrap(), } } @@ -213,38 +210,31 @@ impl McpOptions { let transport = self.proto.take(); transport.unwrap_or_default() } - + /// Adds a root pub fn add_root(&mut self, root: Root) -> &mut Root { - self.roots - .entry(root.uri.clone()) - .or_insert(root) + self.roots.entry(root.uri.clone()).or_insert(root) } /// Adds multiple roots pub fn add_roots(&mut self, roots: I) -> &mut Self where T: Into, - I: IntoIterator + I: IntoIterator, { - let roots = roots - .into_iter() - .map(|item| { - let root: Root = item.into(); - (root.uri.clone(), root) - }); + let roots = roots.into_iter().map(|item| { + let root: Root = item.into(); + (root.uri.clone(), root) + }); self.roots.extend(roots); - self + self } - + /// Returns a list of defined Roots pub fn roots(&self) -> Vec { - self.roots - .values() - .cloned() - .collect() + self.roots.values().cloned().collect() } - + /// Registers a handler for sampling requests pub(crate) fn add_sampling_handler(&mut self, handler: SamplingHandler) { self.sampling_handler = Some(handler); @@ -283,7 +273,7 @@ impl McpOptions { } /// Returns [`ClientTasksCapability`] if configured. - /// + /// /// Otherwise, returns `None`. #[cfg(feature = "tasks")] pub(crate) fn tasks_capability(&self) -> Option { diff --git a/neva/src/client/subscribe.rs b/neva/src/client/subscribe.rs index 7e990a9..bd57ea1 100644 --- a/neva/src/client/subscribe.rs +++ b/neva/src/client/subscribe.rs @@ -1,92 +1,78 @@ -//! Additional helper methods for [`Client`] for variety notification subscription +//! Additional helper methods for [`Client`] for variety notification subscription -use std::future::Future; use super::Client; use crate::types::notification::Notification; +use std::future::Future; impl Client { /// Maps a `handler` to the `notifications/resources/updated` event pub fn on_resource_changed(&mut self, handler: F) where F: Fn(Notification) -> R + Clone + Send + Sync + 'static, - R: Future + Send + R: Future + Send, { assert!( - self.is_resource_subscription_supported(), + self.is_resource_subscription_supported(), "Server does not support resource subscriptions" ); - - self.subscribe( - crate::types::resource::commands::UPDATED, - handler - ); + + self.subscribe(crate::types::resource::commands::UPDATED, handler); } - + /// Maps a `handler` to the `notifications/resources/list_changed` event pub fn on_resources_changed(&mut self, handler: F) where F: Fn(Notification) -> R + Clone + Send + Sync + 'static, - R: Future + Send + R: Future + Send, { assert!( self.is_resource_list_changed_supported(), "Server does not support resource list changed events" ); - - self.subscribe( - crate::types::resource::commands::LIST_CHANGED, - handler - ); + + self.subscribe(crate::types::resource::commands::LIST_CHANGED, handler); } /// Maps a `handler` to the `notifications/tools/list_changed` event pub fn on_tools_changed(&mut self, handler: F) where F: Fn(Notification) -> R + Clone + Send + Sync + 'static, - R: Future + Send + R: Future + Send, { assert!( self.is_tools_list_changed_supported(), "Server does not support tools list changed events" ); - - self.subscribe( - crate::types::tool::commands::LIST_CHANGED, - handler - ); + + self.subscribe(crate::types::tool::commands::LIST_CHANGED, handler); } /// Maps a `handler` to the `notifications/prompts/list_changed` event pub fn on_prompts_changed(&mut self, handler: F) where F: Fn(Notification) -> R + Clone + Send + Sync + 'static, - R: Future + Send + R: Future + Send, { assert!( self.is_prompts_list_changed_supported(), "Server does not support prompts list changed events" ); - - self.subscribe( - crate::types::prompt::commands::LIST_CHANGED, - handler - ); + + self.subscribe(crate::types::prompt::commands::LIST_CHANGED, handler); } /// Maps a `handler` to the `notifications/elicitation/completed` event pub fn on_elicitation_completed(&mut self, handler: F) where F: Fn(Notification) -> R + Clone + Send + Sync + 'static, - R: Future + Send + R: Future + Send, { assert!( self.is_elicitation_supported(), "Client does not support elicitation. You may configure it with `Client::with_options(|opt| opt.with_elicitation())` method." ); - self.subscribe( - crate::types::elicitation::commands::COMPLETE, - handler); + self.subscribe(crate::types::elicitation::commands::COMPLETE, handler); } /// Maps a `handler` to the `notifications/tasks/status` event @@ -94,12 +80,10 @@ impl Client { pub fn on_task_status(&mut self, handler: F) where F: Fn(Notification) -> R + Clone + Send + Sync + 'static, - R: Future + Send + R: Future + Send, { self.ensure_tasks_supported(); - - self.subscribe( - crate::types::task::commands::STATUS, - handler); + + self.subscribe(crate::types::task::commands::STATUS, handler); } -} \ No newline at end of file +} diff --git a/neva/src/commands.rs b/neva/src/commands.rs index e2aae4e..1aaa464 100644 --- a/neva/src/commands.rs +++ b/neva/src/commands.rs @@ -4,4 +4,4 @@ pub const INIT: &str = "initialize"; /// Command name for pinging the server -pub const PING: &str = "ping"; \ No newline at end of file +pub const PING: &str = "ping"; diff --git a/neva/src/di.rs b/neva/src/di.rs index dac1d6d..2d55009 100644 --- a/neva/src/di.rs +++ b/neva/src/di.rs @@ -1,12 +1,7 @@ //! Types and utilities for dependency injection use super::error::{Error, ErrorCode}; -pub use volga_di::{ - ContainerBuilder, - Container, - GenericFactory, - Inject -}; +pub use volga_di::{Container, ContainerBuilder, GenericFactory, Inject}; pub use volga_di::error::Error as DiError; @@ -68,7 +63,7 @@ impl super::App { /// Registers scoped service that required to be resolved via factory /// - /// > **Note:** Provided factory function will be called once per scope + /// > **Note:** Provided factory function will be called once per scope /// > and the result will be available and reused per this scope lifetime. /// /// # Example @@ -91,7 +86,7 @@ impl super::App { where T: Send + Sync + 'static, F: GenericFactory, - Args: Inject + Args: Inject, { self.container.register_scoped_factory(factory); self @@ -99,7 +94,7 @@ impl super::App { /// Registers scoped service that required to be resolved as [`Default`] /// - /// > **Note:** the [`Default::default`] method will be called once per scope + /// > **Note:** the [`Default::default`] method will be called once per scope /// > and the result will be available and reused per this scope lifetime. /// /// # Example @@ -145,7 +140,7 @@ impl super::App { /// Registers transient service that required to be resolved via factory /// - /// > **Note:** Provided factory function will be called + /// > **Note:** Provided factory function will be called /// > every time once this service is requested. /// /// # Example @@ -168,7 +163,7 @@ impl super::App { where T: Send + Sync + 'static, F: GenericFactory, - Args: Inject + Args: Inject, { self.container.register_transient_factory(factory); self @@ -176,7 +171,7 @@ impl super::App { /// Registers transient service that required to be resolved as [`Default`] /// - /// > **Note:** the [`Default::default`] method will be called + /// > **Note:** the [`Default::default`] method will be called /// > every time once this service is requested. /// /// # Example @@ -201,8 +196,8 @@ impl super::App { #[cfg(feature = "server")] #[cfg(test)] mod tests { - use volga_di::{Container, Inject}; use super::super::App; + use volga_di::{Container, Inject}; #[derive(Default)] struct TestDependency; @@ -289,4 +284,4 @@ mod tests { assert!(dep.is_ok()); } -} \ No newline at end of file +} diff --git a/neva/src/di/dc.rs b/neva/src/di/dc.rs index 4cc0658..03d90f0 100644 --- a/neva/src/di/dc.rs +++ b/neva/src/di/dc.rs @@ -1,11 +1,13 @@ //! Extractors for Dependency Injection +use crate::app::handler::{FromHandlerParams, HandlerParams}; use crate::error::{Error, ErrorCode}; -use crate::app::handler::{HandlerParams, FromHandlerParams}; -use crate::types::{helpers::extract::RequestArgument, resource::ResourceArgument, RequestParamsMeta}; +use crate::types::{ + RequestParamsMeta, helpers::extract::RequestArgument, resource::ResourceArgument, +}; use std::{ ops::{Deref, DerefMut}, - sync::Arc + sync::Arc, }; /// `Dc` stands for Dependency Container. @@ -14,7 +16,7 @@ use std::{ /// /// # Example /// ```no_run -/// +/// /// ``` #[derive(Debug, Clone)] pub struct Dc(Arc); @@ -86,7 +88,10 @@ impl FromHandlerParams for Dc { fn from_params(params: &HandlerParams) -> Result { match params { HandlerParams::Request(context, _) => context.resolve_shared().map(Dc), - _ => Err(Error::new(ErrorCode::InternalError, "invalid handler parameters")) + _ => Err(Error::new( + ErrorCode::InternalError, + "invalid handler parameters", + )), } } } @@ -97,4 +102,4 @@ fn make_dc(meta: Option<&RequestParamsMeta>) -> Result .ok_or(Error::new(ErrorCode::InvalidParams, "Missing MCP context"))? .resolve_shared() .map(Dc) -} \ No newline at end of file +} diff --git a/neva/src/error.rs b/neva/src/error.rs index 5cbbe5d..5645af4 100644 --- a/neva/src/error.rs +++ b/neva/src/error.rs @@ -1,19 +1,15 @@ -//! Represents an error +//! Represents an error use std::convert::Infallible; -use std::fmt; use std::error::Error as StdError; +use std::fmt; use std::io::Error as IoError; pub use error_code::ErrorCode; pub mod error_code; -type BoxError = Box< - dyn StdError - + Send - + Sync ->; +type BoxError = Box; /// Represents MCP server error #[derive(Debug)] @@ -36,18 +32,18 @@ impl StdError for Error { impl From for Error { fn from(err: serde_json::Error) -> Error { - Self { + Self { inner: err.into(), - code: ErrorCode::ParseError - } + code: ErrorCode::ParseError, + } } } impl From for Error { fn from(err: IoError) -> Error { - Self { + Self { inner: err.into(), - code: ErrorCode::InternalError + code: ErrorCode::InternalError, } } } @@ -62,17 +58,12 @@ impl Error { /// Creates a new [`Error`] #[inline] pub fn new(code: impl TryInto, err: impl Into) -> Self { - Self { + Self { inner: err.into(), - code: code - .try_into() - .unwrap_or_default() + code: code.try_into().unwrap_or_default(), } } } #[cfg(test)] -mod tests { - -} - +mod tests {} diff --git a/neva/src/error/error_code.rs b/neva/src/error/error_code.rs index c914e84..40b4957 100644 --- a/neva/src/error/error_code.rs +++ b/neva/src/error/error_code.rs @@ -1,8 +1,8 @@ //! Represents error code tools -use std::fmt::Display; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; use crate::error::Error; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::fmt::Display; /// Standard JSON-RPC error codes as defined in the MCP specification. #[derive(Default, Debug, Copy, Clone, Eq, PartialEq)] @@ -28,7 +28,7 @@ pub enum ErrorCode { /// The URL mode elicitation is required. UrlElicitationRequiredError = -32042, - + /// [Internal code] The request has been canceled RequestCancelled = -99999, @@ -80,20 +80,19 @@ impl<'de> Deserialize<'de> for ErrorCode { D: Deserializer<'de>, { let value = i32::deserialize(deserializer)?; - ErrorCode::try_from(value).map_err(|_| { - serde::de::Error::custom(format!("Invalid error code: {value}")) - }) + ErrorCode::try_from(value) + .map_err(|_| serde::de::Error::custom(format!("Invalid error code: {value}"))) } } impl Display for ErrorCode { #[inline] fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { + match self { ErrorCode::ParseError => write!(f, "Parse error"), ErrorCode::InvalidRequest => write!(f, "Invalid request"), ErrorCode::MethodNotFound => write!(f, "Method not found"), - ErrorCode::InvalidParams => write!(f, "Invalid parameters"), + ErrorCode::InvalidParams => write!(f, "Invalid parameters"), ErrorCode::InternalError => write!(f, "Internal error"), ErrorCode::ResourceNotFound => write!(f, "Resource not found"), ErrorCode::UrlElicitationRequiredError => write!(f, "URL elicitation required error"), @@ -116,7 +115,7 @@ mod tests { #[test] fn it_converts_to_i32() { let codes = [ - (-32700, ErrorCode::ParseError), + (-32700, ErrorCode::ParseError), (-32600, ErrorCode::InvalidRequest), (-32601, ErrorCode::MethodNotFound), (-32602, ErrorCode::InvalidParams), @@ -139,7 +138,7 @@ mod tests { #[test] fn it_serializes_error_codes() { let codes = [ - ("-32700", ErrorCode::ParseError), + ("-32700", ErrorCode::ParseError), ("-32600", ErrorCode::InvalidRequest), ("-32601", ErrorCode::MethodNotFound), ("-32602", ErrorCode::InvalidParams), @@ -158,4 +157,4 @@ mod tests { assert_eq!(error_code, val); } } -} \ No newline at end of file +} diff --git a/neva/src/lib.rs b/neva/src/lib.rs index cbbef66..1dc3b1d 100644 --- a/neva/src/lib.rs +++ b/neva/src/lib.rs @@ -1,30 +1,30 @@ //! # Neva //! Easy configurable MCP server and client SDK for Rust -//! +//! //! ## Dependencies //! ```toml //! [dependencies] //! neva = { version = "0.2.1", features = ["full"] } //! tokio = { version = "1", features = ["full"] } //! ``` -//! +//! //! ## Example Server //! ```no_run //! # #[cfg(feature = "server")] { //! use neva::App; -//! +//! //! #[tokio::main] //! async fn main() { //! let mut app = App::new() //! .with_options(|opt| opt //! .with_stdio()); -//! -//! app.map_tool("hello", |name: String| async move { +//! +//! app.map_tool("hello", |name: String| async move { //! format!("Hello, {name}!") //! }); -//! +//! //! app.run().await; -//! } +//! } //! # } //! ``` //! # Example Client @@ -32,20 +32,20 @@ //! # #[cfg(feature = "client")] { //! use std::time::Duration; //! use neva::{Client, error::Error}; -//! +//! //! #[tokio::main] //! async fn main() -> Result<(), Error> { //! let mut client = Client::new() //! .with_options(|opt| opt //! .with_stdio("npx", ["-y", "@modelcontextprotocol/server-everything"])); -//! +//! //! client.connect().await?; -//! +//! //! // Call a tool //! let args = [("message", "Hello MCP!")]; //! let result = client.call_tool("echo", Some(args)).await?; //! println!("{:?}", result.content); -//! +//! //! client.disconnect().await //! } //! # } @@ -56,85 +56,81 @@ pub use app::{App, context::Context}; #[cfg(feature = "client")] pub use client::Client; -pub mod types; -#[cfg(any(feature = "server", feature = "client"))] -pub mod transport; -pub mod error; -pub mod shared; #[cfg(feature = "server")] pub mod app; #[cfg(feature = "client")] pub mod client; +pub mod commands; +#[cfg(feature = "di")] +pub mod di; +pub mod error; #[cfg(feature = "macros")] pub mod macros; -pub mod commands; #[cfg(feature = "server")] pub mod middleware; -#[cfg(feature = "di")] -pub mod di; +pub mod shared; +#[cfg(any(feature = "server", feature = "client"))] +pub mod transport; +pub mod types; -#[cfg(feature = "server-macros")] -pub use neva_macros::{tool, prompt, resource, resources, handler, completion}; -#[cfg(feature = "client-macros")] -pub use neva_macros::{sampling, elicitation}; #[cfg(feature = "macros")] pub use neva_macros::json_schema; +#[cfg(feature = "server-macros")] +pub use neva_macros::{completion, handler, prompt, resource, resources, tool}; +#[cfg(feature = "client-macros")] +pub use neva_macros::{elicitation, sampling}; pub(crate) const SDK_NAME: &str = "neva"; #[cfg(any(feature = "server", feature = "client"))] -pub(crate) const PROTOCOL_VERSIONS: [&str; 4] = [ - "2024-11-05", - "2025-03-26", - "2025-06-18", - "2025-11-25" -]; +pub(crate) const PROTOCOL_VERSIONS: [&str; 4] = + ["2024-11-05", "2025-03-26", "2025-06-18", "2025-11-25"]; #[cfg(feature = "http-server")] pub mod auth { //! Authentication utilities - - pub use volga::auth::{Algorithm, Authorizer, Claims}; + pub use crate::transport::http::server::{AuthConfig, DefaultClaims}; + pub use volga::auth::{Algorithm, Authorizer, Claims}; } pub mod json { //! JSON utilities - - pub use schemars::JsonSchema; + #[doc(hidden)] pub use schemars; + pub use schemars::JsonSchema; } pub mod prelude { //! Prelude with commonly used items - - pub use crate::types::*; + pub use crate::error::*; pub use crate::json::*; + pub use crate::types::*; #[cfg(feature = "http-server")] pub use crate::transport::HttpServer; #[cfg(all(feature = "http-server", feature = "server-tls"))] - pub use crate::transport::http::{TlsConfig, DevCertMode}; - + pub use crate::transport::http::{DevCertMode, TlsConfig}; + #[cfg(feature = "server")] pub use crate::app::{App, context::Context, options}; #[cfg(feature = "server")] pub use crate::middleware::{MwContext, Next}; - + #[cfg(feature = "client")] pub use crate::client::Client; - - #[cfg(feature = "server-macros")] - pub use crate::{tool, prompt, resource, resources, handler, completion}; - #[cfg(feature = "client-macros")] - pub use crate::{sampling, elicitation}; + #[cfg(feature = "macros")] pub use crate::json_schema; + #[cfg(feature = "server-macros")] + pub use crate::{completion, handler, prompt, resource, resources, tool}; + #[cfg(feature = "client-macros")] + pub use crate::{elicitation, sampling}; #[cfg(feature = "http-server")] pub use crate::auth::*; - + #[cfg(feature = "di")] pub use crate::di::Dc; diff --git a/neva/src/macros.rs b/neva/src/macros.rs index 52c3c4d..d9ab134 100644 --- a/neva/src/macros.rs +++ b/neva/src/macros.rs @@ -2,7 +2,7 @@ pub use inventory; -#[cfg(feature = "server")] -pub mod server; #[cfg(feature = "client")] pub mod client; +#[cfg(feature = "server")] +pub mod server; diff --git a/neva/src/macros/client.rs b/neva/src/macros/client.rs index 6f75de8..e2e36e2 100644 --- a/neva/src/macros/client.rs +++ b/neva/src/macros/client.rs @@ -1,7 +1,7 @@ //! Macros for a client -use crate::Client; use super::inventory; +use crate::Client; /// Registrar unit for tools, resources, templates and prompts #[derive(Debug)] @@ -22,4 +22,4 @@ impl Client { registrar.register(self); } } -} \ No newline at end of file +} diff --git a/neva/src/macros/server.rs b/neva/src/macros/server.rs index 8b9e4d7..2a1c558 100644 --- a/neva/src/macros/server.rs +++ b/neva/src/macros/server.rs @@ -1,7 +1,7 @@ //! Macros for server -use crate::App; use super::inventory; +use crate::App; /// Registrar unit for tools, resources, templates and prompts #[derive(Debug)] @@ -22,4 +22,4 @@ impl App { registrar.register(self); } } -} \ No newline at end of file +} diff --git a/neva/src/middleware.rs b/neva/src/middleware.rs index 551552a..4f6a625 100644 --- a/neva/src/middleware.rs +++ b/neva/src/middleware.rs @@ -1,18 +1,15 @@ //! MCP Server middleware utilities -use std::fmt::Debug; -use std::sync::Arc; -use futures_util::future::BoxFuture; use crate::{ - app::context::ServerRuntime, - types::{Message, RequestId, Request, Response, notification::Notification} + app::context::ServerRuntime, + types::{Message, Request, RequestId, Response, notification::Notification}, }; +use futures_util::future::BoxFuture; +use std::fmt::Debug; +use std::sync::Arc; #[cfg(feature = "di")] -use { - volga_di::Container, - crate::error::Error -}; +use {crate::error::Error, volga_di::Container}; pub(super) mod make_fn; pub mod wrap; @@ -23,10 +20,10 @@ const DEFAULT_MW_CAPACITY: usize = 8; pub struct MwContext { /// Current JSON-RPC message pub msg: Message, - + /// Server runtime reference pub(super) runtime: ServerRuntime, - + /// Dependency injection container scope. #[cfg(feature = "di")] pub(super) scope: Container, @@ -35,30 +32,21 @@ pub struct MwContext { impl Debug for MwContext { #[inline] fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("MwContext") - .field("msg", &self.msg) - .finish() + f.debug_struct("MwContext").field("msg", &self.msg).finish() } } /// A reference to the next middleware in the chain -pub type Next = Arc< - dyn Fn(MwContext) -> BoxFuture<'static, Response> - + Send - + Sync ->; +pub type Next = Arc BoxFuture<'static, Response> + Send + Sync>; /// Middleware function wrapper -pub(super) type Middleware = Arc< - dyn Fn(MwContext, Next) -> BoxFuture<'static, Response> - + Send - + Sync ->; +pub(super) type Middleware = + Arc BoxFuture<'static, Response> + Send + Sync>; /// MCP middleware pipeline. #[derive(Clone)] pub(super) struct Middlewares { - pub(super) pipeline: Vec + pub(super) pipeline: Vec, } impl MwContext { @@ -67,11 +55,11 @@ impl MwContext { pub(super) fn msg(msg: Message, runtime: ServerRuntime) -> Self { #[cfg(feature = "di")] let scope = runtime.container.create_scope(); - Self { - msg, + Self { + msg, runtime, #[cfg(feature = "di")] - scope + scope, } } @@ -153,31 +141,29 @@ impl MwContext { } } - /// Resolves a service and returns a cloned instance. - /// `T` must implement `Clone` otherwise + /// Resolves a service and returns a cloned instance. + /// `T` must implement `Clone` otherwise /// use resolve_shared method that returns a shared pointer. #[inline] #[cfg(feature = "di")] pub fn resolve(&self) -> Result { - self.scope - .resolve::() - .map_err(Into::into) + self.scope.resolve::().map_err(Into::into) } /// Resolves a service and returns a shared pointer #[inline] #[cfg(feature = "di")] pub fn resolve_shared(&self) -> Result, Error> { - self.scope - .resolve_shared::() - .map_err(Into::into) + self.scope.resolve_shared::().map_err(Into::into) } } impl Middlewares { /// Initializes a new middleware pipeline pub(super) fn new() -> Self { - Self { pipeline: Vec::with_capacity(DEFAULT_MW_CAPACITY) } + Self { + pipeline: Vec::with_capacity(DEFAULT_MW_CAPACITY), + } } /// Adds middleware function to the pipeline @@ -192,15 +178,14 @@ impl Middlewares { return None; } - let request_handler = self.pipeline - .last() - .unwrap() - .clone(); - - let mut next: Next = Arc::new(move |ctx| request_handler( - ctx, - Arc::new(|ctx| Box::pin(async move { Response::empty(ctx.id()) })) - )); + let request_handler = self.pipeline.last().unwrap().clone(); + + let mut next: Next = Arc::new(move |ctx| { + request_handler( + ctx, + Arc::new(|ctx| Box::pin(async move { Response::empty(ctx.id()) })), + ) + }); for mw in self.pipeline.iter().rev().skip(1) { let current_mw: Middleware = mw.clone(); let prev_next: Next = next.clone(); diff --git a/neva/src/middleware/make_fn.rs b/neva/src/middleware/make_fn.rs index f323c8c..675e374 100644 --- a/neva/src/middleware/make_fn.rs +++ b/neva/src/middleware/make_fn.rs @@ -1,10 +1,10 @@ //! Middleware factory functions -use std::{future::Future, sync::Arc}; use crate::{ middleware::{Middleware, MwContext, Next}, - types::{Message, Response} + types::{Message, Response}, }; +use std::{future::Future, sync::Arc}; /// Turns a closure into middleware #[inline] @@ -13,18 +13,16 @@ where F: Fn(MwContext, Next) -> R + Clone + Send + Sync + 'static, R: Future + Send + 'static, { - Arc::new(move |ctx: MwContext, next: Next| { - Box::pin(f(ctx, next)) - }) + Arc::new(move |ctx: MwContext, next: Next| Box::pin(f(ctx, next))) } -/// Turns a closure into middleware that runs only +/// Turns a closure into middleware that runs only /// if the MCP server received a message that satisfies the condition. #[inline] -pub(super) fn make_on(f: F, p: P) -> Middleware +pub(super) fn make_on(f: F, p: P) -> Middleware where F: Fn(MwContext, Next) -> R + Clone + Send + Sync + 'static, - P: Fn(&Message) -> bool + Clone + Send + Sync + 'static, + P: Fn(&Message) -> bool + Clone + Send + Sync + 'static, R: Future + Send + 'static, { let mw = move |ctx: MwContext, next: Next| { @@ -41,10 +39,10 @@ where make_mw(mw) } -/// Turns a closure into middleware that runs only +/// Turns a closure into middleware that runs only /// if the MCP server received a message that satisfies the condition. #[inline] -pub(super) fn make_on_command(f: F, command: &'static str) -> Middleware +pub(super) fn make_on_command(f: F, command: &'static str) -> Middleware where F: Fn(MwContext, Next) -> R + Clone + Send + Sync + 'static, R: Future + Send + 'static, @@ -56,4 +54,4 @@ where false } }) -} \ No newline at end of file +} diff --git a/neva/src/middleware/wrap.rs b/neva/src/middleware/wrap.rs index 08f7042..30bb6b2 100644 --- a/neva/src/middleware/wrap.rs +++ b/neva/src/middleware/wrap.rs @@ -1,11 +1,14 @@ //! MCP middleware wrappers -use std::future::Future; use crate::{ - middleware::{make_fn::{make_mw, make_on, make_on_command}, MwContext, Next}, + App, + middleware::{ + MwContext, Next, + make_fn::{make_mw, make_on, make_on_command}, + }, types::{Message, Response}, - App }; +use std::future::Future; impl App { /// Registers a global middleware @@ -18,46 +21,43 @@ impl App { self } - /// Registers a global middleware that runs only + /// Registers a global middleware that runs only /// if the MCP server received a notification message pub fn wrap_notification(mut self, middleware: F) -> Self where F: Fn(MwContext, Next) -> R + Clone + Send + Sync + 'static, R: Future + Send + 'static, { - self.options.add_middleware(make_on( - middleware, - |msg| msg.is_notification())); + self.options + .add_middleware(make_on(middleware, |msg| msg.is_notification())); self } - /// Registers a global middleware that runs only + /// Registers a global middleware that runs only /// if the MCP server received a request message pub fn wrap_request(mut self, middleware: F) -> Self where F: Fn(MwContext, Next) -> R + Clone + Send + Sync + 'static, R: Future + Send + 'static, { - self.options.add_middleware(make_on( - middleware, - |msg| msg.is_request())); + self.options + .add_middleware(make_on(middleware, |msg| msg.is_request())); self } - /// Registers a global middleware that runs only + /// Registers a global middleware that runs only /// if the MCP server received a response message pub fn wrap_response(mut self, middleware: F) -> Self where F: Fn(MwContext, Next) -> R + Clone + Send + Sync + 'static, R: Future + Send + 'static, { - self.options.add_middleware(make_on( - middleware, - |msg| msg.is_response())); + self.options + .add_middleware(make_on(middleware, |msg| msg.is_response())); self } - /// Registers a middleware that runs only + /// Registers a middleware that runs only /// if the MCP server received a `tools/call` request pub fn wrap_tools(mut self, middleware: F) -> Self where @@ -66,11 +66,12 @@ impl App { { self.options.add_middleware(make_on_command( middleware, - crate::types::tool::commands::CALL)); + crate::types::tool::commands::CALL, + )); self } - /// Registers a middleware that runs only + /// Registers a middleware that runs only /// if the MCP server received a `prompts/get` request pub fn wrap_prompts(mut self, middleware: F) -> Self where @@ -79,11 +80,12 @@ impl App { { self.options.add_middleware(make_on_command( middleware, - crate::types::prompt::commands::GET)); + crate::types::prompt::commands::GET, + )); self } - /// Registers a middleware that runs only + /// Registers a middleware that runs only /// if the MCP server received a `resources/read` request pub fn wrap_resources(mut self, middleware: F) -> Self where @@ -92,11 +94,12 @@ impl App { { self.options.add_middleware(make_on_command( middleware, - crate::types::resource::commands::READ)); + crate::types::resource::commands::READ, + )); self } - /// Registers a middleware that runs only + /// Registers a middleware that runs only /// if the MCP server received a `resources/list` request pub fn wrap_list_resources(mut self, middleware: F) -> Self where @@ -105,11 +108,12 @@ impl App { { self.options.add_middleware(make_on_command( middleware, - crate::types::resource::commands::LIST)); + crate::types::resource::commands::LIST, + )); self } - /// Registers a middleware that runs only + /// Registers a middleware that runs only /// if the MCP server received a `resources/templates/list` request pub fn wrap_list_resource_templates(mut self, middleware: F) -> Self where @@ -118,11 +122,12 @@ impl App { { self.options.add_middleware(make_on_command( middleware, - crate::types::resource::commands::TEMPLATES_LIST)); + crate::types::resource::commands::TEMPLATES_LIST, + )); self } - /// Registers a middleware that runs only + /// Registers a middleware that runs only /// if the MCP server received a `tools/list` request pub fn wrap_list_tools(mut self, middleware: F) -> Self where @@ -131,11 +136,12 @@ impl App { { self.options.add_middleware(make_on_command( middleware, - crate::types::tool::commands::LIST)); + crate::types::tool::commands::LIST, + )); self } - /// Registers a middleware that runs only + /// Registers a middleware that runs only /// if the MCP server received a `prompts/list` request pub fn wrap_list_prompts(mut self, middleware: F) -> Self where @@ -144,79 +150,74 @@ impl App { { self.options.add_middleware(make_on_command( middleware, - crate::types::prompt::commands::LIST)); + crate::types::prompt::commands::LIST, + )); self } - /// Registers a middleware that runs only + /// Registers a middleware that runs only /// if the MCP server received an `initialize` request pub fn wrap_init(mut self, middleware: F) -> Self where F: Fn(MwContext, Next) -> R + Clone + Send + Sync + 'static, R: Future + Send + 'static, { - self.options.add_middleware(make_on_command( - middleware, - crate::commands::INIT)); + self.options + .add_middleware(make_on_command(middleware, crate::commands::INIT)); self } - /// Registers a middleware that runs only + /// Registers a middleware that runs only /// if the MCP server received the command with the `name` request pub fn wrap_command(&mut self, name: &'static str, middleware: F) -> &mut Self where F: Fn(MwContext, Next) -> R + Clone + Send + Sync + 'static, R: Future + Send + 'static, { - self.options.add_middleware(make_on_command( - middleware, - name)); + self.options + .add_middleware(make_on_command(middleware, name)); self } - /// Registers a middleware that runs only + /// Registers a middleware that runs only /// if the MCP server received a `tools/call` request pub fn wrap_tool(&mut self, name: &'static str, middleware: F) -> &mut Self where F: Fn(MwContext, Next) -> R + Clone + Send + Sync + 'static, R: Future + Send + 'static, { - self.options.add_middleware(make_on( - middleware, - move |msg| { - if let Message::Request(req) = msg { - req.method == crate::types::tool::commands::CALL && - req.params - .as_ref() - .is_some_and(|p| p.get("name") - .is_some_and(|n| n == name)) - } else { - false - } - })); + self.options.add_middleware(make_on(middleware, move |msg| { + if let Message::Request(req) = msg { + req.method == crate::types::tool::commands::CALL + && req + .params + .as_ref() + .is_some_and(|p| p.get("name").is_some_and(|n| n == name)) + } else { + false + } + })); self } - /// Registers a middleware that runs only + /// Registers a middleware that runs only /// if the MCP server received a `prompt/get` request pub fn wrap_prompt(&mut self, name: &'static str, middleware: F) -> &mut Self where F: Fn(MwContext, Next) -> R + Clone + Send + Sync + 'static, R: Future + Send + 'static, { - self.options.add_middleware(make_on( - middleware, - move |msg| { - if let Message::Request(req) = msg { - req.method == crate::types::prompt::commands::GET && - req.params - .as_ref() - .is_some_and(|p| p.get("name") - .is_some_and(|n| n == name)) - } else { - false - } - })); + self.options.add_middleware(make_on(middleware, move |msg| { + if let Message::Request(req) = msg { + req.method == crate::types::prompt::commands::GET + && req + .params + .as_ref() + .is_some_and(|p| p.get("name").is_some_and(|n| n == name)) + } else { + false + } + })); self } -} \ No newline at end of file +} diff --git a/neva/src/shared.rs b/neva/src/shared.rs index 2901bf4..9c03029 100644 --- a/neva/src/shared.rs +++ b/neva/src/shared.rs @@ -1,43 +1,43 @@ -//! Shared utilities for server and client +//! Shared utilities for server and client #[cfg(any(feature = "server", feature = "client"))] use tokio_util::sync::CancellationToken; -#[cfg(any(feature = "server", feature = "client"))] -pub(crate) use requests_queue::RequestQueue; #[cfg(any(feature = "http-server", feature = "tracing"))] pub(crate) use message_registry::MessageRegistry; -#[cfg(feature = "tasks")] -pub(crate) use task_tracker::TaskTracker; +#[cfg(any(feature = "server", feature = "client"))] +pub(crate) use requests_queue::RequestQueue; #[cfg(all(feature = "tasks", feature = "server"))] pub(crate) use task_tracker::TaskHandle; +#[cfg(feature = "tasks")] +pub(crate) use task_tracker::TaskTracker; -pub(crate) use arc_str::ArcStr; pub(crate) use arc_slice::ArcSlice; +pub(crate) use arc_str::ArcStr; pub(crate) use memchr::MemChr; -pub use one_or_many::OneOrMany; pub use either::Either; pub use into_args::IntoArgs; +pub use one_or_many::OneOrMany; #[cfg(feature = "tasks")] pub use task_api::{TaskApi, wait_to_completion}; -#[cfg(feature = "http-client")] -pub mod mt; -#[cfg(any(feature = "server", feature = "client"))] -mod requests_queue; -#[cfg(any(feature = "http-server", feature = "tracing"))] -mod message_registry; -mod arc_str; mod arc_slice; +mod arc_str; +mod either; mod into_args; mod memchr; +#[cfg(any(feature = "http-server", feature = "tracing"))] +mod message_registry; +#[cfg(feature = "http-client")] +pub mod mt; mod one_or_many; -mod either; -#[cfg(feature = "tasks")] -mod task_tracker; +#[cfg(any(feature = "server", feature = "client"))] +mod requests_queue; #[cfg(feature = "tasks")] mod task_api; +#[cfg(feature = "tasks")] +mod task_tracker; #[inline] #[cfg(any(feature = "server", feature = "client"))] @@ -48,9 +48,11 @@ pub(crate) fn wait_for_shutdown_signal(token: CancellationToken) { #[cfg(feature = "tracing")] Err(err) => tracing::error!( logger = "neva", - "Unable to listen for shutdown signal: {}", err), + "Unable to listen for shutdown signal: {}", + err + ), #[cfg(not(feature = "tracing"))] - Err(_) => () + Err(_) => (), } token.cancel(); }); diff --git a/neva/src/shared/arc_slice.rs b/neva/src/shared/arc_slice.rs index e5d81bd..f505cc0 100644 --- a/neva/src/shared/arc_slice.rs +++ b/neva/src/shared/arc_slice.rs @@ -1,10 +1,10 @@ -//! Types and utilities for serializable [`Arc<[T]>`] +//! Types and utilities for serializable [`Arc<[T]>`] +use serde::{Deserialize, Serialize}; use std::fmt; use std::fmt::Display; use std::ops::Deref; use std::sync::Arc; -use serde::{Deserialize, Serialize}; /// Represents a serializable [`Arc<[T]>`] #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -16,7 +16,7 @@ impl Display for ArcSlice { for (i, seg) in self.iter().enumerate() { if i > 0 { write!(f, "/{seg}")?; - } else { + } else { write!(f, "{seg}")?; } } @@ -47,21 +47,21 @@ impl From<[T; N]> for ArcSlice { } } -impl From> for ArcSlice { +impl From> for ArcSlice { #[inline] fn from(arc: Arc<[T]>) -> Self { - Self(arc) + Self(arc) } } impl Serialize for ArcSlice -where - T: Display +where + T: Display, { #[inline] fn serialize(&self, serializer: S) -> Result where - S: serde::Serializer + S: serde::Serializer, { serializer.serialize_str(&self.to_string()) } @@ -75,7 +75,7 @@ where #[inline] fn deserialize(deserializer: D) -> Result where - D: serde::Deserializer<'de> + D: serde::Deserializer<'de>, { let s: &str = Deserialize::deserialize(deserializer)?; let parsed = s @@ -88,10 +88,10 @@ where #[cfg(test)] mod tests { - use std::sync::Arc; - use uuid::Uuid; use super::*; use crate::{shared::ArcStr, types::RequestId}; + use std::sync::Arc; + use uuid::Uuid; #[test] fn it_tests_display() { @@ -111,10 +111,7 @@ mod tests { ])); let json = serde_json::to_string(&slice).unwrap(); - assert_eq!( - json, - "\"b9d3c680-bb27-4d7d-9e76-111111111111/1/abc\"" - ); + assert_eq!(json, "\"b9d3c680-bb27-4d7d-9e76-111111111111/1/abc\""); } #[test] @@ -144,4 +141,4 @@ mod tests { assert_eq!(original, decoded); } -} \ No newline at end of file +} diff --git a/neva/src/shared/arc_str.rs b/neva/src/shared/arc_str.rs index cb16d06..16ea09b 100644 --- a/neva/src/shared/arc_str.rs +++ b/neva/src/shared/arc_str.rs @@ -1,9 +1,9 @@ -//! Types and utilities for serializable [`Arc`] +//! Types and utilities for serializable [`Arc`] +use serde::{Deserialize, Serialize}; use std::fmt; -use std::sync::Arc; use std::ops::Deref; -use serde::{Serialize, Deserialize}; +use std::sync::Arc; /// Represents a serializable [`Arc`] #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -18,7 +18,7 @@ impl fmt::Display for ArcStr { impl Deref for ArcStr { type Target = str; - + #[inline] fn deref(&self) -> &Self::Target { &self.0 @@ -49,8 +49,8 @@ impl From> for ArcStr { impl Serialize for ArcStr { #[inline] fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer + where + S: serde::Serializer, { serializer.serialize_str(self) } @@ -60,7 +60,7 @@ impl<'de> Deserialize<'de> for ArcStr { #[inline] fn deserialize(deserializer: D) -> Result where - D: serde::Deserializer<'de> + D: serde::Deserializer<'de>, { let s: &str = Deserialize::deserialize(deserializer)?; Ok(ArcStr::from(s)) @@ -79,7 +79,9 @@ mod tests { #[test] fn serialize_arcstr() { - let w = Wrapper { id: ArcStr::from("hello") }; + let w = Wrapper { + id: ArcStr::from("hello"), + }; let json = serde_json::to_string(&w).unwrap(); assert_eq!(json, r#"{"id":"hello"}"#); } diff --git a/neva/src/shared/either.rs b/neva/src/shared/either.rs index a7910d8..8bf597e 100644 --- a/neva/src/shared/either.rs +++ b/neva/src/shared/either.rs @@ -1,7 +1,7 @@ //! Types and utilities for the "either" pattern -use serde::{Serialize, Deserialize, Serializer}; use crate::types::{IntoResponse, RequestId, Response}; +use serde::{Deserialize, Serialize, Serializer}; /// Represents a value of one of two types #[derive(Debug, Clone, Deserialize)] @@ -9,38 +9,38 @@ use crate::types::{IntoResponse, RequestId, Response}; pub enum Either { /// Left value Left(L), - + /// Right value Right(R), } -impl Serialize for Either +impl Serialize for Either where L: Serialize, - R: Serialize + R: Serialize, { #[inline] fn serialize(&self, serializer: S) -> Result where S: Serializer, { - match self { + match self { Either::Left(l) => l.serialize(serializer), - Either::Right(r) => r.serialize(serializer) + Either::Right(r) => r.serialize(serializer), } - } + } } impl IntoResponse for Either where L: IntoResponse, - R: IntoResponse + R: IntoResponse, { #[inline] fn into_response(self, req_id: RequestId) -> Response { - match self { + match self { Either::Left(l) => l.into_response(req_id), - Either::Right(r) => r.into_response(req_id) + Either::Right(r) => r.into_response(req_id), } } -} \ No newline at end of file +} diff --git a/neva/src/shared/into_args.rs b/neva/src/shared/into_args.rs index d54b5b5..86ecf6a 100644 --- a/neva/src/shared/into_args.rs +++ b/neva/src/shared/into_args.rs @@ -1,9 +1,9 @@ //! Utilities for conversion various types into tool or prompt arguments use crate::types::Json; -use std::collections::HashMap; use serde::Serialize; use serde_json::Value; +use std::collections::HashMap; /// A trait describes arguments for tools and prompts pub trait IntoArgs { @@ -32,9 +32,10 @@ where { #[inline] fn into_args(self) -> Option> { - Some(HashMap::from([ - (self.0.into(), serde_json::to_value(self.1).unwrap()) - ])) + Some(HashMap::from([( + self.0.into(), + serde_json::to_value(self.1).unwrap(), + )])) } } @@ -75,9 +76,7 @@ impl IntoArgs for Value { #[inline] fn into_args(self) -> Option> { match self { - Value::Object(map) => Some(map - .into_iter() - .collect()), + Value::Object(map) => Some(map.into_iter().collect()), _ => None, } } @@ -86,11 +85,9 @@ impl IntoArgs for Value { impl IntoArgs for Json { #[inline] fn into_args(self) -> Option> { - serde_json::to_value(self.0) - .ok() - .into_args() + serde_json::to_value(self.0).ok().into_args() } -} +} /// Creates arguments for tools and prompts from iterator #[inline] @@ -100,9 +97,10 @@ where K: Into, T: Serialize, { - HashMap::from_iter(args - .into_iter() - .map(|(k, v)| (k.into(), serde_json::to_value(v).unwrap()))) + HashMap::from_iter( + args.into_iter() + .map(|(k, v)| (k.into(), serde_json::to_value(v).unwrap())), + ) } #[cfg(test)] @@ -131,27 +129,21 @@ mod tests { #[test] fn it_converts_tuple_into_single_key_value_pair() { - let args = ("answer", 42) - .into_args() - .unwrap(); + let args = ("answer", 42).into_args().unwrap(); assert_eq!(args.len(), 1); assert_eq!(args.get("answer"), Some(&json!(42))); } #[test] fn it_converts_array_of_pairs_into_hashmap() { - let args = [("a", 1), ("b", 2)] - .into_args() - .unwrap(); + let args = [("a", 1), ("b", 2)].into_args().unwrap(); assert_eq!(args.get("a"), Some(&json!(1))); assert_eq!(args.get("b"), Some(&json!(2))); } #[test] fn it_converts_vec_of_pairs_into_hashmap() { - let args = vec![("x", true), ("y", false)] - .into_args() - .unwrap(); + let args = vec![("x", true), ("y", false)].into_args().unwrap(); assert_eq!(args.get("x"), Some(&json!(true))); assert_eq!(args.get("y"), Some(&json!(false))); } @@ -178,47 +170,42 @@ mod tests { #[test] fn it_overwrites_duplicate_keys_in_later_entries() { - let args = vec![("k", 1), ("k", 2)] - .into_args() - .unwrap(); + let args = vec![("k", 1), ("k", 2)].into_args().unwrap(); assert_eq!(args.get("k"), Some(&json!(2))); } #[test] fn it_supports_keys_that_are_not_str_directly() { - let args = vec![(String::from("id"), 99)] - .into_args() - .unwrap(); + let args = vec![(String::from("id"), 99)].into_args().unwrap(); assert_eq!(args.get("id"), Some(&json!(99))); } #[test] fn it_creates_empty_hashmap_for_empty_vec() { let args: Vec<(String, i32)> = vec![]; - let result = args - .into_args() - .unwrap(); + let result = args.into_args().unwrap(); assert!(result.is_empty()); } #[test] fn it_handles_serde_value() { - let args = json!({ "name": "Alice", "age": 30 }) - .into_args() - .unwrap(); + let args = json!({ "name": "Alice", "age": 30 }).into_args().unwrap(); assert_eq!(args.get("name"), Some(&json!("Alice"))); assert_eq!(args.get("age"), Some(&json!(30))); } #[test] fn it_handles_json_value() { - let args = Json(User { name: "Alice".into(), age: 30 }) - .into_args() - .unwrap(); + let args = Json(User { + name: "Alice".into(), + age: 30, + }) + .into_args() + .unwrap(); assert_eq!(args.get("name"), Some(&json!("Alice"))); assert_eq!(args.get("age"), Some(&json!(30))); } - + #[derive(Serialize)] struct User { name: String, diff --git a/neva/src/shared/memchr.rs b/neva/src/shared/memchr.rs index 23bdaff..26e836c 100644 --- a/neva/src/shared/memchr.rs +++ b/neva/src/shared/memchr.rs @@ -46,7 +46,10 @@ mod tests { #[test] fn it_splits_on_char1() { - assert_eq!(MemChr::split("res://a/b", b'/').collect::>(), ["res:", "a", "b"]); + assert_eq!( + MemChr::split("res://a/b", b'/').collect::>(), + ["res:", "a", "b"] + ); } #[test] @@ -66,6 +69,9 @@ mod tests { #[test] fn it_splits_on_empty() { - assert_eq!(MemChr::split("", b'/').collect::>(), Vec::<&str>::new()); + assert_eq!( + MemChr::split("", b'/').collect::>(), + Vec::<&str>::new() + ); } -} \ No newline at end of file +} diff --git a/neva/src/shared/message_registry.rs b/neva/src/shared/message_registry.rs index d098d88..0490474 100644 --- a/neva/src/shared/message_registry.rs +++ b/neva/src/shared/message_registry.rs @@ -1,15 +1,15 @@ -//! Tools for binding message channels with MCP Sessions +//! Tools for binding message channels with MCP Sessions +use crate::error::{Error, ErrorCode}; +use crate::types::Message; use dashmap::DashMap; use tokio::sync::mpsc::UnboundedSender; use uuid::Uuid; -use crate::error::{Error, ErrorCode}; -use crate::types::Message; /// A concurrent message registry that bounds the MCP session ID and related message channel #[derive(Default)] pub(crate) struct MessageRegistry { - inner: DashMap> + inner: DashMap>, } #[allow(dead_code)] @@ -17,9 +17,11 @@ impl MessageRegistry { /// Creates a new [`MessageRegistry`] #[inline] pub(crate) fn new() -> Self { - Self { inner: DashMap::new() } + Self { + inner: DashMap::new(), + } } - + /// Registers MCP session channel #[inline] pub(crate) fn register(&self, key: Uuid, sender: UnboundedSender) { @@ -35,10 +37,8 @@ impl MessageRegistry { /// Sends a message into an appropriate channel #[inline] pub(crate) fn send(&self, message: Message) -> Result<(), Error> { - let session_id = message - .session_id() - .ok_or(ErrorCode::InvalidParams)?; - + let session_id = message.session_id().ok_or(ErrorCode::InvalidParams)?; + if let Some(sender) = self.inner.get(session_id) { sender .send(message) @@ -52,9 +52,9 @@ impl MessageRegistry { #[cfg(test)] mod tests { use super::*; - use tokio::sync::mpsc; use crate::types::Message; use crate::types::notification::Notification; + use tokio::sync::mpsc; #[test] fn it_creates_new_registry() { @@ -92,8 +92,8 @@ mod tests { registry.register(session_id, tx); // Create a test message - let test_message = Message::Notification(Notification::new("test", None)) - .set_session_id(session_id); + let test_message = + Message::Notification(Notification::new("test", None)).set_session_id(session_id); // Send the message let send_result = registry.send(test_message); @@ -111,8 +111,8 @@ mod tests { let session_id = Uuid::new_v4(); // Create a test message for a non-existent session - let test_message = Message::Notification(Notification::new("test", None)) - .set_session_id(session_id); + let test_message = + Message::Notification(Notification::new("test", None)).set_session_id(session_id); // Attempt to send a message let send_result = registry.send(test_message); @@ -132,4 +132,4 @@ mod tests { assert!(send_result.is_err()); assert_eq!(send_result.unwrap_err().code, ErrorCode::InvalidParams); } -} \ No newline at end of file +} diff --git a/neva/src/shared/mt.rs b/neva/src/shared/mt.rs index 4d2dbf7..f93db88 100644 --- a/neva/src/shared/mt.rs +++ b/neva/src/shared/mt.rs @@ -1,4 +1,4 @@ -//! Utilities and helpers for cross-platform multithreading. +//! Utilities and helpers for cross-platform multithreading. /// Cooperatively yields to the Tokio scheduler before spawning a new task, /// ensuring fair task scheduling especially under high load or on platforms @@ -80,4 +80,4 @@ pub async fn yield_fair() { tokio::task::yield_now().await; #[cfg(target_os = "windows")] tokio::time::sleep(std::time::Duration::from_millis(1)).await; -} \ No newline at end of file +} diff --git a/neva/src/shared/one_or_many.rs b/neva/src/shared/one_or_many.rs index b7f5a4e..41186a4 100644 --- a/neva/src/shared/one_or_many.rs +++ b/neva/src/shared/one_or_many.rs @@ -1,7 +1,7 @@ //! Type representing either a vector or a single value if `T` -use std::{ops::{Deref, DerefMut}}; -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; +use std::ops::{Deref, DerefMut}; /// Type representing either a vector or a single value if `T` #[derive(Debug, Clone, Serialize, Deserialize)] @@ -9,9 +9,9 @@ use serde::{Serialize, Deserialize}; pub enum OneOrMany { /// Represents a single value. One(T), - + /// Represents a vector of values. - Many(Vec) + Many(Vec), } impl From for OneOrMany { @@ -25,9 +25,11 @@ impl From> for OneOrMany { #[inline] fn from(v: Vec) -> Self { if v.len() == 1 { - Self::One(v.into_iter() - .next() - .expect("Expected at least one element in vector, but got an empty vector.")) + Self::One( + v.into_iter() + .next() + .expect("Expected at least one element in vector, but got an empty vector."), + ) } else { Self::Many(v) } @@ -46,7 +48,6 @@ impl IntoIterator for OneOrMany { } } - impl Deref for OneOrMany { type Target = [T]; @@ -68,17 +69,17 @@ impl Default for OneOrMany { fn default() -> Self { Self::new() } -} +} impl OneOrMany { /// Creates an empty [`OneOrMany`]. - /// + /// /// Hold the [`OneOrMany::Many`] with empty vector. #[inline] pub fn new() -> Self { Self::Many(Vec::new()) } - + /// Returns a slice of the underlying data. #[inline] pub fn as_slice(&self) -> &[T] { @@ -139,14 +140,14 @@ impl OneOrMany { /// Otherwise, returns `false`. #[inline(always)] pub fn is_empty(&self) -> bool { - match self { + match self { Self::One(_) => false, Self::Many(v) => v.is_empty(), } } /// Appends a value onto the back of a collection. - /// + /// /// If it's called on [`OneOrMany::One`], /// it became [`OneOrMany::Many`] with the old value in the front. #[inline] @@ -161,15 +162,15 @@ impl OneOrMany { vec.push(old); vec.push(value); } - }, + } OneOrMany::Many(vec) => { vec.push(value); match vec.len() { - 0 => {}, // leave Many([]) + 0 => {} // leave Many([]) 1 => { let only = vec.pop().unwrap(); *self = OneOrMany::One(only); - }, + } _ => {} } } @@ -192,11 +193,11 @@ impl OneOrMany { OneOrMany::Many(vec) => { let value = vec.pop(); match vec.len() { - 0 => {}, // leave Many([]) + 0 => {} // leave Many([]) 1 => { let only = vec.pop().unwrap(); *self = OneOrMany::One(only); - }, + } _ => {} } value @@ -204,12 +205,12 @@ impl OneOrMany { } } - /// Removes and returns the element at position `index` within the vector, + /// Removes and returns the element at position `index` within the vector, /// shifting all elements after it to the left. /// /// If it's called on [`OneOrMany::One`], /// it became [`OneOrMany::Many`] with an empty vector. - /// + /// /// # Panics /// Panics if index is out of bounds. #[inline] @@ -217,21 +218,21 @@ impl OneOrMany { match self { OneOrMany::One(_) => { assert!(index < 1, "Index out of bounds"); - + if let OneOrMany::One(v) = std::mem::replace(self, OneOrMany::Many(Vec::new())) { v } else { unreachable!() } - }, + } OneOrMany::Many(vec) => { let value = vec.remove(index); match vec.len() { - 0 => {}, // leave Many([]) + 0 => {} // leave Many([]) 1 => { let only = vec.pop().unwrap(); *self = OneOrMany::One(only); - }, + } _ => {} } @@ -250,8 +251,8 @@ mod tests { let empty = OneOrMany::::new(); assert!(matches!(empty, OneOrMany::Many(_))); assert_eq!(empty.len(), 0); - } - + } + #[test] fn it_can_be_created_from_single_value() { let one = OneOrMany::from(1); @@ -405,12 +406,12 @@ mod tests { _ => panic!("Expected Many"), } } - + #[test] fn it_pushes_to_one() { let mut one = OneOrMany::from(1); one.push(2); - + assert!(matches!(one, OneOrMany::Many(_))); assert_eq!(one.as_slice(), &[1, 2]); } @@ -423,21 +424,21 @@ mod tests { assert!(matches!(one, OneOrMany::Many(_))); assert_eq!(one.as_slice(), &[1, 2, 3, 4]); } - + #[test] fn it_pushes_to_many_with_empty_one() { let mut one = OneOrMany::::from(vec![]); one.push(1); - + assert!(matches!(one, OneOrMany::One(_))); assert_eq!(one.as_one(), Some(&1)); } - + #[test] fn it_pops_from_many() { let mut many = OneOrMany::from(vec![1, 2]); assert_eq!(many.pop(), Some(2)); - + assert!(matches!(many, OneOrMany::One(_))); assert_eq!(many.as_one(), Some(&1)); } @@ -450,35 +451,35 @@ mod tests { assert!(matches!(many, OneOrMany::Many(_))); assert_eq!(many.len(), 0); } - + #[test] fn it_removes_from_many() { let mut many = OneOrMany::::from(vec![1, 2]); assert_eq!(many.remove(0), 1); - + assert!(matches!(many, OneOrMany::One(_))); assert_eq!(many.as_one(), Some(&2)); } - + #[test] fn it_removes_from_one() { let mut one = OneOrMany::from(1); assert_eq!(one.remove(0), 1); - + assert!(matches!(one, OneOrMany::Many(_))); assert_eq!(one.len(), 0); } - + #[test] fn it_can_be_indexed() { let one = OneOrMany::::from(vec![1, 2, 3]); assert_eq!(one[1], 2); } - + #[test] fn it_can_be_indexed_mutably() { let mut one = OneOrMany::::from(vec![1, 2, 3]); one[1] = 4; assert_eq!(one.as_slice(), &[1, 4, 3]); } -} \ No newline at end of file +} diff --git a/neva/src/shared/requests_queue.rs b/neva/src/shared/requests_queue.rs index 8be5d0c..d569126 100644 --- a/neva/src/shared/requests_queue.rs +++ b/neva/src/shared/requests_queue.rs @@ -1,28 +1,31 @@ -//! Utilities for tracking requests +//! Utilities for tracking requests +use crate::types::{RequestId, Response}; use dashmap::DashMap; use std::sync::Arc; use tokio::sync::oneshot; use tokio_util::sync::CancellationToken; -use crate::types::{RequestId, Response}; /// Represents a request handle pub(crate) struct RequestHandle { sender: oneshot::Sender, - _cancellation_token: CancellationToken + _cancellation_token: CancellationToken, } /// Represents a request tracking "queue" that holds a hash map of [`oneshot::Sender`] for requests /// that are awaiting responses. #[derive(Default, Clone)] pub(crate) struct RequestQueue { - pending: Arc> + pending: Arc>, } impl RequestHandle { /// Creates a new [`RequestHandle`] pub(super) fn new(sender: oneshot::Sender) -> Self { - Self { sender, _cancellation_token: CancellationToken::new() } + Self { + sender, + _cancellation_token: CancellationToken::new(), + } } /// Sends a [`Response`] to MCP server @@ -33,14 +36,16 @@ impl RequestHandle { #[cfg(feature = "tracing")] tracing::error!( logger = "neva", - "Request handler failed to send response: {:?}", _err); + "Request handler failed to send response: {:?}", + _err + ); } }; } } impl RequestQueue { - /// Pushes a request with [`RequestId`] to the "queue" + /// Pushes a request with [`RequestId`] to the "queue" /// and returns a [`oneshot::Receiver`] for the response. #[inline] pub(crate) fn push(&self, id: &RequestId) -> oneshot::Receiver { @@ -52,11 +57,9 @@ impl RequestQueue { /// Pops the [`RequestHandle`] by [`RequestId`] and removes it from the queue #[inline] pub(crate) fn pop(&self, id: &RequestId) -> Option { - self.pending - .remove(id) - .map(|(_, handle)| handle) + self.pending.remove(id).map(|(_, handle)| handle) } - + /// Takes a [`Response`] and completes the request if it's still pending #[inline] pub(crate) fn complete(&self, resp: Response) { @@ -69,8 +72,8 @@ impl RequestQueue { #[cfg(test)] mod tests { use super::*; - use tokio::time::{timeout, Duration}; use serde_json::json; + use tokio::time::{Duration, timeout}; #[test] fn it_pushes_and_pops_request() { @@ -81,7 +84,10 @@ mod tests { let handle = queue.pop(&id); assert!(handle.is_some(), "Expected handle to exist"); - assert!(queue.pop(&id).is_none(), "Handle should be removed after pop"); + assert!( + queue.pop(&id).is_none(), + "Handle should be removed after pop" + ); drop(receiver); // Avoid warning for unused receiver } @@ -96,13 +102,16 @@ mod tests { let expected = Response::success(id, json!({ "content": "done" })); handle.send(expected.clone()); - let Response::Ok(expected) = expected else { unreachable!() }; + let Response::Ok(expected) = expected else { + unreachable!() + }; let Response::Ok(actual) = timeout(Duration::from_secs(1), receiver) .await .expect("Receiver should complete") - .expect("Sender should send") else { - unreachable!() + .expect("Sender should send") + else { + unreachable!() }; assert_eq!(actual.result, expected.result); @@ -118,14 +127,17 @@ mod tests { let response = Response::success(id, json!({ "content": "done" })); queue.complete(response.clone()); - - let Response::Ok(response) = response else { unreachable!() }; - + + let Response::Ok(response) = response else { + unreachable!() + }; + let Response::Ok(actual) = timeout(Duration::from_secs(1), receiver) .await .expect("Should receive within timeout") - .expect("Should receive response") else { - unreachable!() + .expect("Should receive response") + else { + unreachable!() }; assert_eq!(actual.result, response.result); @@ -143,4 +155,4 @@ mod tests { // Nothing to assert really, just verifying it doesn't panic or error } -} \ No newline at end of file +} diff --git a/neva/src/shared/task_api.rs b/neva/src/shared/task_api.rs index 0a0812e..36d1e68 100644 --- a/neva/src/shared/task_api.rs +++ b/neva/src/shared/task_api.rs @@ -1,8 +1,8 @@ //! Utilities and types for handling tasks -use crate::types::{Task, TaskPayload, TaskStatus, CreateTaskResult, ListTasksResult, Cursor}; -use crate::error::{Error, ErrorCode}; use super::Either; +use crate::error::{Error, ErrorCode}; +use crate::types::{CreateTaskResult, Cursor, ListTasksResult, Task, TaskPayload, TaskStatus}; use serde::de::DeserializeOwned; use std::time::Duration; @@ -11,30 +11,40 @@ const DEFAULT_POLL_INTERVAL: usize = 5000; // 5 seconds /// A trait for requestor types pub trait TaskApi { /// Retrieve task result from the client. If the task is not completed yet, waits until it completes or cancels. - fn get_task_result(&mut self, id: impl Into) -> impl Future>; + fn get_task_result( + &mut self, + id: impl Into, + ) -> impl Future>; /// Retrieve task status from the client fn get_task(&mut self, id: impl Into) -> impl Future>; - + /// Cancels a task that is currently running on the client fn cancel_task(&mut self, id: impl Into) -> impl Future>; /// Retrieves a list of tasks from the client - fn list_tasks(&mut self, cursor: Option) -> impl Future>; + fn list_tasks( + &mut self, + cursor: Option, + ) -> impl Future>; /// Input callback - fn handle_input(&mut self, id: &str, params: TaskPayload) -> impl Future>; + fn handle_input( + &mut self, + id: &str, + params: TaskPayload, + ) -> impl Future>; } /// Polls receiver with `tasks/get` until it completed, failed, cancelled or expired. /// Call `tasks/result` if it completed or failed and `tasks/cancel` if expired. pub async fn wait_to_completion( - api: &mut A, - result: Either + api: &mut A, + result: Either, ) -> Result -where +where A: TaskApi, - T: DeserializeOwned + T: DeserializeOwned, { let mut task = match result { Either::Right(result) => return Ok(result), @@ -42,48 +52,48 @@ where }; let mut elapsed = 0; - + loop { if task.ttl <= elapsed { #[cfg(feature = "tracing")] tracing::trace!(logger = "neva", "Task TTL expired. Cancelling task."); - + let _ = api.cancel_task(&task.id).await?; - return Err(Error::new(ErrorCode::InvalidRequest, "Task was cancelled: TTL expired")); + return Err(Error::new( + ErrorCode::InvalidRequest, + "Task was cancelled: TTL expired", + )); } - + task = api.get_task(&task.id).await?; - + match task.status { - TaskStatus::Completed | TaskStatus::Failed => return api - .get_task_result(&task.id) - .await, - TaskStatus::Cancelled => return Err( - Error::new(ErrorCode::InvalidRequest, "Task was cancelled") - ), + TaskStatus::Completed | TaskStatus::Failed => { + return api.get_task_result(&task.id).await; + } + TaskStatus::Cancelled => { + return Err(Error::new(ErrorCode::InvalidRequest, "Task was cancelled")); + } TaskStatus::InputRequired => { #[cfg(feature = "tracing")] tracing::trace!(logger = "neva", "Task input required. Providing input."); - - let params: TaskPayload = api - .get_task_result(&task.id) - .await?; + + let params: TaskPayload = api.get_task_result(&task.id).await?; api.handle_input(&task.id, params).await?; - }, + } _ => { - let poll_interval = task - .poll_interval - .unwrap_or(DEFAULT_POLL_INTERVAL); + let poll_interval = task.poll_interval.unwrap_or(DEFAULT_POLL_INTERVAL); elapsed += poll_interval; #[cfg(feature = "tracing")] tracing::trace!( - logger = "neva", - "Waiting for task to complete. Elapsed: {elapsed}ms"); - + logger = "neva", + "Waiting for task to complete. Elapsed: {elapsed}ms" + ); + tokio::time::sleep(Duration::from_millis(poll_interval as u64)).await; } } } -} \ No newline at end of file +} diff --git a/neva/src/shared/task_tracker.rs b/neva/src/shared/task_tracker.rs index 8318a06..1624ae5 100644 --- a/neva/src/shared/task_tracker.rs +++ b/neva/src/shared/task_tracker.rs @@ -1,14 +1,14 @@ //! Types and utilities for tracking tasks -use serde::Serialize; -use tokio_util::sync::{CancellationToken, WaitForCancellationFuture}; -use tokio::sync::watch::{channel, Sender, Receiver}; use crate::error::{Error, ErrorCode}; use crate::types::{Task, TaskPayload, TaskStatus}; +use serde::Serialize; +use tokio::sync::watch::{Receiver, Sender, channel}; +use tokio_util::sync::{CancellationToken, WaitForCancellationFuture}; #[derive(Default)] pub(crate) struct TaskTracker { - tasks: dashmap::DashMap + tasks: dashmap::DashMap, } /// Alias for [`Option`] @@ -34,10 +34,10 @@ impl TaskTracker { #[inline] pub(crate) fn new() -> Self { Self { - tasks: dashmap::DashMap::new() + tasks: dashmap::DashMap::new(), } } - + /// Returns a list of currently running tasks. pub(crate) fn tasks(&self) -> Vec { self.tasks @@ -50,15 +50,18 @@ impl TaskTracker { pub(crate) fn track(&self, task: Task) -> TaskHandle { let token = CancellationToken::new(); let (tx, rx) = channel(None); - - self.tasks.insert(task.id.clone(), TaskEntry { - token: token.clone(), - #[cfg(feature = "server")] - tx: tx.clone(), - task, - rx, - }); - + + self.tasks.insert( + task.id.clone(), + TaskEntry { + token: token.clone(), + #[cfg(feature = "server")] + tx: tx.clone(), + task, + rx, + }, + ); + TaskHandle { token, tx } } @@ -70,7 +73,8 @@ impl TaskTracker { } else { Err(Error::new( ErrorCode::InvalidParams, - format!("Could not find task with id: {id}"))) + format!("Could not find task with id: {id}"), + )) } } @@ -102,9 +106,7 @@ impl TaskTracker { pub(crate) fn reset(&self, id: &str) { if let Some(mut entry) = self.tasks.get_mut(id) { entry.task.reset(); - let _ = entry - .tx - .send(None); + let _ = entry.tx.send(None); } } @@ -116,43 +118,36 @@ impl TaskTracker { Ok(result) => result, Err(_err) => { #[cfg(feature = "tracing")] - tracing::error!( - logger = "neva", - "Unable to serialize task result: {_err:?}"); + tracing::error!(logger = "neva", "Unable to serialize task result: {_err:?}"); return; } }; - let _ = entry - .tx - .send(Some(TaskPayload(result))); + let _ = entry.tx.send(Some(TaskPayload(result))); } } - /// Retrieves the task status + /// Retrieves the task status pub(crate) fn get_status(&self, id: &str) -> Result { - self.tasks - .get(id) - .map(|t| t.task.clone()) - .ok_or_else(|| Error::new( + self.tasks.get(id).map(|t| t.task.clone()).ok_or_else(|| { + Error::new( ErrorCode::InvalidParams, - format!("Could not find task with id: {id}"))) + format!("Could not find task with id: {id}"), + ) + }) } - /// Returns the task result if it is present, + /// Returns the task result if it is present, /// otherwise waits until the result is available or the task will be canceled. pub(crate) async fn get_result(&self, id: &str) -> Result { let (status, mut result_rx, token) = { - let entry = self.tasks - .get(id) - .ok_or_else(|| Error::new( + let entry = self.tasks.get(id).ok_or_else(|| { + Error::new( ErrorCode::InvalidParams, - format!("Could not find task with id: {id}")))?; + format!("Could not find task with id: {id}"), + ) + })?; - ( - entry.task.status, - entry.rx.clone(), - entry.token.clone(), - ) + (entry.task.status, entry.rx.clone(), entry.token.clone()) }; if let Some(ref result) = *result_rx.borrow_and_update() { @@ -192,9 +187,7 @@ impl TaskHandle { Ok(result) => result, Err(_err) => { #[cfg(feature = "tracing")] - tracing::error!( - logger = "neva", - "Unable to serialize task result: {_err:?}"); + tracing::error!(logger = "neva", "Unable to serialize task result: {_err:?}"); return; } }; @@ -221,12 +214,11 @@ impl TaskHandle { } } - #[cfg(test)] mod tests { - use std::sync::Arc; use super::*; use crate::types::TaskStatus; + use std::sync::Arc; #[cfg(feature = "server")] use crate::types::CallToolResponse; @@ -420,7 +412,7 @@ mod tests { let task_id = task.id.clone(); let _handle = tracker.track(task.clone()); - + let tracker = Arc::new(tracker); tokio::spawn({ @@ -493,7 +485,7 @@ mod tests { let handle = tracker.track(task.clone()); let tracker = Arc::new(tracker); - + tokio::spawn({ let tracker = tracker.clone(); let task_id = task_id.clone(); @@ -566,4 +558,4 @@ mod tests { let status = tracker.get_status(&task_id).unwrap(); assert_eq!(status.status, TaskStatus::Completed); } -} \ No newline at end of file +} diff --git a/neva/src/transport.rs b/neva/src/transport.rs index a4e70fb..dc14477 100644 --- a/neva/src/transport.rs +++ b/neva/src/transport.rs @@ -1,25 +1,25 @@ -//! Transport protocols and utilities for communicating between server and client +//! Transport protocols and utilities for communicating between server and client -use std::future::Future; -use tokio_util::sync::CancellationToken; use crate::error::{Error, ErrorCode}; use crate::types::Message; +use std::future::Future; +use tokio_util::sync::CancellationToken; -#[cfg(feature = "server")] -pub(crate) use stdio::StdIoServer; #[cfg(feature = "http-server")] pub use http::HttpServer; +#[cfg(feature = "server")] +pub(crate) use stdio::StdIoServer; -#[cfg(feature = "client")] -pub(crate) use stdio::StdIoClient; #[cfg(feature = "http-client")] pub(crate) use http::HttpClient; +#[cfg(feature = "client")] +pub(crate) use stdio::StdIoClient; -pub(crate) mod stdio; #[cfg(any(feature = "http-server", feature = "http-client"))] pub(crate) mod http; +pub(crate) mod stdio; -/// Describes a sender that can send messages to a client +/// Describes a sender that can send messages to a client pub(crate) trait Sender { /// Sends messages to a client fn send(&mut self, resp: Message) -> impl Future>; @@ -35,10 +35,10 @@ pub(crate) trait Receiver { pub(crate) trait Transport { type Sender: Sender; type Receiver: Receiver; - + /// Starts the server with the current transport protocol fn start(&mut self) -> CancellationToken; - + /// Splits transport into [`Sender`] and [`Receiver`] that can be used in a different threads fn split(self) -> (Self::Sender, Self::Receiver); } @@ -88,7 +88,7 @@ pub(crate) enum TransportProtoReceiver { None, Stdio(stdio::StdIoReceiver), #[cfg(any(feature = "http-server", feature = "http-client"))] - Http(http::HttpReceiver) + Http(http::HttpReceiver), } impl Default for TransportProto { @@ -107,23 +107,24 @@ impl Sender for TransportProtoSender { TransportProtoSender::Http(http) => http.send(resp).await, TransportProtoSender::None => Err(Error::new( ErrorCode::InternalError, - "Transport protocol must be specified" + "Transport protocol must be specified", )), #[cfg(feature = "server")] - TransportProtoSender::BatchCollect { real_sender, responses } => { - match resp { - Message::Response(response) => { - if let Ok(mut guard) = responses.lock() { - guard.push(crate::types::MessageEnvelope::Response(response)); - } - Ok(()) - } - other => { - let mut guard = real_sender.lock().await; - Box::pin(guard.send(other)).await + TransportProtoSender::BatchCollect { + real_sender, + responses, + } => match resp { + Message::Response(response) => { + if let Ok(mut guard) = responses.lock() { + guard.push(crate::types::MessageEnvelope::Response(response)); } + Ok(()) } - } + other => { + let mut guard = real_sender.lock().await; + Box::pin(guard.send(other)).await + } + }, } } } @@ -137,7 +138,7 @@ impl Receiver for TransportProtoReceiver { TransportProtoReceiver::Http(http) => http.recv().await, TransportProtoReceiver::None => Err(Error::new( ErrorCode::InternalError, - "Transport protocol must be specified" + "Transport protocol must be specified", )), } } @@ -146,7 +147,7 @@ impl Receiver for TransportProtoReceiver { impl Transport for TransportProto { type Sender = TransportProtoSender; type Receiver = TransportProtoReceiver; - + #[inline] fn start(&mut self) -> CancellationToken { match self { @@ -161,29 +162,41 @@ impl Transport for TransportProto { TransportProto::None => CancellationToken::new(), } } - + fn split(self) -> (Self::Sender, Self::Receiver) { match self { #[cfg(feature = "server")] TransportProto::StdIoServer(stdio) => { let (tx, rx) = stdio.split(); - (TransportProtoSender::Stdio(tx), TransportProtoReceiver::Stdio(rx)) - }, + ( + TransportProtoSender::Stdio(tx), + TransportProtoReceiver::Stdio(rx), + ) + } #[cfg(feature = "http-server")] TransportProto::HttpServer(http) => { let (tx, rx) = http.split(); - (TransportProtoSender::Http(tx), TransportProtoReceiver::Http(rx)) - }, + ( + TransportProtoSender::Http(tx), + TransportProtoReceiver::Http(rx), + ) + } #[cfg(feature = "client")] TransportProto::StdioClient(stdio) => { let (tx, rx) = stdio.split(); - (TransportProtoSender::Stdio(tx), TransportProtoReceiver::Stdio(rx)) - }, + ( + TransportProtoSender::Stdio(tx), + TransportProtoReceiver::Stdio(rx), + ) + } #[cfg(feature = "http-client")] TransportProto::HttpClient(http) => { let (tx, rx) = http.split(); - (TransportProtoSender::Http(tx), TransportProtoReceiver::Http(rx)) - }, + ( + TransportProtoSender::Http(tx), + TransportProtoReceiver::Http(rx), + ) + } TransportProto::None => (TransportProtoSender::None, TransportProtoReceiver::None), } } diff --git a/neva/src/transport/http.rs b/neva/src/transport/http.rs index 7ee8f1d..16027e9 100644 --- a/neva/src/transport/http.rs +++ b/neva/src/transport/http.rs @@ -5,39 +5,34 @@ use reqwest::header::HeaderMap; #[cfg(feature = "http-server")] use { + server::{AuthConfig, DefaultClaims}, volga::{auth::AuthClaims, headers::HeaderMap}, - server::{AuthConfig, DefaultClaims} }; -use futures_util::TryFutureExt; -use std::{borrow::Cow, fmt::Display}; -use tokio_util::sync::CancellationToken; -use tokio::sync::{mpsc::{self, Receiver, Sender}}; use crate::{ error::{Error, ErrorCode}, shared::MemChr, - types::Message + types::Message, }; +use futures_util::TryFutureExt; +use std::{borrow::Cow, fmt::Display}; +use tokio::sync::mpsc::{self, Receiver, Sender}; +use tokio_util::sync::CancellationToken; -use super::{ - Transport, - Sender as TransportSender, - Receiver as TransportReceiver -}; +use super::{Receiver as TransportReceiver, Sender as TransportSender, Transport}; #[cfg(all(feature = "http-server", feature = "server-tls"))] -pub use volga::tls::{TlsConfig, DevCertMode}; +pub use volga::tls::{DevCertMode, TlsConfig}; #[cfg(all(feature = "http-client", feature = "client-tls"))] use crate::transport::http::client::tls_config::{ - ClientTlsConfig, - TlsConfig as McpClientTlsConfig + ClientTlsConfig, TlsConfig as McpClientTlsConfig, }; -#[cfg(feature = "http-server")] -pub(crate) mod server; #[cfg(feature = "http-client")] pub(crate) mod client; +#[cfg(feature = "http-server")] +pub(crate) mod server; pub(super) const MCP_SESSION_ID: &str = "Mcp-Session-Id"; const DEFAULT_ADDR: &str = "127.0.0.1:3000"; @@ -56,7 +51,7 @@ pub(super) fn get_mcp_session_id(headers: &HeaderMap) -> Option { pub(crate) enum HttpProto { Http, #[cfg(any(feature = "server-tls", feature = "client-tls"))] - Https + Https, } /// Represents HTTP server transport @@ -116,7 +111,7 @@ pub(crate) struct HttpSender { /// Represents HTTP receiver pub(crate) struct HttpReceiver { tx: Sender>, - rx: Receiver> + rx: Receiver>, } #[cfg(feature = "http-server")] @@ -139,7 +134,7 @@ impl Default for HttpServer { #[cfg(feature = "server-tls")] tls_config: None, receiver: HttpReceiver::new(), - sender: HttpSender::new() + sender: HttpSender::new(), } } } @@ -154,7 +149,7 @@ impl Default for HttpClient { #[cfg(feature = "client-tls")] tls_config: None, receiver: HttpReceiver::new(), - sender: HttpSender::new() + sender: HttpSender::new(), } } } @@ -179,7 +174,7 @@ impl ServiceUrl { impl Display for HttpProto { #[inline] fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match &self { + match &self { HttpProto::Http => f.write_str("http"), #[cfg(any(feature = "server-tls", feature = "client-tls"))] HttpProto::Https => f.write_str("https"), @@ -250,15 +245,15 @@ impl HttpServer { pub fn new(addr: &'static str) -> Self { Self::default().bind(addr) } - + /// Binds HTTP serve to address and port pub fn bind(mut self, addr: &'static str) -> Self { self.url.addr = addr; self } - + /// Sets the MCP endpoint - /// + /// /// Default: `/mcp` pub fn with_endpoint(mut self, prefix: &'static str) -> Self { self.url.endpoint = prefix; @@ -269,25 +264,28 @@ impl HttpServer { #[cfg(feature = "server-tls")] pub fn with_tls(mut self, config: F) -> Self where - F: FnOnce(TlsConfig) -> TlsConfig + F: FnOnce(TlsConfig) -> TlsConfig, { self.tls_config = Some(config(Default::default())); self.url.proto = HttpProto::Https; self } - + /// Configures authentication and authorization pub fn with_auth(mut self, config: F) -> Self - where - F: FnOnce(AuthConfig) -> AuthConfig + where + F: FnOnce(AuthConfig) -> AuthConfig, { self.auth = Some(config(AuthConfig::default())); - self + self } - + fn runtime(&mut self) -> Result { let Some(sender_rx) = self.sender.rx.take() else { - return Err(Error::new(ErrorCode::InternalError, "The HTTP writer is already in use")); + return Err(Error::new( + ErrorCode::InternalError, + "The HTTP writer is already in use", + )); }; Ok(HttpRuntimeContext { url: self.url, @@ -320,16 +318,16 @@ impl HttpClient { #[cfg(feature = "client-tls")] pub fn with_tls(mut self, config: F) -> Self where - F: FnOnce(McpClientTlsConfig) -> McpClientTlsConfig + F: FnOnce(McpClientTlsConfig) -> McpClientTlsConfig, { self.tls_config = Some(config(Default::default())); self.url.proto = HttpProto::Https; self } - + /// Set the bearer token for requests /// - ///Default: `None` + ///Default: `None` pub fn with_auth(mut self, access_token: impl Into) -> Self { self.access_token = Some(access_token.into().into_bytes().into_boxed_slice()); self @@ -337,21 +335,22 @@ impl HttpClient { fn runtime(&mut self) -> Result { let Some(sender_rx) = self.sender.rx.take() else { - return Err(Error::new(ErrorCode::InternalError, "The HTTP writer is already in use")); + return Err(Error::new( + ErrorCode::InternalError, + "The HTTP writer is already in use", + )); }; - + #[cfg(feature = "client-tls")] - let tls_config = self.tls_config.take() - .map(|tls| tls.build()) - .transpose()?; - + let tls_config = self.tls_config.take().map(|tls| tls.build()).transpose()?; + Ok(ClientRuntimeContext { url: self.url, tx: self.receiver.tx.clone(), rx: sender_rx, access_token: self.access_token.take(), #[cfg(feature = "client-tls")] - tls_config + tls_config, }) } } @@ -367,10 +366,12 @@ impl TransportSender for HttpSender { impl TransportReceiver for HttpReceiver { async fn recv(&mut self) -> Result { - self.rx - .recv() - .await - .unwrap_or_else(|| Err(Error::new(ErrorCode::InvalidRequest, "Unexpected end of stream"))) + self.rx.recv().await.unwrap_or_else(|| { + Err(Error::new( + ErrorCode::InvalidRequest, + "Unexpected end of stream", + )) + }) } } @@ -389,11 +390,8 @@ impl Transport for HttpServer { return token; } }; - tokio::spawn(server::serve( - runtime, - token.clone()) - ); - + tokio::spawn(server::serve(runtime, token.clone())); + token } @@ -418,11 +416,8 @@ impl Transport for HttpClient { return token; } }; - tokio::spawn(client::connect( - runtime, - token.clone() - )); - + tokio::spawn(client::connect(runtime, token.clone())); + token } @@ -432,6 +427,4 @@ impl Transport for HttpClient { } #[cfg(test)] -mod test { - -} \ No newline at end of file +mod test {} diff --git a/neva/src/transport/http/client.rs b/neva/src/transport/http/client.rs index 5d54369..7a8ef8f 100644 --- a/neva/src/transport/http/client.rs +++ b/neva/src/transport/http/client.rs @@ -1,41 +1,41 @@ -//! HTTP client implementation +//! HTTP client implementation -use std::sync::Arc; -use tokio::sync::mpsc; -use tokio_util::sync::CancellationToken; -use futures_util::{TryStreamExt, StreamExt}; -use reqwest::{RequestBuilder, header::{CONTENT_TYPE, CACHE_CONTROL, ACCEPT}}; use self::mcp_session::McpSession; use crate::{ - transport::http::{ClientRuntimeContext, get_mcp_session_id, MCP_SESSION_ID}, + error::{Error, ErrorCode}, + transport::http::{ClientRuntimeContext, MCP_SESSION_ID, get_mcp_session_id}, types::Message, - error::{Error, ErrorCode} }; +use futures_util::{StreamExt, TryStreamExt}; +use reqwest::{ + RequestBuilder, + header::{ACCEPT, CACHE_CONTROL, CONTENT_TYPE}, +}; +use std::sync::Arc; #[cfg(feature = "client-tls")] use tls_config::ClientTlsConfig; +use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; pub(super) mod mcp_session; #[cfg(feature = "client-tls")] pub(crate) mod tls_config; -pub(super) async fn connect( - rt: ClientRuntimeContext, - token: CancellationToken -) { +pub(super) async fn connect(rt: ClientRuntimeContext, token: CancellationToken) { let session = Arc::new(McpSession::new(rt.url, token)); let access_token: Option> = rt.access_token.map(|t| t.into()); tokio::join!( handle_connection( - session.clone(), - rt.rx, - rt.tx.clone(), + session.clone(), + rt.rx, + rt.tx.clone(), access_token.clone(), #[cfg(feature = "client-tls")] rt.tls_config.clone() ), start_sse_connection( - session.clone(), - rt.tx.clone(), + session.clone(), + rt.tx.clone(), access_token.clone(), #[cfg(feature = "client-tls")] rt.tls_config.clone() @@ -48,8 +48,7 @@ async fn handle_connection( mut sender_rx: mpsc::Receiver, recv_tx: mpsc::Sender>, access_token: Option>, - #[cfg(feature = "client-tls")] - tls_config: Option, + #[cfg(feature = "client-tls")] tls_config: Option, ) { #[cfg(not(feature = "client-tls"))] let client = match create_client() { @@ -70,7 +69,7 @@ async fn handle_connection( return; } }; - + let token = session.cancellation_token(); loop { tokio::select! { @@ -91,11 +90,11 @@ async fn handle_connection( if let Some(session_id) = session.session_id() { resp = resp.header(MCP_SESSION_ID, session_id.to_string()) } - + if let Some(access_token) = &access_token { resp = resp.bearer_auth(String::from_utf8_lossy(access_token)) } - + crate::spawn_fair!(send_request( session.clone(), resp, @@ -111,9 +110,9 @@ async fn send_request( session: Arc, resp: RequestBuilder, req: Message, - resp_tx: mpsc::Sender> + resp_tx: mpsc::Sender>, ) { - let resp = match resp.send().await { + let resp = match resp.send().await { Ok(resp) => resp, Err(_err) => { #[cfg(feature = "tracing")] @@ -129,23 +128,27 @@ async fn send_request( // A notification-only batch also produces no server response (HTTP 202, // empty body). Attempting resp.json() on an empty body would be a parse // error that gets pushed into recv_tx and breaks the receive loop. - if let Message::Batch(ref batch) = req && !batch.has_requests() { + if let Message::Batch(ref batch) = req + && !batch.has_requests() + { return; } - + if !session.has_session_id() - && let Some(session_id) = get_mcp_session_id(resp.headers()) { + && let Some(session_id) = get_mcp_session_id(resp.headers()) + { session.set_session_id(session_id); } if let Message::Request(r) = req - && r.method == crate::commands::INIT { + && r.method == crate::commands::INIT + { session.notify_session_initialized(); session.sse_ready().await; } let resp = resp.json::().await; - + if let Err(_err) = resp_tx.send(resp.map_err(Error::from)).await { #[cfg(feature = "tracing")] tracing::error!(logger = "neva", "Failed to send response: {}", _err); @@ -156,8 +159,7 @@ async fn start_sse_connection( session: Arc, resp_tx: mpsc::Sender>, access_token: Option>, - #[cfg(feature = "client-tls")] - tls_config: Option, + #[cfg(feature = "client-tls")] tls_config: Option, ) { let token = session.cancellation_token(); tokio::select! { @@ -165,12 +167,12 @@ async fn start_sse_connection( _ = token.cancelled() => (), _ = session.initialized() => { tokio::spawn(handle_sse_connection( - session.clone(), - resp_tx, + session.clone(), + resp_tx, access_token, #[cfg(feature = "client-tls")] tls_config - )); + )); } } } @@ -179,8 +181,7 @@ async fn handle_sse_connection( session: Arc, resp_tx: mpsc::Sender>, access_token: Option>, - #[cfg(feature = "client-tls")] - tls_config: Option, + #[cfg(feature = "client-tls")] tls_config: Option, ) { #[cfg(not(feature = "client-tls"))] let client = match create_client() { @@ -191,7 +192,7 @@ async fn handle_sse_connection( return; } }; - + #[cfg(feature = "client-tls")] let client = match create_client(tls_config) { Ok(client) => client, @@ -201,7 +202,7 @@ async fn handle_sse_connection( return; } }; - + let mut resp = client .get(session.url().as_str().as_ref()) .header(ACCEPT, "application/json, text/event-stream") @@ -223,14 +224,14 @@ async fn handle_sse_connection( return; } }; - + let mut stream = sse_stream::SseStream::from_byte_stream(resp.bytes_stream()) .fuse() .map_ok(|event| handle_event(event, &resp_tx)) .map_err(handle_error); - + session.notify_sse_initialized(); - + let token = session.cancellation_token(); loop { tokio::select! { @@ -251,7 +252,7 @@ async fn handle_sse_connection( async fn handle_event(event: sse_stream::Sse, resp_tx: &mpsc::Sender>) { if event.is_message() { handle_msg(event, resp_tx).await - } else { + } else { #[cfg(feature = "tracing")] tracing::debug!(logger = "neva", event = ?event); } @@ -275,36 +276,28 @@ async fn handle_msg(event: sse_stream::Sse, resp_tx: &mpsc::Sender Result { - reqwest::Client::builder() - .build() - .map_err(Error::from) + reqwest::Client::builder().build().map_err(Error::from) } #[inline] #[cfg(feature = "client-tls")] fn create_client(mut tls_config: Option) -> Result { let mut builder = reqwest::ClientBuilder::new(); - if let Some(ca_cert) = tls_config - .as_mut() - .and_then(|tls| tls.ca.take()) { + if let Some(ca_cert) = tls_config.as_mut().and_then(|tls| tls.ca.take()) { builder = builder.add_root_certificate(ca_cert); } - if let Some(identity) = tls_config - .as_mut() - .and_then(|tls| tls.identity.take()) { + if let Some(identity) = tls_config.as_mut().and_then(|tls| tls.identity.take()) { builder = builder.identity(identity); } - if tls_config.is_some_and(|tls| !tls.certs_verification) { - builder = builder.danger_accept_invalid_certs(true); - } - builder - .build() - .map_err(Error::from) + if tls_config.is_some_and(|tls| !tls.certs_verification) { + builder = builder.danger_accept_invalid_certs(true); + } + builder.build().map_err(Error::from) } impl From for Error { #[inline] fn from(err: reqwest::Error) -> Self { - Error::new(ErrorCode::ParseError, err.to_string()) + Error::new(ErrorCode::ParseError, err.to_string()) } } diff --git a/neva/src/transport/http/client/mcp_session.rs b/neva/src/transport/http/client/mcp_session.rs index 43ae2b4..4f5ff52 100644 --- a/neva/src/transport/http/client/mcp_session.rs +++ b/neva/src/transport/http/client/mcp_session.rs @@ -1,9 +1,9 @@ //! Tools and utilities for MCP Client Session +use crate::transport::http::ServiceUrl; use once_cell::sync::OnceCell; use tokio::sync::Notify; use tokio_util::sync::CancellationToken; -use crate::transport::http::ServiceUrl; /// Represents current MCP Session pub(super) struct McpSession { @@ -11,7 +11,7 @@ pub(super) struct McpSession { sse_ready: Notify, url: ServiceUrl, session_id: OnceCell, - cancellation_token: CancellationToken + cancellation_token: CancellationToken, } impl McpSession { @@ -22,7 +22,7 @@ impl McpSession { sse_ready: Notify::new(), session_id: OnceCell::new(), cancellation_token: token, - url + url, } } @@ -53,7 +53,7 @@ impl McpSession { tracing::info!("MCP Session Id already set"); } } - + /// Sends a signal that this MCP Session has been initialized #[inline] pub(super) fn notify_session_initialized(&self) { @@ -81,12 +81,12 @@ impl McpSession { #[cfg(test)] mod tests { - use std::sync::Arc; use super::*; - use uuid::Uuid; - use tokio::time::{timeout, Duration}; - use tokio_util::sync::CancellationToken; use crate::transport::http::HttpProto; + use std::sync::Arc; + use tokio::time::{Duration, timeout}; + use tokio_util::sync::CancellationToken; + use uuid::Uuid; fn create_session() -> McpSession { let url = ServiceUrl { diff --git a/neva/src/transport/http/client/tls_config.rs b/neva/src/transport/http/client/tls_config.rs index a2bdcd4..3a34230 100644 --- a/neva/src/transport/http/client/tls_config.rs +++ b/neva/src/transport/http/client/tls_config.rs @@ -1,8 +1,8 @@ //! MCP client TLS configuration -use std::path::PathBuf; -use reqwest::{Certificate, Identity}; use crate::error::Error; +use reqwest::{Certificate, Identity}; +use std::path::PathBuf; /// Represents TLS configuration for an MCP client #[derive(Debug)] @@ -37,16 +37,16 @@ impl TlsConfig { self.cert_path = Some(cert.into()); self } - + /// Sets the path to the CA (Client Authority) file pub fn with_ca(mut self, path: impl Into) -> Self { self.ca_path = Some(path.into()); self } - + /// Controls the use of certificate validation. /// Setting this to `false` disables TLS certificate validation. - /// + /// /// Default: `true`. /// /// # Warning @@ -60,7 +60,7 @@ impl TlsConfig { self.certs_verification = certs_verification; self } - + /// Creates MCP client TLS config pub(crate) fn build(self) -> Result { let ca = if let Some(ca_path) = self.ca_path { @@ -68,7 +68,7 @@ impl TlsConfig { .map_err(Error::from) .and_then(|b| Certificate::from_pem(&b).map_err(Into::into))?; Some(ca) - } else { + } else { None }; @@ -81,6 +81,10 @@ impl TlsConfig { None }; - Ok(ClientTlsConfig { ca, identity, certs_verification: self.certs_verification }) + Ok(ClientTlsConfig { + ca, + identity, + certs_verification: self.certs_verification, + }) } -} \ No newline at end of file +} diff --git a/neva/src/transport/http/server.rs b/neva/src/transport/http/server.rs index 11273f9..207e3fa 100644 --- a/neva/src/transport/http/server.rs +++ b/neva/src/transport/http/server.rs @@ -1,26 +1,28 @@ //! HTTP server implementation -use std::sync::Arc; +use super::{HttpRuntimeContext, MCP_SESSION_ID, ServiceUrl, get_mcp_session_id}; +#[cfg(feature = "tracing")] +use crate::types::notification::fmt::LOG_REGISTRY; +use crate::{ + error::{Error, ErrorCode}, + shared::MessageRegistry, + types::{Message, RequestId, Response}, +}; use dashmap::DashMap; +use std::sync::Arc; use tokio::sync::{mpsc, oneshot}; use tokio_stream::{StreamExt, wrappers::UnboundedReceiverStream}; use tokio_util::sync::CancellationToken; -use super::{HttpRuntimeContext, ServiceUrl, MCP_SESSION_ID, get_mcp_session_id}; -use crate::{ - shared::MessageRegistry, - types::{RequestId, Message, Response}, - error::{Error, ErrorCode} -}; -use volga::{ - App, HttpResult, HttpRequest, di::Dc, status, ok, - auth::{BearerTokenService, Bearer}, - http::sse::Message as SseMessage, sse, - headers::{HeaderMap, AUTHORIZATION} -}; #[cfg(feature = "server-tls")] use volga::tls::TlsConfig; -#[cfg(feature = "tracing")] -use crate::types::notification::fmt::LOG_REGISTRY; +use volga::{ + App, HttpRequest, HttpResult, + auth::{Bearer, BearerTokenService}, + di::Dc, + headers::{AUTHORIZATION, HeaderMap}, + http::sse::Message as SseMessage, + ok, sse, status, +}; pub use auth_config::{AuthConfig, DefaultClaims}; pub(crate) use auth_config::{validate_permissions, validate_roles}; @@ -36,10 +38,7 @@ struct RequestManager { sender: mpsc::Sender>, } -pub(super) async fn serve( - rt: HttpRuntimeContext, - token: CancellationToken -) { +pub(super) async fn serve(rt: HttpRuntimeContext, token: CancellationToken) { let pending = Arc::new(DashMap::new()); let registry = Arc::new(MessageRegistry::new()); let manager = RequestManager { @@ -50,12 +49,13 @@ pub(super) async fn serve( tokio::join!( dispatch(pending.clone(), registry.clone(), rt.rx, token.clone()), handle( - rt.url, + rt.url, rt.auth, #[cfg(feature = "server-tls")] - rt.tls_config, - manager, - token.clone()) + rt.tls_config, + manager, + token.clone() + ) ); } @@ -63,7 +63,7 @@ async fn dispatch( pending: RequestMap, msg_registry: Arc, mut sender_rx: mpsc::Receiver, - token: CancellationToken + token: CancellationToken, ) { loop { tokio::select! { @@ -88,17 +88,16 @@ async fn dispatch( async fn handle( service_url: ServiceUrl, auth: Option, - #[cfg(feature = "server-tls")] - tls: Option, + #[cfg(feature = "server-tls")] tls: Option, manager: RequestManager, - token: CancellationToken + token: CancellationToken, ) { let root = "/"; let mut server = App::new() .bind(service_url.addr) .with_no_delay() .without_greeter(); - + if let Some(auth) = auth { let (auth, rules) = auth.into_parts(); server = server.with_bearer_auth(|_| auth); @@ -109,16 +108,16 @@ async fn handle( if let Some(tls) = tls { server = server.set_tls(tls); } - + server .add_singleton(manager) .map_err(handle_http_error) - .group(service_url.endpoint, |mcp| { + .group(service_url.endpoint, |mcp| { mcp.map_get(root, handle_connection); mcp.map_post(root, handle_message); - mcp.map_delete(root, handle_session_end); + mcp.map_delete(root, handle_session_end); }); - + if let Err(_e) = server.run().await { #[cfg(feature = "tracing")] tracing::error!(logger = "neva", "HTTP Server was shutdown: {:?}", _e); @@ -132,32 +131,33 @@ async fn handle_session_end(req: HttpRequest) -> HttpResult { }; let manager: Dc = req.extract()?; - + #[cfg(feature = "tracing")] LOG_REGISTRY.unregister(&id); manager.msg_registry.unregister(&id); - + ok!([(MCP_SESSION_ID, id.to_string())]) } async fn handle_connection(req: HttpRequest) -> HttpResult { - let Some(id) = get_mcp_session_id(req.headers()) else { + let Some(id) = get_mcp_session_id(req.headers()) else { return status!(400); }; let manager: Dc = req.extract()?; - + let (_log_tx, log_rx) = mpsc::unbounded_channel::(); let (msg_tx, msg_rx) = mpsc::unbounded_channel::(); - + #[cfg(feature = "tracing")] LOG_REGISTRY.register(id, _log_tx); manager.msg_registry.register(id, msg_tx); - + let stream = futures_util::stream::select( - UnboundedReceiverStream::new(log_rx), - UnboundedReceiverStream::new(msg_rx)); - + UnboundedReceiverStream::new(log_rx), + UnboundedReceiverStream::new(msg_rx), + ); + sse!(stream.map(handle_sse_message); [ (MCP_SESSION_ID, id.to_string()) ]) @@ -180,7 +180,12 @@ async fn handle_message(req: HttpRequest) -> HttpResult { return ok!(resp; [(MCP_SESSION_ID, id.to_string())]); } }; - if let Message::Notification(_) = msg { + if matches!(msg, Message::Notification(_)) { + // JSON-RPC 2.0 §4 / MCP Streamable HTTP: respond 202 immediately, + // but still forward the notification to the app so it can act on it + // (e.g. cancel a pending request on notifications/cancelled). + let msg = msg.set_session_id(id); + let _ = manager.sender.send(Ok(msg)).await; return status!(202; [ (MCP_SESSION_ID, id.to_string()) ]); @@ -191,7 +196,10 @@ async fn handle_message(req: HttpRequest) -> HttpResult { // pending entry — otherwise the oneshot receiver would hang forever. // session_id must still be set so Response envelopes inside the batch // resolve pending entries by the full session_id/request_id key. - if let Message::Batch(ref batch) = msg && !batch.has_requests() && !batch.has_error_responses() { + if let Message::Batch(ref batch) = msg + && !batch.has_requests() + && !batch.has_error_responses() + { let msg = msg.set_session_id(id); manager.sender.send(Ok(msg)).await.map_err(sender_error)?; return status!(202; [ @@ -201,28 +209,25 @@ async fn handle_message(req: HttpRequest) -> HttpResult { let claims = bts .and_then(|bts| { - headers.get(AUTHORIZATION) + headers + .get(AUTHORIZATION) .and_then(|bearer| Bearer::try_from(bearer).ok()) .and_then(|bearer| bts.decode::(bearer).ok()) }) .unwrap_or_default(); - + headers.remove(AUTHORIZATION); - + let msg = msg .set_session_id(id) .set_claims(claims) .set_headers(headers); - + let (resp_tx, resp_rx) = oneshot::channel::(); manager.pending.insert(msg.full_id(), resp_tx); - manager.sender.send(Ok(msg)) - .await - .map_err(sender_error)?; - let resp = resp_rx - .await - .map_err(receiver_error)?; - + manager.sender.send(Ok(msg)).await.map_err(sender_error)?; + let resp = resp_rx.await.map_err(receiver_error)?; + ok!(resp; [ (MCP_SESSION_ID, id.to_string()) ]) @@ -249,11 +254,10 @@ async fn read_message(req: HttpRequest) -> Result { // Two-step decode per JSON-RPC 2.0 §5.1: // 1. Validate JSON syntax → ParseError on failure. // 2. Validate message shape → InvalidRequest on failure. - let value: serde_json::Value = serde_json::from_slice(&buf) - .map_err(|_| ErrorCode::ParseError)?; + let value: serde_json::Value = + serde_json::from_slice(&buf).map_err(|_| ErrorCode::ParseError)?; - serde_json::from_value::(value) - .map_err(|_| ErrorCode::InvalidRequest) + serde_json::from_value::(value).map_err(|_| ErrorCode::InvalidRequest) } async fn handle_http_error(_err: volga::error::Error) { diff --git a/neva/src/transport/http/server/auth_config.rs b/neva/src/transport/http/server/auth_config.rs index cce17a6..7c1ed17 100644 --- a/neva/src/transport/http/server/auth_config.rs +++ b/neva/src/transport/http/server/auth_config.rs @@ -1,12 +1,9 @@ //! Authentication and Authorization configuration tools -use std::fmt::Debug; use crate::error::{Error, ErrorCode}; use serde::Deserialize; -use volga::auth::{ - BearerAuthConfig, DecodingKey, Algorithm, Authorizer, - predicate, AuthClaims -}; +use std::fmt::Debug; +use volga::auth::{Algorithm, AuthClaims, Authorizer, BearerAuthConfig, DecodingKey, predicate}; const ERR_NO_CLAIMS: &str = "Claims are not provided"; const ERR_UNAUTHORIZED: &str = "Subject is not authorized to invoke this"; @@ -17,27 +14,27 @@ pub struct DefaultClaims { /// Subject #[serde(skip_serializing_if = "Option::is_none")] pub sub: Option, - + /// Issuer #[serde(skip_serializing_if = "Option::is_none")] pub iss: Option, - + /// Audience #[serde(skip_serializing_if = "Option::is_none")] pub aud: Option, - + /// Expiration time #[serde(skip_serializing_if = "Option::is_none")] pub exp: Option, - + /// Not before time #[serde(skip_serializing_if = "Option::is_none")] pub nbf: Option, - + /// Issued at time #[serde(skip_serializing_if = "Option::is_none")] pub iat: Option, - + /// JWT ID #[serde(skip_serializing_if = "Option::is_none")] pub jti: Option, @@ -75,7 +72,7 @@ impl AuthClaims for DefaultClaims { /// Represents authentication and authorization configuration pub struct AuthConfig { inner: BearerAuthConfig, - authorizer: Authorizer + authorizer: Authorizer, } impl Debug for AuthConfig { @@ -90,7 +87,7 @@ impl Default for AuthConfig { fn default() -> Self { Self { inner: BearerAuthConfig::default(), - authorizer: default_auth_rules() + authorizer: default_auth_rules(), } } } @@ -105,12 +102,14 @@ impl From for BearerAuthConfig { impl AuthConfig { /// Specifies a security key to validate a JWT from a secret pub fn set_decoding_key(mut self, secret: &[u8]) -> Self { - self.inner = self.inner.set_decoding_key(DecodingKey::from_secret(secret)); + self.inner = self + .inner + .set_decoding_key(DecodingKey::from_secret(secret)); self } - + /// Specifies the algorithm supported for verifying JWTs - /// + /// /// Default: [`Algorithm::HS256`] /// # Example /// ```no_run @@ -124,7 +123,7 @@ impl AuthConfig { /// ``` pub fn with_alg(mut self, alg: Algorithm) -> Self { self.inner = self.inner.with_alg(alg); - self + self } /// Sets one or more acceptable audience members @@ -142,7 +141,7 @@ impl AuthConfig { pub fn with_aud(mut self, aud: I) -> Self where T: ToString, - I: AsRef<[T]> + I: AsRef<[T]>, { self.inner = self.inner.with_aud(aud); self @@ -163,15 +162,15 @@ impl AuthConfig { pub fn with_iss(mut self, iss: I) -> Self where T: ToString, - I: AsRef<[T]> + I: AsRef<[T]>, { self.inner = self.inner.with_iss(iss); self } /// Specifies whether to validate the `aud` field or not. - /// - /// It will return an error if the aud field is not a member of the audience provided. + /// + /// It will return an error if the aud field is not a member of the audience provided. /// Validation only happens if the aud claim is present in the token. /// /// Default: `true` @@ -209,12 +208,12 @@ impl AuthConfig { /// ``` pub fn validate_exp(mut self, validate: bool) -> Self { self.inner = self.inner.validate_exp(validate); - self + self } /// Specifies whether to validate the `nbf` field or not. /// - /// It will return an error if the current timestamp is before the time in the `nbf` field. + /// It will return an error if the current timestamp is before the time in the `nbf` field. /// Validation only happens if the `nbf` claim is present in the token. /// /// Default: `false` @@ -231,7 +230,7 @@ impl AuthConfig { /// ``` pub fn validate_nbf(mut self, validate: bool) -> Self { self.inner = self.inner.validate_nbf(validate); - self + self } /// Deconstructs into [`Authorizer`] and [`BearerAuthConfig`] @@ -242,9 +241,12 @@ impl AuthConfig { /// Validates JWT claims against required permissions #[inline] -pub(crate) fn validate_permissions(claims: Option<&C>, required: Option<&[String]>) -> Result<(), Error> { +pub(crate) fn validate_permissions( + claims: Option<&C>, + required: Option<&[String]>, +) -> Result<(), Error> { required.map_or(Ok(()), |req| { - let claims = claims.ok_or_else(claims_missing)?; + let claims = claims.ok_or_else(claims_missing)?; contains_any(claims.permissions(), req) .then_some(()) .ok_or_else(unauthorized) @@ -253,7 +255,10 @@ pub(crate) fn validate_permissions(claims: Option<&C>, required: /// Validates JWT claims against required roles #[inline] -pub(crate) fn validate_roles(claims: Option<&C>, required: Option<&[String]>) -> Result<(), Error> { +pub(crate) fn validate_roles( + claims: Option<&C>, + required: Option<&[String]>, +) -> Result<(), Error> { required.map_or(Ok(()), |req| { let claims = claims.ok_or_else(claims_missing)?; (contains(claims.role(), req) || contains_any(claims.roles(), req)) @@ -286,4 +291,4 @@ fn claims_missing() -> Error { #[inline] pub(super) fn default_auth_rules() -> Authorizer { predicate(|_| true) -} \ No newline at end of file +} diff --git a/neva/src/transport/stdio.rs b/neva/src/transport/stdio.rs index 853af8a..8b49daa 100644 --- a/neva/src/transport/stdio.rs +++ b/neva/src/transport/stdio.rs @@ -1,35 +1,27 @@ -//! stdio transport implementation +//! stdio transport implementation -use futures_util::TryFutureExt; -use tokio_util::sync::CancellationToken; use crate::error::{Error, ErrorCode}; +use crate::transport::{Receiver as TransportReceiver, Sender as TransportSender, Transport}; use crate::types::Message; +use futures_util::TryFutureExt; use tokio::{ - io::{ - AsyncWrite, AsyncWriteExt, - AsyncRead, AsyncBufReadExt, - BufReader, BufWriter - }, + io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter}, sync::mpsc::{self, Receiver, Sender}, }; -use crate::transport::{ - Transport, - Sender as TransportSender, - Receiver as TransportReceiver -}; +use tokio_util::sync::CancellationToken; #[cfg(feature = "server")] use tokio::io::{Stdin, Stdout}; -#[cfg(feature = "client")] -use tokio::process::{ChildStdin, ChildStdout}; #[cfg(feature = "client")] use self::options::StdIoOptions; +#[cfg(feature = "client")] +use tokio::process::{ChildStdin, ChildStdout}; -#[cfg(all(feature = "client", target_os = "windows"))] -mod windows; #[cfg(all(feature = "client", target_os = "linux"))] mod linux; +#[cfg(all(feature = "client", target_os = "windows"))] +mod windows; #[cfg(feature = "client")] pub(crate) mod options; @@ -58,7 +50,7 @@ pub(crate) struct StdIoSender { /// Represents stdio receiver pub(crate) struct StdIoReceiver { tx: Sender>, - rx: Receiver> + rx: Receiver>, } impl Clone for StdIoSender { @@ -77,19 +69,19 @@ impl StdIoSender { let (tx, rx) = mpsc::channel(100); Self { tx, rx: Some(rx) } } - + /// Starts a new thread that writes to stdout asynchronously pub(crate) fn start( - &mut self, - mut writer: BufWriter, - token: CancellationToken + &mut self, + mut writer: BufWriter, + token: CancellationToken, ) { let Some(mut receiver) = self.rx.take() else { #[cfg(feature = "tracing")] tracing::error!(logger = "neva", "The stdout writer already in use"); return; }; - + tokio::spawn(async move { loop { tokio::select! { @@ -104,7 +96,7 @@ impl StdIoSender { if let Err(_err) = writer.write_all(&json_bytes).await { #[cfg(feature = "tracing")] tracing::error!( - logger = "neva", + logger = "neva", "stdout write error: {:?}", _err); } let _ = writer.flush().await; @@ -112,7 +104,7 @@ impl StdIoSender { Err(_err) => { #[cfg(feature = "tracing")] tracing::error!( - logger = "neva", + logger = "neva", "Serialization error: {:?}", _err); } } @@ -132,12 +124,12 @@ impl StdIoReceiver { let (tx, rx) = mpsc::channel(100); Self { tx, rx } } - + /// Starts a new thread that reads from stdin asynchronously pub(crate) fn start( - &self, - mut reader: BufReader, - token: CancellationToken + &self, + mut reader: BufReader, + token: CancellationToken, ) { let tx = self.tx.clone(); tokio::spawn(async move { @@ -169,7 +161,7 @@ impl StdIoReceiver { } break; } - }; + }; } } } @@ -189,14 +181,17 @@ impl StdIoClient { } /// Handshakes stdio between client and server apps - fn handshake(&self, token: CancellationToken) -> (BufReader, BufWriter) { - let options = &self.options; + fn handshake( + &self, + token: CancellationToken, + ) -> (BufReader, BufWriter) { + let options = &self.options; #[cfg(target_os = "linux")] - let (job, mut child) = linux::Job::new(options.command, &options.args) - .expect("Failed to handshake"); + let (job, mut child) = + linux::Job::new(options.command, &options.args).expect("Failed to handshake"); #[cfg(target_os = "windows")] - let (job, mut child) = windows::Job::new(options.command, &options.args) - .expect("Failed to handshake"); + let (job, mut child) = + windows::Job::new(options.command, &options.args).expect("Failed to handshake"); #[cfg(all(not(target_os = "windows"), not(target_os = "linux")))] let mut child = tokio::process::Command::new(options.command) .args(options.args) @@ -205,10 +200,12 @@ impl StdIoClient { .spawn() .expect("Failed to handshake"); - let stdin = child.stdin + let stdin = child + .stdin .take() .expect("Failed to handshake: Inaccessible stdin"); - let stdout = child.stdout + let stdout = child + .stdout .take() .expect("Failed to handshake: Inaccessible stdout"); @@ -225,7 +222,7 @@ impl StdIoClient { if let Err(_e) = child.kill().await { #[cfg(feature = "tracing")] tracing::warn!( - logger = "neva", + logger = "neva", pid = child_id, "Failed to kill child process: {:?}", _e); } else { @@ -250,13 +247,16 @@ impl StdIoServer { pub(crate) fn new() -> Self { Self { receiver: StdIoReceiver::new(), - sender: StdIoSender::new() + sender: StdIoSender::new(), } } /// Initializes and Returns references to `stdin` and `stdout` pub(crate) fn init() -> (BufReader, BufWriter) { - (BufReader::new(tokio::io::stdin()), BufWriter::new(tokio::io::stdout())) + ( + BufReader::new(tokio::io::stdin()), + BufWriter::new(tokio::io::stdout()), + ) } } @@ -271,10 +271,12 @@ impl TransportSender for StdIoSender { impl TransportReceiver for StdIoReceiver { async fn recv(&mut self) -> Result { - self.rx - .recv() - .await - .unwrap_or_else(|| Err(Error::new(ErrorCode::InvalidRequest, "Unexpected end of stream"))) + self.rx.recv().await.unwrap_or_else(|| { + Err(Error::new( + ErrorCode::InvalidRequest, + "Unexpected end of stream", + )) + }) } } @@ -286,7 +288,7 @@ impl Transport for StdIoClient { fn start(&mut self) -> CancellationToken { let token = CancellationToken::new(); let (reader, writer) = self.handshake(token.clone()); - + self.receiver.start(reader, token.clone()); self.sender.start(writer, token.clone()); @@ -329,22 +331,27 @@ mod tests { #[tokio::test] #[cfg(all(feature = "client", target_os = "windows"))] async fn it_tests_handshake() { - use tokio_util::sync::CancellationToken; - use crate::transport::StdIoClient; use super::options::StdIoOptions; - - let client = StdIoClient::new(StdIoOptions::new("cmd.exe", ["/c", "ping", "127.0.0.1", "-t"])); + use crate::transport::StdIoClient; + use tokio_util::sync::CancellationToken; + + let client = StdIoClient::new(StdIoOptions::new( + "cmd.exe", + ["/c", "ping", "127.0.0.1", "-t"], + )); let token = CancellationToken::new(); let (_, _) = client.handshake(token.clone()); tokio::time::sleep(std::time::Duration::from_secs(1)).await; - + token.cancel(); let result = tokio::time::timeout( std::time::Duration::from_secs(2), - tokio::process::Command::new("tasklist").output() - ).await.unwrap(); + tokio::process::Command::new("tasklist").output(), + ) + .await + .unwrap(); assert!( !String::from_utf8_lossy(&result.unwrap().stdout).contains("ping.exe"), @@ -355,9 +362,9 @@ mod tests { #[tokio::test] #[cfg(all(feature = "client", target_os = "linux"))] async fn it_tests_handshake() { - use tokio_util::sync::CancellationToken; - use crate::transport::StdIoClient; use super::options::StdIoOptions; + use crate::transport::StdIoClient; + use tokio_util::sync::CancellationToken; let client = StdIoClient::new(StdIoOptions::new("sh", ["-c", "sleep 300"])); let token = CancellationToken::new(); @@ -366,14 +373,14 @@ mod tests { token.cancel(); tokio::time::sleep(std::time::Duration::from_secs(1)).await; - + let output = tokio::process::Command::new("pgrep") .arg("-f") .arg("sleep 300") .output() .await .unwrap(); - + assert!(output.stdout.is_empty(), "Process still running"); - } + } } diff --git a/neva/src/transport/stdio/linux.rs b/neva/src/transport/stdio/linux.rs index bcb63a0..d3f790e 100644 --- a/neva/src/transport/stdio/linux.rs +++ b/neva/src/transport/stdio/linux.rs @@ -1,10 +1,10 @@ //! Linux-specific implementation details -use tokio::process::{Child, Command}; use nix::{ - sys::signal::{killpg, Signal}, + sys::signal::{Signal, killpg}, unistd::Pid, }; +use tokio::process::{Child, Command}; /// Process group wrapper for automatic handle closing pub(super) struct Job(i32); @@ -26,7 +26,10 @@ impl Drop for Job { /// Creates a process in a new group with automatic termination #[inline] -pub(super) fn create_process_group(command: &str, args: &Vec<&str>) -> std::io::Result<(i32, Child)> { +pub(super) fn create_process_group( + command: &str, + args: &Vec<&str>, +) -> std::io::Result<(i32, Child)> { let child = Command::new(command) .args(args) .stdin(std::process::Stdio::piped()) @@ -35,7 +38,7 @@ pub(super) fn create_process_group(command: &str, args: &Vec<&str>) -> std::io:: .spawn()?; let group_pid = child.id().expect("Failed to get process id"); - + Ok((group_pid as i32, child)) } @@ -47,15 +50,12 @@ mod tests { #[tokio::test] async fn it_tests_process_group_kill() { - let (job, _) = create_process_group( - "sh", - &vec!["-c", "sleep 300 & sleep 300"] - ).unwrap(); - + let (job, _) = create_process_group("sh", &vec!["-c", "sleep 300 & sleep 300"]).unwrap(); + let job = Job(job); tokio::time::sleep(Duration::from_millis(100)).await; - + drop(job); let output = Command::new("pgrep") @@ -67,4 +67,4 @@ mod tests { assert!(output.stdout.is_empty(), "Processes still running"); } -} \ No newline at end of file +} diff --git a/neva/src/transport/stdio/options.rs b/neva/src/transport/stdio/options.rs index c96779e..fd9c472 100644 --- a/neva/src/transport/stdio/options.rs +++ b/neva/src/transport/stdio/options.rs @@ -1,4 +1,4 @@ -//! stdio transport options +//! stdio transport options /// Represents stdio transport options pub(crate) struct StdIoOptions { @@ -9,12 +9,12 @@ pub(crate) struct StdIoOptions { impl StdIoOptions { /// Creates new stdio options pub(crate) fn new(command: &'static str, args: T) -> Self - where - T: IntoIterator + where + T: IntoIterator, { Self { args: args.into_iter().collect(), - command + command, } } -} \ No newline at end of file +} diff --git a/neva/src/transport/stdio/windows.rs b/neva/src/transport/stdio/windows.rs index 99afa59..9772844 100644 --- a/neva/src/transport/stdio/windows.rs +++ b/neva/src/transport/stdio/windows.rs @@ -1,27 +1,26 @@ -//! Windows-specific implementation details +//! Windows-specific implementation details -use tokio::process::{Command, Child}; +use tokio::process::{Child, Command}; use windows::{ - core::{Result, Error}, Win32::{ Foundation::{CloseHandle, HANDLE}, System::{ - Threading::{ - OpenThread, OpenProcess, ResumeThread, - PROCESS_ALL_ACCESS, THREAD_SUSPEND_RESUME, CREATE_SUSPENDED - }, Diagnostics::ToolHelp::{ - CreateToolhelp32Snapshot, Thread32First, Thread32Next, - TH32CS_SNAPTHREAD, - THREADENTRY32, + CreateToolhelp32Snapshot, TH32CS_SNAPTHREAD, THREADENTRY32, Thread32First, + Thread32Next, }, JobObjects::{ - AssignProcessToJobObject, CreateJobObjectW, SetInformationJobObject, - JobObjectExtendedLimitInformation, JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE, - JOBOBJECT_EXTENDED_LIMIT_INFORMATION, + AssignProcessToJobObject, CreateJobObjectW, JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE, + JOBOBJECT_EXTENDED_LIMIT_INFORMATION, JobObjectExtendedLimitInformation, + SetInformationJobObject, + }, + Threading::{ + CREATE_SUSPENDED, OpenProcess, OpenThread, PROCESS_ALL_ACCESS, ResumeThread, + THREAD_SUSPEND_RESUME, }, }, }, + core::{Error, Result}, }; const CMD: &str = "cmd"; @@ -47,8 +46,8 @@ impl Job { (CMD, win_args) } else { (command, args.clone()) - }; - + }; + let (job_handle, child) = create_job_object_with_kill_on_close(command, args)?; let job = Self(job_handle); Ok((job, child)) @@ -63,7 +62,9 @@ impl Drop for Job { // - The handle is owned by this `Job` wrapper, and not aliased elsewhere. // - This is the only place where the handle is closed (via `Drop`), ensuring it is closed exactly once. // - `CloseHandle` is safe to call on a valid handle, and we ignore the result to prevent panicking during drop. - unsafe { _ = CloseHandle(self.0); } + unsafe { + _ = CloseHandle(self.0); + } } } @@ -123,7 +124,7 @@ fn create_job_object_with_kill_on_close(command: &str, args: Vec<&str>) -> Resul CloseHandle(thread_handle)?; CloseHandle(process_handle)?; - + match result { Ok(_) => Ok((job, child)), Err(_) => Err(Error::from_thread()), @@ -172,16 +173,18 @@ unsafe fn get_main_thread_id(process_id: u32) -> Option { #[cfg(test)] mod tests { - use tokio::process::Command; + use crate::transport::stdio::windows::{ + create_job_object_with_kill_on_close, get_main_thread_id, + }; use std::time::Duration; + use tokio::process::Command; use windows::Win32::System::Threading::CREATE_SUSPENDED; - use crate::transport::stdio::windows::{create_job_object_with_kill_on_close, get_main_thread_id}; #[tokio::test] async fn it_tests_job_object_kills_children() -> Result<(), Box> { let (_job, mut child) = create_job_object_with_kill_on_close( "cmd.exe", - vec!["/c", "ping", "127.0.0.1", "-n", "5", "-w", "1000"] + vec!["/c", "ping", "127.0.0.1", "-n", "5", "-w", "1000"], )?; tokio::time::sleep(Duration::from_secs(1)).await; @@ -220,4 +223,4 @@ mod tests { child.kill().await.unwrap(); } -} \ No newline at end of file +} diff --git a/neva/src/types.rs b/neva/src/types.rs index 18f4799..e5f1347 100644 --- a/neva/src/types.rs +++ b/neva/src/types.rs @@ -1,12 +1,12 @@ -//! Types used by the MCP protocol -//! +//! Types used by the MCP protocol +//! //! See the [specification](https://github.com/modelcontextprotocol/specification) for details -use std::fmt::Display; -use chrono::{DateTime, Utc}; -use serde::{Deserialize, Serialize}; use crate::SDK_NAME; use crate::types::notification::Notification; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use std::fmt::Display; #[cfg(feature = "server")] use crate::{ @@ -19,175 +19,99 @@ use crate::{ pub use request::FromRequest; #[cfg(feature = "http-server")] -use { - crate::auth::DefaultClaims, - volga::headers::HeaderMap -}; +use {crate::auth::DefaultClaims, volga::headers::HeaderMap}; -pub use helpers::{Json, Meta, PropertyType}; -pub use request::{RequestId, Request, RequestParamsMeta}; -pub use response::{IntoResponse, Response, ErrorDetails}; -pub use reference::Reference; -pub use completion::{Completion, CompleteRequestParams, Argument, CompleteResult}; -pub use cursor::{Cursor, Page, Pagination}; +pub use capabilities::{ + ClientCapabilities, CompletionsCapability, ElicitationCapability, ElicitationFormCapability, + ElicitationUrlCapability, LoggingCapability, PromptsCapability, ResourcesCapability, + RootsCapability, SamplingCapability, SamplingContextCapability, SamplingToolsCapability, + ServerCapabilities, ToolsCapability, +}; +pub use completion::{Argument, CompleteRequestParams, CompleteResult, Completion}; pub use content::{ - Content, - TextContent, - AudioContent, - ImageContent, - ResourceLink, - EmbeddedResource, + AudioContent, Content, EmbeddedResource, ImageContent, ResourceLink, TextContent, ToolResult, ToolUse, - ToolResult -}; -pub use capabilities::{ - ClientCapabilities, - ServerCapabilities, - ToolsCapability, - ResourcesCapability, - PromptsCapability, - LoggingCapability, - CompletionsCapability, - ElicitationCapability, - ElicitationFormCapability, - ElicitationUrlCapability, - SamplingCapability, - SamplingContextCapability, - SamplingToolsCapability, - RootsCapability }; +pub use cursor::{Cursor, Page, Pagination}; +pub use helpers::{Json, Meta, PropertyType}; +pub use reference::Reference; +pub use request::{Request, RequestId, RequestParamsMeta}; +pub use response::{ErrorDetails, IntoResponse, Response}; #[cfg(feature = "tasks")] pub use capabilities::{ - ServerTasksCapability, - ClientTasksCapability, - TaskListCapability, - TaskCancellationCapability, - ClientTaskRequestsCapability, - ServerTaskRequestsCapability, - ToolsTaskCapability, - ToolsCallTaskCapability, - SamplingTaskCapability, - SamplingCreateMessageTaskCapability, - ElicitationTaskCapability, - ElicitationCreateTaskCapability + ClientTaskRequestsCapability, ClientTasksCapability, ElicitationCreateTaskCapability, + ElicitationTaskCapability, SamplingCreateMessageTaskCapability, SamplingTaskCapability, + ServerTaskRequestsCapability, ServerTasksCapability, TaskCancellationCapability, + TaskListCapability, ToolsCallTaskCapability, ToolsTaskCapability, }; pub use tool::{ - ListToolsRequestParams, - CallToolRequestParams, - CallToolResponse, - Tool, - ToolSchema, - ToolAnnotations, - ListToolsResult + CallToolRequestParams, CallToolResponse, ListToolsRequestParams, ListToolsResult, Tool, + ToolAnnotations, ToolSchema, }; #[cfg(feature = "server")] pub use tool::ToolHandler; -pub use resource::{ - Uri, - ListResourcesRequestParams, - ListResourceTemplatesRequestParams, - ListResourcesResult, - ListResourceTemplatesResult, - Resource, - ResourceTemplate, - ResourceContents, - TextResourceContents, - BlobResourceContents, - ReadResourceResult, - ReadResourceRequestParams, - SubscribeRequestParams, - UnsubscribeRequestParams, +pub use elicitation::{ + ElicitRequestFormParams, ElicitRequestParams, ElicitRequestUrlParams, ElicitResult, + ElicitationAction, ElicitationCompleteParams, ElicitationMode, UrlElicitationRequiredError, }; pub use prompt::{ - ListPromptsRequestParams, - ListPromptsResult, - Prompt, - GetPromptRequestParams, - GetPromptResult, - PromptArgument, - PromptMessage, + GetPromptRequestParams, GetPromptResult, ListPromptsRequestParams, ListPromptsResult, Prompt, + PromptArgument, PromptMessage, +}; +pub use resource::{ + BlobResourceContents, ListResourceTemplatesRequestParams, ListResourceTemplatesResult, + ListResourcesRequestParams, ListResourcesResult, ReadResourceRequestParams, ReadResourceResult, + Resource, ResourceContents, ResourceTemplate, SubscribeRequestParams, TextResourceContents, + UnsubscribeRequestParams, Uri, }; pub use sampling::{ - CreateMessageRequestParams, - CreateMessageResult, - SamplingMessage, - StopReason, + CreateMessageRequestParams, CreateMessageResult, SamplingMessage, StopReason, ToolChoice, ToolChoiceMode, - ToolChoice -}; -pub use elicitation::{ - UrlElicitationRequiredError, - ElicitationCompleteParams, - ElicitRequestParams, - ElicitRequestFormParams, - ElicitRequestUrlParams, - ElicitationAction, - ElicitationMode, - ElicitResult }; pub use schema::{ - Schema, - StringSchema, - StringFormat, - NumberSchema, - BooleanSchema, - TitledMultiSelectEnumSchema, - TitledSingleSelectEnumSchema, - UntitledMultiSelectEnumSchema, - UntitledSingleSelectEnumSchema, + BooleanSchema, NumberSchema, Schema, StringFormat, StringSchema, TitledMultiSelectEnumSchema, + TitledSingleSelectEnumSchema, UntitledMultiSelectEnumSchema, UntitledSingleSelectEnumSchema, }; -pub use icon::{ - Icon, - IconSize, - IconTheme, -}; +pub use icon::{Icon, IconSize, IconTheme}; #[cfg(feature = "tasks")] pub use task::{ - GetTaskPayloadRequestParams, - GetTaskRequestParams, - ListTasksRequestParams, - ListTasksResult, - CancelTaskRequestParams, - CreateTaskResult, - RelatedTaskMetadata, - TaskMetadata, - TaskPayload, + CancelTaskRequestParams, CreateTaskResult, GetTaskPayloadRequestParams, GetTaskRequestParams, + ListTasksRequestParams, ListTasksResult, RelatedTaskMetadata, Task, TaskMetadata, TaskPayload, TaskStatus, - Task, }; #[cfg(feature = "server")] pub use prompt::PromptHandler; -pub use root::Root; pub use progress::ProgressToken; +pub use root::Root; -mod request; -mod response; mod capabilities; -mod reference; +pub mod completion; mod content; +pub mod cursor; +pub mod elicitation; +pub(crate) mod helpers; +mod icon; +pub mod notification; mod progress; -mod schema; -pub mod tool; -pub mod resource; pub mod prompt; -pub mod completion; -pub mod notification; -pub mod cursor; +mod reference; +mod request; +pub mod resource; +mod response; pub mod root; pub mod sampling; -pub mod elicitation; +mod schema; #[cfg(feature = "tasks")] pub mod task; -mod icon; -pub(crate) mod helpers; +pub mod tool; pub(super) const JSONRPC_VERSION: &str = "2.0"; @@ -333,7 +257,9 @@ impl MessageBatch { #[inline] #[cfg(any(feature = "http-server", feature = "http-client"))] pub(crate) fn has_requests(&self) -> bool { - self.items.iter().any(|e| matches!(e, MessageEnvelope::Request(_))) + self.items + .iter() + .any(|e| matches!(e, MessageEnvelope::Request(_))) } /// Returns `true` if the batch contains at least one [`MessageEnvelope::Response`] @@ -347,7 +273,9 @@ impl MessageBatch { #[inline] #[cfg(feature = "http-server")] pub(crate) fn has_error_responses(&self) -> bool { - self.items.iter().any(|e| matches!(e, MessageEnvelope::Response(Response::Err(_)))) + self.items + .iter() + .any(|e| matches!(e, MessageEnvelope::Response(Response::Err(_)))) } } @@ -374,7 +302,9 @@ impl<'de> Deserialize<'de> for MessageBatch { // Invalid Request responses, not a top-level failure). let raw = Vec::::deserialize(deserializer)?; if raw.is_empty() { - return Err(serde::de::Error::custom("JSON-RPC batch array must not be empty")); + return Err(serde::de::Error::custom( + "JSON-RPC batch array must not be empty", + )); } // Every item either deserializes cleanly or produces an error response. @@ -415,60 +345,60 @@ impl<'de> Deserialize<'de> for MessageBatch { } /// Parameters for an initialization request sent to the server. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct InitializeRequestParams { /// The version of the Model Context Protocol that the client is to use. #[serde(rename = "protocolVersion")] pub protocol_ver: String, - + /// The client's capabilities. pub capabilities: Option, - + /// Information about the client implementation. #[serde(rename = "clientInfo")] pub client_info: Option, } /// Result of the initialization request sent to the server. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct InitializeResult { /// The version of the Model Context Protocol that the server is to use. #[serde(rename = "protocolVersion")] pub protocol_ver: String, - + /// The server's capabilities. pub capabilities: ServerCapabilities, - + /// Information about the server implementation. #[serde(rename = "serverInfo")] - pub server_info : Implementation, - + pub server_info: Implementation, + /// Optional instructions for using the server and its features. #[serde(skip_serializing_if = "Option::is_none")] - pub instructions: Option + pub instructions: Option, } /// Describes the name and version of an MCP implementation. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Implementation { /// Name of the implementation. pub name: String, - + /// Version of the implementation. pub version: String, - + /// Optional set of sized icons that the client can display in a user interface. - /// + /// /// Clients that support rendering icons **MUST** support at least the following MIME types: /// - `image/png` - PNG images (safe, universal compatibility) /// - `image/jpeg` (and `image/jpg`) - JPEG images (safe, universal compatibility) - /// + /// /// Clients that support rendering icons **SHOULD** also support: /// - `image/svg+xml` - SVG images (scalable but requires security precautions) /// - `image/webp` - WebP images (modern, efficient format) @@ -485,39 +415,39 @@ pub enum Role { /// Corresponds to the user in the conversation. User, /// Corresponds to the AI in the conversation. - Assistant + Assistant, } /// Represents annotations that can be attached to content. /// The client can use annotations to inform how objects are used or displayed -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Default, Clone, Serialize, Deserialize)] pub struct Annotations { /// Describes who the intended customer of this object or data is. audience: Vec, - + /// The moment the resource was last modified, as an ISO 8601 formatted string. - /// + /// /// Should be an ISO 8601 formatted string (e.g., **"2025-01-12T15:00:58Z"**). - /// + /// /// **Examples:** last activity timestamp in an open file, timestamp when the resource /// was attached, etc. #[serde(rename = "lastModified", skip_serializing_if = "Option::is_none")] last_modified: Option>, - + /// Describes how important this data is for operating the server (0 to 1). - /// + /// /// A value of 1 means **most important** and indicates that the data is /// effectively required, while 0 means **least important** and indicates that /// the data is entirely optional. - priority: f32 + priority: f32, } impl Display for Role { #[inline] fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { + match self { Role::User => write!(f, "user"), Role::Assistant => write!(f, "assistant"), } @@ -527,10 +457,10 @@ impl Display for Role { impl From<&str> for Role { #[inline] fn from(role: &str) -> Self { - match role { + match role { "user" => Self::User, "assistant" => Self::Assistant, - _ => Self::User + _ => Self::User, } } } @@ -541,7 +471,7 @@ impl From for Role { match role.as_str() { "user" => Self::User, "assistant" => Self::Assistant, - _ => Self::User + _ => Self::User, } } } @@ -561,7 +491,7 @@ impl IntoResponse for InitializeResult { fn into_response(self, req_id: RequestId) -> Response { match serde_json::to_value(self) { Ok(v) => Response::success(req_id, v), - Err(err) => Response::error(req_id, err.into()) + Err(err) => Response::error(req_id, err.into()), } } } @@ -609,14 +539,14 @@ impl Message { pub fn is_batch(&self) -> bool { matches!(self, Message::Batch(_)) } - + /// Returns [`Message`] ID #[inline] pub fn id(&self) -> RequestId { match self { Message::Request(req) => req.id(), Message::Response(resp) => resp.id().clone(), - Message::Notification(_) | Message::Batch(_) => RequestId::default() + Message::Notification(_) | Message::Batch(_) => RequestId::default(), } } @@ -651,7 +581,7 @@ impl Message { } self } - + /// Sets HTTP headers for [`Request`], [`Response`], or [`MessageBatch`] message #[cfg(feature = "http-server")] pub fn set_headers(mut self, headers: HeaderMap) -> Self { @@ -659,7 +589,7 @@ impl Message { Message::Request(ref mut req) => req.headers = headers, Message::Response(resp) => self = Message::Response(resp.set_headers(headers)), Message::Batch(ref mut batch) => batch.headers = headers, - _ => () + _ => (), } self } @@ -670,7 +600,7 @@ impl Message { match self { Message::Request(ref mut req) => req.claims = Some(Box::new(claims)), Message::Batch(ref mut batch) => batch.claims = Some(Box::new(claims)), - _ => () + _ => (), } self } @@ -680,22 +610,21 @@ impl Annotations { /// Deserializes a new [`Annotations`] from a JSON string #[inline] pub fn from_json_str(json: &str) -> Self { - serde_json::from_str(json) - .expect("Annotations: Incorrect JSON string provided") + serde_json::from_str(json).expect("Annotations: Incorrect JSON string provided") } - + /// Adds audience pub fn with_audience>(mut self, role: T) -> Self { self.audience.push(role.into()); self } - + /// Sets the priority pub fn with_priority(mut self, priority: f32) -> Self { self.priority = priority; self } - + /// Sets the moment the object was last modified pub fn with_last_modified(mut self, last_modified: DateTime) -> Self { self.last_modified = Some(last_modified); @@ -725,10 +654,10 @@ impl InitializeResult { completions: Some(CompletionsCapability::default()), #[cfg(feature = "tasks")] tasks: options.tasks_capability(), - experimental: None + experimental: None, }, server_info: options.implementation.clone(), - instructions: None + instructions: None, } } } @@ -789,7 +718,10 @@ mod tests { assert!(matches!(resp, Response::Err(_))); // The id must serialize as JSON null. let serialized = serde_json::to_string(resp).unwrap(); - assert!(serialized.contains(r#""id":null"#), "expected null id, got: {serialized}"); + assert!( + serialized.contains(r#""id":null"#), + "expected null id, got: {serialized}" + ); } #[test] @@ -810,7 +742,10 @@ mod tests { }; assert!(matches!(resp, Response::Err(_))); let serialized = serde_json::to_string(resp).unwrap(); - assert!(serialized.contains(r#""id":2"#), "expected id 2, got: {serialized}"); + assert!( + serialized.contains(r#""id":2"#), + "expected id 2, got: {serialized}" + ); } #[test] @@ -827,7 +762,10 @@ mod tests { }; assert!(matches!(resp, Response::Err(_))); let serialized = serde_json::to_string(resp).unwrap(); - assert!(serialized.contains(r#""id":null"#), "expected null id, got: {serialized}"); + assert!( + serialized.contains(r#""id":null"#), + "expected null id, got: {serialized}" + ); } } -} \ No newline at end of file +} diff --git a/neva/src/types/capabilities.rs b/neva/src/types/capabilities.rs index 12c473e..1901dfb 100644 --- a/neva/src/types/capabilities.rs +++ b/neva/src/types/capabilities.rs @@ -1,29 +1,29 @@ -//! Types that describes server and client capabilities +//! Types that describes server and client capabilities -use std::collections::HashMap; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; /// Represents the capabilities that a client may support. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ClientCapabilities { /// Gets or sets the client's roots capability, which are entry points for resource navigation. /// - /// > **Note:** When `roots` is `Some`, the client indicates that it can respond to + /// > **Note:** When `roots` is `Some`, the client indicates that it can respond to /// > server requests for listing root URIs. Root URIs serve as entry points for resource navigation in the protocol. - /// > + /// > /// > The server can use `RequestRoots` to request the list of /// > available roots from the client, which will trigger the client's `RootsHandler`. #[serde(skip_serializing_if = "Option::is_none")] pub roots: Option, - /// Gets or sets the client's sampling capability, which indicates whether the client + /// Gets or sets the client's sampling capability, which indicates whether the client /// supports issuing requests to an LLM on behalf of the server. #[serde(skip_serializing_if = "Option::is_none")] pub sampling: Option, - /// Gets or sets the client's elicitation capability, which indicates whether the client + /// Gets or sets the client's elicitation capability, which indicates whether the client /// supports elicitation of additional information from the user on behalf of the server. #[serde(skip_serializing_if = "Option::is_none")] pub elicitation: Option, @@ -32,15 +32,15 @@ pub struct ClientCapabilities { #[cfg(feature = "tasks")] #[serde(skip_serializing_if = "Option::is_none")] pub tasks: Option, - + /// Gets or sets experimental, non-standard capabilities that the client supports. /// - /// > **Note:** The `experimental` map allows clients to advertise support for features that are not yet - /// > standardized in the Model Context Protocol specification. This extension mechanism enables + /// > **Note:** The `experimental` map allows clients to advertise support for features that are not yet + /// > standardized in the Model Context Protocol specification. This extension mechanism enables /// > future protocol enhancements while maintaining backward compatibility. - /// > - /// > Values in this map are implementation-specific and should be coordinated between client - /// > and server implementations. Servers should not assume the presence of any experimental capability + /// > + /// > Values in this map are implementation-specific and should be coordinated between client + /// > and server implementations. Servers should not assume the presence of any experimental capability /// > without checking for it first. #[serde(skip_serializing_if = "Option::is_none")] pub experimental: Option>, @@ -51,20 +51,20 @@ pub struct ClientCapabilities { /// > **Note:** When present in [`ClientCapabilities`], it indicates that the client supports listing /// > root URIs that serve as entry points for resource navigation. /// > -/// > The roots capability establishes a mechanism for servers to discover and access the hierarchical +/// > The roots capability establishes a mechanism for servers to discover and access the hierarchical /// > structure of resources provided by a client. Root URIs represent top-level entry points from which /// > servers can navigate to access specific resources. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct RootsCapability { /// Indicates whether the client supports notifications for changes to the roots list. /// - /// > **Note:** When set to `true`, the client can notify servers when roots are added, + /// > **Note:** When set to `true`, the client can notify servers when roots are added, /// > removed, or modified, allowing servers to refresh their roots cache accordingly. /// > This enables servers to stay synchronized with client-side changes to available roots. #[serde(default, rename = "listChanged")] - pub list_changed: bool + pub list_changed: bool, } /// Represents the capability for a client to generate text or other content using an AI model. @@ -83,11 +83,11 @@ pub struct SamplingCapability { /// Indicates whether the client supports tool use via `tools` and `toolChoice` parameters. #[serde(skip_serializing_if = "Option::is_none")] - tools: Option + tools: Option, } /// Represents the sampling context capability. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct SamplingContextCapability { @@ -95,7 +95,7 @@ pub struct SamplingContextCapability { } /// Represents the sampling tools capability. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct SamplingToolsCapability { @@ -103,12 +103,12 @@ pub struct SamplingToolsCapability { } /// Represents the capability for a client to provide server-requested additional information during interactions. -/// +/// /// > **Note:** This capability enables the MCP client to respond to elicitation requests from an MCP server. /// > /// > When this capability is enabled, an MCP server can request the client to provide additional information /// > during interactions. The client must set a to process these requests. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct ElicitationCapability { @@ -118,11 +118,11 @@ pub struct ElicitationCapability { /// Indicates whether the client supports `url` mode elicitation. #[serde(skip_serializing_if = "Option::is_none")] - pub url: Option + pub url: Option, } /// Represents elicitation form capability. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct ElicitationFormCapability { @@ -130,7 +130,7 @@ pub struct ElicitationFormCapability { } /// Represents elicitation URL capability -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct ElicitationUrlCapability { @@ -141,7 +141,7 @@ pub struct ElicitationUrlCapability { /// /// > **Note:** Server capabilities define the features and functionality available when clients connect. /// > These capabilities are advertised to clients during the initialize handshake. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ServerCapabilities { @@ -152,15 +152,15 @@ pub struct ServerCapabilities { /// Present if the server offers any prompt templates. #[serde(skip_serializing_if = "Option::is_none")] pub prompts: Option, - + /// Present if the server offers any resources to read. #[serde(skip_serializing_if = "Option::is_none")] pub resources: Option, - + /// Present if the server supports sending log messages to the client. #[serde(skip_serializing_if = "Option::is_none")] pub logging: Option, - + /// Present if the server supports argument autocompletion suggestions. #[serde(skip_serializing_if = "Option::is_none")] pub completions: Option, @@ -172,39 +172,39 @@ pub struct ServerCapabilities { /// Indicates experimental, non-standard capabilities that the server supports. /// - /// > **Note:** The `experimental` map allows servers to advertise support for features that are not yet - /// > standardized in the Model Context Protocol specification. This extension mechanism enables + /// > **Note:** The `experimental` map allows servers to advertise support for features that are not yet + /// > standardized in the Model Context Protocol specification. This extension mechanism enables /// > future protocol enhancements while maintaining backward compatibility. - /// > - /// > Values in this dictionary are implementation-specific and should be coordinated between client - /// > and server implementations. Clients should not assume the presence of any experimental capability + /// > + /// > Values in this dictionary are implementation-specific and should be coordinated between client + /// > and server implementations. Clients should not assume the presence of any experimental capability /// > without checking for it first. #[serde(skip_serializing_if = "Option::is_none")] pub experimental: Option>, } /// Represents the tools capability configuration. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct ToolsCapability { /// Indicates whether this server supports notifications for changes to the tool list. #[serde(default, rename = "listChanged")] - pub list_changed: bool + pub list_changed: bool, } /// Represents the prompts capability configuration. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct PromptsCapability { /// Indicates whether this server supports notifications for changes to the prompt list. #[serde(default, rename = "listChanged")] - pub list_changed: bool + pub list_changed: bool, } /// Represents the resources capability configuration. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct ResourcesCapability { @@ -213,11 +213,11 @@ pub struct ResourcesCapability { pub list_changed: bool, /// Indicates whether this server supports subscribing to resource updates. - pub subscribe: bool + pub subscribe: bool, } /// Represents the logging capability configuration. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct LoggingCapability { @@ -225,7 +225,7 @@ pub struct LoggingCapability { } /// Represents the completions capability configuration. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct CompletionsCapability { @@ -233,7 +233,7 @@ pub struct CompletionsCapability { } /// Represents task-augmented requests capability configuration for a server. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[cfg(feature = "tasks")] #[derive(Default, Debug, Clone, Serialize, Deserialize)] @@ -241,14 +241,14 @@ pub struct ServerTasksCapability { /// Indicates whether this server supports `tasks/cancel`. #[serde(skip_serializing_if = "Option::is_none")] pub cancel: Option, - + /// Indicates whether this server supports `tasks/list`. #[serde(skip_serializing_if = "Option::is_none")] pub list: Option, /// Specifies which request types can be augmented with tasks. #[serde(skip_serializing_if = "Option::is_none")] - pub requests: Option + pub requests: Option, } /// Represents task-augmented requests capability configuration for a client. @@ -267,11 +267,11 @@ pub struct ClientTasksCapability { /// Specifies which request types can be augmented with tasks. #[serde(skip_serializing_if = "Option::is_none")] - pub requests: Option + pub requests: Option, } /// Represents task cancellation capability configuration. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[cfg(feature = "tasks")] #[derive(Default, Debug, Clone, Serialize, Deserialize)] @@ -296,7 +296,7 @@ pub struct TaskListCapability { pub struct ServerTaskRequestsCapability { /// Specifies task support for tool-related requests. #[serde(skip_serializing_if = "Option::is_none")] - pub tools: Option + pub tools: Option, } /// Specifies which request types can be augmented with tasks. @@ -311,7 +311,7 @@ pub struct ClientTaskRequestsCapability { /// Specifies task support for sampling-related requests. #[serde(skip_serializing_if = "Option::is_none")] - pub sampling: Option + pub sampling: Option, } /// Specifies task support for tool-related requests. @@ -322,7 +322,7 @@ pub struct ClientTaskRequestsCapability { pub struct ToolsTaskCapability { /// Indicates whether the server supports task-augmented `tools/call` requests. #[serde(skip_serializing_if = "Option::is_none")] - pub call: Option + pub call: Option, } /// Specifies task support for elicitation-related requests. @@ -333,7 +333,7 @@ pub struct ToolsTaskCapability { pub struct ElicitationTaskCapability { /// Indicates whether the client supports task-augmented `elicitation/create` requests. #[serde(skip_serializing_if = "Option::is_none")] - pub create: Option + pub create: Option, } /// Specifies task support for sampling-related requests. @@ -344,7 +344,7 @@ pub struct ElicitationTaskCapability { pub struct SamplingTaskCapability { /// Indicates whether the client supports task-augmented `sampling/createMessage` requests. #[serde(rename = "createMessage", skip_serializing_if = "Option::is_none")] - pub create: Option + pub create: Option, } /// Represents task support configuration for `tools/call` requests. @@ -396,7 +396,7 @@ impl ResourcesCapability { } /// Specifies whether this server supports subscribing to resource updates. - /// + /// /// Default: _false_ pub fn with_subscribe(mut self) -> Self { self.subscribe = true; @@ -429,7 +429,7 @@ impl RootsCapability { #[cfg(feature = "client")] impl SamplingCapability { /// Specifies whether this client supports context inclusion. - /// + /// /// Default: `None` pub fn with_context(mut self) -> Self { self.context = Some(SamplingContextCapability {}); @@ -437,7 +437,7 @@ impl SamplingCapability { } /// Specifies whether this client supports the tool use feature. - /// + /// /// Default: `None` pub fn with_tools(mut self) -> Self { self.tools = Some(SamplingToolsCapability {}); @@ -448,7 +448,7 @@ impl SamplingCapability { #[cfg(feature = "client")] impl ElicitationCapability { /// Specifies whether this client supports `form` elicitation mode. - /// + /// /// Default: `None` pub fn with_form(mut self) -> Self { self.form = Some(ElicitationFormCapability {}); @@ -456,7 +456,7 @@ impl ElicitationCapability { } /// Specifies whether this client supports `url` elicitation mode. - /// + /// /// Default: `None` pub fn with_url(mut self) -> Self { self.url = Some(ElicitationUrlCapability {}); @@ -481,7 +481,7 @@ impl ServerTasksCapability { /// Specifies whether this server supports task-augmented requests pub fn with_requests(mut self, config: F) -> Self where - F: FnOnce(ServerTaskRequestsCapability) -> ServerTaskRequestsCapability + F: FnOnce(ServerTaskRequestsCapability) -> ServerTaskRequestsCapability, { self.requests = Some(config(Default::default())); self @@ -491,12 +491,10 @@ impl ServerTasksCapability { pub fn with_tools(self) -> Self { self.with_requests(|req| req.with_tools()) } - + /// Specifies whether this server supports all task-augmented capabilities pub fn with_all(self) -> Self { - self.with_cancel() - .with_list() - .with_tools() + self.with_cancel().with_list().with_tools() } } @@ -517,7 +515,7 @@ impl ClientTasksCapability { /// Specifies whether this client supports task-augmented requests pub fn with_requests(mut self, config: F) -> Self where - F: FnOnce(ClientTaskRequestsCapability) -> ClientTaskRequestsCapability + F: FnOnce(ClientTaskRequestsCapability) -> ClientTaskRequestsCapability, { self.requests = Some(config(Default::default())); self @@ -547,7 +545,7 @@ impl ServerTaskRequestsCapability { /// Specifies task support for tool-related requests. pub fn with_tools(mut self) -> Self { self.tools = Some(ToolsTaskCapability { - call: Some(ToolsCallTaskCapability {}) + call: Some(ToolsCallTaskCapability {}), }); self } @@ -558,16 +556,16 @@ impl ClientTaskRequestsCapability { /// Specifies task support for elicitation-related requests. pub fn with_elicitation(mut self) -> Self { self.elicitation = Some(ElicitationTaskCapability { - create: Some(ElicitationCreateTaskCapability {}) + create: Some(ElicitationCreateTaskCapability {}), }); self - } - + } + /// Specifies task support for sampling-related requests. pub fn with_sampling(mut self) -> Self { self.sampling = Some(SamplingTaskCapability { - create: Some(SamplingCreateMessageTaskCapability {}) + create: Some(SamplingCreateMessageTaskCapability {}), }); self } -} \ No newline at end of file +} diff --git a/neva/src/types/completion.rs b/neva/src/types/completion.rs index 550641a..50deae2 100644 --- a/neva/src/types/completion.rs +++ b/neva/src/types/completion.rs @@ -1,16 +1,16 @@ -//! Completion request types +//! Completion request types -use serde::{Deserialize, Serialize}; use super::Reference; #[cfg(feature = "server")] use crate::error::Error; +use serde::{Deserialize, Serialize}; #[cfg(feature = "server")] -use crate::types::request::FromRequest; +use super::{IntoResponse, Request, RequestId, Response}; #[cfg(feature = "server")] use crate::app::handler::{FromHandlerParams, HandlerParams}; #[cfg(feature = "server")] -use super::{IntoResponse, RequestId, Response, Request}; +use crate::types::request::FromRequest; /// List of commands for Completion pub mod commands { @@ -19,18 +19,18 @@ pub mod commands { } /// Represents a completion object in the server's response -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/2024-11-05/schema.json) for details #[derive(Debug, Serialize, Deserialize)] pub struct Completion { /// An array of completion values. Must not exceed 100 items. pub values: Vec, - - /// The total number of completion options available. + + /// The total number of completion options available. /// This can exceed the number of values actually sent in the response. #[serde(skip_serializing_if = "Option::is_none")] pub total: Option, - + /// Indicates whether there are additional completion options beyond those provided /// in the current response, even if the exact total is unknown. #[serde(skip_serializing_if = "Option::is_none")] @@ -38,33 +38,33 @@ pub struct Completion { } /// A request from the client to the server to ask for completion options. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Serialize, Deserialize)] pub struct CompleteRequestParams { /// The reference's information #[serde(rename = "ref")] pub r#ref: Reference, - + /// The argument's information #[serde(rename = "argument")] pub arg: Argument, } /// Used for completion requests to provide additional context for the completion options. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Serialize, Deserialize)] pub struct Argument { /// The name of the argument. pub name: String, - + /// The value of the argument to use for completion matching. pub value: String, } /// The server's response to a completion/complete request -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Default, Serialize, Deserialize)] pub struct CompleteResult { @@ -97,14 +97,11 @@ impl Completion { /// Creates a new empty [`Completion`] object #[inline] pub fn new(values: T, total: usize) -> Self - where + where T: IntoIterator, V: Into, { - let values: Vec = values - .into_iter() - .map(Into::into) - .collect(); + let values: Vec = values.into_iter().map(Into::into).collect(); Self { total: Some(total), has_more: Some(total > values.len()), @@ -128,7 +125,7 @@ impl IntoResponse for CompleteResult { fn into_response(self, req_id: RequestId) -> Response { match serde_json::to_value(self) { Ok(v) => Response::success(req_id, v), - Err(err) => Response::error(req_id, err.into()) + Err(err) => Response::error(req_id, err.into()), } } } @@ -137,10 +134,10 @@ impl IntoResponse for CompleteResult { impl From for Completion { #[inline] fn from(val: String) -> Self { - Self { - values: vec![val], + Self { + values: vec![val], total: None, - has_more: None + has_more: None, } } } @@ -152,7 +149,7 @@ impl From<&str> for Completion { Self { values: vec![val.into()], total: None, - has_more: None + has_more: None, } } } @@ -161,7 +158,7 @@ impl From<&str> for Completion { impl TryFrom> for CompleteResult where T: Into, - E: Into + E: Into, { type Error = E; @@ -169,7 +166,7 @@ where fn try_from(value: Result) -> Result { match value { Ok(ok) => Ok(ok.into()), - Err(err) => Err(err) + Err(err) => Err(err), } } } @@ -177,7 +174,7 @@ where #[cfg(feature = "server")] impl From for CompleteResult where - T: Into + T: Into, { #[inline] fn from(val: T) -> Self { @@ -186,15 +183,15 @@ where } #[cfg(feature = "server")] -impl From> for CompleteResult +impl From> for CompleteResult where - T: Into + T: Into, { #[inline] fn from(value: Option) -> Self { - match value { + match value { Some(val) => CompleteResult::new(val.into()), - None => CompleteResult::default() + None => CompleteResult::default(), } } } @@ -220,10 +217,7 @@ impl From> for Completion { Self { total: Some(len), has_more: Some(false), - values: vec - .into_iter() - .map(String::from) - .collect(), + values: vec.into_iter().map(String::from).collect(), } } } @@ -249,10 +243,7 @@ impl From<[&str; N]> for Completion { Self { total: Some(len), has_more: Some(false), - values: arr - .into_iter() - .map(String::from) - .collect(), + values: arr.into_iter().map(String::from).collect(), } } } @@ -261,11 +252,11 @@ impl From<[&str; N]> for Completion { #[cfg(feature = "server")] mod tests { use super::*; - + #[test] fn it_creates_default_completion() { let completion = Completion::default(); - + assert_eq!(completion.values.len(), 0); assert_eq!(completion.total, Some(0)); assert_eq!(completion.has_more, Some(false)); @@ -279,22 +270,25 @@ mod tests { assert_eq!(completion.total, Some(5)); assert_eq!(completion.has_more, Some(true)); } - + #[test] fn it_converts_complete_result_into_response() { let result = CompleteResult::default(); - + let resp = result.into_response(RequestId::default()); let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"jsonrpc":"2.0","id":"(no id)","result":{"completion":{"has_more":false,"total":0,"values":[]}}}"#); + assert_eq!( + json, + r#"{"jsonrpc":"2.0","id":"(no id)","result":{"completion":{"has_more":false,"total":0,"values":[]}}}"# + ); } - + #[test] fn it_converts_vec_into_completion() { let vec = vec!["1", "2", "3"]; let completion: Completion = vec.into(); - + assert_eq!(completion.values.len(), 3); assert_eq!(completion.total, Some(3)); assert_eq!(completion.has_more, Some(false)); @@ -329,4 +323,4 @@ mod tests { assert_eq!(completion.completion.total, Some(3)); assert_eq!(completion.completion.has_more, Some(false)); } -} \ No newline at end of file +} diff --git a/neva/src/types/content.rs b/neva/src/types/content.rs index 1e1f7f7..7c71047 100644 --- a/neva/src/types/content.rs +++ b/neva/src/types/content.rs @@ -1,28 +1,22 @@ -//! Any Text, Image, Audio, Video content utilities +//! Any Text, Image, Audio, Video content utilities -use std::collections::HashMap; -use serde::{Deserialize, Serialize}; -use serde::de::DeserializeOwned; +use crate::error::{Error, ErrorCode}; +use crate::shared; use bytes::Bytes; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; use serde_json::Value; -use crate::shared; -use crate::error::{Error, ErrorCode}; +use std::collections::HashMap; use crate::types::helpers::{deserialize_base64_as_bytes, serialize_bytes_as_base64}; use crate::types::{ - CallToolResponse, - CallToolRequestParams, - Annotations, - Resource, - ResourceContents, - Icon, - Uri + Annotations, CallToolRequestParams, CallToolResponse, Icon, Resource, ResourceContents, Uri, }; const CHUNK_SIZE: usize = 8192; /// Represents the content of the response. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type")] @@ -30,19 +24,19 @@ pub enum Content { /// Audio content #[serde(rename = "audio")] Audio(AudioContent), - + /// Image content #[serde(rename = "image")] Image(ImageContent), - + /// Text content #[serde(rename = "text")] Text(TextContent), - + /// Resource link #[serde(rename = "resource_link")] ResourceLink(ResourceLink), - + /// Embedded resource #[serde(rename = "resource")] Resource(EmbeddedResource), @@ -54,7 +48,7 @@ pub enum Content { /// Tool result content #[serde(rename = "tool_result")] ToolResult(ToolResult), - + /// Empty content #[serde(rename = "empty")] Empty(EmptyContent), @@ -65,13 +59,13 @@ pub enum Content { pub struct EmptyContent; /// Text provided to or from an LLM. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TextContent { /// The text content of the message. pub text: String, - + /// Optional annotations for the client. #[serde(skip_serializing_if = "Option::is_none")] pub annotations: Option, @@ -82,7 +76,7 @@ pub struct TextContent { } /// Audio provided to or from an LLM. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AudioContent { @@ -91,13 +85,14 @@ pub struct AudioContent { /// **Note:** will be serialized as a base64-encoded string #[serde( serialize_with = "serialize_bytes_as_base64", - deserialize_with = "deserialize_base64_as_bytes")] + deserialize_with = "deserialize_base64_as_bytes" + )] pub data: Bytes, /// The MIME type of the audio content, e.g. "audio/mpeg" or "audio/wav". #[serde(rename = "mimeType")] pub mime: String, - + /// Optional annotations for the client. #[serde(skip_serializing_if = "Option::is_none")] pub annotations: Option, @@ -108,16 +103,17 @@ pub struct AudioContent { } /// An image provided to or from an LLM. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ImageContent { /// Raw image data. - /// + /// /// **Note:** will be serialized as a base64-encoded string #[serde( - serialize_with = "serialize_bytes_as_base64", - deserialize_with = "deserialize_base64_as_bytes")] + serialize_with = "serialize_bytes_as_base64", + deserialize_with = "deserialize_base64_as_bytes" + )] pub data: Bytes, /// The MIME type of the audio content, e.g. "image/jpg" or "image/png". @@ -134,15 +130,15 @@ pub struct ImageContent { } /// A resource that the server is capable of reading, included in a prompt or tool call result. -/// +/// /// **Note:** resource links returned by tools are not guaranteed to appear in the results of `resources/list` requests. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ResourceLink { /// The URI of this resource. pub uri: Uri, - + /// Intended for programmatic or logical use /// but used as a display name in past specs or fallback (if a title isn't present). pub name: String, @@ -154,10 +150,10 @@ pub struct ResourceLink { /// The MIME type of the resource. If known. #[serde(rename = "mimeType", skip_serializing_if = "Option::is_none")] pub mime: Option, - + /// Intended for UI and end-user contexts - optimized to be human-readable and easily understood, /// even by those unfamiliar with domain-specific terminology. - /// + /// /// If not provided, the name should be used for display (except for Tool, /// where `annotations.title` should be given precedence over using `name`, if present). #[serde(skip_serializing_if = "Option::is_none")] @@ -166,7 +162,7 @@ pub struct ResourceLink { /// A description of what this resource represents. #[serde(rename = "description", skip_serializing_if = "Option::is_none")] pub descr: Option, - + /// Optional annotations for the client. #[serde(skip_serializing_if = "Option::is_none")] pub annotations: Option, @@ -189,16 +185,16 @@ pub struct ResourceLink { } /// The contents of a resource, embedded into a prompt or tool call result. -/// +/// /// It is up to the client how best to render embedded resources for the benefit /// of the LLM and/or the user. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct EmbeddedResource { /// The resource content of the message. pub resource: ResourceContents, - + /// Optional annotations for the client. #[serde(skip_serializing_if = "Option::is_none")] pub annotations: Option, @@ -209,12 +205,12 @@ pub struct EmbeddedResource { } /// Represents a request from the assistant to call a tool. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolUse { /// A unique identifier for this tool use. - /// + /// /// This ID is used to match tool results to their corresponding tool uses. pub id: String, @@ -230,31 +226,31 @@ pub struct ToolUse { } /// Represents the result of a tool use, provided by the user back to the assistant. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolResult { /// The ID of the tool use this result corresponds to. - /// + /// /// This **MUST** match the ID from a previous [`ToolUse`]. #[serde(rename = "toolUseId")] pub tool_use_id: String, /// The unstructured result content of the tool use. - /// + /// /// This has the same format as [`CallToolResponse::content`] and can include text, images, audio, resource links, and embedded resources. pub content: Vec, - + /// An optional JSON object that represents the structured result of the tool call. - /// + /// /// If the tool defined an `outputSchema`, this **SHOULD** conform to that schema. #[serde(rename = "structuredContent", skip_serializing_if = "Option::is_none")] pub struct_content: Option, /// Whether the tool call was unsuccessful. - /// + /// /// If true, the content typically describes the error that occurred. - /// + /// /// Default: `false` #[serde(default, rename = "isError")] pub is_error: bool, @@ -409,7 +405,7 @@ impl TryFrom for EmbeddedResource { #[inline] fn try_from(value: Content) -> Result { - match value { + match value { Content::Resource(res) => Ok(res), _ => Err(Error::new(ErrorCode::InternalError, "Invalid content type")), } @@ -433,9 +429,9 @@ impl From for Content { impl From for CallToolRequestParams { #[inline] fn from(value: ToolUse) -> Self { - Self { - name: value.name, - args: value.input, + Self { + name: value.name, + args: value.input, meta: None, #[cfg(feature = "tasks")] task: None, @@ -448,7 +444,7 @@ impl TryFrom for ToolUse { #[inline] fn try_from(value: Content) -> Result { - match value { + match value { Content::ToolUse(tool_use) => Ok(tool_use), _ => Err(Error::new(ErrorCode::InternalError, "Invalid content type")), } @@ -460,7 +456,7 @@ impl TryFrom for ToolResult { #[inline] fn try_from(value: Content) -> Result { - match value { + match value { Content::ToolResult(tool_result) => Ok(tool_result), _ => Err(Error::new(ErrorCode::InternalError, "Invalid content type")), } @@ -480,7 +476,7 @@ impl Content { let json = serde_json::to_value(json).unwrap(); Self::from(json) } - + /// Creates an image [`Content`] #[inline] pub fn image(data: impl Into) -> Self { @@ -492,13 +488,13 @@ impl Content { pub fn audio(data: impl Into) -> Self { Self::Audio(AudioContent::new(data)) } - + /// Creates an embedded resource [`Content`] #[inline] pub fn resource(resource: impl Into) -> Self { Self::Resource(EmbeddedResource::new(resource)) } - + /// Creates a resource link [`Content`] #[inline] pub fn link(resource: impl Into) -> Self { @@ -516,7 +512,7 @@ impl Content { pub fn tool_use(name: N, args: Args) -> Self where N: Into, - Args: shared::IntoArgs + Args: shared::IntoArgs, { Self::ToolUse(ToolUse::new(name, args)) } @@ -526,11 +522,11 @@ impl Content { pub fn empty() -> Self { Self::Empty(EmptyContent) } - + /// Returns the type of the content. #[inline] pub fn get_type(&self) -> &str { - match self { + match self { Self::Empty(_) => "empty", Self::Audio(_) => "audio", Self::Image(_) => "image", @@ -538,25 +534,25 @@ impl Content { Self::ResourceLink(_) => "resource_link", Self::Resource(_) => "resource", Self::ToolUse(_) => "tool_use", - Self::ToolResult(_) => "tool_result" + Self::ToolResult(_) => "tool_result", } } - + /// Returns the content as a text content. #[inline] pub fn as_text(&self) -> Option<&TextContent> { match self { Self::Text(c) => Some(c), - _ => None + _ => None, } } - + /// Returns the content as a deserialized struct #[inline] pub fn as_json(&self) -> Option { - match self { + match self { Self::Text(c) => serde_json::from_str(&c.text).ok(), - _ => None + _ => None, } } @@ -565,7 +561,7 @@ impl Content { pub fn as_audio(&self) -> Option<&AudioContent> { match self { Self::Audio(c) => Some(c), - _ => None + _ => None, } } @@ -574,7 +570,7 @@ impl Content { pub fn as_image(&self) -> Option<&ImageContent> { match self { Self::Image(c) => Some(c), - _ => None + _ => None, } } @@ -583,7 +579,7 @@ impl Content { pub fn as_link(&self) -> Option<&ResourceLink> { match self { Self::ResourceLink(c) => Some(c), - _ => None + _ => None, } } @@ -592,7 +588,7 @@ impl Content { pub fn as_resource(&self) -> Option<&EmbeddedResource> { match self { Self::Resource(c) => Some(c), - _ => None + _ => None, } } @@ -601,7 +597,7 @@ impl Content { pub fn as_tool(&self) -> Option<&ToolUse> { match self { Self::ToolUse(c) => Some(c), - _ => None + _ => None, } } @@ -610,7 +606,7 @@ impl Content { pub fn as_result(&self) -> Option<&ToolResult> { match self { Self::ToolResult(c) => Some(c), - _ => None + _ => None, } } } @@ -622,14 +618,14 @@ impl TextContent { Self { text: text.into(), annotations: None, - meta: None + meta: None, } } /// Sets annotations for the client pub fn with_annotations(mut self, config: F) -> Self where - F: FnOnce(Annotations) -> Annotations + F: FnOnce(Annotations) -> Annotations, { self.annotations = Some(config(Default::default())); self @@ -644,7 +640,7 @@ impl AudioContent { data: data.into(), mime: "audio/wav".into(), annotations: None, - meta: None + meta: None, } } @@ -657,12 +653,12 @@ impl AudioContent { /// Sets annotations for the client pub fn with_annotations(mut self, config: F) -> Self where - F: FnOnce(Annotations) -> Annotations + F: FnOnce(Annotations) -> Annotations, { self.annotations = Some(config(Default::default())); self } - + /// Returns audio data as a slice of bytes pub fn as_slice(&self) -> &[u8] { &self.data @@ -689,7 +685,7 @@ impl ImageContent { data: data.into(), mime: "image/jpg".into(), annotations: None, - meta: None + meta: None, } } @@ -702,7 +698,7 @@ impl ImageContent { /// Sets annotations for the client pub fn with_annotations(mut self, config: F) -> Self where - F: FnOnce(Annotations) -> Annotations + F: FnOnce(Annotations) -> Annotations, { self.annotations = Some(config(Default::default())); self @@ -746,7 +742,7 @@ impl EmbeddedResource { Self { resource: resource.into(), annotations: None, - meta: None + meta: None, } } } @@ -757,13 +753,13 @@ impl ToolUse { pub fn new(name: N, args: Args) -> Self where N: Into, - Args: shared::IntoArgs + Args: shared::IntoArgs, { Self { id: uuid::Uuid::new_v4().to_string(), name: name.into(), input: args.into_args(), - meta: None + meta: None, } } } @@ -777,7 +773,7 @@ impl ToolResult { content: resp.content, struct_content: resp.struct_content, is_error: resp.is_error, - meta: None + meta: None, } } @@ -789,7 +785,7 @@ impl ToolResult { content: vec![Content::text(error.to_string())], struct_content: None, is_error: true, - meta: None + meta: None, } } } @@ -798,26 +794,25 @@ impl ToolResult { mod test { use super::*; use futures_util::StreamExt; - + #[derive(Deserialize)] struct Test { name: String, - age: u32 + age: u32, } - + #[test] fn it_serializes_text_content_to_json() { let content = Content::text("hello world"); let json = serde_json::to_string(&content).unwrap(); - + assert_eq!(json, r#"{"type":"text","text":"hello world"}"#); } #[test] fn it_deserializes_text_content_to_json() { let json = r#"{"type":"text","text":"hello world"}"#; - let content = serde_json::from_str::(json) - .unwrap(); + let content = serde_json::from_str::(json).unwrap(); assert_eq!(content.as_text().unwrap().text, "hello world"); } @@ -825,11 +820,10 @@ mod test { #[test] fn it_deserializes_structures_text_content_to_json() { let json = r#"{"type":"text","text":"{\"name\":\"John\",\"age\":30}"}"#; - let content = serde_json::from_str::(json) - .unwrap(); + let content = serde_json::from_str::(json).unwrap(); let user: Test = content.as_json().unwrap(); - + assert_eq!(user.name, "John"); assert_eq!(user.age, 30); } @@ -839,16 +833,21 @@ mod test { let content = Content::audio("hello world"); let json = serde_json::to_string(&content).unwrap(); - assert_eq!(json, r#"{"type":"audio","data":"aGVsbG8gd29ybGQ=","mimeType":"audio/wav"}"#); + assert_eq!( + json, + r#"{"type":"audio","data":"aGVsbG8gd29ybGQ=","mimeType":"audio/wav"}"# + ); } #[test] fn it_deserializes_audio_content_to_json() { let json = r#"{"type":"audio","data":"aGVsbG8gd29ybGQ=","mimeType":"audio/wav"}"#; - let content = serde_json::from_str::(json) - .unwrap(); + let content = serde_json::from_str::(json).unwrap(); - assert_eq!(String::from_utf8_lossy(content.as_audio().unwrap().as_slice()), "hello world"); + assert_eq!( + String::from_utf8_lossy(content.as_audio().unwrap().as_slice()), + "hello world" + ); assert_eq!(content.as_audio().unwrap().mime, "audio/wav"); } @@ -857,45 +856,52 @@ mod test { let content = Content::image("hello world"); let json = serde_json::to_string(&content).unwrap(); - assert_eq!(json, r#"{"type":"image","data":"aGVsbG8gd29ybGQ=","mimeType":"image/jpg"}"#); + assert_eq!( + json, + r#"{"type":"image","data":"aGVsbG8gd29ybGQ=","mimeType":"image/jpg"}"# + ); } #[test] fn it_deserializes_image_content_to_json() { let json = r#"{"type":"image","data":"aGVsbG8gd29ybGQ=","mimeType":"image/jpg"}"#; - let content = serde_json::from_str::(json) - .unwrap(); + let content = serde_json::from_str::(json).unwrap(); - assert_eq!(String::from_utf8_lossy(content.as_image().unwrap().as_slice()), "hello world"); + assert_eq!( + String::from_utf8_lossy(content.as_image().unwrap().as_slice()), + "hello world" + ); assert_eq!(content.as_image().unwrap().mime, "image/jpg"); } #[test] #[cfg(feature = "server")] fn it_serializes_resource_content_to_json() { - let content = Content::resource(ResourceContents::new("res://resource") - .with_text("hello world") - .with_title("some resource") - .with_annotations(|a| a - .with_audience("user") - .with_priority(1.0))); - + let content = Content::resource( + ResourceContents::new("res://resource") + .with_text("hello world") + .with_title("some resource") + .with_annotations(|a| a.with_audience("user").with_priority(1.0)), + ); + let json = serde_json::to_string(&content).unwrap(); - assert_eq!(json, r#"{"type":"resource","resource":{"uri":"res://resource","text":"hello world","title":"some resource","mimeType":"text/plain","annotations":{"audience":["user"],"priority":1.0}}}"#); + assert_eq!( + json, + r#"{"type":"resource","resource":{"uri":"res://resource","text":"hello world","title":"some resource","mimeType":"text/plain","annotations":{"audience":["user"],"priority":1.0}}}"# + ); } #[test] #[cfg(feature = "server")] fn it_deserializes_resource_content_to_json() { use crate::types::Role; - + let json = r#"{"type":"resource","resource":{"uri":"res://resource","text":"hello world","title":"some resource","mimeType":"text/plain","annotations":{"audience":["user"],"priority":1.0}}}"#; - let content = serde_json::from_str::(json) - .unwrap(); + let content = serde_json::from_str::(json).unwrap(); let res = &content.as_resource().unwrap().resource; - + assert_eq!(res.uri().to_string(), "res://resource"); assert_eq!(res.mime().unwrap(), "text/plain"); assert_eq!(res.text().unwrap(), "hello world"); @@ -907,27 +913,29 @@ mod test { #[test] #[cfg(feature = "server")] fn it_serializes_resource_link_content_to_json() { - let content = Content::link(Resource::new("res://resource", "some resource") - .with_title("some resource") - .with_descr("some resource") - .with_size(2) - .with_annotations(|a| a - .with_audience("user") - .with_priority(1.0))); + let content = Content::link( + Resource::new("res://resource", "some resource") + .with_title("some resource") + .with_descr("some resource") + .with_size(2) + .with_annotations(|a| a.with_audience("user").with_priority(1.0)), + ); let json = serde_json::to_string(&content).unwrap(); - assert_eq!(json, r#"{"type":"resource_link","uri":"res://resource","name":"some resource","size":2,"title":"some resource","description":"some resource","annotations":{"audience":["user"],"priority":1.0}}"#); + assert_eq!( + json, + r#"{"type":"resource_link","uri":"res://resource","name":"some resource","size":2,"title":"some resource","description":"some resource","annotations":{"audience":["user"],"priority":1.0}}"# + ); } #[test] #[cfg(feature = "server")] fn it_deserializes_resource_link_content_to_json() { use crate::types::Role; - + let json = r#"{"type":"resource_link","uri":"res://resource","name":"some resource","size":2,"title":"some resource","description":"some resource","annotations":{"audience":["user"],"priority":1.0}}"#; - let content = serde_json::from_str::(json) - .unwrap(); + let content = serde_json::from_str::(json).unwrap(); let res = content.as_link().unwrap(); @@ -975,7 +983,10 @@ mod test { let result_string = String::from_utf8(collected_data).expect("Should be valid UTF-8"); assert_eq!(result_string, test_data); - assert!(chunk_count > 1, "Should have multiple chunks for large data"); + assert!( + chunk_count > 1, + "Should have multiple chunks for large data" + ); } #[tokio::test] @@ -1007,7 +1018,10 @@ mod test { let result_string = String::from_utf8(collected_data).expect("Should be valid UTF-8"); assert_eq!(result_string, test_data); - assert_eq!(chunk_count, 1, "Exactly CHUNK_SIZE data should produce one chunk"); + assert_eq!( + chunk_count, 1, + "Exactly CHUNK_SIZE data should produce one chunk" + ); } #[tokio::test] @@ -1044,7 +1058,10 @@ mod test { let result_string = String::from_utf8(collected_data).expect("Should be valid UTF-8"); assert_eq!(result_string, test_data); - assert!(chunk_count > 1, "Should have multiple chunks for large data"); + assert!( + chunk_count > 1, + "Should have multiple chunks for large data" + ); } #[tokio::test] @@ -1063,7 +1080,10 @@ mod test { max_chunk_size = max_chunk_size.max(chunk_size); // Each chunk should not exceed CHUNK_SIZE - assert!(chunk_size <= CHUNK_SIZE, "Chunk size should not exceed CHUNK_SIZE"); + assert!( + chunk_size <= CHUNK_SIZE, + "Chunk size should not exceed CHUNK_SIZE" + ); } assert_eq!(total_size, test_data.len()); @@ -1121,4 +1141,4 @@ mod test { assert_eq!(audio.data, deserialized.data); assert_eq!(audio.mime, deserialized.mime); } -} \ No newline at end of file +} diff --git a/neva/src/types/cursor.rs b/neva/src/types/cursor.rs index 0a0395e..6370b60 100644 --- a/neva/src/types/cursor.rs +++ b/neva/src/types/cursor.rs @@ -1,8 +1,8 @@ //! Cursor-based pagination utilities +use base64::{Engine as _, engine::general_purpose}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use std::ops::{Deref, DerefMut}; -use serde::{Serialize, Serializer, Deserialize, Deserializer}; -use base64::{engine::general_purpose, Engine as _}; /// An opaque token representing the pagination position after the last returned result. #[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)] @@ -31,17 +31,16 @@ impl<'de> Deserialize<'de> for Cursor { let decoded = general_purpose::STANDARD .decode(&encoded) .map_err(serde::de::Error::custom)?; - - let index: usize = - serde_json::from_slice(&decoded).map_err(serde::de::Error::custom)?; - + + let index: usize = serde_json::from_slice(&decoded).map_err(serde::de::Error::custom)?; + Ok(Cursor(index)) } } impl Deref for Cursor { type Target = usize; - + #[inline] fn deref(&self) -> &Self::Target { &self.0 @@ -158,9 +157,11 @@ mod tests { let page = data.paginate(cursor, 2); collected.extend_from_slice(page.items); cursor = page.next_cursor; - if cursor.is_none() { break; } + if cursor.is_none() { + break; + } } assert_eq!(collected, data); } -} \ No newline at end of file +} diff --git a/neva/src/types/elicitation.rs b/neva/src/types/elicitation.rs index d9de223..1de4787 100644 --- a/neva/src/types/elicitation.rs +++ b/neva/src/types/elicitation.rs @@ -1,44 +1,44 @@ //! Utilities for Elicitation -#[cfg(feature = "client")] -use std::{future::Future, pin::Pin, sync::Arc}; -use std::collections::HashMap; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_json::Value; -use schemars::JsonSchema; use crate::{ - types::{IntoResponse, PropertyType, RequestId, Response, Schema, ErrorDetails}, - types::notification::Notification, error::{Error, ErrorCode}, + types::notification::Notification, + types::{ErrorDetails, IntoResponse, PropertyType, RequestId, Response, Schema}, }; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; +use serde_json::Value; +use std::collections::HashMap; +#[cfg(feature = "client")] +use std::{future::Future, pin::Pin, sync::Arc}; use crate::types::Uri; #[cfg(feature = "tasks")] -use crate::types::{TaskMetadata, RelatedTaskMetadata}; +use crate::types::{RelatedTaskMetadata, TaskMetadata}; /// List of commands for Elicitation pub mod commands { /// Command name for creating a new elicitation request pub const CREATE: &str = "elicitation/create"; - + /// Notification name for indicates the completion of elicitation pub const COMPLETE: &str = "notifications/elicitation/complete"; } /// Represents a message issued from the server to elicit additional information from the user via the client. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub enum ElicitRequestParams { /// Elicitation request parameters for a form Form(ElicitRequestFormParams), - + /// Elicitation request parameters for a URL - Url(ElicitRequestUrlParams) + Url(ElicitRequestUrlParams), } -/// Represents the parameters for a request to elicit non-sensitive information from the user +/// Represents the parameters for a request to elicit non-sensitive information from the user /// via a form in the client. /// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details @@ -69,27 +69,27 @@ pub struct ElicitRequestFormParams { /// Additional metadata to attach to the request. #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] - pub meta: Option + pub meta: Option, } -/// Represents the parameters for a request to elicit information from the user +/// Represents the parameters for a request to elicit information from the user /// via a URL in the client. /// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ElicitRequestUrlParams { /// The ID of the elicitation, which must be unique within the context of the server. - /// + /// /// The client **MUST** treat this ID as an opaque value. #[serde(rename = "elicitationId")] pub id: String, - + /// The message to present to the user. pub message: String, /// The elicitation mode pub mode: ElicitationMode, - + /// The URL that the user should navigate to. pub url: Uri, @@ -102,14 +102,14 @@ pub struct ElicitRequestUrlParams { #[cfg(feature = "tasks")] #[serde(skip_serializing_if = "Option::is_none")] pub task: Option, - + /// Additional metadata to attach to the request. #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] - pub meta: Option + pub meta: Option, } /// Represents elicitation mode. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Clone, Copy, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] @@ -118,18 +118,18 @@ pub enum ElicitationMode { Form, /// `url` elicitation mode - Url + Url, } /// Represents a JSON Schema that can be used to validate the content of an elicitation request. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RequestSchema { /// The type of the schema. - /// + /// /// > **Note:** always "object". #[serde(rename = "type", default)] pub r#type: PropertyType, - + /// The properties of the schema. pub properties: HashMap, @@ -142,20 +142,20 @@ pub struct RequestSchema { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ElicitResult { /// The user action in response to the elicitation. - /// + /// /// * "accept" - User submitted the form/confirmed the action. /// * "cancel" - User dismissed without making an explicit choice. /// * "decline" - User explicitly declined the action. pub action: ElicitationAction, - + /// The submitted form data. - /// + /// /// > **Note:** This is typically omitted if the action is "cancel" or "decline". pub content: Option, /// Additional metadata to attach to the result. #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] - pub meta: Option + pub meta: Option, } /// Represents the user's action in response to an elicitation request. @@ -164,27 +164,27 @@ pub struct ElicitResult { pub enum ElicitationAction { /// User submitted the form/confirmed the action Accept, - + /// User dismissed without making an explicit choice Cancel, - + /// User explicitly declined the action - Decline + Decline, } /// Represents an error response that indicates that the server requires the client /// to provide additional information via an elicitation request. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct UrlElicitationRequiredError { /// A list of required elicitations - pub elicitations: Vec + pub elicitations: Vec, } -/// Represents an optional notification from the server to the client, informing it of a completion +/// Represents an optional notification from the server to the client, informing it of a completion /// of an out-of-band elicitation request. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ElicitationCompleteParams { @@ -226,13 +226,13 @@ impl Default for RequestSchema { impl Validator { /// Creates a new [`Validator`] - #[inline] + #[inline] pub fn new(params: ElicitRequestFormParams) -> Self { Self { schema: params.schema, } } - + /// Validates the elicitation content against the schema #[inline] pub fn validate(&self, content: T) -> Result { @@ -247,12 +247,15 @@ impl Validator { fn validate_schema_compatibility(&self, source: &schemars::Schema) -> Result<(), Error> { const PROP: &str = "properties"; const REQ: &str = "required"; - + let target = &self.schema; let source_props = source .get(PROP) .and_then(|v| v.as_object()) - .ok_or(Error::new(ErrorCode::InvalidParams, "Source schema missing properties"))?; + .ok_or(Error::new( + ErrorCode::InvalidParams, + "Source schema missing properties", + ))?; let source_required = source .get(REQ) @@ -264,8 +267,9 @@ impl Validator { for prop_name in target.properties.keys() { if !source_props.contains_key(prop_name) { return Err(Error::new( - ErrorCode::InvalidParams, - format!("Missing property: {prop_name}"))); + ErrorCode::InvalidParams, + format!("Missing property: {prop_name}"), + )); } } @@ -274,8 +278,9 @@ impl Validator { for required_prop in target_required { if !source_required.contains(&required_prop.as_str()) { return Err(Error::new( - ErrorCode::InvalidParams, - format!("Required property not marked as required: {required_prop}"))); + ErrorCode::InvalidParams, + format!("Required property not marked as required: {required_prop}"), + )); } } } @@ -286,17 +291,19 @@ impl Validator { /// Validates content against schema constraints fn validate_content_constraints(&self, content: &Value) -> Result<(), Error> { let schema = &self.schema; - let content_obj = content - .as_object() - .ok_or(Error::new(ErrorCode::InvalidParams, "Content is not an object"))?; + let content_obj = content.as_object().ok_or(Error::new( + ErrorCode::InvalidParams, + "Content is not an object", + ))?; // Check required properties if let Some(required) = &schema.required { for required_prop in required { if !content_obj.contains_key(required_prop) { return Err(Error::new( - ErrorCode::InvalidParams, - format!("Missing required property: {required_prop}"))); + ErrorCode::InvalidParams, + format!("Missing required property: {required_prop}"), + )); } } } @@ -337,7 +344,7 @@ impl ElicitRequestParams { mode: None, meta: None, #[cfg(feature = "tasks")] - task: None + task: None, } } @@ -351,47 +358,53 @@ impl ElicitRequestParams { mode: ElicitationMode::Url, meta: None, #[cfg(feature = "tasks")] - task: None + task: None, } } - /// Returns a reference to the underlying [`ElicitRequestFormParams`] if the request is a form, + /// Returns a reference to the underlying [`ElicitRequestFormParams`] if the request is a form, /// otherwise returns `None` #[inline] pub fn as_form(&self) -> Option<&ElicitRequestFormParams> { - match self { + match self { Self::Form(params) => Some(params), - _ => None + _ => None, } } - - /// Returns a reference to the underlying [`ElicitRequestUrlParams`] if the request is a URL, + + /// Returns a reference to the underlying [`ElicitRequestUrlParams`] if the request is a URL, /// otherwise returns `None` #[inline] pub fn as_url(&self) -> Option<&ElicitRequestUrlParams> { - match self { + match self { Self::Url(params) => Some(params), - _ => None - } + _ => None, + } } - - /// Converts the request into a form request. + + /// Converts the request into a form request. /// Returns an error if the request is not a form request. #[inline] pub fn into_form(self) -> Result { match self { Self::Form(params) => Ok(params), - _ => Err(Error::new(ErrorCode::InvalidRequest, "Request is not a form request")) + _ => Err(Error::new( + ErrorCode::InvalidRequest, + "Request is not a form request", + )), } } - /// Converts the request into a URL request. + /// Converts the request into a URL request. /// Returns an error if the request is not a URL request. #[inline] pub fn into_url(self) -> Result { match self { Self::Url(params) => Ok(params), - _ => Err(Error::new(ErrorCode::InvalidRequest, "Request is not a URL request")) + _ => Err(Error::new( + ErrorCode::InvalidRequest, + "Request is not a URL request", + )), } } @@ -401,7 +414,7 @@ impl ElicitRequestParams { pub fn with_related_task(self, task_id: impl Into) -> Self { match self { Self::Form(form) => form.with_related_task(task_id).into(), - Self::Url(url) => url.with_related_task(task_id).into() + Self::Url(url) => url.with_related_task(task_id).into(), } } @@ -410,8 +423,7 @@ impl ElicitRequestParams { #[inline] #[cfg(feature = "tasks")] pub fn is_task_augmented(&self) -> bool { - self.as_url() - .is_some_and(|p| p.task.is_some()) + self.as_url().is_some_and(|p| p.task.is_some()) } /// Returns the [`RelatedTaskMetadata`] if it's specified @@ -420,7 +432,7 @@ impl ElicitRequestParams { pub fn related_task(&self) -> Option { match self { Self::Form(form) => form.related_task(), - Self::Url(url) => url.related_task() + Self::Url(url) => url.related_task(), } } } @@ -429,16 +441,14 @@ impl ElicitRequestFormParams { /// Adds a single optional property to the schema #[inline] pub fn with_prop(mut self, prop: &str, schema: impl Into) -> Self { - self.schema = self.schema - .with_prop(prop, schema); + self.schema = self.schema.with_prop(prop, schema); self } /// Adds a single required property to the schema #[inline] pub fn with_required(mut self, prop: &str, schema: impl Into) -> Self { - self.schema = self.schema - .with_required(prop, schema); + self.schema = self.schema.with_required(prop, schema); self } @@ -456,7 +466,8 @@ impl ElicitRequestFormParams { let meta: RelatedTaskMetadata = task.into(); let meta = serde_json::to_value(meta).unwrap(); - self.meta.get_or_insert_with(|| serde_json::json!({})) + self.meta + .get_or_insert_with(|| serde_json::json!({})) .as_object_mut() .unwrap() .insert(crate::types::task::RELATED_TASK_KEY.into(), meta); @@ -478,7 +489,7 @@ impl ElicitRequestFormParams { #[cfg(feature = "tasks")] impl ElicitRequestUrlParams { /// Makes the request task-augmented with TTL. - /// + /// /// Default: `None` pub fn with_ttl(mut self, ttl: Option) -> Self { self.task = Some(TaskMetadata { ttl }); @@ -492,7 +503,8 @@ impl ElicitRequestUrlParams { let meta: RelatedTaskMetadata = task.into(); let meta = serde_json::to_value(meta).unwrap(); - self.meta.get_or_insert_with(|| serde_json::json!({})) + self.meta + .get_or_insert_with(|| serde_json::json!({})) .as_object_mut() .unwrap() .insert(crate::types::task::RELATED_TASK_KEY.into(), meta); @@ -515,20 +527,16 @@ impl RequestSchema { /// Creates a new [`RequestSchema`] without properties #[inline] pub fn new() -> Self { - Self::default() + Self::default() } - + /// Creates a new [`RequestSchema`] from a type that implements [`Default`] and [`Serialize`] - #[inline] + #[inline] pub fn of() -> Self { let mut schema = Self::default(); let json_schema = schemars::schema_for!(T); - let required = json_schema - .get("required") - .and_then(|v| v.as_array()); - if let Some(props) = json_schema - .get("properties") - .and_then(|v| v.as_object()) { + let required = json_schema.get("required").and_then(|v| v.as_array()); + if let Some(props) = json_schema.get("properties").and_then(|v| v.as_object()) { for (field, def) in props { let req = required .map(|arr| !arr.iter().any(|v| v == field)) @@ -549,14 +557,12 @@ impl RequestSchema { self.properties.insert(prop.into(), schema.into()); self } - + /// Creates a new [`RequestSchema`] with a single required property #[inline] pub fn with_required(mut self, prop: &str, schema: impl Into) -> Self { self = self.with_prop(prop, schema); - self.required - .get_or_insert_with(Vec::new) - .push(prop.into()); + self.required.get_or_insert_with(Vec::new).push(prop.into()); self } } @@ -581,7 +587,7 @@ impl ElicitResult { meta: None, } } - + /// Creates a new canceled [`ElicitResult`] #[inline] pub fn cancel() -> Self { @@ -591,32 +597,32 @@ impl ElicitResult { meta: None, } } - + /// Sets the content of the [`ElicitResult`] - #[inline] + #[inline] pub fn with_content(mut self, content: T) -> Self { self.content = Some(serde_json::to_value(&content).unwrap()); self } - + /// Deserializes the content of the [`ElicitResult`] - #[inline] + #[inline] pub fn content(&self) -> Option { self.content .as_ref() .and_then(|content| serde_json::from_value(content.clone()).ok()) } - + /// Returns _true_ if the [`ElicitResult`] is accepted pub fn is_accepted(&self) -> bool { self.action == ElicitationAction::Accept } - + /// Returns _true_ if the [`ElicitResult`] is canceled pub fn is_canceled(&self) -> bool { self.action == ElicitationAction::Cancel } - + /// Returns _true_ if the [`ElicitResult`] is declined pub fn is_declined(&self) -> bool { self.action == ElicitationAction::Decline @@ -634,7 +640,10 @@ impl ElicitResult { .ok_or_else(|| Error::new(ErrorCode::ParseError, "Failed to parse content")) .map(f) } else { - Err(Error::new(ErrorCode::InvalidRequest, "User rejected the request")) + Err(Error::new( + ErrorCode::InvalidRequest, + "User rejected the request", + )) } } @@ -660,7 +669,8 @@ impl ElicitResult { let meta: RelatedTaskMetadata = task.into(); let meta = serde_json::to_value(meta).unwrap(); - self.meta.get_or_insert_with(|| serde_json::json!({})) + self.meta + .get_or_insert_with(|| serde_json::json!({})) .as_object_mut() .unwrap() .insert(crate::types::task::RELATED_TASK_KEY.into(), meta); @@ -683,11 +693,13 @@ impl UrlElicitationRequiredError { /// Creates a new [`UrlElicitationRequiredError`] #[inline] pub fn new(elicitations: impl IntoIterator) -> Self { - Self { elicitations: elicitations.into_iter().collect() } + Self { + elicitations: elicitations.into_iter().collect(), + } } - + /// Converts into JSONRPC error response - #[inline] + #[inline] pub fn to_error(self, message: impl Into) -> Error { let err = match serde_json::to_value(self) { Ok(data) => ErrorDetails { @@ -699,7 +711,7 @@ impl UrlElicitationRequiredError { code: ErrorCode::InternalError, message: err.to_string(), data: None, - } + }, }; err.into() } @@ -718,7 +730,8 @@ impl TryFrom for ElicitationCompleteParams { #[inline] fn try_from(value: Notification) -> Result { - let params = value.params + let params = value + .params .ok_or_else(|| Error::new(ErrorCode::InvalidParams, "Missing params"))?; serde_json::from_value(params).map_err(Error::from) } @@ -738,7 +751,7 @@ impl IntoResponse for ElicitResult { fn into_response(self, req_id: RequestId) -> Response { match serde_json::to_value(self) { Ok(v) => Response::success(req_id, v), - Err(err) => Response::error(req_id, err.into()) + Err(err) => Response::error(req_id, err.into()), } } } @@ -746,21 +759,16 @@ impl IntoResponse for ElicitResult { /// Represents a dynamic handler for handling sampling requests #[cfg(feature = "client")] pub(crate) type ElicitationHandler = Arc< - dyn Fn(ElicitRequestParams) -> Pin< - Box + Send + 'static> - > - + Send - + Sync + dyn Fn(ElicitRequestParams) -> Pin + Send + 'static>> + + Send + + Sync, >; #[cfg(test)] mod tests { use super::*; use crate::types::{ - StringSchema, StringFormat, - NumberSchema, - BooleanSchema, - UntitledSingleSelectEnumSchema + BooleanSchema, NumberSchema, StringFormat, StringSchema, UntitledSingleSelectEnumSchema, }; use schemars::JsonSchema; @@ -782,7 +790,7 @@ mod tests { min_length: Some(2), max_length: Some(50), format: None, - }) + }), ); schema.properties.insert( "age".to_string(), @@ -792,11 +800,11 @@ mod tests { descr: None, min: Some(0.0), max: Some(120.0), - }) + }), ); schema.properties.insert( "active".to_string(), - Schema::Boolean(BooleanSchema::default()) + Schema::Boolean(BooleanSchema::default()), ); schema.required = Some(vec!["name".to_string(), "age".to_string()]); schema @@ -849,7 +857,7 @@ mod tests { let mut schema = create_test_schema(); schema.properties.insert( "missing_prop".to_string(), - Schema::String(StringSchema::default()) + Schema::String(StringSchema::default()), ); let params = create_form_params_with_schema(schema); @@ -872,7 +880,11 @@ mod tests { #[test] fn it_validates_missing_required_property() { let mut schema = create_test_schema(); - schema.required = Some(vec!["name".to_string(), "age".to_string(), "missing_required".to_string()]); + schema.required = Some(vec![ + "name".to_string(), + "age".to_string(), + "missing_required".to_string(), + ]); let params = create_form_params_with_schema(schema); let validator = Validator::new(params); @@ -888,7 +900,11 @@ mod tests { let error = result.unwrap_err(); assert_eq!(error.code, ErrorCode::InvalidParams); - assert!(error.to_string().contains("Required property not marked as required")); + assert!( + error + .to_string() + .contains("Required property not marked as required") + ); } #[test] @@ -1118,9 +1134,13 @@ mod tests { r#type: PropertyType::String, title: None, descr: None, - r#enum: vec!["active".to_string(), "inactive".to_string(), "pending".to_string()], + r#enum: vec![ + "active".to_string(), + "inactive".to_string(), + "pending".to_string(), + ], default: None, - }) + }), ); schema.required = Some(vec!["status".to_string()]); @@ -1146,7 +1166,7 @@ mod tests { descr: None, r#enum: vec!["active".to_string(), "inactive".to_string()], default: None, - }) + }), ); schema.required = Some(vec!["status".to_string()]); @@ -1161,7 +1181,11 @@ mod tests { assert!(result.is_err()); let error = result.unwrap_err(); - assert!(error.to_string().contains("Invalid enum value: invalid_status")); + assert!( + error + .to_string() + .contains("Invalid enum value: invalid_status") + ); } #[test] @@ -1175,7 +1199,7 @@ mod tests { descr: None, r#enum: vec!["active".to_string(), "inactive".to_string()], default: None, - }) + }), ); schema.required = Some(vec!["status".to_string()]); @@ -1205,7 +1229,7 @@ mod tests { min_length: None, max_length: None, format: Some(StringFormat::Email), - }) + }), ); schema.required = Some(vec!["email".to_string()]); @@ -1232,7 +1256,7 @@ mod tests { min_length: None, max_length: None, format: Some(StringFormat::Email), - }) + }), ); schema.required = Some(vec!["email".to_string()]); @@ -1262,7 +1286,7 @@ mod tests { min_length: None, max_length: None, format: Some(StringFormat::Uri), - }) + }), ); let params = create_form_params_with_schema(schema); @@ -1272,7 +1296,7 @@ mod tests { "http://example.com", "https://example.com", "file://path/to/file", - "res://resource_1" + "res://resource_1", ]; for uri in test_cases { @@ -1297,7 +1321,7 @@ mod tests { min_length: None, max_length: None, format: Some(StringFormat::Uri), - }) + }), ); let params = create_form_params_with_schema(schema); @@ -1326,7 +1350,7 @@ mod tests { min_length: None, max_length: None, format: Some(StringFormat::Date), - }) + }), ); let params = create_form_params_with_schema(schema); @@ -1352,17 +1376,17 @@ mod tests { min_length: None, max_length: None, format: Some(StringFormat::Date), - }) + }), ); let params = create_form_params_with_schema(schema); let validator = Validator::new(params); let test_cases = vec![ - "1990/05/15", // Wrong separators - "90-05-15", // Wrong year format - "1990-5-15", // Missing zero padding - "not-a-date", // Invalid format + "1990/05/15", // Wrong separators + "90-05-15", // Wrong year format + "1990-5-15", // Missing zero padding + "not-a-date", // Invalid format ]; for invalid_date in test_cases { @@ -1371,7 +1395,11 @@ mod tests { }); let result = validator.validate_content_constraints(&content_json); - assert!(result.is_err(), "Should fail for invalid date: {}", invalid_date); + assert!( + result.is_err(), + "Should fail for invalid date: {}", + invalid_date + ); let error = result.unwrap_err(); assert!(error.to_string().contains("Invalid date format")); @@ -1390,7 +1418,7 @@ mod tests { min_length: None, max_length: None, format: Some(StringFormat::DateTime), - }) + }), ); let params = create_form_params_with_schema(schema); @@ -1416,7 +1444,7 @@ mod tests { min_length: None, max_length: None, format: Some(StringFormat::DateTime), - }) + }), ); let params = create_form_params_with_schema(schema); @@ -1445,7 +1473,7 @@ mod tests { min_length: None, max_length: None, format: None, - }) + }), ); let params = create_form_params_with_schema(schema); @@ -1465,11 +1493,11 @@ mod tests { let mut schema = RequestSchema::new(); schema.properties.insert( "required_field".to_string(), - Schema::String(StringSchema::default()) + Schema::String(StringSchema::default()), ); schema.properties.insert( "optional_field".to_string(), - Schema::String(StringSchema::default()) + Schema::String(StringSchema::default()), ); schema.required = Some(vec!["required_field".to_string()]); @@ -1499,7 +1527,7 @@ mod tests { let mut schema = RequestSchema::new(); schema.properties.insert( "optional_field".to_string(), - Schema::String(StringSchema::default()) + Schema::String(StringSchema::default()), ); // No required fields schema.required = None; @@ -1568,7 +1596,7 @@ mod tests { descr: None, r#enum: vec![], default: None, - }) + }), ); let params = create_form_params_with_schema(schema); @@ -1584,4 +1612,4 @@ mod tests { let error = result.unwrap_err(); assert!(error.to_string().contains("Invalid enum value")); } -} \ No newline at end of file +} diff --git a/neva/src/types/helpers.rs b/neva/src/types/helpers.rs index b76fd85..3a20958 100644 --- a/neva/src/types/helpers.rs +++ b/neva/src/types/helpers.rs @@ -1,7 +1,10 @@ -//! A set of helpers for types +//! A set of helpers for types -use crate::json::{JsonSchema, schemars::{schema_for, Schema}}; -use base64::{engine::general_purpose, Engine}; +use crate::json::{ + JsonSchema, + schemars::{Schema, schema_for}, +}; +use base64::{Engine, engine::general_purpose}; use bytes::Bytes; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -10,16 +13,16 @@ use std::{ ops::{Deref, DerefMut}, }; -#[cfg(feature = "server")] -pub(crate) mod macros; #[cfg(feature = "server")] pub(crate) mod extract; +#[cfg(feature = "server")] +pub(crate) mod macros; -/// Serializes bytes as base64 string +/// Serializes bytes as base64 string #[inline] pub(crate) fn serialize_bytes_as_base64(bytes: &Bytes, serializer: S) -> Result where - S: serde::Serializer + S: serde::Serializer, { let encoded = general_purpose::STANDARD.encode(bytes); serializer.serialize_str(&encoded) @@ -29,10 +32,11 @@ where #[inline] pub(crate) fn deserialize_base64_as_bytes<'de, D>(deserializer: D) -> Result where - D: serde::Deserializer<'de> + D: serde::Deserializer<'de>, { let s = String::deserialize(deserializer)?; - let decoded = general_purpose::STANDARD.decode(&s) + let decoded = general_purpose::STANDARD + .decode(&s) .map_err(serde::de::Error::custom)?; Ok(Bytes::from(decoded)) } @@ -42,12 +46,10 @@ pub(crate) fn serialize_value_as_string(value: &Value, serializer: S) -> Resu where S: serde::Serializer, { - let json_str = serde_json::to_string(value) - .map_err(serde::ser::Error::custom)?; + let json_str = serde_json::to_string(value).map_err(serde::ser::Error::custom)?; serializer.serialize_str(&json_str) } - #[inline] pub(crate) fn deserialize_value_from_string<'de, D>(deserializer: D) -> Result where @@ -63,23 +65,23 @@ pub enum PropertyType { /// Unknown type. #[serde(rename = "none")] None, - + /// Array type #[serde(rename = "array")] Array, - + /// String type #[serde(rename = "string")] String, - + /// Number type #[serde(rename = "number", alias = "integer")] Number, - + /// Boolean type #[serde(rename = "boolean")] Bool, - + /// Object type. #[serde(rename = "object")] Object, @@ -95,7 +97,7 @@ impl Default for PropertyType { impl From<&str> for PropertyType { #[inline] fn from(s: &str) -> Self { - match s { + match s { "array" => PropertyType::Array, "string" => PropertyType::String, "number" | "integer" => PropertyType::Number, @@ -117,7 +119,7 @@ impl From for PropertyType { impl Display for PropertyType { #[inline] fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self { + match self { PropertyType::Array => write!(f, "array"), PropertyType::String => write!(f, "string"), PropertyType::Number => write!(f, "number"), @@ -232,20 +234,22 @@ mod tests { #[test] fn it_serializes_serde_json_value_as_str() { - let v = Test2 { value: serde_json::json!({ "x": 5, "y": 10 }) }; + let v = Test2 { + value: serde_json::json!({ "x": 5, "y": 10 }), + }; let json = serde_json::to_string(&v).unwrap(); assert_eq!(json, r#"{"value":"{\"x\":5,\"y\":10}"}"#); } - + #[test] fn it_deserializes_serde_json_value_as_str() { let s = r#"{"value":"{\"x\":5,\"y\":10}"}"#; let v: Test2 = serde_json::from_str(s).unwrap(); - + assert_eq!(v.value, serde_json::json!({ "x": 5, "y": 10 })); } - + #[test] fn it_returns_category_for_string() { assert_eq!(String::category(), PropertyType::String); @@ -325,20 +329,20 @@ mod tests { fn it_returns_category_for_f64() { assert_eq!(f64::category(), PropertyType::Number); } - + #[test] fn it_returns_category_for_json() { assert_eq!(Json::::category(), PropertyType::Object); } - + struct Test; - + #[derive(Serialize, Deserialize)] struct Test2 { #[serde( serialize_with = "serialize_value_as_string", deserialize_with = "deserialize_value_from_string" )] - value: Value + value: Value, } } diff --git a/neva/src/types/helpers/extract.rs b/neva/src/types/helpers/extract.rs index e56e1e5..40fe0a5 100644 --- a/neva/src/types/helpers/extract.rs +++ b/neva/src/types/helpers/extract.rs @@ -1,11 +1,11 @@ -//! Traits and helpers for type extraction from request arguments +//! Traits and helpers for type extraction from request arguments -use std::collections::hash_map::Iter; -use serde::de::DeserializeOwned; use crate::Context; use crate::error::{Error, ErrorCode}; -use crate::types::{Meta, ProgressToken}; use crate::types::request::RequestParamsMeta; +use crate::types::{Meta, ProgressToken}; +use serde::de::DeserializeOwned; +use std::collections::hash_map::Iter; #[cfg(feature = "tasks")] use crate::types::RelatedTaskMetadata; @@ -96,7 +96,10 @@ impl RequestArgument for Meta { let meta = payload.expect_meta(); meta.as_ref() .and_then(|meta| meta.progress_token.clone()) - .ok_or(Error::new(ErrorCode::InvalidParams, "Missing progress token")) + .ok_or(Error::new( + ErrorCode::InvalidParams, + "Missing progress token", + )) .map(Meta) } @@ -115,7 +118,10 @@ impl RequestArgument for Meta { let meta = payload.expect_meta(); meta.as_ref() .and_then(|meta| meta.task.clone()) - .ok_or(Error::new(ErrorCode::InvalidParams, "Missing progress token")) + .ok_or(Error::new( + ErrorCode::InvalidParams, + "Missing progress token", + )) .map(Meta) } @@ -145,13 +151,18 @@ impl RequestArgument for Context { #[inline] pub(crate) fn extract_arg>( meta: &Option, - iter: &mut Iter<'_, String, serde_json::Value> + iter: &mut Iter<'_, String, serde_json::Value>, ) -> Result { match T::source() { Source::Meta => T::extract(Payload::Meta(meta)), - Source::Args => T::extract(Payload::Args(iter - .next() - .ok_or(Error::new(ErrorCode::InvalidParams, "Invalid param provided"))? - .1.clone())), + Source::Args => T::extract(Payload::Args( + iter.next() + .ok_or(Error::new( + ErrorCode::InvalidParams, + "Invalid param provided", + ))? + .1 + .clone(), + )), } } diff --git a/neva/src/types/helpers/macros.rs b/neva/src/types/helpers/macros.rs index cde1f2c..23d4a8c 100644 --- a/neva/src/types/helpers/macros.rs +++ b/neva/src/types/helpers/macros.rs @@ -1,13 +1,10 @@ -//! Helper Macros +//! Helper Macros use super::{PropertyType, TypeCategory}; -use serde_json::Value; use crate::types::{ - Json, Meta, Uri, - CallToolRequestParams, - ReadResourceRequestParams, - GetPromptRequestParams + CallToolRequestParams, GetPromptRequestParams, Json, Meta, ReadResourceRequestParams, Uri, }; +use serde_json::Value; macro_rules! impl_type_category { ($t:ty, $cat:expr) => { @@ -67,4 +64,4 @@ impl_type_category!(Meta, T, PropertyType::None); impl_type_category!(crate::Context, PropertyType::None); impl_type_category!(Value, PropertyType::Object); -impl_type_category!(Json, T, PropertyType::Object); \ No newline at end of file +impl_type_category!(Json, T, PropertyType::Object); diff --git a/neva/src/types/icon.rs b/neva/src/types/icon.rs index 7d03458..bcc4abf 100644 --- a/neva/src/types/icon.rs +++ b/neva/src/types/icon.rs @@ -1,43 +1,43 @@ //! Types and utilities for icons -use serde::{Serialize, Deserialize, Serializer, Deserializer}; use crate::types::Uri; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; /// Represents an optionally sized icon that can be displayed in a user interface. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details. #[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)] pub struct Icon { /// Optional MIME type override if the source MIME type is missing or generic. - /// + /// /// For example, `"image/png"`, `"image/jpeg"`, or `"image/svg+xml"`. #[serde(rename = "mimeType", skip_serializing_if = "Option::is_none")] pub mime: Option, - + /// Optional array of strings that specify sizes at which the icon can be used. - /// Each string should be in WxH format (e.g., `"48x48"`, `"96x96"`) or `"any"` + /// Each string should be in WxH format (e.g., `"48x48"`, `"96x96"`) or `"any"` /// for scalable formats like SVG. - /// + /// /// If not provided, the client should assume that the icon can be used at any size. #[serde(skip_serializing_if = "Option::is_none")] pub sizes: Option>, - - /// A standard URI pointing to an icon resource. Maybe an HTTP/HTTPS URL or a + + /// A standard URI pointing to an icon resource. Maybe an HTTP/HTTPS URL or a /// `data:` URI with Base64-encoded image data. - /// - /// Consumers **SHOULD** take steps to ensure URLs serving icons are from the + /// + /// Consumers **SHOULD** take steps to ensure URLs serving icons are from the /// same domain as the client/server or a trusted domain. - /// + /// /// Consumers **SHOULD** take appropriate precautions when consuming SVGs as they can contain /// executable JavaScript. pub src: Uri, - + /// Optional specifier for the theme this icon is designed for. `light` indicates /// the icon is designed to be used with a light background, and `dark` indicates /// the icon is designed to be used with a dark background. - /// + /// /// If not provided, the client should assume the icon can be used with any theme. - pub theme: Option + pub theme: Option, } /// Represents the theme the icon is designed for. @@ -46,9 +46,9 @@ pub struct Icon { pub enum IconTheme { /// The icon is designed for use with a dark background. Dark, - + /// The icon is designed for use with a light background. - Light + Light, } /// Represents the size of an icon. @@ -56,20 +56,20 @@ pub enum IconTheme { pub struct IconSize { /// The width of the icon in pixels. width: usize, - + /// The height of the icon in pixels. height: usize, - + /// Indicates whether the icon is scalable (e.g., SVG). /// If `true` the `width` and `height` fields should be ignored. - is_any: bool + is_any: bool, } impl Serialize for IconSize { #[inline] fn serialize(&self, serializer: S) -> Result - where - S: Serializer + where + S: Serializer, { if self.is_any { serializer.serialize_str("any") @@ -101,18 +101,17 @@ impl From<&str> for IconSize { #[inline] fn from(value: &str) -> Self { match value { - "any" => Self { width: 0, height: 0, is_any: true }, + "any" => Self { + width: 0, + height: 0, + is_any: true, + }, s => { let mut parts = s.split('x'); Self { - width: parts.next().map(|p| p - .parse() - .unwrap_or(0)) - .unwrap_or(0), - height: parts.next().map(|p| p.parse() - .unwrap_or(0)) - .unwrap_or(0), - is_any: false + width: parts.next().map(|p| p.parse().unwrap_or(0)).unwrap_or(0), + height: parts.next().map(|p| p.parse().unwrap_or(0)).unwrap_or(0), + is_any: false, } } } @@ -126,17 +125,17 @@ impl IconSize { Self { width, height, - is_any: false + is_any: false, } } - + /// Creates a new [`IconSize`] that can be used with any size #[inline] pub fn any() -> Self { Self { width: 0, height: 0, - is_any: true + is_any: true, } } } @@ -149,10 +148,10 @@ impl Icon { src: url.into(), mime: None, sizes: None, - theme: None + theme: None, } } - + /// Sets the MIME type #[inline] pub fn with_mime(mut self, mime: impl Into) -> Self { @@ -178,20 +177,20 @@ impl Icon { #[cfg(test)] mod tests { use super::*; - + fn create_test_icon() -> Icon { Icon { mime: Some("image/png".into()), sizes: Some(vec![IconSize::new(48, 48)]), src: Uri::from("https://example.com/icon.png"), - theme: Some(IconTheme::Dark) + theme: Some(IconTheme::Dark), } } - + #[test] fn it_converts_icon_size_from_str() { let size = IconSize::from("48x48"); - + assert_eq!(size.width, 48); assert_eq!(size.height, 48); assert!(!size.is_any); @@ -223,47 +222,51 @@ mod tests { assert_eq!(size.height, 0); assert!(!size.is_any); } - + #[test] fn it_serializes_icon_sizes() { let size = [ IconSize::any(), IconSize::new(48, 48), - IconSize::new(128, 128) + IconSize::new(128, 128), ]; - + let serialized = serde_json::to_string(&size).unwrap(); - + assert_eq!(serialized, r#"["any","48x48","128x128"]"#); } - + #[test] fn it_deserializes_icon_sizes() { - let deserialized: Vec = serde_json::from_str(r#"["any","48x48","128x128"]"#) - .unwrap(); - - assert_eq!(deserialized, [ - IconSize::any(), - IconSize::new(48, 48), - IconSize::new(128, 128) - ]); + let deserialized: Vec = + serde_json::from_str(r#"["any","48x48","128x128"]"#).unwrap(); + + assert_eq!( + deserialized, + [ + IconSize::any(), + IconSize::new(48, 48), + IconSize::new(128, 128) + ] + ); } - + #[test] fn it_serializes_icon() { let icon = create_test_icon(); let serialized = serde_json::to_string(&icon).unwrap(); - + assert_eq!( - serialized, - r#"{"mimeType":"image/png","sizes":["48x48"],"src":"https://example.com/icon.png","theme":"dark"}"#); + serialized, + r#"{"mimeType":"image/png","sizes":["48x48"],"src":"https://example.com/icon.png","theme":"dark"}"# + ); } - + #[test] fn it_deserializes_icon() { let json = r#"{"mimeType":"image/png","sizes":["48x48"],"src":"https://example.com/icon.png","theme":"dark"}"#; let deserialized: Icon = serde_json::from_str(json).unwrap(); - - assert_eq!(deserialized, create_test_icon()) + + assert_eq!(deserialized, create_test_icon()) } -} \ No newline at end of file +} diff --git a/neva/src/types/notification.rs b/neva/src/types/notification.rs index 06137a0..9be886a 100644 --- a/neva/src/types/notification.rs +++ b/neva/src/types/notification.rs @@ -1,16 +1,15 @@ -//! Utilities for Notifications +//! Utilities for Notifications -use serde::{Serialize, Deserialize}; -use serde::de::DeserializeOwned; -use crate::types::{RequestId, Message, JSONRPC_VERSION}; +use crate::types::{JSONRPC_VERSION, Message, RequestId}; #[cfg(feature = "server")] -use crate::{error::Error, types::{FromRequest, Request}}; - -pub use log_message::{ - LogMessage, - LoggingLevel, - SetLevelRequestParams +use crate::{ + error::Error, + types::{FromRequest, Request}, }; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; + +pub use log_message::{LogMessage, LoggingLevel, SetLevelRequestParams}; #[cfg(feature = "server")] use crate::app::handler::{FromHandlerParams, HandlerParams}; @@ -20,30 +19,30 @@ pub use progress::ProgressNotification; #[cfg(feature = "tracing")] pub use formatter::NotificationFormatter; -mod progress; -mod log_message; -#[cfg(feature = "tracing")] -mod formatter; #[cfg(feature = "tracing")] pub mod fmt; +#[cfg(feature = "tracing")] +mod formatter; +mod log_message; +mod progress; /// List of commands for Notifications pub mod commands { /// Notification name that indicates that the notifications have initialized. pub const INITIALIZED: &str = "notifications/initialized"; - + /// Notification name that indicates that notifications have been canceled. pub const CANCELLED: &str = "notifications/cancelled"; - + /// Notification name that indicates that a new log message has been received. pub const MESSAGE: &str = "notifications/message"; - + /// Notification name that indicates that a progress notification has been received. pub const PROGRESS: &str = "notifications/progress"; - + /// Notification name that indicates that a log message has been received on stderr. pub const STDERR: &str = "notifications/stderr"; - + /// Command name that sets the log level. pub const SET_LOG_LEVEL: &str = "logging/setLevel"; } @@ -51,7 +50,7 @@ pub mod commands { /// A notification which does not expect a response. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Notification { - /// JSON-RPC protocol version. + /// JSON-RPC protocol version. /// /// > Note: always 2.0. pub jsonrpc: String, @@ -68,25 +67,25 @@ pub struct Notification { pub session_id: Option, } -/// This notification can be sent by either side to indicate that it is cancelling +/// This notification can be sent by either side to indicate that it is cancelling /// a previously-issued request. -/// -/// The request **SHOULD** still be in-flight, but due to communication latency, +/// +/// The request **SHOULD** still be in-flight, but due to communication latency, /// it is always possible that this notification **MAY** arrive after the request has already finished. -/// -/// This notification indicates that the result will be unused, +/// +/// This notification indicates that the result will be unused, /// so any associated processing **SHOULD** cease. -/// +/// /// A client **MUST NOT** attempt to cancel its `initialize` request. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CancelledNotificationParams { /// The ID of the request to cancel. - /// + /// /// This **MUST** correspond to the ID of a request previously issued in the same direction. #[serde(rename = "requestId")] pub request_id: RequestId, - - /// An optional string describing the reason for the cancellation. + + /// An optional string describing the reason for the cancellation. /// This **MAY** be logged or presented to the user. #[serde(skip_serializing_if = "Option::is_none")] pub reason: Option, @@ -112,11 +111,11 @@ impl Notification { /// Create a new [`Notification`] #[inline] pub fn new(method: &str, params: Option) -> Self { - Self { + Self { jsonrpc: JSONRPC_VERSION.into(), session_id: None, - method: method.into(), - params + method: method.into(), + params, } } @@ -129,38 +128,40 @@ impl Notification { id } } - + /// Parses [`Notification`] params into specified type #[inline] pub fn params(&self) -> Option { - match self.params { + match self.params { Some(ref params) => serde_json::from_value(params.clone()).ok(), None => None, } } - + /// Writes the [`Notification`] #[inline] #[cfg(feature = "tracing")] pub fn write(self) { let is_stderr = self.is_stderr(); - let Some(params) = self.params else { return; }; + let Some(params) = self.params else { + return; + }; if is_stderr { Self::write_err_internal(params); } else { - match serde_json::from_value::(params.clone()) { + match serde_json::from_value::(params.clone()) { Ok(log) => log.write(), Err(err) => tracing::error!(logger = "neva", "{}", err), } } } - + /// Returns `true` is the [`Notification`] received with method `notifications/stderr` #[inline] pub fn is_stderr(&self) -> bool { self.method.as_str() == commands::STDERR } - + /// Writes the [`Notification`] as [`LoggingLevel::Error`] #[inline] #[cfg(feature = "tracing")] @@ -169,36 +170,34 @@ impl Notification { Self::write_err_internal(params) } } - + /// Serializes the [`Notification`] to JSON string pub fn to_json(self) -> String { serde_json::to_string(&self).unwrap() } - + #[inline] #[cfg(feature = "tracing")] fn write_err_internal(params: serde_json::Value) { - let err = params - .get("content") - .unwrap_or(¶ms); + let err = params.get("content").unwrap_or(¶ms); tracing::error!("{}", err); } } #[cfg(test)] mod tests { - use serde_json::json; use super::*; - + use serde_json::json; + #[test] fn it_creates_new_notification() { let notification = Notification::new("test", Some(json!({ "param": "value" }))); - + assert_eq!(notification.jsonrpc, "2.0"); assert_eq!(notification.method, "test"); - + let params_json = serde_json::to_string(¬ification.params.unwrap()).unwrap(); - + assert_eq!(params_json, r#"{"param":"value"}"#); } -} \ No newline at end of file +} diff --git a/neva/src/types/notification/fmt.rs b/neva/src/types/notification/fmt.rs index ca1d657..3f4ecdc 100644 --- a/neva/src/types/notification/fmt.rs +++ b/neva/src/types/notification/fmt.rs @@ -1,35 +1,31 @@ -//! A generic tracing/logging formatting layer for notifications +//! A generic tracing/logging formatting layer for notifications -use std::io::{self, Write}; -use once_cell::sync::Lazy; -use tokio::sync::mpsc::{channel, Sender}; use crate::shared::MessageRegistry; -use crate::types::notification::{ - Notification, - formatter::build_notification -}; +use crate::types::notification::{Notification, formatter::build_notification}; +use once_cell::sync::Lazy; +use std::io::{self, Write}; +use tokio::sync::mpsc::{Sender, channel}; use tracing::{ - {Event, Id, Subscriber, field::Visit}, + field::Field, span::Attributes, - field::Field + {Event, Id, Subscriber, field::Visit}, }; use tracing_subscriber::{ - {layer::Context, Layer}, - registry::LookupSpan + registry::LookupSpan, + {Layer, layer::Context}, }; const MCP_SESSION_ID: &str = "mcp_session_id"; -pub(crate) static LOG_REGISTRY: Lazy = - Lazy::new(MessageRegistry::new); +pub(crate) static LOG_REGISTRY: Lazy = Lazy::new(MessageRegistry::new); /// Creates a custom tracing layer that delivers messages to MCP Client -/// +/// /// # Example /// ```no_run /// use tracing_subscriber::prelude::*; /// use neva::types::notification; -/// +/// /// tracing_subscriber::registry() /// .with(notification::fmt::layer()) /// .init(); @@ -42,11 +38,11 @@ pub fn layer() -> MpscLayer { } }); MpscLayer { - sender: NotificationSender::new(tx) + sender: NotificationSender::new(tx), } } -/// Keeps a [`Sender`] +/// Keeps a [`Sender`] #[derive(Debug)] struct NotificationSender { sender: Sender, @@ -75,34 +71,30 @@ impl NotificationSender { /// ``` #[derive(Debug)] pub struct MpscLayer { - sender: NotificationSender + sender: NotificationSender, } impl Layer for MpscLayer where - S: Subscriber + for<'a> LookupSpan<'a>, + S: Subscriber + for<'a> LookupSpan<'a>, { #[inline] fn on_new_span(&self, attrs: &Attributes<'_>, id: &Id, ctx: Context<'_, S>) { let mut visitor = SpanVisitor { session_id: None }; attrs.record(&mut visitor); if let Some(span) = ctx.span(id) - && let Some(mcp_session_id) = visitor.session_id { - span - .extensions_mut() - .insert(mcp_session_id); + && let Some(mcp_session_id) = visitor.session_id + { + span.extensions_mut().insert(mcp_session_id); } } - + #[inline] fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) { let mut notification = build_notification(event); if let Some(span) = ctx.event_span(event) { - notification.session_id = span - .extensions() - .get::() - .cloned(); - self.sender.send_notification(notification); + notification.session_id = span.extensions().get::().cloned(); + self.sender.send_notification(notification); } else { let mut stderr = io::stderr(); let json = serde_json::to_string(¬ification).unwrap(); @@ -119,11 +111,12 @@ impl Visit for SpanVisitor { #[inline] fn record_str(&mut self, field: &Field, value: &str) { if field.name() == MCP_SESSION_ID - && let Ok(session_id) = uuid::Uuid::parse_str(value) { + && let Ok(session_id) = uuid::Uuid::parse_str(value) + { self.session_id = Some(session_id); } } - + fn record_debug(&mut self, field: &Field, value: &dyn std::fmt::Debug) { // fallback if id was passed as %mcp_session_id or something else if field.name() == MCP_SESSION_ID && self.session_id.is_none() { @@ -132,10 +125,10 @@ impl Visit for SpanVisitor { .strip_prefix('"') .and_then(|s| s.strip_suffix('"')) .unwrap_or(&formatted); - - if let Ok(session_id) = uuid::Uuid::parse_str(stripped) { + + if let Ok(session_id) = uuid::Uuid::parse_str(stripped) { self.session_id = Some(session_id); - } + } } } } diff --git a/neva/src/types/notification/formatter.rs b/neva/src/types/notification/formatter.rs index 3457eaa..03362ff 100644 --- a/neva/src/types/notification/formatter.rs +++ b/neva/src/types/notification/formatter.rs @@ -1,25 +1,16 @@ -//! A tracing/logging formatter for notifications +//! A tracing/logging formatter for notifications use std::collections::BTreeMap; -use tracing::{Event, Subscriber, Level}; use tracing::level_filters::LevelFilter; +use tracing::{Event, Level, Subscriber}; use tracing_subscriber::{ - fmt::{ - format::FormatFields, - FormatEvent, - FmtContext, - format::Writer - }, field::Visit, + fmt::{FmtContext, FormatEvent, format::FormatFields, format::Writer}, registry::LookupSpan, }; -use crate::types::notification::{ - LogMessage, - LoggingLevel, - Notification -}; use crate::types::ProgressToken; +use crate::types::notification::{LogMessage, LoggingLevel, Notification}; /// A formatter that formats tracing events into MCP notification logs #[allow(missing_debug_implementations)] @@ -28,12 +19,12 @@ pub struct NotificationFormatter; impl From<&Level> for LoggingLevel { #[inline] fn from(level: &Level) -> Self { - match *level { + match *level { Level::ERROR => LoggingLevel::Error, Level::WARN => LoggingLevel::Warning, Level::INFO => LoggingLevel::Info, Level::DEBUG => LoggingLevel::Debug, - Level::TRACE => LoggingLevel::Debug + Level::TRACE => LoggingLevel::Debug, } } } @@ -47,7 +38,7 @@ impl From for LoggingLevel { LevelFilter::INFO => LoggingLevel::Info, LevelFilter::DEBUG => LoggingLevel::Debug, LevelFilter::TRACE => LoggingLevel::Debug, - _ => LoggingLevel::Info + _ => LoggingLevel::Info, } } } @@ -106,7 +97,7 @@ pub(super) fn build_notification(event: &Event<'_>) -> Notification { let meta = event.metadata(); let level = meta.level(); let fields = extract_fields(event); - + match meta.target() { "progress" => { let token = fields @@ -121,10 +112,8 @@ pub(super) fn build_notification(event: &Event<'_>) -> Notification { .get("value") .map(|v| v.to_string().replace("\"", "").parse().unwrap()); - token.unwrap() - .notify(value.unwrap(), total) - .into() - }, + token.unwrap().notify(value.unwrap(), total).into() + } _ => { let logger = fields .get("logger") @@ -179,9 +168,9 @@ impl Visit for Visitor<'_> { // Only use this if nothing else handled it if !self.map.contains_key(field.name()) { let formatted = format!("{value:?}"); - let value = serde_json::to_value(&formatted) - .unwrap_or(serde_json::Value::String(formatted)); + let value = + serde_json::to_value(&formatted).unwrap_or(serde_json::Value::String(formatted)); self.map.insert(field.name(), value); } } -} \ No newline at end of file +} diff --git a/neva/src/types/notification/log_message.rs b/neva/src/types/notification/log_message.rs index 357cfc0..c844dbb 100644 --- a/neva/src/types/notification/log_message.rs +++ b/neva/src/types/notification/log_message.rs @@ -1,71 +1,71 @@ -//! Utilities for log messages +//! Utilities for log messages -use serde::{Serialize, Deserialize}; +#[cfg(feature = "server")] +use crate::app::handler::{FromHandlerParams, HandlerParams}; use crate::error::Error; -use crate::types::response::ErrorDetails; use crate::types::notification::Notification; +use crate::types::response::ErrorDetails; #[cfg(feature = "server")] -use crate::types::{Request, FromRequest}; +use crate::types::{FromRequest, Request}; +use serde::{Deserialize, Serialize}; #[cfg(feature = "tracing")] use tracing::Level; -#[cfg(feature = "server")] -use crate::app::handler::{FromHandlerParams, HandlerParams}; /// The severity of a log message. -/// This map to syslog message severities, as specified in +/// This map to syslog message severities, as specified in /// [RFC-5424](https://datatracker.ietf.org/doc/html/rfc5424#section-6.2.1): #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum LoggingLevel { /// Detailed debug information, typically only valuable to developers. Debug, - + /// Normal operational messages that require no action. Info, - + /// Warning conditions that don't represent an error but indicate potential issues. Warning, - + /// Error conditions that should be addressed but don't require immediate action. Error, - + /// Normal but significant events that might deserve attention. Notice, - + /// Critical conditions that require immediate attention. Critical, - + /// Action must be taken immediately to address the condition. Alert, - + /// System is unusable and requires immediate attention. - Emergency + Emergency, } /// Sent from the server as the payload of "notifications/message" notifications whenever a log message is generated. /// If no logging/setLevel request has been sent from the client, the server MAY decide which messages to send automatically. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct LogMessage { /// The severity of this log message. pub level: LoggingLevel, - + /// An optional name of the logger issuing this message. #[serde(skip_serializing_if = "Option::is_none")] pub logger: Option, - + /// The data to be logged, such as a string message or an object. #[serde(skip_serializing_if = "Option::is_none")] pub data: Option, } /// A request from the client to the server, to enable or adjust logging. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SetLevelRequestParams { - /// The level of logging that the client wants to receive from the server. + /// The level of logging that the client wants to receive from the server. /// The server should send all logs at this level and higher (i.e., more severe) to the client as notifications/message. pub level: LoggingLevel, } @@ -85,10 +85,7 @@ impl From for LogMessage { impl From for Notification { #[inline] fn from(log: LogMessage) -> Self { - Self::new( - super::commands::MESSAGE, - serde_json::to_value(log).ok() - ) + Self::new(super::commands::MESSAGE, serde_json::to_value(log).ok()) } } @@ -105,13 +102,17 @@ impl LogMessage { /// Creates a new [`LogMessage`] #[inline] pub fn new( - level: LoggingLevel, - logger: Option, - data: Option + level: LoggingLevel, + logger: Option, + data: Option, ) -> Self { - Self { level, logger, data } + Self { + level, + logger, + data, + } } - + /// Writes a log message #[inline] #[cfg(feature = "tracing")] @@ -128,4 +129,4 @@ impl LogMessage { LoggingLevel::Debug => tracing::event!(Level::DEBUG, %data), }; } -} \ No newline at end of file +} diff --git a/neva/src/types/notification/progress.rs b/neva/src/types/notification/progress.rs index d83193c..8b04ff8 100644 --- a/neva/src/types/notification/progress.rs +++ b/neva/src/types/notification/progress.rs @@ -1,23 +1,23 @@ -//! Progress notification +//! Progress notification -use serde::{Serialize, Deserialize}; -use crate::types::notification::Notification; use crate::types::ProgressToken; +use crate::types::notification::Notification; +use serde::{Deserialize, Serialize}; /// An out-of-band notification used to inform the receiver of a progress update for a long-running request. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ProgressNotification { - /// The progress token which was given in the initial request, + /// The progress token which was given in the initial request, /// used to associate this notification with the request that is proceeding. #[serde(rename = "progressToken")] pub progress_token: ProgressToken, - - /// The progress thus far. This should increase every time progress is made, + + /// The progress thus far. This should increase every time progress is made, /// even if the total is unknown. pub progress: f64, - + /// Total number of items to a process (or total progress required), if known. #[serde(skip_serializing_if = "Option::is_none")] pub total: Option, @@ -27,8 +27,8 @@ impl From for Notification { #[inline] fn from(progress: ProgressNotification) -> Self { Self::new( - super::commands::PROGRESS, - serde_json::to_value(progress).ok() + super::commands::PROGRESS, + serde_json::to_value(progress).ok(), ) } -} \ No newline at end of file +} diff --git a/neva/src/types/progress.rs b/neva/src/types/progress.rs index 12b874d..1ff8aba 100644 --- a/neva/src/types/progress.rs +++ b/neva/src/types/progress.rs @@ -1,11 +1,11 @@ -//! Utilities for tracking operation's progress +//! Utilities for tracking operation's progress +use crate::shared::{ArcSlice, ArcStr, MemChr}; +use crate::types::notification::ProgressNotification; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use std::fmt; use std::fmt::Display; use std::str::FromStr; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use crate::shared::{ArcSlice, ArcStr, MemChr}; -use crate::types::notification::ProgressNotification; const SEPARATOR: u8 = b'/'; @@ -14,15 +14,15 @@ const SEPARATOR: u8 = b'/'; pub enum ProgressToken { /// Represents a numeric progress token. Number(i64), - + /// Represents a UUID progress token. Uuid(uuid::Uuid), - + /// Represents a string progress token. String(ArcStr), - + /// Represents a slash-separated progress token. - Slice(ArcSlice) + Slice(ArcSlice), } impl Display for ProgressToken { @@ -32,7 +32,7 @@ impl Display for ProgressToken { ProgressToken::Number(n) => write!(f, "{n}"), ProgressToken::Uuid(u) => write!(f, "{u}"), ProgressToken::String(s) => write!(f, "{s}"), - ProgressToken::Slice(s) => write!(f, "{s}") + ProgressToken::Slice(s) => write!(f, "{s}"), } } } @@ -121,7 +121,7 @@ impl ProgressToken { ProgressNotification { progress_token: self.clone(), progress, - total + total, } } } @@ -132,10 +132,13 @@ mod tests { #[test] fn it_serializes_and_deserializes_slice_through_str_request_id() { - let expected_id = ProgressToken::Slice([ - ProgressToken::String("user".into()), - ProgressToken::Number(1) - ].into()); + let expected_id = ProgressToken::Slice( + [ + ProgressToken::String("user".into()), + ProgressToken::Number(1), + ] + .into(), + ); let json = serde_json::to_string(&expected_id).unwrap(); let new_id: ProgressToken = serde_json::from_str(&json).unwrap(); @@ -145,10 +148,13 @@ mod tests { #[test] fn it_serializes_and_deserializes_slice_through_value_request_id() { - let expected_id = ProgressToken::Slice([ - ProgressToken::String("user".into()), - ProgressToken::Number(1) - ].into()); + let expected_id = ProgressToken::Slice( + [ + ProgressToken::String("user".into()), + ProgressToken::Number(1), + ] + .into(), + ); let json = serde_json::to_value(&expected_id).unwrap(); let new_id: ProgressToken = serde_json::from_value(json).unwrap(); @@ -215,4 +221,4 @@ mod tests { assert_eq!(expected_id, new_id); } -} \ No newline at end of file +} diff --git a/neva/src/types/prompt.rs b/neva/src/types/prompt.rs index 943aa03..41a8cbe 100644 --- a/neva/src/types/prompt.rs +++ b/neva/src/types/prompt.rs @@ -1,31 +1,31 @@ -//! Represents an MCP prompt +//! Represents an MCP prompt -use std::collections::HashMap; -use std::fmt::{Debug, Formatter}; -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use crate::shared; -use crate::types::{Cursor, Icon}; -use crate::types::request::RequestParamsMeta; -#[cfg(feature = "server")] -use std::sync::Arc; -#[cfg(feature = "server")] -use std::future::Future; -#[cfg(feature = "server")] -use futures_util::future::BoxFuture; #[cfg(feature = "server")] use super::helpers::TypeCategory; #[cfg(feature = "server")] use crate::error::{Error, ErrorCode}; +use crate::shared; #[cfg(feature = "server")] use crate::types::FromRequest; +use crate::types::request::RequestParamsMeta; +use crate::types::{Cursor, Icon}; #[cfg(feature = "server")] use crate::types::{IntoResponse, Page, PropertyType, Request, RequestId, Response}; +#[cfg(feature = "server")] +use futures_util::future::BoxFuture; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::HashMap; +use std::fmt::{Debug, Formatter}; +#[cfg(feature = "server")] +use std::future::Future; +#[cfg(feature = "server")] +use std::sync::Arc; #[cfg(feature = "server")] use crate::app::{ context::Context, - handler::{FromHandlerParams, HandlerParams, GenericHandler, Handler, RequestHandler} + handler::{FromHandlerParams, GenericHandler, Handler, HandlerParams, RequestHandler}, }; pub use get_prompt_result::{GetPromptResult, PromptMessage}; @@ -38,16 +38,16 @@ mod get_prompt_result; pub mod commands { /// Command name that returns a list of prompts the server has. pub const LIST: &str = "prompts/list"; - + /// Notification name that indicates that the list of prompts has changed. pub const LIST_CHANGED: &str = "notifications/prompts/list_changed"; - + /// Command name that returns a prompt provided by the server. pub const GET: &str = "prompts/get"; } /// A prompt or prompt template that the server offers. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Clone, Serialize, Deserialize)] pub struct Prompt { @@ -81,11 +81,11 @@ pub struct Prompt { /// - `image/webp` - WebP images (modern, efficient format) #[serde(skip_serializing_if = "Option::is_none")] pub icons: Option>, - + /// Metadata reserved by MCP for protocol-level metadata. #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] pub meta: Option, - + /// A get prompt handler #[serde(skip)] #[cfg(feature = "server")] @@ -103,7 +103,7 @@ pub struct Prompt { } /// Describes an argument that a prompt can accept. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PromptArgument { @@ -131,13 +131,13 @@ pub struct ListPromptsRequestParams { } /// Used by the client to get a prompt provided by the server. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Serialize, Deserialize)] pub struct GetPromptRequestParams { /// The name of the prompt or prompt template. pub name: String, - + /// Arguments to use for templating the prompt. #[serde(rename = "arguments", skip_serializing_if = "Option::is_none")] pub args: Option>, @@ -151,7 +151,7 @@ pub struct GetPromptRequestParams { } /// The server's response to a prompts/list request from the client. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/2024-11-05/schema.json) for details #[derive(Debug, Default, Serialize, Deserialize)] pub struct ListPromptsResult { @@ -174,7 +174,7 @@ impl IntoResponse for ListPromptsResult { fn into_response(self, req_id: RequestId) -> Response { match serde_json::to_value(self) { Ok(v) => Response::success(req_id, v), - Err(err) => Response::error(req_id, err.into()) + Err(err) => Response::error(req_id, err.into()), } } } @@ -185,7 +185,7 @@ impl From> for ListPromptsResult { fn from(prompts: Vec) -> Self { Self { next_cursor: None, - prompts + prompts, } } } @@ -196,7 +196,7 @@ impl From> for ListPromptsResult { fn from(page: Page<'_, Prompt>) -> Self { Self { next_cursor: page.next_cursor, - prompts: page.items.to_vec() + prompts: page.items.to_vec(), } } } @@ -233,9 +233,9 @@ impl From<&str> for PromptArgument { #[inline] fn from(name: &str) -> Self { Self { - name: name.into(), + name: name.into(), descr: None, - required: Some(true) + required: Some(true), } } } @@ -247,7 +247,7 @@ impl From> for PromptArgument { Self { name: name.into(), descr: None, - required: Some(true) + required: Some(true), } } } @@ -256,8 +256,7 @@ impl From> for PromptArgument { impl From for PromptArgument { #[inline] fn from(json: Value) -> Self { - serde_json::from_value(json) - .expect("A correct PromptArgument value must be provided") + serde_json::from_value(json).expect("A correct PromptArgument value must be provided") } } @@ -321,11 +320,14 @@ where F: PromptHandler, R: TryInto, R::Error: Into, - Args: TryFrom + Args: TryFrom, { /// Creates a new [`PromptFunc`] wrapped into [`Arc`] pub(crate) fn new(func: F) -> Arc { - let func = Self { func, _marker: std::marker::PhantomData }; + let func = Self { + func, + _marker: std::marker::PhantomData, + }; Arc::new(func) } } @@ -336,7 +338,7 @@ where F: PromptHandler, R: TryInto, R::Error: Into, - Args: TryFrom + Send + Sync + Args: TryFrom + Send + Sync, { #[inline] fn call(&self, params: HandlerParams) -> BoxFuture<'_, Result> { @@ -345,11 +347,7 @@ where }; Box::pin(async move { let args = Args::try_from(params)?; - self.func - .call(args) - .await - .try_into() - .map_err(Into::into) + self.func.call(args).await.try_into().map_err(Into::into) }) } } @@ -360,7 +358,7 @@ impl GetPromptRequestParams { Self { name: name.into(), args: None, - meta: None + meta: None, } } @@ -401,12 +399,12 @@ impl Prompt { F: PromptHandler, R: TryInto + Send + 'static, R::Error: Into, - Args: TryFrom + Send + Sync + 'static, + Args: TryFrom + Send + Sync + 'static, { let handler = PromptFunc::new(handler); let args = F::args(); - Self { - name: name.into(), + Self { + name: name.into(), title: None, descr: None, meta: None, @@ -416,35 +414,32 @@ impl Prompt { roles: None, #[cfg(feature = "http-server")] permissions: None, - icons: None + icons: None, } } - + /// Sets a [`Prompt`] title pub fn with_title(&mut self, title: impl Into) -> &mut Self { self.title = Some(title.into()); self } - + /// Sets a [`Prompt`] description pub fn with_description(&mut self, descr: impl Into) -> &mut Self { self.descr = Some(descr.into()); self } - + /// Sets arguments for the [`Prompt`] pub fn with_args(&mut self, args: T) -> &mut Self where T: IntoIterator, A: Into, { - self.args = Some(args - .into_iter() - .map(Into::into) - .collect()); + self.args = Some(args.into_iter().map(Into::into).collect()); self } - + /// Sets the [`Prompt`] icons pub fn with_icons(&mut self, icons: impl IntoIterator) -> &mut Self { self.icons = Some(icons.into_iter().collect()); @@ -456,12 +451,9 @@ impl Prompt { pub fn with_roles(&mut self, roles: T) -> &mut Self where T: IntoIterator, - I: Into + I: Into, { - self.roles = Some(roles - .into_iter() - .map(Into::into) - .collect()); + self.roles = Some(roles.into_iter().map(Into::into).collect()); self } @@ -470,12 +462,9 @@ impl Prompt { pub fn with_permissions(&mut self, permissions: T) -> &mut Self where T: IntoIterator, - I: Into + I: Into, { - self.permissions = Some(permissions - .into_iter() - .map(Into::into) - .collect()); + self.permissions = Some(permissions.into_iter().map(Into::into).collect()); self } @@ -484,7 +473,10 @@ impl Prompt { pub(crate) async fn call(&self, params: HandlerParams) -> Result { match self.handler { Some(ref handler) => handler.call(params).await, - None => Err(Error::new(ErrorCode::InternalError, "Prompt handler not specified")) + None => Err(Error::new( + ErrorCode::InternalError, + "Prompt handler not specified", + )), } } } @@ -499,14 +491,13 @@ impl PromptArguments { /// Deserializes a [`Vec`] of [`PromptArgument`] from a JSON string #[inline] pub fn from_json_str(json: &str) -> Vec { - serde_json::from_str(json) - .expect("PromptArgument: Incorrect JSON string provided") + serde_json::from_str(json).expect("PromptArgument: Incorrect JSON string provided") } } #[cfg(feature = "server")] impl PromptArgument { - /// Creates a new [`PromptArgument`] + /// Creates a new [`PromptArgument`] pub(crate) fn new() -> Self { Self { name: std::any::type_name::().into(), @@ -514,7 +505,7 @@ impl PromptArgument { required: Some(true), } } - + /// Creates a new required [`PromptArgument`] pub fn required>(name: T, descr: T) -> Self { Self { @@ -547,12 +538,12 @@ macro_rules! impl_generic_prompt_handler ({ $($param:ident)* } => { let mut args = Vec::new(); $( { - if $param::category() != PropertyType::None { + if $param::category() != PropertyType::None { args.push(PromptArgument::new::<$param>()); - } + } } )* - if args.len() == 0 { + if args.len() == 0 { None } else { Some(args) @@ -566,4 +557,4 @@ impl_generic_prompt_handler! { T1 } impl_generic_prompt_handler! { T1 T2 } impl_generic_prompt_handler! { T1 T2 T3 } impl_generic_prompt_handler! { T1 T2 T3 T4 } -impl_generic_prompt_handler! { T1 T2 T3 T4 T5 } \ No newline at end of file +impl_generic_prompt_handler! { T1 T2 T3 T4 T5 } diff --git a/neva/src/types/prompt/from_request.rs b/neva/src/types/prompt/from_request.rs index c92807f..a6c60aa 100644 --- a/neva/src/types/prompt/from_request.rs +++ b/neva/src/types/prompt/from_request.rs @@ -1,5 +1,5 @@ -use crate::error::Error; use super::GetPromptRequestParams; +use crate::error::Error; use crate::types::helpers::extract::{RequestArgument, extract_arg}; impl TryFrom for () { @@ -15,7 +15,7 @@ macro_rules! impl_from_get_prompt_params { ($($T: ident),*) => { impl<$($T: RequestArgument),+> TryFrom for ($($T,)+) { type Error = Error; - + #[inline] fn try_from(params: GetPromptRequestParams) -> Result { let args = params.args.unwrap_or_default(); @@ -35,4 +35,4 @@ impl_from_get_prompt_params! { T1 } impl_from_get_prompt_params! { T1, T2 } impl_from_get_prompt_params! { T1, T2, T3 } impl_from_get_prompt_params! { T1, T2, T3, T4 } -impl_from_get_prompt_params! { T1, T2, T3, T4, T5 } \ No newline at end of file +impl_from_get_prompt_params! { T1, T2, T3, T4, T5 } diff --git a/neva/src/types/prompt/get_prompt_result.rs b/neva/src/types/prompt/get_prompt_result.rs index 2d21fa4..97e06d4 100644 --- a/neva/src/types/prompt/get_prompt_result.rs +++ b/neva/src/types/prompt/get_prompt_result.rs @@ -1,10 +1,13 @@ -//! Types and utils for prompt request results +//! Types and utils for prompt request results -use serde::{Serialize, Deserialize}; use crate::types::{Content, Role}; #[cfg(feature = "server")] -use crate::{error::Error, types::{IntoResponse, RequestId, Response}}; - +use crate::{ + error::Error, + types::{IntoResponse, RequestId, Response}, +}; +use serde::{Deserialize, Serialize}; + /// The server's response to a prompts/get request from the client. /// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/2024-11-05/schema.json) for details @@ -20,7 +23,7 @@ pub struct GetPromptResult { /// Describes a message returned as part of a prompt. /// -/// This is similar to `SamplingMessage`, but also supports the embedding of +/// This is similar to `SamplingMessage`, but also supports the embedding of /// resources from the MCP server. /// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/2024-11-05/schema.json) for details @@ -38,16 +41,16 @@ impl IntoResponse for GetPromptResult { fn into_response(self, req_id: RequestId) -> Response { match serde_json::to_value(self) { Ok(v) => Response::success(req_id, v), - Err(err) => Response::error(req_id, err.into()) + Err(err) => Response::error(req_id, err.into()), } } } #[cfg(feature = "server")] impl From<(T1, T2)> for PromptMessage -where +where T1: Into, - T2: Into + T2: Into, { #[inline] fn from((msg, role): (T1, T2)) -> Self { @@ -57,20 +60,23 @@ where #[cfg(feature = "server")] impl From for GetPromptResult -where - T: Into +where + T: Into, { #[inline] fn from(msg: T) -> Self { - Self { descr: None, messages: vec![msg.into()] } + Self { + descr: None, + messages: vec![msg.into()], + } } } #[cfg(feature = "server")] impl TryFrom> for GetPromptResult -where +where T: Into, - E: Into + E: Into, { type Error = E; @@ -78,7 +84,7 @@ where fn try_from(value: Result) -> Result { match value { Ok(ok) => Ok(ok.into()), - Err(err) => Err(err) + Err(err) => Err(err), } } } @@ -86,16 +92,13 @@ where #[cfg(feature = "server")] impl From> for GetPromptResult where - T: Into + T: Into, { #[inline] fn from(iter: Vec) -> Self { Self { descr: None, - messages: iter - .into_iter() - .map(Into::into) - .collect(), + messages: iter.into_iter().map(Into::into).collect(), } } } @@ -103,16 +106,13 @@ where #[cfg(feature = "server")] impl From<[T; N]> for GetPromptResult where - T: Into + T: Into, { #[inline] fn from(iter: [T; N]) -> Self { Self { descr: None, - messages: iter - .into_iter() - .map(Into::into) - .collect(), + messages: iter.into_iter().map(Into::into).collect(), } } } @@ -122,22 +122,22 @@ impl PromptMessage { /// Creates a new [`PromptMessage`] #[inline] pub fn new(role: impl Into) -> Self { - Self { - content: Content::empty(), - role: role.into() + Self { + content: Content::empty(), + role: role.into(), } } - + /// Creates a new [`PromptMessage`] with the user role pub fn user() -> Self { Self::new(Role::User) } - + /// Creates a new [`PromptMessage`] with the assistant role pub fn assistant() -> Self { Self::new(Role::Assistant) } - + /// Sets the content of [`PromptMessage`] pub fn with>(mut self, content: T) -> Self { self.content = content.into(); @@ -150,12 +150,12 @@ impl GetPromptResult { /// Creates a new [`GetPromptResult`] #[inline] pub fn new() -> Self { - Self { + Self { messages: Vec::with_capacity(8), - descr: None + descr: None, } } - + /// Sets the description of the result pub fn with_descr>(mut self, descr: T) -> Self { self.descr = Some(descr.into()); @@ -167,20 +167,17 @@ impl GetPromptResult { self.messages.push(message.into()); self } - + /// Adds multiple messages to the result pub fn with_messages(mut self, messages: T) -> Self - where + where T: IntoIterator, - I: Into + I: Into, { - self.messages - .extend(messages.into_iter().map(Into::into)); + self.messages.extend(messages.into_iter().map(Into::into)); self } } #[cfg(test)] -mod tests { - -} \ No newline at end of file +mod tests {} diff --git a/neva/src/types/reference.rs b/neva/src/types/reference.rs index 4dc570b..fa6d5f4 100644 --- a/neva/src/types/reference.rs +++ b/neva/src/types/reference.rs @@ -1,22 +1,22 @@ -//! Types and utils for references +//! Types and utils for references +use serde::{Deserialize, Serialize}; use std::fmt::Display; -use serde::{Serialize, Deserialize}; /// Represents a reference to a resource or prompt. /// Umbrella type for both ResourceReference and PromptReference from the spec schema. -/// -/// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/2024-11-05/schema.json) for details +/// +/// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/2024-11-05/schema.json) for details #[derive(Debug, Serialize, Deserialize)] pub struct Reference { /// The type of content. Can be ref/resource or ref/prompt. #[serde(rename = "type")] pub r#type: String, - + /// The URI or URI template of the resource. #[serde(skip_serializing_if = "Option::is_none")] pub uri: Option, - + /// The name of the prompt or prompt template. #[serde(skip_serializing_if = "Option::is_none")] pub name: Option, @@ -40,7 +40,7 @@ impl Reference { Self { r#type: "ref/resource".into(), uri: Some(uri.into()), - name: None + name: None, } } @@ -50,16 +50,16 @@ impl Reference { Self { r#type: "ref/prompt".into(), name: Some(name.into()), - uri: None + uri: None, } } - + /// Validates the reference object. - /// + /// /// # Example /// ```no_run /// use neva::types::Reference; - /// + /// /// // valid ref/resource /// let reference = Reference::resource("file://test"); /// assert!(reference.validate().is_none()); @@ -70,9 +70,21 @@ impl Reference { /// ``` pub fn validate(&self) -> Option { match self.r#type.as_ref() { - "ref/resource" => if self.uri.is_none() { Some("uri is required for ref/resource".into()) } else { None }, - "ref/prompt" => if self.name.is_none() { Some("name is required for ref/prompt".into()) } else { None }, + "ref/resource" => { + if self.uri.is_none() { + Some("uri is required for ref/resource".into()) + } else { + None + } + } + "ref/prompt" => { + if self.name.is_none() { + Some("name is required for ref/prompt".into()) + } else { + None + } + } _ => Some(format!("unknown reference type: {}", self.r#type)), } } -} \ No newline at end of file +} diff --git a/neva/src/types/request.rs b/neva/src/types/request.rs index 691ff30..4a4746b 100644 --- a/neva/src/types/request.rs +++ b/neva/src/types/request.rs @@ -1,18 +1,15 @@ -//! Represents a request from an MCP client +//! Represents a request from an MCP client +use super::{JSONRPC_VERSION, Message, ProgressToken}; +use serde::{Deserialize, Serialize}; use std::fmt; use std::fmt::{Debug, Formatter}; -use serde::{Serialize, Deserialize}; -use super::{ProgressToken, Message, JSONRPC_VERSION}; #[cfg(feature = "server")] use crate::Context; #[cfg(feature = "http-server")] -use { - crate::auth::DefaultClaims, - volga::headers::HeaderMap -}; +use {crate::auth::DefaultClaims, volga::headers::HeaderMap}; #[cfg(feature = "tasks")] use crate::types::RelatedTaskMetadata; @@ -28,21 +25,21 @@ mod request_id; /// A request in the JSON-RPC protocol. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Request { - /// JSON-RPC protocol version. + /// JSON-RPC protocol version. /// /// > **Note:** always 2.0. pub jsonrpc: String, /// Request identifier. Must be a string or number and unique within the session. pub id: RequestId, - + /// Name of the method to invoke. pub method: String, - + /// Optional parameters for the method. #[serde(skip_serializing_if = "Option::is_none")] pub params: Option, - + /// Current MCP Session ID #[serde(skip)] pub session_id: Option, @@ -59,28 +56,31 @@ pub struct Request { } /// Provides metadata related to the request that provides additional protocol-level information. -/// +/// /// > **Note:** This class contains properties that are used by the Model Context Protocol /// > for features like progress tracking and other protocol-specific capabilities. #[derive(Default, Clone, Deserialize, Serialize)] pub struct RequestParamsMeta { /// An opaque token that will be attached to any subsequent progress notifications. - /// + /// /// > **Note:** The receiver is not obligated to provide these notifications. #[serde(rename = "progressToken", skip_serializing_if = "Option::is_none")] pub progress_token: Option, - + /// Represents metadata for associating messages with a task. - /// + /// /// > **Note:** Include this in the _meta field under the key `io.modelcontextprotocol/related-task`. - #[serde(rename = "io.modelcontextprotocol/related-task", skip_serializing_if = "Option::is_none")] + #[serde( + rename = "io.modelcontextprotocol/related-task", + skip_serializing_if = "Option::is_none" + )] #[cfg(feature = "tasks")] pub(crate) task: Option, /// MCP request context #[serde(skip)] #[cfg(feature = "server")] - pub(crate) context: Option + pub(crate) context: Option, } impl Debug for RequestParamsMeta { @@ -107,14 +107,18 @@ impl RequestParamsMeta { #[cfg(feature = "tasks")] task: None, #[cfg(feature = "server")] - context: None + context: None, } } } impl Request { /// Creates a new [`Request`] - pub fn new(id: Option, method: impl Into, params: Option) -> Self { + pub fn new( + id: Option, + method: impl Into, + params: Option, + ) -> Self { Self { jsonrpc: JSONRPC_VERSION.into(), session_id: None, @@ -144,10 +148,11 @@ impl Request { id } } - + /// Returns [`Request`] params metadata pub fn meta(&self) -> Option { - self.params.as_ref()? + self.params + .as_ref()? .get("_meta") .cloned() .and_then(|meta| serde_json::from_value(meta).ok()) @@ -155,6 +160,4 @@ impl Request { } #[cfg(test)] -mod tests { - -} \ No newline at end of file +mod tests {} diff --git a/neva/src/types/request/from_request.rs b/neva/src/types/request/from_request.rs index ac43956..e413278 100644 --- a/neva/src/types/request/from_request.rs +++ b/neva/src/types/request/from_request.rs @@ -1,8 +1,8 @@ -//! Utilities for extraction params from Request +//! Utilities for extraction params from Request -use serde::de::DeserializeOwned; use crate::error::{Error, ErrorCode}; use crate::types::Request; +use serde::de::DeserializeOwned; /// A trait that helps the extract typed _params_ from request pub trait FromRequest: Sized { @@ -19,4 +19,4 @@ impl FromRequest for T { let params = serde_json::from_value(params)?; Ok(params) } -} \ No newline at end of file +} diff --git a/neva/src/types/request/request_id.rs b/neva/src/types/request/request_id.rs index 23a564a..80ae6f0 100644 --- a/neva/src/types/request/request_id.rs +++ b/neva/src/types/request/request_id.rs @@ -1,12 +1,12 @@ -//! Generic identity data structure for requests. +//! Generic identity data structure for requests. -use serde::{Deserialize, Deserializer, Serialize, Serializer}; use crate::shared::{ArcSlice, ArcStr, MemChr}; use crate::types::ProgressToken; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use std::{ - convert::Infallible, - fmt::{self, Display, Formatter}, - str::FromStr + convert::Infallible, + fmt::{self, Display, Formatter}, + str::FromStr, }; const SEPARATOR: u8 = b'/'; @@ -120,11 +120,13 @@ impl From<&RequestId> for ProgressToken { // Null ids only appear in error responses; they are never used // as progress tokens in practice. RequestId::Null => ProgressToken::String("null".into()), - RequestId::Slice(slice) => ProgressToken::Slice(slice - .iter() - .map(Into::into) - .collect::>() - .into()) + RequestId::Slice(slice) => ProgressToken::Slice( + slice + .iter() + .map(Into::into) + .collect::>() + .into(), + ), } } } @@ -199,7 +201,7 @@ impl serde::de::Visitor<'_> for RequestIdVisitor { } impl RequestId { - /// Consumes the current [`RequestId`], concatenates it with another one + /// Consumes the current [`RequestId`], concatenates it with another one /// and returns a new [`RequestId::Slice`] pub fn concat(self, request_id: RequestId) -> RequestId { let slice = [self, request_id]; @@ -235,16 +237,19 @@ mod tests { #[test] fn it_converts_slice_request_id_to_progress_token() { - let id = RequestId::Slice([ - RequestId::String("user".into()), - RequestId::Number(1) - ].into()); + let id = RequestId::Slice([RequestId::String("user".into()), RequestId::Number(1)].into()); let token = ProgressToken::from(&id); - assert_eq!(token, ProgressToken::Slice([ - ProgressToken::String("user".into()), - ProgressToken::Number(1) - ].into())); + assert_eq!( + token, + ProgressToken::Slice( + [ + ProgressToken::String("user".into()), + ProgressToken::Number(1) + ] + .into() + ) + ); } #[test] @@ -259,10 +264,8 @@ mod tests { #[test] fn it_serializes_and_deserializes_slice_through_str_request_id() { - let expected_id = RequestId::Slice([ - RequestId::String("user".into()), - RequestId::Number(1) - ].into()); + let expected_id = + RequestId::Slice([RequestId::String("user".into()), RequestId::Number(1)].into()); let json = serde_json::to_string(&expected_id).unwrap(); let new_id: RequestId = serde_json::from_str(&json).unwrap(); @@ -272,10 +275,8 @@ mod tests { #[test] fn it_serializes_and_deserializes_slice_through_value_request_id() { - let expected_id = RequestId::Slice([ - RequestId::String("user".into()), - RequestId::Number(1) - ].into()); + let expected_id = + RequestId::Slice([RequestId::String("user".into()), RequestId::Number(1)].into()); let json = serde_json::to_value(&expected_id).unwrap(); let new_id: RequestId = serde_json::from_value(json).unwrap(); diff --git a/neva/src/types/resource.rs b/neva/src/types/resource.rs index e2557ca..ab28741 100644 --- a/neva/src/types/resource.rs +++ b/neva/src/types/resource.rs @@ -1,78 +1,75 @@ -//! Represents an MCP resource +//! Represents an MCP resource -use serde::{Deserialize, Serialize}; -use crate::types::{Annotations, Cursor, Icon}; -use crate::types::request::RequestParamsMeta; +#[cfg(feature = "server")] +use crate::app::{ + context::Context, + handler::{FromHandlerParams, HandlerParams}, +}; #[cfg(feature = "server")] use crate::error::Error; #[cfg(feature = "server")] use crate::types::request::FromRequest; +use crate::types::request::RequestParamsMeta; +use crate::types::{Annotations, Cursor, Icon}; #[cfg(feature = "server")] -use crate::types::{RequestId, Response, IntoResponse, Request, Page}; -#[cfg(feature = "server")] -use crate::app::{context::Context, handler::{FromHandlerParams, HandlerParams}}; +use crate::types::{IntoResponse, Page, Request, RequestId, Response}; +use serde::{Deserialize, Serialize}; -pub use uri::Uri; pub use read_resource_result::{ - ReadResourceResult, - ResourceContents, - TextResourceContents, - BlobResourceContents, - JsonResourceContents, - EmptyResourceContents + BlobResourceContents, EmptyResourceContents, JsonResourceContents, ReadResourceResult, + ResourceContents, TextResourceContents, }; pub use template::{ - ResourceTemplate, - ListResourceTemplatesResult, - ListResourceTemplatesRequestParams + ListResourceTemplatesRequestParams, ListResourceTemplatesResult, ResourceTemplate, }; +pub use uri::Uri; #[cfg(all(feature = "server", feature = "di"))] -pub(crate) use from_request::{ResourceArgument, Payload, Source}; +pub(crate) use from_request::{Payload, ResourceArgument, Source}; #[cfg(feature = "server")] pub(crate) use route::Route; +#[cfg(feature = "server")] +mod from_request; mod read_resource_result; -mod uri; -pub(crate) mod template; #[cfg(feature = "server")] pub(crate) mod route; -#[cfg(feature = "server")] -mod from_request; +pub(crate) mod template; +mod uri; /// List of commands for Resources pub mod commands { /// Command name that returns a list of resources available on MCP server pub const LIST: &str = "resources/list"; - + /// Notification name that indicates that the list of resources has changed. pub const LIST_CHANGED: &str = "notifications/resources/list_changed"; - + /// Command name that returns a list of resource templates available on MCP server. pub const TEMPLATES_LIST: &str = "resources/templates/list"; - + /// Command name that returns the resource data pub const READ: &str = "resources/read"; - + /// Command name that subscribes to resource updates. pub const SUBSCRIBE: &str = "resources/subscribe"; - + /// Command name that unsubscribes from resource updates. pub const UNSUBSCRIBE: &str = "resources/unsubscribe"; - + /// Notification name that indicates that the resource has been updated. pub const UPDATED: &str = "notifications/resources/updated"; } /// Represents a known resource that the server is capable of reading. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Resource { /// The URI of this resource. pub uri: Uri, - + /// A human-readable name for this resource. pub name: String, @@ -118,7 +115,7 @@ pub struct Resource { } /// Sent from the client to request a list of resources the server has. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Default, Serialize, Deserialize)] pub struct ListResourcesRequestParams { @@ -129,11 +126,11 @@ pub struct ListResourcesRequestParams { } /// Sent from the client to the server to read a specific resource URI. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Serialize, Deserialize)] pub struct ReadResourceRequestParams { - /// The URI of the resource to read. The URI can use any protocol; + /// The URI of the resource to read. The URI can use any protocol; /// it is up to the server how to interpret it. pub uri: Uri, @@ -143,15 +140,15 @@ pub struct ReadResourceRequestParams { /// > that are not part of the primary request parameters. #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] pub meta: Option, - + /// Path arguments extracted from [`Uri`] #[serde(skip)] #[cfg(feature = "server")] - pub(crate) args: Option> + pub(crate) args: Option>, } /// The server's response to a resources/list request from the client. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/2024-11-05/schema.json) for details #[derive(Debug, Default, Serialize, Deserialize)] pub struct ListResourcesResult { @@ -168,25 +165,25 @@ pub struct ListResourcesResult { pub next_cursor: Option, } -/// Sent from the client to request resources/updated notifications +/// Sent from the client to request resources/updated notifications /// from the server whenever a particular resource changes. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Serialize, Deserialize)] pub struct SubscribeRequestParams { - /// The URI of the resource to subscribe to. + /// The URI of the resource to subscribe to. /// The URI can use any protocol; it is up to the server how to interpret it. pub uri: Uri, } -/// Sent from the client to request not receiving updated notifications +/// Sent from the client to request not receiving updated notifications /// from the server whenever a primitive resource changes. /// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Serialize, Deserialize)] pub struct UnsubscribeRequestParams { - /// The URI of the resource to unsubscribe from. - /// The URI can use any protocol; it is up to the server how to interpret it. + /// The URI of the resource to unsubscribe from. + /// The URI can use any protocol; it is up to the server how to interpret it. pub uri: Uri, } @@ -210,7 +207,7 @@ impl IntoResponse for ListResourcesResult { fn into_response(self, req_id: RequestId) -> Response { match serde_json::to_value(self) { Ok(v) => Response::success(req_id, v), - Err(err) => Response::error(req_id, err.into()) + Err(err) => Response::error(req_id, err.into()), } } } @@ -221,7 +218,7 @@ impl From<[Resource; N]> for ListResourcesResult { fn from(resources: [Resource; N]) -> Self { Self { next_cursor: None, - resources: resources.to_vec() + resources: resources.to_vec(), } } } @@ -232,7 +229,7 @@ impl From> for ListResourcesResult { fn from(resources: Vec) -> Self { Self { next_cursor: None, - resources + resources, } } } @@ -243,7 +240,7 @@ impl From> for ListResourcesResult { fn from(page: Page<'_, Resource>) -> Self { Self { next_cursor: page.next_cursor, - resources: page.items.to_vec() + resources: page.items.to_vec(), } } } @@ -305,7 +302,7 @@ impl From for Resource { annotations: None, meta: None, icons: None, - uri + uri, } } } @@ -334,7 +331,7 @@ impl From for ReadResourceRequestParams { meta: None, #[cfg(feature = "server")] args: None, - uri + uri, } } } @@ -346,8 +343,8 @@ impl ReadResourceRequestParams { self.meta.get_or_insert_default().context = Some(ctx); self } - - /// Includes path arguments extracted from [`Uri`] + + /// Includes path arguments extracted from [`Uri`] #[cfg(feature = "server")] pub(crate) fn with_args(mut self, args: Box<[String]>) -> Self { self.args = Some(args); @@ -360,11 +357,11 @@ impl Resource { /// Creates a new [`Resource`] #[inline] pub fn new, S: Into>(uri: U, name: S) -> Self { - Self { - uri: uri.into(), + Self { + uri: uri.into(), name: name.into(), title: None, - descr: None, + descr: None, mime: None, size: None, annotations: None, @@ -372,7 +369,7 @@ impl Resource { meta: None, } } - + /// Sets a name for a resource pub fn with_name(mut self, name: impl Into) -> Self { self.name = name.into(); @@ -400,12 +397,12 @@ impl Resource { /// Sets annotations for the client pub fn with_annotations(mut self, config: F) -> Self where - F: FnOnce(Annotations) -> Annotations + F: FnOnce(Annotations) -> Annotations, { self.annotations = Some(config(Default::default())); self } - + /// Sets a title for a resource pub fn with_title(mut self, title: impl Into) -> Self { self.title = Some(title.into()); @@ -420,6 +417,4 @@ impl Resource { } #[cfg(test)] -mod tests { - -} \ No newline at end of file +mod tests {} diff --git a/neva/src/types/resource/from_request.rs b/neva/src/types/resource/from_request.rs index c1f349e..b1a9ada 100644 --- a/neva/src/types/resource/from_request.rs +++ b/neva/src/types/resource/from_request.rs @@ -1,18 +1,18 @@ -use std::path::PathBuf; +use super::{ReadResourceRequestParams, Uri}; use crate::Context; use crate::error::{Error, ErrorCode}; -use crate::types::{Meta, ProgressToken}; use crate::types::request::RequestParamsMeta; -use super::{Uri, ReadResourceRequestParams}; +use crate::types::{Meta, ProgressToken}; +use std::path::PathBuf; /// Represents a payload that needs the type to be extracted from pub(crate) enum Payload<'a> { /// Resource URI Uri(&'a Uri), - + /// Resource URI part UriPart(String), - + /// Request metadata ("_meta") Meta(&'a Option), } @@ -118,7 +118,7 @@ macro_rules! impl_resource_argument { } impl_resource_argument! { - bool, + bool, char, i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, @@ -151,7 +151,10 @@ impl ResourceArgument for Meta { let meta = payload.expect_meta(); meta.as_ref() .and_then(|meta| meta.progress_token.clone()) - .ok_or(Error::new(ErrorCode::InvalidParams, "Missing progress token")) + .ok_or(Error::new( + ErrorCode::InvalidParams, + "Missing progress token", + )) .map(Meta) } @@ -182,14 +185,15 @@ impl ResourceArgument for Context { pub(crate) fn extract_arg>( uri: &Uri, meta: &Option, - iter: &mut impl Iterator + iter: &mut impl Iterator, ) -> Result { match T::source() { Source::Meta => T::extract(Payload::Meta(meta)), Source::Uri => T::extract(Payload::Uri(uri)), - Source::UriPart => T::extract(Payload::UriPart(iter - .next() - .ok_or(Error::new(ErrorCode::InvalidParams, "Invalid URI param provided"))?)) + Source::UriPart => T::extract(Payload::UriPart(iter.next().ok_or(Error::new( + ErrorCode::InvalidParams, + "Invalid URI param provided", + ))?)), } } @@ -197,15 +201,15 @@ macro_rules! impl_from_read_resource_params { ($($T: ident),*) => { impl<$($T: ResourceArgument),+> TryFrom for ($($T,)+) { type Error = Error; - + #[inline] fn try_from(params: ReadResourceRequestParams) -> Result { let uri = params.uri; let mut iter = params.args.into_iter().flatten(); let tuple = ( $( - extract_arg::<$T>(&uri, ¶ms.meta, &mut iter)?, - )* + extract_arg::<$T>(&uri, ¶ms.meta, &mut iter)?, + )* ); Ok(tuple) } @@ -217,4 +221,4 @@ impl_from_read_resource_params! { T1 } impl_from_read_resource_params! { T1, T2 } impl_from_read_resource_params! { T1, T2, T3 } impl_from_read_resource_params! { T1, T2, T3, T4 } -impl_from_read_resource_params! { T1, T2, T3, T4, T5 } \ No newline at end of file +impl_from_read_resource_params! { T1, T2, T3, T4, T5 } diff --git a/neva/src/types/resource/read_resource_result.rs b/neva/src/types/resource/read_resource_result.rs index e4159f6..94073dd 100644 --- a/neva/src/types/resource/read_resource_result.rs +++ b/neva/src/types/resource/read_resource_result.rs @@ -1,18 +1,19 @@ -//! Types and utils for handling read resource results +//! Types and utils for handling read resource results -use serde::{Deserialize, Serialize}; -use bytes::Bytes; -use crate::types::{Annotations, Uri}; use crate::types::helpers::{ - deserialize_base64_as_bytes, - serialize_bytes_as_base64, - deserialize_value_from_string, - serialize_value_as_string + deserialize_base64_as_bytes, deserialize_value_from_string, serialize_bytes_as_base64, + serialize_value_as_string, }; +use crate::types::{Annotations, Uri}; +use bytes::Bytes; +use serde::{Deserialize, Serialize}; #[cfg(feature = "server")] use { - crate::{error::{Error, ErrorCode}, types::{IntoResponse, RequestId, Response}}, - serde::de::DeserializeOwned + crate::{ + error::{Error, ErrorCode}, + types::{IntoResponse, RequestId, Response}, + }, + serde::de::DeserializeOwned, }; const CHUNK_SIZE: usize = 8192; @@ -23,7 +24,7 @@ const CHUNK_SIZE: usize = 8192; #[derive(Debug, Serialize, Deserialize)] pub struct ReadResourceResult { /// A list of ResourceContents that this resource contains. - pub contents: Vec + pub contents: Vec, } /// Represents the content of a resource. @@ -34,19 +35,19 @@ pub struct ReadResourceResult { pub enum ResourceContents { /// Represents a text resource content Text(TextResourceContents), - + /// Represents a JSON resource content Json(JsonResourceContents), - + /// Represents a blob resource content Blob(BlobResourceContents), - + /// Represents an empty/unknown resource content Empty(EmptyResourceContents), } /// Represents a blob resource content -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct BlobResourceContents { @@ -58,7 +59,8 @@ pub struct BlobResourceContents { /// **Note:** will be serialized as a base64-encoded string #[serde( serialize_with = "serialize_bytes_as_base64", - deserialize_with = "deserialize_base64_as_bytes")] + deserialize_with = "deserialize_base64_as_bytes" + )] pub blob: Bytes, /// Intended for UI and end-user contexts - optimized to be human-readable and easily understood, @@ -83,7 +85,7 @@ pub struct BlobResourceContents { } /// Represents a text resource content -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TextResourceContents { @@ -115,7 +117,7 @@ pub struct TextResourceContents { } /// Represents a JSON resource content -/// +/// /// > **Note:** This is a specialization of [`TextResourceContents`] for JSON content. /// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details @@ -186,7 +188,7 @@ impl IntoResponse for ReadResourceResult { fn into_response(self, req_id: RequestId) -> Response { match serde_json::to_value(self) { Ok(v) => Response::success(req_id, v), - Err(err) => Response::error(req_id, err.into()) + Err(err) => Response::error(req_id, err.into()), } } } @@ -195,7 +197,7 @@ impl Default for ReadResourceResult { #[inline] fn default() -> Self { Self { - contents: Vec::with_capacity(8) + contents: Vec::with_capacity(8), } } } @@ -222,10 +224,10 @@ impl From for ResourceContents { } #[cfg(feature = "server")] -impl From<(T1, T2)> for TextResourceContents -where +impl From<(T1, T2)> for TextResourceContents +where T1: Into, - T2: Into + T2: Into, { #[inline] fn from((uri, text): (T1, T2)) -> Self { @@ -245,7 +247,7 @@ impl From<(T1, T2, T3)> for TextResourceContents where T1: Into, T2: Into, - T3: Into + T3: Into, { #[inline] fn from((uri, mime, text): (T1, T2, T3)) -> Self { @@ -264,7 +266,7 @@ where impl From<(T1, T2)> for ResourceContents where T1: Into, - T2: Into + T2: Into, { #[inline] fn from(pair: (T1, T2)) -> Self { @@ -277,7 +279,7 @@ impl From<(T1, T2, T3)> for ResourceContents where T1: Into, T2: Into, - T3: Into + T3: Into, { #[inline] fn from(triplet: (T1, T2, T3)) -> Self { @@ -287,20 +289,22 @@ where #[cfg(feature = "server")] impl From for ReadResourceResult -where - T: Into +where + T: Into, { #[inline] fn from(content: T) -> Self { - Self { contents: vec![content.into()] } + Self { + contents: vec![content.into()], + } } } #[cfg(feature = "server")] impl TryFrom> for ReadResourceResult -where +where T: Into, - E: Into + E: Into, { type Error = E; @@ -308,39 +312,33 @@ where fn try_from(value: Result) -> Result { match value { Ok(ok) => Ok(ok.into()), - Err(err) => Err(err) + Err(err) => Err(err), } } } #[cfg(feature = "server")] impl From> for ReadResourceResult -where - T: Into +where + T: Into, { #[inline] fn from(vec: Vec) -> Self { Self { - contents: vec - .into_iter() - .map(Into::into) - .collect(), + contents: vec.into_iter().map(Into::into).collect(), } } } #[cfg(feature = "server")] impl From<[T; N]> for ReadResourceResult -where - T: Into +where + T: Into, { #[inline] fn from(vec: [T; N]) -> Self { Self { - contents: vec - .into_iter() - .map(Into::into) - .collect(), + contents: vec.into_iter().map(Into::into).collect(), } } } @@ -350,24 +348,23 @@ impl ReadResourceResult { /// Creates a new read resource result #[inline] pub fn new() -> Self { - Self::default() + Self::default() } - + /// Add a resource content to the result #[inline] pub fn with_content(mut self, content: impl Into) -> Self { self.contents.push(content.into()); self } - + /// Add multiple resource contents to the result pub fn with_contents(mut self, contents: T) -> Self - where + where T: IntoIterator, - I: Into + I: Into, { - self.contents - .extend(contents.into_iter().map(Into::into)); + self.contents.extend(contents.into_iter().map(Into::into)); self } } @@ -379,7 +376,7 @@ impl ResourceContents { pub fn new(uri: impl Into) -> Self { Self::Empty(EmptyResourceContents::new(uri)) } - + /// Returns the URI of the resource content #[inline] pub fn uri(&self) -> &Uri { @@ -387,7 +384,7 @@ impl ResourceContents { Self::Text(text) => &text.uri, Self::Json(json) => &json.uri, Self::Blob(blob) => &blob.uri, - Self::Empty(empty) => &empty.uri + Self::Empty(empty) => &empty.uri, } } @@ -398,7 +395,7 @@ impl ResourceContents { Self::Text(text) => Some(&text.text), Self::Json(json) => json.value.as_str(), Self::Blob(_) => None, - Self::Empty(_) => None + Self::Empty(_) => None, } } @@ -409,7 +406,7 @@ impl ResourceContents { Self::Text(text) => text.title.as_deref(), Self::Json(json) => json.title.as_deref(), Self::Blob(blob) => blob.title.as_deref(), - Self::Empty(empty) => empty.title.as_deref() + Self::Empty(empty) => empty.title.as_deref(), } } @@ -420,7 +417,7 @@ impl ResourceContents { Self::Text(text) => text.annotations.as_ref(), Self::Json(json) => json.annotations.as_ref(), Self::Blob(blob) => blob.annotations.as_ref(), - Self::Empty(empty) => empty.annotations.as_ref() + Self::Empty(empty) => empty.annotations.as_ref(), } } @@ -431,7 +428,7 @@ impl ResourceContents { Self::Blob(blob) => Some(&blob.blob), Self::Json(_) => None, Self::Text(_) => None, - Self::Empty(_) => None + Self::Empty(_) => None, } } @@ -439,26 +436,30 @@ impl ResourceContents { #[inline] pub fn json(&self) -> Result { match self { - Self::Text(text) => serde_json::from_str(&text.text) - .map_err(Error::from), - Self::Json(json) => serde_json::from_value(json.value.clone()) - .map_err(Error::from), - Self::Blob(_) => Err(Error::new(ErrorCode::InvalidRequest, "Cannot deserialize blob")), - Self::Empty(_) => Err(Error::new(ErrorCode::InvalidRequest, "Cannot empty resource")) + Self::Text(text) => serde_json::from_str(&text.text).map_err(Error::from), + Self::Json(json) => serde_json::from_value(json.value.clone()).map_err(Error::from), + Self::Blob(_) => Err(Error::new( + ErrorCode::InvalidRequest, + "Cannot deserialize blob", + )), + Self::Empty(_) => Err(Error::new( + ErrorCode::InvalidRequest, + "Cannot empty resource", + )), } } /// Returns the mime type of the resource content #[inline] pub fn mime(&self) -> Option<&str> { - match self { + match self { Self::Text(text) => text.mime.as_deref(), Self::Json(json) => json.mime.as_deref(), Self::Blob(blob) => blob.mime.as_deref(), - Self::Empty(empty) => empty.mime.as_deref() + Self::Empty(empty) => empty.mime.as_deref(), } } - + /// Sets the mime type of the resource content #[inline] pub fn with_mime(mut self, mime: impl Into) -> Self { @@ -487,7 +488,7 @@ impl ResourceContents { #[inline] pub fn with_annotations(self, config: F) -> Self where - F: FnOnce(Annotations) -> Annotations + F: FnOnce(Annotations) -> Annotations, { match self { Self::Text(text) => Self::Text(text.with_annotations(config)), @@ -496,7 +497,7 @@ impl ResourceContents { Self::Empty(empty) => Self::Empty(empty.with_annotations(config)), } } - + /// Sets the text of the resource content and make it [`TextResourceContents`] #[inline] pub fn with_text(self, text: impl Into) -> Self { @@ -533,7 +534,7 @@ impl ResourceContents { annotations: content.annotations, meta: None, text, - }) + }), } } @@ -573,15 +574,14 @@ impl ResourceContents { annotations: content.annotations, meta: None, blob, - }) + }), } } /// Sets the JSON text of the resource content and make it [`TextResourceContents`] #[inline] pub fn with_json(self, data: T) -> Self { - let value = serde_json::to_value(data) - .expect("Failed to serialize JSON"); + let value = serde_json::to_value(data).expect("Failed to serialize JSON"); match self { Self::Text(content) => Self::Json(JsonResourceContents { uri: content.uri, @@ -614,7 +614,7 @@ impl ResourceContents { annotations: content.annotations, meta: None, value, - }) + }), } } } @@ -639,7 +639,7 @@ impl TextResourceContents { self.mime = Some(mime.into()); self } - + /// Sets the title of the resource #[inline] pub fn with_title(mut self, title: impl Into) -> Self { @@ -650,7 +650,7 @@ impl TextResourceContents { /// Sets annotations for the client pub fn with_annotations(mut self, config: F) -> Self where - F: FnOnce(Annotations) -> Annotations + F: FnOnce(Annotations) -> Annotations, { self.annotations = Some(config(Default::default())); self @@ -688,7 +688,7 @@ impl JsonResourceContents { /// Sets annotations for the client pub fn with_annotations(mut self, config: F) -> Self where - F: FnOnce(Annotations) -> Annotations + F: FnOnce(Annotations) -> Annotations, { self.annotations = Some(config(Default::default())); self @@ -726,7 +726,7 @@ impl BlobResourceContents { /// Sets annotations for the client pub fn with_annotations(mut self, config: F) -> Self where - F: FnOnce(Annotations) -> Annotations + F: FnOnce(Annotations) -> Annotations, { self.annotations = Some(config(Default::default())); self @@ -775,11 +775,11 @@ impl EmptyResourceContents { self.title = Some(title.into()); self } - + /// Sets annotations for the client pub fn with_annotations(mut self, config: F) -> Self where - F: FnOnce(Annotations) -> Annotations + F: FnOnce(Annotations) -> Annotations, { self.annotations = Some(config(Default::default())); self @@ -789,15 +789,15 @@ impl EmptyResourceContents { #[cfg(test)] #[cfg(feature = "server")] mod tests { - use futures_util::StreamExt; use super::*; - + use futures_util::StreamExt; + #[derive(Serialize, Deserialize, Debug, PartialEq)] struct User { name: String, age: u8, } - + #[test] fn it_creates_result_from_array_of_contents() { let result = ReadResourceResult::from([ @@ -806,60 +806,83 @@ mod tests { .with_text("test 1"), ResourceContents::new("/res1") .with_mime("plain/text") - .with_text("test 1") + .with_text("test 1"), ]); let json = serde_json::to_string(&result).unwrap(); - assert_eq!(json, r#"{"contents":[{"uri":"/res1","text":"test 1","mimeType":"plain/text"},{"uri":"/res1","text":"test 1","mimeType":"plain/text"}]}"#); + assert_eq!( + json, + r#"{"contents":[{"uri":"/res1","text":"test 1","mimeType":"plain/text"},{"uri":"/res1","text":"test 1","mimeType":"plain/text"}]}"# + ); } #[test] fn it_creates_result_from_array_of_str_tuples1() { let result = ReadResourceResult::from([ TextResourceContents::new("/res1", "test 1"), - TextResourceContents::new("/res1", "test 1") + TextResourceContents::new("/res1", "test 1"), ]); let json = serde_json::to_string(&result).unwrap(); - assert_eq!(json, r#"{"contents":[{"uri":"/res1","text":"test 1","mimeType":"text/plain"},{"uri":"/res1","text":"test 1","mimeType":"text/plain"}]}"#); + assert_eq!( + json, + r#"{"contents":[{"uri":"/res1","text":"test 1","mimeType":"text/plain"},{"uri":"/res1","text":"test 1","mimeType":"text/plain"}]}"# + ); } #[test] fn it_creates_result_from_array_of_str_tuples2() { let result = ReadResourceResult::from([ TextResourceContents::from(("/res1", "json", "test 1")), - TextResourceContents::from(("/res1", "json", "test 1")) + TextResourceContents::from(("/res1", "json", "test 1")), ]); let json = serde_json::to_string(&result).unwrap(); - assert_eq!(json, r#"{"contents":[{"uri":"/res1","text":"test 1","mimeType":"json"},{"uri":"/res1","text":"test 1","mimeType":"json"}]}"#); + assert_eq!( + json, + r#"{"contents":[{"uri":"/res1","text":"test 1","mimeType":"json"},{"uri":"/res1","text":"test 1","mimeType":"json"}]}"# + ); } #[test] fn it_creates_result_from_array_of_string_tuples1() { let result = ReadResourceResult::from([ (String::from("/res1"), String::from("test 1")), - (String::from("/res1"), String::from("test 1")) + (String::from("/res1"), String::from("test 1")), ]); let json = serde_json::to_string(&result).unwrap(); - assert_eq!(json, r#"{"contents":[{"uri":"/res1","text":"test 1","mimeType":"text/plain"},{"uri":"/res1","text":"test 1","mimeType":"text/plain"}]}"#); + assert_eq!( + json, + r#"{"contents":[{"uri":"/res1","text":"test 1","mimeType":"text/plain"},{"uri":"/res1","text":"test 1","mimeType":"text/plain"}]}"# + ); } #[test] fn it_creates_result_from_array_of_string_tuples2() { let result = ReadResourceResult::from([ - (String::from("/res1"), String::from("json"), String::from("test 1")), - (String::from("/res1"), String::from("json"), String::from("test 1")) + ( + String::from("/res1"), + String::from("json"), + String::from("test 1"), + ), + ( + String::from("/res1"), + String::from("json"), + String::from("test 1"), + ), ]); let json = serde_json::to_string(&result).unwrap(); - assert_eq!(json, r#"{"contents":[{"uri":"/res1","text":"test 1","mimeType":"json"},{"uri":"/res1","text":"test 1","mimeType":"json"}]}"#); + assert_eq!( + json, + r#"{"contents":[{"uri":"/res1","text":"test 1","mimeType":"json"},{"uri":"/res1","text":"test 1","mimeType":"json"}]}"# + ); } #[test] @@ -869,19 +892,27 @@ mod tests { let json = serde_json::to_string(&result).unwrap(); - assert_eq!(json, r#"{"contents":[{"uri":"/res","text":"test","mimeType":"text/plain"}]}"#); + assert_eq!( + json, + r#"{"contents":[{"uri":"/res","text":"test","mimeType":"text/plain"}]}"# + ); } #[test] fn it_creates_result_from_objet_content() { - let content = ResourceContents::new("/res") - .with_json(User { name: "John".into(), age: 33 }); - + let content = ResourceContents::new("/res").with_json(User { + name: "John".into(), + age: 33, + }); + let result: ReadResourceResult = content.into(); let json = serde_json::to_string(&result).unwrap(); - assert_eq!(json, r#"{"contents":[{"uri":"/res","text":"{\"age\":33,\"name\":\"John\"}","mimeType":"application/json"}]}"#); + assert_eq!( + json, + r#"{"contents":[{"uri":"/res","text":"{\"age\":33,\"name\":\"John\"}","mimeType":"application/json"}]}"# + ); } #[test] @@ -889,7 +920,8 @@ mod tests { let blob = BlobResourceContents::new("file://hello", "hello world"); let json = serde_json::to_string(&blob).expect("Should serialize"); - let deserialized: BlobResourceContents = serde_json::from_str(&json).expect("Should deserialize"); + let deserialized: BlobResourceContents = + serde_json::from_str(&json).expect("Should deserialize"); assert_eq!(blob.blob, deserialized.blob); assert_eq!(blob.mime, deserialized.mime); @@ -931,7 +963,10 @@ mod tests { let result_string = String::from_utf8(collected_data).expect("Should be valid UTF-8"); assert_eq!(result_string, test_data); - assert!(chunk_count > 1, "Should have multiple chunks for large data"); + assert!( + chunk_count > 1, + "Should have multiple chunks for large data" + ); } #[tokio::test] @@ -963,6 +998,9 @@ mod tests { let result_string = String::from_utf8(collected_data).expect("Should be valid UTF-8"); assert_eq!(result_string, test_data); - assert_eq!(chunk_count, 1, "Exactly CHUNK_SIZE data should produce one chunk"); + assert_eq!( + chunk_count, 1, + "Exactly CHUNK_SIZE data should produce one chunk" + ); } -} \ No newline at end of file +} diff --git a/neva/src/types/resource/route.rs b/neva/src/types/resource/route.rs index 9be995d..ad85331 100644 --- a/neva/src/types/resource/route.rs +++ b/neva/src/types/resource/route.rs @@ -1,4 +1,4 @@ -//! A set of route handling tools +//! A set of route handling tools use super::ReadResourceResult; use crate::app::handler::RequestHandler; @@ -11,21 +11,21 @@ const CLOSE_BRACKET: char = '}'; /// Represents route path node pub(super) struct RouteNode { path: Box, - node: Box + node: Box, } /// A data structure for easy insert and search handler by route template pub(crate) struct Route { static_routes: Vec, dynamic_route: Option, - handler: Option + handler: Option, } /// A handler function for a resource route pub(crate) struct ResourceHandler { #[cfg(feature = "http-server")] pub(crate) template: String, - handler: RequestHandler + handler: RequestHandler, } impl RouteNode { @@ -34,16 +34,14 @@ impl RouteNode { fn new(path: &str) -> Self { Self { node: Box::new(Route::new()), - path: path.into() + path: path.into(), } } /// Compares two route entries #[inline(always)] fn cmp(&self, path: &str) -> std::cmp::Ordering { - self.path - .as_ref() - .cmp(path) + self.path.as_ref().cmp(path) } } @@ -73,17 +71,16 @@ impl Route { dynamic_route: None, } } - + /// Inserts a route handler pub(crate) fn insert( &mut self, path: &Uri, _template: String, - handler: RequestHandler + handler: RequestHandler, ) { let mut current = self; - let path_segments = path.parts() - .expect("URI parts should be present"); + let path_segments = path.parts().expect("URI parts should be present"); for segment in path_segments { if is_dynamic_segment(segment) { @@ -121,7 +118,8 @@ impl Route { return None; } - current.handler + current + .handler .as_ref() .map(|h| (h, params.into_boxed_slice())) } @@ -139,8 +137,7 @@ impl Route { #[inline(always)] fn insert_dynamic_node(&mut self, segment: &str) -> &mut Self { - self - .dynamic_route + self.dynamic_route .get_or_insert_with(|| RouteNode::new(segment)) .node .as_mut() @@ -149,16 +146,15 @@ impl Route { #[inline(always)] fn is_dynamic_segment(segment: &str) -> bool { - segment.starts_with(OPEN_BRACKET) && - segment.ends_with(CLOSE_BRACKET) + segment.starts_with(OPEN_BRACKET) && segment.ends_with(CLOSE_BRACKET) } #[cfg(test)] mod tests { + use super::*; use crate::types::resource::template::ResourceFunc; use crate::types::{ResourceContents, Uri}; - use super::*; - + #[test] fn it_inserts_and_finds() { let uri1: Uri = "res://path/to/{resource}".into(); @@ -174,12 +170,12 @@ mod tests { .with_mime("text/plain") .with_text("some text 2") }); - + let mut route = Route::default(); route.insert(&uri1, "templ_1".into(), handler1); route.insert(&uri2, "templ_2".into(), handler2); - + assert!(route.find(&uri1).is_some()); assert!(route.find(&uri2).is_some()); } -} \ No newline at end of file +} diff --git a/neva/src/types/resource/template.rs b/neva/src/types/resource/template.rs index 97fe3df..f48f774 100644 --- a/neva/src/types/resource/template.rs +++ b/neva/src/types/resource/template.rs @@ -1,40 +1,33 @@ -//! Utilities for Resource templates +//! Utilities for Resource templates -use std::fmt::Debug; -use serde::{Deserialize, Serialize}; #[cfg(feature = "server")] -use std::sync::Arc; +use crate::app::handler::{FromHandlerParams, GenericHandler, Handler, HandlerParams}; +#[cfg(feature = "server")] +use crate::error::Error; #[cfg(feature = "server")] use futures_util::future::BoxFuture; +use serde::{Deserialize, Serialize}; use serde_json::Value; +use std::fmt::Debug; #[cfg(feature = "server")] -use crate::error::Error; -#[cfg(feature = "server")] -use crate::app::handler::{ - FromHandlerParams, - GenericHandler, - Handler, - HandlerParams -}; +use std::sync::Arc; use crate::types::{ - resource::Uri, Annotations, IntoResponse, - RequestId, Response, - Cursor, Page, Icon + Annotations, Cursor, Icon, IntoResponse, Page, RequestId, Response, resource::Uri, }; #[cfg(feature = "server")] use crate::types::{FromRequest, ReadResourceRequestParams, ReadResourceResult, Request}; /// Represents a known resource template that the server is capable of reading. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Clone, Serialize, Deserialize)] pub struct ResourceTemplate { /// The URI template that identifies this resource template. #[serde(rename = "uriTemplate")] pub uri_template: Uri, - + /// A human-readable name for this resource template. pub name: String, @@ -72,7 +65,7 @@ pub struct ResourceTemplate { /// Metadata reserved by MCP for protocol-level metadata. #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] pub meta: Option, - + /// A list of roles that are allowed to read the resource #[serde(skip)] #[cfg(feature = "http-server")] @@ -96,7 +89,7 @@ pub struct ListResourceTemplatesRequestParams { } /// The server's response to a resources/templates/list request from the client. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Default, Serialize, Deserialize)] pub struct ListResourceTemplatesResult { @@ -119,7 +112,7 @@ impl IntoResponse for ListResourceTemplatesResult { fn into_response(self, req_id: RequestId) -> Response { match serde_json::to_value(self) { Ok(v) => Response::success(req_id, v), - Err(err) => Response::error(req_id, err.into()) + Err(err) => Response::error(req_id, err.into()), } } } @@ -127,9 +120,9 @@ impl IntoResponse for ListResourceTemplatesResult { impl From> for ListResourceTemplatesResult { #[inline] fn from(templates: Vec) -> Self { - Self { + Self { next_cursor: None, - templates + templates, } } } @@ -139,7 +132,7 @@ impl From> for ListResourceTemplatesResult { fn from(page: Page<'_, ResourceTemplate>) -> Self { Self { next_cursor: page.next_cursor, - templates: page.items.to_vec() + templates: page.items.to_vec(), } } } @@ -167,28 +160,31 @@ pub(crate) struct ResourceFunc where F: GenericHandler, R: TryInto, - Args: TryFrom + Args: TryFrom, { func: F, _marker: std::marker::PhantomData, } #[cfg(feature = "server")] -impl ResourceFunc +impl ResourceFunc where F: GenericHandler, R: TryInto, - Args: TryFrom + Args: TryFrom, { /// Creates a new [`ResourceFunc`] wrapped into [`Arc`] pub(crate) fn new(func: F) -> Arc { - let func = Self { func, _marker: std::marker::PhantomData }; + let func = Self { + func, + _marker: std::marker::PhantomData, + }; Arc::new(func) } } #[cfg(feature = "server")] -impl Handler for ResourceFunc +impl Handler for ResourceFunc where F: GenericHandler, R: TryInto, @@ -203,11 +199,7 @@ where Box::pin(async move { //let mut iter = params.args.into_iter().flatten().next(); let args = Args::try_from(params)?; - self.func - .call(args) - .await - .try_into() - .map_err(Into::into) + self.func.call(args).await.try_into().map_err(Into::into) }) } } @@ -247,7 +239,7 @@ impl ResourceTemplate { permissions: None, } } - + /// Sets a title for a resource template pub fn with_title(&mut self, title: impl Into) -> &mut Self { self.title = Some(title.into()); @@ -265,11 +257,11 @@ impl ResourceTemplate { self.mime = Some(mime.into()); self } - + /// Sets annotations for the resource template pub fn with_annotations(&mut self, config: F) -> &mut Self where - F: FnOnce(Annotations) -> Annotations + F: FnOnce(Annotations) -> Annotations, { self.annotations = Some(config(Default::default())); self @@ -280,12 +272,9 @@ impl ResourceTemplate { pub fn with_roles(&mut self, roles: T) -> &mut Self where T: IntoIterator, - I: Into + I: Into, { - self.roles = Some(roles - .into_iter() - .map(Into::into) - .collect()); + self.roles = Some(roles.into_iter().map(Into::into).collect()); self } @@ -294,12 +283,9 @@ impl ResourceTemplate { pub fn with_permissions(&mut self, permissions: T) -> &mut Self where T: IntoIterator, - I: Into + I: Into, { - self.permissions = Some(permissions - .into_iter() - .map(Into::into) - .collect()); + self.permissions = Some(permissions.into_iter().map(Into::into).collect()); self } @@ -311,6 +297,4 @@ impl ResourceTemplate { } #[cfg(test)] -mod tests { - -} +mod tests {} diff --git a/neva/src/types/resource/uri.rs b/neva/src/types/resource/uri.rs index 2654023..cd85a00 100644 --- a/neva/src/types/resource/uri.rs +++ b/neva/src/types/resource/uri.rs @@ -1,9 +1,9 @@ -//! URI helpers and utilities +//! URI helpers and utilities -use serde::{Serialize, Deserialize}; -use std::ops::{Deref, DerefMut}; -use std::fmt::{Display, Formatter}; use crate::shared::MemChr; +use serde::{Deserialize, Serialize}; +use std::fmt::{Display, Formatter}; +use std::ops::{Deref, DerefMut}; const PATH_SEPARATOR: char = '/'; const SCHEME_SEPARATOR: [u8; 3] = [b':', b'/', b'/']; @@ -55,13 +55,13 @@ impl Uri { pub fn into_inner(self) -> String { self.0 } - + /// Splits the URI into scheme and path parts - /// + /// /// # Example /// ```rust /// use neva::types::Uri; - /// + /// /// let uri = Uri::from("res://test1/test2"); /// /// assert_eq!(uri.parts().unwrap().collect::>(), ["res", "test1", "test2"]); @@ -73,27 +73,29 @@ impl Uri { rest = rest.trim_start_matches(PATH_SEPARATOR); - Some(std::iter::once(scheme) - .chain(MemChr::split(rest, PATH_SEPARATOR as u8))) + Some(std::iter::once(scheme).chain(MemChr::split(rest, PATH_SEPARATOR as u8))) } } #[cfg(test)] mod test { use super::*; - + #[test] fn it_converts_from_str() { let uri = Uri::from("res://test1"); - + assert_eq!(uri.to_string(), "res://test1"); } #[test] fn it_splits_scheme_and_path() { let uri = Uri::from("res://test1/test2"); - - assert_eq!(uri.parts().unwrap().collect::>(), ["res", "test1", "test2"]); + + assert_eq!( + uri.parts().unwrap().collect::>(), + ["res", "test1", "test2"] + ); } #[test] @@ -107,20 +109,29 @@ mod test { fn it_splits_scheme_and_path_with_double_slash() { let uri = Uri::from("res://test1//test2"); - assert_eq!(uri.parts().unwrap().collect::>(), ["res", "test1", "test2"]); + assert_eq!( + uri.parts().unwrap().collect::>(), + ["res", "test1", "test2"] + ); } #[test] fn it_splits_scheme_and_path_with_trailing_slash() { let uri = Uri::from("res://test1/test2/"); - assert_eq!(uri.parts().unwrap().collect::>(), ["res", "test1", "test2"]); + assert_eq!( + uri.parts().unwrap().collect::>(), + ["res", "test1", "test2"] + ); } #[test] fn it_splits_scheme_and_path_with_leading_slash() { let uri = Uri::from("res:///test1/test2"); - assert_eq!(uri.parts().unwrap().collect::>(), ["res", "test1", "test2"]); + assert_eq!( + uri.parts().unwrap().collect::>(), + ["res", "test1", "test2"] + ); } -} \ No newline at end of file +} diff --git a/neva/src/types/response.rs b/neva/src/types/response.rs index 777f8ef..c35055c 100644 --- a/neva/src/types/response.rs +++ b/neva/src/types/response.rs @@ -1,10 +1,10 @@ -//! Represents a response that MCP server provides +//! Represents a response that MCP server provides use crate::error::Error; -use serde::{Deserialize, Serialize}; +use crate::types::{JSONRPC_VERSION, Message, RequestId}; use serde::de::DeserializeOwned; -use serde_json::{json, Value}; -use crate::types::{RequestId, Message, JSONRPC_VERSION}; +use serde::{Deserialize, Serialize}; +use serde_json::{Value, json}; #[cfg(feature = "http-server")] use volga::headers::HeaderMap; @@ -21,23 +21,23 @@ mod into_response; pub enum Response { /// A successful response. Ok(OkResponse), - + /// A response that indicates an error occurred. - Err(ErrorResponse) + Err(ErrorResponse), } /// A successful response message in the JSON-RPC protocol. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OkResponse { - /// JSON-RPC protocol version. - /// + /// JSON-RPC protocol version. + /// /// > Note: always 2.0. pub jsonrpc: String, - + /// Request identifier matching the original request. #[serde(default)] pub id: RequestId, - + /// The result of the method invocation. pub result: Value, @@ -48,13 +48,13 @@ pub struct OkResponse { /// HTTP headers #[serde(skip)] #[cfg(feature = "http-server")] - pub headers: HeaderMap + pub headers: HeaderMap, } /// A response to a request that indicates an error occurred. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ErrorResponse { - /// JSON-RPC protocol version. + /// JSON-RPC protocol version. /// /// > Note: always 2.0. pub jsonrpc: String, @@ -73,8 +73,8 @@ pub struct ErrorResponse { /// HTTP headers #[serde(skip)] #[cfg(feature = "http-server")] - pub headers: HeaderMap -} + pub headers: HeaderMap, +} impl From for Message { #[inline] @@ -92,7 +92,7 @@ impl Response { #[cfg(feature = "http-server")] headers: HeaderMap::with_capacity(8), id, - result + result, }) } @@ -104,7 +104,7 @@ impl Response { #[cfg(feature = "http-server")] headers: HeaderMap::new(), id, - result: json!({}) + result: json!({}), }) } @@ -119,15 +119,15 @@ impl Response { error: error.into(), }) } - + /// Returns [`Response`] ID pub fn id(&self) -> &RequestId { match &self { Response::Ok(ok) => &ok.id, - Response::Err(err) => &err.id + Response::Err(err) => &err.id, } } - + /// Returns the full id (session_id?/response_id) pub fn full_id(&self) -> RequestId { let id = self.id().clone(); @@ -137,12 +137,12 @@ impl Response { id } } - + /// Set the `id` for the response pub fn set_id(mut self, id: RequestId) -> Self { match &mut self { Response::Ok(ok) => ok.id = id, - Response::Err(err) => err.id = id + Response::Err(err) => err.id = id, } self } @@ -160,7 +160,7 @@ impl Response { pub fn set_session_id(mut self, id: uuid::Uuid) -> Self { match &mut self { Response::Ok(ok) => ok.session_id = Some(id), - Response::Err(err) => err.session_id = Some(id) + Response::Err(err) => err.session_id = Some(id), } self } @@ -170,33 +170,31 @@ impl Response { pub fn set_headers(mut self, headers: HeaderMap) -> Self { match &mut self { Response::Ok(ok) => ok.headers = headers, - Response::Err(err) => err.headers = headers + Response::Err(err) => err.headers = headers, } self } - + /// Unwraps the [`Response`] into either result of `T` or [`Error`] pub fn into_result(self) -> Result { match self { Response::Ok(ok) => serde_json::from_value::(ok.result).map_err(Into::into), - Response::Err(err) => Err(err.error.into()) + Response::Err(err) => Err(err.error.into()), } } } #[cfg(test)] mod tests { - use crate::{error::Error, types::RequestId}; use super::Response; + use crate::{error::Error, types::RequestId}; #[test] fn it_deserializes_successful_response_with_int_id_to_json() { - let resp = Response::success( - RequestId::Number(42), - serde_json::json!({ "key": "test" })); - + let resp = Response::success(RequestId::Number(42), serde_json::json!({ "key": "test" })); + let json = serde_json::to_string(&resp).unwrap(); - + assert_eq!(json, r#"{"jsonrpc":"2.0","id":42,"result":{"key":"test"}}"#); } @@ -204,10 +202,14 @@ mod tests { fn it_deserializes_error_response_with_string_id_to_json() { let resp = Response::error( RequestId::String("id".into()), - Error::new(-32603, "some error message")); + Error::new(-32603, "some error message"), + ); let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"jsonrpc":"2.0","id":"id","error":{"code":-32603,"message":"some error message","data":null}}"#); + assert_eq!( + json, + r#"{"jsonrpc":"2.0","id":"id","error":{"code":-32603,"message":"some error message","data":null}}"# + ); } } diff --git a/neva/src/types/response/error_details.rs b/neva/src/types/response/error_details.rs index b265de3..25fb6f7 100644 --- a/neva/src/types/response/error_details.rs +++ b/neva/src/types/response/error_details.rs @@ -1,7 +1,7 @@ //! Represents error details utils for JSON-RPC responses -use serde::{Deserialize, Serialize}; use crate::error::{Error, ErrorCode}; +use serde::{Deserialize, Serialize}; /// Detailed error information #[derive(Debug, Clone, Serialize, Deserialize)] @@ -13,7 +13,7 @@ pub struct ErrorDetails { pub message: String, /// Optional additional error data. - pub data: Option + pub data: Option, } impl Default for ErrorDetails { @@ -22,7 +22,7 @@ impl Default for ErrorDetails { Self { code: ErrorCode::InternalError, message: "Unknown error".into(), - data: None + data: None, } } } @@ -30,10 +30,10 @@ impl Default for ErrorDetails { impl From for ErrorDetails { #[inline] fn from(err: Error) -> Self { - Self { - code: err.code, - message: err.to_string(), - data: None + Self { + code: err.code, + message: err.to_string(), + data: None, } } } @@ -49,10 +49,10 @@ impl ErrorDetails { /// Creates a new [`ErrorDetails`] #[inline] pub fn new(err: impl Into) -> Self { - Self { - code: ErrorCode::InternalError, - message: err.into(), - data: None + Self { + code: ErrorCode::InternalError, + message: err.into(), + data: None, } } -} \ No newline at end of file +} diff --git a/neva/src/types/response/into_response.rs b/neva/src/types/response/into_response.rs index 02d9a31..810bb17 100644 --- a/neva/src/types/response/into_response.rs +++ b/neva/src/types/response/into_response.rs @@ -1,12 +1,8 @@ -//! Tools for converting any type into MCP server response +//! Tools for converting any type into MCP server response -use serde::Serialize; use crate::error::{Error, ErrorCode}; -use crate::types::{ - RequestId, - Response, - Json -}; +use crate::types::{Json, RequestId, Response}; +use serde::Serialize; /// A trait for converting any return type into MCP response pub trait IntoResponse { @@ -48,7 +44,7 @@ impl IntoResponse for Json { fn into_response(self, req_id: RequestId) -> Response { match serde_json::to_value(self) { Ok(v) => Response::success(req_id, v), - Err(err) => Response::error(req_id, err.into()) + Err(err) => Response::error(req_id, err.into()), } } } @@ -68,13 +64,13 @@ impl IntoResponse for () { } impl IntoResponse for Result -where +where T: IntoResponse, - E: IntoResponse + E: IntoResponse, { #[inline] fn into_response(self, req_id: RequestId) -> Response { - match self { + match self { Ok(value) => value.into_response(req_id), Err(err) => err.into_response(req_id), } @@ -103,14 +99,17 @@ impl_into_response! { #[cfg(test)] mod tests { use super::*; - + #[test] fn it_converts_str_into_response() { let resp = "test".into_response(RequestId::default()); let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":"test"}}"#); + assert_eq!( + json, + r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":"test"}}"# + ); } #[test] @@ -119,7 +118,10 @@ mod tests { let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":"test"}}"#); + assert_eq!( + json, + r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":"test"}}"# + ); } #[test] @@ -128,7 +130,10 @@ mod tests { let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1}}"#); + assert_eq!( + json, + r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1}}"# + ); } #[test] fn it_converts_i16_into_response() { @@ -136,7 +141,10 @@ mod tests { let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1}}"#); + assert_eq!( + json, + r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1}}"# + ); } #[test] @@ -145,7 +153,10 @@ mod tests { let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1}}"#); + assert_eq!( + json, + r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1}}"# + ); } #[test] @@ -154,7 +165,10 @@ mod tests { let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1}}"#); + assert_eq!( + json, + r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1}}"# + ); } #[test] @@ -163,7 +177,10 @@ mod tests { let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1}}"#); + assert_eq!( + json, + r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1}}"# + ); } #[test] @@ -172,7 +189,10 @@ mod tests { let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1}}"#); + assert_eq!( + json, + r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1}}"# + ); } #[test] @@ -181,7 +201,10 @@ mod tests { let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1}}"#); + assert_eq!( + json, + r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1}}"# + ); } #[test] @@ -190,7 +213,10 @@ mod tests { let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1}}"#); + assert_eq!( + json, + r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1}}"# + ); } #[test] @@ -199,7 +225,10 @@ mod tests { let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1}}"#); + assert_eq!( + json, + r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1}}"# + ); } #[test] @@ -208,7 +237,10 @@ mod tests { let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1}}"#); + assert_eq!( + json, + r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1}}"# + ); } #[test] @@ -217,7 +249,10 @@ mod tests { let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1}}"#); + assert_eq!( + json, + r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1}}"# + ); } #[test] @@ -226,7 +261,10 @@ mod tests { let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1}}"#); + assert_eq!( + json, + r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1}}"# + ); } #[test] @@ -235,7 +273,10 @@ mod tests { let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1.5}}"#); + assert_eq!( + json, + r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1.5}}"# + ); } #[test] @@ -244,26 +285,37 @@ mod tests { let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1.5}}"#); + assert_eq!( + json, + r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":1.5}}"# + ); } - + #[test] fn it_converts_bool_into_response() { let resp = true.into_response(RequestId::default()); let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":true}}"#); + assert_eq!( + json, + r#"{"jsonrpc":"2.0","id":"(no id)","result":{"result":true}}"# + ); } #[test] fn it_converts_json_into_response() { - let json = Json::from(Test { name: "test".into() }); + let json = Json::from(Test { + name: "test".into(), + }); let resp = json.into_response(RequestId::default()); let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"jsonrpc":"2.0","id":"(no id)","result":{"name":"test"}}"#); + assert_eq!( + json, + r#"{"jsonrpc":"2.0","id":"(no id)","result":{"name":"test"}}"# + ); } #[test] @@ -273,11 +325,14 @@ mod tests { let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"jsonrpc":"2.0","id":"(no id)","result":{"some":"prop"}}"#); + assert_eq!( + json, + r#"{"jsonrpc":"2.0","id":"(no id)","result":{"some":"prop"}}"# + ); } - + #[derive(Serialize)] struct Test { - name: String + name: String, } } diff --git a/neva/src/types/root.rs b/neva/src/types/root.rs index f6bb41b..a8afbf9 100644 --- a/neva/src/types/root.rs +++ b/neva/src/types/root.rs @@ -1,13 +1,13 @@ //! Represents MCP Roots. -use serde::{Serialize, Deserialize}; -use crate::types::{Uri, request::RequestParamsMeta, IntoResponse, RequestId, Response}; +use crate::types::{IntoResponse, RequestId, Response, Uri, request::RequestParamsMeta}; +use serde::{Deserialize, Serialize}; /// List of commands for Roots pub mod commands { /// Command name that requests a list of roots available from the client. pub const LIST: &str = "roots/list"; - + /// Notification name that indicates that the list of roots has changed. pub const LIST_CHANGED: &str = "notifications/roots/list_changed"; } @@ -18,13 +18,13 @@ pub mod commands { /// > top-level directories or container resources that can be accessed and traversed. /// > Roots provide a hierarchical structure for organizing and accessing resources within the protocol. /// > Each root has a URI that uniquely identifies it and optional metadata like a human-readable name. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Root { /// The URI of the root. pub uri: Uri, - + /// A human-readable name for the root. pub name: String, @@ -34,9 +34,9 @@ pub struct Root { } /// Represents the parameters used to request a list of roots available from the client. -/// +/// /// > **Note:** The client responds with a ['ListRootsResult'] containing the client's roots. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Default, Clone, Serialize, Deserialize)] pub struct ListRootsRequestParams { @@ -49,9 +49,9 @@ pub struct ListRootsRequestParams { } /// Represents the client's response to a `roots/list` request from the server. -/// This result contains an array of Root objects, each representing a root directory +/// This result contains an array of Root objects, each representing a root directory /// or file that the server can operate on. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Default, Serialize, Deserialize)] pub struct ListRootsResult { @@ -60,7 +60,7 @@ pub struct ListRootsResult { /// > **Note:** This collection contains all available root URIs and their associated metadata. /// > Each root serves as an entry point for resource navigation in the Model Context Protocol. pub roots: Vec, - + /// An additional metadata for the result. #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] pub meta: Option, @@ -71,7 +71,7 @@ impl IntoResponse for ListRootsResult { fn into_response(self, req_id: RequestId) -> Response { match serde_json::to_value(self) { Ok(v) => Response::success(req_id, v), - Err(err) => Response::error(req_id, err.into()) + Err(err) => Response::error(req_id, err.into()), } } } @@ -79,10 +79,7 @@ impl IntoResponse for ListRootsResult { impl From> for ListRootsResult { #[inline] fn from(roots: Vec) -> Self { - Self { - roots, - meta: None, - } + Self { roots, meta: None } } } @@ -101,13 +98,13 @@ where impl Root { /// Creates a new [`Root`] pub fn new(uri: impl Into, name: impl Into) -> Self { - Self { - uri: uri.into(), + Self { + uri: uri.into(), name: name.into(), meta: None, } } - + /// Split [`Root`] into parts of URI and name pub fn into_parts(self) -> (Uri, String) { (self.uri, self.name) diff --git a/neva/src/types/sampling.rs b/neva/src/types/sampling.rs index 0e5a821..b256715 100644 --- a/neva/src/types/sampling.rs +++ b/neva/src/types/sampling.rs @@ -1,22 +1,16 @@ -//! Utilities for Sampling +//! Utilities for Sampling -use serde::{Serialize, Deserialize, Serializer, Deserializer}; -use crate::shared::{OneOrMany, IntoArgs}; +use crate::shared::{IntoArgs, OneOrMany}; use crate::types::{ - Tool, ToolUse, ToolResult, - Content, TextContent, ImageContent, AudioContent, - ResourceLink, EmbeddedResource, - PromptMessage, - Role, - RequestId, - Response, - IntoResponse + AudioContent, Content, EmbeddedResource, ImageContent, IntoResponse, PromptMessage, RequestId, + ResourceLink, Response, Role, TextContent, Tool, ToolResult, ToolUse, }; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; -#[cfg(feature = "client")] -use std::{pin::Pin, sync::Arc, future::Future}; #[cfg(feature = "tasks")] use crate::types::TaskMetadata; +#[cfg(feature = "client")] +use std::{future::Future, pin::Pin, sync::Arc}; const DEFAULT_MESSAGE_MAX_TOKENS: i32 = 512; @@ -27,64 +21,64 @@ pub mod commands { } /// Represents a message issued to or received from an LLM API within the Model Context Protocol. -/// +/// /// > **Note:** A [`SamplingMessage`] encapsulates content sent to or received from AI models in the Model Context Protocol. /// > Each message has a specific role [`Role::User`] or [`Role::Assistant`] and contains content which can be text or images. -/// > +/// > /// > [`SamplingMessage`] objects are typically used in collections within [`CreateMessageRequestParams`] /// > to represent prompts or queries for LLM sampling. They form the core data structure for text generation requests /// > within the Model Context Protocol. -/// > +/// > /// > While similar, to [`PromptMessage`], the [`SamplingMessage`] is focused on direct LLM sampling /// > operations rather than the enhanced resource embedding capabilities provided by [`PromptMessage`]. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SamplingMessage { /// The role of the message sender, indicating whether it's from a _user_ or an _assistant_. pub role: Role, - + /// The content of the message. - pub content: OneOrMany + pub content: OneOrMany, } -/// Represents the parameters used with a _"sampling/createMessage"_ +/// Represents the parameters used with a _"sampling/createMessage"_ /// request from a server to sample an LLM via the client. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CreateMessageRequestParams { /// The messages requested by the server to be included in the prompt. pub messages: Vec, - + /// The maximum number of tokens to generate in the LLM response, as requested by the server. /// - /// > **Note:** A token is generally a word or part of a word in the text. Setting this value helps control + /// > **Note:** A token is generally a word or part of a word in the text. Setting this value helps control /// > response length and computation time. The client may choose to sample fewer tokens than requested. #[serde(rename = "maxTokens")] pub max_tokens: i32, - + /// Represents an indication as to which server contexts should be included in the prompt. - /// + /// /// > **Note:** The client may ignore this request. #[serde(rename = "includeContext", skip_serializing_if = "Option::is_none")] pub include_context: Option, - + /// An optional metadata to pass through to the LLM provider. /// /// > **Note:** The format of this metadata is provider-specific and can include model-specific settings or - /// > configuration that isn't covered by standard parameters. This allows for passing custom parameters + /// > configuration that isn't covered by standard parameters. This allows for passing custom parameters /// > that are specific to certain AI models or providers. #[serde(rename = "metadata", skip_serializing_if = "Option::is_none")] pub meta: Option, - + /// Represents the server's preferences for which model to select. /// /// > **Note:** The client may ignore these preferences. - /// > + /// > /// > These preferences help the client make an appropriate model selection based on the server's priorities /// > for cost, speed, intelligence, and specific model hints. - /// > + /// > /// > When multiple dimensions are specified (cost, speed, intelligence), the client should balance these /// > based on their relative values. If specific model hints are provided, the client should evaluate them /// > in order and prioritize them over numeric priorities. @@ -100,13 +94,13 @@ pub struct CreateMessageRequestParams { /// Represents the temperature to use for sampling, as requested by the server. #[serde(rename = "temperature", skip_serializing_if = "Option::is_none")] pub temp: Option, - + /// Represents optional sequences of characters that signal the LLM to stop generating text when encountered. /// /// > **Note:** When the model generates any of these sequences during sampling, text generation stops immediately, - /// > even if the maximum token limit hasn't been reached. This is useful for controlling generation + /// > even if the maximum token limit hasn't been reached. This is useful for controlling generation /// > endings or preventing the model from continuing beyond certain points. - /// > + /// > /// > Stop sequences are typically case-sensitive, and typically the LLM will only stop generation when a produced /// > sequence exactly matches one of the provided sequences. Common uses include ending markers like _"END"_, punctuation /// > like _"."_, or special delimiter sequences like _"###"_. @@ -118,7 +112,7 @@ pub struct CreateMessageRequestParams { pub tools: Option>, /// Controls how the model uses tools. - /// + /// /// Default is `{ mode: "auto" }`. #[serde(rename = "toolChoice", skip_serializing_if = "Option::is_none")] pub tool_choice: Option, @@ -135,20 +129,20 @@ pub struct CreateMessageRequestParams { } /// Controls tool selection behavior for sampling requests. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)] pub struct ToolChoice { /// Mode that controls which tools the model can call. - pub mode: ToolChoiceMode + pub mode: ToolChoiceMode, } /// Represents the mode that controls which tools the model can call. -/// +/// /// - `auto` - Model decides whether to call tools (default). /// - `required` - Model must call at least one tool. /// - `none` - Model must not call any tools. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] @@ -161,25 +155,25 @@ pub enum ToolChoiceMode { Required, /// The mode value `none`. - None + None, } /// Specifies the context inclusion options for a request in the Model Context Protocol (MCP). -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Clone, Copy, Serialize, Deserialize)] pub enum ContextInclusion { /// Indicates that no context should be included. #[serde(rename = "none")] None, - + /// Indicates that context from the server that sent the request should be included. #[serde(rename = "thisServer")] ThisServer, - + /// Indicates that context from all servers that the client is connected to should be included. #[serde(rename = "allServers")] - AllServers + AllServers, } /// Represents a server's preferences for model selection, requested of the client during sampling. @@ -189,17 +183,17 @@ pub enum ContextInclusion { /// > faster but less capable, others are more capable but more expensive, and so /// > on. This struct allows servers to express their priorities across multiple /// > dimensions to help clients make an appropriate selection for their use case. -/// > +/// > /// > These preferences are always advisory. The client may ignore them. It is also /// > up to the client to decide how to interpret these preferences and how to /// > balance them against other considerations. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Default, Clone, Serialize, Deserialize)] pub struct ModelPreferences { /// Represents how much to prioritize cost when selecting a model. - /// - /// > **Note:** A value of _0_ means cost is not important, + /// + /// > **Note:** A value of _0_ means cost is not important, /// > while a value of _1_ means cost is the most important factor. #[serde(rename = "costPriority", skip_serializing_if = "Option::is_none")] pub cost_priority: Option, @@ -209,33 +203,36 @@ pub struct ModelPreferences { pub hints: Option>, /// Represents how much to prioritize sampling speed (latency) when selecting a model. - /// - /// > **Note:** A value of _0_ means speed is not important, + /// + /// > **Note:** A value of _0_ means speed is not important, /// > while a value of _1_ means speed is the most important factor. #[serde(rename = "speedPriority", skip_serializing_if = "Option::is_none")] pub speed_priority: Option, /// Represents how much to prioritize intelligence and capabilities when selecting a model. - /// - /// > **Note:** A value of _0_ means intelligence is not important, + /// + /// > **Note:** A value of _0_ means intelligence is not important, /// > while a value of _1_ means intelligence is the most important factor. - #[serde(rename = "intelligencePriority", skip_serializing_if = "Option::is_none")] + #[serde( + rename = "intelligencePriority", + skip_serializing_if = "Option::is_none" + )] pub intelligence_priority: Option, } /// Provides hints to use for model selection. /// /// > **Note:** When multiple hints are specified in [`ModelPreferences`], they are evaluated in order, -/// > with the first match taking precedence. -/// > +/// > with the first match taking precedence. +/// > /// > Clients should prioritize these hints over numeric priorities. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Default, Clone, Serialize, Deserialize)] pub struct ModelHint { /// A hint for a model name. - /// - /// > **Note:** The specified string can be a partial or full model name. Clients may also + /// + /// > **Note:** The specified string can be a partial or full model name. Clients may also /// > map hints to equivalent models from different providers. Clients make the final model /// > selection based on these preferences and their available models. #[serde(skip_serializing_if = "Option::is_none")] @@ -243,32 +240,32 @@ pub struct ModelHint { } /// Represents a client's response to a _"sampling/createMessage"_ from the server. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CreateMessageResult { /// Role of the user who generated the message. pub role: Role, - + /// Content of the message. pub content: OneOrMany, - + /// Name of the model that generated the message. /// /// > **Note:** This should contain the specific model identifier such as _"claude-3-5-sonnet-20241022"_ or _"o3-mini"_. - /// > + /// > /// > This property allows the server to know which model was used to generate the response, /// > enabling the appropriate handling based on the model's capabilities and characteristics. pub model: String, /// Reason why message generation (sampling) stopped, if known. - /// + /// /// ### Common values include: /// * `endTurn` - The model naturally completed its response. /// * `maxTokens` - The response was truncated due to reaching token limits. /// * `stopSequence` - A specific stop sequence was encountered during generation. /// * `toolUse` - The model wants to use one or more tools. - /// + /// /// This field is an open string to allow for provider-specific stop reasons. #[serde(rename = "stopReason", skip_serializing_if = "Option::is_none")] pub stop_reason: Option, @@ -279,18 +276,18 @@ pub struct CreateMessageResult { pub enum StopReason { /// The model naturally completed its response. EndTurn, - + /// The response was truncated due to reaching token limits. MaxTokens, - + /// A specific stop sequence was encountered during generation. StopSequence, - + /// The model wants to use one or more tools. ToolUse, - + /// Other stop reasons. - Other(String) + Other(String), } impl Serialize for StopReason { @@ -371,7 +368,7 @@ impl IntoResponse for CreateMessageResult { fn into_response(self, req_id: RequestId) -> Response { match serde_json::to_value(self) { Ok(v) => Response::success(req_id, v), - Err(err) => Response::error(req_id, err.into()) + Err(err) => Response::error(req_id, err.into()), } } } @@ -393,8 +390,7 @@ impl From for SamplingMessage { impl From for SamplingMessage { #[inline] fn from(msg: PromptMessage) -> Self { - Self::new(msg.role) - .with(msg.content) + Self::new(msg.role).with(msg.content) } } @@ -416,22 +412,22 @@ impl SamplingMessage { /// Creates a new [`SamplingMessage`] #[inline] pub fn new(role: Role) -> Self { - Self { + Self { content: OneOrMany::new(), - role + role, } } - + /// Creates a new [`SamplingMessage`] with a user role pub fn user() -> Self { Self::new(Role::User) } - + /// Creates a new [`SamplingMessage`] with an assistant role pub fn assistant() -> Self { Self::new(Role::Assistant) } - + /// Sets the content pub fn with>(mut self, content: T) -> Self { self.content.push(content.into()); @@ -445,38 +441,36 @@ impl ModelPreferences { pub fn new() -> Self { Self::default() } - + /// Sets the cost priority pub fn with_cost_priority(mut self, priority: f32) -> Self { self.cost_priority = Some(priority); self } - + /// Sets the speed priority pub fn with_speed_priority(mut self, priority: f32) -> Self { self.speed_priority = Some(priority); self } - + /// Sets the intelligence priority pub fn with_intel_priority(mut self, priority: f32) -> Self { self.intelligence_priority = Some(priority); self } - + /// Sets the model hint pub fn with_hint(mut self, hint: impl Into) -> Self { - self.hints - .get_or_insert_with(Vec::new) - .push(hint.into()); + self.hints.get_or_insert_with(Vec::new).push(hint.into()); self } /// Sets the model hints - pub fn with_hints(mut self, hint: T) -> Self - where + pub fn with_hints(mut self, hint: T) -> Self + where T: IntoIterator, - I: Into + I: Into, { self.hints .get_or_insert_with(Vec::new) @@ -489,7 +483,9 @@ impl ModelHint { /// Creates a new [`ModelHint`] #[inline] pub fn new(name: impl Into) -> Self { - Self { name: Some(name.into()) } + Self { + name: Some(name.into()), + } } } @@ -497,21 +493,27 @@ impl ToolChoice { /// Creates a new [`ToolChoice`] with [`ToolChoiceMode::Auto`] #[inline] pub fn auto() -> Self { - Self { mode: ToolChoiceMode::Auto } + Self { + mode: ToolChoiceMode::Auto, + } } /// Creates a new [`ToolChoice`] with [`ToolChoiceMode::None`] #[inline] pub fn none() -> Self { - Self { mode: ToolChoiceMode::None } + Self { + mode: ToolChoiceMode::None, + } } /// Creates a new [`ToolChoice`] with [`ToolChoiceMode::Required`] #[inline] pub fn required() -> Self { - Self { mode: ToolChoiceMode::Required } + Self { + mode: ToolChoiceMode::Required, + } } - + /// Returns `true` if the tool choice mode is [`ToolChoiceMode::Auto`] #[inline] pub fn is_auto(&self) -> bool { @@ -536,36 +538,35 @@ impl CreateMessageRequestParams { pub fn new() -> Self { Self::default() } - + /// Creates params for a single message request pub fn with_message(mut self, message: impl Into) -> Self { self.messages.push(message.into()); self } - + /// Creates params for multiple messages request pub fn with_messages(mut self, messages: I) -> Self where I: IntoIterator, T: Into, { - self.messages - .extend(messages.into_iter().map(Into::into)); + self.messages.extend(messages.into_iter().map(Into::into)); self } - + /// Sets the system prompt for this [`CreateMessageRequestParams`] pub fn with_sys_prompt(mut self, sys_prompt: impl Into) -> Self { self.sys_prompt = Some(sys_prompt.into()); self } - + /// Sets the `max_tokens` for this [`CreateMessageRequestParams`] pub fn with_max_tokens(mut self, max_tokens: i32) -> Self { self.max_tokens = max_tokens; self } - + /// Sets the [`ContextInclusion`] for this [`CreateMessageRequestParams`] pub fn with_include_ctx(mut self, inc: ContextInclusion) -> Self { self.include_context = Some(inc); @@ -589,37 +590,35 @@ impl CreateMessageRequestParams { self.include_context = Some(ContextInclusion::AllServers); self } - + /// Sets the [`ModelPreferences`] for this [`CreateMessageRequestParams`] pub fn with_pref(mut self, pref: ModelPreferences) -> Self { self.model_pref = Some(pref); self } - + /// Sets a temperature for this [`CreateMessageRequestParams`] pub fn with_temp(mut self, temp: f32) -> Self { self.temp = Some(temp); self } - + /// Sets the stop sequences for this [`CreateMessageRequestParams`] pub fn with_stop_seq(mut self, stop_sequences: Vec) -> Self { self.stop_sequences = Some(stop_sequences); self } - + /// Sets the list of tools that the model can use during generation - /// + /// /// Default: `None` pub fn with_tools>(mut self, tools: T) -> Self { - self.tools = Some(tools - .into_iter() - .collect()); + self.tools = Some(tools.into_iter().collect()); self.with_tool_choice(ToolChoiceMode::Auto) } /// Sets the control mode for tool selection behavior for sampling requests. - /// + /// /// Default: `None` pub fn with_tool_choice(mut self, mode: ToolChoiceMode) -> Self { self.tool_choice = Some(ToolChoice { mode }); @@ -627,7 +626,7 @@ impl CreateMessageRequestParams { } /// Makes the request task-augmented with TTL. - /// + /// /// Default: `None` #[cfg(feature = "tasks")] pub fn with_ttl(mut self, ttl: Option) -> Self { @@ -637,46 +636,39 @@ impl CreateMessageRequestParams { /// Returns an iterator of text messages pub fn text(&self) -> impl Iterator { - self.msg_iter("text") - .filter_map(|c| c.as_text()) + self.msg_iter("text").filter_map(|c| c.as_text()) } /// Returns an iterator of audio messages pub fn audio(&self) -> impl Iterator { - self.msg_iter("audio") - .filter_map(|c| c.as_audio()) + self.msg_iter("audio").filter_map(|c| c.as_audio()) } /// Returns an iterator of image messages pub fn images(&self) -> impl Iterator { - self.msg_iter("image") - .filter_map(|c| c.as_image()) + self.msg_iter("image").filter_map(|c| c.as_image()) } /// Returns an iterator of resource link messages pub fn links(&self) -> impl Iterator { - self.msg_iter("resource_link") - .filter_map(|c| c.as_link()) + self.msg_iter("resource_link").filter_map(|c| c.as_link()) } /// Returns an iterator of embedded resource messages pub fn resources(&self) -> impl Iterator { - self.msg_iter("resource") - .filter_map(|c| c.as_resource()) + self.msg_iter("resource").filter_map(|c| c.as_resource()) } /// Returns an iterator of tool use messages pub fn tools(&self) -> impl Iterator { - self.msg_iter("tool_use") - .filter_map(|c| c.as_tool()) + self.msg_iter("tool_use").filter_map(|c| c.as_tool()) } /// Returns an iterator of tool execution result messages pub fn results(&self) -> impl Iterator { - self.msg_iter("tool_result") - .filter_map(|c| c.as_result()) + self.msg_iter("tool_result").filter_map(|c| c.as_result()) } - + /// Returns a messages iterator of a given type #[inline] fn msg_iter(&self, t: &'static str) -> impl Iterator { @@ -698,46 +690,46 @@ impl CreateMessageResult { role, } } - + /// Creates a new [`CreateMessageResult`] with a user role pub fn user() -> Self { Self::new(Role::User) } - + /// Creates a new [`CreateMessageResult`] with an assistant role pub fn assistant() -> Self { Self::new(Role::Assistant) } - + /// Sets the stop reason pub fn with_stop_reason(mut self, reason: impl Into) -> Self { self.stop_reason = Some(reason.into()); self } - + /// Sets the model name pub fn with_model(mut self, model: impl Into) -> Self { self.model = model.into(); self } - + /// Sets the content pub fn with_content>(mut self, content: T) -> Self { self.content.push(content.into()); self } - + /// Marks that model completed the response #[inline] pub fn end_turn(self) -> Self { self.with_stop_reason(StopReason::EndTurn) } - + /// Requests a tool use and sets the stop reason to `toolUse` pub fn use_tool(self, name: N, args: Args) -> Self - where + where N: Into, - Args: IntoArgs + Args: IntoArgs, { self.with_content(ToolUse::new(name, args)) .with_stop_reason(StopReason::ToolUse) @@ -747,72 +739,64 @@ impl CreateMessageResult { pub fn use_tools(self, tools: impl IntoIterator) -> Self where N: Into, - Args: IntoArgs + Args: IntoArgs, { - tools.into_iter() + tools + .into_iter() .fold(self, |acc, (name, args)| acc.use_tool(name, args)) .with_stop_reason(StopReason::ToolUse) } /// Returns an iterator of text messages pub fn text(&self) -> impl Iterator { - self.msg_iter("text") - .filter_map(|c| c.as_text()) + self.msg_iter("text").filter_map(|c| c.as_text()) } /// Returns an iterator of audio content pub fn audio(&self) -> impl Iterator { - self.msg_iter("audio") - .filter_map(|c| c.as_audio()) + self.msg_iter("audio").filter_map(|c| c.as_audio()) } /// Returns an iterator of image content pub fn images(&self) -> impl Iterator { - self.msg_iter("image") - .filter_map(|c| c.as_image()) + self.msg_iter("image").filter_map(|c| c.as_image()) } /// Returns an iterator of resource link content pub fn links(&self) -> impl Iterator { - self.msg_iter("resource_link") - .filter_map(|c| c.as_link()) + self.msg_iter("resource_link").filter_map(|c| c.as_link()) } /// Returns an iterator of embedded resource content pub fn resources(&self) -> impl Iterator { - self.msg_iter("resource") - .filter_map(|c| c.as_resource()) + self.msg_iter("resource").filter_map(|c| c.as_resource()) } /// Returns an iterator of tool use content pub fn tools(&self) -> impl Iterator { - self.msg_iter("tool_use") - .filter_map(|c| c.as_tool()) + self.msg_iter("tool_use").filter_map(|c| c.as_tool()) } /// Returns an iterator of tool execution result content pub fn results(&self) -> impl Iterator { - self.msg_iter("tool_result") - .filter_map(|c| c.as_result()) + self.msg_iter("tool_result").filter_map(|c| c.as_result()) } - + /// Returns a content iterator of a given type #[inline] fn msg_iter(&self, t: &'static str) -> impl Iterator { - self.content - .iter() - .filter(move |c| c.get_type() == t) + self.content.iter().filter(move |c| c.get_type() == t) } } /// Represents a dynamic handler for handling sampling requests #[cfg(feature = "client")] pub(crate) type SamplingHandler = Arc< - dyn Fn(CreateMessageRequestParams) -> Pin< - Box + Send + 'static> - > - + Send - + Sync + dyn Fn( + CreateMessageRequestParams, + ) -> Pin + Send + 'static>> + + Send + + Sync, >; #[cfg(test)] @@ -836,27 +820,24 @@ mod tests { #[test] #[cfg(feature = "server")] fn it_sets_auto_tool_choice_when_tools_specified() { - let params = CreateMessageRequestParams::new() - .with_tools([ - Tool::new("test 1", async || "test 1"), - Tool::new("test 2", async || "test 2") - ]); + let params = CreateMessageRequestParams::new().with_tools([ + Tool::new("test 1", async || "test 1"), + Tool::new("test 2", async || "test 2"), + ]); assert_eq!(params.tool_choice.unwrap().mode, ToolChoiceMode::Auto); } #[test] fn it_sets_tool_choice() { - let params = CreateMessageRequestParams::new() - .with_tool_choice(ToolChoiceMode::Required); + let params = CreateMessageRequestParams::new().with_tool_choice(ToolChoiceMode::Required); assert_eq!(params.tool_choice.unwrap().mode, ToolChoiceMode::Required); } #[test] fn it_builds_sampling_message() { - let msg = SamplingMessage::user() - .with("Hello"); + let msg = SamplingMessage::user().with("Hello"); assert_eq!(msg.role, Role::User); assert_eq!(msg.content.len(), 1); @@ -878,17 +859,23 @@ mod tests { #[test] fn it_sets_context_inclusion() { - let params = CreateMessageRequestParams::new() - .with_no_ctx(); - assert!(matches!(params.include_context, Some(ContextInclusion::None))); + let params = CreateMessageRequestParams::new().with_no_ctx(); + assert!(matches!( + params.include_context, + Some(ContextInclusion::None) + )); - let params = CreateMessageRequestParams::new() - .with_this_server(); - assert!(matches!(params.include_context, Some(ContextInclusion::ThisServer))); + let params = CreateMessageRequestParams::new().with_this_server(); + assert!(matches!( + params.include_context, + Some(ContextInclusion::ThisServer) + )); - let params = CreateMessageRequestParams::new() - .with_all_servers(); - assert!(matches!(params.include_context, Some(ContextInclusion::AllServers))); + let params = CreateMessageRequestParams::new().with_all_servers(); + assert!(matches!( + params.include_context, + Some(ContextInclusion::AllServers) + )); } #[test] @@ -906,8 +893,7 @@ mod tests { #[test] fn it_handles_tool_use_in_result() { - let result = CreateMessageResult::assistant() - .use_tool("calculator", ()); + let result = CreateMessageResult::assistant().use_tool("calculator", ()); assert_eq!(result.stop_reason, Some(StopReason::ToolUse)); assert_eq!(result.content.len(), 1); @@ -923,9 +909,12 @@ mod tests { .with_hints(["gpt-4", "llama"]); assert_eq!(pref.hints.as_ref().unwrap().len(), 3); - assert_eq!(pref.hints.as_ref().unwrap()[0].name.as_deref(), Some("claude")); + assert_eq!( + pref.hints.as_ref().unwrap()[0].name.as_deref(), + Some("claude") + ); } - + #[test] fn it_converts_stop_reason_from_str() { let reasons = [ @@ -933,7 +922,7 @@ mod tests { (StopReason::MaxTokens, "maxTokens"), (StopReason::EndTurn, "endTurn"), (StopReason::StopSequence, "stopSequence"), - (StopReason::Other("test".to_string()), "test") + (StopReason::Other("test".to_string()), "test"), ]; for (expected, reason_str) in reasons { @@ -949,7 +938,7 @@ mod tests { (StopReason::MaxTokens, "maxTokens"), (StopReason::EndTurn, "endTurn"), (StopReason::StopSequence, "stopSequence"), - (StopReason::Other("test".to_string()), "test") + (StopReason::Other("test".to_string()), "test"), ]; for (expected, reason_str) in reasons { @@ -957,7 +946,7 @@ mod tests { assert_eq!(reason, expected); } } - + #[test] fn it_serializes_stop_reason() { let reasons = [ @@ -965,7 +954,7 @@ mod tests { (StopReason::MaxTokens, "\"maxTokens\""), (StopReason::EndTurn, "\"endTurn\""), (StopReason::StopSequence, "\"stopSequence\""), - (StopReason::Other("test".to_string()), "\"test\"") + (StopReason::Other("test".to_string()), "\"test\""), ]; for (reason, expected) in reasons { @@ -981,7 +970,7 @@ mod tests { (StopReason::MaxTokens, "\"maxTokens\""), (StopReason::EndTurn, "\"endTurn\""), (StopReason::StopSequence, "\"stopSequence\""), - (StopReason::Other("test".to_string()), "\"test\"") + (StopReason::Other("test".to_string()), "\"test\""), ]; for (expected, reason_str) in reasons { @@ -989,27 +978,27 @@ mod tests { assert_eq!(reason, expected); } } - + #[test] fn it_serializes_model_preferences() { let pref = ModelPreferences::new() .with_cost_priority(0.5) .with_speed_priority(0.75) .with_intel_priority(0.25); - + let json = serde_json::to_string(&pref).unwrap(); - + let expected = r#"{"costPriority":0.5,"speedPriority":0.75,"intelligencePriority":0.25}"#; assert_eq!(json, expected); } - + #[test] fn it_deserializes_model_preferences() { let json = r#"{"costPriority":0.5,"speedPriority":0.75,"intelligencePriority":0.25}"#; let pref: ModelPreferences = serde_json::from_str(json).unwrap(); - + assert_eq!(pref.cost_priority, Some(0.5)); assert_eq!(pref.speed_priority, Some(0.75)); assert_eq!(pref.intelligence_priority, Some(0.25)); } -} \ No newline at end of file +} diff --git a/neva/src/types/schema.rs b/neva/src/types/schema.rs index 0ef560c..22be637 100644 --- a/neva/src/types/schema.rs +++ b/neva/src/types/schema.rs @@ -1,12 +1,15 @@ //! Utilities for Primitive JSON schema definitions -use serde::{Deserialize, Serialize, Deserializer, Serializer}; -use serde_json::Value; -use crate::{error::{Error, ErrorCode}, types::PropertyType}; use crate::prelude::Schema::{MultiTitledEnum, SingleTitledEnum}; use crate::types::Schema::{MultiUntitledEnum, SingleUntitledEnum}; +use crate::{ + error::{Error, ErrorCode}, + types::PropertyType, +}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde_json::Value; -/// Represents restricted subset of JSON Schema: +/// Represents restricted subset of JSON Schema: /// - [`StringSchema`] /// - [`NumberSchema`] /// - [`BooleanSchema`] @@ -18,10 +21,10 @@ use crate::types::Schema::{MultiUntitledEnum, SingleUntitledEnum}; pub enum Schema { /// See [`StringSchema`] String(StringSchema), - + /// See [`NumberSchema`] Number(NumberSchema), - + /// See [`BooleanSchema`] Boolean(BooleanSchema), @@ -38,7 +41,7 @@ pub enum Schema { MultiTitledEnum(TitledMultiSelectEnumSchema), /// See [`LegacyTitledEnum`] - LegacyEnum(LegacyTitledEnumSchema) + LegacyEnum(LegacyTitledEnumSchema), } /// Represents a schema for a string type. @@ -55,15 +58,15 @@ pub struct StringSchema { /// A human-readable description of the property #[serde(rename = "description", skip_serializing_if = "Option::is_none")] pub descr: Option, - + /// The minimum length for the string. #[serde(rename = "minLength", skip_serializing_if = "Option::is_none")] pub min_length: Option, - + /// The maximum length for the string. #[serde(rename = "maxLength", skip_serializing_if = "Option::is_none")] pub max_length: Option, - + /// A specific format for the string ("email", "uri", "date", or "date-time"). #[serde(skip_serializing_if = "Option::is_none")] pub format: Option, @@ -103,11 +106,11 @@ pub struct NumberSchema { /// A human-readable description of the property #[serde(rename = "description", skip_serializing_if = "Option::is_none")] pub descr: Option, - + /// The minimum allowed value. #[serde(rename = "minimum", skip_serializing_if = "Option::is_none")] pub min: Option, - + /// The maximum allowed value. #[serde(rename = "maximum", skip_serializing_if = "Option::is_none")] pub max: Option, @@ -127,7 +130,7 @@ pub struct BooleanSchema { /// A human-readable description of the property #[serde(rename = "description", skip_serializing_if = "Option::is_none")] pub descr: Option, - + /// The default value for the Boolean. #[serde(skip_serializing_if = "Option::is_none")] pub default: Option, @@ -135,7 +138,7 @@ pub struct BooleanSchema { /// Legacy enumeration schema for the protocol versions below `2025-11-25`. /// For the newer versions use the [`TitledSingleSelectEnum`] instead. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct LegacyTitledEnumSchema { @@ -150,18 +153,18 @@ pub struct LegacyTitledEnumSchema { /// A human-readable description of the property #[serde(rename = "description", skip_serializing_if = "Option::is_none")] pub descr: Option, - + /// The list of allowed string values for the enum. #[serde(rename = "enum")] pub r#enum: Vec, - + /// Optional display names corresponding to the enum values #[serde(rename = "enumNames", skip_serializing_if = "Option::is_none")] - pub enum_names: Option>, + pub enum_names: Option>, } /// Schema for single-selection enumeration without display titles for options. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct UntitledSingleSelectEnumSchema { @@ -231,11 +234,11 @@ pub struct UntitledMultiSelectEnumSchema { /// The list of allowed string values for the enum. pub items: EnumItems, - + /// Maximum number of items to select. #[serde(rename = "maxItems", skip_serializing_if = "Option::is_none")] pub max_items: Option, - + /// Minimum number of items to select. #[serde(rename = "minItems", skip_serializing_if = "Option::is_none")] pub min_items: Option, @@ -293,7 +296,7 @@ pub struct EnumItems { } /// Schema for array items with enum options and display labels. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct EnumOptions { @@ -310,7 +313,7 @@ pub struct EnumOption { /// The enum value. #[serde(rename = "const")] pub value: String, - + /// Display label for this option. pub title: String, } @@ -322,9 +325,9 @@ impl<'de> Deserialize<'de> for Schema { D: Deserializer<'de>, { let value = Value::deserialize(deserializer)?; - let obj = value.as_object().ok_or_else(|| { - serde::de::Error::custom("Expected object") - })?; + let obj = value + .as_object() + .ok_or_else(|| serde::de::Error::custom("Expected object"))?; let type_field = obj.get("type").and_then(|v| v.as_str()); let schema = match type_field { @@ -332,25 +335,19 @@ impl<'de> Deserialize<'de> for Schema { if obj.contains_key("enum") { if obj.contains_key("enumNames") { Schema::LegacyEnum( - serde_json::from_value(value) - .map_err(serde::de::Error::custom)? + serde_json::from_value(value).map_err(serde::de::Error::custom)?, ) } else { Schema::SingleUntitledEnum( - serde_json::from_value(value) - .map_err(serde::de::Error::custom)? + serde_json::from_value(value).map_err(serde::de::Error::custom)?, ) } } else if obj.contains_key("oneOf") { Schema::SingleTitledEnum( - serde_json::from_value(value) - .map_err(serde::de::Error::custom)? + serde_json::from_value(value).map_err(serde::de::Error::custom)?, ) } else { - Schema::String( - serde_json::from_value(value) - .map_err(serde::de::Error::custom)? - ) + Schema::String(serde_json::from_value(value).map_err(serde::de::Error::custom)?) } } Some("array") => { @@ -358,13 +355,11 @@ impl<'de> Deserialize<'de> for Schema { if let Some(items_obj) = items.and_then(|v| v.as_object()) { if items_obj.contains_key("anyOf") { Schema::MultiTitledEnum( - serde_json::from_value(value) - .map_err(serde::de::Error::custom)? + serde_json::from_value(value).map_err(serde::de::Error::custom)?, ) } else if items_obj.contains_key("enum") { Schema::MultiUntitledEnum( - serde_json::from_value(value) - .map_err(serde::de::Error::custom)? + serde_json::from_value(value).map_err(serde::de::Error::custom)?, ) } else { return Err(serde::de::Error::custom("Unknown array schema type")); @@ -373,18 +368,17 @@ impl<'de> Deserialize<'de> for Schema { return Err(serde::de::Error::custom("Array schema missing items")); } } - Some("number") | Some("integer") => Schema::Number( - serde_json::from_value(value) - .map_err(serde::de::Error::custom)? - ), - Some("boolean") => Schema::Boolean( - serde_json::from_value(value) - .map_err(serde::de::Error::custom)? - ), + Some("number") | Some("integer") => { + Schema::Number(serde_json::from_value(value).map_err(serde::de::Error::custom)?) + } + Some("boolean") => { + Schema::Boolean(serde_json::from_value(value).map_err(serde::de::Error::custom)?) + } _ => { - return Err(serde::de::Error::custom( - format!("Unknown or missing type field: {:?}", type_field) - )); + return Err(serde::de::Error::custom(format!( + "Unknown or missing type field: {:?}", + type_field + ))); } }; @@ -457,7 +451,7 @@ impl Default for LegacyTitledEnumSchema { r#enum: Vec::new(), title: None, descr: None, - enum_names: None + enum_names: None, } } } @@ -470,7 +464,7 @@ impl Default for UntitledSingleSelectEnumSchema { r#enum: Vec::new(), title: None, descr: None, - default: None + default: None, } } } @@ -483,7 +477,7 @@ impl Default for TitledSingleSelectEnumSchema { one_of: Vec::new(), title: None, descr: None, - default: None + default: None, } } } @@ -498,7 +492,7 @@ impl Default for UntitledMultiSelectEnumSchema { min_items: None, title: None, descr: None, - default: None + default: None, } } } @@ -513,7 +507,7 @@ impl Default for TitledMultiSelectEnumSchema { min_items: None, title: None, descr: None, - default: None + default: None, } } } @@ -521,9 +515,9 @@ impl Default for TitledMultiSelectEnumSchema { impl Default for EnumItems { #[inline] fn default() -> Self { - Self { - r#type: PropertyType::String, - r#enum: Vec::new() + Self { + r#type: PropertyType::String, + r#enum: Vec::new(), } } } @@ -531,16 +525,14 @@ impl Default for EnumItems { impl Default for EnumOptions { #[inline] fn default() -> Self { - Self { - any_of: Vec::new() - } + Self { any_of: Vec::new() } } } impl From<&str> for Schema { #[inline] fn from(value: &str) -> Self { - match value { + match value { "string" => Self::string(), "number" => Self::number(), "boolean" => Self::boolean(), @@ -554,8 +546,7 @@ impl From<&str> for Schema { impl From<&Value> for Schema { #[inline] fn from(value: &Value) -> Self { - serde_json::from_value(value.clone()) - .unwrap_or_else(|_| Schema::string()) + serde_json::from_value(value.clone()).unwrap_or_else(|_| Schema::string()) } } @@ -571,20 +562,20 @@ impl Schema { pub fn string() -> Self { Self::String(Default::default()) } - + /// Creates a new [`Schema`] instance with a [`NumberSchema`] type. pub fn number() -> Self { Self::Number(Default::default()) } - + /// Creates a new [`Schema`] instance with a [`BooleanSchema`] type. pub fn boolean() -> Self { Self::Boolean(Default::default()) } - + /// Creates a new [`Schema`] instance with a [`SingleUntitledEnum`] type. pub fn single_untitled_enum() -> Self { - SingleUntitledEnum(Default::default()) + SingleUntitledEnum(Default::default()) } /// Creates a new [`Schema`] instance with a [`SingleTitledEnum`] type. @@ -607,28 +598,34 @@ impl StringSchema { /// Validates the string value against the schema. #[inline] pub(crate) fn validate(&self, value: &Value) -> Result<(), Error> { - let str_value = value.as_str() - .ok_or(Error::new(ErrorCode::InvalidParams, "Expected string value"))?; + let str_value = value.as_str().ok_or(Error::new( + ErrorCode::InvalidParams, + "Expected string value", + ))?; if let Some(min_len) = self.min_length - && str_value.len() < min_len { + && str_value.len() < min_len + { return Err(Error::new( ErrorCode::InvalidParams, - format!("String too short: {} < {min_len}", str_value.len()))); + format!("String too short: {} < {min_len}", str_value.len()), + )); } if let Some(max_len) = self.max_length - && str_value.len() > max_len { + && str_value.len() > max_len + { return Err(Error::new( ErrorCode::InvalidParams, - format!("String too long: {} > {max_len}", str_value.len()))); + format!("String too long: {} > {max_len}", str_value.len()), + )); } // Validate format if specified if let Some(format) = &self.format { self.validate_string_format(str_value, format)?; } - + Ok(()) } @@ -639,19 +636,25 @@ impl StringSchema { if !value.contains('@') || !value.contains('.') { return Err(Error::new(ErrorCode::InvalidParams, "Invalid email format")); } - }, + } StringFormat::Uri => { let parts: Vec<&str> = value.splitn(2, "://").collect(); if parts.len() != 2 || parts[0].is_empty() || parts[1].is_empty() { return Err(Error::new(ErrorCode::InvalidParams, "Invalid URI format")); } - }, + } StringFormat::Date => { // Basic date format validation (YYYY-MM-DD) - if value.len() != 10 || value.chars().nth(4) != Some('-') || value.chars().nth(7) != Some('-') { - return Err(Error::new(ErrorCode::InvalidParams, "Invalid date format (expected YYYY-MM-DD)")); + if value.len() != 10 + || value.chars().nth(4) != Some('-') + || value.chars().nth(7) != Some('-') + { + return Err(Error::new( + ErrorCode::InvalidParams, + "Invalid date format (expected YYYY-MM-DD)", + )); } - }, + } StringFormat::DateTime => { if !value.contains('T') { return Err(Error::new(ErrorCode::InvalidParams, "Invalid date format")); @@ -666,22 +669,29 @@ impl NumberSchema { /// Validates the number value against the schema. #[inline] pub(crate) fn validate(&self, value: &Value) -> Result<(), Error> { - let num_value = value.as_f64() - .ok_or(Error::new(ErrorCode::InvalidParams, "Expected number value"))?; - - if let Some(min) = self.min && num_value < min { + let num_value = value.as_f64().ok_or(Error::new( + ErrorCode::InvalidParams, + "Expected number value", + ))?; + + if let Some(min) = self.min + && num_value < min + { return Err(Error::new( ErrorCode::InvalidParams, - format!("Number too small: {num_value} < {min}"))); + format!("Number too small: {num_value} < {min}"), + )); } - if let Some(max) = self.max && num_value > max { - return Err( - Error::new( - ErrorCode::InvalidParams, - format!("Number too large: {num_value} > {max}"))); + if let Some(max) = self.max + && num_value > max + { + return Err(Error::new( + ErrorCode::InvalidParams, + format!("Number too large: {num_value} > {max}"), + )); } - + Ok(()) } } @@ -690,7 +700,8 @@ impl BooleanSchema { /// Validates the boolean value against the schema. #[inline] pub(crate) fn validate(&self, value: &Value) -> Result<(), Error> { - value.is_boolean() + value + .is_boolean() .then_some(()) .ok_or_else(|| Error::new(ErrorCode::InvalidParams, "Expected boolean value")) } @@ -701,12 +712,16 @@ impl LegacyTitledEnumSchema { #[inline] pub(crate) fn validate(&self, value: &Value) -> Result<(), Error> { let str_value = make_str(value)?; - self.r#enum.iter() + self.r#enum + .iter() .any(|v| v == str_value) .then_some(()) - .ok_or_else(|| Error::new( - ErrorCode::InvalidParams, - format!("Invalid enum value: {str_value}"))) + .ok_or_else(|| { + Error::new( + ErrorCode::InvalidParams, + format!("Invalid enum value: {str_value}"), + ) + }) } } @@ -715,12 +730,16 @@ impl UntitledSingleSelectEnumSchema { #[inline] pub(crate) fn validate(&self, value: &Value) -> Result<(), Error> { let str_value = make_str(value)?; - self.r#enum.iter() + self.r#enum + .iter() .any(|v| v == str_value) .then_some(()) - .ok_or_else(|| Error::new( - ErrorCode::InvalidParams, - format!("Invalid enum value: {str_value}"))) + .ok_or_else(|| { + Error::new( + ErrorCode::InvalidParams, + format!("Invalid enum value: {str_value}"), + ) + }) } } @@ -729,12 +748,16 @@ impl TitledSingleSelectEnumSchema { #[inline] pub(crate) fn validate(&self, value: &Value) -> Result<(), Error> { let str_value = make_str(value)?; - self.one_of.iter() + self.one_of + .iter() .any(|v| v.value == str_value) .then_some(()) - .ok_or_else(|| Error::new( - ErrorCode::InvalidParams, - format!("Invalid enum value: {str_value}"))) + .ok_or_else(|| { + Error::new( + ErrorCode::InvalidParams, + format!("Invalid enum value: {str_value}"), + ) + }) } } @@ -743,8 +766,8 @@ impl UntitledMultiSelectEnumSchema { #[inline] pub(crate) fn validate(&self, value: &Value) -> Result<(), Error> { let mut str_values = make_iter_of_as_array(value)?; - str_values.all(|value| self.items.r#enum.iter() - .any(|v| v == value)) + str_values + .all(|value| self.items.r#enum.iter().any(|v| v == value)) .then_some(()) .ok_or_else(|| Error::new(ErrorCode::InvalidParams, "Invalid enum values")) } @@ -755,8 +778,8 @@ impl TitledMultiSelectEnumSchema { #[inline] pub(crate) fn validate(&self, value: &Value) -> Result<(), Error> { let mut str_values = make_iter_of_as_array(value)?; - str_values.all(|value| self.items.any_of.iter() - .any(|v| v.value == value)) + str_values + .all(|value| self.items.any_of.iter().any(|v| v.value == value)) .then_some(()) .ok_or_else(|| Error::new(ErrorCode::InvalidParams, "Invalid enum values")) } @@ -767,7 +790,7 @@ impl EnumOptions { #[inline] pub fn new(options: impl IntoIterator) -> Self { Self { - any_of: options.into_iter().collect() + any_of: options.into_iter().collect(), } } } @@ -778,7 +801,7 @@ impl EnumOption { pub fn new>(value: S, title: S) -> Self { Self { value: value.into(), - title: title.into() + title: title.into(), } } } @@ -790,11 +813,9 @@ impl EnumItems { I: IntoIterator, S: Into, { - Self { - r#type: PropertyType::String, - r#enum: items.into_iter() - .map(|s| s.into()) - .collect() + Self { + r#type: PropertyType::String, + r#enum: items.into_iter().map(|s| s.into()).collect(), } } } @@ -804,7 +825,12 @@ fn make_iter_of_as_array(value: &Value) -> Result, Er value .as_array() .map(|v| v.iter().filter_map(|v| v.as_str())) - .ok_or_else(|| Error::new(ErrorCode::InvalidParams, "Expected an array of values for enum")) + .ok_or_else(|| { + Error::new( + ErrorCode::InvalidParams, + "Expected an array of values for enum", + ) + }) } #[inline(always)] @@ -928,10 +954,7 @@ mod tests { #[test] fn it_validates_titled_single_select_enum() { let schema = TitledSingleSelectEnumSchema { - one_of: vec![ - EnumOption::new("A", "A"), - EnumOption::new("B", "B") - ], + one_of: vec![EnumOption::new("A", "A"), EnumOption::new("B", "B")], ..Default::default() }; @@ -955,10 +978,7 @@ mod tests { #[test] fn it_validates_titled_multi_select_enum() { let schema = TitledMultiSelectEnumSchema { - items: EnumOptions::new([ - EnumOption::new("A", "A"), - EnumOption::new("B", "B") - ]), + items: EnumOptions::new([EnumOption::new("A", "A"), EnumOption::new("B", "B")]), ..Default::default() }; @@ -981,7 +1001,7 @@ mod tests { panic!("Expected StringSchema"); } } - + #[test] fn it_serializes_schema_to_json() { let schema = Schema::String(StringSchema { @@ -989,10 +1009,13 @@ mod tests { ..Default::default() }); let json = serde_json::to_value(schema).unwrap(); - assert_eq!(json, json!({ - "type": "string", - "minLength": 5 - })) + assert_eq!( + json, + json!({ + "type": "string", + "minLength": 5 + }) + ) } #[test] @@ -1016,10 +1039,13 @@ mod tests { ..Default::default() }); let json = serde_json::to_value(schema).unwrap(); - assert_eq!(json, json!({ - "type": "number", - "minimum": 5.0 - })) + assert_eq!( + json, + json!({ + "type": "number", + "minimum": 5.0 + }) + ) } #[test] @@ -1043,10 +1069,13 @@ mod tests { ..Default::default() }); let json = serde_json::to_value(schema).unwrap(); - assert_eq!(json, json!({ - "type": "boolean", - "default": false - })) + assert_eq!( + json, + json!({ + "type": "boolean", + "default": false + }) + ) } #[test] @@ -1072,11 +1101,14 @@ mod tests { ..Default::default() }); let json = serde_json::to_value(schema).unwrap(); - assert_eq!(json, json!({ - "type": "string", - "enum": ["Red", "Green", "Blue"], - "default": "Red" - })) + assert_eq!( + json, + json!({ + "type": "string", + "enum": ["Red", "Green", "Blue"], + "default": "Red" + }) + ) } #[test] @@ -1092,11 +1124,14 @@ mod tests { }); let schema: Schema = serde_json::from_value(json).unwrap(); if let SingleTitledEnum(s) = schema { - assert_eq!(s.one_of, [ - EnumOption::new("#FF0000", "Red"), - EnumOption::new("#00FF00", "Green"), - EnumOption::new("#0000FF", "Blue"), - ]); + assert_eq!( + s.one_of, + [ + EnumOption::new("#FF0000", "Red"), + EnumOption::new("#00FF00", "Green"), + EnumOption::new("#0000FF", "Blue"), + ] + ); } else { panic!("Expected SingleTitledEnum"); } @@ -1114,15 +1149,18 @@ mod tests { ..Default::default() }); let json = serde_json::to_value(schema).unwrap(); - assert_eq!(json, json!({ - "type": "string", - "oneOf": [ - { "const": "#FF0000", "title": "Red" }, - { "const": "#00FF00", "title": "Green" }, - { "const": "#0000FF", "title": "Blue" } - ], - "default": "#FF0000" - })) + assert_eq!( + json, + json!({ + "type": "string", + "oneOf": [ + { "const": "#FF0000", "title": "Red" }, + { "const": "#00FF00", "title": "Green" }, + { "const": "#0000FF", "title": "Blue" } + ], + "default": "#FF0000" + }) + ) } #[test] @@ -1140,11 +1178,14 @@ mod tests { }); let schema: Schema = serde_json::from_value(json).unwrap(); if let MultiTitledEnum(s) = schema { - assert_eq!(s.items.any_of, [ - EnumOption::new("#FF0000", "Red"), - EnumOption::new("#00FF00", "Green"), - EnumOption::new("#0000FF", "Blue"), - ]); + assert_eq!( + s.items.any_of, + [ + EnumOption::new("#FF0000", "Red"), + EnumOption::new("#00FF00", "Green"), + EnumOption::new("#0000FF", "Blue"), + ] + ); } else { panic!("Expected MultiTitledEnum"); } @@ -1162,17 +1203,20 @@ mod tests { ..Default::default() }); let json = serde_json::to_value(schema).unwrap(); - assert_eq!(json, json!({ - "type": "array", - "items": { - "anyOf": [ - { "const": "#FF0000", "title": "Red" }, - { "const": "#00FF00", "title": "Green" }, - { "const": "#0000FF", "title": "Blue" } - ] - }, - "default": ["#FF0000", "#00FF00"] - })) + assert_eq!( + json, + json!({ + "type": "array", + "items": { + "anyOf": [ + { "const": "#FF0000", "title": "Red" }, + { "const": "#00FF00", "title": "Green" }, + { "const": "#0000FF", "title": "Blue" } + ] + }, + "default": ["#FF0000", "#00FF00"] + }) + ) } #[test] @@ -1201,13 +1245,16 @@ mod tests { ..Default::default() }); let json = serde_json::to_value(schema).unwrap(); - assert_eq!(json, json!({ - "type": "array", - "items": { - "type": "string", - "enum": ["Red", "Green", "Blue"] - }, - "default": ["Red", "Green"] - })) + assert_eq!( + json, + json!({ + "type": "array", + "items": { + "type": "string", + "enum": ["Red", "Green", "Blue"] + }, + "default": ["Red", "Green"] + }) + ) } -} \ No newline at end of file +} diff --git a/neva/src/types/task.rs b/neva/src/types/task.rs index a3f666d..7676ab3 100644 --- a/neva/src/types/task.rs +++ b/neva/src/types/task.rs @@ -1,18 +1,18 @@ //! Types and utilities for task-augmented requests and responses -use std::ops::{Deref, DerefMut}; -use serde::{Deserialize, Serialize, de::DeserializeOwned}; -use chrono::{DateTime, Utc}; -use serde_json::Value; use crate::{ - types::{Meta, Cursor, IntoResponse, Page, RequestId, Response}, - error::Error + error::Error, + types::{Cursor, IntoResponse, Meta, Page, RequestId, Response}, }; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; +use serde_json::Value; +use std::ops::{Deref, DerefMut}; #[cfg(feature = "server")] use crate::{ app::handler::{FromHandlerParams, HandlerParams}, - types::request::{FromRequest, Request} + types::request::{FromRequest, Request}, }; pub(crate) const RELATED_TASK_KEY: &str = "io.modelcontextprotocol/related-task"; @@ -23,22 +23,22 @@ const DEFAULT_TTL: usize = 30000; pub mod commands { /// Command name that returns a list of tasks that are currently running on the server. pub const LIST: &str = "tasks/list"; - + /// Command name that cancels a task on the server. pub const CANCEL: &str = "tasks/cancel"; - + /// Command name that returns the result of a task. pub const RESULT: &str = "tasks/result"; - + /// Command name that returns the status of a task. pub const GET: &str = "tasks/get"; - + /// Notification name that notifies the client about the status of a task. pub const STATUS: &str = "notifications/tasks/status"; } /// Represents a request to retrieve a list of tasks. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct ListTasksRequestParams { @@ -49,13 +49,13 @@ pub struct ListTasksRequestParams { } /// Represents the response to a `tasks/list` request. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct ListTasksResult { /// A list of tasks that the server currently runs. pub tasks: Vec, - + /// An opaque token representing the pagination position after the last returned result. /// /// When a paginated result has more data available, the `next_cursor` @@ -67,37 +67,37 @@ pub struct ListTasksResult { } /// Represents a request to cancel a task. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct CancelTaskRequestParams { /// The task identifier to cancel. #[serde(rename = "taskId")] - pub id: String + pub id: String, } /// Represents a request to retrieve the state of a task. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct GetTaskRequestParams { /// The task identifier to retrieve the state for. #[serde(rename = "taskId")] - pub id: String + pub id: String, } /// Represents a request to retrieve the result of a completed task. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct GetTaskPayloadRequestParams { /// The task identifier to retrieve the result for. #[serde(rename = "taskId")] - pub id: String + pub id: String, } /// Represents a response to a task-augmented request. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct CreateTaskResult { @@ -109,33 +109,33 @@ pub struct CreateTaskResult { pub meta: Option, } -/// Represents a task. Tasks are durable state machines that carry information -/// about the underlying execution state of the request they wrap, and are intended for requestor -/// polling and deferred result retrieval. -/// +/// Represents a task. Tasks are durable state machines that carry information +/// about the underlying execution state of the request they wrap, and are intended for requestor +/// polling and deferred result retrieval. +/// /// Each task is uniquely identifiable by a receiver-generated **task ID**. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Task { /// The task identifier. #[serde(rename = "taskId")] pub id: String, - + /// ISO 8601 timestamp when the task was created. #[serde(rename = "createdAt")] pub created_at: DateTime, - + /// ISO 8601 timestamp when the task was last updated. #[serde(rename = "lastUpdatedAt")] pub last_updated_at: DateTime, - + /// Time To Live: Actual retention duration from creation in milliseconds, null for unlimited. pub ttl: usize, - + /// Current task state. pub status: TaskStatus, - + /// Optional human-readable message describing the current task state. /// This can provide context for any status, including /// - Reasons for `cancelled` status @@ -143,37 +143,37 @@ pub struct Task { /// - Diagnostic information for `failed` status (e.g., error details, what went wrong) #[serde(rename = "statusMessage", skip_serializing_if = "Option::is_none")] pub status_msg: Option, - + /// Suggested polling interval in milliseconds. #[serde(rename = "pollInterval", skip_serializing_if = "Option::is_none")] pub poll_interval: Option, } /// Represents the status of a task. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Default, Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)] pub enum TaskStatus { /// Task has been canceled. #[serde(rename = "cancelled")] Cancelled, - + /// Task has completed successfully. #[serde(rename = "completed")] Completed, - + /// Task has failed. #[serde(rename = "failed")] Failed, - + /// Task is currently running. #[default] #[serde(rename = "working")] Working, - + /// Task requires an input to proceed. #[serde(rename = "input_required")] - InputRequired + InputRequired, } /// Represents metadata for augmenting a request with a task execution. @@ -189,7 +189,7 @@ pub struct TaskMetadata { /// Represents metadata for associating messages with a task. /// Include this in the `_meta` field under the key `io.modelcontextprotocol/related-task`. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct RelatedTaskMetadata { @@ -225,7 +225,7 @@ impl IntoResponse for Task { fn into_response(self, req_id: RequestId) -> Response { match serde_json::to_value(self) { Ok(v) => Response::success(req_id, v), - Err(err) => Response::error(req_id, err.into()) + Err(err) => Response::error(req_id, err.into()), } } } @@ -242,7 +242,7 @@ impl IntoResponse for CreateTaskResult { fn into_response(self, req_id: RequestId) -> Response { match serde_json::to_value(self) { Ok(v) => Response::success(req_id, v), - Err(err) => Response::error(req_id, err.into()) + Err(err) => Response::error(req_id, err.into()), } } } @@ -252,7 +252,7 @@ impl IntoResponse for ListTasksResult { fn into_response(self, req_id: RequestId) -> Response { match serde_json::to_value(self) { Ok(v) => Response::success(req_id, v), - Err(err) => Response::error(req_id, err.into()) + Err(err) => Response::error(req_id, err.into()), } } } @@ -262,7 +262,7 @@ impl From<[Task; N]> for ListTasksResult { fn from(tasks: [Task; N]) -> Self { Self { next_cursor: None, - tasks: tasks.to_vec() + tasks: tasks.to_vec(), } } } @@ -272,7 +272,7 @@ impl From> for ListTasksResult { fn from(tasks: Vec) -> Self { Self { next_cursor: None, - tasks + tasks, } } } @@ -282,7 +282,7 @@ impl From> for ListTasksResult { fn from(page: Page<'_, Task>) -> Self { Self { next_cursor: page.next_cursor, - tasks: page.items.to_vec() + tasks: page.items.to_vec(), } } } @@ -354,7 +354,7 @@ impl From for Task { ttl: meta.ttl.unwrap_or(DEFAULT_TTL), status: TaskStatus::Working, status_msg: None, - poll_interval: None + poll_interval: None, } } } @@ -385,10 +385,10 @@ impl Task { ttl: DEFAULT_TTL, status: TaskStatus::Working, status_msg: None, - poll_interval: None + poll_interval: None, } } - + /// Sets the status message of the task. pub fn set_message(&mut self, msg: impl Into) { self.status_msg = Some(msg.into()); @@ -437,7 +437,6 @@ impl TaskPayload { /// Unwraps the inner `T` #[inline] pub fn to(self) -> Result { - serde_json::from_value::(self.0) - .map_err(Error::from) + serde_json::from_value::(self.0).map_err(Error::from) } } diff --git a/neva/src/types/tool.rs b/neva/src/types/tool.rs index cfb5d5b..739c271 100644 --- a/neva/src/types/tool.rs +++ b/neva/src/types/tool.rs @@ -1,92 +1,74 @@ -//! Represents an MCP tool +//! Represents an MCP tool -use std::collections::HashMap; -use std::fmt::{Debug, Formatter}; -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use crate::shared; -use crate::types::{ - request::RequestParamsMeta, - PropertyType, - Cursor, - Icon -}; #[cfg(any(feature = "server", feature = "client"))] use crate::error::{Error, ErrorCode}; +use crate::shared; +use crate::types::{Cursor, Icon, PropertyType, request::RequestParamsMeta}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::HashMap; +use std::fmt::{Debug, Formatter}; #[cfg(feature = "server")] use { - std::{future::Future, sync::Arc}, - futures_util::future::BoxFuture, super::helpers::TypeCategory, crate::json::JsonSchema, + crate::types::{FromRequest, IntoResponse, Page, Request, RequestId, Response}, crate::{ Context, - app::handler::{ - FromHandlerParams, - Handler, - HandlerParams, - GenericHandler, - RequestHandler - } + app::handler::{FromHandlerParams, GenericHandler, Handler, HandlerParams, RequestHandler}, }, - crate::types::{ - FromRequest, - IntoResponse, - Page, - RequestId, - Request, - Response - } + futures_util::future::BoxFuture, + std::{future::Future, sync::Arc}, }; -#[cfg(feature = "tasks")] -use crate::types::TaskMetadata; #[cfg(all(feature = "server", feature = "tasks"))] use crate::types::RelatedTaskMetadata; +#[cfg(feature = "tasks")] +use crate::types::TaskMetadata; #[cfg(feature = "client")] use jsonschema::validator_for; pub use call_tool_response::CallToolResponse; +mod call_tool_response; #[cfg(feature = "server")] mod from_request; -mod call_tool_response; /// List of commands for Tools pub mod commands { /// Command name that returns a list of tools available on the server. pub const LIST: &str = "tools/list"; - + /// Name of a notification that indicates that the list of tools has changed. pub const LIST_CHANGED: &str = "notifications/tools/list_changed"; - + /// Command name that calls a tool on the server. pub const CALL: &str = "tools/call"; } /// Represents a tool that the server is capable of calling. Part of the [`ListToolsResult`]. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Clone, Serialize, Deserialize)] pub struct Tool { /// The name of the tool. pub name: String, - + /// Intended for UI and end-user contexts — optimized to be human-readable and easily understood, /// even by those unfamiliar with domain-specific terminology. - /// + /// /// If not provided, the name should be used for display (except for Tool, /// where `annotations.title` should be given precedence over using `name`, if present). #[serde(skip_serializing_if = "Option::is_none")] pub title: Option, - + /// A human-readable description of the tool. #[serde(rename = "description", skip_serializing_if = "Option::is_none")] pub descr: Option, - + /// A JSON Schema object defining the expected parameters for the tool. - /// + /// /// > Note: Needs to a valid JSON schema object that additionally is of a type object. #[serde(rename = "inputSchema")] pub input_schema: ToolSchema, @@ -99,7 +81,7 @@ pub struct Tool { pub output_schema: Option, /// Optional additional tool information. - /// + /// /// Display name precedence order is: title, annotations.title, then name. #[serde(skip_serializing_if = "Option::is_none")] pub annotations: Option, @@ -115,12 +97,12 @@ pub struct Tool { /// - `image/webp` - WebP images (modern, efficient format) #[serde(skip_serializing_if = "Option::is_none")] pub icons: Option>, - + /// Execution-related properties for this tool. #[cfg(feature = "tasks")] #[serde(rename = "execution", skip_serializing_if = "Option::is_none")] pub exec: Option, - + /// Metadata reserved by MCP for protocol-level metadata. #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] pub meta: Option, @@ -134,7 +116,7 @@ pub struct Tool { #[serde(skip)] #[cfg(feature = "http-server")] pub(crate) permissions: Option>, - + /// A tool call handler #[serde(skip)] #[cfg(feature = "server")] @@ -149,7 +131,7 @@ pub struct ToolExecution { /// This allows clients to handle long-running operations through polling /// the task system. #[serde(rename = "taskSupport", skip_serializing_if = "Option::is_none")] - pub task_support: Option + pub task_support: Option, } /// Represents task-augmentation support options for a tool. @@ -166,16 +148,16 @@ pub enum TaskSupport { /// Tool does not support task-augmented execution. #[default] Forbidden, - + /// Tool may support task-augmented execution. Optional, - + /// Tool requires task-augmented execution. Required, } /// Sent from the client to request a list of tools the server has. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Default, Serialize, Deserialize)] pub struct ListToolsRequestParams { @@ -186,13 +168,13 @@ pub struct ListToolsRequestParams { } /// A response to a request to list the tools available on the server. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Default, Serialize, Deserialize)] pub struct ListToolsResult { /// The server's response to a tools/list request from the client. pub tools: Vec, - + /// An opaque token representing the pagination position after the last returned result. /// /// When a paginated result has more data available, the `next_cursor` @@ -204,13 +186,13 @@ pub struct ListToolsResult { } /// Used by the client to invoke a tool provided by the server. -/// +/// /// See the [schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/) for details #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CallToolRequestParams { /// Tool name. pub name: String, - + /// Optional arguments to pass to the tool. #[serde(rename = "arguments")] pub args: Option>, @@ -218,15 +200,15 @@ pub struct CallToolRequestParams { /// If specified, the caller is requesting task-augmented execution for this request. /// The request will return a [`CreateTaskResult`] immediately, and the actual result can be /// retrieved later via `tasks/result`. - /// + /// /// **Note:** Task augmentation is subject to capability negotiation - receivers **MUST** declare support /// for task augmentation of specific request types in their capabilities. #[cfg(feature = "tasks")] #[serde(skip_serializing_if = "Option::is_none")] pub task: Option, - + /// Metadata related to the request that provides additional protocol-level information. - /// + /// /// > **Note:** This can include progress tracking tokens and other protocol-specific properties /// > that are not part of the primary request parameters. #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] @@ -237,11 +219,11 @@ pub struct CallToolRequestParams { #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ToolSchema { /// Schema object type - /// + /// /// > Note: always "object" #[serde(rename = "type", default)] pub r#type: PropertyType, - + /// A list of properties for command #[serde(skip_serializing_if = "Option::is_none")] pub properties: Option>, @@ -264,7 +246,7 @@ pub struct SchemaProperty { } /// Additional properties describing a Tool to clients. -/// +/// /// > **Note:** All properties in ToolAnnotations are **hints**. /// > They are not guaranteed to provide a faithful description of /// > tool behavior (including descriptive properties like `title`). @@ -275,37 +257,37 @@ pub struct ToolAnnotations { /// A human-readable title for the tool. #[serde(skip_serializing_if = "Option::is_none")] pub title: Option, - + /// If `true`, the tool may perform destructive updates to its environment. /// If `false`, the tool performs only additive updates. - /// + /// /// **Note:** This property is meaningful only when `readonly == false` - /// + /// /// Default: `true` #[serde(rename = "destructiveHint", skip_serializing_if = "Option::is_none")] pub destructive: Option, /// If `true`, calling the tool repeatedly with the same arguments /// will have no additional effect on its environment. - /// + /// /// **Note:** This property is meaningful only when `readonly == false` - /// + /// /// Default: `false` #[serde(rename = "idempotentHint", skip_serializing_if = "Option::is_none")] pub idempotent: Option, /// If `true`, this tool may interact with an **"open world"** of external entities. /// If `false`, the tool's domain of interaction is closed. - /// + /// /// For example, the world of a web search tool is open, whereas that /// of a memory tool is not. - /// + /// /// Default: `true` #[serde(rename = "openWorldHint", skip_serializing_if = "Option::is_none")] pub open_world: Option, /// If `true`, the tool does not modify its environment. - /// + /// /// Default: `false` #[serde(rename = "readOnlyHint", skip_serializing_if = "Option::is_none")] pub readonly: Option, @@ -317,7 +299,7 @@ impl IntoResponse for ListToolsResult { fn into_response(self, req_id: RequestId) -> Response { match serde_json::to_value(self) { Ok(v) => Response::success(req_id, v), - Err(err) => Response::error(req_id, err.into()) + Err(err) => Response::error(req_id, err.into()), } } } @@ -328,7 +310,7 @@ impl From> for ListToolsResult { fn from(tools: Vec) -> Self { Self { next_cursor: None, - tools + tools, } } } @@ -339,7 +321,7 @@ impl From> for ListToolsResult { fn from(page: Page<'_, Tool>) -> Self { Self { next_cursor: page.next_cursor, - tools: page.items.to_vec() + tools: page.items.to_vec(), } } } @@ -365,21 +347,19 @@ impl ListToolsResult { #[inline] pub fn get_by(&self, mut f: F) -> Option<&Tool> where - F: FnMut(&Tool) -> bool + F: FnMut(&Tool) -> bool, { - self.tools - .iter() - .find(|&t| f(t)) + self.tools.iter().find(|&t| f(t)) } } impl Default for ToolSchema { #[inline] fn default() -> Self { - Self { - r#type: PropertyType::Object, + Self { + r#type: PropertyType::Object, properties: Some(HashMap::new()), - required: None + required: None, } } } @@ -401,11 +381,11 @@ impl Default for ToolAnnotations { impl From<&str> for TaskSupport { #[inline] fn from(value: &str) -> Self { - match value { + match value { "forbidden" => Self::Forbidden, "required" => Self::Required, "optional" => Self::Optional, - _ => unreachable!() + _ => unreachable!(), } } } @@ -423,44 +403,47 @@ impl ToolSchema { /// Creates a new [`ToolSchema`] object #[inline] pub(crate) fn new(props: Option>) -> Self { - Self { r#type: PropertyType::Object, properties: props, required: None } + Self { + r#type: PropertyType::Object, + properties: props, + required: None, + } } - + /// Deserializes a new [`ToolSchema`] from a JSON string #[inline] pub fn from_json_str(json: &str) -> Self { - serde_json::from_str(json) - .expect("InputSchema: Incorrect JSON string provided") + serde_json::from_str(json).expect("InputSchema: Incorrect JSON string provided") } - - /// Adds a new property into the schema. + + /// Adds a new property into the schema. /// If a property with this name already exists, it overwrites it pub fn with_prop>( - self, - name: &str, - descr: &str, - property_type: T + self, + name: &str, + descr: &str, + property_type: T, ) -> Self { self.add_property_impl(name, descr, property_type.into()) } - /// Adds a new required property into the schema. + /// Adds a new required property into the schema. /// If a property with this name already exists, it overwrites it pub fn with_required>( self, name: &str, descr: &str, - property_type: T + property_type: T, ) -> Self { self.add_required_property_impl(name, descr, property_type.into()) } - + /// Creates a new [`ToolSchema`] from a [`JsonSchema`] object pub fn with_schema(self) -> Self { let json_schema = schemars::schema_for!(T); self.with_schema_impl(json_schema) } - + /// Creates a new [`ToolSchema`] from a [`schemars::Schema`] pub fn from_schema(json_schema: schemars::Schema) -> Self { Self::default().with_schema_impl(json_schema) @@ -468,19 +451,13 @@ impl ToolSchema { #[inline] fn with_schema_impl(mut self, json_schema: schemars::Schema) -> Self { - let required = json_schema - .get("required") - .and_then(|v| v.as_array()); - if let Some(props) = json_schema - .get("properties") - .and_then(|v| v.as_object()) { + let required = json_schema.get("required").and_then(|v| v.as_array()); + if let Some(props) = json_schema.get("properties").and_then(|v| v.as_object()) { for (field, def) in props { let req = required .map(|arr| !arr.iter().any(|v| v == field)) .unwrap_or(true); - let type_str = def.get("type") - .and_then(|v| v.as_str()) - .unwrap_or("string"); + let type_str = def.get("type").and_then(|v| v.as_str()).unwrap_or("string"); self = if req { self.add_required_property_impl(field, field, type_str.into()) } else { @@ -492,18 +469,14 @@ impl ToolSchema { } #[inline] - fn add_property_impl( - mut self, - name: &str, - descr: &str, - property_type: PropertyType - ) -> Self { - self.properties - .get_or_insert_with(HashMap::new) - .insert(name.into(), SchemaProperty { + fn add_property_impl(mut self, name: &str, descr: &str, property_type: PropertyType) -> Self { + self.properties.get_or_insert_with(HashMap::new).insert( + name.into(), + SchemaProperty { r#type: property_type, - descr: Some(descr.into()) - }); + descr: Some(descr.into()), + }, + ); self } @@ -512,12 +485,10 @@ impl ToolSchema { mut self, name: &str, descr: &str, - property_type: PropertyType + property_type: PropertyType, ) -> Self { self = self.add_property_impl(name, descr, property_type); - self.required - .get_or_insert_with(Vec::new) - .push(name.into()); + self.required.get_or_insert_with(Vec::new).push(name.into()); self } } @@ -527,9 +498,9 @@ impl SchemaProperty { /// Creates a new [`SchemaProperty`] for a `T` #[inline] pub(crate) fn new() -> Self { - Self { + Self { r#type: T::category(), - descr: None + descr: None, } } } @@ -567,22 +538,25 @@ pub(crate) struct ToolFunc where F: ToolHandler, R: Into, - Args: TryFrom + Args: TryFrom, { func: F, _marker: std::marker::PhantomData, } #[cfg(feature = "server")] -impl ToolFunc +impl ToolFunc where F: ToolHandler, R: Into, - Args: TryFrom + Args: TryFrom, { /// Creates a new [`ToolFunc`] wrapped into [`Arc`] pub(crate) fn new(func: F) -> Arc { - let func = Self { func, _marker: std::marker::PhantomData }; + let func = Self { + func, + _marker: std::marker::PhantomData, + }; Arc::new(func) } } @@ -596,15 +570,12 @@ where { #[inline] fn call(&self, params: HandlerParams) -> BoxFuture<'_, Result> { - let HandlerParams::Tool(params) = params else { + let HandlerParams::Tool(params) = params else { unreachable!() }; Box::pin(async move { let args = Args::try_from(params)?; - Ok(self.func - .call(args) - .await - .into()) + Ok(self.func.call(args).await.into()) }) } } @@ -617,7 +588,7 @@ impl CallToolRequestParams { args: None, meta: None, #[cfg(feature = "tasks")] - task: None + task: None, } } @@ -626,13 +597,13 @@ impl CallToolRequestParams { self.args = args.into_args(); self } - + /// Sets the metadata for the request pub fn with_meta(mut self, meta: RequestParamsMeta) -> Self { self.meta = Some(meta); self } - + /// Sets the TTL for the [`CallToolRequestParams`], /// which will be used if the tool is support tasks. #[cfg(feature = "tasks")] @@ -653,9 +624,7 @@ impl CallToolRequestParams { /// Associates [`CallToolRequestParams`] with the appropriated task #[cfg(feature = "tasks")] pub(crate) fn with_task(mut self, task_id: impl Into) -> Self { - self.meta.get_or_insert_default().task = Some(RelatedTaskMetadata { - id: task_id.into() - }); + self.meta.get_or_insert_default().task = Some(RelatedTaskMetadata { id: task_id.into() }); self } } @@ -678,7 +647,7 @@ impl Debug for Tool { #[cfg(feature = "server")] impl Tool { /// Initializes a new [`Tool`] - pub fn new(name: impl Into, handler: F) -> Self + pub fn new(name: impl Into, handler: F) -> Self where F: ToolHandler, R: Into + Send + 'static, @@ -701,76 +670,70 @@ impl Tool { #[cfg(feature = "http-server")] permissions: None, #[cfg(feature = "tasks")] - exec: None + exec: None, } } - + /// Sets a title for a tool pub fn with_title(&mut self, title: impl Into) -> &mut Self { self.title = Some(title.into()); self } - + /// Sets a description for a tool pub fn with_description(&mut self, description: &str) -> &mut Self { self.descr = Some(description.into()); self } - - /// Sets an input schema for the tool. - /// + + /// Sets an input schema for the tool. + /// /// > **Note:** Automatically generated schema will be overwritten pub fn with_input_schema(&mut self, config: F) -> &mut Self - where - F: FnOnce(ToolSchema) -> ToolSchema + where + F: FnOnce(ToolSchema) -> ToolSchema, { self.input_schema = config(Default::default()); self } - /// Sets an output schema for the tool. + /// Sets an output schema for the tool. /// /// > **Note:** Automatically generated schema will be overwritten pub fn with_output_schema(&mut self, config: F) -> &mut Self where - F: FnOnce(ToolSchema) -> ToolSchema + F: FnOnce(ToolSchema) -> ToolSchema, { self.output_schema = Some(config(Default::default())); self } - + /// Sets a list of roles that are allowed to invoke the tool #[cfg(feature = "http-server")] pub fn with_roles(&mut self, roles: T) -> &mut Self - where + where T: IntoIterator, - I: Into + I: Into, { - self.roles = Some(roles - .into_iter() - .map(Into::into) - .collect()); + self.roles = Some(roles.into_iter().map(Into::into).collect()); self } - + /// Sets a list of permissions that are allowed to invoke the tool #[cfg(feature = "http-server")] pub fn with_permissions(&mut self, permissions: T) -> &mut Self where T: IntoIterator, - I: Into + I: Into, { - self.permissions = Some(permissions - .into_iter() - .map(Into::into) - .collect()); + self.permissions = Some(permissions.into_iter().map(Into::into).collect()); self } - + /// Configures the annotations for the tool pub fn with_annotations(&mut self, config: F) -> &mut Self where - F: FnOnce(ToolAnnotations) -> ToolAnnotations + F: FnOnce(ToolAnnotations) -> ToolAnnotations, { self.annotations = Some(config(Default::default())); self @@ -788,13 +751,16 @@ impl Tool { self.exec = Some(ToolExecution::new(support.into())); self } - + /// Invoke a tool #[inline] pub(crate) async fn call(&self, params: HandlerParams) -> Result { - match self.handler { + match self.handler { Some(ref handler) => handler.call(params).await, - None => Err(Error::new(ErrorCode::InternalError, "Tool handler not specified")) + None => Err(Error::new( + ErrorCode::InternalError, + "Tool handler not specified", + )), } } } @@ -803,15 +769,19 @@ impl Tool { impl Tool { /// Validates [`CallToolResponse`] against this tool output schema pub fn validate<'a>(&self, resp: &'a CallToolResponse) -> Result<&'a CallToolResponse, Error> { - let schema = self.output_schema - .as_ref() - .map_or_else( - || Err(Error::new(ErrorCode::ParseError, "Tool: Output schema not specified")), - |s| serde_json::to_value(s.clone()).map_err(Into::into))?; - - let validator = validator_for(&schema) - .map_err(|err| Error::new(ErrorCode::ParseError, err))?; - + let schema = self.output_schema.as_ref().map_or_else( + || { + Err(Error::new( + ErrorCode::ParseError, + "Tool: Output schema not specified", + )) + }, + |s| serde_json::to_value(s.clone()).map_err(Into::into), + )?; + + let validator = + validator_for(&schema).map_err(|err| Error::new(ErrorCode::ParseError, err))?; + let content = resp.struct_content()?; validator .validate(content) @@ -825,9 +795,7 @@ impl Tool { /// Returns a task support for the tool if specified. #[inline] pub fn task_support(&self) -> Option { - self.exec - .as_ref() - .and_then(|e| e.task_support) + self.exec.as_ref().and_then(|e| e.task_support) } } @@ -842,19 +810,18 @@ impl ToolAnnotations { /// Deserializes a new [`ToolAnnotations`] from a JSON string #[inline] pub fn from_json_str(json: &str) -> Self { - serde_json::from_str(json) - .expect("ToolAnnotations: Incorrect JSON string provided") + serde_json::from_str(json).expect("ToolAnnotations: Incorrect JSON string provided") } - + /// Sets a title for the tool. #[inline] pub fn with_title(mut self, title: &str) -> Self { self.title = Some(title.into()); self } - + /// Sets/Unsets a hint that the tool may perform destructive updates to its environment. - /// + /// /// Also sets the readonly hint to `false` #[inline] pub fn with_destructive(mut self, destructive: bool) -> Self { @@ -863,17 +830,17 @@ impl ToolAnnotations { self } - /// Sets/Unsets a hint that the tool is idempotent. - /// So calling it repeatedly when it's `true` with the same arguments + /// Sets/Unsets a hint that the tool is idempotent. + /// So calling it repeatedly when it's `true` with the same arguments /// will have no additional effect on its environment. - /// + /// /// Also sets the readonly hint to `false` pub fn with_idempotent(mut self, idempotent: bool) -> Self { self.idempotent = Some(idempotent); self.readonly = Some(false); self } - + /// Sets/Unsets the hint that the tool may interact with an **"open world"** of external entities. #[inline] pub fn with_open_world(mut self, open_world: bool) -> Self { @@ -887,7 +854,9 @@ impl ToolExecution { /// Creates a new [`ToolExecution`] with a task support #[inline] pub fn new(support: TaskSupport) -> Self { - Self { task_support: Some(support) } + Self { + task_support: Some(support), + } } } @@ -913,7 +882,7 @@ macro_rules! impl_generic_tool_handler ({ $($param:ident)* } => { } }; )* - if args.len() == 0 { + if args.len() == 0 { None } else { Some(args) @@ -933,11 +902,11 @@ impl_generic_tool_handler! { T1 T2 T3 T4 T5 } #[cfg(feature = "server")] mod tests { use super::*; - + #[tokio::test] async fn it_creates_and_calls_tool() { let tool = Tool::new("sum", |a: i32, b: i32| async move { a + b }); - + let params = CallToolRequestParams { name: "sum".into(), meta: None, @@ -948,13 +917,16 @@ mod tests { ("b".into(), serde_json::to_value(2).unwrap()), ])), }; - + let resp = tool.call(params.into()).await.unwrap(); let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"content":[{"type":"text","text":"7"}],"isError":false}"#); + assert_eq!( + json, + r#"{"content":[{"type":"text","text":"7"}],"isError":false}"# + ); } - + #[test] fn it_deserializes_input_schema() { let json = r#"{ @@ -965,10 +937,10 @@ mod tests { } } }"#; - + let schema: ToolSchema = serde_json::from_str(json).unwrap(); - + assert_eq!(schema.r#type, PropertyType::Object); assert!(schema.properties.is_some()); } -} \ No newline at end of file +} diff --git a/neva/src/types/tool/call_tool_response.rs b/neva/src/types/tool/call_tool_response.rs index 2984071..54c351a 100644 --- a/neva/src/types/tool/call_tool_response.rs +++ b/neva/src/types/tool/call_tool_response.rs @@ -1,17 +1,14 @@ -//! Types and util for handling tool results +//! Types and util for handling tool results -use crate::types::{Content, IntoResponse, RequestId, Response}; -use serde::{Serialize, Deserialize}; -use serde_json::Value; -#[cfg(feature = "server")] -use crate::types::Json; #[cfg(any(feature = "server", feature = "client"))] use crate::error::Error; +#[cfg(feature = "server")] +use crate::types::Json; +use crate::types::{Content, IntoResponse, RequestId, Response}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; #[cfg(feature = "client")] -use { - crate::error::ErrorCode, - serde::de::DeserializeOwned -}; +use {crate::error::ErrorCode, serde::de::DeserializeOwned}; #[cfg(feature = "client")] const MISSING_STRUCTURED_CONTENT: &str = "Tool: Missing structured content"; @@ -32,7 +29,7 @@ const MISSING_STRUCTURED_CONTENT: &str = "Tool: Missing structured content"; pub struct CallToolResponse { /// The server's response to a tools/call request from the client. pub content: Vec, - + /// An optional JSON object that represents the structured result of the tool call. #[serde(rename = "structuredContent", skip_serializing_if = "Option::is_none")] pub struct_content: Option, @@ -47,7 +44,7 @@ impl IntoResponse for CallToolResponse { fn into_response(self, req_id: RequestId) -> Response { match serde_json::to_value(self) { Ok(v) => Response::success(req_id, v), - Err(err) => Response::error(req_id, err.into()) + Err(err) => Response::error(req_id, err.into()), } } } @@ -68,7 +65,7 @@ where { #[inline] fn from(value: Result) -> Self { - match value { + match value { Ok(value) => value.into(), Err(error) => error.into().into(), } @@ -146,7 +143,7 @@ impl CallToolResponse { /// Creates a single response #[inline] pub fn new(text: impl Into) -> Self { - Self { + Self { content: vec![text.into()], struct_content: None, is_error: false, @@ -158,15 +155,16 @@ impl CallToolResponse { pub fn array(texts: T) -> Self where T: IntoIterator, - I: Into + I: Into, { - let content = texts - .into_iter() - .map(Into::into) - .collect(); - Self { content, struct_content: None, is_error: false } + let content = texts.into_iter().map(Into::into).collect(); + Self { + content, + struct_content: None, + is_error: false, + } } - + /// Creates a single structured JSON response #[inline] pub fn json(data: T) -> Self { @@ -185,18 +183,19 @@ impl CallToolResponse { pub fn array_json(data: T) -> Self where T: IntoIterator, - I: Serialize + I: Serialize, { let vec = data.into_iter().collect::>(); - match serde_json::to_value(&vec) { + match serde_json::to_value(&vec) { Err(err) => Self::error(err.into()), Ok(structure) => Self { struct_content: Some(structure), is_error: false, - content: vec.into_iter() + content: vec + .into_iter() .map(|item| Content::json(&item)) .collect::>(), - } + }, } } @@ -221,24 +220,26 @@ impl CallToolResponse { } /// Creates a structure for existing text content. - /// + /// /// **Note:** If the content type is not a text, this won't get any effect. #[inline] pub fn with_structure(mut self) -> Self { let item = &self.content[0]; if self.content.len() == 1 { if let Content::Text(text) = item { - match serde_json::from_str(&text.text) { + match serde_json::from_str(&text.text) { Ok(structure) => self.struct_content = Some(structure), Err(err) => return Self::error(err.into()), } } } else if let Content::Text(_) = item { - let data = self.content + let data = self + .content .iter() - .filter_map(|item| item - .as_text() - .and_then(|c| serde_json::from_str(&c.text).ok())) + .filter_map(|item| { + item.as_text() + .and_then(|c| serde_json::from_str(&c.text).ok()) + }) .collect::>(); match serde_json::to_value(&data) { Ok(structure) => self.struct_content = Some(structure), @@ -256,7 +257,7 @@ impl CallToolResponse { self.struct_content() .and_then(|c| serde_json::from_value(c.clone()).map_err(Into::into)) } - + /// Returns a reference to a [`Value`] of structured content pub(crate) fn struct_content(&self) -> Result<&Value, Error> { self.struct_content @@ -271,14 +272,17 @@ mod tests { use crate::error::{Error, ErrorCode}; use super::*; - + #[test] fn it_converts_from_str() { let resp: CallToolResponse = "test".into(); - + let json = serde_json::to_string(&resp).unwrap(); - - assert_eq!(json, r#"{"content":[{"type":"text","text":"test"}],"isError":false}"#); + + assert_eq!( + json, + r#"{"content":[{"type":"text","text":"test"}],"isError":false}"# + ); } #[test] @@ -287,7 +291,10 @@ mod tests { let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"content":[{"type":"text","text":"test"}],"isError":false}"#); + assert_eq!( + json, + r#"{"content":[{"type":"text","text":"test"}],"isError":false}"# + ); } #[test] @@ -296,16 +303,23 @@ mod tests { let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"content":[{"type":"text","text":"test"}],"isError":true}"#); + assert_eq!( + json, + r#"{"content":[{"type":"text","text":"test"}],"isError":true}"# + ); } #[test] fn it_converts_from_err_result() { - let resp: CallToolResponse = Err::(Error::new(ErrorCode::InternalError, "test")).into(); + let resp: CallToolResponse = + Err::(Error::new(ErrorCode::InternalError, "test")).into(); let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"content":[{"type":"text","text":"test"}],"isError":true}"#); + assert_eq!( + json, + r#"{"content":[{"type":"text","text":"test"}],"isError":true}"# + ); } #[test] @@ -314,7 +328,10 @@ mod tests { let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"content":[{"type":"text","text":"test"}],"isError":false}"#); + assert_eq!( + json, + r#"{"content":[{"type":"text","text":"test"}],"isError":false}"# + ); } #[test] @@ -323,7 +340,10 @@ mod tests { let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"content":[{"type":"text","text":"test"}],"isError":false}"#); + assert_eq!( + json, + r#"{"content":[{"type":"text","text":"test"}],"isError":false}"# + ); } #[test] @@ -341,7 +361,10 @@ mod tests { let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"content":[{"type":"text","text":"test 1"},{"type":"text","text":"test 2"}],"isError":false}"#); + assert_eq!( + json, + r#"{"content":[{"type":"text","text":"test 1"},{"type":"text","text":"test 2"}],"isError":false}"# + ); } #[test] @@ -360,7 +383,10 @@ mod tests { let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"content":[{"type":"text","text":"{\"msg\":\"test\"}"}],"isError":false}"#); + assert_eq!( + json, + r#"{"content":[{"type":"text","text":"{\"msg\":\"test\"}"}],"isError":false}"# + ); } #[test] @@ -370,7 +396,10 @@ mod tests { let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"content":[{"type":"text","text":"{\"msg\":\"test\"}"}],"structuredContent":{"msg":"test"},"isError":false}"#); + assert_eq!( + json, + r#"{"content":[{"type":"text","text":"{\"msg\":\"test\"}"}],"structuredContent":{"msg":"test"},"isError":false}"# + ); } #[test] @@ -380,55 +409,70 @@ mod tests { let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"content":[{"type":"text","text":"{\"msg\":\"test\"}"}],"structuredContent":{"msg":"test"},"isError":false}"#); + assert_eq!( + json, + r#"{"content":[{"type":"text","text":"{\"msg\":\"test\"}"}],"structuredContent":{"msg":"test"},"isError":false}"# + ); } #[test] fn it_creates_with_array_of_structured_content() { let resp = CallToolResponse::array_json([ - Test { msg: "test 1".into() }, - Test { msg: "test 2".into() } + Test { + msg: "test 1".into(), + }, + Test { + msg: "test 2".into(), + }, ]); let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"content":[{"type":"text","text":"{\"msg\":\"test 1\"}"},{"type":"text","text":"{\"msg\":\"test 2\"}"}],"structuredContent":[{"msg":"test 1"},{"msg":"test 2"}],"isError":false}"#); + assert_eq!( + json, + r#"{"content":[{"type":"text","text":"{\"msg\":\"test 1\"}"},{"type":"text","text":"{\"msg\":\"test 2\"}"}],"structuredContent":[{"msg":"test 1"},{"msg":"test 2"}],"isError":false}"# + ); } #[test] fn it_adds_structured_content() { - let resp = CallToolResponse::new(r#"{"msg":"test"}"#) - .with_structure(); + let resp = CallToolResponse::new(r#"{"msg":"test"}"#).with_structure(); let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"content":[{"type":"text","text":"{\"msg\":\"test\"}"}],"structuredContent":{"msg":"test"},"isError":false}"#); + assert_eq!( + json, + r#"{"content":[{"type":"text","text":"{\"msg\":\"test\"}"}],"structuredContent":{"msg":"test"},"isError":false}"# + ); } #[test] fn it_adds_structured_content_for_string_array() { - let resp = CallToolResponse::new(r#"[{"msg":"test 1"},{"msg":"test 2"}]"#) - .with_structure(); + let resp = CallToolResponse::new(r#"[{"msg":"test 1"},{"msg":"test 2"}]"#).with_structure(); let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"content":[{"type":"text","text":"[{\"msg\":\"test 1\"},{\"msg\":\"test 2\"}]"}],"structuredContent":[{"msg":"test 1"},{"msg":"test 2"}],"isError":false}"#); + assert_eq!( + json, + r#"{"content":[{"type":"text","text":"[{\"msg\":\"test 1\"},{\"msg\":\"test 2\"}]"}],"structuredContent":[{"msg":"test 1"},{"msg":"test 2"}],"isError":false}"# + ); } #[test] fn it_adds_structured_content_for_array() { - let resp = CallToolResponse::array([ - r#"{"msg":"test 1"}"#, - r#"{"msg":"test 2"}"# - ]).with_structure(); + let resp = CallToolResponse::array([r#"{"msg":"test 1"}"#, r#"{"msg":"test 2"}"#]) + .with_structure(); let json = serde_json::to_string(&resp).unwrap(); - assert_eq!(json, r#"{"content":[{"type":"text","text":"{\"msg\":\"test 1\"}"},{"type":"text","text":"{\"msg\":\"test 2\"}"}],"structuredContent":[{"msg":"test 1"},{"msg":"test 2"}],"isError":false}"#); + assert_eq!( + json, + r#"{"content":[{"type":"text","text":"{\"msg\":\"test 1\"}"},{"type":"text","text":"{\"msg\":\"test 2\"}"}],"structuredContent":[{"msg":"test 1"},{"msg":"test 2"}],"isError":false}"# + ); } - + #[derive(Serialize)] struct Test { - msg: String + msg: String, } -} \ No newline at end of file +} diff --git a/neva/src/types/tool/from_request.rs b/neva/src/types/tool/from_request.rs index ce99ea5..7bd1859 100644 --- a/neva/src/types/tool/from_request.rs +++ b/neva/src/types/tool/from_request.rs @@ -1,5 +1,5 @@ -use crate::error::Error; use super::CallToolRequestParams; +use crate::error::Error; use crate::types::helpers::extract::{RequestArgument, extract_arg}; impl TryFrom for () { @@ -15,7 +15,7 @@ macro_rules! impl_from_call_tool_params { ($($T: ident),*) => { impl<$($T: RequestArgument),+> TryFrom for ($($T,)+) { type Error = Error; - + #[inline] fn try_from(params: CallToolRequestParams) -> Result { let args = params.args.unwrap_or_default(); @@ -39,25 +39,23 @@ impl_from_call_tool_params! { T1, T2, T3, T4, T5 } #[cfg(test)] mod tests { - use std::collections::HashMap; - use serde_json::{json, Value}; - use crate::types::{Meta, ProgressToken, request::RequestParamsMeta}; use super::*; - + use crate::types::{Meta, ProgressToken, request::RequestParamsMeta}; + use serde_json::{Value, json}; + use std::collections::HashMap; + #[test] fn it_extracts_args() { let params = CallToolRequestParams { - args: Some(HashMap::from([ - ("arg".into(), json!({ "test": 1 })) - ])), + args: Some(HashMap::from([("arg".into(), json!({ "test": 1 }))])), meta: None, name: "tool".into(), #[cfg(feature = "tasks")] - task: None + task: None, }; - + let arg: (Value,) = params.try_into().unwrap(); - + assert_eq!(arg.0, json!({ "test": 1 })); } @@ -65,9 +63,7 @@ mod tests { #[allow(clippy::useless_conversion)] fn it_extracts_params() { let params = CallToolRequestParams { - args: Some(HashMap::from([ - ("arg".into(), json!(22)) - ])), + args: Some(HashMap::from([("arg".into(), json!(22))])), meta: None, name: "tool".into(), #[cfg(feature = "tasks")] @@ -77,9 +73,7 @@ mod tests { let arg: CallToolRequestParams = params.try_into().unwrap(); assert_eq!(arg.name, "tool"); - assert_eq!(arg.args, Some(HashMap::from([ - ("arg".into(), json!(22)) - ]))); + assert_eq!(arg.args, Some(HashMap::from([("arg".into(), json!(22))]))); } #[test] @@ -90,7 +84,7 @@ mod tests { progress_token: None, context: None, #[cfg(feature = "tasks")] - task: None + task: None, }), args: None, #[cfg(feature = "tasks")] @@ -110,7 +104,7 @@ mod tests { progress_token: Some(ProgressToken::Number(5)), context: None, #[cfg(feature = "tasks")] - task: None + task: None, }), args: None, #[cfg(feature = "tasks")] @@ -121,4 +115,4 @@ mod tests { assert_eq!(arg.0.0, ProgressToken::Number(5)); } -} \ No newline at end of file +} diff --git a/neva_macros/src/client.rs b/neva_macros/src/client.rs index d138eac..849207b 100644 --- a/neva_macros/src/client.rs +++ b/neva_macros/src/client.rs @@ -1,8 +1,8 @@ //! Macros for MCP clients -use syn::ItemFn; use proc_macro2::TokenStream; use quote::quote; +use syn::ItemFn; pub(super) fn expand_elicitation(function: &ItemFn) -> syn::Result { let func_name = &function.sig.ident; @@ -40,4 +40,4 @@ pub(super) fn expand_sampling(function: &ItemFn) -> syn::Result { }; Ok(expanded) -} \ No newline at end of file +} diff --git a/neva_macros/src/lib.rs b/neva_macros/src/lib.rs index 155cc49..91fc32b 100644 --- a/neva_macros/src/lib.rs +++ b/neva_macros/src/lib.rs @@ -1,16 +1,16 @@ //! A proc macro implementation for configuring tool -use syn::{parse_macro_input, punctuated::Punctuated, Token}; use proc_macro::TokenStream; +use syn::{Token, parse_macro_input, punctuated::Punctuated}; -#[cfg(feature = "server")] -mod server; #[cfg(feature = "client")] mod client; +#[cfg(feature = "server")] +mod server; mod shared; /// Maps the function to a tool -/// +/// /// # Parameters /// * `title` - Tool title. /// * `descr` - Tool description. @@ -21,32 +21,32 @@ mod shared; /// * `middleware` - Middleware list to apply to the tool. /// * `task_support` - Specifies task augmentation support for this tool. /// * `no_schema` - Explicitly disables input schema generation if it's not set in `input_schema`. -/// +/// /// # Simple Example /// ```ignore /// use neva::prelude::*; -/// +/// /// #[tool(descr = "Hello world tool")] /// async fn say_hello() -> &'static str { /// "Hello, world!" /// } /// ``` -/// +/// /// # Full Example /// ```ignore /// use neva::prelude::*; -/// +/// /// #[derive(serde::Deserialize)] /// struct Payload { /// say: String, /// name: String, /// } -/// +/// /// #[json_schema(ser)] /// struct Results { /// message: String, /// } -/// +/// /// #[tool( /// title = "JSON Hello", /// descr = "Say from JSON", @@ -60,14 +60,14 @@ mod shared; /// }"#, /// input_schema = r#"{ /// "properties": { -/// "arg": { -/// "type": "object", -/// "description": "A message in JSON format", +/// "arg": { +/// "type": "object", +/// "description": "A message in JSON format", /// "properties": { /// "say": { "type": "string", "description": "A message to say" }, /// "name": { "type": "string", "description": "A name to whom say Hello" } /// }, -/// "required": ["say", "name"] +/// "required": ["say", "name"] /// } /// }, /// "required": ["arg"] @@ -105,11 +105,11 @@ pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream { /// * `mime` - Resource MIME type. /// * `annotations` - Resource content arbitrary [metadata](https://docs.rs/neva/latest/neva/types/struct.Annotations.html). /// * `roles` & `permissions` - Define which users can read the resource when using Streamable HTTP transport with OAuth. -/// +/// /// # Simple Example /// ```ignore /// use neva::prelude::*; -/// +/// /// #[resource(uri = "res://{name}"] /// async fn get_res(name: String) -> TextResourceContents { /// TextResourceContents::new( @@ -121,7 +121,7 @@ pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream { /// # Full Example /// ```ignore /// use neva::prelude::*; -/// +/// /// #[resource( /// uri = "res://{name}", /// title = "Read resource", @@ -174,11 +174,11 @@ pub fn resources(attr: TokenStream, item: TokenStream) -> TokenStream { /// * `no_args` - Explicitly disables argument generation if it's not set in `args`. /// * `middleware` - Middleware list to apply to the prompt. /// * `roles` & `permissions` - Define which users can read the resource when using Streamable HTTP transport with OAuth. -/// +/// /// # Simple Example /// ```ignore /// use neva::prelude::*; -/// +/// /// #[prompt(descr = "Analyze code for potential improvements"] /// async fn analyze_code(lang: String) -> PromptMessage { /// PromptMessage::user() @@ -197,8 +197,8 @@ pub fn resources(attr: TokenStream, item: TokenStream) -> TokenStream { /// permissions = ["read"], /// args = r#"[ /// { -/// "name": "lang", -/// "description": "A language to use", +/// "name": "lang", +/// "description": "A language to use", /// "required": true /// } /// ]"# @@ -225,11 +225,11 @@ pub fn prompt(attr: TokenStream, item: TokenStream) -> TokenStream { /// # Parameters /// * `command` - Command name. /// * `middleware` - Middleware list to apply to the command. -/// +/// /// # Example /// ```ignore /// use neva::prelude::*; -/// +/// /// #[handler(command = "ping")] /// async fn ping_handler() { /// println!("pong"); @@ -261,18 +261,18 @@ pub fn completion(attr: TokenStream, item: TokenStream) -> TokenStream { } /// Maps the elicitation handler function -/// +/// /// # Example /// ```ignore /// use neva::prelude::*; -/// +/// /// #[json_schema(ser)] /// struct Contact { /// name: String, /// email: String, /// age: u32, /// } -/// +/// /// #[elicitation] /// async fn elicitation_handler(params: ElicitRequestParams) -> impl Into { /// let contact = Contact { @@ -294,11 +294,11 @@ pub fn elicitation(_: TokenStream, item: TokenStream) -> TokenStream { } /// Maps the sampling handler function -/// +/// /// # Example /// ```ignore /// use neva::prelude::*; -/// +/// /// #[sampling] /// async fn sampling_handler(params: CreateMessageRequestParams) -> CreateMessageResult { /// CreateMessageResult::assistant() @@ -322,11 +322,11 @@ pub fn sampling(_: TokenStream, item: TokenStream) -> TokenStream { /// * `serde` - Applies also `derive(serde::Serialize, serde::Deserialize)`. /// * `ser` - Applies also `derive(serde::Serialize)`. /// * `de` - Applies also `derive(serde::Deserialize)`. -/// +/// /// # Example /// ```ignore /// use neva::prelude::*; -/// +/// /// #[json_schema(ser)] /// struct Results { /// message: String, diff --git a/neva_macros/src/server.rs b/neva_macros/src/server.rs index c7b9040..10d6fbe 100644 --- a/neva_macros/src/server.rs +++ b/neva_macros/src/server.rs @@ -1,23 +1,26 @@ //! Macros for MCP servers -use syn::{Expr, Lit, Type}; -use syn::{ItemFn, Meta, punctuated::Punctuated, token::Comma}; use proc_macro2::TokenStream; use quote::quote; +use syn::{Expr, Lit, Type}; +use syn::{ItemFn, Meta, punctuated::Punctuated, token::Comma}; -pub(crate) mod tool; -pub(crate) mod resource; pub(super) mod prompt; +pub(crate) mod resource; +pub(crate) mod tool; -pub(super) fn expand_handler(attr: &Punctuated, function: &ItemFn) -> syn::Result { +pub(super) fn expand_handler( + attr: &Punctuated, + function: &ItemFn, +) -> syn::Result { let func_name = &function.sig.ident; let mut command = None; let mut middleware = None; for meta in attr { match &meta { - Meta::Path(_) => {}, - Meta::List(_) => {}, + Meta::Path(_) => {} + Meta::List(_) => {} Meta::NameValue(nv) => { if let Some(ident) = nv.path.get_ident() { match ident.to_string().as_str() { @@ -28,9 +31,9 @@ pub(super) fn expand_handler(attr: &Punctuated, function: &ItemFn) middleware = get_exprs_arr(&nv.value); } _ => {} - } + } } - }, + } } } @@ -42,7 +45,7 @@ pub(super) fn expand_handler(attr: &Punctuated, function: &ItemFn) }); quote! { #(#mw_calls)* } }); - + // Expand the function and apply the tool functionality let expanded = quote! { // Original function @@ -61,23 +64,27 @@ pub(super) fn expand_handler(attr: &Punctuated, function: &ItemFn) Ok(expanded) } -pub(super) fn expand_completion(attr: &Punctuated, function: &ItemFn) -> syn::Result { +pub(super) fn expand_completion( + attr: &Punctuated, + function: &ItemFn, +) -> syn::Result { let func_name = &function.sig.ident; let mut middleware = None; for meta in attr { match &meta { - Meta::Path(_) => {}, - Meta::List(_) => {}, + Meta::Path(_) => {} + Meta::List(_) => {} Meta::NameValue(nv) => { if let Some(ident) = nv.path.get_ident() - && let "middleware" = ident.to_string().as_str() { + && let "middleware" = ident.to_string().as_str() + { middleware = get_exprs_arr(&nv.value); } - }, + } } } - + let module_name = syn::Ident::new(&format!("map_{func_name}"), func_name.span()); let middleware_code = middleware.map(|mws| { let mw_calls = mws.iter().map(|mw| { @@ -111,11 +118,7 @@ pub(super) fn get_arg_type(t: &Type) -> &str { Type::Slice(_) => "slice", Type::Reference(_) => "none", Type::Path(type_path) => { - let type_ident = type_path.path.segments - .last() - .unwrap() - .ident - .to_string(); + let type_ident = type_path.path.segments.last().unwrap().ident.to_string(); match type_ident.as_str() { "String" => "string", "str" => "string", @@ -140,24 +143,29 @@ pub(super) fn get_arg_type(t: &Type) -> &str { #[inline] pub(super) fn get_inner_type_from_generic(ty: &Type) -> Option<&Type> { if let Type::Path(type_path) = ty - && let Some(segment) = type_path.path.segments.last() { + && let Some(segment) = type_path.path.segments.last() + { match segment.ident.to_string().as_str() { "Result" | "Option" | "Vec" | "Meta" | "Json" => { if let syn::PathArguments::AngleBracketed(args) = &segment.arguments - && let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() { + && let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() + { return Some(inner_ty); } } - _ => {} + _ => {} } } None } - #[inline] pub(super) fn get_str_param(value: &Expr) -> Option { - if let Expr::Lit(syn::ExprLit { lit: Lit::Str(lit_str), .. }) = value { + if let Expr::Lit(syn::ExprLit { + lit: Lit::Str(lit_str), + .. + }) = value + { Some(lit_str.value()) } else { None @@ -166,7 +174,11 @@ pub(super) fn get_str_param(value: &Expr) -> Option { #[inline] pub(super) fn get_bool_param(value: &Expr) -> bool { - if let Expr::Lit(syn::ExprLit { lit: Lit::Bool(lit), .. }) = value { + if let Expr::Lit(syn::ExprLit { + lit: Lit::Bool(lit), + .. + }) = value + { lit.value } else { false @@ -176,13 +188,18 @@ pub(super) fn get_bool_param(value: &Expr) -> bool { #[inline] pub(super) fn get_params_arr(value: &Expr) -> Option> { match value { - Expr::Lit(syn::ExprLit { lit: Lit::Str(lit_str), .. }) => { - Some(vec![lit_str.value()]) - } + Expr::Lit(syn::ExprLit { + lit: Lit::Str(lit_str), + .. + }) => Some(vec![lit_str.value()]), Expr::Array(array) => { let mut role_list = Vec::new(); for elem in &array.elems { - if let Expr::Lit(syn::ExprLit { lit: Lit::Str(lit_str), .. }) = elem { + if let Expr::Lit(syn::ExprLit { + lit: Lit::Str(lit_str), + .. + }) = elem + { role_list.push(lit_str.value()); } } @@ -192,7 +209,7 @@ pub(super) fn get_params_arr(value: &Expr) -> Option> { None } } - _ => None + _ => None, } } @@ -204,14 +221,8 @@ pub(super) fn get_exprs_arr(value: &Expr) -> Option> { for elem in &array.elems { exprs.push(elem.clone()); } - if !exprs.is_empty() { - Some(exprs) - } else { - None - } - } - expr => { - Some(vec![expr.clone()]) + if !exprs.is_empty() { Some(exprs) } else { None } } + expr => Some(vec![expr.clone()]), } } diff --git a/neva_macros/src/server/prompt.rs b/neva_macros/src/server/prompt.rs index 62b7549..99353a9 100644 --- a/neva_macros/src/server/prompt.rs +++ b/neva_macros/src/server/prompt.rs @@ -1,11 +1,14 @@ //! Macros for MCP prompts -use syn::{ItemFn, FnArg, Pat, Meta, punctuated::Punctuated, token::Comma}; -use super::{get_str_param, get_params_arr, get_exprs_arr, get_bool_param, get_arg_type}; +use super::{get_arg_type, get_bool_param, get_exprs_arr, get_params_arr, get_str_param}; use proc_macro2::TokenStream; use quote::quote; +use syn::{FnArg, ItemFn, Meta, Pat, punctuated::Punctuated, token::Comma}; -pub(crate) fn expand(attr: &Punctuated, function: &ItemFn) -> syn::Result { +pub(crate) fn expand( + attr: &Punctuated, + function: &ItemFn, +) -> syn::Result { let func_name = &function.sig.ident; let mut description = None; let mut args = None; @@ -21,7 +24,7 @@ pub(crate) fn expand(attr: &Punctuated, function: &ItemFn) -> syn:: if path.is_ident("no_args") { no_args = true; } - }, + } Meta::NameValue(nv) => { if let Some(ident) = nv.path.get_ident() { match ident.to_string().as_str() { @@ -49,7 +52,7 @@ pub(crate) fn expand(attr: &Punctuated, function: &ItemFn) -> syn:: _ => {} } } - }, + } Meta::List(_) => {} } } @@ -70,7 +73,8 @@ pub(crate) fn expand(attr: &Punctuated, function: &ItemFn) -> syn:: let mut arg_entries = Vec::new(); for arg in &function.sig.inputs { if let FnArg::Typed(pat_type) = arg - && let Pat::Ident(pat_ident) = &*pat_type.pat { + && let Pat::Ident(pat_ident) = &*pat_type.pat + { let arg_name = pat_ident.ident.to_string(); let arg_type = get_arg_type(&pat_type.ty); if !arg_type.eq("none") { @@ -112,7 +116,7 @@ pub(crate) fn expand(attr: &Punctuated, function: &ItemFn) -> syn:: let expanded = quote! { // Original function #function - + fn #module_name(app: &mut App) { app #middleware_code @@ -129,4 +133,4 @@ pub(crate) fn expand(attr: &Punctuated, function: &ItemFn) -> syn:: }; Ok(expanded) -} \ No newline at end of file +} diff --git a/neva_macros/src/server/resource.rs b/neva_macros/src/server/resource.rs index dd8cfa9..e014040 100644 --- a/neva_macros/src/server/resource.rs +++ b/neva_macros/src/server/resource.rs @@ -1,11 +1,14 @@ //! Macros for MCP server resources -use syn::{ItemFn, Meta, punctuated::Punctuated, token::Comma}; -use super::{get_str_param, get_params_arr, get_exprs_arr}; +use super::{get_exprs_arr, get_params_arr, get_str_param}; use proc_macro2::TokenStream; use quote::quote; +use syn::{ItemFn, Meta, punctuated::Punctuated, token::Comma}; -pub(crate) fn expand_resource(attr: &Punctuated, function: &ItemFn) -> syn::Result { +pub(crate) fn expand_resource( + attr: &Punctuated, + function: &ItemFn, +) -> syn::Result { let func_name = &function.sig.ident; let mut uri = None; let mut title = None; @@ -17,8 +20,8 @@ pub(crate) fn expand_resource(attr: &Punctuated, function: &ItemFn) for meta in attr { match &meta { - Meta::Path(_) => {}, - Meta::List(_) => {}, + Meta::Path(_) => {} + Meta::List(_) => {} Meta::NameValue(nv) => { if let Some(ident) = nv.path.get_ident() { match ident.to_string().as_str() { @@ -46,7 +49,7 @@ pub(crate) fn expand_resource(attr: &Punctuated, function: &ItemFn) _ => {} } } - }, + } } } @@ -66,10 +69,10 @@ pub(crate) fn expand_resource(attr: &Punctuated, function: &ItemFn) }); let annotations_code = annotations.map(|annotations_json| { - quote! { + quote! { .with_annotations(|_| { neva::types::Annotations::from_json_str(#annotations_json) - }) + }) } }); @@ -107,23 +110,27 @@ pub(crate) fn expand_resource(attr: &Punctuated, function: &ItemFn) Ok(expanded) } -pub(crate) fn expand_resources(attr: &Punctuated, function: &ItemFn) -> syn::Result { +pub(crate) fn expand_resources( + attr: &Punctuated, + function: &ItemFn, +) -> syn::Result { let func_name = &function.sig.ident; let mut middleware = None; for meta in attr { match &meta { - Meta::Path(_) => {}, - Meta::List(_) => {}, + Meta::Path(_) => {} + Meta::List(_) => {} Meta::NameValue(nv) => { if let Some(ident) = nv.path.get_ident() - && let "middleware" = ident.to_string().as_str() { + && let "middleware" = ident.to_string().as_str() + { middleware = get_exprs_arr(&nv.value); } - }, + } } } - + let module_name = syn::Ident::new(&format!("map_{func_name}"), func_name.span()); let middleware_code = middleware.map(|mws| { let mw_calls = mws.iter().map(|mw| { @@ -131,7 +138,7 @@ pub(crate) fn expand_resources(attr: &Punctuated, function: &ItemFn }); quote! { #(#mw_calls)* } }); - + // Expand the function and apply the tool functionality let expanded = quote! { // Original function @@ -148,4 +155,4 @@ pub(crate) fn expand_resources(attr: &Punctuated, function: &ItemFn }; Ok(expanded) -} \ No newline at end of file +} diff --git a/neva_macros/src/server/tool.rs b/neva_macros/src/server/tool.rs index 6312bb5..1ed9b78 100644 --- a/neva_macros/src/server/tool.rs +++ b/neva_macros/src/server/tool.rs @@ -1,11 +1,17 @@ //! Macros for MCP server tools -use syn::{ItemFn, FnArg, Pat, Meta, ReturnType, punctuated::Punctuated, token::Comma}; -use super::{get_str_param, get_params_arr, get_exprs_arr, get_bool_param, get_arg_type, get_inner_type_from_generic}; +use super::{ + get_arg_type, get_bool_param, get_exprs_arr, get_inner_type_from_generic, get_params_arr, + get_str_param, +}; use proc_macro2::TokenStream; use quote::quote; +use syn::{FnArg, ItemFn, Meta, Pat, ReturnType, punctuated::Punctuated, token::Comma}; -pub(crate) fn expand(attr: &Punctuated, function: &ItemFn) -> syn::Result { +pub(crate) fn expand( + attr: &Punctuated, + function: &ItemFn, +) -> syn::Result { let func_name = &function.sig.ident; let mut description = None; let mut input_schema = None; @@ -24,7 +30,7 @@ pub(crate) fn expand(attr: &Punctuated, function: &ItemFn) -> syn:: if path.is_ident("no_schema") { no_schema = true; } - }, + } Meta::NameValue(nv) => { if let Some(ident) = nv.path.get_ident() { match ident.to_string().as_str() { @@ -61,7 +67,7 @@ pub(crate) fn expand(attr: &Punctuated, function: &ItemFn) -> syn:: _ => {} } } - }, + } Meta::List(_) => {} } } @@ -87,7 +93,8 @@ pub(crate) fn expand(attr: &Punctuated, function: &ItemFn) -> syn:: for arg in &function.sig.inputs { if let FnArg::Typed(pat_type) = arg - && let Pat::Ident(pat_ident) = &*pat_type.pat { + && let Pat::Ident(pat_ident) = &*pat_type.pat + { let arg_name = pat_ident.ident.to_string(); let arg_type = get_arg_type(&pat_type.ty); if !arg_type.eq("none") { @@ -124,7 +131,7 @@ pub(crate) fn expand(attr: &Punctuated, function: &ItemFn) -> syn:: ReturnType::Default => { // Function returns () - no schema needed quote! {} - }, + } ReturnType::Type(_, return_type) => { let type_str = get_arg_type(return_type); if type_str == "object" { @@ -138,7 +145,7 @@ pub(crate) fn expand(attr: &Punctuated, function: &ItemFn) -> syn:: .with_output_schema(|schema| { schema.with_schema::<#return_type>() }) - } + }, } } else if type_str == "array" { // For array types @@ -149,15 +156,15 @@ pub(crate) fn expand(attr: &Punctuated, function: &ItemFn) -> syn:: } } } - } else { + } else { quote! {} }; let annotations_code = annotations.map(|annotations_json| { - quote! { + quote! { .with_annotations(|_| { neva::types::ToolAnnotations::from_json_str(#annotations_json) - }) + }) } }); @@ -181,7 +188,7 @@ pub(crate) fn expand(attr: &Punctuated, function: &ItemFn) -> syn:: let task_support_code = task_support.map(|ts| { quote! { .with_task_support(#ts) } }); - + let module_name = syn::Ident::new(&format!("map_{func_name}"), func_name.span()); // Expand the function and apply the tool functionality @@ -208,4 +215,4 @@ pub(crate) fn expand(attr: &Punctuated, function: &ItemFn) -> syn:: }; Ok(expanded) -} \ No newline at end of file +} diff --git a/neva_macros/src/shared.rs b/neva_macros/src/shared.rs index 01c2382..5c7074b 100644 --- a/neva_macros/src/shared.rs +++ b/neva_macros/src/shared.rs @@ -1,10 +1,13 @@ //! Shared macros for MCP clients and servers -use syn::{Path, punctuated::Punctuated, token::Comma}; use proc_macro2::TokenStream; use quote::quote; +use syn::{Path, punctuated::Punctuated, token::Comma}; -pub(super) fn expand_json_schema(attr: &Punctuated, input: &syn::DeriveInput) -> syn::Result { +pub(super) fn expand_json_schema( + attr: &Punctuated, + input: &syn::DeriveInput, +) -> syn::Result { let mut include_ser = false; let mut include_de = false; let mut include_debug = false; @@ -44,4 +47,4 @@ pub(super) fn expand_json_schema(attr: &Punctuated, input: &syn::De }; Ok(expanded) -} \ No newline at end of file +}