Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions pkg/handler/conversations.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,14 +333,40 @@ func (ch *ConversationsHandler) ConversationsSearchHandler(ctx context.Context,
}
ch.logger.Debug("Search completed", zap.Int("matches", len(messagesRes.Matches)))

messages := ch.convertMessagesFromSearch(messagesRes.Matches)
if len(messages) > 0 && messagesRes.Pagination.Page < messagesRes.Pagination.PageCount {
matches := messagesRes.Matches
if isSafeSearchEnabled() {
matches = filterSafeSearch(matches)
}

messages := ch.convertMessagesFromSearch(matches)
if len(messagesRes.Matches) > 0 && messagesRes.Pagination.Page < messagesRes.Pagination.PageCount {
nextCursor := fmt.Sprintf("page:%d", messagesRes.Pagination.Page+1)

if len(messages) == 0 {
// No messages after filtering; create a dummy message to hold the cursor
messages = append(messages, Message{})
}
messages[len(messages)-1].Cursor = base64.StdEncoding.EncodeToString([]byte(nextCursor))
}
return marshalMessagesToCSV(messages)
}

func isSafeSearchEnabled() bool {
val := strings.ToLower(strings.TrimSpace(os.Getenv("SLACK_MCP_SAFE_SEARCH")))
return val == "true" || val == "1" || val == "yes"
}

func filterSafeSearch(messages []slack.SearchMessage) []slack.SearchMessage {
filtered := make([]slack.SearchMessage, 0, len(messages))
for _, msg := range messages {
if msg.Channel.IsPrivate || msg.Channel.IsMPIM || strings.HasPrefix(msg.Channel.ID, "D") {
continue
}
filtered = append(filtered, msg)
}
return filtered
}

func isChannelAllowed(channel string) bool {
config := os.Getenv("SLACK_MCP_ADD_MESSAGE_TOOL")
if config == "" || config == "true" || config == "1" {
Expand Down Expand Up @@ -596,6 +622,12 @@ func (ch *ConversationsHandler) parseParamsToolSearch(req mcp.CallToolRequest) (
rawQuery := strings.TrimSpace(req.GetString("search_query", ""))
freeText, filters := splitQuery(rawQuery)

if isSafeSearchEnabled() {
if im := req.GetString("filter_in_im_or_mpim", ""); im != "" {
return nil, errors.New("filter_in_im_or_mpim is not allowed when SLACK_MCP_SAFE_SEARCH is enabled")
}
}

if req.GetBool("filter_threads_only", false) {
addFilter(filters, "is", "thread")
}
Expand Down
92 changes: 92 additions & 0 deletions pkg/handler/conversations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/param"
"github.com/openai/openai-go/responses"
"github.com/slack-go/slack"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -590,3 +591,94 @@ func TestUnitLimitByExpression_Invalid(t *testing.T) {
})
}
}

func TestUnitFilterSafeSearch(t *testing.T) {
tests := []struct {
name string
input []slack.SearchMessage
expectedIDs []string // expected channel IDs after filtering
}{
{
name: "empty input",
input: []slack.SearchMessage{},
expectedIDs: []string{},
},
{
name: "public channel included",
input: []slack.SearchMessage{
{Channel: slack.CtxChannel{ID: "C123", Name: "general", IsPrivate: false, IsMPIM: false}},
},
expectedIDs: []string{"C123"},
},
{
name: "private channel excluded",
input: []slack.SearchMessage{
{Channel: slack.CtxChannel{ID: "C123", Name: "secret", IsPrivate: true, IsMPIM: false}},
},
expectedIDs: []string{},
},
{
name: "MPIM excluded",
input: []slack.SearchMessage{
{Channel: slack.CtxChannel{ID: "G123", Name: "mpdm-user1-user2", IsPrivate: false, IsMPIM: true}},
},
expectedIDs: []string{},
},
{
name: "DM excluded by ID prefix",
input: []slack.SearchMessage{
{Channel: slack.CtxChannel{ID: "D123", Name: "", IsPrivate: false, IsMPIM: false}},
},
expectedIDs: []string{},
},
{
name: "mixed results",
input: []slack.SearchMessage{
{Channel: slack.CtxChannel{ID: "C001", Name: "public1", IsPrivate: false, IsMPIM: false}},
{Channel: slack.CtxChannel{ID: "C002", Name: "private1", IsPrivate: true, IsMPIM: false}},
{Channel: slack.CtxChannel{ID: "D003", Name: "", IsPrivate: false, IsMPIM: false}},
{Channel: slack.CtxChannel{ID: "C004", Name: "public2", IsPrivate: false, IsMPIM: false}},
{Channel: slack.CtxChannel{ID: "G005", Name: "mpim", IsPrivate: false, IsMPIM: true}},
{Channel: slack.CtxChannel{ID: "C006", Name: "public3", IsPrivate: false, IsMPIM: false}},
},
expectedIDs: []string{"C001", "C004", "C006"},
},
{
name: "all filtered out",
input: []slack.SearchMessage{
{Channel: slack.CtxChannel{ID: "D001", Name: "", IsPrivate: false, IsMPIM: false}},
{Channel: slack.CtxChannel{ID: "C002", Name: "private", IsPrivate: true, IsMPIM: false}},
},
expectedIDs: []string{},
},
{
name: "private channel with G prefix excluded",
input: []slack.SearchMessage{
{Channel: slack.CtxChannel{ID: "G123", Name: "private-group", IsPrivate: true, IsMPIM: false}},
},
expectedIDs: []string{},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := filterSafeSearch(tt.input)

gotIDs := make([]string, len(result))
for i, msg := range result {
gotIDs[i] = msg.Channel.ID
}

if len(gotIDs) != len(tt.expectedIDs) {
t.Errorf("filterSafeSearch() returned %d messages (IDs: %v), want %d (IDs: %v)",
len(gotIDs), gotIDs, len(tt.expectedIDs), tt.expectedIDs)
return
}
for i, id := range gotIDs {
if id != tt.expectedIDs[i] {
t.Errorf("filterSafeSearch() result[%d].Channel.ID = %s, want %s", i, id, tt.expectedIDs[i])
}
}
})
}
}