diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 8d372cf10..72c83fd36 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -67,6 +67,12 @@ jobs: path: artifacts merge-multiple: true + - name: Generate checksums + run: | + cd artifacts + sha256sum *.tar.gz > SHA256SUMS + cat SHA256SUMS + - name: Get tag message id: tag_message run: | @@ -97,6 +103,8 @@ jobs: - name: Create Release uses: softprops/action-gh-release@v1 with: - files: artifacts/*.tar.gz + files: | + artifacts/*.tar.gz + artifacts/SHA256SUMS body: ${{ steps.tag_message.outputs.has_body == 'true' && steps.tag_message.outputs.body || '' }} generate_release_notes: ${{ steps.tag_message.outputs.has_body != 'true' }} diff --git a/cmd/roborev/main.go b/cmd/roborev/main.go index f9ea4f73c..de97863c4 100644 --- a/cmd/roborev/main.go +++ b/cmd/roborev/main.go @@ -20,6 +20,7 @@ import ( "github.com/wesm/roborev/internal/daemon" "github.com/wesm/roborev/internal/git" "github.com/wesm/roborev/internal/storage" + "github.com/wesm/roborev/internal/update" "github.com/wesm/roborev/internal/version" ) @@ -48,6 +49,7 @@ func main() { rootCmd.AddCommand(uninstallHookCmd()) rootCmd.AddCommand(daemonCmd()) rootCmd.AddCommand(tuiCmd()) + rootCmd.AddCommand(updateCmd()) rootCmd.AddCommand(versionCmd()) if err := rootCmd.Execute(); err != nil { @@ -744,9 +746,9 @@ func uninstallHookCmd() *cobra.Command { return fmt.Errorf("read hook: %w", err) } - // Check if it contains roborev + // Check if it contains roborev (case-insensitive) hookStr := string(content) - if !strings.Contains(hookStr, "roborev") { + if !strings.Contains(strings.ToLower(hookStr), "roborev") { fmt.Println("Post-commit hook does not contain roborev") return nil } @@ -755,8 +757,8 @@ func uninstallHookCmd() *cobra.Command { lines := strings.Split(hookStr, "\n") var newLines []string for _, line := range lines { - // Skip roborev-related lines - if strings.Contains(line, "roborev") || strings.Contains(line, "RoboRev") { + // Skip roborev-related lines (case-insensitive) + if strings.Contains(strings.ToLower(line), "roborev") { continue } newLines = append(newLines, line) @@ -792,6 +794,124 @@ func uninstallHookCmd() *cobra.Command { } } +func updateCmd() *cobra.Command { + var checkOnly bool + var yes bool + + cmd := &cobra.Command{ + Use: "update", + Short: "Update roborev to the latest version", + Long: `Check for and install roborev updates. + +Shows exactly what will be downloaded and where it will be installed. +Requires confirmation before making changes (use --yes to skip).`, + RunE: func(cmd *cobra.Command, args []string) error { + fmt.Println("Checking for updates...") + + info, err := update.CheckForUpdate(true) // Force check, ignore cache + if err != nil { + return fmt.Errorf("check for updates: %w", err) + } + + if info == nil { + fmt.Printf("Already running latest version (%s)\n", version.Version) + return nil + } + + fmt.Printf("\n Current version: %s\n", info.CurrentVersion) + fmt.Printf(" Latest version: %s\n", info.LatestVersion) + fmt.Println("\nUpdate available!") + fmt.Println("\nDownload:") + fmt.Printf(" URL: %s\n", info.DownloadURL) + fmt.Printf(" Size: %s\n", update.FormatSize(info.Size)) + if info.Checksum != "" { + fmt.Printf(" SHA256: %s\n", info.Checksum) + } + + // Show install location + currentExe, err := os.Executable() + if err != nil { + return fmt.Errorf("find executable: %w", err) + } + currentExe, _ = filepath.EvalSymlinks(currentExe) + binDir := filepath.Dir(currentExe) + + fmt.Println("\nInstall location:") + fmt.Printf(" %s\n", binDir) + + if checkOnly { + return nil + } + + // Confirm + if !yes { + fmt.Print("\nProceed with update? [y/N] ") + var response string + fmt.Scanln(&response) + if strings.ToLower(response) != "y" && strings.ToLower(response) != "yes" { + fmt.Println("Update cancelled") + return nil + } + } + + fmt.Println() + + // Progress display + var lastPercent int + progressFn := func(downloaded, total int64) { + if total > 0 { + percent := int(downloaded * 100 / total) + if percent != lastPercent { + fmt.Printf("\rDownloading... %d%% (%s / %s)", + percent, update.FormatSize(downloaded), update.FormatSize(total)) + lastPercent = percent + } + } + } + + // Perform update + if err := update.PerformUpdate(info, progressFn); err != nil { + return fmt.Errorf("update failed: %w", err) + } + + fmt.Printf("\nUpdated to %s\n", info.LatestVersion) + + // Restart daemon if running + if daemonInfo, err := daemon.ReadRuntime(); err == nil && daemonInfo != nil { + fmt.Print("Restarting daemon... ") + // Stop old daemon with timeout + stopURL := fmt.Sprintf("http://%s/api/shutdown", daemonInfo.Addr) + client := &http.Client{Timeout: 5 * time.Second} + if resp, err := client.Post(stopURL, "application/json", nil); err != nil { + fmt.Printf("warning: failed to stop daemon: %v\n", err) + } else { + resp.Body.Close() + } + time.Sleep(500 * time.Millisecond) + + // Start new daemon + daemonPath := filepath.Join(binDir, "roborevd") + if runtime.GOOS == "windows" { + daemonPath += ".exe" + } + startCmd := exec.Command(daemonPath) + if err := startCmd.Start(); err != nil { + fmt.Printf("warning: failed to start daemon: %v\n", err) + } else { + fmt.Println("OK") + } + } + + return nil + }, + } + + cmd.Flags().BoolVar(&checkOnly, "check", false, "only check for updates, don't install") + cmd.Flags().BoolVarP(&yes, "yes", "y", false, "skip confirmation prompt") + + return cmd +} + func versionCmd() *cobra.Command { return &cobra.Command{ Use: "version", diff --git a/cmd/roborev/main_test.go b/cmd/roborev/main_test.go index a3849aac0..e3a43db42 100644 --- a/cmd/roborev/main_test.go +++ b/cmd/roborev/main_test.go @@ -7,6 +7,7 @@ import ( "os" "os/exec" "path/filepath" + "strings" "testing" "github.com/wesm/roborev/internal/daemon" @@ -162,3 +163,146 @@ func TestEnqueueCmdPositionalArg(t *testing.T) { } }) } + +func TestUninstallHookCmd(t *testing.T) { + // Helper to create a git repo with an optional hook + setupRepo := func(t *testing.T, hookContent string) (repoPath string, hookPath string) { + tmpDir := t.TempDir() + + // Initialize git repo + cmd := exec.Command("git", "init") + cmd.Dir = tmpDir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git init failed: %v\n%s", err, out) + } + + hookPath = filepath.Join(tmpDir, ".git", "hooks", "post-commit") + + if hookContent != "" { + if err := os.MkdirAll(filepath.Dir(hookPath), 0755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(hookPath, []byte(hookContent), 0755); err != nil { + t.Fatal(err) + } + } + + return tmpDir, hookPath + } + + t.Run("hook missing", func(t *testing.T) { + repoPath, hookPath := setupRepo(t, "") + + // Change to repo dir for the command + origDir, _ := os.Getwd() + os.Chdir(repoPath) + defer os.Chdir(origDir) + + cmd := uninstallHookCmd() + err := cmd.Execute() + if err != nil { + t.Fatalf("uninstall-hook failed: %v", err) + } + + // Hook should still not exist + if _, err := os.Stat(hookPath); !os.IsNotExist(err) { + t.Error("Hook file should not exist") + } + }) + + t.Run("hook without roborev", func(t *testing.T) { + hookContent := "#!/bin/bash\necho 'other hook'\n" + repoPath, hookPath := setupRepo(t, hookContent) + + origDir, _ := os.Getwd() + os.Chdir(repoPath) + defer os.Chdir(origDir) + + cmd := uninstallHookCmd() + err := cmd.Execute() + if err != nil { + t.Fatalf("uninstall-hook failed: %v", err) + } + + // Hook should be unchanged + content, err := os.ReadFile(hookPath) + if err != nil { + t.Fatalf("Failed to read hook: %v", err) + } + if string(content) != hookContent { + t.Errorf("Hook content changed: got %q, want %q", string(content), hookContent) + } + }) + + t.Run("hook with roborev only - removes file", func(t *testing.T) { + hookContent := "#!/bin/bash\n# RoboRev auto-commit hook\nroborev enqueue\n" + repoPath, hookPath := setupRepo(t, hookContent) + + origDir, _ := os.Getwd() + os.Chdir(repoPath) + defer os.Chdir(origDir) + + cmd := uninstallHookCmd() + err := cmd.Execute() + if err != nil { + t.Fatalf("uninstall-hook failed: %v", err) + } + + // Hook should be removed entirely + if _, err := os.Stat(hookPath); !os.IsNotExist(err) { + t.Error("Hook file should have been removed") + } + }) + + t.Run("hook with roborev and other commands - preserves others", func(t *testing.T) { + hookContent := "#!/bin/bash\necho 'before'\nroborev enqueue\necho 'after'\n" + repoPath, hookPath := setupRepo(t, hookContent) + + origDir, _ := os.Getwd() + os.Chdir(repoPath) + defer os.Chdir(origDir) + + cmd := uninstallHookCmd() + err := cmd.Execute() + if err != nil { + t.Fatalf("uninstall-hook failed: %v", err) + } + + // Hook should exist with roborev line removed + content, err := os.ReadFile(hookPath) + if err != nil { + t.Fatalf("Failed to read hook: %v", err) + } + + contentStr := string(content) + if strings.Contains(strings.ToLower(contentStr), "roborev") { + t.Error("Hook should not contain roborev") + } + if !strings.Contains(contentStr, "echo 'before'") { + t.Error("Hook should still contain 'echo before'") + } + if !strings.Contains(contentStr, "echo 'after'") { + t.Error("Hook should still contain 'echo after'") + } + }) + + t.Run("hook with capitalized RoboRev", func(t *testing.T) { + hookContent := "#!/bin/bash\n# RoboRev hook\nRoboRev enqueue\n" + repoPath, hookPath := setupRepo(t, hookContent) + + origDir, _ := os.Getwd() + os.Chdir(repoPath) + defer os.Chdir(origDir) + + cmd := uninstallHookCmd() + err := cmd.Execute() + if err != nil { + t.Fatalf("uninstall-hook failed: %v", err) + } + + // Hook should be removed (only had RoboRev content) + if _, err := os.Stat(hookPath); !os.IsNotExist(err) { + t.Error("Hook file should have been removed") + } + }) +} diff --git a/cmd/roborev/tui.go b/cmd/roborev/tui.go index 7aeabae31..fd6d80a5a 100644 --- a/cmd/roborev/tui.go +++ b/cmd/roborev/tui.go @@ -13,6 +13,7 @@ import ( "github.com/spf13/cobra" "github.com/wesm/roborev/internal/daemon" "github.com/wesm/roborev/internal/storage" + "github.com/wesm/roborev/internal/update" "github.com/wesm/roborev/internal/version" ) @@ -65,6 +66,7 @@ type tuiModel struct { width int height int err error + updateAvailable string // Latest version if update available, empty if up to date } type tuiTickMsg time.Time @@ -87,6 +89,7 @@ type tuiCancelResultMsg struct { err error } type tuiErrMsg error +type tuiUpdateCheckMsg string // Latest version if available, empty if up to date func newTuiModel(serverAddr string) tuiModel { // Get daemon version from runtime info @@ -112,6 +115,7 @@ func (m tuiModel) Init() tea.Cmd { m.tick(), m.fetchJobs(), m.fetchStatus(), + m.checkForUpdate(), ) } @@ -163,6 +167,16 @@ func (m tuiModel) fetchStatus() tea.Cmd { } } +func (m tuiModel) checkForUpdate() tea.Cmd { + return func() tea.Msg { + info, err := update.CheckForUpdate(false) // Use cache + if err != nil || info == nil { + return tuiUpdateCheckMsg("") // No update or error + } + return tuiUpdateCheckMsg(info.LatestVersion) + } +} + func (m tuiModel) fetchReview(jobID int64) tea.Cmd { return func() tea.Msg { resp, err := m.client.Get(fmt.Sprintf("%s/api/review?job_id=%d", m.serverAddr, jobID)) @@ -622,6 +636,9 @@ func (m tuiModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case tuiStatusMsg: m.status = storage.DaemonStatus(msg) + case tuiUpdateCheckMsg: + m.updateAvailable = string(msg) + case tuiReviewMsg: m.currentReview = msg m.currentView = tuiViewReview @@ -692,7 +709,15 @@ func (m tuiModel) renderQueueView() string { m.status.CompletedJobs, m.status.FailedJobs, m.status.CanceledJobs) b.WriteString(tuiStatusStyle.Render(statusLine)) - b.WriteString("\n\n") + b.WriteString("\n") + + // Update notification + if m.updateAvailable != "" { + updateStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("226")).Bold(true) + b.WriteString(updateStyle.Render(fmt.Sprintf("Update available: %s - run 'roborev update'", m.updateAvailable))) + b.WriteString("\n") + } + b.WriteString("\n") if len(m.jobs) == 0 { b.WriteString("No jobs in queue\n") diff --git a/internal/daemon/worker.go b/internal/daemon/worker.go index 4648a28b6..29d4b6891 100644 --- a/internal/daemon/worker.go +++ b/internal/daemon/worker.go @@ -29,6 +29,9 @@ type WorkerPool struct { runningJobs map[int64]context.CancelFunc pendingCancels map[int64]bool // Jobs canceled before registered runningJobsMu sync.Mutex + + // Test hooks for deterministic synchronization (nil in production) + testHookAfterSecondCheck func() // Called after second runningJobs check, before second DB lookup } // NewWorkerPool creates a new worker pool @@ -116,6 +119,11 @@ func (wp *WorkerPool) CancelJob(jobID int64) bool { } wp.runningJobsMu.Unlock() + // Test hook: allows tests to register job between second check and final check + if wp.testHookAfterSecondCheck != nil { + wp.testHookAfterSecondCheck() + } + // Re-verify job is still cancellable before adding to pendingCancels // The job may have registered and finished during our DB lookup window // Do this outside the lock to avoid blocking other operations diff --git a/internal/daemon/worker_test.go b/internal/daemon/worker_test.go index 8c2a6ceb9..c683339eb 100644 --- a/internal/daemon/worker_test.go +++ b/internal/daemon/worker_test.go @@ -1,7 +1,6 @@ package daemon import ( - "fmt" "path/filepath" "sync/atomic" "testing" @@ -478,7 +477,8 @@ func TestWorkerPoolCancelJobRegisteredDuringCheck(t *testing.T) { func TestWorkerPoolCancelJobConcurrentRegister(t *testing.T) { // Test concurrent registration during CancelJob - // This exercises the race condition where a job registers during DB lookup + // Uses a test hook to deterministically register the job during CancelJob's + // DB lookup window, exercising the "registration during cancel" code path tmpDir := t.TempDir() dbPath := filepath.Join(tmpDir, "test.db") @@ -496,53 +496,48 @@ func TestWorkerPoolCancelJobConcurrentRegister(t *testing.T) { t.Fatalf("GetOrCreateRepo failed: %v", err) } - // Run multiple iterations to increase chance of hitting race - for i := 0; i < 10; i++ { - sha := fmt.Sprintf("concurrent-race-%d", i) - commit, err := db.GetOrCreateCommit(repo.ID, sha, "Author", "Subject", time.Now()) - if err != nil { - t.Fatalf("GetOrCreateCommit failed: %v", err) - } - job, err := db.EnqueueJob(repo.ID, commit.ID, sha, "test") - if err != nil { - t.Fatalf("EnqueueJob failed: %v", err) - } - _, err = db.ClaimJob("test-worker") - if err != nil { - t.Fatalf("ClaimJob failed: %v", err) - } - - var canceled int32 - cancelFunc := func() { atomic.AddInt32(&canceled, 1) } + sha := "concurrent-register" + commit, err := db.GetOrCreateCommit(repo.ID, sha, "Author", "Subject", time.Now()) + if err != nil { + t.Fatalf("GetOrCreateCommit failed: %v", err) + } + job, err := db.EnqueueJob(repo.ID, commit.ID, sha, "test") + if err != nil { + t.Fatalf("EnqueueJob failed: %v", err) + } + _, err = db.ClaimJob("test-worker") + if err != nil { + t.Fatalf("ClaimJob failed: %v", err) + } - // Start CancelJob in goroutine - cancelDone := make(chan bool) - go func() { - cancelDone <- pool.CancelJob(job.ID) - }() + var canceled int32 + cancelFunc := func() { atomic.AddInt32(&canceled, 1) } - // Concurrently register the job (simulates worker starting) + // Set up hook to register the job at a deterministic point during CancelJob + // This happens after the second runningJobs check, ensuring we exercise + // the final check code path where registration occurs during DB lookup + pool.testHookAfterSecondCheck = func() { pool.registerRunningJob(job.ID, cancelFunc) + } - // Wait for CancelJob to complete - result := <-cancelDone - - // Job should have been canceled (either via runningJobs or pendingCancels) - if !result { - t.Errorf("Iteration %d: CancelJob should return true", i) - } - if atomic.LoadInt32(&canceled) == 0 { - t.Errorf("Iteration %d: Job should have been canceled", i) - } + // CancelJob should find the job via the final check and cancel it + result := pool.CancelJob(job.ID) - // Clean up for next iteration - pool.unregisterRunningJob(job.ID) + if !result { + t.Error("CancelJob should return true") } + if atomic.LoadInt32(&canceled) != 1 { + t.Error("Job should have been canceled exactly once") + } + + // Clean up + pool.unregisterRunningJob(job.ID) } func TestWorkerPoolCancelJobFinalCheckDeadlockSafe(t *testing.T) { // Test that cancel() is called without holding the lock (no deadlock) - // This verifies the fix for the "final check" path + // This verifies the fix for the "final check" path by using a test hook + // to deterministically register the job between the second check and final check tmpDir := t.TempDir() dbPath := filepath.Join(tmpDir, "test.db") @@ -581,24 +576,30 @@ func TestWorkerPoolCancelJobFinalCheckDeadlockSafe(t *testing.T) { pool.unregisterRunningJob(job.ID) } - // Register the job - pool.registerRunningJob(job.ID, cancelFunc) + // Set up hook to register the job between second check and final check + // This ensures we exercise the "final check" code path + pool.testHookAfterSecondCheck = func() { + pool.registerRunningJob(job.ID, cancelFunc) + } // CancelJob should complete without deadlock + // The job is NOT registered initially, so it passes first and second checks, + // then the hook fires and registers it, then the final check finds it done := make(chan bool) go func() { - pool.CancelJob(job.ID) - done <- true + done <- pool.CancelJob(job.ID) }() select { - case <-done: - // Success - no deadlock + case result := <-done: + if !result { + t.Error("CancelJob should return true") + } case <-time.After(2 * time.Second): t.Fatal("CancelJob deadlocked - cancel() called while holding lock") } if !canceled { - t.Error("Job should have been canceled") + t.Error("Job should have been canceled via final check path") } } diff --git a/internal/update/update.go b/internal/update/update.go new file mode 100644 index 000000000..2b3066610 --- /dev/null +++ b/internal/update/update.go @@ -0,0 +1,551 @@ +package update + +import ( + "archive/tar" + "compress/gzip" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "regexp" + "runtime" + "strings" + "time" + + "github.com/wesm/roborev/internal/version" +) + +const ( + githubAPIURL = "https://api.github.com/repos/wesm/roborev/releases/latest" + cacheFileName = "update_check.json" + cacheDuration = 1 * time.Hour +) + +// Release represents a GitHub release +type Release struct { + TagName string `json:"tag_name"` + Body string `json:"body"` + Assets []Asset `json:"assets"` +} + +// Asset represents a release asset +type Asset struct { + Name string `json:"name"` + Size int64 `json:"size"` + BrowserDownloadURL string `json:"browser_download_url"` +} + +// UpdateInfo contains information about an available update +type UpdateInfo struct { + CurrentVersion string + LatestVersion string + DownloadURL string + AssetName string + Size int64 + Checksum string // SHA256 if available +} + +// findAssets locates the platform-specific binary and checksums file from release assets +func findAssets(assets []Asset, assetName string) (asset *Asset, checksumsAsset *Asset) { + for i := range assets { + a := &assets[i] + if a.Name == assetName { + asset = a + } + if a.Name == "SHA256SUMS" || a.Name == "checksums.txt" { + checksumsAsset = a + } + } + return asset, checksumsAsset +} + +// cachedCheck stores the last update check result +type cachedCheck struct { + CheckedAt time.Time `json:"checked_at"` + Version string `json:"version"` +} + +// CheckForUpdate checks if a newer version is available +// Uses a 1-hour cache to avoid hitting GitHub API too often +func CheckForUpdate(forceCheck bool) (*UpdateInfo, error) { + currentVersion := strings.TrimPrefix(version.Version, "v") + + // Check cache first (unless forced) + if !forceCheck { + if cached, err := loadCache(); err == nil { + if time.Since(cached.CheckedAt) < cacheDuration { + latestVersion := strings.TrimPrefix(cached.Version, "v") + if !isNewer(latestVersion, currentVersion) { + return nil, nil // Up to date (cached) + } + // Cache says update available, fetch fresh info + } + } + } + + // Fetch latest release from GitHub + release, err := fetchLatestRelease() + if err != nil { + return nil, fmt.Errorf("check for updates: %w", err) + } + + // Save to cache + saveCache(release.TagName) + + latestVersion := strings.TrimPrefix(release.TagName, "v") + if !isNewer(latestVersion, currentVersion) { + return nil, nil // Up to date + } + + // Find the right asset for this platform + // Asset naming: roborev___.tar.gz (e.g., roborev_0.3.0_darwin_arm64.tar.gz) + assetName := fmt.Sprintf("roborev_%s_%s_%s.tar.gz", latestVersion, runtime.GOOS, runtime.GOARCH) + asset, checksumsAsset := findAssets(release.Assets, assetName) + if asset == nil { + return nil, fmt.Errorf("no release asset found for %s/%s", runtime.GOOS, runtime.GOARCH) + } + + // Get checksum - first try checksums file, then release body + var checksum string + if checksumsAsset != nil { + checksum, _ = fetchChecksumFromFile(checksumsAsset.BrowserDownloadURL, assetName) + } + if checksum == "" { + // Fall back to release body + checksum = extractChecksum(release.Body, assetName) + } + + return &UpdateInfo{ + CurrentVersion: version.Version, + LatestVersion: release.TagName, + DownloadURL: asset.BrowserDownloadURL, + AssetName: asset.Name, + Size: asset.Size, + Checksum: checksum, + }, nil +} + +// PerformUpdate downloads and installs the update +func PerformUpdate(info *UpdateInfo, progressFn func(downloaded, total int64)) error { + // Security: require checksum verification + if info.Checksum == "" { + return fmt.Errorf("no checksum available for %s - refusing to install unverified binary", info.AssetName) + } + + // 1. Download to temp file + fmt.Printf("Downloading %s...\n", info.AssetName) + tempDir, err := os.MkdirTemp("", "roborev-update-*") + if err != nil { + return fmt.Errorf("create temp dir: %w", err) + } + defer os.RemoveAll(tempDir) + + archivePath := filepath.Join(tempDir, info.AssetName) + checksum, err := downloadFile(info.DownloadURL, archivePath, info.Size, progressFn) + if err != nil { + return fmt.Errorf("download: %w", err) + } + + // 2. Verify checksum (required) + fmt.Printf("Verifying checksum... ") + if !strings.EqualFold(checksum, info.Checksum) { + fmt.Println("FAILED") + return fmt.Errorf("checksum mismatch: expected %s, got %s", info.Checksum, checksum) + } + fmt.Println("OK") + + // 3. Extract archive + fmt.Println("Extracting...") + extractDir := filepath.Join(tempDir, "extracted") + if err := extractTarGz(archivePath, extractDir); err != nil { + return fmt.Errorf("extract: %w", err) + } + + // 4. Find current binary locations + currentExe, err := os.Executable() + if err != nil { + return fmt.Errorf("find current executable: %w", err) + } + currentExe, err = filepath.EvalSymlinks(currentExe) + if err != nil { + return fmt.Errorf("resolve symlinks: %w", err) + } + binDir := filepath.Dir(currentExe) + + // 5. Install new binaries + binaries := []string{"roborev", "roborevd"} + if runtime.GOOS == "windows" { + binaries = []string{"roborev.exe", "roborevd.exe"} + } + + for _, binary := range binaries { + srcPath := filepath.Join(extractDir, binary) + dstPath := filepath.Join(binDir, binary) + backupPath := dstPath + ".old" + + // Check if source exists + if _, err := os.Stat(srcPath); os.IsNotExist(err) { + continue // Skip if not in archive + } + + fmt.Printf("Installing %s... ", binary) + + // Clean up any old backup from previous update + os.Remove(backupPath) + + // Backup existing + if _, err := os.Stat(dstPath); err == nil { + if err := os.Rename(dstPath, backupPath); err != nil { + // On Windows, renaming a running executable may fail + if runtime.GOOS == "windows" { + return fmt.Errorf("cannot update %s while it is running - please stop the daemon and try again: %w", binary, err) + } + return fmt.Errorf("backup %s: %w", binary, err) + } + } + + // Copy new binary + if err := copyFile(srcPath, dstPath); err != nil { + // Try to restore backup + os.Rename(backupPath, dstPath) + return fmt.Errorf("install %s: %w", binary, err) + } + + // Set executable permission (no-op on Windows) + if runtime.GOOS != "windows" { + if err := os.Chmod(dstPath, 0755); err != nil { + return fmt.Errorf("chmod %s: %w", binary, err) + } + } + + // Try to remove backup (may fail on Windows if daemon was running) + // The .old file will be cleaned up on next update + os.Remove(backupPath) + + fmt.Println("OK") + } + + return nil +} + +// RestartDaemon stops and starts the daemon +func RestartDaemon() error { + // Find roborevd and restart it + // We do this by calling the daemon restart command + // Since we're in a library, we'll just return instructions + // The CLI will handle the actual restart + return nil +} + +// GetCacheDir returns the roborev cache directory +func GetCacheDir() string { + home, _ := os.UserHomeDir() + return filepath.Join(home, ".roborev") +} + +func fetchLatestRelease() (*Release, error) { + req, err := http.NewRequest("GET", githubAPIURL, nil) + if err != nil { + return nil, err + } + req.Header.Set("Accept", "application/vnd.github.v3+json") + req.Header.Set("User-Agent", "roborev/"+version.Version) + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("GitHub API returned %s", resp.Status) + } + + var release Release + if err := json.NewDecoder(resp.Body).Decode(&release); err != nil { + return nil, err + } + + return &release, nil +} + +func downloadFile(url, dest string, totalSize int64, progressFn func(downloaded, total int64)) (string, error) { + resp, err := http.Get(url) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("download failed: %s", resp.Status) + } + + out, err := os.Create(dest) + if err != nil { + return "", err + } + defer out.Close() + + // Calculate checksum while downloading + hasher := sha256.New() + writer := io.MultiWriter(out, hasher) + + // Download with progress + var downloaded int64 + buf := make([]byte, 32*1024) + for { + n, err := resp.Body.Read(buf) + if n > 0 { + _, writeErr := writer.Write(buf[:n]) + if writeErr != nil { + return "", writeErr + } + downloaded += int64(n) + if progressFn != nil { + progressFn(downloaded, totalSize) + } + } + if err == io.EOF { + break + } + if err != nil { + return "", err + } + } + + return hex.EncodeToString(hasher.Sum(nil)), nil +} + +func extractTarGz(archivePath, destDir string) error { + if err := os.MkdirAll(destDir, 0755); err != nil { + return err + } + + // Get absolute path of destDir for security checks + absDestDir, err := filepath.Abs(destDir) + if err != nil { + return fmt.Errorf("resolve dest dir: %w", err) + } + + file, err := os.Open(archivePath) + if err != nil { + return err + } + defer file.Close() + + gzr, err := gzip.NewReader(file) + if err != nil { + return err + } + defer gzr.Close() + + tr := tar.NewReader(gzr) + for { + header, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return err + } + + // Security: sanitize and validate the path + target, err := sanitizeTarPath(absDestDir, header.Name) + if err != nil { + return fmt.Errorf("invalid tar entry %q: %w", header.Name, err) + } + + // Security: skip symlinks and hardlinks to prevent attacks + if header.Typeflag == tar.TypeSymlink || header.Typeflag == tar.TypeLink { + continue + } + + switch header.Typeflag { + case tar.TypeDir: + if err := os.MkdirAll(target, 0755); err != nil { + return err + } + case tar.TypeReg: + // Ensure parent directory exists + if err := os.MkdirAll(filepath.Dir(target), 0755); err != nil { + return err + } + outFile, err := os.Create(target) + if err != nil { + return err + } + if _, err := io.Copy(outFile, tr); err != nil { + outFile.Close() + return err + } + outFile.Close() + if err := os.Chmod(target, os.FileMode(header.Mode)); err != nil { + return err + } + } + } + + return nil +} + +// sanitizeTarPath validates and sanitizes a tar entry path to prevent directory traversal +func sanitizeTarPath(destDir, name string) (string, error) { + // Clean the path to remove . and .. components + cleanName := filepath.Clean(name) + + // Reject absolute paths + if filepath.IsAbs(cleanName) { + return "", fmt.Errorf("absolute path not allowed") + } + + // Reject paths that try to escape with .. + if strings.HasPrefix(cleanName, "..") || strings.Contains(cleanName, string(filepath.Separator)+"..") { + return "", fmt.Errorf("path traversal not allowed") + } + + // Build the target path + target := filepath.Join(destDir, cleanName) + + // Final check: ensure the target is within destDir + // This catches any edge cases the above checks might miss + absTarget, err := filepath.Abs(target) + if err != nil { + return "", err + } + if !strings.HasPrefix(absTarget, destDir+string(filepath.Separator)) && absTarget != destDir { + return "", fmt.Errorf("path escapes destination directory") + } + + return target, nil +} + +func copyFile(src, dst string) error { + in, err := os.Open(src) + if err != nil { + return err + } + defer in.Close() + + out, err := os.Create(dst) + if err != nil { + return err + } + defer out.Close() + + if _, err := io.Copy(out, in); err != nil { + return err + } + + return out.Close() +} + +// fetchChecksumFromFile downloads a checksums file and extracts the checksum for assetName +func fetchChecksumFromFile(url, assetName string) (string, error) { + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Get(url) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("failed to fetch checksums: %s", resp.Status) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + return extractChecksum(string(body), assetName), nil +} + +func extractChecksum(releaseBody, assetName string) string { + // Look for checksum in release notes or checksums file + // Format: "checksum assetname" (standard sha256sum output) or "assetname: checksum" + lines := strings.Split(releaseBody, "\n") + // Case-insensitive regex for SHA256 hex (64 chars) + re := regexp.MustCompile(`(?i)[a-f0-9]{64}`) + for _, line := range lines { + line = strings.TrimSpace(line) + if strings.Contains(line, assetName) { + if match := re.FindString(line); match != "" { + return strings.ToLower(match) // Normalize to lowercase + } + } + } + return "" +} + +func loadCache() (*cachedCheck, error) { + cachePath := filepath.Join(GetCacheDir(), cacheFileName) + data, err := os.ReadFile(cachePath) + if err != nil { + return nil, err + } + var cached cachedCheck + if err := json.Unmarshal(data, &cached); err != nil { + return nil, err + } + return &cached, nil +} + +func saveCache(version string) { + cached := cachedCheck{ + CheckedAt: time.Now(), + Version: version, + } + data, err := json.Marshal(cached) + if err != nil { + return + } + cachePath := filepath.Join(GetCacheDir(), cacheFileName) + os.MkdirAll(filepath.Dir(cachePath), 0755) + os.WriteFile(cachePath, data, 0644) +} + +// isNewer returns true if v1 is newer than v2 +// Assumes semver format: major.minor.patch +func isNewer(v1, v2 string) bool { + v1 = strings.TrimPrefix(v1, "v") + v2 = strings.TrimPrefix(v2, "v") + + parts1 := strings.Split(v1, ".") + parts2 := strings.Split(v2, ".") + + for i := 0; i < 3; i++ { + var n1, n2 int + if i < len(parts1) { + fmt.Sscanf(parts1[i], "%d", &n1) + } + if i < len(parts2) { + fmt.Sscanf(parts2[i], "%d", &n2) + } + if n1 > n2 { + return true + } + if n1 < n2 { + return false + } + } + return false +} + +// FormatSize formats bytes as human-readable string +func FormatSize(bytes int64) string { + const unit = 1024 + if bytes < unit { + return fmt.Sprintf("%d B", bytes) + } + div, exp := int64(unit), 0 + for n := bytes / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp]) +} diff --git a/internal/update/update_test.go b/internal/update/update_test.go new file mode 100644 index 000000000..38e7af8fc --- /dev/null +++ b/internal/update/update_test.go @@ -0,0 +1,330 @@ +package update + +import ( + "archive/tar" + "compress/gzip" + "os" + "path/filepath" + "testing" +) + +func TestSanitizeTarPath(t *testing.T) { + destDir := "/tmp/extract" + + tests := []struct { + name string + path string + wantErr bool + }{ + {"normal file", "roborev", false}, + {"nested file", "bin/roborev", false}, + {"absolute path", "/etc/passwd", true}, + {"path traversal with ..", "../../../etc/passwd", true}, + {"path traversal mid-path", "foo/../../../etc/passwd", true}, + {"hidden traversal", "foo/bar/../../..", true}, + {"dot only", ".", false}, + {"double dot only", "..", true}, + {"empty path", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := sanitizeTarPath(destDir, tt.path) + if (err != nil) != tt.wantErr { + t.Errorf("sanitizeTarPath(%q) error = %v, wantErr %v", tt.path, err, tt.wantErr) + } + }) + } +} + +func TestExtractTarGzPathTraversal(t *testing.T) { + // Create a malicious tar.gz with path traversal + tmpDir := t.TempDir() + archivePath := filepath.Join(tmpDir, "malicious.tar.gz") + extractDir := filepath.Join(tmpDir, "extract") + outsideFile := filepath.Join(tmpDir, "pwned") + + // Create archive with path traversal attempt + f, err := os.Create(archivePath) + if err != nil { + t.Fatal(err) + } + gzw := gzip.NewWriter(f) + tw := tar.NewWriter(gzw) + + // Add a malicious entry that tries to escape + header := &tar.Header{ + Name: "../pwned", + Mode: 0644, + Size: 5, + } + if err := tw.WriteHeader(header); err != nil { + t.Fatal(err) + } + if _, err := tw.Write([]byte("owned")); err != nil { + t.Fatal(err) + } + + tw.Close() + gzw.Close() + f.Close() + + // Extract should fail + err = extractTarGz(archivePath, extractDir) + if err == nil { + t.Error("extractTarGz should fail with path traversal attempt") + } + + // Verify the file wasn't created outside + if _, err := os.Stat(outsideFile); !os.IsNotExist(err) { + t.Error("Malicious file was created outside extract dir") + } +} + +func TestExtractTarGzSymlinkSkipped(t *testing.T) { + // Create a tar.gz with a symlink + tmpDir := t.TempDir() + archivePath := filepath.Join(tmpDir, "symlink.tar.gz") + extractDir := filepath.Join(tmpDir, "extract") + + f, err := os.Create(archivePath) + if err != nil { + t.Fatal(err) + } + gzw := gzip.NewWriter(f) + tw := tar.NewWriter(gzw) + + // Add a symlink entry + header := &tar.Header{ + Name: "evil-link", + Typeflag: tar.TypeSymlink, + Linkname: "/etc/passwd", + } + if err := tw.WriteHeader(header); err != nil { + t.Fatal(err) + } + + // Add a normal file + header = &tar.Header{ + Name: "normal.txt", + Mode: 0644, + Size: 4, + } + if err := tw.WriteHeader(header); err != nil { + t.Fatal(err) + } + if _, err := tw.Write([]byte("test")); err != nil { + t.Fatal(err) + } + + tw.Close() + gzw.Close() + f.Close() + + // Extract should succeed (symlinks are skipped) + if err := extractTarGz(archivePath, extractDir); err != nil { + t.Fatalf("extractTarGz failed: %v", err) + } + + // Normal file should exist + if _, err := os.Stat(filepath.Join(extractDir, "normal.txt")); err != nil { + t.Error("Normal file should have been extracted") + } + + // Symlink should not exist + if _, err := os.Lstat(filepath.Join(extractDir, "evil-link")); !os.IsNotExist(err) { + t.Error("Symlink should have been skipped") + } +} + +func TestExtractChecksum(t *testing.T) { + tests := []struct { + name string + body string + assetName string + want string + }{ + { + name: "standard sha256sum format", + body: "abc123def456789012345678901234567890123456789012345678901234abcd roborev_darwin_arm64.tar.gz", + assetName: "roborev_darwin_arm64.tar.gz", + want: "abc123def456789012345678901234567890123456789012345678901234abcd", + }, + { + name: "uppercase checksum", + body: "ABC123DEF456789012345678901234567890123456789012345678901234ABCD roborev_linux_amd64.tar.gz", + assetName: "roborev_linux_amd64.tar.gz", + want: "abc123def456789012345678901234567890123456789012345678901234abcd", + }, + { + name: "mixed case checksum", + body: "AbC123DeF456789012345678901234567890123456789012345678901234aBcD roborev_darwin_amd64.tar.gz", + assetName: "roborev_darwin_amd64.tar.gz", + want: "abc123def456789012345678901234567890123456789012345678901234abcd", + }, + { + name: "colon format", + body: "roborev_darwin_arm64.tar.gz: abc123def456789012345678901234567890123456789012345678901234abcd", + assetName: "roborev_darwin_arm64.tar.gz", + want: "abc123def456789012345678901234567890123456789012345678901234abcd", + }, + { + name: "multiline with target in middle", + body: "abc123def456789012345678901234567890123456789012345678901234aaaa roborev_linux_amd64.tar.gz\nabc123def456789012345678901234567890123456789012345678901234bbbb roborev_darwin_arm64.tar.gz\nabc123def456789012345678901234567890123456789012345678901234cccc roborev_darwin_amd64.tar.gz", + assetName: "roborev_darwin_arm64.tar.gz", + want: "abc123def456789012345678901234567890123456789012345678901234bbbb", + }, + { + name: "no match", + body: "abc123def456789012345678901234567890123456789012345678901234abcd roborev_linux_amd64.tar.gz", + assetName: "roborev_darwin_arm64.tar.gz", + want: "", + }, + { + name: "empty body", + body: "", + assetName: "roborev_darwin_arm64.tar.gz", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractChecksum(tt.body, tt.assetName) + if got != tt.want { + t.Errorf("extractChecksum() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestIsNewer(t *testing.T) { + tests := []struct { + v1, v2 string + want bool + }{ + {"1.0.0", "0.9.0", true}, + {"1.1.0", "1.0.0", true}, + {"1.0.1", "1.0.0", true}, + {"2.0.0", "1.9.9", true}, + {"1.0.0", "1.0.0", false}, + {"0.9.0", "1.0.0", false}, + {"v1.0.0", "v0.9.0", true}, + {"v1.0.0", "0.9.0", true}, + {"1.0.0", "v0.9.0", true}, + } + + for _, tt := range tests { + t.Run(tt.v1+"_vs_"+tt.v2, func(t *testing.T) { + got := isNewer(tt.v1, tt.v2) + if got != tt.want { + t.Errorf("isNewer(%q, %q) = %v, want %v", tt.v1, tt.v2, got, tt.want) + } + }) + } +} + +func TestFindAssets(t *testing.T) { + // Test that findAssets correctly selects assets from a list with multiple items + // This specifically tests the loop variable pointer bug fix + assets := []Asset{ + {Name: "roborev_linux_amd64.tar.gz", Size: 1000, BrowserDownloadURL: "https://example.com/linux_amd64"}, + {Name: "roborev_darwin_arm64.tar.gz", Size: 2000, BrowserDownloadURL: "https://example.com/darwin_arm64"}, + {Name: "SHA256SUMS", Size: 500, BrowserDownloadURL: "https://example.com/checksums"}, + {Name: "roborev_darwin_amd64.tar.gz", Size: 3000, BrowserDownloadURL: "https://example.com/darwin_amd64"}, + {Name: "roborev_windows_amd64.zip", Size: 4000, BrowserDownloadURL: "https://example.com/windows"}, + } + + tests := []struct { + name string + assetName string + wantAssetURL string + wantAssetSize int64 + wantChecksumsURL string + wantAssetNil bool + wantChecksumsNil bool + }{ + { + name: "find darwin_arm64 (second in list)", + assetName: "roborev_darwin_arm64.tar.gz", + wantAssetURL: "https://example.com/darwin_arm64", + wantAssetSize: 2000, + wantChecksumsURL: "https://example.com/checksums", + }, + { + name: "find linux_amd64 (first in list)", + assetName: "roborev_linux_amd64.tar.gz", + wantAssetURL: "https://example.com/linux_amd64", + wantAssetSize: 1000, + wantChecksumsURL: "https://example.com/checksums", + }, + { + name: "find darwin_amd64 (after checksums)", + assetName: "roborev_darwin_amd64.tar.gz", + wantAssetURL: "https://example.com/darwin_amd64", + wantAssetSize: 3000, + wantChecksumsURL: "https://example.com/checksums", + }, + { + name: "asset not found", + assetName: "roborev_freebsd_amd64.tar.gz", + wantAssetNil: true, + wantChecksumsURL: "https://example.com/checksums", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + asset, checksums := findAssets(assets, tt.assetName) + + if tt.wantAssetNil { + if asset != nil { + t.Errorf("expected asset to be nil, got %+v", asset) + } + } else { + if asset == nil { + t.Fatal("expected asset to be non-nil") + } + if asset.BrowserDownloadURL != tt.wantAssetURL { + t.Errorf("asset URL = %q, want %q", asset.BrowserDownloadURL, tt.wantAssetURL) + } + if asset.Size != tt.wantAssetSize { + t.Errorf("asset size = %d, want %d", asset.Size, tt.wantAssetSize) + } + } + + if tt.wantChecksumsNil { + if checksums != nil { + t.Errorf("expected checksums to be nil, got %+v", checksums) + } + } else { + if checksums == nil { + t.Fatal("expected checksums to be non-nil") + } + if checksums.BrowserDownloadURL != tt.wantChecksumsURL { + t.Errorf("checksums URL = %q, want %q", checksums.BrowserDownloadURL, tt.wantChecksumsURL) + } + } + }) + } +} + +func TestFindAssetsNoChecksums(t *testing.T) { + // Test case where there's no checksums file + assets := []Asset{ + {Name: "roborev_linux_amd64.tar.gz", Size: 1000, BrowserDownloadURL: "https://example.com/linux"}, + {Name: "roborev_darwin_arm64.tar.gz", Size: 2000, BrowserDownloadURL: "https://example.com/darwin"}, + } + + asset, checksums := findAssets(assets, "roborev_darwin_arm64.tar.gz") + + if asset == nil { + t.Fatal("expected asset to be non-nil") + } + if asset.BrowserDownloadURL != "https://example.com/darwin" { + t.Errorf("asset URL = %q, want %q", asset.BrowserDownloadURL, "https://example.com/darwin") + } + if checksums != nil { + t.Errorf("expected checksums to be nil, got %+v", checksums) + } +}