From 7c5e2483cd46d228772b61a49bd9445b0fec28da Mon Sep 17 00:00:00 2001 From: Tristan McKinney Date: Fri, 19 Dec 2025 12:41:52 -0800 Subject: [PATCH] feat: add --timeout/-t flag for initialization timeout Adds a configurable timeout for MCP server initialization: - Default remains 10 seconds (matching current behavior) - Can be set via --timeout or -t flag - Includes tests for timeout flag parsing --- cmd/mcptools/commands/call.go | 5 +++ cmd/mcptools/commands/root.go | 25 +++++++----- cmd/mcptools/commands/utils.go | 8 +++- cmd/mcptools/commands/utils_test.go | 59 +++++++++++++++++++++++++++++ 4 files changed, 86 insertions(+), 11 deletions(-) diff --git a/cmd/mcptools/commands/call.go b/cmd/mcptools/commands/call.go index 2f29a09..96af58d 100644 --- a/cmd/mcptools/commands/call.go +++ b/cmd/mcptools/commands/call.go @@ -36,6 +36,11 @@ func parseCallArgs(cmdArgs []string) (string, []string) { case (cmdArgs[i] == FlagAuthHeader) && i+1 < len(cmdArgs): AuthHeader = cmdArgs[i+1] i += 2 + case (cmdArgs[i] == FlagTimeout || cmdArgs[i] == FlagTimeoutShort) && i+1 < len(cmdArgs): + if _, err := fmt.Sscanf(cmdArgs[i+1], "%d", &InitTimeout); err != nil { + fmt.Fprintf(os.Stderr, "Warning: invalid timeout value %q, using default\n", cmdArgs[i+1]) + } + i += 2 case !entityExtracted: entityName = cmdArgs[i] entityExtracted = true diff --git a/cmd/mcptools/commands/root.go b/cmd/mcptools/commands/root.go index 068a7eb..2429e3a 100644 --- a/cmd/mcptools/commands/root.go +++ b/cmd/mcptools/commands/root.go @@ -9,16 +9,18 @@ import ( // flags. const ( - FlagFormat = "--format" - FlagFormatShort = "-f" - FlagParams = "--params" - FlagParamsShort = "-p" - FlagHelp = "--help" - FlagHelpShort = "-h" - FlagServerLogs = "--server-logs" - FlagTransport = "--transport" - FlagAuthUser = "--auth-user" - FlagAuthHeader = "--auth-header" + FlagFormat = "--format" + FlagFormatShort = "-f" + FlagParams = "--params" + FlagParamsShort = "-p" + FlagHelp = "--help" + FlagHelpShort = "-h" + FlagServerLogs = "--server-logs" + FlagTransport = "--transport" + FlagAuthUser = "--auth-user" + FlagAuthHeader = "--auth-header" + FlagTimeout = "--timeout" + FlagTimeoutShort = "-t" ) // entity types. @@ -51,6 +53,8 @@ var ( AuthUser string // AuthHeader is a custom Authorization header. AuthHeader string + // InitTimeout is the timeout for MCP server initialization in seconds. + InitTimeout = 10 ) // RootCmd creates the root command. @@ -68,6 +72,7 @@ It allows you to discover and call tools, list resources, and interact with MCP- cmd.PersistentFlags().StringVar(&TransportOption, "transport", "http", "HTTP transport type (http, sse)") cmd.PersistentFlags().StringVar(&AuthUser, "auth-user", "", "Basic authentication in username:password format") cmd.PersistentFlags().StringVar(&AuthHeader, "auth-header", "", "Custom Authorization header (e.g., 'Bearer token' or 'Basic base64credentials')") + cmd.PersistentFlags().IntVarP(&InitTimeout, "timeout", "t", 10, "Initialization timeout in seconds") return cmd } diff --git a/cmd/mcptools/commands/utils.go b/cmd/mcptools/commands/utils.go index 3789004..0691364 100644 --- a/cmd/mcptools/commands/utils.go +++ b/cmd/mcptools/commands/utils.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "net/url" + "os" "strings" "time" @@ -176,7 +177,7 @@ var CreateClientFunc = func(args []string, _ ...client.ClientOption) (*client.Cl if err != nil { return nil, fmt.Errorf("init error: %w", err) } - case <-time.After(10 * time.Second): + case <-time.After(time.Duration(InitTimeout) * time.Second): return nil, fmt.Errorf("initialization timed out") } @@ -211,6 +212,11 @@ func ProcessFlags(args []string) []string { case args[i] == FlagAuthHeader && i+1 < len(args): AuthHeader = args[i+1] i += 2 + case (args[i] == FlagTimeout || args[i] == FlagTimeoutShort) && i+1 < len(args): + if _, err := fmt.Sscanf(args[i+1], "%d", &InitTimeout); err != nil { + fmt.Fprintf(os.Stderr, "Warning: invalid timeout value %q, using default\n", args[i+1]) + } + i += 2 default: parsedArgs = append(parsedArgs, args[i]) i++ diff --git a/cmd/mcptools/commands/utils_test.go b/cmd/mcptools/commands/utils_test.go index 0eb06b8..ec65a96 100644 --- a/cmd/mcptools/commands/utils_test.go +++ b/cmd/mcptools/commands/utils_test.go @@ -190,3 +190,62 @@ nested {"key":"value"}`[1:] // remove first newline }) } } + +func TestProcessFlagsTimeout(t *testing.T) { + originalTimeout := InitTimeout + defer func() { InitTimeout = originalTimeout }() + + tests := []struct { + name string + args []string + wantArgs []string + wantTimeout int + }{ + { + name: "default timeout", + args: []string{"cmd", "arg1"}, + wantArgs: []string{"cmd", "arg1"}, + wantTimeout: 10, + }, + { + name: "long timeout flag", + args: []string{"cmd", "--timeout", "60", "arg1"}, + wantArgs: []string{"cmd", "arg1"}, + wantTimeout: 60, + }, + { + name: "short timeout flag", + args: []string{"cmd", "-t", "10", "arg1"}, + wantArgs: []string{"cmd", "arg1"}, + wantTimeout: 10, + }, + { + name: "timeout at end", + args: []string{"cmd", "arg1", "--timeout", "120"}, + wantArgs: []string{"cmd", "arg1"}, + wantTimeout: 120, + }, + { + name: "invalid timeout keeps previous", + args: []string{"cmd", "--timeout", "invalid", "arg1"}, + wantArgs: []string{"cmd", "arg1"}, + wantTimeout: 10, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + InitTimeout = 10 + + gotArgs := ProcessFlags(tt.args) + + if !reflect.DeepEqual(gotArgs, tt.wantArgs) { + t.Errorf("ProcessFlags() gotArgs = %v, want %v", gotArgs, tt.wantArgs) + } + + if InitTimeout != tt.wantTimeout { + t.Errorf("ProcessFlags() InitTimeout = %v, want %v", InitTimeout, tt.wantTimeout) + } + }) + } +}