diff --git a/internal/logger/common.go b/internal/logger/common.go index b04d4590..b86aa82b 100644 --- a/internal/logger/common.go +++ b/internal/logger/common.go @@ -393,6 +393,23 @@ func initLogFile(logDir, fileName string, flags int) (*os.File, error) { return file, nil } +// atomicWriteFile writes data to filePath atomically using a temp-file + rename strategy. +// On rename failure the temp file is removed; a removal error that is not os.IsNotExist +// is logged as a warning but does not mask the primary rename error. +func atomicWriteFile(filePath string, data []byte, perm os.FileMode) error { + tempPath := filePath + ".tmp" + if err := os.WriteFile(tempPath, data, perm); err != nil { + return fmt.Errorf("failed to write temp file: %w", err) + } + if err := os.Rename(tempPath, filePath); err != nil { + if removeErr := os.Remove(tempPath); removeErr != nil && !os.IsNotExist(removeErr) { + log.Printf("WARNING: Failed to cleanup temp file %s: %v", tempPath, removeErr) + } + return fmt.Errorf("failed to rename temp file: %w", err) + } + return nil +} + // loggerSetupFunc is a function type that sets up a logger instance after the log file is opened. // It receives the opened file, logDir, and fileName, and returns the configured logger. type loggerSetupFunc[T closableLogger] func(file *os.File, logDir, fileName string) (T, error) diff --git a/internal/logger/observed_url_domains_logger.go b/internal/logger/observed_url_domains_logger.go index 2b25e708..31632280 100644 --- a/internal/logger/observed_url_domains_logger.go +++ b/internal/logger/observed_url_domains_logger.go @@ -6,9 +6,10 @@ import ( "log" "os" "path/filepath" - "sort" "sync" "sync/atomic" + + "github.com/github/gh-aw-mcpg/internal/strutil" ) const observedURLDomainsFileName = "observed-url-domains.json" @@ -117,12 +118,7 @@ func (l *ObservedURLDomainsLogger) LogDomains(serverID string, domains []string) func (l *ObservedURLDomainsLogger) writeToFile() error { serialized := make(map[string][]string, len(l.data)) for serverID, domains := range l.data { - items := make([]string, 0, len(domains)) - for domain := range domains { - items = append(items, domain) - } - sort.Strings(items) - serialized[serverID] = items + serialized[serverID] = strutil.SortedSetKeys(domains) } jsonData, err := json.MarshalIndent(serialized, "", " ") @@ -131,17 +127,7 @@ func (l *ObservedURLDomainsLogger) writeToFile() error { } filePath := filepath.Join(l.logDir, l.fileName) - tempPath := filePath + ".tmp" - if err := os.WriteFile(tempPath, jsonData, 0600); err != nil { - return fmt.Errorf("failed to write temp file: %w", err) - } - if err := os.Rename(tempPath, filePath); err != nil { - if removeErr := os.Remove(tempPath); removeErr != nil && !os.IsNotExist(removeErr) { - log.Printf("WARNING: Failed to cleanup temp observed URL domains file %s: %v", tempPath, removeErr) - } - return fmt.Errorf("failed to rename temp file: %w", err) - } - return nil + return atomicWriteFile(filePath, jsonData, 0600) } func (l *ObservedURLDomainsLogger) Close() error { return nil } diff --git a/internal/logger/tools_logger.go b/internal/logger/tools_logger.go index 8f4b9c00..bd1fd39e 100644 --- a/internal/logger/tools_logger.go +++ b/internal/logger/tools_logger.go @@ -112,19 +112,7 @@ func (tl *ToolsLogger) writeToFile() error { return fmt.Errorf("failed to marshal tools data: %w", err) } - // Write to file atomically using a temp file + rename - tempPath := filePath + ".tmp" - if err := os.WriteFile(tempPath, jsonData, 0644); err != nil { - return fmt.Errorf("failed to write temp file: %w", err) - } - - if err := os.Rename(tempPath, filePath); err != nil { - // Clean up temp file on error - os.Remove(tempPath) - return fmt.Errorf("failed to rename temp file: %w", err) - } - - return nil + return atomicWriteFile(filePath, jsonData, 0644) } // Close is a no-op for ToolsLogger (implements closableLogger interface) diff --git a/internal/strutil/strutil.go b/internal/strutil/strutil.go index 9dc59fa4..441090ab 100644 --- a/internal/strutil/strutil.go +++ b/internal/strutil/strutil.go @@ -5,6 +5,17 @@ import ( "strings" ) +// SortedSetKeys returns the keys of a string set (map[string]struct{}) as a sorted slice. +// Returns an empty (non-nil) slice when the set is empty. +func SortedSetKeys(set map[string]struct{}) []string { + keys := make([]string, 0, len(set)) + for k := range set { + keys = append(keys, k) + } + sort.Strings(keys) + return keys +} + // DeduplicateStrings returns a new slice with whitespace-trimmed, empty, and duplicate // entries removed from input. When sorted is true the result is sorted in ascending order. // The relative order of first-seen entries is preserved when sorted is false. diff --git a/internal/strutil/util_test.go b/internal/strutil/util_test.go index 5bbf987e..e9bf96b3 100644 --- a/internal/strutil/util_test.go +++ b/internal/strutil/util_test.go @@ -6,6 +6,32 @@ import ( "github.com/stretchr/testify/assert" ) +func TestSortedSetKeys(t *testing.T) { + t.Parallel() + + t.Run("returns sorted keys", func(t *testing.T) { + t.Parallel() + set := map[string]struct{}{"banana": {}, "apple": {}, "cherry": {}} + assert.Equal(t, []string{"apple", "banana", "cherry"}, SortedSetKeys(set)) + }) + + t.Run("returns empty slice for empty set", func(t *testing.T) { + t.Parallel() + assert.Empty(t, SortedSetKeys(map[string]struct{}{})) + }) + + t.Run("returns single element slice", func(t *testing.T) { + t.Parallel() + set := map[string]struct{}{"only": {}} + assert.Equal(t, []string{"only"}, SortedSetKeys(set)) + }) + + t.Run("handles nil map", func(t *testing.T) { + t.Parallel() + assert.Empty(t, SortedSetKeys(nil)) + }) +} + func TestGetStringFromMap(t *testing.T) { t.Parallel() diff --git a/internal/urlutil/domains.go b/internal/urlutil/domains.go index e826e2de..49b31638 100644 --- a/internal/urlutil/domains.go +++ b/internal/urlutil/domains.go @@ -3,10 +3,10 @@ package urlutil import ( "net/url" "regexp" - "sort" "strings" "github.com/github/gh-aw-mcpg/internal/logger" + "github.com/github/gh-aw-mcpg/internal/strutil" ) var logDomains = logger.New("urlutil:domains") @@ -26,11 +26,7 @@ func ExtractURLDomainsFromValue(value any) []string { return nil } - domains := make([]string, 0, len(domainSet)) - for domain := range domainSet { - domains = append(domains, domain) - } - sort.Strings(domains) + domains := strutil.SortedSetKeys(domainSet) logDomains.Printf("ExtractURLDomainsFromValue: extracted %d unique domain(s)", len(domains)) return domains } @@ -91,11 +87,7 @@ func ExtractURLDomains(text string) []string { if len(domainSet) == 0 { return nil } - domains := make([]string, 0, len(domainSet)) - for domain := range domainSet { - domains = append(domains, domain) - } - sort.Strings(domains) + domains := strutil.SortedSetKeys(domainSet) logDomains.Printf("ExtractURLDomains: resolved %d unique domain(s) from %d candidate(s)", len(domains), len(matches)) return domains }