diff --git a/pkg/handler/conversations.go b/pkg/handler/conversations.go index 65395ade..004b927b 100644 --- a/pkg/handler/conversations.go +++ b/pkg/handler/conversations.go @@ -41,20 +41,20 @@ var validFilterKeys = map[string]struct{}{ } type Message struct { - MsgID string `json:"msgID"` - UserID string `json:"userID"` - UserName string `json:"userUser"` - RealName string `json:"realName"` - Channel string `json:"channelID"` - ThreadTs string `json:"ThreadTs"` - Text string `json:"text"` - Time string `json:"time"` - Reactions string `json:"reactions,omitempty"` - BotName string `json:"botName,omitempty"` - FileCount int `json:"fileCount,omitempty"` - AttachmentIDs string `json:"attachmentIDs,omitempty"` - HasMedia bool `json:"hasMedia,omitempty"` - Cursor string `json:"cursor"` + MsgID string `json:"msgID"` + UserID string `json:"userID"` + UserName string `json:"userUser"` + RealName string `json:"realName"` + Channel string `json:"channelID"` + ThreadTs string `json:"ThreadTs"` + Text string `json:"text"` + Time string `json:"time"` + Reactions string `json:"reactions,omitempty"` + BotName string `json:"botName,omitempty"` + FileCount int `json:"fileCount,omitempty"` + AttachmentIDs string `json:"attachmentIDs,omitempty"` + HasMedia bool `json:"hasMedia,omitempty"` + Cursor string `json:"cursor"` } type User struct { @@ -591,9 +591,19 @@ func (ch *ConversationsHandler) ConversationsSearchHandler(ctx context.Context, } ch.logger.Debug("Search completed", zap.Int("matches", len(messagesRes.Matches))) - messages := ch.convertMessagesFromSearch(messagesRes.Matches) - if len(messages) > 0 && messagesRes.Pagination.Page < messagesRes.Pagination.PageCount { + matches := messagesRes.Matches + if isSafeSearchEnabled() { + matches = filterSafeSearch(matches) + } + + messages := ch.convertMessagesFromSearch(matches) + if len(messagesRes.Matches) > 0 && messagesRes.Pagination.Page < messagesRes.Pagination.PageCount { nextCursor := fmt.Sprintf("page:%d", messagesRes.Pagination.Page+1) + + if len(messages) == 0 { + // No messages after filtering; create a dummy message to hold the cursor + messages = append(messages, Message{}) + } messages[len(messages)-1].Cursor = base64.StdEncoding.EncodeToString([]byte(nextCursor)) } return marshalMessagesToCSV(messages) @@ -620,6 +630,22 @@ func isChannelAllowedForConfig(channel, config string) bool { return isNegated } +func isSafeSearchEnabled() bool { + val := strings.ToLower(strings.TrimSpace(os.Getenv("SLACK_MCP_SAFE_SEARCH"))) + return val == "true" || val == "1" || val == "yes" +} + +func filterSafeSearch(messages []slack.SearchMessage) []slack.SearchMessage { + filtered := make([]slack.SearchMessage, 0, len(messages)) + for _, msg := range messages { + if msg.Channel.IsPrivate || msg.Channel.IsMPIM || strings.HasPrefix(msg.Channel.ID, "D") { + continue + } + filtered = append(filtered, msg) + } + return filtered +} + func isChannelAllowed(channel string) bool { return isChannelAllowedForConfig(channel, os.Getenv("SLACK_MCP_ADD_MESSAGE_TOOL")) } @@ -722,19 +748,19 @@ func (ch *ConversationsHandler) convertMessagesFromHistory(slackMessages []slack attachmentIDsStr := strings.Join(attachmentIDs, ",") messages = append(messages, Message{ - MsgID: msg.Timestamp, - UserID: msg.User, - UserName: userName, - RealName: realName, - Text: text.ProcessText(msgText), - Channel: channel, - ThreadTs: msg.ThreadTimestamp, - Time: timestamp, - Reactions: reactionsString, - BotName: botName, - FileCount: fileCount, - AttachmentIDs: attachmentIDsStr, - HasMedia: hasMedia, + MsgID: msg.Timestamp, + UserID: msg.User, + UserName: userName, + RealName: realName, + Text: text.ProcessText(msgText), + Channel: channel, + ThreadTs: msg.ThreadTimestamp, + Time: timestamp, + Reactions: reactionsString, + BotName: botName, + FileCount: fileCount, + AttachmentIDs: attachmentIDsStr, + HasMedia: hasMedia, }) } @@ -1028,6 +1054,12 @@ func (ch *ConversationsHandler) parseParamsToolSearch(req mcp.CallToolRequest) ( rawQuery := strings.TrimSpace(req.GetString("search_query", "")) freeText, filters := splitQuery(rawQuery) + if isSafeSearchEnabled() { + if im := req.GetString("filter_in_im_or_mpim", ""); im != "" { + return nil, errors.New("filter_in_im_or_mpim is not allowed when SLACK_MCP_SAFE_SEARCH is enabled") + } + } + if req.GetBool("filter_threads_only", false) { addFilter(filters, "is", "thread") } diff --git a/pkg/handler/conversations_test.go b/pkg/handler/conversations_test.go index 3625ea91..16c49195 100644 --- a/pkg/handler/conversations_test.go +++ b/pkg/handler/conversations_test.go @@ -17,6 +17,7 @@ import ( "github.com/openai/openai-go/option" "github.com/openai/openai-go/packages/param" "github.com/openai/openai-go/responses" + "github.com/slack-go/slack" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -653,3 +654,94 @@ func TestUnitIsSlackUserIDPrefix(t *testing.T) { }) } } + +func TestUnitFilterSafeSearch(t *testing.T) { + tests := []struct { + name string + input []slack.SearchMessage + expectedIDs []string // expected channel IDs after filtering + }{ + { + name: "empty input", + input: []slack.SearchMessage{}, + expectedIDs: []string{}, + }, + { + name: "public channel included", + input: []slack.SearchMessage{ + {Channel: slack.CtxChannel{ID: "C123", Name: "general", IsPrivate: false, IsMPIM: false}}, + }, + expectedIDs: []string{"C123"}, + }, + { + name: "private channel excluded", + input: []slack.SearchMessage{ + {Channel: slack.CtxChannel{ID: "C123", Name: "secret", IsPrivate: true, IsMPIM: false}}, + }, + expectedIDs: []string{}, + }, + { + name: "MPIM excluded", + input: []slack.SearchMessage{ + {Channel: slack.CtxChannel{ID: "G123", Name: "mpdm-user1-user2", IsPrivate: false, IsMPIM: true}}, + }, + expectedIDs: []string{}, + }, + { + name: "DM excluded by ID prefix", + input: []slack.SearchMessage{ + {Channel: slack.CtxChannel{ID: "D123", Name: "", IsPrivate: false, IsMPIM: false}}, + }, + expectedIDs: []string{}, + }, + { + name: "mixed results", + input: []slack.SearchMessage{ + {Channel: slack.CtxChannel{ID: "C001", Name: "public1", IsPrivate: false, IsMPIM: false}}, + {Channel: slack.CtxChannel{ID: "C002", Name: "private1", IsPrivate: true, IsMPIM: false}}, + {Channel: slack.CtxChannel{ID: "D003", Name: "", IsPrivate: false, IsMPIM: false}}, + {Channel: slack.CtxChannel{ID: "C004", Name: "public2", IsPrivate: false, IsMPIM: false}}, + {Channel: slack.CtxChannel{ID: "G005", Name: "mpim", IsPrivate: false, IsMPIM: true}}, + {Channel: slack.CtxChannel{ID: "C006", Name: "public3", IsPrivate: false, IsMPIM: false}}, + }, + expectedIDs: []string{"C001", "C004", "C006"}, + }, + { + name: "all filtered out", + input: []slack.SearchMessage{ + {Channel: slack.CtxChannel{ID: "D001", Name: "", IsPrivate: false, IsMPIM: false}}, + {Channel: slack.CtxChannel{ID: "C002", Name: "private", IsPrivate: true, IsMPIM: false}}, + }, + expectedIDs: []string{}, + }, + { + name: "private channel with G prefix excluded", + input: []slack.SearchMessage{ + {Channel: slack.CtxChannel{ID: "G123", Name: "private-group", IsPrivate: true, IsMPIM: false}}, + }, + expectedIDs: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := filterSafeSearch(tt.input) + + gotIDs := make([]string, len(result)) + for i, msg := range result { + gotIDs[i] = msg.Channel.ID + } + + if len(gotIDs) != len(tt.expectedIDs) { + t.Errorf("filterSafeSearch() returned %d messages (IDs: %v), want %d (IDs: %v)", + len(gotIDs), gotIDs, len(tt.expectedIDs), tt.expectedIDs) + return + } + for i, id := range gotIDs { + if id != tt.expectedIDs[i] { + t.Errorf("filterSafeSearch() result[%d].Channel.ID = %s, want %s", i, id, tt.expectedIDs[i]) + } + } + }) + } +}