Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/ENVIRONMENT_VARIABLES.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ When running locally (`run.sh`), these variables are optional (warnings shown if
| `MCP_GATEWAY_PAYLOAD_DIR` | Large payload storage directory (sets default for `--payload-dir` flag). Must be an absolute path. | `/tmp/jq-payloads` |
| `MCP_GATEWAY_PAYLOAD_PATH_PREFIX` | Path prefix for remapping payloadPath returned to clients (sets default for `--payload-path-prefix` flag) | (empty - use actual filesystem path) |
| `MCP_GATEWAY_PAYLOAD_SIZE_THRESHOLD` | Size threshold in bytes for payload storage (sets default for `--payload-size-threshold` flag) | `524288` |
| `MCP_GATEWAY_URL_DOMAIN_AUDIT` | Enable URL-domain audit logging (sets default for `--url-domain-audit`). When enabled, observed domains are written to `observed-url-domains.json`. | `false` |
| `MCP_GATEWAY_SESSION_TIMEOUT` | Session timeout for stateful sessions in both unified (`/mcp`) and routed (`/mcp/<server>`) modes. Accepts Go duration strings (e.g., `30m`, `1h`). Default is 6 hours to match the GitHub Actions default timeout. | `6h` |
| `MCP_GATEWAY_SHUTDOWN_TIMEOUT` | Maximum time to wait for in-flight requests to complete during graceful shutdown (sets default for `--shutdown-timeout` flag). Accepts Go duration strings (e.g., `30s`, `2m`). | `5s` |
| `MCP_GATEWAY_TOOL_TIMEOUT` | Tool invocation timeout in seconds. Used as fallback when `gateway.toolTimeout` is not set in the stdin JSON config. Accepts any integer ≥ 10 (no upper bound). Priority: stdin `gateway.toolTimeout` > this env var > built-in default. | `60` |
Expand Down
2 changes: 2 additions & 0 deletions internal/cmd/flags_logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ var (
payloadDir string
payloadPathPrefix string
payloadSizeThreshold int
urlDomainAudit bool
wasmCacheDir string
)

Expand All @@ -26,5 +27,6 @@ func init() {
cmd.Flags().StringVar(&wasmCacheDir, "wasm-cache-dir", resolveWasmCacheDir(false, "", envutil.GetEnvString("MCP_GATEWAY_LOG_DIR", config.DefaultLogDir)), "Directory for disk-backed wazero compilation cache (default: sibling of <log-dir>, named wazero-cache)")
cmd.Flags().StringVar(&payloadPathPrefix, "payload-path-prefix", envutil.GetEnvString("MCP_GATEWAY_PAYLOAD_PATH_PREFIX", ""), "Path prefix to use when returning payloadPath to clients (allows remapping host paths to client/agent container paths)")
cmd.Flags().IntVar(&payloadSizeThreshold, "payload-size-threshold", envutil.GetEnvInt("MCP_GATEWAY_PAYLOAD_SIZE_THRESHOLD", config.DefaultPayloadSizeThreshold), "Size threshold (in bytes) for storing payloads to disk. Payloads larger than this are stored, smaller ones returned inline")
cmd.Flags().BoolVar(&urlDomainAudit, "url-domain-audit", envutil.GetEnvBool("MCP_GATEWAY_URL_DOMAIN_AUDIT", false), "Observe and persist URL domains seen in tool responses and safe-output writes")
})
}
1 change: 1 addition & 0 deletions internal/cmd/flags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ func TestRegisterAllFlags(t *testing.T) {
assert.NotNil(t, cmd.Flags().Lookup("log-dir"), "log-dir flag should be registered")
assert.NotNil(t, cmd.Flags().Lookup("payload-dir"), "payload-dir flag should be registered")
assert.NotNil(t, cmd.Flags().Lookup("wasm-cache-dir"), "wasm-cache-dir flag should be registered")
assert.NotNil(t, cmd.Flags().Lookup("url-domain-audit"), "url-domain-audit flag should be registered")

// Tracing flags
assert.NotNil(t, cmd.Flags().Lookup("otlp-endpoint"), "otlp-endpoint flag should be registered")
Expand Down
2 changes: 2 additions & 0 deletions internal/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,8 @@ func run(cmd *cobra.Command, args []string) error {
applyFlagOrEnv(cmd, "payload-dir", &cfg.Gateway.PayloadDir, payloadDir, config.DefaultPayloadDir)
applyFlagOrEnv(cmd, "payload-path-prefix", &cfg.Gateway.PayloadPathPrefix, payloadPathPrefix, "")
applyFlagOrEnv(cmd, "payload-size-threshold", &cfg.Gateway.PayloadSizeThreshold, payloadSizeThreshold, config.DefaultPayloadSizeThreshold)
applyFlagOrEnv(cmd, "url-domain-audit", &cfg.Gateway.URLDomainAudit, urlDomainAudit, false)
logger.SetURLDomainAuditEnabled(cfg.Gateway.URLDomainAudit)

if sequentialLaunch {
log.Println("Sequential server launching enabled")
Expand Down
4 changes: 4 additions & 0 deletions internal/config/config_core.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ type GatewayConfig struct {
// Default: 524288 bytes (512KB)
PayloadSizeThreshold int `toml:"payload_size_threshold" json:"payload_size_threshold,omitempty"`

// URLDomainAudit enables URL domain observation in middleware/guard logging.
// This is currently toggled via CLI/environment and not loaded from config files.
URLDomainAudit bool `toml:"-" json:"-"`

// TrustedBots is an optional list of additional bot usernames that should be treated
// as trusted. Objects authored by these bots receive "approved" integrity regardless
// of their author_association. This list is merged with the guard's built-in trusted
Expand Down
16 changes: 15 additions & 1 deletion internal/guard/write_sink.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/github/gh-aw-mcpg/internal/difc"
"github.com/github/gh-aw-mcpg/internal/logger"
"github.com/github/gh-aw-mcpg/internal/urlutil"
)

var logWriteSink = logger.New("guard:write-sink")
Expand Down Expand Up @@ -84,8 +85,9 @@ func (g *WriteSinkGuard) LabelAgent(_ context.Context, _ interface{}, _ BackendC
// whose secrecy tags are a subset of the accept set can write successfully.
// By leaving the resource integrity empty, the second check also passes
// because the agent has all zero of the (empty) required integrity tags.
func (g *WriteSinkGuard) LabelResource(_ context.Context, toolName string, _ interface{}, _ BackendCaller, _ *difc.Capabilities) (*difc.LabeledResource, difc.OperationType, error) {
func (g *WriteSinkGuard) LabelResource(_ context.Context, toolName string, toolArgs interface{}, _ BackendCaller, _ *difc.Capabilities) (*difc.LabeledResource, difc.OperationType, error) {
logWriteSink.Printf("LabelResource: tool=%s, operation=write, accept_tags=%d", toolName, len(g.acceptTags))
g.auditURLsInBody(toolName, toolArgs)

resource := &difc.LabeledResource{
Description: "write-sink (" + toolName + ")",
Expand All @@ -96,6 +98,18 @@ func (g *WriteSinkGuard) LabelResource(_ context.Context, toolName string, _ int
return resource, difc.OperationWrite, nil
}

func (g *WriteSinkGuard) auditURLsInBody(toolName string, args interface{}) {
if !logger.URLDomainAuditEnabled() || args == nil {
return
}
domains := urlutil.ExtractURLDomainsFromValue(args)
if len(domains) == 0 {
return
}
logger.LogDebug("write-sink", "URL domains in write body: tool=%s domains=%v", toolName, domains)
logger.LogObservedURLDomains("write-sink", domains)
}

// LabelResponse returns nil; the write-sink does not perform fine-grained
// response labeling since all operations are writes (responses are confirmations).
func (g *WriteSinkGuard) LabelResponse(_ context.Context, _ string, _ interface{}, _ BackendCaller, _ *difc.Capabilities) (difc.LabeledData, error) {
Expand Down
40 changes: 40 additions & 0 deletions internal/guard/write_sink_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@ package guard

import (
"context"
"encoding/json"
"os"
"path/filepath"
"testing"

"github.com/github/gh-aw-mcpg/internal/difc"
"github.com/github/gh-aw-mcpg/internal/logger"
"github.com/github/gh-aw-mcpg/internal/urlutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -79,6 +84,41 @@ func TestWriteSinkGuard_LabelResponse(t *testing.T) {
assert.Nil(t, data, "write-sink should not label responses")
}

func TestWriteSinkExtractURLDomainsFromValue(t *testing.T) {
args := map[string]any{
"body": "See https://example.com/path and https://EXAMPLE.com/other",
"references": []any{
map[string]any{"url": "http://docs.github.com/en"},
},
}

assert.Equal(t, []string{"docs.github.com", "example.com"}, urlutil.ExtractURLDomainsFromValue(args))
}

func TestWriteSinkGuard_LabelResource_AuditsURLs(t *testing.T) {
logDir := t.TempDir()
logger.InitGatewayLoggers(logDir)
t.Cleanup(func() {
logger.SetURLDomainAuditEnabled(false)
require.NoError(t, logger.CloseAllLoggers())
})
logger.SetURLDomainAuditEnabled(true)

g := NewWriteSinkGuard([]string{"*"})
resource, operation, err := g.LabelResource(context.Background(), "create_issue", map[string]any{
"body": "Refs: https://example.com/a https://golang.org/doc",
}, nil, nil)
require.NoError(t, err)
require.NotNil(t, resource)
assert.Equal(t, difc.OperationWrite, operation)

content, err := os.ReadFile(filepath.Join(logDir, "observed-url-domains.json"))
require.NoError(t, err)
var observed map[string][]string
require.NoError(t, json.Unmarshal(content, &observed))
assert.Equal(t, []string{"example.com", "golang.org"}, observed["write-sink"])
}

func TestWriteSinkGuard_WriteEvaluation_Passes(t *testing.T) {
// End-to-end: simulate the exact DIFC flow that was failing with noop guard.
// Agent has secrecy from reading a private repo; write-sink accepts it.
Expand Down
4 changes: 2 additions & 2 deletions internal/logger/global_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ func withMutexLock(mu *sync.Mutex, fn func() error) error {
}

// closableLogger is a constraint for types that have a Close method.
// This is satisfied by *FileLogger, *JSONLLogger, *MarkdownLogger, *ServerFileLogger, and *ToolsLogger.
// This is satisfied by *FileLogger, *JSONLLogger, *MarkdownLogger, *ObservedURLDomainsLogger, *ServerFileLogger, and *ToolsLogger.
type closableLogger interface {
*FileLogger | *JSONLLogger | *MarkdownLogger | *ServerFileLogger | *ToolsLogger
*FileLogger | *JSONLLogger | *MarkdownLogger | *ServerFileLogger | *ToolsLogger | *ObservedURLDomainsLogger
Close() error
Comment on lines 54 to 58
}

Expand Down
3 changes: 3 additions & 0 deletions internal/logger/init_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ func TestLoggerRegistries(t *testing.T) {
"markdown logger",
"JSONL logger",
"tools logger",
"observed URL domains logger",
}, names)
})

Expand Down Expand Up @@ -52,6 +53,7 @@ func TestLoggerRegistries(t *testing.T) {
"markdown logger",
"tools logger",
"server file logger",
"observed URL domains logger",
}, names)
})
}
Expand Down Expand Up @@ -114,6 +116,7 @@ func TestInitGatewayLoggers(t *testing.T) {
"gateway.md",
"rpc-messages.jsonl",
"tools.json",
"observed-url-domains.json",
}
for _, f := range expectedFiles {
path := filepath.Join(logDir, f)
Expand Down
161 changes: 161 additions & 0 deletions internal/logger/observed_url_domains_logger.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package logger

import (
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"sort"
"sync"
"sync/atomic"
)

const observedURLDomainsFileName = "observed-url-domains.json"

var urlDomainAuditEnabled atomic.Bool

// SetURLDomainAuditEnabled toggles URL domain observation for middleware and guards.
func SetURLDomainAuditEnabled(enabled bool) {
urlDomainAuditEnabled.Store(enabled)
}

// URLDomainAuditEnabled reports whether URL domain observation is enabled.
func URLDomainAuditEnabled() bool {
return urlDomainAuditEnabled.Load()
}

// ObservedURLDomainsLogger manages unique observed URL domains grouped by server ID.
type ObservedURLDomainsLogger struct {
lockable
logDir string
fileName string
data map[string]map[string]struct{}
useFallback bool
}

var (
globalObservedURLDomainsLogger *ObservedURLDomainsLogger
globalObservedURLDomainsMu sync.RWMutex
)

func setupObservedURLDomainsLogger(file *os.File, logDir, fileName string) (*ObservedURLDomainsLogger, error) {
if file != nil {
file.Close()
}

l := &ObservedURLDomainsLogger{
logDir: logDir,
fileName: fileName,
data: make(map[string]map[string]struct{}),
}
if err := l.writeToFile(); err != nil {
return nil, err
}
log.Printf("Observed URL domains logging to file: %s", filepath.Join(logDir, fileName))
return l, nil
}

func handleObservedURLDomainsLoggerError(err error, logDir, fileName string) (*ObservedURLDomainsLogger, error) {
logFallbackWarnings(err, "Failed to initialize observed URL domains log file", "Observed URL domains logging disabled")
return &ObservedURLDomainsLogger{
logDir: logDir,
fileName: fileName,
data: make(map[string]map[string]struct{}),
useFallback: true,
}, nil
}

var observedURLDomainsLoggerFactory = loggerFactory[*ObservedURLDomainsLogger]{
setup: setupObservedURLDomainsLogger,
onError: handleObservedURLDomainsLoggerError,
}

// InitObservedURLDomainsLogger initializes observed-url-domains.json logger.
func InitObservedURLDomainsLogger(logDir, fileName string) error {
l, err := initLogger(logDir, fileName, os.O_TRUNC, observedURLDomainsLoggerFactory)
initGlobalLogger(&globalObservedURLDomainsMu, &globalObservedURLDomainsLogger, l)
return err
}

// LogDomains logs unique domains for a server ID.
func (l *ObservedURLDomainsLogger) LogDomains(serverID string, domains []string) error {
if serverID == "" || len(domains) == 0 {
return nil
}

return l.withLock(func() error {
if l.useFallback {
return nil
}

serverDomains, ok := l.data[serverID]
if !ok {
serverDomains = make(map[string]struct{})
l.data[serverID] = serverDomains
}

changed := false
for _, d := range domains {
if d == "" {
continue
}
if _, exists := serverDomains[d]; exists {
continue
}
serverDomains[d] = struct{}{}
changed = true
}

if !changed {
return nil
}
return l.writeToFile()
})
}

func (l *ObservedURLDomainsLogger) writeToFile() error {
serialized := make(map[string][]string, len(l.data))
for serverID, domains := range l.data {
items := make([]string, 0, len(domains))
for domain := range domains {
items = append(items, domain)
}
sort.Strings(items)
serialized[serverID] = items
}

jsonData, err := json.MarshalIndent(serialized, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal observed URL domains: %w", err)
}

filePath := filepath.Join(l.logDir, l.fileName)
tempPath := filePath + ".tmp"
if err := os.WriteFile(tempPath, jsonData, 0600); err != nil {
return fmt.Errorf("failed to write temp file: %w", err)
}
if err := os.Rename(tempPath, filePath); err != nil {
if removeErr := os.Remove(tempPath); removeErr != nil && !os.IsNotExist(removeErr) {
log.Printf("WARNING: Failed to cleanup temp observed URL domains file %s: %v", tempPath, removeErr)
}
return fmt.Errorf("failed to rename temp file: %w", err)
}
return nil
}

func (l *ObservedURLDomainsLogger) Close() error { return nil }

// LogObservedURLDomains appends newly observed domains for a server.
func LogObservedURLDomains(serverID string, domains []string) {
withGlobalLogger(&globalObservedURLDomainsMu, &globalObservedURLDomainsLogger, func(l *ObservedURLDomainsLogger) {
if err := l.LogDomains(serverID, domains); err != nil {
log.Printf("WARNING: Failed to log observed URL domains for server %s: %v", serverID, err)
}
})
}

// CloseObservedURLDomainsLogger closes the global observed URL domains logger.
func CloseObservedURLDomainsLogger() error {
return closeGlobalLogger(&globalObservedURLDomainsMu, &globalObservedURLDomainsLogger)
}
10 changes: 10 additions & 0 deletions internal/logger/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ var gatewayLoggerInitializers = []loggerInitEntry{
return InitToolsLogger(logDir, "tools.json")
},
},
{
name: "observed URL domains logger",
init: func(logDir string) error {
return InitObservedURLDomainsLogger(logDir, observedURLDomainsFileName)
},
},
}

var proxyLoggerInitializers = []loggerInitEntry{
Expand Down Expand Up @@ -83,6 +89,10 @@ var globalLoggerClosers = []loggerCloseEntry{
name: "server file logger",
close: CloseServerFileLogger,
},
{
name: "observed URL domains logger",
close: CloseObservedURLDomainsLogger,
},
}

func initLoggerSet(logDir string, entries []loggerInitEntry) {
Expand Down
Loading
Loading