From 7a31baa7575099ba15ff618871ad82ad83fe2792 Mon Sep 17 00:00:00 2001 From: William Phetsinorath Date: Thu, 23 Jan 2025 14:22:59 +0000 Subject: [PATCH] feat: add notify based formatter Close: #509 --- config/config.go | 11 +++ test/test.go | 37 +++++++++ walk/filesystem_test.go | 39 +--------- walk/type_enum.go | 12 ++- walk/walk.go | 3 + walk/watch.go | 164 ++++++++++++++++++++++++++++++++++++++++ walk/watch_test.go | 118 +++++++++++++++++++++++++++++ 7 files changed, 342 insertions(+), 42 deletions(-) create mode 100644 walk/watch.go create mode 100644 walk/watch_test.go diff --git a/config/config.go b/config/config.go index c10df3ff..7e5687ad 100644 --- a/config/config.go +++ b/config/config.go @@ -29,6 +29,7 @@ type Config struct { Walk string `mapstructure:"walk" toml:"walk,omitempty"` WorkingDirectory string `mapstructure:"working-dir" toml:"-"` Stdin bool `mapstructure:"stdin" toml:"-"` // not allowed in config + Watch bool `mapstructure:"watch" toml:"-"` // not allowed in config FormatterConfigs map[string]*Formatter `mapstructure:"formatter" toml:"formatter,omitempty"` @@ -98,6 +99,10 @@ func SetFlags(fs *pflag.FlagSet) { "stdin", false, "Format the context passed in via stdin.", ) + fs.Bool( + "watch", false, + "Watch the filesystem for changes and apply formatters when changes are detected. (env $TREEFMT_WATCH)", + ) fs.String( "tree-root", "", "The root directory from which treefmt will start walking the filesystem (defaults to the directory "+ @@ -157,6 +162,7 @@ func FromViper(v *viper.Viper) (*Config, error) { "clear-cache": false, "no-cache": false, "stdin": false, + "watch": false, "working-dir": ".", } @@ -185,6 +191,11 @@ func FromViper(v *viper.Viper) (*Config, error) { cfg.Walk = walk.Stdin.String() } + // if the watch flag was passed, we force the watch walk type + if cfg.Watch { + cfg.Walk = walk.Watch.String() + } + // determine the tree root if cfg.TreeRoot == "" { // if none was specified, we first try with tree-root-file diff --git a/test/test.go b/test/test.go index 80c984d2..1b93d71f 100644 --- a/test/test.go +++ b/test/test.go @@ -14,6 +14,43 @@ import ( "golang.org/x/sys/unix" ) +//nolint:gochecknoglobals +var ExamplesPaths = []string{ + "elm/elm.json", + "elm/src/Main.elm", + "emoji 🕰️/README.md", + "go/go.mod", + "go/main.go", + "haskell/CHANGELOG.md", + "haskell/Foo.hs", + "haskell/Main.hs", + "haskell/Nested/Foo.hs", + "haskell/Setup.hs", + "haskell/haskell.cabal", + "haskell/treefmt.toml", + "haskell-frontend/CHANGELOG.md", + "haskell-frontend/Main.hs", + "haskell-frontend/Setup.hs", + "haskell-frontend/haskell-frontend.cabal", + "html/index.html", + "html/scripts/.gitkeep", + "javascript/source/hello.js", + "nix/sources.nix", + "nixpkgs.toml", + "python/main.py", + "python/requirements.txt", + "python/virtualenv_proxy.py", + "ruby/bundler.rb", + "rust/Cargo.toml", + "rust/src/main.rs", + "shell/foo.sh", + "terraform/main.tf", + "terraform/two.tf", + "touch.toml", + "treefmt.toml", + "yaml/test.yaml", +} + func WriteConfig(t *testing.T, path string, cfg *config.Config) { t.Helper() diff --git a/walk/filesystem_test.go b/walk/filesystem_test.go index 8c5b74c6..b9189e67 100644 --- a/walk/filesystem_test.go +++ b/walk/filesystem_test.go @@ -13,43 +13,6 @@ import ( "github.com/stretchr/testify/require" ) -//nolint:gochecknoglobals -var examplesPaths = []string{ - "elm/elm.json", - "elm/src/Main.elm", - "emoji 🕰️/README.md", - "go/go.mod", - "go/main.go", - "haskell/CHANGELOG.md", - "haskell/Foo.hs", - "haskell/Main.hs", - "haskell/Nested/Foo.hs", - "haskell/Setup.hs", - "haskell/haskell.cabal", - "haskell/treefmt.toml", - "haskell-frontend/CHANGELOG.md", - "haskell-frontend/Main.hs", - "haskell-frontend/Setup.hs", - "haskell-frontend/haskell-frontend.cabal", - "html/index.html", - "html/scripts/.gitkeep", - "javascript/source/hello.js", - "nix/sources.nix", - "nixpkgs.toml", - "python/main.py", - "python/requirements.txt", - "python/virtualenv_proxy.py", - "ruby/bundler.rb", - "rust/Cargo.toml", - "rust/src/main.rs", - "shell/foo.sh", - "terraform/main.tf", - "terraform/two.tf", - "touch.toml", - "treefmt.toml", - "yaml/test.yaml", -} - func TestFilesystemReader(t *testing.T) { as := require.New(t) @@ -67,7 +30,7 @@ func TestFilesystemReader(t *testing.T) { n, err := r.Read(ctx, files) for i := count; i < count+n; i++ { - as.Equal(examplesPaths[i], files[i-count].RelPath) + as.Equal(test.ExamplesPaths[i], files[i-count].RelPath) } count += n diff --git a/walk/type_enum.go b/walk/type_enum.go index 4550804f..518ad666 100644 --- a/walk/type_enum.go +++ b/walk/type_enum.go @@ -7,11 +7,11 @@ import ( "strings" ) -const _TypeName = "autostdinfilesystemgit" +const _TypeName = "autostdinfilesystemgitwatch" -var _TypeIndex = [...]uint8{0, 4, 9, 19, 22} +var _TypeIndex = [...]uint8{0, 4, 9, 19, 22, 27} -const _TypeLowerName = "autostdinfilesystemgit" +const _TypeLowerName = "autostdinfilesystemgitwatch" func (i Type) String() string { if i < 0 || i >= Type(len(_TypeIndex)-1) { @@ -28,9 +28,10 @@ func _TypeNoOp() { _ = x[Stdin-(1)] _ = x[Filesystem-(2)] _ = x[Git-(3)] + _ = x[Watch-(4)] } -var _TypeValues = []Type{Auto, Stdin, Filesystem, Git} +var _TypeValues = []Type{Auto, Stdin, Filesystem, Git, Watch} var _TypeNameToValueMap = map[string]Type{ _TypeName[0:4]: Auto, @@ -41,6 +42,8 @@ var _TypeNameToValueMap = map[string]Type{ _TypeLowerName[9:19]: Filesystem, _TypeName[19:22]: Git, _TypeLowerName[19:22]: Git, + _TypeName[22:27]: Watch, + _TypeLowerName[22:27]: Watch, } var _TypeNames = []string{ @@ -48,6 +51,7 @@ var _TypeNames = []string{ _TypeName[4:9], _TypeName[9:19], _TypeName[19:22], + _TypeName[22:27], } // TypeString retrieves an enum value from the enum constants string name. diff --git a/walk/walk.go b/walk/walk.go index 7fc3fa0b..9507395e 100644 --- a/walk/walk.go +++ b/walk/walk.go @@ -22,6 +22,7 @@ const ( Stdin Filesystem Git + Watch BatchSize = 1024 ) @@ -215,6 +216,8 @@ func NewReader( reader = NewFilesystemReader(root, path, statz, BatchSize) case Git: reader, err = NewGitReader(root, path, statz) + case Watch: + reader, err = NewWatchReader(root, path, statz, BatchSize) default: return nil, fmt.Errorf("unknown walk type: %v", walkType) diff --git a/walk/watch.go b/walk/watch.go new file mode 100644 index 00000000..56b0fa1e --- /dev/null +++ b/walk/watch.go @@ -0,0 +1,164 @@ +package walk + +import ( + "context" + "errors" + "fmt" + "io" + "log" + "os" + "os/signal" + "path/filepath" + "strings" + "syscall" + + "github.com/fsnotify/fsnotify" + "github.com/numtide/treefmt/v2/stats" +) + +type WatchReader struct { + root string + path string + + log *log.Logger + stats *stats.Stats + + watcher *fsnotify.Watcher +} + +func (f *WatchReader) Read(ctx context.Context, files []*File) (n int, err error) { + // ensure we record how many files we traversed + defer func() { + f.stats.Add(stats.Traversed, n) + }() + + // listen for shutdown signal and cancel the context + exit := make(chan os.Signal, 1) + signal.Notify(exit, os.Interrupt, syscall.SIGTERM) + + for n < len(files) { + select { + // since we can't detect exit using the context as watch + // events are an unbounded channel, we need to check + // for this explicitly + case <-exit: + return n, io.EOF + + // exit early if the context was cancelled + case <-ctx.Done(): + err = ctx.Err() + if err == nil { + return n, fmt.Errorf("context cancelled: %w", ctx.Err()) + } + + return n, fmt.Errorf("context error: %w", err) + + // read the next event from the channel + case event, ok := <-f.watcher.Events: + if !ok { + // channel was closed, exit the loop + return n, io.EOF + } + + // skip if the event which doesn't have content changed + if !event.Has(fsnotify.Write) { + return n, nil + } + + file, err := os.Open(event.Name) + if errors.Is(err, os.ErrNotExist) { + // file was deleted, skip it + return n, nil + } else if err != nil { + return n, fmt.Errorf("failed to stat file %s: %w", event.Name, err) + } + defer file.Close() + + info, err := file.Stat() + if err != nil { + return n, fmt.Errorf("failed to stat file %s: %w", event.Name, err) + } + + // determine a path relative to the root + relPath, err := filepath.Rel(f.root, event.Name) + if err != nil { + return n, fmt.Errorf("failed to determine a relative path for %s: %w", event.Name, err) + } + + // add to the file array and increment n + files[n] = &File{ + Path: event.Name, + RelPath: relPath, + Info: info, + } + n++ + + case err, ok := <-f.watcher.Errors: + if !ok { + return n, fmt.Errorf("failed to read from watcher: %w", err) + } + } + } + + return n, nil +} + +// Close waits for all watcher processing to complete. +func (f *WatchReader) Close() error { + err := f.watcher.Close() + if err != nil { + return fmt.Errorf("failed to close watcher: %w", err) + } + + return nil +} + +func NewWatchReader( + root string, + path string, + statz *stats.Stats, + batchSize uint, +) (*WatchReader, error) { + watcher, err := fsnotify.NewBufferedWatcher(batchSize) + if err != nil { + log.Fatalf("failed to create watcher: %v", err) + } + + r := WatchReader{ + root: root, + path: path, + log: log.Default(), + stats: statz, + watcher: watcher, + } + + // path is relative to the root, so we create a fully qualified version + // we also clean the path up in case there are any ../../ components etc. + fqPath := filepath.Clean(filepath.Join(root, path)) + + // ensure the path is within the root + if !strings.HasPrefix(fqPath, root) { + return nil, fmt.Errorf("path '%s' is outside of the root '%s'", fqPath, root) + } + + // start watching for changes recursively + err = filepath.Walk(fqPath, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if info.IsDir() { + err := watcher.Add(path) + if err != nil { + return fmt.Errorf("failed to watch path %s: %w", path, err) + } + } + + return nil + }) + if err != nil { + return nil, fmt.Errorf("failed to walk directory %s: %w", fqPath, err) + } + + return &r, nil +} diff --git a/walk/watch_test.go b/walk/watch_test.go new file mode 100644 index 00000000..f6c9ad28 --- /dev/null +++ b/walk/watch_test.go @@ -0,0 +1,118 @@ +package walk_test + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "path" + "testing" + "time" + + "github.com/numtide/treefmt/v2/stats" + "github.com/numtide/treefmt/v2/test" + "github.com/numtide/treefmt/v2/walk" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +//nolint:gochecknoglobals +var sourceExample = ` +package main + +import "fmt" + +func main() { + fmt.Println("Hello, world!") +}` + +func TestWatchReader(t *testing.T) { + as := require.New(t) + + tempDir := test.TempExamples(t) + statz := stats.New() + + r, err := walk.NewWatchReader(tempDir, "", &statz, 1024) + as.NoError(err) + + eg := errgroup.Group{} + for _, example := range test.ExamplesPaths { + eg.Go(func() error { + filePath := path.Join(tempDir, example) + + content, err := os.ReadFile(filePath) + if err != nil { + return fmt.Errorf("failed to read file: %w", err) + } + + return os.WriteFile(filePath, content, 0o600) + }) + } + + count := 0 + + for count < len(test.ExamplesPaths) { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + + files := make([]*walk.File, 8) + n, err := r.Read(ctx, files) + + count += n + + cancel() + + if errors.Is(err, io.EOF) { + break + } + } + + as.NoError(eg.Wait()) + + as.Equal(40, count) + as.Equal(40, statz.Value(stats.Traversed)) + as.Equal(0, statz.Value(stats.Matched)) + as.Equal(0, statz.Value(stats.Formatted)) + as.Equal(0, statz.Value(stats.Changed)) +} + +func TestWatchReaderCreate(t *testing.T) { + as := require.New(t) + + tempDir := t.TempDir() + statz := stats.New() + + r, err := walk.NewWatchReader(tempDir, "", &statz, 1024) + as.NoError(err) + + as.NoError( + os.WriteFile( + path.Join(tempDir, "main.go"), + []byte(sourceExample), + 0o600, + ), + ) + + count := 0 + + for count < 1 { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + + files := make([]*walk.File, 8) + n, err := r.Read(ctx, files) + + count += n + + cancel() + + if errors.Is(err, io.EOF) { + break + } + } + + as.Equal(1, count) + as.Equal(1, statz.Value(stats.Traversed)) + as.Equal(0, statz.Value(stats.Matched)) + as.Equal(0, statz.Value(stats.Formatted)) + as.Equal(0, statz.Value(stats.Changed)) +}