Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 122 additions & 14 deletions internal/searchindex/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,7 @@ func (s *Service) SearchMentions(ctx context.Context, repositoryID uint, request
if err != nil {
return nil, err
}
if request.Mode == ModeFuzzy {
return s.searchFallback(ctx, repositoryID, request, documentTypes)
}
if s.db.Dialector.Name() == "postgres" {
if request.Mode == ModeFTS && s.db.Dialector.Name() == "postgres" {
return s.searchPostgres(ctx, repositoryID, request, documentTypes)
}
return s.searchFallback(ctx, repositoryID, request, documentTypes)
Expand All @@ -199,15 +196,6 @@ func (s *Service) searchPostgres(ctx context.Context, repositoryID uint, request
Where("to_tsvector('simple', search_text) @@ websearch_to_tsquery('simple', ?)", request.Query).
Order("score DESC").
Order("object_updated_at DESC")
case ModeRegex:
if _, err := regexp.Compile(request.Query); err != nil {
return nil, invalidMentionRequest("invalid regex")
}
query = query.
Select("search_documents.*, 1.0 AS score").
Where("(title_text ~* ? OR body_text ~* ?)", request.Query, request.Query).
Order("object_updated_at DESC").
Order("document_github_id ASC")
default:
return nil, invalidMentionRequest("unsupported mode")
}
Expand All @@ -223,7 +211,10 @@ func (s *Service) searchPostgres(ctx context.Context, repositoryID uint, request
}

func (s *Service) searchFallback(ctx context.Context, repositoryID uint, request MentionRequest, documentTypes []string) ([]MentionMatch, error) {
query := s.baseQuery(ctx, repositoryID, request, documentTypes).Order("object_updated_at DESC")
query, err := s.prefilterFallbackQuery(ctx, repositoryID, request, documentTypes)
if err != nil {
return nil, err
}
var docs []database.SearchDocument
if err := query.Find(&docs).Error; err != nil {
return nil, err
Expand Down Expand Up @@ -264,6 +255,41 @@ func (s *Service) searchFallback(ctx context.Context, repositoryID uint, request
return buildMentionMatches(scored[start:end], request)
}

func (s *Service) prefilterFallbackQuery(ctx context.Context, repositoryID uint, request MentionRequest, documentTypes []string) (*gorm.DB, error) {
query := s.baseQuery(ctx, repositoryID, request, documentTypes).Order("object_updated_at DESC")
switch request.Mode {
case ModeFuzzy:
return applyFallbackCandidateFilter(query, fuzzyCandidateFragments(request.Query)), nil
case ModeRegex:
if _, err := regexp.Compile(request.Query); err != nil {
return nil, invalidMentionRequest("invalid regex")
}
return applyFallbackCandidateFilter(query, regexCandidateFragments(request.Query)), nil
default:
return query, nil
}
}

func applyFallbackCandidateFilter(query *gorm.DB, fragments []string) *gorm.DB {
if len(fragments) == 0 {
return query
}
clauses := make([]string, 0, len(fragments))
args := make([]any, 0, len(fragments))
for _, fragment := range fragments {
fragment = strings.TrimSpace(fragment)
if fragment == "" {
continue
}
clauses = append(clauses, "normalized_text LIKE ?")
args = append(args, "%"+fragment+"%")
}
if len(clauses) == 0 {
return query
}
return query.Where("("+strings.Join(clauses, " OR ")+")", args...)
}

func (s *Service) baseQuery(ctx context.Context, repositoryID uint, request MentionRequest, documentTypes []string) *gorm.DB {
query := s.db.WithContext(ctx).Model(&database.SearchDocument{}).Where("repository_id = ?", repositoryID)
if len(documentTypes) > 0 {
Expand Down Expand Up @@ -766,6 +792,88 @@ func normalizeSearchText(text string) string {
return strings.TrimSpace(b.String())
}

func fuzzyCandidateFragments(query string) []string {
normalized := normalizeSearchText(query)
if normalized == "" {
return nil
}
fragments := make([]string, 0, 8)
for _, token := range strings.Fields(normalized) {
if len(token) >= 3 {
fragments = append(fragments, token)
}
}
collapsed := strings.ReplaceAll(normalized, " ", "")
if len(collapsed) >= 3 {
fragments = append(fragments, collapsed)
}
return rankedCandidateFragments(fragments)
}

func regexCandidateFragments(pattern string) []string {
var fragments []string
var current strings.Builder
flush := func() {
if current.Len() >= 3 {
fragments = append(fragments, strings.ToLower(current.String()))
}
current.Reset()
}
escaped := false
for _, r := range pattern {
if escaped {
if unicode.IsLetter(r) || unicode.IsNumber(r) {
current.WriteRune(unicode.ToLower(r))
} else {
flush()
}
escaped = false
continue
}
if r == '\\' {
flush()
escaped = true
continue
}
if unicode.IsLetter(r) || unicode.IsNumber(r) {
current.WriteRune(unicode.ToLower(r))
continue
}
flush()
}
flush()
return rankedCandidateFragments(fragments)
}

func rankedCandidateFragments(fragments []string) []string {
if len(fragments) == 0 {
return nil
}
seen := make(map[string]struct{}, len(fragments))
unique := make([]string, 0, len(fragments))
for _, fragment := range fragments {
fragment = strings.TrimSpace(fragment)
if fragment == "" {
continue
}
if _, ok := seen[fragment]; ok {
continue
}
seen[fragment] = struct{}{}
unique = append(unique, fragment)
}
sort.SliceStable(unique, func(i, j int) bool {
if len(unique[i]) == len(unique[j]) {
return unique[i] < unique[j]
}
return len(unique[i]) > len(unique[j])
})
if len(unique) > 6 {
unique = unique[:6]
}
return unique
}

func trigramSimilarity(a, b string) float64 {
if a == "" || b == "" {
return 0
Expand Down
Loading