diff --git a/devops-mcp-server/go.mod b/devops-mcp-server/go.mod index 710af15..b337191 100644 --- a/devops-mcp-server/go.mod +++ b/devops-mcp-server/go.mod @@ -6,6 +6,7 @@ require ( cloud.google.com/go/artifactregistry v1.17.2 cloud.google.com/go/auth v0.17.0 cloud.google.com/go/cloudbuild v1.23.1 + cloud.google.com/go/iam v1.5.3 cloud.google.com/go/resourcemanager v1.10.7 cloud.google.com/go/run v1.12.1 cloud.google.com/go/storage v1.57.0 @@ -26,7 +27,6 @@ require ( cloud.google.com/go v0.123.0 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect cloud.google.com/go/compute/metadata v0.9.0 // indirect - cloud.google.com/go/iam v1.5.3 // indirect cloud.google.com/go/longrunning v0.7.0 // indirect cloud.google.com/go/monitoring v1.24.3 // indirect cyphar.com/go-pathrs v0.2.1 // indirect diff --git a/devops-mcp-server/rag/devops-rag.db b/devops-mcp-server/rag/client/devops-rag.db similarity index 100% rename from devops-mcp-server/rag/devops-rag.db rename to devops-mcp-server/rag/client/devops-rag.db diff --git a/devops-mcp-server/rag/client/mocks/mock_ragclient.go b/devops-mcp-server/rag/client/mocks/mock_ragclient.go new file mode 100644 index 0000000..8c08ecc --- /dev/null +++ b/devops-mcp-server/rag/client/mocks/mock_ragclient.go @@ -0,0 +1,65 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: rag/client/ragclient.go + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockRagClient is a mock of RagClient interface. +type MockRagClient struct { + ctrl *gomock.Controller + recorder *MockRagClientMockRecorder +} + +// MockRagClientMockRecorder is the mock recorder for MockRagClient. +type MockRagClientMockRecorder struct { + mock *MockRagClient +} + +// NewMockRagClient creates a new mock instance. +func NewMockRagClient(ctrl *gomock.Controller) *MockRagClient { + mock := &MockRagClient{ctrl: ctrl} + mock.recorder = &MockRagClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRagClient) EXPECT() *MockRagClientMockRecorder { + return m.recorder +} + +// QueryPatterns mocks base method. +func (m *MockRagClient) QueryPatterns(ctx context.Context, query string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "QueryPatterns", ctx, query) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryPatterns indicates an expected call of QueryPatterns. +func (mr *MockRagClientMockRecorder) QueryPatterns(ctx, query interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryPatterns", reflect.TypeOf((*MockRagClient)(nil).QueryPatterns), ctx, query) +} + +// Queryknowledge mocks base method. +func (m *MockRagClient) Queryknowledge(ctx context.Context, query string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Queryknowledge", ctx, query) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Queryknowledge indicates an expected call of Queryknowledge. +func (mr *MockRagClientMockRecorder) Queryknowledge(ctx, query interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Queryknowledge", reflect.TypeOf((*MockRagClient)(nil).Queryknowledge), ctx, query) +} diff --git a/devops-mcp-server/rag/client/ragclient.go b/devops-mcp-server/rag/client/ragclient.go new file mode 100644 index 0000000..dc4fb84 --- /dev/null +++ b/devops-mcp-server/rag/client/ragclient.go @@ -0,0 +1,153 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ragclient + +import ( + "bytes" + "context" + "devops-mcp-server/auth" + _ "embed" + "encoding/json" + "fmt" + "log" + + chromem "github.com/philippgille/chromem-go" +) + +//go:embed devops-rag.db +var embeddedDB []byte + +type RagClientImpl struct { + DB *chromem.DB + Pattern *chromem.Collection + Knowledge *chromem.Collection +} + +// Only expose what the LLM needs to read. +type Result struct { + Content string `json:"content"` + Metadata map[string]string `json:"metadata,omitempty"` // Source info + Similarity float32 `json:"relevance_score"` // Helps LLM weigh confidence +} + +type RagClient interface { + Queryknowledge(ctx context.Context, query string) (string, error) + QueryPatterns(ctx context.Context, query string) (string, error) +} + +// loadRAG performs the one-time initialization. +func loadRAG(ctx context.Context) (RagClient, error) { + ragClient := &RagClientImpl{DB: chromem.NewDB()} + reader := bytes.NewReader(embeddedDB) + err := ragClient.DB.ImportFromReader(reader, "") + if err != nil { + log.Printf("Unable to import from the RAG DB file: %v", err) + return nil, err + } + log.Printf("IMPORTED from the RAG DB collections: %v", len(ragClient.DB.ListCollections())) + + creds, err := auth.GetAuthToken(ctx) + if err != nil { + log.Printf("Error: Google Cloud account is required: %v", err) + // RETURN AN ERROR + return nil, fmt.Errorf("Google Cloud account is required: %w", err) + } + + vertexEmbeddingFunc := chromem.NewEmbeddingFuncVertex( + creds.Token, + creds.ProjectId, + chromem.EmbeddingModelVertexEnglishV4) + ragClient.Knowledge, err = ragClient.DB.GetOrCreateCollection("knowledge", nil, vertexEmbeddingFunc) + if err != nil { + return nil, fmt.Errorf("Unable to get collection knowledge: %w", err) + } + log.Printf("LOADED collection knowledge: %v", ragClient.Knowledge.Count()) + ragClient.Pattern, err = ragClient.DB.GetOrCreateCollection("pattern", nil, vertexEmbeddingFunc) + if err != nil { + return nil, fmt.Errorf("Unable to get collection pattern: %w", err) + } + log.Printf("LOADED collection pattern: %v", ragClient.Pattern.Count()) + + log.Print("RAG Init Completed!") + return ragClient, nil // Success +} + +// contextKey is a private type to use as a key for context values. +type contextKey string + +const ( + ragClientKey contextKey = "ragClient" +) + +// ClientFrom returns the RagClient stored in the context, if any. +func ClientFrom(ctx context.Context) (RagClient, bool) { + client, ok := ctx.Value(ragClientKey).(RagClient) + return client, ok +} + +// ContextWithClient returns a new context with the provided RagClient. +func ContextWithClient(ctx context.Context, client RagClient) context.Context { + return context.WithValue(ctx, ragClientKey, client) +} + +// NewClient creates a new Client. +func NewClient(ctx context.Context) (RagClient, error) { + return loadRAG(ctx) +} + + +func (r *RagClientImpl) QueryPatterns(ctx context.Context, query string) (string, error) { + results, err := r.Pattern.Query(ctx, query, 2, nil, nil) + if err != nil { + log.Fatalf("Unable to Query collection pattern: %v", err) + } + cleanResults := make([]Result, len(results)) + for i, r := range results { + cleanResults[i] = Result{ + Content: r.Content, + Metadata: r.Metadata, + Similarity: r.Similarity, + } + } + + // Marshal to JSON + jsonData, err := json.Marshal(cleanResults) + if err != nil { + return "", fmt.Errorf("failed to marshal results: %w", err) + } + return string(jsonData), nil +} + +func (r *RagClientImpl) Queryknowledge(ctx context.Context, query string) (string, error) { + results, err := r.Knowledge.Query(ctx, query, 2, nil, nil) + if err != nil { + log.Fatalf("Unable to Query collection knowledge: %v", err) + } + cleanResults := make([]Result, len(results)) + for i, r := range results { + cleanResults[i] = Result{ + Content: r.Content, + Metadata: r.Metadata, + Similarity: r.Similarity, + } + } + + // Marshal to JSON + jsonData, err := json.Marshal(cleanResults) + if err != nil { + return "", fmt.Errorf("failed to marshal results: %w", err) + } + return string(jsonData), nil +} diff --git a/devops-mcp-server/rag/rag.go b/devops-mcp-server/rag/rag.go index f602a57..80b2ccf 100644 --- a/devops-mcp-server/rag/rag.go +++ b/devops-mcp-server/rag/rag.go @@ -2,6 +2,7 @@ // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. +// You may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 @@ -15,127 +16,51 @@ package rag import ( - "bytes" "context" - "devops-mcp-server/auth" - _ "embed" - "encoding/json" "fmt" - "log" - "sync" - chromem "github.com/philippgille/chromem-go" -) - -var initOnce sync.Once - -//go:embed devops-rag.db -var embeddedDB []byte - -type RagData struct { - DB *chromem.DB - Pattern *chromem.Collection - Knowledge *chromem.Collection -} - -// Only expose what the LLM needs to read. -type Result struct { - Content string `json:"content"` - Metadata map[string]string `json:"metadata,omitempty"` // Source info - Similarity float32 `json:"relevance_score"` // Helps LLM weigh confidence -} - -var RagDB RagData - -// loadRAG performs the one-time initialization. -func loadRAG(ctx context.Context) error { - RagDB = RagData{DB: chromem.NewDB()} - reader := bytes.NewReader(embeddedDB) - err := RagDB.DB.ImportFromReader(reader, "") - if err != nil { - log.Printf("Unable to import from the RAG DB file: %v", err) - return err - } - log.Printf("IMPORTED from the RAG DB collections: %v", len(RagDB.DB.ListCollections())) + "github.com/modelcontextprotocol/go-sdk/mcp" - creds, err := auth.GetAuthToken(ctx) - if err != nil { - log.Printf("Error: Google Cloud account is required: %v", err) - // RETURN AN ERROR - return fmt.Errorf("Google Cloud account is required: %w", err) - } + ragclient "devops-mcp-server/rag/client" +) - vertexEmbeddingFunc := chromem.NewEmbeddingFuncVertex( - creds.Token, - creds.ProjectId, - chromem.EmbeddingModelVertexEnglishV4) - RagDB.Knowledge, err = RagDB.DB.GetOrCreateCollection("knowledge", nil, vertexEmbeddingFunc) - if err != nil { - return fmt.Errorf("Unable to get collection knowledge: %w", err) +// AddTools adds all rag related tools to the mcp server. +// It expects the ragclient and mcp.Server to be in the context. +func AddTools(ctx context.Context, server *mcp.Server) error { + r, ok := ragclient.ClientFrom(ctx) + if !ok { + return fmt.Errorf("rag client not found in context") } - log.Printf("LOADED collection knowledge: %v", RagDB.Pattern.Count()) - RagDB.Pattern, err = RagDB.DB.GetOrCreateCollection("pattern", nil, vertexEmbeddingFunc) - if err != nil { - return fmt.Errorf("Unable to get collection pattern: %w", err) - } - log.Printf("LOADED collection pattern: %v", RagDB.Pattern.Count()) - - log.Print("RAG Init Completed!") - return nil // Success + addQueryPatternTool(server, r) + addQueryKnowledgeTool(server, r) + return nil } -// GetRAG returns the initialized RAG database. -// It ensures initialization is performed exactly once. -func GetRAG(ctx context.Context) (RagData, error) { - var initErr error - initOnce.Do(func() { - // Use a background context for the one-time init - initErr = loadRAG(context.Background()) - }) - - return RagDB, initErr +type QueryArgs struct { + Query string `json:"query" jsonschema:"The query to search for."` } -func (r *RagData) QueryPattern(ctx context.Context, query string) (string, error) { - results, err := r.Pattern.Query(ctx, query, 2, nil, nil) - if err != nil { - log.Fatalf("Unable to Query collection pattern: %v", err) - } - cleanResults := make([]Result, len(results)) - for i, r := range results { - cleanResults[i] = Result{ - Content: r.Content, - Metadata: r.Metadata, - Similarity: r.Similarity, - } - } +var queryPatternToolFunc func(ctx context.Context, req *mcp.CallToolRequest, args QueryArgs) (*mcp.CallToolResult, any, error) +var queryKnowledgeToolFunc func(ctx context.Context, req *mcp.CallToolRequest, args QueryArgs) (*mcp.CallToolResult, any, error) - // Marshal to JSON - jsonData, err := json.Marshal(cleanResults) - if err != nil { - return "", fmt.Errorf("failed to marshal results: %w", err) +func addQueryPatternTool(server *mcp.Server, ragClient ragclient.RagClient) { + queryPatternToolFunc = func(ctx context.Context, req *mcp.CallToolRequest, args QueryArgs) (*mcp.CallToolResult, any, error) { + res, err := ragClient.QueryPatterns(ctx, args.Query) + if err != nil { + return &mcp.CallToolResult{}, nil, fmt.Errorf("failed to query patterns: %w", err) + } + return &mcp.CallToolResult{}, res, nil } - return string(jsonData), nil + mcp.AddTool(server, &mcp.Tool{Name: "rag.query_pattern", Description: "Queries the RAG pattern database."}, queryPatternToolFunc) } -func (r *RagData) Queryknowledge(ctx context.Context, query string) (string, error) { - results, err := r.Knowledge.Query(ctx, query, 2, nil, nil) - if err != nil { - log.Fatalf("Unable to Query collection knowledge: %v", err) - } - cleanResults := make([]Result, len(results)) - for i, r := range results { - cleanResults[i] = Result{ - Content: r.Content, - Metadata: r.Metadata, - Similarity: r.Similarity, +func addQueryKnowledgeTool(server *mcp.Server, ragClient ragclient.RagClient) { + queryKnowledgeToolFunc = func(ctx context.Context, req *mcp.CallToolRequest, args QueryArgs) (*mcp.CallToolResult, any, error) { + res, err := ragClient.Queryknowledge(ctx, args.Query) + if err != nil { + return &mcp.CallToolResult{}, nil, fmt.Errorf("failed to query knowledge: %w", err) } + return &mcp.CallToolResult{}, res, nil } - - // Marshal to JSON - jsonData, err := json.Marshal(cleanResults) - if err != nil { - return "", fmt.Errorf("failed to marshal results: %w", err) - } - return string(jsonData), nil -} + mcp.AddTool(server, &mcp.Tool{Name: "rag.query_knowledge", Description: "Queries the RAG knowledge database."}, queryKnowledgeToolFunc) +} \ No newline at end of file diff --git a/devops-mcp-server/rag/rag_test.go b/devops-mcp-server/rag/rag_test.go new file mode 100644 index 0000000..e68a6c4 --- /dev/null +++ b/devops-mcp-server/rag/rag_test.go @@ -0,0 +1,157 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rag + +import ( + "context" + "errors" + "testing" + + "github.com/golang/mock/gomock" + "github.com/modelcontextprotocol/go-sdk/mcp" + ragclient "devops-mcp-server/rag/client" + "devops-mcp-server/rag/client/mocks" +) + +func TestAddTools(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockRagClient := mocks.NewMockRagClient(ctrl) + + ctx = ragclient.ContextWithClient(ctx, mockRagClient) + + server := mcp.NewServer(&mcp.Implementation{Name: "test"}, &mcp.ServerOptions{}) + + err := AddTools(ctx, server) + if err != nil { + t.Fatalf("rag.AddTools failed: %v", err) + } +} + +func TestQueryPatternTool(t *testing.T) { + ctx := context.Background() + query := "test query" + + tests := []struct { + name string + setupMocks func(*mocks.MockRagClient) + expectErr bool + expectedResult any + expectedError string + }{ + { + name: "Success case", + setupMocks: func(mock *mocks.MockRagClient) { + mock.EXPECT().QueryPatterns(gomock.Any(), query).Return("mocked pattern result", nil) + }, + expectErr: false, + expectedResult: "mocked pattern result", + }, + { + name: "Error case", + setupMocks: func(mock *mocks.MockRagClient) { + mock.EXPECT().QueryPatterns(gomock.Any(), query).Return("", errors.New("query failed")) + }, + expectErr: true, + expectedError: "failed to query patterns: query failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockRagClient := mocks.NewMockRagClient(ctrl) + tt.setupMocks(mockRagClient) + + server := mcp.NewServer(&mcp.Implementation{Name: "test"}, &mcp.ServerOptions{}) + addQueryPatternTool(server, mockRagClient) + + _, res, err := queryPatternToolFunc(ctx, nil, QueryArgs{Query: query}) + + if (err != nil) != tt.expectErr { + t.Errorf("queryPatternToolFunc() error = %v, expectErr %v", err, tt.expectErr) + } + + if tt.expectErr && err.Error() != tt.expectedError { + t.Errorf("queryPatternToolFunc() error = %q, expectedError %q", err.Error(), tt.expectedError) + } + + if !tt.expectErr && res != tt.expectedResult { + t.Errorf("queryPatternToolFunc() result = %v, expectedResult %v", res, tt.expectedResult) + } + }) + } +} + +func TestQueryKnowledgeTool(t *testing.T) { + ctx := context.Background() + query := "test query" + + tests := []struct { + name string + setupMocks func(*mocks.MockRagClient) + expectErr bool + expectedResult any + expectedError string + }{ + { + name: "Success case", + setupMocks: func(mock *mocks.MockRagClient) { + mock.EXPECT().Queryknowledge(gomock.Any(), query).Return("mocked knowledge result", nil) + }, + expectErr: false, + expectedResult: "mocked knowledge result", + }, + { + name: "Error case", + setupMocks: func(mock *mocks.MockRagClient) { + mock.EXPECT().Queryknowledge(gomock.Any(), query).Return("", errors.New("query failed")) + }, + expectErr: true, + expectedError: "failed to query knowledge: query failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockRagClient := mocks.NewMockRagClient(ctrl) + tt.setupMocks(mockRagClient) + + server := mcp.NewServer(&mcp.Implementation{Name: "test"}, &mcp.ServerOptions{}) + addQueryKnowledgeTool(server, mockRagClient) + + _, res, err := queryKnowledgeToolFunc(ctx, nil, QueryArgs{Query: query}) + + if (err != nil) != tt.expectErr { + t.Errorf("queryKnowledgeToolFunc() error = %v, expectErr %v", err, tt.expectErr) + } + + if tt.expectErr && err.Error() != tt.expectedError { + t.Errorf("queryKnowledgeToolFunc() error = %q, expectedError %q", err.Error(), tt.expectedError) + } + + if !tt.expectErr && res != tt.expectedResult { + t.Errorf("queryKnowledgeToolFunc() result = %v, expectedResult %v", res, tt.expectedResult) + } + }) + } +} diff --git a/devops-mcp-server/server.go b/devops-mcp-server/server.go index f74cbe6..702cd7f 100644 --- a/devops-mcp-server/server.go +++ b/devops-mcp-server/server.go @@ -26,6 +26,7 @@ import ( "devops-mcp-server/devconnect" "devops-mcp-server/osv" "devops-mcp-server/prompts" + "devops-mcp-server/rag" artifactregistryclient "devops-mcp-server/artifactregistry/client" cloudbuildclient "devops-mcp-server/cloudbuild/client" @@ -34,6 +35,7 @@ import ( developerconnectclient "devops-mcp-server/devconnect/client" iamclient "devops-mcp-server/iam/client" osvclient "devops-mcp-server/osv/client" + ragclient "devops-mcp-server/rag/client" resourcemanagerclient "devops-mcp-server/resourcemanager/client" _ "embed" @@ -145,5 +147,14 @@ func addAllTools(ctx context.Context, server *mcp.Server) error { if err := osv.AddTools(ctxWithDeps, server); err != nil { return err } + + ragClient, err := ragclient.NewClient(ctxWithDeps) + if err != nil { + return fmt.Errorf("failed to create rag client: %w", err) + } + ctxWithDeps = ragclient.ContextWithClient(ctxWithDeps, ragClient) + if err := rag.AddTools(ctxWithDeps, server); err != nil { + return err + } return nil }