Skip to content

Commit

Permalink
feat: add notify based formatter
Browse files Browse the repository at this point in the history
shikanime committed Jan 22, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 45881a4 commit ff83eb4
Showing 8 changed files with 326 additions and 43 deletions.
11 changes: 11 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
@@ -29,6 +29,7 @@ type Config struct {
Walk string `mapstructure:"walk" toml:"walk,omitempty"`
WorkingDirectory string `mapstructure:"working-dir" toml:"-"`
Stdin bool `mapstructure:"stdin" toml:"-"` // not allowed in config
Watch bool `mapstructure:"watch" toml:"-"` // not allowed in config

FormatterConfigs map[string]*Formatter `mapstructure:"formatter" toml:"formatter,omitempty"`

@@ -98,6 +99,10 @@ func SetFlags(fs *pflag.FlagSet) {
"stdin", false,
"Format the context passed in via stdin.",
)
fs.Bool(
"watch", false,
"Watch the filesystem for changes and apply formatters when changes are detected. (env $TREEFMT_WATCH)",
)
fs.String(
"tree-root", "",
"The root directory from which treefmt will start walking the filesystem (defaults to the directory "+
@@ -157,6 +162,7 @@ func FromViper(v *viper.Viper) (*Config, error) {
"clear-cache": false,
"no-cache": false,
"stdin": false,
"watch": false,
"working-dir": ".",
}

@@ -185,6 +191,11 @@ func FromViper(v *viper.Viper) (*Config, error) {
cfg.Walk = walk.Stdin.String()
}

// if the watch flag was passed, we force the watch walk type
if cfg.Watch {
cfg.Walk = walk.Watch.String()
}

// determine the tree root
if cfg.TreeRoot == "" {
// if none was specified, we first try with tree-root-file
1 change: 0 additions & 1 deletion format/formatter.go
Original file line number Diff line number Diff line change
@@ -103,7 +103,6 @@ func (f *Formatter) Apply(ctx context.Context, files []*walk.File) error {

// log out the command being executed
f.log.Debugf("executing: %s", cmd.String())

if out, err := cmd.CombinedOutput(); err != nil {
f.log.Errorf("failed to apply with options '%v': %s", f.config.Options, err)

37 changes: 37 additions & 0 deletions test/test.go
Original file line number Diff line number Diff line change
@@ -14,6 +14,43 @@ import (
"golang.org/x/sys/unix"
)

//nolint:gochecknoglobals
var ExamplesPaths = []string{
"elm/elm.json",
"elm/src/Main.elm",
"emoji 🕰️/README.md",
"go/go.mod",
"go/main.go",
"haskell/CHANGELOG.md",
"haskell/Foo.hs",
"haskell/Main.hs",
"haskell/Nested/Foo.hs",
"haskell/Setup.hs",
"haskell/haskell.cabal",
"haskell/treefmt.toml",
"haskell-frontend/CHANGELOG.md",
"haskell-frontend/Main.hs",
"haskell-frontend/Setup.hs",
"haskell-frontend/haskell-frontend.cabal",
"html/index.html",
"html/scripts/.gitkeep",
"javascript/source/hello.js",
"nix/sources.nix",
"nixpkgs.toml",
"python/main.py",
"python/requirements.txt",
"python/virtualenv_proxy.py",
"ruby/bundler.rb",
"rust/Cargo.toml",
"rust/src/main.rs",
"shell/foo.sh",
"terraform/main.tf",
"terraform/two.tf",
"touch.toml",
"treefmt.toml",
"yaml/test.yaml",
}

func WriteConfig(t *testing.T, path string, cfg *config.Config) {
t.Helper()

39 changes: 1 addition & 38 deletions walk/filesystem_test.go
Original file line number Diff line number Diff line change
@@ -13,43 +13,6 @@ import (
"github.com/stretchr/testify/require"
)

//nolint:gochecknoglobals
var examplesPaths = []string{
"elm/elm.json",
"elm/src/Main.elm",
"emoji 🕰️/README.md",
"go/go.mod",
"go/main.go",
"haskell/CHANGELOG.md",
"haskell/Foo.hs",
"haskell/Main.hs",
"haskell/Nested/Foo.hs",
"haskell/Setup.hs",
"haskell/haskell.cabal",
"haskell/treefmt.toml",
"haskell-frontend/CHANGELOG.md",
"haskell-frontend/Main.hs",
"haskell-frontend/Setup.hs",
"haskell-frontend/haskell-frontend.cabal",
"html/index.html",
"html/scripts/.gitkeep",
"javascript/source/hello.js",
"nix/sources.nix",
"nixpkgs.toml",
"python/main.py",
"python/requirements.txt",
"python/virtualenv_proxy.py",
"ruby/bundler.rb",
"rust/Cargo.toml",
"rust/src/main.rs",
"shell/foo.sh",
"terraform/main.tf",
"terraform/two.tf",
"touch.toml",
"treefmt.toml",
"yaml/test.yaml",
}

func TestFilesystemReader(t *testing.T) {
as := require.New(t)

@@ -67,7 +30,7 @@ func TestFilesystemReader(t *testing.T) {
n, err := r.Read(ctx, files)

for i := count; i < count+n; i++ {
as.Equal(examplesPaths[i], files[i-count].RelPath)
as.Equal(test.ExamplesPaths[i], files[i-count].RelPath)
}

count += n
12 changes: 8 additions & 4 deletions walk/type_enum.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions walk/walk.go
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@ const (
Stdin
Filesystem
Git
Watch

BatchSize = 1024
)
@@ -215,6 +216,8 @@ func NewReader(
reader = NewFilesystemReader(root, path, statz, BatchSize)
case Git:
reader, err = NewGitReader(root, path, statz)
case Watch:
reader, err = NewWatchReader(root, path, statz)

default:
return nil, fmt.Errorf("unknown walk type: %v", walkType)
108 changes: 108 additions & 0 deletions walk/walk_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package walk_test

import (
"context"
"errors"
"io"
"os"
"path"
"testing"
"time"

"github.com/numtide/treefmt/v2/stats"
"github.com/numtide/treefmt/v2/test"
"github.com/numtide/treefmt/v2/walk"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
)

var sourceExample = `
package main
import "fmt"
func main() {
fmt.Println("Hello, world!")
}`

func TestWatchReader(t *testing.T) {
as := require.New(t)

tempDir := test.TempExamples(t)
statz := stats.New()

r, err := walk.NewWatchReader(tempDir, "", &statz)
as.NoError(err)

eg := errgroup.Group{}
for _, example := range test.ExamplesPaths {
eg.Go(func() error {
filePath := path.Join(tempDir, example)
content, err := os.ReadFile(filePath)
if err != nil {
return err
}
return os.WriteFile(filePath, content, 0o644)
})
}

count := 0

for count < 33 {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)

files := make([]*walk.File, 8)
n, err := r.Read(ctx, files)

count += n

cancel()

if errors.Is(err, io.EOF) {
break
}
}

as.NoError(eg.Wait())

as.Equal(33, count)
as.Equal(33, statz.Value(stats.Traversed))
as.Equal(0, statz.Value(stats.Matched))
as.Equal(0, statz.Value(stats.Formatted))
as.Equal(0, statz.Value(stats.Changed))
}

func TestWatchReaderCreate(t *testing.T) {
as := require.New(t)

tempDir := t.TempDir()
statz := stats.New()

r, err := walk.NewWatchReader(tempDir, "", &statz)
as.NoError(err)

as.NoError(
os.WriteFile(
path.Join(tempDir, "main.go"),
[]byte(sourceExample),
0o644,
),
)

ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)

files := make([]*walk.File, 8)
n, err := r.Read(ctx, files)

cancel()

if !errors.Is(err, io.EOF) {
as.NoError(err)
}

as.Equal(1, n)
as.Equal(1, statz.Value(stats.Traversed))
as.Equal(0, statz.Value(stats.Matched))
as.Equal(0, statz.Value(stats.Formatted))
as.Equal(0, statz.Value(stats.Changed))
}
158 changes: 158 additions & 0 deletions walk/watch.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package walk

import (
"context"
"errors"
"fmt"
"io"
"log"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"

"github.com/fsnotify/fsnotify"
"github.com/numtide/treefmt/v2/stats"
)

type WatchReader struct {
root string
path string

log *log.Logger
stats *stats.Stats

watcher *fsnotify.Watcher
}

func (f *WatchReader) Read(ctx context.Context, files []*File) (n int, err error) {
// ensure we record how many files we traversed
defer func() {
f.stats.Add(stats.Traversed, n)
}()

// listen for shutdown signal and cancel the context
exit := make(chan os.Signal, 1)
signal.Notify(exit, os.Interrupt, syscall.SIGTERM)

select {
// since we can't detect exit using the context as watch
// events are an unbounded channel, we need to check
// for this explicitly
case <-exit:
return n, io.EOF

// exit early if the context was cancelled
case <-ctx.Done():
err = ctx.Err()
if err == nil {
return n, fmt.Errorf("context cancelled: %w", ctx.Err())
}

return n, nil

// read the next event from the channel
case event, ok := <-f.watcher.Events:
if !ok {
// channel was closed, exit the loop
return n, io.EOF
}

// skip if the event is a chmod or rename event since it doesn't
// change the file contents
if !event.Has(fsnotify.Create) && !event.Has(fsnotify.Write) {
return n, nil
}

file, err := os.Open(event.Name)
if errors.Is(err, os.ErrNotExist) {
// file was deleted, skip it
return n, nil
} else if err != nil {
return n, fmt.Errorf("failed to stat file %s: %w", event.Name, err)
}
defer file.Close()
info, err := file.Stat()
if err != nil {
return n, fmt.Errorf("failed to stat file %s: %w", event.Name, err)
}

// determine a path relative to the root
relPath, err := filepath.Rel(f.root, event.Name)
if err != nil {
return n, fmt.Errorf("failed to determine a relative path for %s: %w", event.Name, err)
}

// add to the file array and increment n
files[n] = &File{
Path: event.Name,
RelPath: relPath,
Info: info,
}
n++

case err, ok := <-f.watcher.Errors:
if !ok {
return n, fmt.Errorf("failed to read from watcher: %w", err)
}
}

return n, err
}

// Close waits for all watcher processing to complete.
func (f *WatchReader) Close() error {
err := f.watcher.Close()
if err != nil {
return fmt.Errorf("failed to close watcher: %w", err)
}

return nil
}

func NewWatchReader(
root string,
path string,
statz *stats.Stats,
) (*WatchReader, error) {
watcher, err := fsnotify.NewWatcher()
if err != nil {
log.Fatalf("failed to create watcher: %v", err)
}

r := WatchReader{
root: root,
path: path,
log: log.Default(),
stats: statz,
watcher: watcher,
}

// path is relative to the root, so we create a fully qualified version
// we also clean the path up in case there are any ../../ components etc.
fqPath := filepath.Clean(filepath.Join(root, path))

// ensure the path is within the root
if !strings.HasPrefix(fqPath, root) {
return nil, fmt.Errorf("path '%s' is outside of the root '%s'", fqPath, root)
}

// start watching for changes recursively
err = filepath.Walk(fqPath, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
if err := watcher.Add(path); err != nil {
return fmt.Errorf("failed to watch path %s: %w", path, err)
}
}
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to walk directory %s: %w", fqPath, err)
}

return &r, nil
}

0 comments on commit ff83eb4

Please sign in to comment.