diff --git a/internal/cmd/flags.go b/internal/cmd/flags.go index cc53037c9..8bc800532 100644 --- a/internal/cmd/flags.go +++ b/internal/cmd/flags.go @@ -28,12 +28,17 @@ // // Current helpers and their environment variables: // -// flags_logging.go getDefaultLogDir() → MCP_GATEWAY_LOG_DIR -// flags_logging.go getDefaultPayloadDir() → MCP_GATEWAY_PAYLOAD_DIR -// flags_logging.go getDefaultPayloadPathPrefix() → MCP_GATEWAY_PAYLOAD_PATH_PREFIX -// flags_logging.go getDefaultPayloadSizeThreshold() → MCP_GATEWAY_PAYLOAD_SIZE_THRESHOLD -// flags_difc.go getDefaultDIFCMode() → MCP_GATEWAY_GUARDS_MODE -// flags_difc.go getDefaultDIFCSinkServerIDs() → MCP_GATEWAY_GUARDS_SINK_SERVER_IDS +// flags_logging.go getDefaultLogDir() → MCP_GATEWAY_LOG_DIR +// flags_logging.go getDefaultPayloadDir() → MCP_GATEWAY_PAYLOAD_DIR +// flags_logging.go getDefaultPayloadPathPrefix() → MCP_GATEWAY_PAYLOAD_PATH_PREFIX +// flags_logging.go getDefaultPayloadSizeThreshold() → MCP_GATEWAY_PAYLOAD_SIZE_THRESHOLD +// flags_difc.go getDefaultDIFCMode() → MCP_GATEWAY_GUARDS_MODE +// flags_difc.go getDefaultDIFCSinkServerIDs() → MCP_GATEWAY_GUARDS_SINK_SERVER_IDS +// flags_difc.go getDefaultGuardPolicyJSON() → MCP_GATEWAY_GUARD_POLICY_JSON +// flags_difc.go getDefaultAllowOnlyScopePublic() → MCP_GATEWAY_ALLOWONLY_SCOPE_PUBLIC +// flags_difc.go getDefaultAllowOnlyOwner() → MCP_GATEWAY_ALLOWONLY_SCOPE_OWNER +// flags_difc.go getDefaultAllowOnlyRepo() → MCP_GATEWAY_ALLOWONLY_SCOPE_REPO +// flags_difc.go getDefaultAllowOnlyMinIntegrity() → MCP_GATEWAY_ALLOWONLY_MIN_INTEGRITY // // This pattern is intentionally kept in individual feature files because: // - Each helper names the specific environment variable it reads, making the diff --git a/internal/cmd/flags_difc.go b/internal/cmd/flags_difc.go index 4365b340e..1bb6ccdf9 100644 --- a/internal/cmd/flags_difc.go +++ b/internal/cmd/flags_difc.go @@ -13,6 +13,11 @@ import ( "github.com/spf13/cobra" ) +// DIFC flag defaults +const ( + defaultAllowOnlyMinIntegrity = "" +) + // DIFC flag variables var ( difcMode string @@ -27,12 +32,12 @@ var ( func init() { RegisterFlag(func(cmd *cobra.Command) { cmd.Flags().StringVar(&difcMode, "guards-mode", getDefaultDIFCMode(), "Guards enforcement mode: strict (deny violations), filter (remove denied tools), or propagate (auto-adjust agent labels on reads)") - cmd.Flags().StringVar(&difcSinkServerIDs, "guards-sink-server-ids", os.Getenv("MCP_GATEWAY_GUARDS_SINK_SERVER_IDS"), "Comma-separated server IDs whose RPC JSONL logs should include agent secrecy/integrity tag snapshots") - cmd.Flags().StringVar(&guardPolicyJSON, "guard-policy-json", os.Getenv("MCP_GATEWAY_GUARD_POLICY_JSON"), "Guard policy JSON (e.g. {\"allow-only\":{\"repos\":\"public\",\"min-integrity\":\"none\"}})") + cmd.Flags().StringVar(&difcSinkServerIDs, "guards-sink-server-ids", getDefaultDIFCSinkServerIDs(), "Comma-separated server IDs whose RPC JSONL logs should include agent secrecy/integrity tag snapshots") + cmd.Flags().StringVar(&guardPolicyJSON, "guard-policy-json", getDefaultGuardPolicyJSON(), "Guard policy JSON (e.g. {\"allow-only\":{\"repos\":\"public\",\"min-integrity\":\"none\"}})") cmd.Flags().BoolVar(&allowOnlyPublic, "allowonly-scope-public", getDefaultAllowOnlyScopePublic(), "Use public AllowOnly scope") - cmd.Flags().StringVar(&allowOnlyOwner, "allowonly-scope-owner", os.Getenv("MCP_GATEWAY_ALLOWONLY_SCOPE_OWNER"), "AllowOnly owner scope value") - cmd.Flags().StringVar(&allowOnlyRepo, "allowonly-scope-repo", os.Getenv("MCP_GATEWAY_ALLOWONLY_SCOPE_REPO"), "AllowOnly repo name (requires owner)") - cmd.Flags().StringVar(&allowOnlyMinInt, "allowonly-min-integrity", os.Getenv("MCP_GATEWAY_ALLOWONLY_MIN_INTEGRITY"), "AllowOnly integrity: none|unapproved|approved|merged") + cmd.Flags().StringVar(&allowOnlyOwner, "allowonly-scope-owner", getDefaultAllowOnlyOwner(), "AllowOnly owner scope value") + cmd.Flags().StringVar(&allowOnlyRepo, "allowonly-scope-repo", getDefaultAllowOnlyRepo(), "AllowOnly repo name (requires owner)") + cmd.Flags().StringVar(&allowOnlyMinInt, "allowonly-min-integrity", getDefaultAllowOnlyMinIntegrity(), "AllowOnly integrity: none|unapproved|approved|merged") }) } @@ -54,6 +59,36 @@ func getDefaultAllowOnlyScopePublic() bool { return envutil.GetEnvBool("MCP_GATEWAY_ALLOWONLY_SCOPE_PUBLIC", false) } +// getDefaultDIFCSinkServerIDs returns the default DIFC sink server IDs string, +// checking MCP_GATEWAY_GUARDS_SINK_SERVER_IDS environment variable. +func getDefaultDIFCSinkServerIDs() string { + return envutil.GetEnvString("MCP_GATEWAY_GUARDS_SINK_SERVER_IDS", "") +} + +// getDefaultGuardPolicyJSON returns the default guard policy JSON string, +// checking MCP_GATEWAY_GUARD_POLICY_JSON environment variable. +func getDefaultGuardPolicyJSON() string { + return envutil.GetEnvString("MCP_GATEWAY_GUARD_POLICY_JSON", "") +} + +// getDefaultAllowOnlyOwner returns the default AllowOnly owner scope, +// checking MCP_GATEWAY_ALLOWONLY_SCOPE_OWNER environment variable. +func getDefaultAllowOnlyOwner() string { + return envutil.GetEnvString("MCP_GATEWAY_ALLOWONLY_SCOPE_OWNER", "") +} + +// getDefaultAllowOnlyRepo returns the default AllowOnly repo name, +// checking MCP_GATEWAY_ALLOWONLY_SCOPE_REPO environment variable. +func getDefaultAllowOnlyRepo() string { + return envutil.GetEnvString("MCP_GATEWAY_ALLOWONLY_SCOPE_REPO", "") +} + +// getDefaultAllowOnlyMinIntegrity returns the default AllowOnly minimum integrity level, +// checking MCP_GATEWAY_ALLOWONLY_MIN_INTEGRITY environment variable. +func getDefaultAllowOnlyMinIntegrity() string { + return envutil.GetEnvString("MCP_GATEWAY_ALLOWONLY_MIN_INTEGRITY", defaultAllowOnlyMinIntegrity) +} + // ValidateDIFCMode validates the guards mode flag value and returns an error if invalid func ValidateDIFCMode(mode string) error { _, err := difc.ParseEnforcementMode(mode) diff --git a/internal/cmd/flags_difc_test.go b/internal/cmd/flags_difc_test.go index 5050d704d..bbb78cadb 100644 --- a/internal/cmd/flags_difc_test.go +++ b/internal/cmd/flags_difc_test.go @@ -145,20 +145,43 @@ func TestValidDIFCModes(t *testing.T) { } func TestGetDefaultDIFCSinkServerIDs(t *testing.T) { - originalEnv := os.Getenv("MCP_GATEWAY_GUARDS_SINK_SERVER_IDS") - defer func() { - if originalEnv != "" { - os.Setenv("MCP_GATEWAY_GUARDS_SINK_SERVER_IDS", originalEnv) - } else { - os.Unsetenv("MCP_GATEWAY_GUARDS_SINK_SERVER_IDS") - } - }() + tests := []struct { + name string + envValue string + setEnv bool + expected string + }{ + { + name: "no env var - returns empty string", + setEnv: false, + expected: "", + }, + { + name: "env var set - returns value", + envValue: "safeoutputs,github", + setEnv: true, + expected: "safeoutputs,github", + }, + { + name: "empty env var - returns empty string", + envValue: "", + setEnv: true, + expected: "", + }, + } - os.Unsetenv("MCP_GATEWAY_GUARDS_SINK_SERVER_IDS") - assert.Equal(t, "", os.Getenv("MCP_GATEWAY_GUARDS_SINK_SERVER_IDS")) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setEnv { + t.Setenv("MCP_GATEWAY_GUARDS_SINK_SERVER_IDS", tt.envValue) + } else { + t.Setenv("MCP_GATEWAY_GUARDS_SINK_SERVER_IDS", "") + } - os.Setenv("MCP_GATEWAY_GUARDS_SINK_SERVER_IDS", "safeoutputs,github") - assert.Equal(t, "safeoutputs,github", os.Getenv("MCP_GATEWAY_GUARDS_SINK_SERVER_IDS")) + result := getDefaultDIFCSinkServerIDs() + assert.Equal(t, tt.expected, result) + }) + } } func TestParseDIFCSinkServerIDs(t *testing.T) { @@ -246,48 +269,15 @@ func TestBuildAllowOnlyPolicy(t *testing.T) { } func TestGetDefaultGuardPolicyInputs(t *testing.T) { - originalJSON := os.Getenv("MCP_GATEWAY_GUARD_POLICY_JSON") - originalPublic := os.Getenv("MCP_GATEWAY_ALLOWONLY_SCOPE_PUBLIC") - originalOwner := os.Getenv("MCP_GATEWAY_ALLOWONLY_SCOPE_OWNER") - originalRepo := os.Getenv("MCP_GATEWAY_ALLOWONLY_SCOPE_REPO") - originalMin := os.Getenv("MCP_GATEWAY_ALLOWONLY_MIN_INTEGRITY") - defer func() { - if originalJSON != "" { - os.Setenv("MCP_GATEWAY_GUARD_POLICY_JSON", originalJSON) - } else { - os.Unsetenv("MCP_GATEWAY_GUARD_POLICY_JSON") - } - if originalPublic != "" { - os.Setenv("MCP_GATEWAY_ALLOWONLY_SCOPE_PUBLIC", originalPublic) - } else { - os.Unsetenv("MCP_GATEWAY_ALLOWONLY_SCOPE_PUBLIC") - } - if originalOwner != "" { - os.Setenv("MCP_GATEWAY_ALLOWONLY_SCOPE_OWNER", originalOwner) - } else { - os.Unsetenv("MCP_GATEWAY_ALLOWONLY_SCOPE_OWNER") - } - if originalRepo != "" { - os.Setenv("MCP_GATEWAY_ALLOWONLY_SCOPE_REPO", originalRepo) - } else { - os.Unsetenv("MCP_GATEWAY_ALLOWONLY_SCOPE_REPO") - } - if originalMin != "" { - os.Setenv("MCP_GATEWAY_ALLOWONLY_MIN_INTEGRITY", originalMin) - } else { - os.Unsetenv("MCP_GATEWAY_ALLOWONLY_MIN_INTEGRITY") - } - }() - - os.Setenv("MCP_GATEWAY_GUARD_POLICY_JSON", `{"allow-only":{"repos":"public","min-integrity":"none"}}`) - os.Setenv("MCP_GATEWAY_ALLOWONLY_SCOPE_PUBLIC", "1") - os.Setenv("MCP_GATEWAY_ALLOWONLY_SCOPE_OWNER", "lpcox") - os.Setenv("MCP_GATEWAY_ALLOWONLY_SCOPE_REPO", "gh-aw-mcpg") - os.Setenv("MCP_GATEWAY_ALLOWONLY_MIN_INTEGRITY", "unapproved") + t.Setenv("MCP_GATEWAY_GUARD_POLICY_JSON", `{"allow-only":{"repos":"public","min-integrity":"none"}}`) + t.Setenv("MCP_GATEWAY_ALLOWONLY_SCOPE_PUBLIC", "1") + t.Setenv("MCP_GATEWAY_ALLOWONLY_SCOPE_OWNER", "lpcox") + t.Setenv("MCP_GATEWAY_ALLOWONLY_SCOPE_REPO", "gh-aw-mcpg") + t.Setenv("MCP_GATEWAY_ALLOWONLY_MIN_INTEGRITY", "unapproved") - assert.NotEmpty(t, os.Getenv("MCP_GATEWAY_GUARD_POLICY_JSON")) + assert.NotEmpty(t, getDefaultGuardPolicyJSON()) assert.True(t, getDefaultAllowOnlyScopePublic()) - assert.Equal(t, "lpcox", os.Getenv("MCP_GATEWAY_ALLOWONLY_SCOPE_OWNER")) - assert.Equal(t, "gh-aw-mcpg", os.Getenv("MCP_GATEWAY_ALLOWONLY_SCOPE_REPO")) - assert.Equal(t, "unapproved", os.Getenv("MCP_GATEWAY_ALLOWONLY_MIN_INTEGRITY")) + assert.Equal(t, "lpcox", getDefaultAllowOnlyOwner()) + assert.Equal(t, "gh-aw-mcpg", getDefaultAllowOnlyRepo()) + assert.Equal(t, "unapproved", getDefaultAllowOnlyMinIntegrity()) }