Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions internal/logger/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,23 @@ func initLogFile(logDir, fileName string, flags int) (*os.File, error) {
return file, nil
}

// atomicWriteFile writes data to filePath atomically using a temp-file + rename strategy.
// On rename failure the temp file is removed; a removal error that is not os.IsNotExist
// is logged as a warning but does not mask the primary rename error.
func atomicWriteFile(filePath string, data []byte, perm os.FileMode) error {
tempPath := filePath + ".tmp"
if err := os.WriteFile(tempPath, data, perm); err != nil {
return fmt.Errorf("failed to write temp file: %w", err)
}
if err := os.Rename(tempPath, filePath); err != nil {
if removeErr := os.Remove(tempPath); removeErr != nil && !os.IsNotExist(removeErr) {
log.Printf("WARNING: Failed to cleanup temp file %s: %v", tempPath, removeErr)
}
return fmt.Errorf("failed to rename temp file: %w", err)
}
return nil
}

// loggerSetupFunc is a function type that sets up a logger instance after the log file is opened.
// It receives the opened file, logDir, and fileName, and returns the configured logger.
type loggerSetupFunc[T closableLogger] func(file *os.File, logDir, fileName string) (T, error)
Expand Down
22 changes: 4 additions & 18 deletions internal/logger/observed_url_domains_logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ import (
"log"
"os"
"path/filepath"
"sort"
"sync"
"sync/atomic"

"github.com/github/gh-aw-mcpg/internal/strutil"
)

const observedURLDomainsFileName = "observed-url-domains.json"
Expand Down Expand Up @@ -117,12 +118,7 @@ func (l *ObservedURLDomainsLogger) LogDomains(serverID string, domains []string)
func (l *ObservedURLDomainsLogger) writeToFile() error {
serialized := make(map[string][]string, len(l.data))
for serverID, domains := range l.data {
items := make([]string, 0, len(domains))
for domain := range domains {
items = append(items, domain)
}
sort.Strings(items)
serialized[serverID] = items
serialized[serverID] = strutil.SortedSetKeys(domains)
}

jsonData, err := json.MarshalIndent(serialized, "", " ")
Expand All @@ -131,17 +127,7 @@ func (l *ObservedURLDomainsLogger) writeToFile() error {
}

filePath := filepath.Join(l.logDir, l.fileName)
tempPath := filePath + ".tmp"
if err := os.WriteFile(tempPath, jsonData, 0600); err != nil {
return fmt.Errorf("failed to write temp file: %w", err)
}
if err := os.Rename(tempPath, filePath); err != nil {
if removeErr := os.Remove(tempPath); removeErr != nil && !os.IsNotExist(removeErr) {
log.Printf("WARNING: Failed to cleanup temp observed URL domains file %s: %v", tempPath, removeErr)
}
return fmt.Errorf("failed to rename temp file: %w", err)
}
return nil
return atomicWriteFile(filePath, jsonData, 0600)
}

func (l *ObservedURLDomainsLogger) Close() error { return nil }
Expand Down
14 changes: 1 addition & 13 deletions internal/logger/tools_logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,19 +112,7 @@ func (tl *ToolsLogger) writeToFile() error {
return fmt.Errorf("failed to marshal tools data: %w", err)
}

// Write to file atomically using a temp file + rename
tempPath := filePath + ".tmp"
if err := os.WriteFile(tempPath, jsonData, 0644); err != nil {
return fmt.Errorf("failed to write temp file: %w", err)
}

if err := os.Rename(tempPath, filePath); err != nil {
// Clean up temp file on error
os.Remove(tempPath)
return fmt.Errorf("failed to rename temp file: %w", err)
}

return nil
return atomicWriteFile(filePath, jsonData, 0644)
}

// Close is a no-op for ToolsLogger (implements closableLogger interface)
Expand Down
11 changes: 11 additions & 0 deletions internal/strutil/strutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@ import (
"strings"
)

// SortedSetKeys returns the keys of a string set (map[string]struct{}) as a sorted slice.
// Returns an empty (non-nil) slice when the set is empty.
func SortedSetKeys(set map[string]struct{}) []string {
keys := make([]string, 0, len(set))
for k := range set {
keys = append(keys, k)
}
sort.Strings(keys)
return keys
}

// DeduplicateStrings returns a new slice with whitespace-trimmed, empty, and duplicate
// entries removed from input. When sorted is true the result is sorted in ascending order.
// The relative order of first-seen entries is preserved when sorted is false.
Expand Down
26 changes: 26 additions & 0 deletions internal/strutil/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,32 @@ import (
"github.com/stretchr/testify/assert"
)

func TestSortedSetKeys(t *testing.T) {
t.Parallel()

t.Run("returns sorted keys", func(t *testing.T) {
t.Parallel()
set := map[string]struct{}{"banana": {}, "apple": {}, "cherry": {}}
assert.Equal(t, []string{"apple", "banana", "cherry"}, SortedSetKeys(set))
})

t.Run("returns empty slice for empty set", func(t *testing.T) {
t.Parallel()
assert.Empty(t, SortedSetKeys(map[string]struct{}{}))
})

t.Run("returns single element slice", func(t *testing.T) {
t.Parallel()
set := map[string]struct{}{"only": {}}
assert.Equal(t, []string{"only"}, SortedSetKeys(set))
})

t.Run("handles nil map", func(t *testing.T) {
t.Parallel()
assert.Empty(t, SortedSetKeys(nil))
})
}

func TestGetStringFromMap(t *testing.T) {
t.Parallel()

Expand Down
14 changes: 3 additions & 11 deletions internal/urlutil/domains.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package urlutil
import (
"net/url"
"regexp"
"sort"
"strings"

"github.com/github/gh-aw-mcpg/internal/logger"
"github.com/github/gh-aw-mcpg/internal/strutil"
)

var logDomains = logger.New("urlutil:domains")
Expand All @@ -26,11 +26,7 @@ func ExtractURLDomainsFromValue(value any) []string {
return nil
}

domains := make([]string, 0, len(domainSet))
for domain := range domainSet {
domains = append(domains, domain)
}
sort.Strings(domains)
domains := strutil.SortedSetKeys(domainSet)
logDomains.Printf("ExtractURLDomainsFromValue: extracted %d unique domain(s)", len(domains))
return domains
}
Expand Down Expand Up @@ -91,11 +87,7 @@ func ExtractURLDomains(text string) []string {
if len(domainSet) == 0 {
return nil
}
domains := make([]string, 0, len(domainSet))
for domain := range domainSet {
domains = append(domains, domain)
}
sort.Strings(domains)
domains := strutil.SortedSetKeys(domainSet)
logDomains.Printf("ExtractURLDomains: resolved %d unique domain(s) from %d candidate(s)", len(domains), len(matches))
return domains
}
Loading