From 0e4e08d9aa83376769187de61df4b1334d0a1db9 Mon Sep 17 00:00:00 2001 From: samzong Date: Fri, 3 Apr 2026 13:32:57 +0800 Subject: [PATCH 1/2] feat(worktree): unified protection policy for destructive operations Add ProtectionPolicy struct with IsProtected/Reason methods as a single predicate for bare, main-branch, and root-worktree protection. Wire into Remove, Promote, and Prune to replace scattered inline checks. - Unify main branch resolution via resolveBaseBranchWithPolicy - Add --all/-a flag to wt rm (skips protected, removes the rest) - Contextual error messages: bare repository / main worktree / main branch - Bare-layout HEAD fallback for non-main/master default branches Signed-off-by: samzong --- cmd/worktree.go | 53 +++++++++++++++-- cmd/worktree_test.go | 85 +++++++++++++++++++++++++++ internal/worktree/bare.go | 20 ++----- internal/worktree/prune.go | 3 +- internal/worktree/share_discover.go | 9 ++- internal/worktree/worktree.go | 66 +++++++++++++++++++-- internal/worktree/worktree_test.go | 90 +++++++++++++++++++++++++++++ 7 files changed, 298 insertions(+), 28 deletions(-) diff --git a/cmd/worktree.go b/cmd/worktree.go index f157304..2f8e398 100644 --- a/cmd/worktree.go +++ b/cmd/worktree.go @@ -1,6 +1,7 @@ package cmd import ( + "errors" "fmt" "os" "path/filepath" @@ -17,6 +18,7 @@ var ( wtForce bool wtDeleteBranch bool wtDryRun bool + wtAll bool wtUpstream string wtProjectName string prRemote string @@ -69,21 +71,30 @@ var wtListCmd = &cobra.Command{ } var wtRemoveCmd = &cobra.Command{ - Use: "remove [name...]", + Use: "remove [name...]", Aliases: []string{"rm"}, Short: "Remove worktrees (alias: rm)", Long: `Remove one or more worktrees. By default, only removes the worktree directory, keeping the branch. -Use -D to also delete the branch. +Use -D to also delete the branch. Use --all to remove all non-protected worktrees. Examples: gmc wt remove feature-login # Remove one worktree gmc wt rm feat-a feat-b feat-c # Remove multiple worktrees gmc wt rm feature-login -D # Remove worktree and delete branch gmc wt rm feature-login -f # Force remove (ignore dirty state) - gmc wt rm feature-login --dry-run # Preview what would be removed`, - Args: cobra.MinimumNArgs(1), + gmc wt rm feature-login --dry-run # Preview what would be removed + gmc wt rm --all -D # Remove all non-protected worktrees and branches`, + Args: func(_ *cobra.Command, args []string) error { + if wtAll && len(args) > 0 { + return errors.New("--all and positional arguments are mutually exclusive") + } + if !wtAll && len(args) < 1 { + return errors.New("requires at least 1 arg(s) or --all flag") + } + return nil + }, RunE: func(_ *cobra.Command, args []string) error { wtClient := newWorktreeClient() return runWorktreeRemove(wtClient, args) @@ -192,6 +203,7 @@ func init() { wtRemoveCmd.Flags().BoolVarP(&wtForce, "force", "f", false, "Force removal even if worktree is dirty") wtRemoveCmd.Flags().BoolVarP(&wtDeleteBranch, "delete-branch", "D", false, "Also delete the branch") wtRemoveCmd.Flags().BoolVar(&wtDryRun, "dry-run", false, "Preview what would be removed without making changes") + wtRemoveCmd.Flags().BoolVarP(&wtAll, "all", "a", false, "Remove all non-protected worktrees") // Flags for clone command wtCloneCmd.Flags().StringVar(&wtUpstream, "upstream", "", "Upstream repository URL (for fork workflow)") @@ -335,6 +347,18 @@ func runWorktreeList(wtClient *worktree.Client) error { } func runWorktreeRemove(wtClient *worktree.Client, names []string) error { + if wtAll { + resolved, err := resolveAllRemovableWorktrees(wtClient) + if err != nil { + return err + } + if len(resolved) == 0 { + fmt.Fprintln(outWriter(), "No removable worktrees found.") + return nil + } + names = resolved + } + opts := worktree.RemoveOptions{ Force: wtForce, DeleteBranch: wtDeleteBranch, @@ -355,6 +379,27 @@ func runWorktreeRemove(wtClient *worktree.Client, names []string) error { return nil } +func resolveAllRemovableWorktrees(wtClient *worktree.Client) ([]string, error) { + all, err := wtClient.List() + if err != nil { + return nil, err + } + + pp := wtClient.NewProtectionPolicy() + root := getDisplayRoot(wtClient) + var names []string + for _, wt := range all { + if pp.IsProtected(wt) { + continue + } + if isExternalWorktree(root, wt.Path) || isAgentWorktree(wt.Path) { + continue + } + names = append(names, displayWorktreeName(root, wt.Path)) + } + return names, nil +} + func runWorktreeClone(wtClient *worktree.Client, url string) error { opts := worktree.CloneOptions{ Name: wtProjectName, diff --git a/cmd/worktree_test.go b/cmd/worktree_test.go index 64bd635..4818e2d 100644 --- a/cmd/worktree_test.go +++ b/cmd/worktree_test.go @@ -102,3 +102,88 @@ func runGitCmd(t *testing.T, dir string, args ...string) string { var execCommand = func(name string, args ...string) *exec.Cmd { return exec.Command(name, args...) } + +func TestRemoveAll_SkipsProtected(t *testing.T) { + repoDir := initCmdTestRepo(t) + + feat1 := filepath.Join(repoDir, "feat-1") + feat2 := filepath.Join(repoDir, "feat-2") + runGitCmd(t, repoDir, "worktree", "add", "-b", "feat-1", feat1, "main") + runGitCmd(t, repoDir, "worktree", "add", "-b", "feat-2", feat2, "main") + + oldCwd, err := os.Getwd() + require.NoError(t, err) + defer func() { _ = os.Chdir(oldCwd) }() + require.NoError(t, os.Chdir(repoDir)) + + var out bytes.Buffer + oldOut := outWriterFunc + oldErr := errWriterFunc + outWriterFunc = func() io.Writer { return &out } + errWriterFunc = func() io.Writer { return &out } + defer func() { + outWriterFunc = oldOut + errWriterFunc = oldErr + }() + + oldAll := wtAll + oldForce := wtForce + oldDelete := wtDeleteBranch + oldDry := wtDryRun + defer func() { + wtAll = oldAll + wtForce = oldForce + wtDeleteBranch = oldDelete + wtDryRun = oldDry + }() + + wtAll = true + wtForce = false + wtDeleteBranch = true + wtDryRun = false + + client := worktree.NewClient(worktree.Options{}) + err = runWorktreeRemove(client, nil) + require.NoError(t, err) + + _, err = os.Stat(feat1) + assert.True(t, os.IsNotExist(err), "feat-1 should be removed") + _, err = os.Stat(feat2) + assert.True(t, os.IsNotExist(err), "feat-2 should be removed") + + _, err = os.Stat(repoDir) + assert.NoError(t, err, "main worktree (repoDir) must survive --all") + + remaining, err := client.List() + require.NoError(t, err) + var mainFound bool + for _, wt := range remaining { + if wt.Branch == "feat-1" || wt.Branch == "feat-2" { + t.Errorf("branch %s should have been deleted", wt.Branch) + } + if wt.Branch == "main" { + mainFound = true + } + } + assert.True(t, mainFound, "main branch worktree must still exist") +} + +func TestRemoveAllMutuallyExclusiveWithArgs(t *testing.T) { + oldAll := wtAll + defer func() { wtAll = oldAll }() + wtAll = true + + err := wtRemoveCmd.Args(wtRemoveCmd, []string{"some-worktree"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "mutually exclusive") +} + +func TestRemoveRequiresArgsOrAll(t *testing.T) { + oldAll := wtAll + defer func() { wtAll = oldAll }() + wtAll = false + + err := wtRemoveCmd.Args(wtRemoveCmd, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "requires at least 1 arg") +} diff --git a/internal/worktree/bare.go b/internal/worktree/bare.go index dfcf611..64e7c77 100644 --- a/internal/worktree/bare.go +++ b/internal/worktree/bare.go @@ -129,23 +129,11 @@ func (c *Client) gitConfig(repoDir string, key string, value string) error { } func (c *Client) getDefaultBranch(bareDir string) (string, error) { - args := []string{"-C", bareDir, "symbolic-ref", "--short", "HEAD"} - result, err := c.runner.Run(args...) - if err == nil { - branch := result.StdoutString(true) - if branch != "" { - return branch, nil - } - } - - for _, branch := range []string{"main", "master"} { - args := []string{"-C", bareDir, "rev-parse", "--verify", "refs/heads/" + branch} - if _, err := c.runner.Run(args...); err == nil { - return branch, nil - } + branch, err := c.resolveBaseBranchWithPolicy(bareDir, "", true) + if err != nil { + return "", err } - - return "", errors.New("could not determine default branch") + return localBranchName(branch), nil } func extractProjectName(repoURL string) (string, error) { diff --git a/internal/worktree/prune.go b/internal/worktree/prune.go index 64fb928..bd5db3f 100644 --- a/internal/worktree/prune.go +++ b/internal/worktree/prune.go @@ -91,9 +91,10 @@ func (c *Client) collectPruneCandidates(root, baseBranch string, report *Report) repoDir := repoDirForGit(root) isBare := repoDir != root + pp := c.NewProtectionPolicy() var candidates []pruneCandidate for _, wt := range worktrees { - if wt.IsBare || filepath.Base(wt.Path) == ".bare" || wt.Path == root { + if pp.IsProtected(wt) { continue } if isBare && isExternalPath(root, wt.Path) { diff --git a/internal/worktree/share_discover.go b/internal/worktree/share_discover.go index 5c37436..51375d0 100644 --- a/internal/worktree/share_discover.go +++ b/internal/worktree/share_discover.go @@ -124,9 +124,12 @@ func (c *Client) findMainWorktreePath() (string, error) { return "", fmt.Errorf("failed to list worktrees: %w", err) } - for _, wt := range worktrees { - if wt.Branch == "main" || wt.Branch == "master" { - return wt.Path, nil + mainBranch := c.resolvedMainBranch() + if mainBranch != "" { + for _, wt := range worktrees { + if wt.Branch == mainBranch { + return wt.Path, nil + } } } diff --git a/internal/worktree/worktree.go b/internal/worktree/worktree.go index 46f009f..32d378b 100644 --- a/internal/worktree/worktree.go +++ b/internal/worktree/worktree.go @@ -484,8 +484,9 @@ func (c *Client) prepareRemove(name string) (removeContext, error) { if !found { return removeContext{}, fmt.Errorf("worktree not found: %s\nUse 'gmc wt ls' to see available worktrees", name) } - if wtInfo.IsBare { - return removeContext{}, errors.New("cannot remove the main bare worktree") + pp := c.NewProtectionPolicy() + if pp.IsProtected(wtInfo) { + return removeContext{}, fmt.Errorf("cannot remove protected worktree '%s' (%s)", name, pp.Reason(wtInfo)) } // Reject agent/external worktrees (outside searchRoot). @@ -672,12 +673,10 @@ func (c *Client) Promote(worktreeName, newBranchName string) (Report, error) { targetPath := filepath.Join(searchRoot, worktreeName) - // Verify worktree exists if _, err := os.Stat(targetPath); os.IsNotExist(err) { return report, fmt.Errorf("worktree not found: %s", worktreeName) } - // Get current branch name result, err := c.runner.Run("-C", targetPath, "rev-parse", "--abbrev-ref", "HEAD") if err != nil { return report, fmt.Errorf("failed to get current branch: %w", err) @@ -688,6 +687,12 @@ func (c *Client) Promote(worktreeName, newBranchName string) (Report, error) { return report, errors.New("worktree is in detached HEAD state, cannot promote") } + pp := c.NewProtectionPolicy() + checkWt := Info{Path: targetPath, Branch: oldBranch} + if pp.IsProtected(checkWt) { + return report, fmt.Errorf("cannot promote protected worktree '%s' (%s)", worktreeName, pp.Reason(checkWt)) + } + // Rename branch args := []string{"-C", targetPath, "branch", "-m", newBranchName} result, err = c.runner.RunLogged(args...) @@ -756,6 +761,59 @@ func (c *Client) listGitRefs(errLabel string, gitArgs ...string) ([]string, erro return strings.Split(output, "\n"), nil } +type ProtectionPolicy struct { + MainBranch string + RootPath string +} + +func (c *Client) NewProtectionPolicy() ProtectionPolicy { + var p ProtectionPolicy + root, err := c.GetWorktreeRoot() + if err != nil { + return p + } + p.RootPath = root + repoDir := repoDirForGit(root) + isBareLayout := repoDir != root + branch, err := c.resolveBaseBranchWithPolicy(repoDir, "", isBareLayout) + if err != nil { + return p + } + p.MainBranch = localBranchName(branch) + return p +} + +func (p ProtectionPolicy) IsProtected(wt Info) bool { + if wt.IsBare { + return true + } + if p.RootPath != "" && wt.Path == p.RootPath { + return true + } + if p.MainBranch != "" && wt.Branch == p.MainBranch { + return true + } + return false +} + +func (p ProtectionPolicy) Reason(wt Info) string { + if wt.IsBare { + return "bare repository" + } + if p.RootPath != "" && wt.Path == p.RootPath { + return "main worktree" + } + return "main branch" +} + +func (c *Client) IsProtectedWorktree(wt Info) bool { + return c.NewProtectionPolicy().IsProtected(wt) +} + +func (c *Client) resolvedMainBranch() string { + return c.NewProtectionPolicy().MainBranch +} + // ListBranches returns all local branch names func (c *Client) ListBranches() ([]string, error) { return c.listGitRefs("list branches", "branch", "--format=%(refname:short)") diff --git a/internal/worktree/worktree_test.go b/internal/worktree/worktree_test.go index d0ff11a..1807cb9 100644 --- a/internal/worktree/worktree_test.go +++ b/internal/worktree/worktree_test.go @@ -593,6 +593,96 @@ func TestLocalBranchName(t *testing.T) { } } +func TestIsProtectedWorktree(t *testing.T) { + repoDir := initTestRepo(t) + + featureDir := filepath.Join(repoDir, "feature-wt") + runGit(t, repoDir, "worktree", "add", "-b", "feature", featureDir, "main") + + cwd, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + defer func() { _ = os.Chdir(cwd) }() + if err := os.Chdir(repoDir); err != nil { + t.Fatal(err) + } + + client := NewClient(Options{}) + worktrees, err := client.List() + if err != nil { + t.Fatalf("List() error = %v", err) + } + + for _, wt := range worktrees { + protected := client.IsProtectedWorktree(wt) + switch wt.Branch { + case "main": + if !protected { + t.Errorf("main worktree should be protected, path=%s", wt.Path) + } + case "feature": + if protected { + t.Errorf("feature worktree should NOT be protected, path=%s", wt.Path) + } + } + } +} + +func TestRemoveProtectedWorktree(t *testing.T) { + repoDir := initTestRepo(t) + + featureDir := filepath.Join(filepath.Dir(repoDir), filepath.Base(repoDir)+"--feature") + runGit(t, repoDir, "worktree", "add", "-b", "feature", featureDir, "main") + defer os.RemoveAll(featureDir) + + cwd, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + defer func() { _ = os.Chdir(cwd) }() + if err := os.Chdir(featureDir); err != nil { + t.Fatal(err) + } + + repoName := filepath.Base(repoDir) + client := NewClient(Options{}) + _, err = client.Remove(repoName, RemoveOptions{}) + if err == nil { + t.Fatal("expected error when removing protected worktree") + } + if !strings.Contains(err.Error(), "cannot remove protected worktree") { + t.Errorf("unexpected error: %v", err) + } +} + +func TestPromoteProtectedWorktree(t *testing.T) { + repoDir := initTestRepo(t) + + featureDir := filepath.Join(filepath.Dir(repoDir), filepath.Base(repoDir)+"--feature") + runGit(t, repoDir, "worktree", "add", "-b", "feature", featureDir, "main") + defer os.RemoveAll(featureDir) + + cwd, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + defer func() { _ = os.Chdir(cwd) }() + if err := os.Chdir(featureDir); err != nil { + t.Fatal(err) + } + + repoName := filepath.Base(repoDir) + client := NewClient(Options{}) + _, err = client.Promote(repoName, "new-name") + if err == nil { + t.Fatal("expected error when promoting protected worktree") + } + if !strings.Contains(err.Error(), "cannot promote protected worktree") { + t.Errorf("unexpected error: %v", err) + } +} + func initTestRepo(t *testing.T) string { return initTestRepoWithBranch(t, "main") } From 93f31ac6527fe32f4d8a335d88cc75824eea727c Mon Sep 17 00:00:00 2001 From: samzong Date: Fri, 3 Apr 2026 13:39:58 +0800 Subject: [PATCH 2/2] fix(worktree): make NewProtectionPolicy fail-closed on errors Return error from NewProtectionPolicy, IsProtectedWorktree, and resolvedMainBranch instead of silently returning empty policy. Propagate errors in all callers (prepareRemove, Promote, Prune, resolveAllRemovableWorktrees) so destructive operations abort when protection state cannot be determined. Signed-off-by: samzong --- cmd/worktree.go | 5 ++++- internal/worktree/prune.go | 5 ++++- internal/worktree/share_discover.go | 4 ++-- internal/worktree/worktree.go | 34 ++++++++++++++++++++--------- internal/worktree/worktree_test.go | 5 ++++- 5 files changed, 38 insertions(+), 15 deletions(-) diff --git a/cmd/worktree.go b/cmd/worktree.go index 2f8e398..f9b349c 100644 --- a/cmd/worktree.go +++ b/cmd/worktree.go @@ -385,7 +385,10 @@ func resolveAllRemovableWorktrees(wtClient *worktree.Client) ([]string, error) { return nil, err } - pp := wtClient.NewProtectionPolicy() + pp, err := wtClient.NewProtectionPolicy() + if err != nil { + return nil, err + } root := getDisplayRoot(wtClient) var names []string for _, wt := range all { diff --git a/internal/worktree/prune.go b/internal/worktree/prune.go index bd5db3f..6dbf346 100644 --- a/internal/worktree/prune.go +++ b/internal/worktree/prune.go @@ -91,7 +91,10 @@ func (c *Client) collectPruneCandidates(root, baseBranch string, report *Report) repoDir := repoDirForGit(root) isBare := repoDir != root - pp := c.NewProtectionPolicy() + pp, err := c.NewProtectionPolicy() + if err != nil { + return nil, "", err + } var candidates []pruneCandidate for _, wt := range worktrees { if pp.IsProtected(wt) { diff --git a/internal/worktree/share_discover.go b/internal/worktree/share_discover.go index 51375d0..b35ec18 100644 --- a/internal/worktree/share_discover.go +++ b/internal/worktree/share_discover.go @@ -124,8 +124,8 @@ func (c *Client) findMainWorktreePath() (string, error) { return "", fmt.Errorf("failed to list worktrees: %w", err) } - mainBranch := c.resolvedMainBranch() - if mainBranch != "" { + mainBranch, err := c.resolvedMainBranch() + if err == nil && mainBranch != "" { for _, wt := range worktrees { if wt.Branch == mainBranch { return wt.Path, nil diff --git a/internal/worktree/worktree.go b/internal/worktree/worktree.go index 32d378b..1179a69 100644 --- a/internal/worktree/worktree.go +++ b/internal/worktree/worktree.go @@ -484,7 +484,10 @@ func (c *Client) prepareRemove(name string) (removeContext, error) { if !found { return removeContext{}, fmt.Errorf("worktree not found: %s\nUse 'gmc wt ls' to see available worktrees", name) } - pp := c.NewProtectionPolicy() + pp, err := c.NewProtectionPolicy() + if err != nil { + return removeContext{}, err + } if pp.IsProtected(wtInfo) { return removeContext{}, fmt.Errorf("cannot remove protected worktree '%s' (%s)", name, pp.Reason(wtInfo)) } @@ -687,7 +690,10 @@ func (c *Client) Promote(worktreeName, newBranchName string) (Report, error) { return report, errors.New("worktree is in detached HEAD state, cannot promote") } - pp := c.NewProtectionPolicy() + pp, err := c.NewProtectionPolicy() + if err != nil { + return report, err + } checkWt := Info{Path: targetPath, Branch: oldBranch} if pp.IsProtected(checkWt) { return report, fmt.Errorf("cannot promote protected worktree '%s' (%s)", worktreeName, pp.Reason(checkWt)) @@ -766,21 +772,21 @@ type ProtectionPolicy struct { RootPath string } -func (c *Client) NewProtectionPolicy() ProtectionPolicy { +func (c *Client) NewProtectionPolicy() (ProtectionPolicy, error) { var p ProtectionPolicy root, err := c.GetWorktreeRoot() if err != nil { - return p + return p, fmt.Errorf("failed to get worktree root: %w", err) } p.RootPath = root repoDir := repoDirForGit(root) isBareLayout := repoDir != root branch, err := c.resolveBaseBranchWithPolicy(repoDir, "", isBareLayout) if err != nil { - return p + return p, fmt.Errorf("failed to resolve main branch: %w", err) } p.MainBranch = localBranchName(branch) - return p + return p, nil } func (p ProtectionPolicy) IsProtected(wt Info) bool { @@ -806,12 +812,20 @@ func (p ProtectionPolicy) Reason(wt Info) string { return "main branch" } -func (c *Client) IsProtectedWorktree(wt Info) bool { - return c.NewProtectionPolicy().IsProtected(wt) +func (c *Client) IsProtectedWorktree(wt Info) (bool, error) { + pp, err := c.NewProtectionPolicy() + if err != nil { + return false, err + } + return pp.IsProtected(wt), nil } -func (c *Client) resolvedMainBranch() string { - return c.NewProtectionPolicy().MainBranch +func (c *Client) resolvedMainBranch() (string, error) { + pp, err := c.NewProtectionPolicy() + if err != nil { + return "", err + } + return pp.MainBranch, nil } // ListBranches returns all local branch names diff --git a/internal/worktree/worktree_test.go b/internal/worktree/worktree_test.go index 1807cb9..0e7ea78 100644 --- a/internal/worktree/worktree_test.go +++ b/internal/worktree/worktree_test.go @@ -615,7 +615,10 @@ func TestIsProtectedWorktree(t *testing.T) { } for _, wt := range worktrees { - protected := client.IsProtectedWorktree(wt) + protected, err := client.IsProtectedWorktree(wt) + if err != nil { + t.Fatalf("IsProtectedWorktree() error = %v", err) + } switch wt.Branch { case "main": if !protected {