diff --git a/README.md b/README.md index 76800430..ded7586b 100644 --- a/README.md +++ b/README.md @@ -74,7 +74,7 @@ func main() { // Connect to a server over stdin/stdout transport := mcp.NewCommandTransport(exec.Command("myserver")) - session, err := client.Connect(ctx, transport) + session, err := client.Connect(ctx, transport, nil) if err != nil { log.Fatal(err) } diff --git a/examples/client/listfeatures/main.go b/examples/client/listfeatures/main.go index caf21bfe..00aa459b 100644 --- a/examples/client/listfeatures/main.go +++ b/examples/client/listfeatures/main.go @@ -41,7 +41,7 @@ func main() { ctx := context.Background() cmd := exec.Command(args[0], args[1:]...) client := mcp.NewClient(&mcp.Implementation{Name: "mcp-client", Version: "v1.0.0"}, nil) - cs, err := client.Connect(ctx, mcp.NewCommandTransport(cmd)) + cs, err := client.Connect(ctx, mcp.NewCommandTransport(cmd), nil) if err != nil { log.Fatal(err) } diff --git a/internal/readme/client/client.go b/internal/readme/client/client.go index 666ee925..57ec54fa 100644 --- a/internal/readme/client/client.go +++ b/internal/readme/client/client.go @@ -21,7 +21,7 @@ func main() { // Connect to a server over stdin/stdout transport := mcp.NewCommandTransport(exec.Command("myserver")) - session, err := client.Connect(ctx, transport) + session, err := client.Connect(ctx, transport, nil) if err != nil { log.Fatal(err) } diff --git a/mcp/client.go b/mcp/client.go index d3139f54..5798dc5a 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -71,10 +71,11 @@ type ClientOptions struct { // bind implements the binder[*ClientSession] interface, so that Clients can // be connected using [connect]. -func (c *Client) bind(conn *jsonrpc2.Connection) *ClientSession { - cs := &ClientSession{ - conn: conn, - client: c, +func (c *Client) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *clientSessionState) *ClientSession { + assert(mcpConn != nil && conn != nil, "nil connection") + cs := &ClientSession{conn: conn, mcpConn: mcpConn, client: c} + if state != nil { + cs.state = *state } c.mu.Lock() defer c.mu.Unlock() @@ -101,6 +102,10 @@ func (e unsupportedProtocolVersionError) Error() string { return fmt.Sprintf("unsupported protocol version: %q", e.version) } +// ClientSessionOptions is reserved for future use. +type ClientSessionOptions struct { +} + // Connect begins an MCP session by connecting to a server over the given // transport, and initializing the session. // @@ -108,8 +113,8 @@ func (e unsupportedProtocolVersionError) Error() string { // when it is no longer needed. However, if the connection is closed by the // server, calls or notifications will return an error wrapping // [ErrConnectionClosed]. -func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, err error) { - cs, err = connect(ctx, t, c) +func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptions) (cs *ClientSession, err error) { + cs, err = connect(ctx, t, c, (*clientSessionState)(nil)) if err != nil { return nil, err } @@ -133,9 +138,9 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e if !slices.Contains(supportedProtocolVersions, res.ProtocolVersion) { return nil, unsupportedProtocolVersionError{res.ProtocolVersion} } - cs.initializeResult = res + cs.state.InitializeResult = res if hc, ok := cs.mcpConn.(clientConnection); ok { - hc.initialized(res) + hc.sessionUpdated(cs.state) } if err := handleNotify(ctx, cs, notificationInitialized, &InitializedParams{}); err != nil { _ = cs.Close() @@ -156,22 +161,25 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e // Call [ClientSession.Close] to close the connection, or await server // termination with [ClientSession.Wait]. type ClientSession struct { - conn *jsonrpc2.Connection - client *Client - initializeResult *InitializeResult - keepaliveCancel context.CancelFunc - mcpConn Connection + conn *jsonrpc2.Connection + client *Client + keepaliveCancel context.CancelFunc + mcpConn Connection + + // No mutex is (currently) required to guard the session state, because it is + // only set synchronously during Client.Connect. + state clientSessionState } -func (cs *ClientSession) setConn(c Connection) { - cs.mcpConn = c +type clientSessionState struct { + InitializeResult *InitializeResult } func (cs *ClientSession) ID() string { - if cs.mcpConn == nil { - return "" + if c, ok := cs.mcpConn.(hasSessionID); ok { + return c.SessionID() } - return cs.mcpConn.SessionID() + return "" } // Close performs a graceful close of the connection, preventing new requests diff --git a/mcp/client_list_test.go b/mcp/client_list_test.go index 5b13a4c8..836d4803 100644 --- a/mcp/client_list_test.go +++ b/mcp/client_list_test.go @@ -38,7 +38,7 @@ func TestList(t *testing.T) { } }) t.Run("iterator", func(t *testing.T) { - testIterator(ctx, t, clientSession.Tools(ctx, nil), wantTools) + testIterator(t, clientSession.Tools(ctx, nil), wantTools) }) }) @@ -60,7 +60,7 @@ func TestList(t *testing.T) { } }) t.Run("iterator", func(t *testing.T) { - testIterator(ctx, t, clientSession.Resources(ctx, nil), wantResources) + testIterator(t, clientSession.Resources(ctx, nil), wantResources) }) }) @@ -81,7 +81,7 @@ func TestList(t *testing.T) { } }) t.Run("ResourceTemplatesIterator", func(t *testing.T) { - testIterator(ctx, t, clientSession.ResourceTemplates(ctx, nil), wantResourceTemplates) + testIterator(t, clientSession.ResourceTemplates(ctx, nil), wantResourceTemplates) }) }) @@ -102,12 +102,12 @@ func TestList(t *testing.T) { } }) t.Run("iterator", func(t *testing.T) { - testIterator(ctx, t, clientSession.Prompts(ctx, nil), wantPrompts) + testIterator(t, clientSession.Prompts(ctx, nil), wantPrompts) }) }) } -func testIterator[T any](ctx context.Context, t *testing.T, seq iter.Seq2[*T, error], want []*T) { +func testIterator[T any](t *testing.T, seq iter.Seq2[*T, error], want []*T) { t.Helper() var got []*T for x, err := range seq { diff --git a/mcp/cmd_test.go b/mcp/cmd_test.go index 82a35a80..b4e8f372 100644 --- a/mcp/cmd_test.go +++ b/mcp/cmd_test.go @@ -81,7 +81,7 @@ func TestServerRunContextCancel(t *testing.T) { // send a ping to the server to ensure it's running client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) - session, err := client.Connect(ctx, clientTransport) + session, err := client.Connect(ctx, clientTransport, nil) if err != nil { t.Fatal(err) } @@ -116,7 +116,7 @@ func TestServerInterrupt(t *testing.T) { cmd := createServerCommand(t, "default") client := mcp.NewClient(testImpl, nil) - _, err := client.Connect(ctx, mcp.NewCommandTransport(cmd)) + _, err := client.Connect(ctx, mcp.NewCommandTransport(cmd), nil) if err != nil { t.Fatal(err) } @@ -189,7 +189,7 @@ func TestCmdTransport(t *testing.T) { cmd := createServerCommand(t, "default") client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) - session, err := client.Connect(ctx, mcp.NewCommandTransport(cmd)) + session, err := client.Connect(ctx, mcp.NewCommandTransport(cmd), nil) if err != nil { t.Fatal(err) } diff --git a/mcp/conformance_test.go b/mcp/conformance_test.go index 883d8a89..8e6ea1be 100644 --- a/mcp/conformance_test.go +++ b/mcp/conformance_test.go @@ -135,7 +135,7 @@ func runServerTest(t *testing.T, test *conformanceTest) { // Connect the server, and connect the client stream, // but don't connect an actual client. cTransport, sTransport := NewInMemoryTransports() - ss, err := s.Connect(ctx, sTransport) + ss, err := s.Connect(ctx, sTransport, nil) if err != nil { t.Fatal(err) } diff --git a/mcp/example_middleware_test.go b/mcp/example_middleware_test.go index 597b9dcd..c91250c3 100644 --- a/mcp/example_middleware_test.go +++ b/mcp/example_middleware_test.go @@ -114,10 +114,10 @@ func Example_loggingMiddleware() { ctx := context.Background() // Connect server and client - serverSession, _ := server.Connect(ctx, serverTransport) + serverSession, _ := server.Connect(ctx, serverTransport, nil) defer serverSession.Close() - clientSession, _ := client.Connect(ctx, clientTransport) + clientSession, _ := client.Connect(ctx, clientTransport, nil) defer clientSession.Close() // Call the tool to demonstrate logging diff --git a/mcp/logging.go b/mcp/logging.go index 4880e179..4d33097a 100644 --- a/mcp/logging.go +++ b/mcp/logging.go @@ -117,7 +117,7 @@ func (h *LoggingHandler) Enabled(ctx context.Context, level slog.Level) bool { // This is also checked in ServerSession.LoggingMessage, so checking it here // is just an optimization that skips building the JSON. h.ss.mu.Lock() - mcpLevel := h.ss.logLevel + mcpLevel := h.ss.state.LogLevel h.ss.mu.Unlock() return level >= mcpLevelToSlog(mcpLevel) } diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 48e95de2..9e9a6a30 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -104,7 +104,7 @@ func TestEndToEnd(t *testing.T) { s.AddResource(resource2, readHandler) // Connect the server. - ss, err := s.Connect(ctx, st) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } @@ -148,7 +148,7 @@ func TestEndToEnd(t *testing.T) { c.AddRoots(&Root{URI: "file://" + rootAbs}) // Connect the client. - cs, err := c.Connect(ctx, ct) + cs, err := c.Connect(ctx, ct, nil) if err != nil { t.Fatal(err) } @@ -549,13 +549,13 @@ func basicConnection(t *testing.T, config func(*Server)) (*ServerSession, *Clien if config != nil { config(s) } - ss, err := s.Connect(ctx, st) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } c := NewClient(testImpl, nil) - cs, err := c.Connect(ctx, ct) + cs, err := c.Connect(ctx, ct, nil) if err != nil { t.Fatal(err) } @@ -598,7 +598,7 @@ func TestBatching(t *testing.T) { ct, st := NewInMemoryTransports() s := NewServer(testImpl, nil) - _, err := s.Connect(ctx, st) + _, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } @@ -608,7 +608,7 @@ func TestBatching(t *testing.T) { // 'initialize' to block. Therefore, we can only test with a size of 1. // Since batching is being removed, we can probably just delete this. const batchSize = 1 - cs, err := c.Connect(ctx, ct) + cs, err := c.Connect(ctx, ct, nil) if err != nil { t.Fatal(err) } @@ -668,7 +668,7 @@ func TestMiddleware(t *testing.T) { ct, st := NewInMemoryTransports() s := NewServer(testImpl, nil) - ss, err := s.Connect(ctx, st) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } @@ -695,7 +695,7 @@ func TestMiddleware(t *testing.T) { c.AddSendingMiddleware(traceCalls[*ClientSession](&cbuf, "S1"), traceCalls[*ClientSession](&cbuf, "S2")) c.AddReceivingMiddleware(traceCalls[*ClientSession](&cbuf, "R1"), traceCalls[*ClientSession](&cbuf, "R2")) - cs, err := c.Connect(ctx, ct) + cs, err := c.Connect(ctx, ct, nil) if err != nil { t.Fatal(err) } @@ -777,13 +777,13 @@ func TestNoJSONNull(t *testing.T) { ct = NewLoggingTransport(ct, &logbuf) s := NewServer(testImpl, nil) - ss, err := s.Connect(ctx, st) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } c := NewClient(testImpl, nil) - cs, err := c.Connect(ctx, ct) + cs, err := c.Connect(ctx, ct, nil) if err != nil { t.Fatal(err) } @@ -845,7 +845,7 @@ func TestKeepAlive(t *testing.T) { s := NewServer(testImpl, serverOpts) AddTool(s, greetTool(), sayHi) - ss, err := s.Connect(ctx, st) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } @@ -855,7 +855,7 @@ func TestKeepAlive(t *testing.T) { KeepAlive: 100 * time.Millisecond, } c := NewClient(testImpl, clientOpts) - cs, err := c.Connect(ctx, ct) + cs, err := c.Connect(ctx, ct, nil) if err != nil { t.Fatal(err) } @@ -889,7 +889,7 @@ func TestKeepAliveFailure(t *testing.T) { // Server without keepalive (to test one-sided keepalive) s := NewServer(testImpl, nil) AddTool(s, greetTool(), sayHi) - ss, err := s.Connect(ctx, st) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } @@ -899,7 +899,7 @@ func TestKeepAliveFailure(t *testing.T) { KeepAlive: 50 * time.Millisecond, } c := NewClient(testImpl, clientOpts) - cs, err := c.Connect(ctx, ct) + cs, err := c.Connect(ctx, ct, nil) if err != nil { t.Fatal(err) } diff --git a/mcp/server.go b/mcp/server.go index d81dec60..89f3b6c9 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -301,7 +301,7 @@ func (s *Server) listPrompts(_ context.Context, _ *ServerSession, params *ListPr }) } -func (s *Server) getPrompt(ctx context.Context, cc *ServerSession, params *GetPromptParams) (*GetPromptResult, error) { +func (s *Server) getPrompt(ctx context.Context, ss *ServerSession, params *GetPromptParams) (*GetPromptResult, error) { s.mu.Lock() prompt, ok := s.prompts.get(params.Name) s.mu.Unlock() @@ -309,7 +309,7 @@ func (s *Server) getPrompt(ctx context.Context, cc *ServerSession, params *GetPr // TODO: surface the error code over the wire, instead of flattening it into the string. return nil, fmt.Errorf("%s: unknown prompt %q", jsonrpc2.ErrInvalidParams, params.Name) } - return prompt.handler(ctx, cc, params) + return prompt.handler(ctx, ss, params) } func (s *Server) listTools(_ context.Context, _ *ServerSession, params *ListToolsParams) (*ListToolsResult, error) { @@ -518,7 +518,7 @@ func (s *Server) unsubscribe(ctx context.Context, ss *ServerSession, params *Uns // It need not be called on servers that are used for multiple concurrent connections, // as with [StreamableHTTPHandler]. func (s *Server) Run(ctx context.Context, t Transport) error { - ss, err := s.Connect(ctx, t) + ss, err := s.Connect(ctx, t, nil) if err != nil { return err } @@ -539,8 +539,12 @@ func (s *Server) Run(ctx context.Context, t Transport) error { // bind implements the binder[*ServerSession] interface, so that Servers can // be connected using [connect]. -func (s *Server) bind(conn *jsonrpc2.Connection) *ServerSession { - ss := &ServerSession{conn: conn, server: s} +func (s *Server) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *ServerSessionState) *ServerSession { + assert(mcpConn != nil && conn != nil, "nil connection") + ss := &ServerSession{conn: conn, mcpConn: mcpConn, server: s} + if state != nil { + ss.state = *state + } s.mu.Lock() s.sessions = append(s.sessions, ss) s.mu.Unlock() @@ -561,32 +565,50 @@ func (s *Server) disconnect(cc *ServerSession) { } } +// ServerSessionOptions configures the server session. +type ServerSessionOptions struct { + State *ServerSessionState +} + // Connect connects the MCP server over the given transport and starts handling // messages. // // It returns a connection object that may be used to terminate the connection // (with [Connection.Close]), or await client termination (with // [Connection.Wait]). -func (s *Server) Connect(ctx context.Context, t Transport) (*ServerSession, error) { - return connect(ctx, t, s) +// +// If opts.State is non-nil, it is the initial state for the server. +func (s *Server) Connect(ctx context.Context, t Transport, opts *ServerSessionOptions) (*ServerSession, error) { + var state *ServerSessionState + if opts != nil { + state = opts.State + } + return connect(ctx, t, s, state) } +// TODO: (nit) move all ServerSession methods below the ServerSession declaration. func (ss *ServerSession) initialized(ctx context.Context, params *InitializedParams) (Result, error) { + if params == nil { + // Since we use nilness to signal 'initialized' state, we must ensure that + // params are non-nil. + params = new(InitializedParams) + } if ss.server.opts.KeepAlive > 0 { ss.startKeepalive(ss.server.opts.KeepAlive) } - ss.mu.Lock() - hasParams := ss.initializeParams != nil - wasInitialized := ss._initialized - if hasParams { - ss._initialized = true - } - ss.mu.Unlock() + var wasInit, wasInitd bool + ss.updateState(func(state *ServerSessionState) { + wasInit = state.InitializeParams != nil + wasInitd = state.InitializedParams != nil + if wasInit && !wasInitd { + state.InitializedParams = params + } + }) - if !hasParams { + if !wasInit { return nil, fmt.Errorf("%q before %q", notificationInitialized, methodInitialize) } - if wasInitialized { + if wasInitd { return nil, fmt.Errorf("duplicate %q received", notificationInitialized) } return callNotificationHandler(ctx, ss.server.opts.InitializedHandler, ss, params) @@ -615,25 +637,30 @@ func (ss *ServerSession) NotifyProgress(ctx context.Context, params *ProgressNot // Call [ServerSession.Close] to close the connection, or await client // termination with [ServerSession.Wait]. type ServerSession struct { - server *Server - conn *jsonrpc2.Connection - mcpConn Connection - mu sync.Mutex - logLevel LoggingLevel - initializeParams *InitializeParams - _initialized bool - keepaliveCancel context.CancelFunc + server *Server + conn *jsonrpc2.Connection + mcpConn Connection + keepaliveCancel context.CancelFunc // TODO: theory around why keepaliveCancel need not be guarded + + mu sync.Mutex + state ServerSessionState } -func (ss *ServerSession) setConn(c Connection) { - ss.mcpConn = c +func (ss *ServerSession) updateState(mut func(*ServerSessionState)) { + ss.mu.Lock() + mut(&ss.state) + copy := ss.state + ss.mu.Unlock() + if c, ok := ss.mcpConn.(serverConnection); ok { + c.sessionUpdated(copy) + } } func (ss *ServerSession) ID() string { - if ss.mcpConn == nil { - return "" + if c, ok := ss.mcpConn.(hasSessionID); ok { + return c.SessionID() } - return ss.mcpConn.SessionID() + return "" } // Ping pings the client. @@ -657,7 +684,7 @@ func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessag // is below that of the last SetLevel. func (ss *ServerSession) Log(ctx context.Context, params *LoggingMessageParams) error { ss.mu.Lock() - logLevel := ss.logLevel + logLevel := ss.state.LogLevel ss.mu.Unlock() if logLevel == "" { // The spec is unclear, but seems to imply that no log messages are sent until the client @@ -747,7 +774,7 @@ func (ss *ServerSession) getConn() *jsonrpc2.Connection { return ss.conn } // handle invokes the method described by the given JSON RPC request. func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) { ss.mu.Lock() - initialized := ss._initialized + initialized := ss.state.InitializedParams != nil ss.mu.Unlock() // From the spec: // "The client SHOULD NOT send requests other than pings before the server @@ -770,9 +797,9 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam if params == nil { return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams) } - ss.mu.Lock() - ss.initializeParams = params - ss.mu.Unlock() + ss.updateState(func(state *ServerSessionState) { + state.InitializeParams = params + }) // If we support the client's version, reply with it. Otherwise, reply with our // latest version. @@ -796,9 +823,9 @@ func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error } func (ss *ServerSession) setLevel(_ context.Context, params *SetLevelParams) (*emptyResult, error) { - ss.mu.Lock() - defer ss.mu.Unlock() - ss.logLevel = params.Level + ss.updateState(func(state *ServerSessionState) { + state.LogLevel = params.Level + }) return &emptyResult{}, nil } diff --git a/mcp/server_example_test.go b/mcp/server_example_test.go index 3ab7a2a4..241008e9 100644 --- a/mcp/server_example_test.go +++ b/mcp/server_example_test.go @@ -31,13 +31,13 @@ func ExampleServer() { server := mcp.NewServer(&mcp.Implementation{Name: "greeter", Version: "v0.0.1"}, nil) mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) - serverSession, err := server.Connect(ctx, serverTransport) + serverSession, err := server.Connect(ctx, serverTransport, nil) if err != nil { log.Fatal(err) } client := mcp.NewClient(&mcp.Implementation{Name: "client"}, nil) - clientSession, err := client.Connect(ctx, clientTransport) + clientSession, err := client.Connect(ctx, clientTransport, nil) if err != nil { log.Fatal(err) } @@ -62,11 +62,11 @@ func createSessions(ctx context.Context) (*mcp.ClientSession, *mcp.ServerSession server := mcp.NewServer(testImpl, nil) client := mcp.NewClient(testImpl, nil) serverTransport, clientTransport := mcp.NewInMemoryTransports() - serverSession, err := server.Connect(ctx, serverTransport) + serverSession, err := server.Connect(ctx, serverTransport, nil) if err != nil { log.Fatal(err) } - clientSession, err := client.Connect(ctx, clientTransport) + clientSession, err := client.Connect(ctx, clientTransport, nil) if err != nil { log.Fatal(err) } diff --git a/mcp/session.go b/mcp/session.go new file mode 100644 index 00000000..dcf9888c --- /dev/null +++ b/mcp/session.go @@ -0,0 +1,29 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +// hasSessionID is the interface which, if implemented by connections, informs +// the session about their session ID. +// +// TODO(rfindley): remove SessionID methods from connections, when it doesn't +// make sense. Or remove it from the Sessions entirely: why does it even need +// to be exposed? +type hasSessionID interface { + SessionID() string +} + +// ServerSessionState is the state of a session. +type ServerSessionState struct { + // InitializeParams are the parameters from 'initialize'. + InitializeParams *InitializeParams `json:"initializeParams"` + + // InitializedParams are the parameters from 'notifications/initialized'. + InitializedParams *InitializedParams `json:"initializedParams"` + + // LogLevel is the logging level for the session. + LogLevel LoggingLevel `json:"logLevel"` + + // TODO: resource subscriptions +} diff --git a/mcp/sse.go b/mcp/sse.go index bdc4770b..f74a3fb6 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -221,7 +221,7 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { http.Error(w, "no server available", http.StatusBadRequest) return } - ss, err := server.Connect(req.Context(), transport) + ss, err := server.Connect(req.Context(), transport, nil) if err != nil { http.Error(w, "connection failed", http.StatusInternalServerError) return diff --git a/mcp/sse_example_test.go b/mcp/sse_example_test.go index d8ce939b..9a7c8ae7 100644 --- a/mcp/sse_example_test.go +++ b/mcp/sse_example_test.go @@ -37,7 +37,7 @@ func ExampleSSEHandler() { ctx := context.Background() transport := mcp.NewSSEClientTransport(httpServer.URL, nil) client := mcp.NewClient(&mcp.Implementation{Name: "test", Version: "v1.0.0"}, nil) - cs, err := client.Connect(ctx, transport) + cs, err := client.Connect(ctx, transport, nil) if err != nil { log.Fatal(err) } diff --git a/mcp/sse_test.go b/mcp/sse_test.go index 35fdbdbf..79cfacc3 100644 --- a/mcp/sse_test.go +++ b/mcp/sse_test.go @@ -49,7 +49,7 @@ func TestSSEServer(t *testing.T) { }) c := NewClient(testImpl, nil) - cs, err := c.Connect(ctx, clientTransport) + cs, err := c.Connect(ctx, clientTransport, nil) if err != nil { t.Fatal(err) } diff --git a/mcp/streamable.go b/mcp/streamable.go index 108de5d2..72ac7e83 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -42,9 +42,14 @@ type StreamableHTTPHandler struct { sessions map[string]*StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header) } -// StreamableHTTPOptions is a placeholder options struct for future -// configuration of the StreamableHTTP handler. +// StreamableHTTPOptions configures the StreamableHTTPHandler. type StreamableHTTPOptions struct { + // GetSessionID provides the next session ID to use for an incoming request. + // + // If GetSessionID returns an empty string, the session is 'stateless', + // meaning it is not persisted and no session validation is performed. + GetSessionID func() string + // TODO: support configurable session ID generation (?) // TODO: support session retention (?) @@ -66,6 +71,9 @@ func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *Strea if opts != nil { h.opts = *opts } + if h.opts.GetSessionID == nil { + h.opts.GetSessionID = randText + } return h } @@ -138,6 +146,10 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque switch req.Method { case http.MethodPost, http.MethodGet: + if req.Method == http.MethodGet && session == nil { + http.Error(w, "GET requires an active session", http.StatusMethodNotAllowed) + return + } default: w.Header().Set("Allow", "GET, POST") http.Error(w, "unsupported method", http.StatusMethodNotAllowed) @@ -145,23 +157,42 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque } if session == nil { - s := NewStreamableServerTransport(randText(), h.opts.transportOptions) server := h.getServer(req) if server == nil { // The getServer argument to NewStreamableHTTPHandler returned nil. http.Error(w, "no server available", http.StatusBadRequest) return } + sessionID := h.opts.GetSessionID() + s := NewStreamableServerTransport(sessionID, h.opts.transportOptions) + + // To support stateless mode, we initialize the session with a default + // state, so that it doesn't reject subsequent requests. + var connectOpts *ServerSessionOptions + if sessionID == "" { + connectOpts = &ServerSessionOptions{ + State: &ServerSessionState{ + InitializeParams: new(InitializeParams), + InitializedParams: new(InitializedParams), + }, + } + } // Pass req.Context() here, to allow middleware to add context values. // The context is detached in the jsonrpc2 library when handling the // long-running stream. - if _, err := server.Connect(req.Context(), s); err != nil { + ss, err := server.Connect(req.Context(), s, connectOpts) + if err != nil { http.Error(w, "failed connection", http.StatusInternalServerError) return } - h.sessionsMu.Lock() - h.sessions[s.sessionID] = s - h.sessionsMu.Unlock() + if sessionID == "" { + // Stateless mode: close the session when the request exits. + defer ss.Close() // close the fake session after handling the request + } else { + h.sessionsMu.Lock() + h.sessions[s.sessionID] = s + h.sessionsMu.Unlock() + } session = s } @@ -485,7 +516,9 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R func (t *StreamableServerTransport) respondJSON(stream *stream, w http.ResponseWriter, req *http.Request) { w.Header().Set("Cache-Control", "no-cache, no-transform") w.Header().Set("Content-Type", "application/json") - w.Header().Set(sessionIDHeader, t.sessionID) + if t.sessionID != "" { + w.Header().Set(sessionIDHeader, t.sessionID) + } var msgs []json.RawMessage ctx := req.Context() @@ -524,7 +557,9 @@ func (t *StreamableServerTransport) respondSSE(stream *stream, w http.ResponseWr w.Header().Set("Cache-Control", "no-cache, no-transform") w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler] w.Header().Set("Connection", "keep-alive") - w.Header().Set(sessionIDHeader, t.sessionID) + if t.sessionID != "" { + w.Header().Set(sessionIDHeader, t.sessionID) + } // write one event containing data. write := func(data []byte) bool { @@ -893,9 +928,11 @@ type streamableClientConn struct { sessionID string } -func (c *streamableClientConn) initialized(res *InitializeResult) { +var _ clientConnection = (*streamableClientConn)(nil) + +func (c *streamableClientConn) sessionUpdated(state clientSessionState) { c.mu.Lock() - c.initializedResult = res + c.initializedResult = state.InitializeResult c.mu.Unlock() // Start the persistent SSE listener as soon as we have the initialized diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 24368b00..54803939 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -81,7 +81,7 @@ func TestStreamableTransports(t *testing.T) { HTTPClient: httpClient, }) client := NewClient(testImpl, nil) - session, err := client.Connect(ctx, transport) + session, err := client.Connect(ctx, transport, nil) if err != nil { t.Fatalf("client.Connect() failed: %v", err) } @@ -173,7 +173,7 @@ func TestClientReplay(t *testing.T) { notifications <- params.Message }, }) - clientSession, err := client.Connect(ctx, NewStreamableClientTransport(proxy.URL, nil)) + clientSession, err := client.Connect(ctx, NewStreamableClientTransport(proxy.URL, nil), nil) if err != nil { t.Fatalf("client.Connect() failed: %v", err) } @@ -239,7 +239,7 @@ func TestServerInitiatedSSE(t *testing.T) { notifications <- "toolListChanged" }, }) - clientSession, err := client.Connect(ctx, NewStreamableClientTransport(httpServer.URL, nil)) + clientSession, err := client.Connect(ctx, NewStreamableClientTransport(httpServer.URL, nil), nil) if err != nil { t.Fatalf("client.Connect() failed: %v", err) } @@ -767,12 +767,12 @@ func TestStreamableClientTransportApplicationJSON(t *testing.T) { transport := NewStreamableClientTransport(httpServer.URL, nil) client := NewClient(testImpl, nil) - session, err := client.Connect(ctx, transport) + session, err := client.Connect(ctx, transport, nil) if err != nil { t.Fatalf("client.Connect() failed: %v", err) } defer session.Close() - if diff := cmp.Diff(initResult, session.initializeResult); diff != "" { + if diff := cmp.Diff(initResult, session.state.InitializeResult); diff != "" { t.Errorf("mismatch (-want, +got):\n%s", diff) } } @@ -821,3 +821,73 @@ func TestEventID(t *testing.T) { }) } } +func TestStreamableStateless(t *testing.T) { + // Test stateless mode behavior + ctx := context.Background() + + // This version of sayHi doesn't make a ping request (we can't respond to + // that request from our client). + sayHi := func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[hiParams]) (*CallToolResultFor[any], error) { + return &CallToolResultFor[any]{Content: []Content{&TextContent{Text: "hi " + params.Arguments.Name}}}, nil + } + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) + + // Test stateless mode. + handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ + GetSessionID: func() string { return "" }, + }) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + checkRequest := func(body string) { + // Verify we can call tools/list directly without initialization in stateless mode + req, err := http.NewRequestWithContext(ctx, http.MethodPost, httpServer.URL, strings.NewReader(body)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + // Verify that no session ID header is returned in stateless mode + if sessionID := resp.Header.Get(sessionIDHeader); sessionID != "" { + t.Errorf("%s = %s, want no session ID header", sessionIDHeader, sessionID) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("Status code = %d; want successful response", resp.StatusCode) + } + + var events []Event + for event, err := range scanEvents(resp.Body) { + if err != nil { + t.Fatal(err) + } + events = append(events, event) + } + if len(events) != 1 { + t.Fatalf("got %d SSE events, want 1; events: %v", len(events), events) + } + msg, err := jsonrpc.DecodeMessage(events[0].Data) + if err != nil { + t.Fatal(err) + } + jsonResp, ok := msg.(*jsonrpc.Response) + if !ok { + t.Errorf("event is %T, want response", jsonResp) + } + if jsonResp.Error != nil { + t.Errorf("request failed: %v", jsonResp.Error) + } + } + + checkRequest(`{"jsonrpc":"2.0","method":"tools/list","id":1,"params":{}}`) + + // Verify we can make another request without session ID + checkRequest(`{"jsonrpc":"2.0","method":"tools/call","id":2,"params":{"name":"greet","arguments":{"name":"World"}}}`) +} diff --git a/mcp/transport.go b/mcp/transport.go index 02b21806..a0dc7bbd 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -60,11 +60,28 @@ type Connection interface { SessionID() string } -// A clientConnection is a [Connection] that is specific to the MCP client, and -// so may receive information about the client session. +// A ClientConnection is a [Connection] that is specific to the MCP client. +// +// If client connections implement this interface, they may receive information +// about changes to the client session. +// +// TODO: should this interface be exported? type clientConnection interface { Connection - initialized(*InitializeResult) + + // SessionUpdated is called whenever the client session state changes. + sessionUpdated(clientSessionState) +} + +// A serverConnection is a Connection that is specific to the MCP server. +// +// If server connections implement this interface, they receive information +// about changes to the server session. +// +// TODO: should this interface be exported? +type serverConnection interface { + Connection + sessionUpdated(ServerSessionState) } // A StdioTransport is a [Transport] that communicates over stdin/stdout using @@ -102,37 +119,36 @@ func NewInMemoryTransports() (*InMemoryTransport, *InMemoryTransport) { return &InMemoryTransport{ioTransport{c1}}, &InMemoryTransport{ioTransport{c2}} } -type binder[T handler] interface { - bind(*jsonrpc2.Connection) T +type binder[T handler, State any] interface { + bind(Connection, *jsonrpc2.Connection, State) T disconnect(T) } type handler interface { handle(ctx context.Context, req *jsonrpc.Request) (any, error) - setConn(Connection) } -func connect[H handler](ctx context.Context, t Transport, b binder[H]) (H, error) { +func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, State], s State) (H, error) { var zero H - conn, err := t.Connect(ctx) + mcpConn, err := t.Connect(ctx) if err != nil { return zero, err } // If logging is configured, write message logs. - reader, writer := jsonrpc2.Reader(conn), jsonrpc2.Writer(conn) + reader, writer := jsonrpc2.Reader(mcpConn), jsonrpc2.Writer(mcpConn) var ( h H preempter canceller ) bind := func(conn *jsonrpc2.Connection) jsonrpc2.Handler { - h = b.bind(conn) + h = b.bind(mcpConn, conn, s) preempter.conn = conn return jsonrpc2.HandlerFunc(h.handle) } _ = jsonrpc2.NewConnection(ctx, jsonrpc2.ConnectionConfig{ Reader: reader, Writer: writer, - Closer: conn, + Closer: mcpConn, Bind: bind, Preempter: &preempter, OnDone: func() { @@ -141,7 +157,6 @@ func connect[H handler](ctx context.Context, t Transport, b binder[H]) (H, error OnInternalError: func(err error) { log.Printf("jsonrpc2 error: %v", err) }, }) assert(preempter.conn != nil, "unbound preempter") - h.setConn(conn) return h, nil }