diff --git a/Makefile b/Makefile index 6a3aa6d0..23acdde3 100644 --- a/Makefile +++ b/Makefile @@ -28,9 +28,9 @@ KUBECONFIG ?= $(HOME)/.kube/karmada.config # API server common flags API_COMMON_FLAGS = \ - $(if $(OPENAI_API_KEY),--openai-api-key="$(OPENAI_API_KEY)") \ - $(if $(OPENAI_MODEL),--openai-model="$(OPENAI_MODEL)",--openai-model="gpt-3.5-turbo") \ - $(if $(OPENAI_ENDPOINT),--openai-endpoint="$(OPENAI_ENDPOINT)",--openai-endpoint="https://api.openai.com/v1") \ + $(if $(LLM_API_KEY),--llm-api-key="$(LLM_API_KEY)") \ + $(if $(LLM_MODEL),--llm-model="$(LLM_MODEL)",--llm-model="gpt-3.5-turbo") \ + $(if $(LLM_ENDPOINT),--llm-endpoint="$(LLM_ENDPOINT)",--llm-endpoint="https://api.openai.com/v1") \ $(if $(filter true,$(ENABLE_MCP)),--enable-mcp=true) \ $(if $(MCP_TRANSPORT_MODE),--mcp-transport-mode="$(MCP_TRANSPORT_MODE)",--mcp-transport-mode="stdio") \ $(if $(KARMADA_MCP_SERVER_PATH),--mcp-server-path="$(KARMADA_MCP_SERVER_PATH)") \ diff --git a/cmd/api/app/options/options.go b/cmd/api/app/options/options.go index 9e463935..c8fa154f 100644 --- a/cmd/api/app/options/options.go +++ b/cmd/api/app/options/options.go @@ -18,6 +18,7 @@ package options import ( "net" + "time" "github.com/spf13/pflag" ) @@ -44,10 +45,11 @@ type Options struct { MCPServerPath string MCPSSEEndpoint string - // OpenAI related options - OpenAIAPIKey string - OpenAIModel string - OpenAIEndpoint string + // LLM related options + LLMAPIKey string + LLMModel string + LLMEndpoint string + LLMTimeout time.Duration } // NewOptions returns initialized Options. @@ -80,8 +82,9 @@ func (o *Options) AddFlags(fs *pflag.FlagSet) { fs.StringVar(&o.MCPServerPath, "mcp-server-path", "", "Path to the MCP server binary (required for stdio mode)") fs.StringVar(&o.MCPSSEEndpoint, "mcp-sse-endpoint", "", "MCP SSE endpoint URL (required for sse mode)") - // OpenAI related flags - fs.StringVar(&o.OpenAIAPIKey, "openai-api-key", "", "OpenAI API key for AI assistant functionality") - fs.StringVar(&o.OpenAIModel, "openai-model", "gpt-3.5-turbo", "OpenAI model to use for AI assistant") - fs.StringVar(&o.OpenAIEndpoint, "openai-endpoint", "https://api.openai.com/v1", "OpenAI API endpoint URL") + // LLM related flags + fs.StringVar(&o.LLMAPIKey, "llm-api-key", "", "LLM API key for AI assistant functionality") + fs.StringVar(&o.LLMModel, "llm-model", "gpt-3.5-turbo", "LLM model to use for AI assistant") + fs.StringVar(&o.LLMEndpoint, "llm-endpoint", "https://api.openai.com/v1", "LLM API endpoint URL") + fs.DurationVar(&o.LLMTimeout, "llm-timeout", 30*time.Second, "Timeout for LLM API requests") } diff --git a/go.mod b/go.mod index e8bb6aaf..7a7385ed 100644 --- a/go.mod +++ b/go.mod @@ -12,10 +12,10 @@ require ( github.com/gobuffalo/flect v1.0.3 github.com/golang-jwt/jwt/v5 v5.3.0 github.com/karmada-io/karmada v1.15.0 - github.com/mark3labs/mcp-go v0.26.0 + github.com/mark3labs/mcp-go v0.42.0 github.com/prometheus/client_golang v1.22.0 github.com/prometheus/common v0.65.0 - github.com/sashabaranov/go-openai v1.40.5 + github.com/sashabaranov/go-openai v1.41.2 github.com/spf13/cobra v1.9.1 github.com/spf13/pflag v1.0.6 golang.org/x/net v0.40.0 @@ -42,8 +42,10 @@ require ( github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect github.com/MakeNowJust/heredoc v1.0.0 // indirect github.com/Yiling-J/theine-go v0.6.0 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/blang/semver/v4 v4.0.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect github.com/bytedance/sonic v1.12.7 // indirect github.com/bytedance/sonic/loader v0.2.2 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect @@ -77,6 +79,7 @@ require ( github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 // indirect github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/josharian/intern v1.0.0 // indirect @@ -109,6 +112,7 @@ require ( github.com/spf13/cast v1.7.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/x448/float16 v0.8.4 // indirect github.com/xlab/treeprint v1.2.0 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect diff --git a/go.sum b/go.sum index b9d0e4fd..978f777a 100644 --- a/go.sum +++ b/go.sum @@ -6,10 +6,14 @@ github.com/Yiling-J/theine-go v0.6.0 h1:jv7V/tcD6ijL0T4kfbJDKP81TCZBkoriNTPSqwiv github.com/Yiling-J/theine-go v0.6.0/go.mod h1:mdch1vjgGWd7s3rWKvY+MF5InRLfRv/CWVI9RVNQ8wY= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/bytedance/sonic v1.12.7 h1:CQU8pxOy9HToxhndH0Kx/S1qU/CuS9GnKYrGioDcU1Q= github.com/bytedance/sonic v1.12.7/go.mod h1:tnbal4mxOMju17EGfknm2XyYcpyCnIROYOEYuemj13I= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= @@ -122,6 +126,8 @@ github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79 h1:+ngKgrYPPJr github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= @@ -158,8 +164,8 @@ github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= -github.com/mark3labs/mcp-go v0.26.0 h1:xz/Kv1cHLYovF8txv6btBM39/88q3YOjnxqhi51jB0w= -github.com/mark3labs/mcp-go v0.26.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= +github.com/mark3labs/mcp-go v0.42.0 h1:gk/8nYJh8t3yroCAOBhNbYsM9TCKvkM13I5t5Hfu6Ls= +github.com/mark3labs/mcp-go v0.42.0/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0= @@ -214,8 +220,8 @@ github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/samber/lo v1.47.0 h1:z7RynLwP5nbyRscyvcD043DWYoOcYRv3mV8lBeqOCLc= github.com/samber/lo v1.47.0/go.mod h1:RmDH9Ct32Qy3gduHQuKJ3gW1fMHAnE/fAzQuf6He5cU= -github.com/sashabaranov/go-openai v1.40.5 h1:SwIlNdWflzR1Rxd1gv3pUg6pwPc6cQ2uMoHs8ai+/NY= -github.com/sashabaranov/go-openai v1.40.5/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= +github.com/sashabaranov/go-openai v1.41.2 h1:vfPRBZNMpnqu8ELsclWcAvF19lDNgh1t6TVfFFOPiSM= +github.com/sashabaranov/go-openai v1.41.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/sergi/go-diff v1.2.0 h1:XU+rvMAioB0UC3q1MFrIQy4Vo5/4VsRDQQXHsEya6xQ= github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= @@ -242,6 +248,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xlab/treeprint v1.2.0 h1:HzHnuAF1plUN2zGlAFHbSQP2qJ0ZAD3XF5XD7OesXRQ= diff --git a/pkg/llm/client.go b/pkg/llm/client.go new file mode 100644 index 00000000..5674ade4 --- /dev/null +++ b/pkg/llm/client.go @@ -0,0 +1,143 @@ +/* +Copyright 2025 The Karmada Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package llm + +import ( + "context" + "fmt" + "net/http" + "sync" + "time" + + "github.com/sashabaranov/go-openai" + "k8s.io/klog/v2" +) + +var ( + globalLLMConfig *Config + globalLLMClient *openai.Client + llmInitialized bool + llmMutex sync.RWMutex +) + +// InitLLMConfig initializes LLM configuration from Config. +func InitLLMConfig(config *Config) { + llmMutex.Lock() + defer llmMutex.Unlock() + + globalLLMConfig = config + globalLLMClient = nil // Reset client to force recreation + llmInitialized = true + + klog.InfoS("LLM configuration initialized", + "hasAPIKey", config.LLMAPIKey != "", + "model", config.LLMModel, + "endpoint", config.LLMEndpoint, + "timeout", config.Timeout) +} + +// GetLLMClient returns a configured LLM client (singleton). +// Note: This assumes the configured endpoint supports OpenAI API format. +// For incompatible providers, consider using their native SDKs directly. +func GetLLMClient() (*openai.Client, error) { + llmMutex.Lock() + defer llmMutex.Unlock() + + // Return existing client if already created (singleton) + if globalLLMClient != nil { + return globalLLMClient, nil + } + + if !llmInitialized || globalLLMConfig == nil { + return nil, fmt.Errorf("%w: call InitLLMConfig first", ErrLLMNotInitialized) + } + + // Validate configuration before creating client + if err := globalLLMConfig.Validate(); err != nil { + return nil, fmt.Errorf("invalid LLM configuration: %w", err) + } + + // Create new client + config := openai.DefaultConfig(globalLLMConfig.LLMAPIKey) + if globalLLMConfig.LLMEndpoint != "" { + config.BaseURL = globalLLMConfig.LLMEndpoint + } + + // Configure HTTP client with timeout + config.HTTPClient = &http.Client{ + Timeout: globalLLMConfig.Timeout, + } + + globalLLMClient = openai.NewClientWithConfig(config) + + klog.InfoS("LLM client created successfully", + "endpoint", config.BaseURL, + "model", globalLLMConfig.LLMModel, + "timeout", globalLLMConfig.Timeout) + + return globalLLMClient, nil +} + +// GetLLMModel returns the configured LLM model. +func GetLLMModel() string { + llmMutex.RLock() + defer llmMutex.RUnlock() + + if !llmInitialized || globalLLMConfig == nil || globalLLMConfig.LLMModel == "" { + return openai.GPT3Dot5Turbo // default value + } + return globalLLMConfig.LLMModel +} + +// IsLLMConfigured returns true if LLM is properly configured. +func IsLLMConfigured() bool { + llmMutex.RLock() + defer llmMutex.RUnlock() + + return llmInitialized && globalLLMConfig != nil && globalLLMConfig.LLMAPIKey != "" +} + +// ValidateLLMConnection performs a health check to validate the LLM connection. +// This function sends a simple API request to verify that the endpoint is reachable +// and compatible with the OpenAI API format. +func ValidateLLMConnection(ctx context.Context) error { + client, err := GetLLMClient() + if err != nil { + return fmt.Errorf("%w: %v", ErrConnectionFailed, err) + } + + // Create a context with timeout if none provided + if ctx == nil { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + } + + // Send a simple models list request to validate connection + // This is a lightweight operation that verifies API compatibility + _, err = client.ListModels(ctx) + if err != nil { + klog.ErrorS(err, "LLM connection validation failed", + "endpoint", globalLLMConfig.LLMEndpoint) + return fmt.Errorf("%w: unable to communicate with LLM endpoint: %v", + ErrConnectionFailed, err) + } + + klog.InfoS("LLM connection validated successfully", + "endpoint", globalLLMConfig.LLMEndpoint) + return nil +} diff --git a/pkg/llm/client_test.go b/pkg/llm/client_test.go new file mode 100644 index 00000000..7dde721a --- /dev/null +++ b/pkg/llm/client_test.go @@ -0,0 +1,421 @@ +/* +Copyright 2025 The Karmada Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package llm + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/sashabaranov/go-openai" +) + +func resetGlobalState() { + llmMutex.Lock() + defer llmMutex.Unlock() + globalLLMConfig = nil + globalLLMClient = nil + llmInitialized = false +} + +func TestInitLLMConfig(t *testing.T) { + resetGlobalState() + + config := &Config{ + LLMAPIKey: "test-key", + LLMModel: "gpt-4", + LLMEndpoint: "https://test.endpoint.com/v1", + Timeout: 60 * time.Second, + } + + InitLLMConfig(config) + + if !llmInitialized { + t.Error("Expected llmInitialized to be true after InitLLMConfig") + } + + if globalLLMConfig != config { + t.Error("Expected globalLLMConfig to be set to the provided config") + } + + if globalLLMClient != nil { + t.Error("Expected globalLLMClient to be reset to nil") + } +} + +func TestGetLLMClientNotInitialized(t *testing.T) { + resetGlobalState() + + _, err := GetLLMClient() + if err == nil { + t.Error("Expected error when LLM not initialized") + } + + // Check if error contains the sentinel error + if !errors.Is(err, ErrLLMNotInitialized) { + t.Errorf("Expected error to be ErrLLMNotInitialized, got %v", err) + } +} + +func TestGetLLMClientNoAPIKey(t *testing.T) { + resetGlobalState() + + config := &Config{ + LLMAPIKey: "", + LLMModel: "gpt-4", + LLMEndpoint: "https://test.endpoint.com/v1", + Timeout: 60 * time.Second, + } + + InitLLMConfig(config) + + _, err := GetLLMClient() + if err == nil { + t.Error("Expected error when API key is empty") + } + + // Check if error contains the sentinel error + if !errors.Is(err, ErrLLMAPIKeyNotConfigured) { + t.Errorf("Expected error to be ErrLLMAPIKeyNotConfigured, got %v", err) + } +} + +func TestGetLLMClientSuccess(t *testing.T) { + resetGlobalState() + + config := &Config{ + LLMAPIKey: "test-api-key", + LLMModel: "gpt-4", + LLMEndpoint: "https://test.endpoint.com/v1", + Timeout: 60 * time.Second, + } + + InitLLMConfig(config) + + client, err := GetLLMClient() + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + if client == nil { + t.Error("Expected client to be non-nil") + } + + client2, err := GetLLMClient() + if err != nil { + t.Errorf("Expected no error on second call, got %v", err) + } + + if client != client2 { + t.Error("Expected same client instance (singleton pattern)") + } +} + +func TestGetLLMClientDefaultEndpoint(t *testing.T) { + resetGlobalState() + + config := &Config{ + LLMAPIKey: "test-api-key", + LLMModel: "gpt-4", + LLMEndpoint: "", + Timeout: 60 * time.Second, + } + + InitLLMConfig(config) + + client, err := GetLLMClient() + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + if client == nil { + t.Error("Expected client to be non-nil") + } +} + +func TestGetLLMModel(t *testing.T) { + resetGlobalState() + + model := GetLLMModel() + if model != openai.GPT3Dot5Turbo { + t.Errorf("Expected default model %s when not initialized, got %s", openai.GPT3Dot5Turbo, model) + } + + config := &Config{ + LLMAPIKey: "test-api-key", + LLMModel: "gpt-4", + LLMEndpoint: "https://test.endpoint.com/v1", + Timeout: 60 * time.Second, + } + + InitLLMConfig(config) + + model = GetLLMModel() + if model != "gpt-4" { + t.Errorf("Expected model gpt-4, got %s", model) + } + + config.LLMModel = "" + InitLLMConfig(config) + + model = GetLLMModel() + if model != openai.GPT3Dot5Turbo { + t.Errorf("Expected default model %s when empty, got %s", openai.GPT3Dot5Turbo, model) + } +} + +func TestIsLLMConfigured(t *testing.T) { + resetGlobalState() + + if IsLLMConfigured() { + t.Error("Expected IsLLMConfigured to return false when not initialized") + } + + config := &Config{ + LLMAPIKey: "", + LLMModel: "gpt-4", + LLMEndpoint: "https://test.endpoint.com/v1", + Timeout: 60 * time.Second, + } + + InitLLMConfig(config) + + if IsLLMConfigured() { + t.Error("Expected IsLLMConfigured to return false when API key is empty") + } + + config.LLMAPIKey = "test-api-key" + InitLLMConfig(config) + + if !IsLLMConfigured() { + t.Error("Expected IsLLMConfigured to return true when properly configured") + } +} + +func TestConcurrentAccess(t *testing.T) { + resetGlobalState() + + config := &Config{ + LLMAPIKey: "test-api-key", + LLMModel: "gpt-4", + LLMEndpoint: "https://test.endpoint.com/v1", + Timeout: 60 * time.Second, + } + + InitLLMConfig(config) + + var wg sync.WaitGroup + clientChan := make(chan *openai.Client, 10) + + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + client, err := GetLLMClient() + if err != nil { + t.Errorf("Unexpected error in goroutine: %v", err) + return + } + clientChan <- client + }() + } + + wg.Wait() + close(clientChan) + + var firstClient *openai.Client + count := 0 + for client := range clientChan { + count++ + if firstClient == nil { + firstClient = client + } else if client != firstClient { + t.Error("Expected all goroutines to get the same client instance") + } + } + + if count != 10 { + t.Errorf("Expected 10 client instances, got %d", count) + } +} + +func TestLLMWithMockServer(t *testing.T) { + resetGlobalState() + + mockResponse := `{ + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-3.5-turbo", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "Hello from mock server!" + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 9, + "completion_tokens": 5, + "total_tokens": 14 + } + }` + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/chat/completions" { + t.Errorf("Expected path /chat/completions, got %s", r.URL.Path) + } + + if r.Method != "POST" { + t.Errorf("Expected POST method, got %s", r.Method) + } + + authHeader := r.Header.Get("Authorization") + if !strings.HasPrefix(authHeader, "Bearer ") { + t.Errorf("Expected Authorization header with Bearer token, got %s", authHeader) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, mockResponse) + })) + defer server.Close() + + config := &Config{ + LLMAPIKey: "test-api-key", + LLMModel: "gpt-3.5-turbo", + LLMEndpoint: server.URL, + Timeout: 30 * time.Second, + } + + InitLLMConfig(config) + + client, err := GetLLMClient() + if err != nil { + t.Fatalf("Failed to get LLM client: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + req := openai.ChatCompletionRequest{ + Model: "gpt-3.5-turbo", + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Test message", + }, + }, + MaxTokens: 50, + } + + resp, err := client.CreateChatCompletion(ctx, req) + if err != nil { + t.Fatalf("Failed to create chat completion: %v", err) + } + + if len(resp.Choices) == 0 { + t.Fatal("Expected at least one choice in response") + } + + expectedContent := "Hello from mock server!" + actualContent := resp.Choices[0].Message.Content + if actualContent != expectedContent { + t.Errorf("Expected content '%s', got '%s'", expectedContent, actualContent) + } + + t.Logf("Mock server test passed. Response: %s", actualContent) +} + +func TestValidateLLMConnection(t *testing.T) { + resetGlobalState() + + // Mock server that responds to models list request + mockModelsResponse := `{ + "data": [ + {"id": "gpt-3.5-turbo", "object": "model"}, + {"id": "gpt-4", "object": "model"} + ] + }` + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/models" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, mockModelsResponse) + } else { + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + config := &Config{ + LLMAPIKey: "test-api-key", + LLMModel: "gpt-3.5-turbo", + LLMEndpoint: server.URL, + Timeout: 5 * time.Second, + } + + InitLLMConfig(config) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + err := ValidateLLMConnection(ctx) + if err != nil { + t.Fatalf("Expected successful connection validation, got error: %v", err) + } +} + +func TestValidateLLMConnectionFailure(t *testing.T) { + resetGlobalState() + + // Mock server that returns error + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + fmt.Fprint(w, `{"error": "invalid API key"}`) + })) + defer server.Close() + + config := &Config{ + LLMAPIKey: "invalid-key", + LLMModel: "gpt-3.5-turbo", + LLMEndpoint: server.URL, + Timeout: 5 * time.Second, + } + + InitLLMConfig(config) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + err := ValidateLLMConnection(ctx) + if err == nil { + t.Fatal("Expected connection validation to fail") + } + + if !errors.Is(err, ErrConnectionFailed) { + t.Errorf("Expected error to be ErrConnectionFailed, got %v", err) + } +} diff --git a/pkg/llm/config.go b/pkg/llm/config.go new file mode 100644 index 00000000..23fd07c4 --- /dev/null +++ b/pkg/llm/config.go @@ -0,0 +1,61 @@ +/* +Copyright 2025 The Karmada Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package llm + +import ( + "errors" + "net/url" + "time" + + "github.com/sashabaranov/go-openai" +) + +// Config holds LLM configuration. +// This package assumes LLM endpoints are compatible with OpenAI API format. +type Config struct { + LLMAPIKey string // API key for the LLM service + LLMModel string // Model name (e.g., "gpt-3.5-turbo", "gpt-4") + LLMEndpoint string // API endpoint URL (must be OpenAI-compatible) + Timeout time.Duration // HTTP request timeout +} + +// Validate verifies that the configuration contains valid values. +// It checks for required fields and validates the endpoint URL format. +func (c *Config) Validate() error { + // Validate API key + if c.LLMAPIKey == "" { + return ErrLLMAPIKeyNotConfigured + } + + // Validate endpoint URL if provided + if c.LLMEndpoint != "" { + if _, err := url.Parse(c.LLMEndpoint); err != nil { + return errors.Join(ErrInvalidEndpoint, err) + } + } + + return nil +} + +// NewConfig creates a new LLM config with default values. +func NewConfig() *Config { + return &Config{ + LLMModel: openai.GPT3Dot5Turbo, + LLMEndpoint: "https://api.openai.com/v1", + Timeout: 30 * time.Second, + } +} diff --git a/pkg/llm/config_test.go b/pkg/llm/config_test.go new file mode 100644 index 00000000..66b2c1ef --- /dev/null +++ b/pkg/llm/config_test.go @@ -0,0 +1,144 @@ +/* +Copyright 2025 The Karmada Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package llm + +import ( + "errors" + "testing" + "time" + + "github.com/sashabaranov/go-openai" +) + +func TestNewConfig(t *testing.T) { + config := NewConfig() + + if config == nil { + t.Fatal("NewConfig() returned nil") + } + + if config.LLMModel != openai.GPT3Dot5Turbo { + t.Errorf("Expected default model %s, got %s", openai.GPT3Dot5Turbo, config.LLMModel) + } + + if config.LLMEndpoint != "https://api.openai.com/v1" { + t.Errorf("Expected default endpoint https://api.openai.com/v1, got %s", config.LLMEndpoint) + } + + if config.Timeout != 30*time.Second { + t.Errorf("Expected default timeout 30s, got %v", config.Timeout) + } + + if config.LLMAPIKey != "" { + t.Errorf("Expected empty API key, got %s", config.LLMAPIKey) + } +} + +func TestConfigCustomization(t *testing.T) { + config := NewConfig() + + config.LLMAPIKey = "test-api-key" + config.LLMModel = "gpt-4" + config.LLMEndpoint = "https://custom.endpoint.com/v1" + config.Timeout = 60 * time.Second + + if config.LLMAPIKey != "test-api-key" { + t.Errorf("Expected API key test-api-key, got %s", config.LLMAPIKey) + } + + if config.LLMModel != "gpt-4" { + t.Errorf("Expected model gpt-4, got %s", config.LLMModel) + } + + if config.LLMEndpoint != "https://custom.endpoint.com/v1" { + t.Errorf("Expected endpoint https://custom.endpoint.com/v1, got %s", config.LLMEndpoint) + } + + if config.Timeout != 60*time.Second { + t.Errorf("Expected timeout 60s, got %v", config.Timeout) + } +} + +func TestConfigValidate(t *testing.T) { + tests := []struct { + name string + config *Config + wantErr bool + expectedErr error + }{ + { + name: "valid config with all fields", + config: &Config{ + LLMAPIKey: "test-api-key", + LLMModel: "gpt-4", + LLMEndpoint: "https://api.openai.com/v1", + Timeout: 30 * time.Second, + }, + wantErr: false, + }, + { + name: "valid config with empty endpoint", + config: &Config{ + LLMAPIKey: "test-api-key", + LLMModel: "gpt-4", + Timeout: 30 * time.Second, + }, + wantErr: false, + }, + { + name: "missing API key", + config: &Config{ + LLMModel: "gpt-4", + LLMEndpoint: "https://api.openai.com/v1", + Timeout: 30 * time.Second, + }, + wantErr: true, + expectedErr: ErrLLMAPIKeyNotConfigured, + }, + { + name: "invalid endpoint URL", + config: &Config{ + LLMAPIKey: "test-api-key", + LLMModel: "gpt-4", + LLMEndpoint: "://invalid-url", + Timeout: 30 * time.Second, + }, + wantErr: true, + expectedErr: ErrInvalidEndpoint, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + + if tt.wantErr && err == nil { + t.Errorf("Validate() expected error but got nil") + } + + if !tt.wantErr && err != nil { + t.Errorf("Validate() unexpected error: %v", err) + } + + if tt.wantErr && tt.expectedErr != nil { + if !errors.Is(err, tt.expectedErr) { + t.Errorf("Validate() error = %v, expected to contain %v", err, tt.expectedErr) + } + } + }) + } +} diff --git a/pkg/llm/errors.go b/pkg/llm/errors.go new file mode 100644 index 00000000..b1e4d924 --- /dev/null +++ b/pkg/llm/errors.go @@ -0,0 +1,35 @@ +/* +Copyright 2025 The Karmada Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package llm + +import "errors" + +// Sentinel errors for LLM package. +// These errors can be used with errors.Is() for error type checking. +var ( + // ErrLLMNotInitialized is returned when LLM configuration has not been initialized. + ErrLLMNotInitialized = errors.New("LLM not initialized") + + // ErrLLMAPIKeyNotConfigured is returned when LLM API key is not set. + ErrLLMAPIKeyNotConfigured = errors.New("LLM API key not configured") + + // ErrInvalidEndpoint is returned when the LLM endpoint URL is invalid. + ErrInvalidEndpoint = errors.New("invalid LLM endpoint URL") + + // ErrConnectionFailed is returned when LLM connection validation fails. + ErrConnectionFailed = errors.New("LLM connection validation failed") +) diff --git a/pkg/llm/integration_test.go b/pkg/llm/integration_test.go new file mode 100644 index 00000000..e38158aa --- /dev/null +++ b/pkg/llm/integration_test.go @@ -0,0 +1,211 @@ +//go:build integration + +/* +Copyright 2025 The Karmada Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Integration tests for LLM package. +// +// These tests interact with actual LLM API endpoints and are disabled by default. +// To run these tests, you need to: +// +// 1. Set required environment variables: +// export LLM_API_KEY=your-api-key-here +// +// 2. (Optional) Configure custom endpoint: +// export LLM_ENDPOINT=https://api.openai.com/v1 # default value +// +// 3. (Optional) Configure custom model: +// export LLM_MODEL=gpt-3.5-turbo # default value +// +// 4. Run the integration tests: +// go test -v -tags=integration -run '^TestLLMIntegration' ./pkg/llm +// +// Example: +// export LLM_API_KEY=sk-xxx +// go test -v -tags=integration ./pkg/llm + +package llm + +import ( + "context" + "errors" + "io" + "os" + "strings" + "testing" + "time" + + "github.com/sashabaranov/go-openai" +) + +func TestLLMIntegration(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + resetGlobalState() + + apiKey := os.Getenv("LLM_API_KEY") + endpoint := os.Getenv("LLM_ENDPOINT") + model := os.Getenv("LLM_MODEL") + + if apiKey == "" { + t.Fatalf("LLM_API_KEY environment variable must be set to run integration tests") + } + + if endpoint == "" { + endpoint = "https://api.openai.com/v1" + } + if model == "" { + model = "gpt-3.5-turbo" + } + + config := &Config{ + LLMAPIKey: apiKey, + LLMModel: model, + LLMEndpoint: endpoint, + Timeout: 30 * time.Second, + } + + InitLLMConfig(config) + + if !IsLLMConfigured() { + t.Fatalf("LLM configuration check failed after initialization") + } + + client, err := GetLLMClient() + if err != nil { + t.Fatalf("Failed to get LLM client: %v", err) + } + + t.Run("BasicQuestion", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + req := openai.ChatCompletionRequest{ + Model: GetLLMModel(), + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "What is Kubernetes? Answer in one sentence.", + }, + }, + MaxTokens: 50, + } + + resp, err := client.CreateChatCompletion(ctx, req) + if err != nil { + t.Fatalf("CreateChatCompletion failed: %v", err) + } + + if len(resp.Choices) == 0 { + t.Fatalf("Received no choices in response from LLM API") + } + + answer := resp.Choices[0].Message.Content + if !strings.Contains(strings.ToLower(answer), "kubernetes") && + !strings.Contains(strings.ToLower(answer), "container") { + t.Errorf("Expected answer to contain 'kubernetes' or 'container', but got: %s", answer) + } + }) + + t.Run("KarmadaConcept", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + req := openai.ChatCompletionRequest{ + Model: GetLLMModel(), + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "What is Karmada and how does it help with multi-cluster management? Answer briefly.", + }, + }, + MaxTokens: 100, + } + + resp, err := client.CreateChatCompletion(ctx, req) + if err != nil { + t.Fatalf("CreateChatCompletion for Karmada question failed: %v", err) + } + + if len(resp.Choices) == 0 { + t.Fatalf("Received no choices for Karmada question from LLM API") + } + + answer := resp.Choices[0].Message.Content + lowerAnswer := strings.ToLower(answer) + if !strings.Contains(lowerAnswer, "karmada") && + !strings.Contains(lowerAnswer, "multi-cluster") && + !strings.Contains(lowerAnswer, "cluster") { + t.Errorf("Expected answer to contain 'karmada' or 'multi-cluster', but got: %s", answer) + } + }) + + t.Run("StreamingResponse", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + req := openai.ChatCompletionRequest{ + Model: GetLLMModel(), + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "List 3 benefits of using Karmada for multi-cluster management.", + }, + }, + MaxTokens: 150, + Stream: true, + } + + stream, err := client.CreateChatCompletionStream(ctx, req) + if err != nil { + t.Fatalf("Failed to create chat completion stream: %v", err) + } + defer stream.Close() + + fullResponse := "" + chunkCount := 0 + for { + response, err := stream.Recv() + if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + break + } + if err != nil { + t.Fatalf("Failed while receiving stream response: %v", err) + } + + if len(response.Choices) > 0 { + chunk := response.Choices[0].Delta.Content + fullResponse += chunk + if chunk != "" { + chunkCount++ + } + // Key change: break on finish_reason + if response.Choices[0].FinishReason == "stop" { + break + } + } + } + + if chunkCount < 3 || len(fullResponse) < 50 { + t.Errorf("Streaming response seems abnormal: received %d chunks and %d characters", chunkCount, len(fullResponse)) + } + }) +} diff --git a/pkg/mcpclient/test_mcp_server.go b/pkg/mcpclient/test_mcp_server.go index a0049129..129bcd42 100644 --- a/pkg/mcpclient/test_mcp_server.go +++ b/pkg/mcpclient/test_mcp_server.go @@ -243,16 +243,12 @@ func (ts *TestMCPServer) StartStdioServer() error { // Tool handlers for testing func testEchoHandler(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - arguments := request.Params.Arguments - message, ok := arguments["message"].(string) - if !ok { + message, err := request.RequireString("message") + if err != nil { return nil, fmt.Errorf("message parameter is required and must be a string") } - prefix := "" - if prefixArg, exists := arguments["prefix"]; exists { - prefix, _ = prefixArg.(string) - } + prefix := request.GetString("prefix", "") result := prefix + message @@ -264,13 +260,19 @@ func testEchoHandler(_ context.Context, request mcp.CallToolRequest) (*mcp.CallT } func testCalculateHandler(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - arguments := request.Params.Arguments - a, ok1 := arguments["a"].(float64) - b, ok2 := arguments["b"].(float64) - operation, ok3 := arguments["operation"].(string) + a, err := request.RequireFloat("a") + if err != nil { + return nil, fmt.Errorf("a parameter is required and must be a number") + } + + b, err := request.RequireFloat("b") + if err != nil { + return nil, fmt.Errorf("b parameter is required and must be a number") + } - if !ok1 || !ok2 || !ok3 { - return nil, fmt.Errorf("a, b, and operation parameters are required") + operation, err := request.RequireString("operation") + if err != nil { + return nil, fmt.Errorf("operation parameter is required and must be a string") } var result float64 @@ -298,10 +300,13 @@ func testCalculateHandler(_ context.Context, request mcp.CallToolRequest) (*mcp. } func testDelayHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - arguments := request.Params.Arguments - milliseconds, ok := arguments["milliseconds"].(float64) - if !ok || milliseconds < 0 { - return nil, fmt.Errorf("milliseconds parameter is required and must be a positive number") + milliseconds, err := request.RequireFloat("milliseconds") + if err != nil { + return nil, fmt.Errorf("milliseconds parameter is required and must be a number") + } + + if milliseconds < 0 { + return nil, fmt.Errorf("milliseconds must be a positive number") } select {