diff --git a/.github/workflows/repo-assist.lock.yml b/.github/workflows/repo-assist.lock.yml index d2f6c99fd..2c71e5e00 100644 --- a/.github/workflows/repo-assist.lock.yml +++ b/.github/workflows/repo-assist.lock.yml @@ -399,17 +399,18 @@ jobs: - name: Create gh-aw temp directory run: bash ${RUNNER_TEMP}/gh-aw/actions/create_gh_aw_tmp_dir.sh # NOTE: Local container build kept for debugging. Uncomment to test unpublished changes. - # - name: Build local MCPG container (debugging only) - # run: | - # rustup target add wasm32-wasip1 - # cd guards/github-guard/rust-guard && ./build.sh && cd ../../.. - # docker build . -t ghcr.io/github/gh-aw-mcpg:local + - name: Build local MCPG container (debugging only) + run: | + rustup target add wasm32-wasip1 + cd guards/github-guard/rust-guard && ./build.sh && cd ../../.. + docker build . -t ghcr.io/github/gh-aw-mcpg:local - name: Start DIFC proxy for pre-agent gh calls env: - GH_TOKEN: ${{ github.token }} + GH_TOKEN: ${{ secrets.GH_AW_GITHUB_MCP_SERVER_TOKEN || secrets.GH_AW_GITHUB_TOKEN || secrets.GITHUB_TOKEN }} run: | PROXY_LOG_DIR=/tmp/gh-aw/proxy-logs - mkdir -p "$PROXY_LOG_DIR" + MCP_LOG_DIR=/tmp/gh-aw/mcp-logs + mkdir -p "$PROXY_LOG_DIR" "$MCP_LOG_DIR" POLICY='{"allow-only":{"repos":["github/*"],"min-integrity":"approved"}}' @@ -417,12 +418,14 @@ jobs: -e GH_TOKEN \ -e DEBUG='*' \ -v "$PROXY_LOG_DIR:$PROXY_LOG_DIR" \ - ghcr.io/github/gh-aw-mcpg:v0.1.26 proxy \ + -v "$MCP_LOG_DIR:$MCP_LOG_DIR" \ + ghcr.io/github/gh-aw-mcpg:local proxy \ --policy "$POLICY" \ --listen 0.0.0.0:18443 \ - --log-dir "$PROXY_LOG_DIR" \ + --log-dir "$MCP_LOG_DIR" \ --tls --tls-dir "$PROXY_LOG_DIR/proxy-tls" \ - --guards-mode filter + --guards-mode filter \ + --trusted-bots github-actions[bot],github-actions,dependabot[bot],copilot # Wait for TLS cert to be generated CA_INSTALLED=false @@ -458,9 +461,9 @@ jobs: - name: Configure gh CLI for GitHub Enterprise run: bash ${RUNNER_TEMP}/gh-aw/actions/configure_gh_for_ghe.sh env: - GH_TOKEN: ${{ github.token }} + GH_TOKEN: ${{ secrets.GH_AW_GITHUB_MCP_SERVER_TOKEN || secrets.GH_AW_GITHUB_TOKEN || secrets.GITHUB_TOKEN }} - env: - GH_TOKEN: ${{ github.token }} + GH_TOKEN: ${{ secrets.GH_AW_GITHUB_MCP_SERVER_TOKEN || secrets.GH_AW_GITHUB_TOKEN || secrets.GITHUB_TOKEN }} name: Fetch repo data for task weighting run: "mkdir -p /tmp/gh-aw\n\n# Fetch open issues with labels (up to 500)\n# Fallback to empty array if DIFC proxy filters all data\ngh issue list -R $GITHUB_REPOSITORY --state open --limit 500 --json number,labels > /tmp/gh-aw/issues.json 2>/dev/null || echo '[]' > /tmp/gh-aw/issues.json\n\n# Fetch open PRs with titles (up to 200)\ngh pr list -R $GITHUB_REPOSITORY --state open --limit 200 --json number,title > /tmp/gh-aw/prs.json 2>/dev/null || echo '[]' > /tmp/gh-aw/prs.json\n\n# Compute task weights and select two tasks for this run\npython3 - << 'EOF'\nimport json, random, os\n\nwith open('/tmp/gh-aw/issues.json') as f:\n issues = json.load(f)\nwith open('/tmp/gh-aw/prs.json') as f:\n prs = json.load(f)\n\nopen_issues = len(issues)\nunlabelled = sum(1 for i in issues if not i.get('labels'))\nrepo_assist_prs = sum(1 for p in prs if p['title'].startswith('[Repo Assist]'))\nother_prs = sum(1 for p in prs if not p['title'].startswith('[Repo Assist]'))\n\ntask_names = {\n 1: 'Issue Labelling',\n 2: 'Issue Investigation and Comment',\n 3: 'Issue Investigation and Fix',\n 4: 'Engineering Investments',\n 5: 'Coding Improvements',\n 6: 'Maintain Repo Assist PRs',\n 7: 'Stale PR Nudges',\n 8: 'Performance Improvements',\n 9: 'Testing Improvements',\n 10: 'Take the Repository Forward',\n}\n\nweights = {\n 1: 1 + 3 * unlabelled,\n 2: 3 + 1 * open_issues,\n 3: 3 + 0.7 * open_issues,\n 4: 5 + 0.2 * open_issues,\n 5: 5 + 0.1 * open_issues,\n 6: float(repo_assist_prs),\n 7: 0.1 * other_prs,\n 8: 3 + 0.05 * open_issues,\n 9: 3 + 0.05 * open_issues,\n 10: 3 + 0.05 * open_issues,\n}\n\n# Seed with run ID for reproducibility within a run\nrun_id = int(os.environ.get('GITHUB_RUN_ID', '0'))\nrng = random.Random(run_id)\n\ntask_ids = list(weights.keys())\ntask_weights = [weights[t] for t in task_ids]\n\n# Weighted sample without replacement (pick 2 distinct tasks)\nchosen, seen = [], set()\nfor t in rng.choices(task_ids, weights=task_weights, k=30):\n if t not in seen:\n seen.add(t)\n chosen.append(t)\n if len(chosen) == 2:\n break\n\nprint('=== Repo Assist Task Selection ===')\nprint(f'Open issues : {open_issues}')\nprint(f'Unlabelled issues : {unlabelled}')\nprint(f'Repo Assist PRs : {repo_assist_prs}')\nprint(f'Other open PRs : {other_prs}')\nprint()\nprint('Task weights:')\nfor t, w in weights.items():\n tag = ' <-- SELECTED' if t in chosen else ''\n print(f' Task {t:2d} ({task_names[t]}): weight {w:6.1f}{tag}')\nprint()\nprint(f'Selected tasks for this run: Task {chosen[0]} ({task_names[chosen[0]]}) and Task {chosen[1]} ({task_names[chosen[1]]})')\n\nresult = {\n 'open_issues': open_issues, 'unlabelled_issues': unlabelled,\n 'repo_assist_prs': repo_assist_prs, 'other_prs': other_prs,\n 'task_names': task_names,\n 'weights': {str(k): round(v, 2) for k, v in weights.items()},\n 'selected_tasks': chosen,\n}\nwith open('/tmp/gh-aw/task_selection.json', 'w') as f:\n json.dump(result, f, indent=2)\nEOF\n" - name: Dump proxy logs (debug) @@ -469,12 +472,12 @@ jobs: echo "=== Proxy container logs ===" docker logs awmg-proxy 2>&1 | tail -80 || true echo "=== Proxy log file ===" - cat /tmp/gh-aw/proxy-logs/proxy.log 2>/dev/null | tail -50 || true + cat /tmp/gh-aw/mcp-logs/proxy.log 2>/dev/null | tail -50 || true # Repo memory git-based storage configuration from frontmatter processed below - name: Clone repo-memory branch (default) env: - GH_TOKEN: ${{ github.token }} + GH_TOKEN: ${{ secrets.GH_AW_GITHUB_MCP_SERVER_TOKEN || secrets.GH_AW_GITHUB_TOKEN || secrets.GITHUB_TOKEN }} GITHUB_SERVER_URL: ${{ github.server_url }} BRANCH_NAME: memory/repo-assist TARGET_REPO: ${{ github.repository }} @@ -862,7 +865,7 @@ jobs: export DEBUG="*" export GH_AW_ENGINE="copilot" - export MCP_GATEWAY_DOCKER_COMMAND='docker run -i --rm --network host -v /var/run/docker.sock:/var/run/docker.sock -e MCP_GATEWAY_PORT -e MCP_GATEWAY_DOMAIN -e MCP_GATEWAY_API_KEY -e MCP_GATEWAY_PAYLOAD_DIR -e MCP_GATEWAY_PAYLOAD_SIZE_THRESHOLD -e DEBUG -e MCP_GATEWAY_LOG_DIR -e GH_AW_MCP_LOG_DIR -e GH_AW_SAFE_OUTPUTS -e GH_AW_SAFE_OUTPUTS_CONFIG_PATH -e GH_AW_SAFE_OUTPUTS_TOOLS_PATH -e GH_AW_ASSETS_BRANCH -e GH_AW_ASSETS_MAX_SIZE_KB -e GH_AW_ASSETS_ALLOWED_EXTS -e DEFAULT_BRANCH -e GITHUB_MCP_SERVER_TOKEN -e GITHUB_MCP_GUARD_MIN_INTEGRITY -e GITHUB_MCP_GUARD_REPOS -e GITHUB_REPOSITORY -e GITHUB_SERVER_URL -e GITHUB_SHA -e GITHUB_WORKSPACE -e GITHUB_TOKEN -e GITHUB_RUN_ID -e GITHUB_RUN_NUMBER -e GITHUB_RUN_ATTEMPT -e GITHUB_JOB -e GITHUB_ACTION -e GITHUB_EVENT_NAME -e GITHUB_EVENT_PATH -e GITHUB_ACTOR -e GITHUB_ACTOR_ID -e GITHUB_TRIGGERING_ACTOR -e GITHUB_WORKFLOW -e GITHUB_WORKFLOW_REF -e GITHUB_WORKFLOW_SHA -e GITHUB_REF -e GITHUB_REF_NAME -e GITHUB_REF_TYPE -e GITHUB_HEAD_REF -e GITHUB_BASE_REF -e GH_AW_SAFE_OUTPUTS_PORT -e GH_AW_SAFE_OUTPUTS_API_KEY -v /tmp/gh-aw/mcp-payloads:/tmp/gh-aw/mcp-payloads:rw -v /opt:/opt:ro -v /tmp:/tmp:rw -v '"${GITHUB_WORKSPACE}"':'"${GITHUB_WORKSPACE}"':rw ghcr.io/github/gh-aw-mcpg:v0.1.26' + export MCP_GATEWAY_DOCKER_COMMAND='docker run -i --rm --network host -v /var/run/docker.sock:/var/run/docker.sock -e MCP_GATEWAY_PORT -e MCP_GATEWAY_DOMAIN -e MCP_GATEWAY_API_KEY -e MCP_GATEWAY_PAYLOAD_DIR -e MCP_GATEWAY_PAYLOAD_SIZE_THRESHOLD -e DEBUG -e MCP_GATEWAY_LOG_DIR -e GH_AW_MCP_LOG_DIR -e GH_AW_SAFE_OUTPUTS -e GH_AW_SAFE_OUTPUTS_CONFIG_PATH -e GH_AW_SAFE_OUTPUTS_TOOLS_PATH -e GH_AW_ASSETS_BRANCH -e GH_AW_ASSETS_MAX_SIZE_KB -e GH_AW_ASSETS_ALLOWED_EXTS -e DEFAULT_BRANCH -e GITHUB_MCP_SERVER_TOKEN -e GITHUB_MCP_GUARD_MIN_INTEGRITY -e GITHUB_MCP_GUARD_REPOS -e GITHUB_REPOSITORY -e GITHUB_SERVER_URL -e GITHUB_SHA -e GITHUB_WORKSPACE -e GITHUB_TOKEN -e GITHUB_RUN_ID -e GITHUB_RUN_NUMBER -e GITHUB_RUN_ATTEMPT -e GITHUB_JOB -e GITHUB_ACTION -e GITHUB_EVENT_NAME -e GITHUB_EVENT_PATH -e GITHUB_ACTOR -e GITHUB_ACTOR_ID -e GITHUB_TRIGGERING_ACTOR -e GITHUB_WORKFLOW -e GITHUB_WORKFLOW_REF -e GITHUB_WORKFLOW_SHA -e GITHUB_REF -e GITHUB_REF_NAME -e GITHUB_REF_TYPE -e GITHUB_HEAD_REF -e GITHUB_BASE_REF -e GH_AW_SAFE_OUTPUTS_PORT -e GH_AW_SAFE_OUTPUTS_API_KEY -v /tmp/gh-aw/mcp-payloads:/tmp/gh-aw/mcp-payloads:rw -v /opt:/opt:ro -v /tmp:/tmp:rw -v '"${GITHUB_WORKSPACE}"':'"${GITHUB_WORKSPACE}"':rw ghcr.io/github/gh-aw-mcpg:local' mkdir -p /home/runner/.copilot cat << GH_AW_MCP_CONFIG_EOF | bash ${RUNNER_TEMP}/gh-aw/actions/start_mcp_gateway.sh diff --git a/internal/cmd/proxy.go b/internal/cmd/proxy.go index 484eca18e..82bab6164 100644 --- a/internal/cmd/proxy.go +++ b/internal/cmd/proxy.go @@ -18,15 +18,16 @@ import ( // Proxy subcommand flag variables var ( - proxyGuardWasm string - proxyPolicy string - proxyToken string - proxyListen string - proxyLogDir string - proxyDIFCMode string - proxyAPIURL string - proxyTLS bool - proxyTLSDir string + proxyGuardWasm string + proxyPolicy string + proxyToken string + proxyListen string + proxyLogDir string + proxyDIFCMode string + proxyAPIURL string + proxyTLS bool + proxyTLSDir string + proxyTrustedBots []string ) func init() { @@ -94,6 +95,7 @@ Local usage: cmd.Flags().StringVar(&proxyAPIURL, "github-api-url", proxy.DefaultGitHubAPIBase, "Upstream GitHub API URL") cmd.Flags().BoolVar(&proxyTLS, "tls", false, "Enable HTTPS with auto-generated self-signed certificates") cmd.Flags().StringVar(&proxyTLSDir, "tls-dir", "", "Directory for TLS certificates (default: /proxy-tls)") + cmd.Flags().StringSliceVar(&proxyTrustedBots, "trusted-bots", nil, "Additional trusted bot usernames (comma-separated, extends built-in list)") // Only require --guard-wasm when no baked-in guard is available if defaultGuard == "" { @@ -111,7 +113,7 @@ func runProxy(cmd *cobra.Command, args []string) error { if err := logger.InitFileLogger(proxyLogDir, "proxy.log"); err != nil { log.Printf("Warning: Failed to initialize file logger: %v", err) } - if err := logger.InitJSONLLogger(proxyLogDir, "proxy-rpc.jsonl"); err != nil { + if err := logger.InitJSONLLogger(proxyLogDir, "rpc-messages.jsonl"); err != nil { log.Printf("Warning: Failed to initialize JSONL logger: %v", err) } @@ -141,6 +143,7 @@ func runProxy(cmd *cobra.Command, args []string) error { GitHubToken: token, GitHubAPIURL: proxyAPIURL, DIFCMode: proxyDIFCMode, + TrustedBots: proxyTrustedBots, }) if err != nil { return fmt.Errorf("failed to create proxy server: %w", err) diff --git a/internal/proxy/graphql_rewrite.go b/internal/proxy/graphql_rewrite.go new file mode 100644 index 000000000..3c1359ac0 --- /dev/null +++ b/internal/proxy/graphql_rewrite.go @@ -0,0 +1,143 @@ +package proxy + +import ( + "encoding/json" + "regexp" + "strings" +) + +// guardRequiredFields lists the GraphQL selection fields the DIFC guard needs +// for accurate integrity labeling. author{login} enables trusted-bot detection; +// authorAssociation provides the integrity level directly (MEMBER, CONTRIBUTOR, +// etc.) so the guard doesn't need extra enrichment REST round-trips. +var guardRequiredFields = []struct { + field string // field text to inject + present *regexp.Regexp // pattern that indicates the field is already selected +}{ + {"author{login}", regexp.MustCompile(`\bauthor\s*\{[^}]*\blogin\b`)}, + {"authorAssociation", regexp.MustCompile(`\bauthorAssociation\b`)}, +} + +// allGuardFieldsPresent returns true if the query already contains every +// required guard field. +func allGuardFieldsPresent(query string) bool { + for _, f := range guardRequiredFields { + if !f.present.MatchString(query) { + return false + } + } + return true +} + +// missingGuardFields returns the field strings not yet present in the query. +func missingGuardFields(query string) []string { + var missing []string + for _, f := range guardRequiredFields { + if !f.present.MatchString(query) { + missing = append(missing, f.field) + } + } + return missing +} + +// InjectGuardFields rewrites a GraphQL request body to include fields +// required by the DIFC guard (e.g. author{login} for trusted-bot detection). +// Returns the (possibly modified) body. If injection is not needed or fails, +// the original body is returned unchanged. +func InjectGuardFields(body []byte, toolName string) []byte { + // Only rewrite for tools that need author info + switch toolName { + case "list_issues", "list_pull_requests", "issue_read", "pull_request_read", + "search_issues": + default: + return body + } + + var gql GraphQLRequest + if err := json.Unmarshal(body, &gql); err != nil { + return body + } + + if gql.Query == "" || allGuardFieldsPresent(gql.Query) { + return body + } + + missing := missingGuardFields(gql.Query) + modified := injectFieldsIntoQuery(gql.Query, missing) + if modified == gql.Query { + return body + } + + logGraphQL.Printf("injected %v into GraphQL query for %s", missing, toolName) + + gql.Query = modified + out, err := json.Marshal(gql) + if err != nil { + return body + } + return out +} + +// injectFieldsIntoQuery adds the given fields into the GraphQL query's node +// selection or fragment. Each field string (e.g. "author{login}", +// "authorAssociation") is comma-joined and injected as a single block. +func injectFieldsIntoQuery(query string, fields []string) string { + injection := strings.Join(fields, ",") + + // Step 1: Check if the query uses a fragment spread in the nodes. + // Pattern: nodes { ...fragmentName } + fragmentInNodes := regexp.MustCompile(`nodes\s*\{\s*\.\.\.(\w+)`) + if m := fragmentInNodes.FindStringSubmatch(query); m != nil { + fragName := m[1] + return injectIntoFragment(query, fragName, injection) + } + + // Step 2: No fragment — inject directly into nodes { ... } + nodesPattern := regexp.MustCompile(`(nodes\s*\{)`) + if nodesPattern.MatchString(query) { + return nodesPattern.ReplaceAllString(query, "${1}"+injection+",") + } + + return query +} + +// injectIntoFragment adds a field to the end of a named fragment definition. +// "fragment Name on Type { existing fields }" → "fragment Name on Type { existing fields field }" +func injectIntoFragment(query, fragName, field string) string { + // Match: fragment on { ... } + // We need to find the closing brace of this specific fragment. + fragPrefix := "fragment " + fragName + " on " + idx := strings.Index(query, fragPrefix) + if idx == -1 { + return query + } + + // Find the opening brace of the fragment body + braceStart := strings.Index(query[idx:], "{") + if braceStart == -1 { + return query + } + braceStart += idx + + // Find the matching closing brace (handle nested braces) + depth := 0 + braceEnd := -1 + for i := braceStart; i < len(query); i++ { + if query[i] == '{' { + depth++ + } else if query[i] == '}' { + depth-- + if depth == 0 { + braceEnd = i + break + } + } + } + + if braceEnd == -1 { + return query + } + + // Insert field before the closing brace + return query[:braceEnd] + "," + field + query[braceEnd:] +} diff --git a/internal/proxy/graphql_rewrite_test.go b/internal/proxy/graphql_rewrite_test.go new file mode 100644 index 000000000..9b265169b --- /dev/null +++ b/internal/proxy/graphql_rewrite_test.go @@ -0,0 +1,158 @@ +package proxy + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInjectGuardFields_SkipsIrrelevantTools(t *testing.T) { + body := []byte(`{"query":"{ viewer { login } }"}`) + result := InjectGuardFields(body, "get_me") + assert.Equal(t, body, result) +} + +func TestInjectGuardFields_SkipsWhenFieldsPresent(t *testing.T) { + query := `query { repository(owner:"o", name:"r") { pullRequests(first:10) { nodes { number author{login} authorAssociation } } } }` + body, _ := json.Marshal(GraphQLRequest{Query: query}) + result := InjectGuardFields(body, "list_pull_requests") + assert.Equal(t, body, result) +} + +func TestInjectGuardFields_InjectsIntoNodes(t *testing.T) { + query := `query { repository(owner:"o", name:"r") { pullRequests(first:10) { nodes { number title } } } }` + body, _ := json.Marshal(GraphQLRequest{Query: query}) + + result := InjectGuardFields(body, "list_pull_requests") + + var gql GraphQLRequest + require.NoError(t, json.Unmarshal(result, &gql)) + assert.Contains(t, gql.Query, "author{login}") + assert.Contains(t, gql.Query, "authorAssociation") + // Original fields still present + assert.Contains(t, gql.Query, "number") + assert.Contains(t, gql.Query, "title") +} + +func TestInjectGuardFields_InjectsIntoFragment(t *testing.T) { + query := `fragment pr on PullRequest{number,title} +query { repository(owner:"o", name:"r") { pullRequests(first:10) { nodes { ...pr } } } }` + body, _ := json.Marshal(GraphQLRequest{Query: query}) + + result := InjectGuardFields(body, "list_pull_requests") + + var gql GraphQLRequest + require.NoError(t, json.Unmarshal(result, &gql)) + assert.Contains(t, gql.Query, "author{login}") + assert.Contains(t, gql.Query, "authorAssociation") + // Fragment still intact + assert.Contains(t, gql.Query, "fragment pr on PullRequest") + assert.Contains(t, gql.Query, "number") +} + +func TestInjectGuardFields_InjectsOnlyMissing(t *testing.T) { + // Has author{login} but not authorAssociation + query := `query { repository(owner:"o", name:"r") { issues(first:10) { nodes { number author{login} } } } }` + body, _ := json.Marshal(GraphQLRequest{Query: query}) + + result := InjectGuardFields(body, "list_issues") + + var gql GraphQLRequest + require.NoError(t, json.Unmarshal(result, &gql)) + assert.Contains(t, gql.Query, "authorAssociation") + // Should not double-inject author{login} + assert.Equal(t, 1, countOccurrences(gql.Query, "author{login}")) +} + +func TestInjectGuardFields_HandlesIssues(t *testing.T) { + query := `query { repository(owner:"o", name:"r") { issues(first:10) { nodes { number labels } } } }` + body, _ := json.Marshal(GraphQLRequest{Query: query}) + + result := InjectGuardFields(body, "list_issues") + + var gql GraphQLRequest + require.NoError(t, json.Unmarshal(result, &gql)) + assert.Contains(t, gql.Query, "author{login}") + assert.Contains(t, gql.Query, "authorAssociation") +} + +func TestInjectGuardFields_PreservesVariables(t *testing.T) { + query := `query($owner:String!,$repo:String!) { repository(owner:$owner, name:$repo) { pullRequests(first:10) { nodes { number } } } }` + vars := map[string]interface{}{"owner": "github", "repo": "gh-aw-mcpg"} + body, _ := json.Marshal(GraphQLRequest{Query: query, Variables: vars}) + + result := InjectGuardFields(body, "list_pull_requests") + + var gql GraphQLRequest + require.NoError(t, json.Unmarshal(result, &gql)) + assert.Contains(t, gql.Query, "author{login}") + assert.Equal(t, "github", gql.Variables["owner"]) + assert.Equal(t, "gh-aw-mcpg", gql.Variables["repo"]) +} + +func TestInjectGuardFields_InvalidJSON(t *testing.T) { + body := []byte(`not json`) + result := InjectGuardFields(body, "list_pull_requests") + assert.Equal(t, body, result) +} + +func TestInjectGuardFields_RealGhCliQuery(t *testing.T) { + // Actual query from `gh pr list --json number,title` + query := `fragment pr on PullRequest{number,title} + query PullRequestList( + $owner: String!, + $repo: String!, + $limit: Int!, + $endCursor: String, + $state: [PullRequestState!] = OPEN + ) { + repository(owner: $owner, name: $repo) { + pullRequests( + states: $state, + first: $limit, + after: $endCursor, + orderBy: {field: CREATED_AT, direction: DESC} + ) { + totalCount + nodes { + ...pr + } + pageInfo { + hasNextPage + endCursor + } + } + } + }` + body, _ := json.Marshal(GraphQLRequest{Query: query}) + + result := InjectGuardFields(body, "list_pull_requests") + + var gql GraphQLRequest + require.NoError(t, json.Unmarshal(result, &gql)) + assert.Contains(t, gql.Query, "author{login}") + assert.Contains(t, gql.Query, "authorAssociation") + // Injected into fragment, not nodes + assert.Contains(t, gql.Query, "fragment pr on PullRequest{number,title,author{login},authorAssociation}") +} + +func TestInjectIntoFragment_NestedBraces(t *testing.T) { + query := `fragment pr on PullRequest{number,labels{nodes{name}}} +query { repository(owner:"o",name:"r") { pullRequests(first:1) { nodes { ...pr } } } }` + + result := injectIntoFragment(query, "pr", "author{login},authorAssociation") + assert.Contains(t, result, "labels{nodes{name}}") + assert.Contains(t, result, "author{login},authorAssociation}") +} + +func countOccurrences(s, substr string) int { + count := 0 + for i := 0; i+len(substr) <= len(s); i++ { + if s[i:i+len(substr)] == substr { + count++ + } + } + return count +} diff --git a/internal/proxy/handler.go b/internal/proxy/handler.go index e1b9fd78e..840b01a95 100644 --- a/internal/proxy/handler.go +++ b/internal/proxy/handler.go @@ -9,6 +9,7 @@ import ( "net/http" "github.com/github/gh-aw-mcpg/internal/difc" + "github.com/github/gh-aw-mcpg/internal/guard" "github.com/github/gh-aw-mcpg/internal/logger" ) @@ -90,6 +91,10 @@ func (h *proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } toolName = match.ToolName args = match.Args + + // Inject guard-required fields (author{login}, authorAssociation) into + // the GraphQL query so the guard can label items without enrichment. + graphQLBody = InjectGuardFields(graphQLBody, toolName) } else { match := MatchRoute(rawPath) if match == nil { @@ -110,7 +115,7 @@ func (h *proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, path, toolName string, args map[string]interface{}, graphQLBody []byte) { ctx := r.Context() s := h.server - backend := &stubBackendCaller{} + backend := &restBackendCaller{server: s, clientAuth: r.Header.Get("Authorization")} if !s.guardInitialized { log.Printf("[proxy] WARNING: guard not initialized, blocking request") @@ -193,6 +198,10 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa } // **Phase 4: Guard labels the response** + // Store tool_args in context so LabelResponse can pass them to the WASM guard + ctx = guard.SetRequestStateInContext(ctx, map[string]interface{}{ + "tool_args": args, + }) labeledData, err := s.guard.LabelResponse(ctx, toolName, responseData, backend, s.capabilities) if err != nil { logHandler.Printf("[DIFC] Phase 4 failed: %v", err) diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 0284297e4..7818dd506 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -67,6 +67,12 @@ type Config struct { // DIFCMode is the enforcement mode (strict, filter, propagate). DIFCMode string + + // TrustedBots is an optional list of additional trusted bot usernames. + // These are passed to the guard alongside the policy during LabelAgent + // initialization, extending the guard's built-in trusted bot list + // (e.g. dependabot[bot], github-actions[bot]). + TrustedBots []string } // New creates a new proxy Server from the given Config. @@ -122,7 +128,7 @@ func New(ctx context.Context, cfg Config) (*Server, error) { // Initialize guard policy (LabelAgent) if cfg.Policy != "" { logProxy.Printf("Initializing guard policy from config") - if err := s.initGuardPolicy(ctx, cfg.Policy); err != nil { + if err := s.initGuardPolicy(ctx, cfg.Policy, cfg.TrustedBots); err != nil { return nil, fmt.Errorf("failed to initialize guard policy: %w", err) } } else { @@ -133,9 +139,9 @@ func New(ctx context.Context, cfg Config) (*Server, error) { return s, nil } -// initGuardPolicy calls LabelAgent with the provided policy JSON. -func (s *Server) initGuardPolicy(ctx context.Context, policyJSON string) error { - logProxy.Printf("Initializing guard policy: policyJSON_len=%d", len(policyJSON)) +// initGuardPolicy calls LabelAgent with the provided policy JSON and optional trusted bots. +func (s *Server) initGuardPolicy(ctx context.Context, policyJSON string, trustedBots []string) error { + logProxy.Printf("Initializing guard policy: policyJSON_len=%d, trustedBots=%d", len(policyJSON), len(trustedBots)) var policy interface{} if err := json.Unmarshal([]byte(policyJSON), &policy); err != nil { @@ -161,9 +167,12 @@ func (s *Server) initGuardPolicy(ctx context.Context, policyJSON string) error { return fmt.Errorf("policy validation failed: %w", err) } + // Build payload with optional trusted bots + payload := guard.BuildLabelAgentPayload(policy, trustedBots) + logProxy.Printf("Calling LabelAgent to initialize agent labels from guard") - backend := &stubBackendCaller{} - result, err := s.guard.LabelAgent(ctx, policy, backend, s.capabilities) + backend := &restBackendCaller{server: s} + result, err := s.guard.LabelAgent(ctx, payload, backend, s.capabilities) if err != nil { return fmt.Errorf("LabelAgent failed: %w", err) } @@ -200,14 +209,100 @@ func (s *Server) Handler() http.Handler { return &proxyHandler{server: s} } -// stubBackendCaller is a no-op BackendCaller for the proxy. -// The guard receives the full API response in LabelResponse, so it -// does not need to make recursive backend calls. -type stubBackendCaller struct{} +// restBackendCaller translates guard CallTool requests into GitHub REST API +// calls, enabling backend enrichment (author_association, repo visibility, etc.) +// that the WASM guard needs for accurate integrity labeling. +type restBackendCaller struct { + server *Server + clientAuth string +} + +func (r *restBackendCaller) CallTool(ctx context.Context, toolName string, args interface{}) (interface{}, error) { + argsMap, ok := args.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("unexpected args type: %T", args) + } + + var apiPath string + switch toolName { + case "pull_request_read": + owner, _ := argsMap["owner"].(string) + repo, _ := argsMap["repo"].(string) + number, _ := argsMap["pullNumber"].(string) + if number == "" { + if n, ok := argsMap["pullNumber"].(float64); ok { + number = fmt.Sprintf("%d", int(n)) + } + } + if owner == "" || repo == "" || number == "" { + return nil, fmt.Errorf("pull_request_read: missing owner/repo/pullNumber") + } + apiPath = fmt.Sprintf("/repos/%s/%s/pulls/%s", owner, repo, number) + + case "issue_read": + owner, _ := argsMap["owner"].(string) + repo, _ := argsMap["repo"].(string) + number, _ := argsMap["issue_number"].(string) + if number == "" { + if n, ok := argsMap["issue_number"].(float64); ok { + number = fmt.Sprintf("%d", int(n)) + } + } + if owner == "" || repo == "" || number == "" { + return nil, fmt.Errorf("issue_read: missing owner/repo/issue_number") + } + apiPath = fmt.Sprintf("/repos/%s/%s/issues/%s", owner, repo, number) + + case "search_repositories": + query, _ := argsMap["query"].(string) + if query == "" { + return nil, fmt.Errorf("search_repositories: missing query") + } + perPage := "10" + if pp, ok := argsMap["perPage"].(float64); ok { + perPage = fmt.Sprintf("%d", int(pp)) + } + apiPath = fmt.Sprintf("/search/repositories?q=%s&per_page=%s", query, perPage) + + default: + logProxy.Printf("restBackendCaller: unsupported tool %s", toolName) + return nil, fmt.Errorf("unsupported tool: %s", toolName) + } + + logProxy.Printf("restBackendCaller: %s → GET %s", toolName, apiPath) + + // Use the server's configured token for enrichment calls rather than the + // client's auth header. Enrichment needs org-level visibility (e.g. to get + // correct author_association) which the client's GITHUB_TOKEN may lack. + enrichmentAuth := "" + if r.server.githubToken != "" { + enrichmentAuth = "token " + r.server.githubToken + } else if r.clientAuth != "" { + enrichmentAuth = r.clientAuth + } + resp, err := r.server.forwardToGitHub(ctx, "GET", apiPath, nil, "", enrichmentAuth) + if err != nil { + return nil, fmt.Errorf("REST call failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } -func (s *stubBackendCaller) CallTool(_ context.Context, toolName string, _ interface{}) (interface{}, error) { - logProxy.Printf("stub BackendCaller: ignoring CallTool(%s) — proxy provides full responses", toolName) - return nil, fmt.Errorf("CallTool not supported in proxy mode") + if resp.StatusCode >= 400 { + logProxy.Printf("restBackendCaller: %s returned %d", toolName, resp.StatusCode) + return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode) + } + + // Wrap in MCP response format: {content: [{type: "text", text: "..."}]} + mcpResp := map[string]interface{}{ + "content": []map[string]interface{}{ + {"type": "text", "text": string(body)}, + }, + } + return mcpResp, nil } // forwardToGitHub sends a request to the upstream GitHub API.