diff --git a/pkg/workflow/awf_helpers.go b/pkg/workflow/awf_helpers.go index d504d666a5e..a825aa35208 100644 --- a/pkg/workflow/awf_helpers.go +++ b/pkg/workflow/awf_helpers.go @@ -173,6 +173,7 @@ func injectMaxAICreditsExpression(awfConfigJSON string, expr string) string { } func buildWorkflowCallNetworkAllowedUpdateScript() (string, error) { + ecosystemDomains := getLoadedEcosystemDomains() ecosystemMap := make(map[string][]string, safeAllocationCapacity(len(ecosystemDomains), len(compoundEcosystems))) for ecosystem := range ecosystemDomains { ecosystemMap[ecosystem] = getEcosystemDomains(ecosystem) diff --git a/pkg/workflow/compiler_activation_job_builder.go b/pkg/workflow/compiler_activation_job_builder.go index 9e88bf49b7d..4275f3e258f 100644 --- a/pkg/workflow/compiler_activation_job_builder.go +++ b/pkg/workflow/compiler_activation_job_builder.go @@ -64,7 +64,9 @@ func (c *Compiler) newActivationJobBuildContext( } ctx := newActivationBuildContext(data, preActivationJobCreated, workflowRunRepoSafety, lockFilename) - cacheActivationPreStepPermissions(ctx) + if err := cacheActivationPreStepPermissions(ctx); err != nil { + return nil, err + } c.addActivationSetupAndWorkflowCallSteps(ctx, setupActionRef) engine, err := c.getAgenticEngine(data.AI) @@ -99,7 +101,7 @@ func newActivationBuildContext(data *WorkflowData, preActivationJobCreated bool, return ctx } -func cacheActivationPreStepPermissions(ctx *activationJobBuildContext) { +func cacheActivationPreStepPermissions(ctx *activationJobBuildContext) error { // Cache scripts from setup/pre-steps and inferred permissions once to avoid redundant // extraction and inference calls in buildActivationPermissions and // addActivationFeedbackAndValidationSteps. @@ -111,8 +113,13 @@ func cacheActivationPreStepPermissions(ctx *activationJobBuildContext) { ctx.activationAllScripts = extractRunScriptsFromJobSection(ctx.data.Jobs, activationJobName, "setup-steps") ctx.activationAllScripts = append(ctx.activationAllScripts, extractRunScriptsFromJobSection(ctx.data.Jobs, activationJobName, "pre-steps")...) if len(ctx.activationAllScripts) > 0 { - ctx.activationInferredPerms = inferPermissionsFromShellScripts(ctx.activationAllScripts) + inferredPerms, err := inferPermissionsFromShellScripts(ctx.activationAllScripts) + if err != nil { + return err + } + ctx.activationInferredPerms = inferredPerms } + return nil } func (c *Compiler) addActivationSetupAndWorkflowCallSteps(ctx *activationJobBuildContext, setupActionRef string) { @@ -855,7 +862,11 @@ func (c *Compiler) addActivationScriptPermissions(permsMap map[PermissionScope]P if len(ctx.activationAllScripts) > 0 { // Detect write commands first — these are not permitted in the activation job // because it intentionally operates with read-only permissions. - if writeCmds := detectWriteCommandsInShellScripts(ctx.activationAllScripts); len(writeCmds) > 0 { + writeCmds, err := detectWriteCommandsInShellScripts(ctx.activationAllScripts) + if err != nil { + return err + } + if len(writeCmds) > 0 { return fmt.Errorf( "activation job uses write gh command(s) [%s]; write operations are not permitted in activation job steps because the activation job runs with read-only permissions. Move write operations to the agent job steps or use safe-outputs. See: https://github.github.com/gh-aw/reference/safe-outputs/", strings.Join(writeCmds, ", "), diff --git a/pkg/workflow/compiler_main_job.go b/pkg/workflow/compiler_main_job.go index 2b7036ce50a..5301c67166c 100644 --- a/pkg/workflow/compiler_main_job.go +++ b/pkg/workflow/compiler_main_job.go @@ -398,7 +398,11 @@ func (c *Compiler) buildMainJob(data *WorkflowData, activationJobCreated bool) ( agentAllScripts = append(agentAllScripts, extractRunScriptsFromJobSection(data.Jobs, agentJobName, "pre-steps")...) } if len(agentAllScripts) > 0 { - if writeCmds := detectWriteCommandsInShellScripts(agentAllScripts); len(writeCmds) > 0 { + writeCmds, err := detectWriteCommandsInShellScripts(agentAllScripts) + if err != nil { + return nil, err + } + if len(writeCmds) > 0 { return nil, fmt.Errorf( "agent job uses write gh command(s) [%s]; write operations are not permitted in agent job steps because the agent job runs with read-only permissions. Use safe-outputs for write operations. See: https://github.github.com/gh-aw/reference/safe-outputs/", strings.Join(writeCmds, ", "), @@ -410,7 +414,10 @@ func (c *Compiler) buildMainJob(data *WorkflowData, activationJobCreated bool) ( // Uses the same exact-string check as tools.go (the YAML parser always normalizes // "permissions: {}" to this canonical form when parsing the frontmatter). if data.Permissions != "permissions: {}" && permissions != "" { - inferred := inferPermissionsFromShellScripts(agentAllScripts) + inferred, err := inferPermissionsFromShellScripts(agentAllScripts) + if err != nil { + return nil, err + } if len(inferred) > 0 { permissions = mergeInferredIntoPermissionsYAML(permissions, inferred) } diff --git a/pkg/workflow/domains.go b/pkg/workflow/domains.go index 269a07afa22..77d79c4b70e 100644 --- a/pkg/workflow/domains.go +++ b/pkg/workflow/domains.go @@ -6,6 +6,7 @@ import ( "fmt" "sort" "strings" + "sync" "github.com/github/gh-aw/pkg/constants" "github.com/github/gh-aw/pkg/logger" @@ -18,8 +19,31 @@ var domainsLog = logger.New("workflow:domains") //go:embed data/ecosystem_domains.json var ecosystemDomainsJSON []byte -// ecosystemDomains holds the loaded domain data -var ecosystemDomains map[string][]string +var loadEcosystemDomains = sync.OnceValues(func() (map[string][]string, error) { + domainsLog.Print("Loading ecosystem domains from embedded JSON") + + ecosystemDomains := make(map[string][]string) + if err := json.Unmarshal(ecosystemDomainsJSON, &ecosystemDomains); err != nil { + return nil, fmt.Errorf("failed to load ecosystem domains from JSON: %w", err) + } + + // Pre-sort all domain lists once so getEcosystemDomains only needs to copy, not sort. + for key := range ecosystemDomains { + sort.Strings(ecosystemDomains[key]) + } + + domainsLog.Printf("Loaded %d ecosystem categories", len(ecosystemDomains)) + return ecosystemDomains, nil +}) + +func getLoadedEcosystemDomains() map[string][]string { + ecosystemDomains, err := loadEcosystemDomains() + if err != nil { + domainsLog.Printf("Failed to load ecosystem domains: %v", err) + return map[string][]string{} + } + return ecosystemDomains +} // CopilotDefaultDomains are the default domains required for GitHub Copilot CLI authentication and operation var CopilotDefaultDomains = []string{ @@ -318,25 +342,6 @@ var PlaywrightDomains = []string{ "playwright.download.prss.microsoft.com", } -// init loads the ecosystem domains from the embedded JSON and pre-sorts each list. -// Pre-sorting at startup avoids the per-call sort.Strings in getEcosystemDomains, -// which is called on every compilation and previously allocated + sorted each list -// on every invocation. -func init() { - domainsLog.Print("Loading ecosystem domains from embedded JSON") - - if err := json.Unmarshal(ecosystemDomainsJSON, &ecosystemDomains); err != nil { - panic(fmt.Sprintf("failed to load ecosystem domains from JSON: %v", err)) - } - - // Pre-sort all domain lists once so getEcosystemDomains only needs to copy, not sort. - for key := range ecosystemDomains { - sort.Strings(ecosystemDomains[key]) - } - - domainsLog.Printf("Loaded %d ecosystem categories", len(ecosystemDomains)) -} - // compoundEcosystems defines ecosystem identifiers that expand to the union of multiple // component ecosystems. These are resolved at lookup time, so they stay in sync with // any future changes to the component ecosystems. @@ -364,6 +369,7 @@ func getEcosystemDomains(category string) []string { return result } + ecosystemDomains := getLoadedEcosystemDomains() domains, exists := ecosystemDomains[category] if !exists { return []string{} @@ -586,7 +592,7 @@ func GetDomainEcosystem(domain string) string { // Fall back to any ecosystems not in the priority list, sorted for determinism remaining := make([]string, 0) - for ecosystem := range ecosystemDomains { + for ecosystem := range getLoadedEcosystemDomains() { if _, ok := checked[ecosystem]; !ok { remaining = append(remaining, ecosystem) } diff --git a/pkg/workflow/gh_cli_permissions.go b/pkg/workflow/gh_cli_permissions.go index c669e80b61a..306fcdbe32f 100644 --- a/pkg/workflow/gh_cli_permissions.go +++ b/pkg/workflow/gh_cli_permissions.go @@ -7,6 +7,7 @@ import ( "regexp" "sort" "strings" + "sync" "github.com/github/gh-aw/pkg/logger" "github.com/goccy/go-yaml" @@ -74,12 +75,10 @@ type compiledAPIPathPattern struct { appPermissions []PermissionScope } -var ghCLIPermissions compiledGHCLIPermissions - -func init() { +var getCompiledGHCLIPermissions = sync.OnceValues(func() (compiledGHCLIPermissions, error) { var data ghCLIPermissionsData if err := json.Unmarshal(ghCLIPermissionsJSON, &data); err != nil { - panic(fmt.Sprintf("failed to load gh CLI permissions from JSON: %v", err)) + return compiledGHCLIPermissions{}, fmt.Errorf("failed to load gh CLI permissions from JSON: %w", err) } cp := compiledGHCLIPermissions{ @@ -99,7 +98,13 @@ func init() { } sort.Strings(groups) // deterministic alternation order subcommandPattern := `(?m)(?:^|[\s|;])gh\s+(` + strings.Join(groups, "|") + `)\s+([\w][\w-]*)\b` - cp.subcommandRE = regexp.MustCompile(subcommandPattern) + // Defensive check: the pattern is built from embedded JSON keys quoted with + // regexp.QuoteMeta, so a compile error would indicate unexpected data corruption. + subcommandRE, err := regexp.Compile(subcommandPattern) + if err != nil { + return compiledGHCLIPermissions{}, fmt.Errorf("invalid gh subcommand pattern %q: %w", subcommandPattern, err) + } + cp.subcommandRE = subcommandRE for group, sg := range data.SubcommandGroups { readPerms := make([]PermissionScope, len(sg.ReadPermissions)) @@ -138,7 +143,7 @@ func init() { for _, ap := range data.APIPathPatterns { re, err := regexp.Compile(ap.Pattern) if err != nil { - panic(fmt.Sprintf("invalid gh API path pattern %q in gh_cli_permissions.json: %v", ap.Pattern, err)) + return compiledGHCLIPermissions{}, fmt.Errorf("invalid gh API path pattern %q in gh_cli_permissions.json: %w", ap.Pattern, err) } perms := make([]PermissionScope, len(ap.Permissions)) for i, p := range ap.Permissions { @@ -155,9 +160,9 @@ func init() { }) } - ghCLIPermissions = cp ghCLIPermissionsLog.Printf("Loaded gh CLI permissions: version=%s, subcommand_groups=%d, api_path_patterns=%d", data.Version, len(data.SubcommandGroups), len(data.APIPathPatterns)) -} + return cp, nil +}) // ghAPICmdRE matches `gh api` at a command boundary, capturing the rest of the line. var ghAPICmdRE = regexp.MustCompile(`(?m)(?:^|[\s|;])gh\s+api\s+(.+)`) @@ -275,9 +280,13 @@ func splitShellTokens(s string) []string { // Only read-level permissions are inferred here; write-level operations are // intentionally not auto-escalated. Use detectWriteCommandsInShellScripts to // surface write commands as validation errors. -func inferPermissionsFromShellScripts(scripts []string) map[PermissionScope]PermissionLevel { +func inferPermissionsFromShellScripts(scripts []string) (map[PermissionScope]PermissionLevel, error) { ghCLIPermissionsLog.Printf("Inferring permissions from %d shell script(s)", len(scripts)) perms := make(map[PermissionScope]PermissionLevel) + ghCLIPermissions, err := getCompiledGHCLIPermissions() + if err != nil { + return nil, fmt.Errorf("load gh CLI permissions: %w", err) + } addScopes := func(scopes []PermissionScope) { for _, scope := range scopes { @@ -337,14 +346,18 @@ func inferPermissionsFromShellScripts(scripts []string) map[PermissionScope]Perm } ghCLIPermissionsLog.Printf("Inferred %d permission scope(s) from shell scripts", len(perms)) - return perms + return perms, nil } // detectWriteCommandsInShellScripts returns all write gh CLI commands found in the // given scripts, formatted as "gh " (e.g. "gh pr create"). // The slice contains no duplicates and is sorted deterministically in discovery order. -func detectWriteCommandsInShellScripts(scripts []string) []string { +func detectWriteCommandsInShellScripts(scripts []string) ([]string, error) { ghCLIPermissionsLog.Printf("Scanning %d shell script(s) for write gh CLI commands", len(scripts)) + ghCLIPermissions, err := getCompiledGHCLIPermissions() + if err != nil { + return nil, fmt.Errorf("load gh CLI permissions: %w", err) + } var found []string seen := make(map[string]struct{}) @@ -367,7 +380,7 @@ func detectWriteCommandsInShellScripts(scripts []string) []string { if len(found) > 0 { ghCLIPermissionsLog.Printf("Detected %d write gh CLI command(s) in shell scripts", len(found)) } - return found + return found, nil } // extractRunScriptsFromSectionYAML parses a step-section YAML string (e.g. as stored in diff --git a/pkg/workflow/gh_cli_permissions_test.go b/pkg/workflow/gh_cli_permissions_test.go index 452366da076..304144833bc 100644 --- a/pkg/workflow/gh_cli_permissions_test.go +++ b/pkg/workflow/gh_cli_permissions_test.go @@ -14,48 +14,66 @@ import ( "github.com/stretchr/testify/require" ) +func inferPermissionsFromShellScriptsForTest(t *testing.T, scripts []string) map[PermissionScope]PermissionLevel { + t.Helper() + perms, err := inferPermissionsFromShellScripts(scripts) + if err != nil { + t.Fatalf("inferPermissionsFromShellScripts() error = %v", err) + } + return perms +} + +func detectWriteCommandsInShellScriptsForTest(t *testing.T, scripts []string) []string { + t.Helper() + cmds, err := detectWriteCommandsInShellScripts(scripts) + if err != nil { + t.Fatalf("detectWriteCommandsInShellScripts() error = %v", err) + } + return cmds +} + // TestInferPermissionsFromShellScripts_GhPrDiff verifies that `gh pr diff` in a // shell script is recognized as requiring pull-requests: read. func TestInferPermissionsFromShellScripts_GhPrDiff(t *testing.T) { scripts := []string{ `gh pr diff "$PR_NUMBER" --name-only | awk '/\.md$/' > /tmp/changed.txt`, } - perms := inferPermissionsFromShellScripts(scripts) + perms := inferPermissionsFromShellScriptsForTest(t, scripts) assert.Equal(t, PermissionRead, perms[PermissionPullRequests], "gh pr diff should require pull-requests: read") } // TestInferPermissionsFromShellScripts_GhPrView verifies pull-requests: read for gh pr view. func TestInferPermissionsFromShellScripts_GhPrView(t *testing.T) { scripts := []string{`gh pr view 123 --json title`} - perms := inferPermissionsFromShellScripts(scripts) + perms := inferPermissionsFromShellScriptsForTest(t, scripts) assert.Equal(t, PermissionRead, perms[PermissionPullRequests]) } // TestInferPermissionsFromShellScripts_GhIssueList verifies issues: read for gh issue list. func TestInferPermissionsFromShellScripts_GhIssueList(t *testing.T) { scripts := []string{`gh issue list --label bug --json number`} - perms := inferPermissionsFromShellScripts(scripts) + perms := inferPermissionsFromShellScriptsForTest(t, scripts) assert.Equal(t, PermissionRead, perms[PermissionIssues]) } // TestInferPermissionsFromShellScripts_GhWorkflowList verifies actions: read for gh workflow list. func TestInferPermissionsFromShellScripts_GhWorkflowList(t *testing.T) { scripts := []string{`gh workflow list`} - perms := inferPermissionsFromShellScripts(scripts) + perms := inferPermissionsFromShellScriptsForTest(t, scripts) assert.Equal(t, PermissionRead, perms[PermissionActions]) } // TestInferPermissionsFromShellScripts_GhRunView verifies actions: read for gh run view. func TestInferPermissionsFromShellScripts_GhRunView(t *testing.T) { scripts := []string{`gh run view $RUN_ID --json conclusion`} - perms := inferPermissionsFromShellScripts(scripts) + perms := inferPermissionsFromShellScriptsForTest(t, scripts) assert.Equal(t, PermissionRead, perms[PermissionActions]) } // TestInferPermissionsFromShellScripts_GhAPI verifies pull-requests: read for gh api pulls endpoint. func TestInferPermissionsFromShellScripts_GhAPI(t *testing.T) { scripts := []string{`gh api /repos/owner/repo/pulls/1 --jq .title`} - perms := inferPermissionsFromShellScripts(scripts) + perms := inferPermissionsFromShellScriptsForTest(t, scripts) assert.Equal(t, PermissionRead, perms[PermissionPullRequests], "gh api /repos/.../pulls should require pull-requests: read") } @@ -63,7 +81,7 @@ func TestInferPermissionsFromShellScripts_GhAPI(t *testing.T) { // endpoint are skipped, e.g. gh api -H 'Accept: ...' /repos/owner/repo/pulls. func TestInferPermissionsFromShellScripts_GhAPIWithHeaderFlag(t *testing.T) { scripts := []string{`gh api -H 'Accept: application/vnd.github+json' /repos/owner/repo/pulls`} - perms := inferPermissionsFromShellScripts(scripts) + perms := inferPermissionsFromShellScriptsForTest(t, scripts) assert.Equal(t, PermissionRead, perms[PermissionPullRequests], "gh api with -H flag before endpoint should still infer pull-requests: read") } @@ -72,7 +90,7 @@ func TestInferPermissionsFromShellScripts_GhAPIWithHeaderFlag(t *testing.T) { // before the endpoint is skipped properly. func TestInferPermissionsFromShellScripts_GhAPIWithMethodFlag(t *testing.T) { scripts := []string{`gh api --method GET /repos/owner/repo/issues`} - perms := inferPermissionsFromShellScripts(scripts) + perms := inferPermissionsFromShellScriptsForTest(t, scripts) assert.Equal(t, PermissionRead, perms[PermissionIssues], "gh api with --method GET before endpoint should still infer issues: read") } @@ -81,7 +99,7 @@ func TestInferPermissionsFromShellScripts_GhAPIWithMethodFlag(t *testing.T) { // is correctly extracted, e.g. gh api "/repos/owner/repo/pulls". func TestInferPermissionsFromShellScripts_GhAPIQuotedEndpoint(t *testing.T) { scripts := []string{`gh api "/repos/owner/repo/pulls"`} - perms := inferPermissionsFromShellScripts(scripts) + perms := inferPermissionsFromShellScriptsForTest(t, scripts) assert.Equal(t, PermissionRead, perms[PermissionPullRequests], `gh api with quoted endpoint should infer pull-requests: read`) } @@ -89,7 +107,7 @@ func TestInferPermissionsFromShellScripts_GhAPIQuotedEndpoint(t *testing.T) { // TestInferPermissionsFromShellScripts_GhAPIIssues verifies issues: read for gh api issues endpoint. func TestInferPermissionsFromShellScripts_GhAPIIssues(t *testing.T) { scripts := []string{`gh api /repos/owner/repo/issues --jq '.[].number'`} - perms := inferPermissionsFromShellScripts(scripts) + perms := inferPermissionsFromShellScriptsForTest(t, scripts) assert.Equal(t, PermissionRead, perms[PermissionIssues], "gh api /repos/.../issues should require issues: read") } @@ -97,7 +115,7 @@ func TestInferPermissionsFromShellScripts_GhAPIIssues(t *testing.T) { // there are no gh CLI calls in the script. func TestInferPermissionsFromShellScripts_NoGhCommand(t *testing.T) { scripts := []string{`echo "hello" && ls /tmp`} - perms := inferPermissionsFromShellScripts(scripts) + perms := inferPermissionsFromShellScriptsForTest(t, scripts) assert.Empty(t, perms, "no gh commands should produce no permission requirements") } @@ -108,7 +126,7 @@ func TestInferPermissionsFromShellScripts_MultiLine(t *testing.T) { | awk '/\.md$/' \ > /tmp/gh-aw/docs-review-data/changed-md.txt`, } - perms := inferPermissionsFromShellScripts(scripts) + perms := inferPermissionsFromShellScriptsForTest(t, scripts) assert.Equal(t, PermissionRead, perms[PermissionPullRequests], "multi-line gh pr diff should require pull-requests: read") } @@ -118,7 +136,7 @@ func TestInferPermissionsFromShellScripts_MultipleCommands(t *testing.T) { `gh pr diff "$PR_NUMBER" --name-only > /tmp/changed.txt gh issue view $ISSUE_NUMBER --json body > /tmp/issue.json`, } - perms := inferPermissionsFromShellScripts(scripts) + perms := inferPermissionsFromShellScriptsForTest(t, scripts) assert.Equal(t, PermissionRead, perms[PermissionPullRequests], "should infer pull-requests: read") assert.Equal(t, PermissionRead, perms[PermissionIssues], "should infer issues: read") } @@ -241,7 +259,7 @@ engine: copilot // TestDetectWriteCommandsInShellScripts_GhPrCreate verifies that `gh pr create` is detected as a write command. func TestDetectWriteCommandsInShellScripts_GhPrCreate(t *testing.T) { scripts := []string{`gh pr create --title "Fix bug" --body "details"`} - cmds := detectWriteCommandsInShellScripts(scripts) + cmds := detectWriteCommandsInShellScriptsForTest(t, scripts) require.Len(t, cmds, 1) assert.Equal(t, "gh pr create", cmds[0]) } @@ -249,7 +267,7 @@ func TestDetectWriteCommandsInShellScripts_GhPrCreate(t *testing.T) { // TestDetectWriteCommandsInShellScripts_GhIssueClose verifies that `gh issue close` is detected. func TestDetectWriteCommandsInShellScripts_GhIssueClose(t *testing.T) { scripts := []string{`gh issue close $ISSUE_NUMBER`} - cmds := detectWriteCommandsInShellScripts(scripts) + cmds := detectWriteCommandsInShellScriptsForTest(t, scripts) require.Len(t, cmds, 1) assert.Equal(t, "gh issue close", cmds[0]) } @@ -258,7 +276,7 @@ func TestDetectWriteCommandsInShellScripts_GhIssueClose(t *testing.T) { // (e.g. `gh pr diff`) is NOT flagged as a write command. func TestDetectWriteCommandsInShellScripts_ReadCommandNotDetected(t *testing.T) { scripts := []string{`gh pr diff "$PR_NUMBER" --name-only`} - cmds := detectWriteCommandsInShellScripts(scripts) + cmds := detectWriteCommandsInShellScriptsForTest(t, scripts) assert.Empty(t, cmds, "gh pr diff is a read command and should not be detected as write") } @@ -269,7 +287,7 @@ func TestDetectWriteCommandsInShellScripts_Deduplicated(t *testing.T) { `gh pr create --title "Fix 1" gh pr create --title "Fix 2"`, } - cmds := detectWriteCommandsInShellScripts(scripts) + cmds := detectWriteCommandsInShellScriptsForTest(t, scripts) assert.Len(t, cmds, 1, "duplicate write commands should be deduplicated") assert.Equal(t, "gh pr create", cmds[0]) } @@ -281,7 +299,7 @@ func TestDetectWriteCommandsInShellScripts_MultipleWriteCommands(t *testing.T) { `gh pr merge $PR_NUMBER --squash gh issue comment $ISSUE_NUMBER --body "done"`, } - cmds := detectWriteCommandsInShellScripts(scripts) + cmds := detectWriteCommandsInShellScriptsForTest(t, scripts) assert.Len(t, cmds, 2) assert.Contains(t, cmds, "gh pr merge") assert.Contains(t, cmds, "gh issue comment") @@ -290,21 +308,21 @@ gh issue comment $ISSUE_NUMBER --body "done"`, // TestInferPermissionsFromShellScripts_GhCacheList verifies actions: read for gh cache list. func TestInferPermissionsFromShellScripts_GhCacheList(t *testing.T) { scripts := []string{`gh cache list --json key`} - perms := inferPermissionsFromShellScripts(scripts) + perms := inferPermissionsFromShellScriptsForTest(t, scripts) assert.Equal(t, PermissionRead, perms[PermissionActions], "gh cache list should require actions: read") } // TestInferPermissionsFromShellScripts_GhRepoView verifies contents: read for gh repo view. func TestInferPermissionsFromShellScripts_GhRepoView(t *testing.T) { scripts := []string{`gh repo view owner/repo --json description`} - perms := inferPermissionsFromShellScripts(scripts) + perms := inferPermissionsFromShellScriptsForTest(t, scripts) assert.Equal(t, PermissionRead, perms[PermissionContents], "gh repo view should require contents: read") } // TestInferPermissionsFromShellScripts_GhLabelList verifies issues: read for gh label list. func TestInferPermissionsFromShellScripts_GhLabelList(t *testing.T) { scripts := []string{`gh label list --json name`} - perms := inferPermissionsFromShellScripts(scripts) + perms := inferPermissionsFromShellScriptsForTest(t, scripts) assert.Equal(t, PermissionRead, perms[PermissionIssues], "gh label list should require issues: read") } @@ -313,21 +331,21 @@ func TestInferPermissionsFromShellScripts_GhLabelList(t *testing.T) { // in the activation job — the write-command check is separate. func TestInferPermissionsFromShellScripts_GhIssueComment(t *testing.T) { scripts := []string{`gh issue comment $ISSUE_NUMBER --body "hello"`} - perms := inferPermissionsFromShellScripts(scripts) + perms := inferPermissionsFromShellScriptsForTest(t, scripts) assert.Equal(t, PermissionRead, perms[PermissionIssues], "write commands still require at minimum read-level permission for the scope") } // TestInferPermissionsFromShellScripts_GhAPIReleases verifies contents: read for gh api releases. func TestInferPermissionsFromShellScripts_GhAPIReleases(t *testing.T) { scripts := []string{`gh api /repos/owner/repo/releases --jq '.[0].tag_name'`} - perms := inferPermissionsFromShellScripts(scripts) + perms := inferPermissionsFromShellScriptsForTest(t, scripts) assert.Equal(t, PermissionRead, perms[PermissionContents], "gh api /repos/.../releases should require contents: read") } // TestInferPermissionsFromShellScripts_GhAPILabels verifies issues: read for gh api labels endpoint. func TestInferPermissionsFromShellScripts_GhAPILabels(t *testing.T) { scripts := []string{`gh api /repos/owner/repo/labels`} - perms := inferPermissionsFromShellScripts(scripts) + perms := inferPermissionsFromShellScriptsForTest(t, scripts) assert.Equal(t, PermissionRead, perms[PermissionIssues], "gh api /repos/.../labels should require issues: read") } @@ -406,7 +424,7 @@ jobs: // returns the GitHub App-only codespaces: read permission (no GITHUB_TOKEN equivalent). func TestInferPermissionsFromShellScripts_GhCodespaceList(t *testing.T) { scripts := []string{`gh codespace list --json name`} - perms := inferPermissionsFromShellScripts(scripts) + perms := inferPermissionsFromShellScriptsForTest(t, scripts) assert.Equal(t, PermissionRead, perms[PermissionCodespaces], "gh codespace list should require codespaces: read (GitHub App-only)") } @@ -415,7 +433,7 @@ func TestInferPermissionsFromShellScripts_GhCodespaceList(t *testing.T) { // returns the GitHub App-only members: read permission. func TestInferPermissionsFromShellScripts_GhAPIOrgsMembers(t *testing.T) { scripts := []string{`gh api /orgs/myorg/members --jq '.[].login'`} - perms := inferPermissionsFromShellScripts(scripts) + perms := inferPermissionsFromShellScriptsForTest(t, scripts) assert.Equal(t, PermissionRead, perms[PermissionMembers], "gh api /orgs/.../members should require members: read (GitHub App-only)") } @@ -427,7 +445,7 @@ func TestInferPermissionsFromShellScripts_AppAndActionsPermissions(t *testing.T) `gh pr diff "$PR_NUMBER" --name-only gh codespace list --json name`, } - perms := inferPermissionsFromShellScripts(scripts) + perms := inferPermissionsFromShellScriptsForTest(t, scripts) assert.Equal(t, PermissionRead, perms[PermissionPullRequests], "gh pr diff should require pull-requests: read") assert.Equal(t, PermissionRead, perms[PermissionCodespaces], @@ -438,7 +456,7 @@ gh codespace list --json name`, // (a write command) is still inferred to need administration: read (GitHub App-only) at minimum. func TestInferPermissionsFromShellScripts_GhRepoWriteHasAppAdminPerm(t *testing.T) { scripts := []string{`gh repo archive owner/repo`} - perms := inferPermissionsFromShellScripts(scripts) + perms := inferPermissionsFromShellScriptsForTest(t, scripts) assert.Equal(t, PermissionRead, perms[PermissionAdministration], "gh repo archive (write) should infer administration: read for GitHub App") } @@ -447,7 +465,7 @@ func TestInferPermissionsFromShellScripts_GhRepoWriteHasAppAdminPerm(t *testing. // (GitHub App-only) for the environments REST API path. func TestInferPermissionsFromShellScripts_GhAPIRepoEnvironments(t *testing.T) { scripts := []string{`gh api /repos/owner/repo/environments --jq '.[].name'`} - perms := inferPermissionsFromShellScripts(scripts) + perms := inferPermissionsFromShellScriptsForTest(t, scripts) assert.Equal(t, PermissionRead, perms[PermissionEnvironments], "gh api /repos/.../environments should require environments: read (GitHub App-only)") } diff --git a/pkg/workflow/github_tool_to_toolset.go b/pkg/workflow/github_tool_to_toolset.go index ce489c9b5d6..5c6864bc0c2 100644 --- a/pkg/workflow/github_tool_to_toolset.go +++ b/pkg/workflow/github_tool_to_toolset.go @@ -3,6 +3,8 @@ package workflow import ( _ "embed" "encoding/json" + "fmt" + "sync" "github.com/github/gh-aw/pkg/logger" ) @@ -12,18 +14,11 @@ var githubToolToToolsetLog = logger.New("workflow:github_tool_to_toolset") //go:embed data/github_tool_to_toolset.json var githubToolToToolsetJSON []byte -// GitHubToolToToolsetMap maps individual GitHub MCP tools to their respective toolsets -// This mapping is loaded from an embedded JSON file based on the documentation -// in .github/aw/github-mcp-server.md -var GitHubToolToToolsetMap map[string]string - -func init() { - // Load the mapping from embedded JSON - if err := json.Unmarshal(githubToolToToolsetJSON, &GitHubToolToToolsetMap); err != nil { - panic("failed to load GitHub tool to toolset mapping: " + err.Error()) +var getGitHubToolToToolsetMap = sync.OnceValues(func() (map[string]string, error) { + var toolToToolsetMap map[string]string + if err := json.Unmarshal(githubToolToToolsetJSON, &toolToToolsetMap); err != nil { + return nil, fmt.Errorf("failed to load GitHub tool to toolset mapping: %w", err) } - githubToolToToolsetLog.Printf("Loaded GitHub tool-to-toolset mapping: %d entries", len(GitHubToolToToolsetMap)) -} - -// GitHubToolToToolsetMap is the last declaration in this file; ValidateGitHubToolsAgainstToolsets -// has been moved to tools_validation.go. + githubToolToToolsetLog.Printf("Loaded GitHub tool-to-toolset mapping: %d entries", len(toolToToolsetMap)) + return toolToToolsetMap, nil +}) diff --git a/pkg/workflow/github_tool_to_toolset_test.go b/pkg/workflow/github_tool_to_toolset_test.go index 500870b3311..3e854e30315 100644 --- a/pkg/workflow/github_tool_to_toolset_test.go +++ b/pkg/workflow/github_tool_to_toolset_test.go @@ -270,21 +270,22 @@ func TestGitHubToolToToolsetMap_Completeness(t *testing.T) { } foundToolsets := make(map[string]bool) - for _, toolset := range GitHubToolToToolsetMap { + for _, toolset := range loadGitHubToolToToolsetMap(t) { foundToolsets[toolset] = true } for _, expectedToolset := range expectedToolsets { if !foundToolsets[expectedToolset] { - t.Errorf("Expected to find tools for toolset %q in GitHubToolToToolsetMap", expectedToolset) + t.Errorf("Expected to find tools for toolset %q in getGitHubToolToToolsetMap()", expectedToolset) } } } func TestGitHubToolToToolsetMap_IncludesDefaultGitHubTools(t *testing.T) { + toolToToolsetMap := loadGitHubToolToToolsetMap(t) for _, tool := range constants.DefaultReadOnlyGitHubTools { - if _, exists := GitHubToolToToolsetMap[tool]; !exists { - t.Errorf("Expected tool %q from constants.DefaultReadOnlyGitHubTools to be in GitHubToolToToolsetMap", tool) + if _, exists := toolToToolsetMap[tool]; !exists { + t.Errorf("Expected tool %q from constants.DefaultReadOnlyGitHubTools to be in getGitHubToolToToolsetMap()", tool) } } } @@ -314,10 +315,11 @@ func TestGitHubToolToToolsetMap_ConsistencyWithDocumentation(t *testing.T) { "list_secret_scanning_alerts": "secret_protection", } + toolToToolsetMap := loadGitHubToolToToolsetMap(t) for tool, expectedToolset := range expectedMappings { - actualToolset, exists := GitHubToolToToolsetMap[tool] + actualToolset, exists := toolToToolsetMap[tool] if !exists { - t.Errorf("Expected tool %q to be in GitHubToolToToolsetMap", tool) + t.Errorf("Expected tool %q to be in getGitHubToolToToolsetMap()", tool) continue } if actualToolset != expectedToolset { @@ -326,6 +328,16 @@ func TestGitHubToolToToolsetMap_ConsistencyWithDocumentation(t *testing.T) { } } +func loadGitHubToolToToolsetMap(t *testing.T) map[string]string { + t.Helper() + + toolToToolsetMap, err := getGitHubToolToToolsetMap() + if err != nil { + t.Fatalf("getGitHubToolToToolsetMap() error = %v", err) + } + return toolToToolsetMap +} + // expandToolsetsForTesting expands "default" and "all" toolsets for testing purposes func expandToolsetsForTesting(toolsets []string) []string { var expanded []string diff --git a/pkg/workflow/network_firewall_validation.go b/pkg/workflow/network_firewall_validation.go index 3e87d6695e3..50ef6fb03e6 100644 --- a/pkg/workflow/network_firewall_validation.go +++ b/pkg/workflow/network_firewall_validation.go @@ -143,6 +143,7 @@ func isEcosystemIdentifier(domain string) bool { // It checks the base ecosystemDomains map and the compoundEcosystems map directly, // avoiding the allocations that getEcosystemDomains incurs. func isKnownEcosystemIdentifier(id string) bool { + ecosystemDomains := getLoadedEcosystemDomains() if _, ok := ecosystemDomains[id]; ok { return true } @@ -153,6 +154,7 @@ func isKnownEcosystemIdentifier(id string) bool { // getValidEcosystemIdentifiers returns a sorted list of all valid ecosystem identifiers, // including both the base identifiers from ecosystemDomains and compound identifiers. func getValidEcosystemIdentifiers() []string { + ecosystemDomains := getLoadedEcosystemDomains() ids := make([]string, 0, safeAllocationCapacity(len(ecosystemDomains), len(compoundEcosystems))) for id := range ecosystemDomains { ids = append(ids, id) diff --git a/pkg/workflow/tools_validation_github_toolsets.go b/pkg/workflow/tools_validation_github_toolsets.go index d5cf314aa72..562bdfce51e 100644 --- a/pkg/workflow/tools_validation_github_toolsets.go +++ b/pkg/workflow/tools_validation_github_toolsets.go @@ -13,13 +13,17 @@ import ( func validateGitHubToolsAgainstToolsetsCore(allowedTools []string, enabledToolsets []string) error { githubToolToToolsetLog.Printf("Validating GitHub tools against toolsets: allowed_tools=%d, enabled_toolsets=%d", len(allowedTools), len(enabledToolsets)) - if len(allowedTools) == 0 { githubToolToToolsetLog.Print("No tools to validate, skipping") // No specific tools restricted, validation not needed return nil } + toolToToolsetMap, err := getGitHubToolToToolsetMap() + if err != nil { + return fmt.Errorf("failed to load GitHub tool-to-toolset mapping: %w", err) + } + // Create a set of enabled toolsets for fast lookup enabledSet := make(map[string]struct { }) @@ -42,13 +46,13 @@ func validateGitHubToolsAgainstToolsetsCore(allowedTools []string, enabledToolse continue } - requiredToolset, exists := GitHubToolToToolsetMap[tool] + requiredToolset, exists := toolToToolsetMap[tool] if !exists { githubToolToToolsetLog.Printf("Tool %s not found in mapping, checking for typo", tool) // Get all valid tool names for suggestion - validTools := make([]string, 0, len(GitHubToolToToolsetMap)) - for validTool := range GitHubToolToToolsetMap { + validTools := make([]string, 0, len(toolToToolsetMap)) + for validTool := range toolToToolsetMap { validTools = append(validTools, validTool) } sort.Strings(validTools) @@ -89,8 +93,8 @@ func validateGitHubToolsAgainstToolsetsCore(allowedTools []string, enabledToolse } // Show a few examples of valid tools - validTools := make([]string, 0, len(GitHubToolToToolsetMap)) - for tool := range GitHubToolToToolsetMap { + validTools := make([]string, 0, len(toolToToolsetMap)) + for tool := range toolToToolsetMap { validTools = append(validTools, tool) } sort.Strings(validTools)