Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cmd/mcptools/commands/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ func parseCallArgs(cmdArgs []string) (string, []string) {
case (cmdArgs[i] == FlagAuthHeader) && i+1 < len(cmdArgs):
AuthHeader = cmdArgs[i+1]
i += 2
case (cmdArgs[i] == FlagTimeout || cmdArgs[i] == FlagTimeoutShort) && i+1 < len(cmdArgs):
if _, err := fmt.Sscanf(cmdArgs[i+1], "%d", &InitTimeout); err != nil {
fmt.Fprintf(os.Stderr, "Warning: invalid timeout value %q, using default\n", cmdArgs[i+1])
}
i += 2
case !entityExtracted:
entityName = cmdArgs[i]
entityExtracted = true
Expand Down
25 changes: 15 additions & 10 deletions cmd/mcptools/commands/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@ import (

// flags.
const (
FlagFormat = "--format"
FlagFormatShort = "-f"
FlagParams = "--params"
FlagParamsShort = "-p"
FlagHelp = "--help"
FlagHelpShort = "-h"
FlagServerLogs = "--server-logs"
FlagTransport = "--transport"
FlagAuthUser = "--auth-user"
FlagAuthHeader = "--auth-header"
FlagFormat = "--format"
FlagFormatShort = "-f"
FlagParams = "--params"
FlagParamsShort = "-p"
FlagHelp = "--help"
FlagHelpShort = "-h"
FlagServerLogs = "--server-logs"
FlagTransport = "--transport"
FlagAuthUser = "--auth-user"
FlagAuthHeader = "--auth-header"
FlagTimeout = "--timeout"
FlagTimeoutShort = "-t"
)

// entity types.
Expand Down Expand Up @@ -51,6 +53,8 @@ var (
AuthUser string
// AuthHeader is a custom Authorization header.
AuthHeader string
// InitTimeout is the timeout for MCP server initialization in seconds.
InitTimeout = 10
)

// RootCmd creates the root command.
Expand All @@ -68,6 +72,7 @@ It allows you to discover and call tools, list resources, and interact with MCP-
cmd.PersistentFlags().StringVar(&TransportOption, "transport", "http", "HTTP transport type (http, sse)")
cmd.PersistentFlags().StringVar(&AuthUser, "auth-user", "", "Basic authentication in username:password format")
cmd.PersistentFlags().StringVar(&AuthHeader, "auth-header", "", "Custom Authorization header (e.g., 'Bearer token' or 'Basic base64credentials')")
cmd.PersistentFlags().IntVarP(&InitTimeout, "timeout", "t", 10, "Initialization timeout in seconds")

return cmd
}
8 changes: 7 additions & 1 deletion cmd/mcptools/commands/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/json"
"fmt"
"net/url"
"os"
"strings"
"time"

Expand Down Expand Up @@ -176,7 +177,7 @@ var CreateClientFunc = func(args []string, _ ...client.ClientOption) (*client.Cl
if err != nil {
return nil, fmt.Errorf("init error: %w", err)
}
case <-time.After(10 * time.Second):
case <-time.After(time.Duration(InitTimeout) * time.Second):
return nil, fmt.Errorf("initialization timed out")
}

Expand Down Expand Up @@ -211,6 +212,11 @@ func ProcessFlags(args []string) []string {
case args[i] == FlagAuthHeader && i+1 < len(args):
AuthHeader = args[i+1]
i += 2
case (args[i] == FlagTimeout || args[i] == FlagTimeoutShort) && i+1 < len(args):
if _, err := fmt.Sscanf(args[i+1], "%d", &InitTimeout); err != nil {
fmt.Fprintf(os.Stderr, "Warning: invalid timeout value %q, using default\n", args[i+1])
}
i += 2
default:
parsedArgs = append(parsedArgs, args[i])
i++
Expand Down
59 changes: 59 additions & 0 deletions cmd/mcptools/commands/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,62 @@ nested {"key":"value"}`[1:] // remove first newline
})
}
}

func TestProcessFlagsTimeout(t *testing.T) {
originalTimeout := InitTimeout
defer func() { InitTimeout = originalTimeout }()

tests := []struct {
name string
args []string
wantArgs []string
wantTimeout int
}{
{
name: "default timeout",
args: []string{"cmd", "arg1"},
wantArgs: []string{"cmd", "arg1"},
wantTimeout: 10,
},
{
name: "long timeout flag",
args: []string{"cmd", "--timeout", "60", "arg1"},
wantArgs: []string{"cmd", "arg1"},
wantTimeout: 60,
},
{
name: "short timeout flag",
args: []string{"cmd", "-t", "10", "arg1"},
wantArgs: []string{"cmd", "arg1"},
wantTimeout: 10,
},
{
name: "timeout at end",
args: []string{"cmd", "arg1", "--timeout", "120"},
wantArgs: []string{"cmd", "arg1"},
wantTimeout: 120,
},
{
name: "invalid timeout keeps previous",
args: []string{"cmd", "--timeout", "invalid", "arg1"},
wantArgs: []string{"cmd", "arg1"},
wantTimeout: 10,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
InitTimeout = 10

gotArgs := ProcessFlags(tt.args)

if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
t.Errorf("ProcessFlags() gotArgs = %v, want %v", gotArgs, tt.wantArgs)
}

if InitTimeout != tt.wantTimeout {
t.Errorf("ProcessFlags() InitTimeout = %v, want %v", InitTimeout, tt.wantTimeout)
}
})
}
}