diff --git a/internal/searchindex/service.go b/internal/searchindex/service.go index 778a340..c0531c7 100644 --- a/internal/searchindex/service.go +++ b/internal/searchindex/service.go @@ -181,10 +181,7 @@ func (s *Service) SearchMentions(ctx context.Context, repositoryID uint, request if err != nil { return nil, err } - if request.Mode == ModeFuzzy { - return s.searchFallback(ctx, repositoryID, request, documentTypes) - } - if s.db.Dialector.Name() == "postgres" { + if request.Mode == ModeFTS && s.db.Dialector.Name() == "postgres" { return s.searchPostgres(ctx, repositoryID, request, documentTypes) } return s.searchFallback(ctx, repositoryID, request, documentTypes) @@ -199,15 +196,6 @@ func (s *Service) searchPostgres(ctx context.Context, repositoryID uint, request Where("to_tsvector('simple', search_text) @@ websearch_to_tsquery('simple', ?)", request.Query). Order("score DESC"). Order("object_updated_at DESC") - case ModeRegex: - if _, err := regexp.Compile(request.Query); err != nil { - return nil, invalidMentionRequest("invalid regex") - } - query = query. - Select("search_documents.*, 1.0 AS score"). - Where("(title_text ~* ? OR body_text ~* ?)", request.Query, request.Query). - Order("object_updated_at DESC"). - Order("document_github_id ASC") default: return nil, invalidMentionRequest("unsupported mode") } @@ -223,7 +211,10 @@ func (s *Service) searchPostgres(ctx context.Context, repositoryID uint, request } func (s *Service) searchFallback(ctx context.Context, repositoryID uint, request MentionRequest, documentTypes []string) ([]MentionMatch, error) { - query := s.baseQuery(ctx, repositoryID, request, documentTypes).Order("object_updated_at DESC") + query, err := s.prefilterFallbackQuery(ctx, repositoryID, request, documentTypes) + if err != nil { + return nil, err + } var docs []database.SearchDocument if err := query.Find(&docs).Error; err != nil { return nil, err @@ -264,6 +255,41 @@ func (s *Service) searchFallback(ctx context.Context, repositoryID uint, request return buildMentionMatches(scored[start:end], request) } +func (s *Service) prefilterFallbackQuery(ctx context.Context, repositoryID uint, request MentionRequest, documentTypes []string) (*gorm.DB, error) { + query := s.baseQuery(ctx, repositoryID, request, documentTypes).Order("object_updated_at DESC") + switch request.Mode { + case ModeFuzzy: + return applyFallbackCandidateFilter(query, fuzzyCandidateFragments(request.Query)), nil + case ModeRegex: + if _, err := regexp.Compile(request.Query); err != nil { + return nil, invalidMentionRequest("invalid regex") + } + return applyFallbackCandidateFilter(query, regexCandidateFragments(request.Query)), nil + default: + return query, nil + } +} + +func applyFallbackCandidateFilter(query *gorm.DB, fragments []string) *gorm.DB { + if len(fragments) == 0 { + return query + } + clauses := make([]string, 0, len(fragments)) + args := make([]any, 0, len(fragments)) + for _, fragment := range fragments { + fragment = strings.TrimSpace(fragment) + if fragment == "" { + continue + } + clauses = append(clauses, "normalized_text LIKE ?") + args = append(args, "%"+fragment+"%") + } + if len(clauses) == 0 { + return query + } + return query.Where("("+strings.Join(clauses, " OR ")+")", args...) +} + func (s *Service) baseQuery(ctx context.Context, repositoryID uint, request MentionRequest, documentTypes []string) *gorm.DB { query := s.db.WithContext(ctx).Model(&database.SearchDocument{}).Where("repository_id = ?", repositoryID) if len(documentTypes) > 0 { @@ -766,6 +792,88 @@ func normalizeSearchText(text string) string { return strings.TrimSpace(b.String()) } +func fuzzyCandidateFragments(query string) []string { + normalized := normalizeSearchText(query) + if normalized == "" { + return nil + } + fragments := make([]string, 0, 8) + for _, token := range strings.Fields(normalized) { + if len(token) >= 3 { + fragments = append(fragments, token) + } + } + collapsed := strings.ReplaceAll(normalized, " ", "") + if len(collapsed) >= 3 { + fragments = append(fragments, collapsed) + } + return rankedCandidateFragments(fragments) +} + +func regexCandidateFragments(pattern string) []string { + var fragments []string + var current strings.Builder + flush := func() { + if current.Len() >= 3 { + fragments = append(fragments, strings.ToLower(current.String())) + } + current.Reset() + } + escaped := false + for _, r := range pattern { + if escaped { + if unicode.IsLetter(r) || unicode.IsNumber(r) { + current.WriteRune(unicode.ToLower(r)) + } else { + flush() + } + escaped = false + continue + } + if r == '\\' { + flush() + escaped = true + continue + } + if unicode.IsLetter(r) || unicode.IsNumber(r) { + current.WriteRune(unicode.ToLower(r)) + continue + } + flush() + } + flush() + return rankedCandidateFragments(fragments) +} + +func rankedCandidateFragments(fragments []string) []string { + if len(fragments) == 0 { + return nil + } + seen := make(map[string]struct{}, len(fragments)) + unique := make([]string, 0, len(fragments)) + for _, fragment := range fragments { + fragment = strings.TrimSpace(fragment) + if fragment == "" { + continue + } + if _, ok := seen[fragment]; ok { + continue + } + seen[fragment] = struct{}{} + unique = append(unique, fragment) + } + sort.SliceStable(unique, func(i, j int) bool { + if len(unique[i]) == len(unique[j]) { + return unique[i] < unique[j] + } + return len(unique[i]) > len(unique[j]) + }) + if len(unique) > 6 { + unique = unique[:6] + } + return unique +} + func trigramSimilarity(a, b string) float64 { if a == "" || b == "" { return 0